Source code for recbole.model.general_recommender.spectralcf

# -*- coding: utf-8 -*-
# @Time   : 2020/10/2
# @Author : Changxin Tian
# @Email  : cx.tian@outlook.com

"""
SpectralCF
################################################

Reference:
    Lei Zheng et al. "Spectral collaborative filtering." in RecSys 2018.

Reference code:
    https://github.com/lzheng21/SpectralCF
"""

import numpy as np
import scipy.sparse as sp
import torch

from recbole.model.abstract_recommender import GeneralRecommender
from recbole.model.init import xavier_uniform_initialization
from recbole.model.loss import BPRLoss, EmbLoss
from recbole.utils import InputType


[docs]class SpectralCF(GeneralRecommender): r"""SpectralCF is a spectral convolution model that directly learns latent factors of users and items from the spectral domain for recommendation. The spectral convolution operation with C input channels and F filters is shown as the following: .. math:: \left[\begin{array} {c} X_{new}^{u} \\ X_{new}^{i} \end{array}\right]=\sigma\left(\left(U U^{\top}+U \Lambda U^{\top}\right) \left[\begin{array}{c} X^{u} \\ X^{i} \end{array}\right] \Theta^{\prime}\right) where :math:`X_{new}^{u} \in R^{n_{users} \times F}` and :math:`X_{new}^{i} \in R^{n_{items} \times F}` denote convolution results learned with F filters from the spectral domain for users and items, respectively; :math:`\sigma` denotes the logistic sigmoid function. Note: Our implementation is a improved version which is different from the original paper. For a better stability, we replace :math:`U U^T` with identity matrix :math:`I` and replace :math:`U \Lambda U^T` with laplace matrix :math:`L`. """ input_type = InputType.PAIRWISE def __init__(self, config, dataset): super(SpectralCF, self).__init__(config, dataset) # load parameters info self.n_layers = config["n_layers"] self.emb_dim = config["embedding_size"] self.reg_weight = config["reg_weight"] # generate intermediate data # "A_hat = I + L" is equivalent to "A_hat = U U^T + U \Lambda U^T" self.interaction_matrix = dataset.inter_matrix(form="coo").astype(np.float32) I = self.get_eye_mat(self.n_items + self.n_users) L = self.get_laplacian_matrix() A_hat = I + L self.A_hat = A_hat.to(self.device) # define layers and loss self.user_embedding = torch.nn.Embedding( num_embeddings=self.n_users, embedding_dim=self.emb_dim ) self.item_embedding = torch.nn.Embedding( num_embeddings=self.n_items, embedding_dim=self.emb_dim ) self.filters = torch.nn.ParameterList( [ torch.nn.Parameter( torch.normal( mean=0.01, std=0.02, size=(self.emb_dim, self.emb_dim) ), requires_grad=True, ) for _ in range(self.n_layers) ] ) self.sigmoid = torch.nn.Sigmoid() self.mf_loss = BPRLoss() self.reg_loss = EmbLoss() self.restore_user_e = None self.restore_item_e = None self.other_parameter_name = ["restore_user_e", "restore_item_e"] # parameters initialization self.apply(xavier_uniform_initialization)
[docs] def get_laplacian_matrix(self): r"""Get the laplacian matrix of users and items. .. math:: L = I - D^{-1} \times A Returns: Sparse tensor of the laplacian matrix. """ # build adj matrix A = sp.dok_matrix( (self.n_users + self.n_items, self.n_users + self.n_items), dtype=np.float32 ) inter_M = self.interaction_matrix inter_M_t = self.interaction_matrix.transpose() data_dict = dict( zip(zip(inter_M.row, inter_M.col + self.n_users), [1] * inter_M.nnz) ) data_dict.update( dict( zip( zip(inter_M_t.row + self.n_users, inter_M_t.col), [1] * inter_M_t.nnz, ) ) ) A._update(data_dict) # norm adj matrix sumArr = (A > 0).sum(axis=1) diag = np.array(sumArr.flatten())[0] + 1e-7 diag = np.power(diag, -1) D = sp.diags(diag) A_tilde = D * A # covert norm_adj matrix to tensor A_tilde = sp.coo_matrix(A_tilde) row = A_tilde.row col = A_tilde.col i = torch.LongTensor([row, col]) data = torch.FloatTensor(A_tilde.data) A_tilde = torch.sparse.FloatTensor(i, data, torch.Size(A_tilde.shape)) # generate laplace matrix L = self.get_eye_mat(self.n_items + self.n_users) - A_tilde return L
[docs] def get_eye_mat(self, num): r"""Construct the identity matrix with the size of n_items+n_users. Args: num: number of column of the square matrix Returns: Sparse tensor of the identity matrix. Shape of (n_items+n_users, n_items+n_users) """ i = torch.LongTensor([range(0, num), range(0, num)]) val = torch.FloatTensor([1] * num) return torch.sparse.FloatTensor(i, val)
[docs] def get_ego_embeddings(self): r"""Get the embedding of users and items and combine to an embedding matrix. Returns: Tensor of the embedding matrix. Shape of (n_items+n_users, embedding_dim) """ user_embeddings = self.user_embedding.weight item_embeddings = self.item_embedding.weight ego_embeddings = torch.cat([user_embeddings, item_embeddings], dim=0) return ego_embeddings
[docs] def forward(self): all_embeddings = self.get_ego_embeddings() embeddings_list = [all_embeddings] for k in range(self.n_layers): all_embeddings = torch.sparse.mm(self.A_hat, all_embeddings) all_embeddings = self.sigmoid(torch.mm(all_embeddings, self.filters[k])) embeddings_list.append(all_embeddings) new_embeddings = torch.cat(embeddings_list, dim=1) user_all_embeddings, item_all_embeddings = torch.split( new_embeddings, [self.n_users, self.n_items] ) return user_all_embeddings, item_all_embeddings
[docs] def calculate_loss(self, interaction): if self.restore_user_e is not None or self.restore_item_e is not None: self.restore_user_e, self.restore_item_e = None, None user = interaction[self.USER_ID] pos_item = interaction[self.ITEM_ID] neg_item = interaction[self.NEG_ITEM_ID] user_all_embeddings, item_all_embeddings = self.forward() u_embeddings = user_all_embeddings[user] pos_embeddings = item_all_embeddings[pos_item] neg_embeddings = item_all_embeddings[neg_item] pos_scores = torch.mul(u_embeddings, pos_embeddings).sum(dim=1) neg_scores = torch.mul(u_embeddings, neg_embeddings).sum(dim=1) mf_loss = self.mf_loss(pos_scores, neg_scores) reg_loss = self.reg_loss(u_embeddings, pos_embeddings, neg_embeddings) loss = mf_loss + self.reg_weight * reg_loss return loss
[docs] def predict(self, interaction): user = interaction[self.USER_ID] item = interaction[self.ITEM_ID] user_all_embeddings, item_all_embeddings = self.forward() u_embeddings = user_all_embeddings[user] i_embeddings = item_all_embeddings[item] scores = torch.mul(u_embeddings, i_embeddings).sum(dim=1) return scores
[docs] def full_sort_predict(self, interaction): user = interaction[self.USER_ID] if self.restore_user_e is None or self.restore_item_e is None: self.restore_user_e, self.restore_item_e = self.forward() u_embeddings = self.restore_user_e[user] scores = torch.matmul(u_embeddings, self.restore_item_e.transpose(0, 1)) return scores.view(-1)