Source code for recbole.model.sequential_recommender.hrm

# -*- coding: utf-8 -*-
# @Time     : 2020/11/22 12:08
# @Author   : Shao Weiqi
# @Reviewer : Lin Kun
# @Email    : shaoweiqi@ruc.edu.cn

r"""
HRM
################################################

Reference:
    Pengfei Wang et al. "Learning Hierarchical Representation Model for Next Basket Recommendation." in SIGIR 2015.

Reference code:
    https://github.com/wubinzzu/NeuRec

"""

import torch
import torch.nn as nn
from torch.nn.init import xavier_normal_

from recbole.model.abstract_recommender import SequentialRecommender
from recbole.model.loss import BPRLoss


[docs]class HRM(SequentialRecommender): r""" HRM can well capture both sequential behavior and users’ general taste by involving transaction and user representations in prediction. HRM user max- & average- pooling as a good helper. """ def __init__(self, config, dataset): super(HRM, self).__init__(config, dataset) # load the dataset information self.n_user = dataset.num(self.USER_ID) self.device = config["device"] # load the parameters information self.embedding_size = config["embedding_size"] self.pooling_type_layer_1 = config["pooling_type_layer_1"] self.pooling_type_layer_2 = config["pooling_type_layer_2"] self.high_order = config["high_order"] assert ( self.high_order <= self.max_seq_length ), "high_order can't longer than the max_seq_length" self.reg_weight = config["reg_weight"] self.dropout_prob = config["dropout_prob"] # define the layers and loss type self.item_embedding = nn.Embedding( self.n_items, self.embedding_size, padding_idx=0 ) self.user_embedding = nn.Embedding(self.n_user, self.embedding_size) self.dropout = nn.Dropout(self.dropout_prob) self.loss_type = config["loss_type"] if self.loss_type == "BPR": self.loss_fct = BPRLoss() elif self.loss_type == "CE": self.loss_fct = nn.CrossEntropyLoss() else: raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!") # init the parameters of the model self.apply(self._init_weights)
[docs] def inverse_seq_item(self, seq_item, seq_item_len): """ inverse the seq_item, like this [1,2,3,0,0,0,0] -- after inverse -->> [0,0,0,0,1,2,3] """ seq_item = seq_item.cpu().numpy() seq_item_len = seq_item_len.cpu().numpy() new_seq_item = [] for items, length in zip(seq_item, seq_item_len): item = list(items[:length]) zeros = list(items[length:]) seqs = zeros + item new_seq_item.append(seqs) seq_item = torch.tensor(new_seq_item, dtype=torch.long, device=self.device) return seq_item
def _init_weights(self, module): if isinstance(module, nn.Embedding): xavier_normal_(module.weight.data)
[docs] def forward(self, seq_item, user, seq_item_len): # seq_item=self.inverse_seq_item(seq_item) seq_item = self.inverse_seq_item(seq_item, seq_item_len) seq_item_embedding = self.item_embedding(seq_item) # batch_size * seq_len * embedding_size high_order_item_embedding = seq_item_embedding[:, -self.high_order :, :] # batch_size * high_order * embedding_size user_embedding = self.dropout(self.user_embedding(user)) # batch_size * embedding_size # layer 1 if self.pooling_type_layer_1 == "max": high_order_item_embedding = torch.max( high_order_item_embedding, dim=1 ).values # batch_size * embedding_size else: for idx, len in enumerate(seq_item_len): if len > self.high_order: seq_item_len[idx] = self.high_order high_order_item_embedding = torch.sum(seq_item_embedding, dim=1) high_order_item_embedding = torch.div( high_order_item_embedding, seq_item_len.unsqueeze(1).float() ) # batch_size * embedding_size hybrid_user_embedding = self.dropout( torch.cat( [ user_embedding.unsqueeze(dim=1), high_order_item_embedding.unsqueeze(dim=1), ], dim=1, ) ) # batch_size * 2_mul_embedding_size # layer 2 if self.pooling_type_layer_2 == "max": hybrid_user_embedding = torch.max(hybrid_user_embedding, dim=1).values # batch_size * embedding_size else: hybrid_user_embedding = torch.mean(hybrid_user_embedding, dim=1) # batch_size * embedding_size return hybrid_user_embedding
[docs] def calculate_loss(self, interaction): seq_item = interaction[self.ITEM_SEQ] seq_item_len = interaction[self.ITEM_SEQ_LEN] user = interaction[self.USER_ID] seq_output = self.forward(seq_item, user, seq_item_len) pos_items = interaction[self.POS_ITEM_ID] pos_items_emb = self.item_embedding(pos_items) if self.loss_type == "BPR": neg_items = interaction[self.NEG_ITEM_ID] neg_items_emb = self.item_embedding(neg_items) pos_score = torch.sum(seq_output * pos_items_emb, dim=-1) neg_score = torch.sum(seq_output * neg_items_emb, dim=-1) loss = self.loss_fct(pos_score, neg_score) return loss else: # self.loss_type = 'CE' test_item_emb = self.item_embedding.weight.t() logits = torch.matmul(seq_output, test_item_emb) loss = self.loss_fct(logits, pos_items) return loss
[docs] def predict(self, interaction): item_seq = interaction[self.ITEM_SEQ] seq_item_len = interaction[self.ITEM_SEQ_LEN] test_item = interaction[self.ITEM_ID] user = interaction[self.USER_ID] seq_output = self.forward(item_seq, user, seq_item_len) test_item_emb = self.item_embedding(test_item) scores = torch.mul(seq_output, test_item_emb).sum(dim=1) return scores
[docs] def full_sort_predict(self, interaction): item_seq = interaction[self.ITEM_SEQ] seq_item_len = interaction[self.ITEM_SEQ_LEN] user = interaction[self.USER_ID] seq_output = self.forward(item_seq, user, seq_item_len) test_items_emb = self.item_embedding.weight scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1)) return scores