# -*- coding: utf-8 -*-
# @Time : 2020/9/8 19:24
# @Author : Yujie Lu
# @Email : yujielu1998@gmail.com
# UPDATE
# @Time : 2020/10/2
# @Author : Yujie Lu
# @Email : yujielu1998@gmail.com
r"""
STAMP
################################################
Reference:
Qiao Liu et al. "STAMP: Short-Term Attention/Memory Priority Model for Session-based Recommendation." in KDD 2018.
"""
import torch
from torch import nn
from torch.nn.init import normal_
from recbole.model.abstract_recommender import SequentialRecommender
from recbole.model.loss import BPRLoss
[docs]class STAMP(SequentialRecommender):
r"""STAMP is capable of capturing users’ general interests from the long-term memory of a session context,
whilst taking into account users’ current interests from the short-term memory of the last-clicks.
Note:
According to the test results, we made a little modification to the score function mentioned in the paper,
and did not use the final sigmoid activation function.
"""
def __init__(self, config, dataset):
super(STAMP, self).__init__(config, dataset)
# load parameters info
self.embedding_size = config["embedding_size"]
# define layers and loss
self.item_embedding = nn.Embedding(
self.n_items, self.embedding_size, padding_idx=0
)
self.w1 = nn.Linear(self.embedding_size, self.embedding_size, bias=False)
self.w2 = nn.Linear(self.embedding_size, self.embedding_size, bias=False)
self.w3 = nn.Linear(self.embedding_size, self.embedding_size, bias=False)
self.w0 = nn.Linear(self.embedding_size, 1, bias=False)
self.b_a = nn.Parameter(torch.zeros(self.embedding_size), requires_grad=True)
self.mlp_a = nn.Linear(self.embedding_size, self.embedding_size, bias=True)
self.mlp_b = nn.Linear(self.embedding_size, self.embedding_size, bias=True)
self.sigmoid = nn.Sigmoid()
self.tanh = nn.Tanh()
self.loss_type = config["loss_type"]
if self.loss_type == "BPR":
self.loss_fct = BPRLoss()
elif self.loss_type == "CE":
self.loss_fct = nn.CrossEntropyLoss()
else:
raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!")
# # parameters initialization
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Embedding):
normal_(module.weight.data, 0, 0.002)
elif isinstance(module, nn.Linear):
normal_(module.weight.data, 0, 0.05)
if module.bias is not None:
module.bias.data.fill_(0.0)
[docs] def forward(self, item_seq, item_seq_len):
item_seq_emb = self.item_embedding(item_seq)
last_inputs = self.gather_indexes(item_seq_emb, item_seq_len - 1)
org_memory = item_seq_emb
ms = torch.div(torch.sum(org_memory, dim=1), item_seq_len.unsqueeze(1).float())
alpha = self.count_alpha(org_memory, last_inputs, ms)
vec = torch.matmul(alpha.unsqueeze(1), org_memory)
ma = vec.squeeze(1) + ms
hs = self.tanh(self.mlp_a(ma))
ht = self.tanh(self.mlp_b(last_inputs))
seq_output = hs * ht
return seq_output
[docs] def count_alpha(self, context, aspect, output):
r"""This is a function that count the attention weights
Args:
context(torch.FloatTensor): Item list embedding matrix, shape of [batch_size, time_steps, emb]
aspect(torch.FloatTensor): The embedding matrix of the last click item, shape of [batch_size, emb]
output(torch.FloatTensor): The average of the context, shape of [batch_size, emb]
Returns:
torch.Tensor:attention weights, shape of [batch_size, time_steps]
"""
timesteps = context.size(1)
aspect_3dim = aspect.repeat(1, timesteps).view(
-1, timesteps, self.embedding_size
)
output_3dim = output.repeat(1, timesteps).view(
-1, timesteps, self.embedding_size
)
res_ctx = self.w1(context)
res_asp = self.w2(aspect_3dim)
res_output = self.w3(output_3dim)
res_sum = res_ctx + res_asp + res_output + self.b_a
res_act = self.w0(self.sigmoid(res_sum))
alpha = res_act.squeeze(2)
return alpha
[docs] def calculate_loss(self, interaction):
item_seq = interaction[self.ITEM_SEQ]
item_seq_len = interaction[self.ITEM_SEQ_LEN]
seq_output = self.forward(item_seq, item_seq_len)
pos_items = interaction[self.POS_ITEM_ID]
if self.loss_type == "BPR":
neg_items = interaction[self.NEG_ITEM_ID]
pos_items_emb = self.item_embedding(pos_items)
neg_items_emb = self.item_embedding(neg_items)
pos_score = torch.sum(seq_output * pos_items_emb, dim=-1) # [B]
neg_score = torch.sum(seq_output * neg_items_emb, dim=-1) # [B]
loss = self.loss_fct(pos_score, neg_score)
return loss
else: # self.loss_type = 'CE'
test_item_emb = self.item_embedding.weight
logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1))
loss = self.loss_fct(logits, pos_items)
return loss
[docs] def predict(self, interaction):
item_seq = interaction[self.ITEM_SEQ]
item_seq_len = interaction[self.ITEM_SEQ_LEN]
test_item = interaction[self.ITEM_ID]
seq_output = self.forward(item_seq, item_seq_len)
test_item_emb = self.item_embedding(test_item)
scores = torch.mul(seq_output, test_item_emb).sum(dim=1) # [B]
return scores
[docs] def full_sort_predict(self, interaction):
item_seq = interaction[self.ITEM_SEQ]
item_seq_len = interaction[self.ITEM_SEQ_LEN]
seq_output = self.forward(item_seq, item_seq_len)
test_items_emb = self.item_embedding.weight
scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1))
return scores