Source code for recbole.model.sequential_recommender.fdsa

# -*- coding: utf-8 -*-
# @Time    : 2020/9/18 11:27
# @Author  : Hui Wang
# @Email   : hui.wang@ruc.edu.cn

r"""
FDSA
################################################

Reference:
    Tingting Zhang et al. "Feature-level Deeper Self-Attention Network for Sequential Recommendation."
    In IJCAI 2019

"""

import torch
from torch import nn

from recbole.model.abstract_recommender import SequentialRecommender
from recbole.model.layers import TransformerEncoder, FeatureSeqEmbLayer, VanillaAttention
from recbole.model.loss import BPRLoss


[docs]class FDSA(SequentialRecommender): r""" FDSA is similar with the GRU4RecF implemented in RecBole, which uses two different Transformer encoders to encode items and features respectively and concatenates the two subparts' outputs as the final output. """ def __init__(self, config, dataset): super(FDSA, self).__init__(config, dataset) # load parameters info self.n_layers = config['n_layers'] self.n_heads = config['n_heads'] self.hidden_size = config['hidden_size'] # same as embedding_size self.inner_size = config['inner_size'] # the dimensionality in feed-forward layer self.hidden_dropout_prob = config['hidden_dropout_prob'] self.attn_dropout_prob = config['attn_dropout_prob'] self.hidden_act = config['hidden_act'] self.layer_norm_eps = config['layer_norm_eps'] self.selected_features = config['selected_features'] self.pooling_mode = config['pooling_mode'] self.device = config['device'] self.num_feature_field = len(config['selected_features']) self.initializer_range = config['initializer_range'] self.loss_type = config['loss_type'] # define layers and loss self.item_embedding = nn.Embedding(self.n_items, self.hidden_size, padding_idx=0) self.position_embedding = nn.Embedding(self.max_seq_length, self.hidden_size) self.feature_embed_layer = FeatureSeqEmbLayer( dataset, self.hidden_size, self.selected_features, self.pooling_mode, self.device ) self.item_trm_encoder = TransformerEncoder( n_layers=self.n_layers, n_heads=self.n_heads, hidden_size=self.hidden_size, inner_size=self.inner_size, hidden_dropout_prob=self.hidden_dropout_prob, attn_dropout_prob=self.attn_dropout_prob, hidden_act=self.hidden_act, layer_norm_eps=self.layer_norm_eps ) self.feature_att_layer = VanillaAttention(self.hidden_size, self.hidden_size) # For simplicity, we use same architecture for item_trm and feature_trm self.feature_trm_encoder = TransformerEncoder( n_layers=self.n_layers, n_heads=self.n_heads, hidden_size=self.hidden_size, inner_size=self.inner_size, hidden_dropout_prob=self.hidden_dropout_prob, attn_dropout_prob=self.attn_dropout_prob, hidden_act=self.hidden_act, layer_norm_eps=self.layer_norm_eps ) self.LayerNorm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) self.dropout = nn.Dropout(self.hidden_dropout_prob) self.concat_layer = nn.Linear(self.hidden_size * 2, self.hidden_size) if self.loss_type == 'BPR': self.loss_fct = BPRLoss() elif self.loss_type == 'CE': self.loss_fct = nn.CrossEntropyLoss() else: raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!") # parameters initialization self.apply(self._init_weights) self.other_parameter_name = ['feature_embed_layer'] def _init_weights(self, module): """ Initialize the weights """ if isinstance(module, (nn.Linear, nn.Embedding)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.initializer_range) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_()
[docs] def get_attention_mask(self, item_seq): """Generate left-to-right uni-directional attention mask for multi-head attention.""" attention_mask = (item_seq > 0).long() extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # torch.int64 # mask for left-to-right unidirectional max_len = attention_mask.size(-1) attn_shape = (1, max_len, max_len) subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1) # torch.uint8 subsequent_mask = (subsequent_mask == 0).unsqueeze(1) subsequent_mask = subsequent_mask.long().to(item_seq.device) extended_attention_mask = extended_attention_mask * subsequent_mask extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 return extended_attention_mask
[docs] def forward(self, item_seq, item_seq_len): item_emb = self.item_embedding(item_seq) position_ids = torch.arange(item_seq.size(1), dtype=torch.long, device=item_seq.device) position_ids = position_ids.unsqueeze(0).expand_as(item_seq) position_embedding = self.position_embedding(position_ids) # get item_trm_input # item position add position embedding item_emb = item_emb + position_embedding item_emb = self.LayerNorm(item_emb) item_trm_input = self.dropout(item_emb) sparse_embedding, dense_embedding = self.feature_embed_layer(None, item_seq) sparse_embedding = sparse_embedding['item'] dense_embedding = dense_embedding['item'] # concat the sparse embedding and float embedding feature_table = [] if sparse_embedding is not None: feature_table.append(sparse_embedding) if dense_embedding is not None: feature_table.append(dense_embedding) # [batch len num_features hidden_size] feature_table = torch.cat(feature_table, dim=-2) # feature_emb [batch len hidden] # weight [batch len num_features] # if only one feature, the weight would be 1.0 feature_emb, attn_weight = self.feature_att_layer(feature_table) # feature position add position embedding feature_emb = feature_emb + position_embedding feature_emb = self.LayerNorm(feature_emb) feature_trm_input = self.dropout(feature_emb) extended_attention_mask = self.get_attention_mask(item_seq) item_trm_output = self.item_trm_encoder(item_trm_input, extended_attention_mask, output_all_encoded_layers=True) item_output = item_trm_output[-1] feature_trm_output = self.feature_trm_encoder( feature_trm_input, extended_attention_mask, output_all_encoded_layers=True ) # [B Len H] feature_output = feature_trm_output[-1] item_output = self.gather_indexes(item_output, item_seq_len - 1) # [B H] feature_output = self.gather_indexes(feature_output, item_seq_len - 1) # [B H] output_concat = torch.cat((item_output, feature_output), -1) # [B 2*H] output = self.concat_layer(output_concat) output = self.LayerNorm(output) seq_output = self.dropout(output) return seq_output # [B H]
[docs] def calculate_loss(self, interaction): item_seq = interaction[self.ITEM_SEQ] item_seq_len = interaction[self.ITEM_SEQ_LEN] seq_output = self.forward(item_seq, item_seq_len) pos_items = interaction[self.POS_ITEM_ID] if self.loss_type == 'BPR': neg_items = interaction[self.NEG_ITEM_ID] pos_items_emb = self.item_embedding(pos_items) neg_items_emb = self.item_embedding(neg_items) pos_score = torch.sum(seq_output * pos_items_emb, dim=-1) # [B] neg_score = torch.sum(seq_output * neg_items_emb, dim=-1) # [B] loss = self.loss_fct(pos_score, neg_score) return loss else: # self.loss_type = 'CE' test_item_emb = self.item_embedding.weight logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1)) loss = self.loss_fct(logits, pos_items) return loss
[docs] def predict(self, interaction): item_seq = interaction[self.ITEM_SEQ] item_seq_len = interaction[self.ITEM_SEQ_LEN] test_item = interaction[self.ITEM_ID] seq_output = self.forward(item_seq, item_seq_len) test_item_emb = self.item_embedding(test_item) scores = torch.mul(seq_output, test_item_emb).sum(dim=1) # [B] return scores
[docs] def full_sort_predict(self, interaction): item_seq = interaction[self.ITEM_SEQ] item_seq_len = interaction[self.ITEM_SEQ_LEN] seq_output = self.forward(item_seq, item_seq_len) test_items_emb = self.item_embedding.weight scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1)) # [B, n_items] return scores