Source code for recbole.model.general_recommender.sgl

# -*- coding: utf-8 -*-
# @Time   : 2021/10/12
# @Author : Tian Zhen
# @Email  : chenyuwuxinn@gmail.com

r"""
SGL
################################################
Reference:
    Jiancan Wu et al. "SGL: Self-supervised Graph Learning for Recommendation" in SIGIR 2021.

Reference code:
    https://github.com/wujcan/SGL
"""

import numpy as np
import scipy.sparse as sp
import torch
from recbole.model.abstract_recommender import GeneralRecommender
from recbole.model.init import xavier_uniform_initialization
from recbole.model.loss import BPRLoss, EmbLoss
from recbole.utils import InputType
import torch.nn.functional as F


[docs]class SGL(GeneralRecommender): r"""SGL is a GCN-based recommender model. SGL supplements the classical supervised task of recommendation with an auxiliary self supervised task, which reinforces node representation learning via self- discrimination.Specifically,SGL generates multiple views of a node, maximizing the agreement between different views of the same node compared to that of other nodes. SGL devises three operators to generate the views — node dropout, edge dropout, and random walk — that change the graph structure in different manners. We implement the model following the original author with a pairwise training mode. """ input_type = InputType.PAIRWISE def __init__(self, config, dataset): super(SGL, self).__init__(config, dataset) self._user = dataset.inter_feat[dataset.uid_field] self._item = dataset.inter_feat[dataset.iid_field] self.embed_dim = config["embedding_size"] self.n_layers = int(config["n_layers"]) self.type = config["type"] self.drop_ratio = config["drop_ratio"] self.ssl_tau = config["ssl_tau"] self.reg_weight = config["reg_weight"] self.ssl_weight = config["ssl_weight"] self.user_embedding = torch.nn.Embedding(self.n_users, self.embed_dim) self.item_embedding = torch.nn.Embedding(self.n_items, self.embed_dim) self.reg_loss = EmbLoss() self.train_graph = self.csr2tensor(self.create_adjust_matrix(is_sub=False)) self.restore_user_e = None self.restore_item_e = None self.apply(xavier_uniform_initialization) self.other_parameter_name = ["restore_user_e", "restore_item_e"]
[docs] def graph_construction(self): r"""Devise three operators to generate the views — node dropout, edge dropout, and random walk of a node.""" self.sub_graph1 = [] if self.type == "ND" or self.type == "ED": self.sub_graph1 = self.csr2tensor(self.create_adjust_matrix(is_sub=True)) elif self.type == "RW": for i in range(self.n_layers): _g = self.csr2tensor(self.create_adjust_matrix(is_sub=True)) self.sub_graph1.append(_g) self.sub_graph2 = [] if self.type == "ND" or self.type == "ED": self.sub_graph2 = self.csr2tensor(self.create_adjust_matrix(is_sub=True)) elif self.type == "RW": for i in range(self.n_layers): _g = self.csr2tensor(self.create_adjust_matrix(is_sub=True)) self.sub_graph2.append(_g)
[docs] def rand_sample(self, high, size=None, replace=True): r"""Randomly discard some points or edges. Args: high (int): Upper limit of index value size (int): Array size after sampling Returns: numpy.ndarray: Array index after sampling, shape: [size] """ a = np.arange(high) sample = np.random.choice(a, size=size, replace=replace) return sample
[docs] def create_adjust_matrix(self, is_sub: bool): r"""Get the normalized interaction matrix of users and items. Construct the square matrix from the training data and normalize it using the laplace matrix.If it is a subgraph, it may be processed by node dropout or edge dropout. .. math:: A_{hat} = D^{-0.5} \times A \times D^{-0.5} Returns: csr_matrix of the normalized interaction matrix. """ matrix = None if not is_sub: ratings = np.ones_like(self._user, dtype=np.float32) matrix = sp.csr_matrix( (ratings, (self._user, self._item + self.n_users)), shape=(self.n_users + self.n_items, self.n_users + self.n_items), ) else: if self.type == "ND": drop_user = self.rand_sample( self.n_users, size=int(self.n_users * self.drop_ratio), replace=False, ) drop_item = self.rand_sample( self.n_items, size=int(self.n_items * self.drop_ratio), replace=False, ) R_user = np.ones(self.n_users, dtype=np.float32) R_user[drop_user] = 0.0 R_item = np.ones(self.n_items, dtype=np.float32) R_item[drop_item] = 0.0 R_user = sp.diags(R_user) R_item = sp.diags(R_item) R_G = sp.csr_matrix( ( np.ones_like(self._user, dtype=np.float32), (self._user, self._item), ), shape=(self.n_users, self.n_items), ) res = R_user.dot(R_G) res = res.dot(R_item) user, item = res.nonzero() ratings = res.data matrix = sp.csr_matrix( (ratings, (user, item + self.n_users)), shape=(self.n_users + self.n_items, self.n_users + self.n_items), ) elif self.type == "ED" or self.type == "RW": keep_item = self.rand_sample( len(self._user), size=int(len(self._user) * (1 - self.drop_ratio)), replace=False, ) user = self._user[keep_item] item = self._item[keep_item] matrix = sp.csr_matrix( (np.ones_like(user), (user, item + self.n_users)), shape=(self.n_users + self.n_items, self.n_users + self.n_items), ) matrix = matrix + matrix.T D = np.array(matrix.sum(axis=1)) + 1e-7 D = np.power(D, -0.5).flatten() D = sp.diags(D) return D.dot(matrix).dot(D)
[docs] def csr2tensor(self, matrix: sp.csr_matrix): r"""Convert csr_matrix to tensor. Args: matrix (scipy.csr_matrix): Sparse matrix to be converted. Returns: torch.sparse.FloatTensor: Transformed sparse matrix. """ matrix = matrix.tocoo() x = torch.sparse.FloatTensor( torch.LongTensor(np.array([matrix.row, matrix.col])), torch.FloatTensor(matrix.data.astype(np.float32)), matrix.shape, ).to(self.device) return x
[docs] def forward(self, graph): main_ego = torch.cat([self.user_embedding.weight, self.item_embedding.weight]) all_ego = [main_ego] if isinstance(graph, list): for sub_graph in graph: main_ego = torch.sparse.mm(sub_graph, main_ego) all_ego.append(main_ego) else: for i in range(self.n_layers): main_ego = torch.sparse.mm(graph, main_ego) all_ego.append(main_ego) all_ego = torch.stack(all_ego, dim=1) all_ego = torch.mean(all_ego, dim=1, keepdim=False) user_emd, item_emd = torch.split(all_ego, [self.n_users, self.n_items], dim=0) return user_emd, item_emd
[docs] def calculate_loss(self, interaction): if self.restore_user_e is not None or self.restore_item_e is not None: self.restore_user_e, self.restore_item_e = None, None user_list = interaction[self.USER_ID] pos_item_list = interaction[self.ITEM_ID] neg_item_list = interaction[self.NEG_ITEM_ID] user_emd, item_emd = self.forward(self.train_graph) user_sub1, item_sub1 = self.forward(self.sub_graph1) user_sub2, item_sub2 = self.forward(self.sub_graph2) total_loss = self.calc_bpr_loss( user_emd, item_emd, user_list, pos_item_list, neg_item_list ) + self.calc_ssl_loss( user_list, pos_item_list, user_sub1, user_sub2, item_sub1, item_sub2 ) return total_loss
[docs] def calc_bpr_loss( self, user_emd, item_emd, user_list, pos_item_list, neg_item_list ): r"""Calculate the the pairwise Bayesian Personalized Ranking (BPR) loss and parameter regularization loss. Args: user_emd (torch.Tensor): Ego embedding of all users after forwarding. item_emd (torch.Tensor): Ego embedding of all items after forwarding. user_list (torch.Tensor): List of the user. pos_item_list (torch.Tensor): List of positive examples. neg_item_list (torch.Tensor): List of negative examples. Returns: torch.Tensor: Loss of BPR tasks and parameter regularization. """ u_e = user_emd[user_list] pi_e = item_emd[pos_item_list] ni_e = item_emd[neg_item_list] p_scores = torch.mul(u_e, pi_e).sum(dim=1) n_scores = torch.mul(u_e, ni_e).sum(dim=1) l1 = torch.sum(-F.logsigmoid(p_scores - n_scores)) u_e_p = self.user_embedding(user_list) pi_e_p = self.item_embedding(pos_item_list) ni_e_p = self.item_embedding(neg_item_list) l2 = self.reg_loss(u_e_p, pi_e_p, ni_e_p) return l1 + l2 * self.reg_weight
[docs] def calc_ssl_loss( self, user_list, pos_item_list, user_sub1, user_sub2, item_sub1, item_sub2 ): r"""Calculate the loss of self-supervised tasks. Args: user_list (torch.Tensor): List of the user. pos_item_list (torch.Tensor): List of positive examples. user_sub1 (torch.Tensor): Ego embedding of all users in the first subgraph after forwarding. user_sub2 (torch.Tensor): Ego embedding of all users in the second subgraph after forwarding. item_sub1 (torch.Tensor): Ego embedding of all items in the first subgraph after forwarding. item_sub2 (torch.Tensor): Ego embedding of all items in the second subgraph after forwarding. Returns: torch.Tensor: Loss of self-supervised tasks. """ u_emd1 = F.normalize(user_sub1[user_list], dim=1) u_emd2 = F.normalize(user_sub2[user_list], dim=1) all_user2 = F.normalize(user_sub2, dim=1) v1 = torch.sum(u_emd1 * u_emd2, dim=1) v2 = u_emd1.matmul(all_user2.T) v1 = torch.exp(v1 / self.ssl_tau) v2 = torch.sum(torch.exp(v2 / self.ssl_tau), dim=1) ssl_user = -torch.sum(torch.log(v1 / v2)) i_emd1 = F.normalize(item_sub1[pos_item_list], dim=1) i_emd2 = F.normalize(item_sub2[pos_item_list], dim=1) all_item2 = F.normalize(item_sub2, dim=1) v3 = torch.sum(i_emd1 * i_emd2, dim=1) v4 = i_emd1.matmul(all_item2.T) v3 = torch.exp(v3 / self.ssl_tau) v4 = torch.sum(torch.exp(v4 / self.ssl_tau), dim=1) ssl_item = -torch.sum(torch.log(v3 / v4)) return (ssl_item + ssl_user) * self.ssl_weight
[docs] def predict(self, interaction): if self.restore_user_e is None or self.restore_item_e is None: self.restore_user_e, self.restore_item_e = self.forward(self.train_graph) user = self.restore_user_e[interaction[self.USER_ID]] item = self.restore_item_e[interaction[self.ITEM_ID]] return torch.sum(user * item, dim=1)
[docs] def full_sort_predict(self, interaction): if self.restore_user_e is None or self.restore_item_e is None: self.restore_user_e, self.restore_item_e = self.forward(self.train_graph) user = self.restore_user_e[interaction[self.USER_ID]] return user.matmul(self.restore_item_e.T)
[docs] def train(self, mode: bool = True): r"""Override train method of base class.The subgraph is reconstructed each time it is called.""" T = super().train(mode=mode) if mode: self.graph_construction() return T