# -*- coding: utf-8 -*-
# @Time : 2020/10/3
# @Author : Changxin Tian
# @Email : cx.tian@outlook.com
r"""
KGNNLS
################################################
Reference:
Hongwei Wang et al. "Knowledge-aware Graph Neural Networks with Label Smoothness Regularization
for Recommender Systems." in KDD 2019.
Reference code:
https://github.com/hwwang55/KGNN-LS
"""
import random
import numpy as np
import torch
import torch.nn as nn
from recbole.model.abstract_recommender import KnowledgeRecommender
from recbole.model.init import xavier_normal_initialization
from recbole.model.loss import EmbLoss
from recbole.utils import InputType
[docs]class KGNNLS(KnowledgeRecommender):
r"""KGNN-LS is a knowledge-based recommendation model.
KGNN-LS transforms the knowledge graph into a user-specific weighted graph and then apply a graph neural network to
compute personalized item embeddings. To provide better inductive bias, KGNN-LS relies on label smoothness
assumption, which posits that adjacent items in the knowledge graph are likely to have similar user relevance
labels/scores. Label smoothness provides regularization over the edge weights and it is equivalent to a label
propagation scheme on a graph.
"""
input_type = InputType.PAIRWISE
def __init__(self, config, dataset):
super(KGNNLS, self).__init__(config, dataset)
# load parameters info
self.embedding_size = config["embedding_size"]
self.neighbor_sample_size = config["neighbor_sample_size"]
self.aggregator_class = config["aggregator"] # which aggregator to use
# number of iterations when computing entity representation
self.n_iter = config["n_iter"]
self.reg_weight = config["reg_weight"] # weight of l2 regularization
# weight of label Smoothness regularization
self.ls_weight = config["ls_weight"]
# define embedding
self.user_embedding = nn.Embedding(self.n_users, self.embedding_size)
self.entity_embedding = nn.Embedding(self.n_entities, self.embedding_size)
self.relation_embedding = nn.Embedding(
self.n_relations + 1, self.embedding_size
)
# sample neighbors and construct interaction table
kg_graph = dataset.kg_graph(form="coo", value_field="relation_id")
adj_entity, adj_relation = self.construct_adj(kg_graph)
self.adj_entity, self.adj_relation = adj_entity.to(
self.device
), adj_relation.to(self.device)
inter_feat = dataset.inter_feat
pos_users = inter_feat[dataset.uid_field]
pos_items = inter_feat[dataset.iid_field]
pos_label = torch.ones(pos_items.shape)
pos_interaction_table, self.offset = self.get_interaction_table(
pos_users, pos_items, pos_label
)
self.interaction_table = self.sample_neg_interaction(
pos_interaction_table, self.offset
)
# define function
self.softmax = nn.Softmax(dim=-1)
self.linear_layers = torch.nn.ModuleList()
for i in range(self.n_iter):
self.linear_layers.append(
nn.Linear(
self.embedding_size
if not self.aggregator_class == "concat"
else self.embedding_size * 2,
self.embedding_size,
)
)
self.ReLU = nn.ReLU()
self.Tanh = nn.Tanh()
self.bce_loss = nn.BCEWithLogitsLoss()
self.l2_loss = EmbLoss()
# parameters initialization
self.apply(xavier_normal_initialization)
self.other_parameter_name = ["adj_entity", "adj_relation"]
[docs] def get_interaction_table(self, user_id, item_id, y):
r"""Get interaction_table that is used for fetching user-item interaction label in LS regularization.
Args:
user_id(torch.Tensor): the user id in user-item interactions, shape: [n_interactions, 1]
item_id(torch.Tensor): the item id in user-item interactions, shape: [n_interactions, 1]
y(torch.Tensor): the label in user-item interactions, shape: [n_interactions, 1]
Returns:
tuple:
- interaction_table(dict): key: user_id * 10^offset + item_id; value: y_{user_id, item_id}
- offset(int): The offset that is used for calculating the key(index) in interaction_table
"""
offset = len(str(self.n_entities))
offset = 10**offset
keys = user_id * offset + item_id
keys = keys.int().cpu().numpy().tolist()
values = y.float().cpu().numpy().tolist()
interaction_table = dict(zip(keys, values))
return interaction_table, offset
[docs] def sample_neg_interaction(self, pos_interaction_table, offset):
r"""Sample neg_interaction to construct train data.
Args:
pos_interaction_table(dict): the interaction_table that only contains pos_interaction.
offset(int): The offset that is used for calculating the key(index) in interaction_table
Returns:
interaction_table(dict): key: user_id * 10^offset + item_id; value: y_{user_id, item_id}
"""
pos_num = len(pos_interaction_table)
neg_num = 0
neg_interaction_table = {}
while neg_num < pos_num:
user_id = random.randint(0, self.n_users)
item_id = random.randint(0, self.n_items)
keys = user_id * offset + item_id
if keys not in pos_interaction_table:
neg_interaction_table[keys] = 0.0
neg_num += 1
interaction_table = {**pos_interaction_table, **neg_interaction_table}
return interaction_table
[docs] def construct_adj(self, kg_graph):
r"""Get neighbors and corresponding relations for each entity in the KG.
Args:
kg_graph(scipy.sparse.coo_matrix): an undirected graph
Returns:
tuple:
- adj_entity (torch.LongTensor): each line stores the sampled neighbor entities for a given entity,
shape: [n_entities, neighbor_sample_size]
- adj_relation (torch.LongTensor): each line stores the corresponding sampled neighbor relations,
shape: [n_entities, neighbor_sample_size]
"""
# self.logger.info('constructing knowledge graph ...')
# treat the KG as an undirected graph
kg_dict = dict()
for triple in zip(kg_graph.row, kg_graph.data, kg_graph.col):
head = triple[0]
relation = triple[1]
tail = triple[2]
if head not in kg_dict:
kg_dict[head] = []
kg_dict[head].append((tail, relation))
if tail not in kg_dict:
kg_dict[tail] = []
kg_dict[tail].append((head, relation))
# self.logger.info('constructing adjacency matrix ...')
# each line of adj_entity stores the sampled neighbor entities for a given entity
# each line of adj_relation stores the corresponding sampled neighbor relations
entity_num = kg_graph.shape[0]
adj_entity = np.zeros([entity_num, self.neighbor_sample_size], dtype=np.int64)
adj_relation = np.zeros([entity_num, self.neighbor_sample_size], dtype=np.int64)
for entity in range(entity_num):
if entity not in kg_dict.keys():
adj_entity[entity] = np.array([entity] * self.neighbor_sample_size)
adj_relation[entity] = np.array([0] * self.neighbor_sample_size)
continue
neighbors = kg_dict[entity]
n_neighbors = len(neighbors)
if n_neighbors >= self.neighbor_sample_size:
sampled_indices = np.random.choice(
list(range(n_neighbors)),
size=self.neighbor_sample_size,
replace=False,
)
else:
sampled_indices = np.random.choice(
list(range(n_neighbors)),
size=self.neighbor_sample_size,
replace=True,
)
adj_entity[entity] = np.array([neighbors[i][0] for i in sampled_indices])
adj_relation[entity] = np.array([neighbors[i][1] for i in sampled_indices])
return torch.from_numpy(adj_entity), torch.from_numpy(adj_relation)
[docs] def get_neighbors(self, items):
r"""Get neighbors and corresponding relations for each entity in items from adj_entity and adj_relation.
Args:
items(torch.LongTensor): The input tensor that contains item's id, shape: [batch_size, ]
Returns:
tuple:
- entities(list): Entities is a list of i-iter (i = 0, 1, ..., n_iter) neighbors for the batch of items.
dimensions of entities: {[batch_size, 1],
[batch_size, n_neighbor],
[batch_size, n_neighbor^2],
...,
[batch_size, n_neighbor^n_iter]}
- relations(list): Relations is a list of i-iter (i = 0, 1, ..., n_iter) corresponding relations for
entities. Relations have the same shape as entities.
"""
items = torch.unsqueeze(items, dim=1)
entities = [items]
relations = []
for i in range(self.n_iter):
index = torch.flatten(entities[i])
neighbor_entities = torch.index_select(self.adj_entity, 0, index).reshape(
self.batch_size, -1
)
neighbor_relations = torch.index_select(
self.adj_relation, 0, index
).reshape(self.batch_size, -1)
entities.append(neighbor_entities)
relations.append(neighbor_relations)
return entities, relations
[docs] def aggregate(self, user_embeddings, entities, relations):
r"""For each item, aggregate the entity representation and its neighborhood representation into a single vector.
Args:
user_embeddings(torch.FloatTensor): The embeddings of users, shape: [batch_size, embedding_size]
entities(list): entities is a list of i-iter (i = 0, 1, ..., n_iter) neighbors for the batch of items.
dimensions of entities: {[batch_size, 1],
[batch_size, n_neighbor],
[batch_size, n_neighbor^2],
...,
[batch_size, n_neighbor^n_iter]}
relations(list): relations is a list of i-iter (i = 0, 1, ..., n_iter) corresponding relations for entities.
relations have the same shape as entities.
Returns:
item_embeddings(torch.FloatTensor): The embeddings of items, shape: [batch_size, embedding_size]
"""
entity_vectors = [self.entity_embedding(i) for i in entities]
relation_vectors = [self.relation_embedding(i) for i in relations]
for i in range(self.n_iter):
entity_vectors_next_iter = []
for hop in range(self.n_iter - i):
shape = (
self.batch_size,
-1,
self.neighbor_sample_size,
self.embedding_size,
)
self_vectors = entity_vectors[hop]
neighbor_vectors = entity_vectors[hop + 1].reshape(shape)
neighbor_relations = relation_vectors[hop].reshape(shape)
# mix_neighbor_vectors
user_embeddings = user_embeddings.reshape(
self.batch_size, 1, 1, self.embedding_size
) # [batch_size, 1, 1, dim]
user_relation_scores = torch.mean(
user_embeddings * neighbor_relations, dim=-1
) # [batch_size, -1, n_neighbor]
user_relation_scores_normalized = torch.unsqueeze(
self.softmax(user_relation_scores), dim=-1
) # [batch_size, -1, n_neighbor, 1]
neighbors_agg = torch.mean(
user_relation_scores_normalized * neighbor_vectors, dim=2
) # [batch_size, -1, dim]
if self.aggregator_class == "sum":
output = (self_vectors + neighbors_agg).reshape(
-1, self.embedding_size
) # [-1, dim]
elif self.aggregator_class == "neighbor":
output = neighbors_agg.reshape(-1, self.embedding_size) # [-1, dim]
elif self.aggregator_class == "concat":
# [batch_size, -1, dim * 2]
output = torch.cat([self_vectors, neighbors_agg], dim=-1)
output = output.reshape(
-1, self.embedding_size * 2
) # [-1, dim * 2]
else:
raise Exception("Unknown aggregator: " + self.aggregator_class)
output = self.linear_layers[i](output)
# [batch_size, -1, dim]
output = output.reshape(self.batch_size, -1, self.embedding_size)
if i == self.n_iter - 1:
vector = self.Tanh(output)
else:
vector = self.ReLU(output)
entity_vectors_next_iter.append(vector)
entity_vectors = entity_vectors_next_iter
res = entity_vectors[0].reshape(self.batch_size, self.embedding_size)
return res
[docs] def label_smoothness_predict(self, user_embeddings, user, entities, relations):
r"""Predict the label of items by label smoothness.
Args:
user_embeddings(torch.FloatTensor): The embeddings of users, shape: [batch_size*2, embedding_size],
user(torch.FloatTensor): the index of users, shape: [batch_size*2]
entities(list): entities is a list of i-iter (i = 0, 1, ..., n_iter) neighbors for the batch of items.
dimensions of entities: {[batch_size*2, 1],
[batch_size*2, n_neighbor],
[batch_size*2, n_neighbor^2],
...,
[batch_size*2, n_neighbor^n_iter]}
relations(list): relations is a list of i-iter (i = 0, 1, ..., n_iter) corresponding relations for entities.
relations have the same shape as entities.
Returns:
predicted_labels(torch.FloatTensor): The predicted label of items, shape: [batch_size*2]
"""
# calculate initial labels; calculate updating masks for label propagation
entity_labels = []
# True means the label of this item is reset to initial value during label propagation
reset_masks = []
holdout_item_for_user = None
for entities_per_iter in entities:
users = torch.unsqueeze(user, dim=1) # [batch_size, 1]
user_entity_concat = (
users * self.offset + entities_per_iter
) # [batch_size, n_neighbor^i]
# the first one in entities is the items to be held out
if holdout_item_for_user is None:
holdout_item_for_user = user_entity_concat
def lookup_interaction_table(x, _):
x = int(x)
label = self.interaction_table.setdefault(x, 0.5)
return label
initial_label = user_entity_concat.clone().cpu().double()
initial_label.map_(initial_label, lookup_interaction_table)
initial_label = initial_label.float().to(self.device)
# False if the item is held out
holdout_mask = (holdout_item_for_user - user_entity_concat).bool()
# True if the entity is a labeled item
reset_mask = (initial_label - 0.5).bool()
reset_mask = torch.logical_and(
reset_mask, holdout_mask
) # remove held-out items
initial_label = (
holdout_mask.float() * initial_label
+ torch.logical_not(holdout_mask).float() * 0.5
) # label initialization
reset_masks.append(reset_mask)
entity_labels.append(initial_label)
# we do not need the reset_mask for the last iteration
reset_masks = reset_masks[:-1]
# label propagation
relation_vectors = [self.relation_embedding(i) for i in relations]
for i in range(self.n_iter):
entity_labels_next_iter = []
for hop in range(self.n_iter - i):
masks = reset_masks[hop]
self_labels = entity_labels[hop]
neighbor_labels = entity_labels[hop + 1].reshape(
self.batch_size, -1, self.neighbor_sample_size
)
neighbor_relations = relation_vectors[hop].reshape(
self.batch_size, -1, self.neighbor_sample_size, self.embedding_size
)
# mix_neighbor_labels
user_embeddings = user_embeddings.reshape(
self.batch_size, 1, 1, self.embedding_size
) # [batch_size, 1, 1, dim]
user_relation_scores = torch.mean(
user_embeddings * neighbor_relations, dim=-1
) # [batch_size, -1, n_neighbor]
user_relation_scores_normalized = self.softmax(
user_relation_scores
) # [batch_size, -1, n_neighbor]
neighbors_aggregated_label = torch.mean(
user_relation_scores_normalized * neighbor_labels, dim=2
) # [batch_size, -1, dim] # [batch_size, -1]
output = (
masks.float() * self_labels
+ torch.logical_not(masks).float() * neighbors_aggregated_label
)
entity_labels_next_iter.append(output)
entity_labels = entity_labels_next_iter
predicted_labels = entity_labels[0].squeeze(-1)
return predicted_labels
[docs] def forward(self, user, item):
self.batch_size = item.shape[0]
# [batch_size, dim]
user_e = self.user_embedding(user)
# entities is a list of i-iter (i = 0, 1, ..., n_iter) neighbors for the batch of items. dimensions of entities:
# {[batch_size, 1], [batch_size, n_neighbor], [batch_size, n_neighbor^2], ..., [batch_size, n_neighbor^n_iter]}
entities, relations = self.get_neighbors(item)
# [batch_size, dim]
item_e = self.aggregate(user_e, entities, relations)
return user_e, item_e
[docs] def calculate_ls_loss(self, user, item, target):
r"""Calculate label smoothness loss.
Args:
user(torch.FloatTensor): the index of users, shape: [batch_size*2],
item(torch.FloatTensor): the index of items, shape: [batch_size*2],
target(torch.FloatTensor): the label of user-item, shape: [batch_size*2],
Returns:
ls_loss: label smoothness loss
"""
user_e = self.user_embedding(user)
entities, relations = self.get_neighbors(item)
predicted_labels = self.label_smoothness_predict(
user_e, user, entities, relations
)
ls_loss = self.bce_loss(predicted_labels, target)
return ls_loss
[docs] def calculate_loss(self, interaction):
user = interaction[self.USER_ID]
pos_item = interaction[self.ITEM_ID]
neg_item = interaction[self.NEG_ITEM_ID]
target = torch.zeros(len(user) * 2, dtype=torch.float32).to(self.device)
target[: len(user)] = 1
users = torch.cat((user, user))
items = torch.cat((pos_item, neg_item))
user_e, item_e = self.forward(users, items)
predict = torch.mul(user_e, item_e).sum(dim=1)
rec_loss = self.bce_loss(predict, target)
ls_loss = self.calculate_ls_loss(users, items, target)
l2_loss = self.l2_loss(user_e, item_e)
loss = rec_loss + self.ls_weight * ls_loss + self.reg_weight * l2_loss
return loss
[docs] def predict(self, interaction):
user = interaction[self.USER_ID]
item = interaction[self.ITEM_ID]
user_e, item_e = self.forward(user, item)
return torch.mul(user_e, item_e).sum(dim=1)
[docs] def full_sort_predict(self, interaction):
user_index = interaction[self.USER_ID]
item_index = torch.tensor(range(self.n_items)).to(self.device)
user = torch.unsqueeze(user_index, dim=1).repeat(1, item_index.shape[0])
user = torch.flatten(user)
item = torch.unsqueeze(item_index, dim=0).repeat(user_index.shape[0], 1)
item = torch.flatten(item)
user_e, item_e = self.forward(user, item)
score = torch.mul(user_e, item_e).sum(dim=1)
return score.view(-1)