def_init_weights(self,module):""" Initialize the weights """ifisinstance(module,(nn.Linear,nn.Embedding)):# Slightly different from the TF version which uses truncated_normal for initialization# cf https://github.com/pytorch/pytorch/pull/5617module.weight.data.normal_(mean=0.0,std=self.initializer_range)elifisinstance(module,nn.LayerNorm):module.bias.data.zero_()module.weight.data.fill_(1.0)ifisinstance(module,nn.Linear)andmodule.biasisnotNone:module.bias.data.zero_()
[docs]classCORE(SequentialRecommender):r"""CORE is a simple and effective framewor, which unifies the representation spac for both the encoding and decoding processes in session-based recommendation. """def__init__(self,config,dataset):super(CORE,self).__init__(config,dataset)# load parameters infoself.embedding_size=config['embedding_size']self.loss_type=config['loss_type']self.dnn_type=config['dnn_type']self.sess_dropout=nn.Dropout(config['sess_dropout'])self.item_dropout=nn.Dropout(config['item_dropout'])self.temperature=config['temperature']# item embeddingself.item_embedding=nn.Embedding(self.n_items,self.embedding_size,padding_idx=0)# DNNifself.dnn_type=='trm':self.net=TransNet(config,dataset)elifself.dnn_type=='ave':self.net=self.ave_netelse:raiseValueError(f'dnn_type should be either trm or ave, but have [{self.dnn_type}].')ifself.loss_type=='CE':self.loss_fct=nn.CrossEntropyLoss()else:raiseNotImplementedError("Make sure 'loss_type' in ['CE']!")# parameters initializationself._reset_parameters()def_reset_parameters(self):stdv=1.0/np.sqrt(self.embedding_size)forweightinself.parameters():weight.data.uniform_(-stdv,stdv)
[docs]deffull_sort_predict(self,interaction):item_seq=interaction[self.ITEM_SEQ]seq_output=self.forward(item_seq)test_item_emb=self.item_embedding.weight# no dropout for evaluationtest_item_emb=F.normalize(test_item_emb,dim=-1)scores=torch.matmul(seq_output,test_item_emb.transpose(0,1))/self.temperaturereturnscores