Source code for recbole.model.knowledge_aware_recommender.ktup
# @Time : 2020/8/6# @Author : Shanlei Mu# @Email : slmu@ruc.edu.cnr"""KTUP##################################################Reference: Yixin Cao et al. "Unifying Knowledge Graph Learning and Recommendation:Towards a Better Understanding of User Preferences." in WWW 2019.Reference code: https://github.com/TaoMiner/joint-kg-recommender"""importtorchimporttorch.nnasnnimporttorch.nn.functionalasFfromtorch.autogradimportVariablefromrecbole.model.abstract_recommenderimportKnowledgeRecommenderfromrecbole.model.initimportxavier_uniform_initializationfromrecbole.model.lossimportBPRLoss,EmbMarginLossfromrecbole.utilsimportInputType
[docs]classKTUP(KnowledgeRecommender):r"""KTUP is a knowledge-based recommendation model. It adopts the strategy of multi-task learning to jointly learn recommendation and KG-related tasks, with the goal of understanding the reasons that a user interacts with an item. This method utilizes an attention mechanism to combine all preferences into a single-vector representation. """input_type=InputType.PAIRWISEdef__init__(self,config,dataset):super(KTUP,self).__init__(config,dataset)# load parameters infoself.embedding_size=config["embedding_size"]self.L1_flag=config["L1_flag"]self.use_st_gumbel=config["use_st_gumbel"]self.kg_weight=config["kg_weight"]self.align_weight=config["align_weight"]self.margin=config["margin"]# define layers and lossself.user_embedding=nn.Embedding(self.n_users,self.embedding_size)self.item_embedding=nn.Embedding(self.n_items,self.embedding_size)self.pref_embedding=nn.Embedding(self.n_relations,self.embedding_size)self.pref_norm_embedding=nn.Embedding(self.n_relations,self.embedding_size)self.entity_embedding=nn.Embedding(self.n_entities,self.embedding_size)self.relation_embedding=nn.Embedding(self.n_relations,self.embedding_size)self.relation_norm_embedding=nn.Embedding(self.n_relations,self.embedding_size)self.rec_loss=BPRLoss()self.kg_loss=nn.MarginRankingLoss(margin=self.margin)self.reg_loss=EmbMarginLoss()# parameters initializationself.apply(xavier_uniform_initialization)normalize_user_emb=F.normalize(self.user_embedding.weight.data,p=2,dim=1)normalize_item_emb=F.normalize(self.item_embedding.weight.data,p=2,dim=1)normalize_pref_emb=F.normalize(self.pref_embedding.weight.data,p=2,dim=1)normalize_pref_norm_emb=F.normalize(self.pref_norm_embedding.weight.data,p=2,dim=1)normalize_entity_emb=F.normalize(self.entity_embedding.weight.data,p=2,dim=1)normalize_rel_emb=F.normalize(self.relation_embedding.weight.data,p=2,dim=1)normalize_rel_norm_emb=F.normalize(self.relation_norm_embedding.weight.data,p=2,dim=1)self.user_embedding.weight.data=normalize_user_embself.item_embedding.weight_data=normalize_item_embself.pref_embedding.weight.data=normalize_pref_embself.pref_norm_embedding.weight.data=normalize_pref_norm_embself.entity_embedding.weight.data=normalize_entity_embself.relation_embedding.weight.data=normalize_rel_embself.relation_norm_embedding.weight.data=normalize_rel_norm_embdef_masked_softmax(self,logits):probs=F.softmax(logits,dim=len(logits.shape)-1)returnprobs
[docs]defconvert_to_one_hot(self,indices,num_classes):r""" Args: indices (Variable): A vector containing indices, whose size is (batch_size,). num_classes (Variable): The number of classes, which would be the second dimension of the resulting one-hot matrix. Returns: torch.Tensor: The one-hot matrix of size (batch_size, num_classes). """old_shape=indices.shapenew_shape=torch.Size([iforiinold_shape]+[num_classes])indices=indices.unsqueeze(len(old_shape))one_hot=Variable(indices.data.new(new_shape).zero_().scatter_(len(old_shape),indices.data,1))returnone_hot
[docs]defst_gumbel_softmax(self,logits,temperature=1.0):r"""Return the result of Straight-Through Gumbel-Softmax Estimation. It approximates the discrete sampling via Gumbel-Softmax trick and applies the biased ST estimator. In the forward propagation, it emits the discrete one-hot result, and in the backward propagation it approximates the categorical distribution via smooth Gumbel-Softmax distribution. Args: logits (Variable): A un-normalized probability values, which has the size (batch_size, num_classes) temperature (float): A temperature parameter. The higher the value is, the smoother the distribution is. Returns: torch.Tensor: The sampled output, which has the property explained above. """eps=1e-20u=logits.data.new(*logits.size()).uniform_()gumbel_noise=Variable(-torch.log(-torch.log(u+eps)+eps))y=logits+gumbel_noisey=self._masked_softmax(logits=y/temperature)y_argmax=y.max(len(y.shape)-1)[1]y_hard=self.convert_to_one_hot(indices=y_argmax,num_classes=y.size(len(y.shape)-1)).float()y=(y_hard-y).detach()+yreturny
def_get_preferences(self,user_e,item_e,use_st_gumbel=False):pref_probs=(torch.matmul(user_e+item_e,torch.t(self.pref_embedding.weight+self.relation_embedding.weight),)/2)ifuse_st_gumbel:# todo: different torch versions may cause the st_gumbel_softmax to report errors, wait to be testpref_probs=self.st_gumbel_softmax(pref_probs)relation_e=(torch.matmul(pref_probs,self.pref_embedding.weight+self.relation_embedding.weight)/2)norm_e=(torch.matmul(pref_probs,self.pref_norm_embedding.weight+self.relation_norm_embedding.weight,)/2)returnpref_probs,relation_e,norm_e@staticmethoddef_transH_projection(original,norm):return(original-torch.sum(original*norm,dim=len(original.size())-1,keepdim=True)*norm)def_get_score(self,h_e,r_e,t_e):ifself.L1_flag:score=-torch.sum(torch.abs(h_e+r_e-t_e),1)else:score=-torch.sum((h_e+r_e-t_e)**2,1)returnscore
[docs]defcalculate_kg_loss(self,interaction):r"""Calculate the training loss for a batch data of KG. Args: interaction (Interaction): Interaction class of the batch. Returns: torch.Tensor: Training loss, shape: [] """h=interaction[self.HEAD_ENTITY_ID]r=interaction[self.RELATION_ID]pos_t=interaction[self.TAIL_ENTITY_ID]neg_t=interaction[self.NEG_TAIL_ENTITY_ID]h_e=self.entity_embedding(h)pos_t_e=self.entity_embedding(pos_t)neg_t_e=self.entity_embedding(neg_t)r_e=self.relation_embedding(r)norm_e=self.relation_norm_embedding(r)proj_h_e=self._transH_projection(h_e,norm_e)proj_pos_t_e=self._transH_projection(pos_t_e,norm_e)proj_neg_t_e=self._transH_projection(neg_t_e,norm_e)pos_tail_score=self._get_score(proj_h_e,r_e,proj_pos_t_e)neg_tail_score=self._get_score(proj_h_e,r_e,proj_neg_t_e)kg_loss=self.kg_loss(pos_tail_score,neg_tail_score,torch.ones(h.size(0)).to(self.device))orthogonal_loss=orthogonalLoss(r_e,norm_e)reg_loss=self.reg_loss(h_e,pos_t_e,neg_t_e,r_e)loss=self.kg_weight*(kg_loss+orthogonal_loss+reg_loss)entity=torch.cat([h,pos_t,neg_t])entity=entity[entity<self.n_items]align_loss=self.align_weight*alignLoss(self.item_embedding(entity),self.entity_embedding(entity),self.L1_flag)returnloss,align_loss