Source code for recbole.model.general_recommender.pop

# -*- coding: utf-8 -*-
# @Time   : 2020/8/11 9:57
# @Author : Zihan Lin
# @Email  : linzihan.super@foxmail.com
# UPDATE
# @Time   : 2020/11/9
# @Author : Zihan Lin
# @Email  : zhlin@ruc.edu.cn
r"""
Pop
################################################

"""

import torch

from recbole.utils import InputType, ModelType
from recbole.model.abstract_recommender import GeneralRecommender


[docs]class Pop(GeneralRecommender): r"""Pop is an fundamental model that always recommend the most popular item. """ input_type = InputType.POINTWISE type = ModelType.TRADITIONAL def __init__(self, config, dataset): super(Pop, self).__init__(config, dataset) self.item_cnt = torch.zeros(self.n_items, 1, dtype=torch.long, device=self.device, requires_grad=False) self.max_cnt = None self.fake_loss = torch.nn.Parameter(torch.zeros(1))
[docs] def forward(self): pass
[docs] def calculate_loss(self, interaction): item = interaction[self.ITEM_ID] self.item_cnt[item, :] = self.item_cnt[item, :] + 1 self.max_cnt = torch.max(self.item_cnt, dim=0)[0] return torch.nn.Parameter(torch.zeros(1))
[docs] def predict(self, interaction): item = interaction[self.ITEM_ID] result = torch.true_divide(self.item_cnt[item, :], self.max_cnt) return result.squeeze()
[docs] def full_sort_predict(self, interaction): batch_user_num = interaction[self.USER_ID].shape[0] result = self.item_cnt.to(torch.float64) / self.max_cnt.to(torch.float64) result = torch.repeat_interleave(result.unsqueeze(0), batch_user_num, dim=0) return result.view(-1)