Source code for recbole.model.knowledge_aware_recommender.kgin

# -*- coding: utf-8 -*-
# @Time   : 2021/3/25
# @Author : Wenqi Sun
# @Email  : wenqisun@pku.edu.cn

# UPDATE:
# @Time   : 2022/8/31
# @Author : Bowen Zheng
# @Email  : 18735382001@163.com

r"""
KGIN
##################################################
Reference:
    Xiang Wang et al. "Learning Intents behind Interactions with Knowledge Graph for Recommendation." in WWW 2021.
Reference code:
    https://github.com/huangtinglin/Knowledge_Graph_based_Intent_Network
"""

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import scipy.sparse as sp

from recbole.model.abstract_recommender import KnowledgeRecommender
from recbole.model.init import xavier_uniform_initialization
from recbole.model.layers import SparseDropout
from recbole.model.loss import BPRLoss, EmbLoss
from recbole.utils import InputType


[docs]class Aggregator(nn.Module): """ Relational Path-aware Convolution Network """ def __init__( self, ): super(Aggregator, self).__init__()
[docs] def forward( self, entity_emb, user_emb, latent_emb, relation_emb, edge_index, edge_type, interact_mat, disen_weight_att, ): from torch_scatter import scatter_mean n_entities = entity_emb.shape[0] """KG aggregate""" head, tail = edge_index edge_relation_emb = relation_emb[edge_type] neigh_relation_emb = ( entity_emb[tail] * edge_relation_emb ) # [-1, embedding_size] entity_agg = scatter_mean( src=neigh_relation_emb, index=head, dim_size=n_entities, dim=0 ) """cul user->latent factor attention""" score_ = torch.mm(user_emb, latent_emb.t()) score = nn.Softmax(dim=1)(score_) # [n_users, n_factors] """user aggregate""" user_agg = torch.sparse.mm( interact_mat, entity_emb ) # [n_users, embedding_size] disen_weight = torch.mm( nn.Softmax(dim=-1)(disen_weight_att), relation_emb ) # [n_factors, embedding_size] user_agg = ( torch.mm(score, disen_weight) ) * user_agg + user_agg # [n_users, embedding_size] return entity_agg, user_agg
[docs]class GraphConv(nn.Module): """ Graph Convolutional Network """ def __init__( self, embedding_size, n_hops, n_users, n_factors, n_relations, edge_index, edge_type, interact_mat, ind, tmp, device, node_dropout_rate=0.5, mess_dropout_rate=0.1, ): super(GraphConv, self).__init__() self.embedding_size = embedding_size self.n_hops = n_hops self.n_relations = n_relations self.n_users = n_users self.n_factors = n_factors self.edge_index = edge_index self.edge_type = edge_type self.interact_mat = interact_mat self.node_dropout_rate = node_dropout_rate self.mess_dropout_rate = mess_dropout_rate self.ind = ind self.temperature = tmp self.device = device # define layers self.relation_embedding = nn.Embedding(self.n_relations, self.embedding_size) disen_weight_att = nn.init.xavier_uniform_(torch.empty(n_factors, n_relations)) self.disen_weight_att = nn.Parameter(disen_weight_att) self.convs = nn.ModuleList() for i in range(self.n_hops): self.convs.append(Aggregator()) self.node_dropout = SparseDropout(p=self.mess_dropout_rate) # node dropout self.mess_dropout = nn.Dropout(p=self.mess_dropout_rate) # mess dropout # parameters initialization self.apply(xavier_uniform_initialization)
[docs] def edge_sampling(self, edge_index, edge_type, rate=0.5): # edge_index: [2, -1] # edge_type: [-1] n_edges = edge_index.shape[1] random_indices = np.random.choice( n_edges, size=int(n_edges * rate), replace=False ) return edge_index[:, random_indices], edge_type[random_indices]
[docs] def forward(self, user_emb, entity_emb, latent_emb): """node dropout""" # node dropout if self.node_dropout_rate > 0.0: edge_index, edge_type = self.edge_sampling( self.edge_index, self.edge_type, self.node_dropout_rate ) interact_mat = self.node_dropout(self.interact_mat) else: edge_index, edge_type = self.edge_index, self.edge_type interact_mat = self.interact_mat entity_res_emb = entity_emb # [n_entities, embedding_size] user_res_emb = user_emb # [n_users, embedding_size] relation_emb = self.relation_embedding.weight # [n_relations, embedding_size] for i in range(len(self.convs)): entity_emb, user_emb = self.convs[i]( entity_emb, user_emb, latent_emb, relation_emb, edge_index, edge_type, interact_mat, self.disen_weight_att, ) """message dropout""" if self.mess_dropout_rate > 0.0: entity_emb = self.mess_dropout(entity_emb) user_emb = self.mess_dropout(user_emb) entity_emb = F.normalize(entity_emb) user_emb = F.normalize(user_emb) """result emb""" entity_res_emb = torch.add(entity_res_emb, entity_emb) user_res_emb = torch.add(user_res_emb, user_emb) return ( entity_res_emb, user_res_emb, self.calculate_cor_loss(self.disen_weight_att), )
[docs] def calculate_cor_loss(self, tensors): def CosineSimilarity(tensor_1, tensor_2): # tensor_1, tensor_2: [channel] normalized_tensor_1 = F.normalize(tensor_1, dim=0) normalized_tensor_2 = F.normalize(tensor_2, dim=0) return (normalized_tensor_1 * normalized_tensor_2).sum( dim=0 ) ** 2 # no negative def DistanceCorrelation(tensor_1, tensor_2): # tensor_1, tensor_2: [channel] # ref: https://en.wikipedia.org/wiki/Distance_correlation channel = tensor_1.shape[0] zeros = torch.zeros(channel, channel).to(tensor_1.device) zero = torch.zeros(1).to(tensor_1.device) tensor_1, tensor_2 = tensor_1.unsqueeze(-1), tensor_2.unsqueeze(-1) """cul distance matrix""" a_, b_ = ( torch.matmul(tensor_1, tensor_1.t()) * 2, torch.matmul(tensor_2, tensor_2.t()) * 2, ) # [channel, channel] tensor_1_square, tensor_2_square = tensor_1**2, tensor_2**2 a, b = torch.sqrt( torch.max(tensor_1_square - a_ + tensor_1_square.t(), zeros) + 1e-8 ), torch.sqrt( torch.max(tensor_2_square - b_ + tensor_2_square.t(), zeros) + 1e-8 ) # [channel, channel] """cul distance correlation""" A = a - a.mean(dim=0, keepdim=True) - a.mean(dim=1, keepdim=True) + a.mean() B = b - b.mean(dim=0, keepdim=True) - b.mean(dim=1, keepdim=True) + b.mean() dcov_AB = torch.sqrt(torch.max((A * B).sum() / channel**2, zero) + 1e-8) dcov_AA = torch.sqrt(torch.max((A * A).sum() / channel**2, zero) + 1e-8) dcov_BB = torch.sqrt(torch.max((B * B).sum() / channel**2, zero) + 1e-8) return dcov_AB / torch.sqrt(dcov_AA * dcov_BB + 1e-8) def MutualInformation(tensors): # tensors: [n_factors, dimension] # normalized_tensors: [n_factors, dimension] normalized_tensors = F.normalize(tensors, dim=1) scores = torch.mm(normalized_tensors, normalized_tensors.t()) scores = torch.exp(scores / self.temperature) cor_loss = -torch.sum(torch.log(scores.diag() / scores.sum(1))) return cor_loss """cul similarity for each latent factor weight pairs""" if self.ind == "mi": return MutualInformation(tensors) elif self.ind == "distance": cor_loss = 0.0 for i in range(self.n_factors): for j in range(i + 1, self.n_factors): cor_loss += DistanceCorrelation(tensors[i], tensors[j]) elif self.ind == "cosine": cor_loss = 0.0 for i in range(self.n_factors): for j in range(i + 1, self.n_factors): cor_loss += CosineSimilarity(tensors[i], tensors[j]) else: raise NotImplementedError( f"The independence loss type [{self.ind}] has not been supported." ) return cor_loss
[docs]class KGIN(KnowledgeRecommender): r"""KGIN is a knowledge-aware recommendation model. It combines knowledge graph and the user-item interaction graph to a new graph called collaborative knowledge graph (CKG). This model explores intents behind a user-item interaction by using auxiliary item knowledge. """ input_type = InputType.PAIRWISE def __init__(self, config, dataset): super(KGIN, self).__init__(config, dataset) # load parameters info self.embedding_size = config["embedding_size"] self.n_factors = config["n_factors"] self.context_hops = config["context_hops"] self.node_dropout_rate = config["node_dropout_rate"] self.mess_dropout_rate = config["mess_dropout_rate"] self.ind = config["ind"] self.sim_decay = config["sim_regularity"] self.reg_weight = config["reg_weight"] self.temperature = config["temperature"] # load dataset info self.inter_matrix = dataset.inter_matrix(form="coo").astype( np.float32 ) # [n_users, n_items] # inter_matrix: [n_users, n_entities]; inter_graph: [n_users + n_entities, n_users + n_entities] self.interact_mat, _ = self.get_norm_inter_matrix(mode="si") self.kg_graph = dataset.kg_graph( form="coo", value_field="relation_id" ) # [n_entities, n_entities] # edge_index: [2, -1]; edge_type: [-1,] self.edge_index, self.edge_type = self.get_edges(self.kg_graph) # define layers and loss self.n_nodes = self.n_users + self.n_entities self.user_embedding = nn.Embedding(self.n_users, self.embedding_size) self.entity_embedding = nn.Embedding(self.n_entities, self.embedding_size) self.latent_embedding = nn.Embedding(self.n_factors, self.embedding_size) self.gcn = GraphConv( embedding_size=self.embedding_size, n_hops=self.context_hops, n_users=self.n_users, n_relations=self.n_relations, n_factors=self.n_factors, edge_index=self.edge_index, edge_type=self.edge_type, interact_mat=self.interact_mat, ind=self.ind, tmp=self.temperature, device=self.device, node_dropout_rate=self.node_dropout_rate, mess_dropout_rate=self.mess_dropout_rate, ) self.mf_loss = BPRLoss() self.reg_loss = EmbLoss() self.restore_user_e = None self.restore_entity_e = None # parameters initialization self.apply(xavier_uniform_initialization)
[docs] def get_norm_inter_matrix(self, mode="bi"): # Get the normalized interaction matrix of users and items. def _bi_norm_lap(A): # D^{-1/2}AD^{-1/2} rowsum = np.array(A.sum(1)) d_inv_sqrt = np.power(rowsum, -0.5).flatten() d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0 d_mat_inv_sqrt = sp.diags(d_inv_sqrt) # bi_lap = adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt) bi_lap = d_mat_inv_sqrt.dot(A).dot(d_mat_inv_sqrt) return bi_lap.tocoo() def _si_norm_lap(A): # D^{-1}A rowsum = np.array(A.sum(1)) d_inv = np.power(rowsum, -1).flatten() d_inv[np.isinf(d_inv)] = 0.0 d_mat_inv = sp.diags(d_inv) norm_adj = d_mat_inv.dot(A) return norm_adj.tocoo() # build adj matrix A = sp.dok_matrix( (self.n_users + self.n_entities, self.n_users + self.n_entities), dtype=np.float32, ) inter_M = self.inter_matrix inter_M_t = self.inter_matrix.transpose() data_dict = dict( zip(zip(inter_M.row, inter_M.col + self.n_users), [1] * inter_M.nnz) ) data_dict.update( dict( zip( zip(inter_M_t.row + self.n_users, inter_M_t.col), [1] * inter_M_t.nnz, ) ) ) A._update(data_dict) # norm adj matrix if mode == "bi": L = _bi_norm_lap(A) elif mode == "si": L = _si_norm_lap(A) else: raise NotImplementedError( f"Normalize mode [{mode}] has not been implemented." ) # covert norm_inter_graph to tensor i = torch.LongTensor(np.array([L.row, L.col])) data = torch.FloatTensor(L.data) norm_graph = torch.sparse.FloatTensor(i, data, L.shape) # interaction: user->item, [n_users, n_entities] L_ = L.tocsr()[: self.n_users, self.n_users :].tocoo() # covert norm_inter_matrix to tensor i_ = torch.LongTensor(np.array([L_.row, L_.col])) data_ = torch.FloatTensor(L_.data) norm_matrix = torch.sparse.FloatTensor(i_, data_, L_.shape) return norm_matrix.to(self.device), norm_graph.to(self.device)
[docs] def get_edges(self, graph): index = torch.LongTensor(np.array([graph.row, graph.col])) type = torch.LongTensor(np.array(graph.data)) return index.to(self.device), type.to(self.device)
[docs] def forward(self): user_embeddings = self.user_embedding.weight entity_embeddings = self.entity_embedding.weight latent_embeddings = self.latent_embedding.weight # entity_gcn_emb: [n_entities, embedding_size] # user_gcn_emb: [n_users, embedding_size] # latent_gcn_emb: [n_factors, embedding_size] entity_gcn_emb, user_gcn_emb, cor_loss = self.gcn( user_embeddings, entity_embeddings, latent_embeddings ) return user_gcn_emb, entity_gcn_emb, cor_loss
[docs] def calculate_loss(self, interaction): r"""Calculate the training loss for a batch data of KG. Args: interaction (Interaction): Interaction class of the batch. Returns: torch.Tensor: Training loss, shape: [] """ if self.restore_user_e is not None or self.restore_entity_e is not None: self.restore_user_e, self.restore_entity_e = None, None user = interaction[self.USER_ID] pos_item = interaction[self.ITEM_ID] neg_item = interaction[self.NEG_ITEM_ID] user_all_embeddings, entity_all_embeddings, cor_loss = self.forward() u_embeddings = user_all_embeddings[user] pos_embeddings = entity_all_embeddings[pos_item] neg_embeddings = entity_all_embeddings[neg_item] pos_scores = torch.mul(u_embeddings, pos_embeddings).sum(dim=1) neg_scores = torch.mul(u_embeddings, neg_embeddings).sum(dim=1) mf_loss = self.mf_loss(pos_scores, neg_scores) reg_loss = self.reg_loss(u_embeddings, pos_embeddings, neg_embeddings) cor_loss = self.sim_decay * cor_loss loss = mf_loss + self.reg_weight * reg_loss + cor_loss return loss
[docs] def predict(self, interaction): user = interaction[self.USER_ID] item = interaction[self.ITEM_ID] user_all_embeddings, entity_all_embeddings, _ = self.forward() u_embeddings = user_all_embeddings[user] i_embeddings = entity_all_embeddings[item] scores = torch.mul(u_embeddings, i_embeddings).sum(dim=1) return scores
[docs] def full_sort_predict(self, interaction): user = interaction[self.USER_ID] if self.restore_user_e is None or self.restore_entity_e is None: self.restore_user_e, self.restore_entity_e, _ = self.forward() u_embeddings = self.restore_user_e[user] i_embeddings = self.restore_entity_e[: self.n_items] scores = torch.matmul(u_embeddings, i_embeddings.transpose(0, 1)) return scores.view(-1)