# -*- coding: utf-8 -*-
# @Time : 2020/8/25 19:56
# @Author : Yujie Lu
# @Email : yujielu1998@gmail.com
# UPDATE
# @Time : 2020/9/15, 2020/10/2
# @Author : Yupeng Hou, Yujie Lu
# @Email : houyupeng@ruc.edu.cn, yujielu1998@gmail.com
r"""
NARM
################################################
Reference:
Jing Li et al. "Neural Attentive Session-based Recommendation." in CIKM 2017.
Reference code:
https://github.com/Wang-Shuo/Neural-Attentive-Session-Based-Recommendation-PyTorch
"""
import torch
from torch import nn
from torch.nn.init import xavier_normal_, constant_
from recbole.model.abstract_recommender import SequentialRecommender
from recbole.model.loss import BPRLoss
[docs]class NARM(SequentialRecommender):
r"""NARM explores a hybrid encoder with an attention mechanism to model the user’s sequential behavior,
and capture the user’s main purpose in the current session.
"""
def __init__(self, config, dataset):
super(NARM, self).__init__(config, dataset)
# load parameters info
self.embedding_size = config['embedding_size']
self.hidden_size = config['hidden_size']
self.n_layers = config['n_layers']
self.dropout_probs = config['dropout_probs']
self.device = config['device']
# define layers and loss
self.item_embedding = nn.Embedding(self.n_items, self.embedding_size, padding_idx=0)
self.emb_dropout = nn.Dropout(self.dropout_probs[0])
self.gru = nn.GRU(self.embedding_size, self.hidden_size, self.n_layers, bias=False, batch_first=True)
self.a_1 = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
self.a_2 = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
self.v_t = nn.Linear(self.hidden_size, 1, bias=False)
self.ct_dropout = nn.Dropout(self.dropout_probs[1])
self.b = nn.Linear(2 * self.hidden_size, self.embedding_size, bias=False)
self.loss_type = config['loss_type']
if self.loss_type == 'BPR':
self.loss_fct = BPRLoss()
elif self.loss_type == 'CE':
self.loss_fct = nn.CrossEntropyLoss()
else:
raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!")
# parameters initialization
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Embedding):
xavier_normal_(module.weight.data)
elif isinstance(module, nn.Linear):
xavier_normal_(module.weight.data)
if module.bias is not None:
constant_(module.bias.data, 0)
[docs] def forward(self, item_seq, item_seq_len):
item_seq_emb = self.item_embedding(item_seq)
item_seq_emb_dropout = self.emb_dropout(item_seq_emb)
gru_out, _ = self.gru(item_seq_emb_dropout)
# fetch the last hidden state of last timestamp
c_global = ht = self.gather_indexes(gru_out, item_seq_len - 1)
# avoid the influence of padding
mask = item_seq.gt(0).unsqueeze(2).expand_as(gru_out)
q1 = self.a_1(gru_out)
q2 = self.a_2(ht)
q2_expand = q2.unsqueeze(1).expand_as(q1)
# calculate weighted factors α
alpha = self.v_t(mask * torch.sigmoid(q1 + q2_expand))
c_local = torch.sum(alpha.expand_as(gru_out) * gru_out, 1)
c_t = torch.cat([c_local, c_global], 1)
c_t = self.ct_dropout(c_t)
seq_output = self.b(c_t)
return seq_output
[docs] def calculate_loss(self, interaction):
item_seq = interaction[self.ITEM_SEQ]
item_seq_len = interaction[self.ITEM_SEQ_LEN]
seq_output = self.forward(item_seq, item_seq_len)
pos_items = interaction[self.POS_ITEM_ID]
if self.loss_type == 'BPR':
neg_items = interaction[self.NEG_ITEM_ID]
pos_items_emb = self.item_embedding(pos_items)
neg_items_emb = self.item_embedding(neg_items)
pos_score = torch.sum(seq_output * pos_items_emb, dim=-1) # [B]
neg_score = torch.sum(seq_output * neg_items_emb, dim=-1) # [B]
loss = self.loss_fct(pos_score, neg_score)
return loss
else: # self.loss_type = 'CE'
test_item_emb = self.item_embedding.weight
logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1))
loss = self.loss_fct(logits, pos_items)
return loss
[docs] def predict(self, interaction):
item_seq = interaction[self.ITEM_SEQ]
item_seq_len = interaction[self.ITEM_SEQ_LEN]
test_item = interaction[self.ITEM_ID]
seq_output = self.forward(item_seq, item_seq_len)
test_item_emb = self.item_embedding(test_item)
scores = torch.mul(seq_output, test_item_emb).sum(dim=1) # [B]
return scores
[docs] def full_sort_predict(self, interaction):
item_seq = interaction[self.ITEM_SEQ]
item_seq_len = interaction[self.ITEM_SEQ_LEN]
seq_output = self.forward(item_seq, item_seq_len)
test_items_emb = self.item_embedding.weight
scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1))
return scores