Source code for recbole.model.sequential_recommender.sasreccpr

# -*- coding: utf-8 -*-
# @Time    : 2020/9/18 11:33
# @Author  : Hui Wang
# @Email   : hui.wang@ruc.edu.cn

# UPDATE:
# @Time   : 2023/11/24
# @Author : Haw-Shiuan Chang
# @Email  : ken77921@gmail.com

"""
SASRec + Softmax-CPR
################################################

Reference:
    Wang-Cheng Kang et al. "Self-Attentive Sequential Recommendation." in ICDM 2018.
    Haw-Shiuan Chang, Nikhil Agarwal, and Andrew McCallum "To Copy, or not to Copy; That is a Critical Issue of the Output Softmax Layer in Neural Sequential Recommenders" in WSDM 2024

Reference:
    https://github.com/kang205/SASRec
    https://arxiv.org/pdf/2310.14079.pdf

"""

import sys
import torch
from torch import nn
import torch.nn.functional as F
from recbole.model.abstract_recommender import SequentialRecommender
from recbole.model.layers import TransformerEncoder

# from recbole.model.loss import BPRLoss
import math


[docs]def gelu(x): return ( 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) )
[docs]class SASRecCPR(SequentialRecommender): r""" SASRec is the first sequential recommender based on self-attentive mechanism. NOTE: In the author's implementation, the Point-Wise Feed-Forward Network (PFFN) is implemented by CNN with 1x1 kernel. In this implementation, we follows the original BERT implementation using Fully Connected Layer to implement the PFFN. """ def __init__(self, config, dataset): super(SASRecCPR, self).__init__(config, dataset) # load parameters info self.n_layers = config["n_layers"] self.n_heads = config["n_heads"] self.hidden_size = config["hidden_size"] # same as embedding_size self.inner_size = config[ "inner_size" ] # the dimensionality in feed-forward layer self.hidden_dropout_prob = config["hidden_dropout_prob"] self.attn_dropout_prob = config["attn_dropout_prob"] self.hidden_act = config["hidden_act"] self.layer_norm_eps = config["layer_norm_eps"] self.initializer_range = config["initializer_range"] self.loss_type = config["loss_type"] self.n_facet_all = config["n_facet_all"] # added for mfs self.n_facet = config["n_facet"] # added for mfs self.n_facet_window = config["n_facet_window"] # added for mfs self.n_facet_hidden = min( config["n_facet_hidden"], config["n_layers"] ) # added for mfs self.n_facet_MLP = config["n_facet_MLP"] # added for mfs self.n_facet_context = config["n_facet_context"] # added for dynamic partioning self.n_facet_reranker = config[ "n_facet_reranker" ] # added for dynamic partioning self.n_facet_emb = config["n_facet_emb"] # added for dynamic partioning self.weight_mode = config["weight_mode"] # added for mfs self.context_norm = config["context_norm"] # added for mfs self.post_remove_context = config["post_remove_context"] # added for mfs self.partition_merging_mode = config["partition_merging_mode"] # added for mfs self.reranker_merging_mode = config["reranker_merging_mode"] # added for mfs self.reranker_CAN_NUM = [ int(x) for x in str(config["reranker_CAN_NUM"]).split(",") ] self.candidates_from_previous_reranker = True if self.weight_mode == "max_logits": self.n_facet_effective = 1 else: self.n_facet_effective = self.n_facet assert ( self.n_facet + self.n_facet_context + self.n_facet_reranker * len(self.reranker_CAN_NUM) + self.n_facet_emb == self.n_facet_all ) assert self.n_facet_emb == 0 or self.n_facet_emb == 2 assert self.n_facet_MLP <= 0 # -1 or 0 assert self.n_facet_window <= 0 self.n_facet_window = -self.n_facet_window self.n_facet_MLP = -self.n_facet_MLP self.softmax_nonlinear = "None" # added for mfs self.use_proj_bias = config["use_proj_bias"] # added for mfs hidden_state_input_ratio = 1 + self.n_facet_MLP # 1 + 1 self.MLP_linear = nn.Linear( self.hidden_size * (self.n_facet_hidden * (self.n_facet_window + 1)), self.hidden_size * self.n_facet_MLP, ) # (hid_dim*2) -> (hid_dim) total_lin_dim = self.hidden_size * hidden_state_input_ratio self.project_arr = nn.ModuleList( [ nn.Linear(total_lin_dim, self.hidden_size, bias=self.use_proj_bias) for i in range(self.n_facet_all) ] ) self.project_emb = nn.Linear( self.hidden_size, self.hidden_size, bias=self.use_proj_bias ) if len(self.weight_mode) > 0: self.weight_facet_decoder = nn.Linear( self.hidden_size * hidden_state_input_ratio, self.n_facet_effective ) self.weight_global = nn.Parameter(torch.ones(self.n_facet_effective)) self.output_probs = True self.item_embedding = nn.Embedding( self.n_items, self.hidden_size, padding_idx=0 ) self.position_embedding = nn.Embedding(self.max_seq_length, self.hidden_size) self.trm_encoder = TransformerEncoder( n_layers=self.n_layers, n_heads=self.n_heads, hidden_size=self.hidden_size, inner_size=self.inner_size, hidden_dropout_prob=self.hidden_dropout_prob, attn_dropout_prob=self.attn_dropout_prob, hidden_act=self.hidden_act, layer_norm_eps=self.layer_norm_eps, ) self.LayerNorm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) self.dropout = nn.Dropout(self.hidden_dropout_prob) if self.loss_type == "BPR": print("current softmax-cpr code does not support BPR loss") sys.exit(0) elif self.loss_type == "CE": self.loss_fct = nn.NLLLoss( reduction="none", ignore_index=0 ) # modified for mfs else: raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!") # parameters initialization self.apply(self._init_weights) small_value = 0.0001
[docs] def get_facet_emb(self, input_emb, i): return self.project_arr[i](input_emb)
def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Embedding)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.initializer_range) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_()
[docs] def forward(self, item_seq, item_seq_len): position_ids = torch.arange( item_seq.size(1), dtype=torch.long, device=item_seq.device ) position_ids = position_ids.unsqueeze(0).expand_as(item_seq) position_embedding = self.position_embedding(position_ids) item_emb = self.item_embedding(item_seq) input_emb = item_emb + position_embedding input_emb = self.LayerNorm(input_emb) input_emb = self.dropout(input_emb) extended_attention_mask = self.get_attention_mask(item_seq) trm_output = self.trm_encoder( input_emb, extended_attention_mask, output_all_encoded_layers=True ) return trm_output
[docs] def calculate_loss_prob(self, interaction, only_compute_prob=False): item_seq = interaction[self.ITEM_SEQ] item_seq_len = interaction[self.ITEM_SEQ_LEN] all_hidden_states = self.forward(item_seq, item_seq_len) if self.loss_type != "CE": print( "current softmax-cpr code does not support BPR or the losses other than cross entropy" ) sys.exit(0) else: # self.loss_type = 'CE' test_item_emb = self.item_embedding.weight """mfs code starts""" device = all_hidden_states[0].device # check seq_len from hidden size ## Multi-input hidden states: generate q_ct from hidden states # list of hidden state embeddings taken as input hidden_emb_arr = [] # h_facet_hidden -> H, n_face_window -> W, here 1 and 0 for i in range(self.n_facet_hidden): hidden_states = all_hidden_states[ -(i + 1) ] # i-th hidden-state embedding from the top device = hidden_states.device hidden_emb_arr.append(hidden_states) for j in range(self.n_facet_window): ( bsz, seq_len, hidden_size, ) = ( hidden_states.size() ) # bsz -> , seq_len -> , hidden_size -> 768 in GPT-small? if j + 1 < hidden_states.size(1): shifted_hidden = torch.cat( ( torch.zeros((bsz, (j + 1), hidden_size), device=device), hidden_states[:, : -(j + 1), :], ), dim=1, ) else: shifted_hidden = torch.zeros( (bsz, hidden_states.size(1), hidden_size), device=device ) hidden_emb_arr.append(shifted_hidden) # hidden_emb_arr -> (W*H, bsz, seq_len, hidden_size) # n_facet_MLP -> 1 if self.n_facet_MLP > 0: stacked_hidden_emb_raw_arr = torch.cat( hidden_emb_arr, dim=-1 ) # (bsz, seq_len, W*H*hidden_size) # self.MLP_linear = nn.Linear(config.hidden_size * (n_facet_hidden * (n_facet_window+1) ), config.hidden_size * n_facet_MLP) -> why +1? hidden_emb_MLP = self.MLP_linear( stacked_hidden_emb_raw_arr ) # bsz, seq_len, hidden_size stacked_hidden_emb_arr_raw = torch.cat( [hidden_emb_arr[0], gelu(hidden_emb_MLP)], dim=-1 ) # bsz, seq_len, 2*hidden_size else: stacked_hidden_emb_arr_raw = hidden_emb_arr[0] # Only use the hidden state corresponding to the last item # The seq_len = 1 in the following code stacked_hidden_emb_arr = stacked_hidden_emb_arr_raw[:, -1, :].unsqueeze( dim=1 ) # list of linear projects per facet projected_emb_arr = [] # list of final logits per facet facet_lm_logits_arr = [] facet_lm_logits_real_arr = [] # logits for orig facets rereanker_candidate_token_ids_arr = [] for i in range(self.n_facet): # #linear projection projected_emb = self.get_facet_emb( stacked_hidden_emb_arr, i ) # (bsz, seq_len, hidden_dim) projected_emb_arr.append(projected_emb) # logits for all tokens in vocab lm_logits = F.linear(projected_emb, self.item_embedding.weight, None) facet_lm_logits_arr.append(lm_logits) if ( i < self.n_facet_reranker and not self.candidates_from_previous_reranker ): candidate_token_ids = [] for j in range(len(self.reranker_CAN_NUM)): _, candidate_token_ids_ = torch.topk( lm_logits, self.reranker_CAN_NUM[j] ) candidate_token_ids.append(candidate_token_ids_) rereanker_candidate_token_ids_arr.append(candidate_token_ids) for i in range(self.n_facet_reranker): for j in range(len(self.reranker_CAN_NUM)): projected_emb = self.get_facet_emb( stacked_hidden_emb_arr, self.n_facet + i * len(self.reranker_CAN_NUM) + j, ) # (bsz, seq_len, hidden_dim) projected_emb_arr.append(projected_emb) for i in range(self.n_facet_context): projected_emb = self.get_facet_emb( stacked_hidden_emb_arr, self.n_facet + self.n_facet_reranker * len(self.reranker_CAN_NUM) + i, ) # (bsz, seq_len, hidden_dim) projected_emb_arr.append(projected_emb) # to generate context-based embeddings for words in input for i in range(self.n_facet_emb): projected_emb = self.get_facet_emb( stacked_hidden_emb_arr_raw, self.n_facet + self.n_facet_context + self.n_facet_reranker * len(self.reranker_CAN_NUM) + i, ) # (bsz, seq_len, hidden_dim) projected_emb_arr.append(projected_emb) for i in range(self.n_facet_reranker): bsz, seq_len, hidden_size = projected_emb_arr[i].size() for j in range(len(self.reranker_CAN_NUM)): if self.candidates_from_previous_reranker: _, candidate_token_ids = torch.topk( facet_lm_logits_arr[i], self.reranker_CAN_NUM[j] ) # (bsz, seq_len, topk) else: candidate_token_ids = rereanker_candidate_token_ids_arr[i][j] logit_hidden_reranker_topn = ( projected_emb_arr[ self.n_facet + i * len(self.reranker_CAN_NUM) + j ] .unsqueeze(dim=2) .expand(bsz, seq_len, self.reranker_CAN_NUM[j], hidden_size) * self.item_embedding.weight[candidate_token_ids, :] ).sum( dim=-1 ) # (bsz, seq_len, emb_size) x (bsz, seq_len, topk, emb_size) -> (bsz, seq_len, topk) if self.reranker_merging_mode == "add": facet_lm_logits_arr[i].scatter_add_( 2, candidate_token_ids, logit_hidden_reranker_topn ) # (bsz, seq_len, vocab_size) <- (bsz, seq_len, topk) x (bsz, seq_len, topk) else: facet_lm_logits_arr[i].scatter_( 2, candidate_token_ids, logit_hidden_reranker_topn ) # (bsz, seq_len, vocab_size) <- (bsz, seq_len, topk) x (bsz, seq_len, topk) for i in range(self.n_facet_context): bsz, seq_len_1, hidden_size = projected_emb_arr[i].size() bsz, seq_len_2 = item_seq.size() logit_hidden_context = ( projected_emb_arr[ self.n_facet + self.n_facet_reranker * len(self.reranker_CAN_NUM) + i ] .unsqueeze(dim=2) .expand(-1, -1, seq_len_2, -1) * self.item_embedding.weight[item_seq, :] .unsqueeze(dim=1) .expand(-1, seq_len_1, -1, -1) ).sum(dim=-1) logit_hidden_pointer = 0 if self.n_facet_emb == 2: logit_hidden_pointer = ( projected_emb_arr[-2][:, -1, :] .unsqueeze(dim=1) .unsqueeze(dim=1) .expand(-1, seq_len_1, seq_len_2, -1) * projected_emb_arr[-1] .unsqueeze(dim=1) .expand(-1, seq_len_1, -1, -1) ).sum(dim=-1) item_seq_expand = item_seq.unsqueeze(dim=1).expand(-1, seq_len_1, -1) only_new_logits = torch.zeros_like(facet_lm_logits_arr[i]) if self.context_norm: only_new_logits.scatter_add_( dim=2, index=item_seq_expand, src=logit_hidden_context + logit_hidden_pointer, ) item_count = torch.zeros_like(only_new_logits) + 1e-15 item_count.scatter_add_( dim=2, index=item_seq_expand, src=torch.ones_like(item_seq_expand).to(dtype=item_count.dtype), ) only_new_logits = only_new_logits / item_count else: only_new_logits.scatter_add_( dim=2, index=item_seq_expand, src=logit_hidden_context ) item_count = torch.zeros_like(only_new_logits) + 1e-15 item_count.scatter_add_( dim=2, index=item_seq_expand, src=torch.ones_like(item_seq_expand).to(dtype=item_count.dtype), ) only_new_logits = only_new_logits / item_count only_new_logits.scatter_add_( dim=2, index=item_seq_expand, src=logit_hidden_pointer ) if self.partition_merging_mode == "replace": facet_lm_logits_arr[i].scatter_( dim=2, index=item_seq_expand, src=torch.zeros_like(item_seq_expand).to( dtype=facet_lm_logits_arr[i].dtype ), ) facet_lm_logits_arr[i] = facet_lm_logits_arr[i] + only_new_logits weight = None if self.weight_mode == "dynamic": weight = self.weight_facet_decoder(stacked_hidden_emb_arr).softmax( dim=-1 ) # hidden_dim*hidden_input_state_ration -> n_facet_effective elif self.weight_mode == "static": weight = self.weight_global.softmax( dim=-1 ) # torch.ones(n_facet_effective) elif self.weight_mode == "max_logits": stacked_facet_lm_logits = torch.stack(facet_lm_logits_arr, dim=0) facet_lm_logits_arr = [stacked_facet_lm_logits.amax(dim=0)] prediction_prob = 0 for i in range(self.n_facet_effective): facet_lm_logits = facet_lm_logits_arr[i] if self.softmax_nonlinear == "sigsoftmax": #'None' here facet_lm_logits_sig = torch.exp( facet_lm_logits - facet_lm_logits.max(dim=-1, keepdim=True)[0] ) * (1e-20 + torch.sigmoid(facet_lm_logits)) facet_lm_logits_softmax = ( facet_lm_logits_sig / facet_lm_logits_sig.sum(dim=-1, keepdim=True) ) elif self.softmax_nonlinear == "None": facet_lm_logits_softmax = facet_lm_logits.softmax( dim=-1 ) # softmax over final logits if self.weight_mode == "dynamic": prediction_prob += facet_lm_logits_softmax * weight[ :, :, i ].unsqueeze(-1) elif self.weight_mode == "static": prediction_prob += facet_lm_logits_softmax * weight[i] else: prediction_prob += ( facet_lm_logits_softmax / self.n_facet_effective ) # softmax over final logits/1 if not only_compute_prob: inp = torch.log(prediction_prob.view(-1, self.n_items) + 1e-8) pos_items = interaction[self.POS_ITEM_ID] loss_raw = self.loss_fct(inp, pos_items.view(-1)) loss = loss_raw.mean() else: loss = None # return loss, prediction_prob.squeeze() return loss, prediction_prob.squeeze(dim=1)
[docs] def calculate_loss(self, interaction): loss, prediction_prob = self.calculate_loss_prob(interaction) return loss
[docs] def predict(self, interaction): print( "Current softmax cpr code does not support negative sampling in an efficient way just like RepeatNet.", file=sys.stderr, ) assert False # If you can accept slow running time, comment this line loss, prediction_prob = self.calculate_loss_prob( interaction, only_compute_prob=True ) if self.post_remove_context: item_seq = interaction[self.ITEM_SEQ] prediction_prob.scatter_(1, item_seq, 0) test_item = interaction[self.ITEM_ID] prediction_prob = prediction_prob.unsqueeze(-1) # batch_size * num_items * 1 scores = self.gather_indexes(prediction_prob, test_item).squeeze(-1) return scores
[docs] def full_sort_predict(self, interaction): loss, prediction_prob = self.calculate_loss_prob(interaction) if self.post_remove_context: item_seq = interaction[self.ITEM_SEQ] prediction_prob.scatter_(1, item_seq, 0) return prediction_prob