# -*- coding: utf-8 -*-
# @Time : 2020/10/6
# @Author : Changxin Tian
# @Email : cx.tian@outlook.com
r"""
KGCN
################################################
Reference:
Hongwei Wang et al. "Knowledge graph convolution networks for recommender systems." in WWW 2019.
Reference code:
https://github.com/hwwang55/KGCN
"""
import torch
import torch.nn as nn
import numpy as np
from recbole.utils import InputType
from recbole.model.abstract_recommender import KnowledgeRecommender
from recbole.model.loss import BPRLoss, EmbLoss
from recbole.model.init import xavier_normal_initialization
[docs]class KGCN(KnowledgeRecommender):
r"""KGCN is a knowledge-based recommendation model that captures inter-item relatedness effectively by mining their
associated attributes on the KG. To automatically discover both high-order structure information and semantic
information of the KG, we treat KG as an undirected graph and sample from the neighbors for each entity in the KG
as their receptive field, then combine neighborhood information with bias when calculating the representation of a
given entity.
"""
input_type = InputType.PAIRWISE
def __init__(self, config, dataset):
super(KGCN, self).__init__(config, dataset)
# load parameters info
self.embedding_size = config['embedding_size']
# number of iterations when computing entity representation
self.n_iter = config['n_iter']
self.aggregator_class = config['aggregator'] # which aggregator to use
self.reg_weight = config['reg_weight'] # weight of l2 regularization
self.neighbor_sample_size = config['neighbor_sample_size']
# define embedding
self.user_embedding = nn.Embedding(self.n_users, self.embedding_size)
self.entity_embedding = nn.Embedding(
self.n_entities, self.embedding_size)
self.relation_embedding = nn.Embedding(
self.n_relations + 1, self.embedding_size)
# sample neighbors
kg_graph = dataset.kg_graph(form='coo', value_field='relation_id')
adj_entity, adj_relation = self.construct_adj(kg_graph)
self.adj_entity, self.adj_relation = adj_entity.to(
self.device), adj_relation.to(self.device)
# define function
self.softmax = nn.Softmax(dim=-1)
self.linear_layers = torch.nn.ModuleList()
for i in range(self.n_iter):
self.linear_layers.append(nn.Linear(
self.embedding_size if not self.aggregator_class == 'concat' else self.embedding_size * 2,
self.embedding_size))
self.ReLU = nn.ReLU()
self.Tanh = nn.Tanh()
self.bce_loss = nn.BCEWithLogitsLoss()
self.l2_loss = EmbLoss()
# parameters initialization
self.apply(xavier_normal_initialization)
[docs] def construct_adj(self, kg_graph):
r"""Get neighbors and corresponding relations for each entity in the KG.
Args:
kg_graph(scipy.sparse.coo_matrix): an undirected graph
Returns:
tuple:
- adj_entity(torch.LongTensor): each line stores the sampled neighbor entities for a given entity,
shape: [n_entities, neighbor_sample_size]
- adj_relation(torch.LongTensor): each line stores the corresponding sampled neighbor relations,
shape: [n_entities, neighbor_sample_size]
"""
# print('constructing knowledge graph ...')
# treat the KG as an undirected graph
kg_dict = dict()
for triple in zip(kg_graph.row, kg_graph.data, kg_graph.col):
head = triple[0]
relation = triple[1]
tail = triple[2]
if head not in kg_dict:
kg_dict[head] = []
kg_dict[head].append((tail, relation))
if tail not in kg_dict:
kg_dict[tail] = []
kg_dict[tail].append((head, relation))
# print('constructing adjacency matrix ...')
# each line of adj_entity stores the sampled neighbor entities for a given entity
# each line of adj_relation stores the corresponding sampled neighbor relations
entity_num = kg_graph.shape[0]
adj_entity = np.zeros(
[entity_num, self.neighbor_sample_size], dtype=np.int64)
adj_relation = np.zeros(
[entity_num, self.neighbor_sample_size], dtype=np.int64)
for entity in range(entity_num):
if entity not in kg_dict.keys():
adj_entity[entity] = np.array(
[entity] * self.neighbor_sample_size)
adj_relation[entity] = np.array(
[0] * self.neighbor_sample_size)
continue
neighbors = kg_dict[entity]
n_neighbors = len(neighbors)
if n_neighbors >= self.neighbor_sample_size:
sampled_indices = np.random.choice(list(range(n_neighbors)), size=self.neighbor_sample_size,
replace=False)
else:
sampled_indices = np.random.choice(list(range(n_neighbors)), size=self.neighbor_sample_size,
replace=True)
adj_entity[entity] = np.array(
[neighbors[i][0] for i in sampled_indices])
adj_relation[entity] = np.array(
[neighbors[i][1] for i in sampled_indices])
return torch.from_numpy(adj_entity), torch.from_numpy(adj_relation)
[docs] def get_neighbors(self, items):
r"""Get neighbors and corresponding relations for each entity in items from adj_entity and adj_relation.
Args:
items(torch.LongTensor): The input tensor that contains item's id, shape: [batch_size, ]
Returns:
tuple:
- entities(list): Entities is a list of i-iter (i = 0, 1, ..., n_iter) neighbors for the batch of items.
dimensions of entities: {[batch_size, 1],
[batch_size, n_neighbor],
[batch_size, n_neighbor^2],
...,
[batch_size, n_neighbor^n_iter]}
- relations(list): Relations is a list of i-iter (i = 0, 1, ..., n_iter) corresponding relations for
entities. Relations have the same shape as entities.
"""
items = torch.unsqueeze(items, dim=1)
entities = [items]
relations = []
for i in range(self.n_iter):
index = torch.flatten(entities[i])
neighbor_entities = torch.reshape(torch.index_select(
self.adj_entity, 0, index), (self.batch_size, -1))
neighbor_relations = torch.reshape(torch.index_select(
self.adj_relation, 0, index), (self.batch_size, -1))
entities.append(neighbor_entities)
relations.append(neighbor_relations)
return entities, relations
[docs] def mix_neighbor_vectors(self, neighbor_vectors, neighbor_relations, user_embeddings):
r"""Mix neighbor vectors on user-specific graph.
Args:
neighbor_vectors(torch.FloatTensor): The embeddings of neighbor entities(items),
shape: [batch_size, -1, neighbor_sample_size, embedding_size]
neighbor_relations(torch.FloatTensor): The embeddings of neighbor relations,
shape: [batch_size, -1, neighbor_sample_size, embedding_size]
user_embeddings(torch.FloatTensor): The embeddings of users, shape: [batch_size, embedding_size]
Returns:
neighbors_aggregated(torch.FloatTensor): The neighbors aggregated embeddings,
shape: [batch_size, -1, embedding_size]
"""
avg = False
if not avg:
user_embeddings = torch.reshape(user_embeddings,
(self.batch_size, 1, 1, self.embedding_size)) # [batch_size, 1, 1, dim]
user_relation_scores = torch.mean(user_embeddings * neighbor_relations,
dim=-1) # [batch_size, -1, n_neighbor]
user_relation_scores_normalized = self.softmax(
user_relation_scores) # [batch_size, -1, n_neighbor]
user_relation_scores_normalized = torch.unsqueeze(user_relation_scores_normalized,
dim=-1) # [batch_size, -1, n_neighbor, 1]
neighbors_aggregated = torch.mean(user_relation_scores_normalized * neighbor_vectors,
dim=2) # [batch_size, -1, dim]
else:
neighbors_aggregated = torch.mean(
neighbor_vectors, dim=2) # [batch_size, -1, dim]
return neighbors_aggregated
[docs] def aggregate(self, user_embeddings, entities, relations):
r"""For each item, aggregate the entity representation and its neighborhood representation into a single vector.
Args:
user_embeddings(torch.FloatTensor): The embeddings of users, shape: [batch_size, embedding_size]
entities(list): entities is a list of i-iter (i = 0, 1, ..., n_iter) neighbors for the batch of items.
dimensions of entities: {[batch_size, 1],
[batch_size, n_neighbor],
[batch_size, n_neighbor^2],
...,
[batch_size, n_neighbor^n_iter]}
relations(list): relations is a list of i-iter (i = 0, 1, ..., n_iter) corresponding relations for entities.
relations have the same shape as entities.
Returns:
item_embeddings(torch.FloatTensor): The embeddings of items, shape: [batch_size, embedding_size]
"""
entity_vectors = [self.entity_embedding(i) for i in entities]
relation_vectors = [self.relation_embedding(i) for i in relations]
for i in range(self.n_iter):
entity_vectors_next_iter = []
for hop in range(self.n_iter - i):
shape = (self.batch_size, -1,
self.neighbor_sample_size, self.embedding_size)
self_vectors = entity_vectors[hop]
neighbor_vectors = torch.reshape(
entity_vectors[hop + 1], shape)
neighbor_relations = torch.reshape(
relation_vectors[hop], shape)
neighbors_agg = self.mix_neighbor_vectors(neighbor_vectors, neighbor_relations,
user_embeddings) # [batch_size, -1, dim]
if self.aggregator_class == 'sum':
output = torch.reshape(
self_vectors + neighbors_agg, (-1, self.embedding_size)) # [-1, dim]
elif self.aggregator_class == 'neighbor':
output = torch.reshape(
neighbors_agg, (-1, self.embedding_size)) # [-1, dim]
elif self.aggregator_class == 'concat':
# [batch_size, -1, dim * 2]
output = torch.cat([self_vectors, neighbors_agg], dim=-1)
output = torch.reshape(
output, (-1, self.embedding_size * 2)) # [-1, dim * 2]
else:
raise Exception("Unknown aggregator: " +
self.aggregator_class)
output = self.linear_layers[i](output)
# [batch_size, -1, dim]
output = torch.reshape(
output, [self.batch_size, -1, self.embedding_size])
if i == self.n_iter - 1:
vector = self.Tanh(output)
else:
vector = self.ReLU(output)
entity_vectors_next_iter.append(vector)
entity_vectors = entity_vectors_next_iter
item_embeddings = torch.reshape(
entity_vectors[0], (self.batch_size, self.embedding_size))
return item_embeddings
[docs] def forward(self, user, item):
self.batch_size = item.shape[0]
# [batch_size, dim]
user_e = self.user_embedding(user)
# entities is a list of i-iter (i = 0, 1, ..., n_iter) neighbors for the batch of items. dimensions of entities:
# {[batch_size, 1], [batch_size, n_neighbor], [batch_size, n_neighbor^2], ..., [batch_size, n_neighbor^n_iter]}
entities, relations = self.get_neighbors(item)
# [batch_size, dim]
item_e = self.aggregate(user_e, entities, relations)
return user_e, item_e
[docs] def calculate_loss(self, interaction):
user = interaction[self.USER_ID]
pos_item = interaction[self.ITEM_ID]
neg_item = interaction[self.NEG_ITEM_ID]
user_e, pos_item_e = self.forward(user, pos_item)
user_e, neg_item_e = self.forward(user, neg_item)
pos_item_score = torch.mul(user_e, pos_item_e).sum(dim=1)
neg_item_score = torch.mul(user_e, neg_item_e).sum(dim=1)
predict = torch.cat((pos_item_score, neg_item_score))
target = torch.zeros(
len(user) * 2, dtype=torch.float32).to(self.device)
target[:len(user)] = 1
rec_loss = self.bce_loss(predict, target)
l2_loss = self.l2_loss(user_e, pos_item_e, neg_item_e)
loss = rec_loss + self.reg_weight * l2_loss
return loss
[docs] def predict(self, interaction):
user = interaction[self.USER_ID]
item = interaction[self.ITEM_ID]
user_e, item_e = self.forward(user, item)
return torch.mul(user_e, item_e).sum(dim=1)
[docs] def full_sort_predict(self, interaction):
user_index = interaction[self.USER_ID]
item_index = torch.tensor(range(self.n_items)).to(self.device)
user = torch.unsqueeze(user_index, dim=1).repeat(
1, item_index.shape[0])
user = torch.flatten(user)
item = torch.unsqueeze(item_index, dim=0).repeat(
user_index.shape[0], 1)
item = torch.flatten(item)
user_e, item_e = self.forward(user, item)
score = torch.mul(user_e, item_e).sum(dim=1)
return score.view(-1)