# -*- coding: utf-8 -*-
# @Time : 2022/3/25 13:38
# @Author : HaoJun Qin
# @Email : 18697951462@163.com
r"""
SimpleX
################################################
Reference:
Kelong Mao et al. "SimpleX: A Simple and Strong Baseline for Collaborative Filtering." in CIKM 2021.
Reference code:
https://github.com/xue-pai/TwoToweRS
"""
import torch
from torch import nn
import torch.nn.functional as F
from recbole.model.init import xavier_normal_initialization
from recbole.model.abstract_recommender import GeneralRecommender
from recbole.model.loss import EmbLoss
from recbole.utils import InputType
[docs]class SimpleX(GeneralRecommender):
r"""SimpleX is a simple, unified collaborative filtering model.
SimpleX presents a simple and easy-to-understand model. Its advantage lies
in its loss function, which uses a larger number of negative samples and
sets a threshold to filter out less informative samples, it also uses
relative weights to control the balance of positive-sample loss
and negative-sample loss.
We implement the model following the original author with a pairwise training mode.
"""
input_type = InputType.PAIRWISE
def __init__(self, config, dataset):
super(SimpleX, self).__init__(config, dataset)
# Get user history interacted items
self.history_item_id, _, self.history_item_len = dataset.history_item_matrix(
max_history_len=config["history_len"]
)
self.history_item_id = self.history_item_id.to(self.device)
self.history_item_len = self.history_item_len.to(self.device)
# load parameters info
self.embedding_size = config["embedding_size"]
self.margin = config["margin"]
self.negative_weight = config["negative_weight"]
self.gamma = config["gamma"]
self.neg_seq_len = config["train_neg_sample_args"]["sample_num"]
self.reg_weight = config["reg_weight"]
self.aggregator = config["aggregator"]
if self.aggregator not in ["mean", "user_attention", "self_attention"]:
raise ValueError(
"aggregator must be mean, user_attention or self_attention"
)
self.history_len = torch.max(self.history_item_len, dim=0)
# user embedding matrix
self.user_emb = nn.Embedding(self.n_users, self.embedding_size)
# item embedding matrix
self.item_emb = nn.Embedding(self.n_items, self.embedding_size, padding_idx=0)
# feature space mapping matrix of user and item
self.UI_map = nn.Linear(self.embedding_size, self.embedding_size, bias=False)
if self.aggregator in ["user_attention", "self_attention"]:
self.W_k = nn.Sequential(
nn.Linear(self.embedding_size, self.embedding_size), nn.Tanh()
)
if self.aggregator == "self_attention":
self.W_q = nn.Linear(self.embedding_size, 1, bias=False)
# dropout
self.dropout = nn.Dropout(0.1)
self.require_pow = config["require_pow"]
# l2 regularization loss
self.reg_loss = EmbLoss()
# parameters initialization
self.apply(xavier_normal_initialization)
# get the mask
self.item_emb.weight.data[0, :] = 0
[docs] def get_UI_aggregation(self, user_e, history_item_e, history_len):
r"""Get the combined vector of user and historically interacted items
Args:
user_e (torch.Tensor): User's feature vector, shape: [user_num, embedding_size]
history_item_e (torch.Tensor): History item's feature vector,
shape: [user_num, max_history_len, embedding_size]
history_len (torch.Tensor): User's history length, shape: [user_num]
Returns:
torch.Tensor: Combined vector of user and item sequences, shape: [user_num, embedding_size]
"""
if self.aggregator == "mean":
pos_item_sum = history_item_e.sum(dim=1)
# [user_num, embedding_size]
out = pos_item_sum / (history_len + 1.0e-10).unsqueeze(1)
elif self.aggregator in ["user_attention", "self_attention"]:
# [user_num, max_history_len, embedding_size]
key = self.W_k(history_item_e)
if self.aggregator == "user_attention":
# [user_num, max_history_len]
attention = torch.matmul(key, user_e.unsqueeze(2)).squeeze(2)
elif self.aggregator == "self_attention":
# [user_num, max_history_len]
attention = self.W_q(key).squeeze(2)
e_attention = torch.exp(attention)
mask = (history_item_e.sum(dim=-1) != 0).int()
e_attention = e_attention * mask
# [user_num, max_history_len]
attention_weight = e_attention / (
e_attention.sum(dim=1, keepdim=True) + 1.0e-10
)
# [user_num, embedding_size]
out = torch.matmul(attention_weight.unsqueeze(1), history_item_e).squeeze(1)
# Combined vector of user and item sequences
out = self.UI_map(out)
g = self.gamma
UI_aggregation_e = g * user_e + (1 - g) * out
return UI_aggregation_e
[docs] def get_cos(self, user_e, item_e):
r"""Get the cosine similarity between user and item
Args:
user_e (torch.Tensor): User's feature vector, shape: [user_num, embedding_size]
item_e (torch.Tensor): Item's feature vector,
shape: [user_num, item_num, embedding_size]
Returns:
torch.Tensor: Cosine similarity between user and item, shape: [user_num, item_num]
"""
user_e = F.normalize(user_e, dim=1)
# [user_num, embedding_size, 1]
user_e = user_e.unsqueeze(2)
item_e = F.normalize(item_e, dim=2)
UI_cos = torch.matmul(item_e, user_e)
return UI_cos.squeeze(2)
[docs] def forward(self, user, pos_item, history_item, history_len, neg_item_seq):
r"""Get the loss
Args:
user (torch.Tensor): User's id, shape: [user_num]
pos_item (torch.Tensor): Positive item's id, shape: [user_num]
history_item (torch.Tensor): Id of historty item, shape: [user_num, max_history_len]
history_len (torch.Tensor): History item's length, shape: [user_num]
neg_item_seq (torch.Tensor): Negative item seq's id, shape: [user_num, neg_seq_len]
Returns:
torch.Tensor: Loss, shape: []
"""
# [user_num, embedding_size]
user_e = self.user_emb(user)
# [user_num, embedding_size]
pos_item_e = self.item_emb(pos_item)
# [user_num, max_history_len, embedding_size]
history_item_e = self.item_emb(history_item)
# [nuser_num, neg_seq_len, embedding_size]
neg_item_seq_e = self.item_emb(neg_item_seq)
# [user_num, embedding_size]
UI_aggregation_e = self.get_UI_aggregation(user_e, history_item_e, history_len)
UI_aggregation_e = self.dropout(UI_aggregation_e)
pos_cos = self.get_cos(UI_aggregation_e, pos_item_e.unsqueeze(1))
neg_cos = self.get_cos(UI_aggregation_e, neg_item_seq_e)
# CCL loss
pos_loss = torch.relu(1 - pos_cos)
neg_loss = torch.relu(neg_cos - self.margin)
neg_loss = neg_loss.mean(1, keepdim=True) * self.negative_weight
CCL_loss = (pos_loss + neg_loss).mean()
# l2 regularization loss
reg_loss = self.reg_loss(
user_e,
pos_item_e,
history_item_e,
neg_item_seq_e,
require_pow=self.require_pow,
)
loss = CCL_loss + self.reg_weight * reg_loss.sum()
return loss
[docs] def calculate_loss(self, interaction):
r"""Data processing and call function forward(), return loss
To use SimpleX, a user must have a historical transaction record,
a pos item and a sequence of neg items. Based on the RecBole
framework, the data in the interaction object is ordered, so
we can get the data quickly.
"""
user = interaction[self.USER_ID]
pos_item = interaction[self.ITEM_ID]
neg_item = interaction[self.NEG_ITEM_ID]
# get the sequence of neg items
neg_item_seq = neg_item.reshape((self.neg_seq_len, -1))
neg_item_seq = neg_item_seq.T
user_number = int(len(user) / self.neg_seq_len)
# user's id
user = user[0:user_number]
# historical transaction record
history_item = self.history_item_id[user]
# positive item's id
pos_item = pos_item[0:user_number]
# history_len
history_len = self.history_item_len[user]
loss = self.forward(user, pos_item, history_item, history_len, neg_item_seq)
return loss
[docs] def predict(self, interaction):
user = interaction[self.USER_ID]
history_item = self.history_item_id[user]
history_len = self.history_item_len[user]
test_item = interaction[self.ITEM_ID]
# [user_num, embedding_size]
user_e = self.user_emb(user)
# [user_num, embedding_size]
test_item_e = self.item_emb(test_item)
# [user_num, max_history_len, embedding_size]
history_item_e = self.item_emb(history_item)
# [user_num, embedding_size]
UI_aggregation_e = self.get_UI_aggregation(user_e, history_item_e, history_len)
UI_cos = self.get_cos(UI_aggregation_e, test_item_e.unsqueeze(1))
return UI_cos.squeeze(1)
[docs] def full_sort_predict(self, interaction):
user = interaction[self.USER_ID]
history_item = self.history_item_id[user]
history_len = self.history_item_len[user]
# [user_num, embedding_size]
user_e = self.user_emb(user)
# [user_num, max_history_len, embedding_size]
history_item_e = self.item_emb(history_item)
# [user_num, embedding_size]
UI_aggregation_e = self.get_UI_aggregation(user_e, history_item_e, history_len)
UI_aggregation_e = F.normalize(UI_aggregation_e, dim=1)
all_item_emb = self.item_emb.weight
all_item_emb = F.normalize(all_item_emb, dim=1)
UI_cos = torch.matmul(UI_aggregation_e, all_item_emb.T)
return UI_cos