Source code for recbole.model.sequential_recommender.bert4rec
# -*- coding: utf-8 -*-# @Time : 2020/9/18 12:08# @Author : Hui Wang# @Email : hui.wang@ruc.edu.cn# UPDATE# @Time : 2023/9/4# @Author : Enze Liu# @Email : enzeeliu@foxmail.comr"""BERT4Rec################################################Reference: Fei Sun et al. "BERT4Rec: Sequential Recommendation with Bidirectional Encoder Representations from Transformer." In CIKM 2019.Reference code: The authors' tensorflow implementation https://github.com/FeiSun/BERT4Rec"""importrandomimporttorchfromtorchimportnnfromrecbole.model.abstract_recommenderimportSequentialRecommenderfromrecbole.model.layersimportTransformerEncoder
[docs]classBERT4Rec(SequentialRecommender):def__init__(self,config,dataset):super(BERT4Rec,self).__init__(config,dataset)# load parameters infoself.n_layers=config["n_layers"]self.n_heads=config["n_heads"]self.hidden_size=config["hidden_size"]# same as embedding_sizeself.inner_size=config["inner_size"]# the dimensionality in feed-forward layerself.hidden_dropout_prob=config["hidden_dropout_prob"]self.attn_dropout_prob=config["attn_dropout_prob"]self.hidden_act=config["hidden_act"]self.layer_norm_eps=config["layer_norm_eps"]self.mask_ratio=config["mask_ratio"]self.MASK_ITEM_SEQ=config["MASK_ITEM_SEQ"]self.POS_ITEMS=config["POS_ITEMS"]self.NEG_ITEMS=config["NEG_ITEMS"]self.MASK_INDEX=config["MASK_INDEX"]self.loss_type=config["loss_type"]self.initializer_range=config["initializer_range"]# load dataset infoself.mask_token=self.n_itemsself.mask_item_length=int(self.mask_ratio*self.max_seq_length)# define layers and lossself.item_embedding=nn.Embedding(self.n_items+1,self.hidden_size,padding_idx=0)# mask token add 1self.position_embedding=nn.Embedding(self.max_seq_length,self.hidden_size)# add mask_token at the lastself.trm_encoder=TransformerEncoder(n_layers=self.n_layers,n_heads=self.n_heads,hidden_size=self.hidden_size,inner_size=self.inner_size,hidden_dropout_prob=self.hidden_dropout_prob,attn_dropout_prob=self.attn_dropout_prob,hidden_act=self.hidden_act,layer_norm_eps=self.layer_norm_eps,)self.LayerNorm=nn.LayerNorm(self.hidden_size,eps=self.layer_norm_eps)self.dropout=nn.Dropout(self.hidden_dropout_prob)self.output_ffn=nn.Linear(self.hidden_size,self.hidden_size)self.output_gelu=nn.GELU()self.output_ln=nn.LayerNorm(self.hidden_size,eps=self.layer_norm_eps)self.output_bias=nn.Parameter(torch.zeros(self.n_items))# we only need compute the loss at the masked positiontry:assertself.loss_typein["BPR","CE"]exceptAssertionError:raiseAssertionError("Make sure 'loss_type' in ['BPR', 'CE']!")# parameters initializationself.apply(self._init_weights)def_init_weights(self,module):"""Initialize the weights"""ifisinstance(module,(nn.Linear,nn.Embedding)):# Slightly different from the TF version which uses truncated_normal for initialization# cf https://github.com/pytorch/pytorch/pull/5617module.weight.data.normal_(mean=0.0,std=self.initializer_range)elifisinstance(module,nn.LayerNorm):module.bias.data.zero_()module.weight.data.fill_(1.0)ifisinstance(module,nn.Linear)andmodule.biasisnotNone:module.bias.data.zero_()
[docs]defreconstruct_test_data(self,item_seq,item_seq_len):""" Add mask token at the last position according to the lengths of item_seq """padding=torch.zeros(item_seq.size(0),dtype=torch.long,device=item_seq.device)# [B]item_seq=torch.cat((item_seq,padding.unsqueeze(-1)),dim=-1)# [B max_len+1]forbatch_id,last_positioninenumerate(item_seq_len):item_seq[batch_id][last_position]=self.mask_tokenitem_seq=item_seq[:,1:]returnitem_seq
[docs]defforward(self,item_seq):position_ids=torch.arange(item_seq.size(1),dtype=torch.long,device=item_seq.device)position_ids=position_ids.unsqueeze(0).expand_as(item_seq)position_embedding=self.position_embedding(position_ids)item_emb=self.item_embedding(item_seq)input_emb=item_emb+position_embeddinginput_emb=self.LayerNorm(input_emb)input_emb=self.dropout(input_emb)extended_attention_mask=self.get_attention_mask(item_seq,bidirectional=True)trm_output=self.trm_encoder(input_emb,extended_attention_mask,output_all_encoded_layers=True)ffn_output=self.output_ffn(trm_output[-1])ffn_output=self.output_gelu(ffn_output)output=self.output_ln(ffn_output)returnoutput# [B L H]
[docs]defmulti_hot_embed(self,masked_index,max_length):""" For memory, we only need calculate loss for masked position. Generate a multi-hot vector to indicate the masked position for masked sequence, and then is used for gathering the masked position hidden representation. Examples: sequence: [1 2 3 4 5] masked_sequence: [1 mask 3 mask 5] masked_index: [1, 3] max_length: 5 multi_hot_embed: [[0 1 0 0 0], [0 0 0 1 0]] """masked_index=masked_index.view(-1)multi_hot=torch.zeros(masked_index.size(0),max_length,device=masked_index.device)multi_hot[torch.arange(masked_index.size(0)),masked_index]=1returnmulti_hot