# -*- coding: utf-8 -*-
# @Time : 2020/9/30 14:07
# @Author : Yujie Lu
# @Email : yujielu1998@gmail.com
r"""
SRGNN
################################################
Reference:
Shu Wu et al. "Session-based Recommendation with Graph Neural Networks." in AAAI 2019.
Reference code:
https://github.com/CRIPAC-DIG/SR-GNN
"""
import math
import numpy as np
import torch
from torch import nn
from torch.nn import Parameter
from torch.nn import functional as F
from recbole.model.abstract_recommender import SequentialRecommender
from recbole.model.loss import BPRLoss
[docs]class GNN(nn.Module):
r"""Graph neural networks are well-suited for session-based recommendation,
because it can automatically extract features of session graphs with considerations of rich node connections.
"""
def __init__(self, embedding_size, step=1):
super(GNN, self).__init__()
self.step = step
self.embedding_size = embedding_size
self.input_size = embedding_size * 2
self.gate_size = embedding_size * 3
self.w_ih = Parameter(torch.Tensor(self.gate_size, self.input_size))
self.w_hh = Parameter(torch.Tensor(self.gate_size, self.embedding_size))
self.b_ih = Parameter(torch.Tensor(self.gate_size))
self.b_hh = Parameter(torch.Tensor(self.gate_size))
self.b_iah = Parameter(torch.Tensor(self.embedding_size))
self.b_ioh = Parameter(torch.Tensor(self.embedding_size))
self.linear_edge_in = nn.Linear(self.embedding_size, self.embedding_size, bias=True)
self.linear_edge_out = nn.Linear(self.embedding_size, self.embedding_size, bias=True)
[docs] def GNNCell(self, A, hidden):
r"""Obtain latent vectors of nodes via graph neural networks.
Args:
A(torch.FloatTensor):The connection matrix,shape of [batch_size, max_session_len, 2 * max_session_len]
hidden(torch.FloatTensor):The item node embedding matrix, shape of
[batch_size, max_session_len, embedding_size]
Returns:
torch.FloatTensor: Latent vectors of nodes,shape of [batch_size, max_session_len, embedding_size]
"""
input_in = torch.matmul(A[:, :, :A.size(1)], self.linear_edge_in(hidden)) + self.b_iah
input_out = torch.matmul(A[:, :, A.size(1):2 * A.size(1)], self.linear_edge_out(hidden)) + self.b_ioh
# [batch_size, max_session_len, embedding_size * 2]
inputs = torch.cat([input_in, input_out], 2)
# gi.size equals to gh.size, shape of [batch_size, max_session_len, embedding_size * 3]
gi = F.linear(inputs, self.w_ih, self.b_ih)
gh = F.linear(hidden, self.w_hh, self.b_hh)
# (batch_size, max_session_len, embedding_size)
i_r, i_i, i_n = gi.chunk(3, 2)
h_r, h_i, h_n = gh.chunk(3, 2)
reset_gate = torch.sigmoid(i_r + h_r)
input_gate = torch.sigmoid(i_i + h_i)
new_gate = torch.tanh(i_n + reset_gate * h_n)
hy = (1 - input_gate) * hidden + input_gate * new_gate
return hy
[docs] def forward(self, A, hidden):
for i in range(self.step):
hidden = self.GNNCell(A, hidden)
return hidden
[docs]class SRGNN(SequentialRecommender):
r"""SRGNN regards the conversation history as a directed graph.
In addition to considering the connection between the item and the adjacent item,
it also considers the connection with other interactive items.
Such as: A example of a session sequence(eg:item1, item2, item3, item2, item4) and the connection matrix A
Outgoing edges:
=== ===== ===== ===== =====
\ 1 2 3 4
=== ===== ===== ===== =====
1 0 1 0 0
2 0 0 1/2 1/2
3 0 1 0 0
4 0 0 0 0
=== ===== ===== ===== =====
Incoming edges:
=== ===== ===== ===== =====
\ 1 2 3 4
=== ===== ===== ===== =====
1 0 0 0 0
2 1/2 0 1/2 0
3 0 1 0 0
4 0 1 0 0
=== ===== ===== ===== =====
"""
def __init__(self, config, dataset):
super(SRGNN, self).__init__(config, dataset)
# load parameters info
self.embedding_size = config['embedding_size']
self.step = config['step']
self.device = config['device']
self.loss_type = config['loss_type']
# define layers and loss
# item embedding
self.item_embedding = nn.Embedding(self.n_items, self.embedding_size, padding_idx=0)
# define layers and loss
self.gnn = GNN(self.embedding_size, self.step)
self.linear_one = nn.Linear(self.embedding_size, self.embedding_size, bias=True)
self.linear_two = nn.Linear(self.embedding_size, self.embedding_size, bias=True)
self.linear_three = nn.Linear(self.embedding_size, 1, bias=False)
self.linear_transform = nn.Linear(self.embedding_size * 2, self.embedding_size, bias=True)
if self.loss_type == 'BPR':
self.loss_fct = BPRLoss()
elif self.loss_type == 'CE':
self.loss_fct = nn.CrossEntropyLoss()
else:
raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!")
# parameters initialization
self._reset_parameters()
def _reset_parameters(self):
stdv = 1.0 / math.sqrt(self.embedding_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)
def _get_slice(self, item_seq):
# Mask matrix, shape of [batch_size, max_session_len]
mask = item_seq.gt(0)
items, n_node, A, alias_inputs = [], [], [], []
max_n_node = item_seq.size(1)
item_seq = item_seq.cpu().numpy()
for u_input in item_seq:
node = np.unique(u_input)
items.append(node.tolist() + (max_n_node - len(node)) * [0])
u_A = np.zeros((max_n_node, max_n_node))
for i in np.arange(len(u_input) - 1):
if u_input[i + 1] == 0:
break
u = np.where(node == u_input[i])[0][0]
v = np.where(node == u_input[i + 1])[0][0]
u_A[u][v] = 1
u_sum_in = np.sum(u_A, 0)
u_sum_in[np.where(u_sum_in == 0)] = 1
u_A_in = np.divide(u_A, u_sum_in)
u_sum_out = np.sum(u_A, 1)
u_sum_out[np.where(u_sum_out == 0)] = 1
u_A_out = np.divide(u_A.transpose(), u_sum_out)
u_A = np.concatenate([u_A_in, u_A_out]).transpose()
A.append(u_A)
alias_inputs.append([np.where(node == i)[0][0] for i in u_input])
# The relative coordinates of the item node, shape of [batch_size, max_session_len]
alias_inputs = torch.LongTensor(alias_inputs).to(self.device)
# The connecting matrix, shape of [batch_size, max_session_len, 2 * max_session_len]
A = torch.FloatTensor(A).to(self.device)
# The unique item nodes, shape of [batch_size, max_session_len]
items = torch.LongTensor(items).to(self.device)
return alias_inputs, A, items, mask
[docs] def forward(self, item_seq, item_seq_len):
alias_inputs, A, items, mask = self._get_slice(item_seq)
hidden = self.item_embedding(items)
hidden = self.gnn(A, hidden)
alias_inputs = alias_inputs.view(-1, alias_inputs.size(1), 1).expand(-1, -1, self.embedding_size)
seq_hidden = torch.gather(hidden, dim=1, index=alias_inputs)
# fetch the last hidden state of last timestamp
ht = self.gather_indexes(seq_hidden, item_seq_len - 1)
q1 = self.linear_one(ht).view(ht.size(0), 1, ht.size(1))
q2 = self.linear_two(seq_hidden)
alpha = self.linear_three(torch.sigmoid(q1 + q2))
a = torch.sum(alpha * seq_hidden * mask.view(mask.size(0), -1, 1).float(), 1)
seq_output = self.linear_transform(torch.cat([a, ht], dim=1))
return seq_output
[docs] def calculate_loss(self, interaction):
item_seq = interaction[self.ITEM_SEQ]
item_seq_len = interaction[self.ITEM_SEQ_LEN]
seq_output = self.forward(item_seq, item_seq_len)
pos_items = interaction[self.POS_ITEM_ID]
if self.loss_type == 'BPR':
neg_items = interaction[self.NEG_ITEM_ID]
pos_items_emb = self.item_embedding(pos_items)
neg_items_emb = self.item_embedding(neg_items)
pos_score = torch.sum(seq_output * pos_items_emb, dim=-1) # [B]
neg_score = torch.sum(seq_output * neg_items_emb, dim=-1) # [B]
loss = self.loss_fct(pos_score, neg_score)
return loss
else: # self.loss_type = 'CE'
test_item_emb = self.item_embedding.weight
logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1))
loss = self.loss_fct(logits, pos_items)
return loss
[docs] def predict(self, interaction):
item_seq = interaction[self.ITEM_SEQ]
item_seq_len = interaction[self.ITEM_SEQ_LEN]
test_item = interaction[self.ITEM_ID]
seq_output = self.forward(item_seq, item_seq_len)
test_item_emb = self.item_embedding(test_item)
scores = torch.mul(seq_output, test_item_emb).sum(dim=1) # [B]
return scores
[docs] def full_sort_predict(self, interaction):
item_seq = interaction[self.ITEM_SEQ]
item_seq_len = interaction[self.ITEM_SEQ_LEN]
seq_output = self.forward(item_seq, item_seq_len)
test_items_emb = self.item_embedding.weight
scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1)) # [B, n_items]
return scores