Source code for recbole.model.general_recommender.fism

# -*- encoding: utf-8 -*-
# @Time    :   2020/09/28
# @Author  :   Kaiyuan Li
# @email   :   tsotfsk@outlook.com

"""
FISM
#######################################
Reference:
    S. Kabbur et al. "FISM: Factored item similarity models for top-n recommender systems" in KDD 2013

Reference code:
    https://github.com/AaronHeee/Neural-Attentive-Item-Similarity-Model
"""

import torch
import torch.nn as nn
from torch.nn.init import normal_

from recbole.model.abstract_recommender import GeneralRecommender
from recbole.utils import InputType


[docs]class FISM(GeneralRecommender): """FISM is an item-based model for generating top-N recommendations that learns the item-item similarity matrix as the product of two low dimensional latent factor matrices. These matrices are learned using a structural equation modeling approach, where in the value being estimated is not used for its own estimation. """ input_type = InputType.POINTWISE def __init__(self, config, dataset): super(FISM, self).__init__(config, dataset) # load dataset info self.LABEL = config['LABEL_FIELD'] # get all users' history interaction information.the history item # matrix is padding by the maximum number of a user's interactions self.history_item_matrix, self.history_lens, self.mask_mat = self.get_history_info(dataset) # load parameters info self.embedding_size = config['embedding_size'] self.reg_weights = config['reg_weights'] self.alpha = config['alpha'] self.split_to = config['split_to'] # split the too large dataset into the specified pieces if self.split_to > 0: self.group = torch.chunk(torch.arange(self.n_items).to(self.device), self.split_to) else: self.logger.warning('Pay Attetion!! the `split_to` is set to 0. If you catch a OMM error in this case, ' + \ 'you need to increase it \n\t\t\tuntil the error disappears. For example, ' + \ 'you can append it in the command line such as `--split_to=5`') # define layers and loss # construct source and destination item embedding matrix self.item_src_embedding = nn.Embedding(self.n_items, self.embedding_size, padding_idx=0) self.item_dst_embedding = nn.Embedding(self.n_items, self.embedding_size, padding_idx=0) self.user_bias = nn.Parameter(torch.zeros(self.n_users)) self.item_bias = nn.Parameter(torch.zeros(self.n_items)) self.bceloss = nn.BCELoss() # parameters initialization self.apply(self._init_weights)
[docs] def get_history_info(self, dataset): """get the user history interaction information Args: dataset (DataSet): train dataset Returns: tuple: (history_item_matrix, history_lens, mask_mat) """ history_item_matrix, _, history_lens = dataset.history_item_matrix() history_item_matrix = history_item_matrix.to(self.device) history_lens = history_lens.to(self.device) arange_tensor = torch.arange(history_item_matrix.shape[1]).to(self.device) mask_mat = (arange_tensor < history_lens.unsqueeze(1)).float() return history_item_matrix, history_lens, mask_mat
[docs] def reg_loss(self): """calculate the reg loss for embedding layers Returns: torch.Tensor: reg loss """ reg_1, reg_2 = self.reg_weights loss_1 = reg_1 * self.item_src_embedding.weight.norm(2) loss_2 = reg_2 * self.item_dst_embedding.weight.norm(2) return loss_1 + loss_2
def _init_weights(self, module): """Initialize the module's parameters Note: It's a little different from the source code, because pytorch has no function to initialize the parameters by truncated normal distribution, so we replace it with xavier normal distribution """ if isinstance(module, nn.Embedding): normal_(module.weight.data, 0, 0.01)
[docs] def inter_forward(self, user, item): """forward the model by interaction """ user_inter = self.history_item_matrix[user] item_num = self.history_lens[user].unsqueeze(1) batch_mask_mat = self.mask_mat[user] user_history = self.item_src_embedding(user_inter) # batch_size x max_len x embedding_size target = self.item_dst_embedding(item) # batch_size x embedding_size user_bias = self.user_bias[user] # batch_size x 1 item_bias = self.item_bias[item] similarity = torch.bmm(user_history, target.unsqueeze(2)).squeeze(2) # batch_size x max_len similarity = batch_mask_mat * similarity coeff = torch.pow(item_num.squeeze(1), -self.alpha) scores = torch.sigmoid(coeff.float() * torch.sum(similarity, dim=1) + user_bias + item_bias) return scores
[docs] def user_forward(self, user_input, item_num, user_bias, repeats=None, pred_slc=None): """forward the model by user Args: user_input (torch.Tensor): user input tensor item_num (torch.Tensor): user history interaction lens repeats (int, optional): the number of items to be evaluated pred_slc (torch.Tensor, optional): continuous index which controls the current evaluation items, if pred_slc is None, it will evaluate all items Returns: torch.Tensor: result """ item_num = item_num.repeat(repeats, 1) user_history = self.item_src_embedding(user_input) # inter_num x embedding_size user_history = user_history.repeat(repeats, 1, 1) # target_items x inter_num x embedding_size if pred_slc is None: targets = self.item_dst_embedding.weight # target_items x embedding_size item_bias = self.item_bias else: targets = self.item_dst_embedding(pred_slc) item_bias = self.item_bias[pred_slc] similarity = torch.bmm(user_history, targets.unsqueeze(2)).squeeze(2) # inter_num x target_items coeff = torch.pow(item_num.squeeze(1), -self.alpha) scores = torch.sigmoid(coeff.float() * torch.sum(similarity, dim=1) + user_bias + item_bias) return scores
[docs] def forward(self, user, item): return self.inter_forward(user, item)
[docs] def calculate_loss(self, interaction): user = interaction[self.USER_ID] item = interaction[self.ITEM_ID] label = interaction[self.LABEL] output = self.forward(user, item) loss = self.bceloss(output, label) + self.reg_loss() return loss
[docs] def full_sort_predict(self, interaction): user = interaction[self.USER_ID] batch_user_bias = self.user_bias[user] user_inters = self.history_item_matrix[user] item_nums = self.history_lens[user] scores = [] # test users one by one, if the number of items is too large, we will split it to some pieces for user_input, item_num, user_bias in zip(user_inters, item_nums.unsqueeze(1), batch_user_bias): if self.split_to <= 0: output = self.user_forward(user_input[:item_num], item_num, user_bias, repeats=self.n_items) else: output = [] for mask in self.group: tmp_output = self.user_forward( user_input[:item_num], item_num, user_bias, repeats=len(mask), pred_slc=mask ) output.append(tmp_output) output = torch.cat(output, dim=0) scores.append(output) result = torch.cat(scores, dim=0) return result
[docs] def predict(self, interaction): user = interaction[self.USER_ID] item = interaction[self.ITEM_ID] output = self.forward(user, item) return output