# -*- coding: utf-8 -*-
# @Time : 2020/9/18 11:32
# @Author : Hui Wang
# @Email : hui.wang@ruc.edu.cn
r"""
SASRecF
################################################
"""
import torch
from torch import nn
from recbole.model.abstract_recommender import SequentialRecommender
from recbole.model.layers import TransformerEncoder, FeatureSeqEmbLayer
from recbole.model.loss import BPRLoss
[docs]class SASRecF(SequentialRecommender):
"""This is an extension of SASRec, which concatenates item representations and item attribute representations
as the input to the model.
"""
def __init__(self, config, dataset):
super(SASRecF, self).__init__(config, dataset)
# load parameters info
self.n_layers = config['n_layers']
self.n_heads = config['n_heads']
self.hidden_size = config['hidden_size'] # same as embedding_size
self.inner_size = config['inner_size'] # the dimensionality in feed-forward layer
self.hidden_dropout_prob = config['hidden_dropout_prob']
self.attn_dropout_prob = config['attn_dropout_prob']
self.hidden_act = config['hidden_act']
self.layer_norm_eps = config['layer_norm_eps']
self.selected_features = config['selected_features']
self.pooling_mode = config['pooling_mode']
self.device = config['device']
self.num_feature_field = len(config['selected_features'])
self.initializer_range = config['initializer_range']
self.loss_type = config['loss_type']
# define layers and loss
self.item_embedding = nn.Embedding(self.n_items, self.hidden_size, padding_idx=0)
self.position_embedding = nn.Embedding(self.max_seq_length, self.hidden_size)
self.feature_embed_layer = FeatureSeqEmbLayer(
dataset, self.hidden_size, self.selected_features, self.pooling_mode, self.device
)
self.trm_encoder = TransformerEncoder(
n_layers=self.n_layers,
n_heads=self.n_heads,
hidden_size=self.hidden_size,
inner_size=self.inner_size,
hidden_dropout_prob=self.hidden_dropout_prob,
attn_dropout_prob=self.attn_dropout_prob,
hidden_act=self.hidden_act,
layer_norm_eps=self.layer_norm_eps
)
self.concat_layer = nn.Linear(self.hidden_size * (1 + self.num_feature_field), self.hidden_size)
self.LayerNorm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
self.dropout = nn.Dropout(self.hidden_dropout_prob)
if self.loss_type == 'BPR':
self.loss_fct = BPRLoss()
elif self.loss_type == 'CE':
self.loss_fct = nn.CrossEntropyLoss()
else:
raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!")
# parameters initialization
self.apply(self._init_weights)
self.other_parameter_name = ['feature_embed_layer']
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.initializer_range)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
[docs] def get_attention_mask(self, item_seq):
"""Generate left-to-right uni-directional attention mask for multi-head attention."""
attention_mask = (item_seq > 0).long()
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # torch.int64
# mask for left-to-right unidirectional
max_len = attention_mask.size(-1)
attn_shape = (1, max_len, max_len)
subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1) # torch.uint8
subsequent_mask = (subsequent_mask == 0).unsqueeze(1)
subsequent_mask = subsequent_mask.long().to(item_seq.device)
extended_attention_mask = extended_attention_mask * subsequent_mask
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask
[docs] def forward(self, item_seq, item_seq_len):
item_emb = self.item_embedding(item_seq)
# position embedding
position_ids = torch.arange(item_seq.size(1), dtype=torch.long, device=item_seq.device)
position_ids = position_ids.unsqueeze(0).expand_as(item_seq)
position_embedding = self.position_embedding(position_ids)
sparse_embedding, dense_embedding = self.feature_embed_layer(None, item_seq)
sparse_embedding = sparse_embedding['item']
dense_embedding = dense_embedding['item']
# concat the sparse embedding and float embedding
feature_table = []
if sparse_embedding is not None:
feature_table.append(sparse_embedding)
if dense_embedding is not None:
feature_table.append(dense_embedding)
feature_table = torch.cat(feature_table, dim=-2)
table_shape = feature_table.shape
feat_num, embedding_size = table_shape[-2], table_shape[-1]
feature_emb = feature_table.view(table_shape[:-2] + (feat_num * embedding_size,))
input_concat = torch.cat((item_emb, feature_emb), -1) # [B 1+field_num*H]
input_emb = self.concat_layer(input_concat)
input_emb = input_emb + position_embedding
input_emb = self.LayerNorm(input_emb)
input_emb = self.dropout(input_emb)
extended_attention_mask = self.get_attention_mask(item_seq)
trm_output = self.trm_encoder(input_emb, extended_attention_mask, output_all_encoded_layers=True)
output = trm_output[-1]
seq_output = self.gather_indexes(output, item_seq_len - 1)
return seq_output # [B H]
[docs] def calculate_loss(self, interaction):
item_seq = interaction[self.ITEM_SEQ]
item_seq_len = interaction[self.ITEM_SEQ_LEN]
seq_output = self.forward(item_seq, item_seq_len)
pos_items = interaction[self.POS_ITEM_ID]
if self.loss_type == 'BPR':
neg_items = interaction[self.NEG_ITEM_ID]
pos_items_emb = self.item_embedding(pos_items)
neg_items_emb = self.item_embedding(neg_items)
pos_score = torch.sum(seq_output * pos_items_emb, dim=-1) # [B]
neg_score = torch.sum(seq_output * neg_items_emb, dim=-1) # [B]
loss = self.loss_fct(pos_score, neg_score)
return loss
else: # self.loss_type = 'CE'
test_item_emb = self.item_embedding.weight
logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1))
loss = self.loss_fct(logits, pos_items)
return loss
[docs] def predict(self, interaction):
item_seq = interaction[self.ITEM_SEQ]
item_seq_len = interaction[self.ITEM_SEQ_LEN]
test_item = interaction[self.ITEM_ID]
seq_output = self.forward(item_seq, item_seq_len)
test_item_emb = self.item_embedding(test_item)
scores = torch.mul(seq_output, test_item_emb).sum(dim=1)
return scores
[docs] def full_sort_predict(self, interaction):
item_seq = interaction[self.ITEM_SEQ]
item_seq_len = interaction[self.ITEM_SEQ_LEN]
seq_output = self.forward(item_seq, item_seq_len)
test_items_emb = self.item_embedding.weight
scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1)) # [B, item_num]
return scores