Source code for recbole.model.general_recommender.lightgcn
# -*- coding: utf-8 -*-# @Time : 2020/8/31# @Author : Changxin Tian# @Email : cx.tian@outlook.com# UPDATE:# @Time : 2020/9/16, 2021/12/22# @Author : Shanlei Mu, Gaowei Zhang# @Email : slmu@ruc.edu.cn, 1462034631@qq.comr"""LightGCN################################################Reference: Xiangnan He et al. "LightGCN: Simplifying and Powering Graph Convolution Network for Recommendation." in SIGIR 2020.Reference code: https://github.com/kuandeng/LightGCN"""importnumpyasnpimportscipy.sparseasspimporttorchfromrecbole.model.abstract_recommenderimportGeneralRecommenderfromrecbole.model.initimportxavier_uniform_initializationfromrecbole.model.lossimportBPRLoss,EmbLossfromrecbole.utilsimportInputType
[docs]classLightGCN(GeneralRecommender):r"""LightGCN is a GCN-based recommender model. LightGCN includes only the most essential component in GCN — neighborhood aggregation — for collaborative filtering. Specifically, LightGCN learns user and item embeddings by linearly propagating them on the user-item interaction graph, and uses the weighted sum of the embeddings learned at all layers as the final embedding. We implement the model following the original author with a pairwise training mode. """input_type=InputType.PAIRWISEdef__init__(self,config,dataset):super(LightGCN,self).__init__(config,dataset)# load dataset infoself.interaction_matrix=dataset.inter_matrix(form='coo').astype(np.float32)# load parameters infoself.latent_dim=config['embedding_size']# int type:the embedding size of lightGCNself.n_layers=config['n_layers']# int type:the layer num of lightGCNself.reg_weight=config['reg_weight']# float32 type: the weight decay for l2 normalizationself.require_pow=config['require_pow']# define layers and lossself.user_embedding=torch.nn.Embedding(num_embeddings=self.n_users,embedding_dim=self.latent_dim)self.item_embedding=torch.nn.Embedding(num_embeddings=self.n_items,embedding_dim=self.latent_dim)self.mf_loss=BPRLoss()self.reg_loss=EmbLoss()# storage variables for full sort evaluation accelerationself.restore_user_e=Noneself.restore_item_e=None# generate intermediate dataself.norm_adj_matrix=self.get_norm_adj_mat().to(self.device)# parameters initializationself.apply(xavier_uniform_initialization)self.other_parameter_name=['restore_user_e','restore_item_e']
[docs]defget_norm_adj_mat(self):r"""Get the normalized interaction matrix of users and items. Construct the square matrix from the training data and normalize it using the laplace matrix. .. math:: A_{hat} = D^{-0.5} \times A \times D^{-0.5} Returns: Sparse tensor of the normalized interaction matrix. """# build adj matrixA=sp.dok_matrix((self.n_users+self.n_items,self.n_users+self.n_items),dtype=np.float32)inter_M=self.interaction_matrixinter_M_t=self.interaction_matrix.transpose()data_dict=dict(zip(zip(inter_M.row,inter_M.col+self.n_users),[1]*inter_M.nnz))data_dict.update(dict(zip(zip(inter_M_t.row+self.n_users,inter_M_t.col),[1]*inter_M_t.nnz)))A._update(data_dict)# norm adj matrixsumArr=(A>0).sum(axis=1)# add epsilon to avoid divide by zero Warningdiag=np.array(sumArr.flatten())[0]+1e-7diag=np.power(diag,-0.5)D=sp.diags(diag)L=D*A*D# covert norm_adj matrix to tensorL=sp.coo_matrix(L)row=L.rowcol=L.coli=torch.LongTensor(np.array([row,col]))data=torch.FloatTensor(L.data)SparseL=torch.sparse.FloatTensor(i,data,torch.Size(L.shape))returnSparseL
[docs]defget_ego_embeddings(self):r"""Get the embedding of users and items and combine to an embedding matrix. Returns: Tensor of the embedding matrix. Shape of [n_items+n_users, embedding_dim] """user_embeddings=self.user_embedding.weightitem_embeddings=self.item_embedding.weightego_embeddings=torch.cat([user_embeddings,item_embeddings],dim=0)returnego_embeddings
[docs]deffull_sort_predict(self,interaction):user=interaction[self.USER_ID]ifself.restore_user_eisNoneorself.restore_item_eisNone:self.restore_user_e,self.restore_item_e=self.forward()# get user embedding from storage variableu_embeddings=self.restore_user_e[user]# dot with all item embedding to acceleratescores=torch.matmul(u_embeddings,self.restore_item_e.transpose(0,1))returnscores.view(-1)