Source code for recbole.model.sequential_recommender.stamp

# -*- coding: utf-8 -*-
# @Time   : 2020/9/8 19:24
# @Author : Yujie Lu
# @Email  : yujielu1998@gmail.com

# UPDATE
# @Time   : 2020/10/2
# @Author : Yujie Lu
# @Email  : yujielu1998@gmail.com

r"""
STAMP
################################################

Reference:
    Qiao Liu et al. "STAMP: Short-Term Attention/Memory Priority Model for Session-based Recommendation." in KDD 2018.

"""

import torch
from torch import nn
from torch.nn.init import normal_

from recbole.model.abstract_recommender import SequentialRecommender
from recbole.model.loss import BPRLoss


[docs]class STAMP(SequentialRecommender): r"""STAMP is capable of capturing users’ general interests from the long-term memory of a session context, whilst taking into account users’ current interests from the short-term memory of the last-clicks. Note: According to the test results, we made a little modification to the score function mentioned in the paper, and did not use the final sigmoid activation function. """ def __init__(self, config, dataset): super(STAMP, self).__init__(config, dataset) # load parameters info self.embedding_size = config["embedding_size"] # define layers and loss self.item_embedding = nn.Embedding( self.n_items, self.embedding_size, padding_idx=0 ) self.w1 = nn.Linear(self.embedding_size, self.embedding_size, bias=False) self.w2 = nn.Linear(self.embedding_size, self.embedding_size, bias=False) self.w3 = nn.Linear(self.embedding_size, self.embedding_size, bias=False) self.w0 = nn.Linear(self.embedding_size, 1, bias=False) self.b_a = nn.Parameter(torch.zeros(self.embedding_size), requires_grad=True) self.mlp_a = nn.Linear(self.embedding_size, self.embedding_size, bias=True) self.mlp_b = nn.Linear(self.embedding_size, self.embedding_size, bias=True) self.sigmoid = nn.Sigmoid() self.tanh = nn.Tanh() self.loss_type = config["loss_type"] if self.loss_type == "BPR": self.loss_fct = BPRLoss() elif self.loss_type == "CE": self.loss_fct = nn.CrossEntropyLoss() else: raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!") # # parameters initialization self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Embedding): normal_(module.weight.data, 0, 0.002) elif isinstance(module, nn.Linear): normal_(module.weight.data, 0, 0.05) if module.bias is not None: module.bias.data.fill_(0.0)
[docs] def forward(self, item_seq, item_seq_len): item_seq_emb = self.item_embedding(item_seq) last_inputs = self.gather_indexes(item_seq_emb, item_seq_len - 1) org_memory = item_seq_emb ms = torch.div(torch.sum(org_memory, dim=1), item_seq_len.unsqueeze(1).float()) alpha = self.count_alpha(org_memory, last_inputs, ms) vec = torch.matmul(alpha.unsqueeze(1), org_memory) ma = vec.squeeze(1) + ms hs = self.tanh(self.mlp_a(ma)) ht = self.tanh(self.mlp_b(last_inputs)) seq_output = hs * ht return seq_output
[docs] def count_alpha(self, context, aspect, output): r"""This is a function that count the attention weights Args: context(torch.FloatTensor): Item list embedding matrix, shape of [batch_size, time_steps, emb] aspect(torch.FloatTensor): The embedding matrix of the last click item, shape of [batch_size, emb] output(torch.FloatTensor): The average of the context, shape of [batch_size, emb] Returns: torch.Tensor:attention weights, shape of [batch_size, time_steps] """ timesteps = context.size(1) aspect_3dim = aspect.repeat(1, timesteps).view( -1, timesteps, self.embedding_size ) output_3dim = output.repeat(1, timesteps).view( -1, timesteps, self.embedding_size ) res_ctx = self.w1(context) res_asp = self.w2(aspect_3dim) res_output = self.w3(output_3dim) res_sum = res_ctx + res_asp + res_output + self.b_a res_act = self.w0(self.sigmoid(res_sum)) alpha = res_act.squeeze(2) return alpha
[docs] def calculate_loss(self, interaction): item_seq = interaction[self.ITEM_SEQ] item_seq_len = interaction[self.ITEM_SEQ_LEN] seq_output = self.forward(item_seq, item_seq_len) pos_items = interaction[self.POS_ITEM_ID] if self.loss_type == "BPR": neg_items = interaction[self.NEG_ITEM_ID] pos_items_emb = self.item_embedding(pos_items) neg_items_emb = self.item_embedding(neg_items) pos_score = torch.sum(seq_output * pos_items_emb, dim=-1) # [B] neg_score = torch.sum(seq_output * neg_items_emb, dim=-1) # [B] loss = self.loss_fct(pos_score, neg_score) return loss else: # self.loss_type = 'CE' test_item_emb = self.item_embedding.weight logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1)) loss = self.loss_fct(logits, pos_items) return loss
[docs] def predict(self, interaction): item_seq = interaction[self.ITEM_SEQ] item_seq_len = interaction[self.ITEM_SEQ_LEN] test_item = interaction[self.ITEM_ID] seq_output = self.forward(item_seq, item_seq_len) test_item_emb = self.item_embedding(test_item) scores = torch.mul(seq_output, test_item_emb).sum(dim=1) # [B] return scores
[docs] def full_sort_predict(self, interaction): item_seq = interaction[self.ITEM_SEQ] item_seq_len = interaction[self.ITEM_SEQ_LEN] seq_output = self.forward(item_seq, item_seq_len) test_items_emb = self.item_embedding.weight scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1)) return scores