Source code for recbole.model.sequential_recommender.core

# -*- coding: utf-8 -*-

r"""
CORE
################################################
Reference:
    Yupeng Hou, Binbin Hu, Zhiqiang Zhang, Wayne Xin Zhao. "CORE: Simple and Effective Session-based Recommendation within Consistent Representation Space." in SIGIR 2022.

    https://github.com/RUCAIBox/CORE
"""

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F

from recbole.model.abstract_recommender import SequentialRecommender
from recbole.model.layers import TransformerEncoder


[docs]class TransNet(nn.Module): def __init__(self, config, dataset): super().__init__() self.n_layers = config['n_layers'] self.n_heads = config['n_heads'] self.hidden_size = config['embedding_size'] self.inner_size = config['inner_size'] self.hidden_dropout_prob = config['hidden_dropout_prob'] self.attn_dropout_prob = config['attn_dropout_prob'] self.hidden_act = config['hidden_act'] self.layer_norm_eps = config['layer_norm_eps'] self.initializer_range = config['initializer_range'] self.position_embedding = nn.Embedding(dataset.field2seqlen[config['ITEM_ID_FIELD'] + config['LIST_SUFFIX']], self.hidden_size) self.trm_encoder = TransformerEncoder( n_layers=self.n_layers, n_heads=self.n_heads, hidden_size=self.hidden_size, inner_size=self.inner_size, hidden_dropout_prob=self.hidden_dropout_prob, attn_dropout_prob=self.attn_dropout_prob, hidden_act=self.hidden_act, layer_norm_eps=self.layer_norm_eps ) self.LayerNorm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) self.dropout = nn.Dropout(self.hidden_dropout_prob) self.fn = nn.Linear(self.hidden_size, 1) self.apply(self._init_weights)
[docs] def get_attention_mask(self, item_seq, bidirectional=False): """Generate left-to-right uni-directional or bidirectional attention mask for multi-head attention.""" attention_mask = (item_seq != 0) extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # torch.bool if not bidirectional: extended_attention_mask = torch.tril(extended_attention_mask.expand((-1, -1, item_seq.size(-1), -1))) extended_attention_mask = torch.where(extended_attention_mask, 0., -10000.) return extended_attention_mask
[docs] def forward(self, item_seq, item_emb): mask = item_seq.gt(0) position_ids = torch.arange(item_seq.size(1), dtype=torch.long, device=item_seq.device) position_ids = position_ids.unsqueeze(0).expand_as(item_seq) position_embedding = self.position_embedding(position_ids) input_emb = item_emb + position_embedding input_emb = self.LayerNorm(input_emb) input_emb = self.dropout(input_emb) extended_attention_mask = self.get_attention_mask(item_seq) trm_output = self.trm_encoder(input_emb, extended_attention_mask, output_all_encoded_layers=True) output = trm_output[-1] alpha = self.fn(output).to(torch.double) alpha = torch.where(mask.unsqueeze(-1), alpha, -9e15) alpha = torch.softmax(alpha, dim=1, dtype=torch.float) return alpha
def _init_weights(self, module): """ Initialize the weights """ if isinstance(module, (nn.Linear, nn.Embedding)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.initializer_range) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_()
[docs]class CORE(SequentialRecommender): r"""CORE is a simple and effective framewor, which unifies the representation spac for both the encoding and decoding processes in session-based recommendation. """ def __init__(self, config, dataset): super(CORE, self).__init__(config, dataset) # load parameters info self.embedding_size = config['embedding_size'] self.loss_type = config['loss_type'] self.dnn_type = config['dnn_type'] self.sess_dropout = nn.Dropout(config['sess_dropout']) self.item_dropout = nn.Dropout(config['item_dropout']) self.temperature = config['temperature'] # item embedding self.item_embedding = nn.Embedding(self.n_items, self.embedding_size, padding_idx=0) # DNN if self.dnn_type == 'trm': self.net = TransNet(config, dataset) elif self.dnn_type == 'ave': self.net = self.ave_net else: raise ValueError(f'dnn_type should be either trm or ave, but have [{self.dnn_type}].') if self.loss_type == 'CE': self.loss_fct = nn.CrossEntropyLoss() else: raise NotImplementedError("Make sure 'loss_type' in ['CE']!") # parameters initialization self._reset_parameters() def _reset_parameters(self): stdv = 1.0 / np.sqrt(self.embedding_size) for weight in self.parameters(): weight.data.uniform_(-stdv, stdv)
[docs] @staticmethod def ave_net(item_seq, item_emb): mask = item_seq.gt(0) alpha = mask.to(torch.float) / mask.sum(dim=-1, keepdim=True) return alpha.unsqueeze(-1)
[docs] def forward(self, item_seq): x = self.item_embedding(item_seq) x = self.sess_dropout(x) # Representation-Consistent Encoder (RCE) alpha = self.net(item_seq, x) seq_output = torch.sum(alpha * x, dim=1) seq_output = F.normalize(seq_output, dim=-1) return seq_output
[docs] def calculate_loss(self, interaction): item_seq = interaction[self.ITEM_SEQ] seq_output = self.forward(item_seq) pos_items = interaction[self.POS_ITEM_ID] all_item_emb = self.item_embedding.weight all_item_emb = self.item_dropout(all_item_emb) # Robust Distance Measuring (RDM) all_item_emb = F.normalize(all_item_emb, dim=-1) logits = torch.matmul(seq_output, all_item_emb.transpose(0, 1)) / self.temperature loss = self.loss_fct(logits, pos_items) return loss
[docs] def predict(self, interaction): item_seq = interaction[self.ITEM_SEQ] item_seq_len = interaction[self.ITEM_SEQ_LEN] test_item = interaction[self.ITEM_ID] seq_output = self.forward(item_seq, item_seq_len) test_item_emb = self.item_embedding(test_item) scores = torch.mul(seq_output, test_item_emb).sum(dim=1) / self.temperature return scores
[docs] def full_sort_predict(self, interaction): item_seq = interaction[self.ITEM_SEQ] seq_output = self.forward(item_seq) test_item_emb = self.item_embedding.weight # no dropout for evaluation test_item_emb = F.normalize(test_item_emb, dim=-1) scores = torch.matmul(seq_output, test_item_emb.transpose(0, 1)) / self.temperature return scores