# -*- coding: utf-8 -*-
# @Time : 2020/9/15
# @Author : Shanlei Mu
# @Email : slmu@ruc.edu.cn
r"""
KGAT
##################################################
Reference:
Xiang Wang et al. "KGAT: Knowledge Graph Attention Network for Recommendation." in SIGKDD 2019.
Reference code:
https://github.com/xiangwang1223/knowledge_graph_attention_network
"""
import numpy as np
import scipy.sparse as sp
import torch
import torch.nn as nn
import torch.nn.functional as F
from recbole.model.abstract_recommender import KnowledgeRecommender
from recbole.model.init import xavier_normal_initialization
from recbole.model.loss import BPRLoss, EmbLoss
from recbole.utils import InputType
[docs]class Aggregator(nn.Module):
""" GNN Aggregator layer
"""
def __init__(self, input_dim, output_dim, dropout, aggregator_type):
super(Aggregator, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.dropout = dropout
self.aggregator_type = aggregator_type
self.message_dropout = nn.Dropout(dropout)
if self.aggregator_type == 'gcn':
self.W = nn.Linear(self.input_dim, self.output_dim)
elif self.aggregator_type == 'graphsage':
self.W = nn.Linear(self.input_dim * 2, self.output_dim)
elif self.aggregator_type == 'bi':
self.W1 = nn.Linear(self.input_dim, self.output_dim)
self.W2 = nn.Linear(self.input_dim, self.output_dim)
else:
raise NotImplementedError
self.activation = nn.LeakyReLU()
[docs] def forward(self, norm_matrix, ego_embeddings):
side_embeddings = torch.sparse.mm(norm_matrix, ego_embeddings)
if self.aggregator_type == 'gcn':
ego_embeddings = self.activation(self.W(ego_embeddings + side_embeddings))
elif self.aggregator_type == 'graphsage':
ego_embeddings = self.activation(self.W(torch.cat([ego_embeddings, side_embeddings], dim=1)))
elif self.aggregator_type == 'bi':
add_embeddings = ego_embeddings + side_embeddings
sum_embeddings = self.activation(self.W1(add_embeddings))
bi_embeddings = torch.mul(ego_embeddings, side_embeddings)
bi_embeddings = self.activation(self.W2(bi_embeddings))
ego_embeddings = bi_embeddings + sum_embeddings
else:
raise NotImplementedError
ego_embeddings = self.message_dropout(ego_embeddings)
return ego_embeddings
[docs]class KGAT(KnowledgeRecommender):
r"""KGAT is a knowledge-based recommendation model. It combines knowledge graph and the user-item interaction
graph to a new graph called collaborative knowledge graph (CKG). This model learns the representations of users and
items by exploiting the structure of CKG. It adopts a GNN-based architecture and define the attention on the CKG.
"""
input_type = InputType.PAIRWISE
def __init__(self, config, dataset):
super(KGAT, self).__init__(config, dataset)
# load dataset info
self.ckg = dataset.ckg_graph(form='dgl', value_field='relation_id')
self.all_hs = torch.LongTensor(dataset.ckg_graph(form='coo', value_field='relation_id').row).to(self.device)
self.all_ts = torch.LongTensor(dataset.ckg_graph(form='coo', value_field='relation_id').col).to(self.device)
self.all_rs = torch.LongTensor(dataset.ckg_graph(form='coo', value_field='relation_id').data).to(self.device)
self.matrix_size = torch.Size([self.n_users + self.n_entities, self.n_users + self.n_entities])
# load parameters info
self.embedding_size = config['embedding_size']
self.kg_embedding_size = config['kg_embedding_size']
self.layers = [self.embedding_size] + config['layers']
self.aggregator_type = config['aggregator_type']
self.mess_dropout = config['mess_dropout']
self.reg_weight = config['reg_weight']
# generate intermediate data
self.A_in = self.init_graph() # init the attention matrix by the structure of ckg
# define layers and loss
self.user_embedding = nn.Embedding(self.n_users, self.embedding_size)
self.entity_embedding = nn.Embedding(self.n_entities, self.embedding_size)
self.relation_embedding = nn.Embedding(self.n_relations, self.kg_embedding_size)
self.trans_w = nn.Embedding(self.n_relations, self.embedding_size * self.kg_embedding_size)
self.aggregator_layers = nn.ModuleList()
for idx, (input_dim, output_dim) in enumerate(zip(self.layers[:-1], self.layers[1:])):
self.aggregator_layers.append(Aggregator(input_dim, output_dim, self.mess_dropout, self.aggregator_type))
self.tanh = nn.Tanh()
self.mf_loss = BPRLoss()
self.reg_loss = EmbLoss()
self.restore_user_e = None
self.restore_entity_e = None
# parameters initialization
self.apply(xavier_normal_initialization)
self.other_parameter_name = ['restore_user_e', 'restore_entity_e']
[docs] def init_graph(self):
r"""Get the initial attention matrix through the collaborative knowledge graph
Returns:
torch.sparse.FloatTensor: Sparse tensor of the attention matrix
"""
import dgl
adj_list = []
for rel_type in range(1, self.n_relations, 1):
edge_idxs = self.ckg.filter_edges(lambda edge: edge.data['relation_id'] == rel_type)
sub_graph = dgl.edge_subgraph(self.ckg, edge_idxs, preserve_nodes=True). \
adjacency_matrix(transpose=False, scipy_fmt='coo').astype('float')
rowsum = np.array(sub_graph.sum(1))
d_inv = np.power(rowsum, -1).flatten()
d_inv[np.isinf(d_inv)] = 0.
d_mat_inv = sp.diags(d_inv)
norm_adj = d_mat_inv.dot(sub_graph).tocoo()
adj_list.append(norm_adj)
final_adj_matrix = sum(adj_list).tocoo()
indices = torch.LongTensor([final_adj_matrix.row, final_adj_matrix.col])
values = torch.FloatTensor(final_adj_matrix.data)
adj_matrix_tensor = torch.sparse.FloatTensor(indices, values, self.matrix_size)
return adj_matrix_tensor.to(self.device)
def _get_ego_embeddings(self):
user_embeddings = self.user_embedding.weight
entity_embeddings = self.entity_embedding.weight
ego_embeddings = torch.cat([user_embeddings, entity_embeddings], dim=0)
return ego_embeddings
[docs] def forward(self):
ego_embeddings = self._get_ego_embeddings()
embeddings_list = [ego_embeddings]
for aggregator in self.aggregator_layers:
ego_embeddings = aggregator(self.A_in, ego_embeddings)
norm_embeddings = F.normalize(ego_embeddings, p=2, dim=1)
embeddings_list.append(norm_embeddings)
kgat_all_embeddings = torch.cat(embeddings_list, dim=1)
user_all_embeddings, entity_all_embeddings = torch.split(kgat_all_embeddings, [self.n_users, self.n_entities])
return user_all_embeddings, entity_all_embeddings
def _get_kg_embedding(self, h, r, pos_t, neg_t):
h_e = self.entity_embedding(h).unsqueeze(1)
pos_t_e = self.entity_embedding(pos_t).unsqueeze(1)
neg_t_e = self.entity_embedding(neg_t).unsqueeze(1)
r_e = self.relation_embedding(r)
r_trans_w = self.trans_w(r).view(r.size(0), self.embedding_size, self.kg_embedding_size)
h_e = torch.bmm(h_e, r_trans_w).squeeze(1)
pos_t_e = torch.bmm(pos_t_e, r_trans_w).squeeze(1)
neg_t_e = torch.bmm(neg_t_e, r_trans_w).squeeze(1)
return h_e, r_e, pos_t_e, neg_t_e
[docs] def calculate_loss(self, interaction):
if self.restore_user_e is not None or self.restore_entity_e is not None:
self.restore_user_e, self.restore_entity_e = None, None
# get loss for training rs
user = interaction[self.USER_ID]
pos_item = interaction[self.ITEM_ID]
neg_item = interaction[self.NEG_ITEM_ID]
user_all_embeddings, entity_all_embeddings = self.forward()
u_embeddings = user_all_embeddings[user]
pos_embeddings = entity_all_embeddings[pos_item]
neg_embeddings = entity_all_embeddings[neg_item]
pos_scores = torch.mul(u_embeddings, pos_embeddings).sum(dim=1)
neg_scores = torch.mul(u_embeddings, neg_embeddings).sum(dim=1)
mf_loss = self.mf_loss(pos_scores, neg_scores)
reg_loss = self.reg_loss(u_embeddings, pos_embeddings, neg_embeddings)
loss = mf_loss + self.reg_weight * reg_loss
return loss
[docs] def calculate_kg_loss(self, interaction):
r"""Calculate the training loss for a batch data of KG.
Args:
interaction (Interaction): Interaction class of the batch.
Returns:
torch.Tensor: Training loss, shape: []
"""
if self.restore_user_e is not None or self.restore_entity_e is not None:
self.restore_user_e, self.restore_entity_e = None, None
# get loss for training kg
h = interaction[self.HEAD_ENTITY_ID]
r = interaction[self.RELATION_ID]
pos_t = interaction[self.TAIL_ENTITY_ID]
neg_t = interaction[self.NEG_TAIL_ENTITY_ID]
h_e, r_e, pos_t_e, neg_t_e = self._get_kg_embedding(h, r, pos_t, neg_t)
pos_tail_score = ((h_e + r_e - pos_t_e) ** 2).sum(dim=1)
neg_tail_score = ((h_e + r_e - neg_t_e) ** 2).sum(dim=1)
kg_loss = F.softplus(pos_tail_score - neg_tail_score).mean()
kg_reg_loss = self.reg_loss(h_e, r_e, pos_t_e, neg_t_e)
loss = kg_loss + self.reg_weight * kg_reg_loss
return loss
[docs] def generate_transE_score(self, hs, ts, r):
r"""Calculating scores for triples in KG.
Args:
hs (torch.Tensor): head entities
ts (torch.Tensor): tail entities
r (int): the relation id between hs and ts
Returns:
torch.Tensor: the scores of (hs, r, ts)
"""
all_embeddings = self._get_ego_embeddings()
h_e = all_embeddings[hs]
t_e = all_embeddings[ts]
r_e = self.relation_embedding.weight[r]
r_trans_w = self.trans_w.weight[r].view(self.embedding_size, self.kg_embedding_size)
h_e = torch.matmul(h_e, r_trans_w)
t_e = torch.matmul(t_e, r_trans_w)
kg_score = torch.mul(t_e, self.tanh(h_e + r_e)).sum(dim=1)
return kg_score
[docs] def update_attentive_A(self):
r"""Update the attention matrix using the updated embedding matrix
"""
kg_score_list, row_list, col_list = [], [], []
# To reduce the GPU memory consumption, we calculate the scores of KG triples according to the type of relation
for rel_idx in range(1, self.n_relations, 1):
triple_index = torch.where(self.all_rs == rel_idx)
kg_score = self.generate_transE_score(self.all_hs[triple_index], self.all_ts[triple_index], rel_idx)
row_list.append(self.all_hs[triple_index])
col_list.append(self.all_ts[triple_index])
kg_score_list.append(kg_score)
kg_score = torch.cat(kg_score_list, dim=0)
row = torch.cat(row_list, dim=0)
col = torch.cat(col_list, dim=0)
indices = torch.cat([row, col], dim=0).view(2, -1)
# Current PyTorch version does not support softmax on SparseCUDA, temporarily move to CPU to calculate softmax
A_in = torch.sparse.FloatTensor(indices, kg_score, self.matrix_size).cpu()
A_in = torch.sparse.softmax(A_in, dim=1).to(self.device)
self.A_in = A_in
[docs] def predict(self, interaction):
user = interaction[self.USER_ID]
item = interaction[self.ITEM_ID]
user_all_embeddings, entity_all_embeddings = self.forward()
u_embeddings = user_all_embeddings[user]
i_embeddings = entity_all_embeddings[item]
scores = torch.mul(u_embeddings, i_embeddings).sum(dim=1)
return scores
[docs] def full_sort_predict(self, interaction):
user = interaction[self.USER_ID]
if self.restore_user_e is None or self.restore_entity_e is None:
self.restore_user_e, self.restore_entity_e = self.forward()
u_embeddings = self.restore_user_e[user]
i_embeddings = self.restore_entity_e[:self.n_items]
scores = torch.matmul(u_embeddings, i_embeddings.transpose(0, 1))
return scores.view(-1)