# -*- encoding: utf-8 -*-
# @Time : 2020/09/01
# @Author : Kaiyuan Li
# @email : tsotfsk@outlook.com
# @Time : 2020/10/14
# @Author : Kaiyuan Li
# @Email : tsotfsk@outlook.com
Xiangnan He et al. "NAIS: Neural Attentive Item Similarity Model for Recommendation." in TKDE 2018.
Reference code:
from logging import getLogger
import torch
import torch.nn as nn
from recbole.model.abstract_recommender import GeneralRecommender
from recbole.model.layers import MLPLayers
from recbole.utils import InputType
from torch.nn.init import constant_, normal_, xavier_normal_
[docs]class NAIS(GeneralRecommender):
"""NAIS is an attention network, which is capable of distinguishing which historical items
in a user profile are more important for a prediction. We just implement the model following
the original author with a pointwise training mode.
instead of forming a minibatch as all training instances of a randomly sampled user which is
mentioned in the original paper, we still train the model by a randomly sampled interactions.
input_type = InputType.POINTWISE
def __init__(self, config, dataset):
super(NAIS, self).__init__(config, dataset)
# load dataset info
self.LABEL = config['LABEL_FIELD']
self.logger = getLogger()
# get all users's history interaction information.the history item
# matrix is padding by the maximum number of a user's interactions
self.history_item_matrix, self.history_lens, self.mask_mat = self.get_history_info(dataset)
# load parameters info
self.embedding_size = config['embedding_size']
self.weight_size = config['weight_size']
self.algorithm = config['algorithm']
self.reg_weights = config['reg_weights']
self.alpha = config['alpha']
self.beta = config['beta']
self.split_to = config['split_to']
self.pretrain_path = config['pretrain_path']
# split the too large dataset into the specified pieces
if self.split_to > 0:
self.logger.info('split the n_items to {} pieces'.format(self.split_to))
self.group = torch.chunk(torch.arange(self.n_items).to(self.device), self.split_to)
self.logger.warning('Pay Attetion!! the `split_to` is set to 0. If you catch a OMM error in this case, ' + \
'you need to increase it \n\t\t\tuntil the error disappears. For example, ' + \
'you can append it in the command line such as `--split_to=5`')
# define layers and loss
# construct source and destination item embedding matrix
self.item_src_embedding = nn.Embedding(self.n_items, self.embedding_size, padding_idx=0)
self.item_dst_embedding = nn.Embedding(self.n_items, self.embedding_size, padding_idx=0)
self.bias = nn.Parameter(torch.zeros(self.n_items))
if self.algorithm == 'concat':
self.mlp_layers = MLPLayers([self.embedding_size*2, self.weight_size])
elif self.algorithm == 'prod':
self.mlp_layers = MLPLayers([self.embedding_size, self.weight_size])
raise ValueError("NAIS just support attention type in ['concat', 'prod'] but get {}".format(self.algorithm))
self.weight_layer = nn.Parameter(torch.ones(self.weight_size, 1))
self.bceloss = nn.BCELoss()
# parameters initialization
if self.pretrain_path is not None:
self.logger.info('use pretrain from [{}]...'.format(self.pretrain_path))
self.logger.info('unuse pretrain...')
def _init_weights(self, module):
"""Initialize the module's parameters
It's a little different from the source code, because pytorch has no function to initialize
the parameters by truncated normal distribution, so we replace it with xavier normal distribution
if isinstance(module, nn.Embedding):
normal_(module.weight.data, 0, 0.01)
elif isinstance(module, nn.Linear):
if module.bias is not None:
constant_(module.bias.data, 0)
def _load_pretrain(self):
"""A simple implementation of loading pretrained parameters.
fism = torch.load(self.pretrain_path)['state_dict']
for name, parm in self.mlp_layers.named_parameters():
if name.endswith('weight'):
elif name.endswith('bias'):
constant_(parm.data, 0)
[docs] def get_history_info(self, dataset):
"""get the user history interaction information
dataset (DataSet): train dataset
tuple: (history_item_matrix, history_lens, mask_mat)
history_item_matrix, _, history_lens = dataset.history_item_matrix()
history_item_matrix = history_item_matrix.to(self.device)
history_lens = history_lens.to(self.device)
arange_tensor = torch.arange(history_item_matrix.shape[1]).to(self.device)
mask_mat = (arange_tensor < history_lens.unsqueeze(1)).float()
return history_item_matrix, history_lens, mask_mat
[docs] def reg_loss(self):
"""calculate the reg loss for embedding layers and mlp layers
torch.Tensor: reg loss
reg_1, reg_2, reg_3 = self.reg_weights
loss_1 = reg_1 * self.item_src_embedding.weight.norm(2)
loss_2 = reg_2 * self.item_dst_embedding.weight.norm(2)
loss_3 = 0
for name, parm in self.mlp_layers.named_parameters():
if name.endswith('weight'):
loss_3 = loss_3 + reg_3 * parm.norm(2)
return loss_1 + loss_2 + loss_3
[docs] def attention_mlp(self, inter, target):
"""layers of attention which support `prod` and `concat`
inter (torch.Tensor): the embedding of history items
target (torch.Tensor): the embedding of target items
torch.Tensor: the result of attention
if self.algorithm == 'prod':
mlp_input = inter * target.unsqueeze(1) # batch_size x max_len x embedding_size
mlp_input = torch.cat([inter, target.unsqueeze(1).expand_as(inter)], dim=2) # batch_size x max_len x embedding_size*2
mlp_output = self.mlp_layers(mlp_input) # batch_size x max_len x weight_size
logits = torch.matmul(mlp_output, self.weight_layer).squeeze(2) # batch_size x max_len
return logits
[docs] def mask_softmax(self, similarity, logits, bias, item_num, batch_mask_mat):
"""softmax the unmasked user history items and get the final output
similarity (torch.Tensor): the similarity between the histoy items and target items
logits (torch.Tensor): the initial weights of the history items
item_num (torch.Tensor): user hitory interaction lengths
bias (torch.Tensor): bias
batch_mask_mat (torch.Tensor): the mask of user history interactions
torch.Tensor: final output
exp_logits = torch.exp(logits) # batch_size x max_len
exp_logits = batch_mask_mat * exp_logits # batch_size x max_len
exp_sum = torch.sum(exp_logits, dim=1, keepdim=True)
exp_sum = torch.pow(exp_sum, self.beta)
weights = torch.div(exp_logits, exp_sum)
coeff = torch.pow(item_num.squeeze(1), -self.alpha)
output = torch.sigmoid(coeff.float() * torch.sum(weights * similarity, dim=1) + bias)
return output
[docs] def softmax(self, similarity, logits, item_num, bias):
"""softmax the user history features and get the final output
similarity (torch.Tensor): the similarity between the histoy items and target items
logits (torch.Tensor): the initial weights of the history items
item_num (torch.Tensor): user hitory interaction lengths
bias (torch.Tensor): bias
torch.Tensor: final output
exp_logits = torch.exp(logits) # batch_size x max_len
exp_sum = torch.sum(exp_logits, dim=1, keepdim=True)
exp_sum = torch.pow(exp_sum, self.beta)
weights = torch.div(exp_logits, exp_sum)
coeff = torch.pow(item_num.squeeze(1), -self.alpha)
output = torch.sigmoid(coeff.float() * torch.sum(weights * similarity, dim=1) + bias)
return output
[docs] def inter_forward(self, user, item):
"""forward the model by interaction
user_inter = self.history_item_matrix[user]
item_num = self.history_lens[user].unsqueeze(1)
batch_mask_mat = self.mask_mat[user]
user_history = self.item_src_embedding(user_inter) # batch_size x max_len x embedding_size
target = self.item_dst_embedding(item) # batch_size x embedding_size
bias = self.bias[item] # batch_size x 1
similarity = torch.bmm(user_history, target.unsqueeze(2)).squeeze(2) # batch_size x max_len
logits = self.attention_mlp(user_history, target)
scores = self.mask_softmax(similarity, logits, bias, item_num, batch_mask_mat)
return scores
[docs] def user_forward(self, user_input, item_num, repeats=None, pred_slc=None):
"""forward the model by user
user_input (torch.Tensor): user input tensor
item_num (torch.Tensor): user hitory interaction lens
repeats (int, optional): the number of items to be evaluated
pred_slc (torch.Tensor, optional): continuous index which controls the current evaluation items,
if pred_slc is None, it will evaluate all items
torch.Tensor: result
item_num = item_num.repeat(repeats, 1)
user_history = self.item_src_embedding(user_input) # inter_num x embedding_size
user_history = user_history.repeat(repeats, 1, 1) # target_items x inter_num x embedding_size
if pred_slc is None:
targets = self.item_dst_embedding.weight # target_items x embedding_size
bias = self.bias
targets = self.item_dst_embedding(pred_slc)
bias = self.bias[pred_slc]
similarity = torch.bmm(user_history, targets.unsqueeze(2)).squeeze(2) # inter_num x target_items
logits = self.attention_mlp(user_history, targets)
scores = self.softmax(similarity, logits, item_num, bias)
return scores
[docs] def forward(self, user, item):
return self.inter_forward(user, item)
[docs] def calculate_loss(self, interaction):
user = interaction[self.USER_ID]
item = interaction[self.ITEM_ID]
label = interaction[self.LABEL]
output = self.forward(user, item)
loss = self.bceloss(output, label) + self.reg_loss()
return loss
[docs] def full_sort_predict(self, interaction):
user = interaction[self.USER_ID]
user_inters = self.history_item_matrix[user]
item_nums = self.history_lens[user]
scores = []
# test users one by one, if the number of items is too large, we will split it to some pieces
for user_input, item_num in zip(user_inters, item_nums.unsqueeze(1)):
if self.split_to <= 0:
output = self.user_forward(user_input[:item_num], item_num, repeats=self.n_items)
output = []
for mask in self.group:
tmp_output = self.user_forward(user_input[:item_num], item_num, repeats=len(mask), pred_slc=mask)
output = torch.cat(output, dim=0)
result = torch.cat(scores, dim=0)
return result
[docs] def predict(self, interaction):
user = interaction[self.USER_ID]
item = interaction[self.ITEM_ID]
output = self.forward(user, item)
return output