# -*- encoding: utf-8 -*-
# @Time : 2020/09/01
# @Author : Kaiyuan Li
# @email : tsotfsk@outlook.com
# UPDATE:
# @Time : 2020/10/14
# @Author : Kaiyuan Li
# @Email : tsotfsk@outlook.com
"""
NAIS
######################################
Reference:
Xiangnan He et al. "NAIS: Neural Attentive Item Similarity Model for Recommendation." in TKDE 2018.
Reference code:
https://github.com/AaronHeee/Neural-Attentive-Item-Similarity-Model
"""
import torch
import torch.nn as nn
from torch.nn.init import constant_, normal_, xavier_normal_
from recbole.model.abstract_recommender import GeneralRecommender
from recbole.model.layers import MLPLayers
from recbole.utils import InputType
[docs]class NAIS(GeneralRecommender):
"""NAIS is an attention network, which is capable of distinguishing which historical items
in a user profile are more important for a prediction. We just implement the model following
the original author with a pointwise training mode.
Note:
instead of forming a minibatch as all training instances of a randomly sampled user which is
mentioned in the original paper, we still train the model by a randomly sampled interactions.
"""
input_type = InputType.POINTWISE
def __init__(self, config, dataset):
super(NAIS, self).__init__(config, dataset)
# load dataset info
self.LABEL = config["LABEL_FIELD"]
# get all users' history interaction information.the history item
# matrix is padding by the maximum number of a user's interactions
(
self.history_item_matrix,
self.history_lens,
self.mask_mat,
) = self.get_history_info(dataset)
# load parameters info
self.embedding_size = config["embedding_size"]
self.weight_size = config["weight_size"]
self.algorithm = config["algorithm"]
self.reg_weights = config["reg_weights"]
self.alpha = config["alpha"]
self.beta = config["beta"]
self.split_to = config["split_to"]
self.pretrain_path = config["pretrain_path"]
# split the too large dataset into the specified pieces
if self.split_to > 0:
self.logger.info("split the n_items to {} pieces".format(self.split_to))
self.group = torch.chunk(
torch.arange(self.n_items).to(self.device), self.split_to
)
else:
self.logger.warning(
"Pay Attetion!! the `split_to` is set to 0. If you catch a OMM error in this case, "
+ "you need to increase it \n\t\t\tuntil the error disappears. For example, "
+ "you can append it in the command line such as `--split_to=5`"
)
# define layers and loss
# construct source and destination item embedding matrix
self.item_src_embedding = nn.Embedding(
self.n_items, self.embedding_size, padding_idx=0
)
self.item_dst_embedding = nn.Embedding(
self.n_items, self.embedding_size, padding_idx=0
)
self.bias = nn.Parameter(torch.zeros(self.n_items))
if self.algorithm == "concat":
self.mlp_layers = MLPLayers([self.embedding_size * 2, self.weight_size])
elif self.algorithm == "prod":
self.mlp_layers = MLPLayers([self.embedding_size, self.weight_size])
else:
raise ValueError(
"NAIS just support attention type in ['concat', 'prod'] but get {}".format(
self.algorithm
)
)
self.weight_layer = nn.Parameter(torch.ones(self.weight_size, 1))
self.bceloss = nn.BCEWithLogitsLoss()
# parameters initialization
if self.pretrain_path is not None:
self.logger.info("use pretrain from [{}]...".format(self.pretrain_path))
self._load_pretrain()
else:
self.logger.info("unused pretrain...")
self.apply(self._init_weights)
def _init_weights(self, module):
"""Initialize the module's parameters
Note:
It's a little different from the source code, because pytorch has no function to initialize
the parameters by truncated normal distribution, so we replace it with xavier normal distribution
"""
if isinstance(module, nn.Embedding):
normal_(module.weight.data, 0, 0.01)
elif isinstance(module, nn.Linear):
xavier_normal_(module.weight.data)
if module.bias is not None:
constant_(module.bias.data, 0)
def _load_pretrain(self):
"""A simple implementation of loading pretrained parameters."""
fism = torch.load(self.pretrain_path)["state_dict"]
self.item_src_embedding.weight.data.copy_(fism["item_src_embedding.weight"])
self.item_dst_embedding.weight.data.copy_(fism["item_dst_embedding.weight"])
for name, parm in self.mlp_layers.named_parameters():
if name.endswith("weight"):
xavier_normal_(parm.data)
elif name.endswith("bias"):
constant_(parm.data, 0)
[docs] def get_history_info(self, dataset):
"""get the user history interaction information
Args:
dataset (DataSet): train dataset
Returns:
tuple: (history_item_matrix, history_lens, mask_mat)
"""
history_item_matrix, _, history_lens = dataset.history_item_matrix()
history_item_matrix = history_item_matrix.to(self.device)
history_lens = history_lens.to(self.device)
arange_tensor = torch.arange(history_item_matrix.shape[1]).to(self.device)
mask_mat = (arange_tensor < history_lens.unsqueeze(1)).float()
return history_item_matrix, history_lens, mask_mat
[docs] def reg_loss(self):
"""calculate the reg loss for embedding layers and mlp layers
Returns:
torch.Tensor: reg loss
"""
reg_1, reg_2, reg_3 = self.reg_weights
loss_1 = reg_1 * self.item_src_embedding.weight.norm(2)
loss_2 = reg_2 * self.item_dst_embedding.weight.norm(2)
loss_3 = 0
for name, parm in self.mlp_layers.named_parameters():
if name.endswith("weight"):
loss_3 = loss_3 + reg_3 * parm.norm(2)
return loss_1 + loss_2 + loss_3
[docs] def attention_mlp(self, inter, target):
"""layers of attention which support `prod` and `concat`
Args:
inter (torch.Tensor): the embedding of history items
target (torch.Tensor): the embedding of target items
Returns:
torch.Tensor: the result of attention
"""
if self.algorithm == "prod":
mlp_input = inter * target.unsqueeze(
1
) # batch_size x max_len x embedding_size
else:
mlp_input = torch.cat(
[inter, target.unsqueeze(1).expand_as(inter)], dim=2
) # batch_size x max_len x embedding_size*2
mlp_output = self.mlp_layers(mlp_input) # batch_size x max_len x weight_size
logits = torch.matmul(mlp_output, self.weight_layer).squeeze(
2
) # batch_size x max_len
return logits
[docs] def mask_softmax(self, similarity, logits, bias, item_num, batch_mask_mat):
"""softmax the unmasked user history items and get the final output
Args:
similarity (torch.Tensor): the similarity between the history items and target items
logits (torch.Tensor): the initial weights of the history items
item_num (torch.Tensor): user history interaction lengths
bias (torch.Tensor): bias
batch_mask_mat (torch.Tensor): the mask of user history interactions
Returns:
torch.Tensor: final output
"""
exp_logits = torch.exp(logits) # batch_size x max_len
exp_logits = batch_mask_mat * exp_logits # batch_size x max_len
exp_sum = torch.sum(exp_logits, dim=1, keepdim=True)
exp_sum = torch.pow(exp_sum, self.beta)
weights = torch.div(exp_logits, exp_sum)
coeff = torch.pow(item_num.squeeze(1), -self.alpha)
output = coeff.float() * torch.sum(weights * similarity, dim=1) + bias
return output
[docs] def softmax(self, similarity, logits, item_num, bias):
"""softmax the user history features and get the final output
Args:
similarity (torch.Tensor): the similarity between the history items and target items
logits (torch.Tensor): the initial weights of the history items
item_num (torch.Tensor): user history interaction lengths
bias (torch.Tensor): bias
Returns:
torch.Tensor: final output
"""
exp_logits = torch.exp(logits) # batch_size x max_len
exp_sum = torch.sum(exp_logits, dim=1, keepdim=True)
exp_sum = torch.pow(exp_sum, self.beta)
weights = torch.div(exp_logits, exp_sum)
coeff = torch.pow(item_num.squeeze(1), -self.alpha)
output = torch.sigmoid(
coeff.float() * torch.sum(weights * similarity, dim=1) + bias
)
return output
[docs] def inter_forward(self, user, item):
"""forward the model by interaction"""
user_inter = self.history_item_matrix[user]
item_num = self.history_lens[user].unsqueeze(1)
batch_mask_mat = self.mask_mat[user]
user_history = self.item_src_embedding(
user_inter
) # batch_size x max_len x embedding_size
target = self.item_dst_embedding(item) # batch_size x embedding_size
bias = self.bias[item] # batch_size x 1
similarity = torch.bmm(user_history, target.unsqueeze(2)).squeeze(
2
) # batch_size x max_len
logits = self.attention_mlp(user_history, target)
scores = self.mask_softmax(similarity, logits, bias, item_num, batch_mask_mat)
return scores
[docs] def user_forward(self, user_input, item_num, repeats=None, pred_slc=None):
"""forward the model by user
Args:
user_input (torch.Tensor): user input tensor
item_num (torch.Tensor): user history interaction lens
repeats (int, optional): the number of items to be evaluated
pred_slc (torch.Tensor, optional): continuous index which controls the current evaluation items,
if pred_slc is None, it will evaluate all items
Returns:
torch.Tensor: result
"""
item_num = item_num.repeat(repeats, 1)
user_history = self.item_src_embedding(user_input) # inter_num x embedding_size
user_history = user_history.repeat(
repeats, 1, 1
) # target_items x inter_num x embedding_size
if pred_slc is None:
targets = self.item_dst_embedding.weight # target_items x embedding_size
bias = self.bias
else:
targets = self.item_dst_embedding(pred_slc)
bias = self.bias[pred_slc]
similarity = torch.bmm(user_history, targets.unsqueeze(2)).squeeze(
2
) # inter_num x target_items
logits = self.attention_mlp(user_history, targets)
scores = self.softmax(similarity, logits, item_num, bias)
return scores
[docs] def forward(self, user, item):
return self.inter_forward(user, item)
[docs] def calculate_loss(self, interaction):
user = interaction[self.USER_ID]
item = interaction[self.ITEM_ID]
label = interaction[self.LABEL]
output = self.forward(user, item)
loss = self.bceloss(output, label) + self.reg_loss()
return loss
[docs] def full_sort_predict(self, interaction):
user = interaction[self.USER_ID]
user_inters = self.history_item_matrix[user]
item_nums = self.history_lens[user]
scores = []
# test users one by one, if the number of items is too large, we will split it to some pieces
for user_input, item_num in zip(user_inters, item_nums.unsqueeze(1)):
if self.split_to <= 0:
output = self.user_forward(
user_input[:item_num], item_num, repeats=self.n_items
)
else:
output = []
for mask in self.group:
tmp_output = self.user_forward(
user_input[:item_num],
item_num,
repeats=len(mask),
pred_slc=mask,
)
output.append(tmp_output)
output = torch.cat(output, dim=0)
scores.append(output)
result = torch.cat(scores, dim=0)
return result
[docs] def predict(self, interaction):
user = interaction[self.USER_ID]
item = interaction[self.ITEM_ID]
output = torch.sigmoid(self.forward(user, item))
return output