Source code for recbole.model.knowledge_aware_recommender.kgnnls
# -*- coding: utf-8 -*-# @Time : 2020/10/3# @Author : Changxin Tian# @Email : cx.tian@outlook.comr"""KGNNLS################################################Reference: Hongwei Wang et al. "Knowledge-aware Graph Neural Networks with Label Smoothness Regularization for Recommender Systems." in KDD 2019.Reference code: https://github.com/hwwang55/KGNN-LS"""importrandomimportnumpyasnpimporttorchimporttorch.nnasnnfromrecbole.model.abstract_recommenderimportKnowledgeRecommenderfromrecbole.model.initimportxavier_normal_initializationfromrecbole.model.lossimportEmbLossfromrecbole.utilsimportInputType
[docs]classKGNNLS(KnowledgeRecommender):r"""KGNN-LS is a knowledge-based recommendation model. KGNN-LS transforms the knowledge graph into a user-specific weighted graph and then apply a graph neural network to compute personalized item embeddings. To provide better inductive bias, KGNN-LS relies on label smoothness assumption, which posits that adjacent items in the knowledge graph are likely to have similar user relevance labels/scores. Label smoothness provides regularization over the edge weights and it is equivalent to a label propagation scheme on a graph. """input_type=InputType.PAIRWISEdef__init__(self,config,dataset):super(KGNNLS,self).__init__(config,dataset)# load parameters infoself.embedding_size=config['embedding_size']self.neighbor_sample_size=config['neighbor_sample_size']self.aggregator_class=config['aggregator']# which aggregator to use# number of iterations when computing entity representationself.n_iter=config['n_iter']self.reg_weight=config['reg_weight']# weight of l2 regularization# weight of label Smoothness regularizationself.ls_weight=config['ls_weight']# define embeddingself.user_embedding=nn.Embedding(self.n_users,self.embedding_size)self.entity_embedding=nn.Embedding(self.n_entities,self.embedding_size)self.relation_embedding=nn.Embedding(self.n_relations+1,self.embedding_size)# sample neighbors and construct interaction tablekg_graph=dataset.kg_graph(form='coo',value_field='relation_id')adj_entity,adj_relation=self.construct_adj(kg_graph)self.adj_entity,self.adj_relation=adj_entity.to(self.device),adj_relation.to(self.device)inter_feat=dataset.inter_featpos_users=inter_feat[dataset.uid_field]pos_items=inter_feat[dataset.iid_field]pos_label=torch.ones(pos_items.shape)pos_interaction_table,self.offset=self.get_interaction_table(pos_users,pos_items,pos_label)self.interaction_table=self.sample_neg_interaction(pos_interaction_table,self.offset)# define functionself.softmax=nn.Softmax(dim=-1)self.linear_layers=torch.nn.ModuleList()foriinrange(self.n_iter):self.linear_layers.append(nn.Linear(self.embedding_sizeifnotself.aggregator_class=='concat'elseself.embedding_size*2,self.embedding_size))self.ReLU=nn.ReLU()self.Tanh=nn.Tanh()self.bce_loss=nn.BCEWithLogitsLoss()self.l2_loss=EmbLoss()# parameters initializationself.apply(xavier_normal_initialization)self.other_parameter_name=['adj_entity','adj_relation']
[docs]defget_interaction_table(self,user_id,item_id,y):r"""Get interaction_table that is used for fetching user-item interaction label in LS regularization. Args: user_id(torch.Tensor): the user id in user-item interactions, shape: [n_interactions, 1] item_id(torch.Tensor): the item id in user-item interactions, shape: [n_interactions, 1] y(torch.Tensor): the label in user-item interactions, shape: [n_interactions, 1] Returns: tuple: - interaction_table(dict): key: user_id * 10^offset + item_id; value: y_{user_id, item_id} - offset(int): The offset that is used for calculating the key(index) in interaction_table """offset=len(str(self.n_entities))offset=10**offsetkeys=user_id*offset+item_idkeys=keys.int().cpu().numpy().tolist()values=y.float().cpu().numpy().tolist()interaction_table=dict(zip(keys,values))returninteraction_table,offset
[docs]defsample_neg_interaction(self,pos_interaction_table,offset):r"""Sample neg_interaction to construct train data. Args: pos_interaction_table(dict): the interaction_table that only contains pos_interaction. offset(int): The offset that is used for calculating the key(index) in interaction_table Returns: interaction_table(dict): key: user_id * 10^offset + item_id; value: y_{user_id, item_id} """pos_num=len(pos_interaction_table)neg_num=0neg_interaction_table={}whileneg_num<pos_num:user_id=random.randint(0,self.n_users)item_id=random.randint(0,self.n_items)keys=user_id*offset+item_idifkeysnotinpos_interaction_table:neg_interaction_table[keys]=0.neg_num+=1interaction_table={**pos_interaction_table,**neg_interaction_table}returninteraction_table
[docs]defconstruct_adj(self,kg_graph):r"""Get neighbors and corresponding relations for each entity in the KG. Args: kg_graph(scipy.sparse.coo_matrix): an undirected graph Returns: tuple: - adj_entity (torch.LongTensor): each line stores the sampled neighbor entities for a given entity, shape: [n_entities, neighbor_sample_size] - adj_relation (torch.LongTensor): each line stores the corresponding sampled neighbor relations, shape: [n_entities, neighbor_sample_size] """# self.logger.info('constructing knowledge graph ...')# treat the KG as an undirected graphkg_dict=dict()fortripleinzip(kg_graph.row,kg_graph.data,kg_graph.col):head=triple[0]relation=triple[1]tail=triple[2]ifheadnotinkg_dict:kg_dict[head]=[]kg_dict[head].append((tail,relation))iftailnotinkg_dict:kg_dict[tail]=[]kg_dict[tail].append((head,relation))# self.logger.info('constructing adjacency matrix ...')# each line of adj_entity stores the sampled neighbor entities for a given entity# each line of adj_relation stores the corresponding sampled neighbor relationsentity_num=kg_graph.shape[0]adj_entity=np.zeros([entity_num,self.neighbor_sample_size],dtype=np.int64)adj_relation=np.zeros([entity_num,self.neighbor_sample_size],dtype=np.int64)forentityinrange(entity_num):ifentitynotinkg_dict.keys():adj_entity[entity]=np.array([entity]*self.neighbor_sample_size)adj_relation[entity]=np.array([0]*self.neighbor_sample_size)continueneighbors=kg_dict[entity]n_neighbors=len(neighbors)ifn_neighbors>=self.neighbor_sample_size:sampled_indices=np.random.choice(list(range(n_neighbors)),size=self.neighbor_sample_size,replace=False)else:sampled_indices=np.random.choice(list(range(n_neighbors)),size=self.neighbor_sample_size,replace=True)adj_entity[entity]=np.array([neighbors[i][0]foriinsampled_indices])adj_relation[entity]=np.array([neighbors[i][1]foriinsampled_indices])returntorch.from_numpy(adj_entity),torch.from_numpy(adj_relation)
[docs]defget_neighbors(self,items):r"""Get neighbors and corresponding relations for each entity in items from adj_entity and adj_relation. Args: items(torch.LongTensor): The input tensor that contains item's id, shape: [batch_size, ] Returns: tuple: - entities(list): Entities is a list of i-iter (i = 0, 1, ..., n_iter) neighbors for the batch of items. dimensions of entities: {[batch_size, 1], [batch_size, n_neighbor], [batch_size, n_neighbor^2], ..., [batch_size, n_neighbor^n_iter]} - relations(list): Relations is a list of i-iter (i = 0, 1, ..., n_iter) corresponding relations for entities. Relations have the same shape as entities. """items=torch.unsqueeze(items,dim=1)entities=[items]relations=[]foriinrange(self.n_iter):index=torch.flatten(entities[i])neighbor_entities=torch.index_select(self.adj_entity,0,index).reshape(self.batch_size,-1)neighbor_relations=torch.index_select(self.adj_relation,0,index).reshape(self.batch_size,-1)entities.append(neighbor_entities)relations.append(neighbor_relations)returnentities,relations
[docs]defaggregate(self,user_embeddings,entities,relations):r"""For each item, aggregate the entity representation and its neighborhood representation into a single vector. Args: user_embeddings(torch.FloatTensor): The embeddings of users, shape: [batch_size, embedding_size] entities(list): entities is a list of i-iter (i = 0, 1, ..., n_iter) neighbors for the batch of items. dimensions of entities: {[batch_size, 1], [batch_size, n_neighbor], [batch_size, n_neighbor^2], ..., [batch_size, n_neighbor^n_iter]} relations(list): relations is a list of i-iter (i = 0, 1, ..., n_iter) corresponding relations for entities. relations have the same shape as entities. Returns: item_embeddings(torch.FloatTensor): The embeddings of items, shape: [batch_size, embedding_size] """entity_vectors=[self.entity_embedding(i)foriinentities]relation_vectors=[self.relation_embedding(i)foriinrelations]foriinrange(self.n_iter):entity_vectors_next_iter=[]forhopinrange(self.n_iter-i):shape=(self.batch_size,-1,self.neighbor_sample_size,self.embedding_size)self_vectors=entity_vectors[hop]neighbor_vectors=entity_vectors[hop+1].reshape(shape)neighbor_relations=relation_vectors[hop].reshape(shape)# mix_neighbor_vectorsuser_embeddings=user_embeddings.reshape(self.batch_size,1,1,self.embedding_size)# [batch_size, 1, 1, dim]user_relation_scores=torch.mean(user_embeddings*neighbor_relations,dim=-1)# [batch_size, -1, n_neighbor]user_relation_scores_normalized=torch.unsqueeze(self.softmax(user_relation_scores),dim=-1)# [batch_size, -1, n_neighbor, 1]neighbors_agg=torch.mean(user_relation_scores_normalized*neighbor_vectors,dim=2)# [batch_size, -1, dim]ifself.aggregator_class=='sum':output=(self_vectors+neighbors_agg).reshape(-1,self.embedding_size)# [-1, dim]elifself.aggregator_class=='neighbor':output=neighbors_agg.reshape(-1,self.embedding_size)# [-1, dim]elifself.aggregator_class=='concat':# [batch_size, -1, dim * 2]output=torch.cat([self_vectors,neighbors_agg],dim=-1)output=output.reshape(-1,self.embedding_size*2)# [-1, dim * 2]else:raiseException("Unknown aggregator: "+self.aggregator_class)output=self.linear_layers[i](output)# [batch_size, -1, dim]output=output.reshape(self.batch_size,-1,self.embedding_size)ifi==self.n_iter-1:vector=self.Tanh(output)else:vector=self.ReLU(output)entity_vectors_next_iter.append(vector)entity_vectors=entity_vectors_next_iterres=entity_vectors[0].reshape(self.batch_size,self.embedding_size)returnres
[docs]deflabel_smoothness_predict(self,user_embeddings,user,entities,relations):r"""Predict the label of items by label smoothness. Args: user_embeddings(torch.FloatTensor): The embeddings of users, shape: [batch_size*2, embedding_size], user(torch.FloatTensor): the index of users, shape: [batch_size*2] entities(list): entities is a list of i-iter (i = 0, 1, ..., n_iter) neighbors for the batch of items. dimensions of entities: {[batch_size*2, 1], [batch_size*2, n_neighbor], [batch_size*2, n_neighbor^2], ..., [batch_size*2, n_neighbor^n_iter]} relations(list): relations is a list of i-iter (i = 0, 1, ..., n_iter) corresponding relations for entities. relations have the same shape as entities. Returns: predicted_labels(torch.FloatTensor): The predicted label of items, shape: [batch_size*2] """# calculate initial labels; calculate updating masks for label propagationentity_labels=[]# True means the label of this item is reset to initial value during label propagationreset_masks=[]holdout_item_for_user=Noneforentities_per_iterinentities:users=torch.unsqueeze(user,dim=1)# [batch_size, 1]user_entity_concat=users*self.offset+entities_per_iter# [batch_size, n_neighbor^i]# the first one in entities is the items to be held outifholdout_item_for_userisNone:holdout_item_for_user=user_entity_concatdeflookup_interaction_table(x,_):x=int(x)label=self.interaction_table.setdefault(x,0.5)returnlabelinitial_label=user_entity_concat.clone().cpu().double()initial_label.map_(initial_label,lookup_interaction_table)initial_label=initial_label.float().to(self.device)# False if the item is held outholdout_mask=(holdout_item_for_user-user_entity_concat).bool()# True if the entity is a labeled itemreset_mask=(initial_label-0.5).bool()reset_mask=torch.logical_and(reset_mask,holdout_mask)# remove held-out itemsinitial_label=holdout_mask.float()*initial_label+ \
torch.logical_not(holdout_mask).float()*0.5# label initializationreset_masks.append(reset_mask)entity_labels.append(initial_label)# we do not need the reset_mask for the last iterationreset_masks=reset_masks[:-1]# label propagationrelation_vectors=[self.relation_embedding(i)foriinrelations]foriinrange(self.n_iter):entity_labels_next_iter=[]forhopinrange(self.n_iter-i):masks=reset_masks[hop]self_labels=entity_labels[hop]neighbor_labels=entity_labels[hop+1].reshape(self.batch_size,-1,self.neighbor_sample_size)neighbor_relations=relation_vectors[hop].reshape(self.batch_size,-1,self.neighbor_sample_size,self.embedding_size)# mix_neighbor_labelsuser_embeddings=user_embeddings.reshape(self.batch_size,1,1,self.embedding_size)# [batch_size, 1, 1, dim]user_relation_scores=torch.mean(user_embeddings*neighbor_relations,dim=-1)# [batch_size, -1, n_neighbor]user_relation_scores_normalized=self.softmax(user_relation_scores)# [batch_size, -1, n_neighbor]neighbors_aggregated_label=torch.mean(user_relation_scores_normalized*neighbor_labels,dim=2)# [batch_size, -1, dim] # [batch_size, -1]output=masks.float()*self_labels+ \
torch.logical_not(masks).float()*neighbors_aggregated_labelentity_labels_next_iter.append(output)entity_labels=entity_labels_next_iterpredicted_labels=entity_labels[0].squeeze(-1)returnpredicted_labels
[docs]defforward(self,user,item):self.batch_size=item.shape[0]# [batch_size, dim]user_e=self.user_embedding(user)# entities is a list of i-iter (i = 0, 1, ..., n_iter) neighbors for the batch of items. dimensions of entities:# {[batch_size, 1], [batch_size, n_neighbor], [batch_size, n_neighbor^2], ..., [batch_size, n_neighbor^n_iter]}entities,relations=self.get_neighbors(item)# [batch_size, dim]item_e=self.aggregate(user_e,entities,relations)returnuser_e,item_e
[docs]defcalculate_ls_loss(self,user,item,target):r"""Calculate label smoothness loss. Args: user(torch.FloatTensor): the index of users, shape: [batch_size*2], item(torch.FloatTensor): the index of items, shape: [batch_size*2], target(torch.FloatTensor): the label of user-item, shape: [batch_size*2], Returns: ls_loss: label smoothness loss """user_e=self.user_embedding(user)entities,relations=self.get_neighbors(item)predicted_labels=self.label_smoothness_predict(user_e,user,entities,relations)ls_loss=self.bce_loss(predicted_labels,target)returnls_loss