Source code for recbole.model.knowledge_aware_recommender.ktup

# @Time   : 2020/8/6
# @Author : Shanlei Mu
# @Email  : slmu@ruc.edu.cn

r"""
KTUP
##################################################
Reference:
    Yixin Cao et al. "Unifying Knowledge Graph Learning and Recommendation:Towards a Better Understanding
    of User Preferences." in WWW 2019.

Reference code:
    https://github.com/TaoMiner/joint-kg-recommender
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

from recbole.utils import InputType
from recbole.model.abstract_recommender import KnowledgeRecommender
from recbole.model.loss import BPRLoss, EmbMarginLoss
from recbole.model.init import xavier_uniform_initialization


[docs]class KTUP(KnowledgeRecommender): r"""KTUP is a knowledge-based recommendation model. It adopts the strategy of multi-task learning to jointly learn recommendation and KG-related tasks, with the goal of understanding the reasons that a user interacts with an item. This method utilizes an attention mechanism to combine all preferences into a single-vector representation. """ input_type = InputType.PAIRWISE def __init__(self, config, dataset): super(KTUP, self).__init__(config, dataset) # load parameters info self.embedding_size = config['embedding_size'] self.L1_flag = config['L1_flag'] self.use_st_gumbel = config['use_st_gumbel'] self.kg_weight = config['kg_weight'] self.align_weight = config['align_weight'] self.margin = config['margin'] # define layers and loss self.user_embedding = nn.Embedding(self.n_users, self.embedding_size) self.item_embedding = nn.Embedding(self.n_items, self.embedding_size) self.pref_embedding = nn.Embedding(self.n_relations, self.embedding_size) self.pref_norm_embedding = nn.Embedding(self.n_relations, self.embedding_size) self.entity_embedding = nn.Embedding(self.n_entities, self.embedding_size) self.relation_embedding = nn.Embedding(self.n_relations, self.embedding_size) self.relation_norm_embedding = nn.Embedding(self.n_relations, self.embedding_size) self.rec_loss = BPRLoss() self.kg_loss = nn.MarginRankingLoss(margin=self.margin) self.reg_loss = EmbMarginLoss() # parameters initialization self.apply(xavier_uniform_initialization) normalize_user_emb = F.normalize(self.user_embedding.weight.data, p=2, dim=1) normalize_item_emb = F.normalize(self.item_embedding.weight.data, p=2, dim=1) normalize_pref_emb = F.normalize(self.pref_embedding.weight.data, p=2, dim=1) normalize_pref_norm_emb = F.normalize(self.pref_norm_embedding.weight.data, p=2, dim=1) normalize_entity_emb = F.normalize(self.entity_embedding.weight.data, p=2, dim=1) normalize_rel_emb = F.normalize(self.relation_embedding.weight.data, p=2, dim=1) normalize_rel_norm_emb = F.normalize(self.relation_norm_embedding.weight.data, p=2, dim=1) self.user_embedding.weight.data = normalize_user_emb self.item_embedding.weight_data = normalize_item_emb self.pref_embedding.weight.data = normalize_pref_emb self.pref_norm_embedding.weight.data = normalize_pref_norm_emb self.entity_embedding.weight.data = normalize_entity_emb self.relation_embedding.weight.data = normalize_rel_emb self.relation_norm_embedding.weight.data = normalize_rel_norm_emb def _masked_softmax(self, logits): probs = F.softmax(logits, dim=len(logits.shape)-1) return probs
[docs] def convert_to_one_hot(self, indices, num_classes): r""" Args: indices (Variable): A vector containing indices, whose size is (batch_size,). num_classes (Variable): The number of classes, which would be the second dimension of the resulting one-hot matrix. Returns: torch.Tensor: The one-hot matrix of size (batch_size, num_classes). """ old_shape = indices.shape new_shape = torch.Size([i for i in old_shape] + [num_classes]) indices = indices.unsqueeze(len(old_shape)) one_hot = Variable(indices.data.new(new_shape).zero_().scatter_(len(old_shape), indices.data, 1)) return one_hot
[docs] def st_gumbel_softmax(self, logits, temperature=1.0): r"""Return the result of Straight-Through Gumbel-Softmax Estimation. It approximates the discrete sampling via Gumbel-Softmax trick and applies the biased ST estimator. In the forward propagation, it emits the discrete one-hot result, and in the backward propagation it approximates the categorical distribution via smooth Gumbel-Softmax distribution. Args: logits (Variable): A un-normalized probability values, which has the size (batch_size, num_classes) temperature (float): A temperature parameter. The higher the value is, the smoother the distribution is. Returns: torch.Tensor: The sampled output, which has the property explained above. """ eps = 1e-20 u = logits.data.new(*logits.size()).uniform_() gumbel_noise = Variable(-torch.log(-torch.log(u + eps) + eps)) y = logits + gumbel_noise y = self._masked_softmax(logits=y / temperature) y_argmax = y.max(len(y.shape) - 1)[1] y_hard = self.convert_to_one_hot( indices=y_argmax, num_classes=y.size(len(y.shape) - 1)).float() y = (y_hard - y).detach() + y return y
def _get_preferences(self, user_e, item_e, use_st_gumbel=False): pref_probs = torch.matmul(user_e + item_e, torch.t(self.pref_embedding.weight + self.relation_embedding.weight)) / 2 if use_st_gumbel: # todo: different torch versions may cause the st_gumbel_softmax to report errors, wait to be test pref_probs = self.st_gumbel_softmax(pref_probs) relation_e = torch.matmul(pref_probs, self.pref_embedding.weight + self.relation_embedding.weight) / 2 norm_e = torch.matmul(pref_probs, self.pref_norm_embedding.weight + self.relation_norm_embedding.weight) / 2 return pref_probs, relation_e, norm_e @staticmethod def _transH_projection(original, norm): return original - torch.sum(original * norm, dim=len(original.size()) - 1, keepdim=True) * norm def _get_score(self, h_e, r_e, t_e): if self.L1_flag: score = - torch.sum(torch.abs(h_e + r_e - t_e), 1) else: score = - torch.sum((h_e + r_e - t_e) ** 2, 1) return score
[docs] def forward(self, user, item): user_e = self.user_embedding(user) item_e = self.item_embedding(item) entity_e = self.entity_embedding(item) item_e = item_e + entity_e _, relation_e, norm_e = self._get_preferences(user_e, item_e, use_st_gumbel=self.use_st_gumbel) proj_user_e = self._transH_projection(user_e, norm_e) proj_item_e = self._transH_projection(item_e, norm_e) return proj_user_e, relation_e, proj_item_e
[docs] def calculate_loss(self, interaction): user = interaction[self.USER_ID] pos_item = interaction[self.ITEM_ID] neg_item = interaction[self.NEG_ITEM_ID] proj_pos_user_e, pos_relation_e, proj_pos_item_e = self.forward(user, pos_item) proj_neg_user_e, neg_relation_e, proj_neg_item_e = self.forward(user, neg_item) pos_item_score = self._get_score(proj_pos_user_e, pos_relation_e, proj_pos_item_e) neg_item_score = self._get_score(proj_neg_user_e, neg_relation_e, proj_neg_item_e) rec_loss = self.rec_loss(pos_item_score, neg_item_score) orthogonal_loss = orthogonalLoss(self.pref_embedding.weight, self.pref_norm_embedding.weight) item = torch.cat([pos_item, neg_item]) align_loss = self.align_weight * alignLoss(self.item_embedding(item), self.entity_embedding(item), self.L1_flag) return rec_loss, orthogonal_loss, align_loss
[docs] def calculate_kg_loss(self, interaction): r"""Calculate the training loss for a batch data of KG. Args: interaction (Interaction): Interaction class of the batch. Returns: torch.Tensor: Training loss, shape: [] """ h = interaction[self.HEAD_ENTITY_ID] r = interaction[self.RELATION_ID] pos_t = interaction[self.TAIL_ENTITY_ID] neg_t = interaction[self.NEG_TAIL_ENTITY_ID] h_e = self.entity_embedding(h) pos_t_e = self.entity_embedding(pos_t) neg_t_e = self.entity_embedding(neg_t) r_e = self.relation_embedding(r) norm_e = self.relation_norm_embedding(r) proj_h_e = self._transH_projection(h_e, norm_e) proj_pos_t_e = self._transH_projection(pos_t_e, norm_e) proj_neg_t_e = self._transH_projection(neg_t_e, norm_e) pos_tail_score = self._get_score(proj_h_e, r_e, proj_pos_t_e) neg_tail_score = self._get_score(proj_h_e, r_e, proj_neg_t_e) kg_loss = self.kg_loss(pos_tail_score, neg_tail_score, torch.ones(h.size(0)).to(self.device)) orthogonal_loss = orthogonalLoss(r_e, norm_e) reg_loss = self.reg_loss(h_e, pos_t_e, neg_t_e, r_e) loss = self.kg_weight * (kg_loss + orthogonal_loss + reg_loss) entity = torch.cat([h, pos_t, neg_t]) entity = entity[entity < self.n_items] align_loss = self.align_weight * alignLoss(self.item_embedding(entity), self.entity_embedding(entity), self.L1_flag) return loss, align_loss
[docs] def predict(self, interaction): user = interaction[self.USER_ID] item = interaction[self.ITEM_ID] proj_user_e, relation_e, proj_item_e = self.forward(user, item) return self._get_score(proj_user_e, relation_e, proj_item_e)
[docs]def orthogonalLoss(rel_embeddings, norm_embeddings): return torch.sum(torch.sum(norm_embeddings * rel_embeddings, dim=1, keepdim=True) ** 2 / torch.sum(rel_embeddings ** 2, dim=1, keepdim=True))
[docs]def alignLoss(emb1, emb2, L1_flag=False): if L1_flag: distance = torch.sum(torch.abs(emb1 - emb2), 1) else: distance = torch.sum((emb1 - emb2) ** 2, 1) return distance.mean()