Source code for recbole.model.knowledge_aware_recommender.cke

# -*- coding: utf-8 -*-
# @Time   : 2020/8/6
# @Author : Shanlei Mu
# @Email  : slmu@ruc.edu.cn

r"""
CKE
##################################################
Reference:
    Fuzheng Zhang et al. "Collaborative Knowledge Base Embedding for Recommender Systems." in SIGKDD 2016.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

from recbole.model.abstract_recommender import KnowledgeRecommender
from recbole.model.init import xavier_normal_initialization
from recbole.model.loss import BPRLoss, EmbLoss
from recbole.utils import InputType


[docs]class CKE(KnowledgeRecommender): r"""CKE is a knowledge-based recommendation model, it can incorporate KG and other information such as corresponding images to enrich the representation of items for item recommendations. Note: In the original paper, CKE used structural knowledge, textual knowledge and visual knowledge. In our implementation, we only used structural knowledge. Meanwhile, the version we implemented uses a simpler regular way which can get almost the same result (even better) as the original regular way. """ input_type = InputType.PAIRWISE def __init__(self, config, dataset): super(CKE, self).__init__(config, dataset) # load parameters info self.embedding_size = config['embedding_size'] self.kg_embedding_size = config['kg_embedding_size'] self.reg_weights = config['reg_weights'] # define layers and loss self.user_embedding = nn.Embedding(self.n_users, self.embedding_size) self.item_embedding = nn.Embedding(self.n_items, self.embedding_size) self.entity_embedding = nn.Embedding(self.n_entities, self.embedding_size) self.relation_embedding = nn.Embedding(self.n_relations, self.kg_embedding_size) self.trans_w = nn.Embedding(self.n_relations, self.embedding_size * self.kg_embedding_size) self.rec_loss = BPRLoss() self.kg_loss = BPRLoss() self.reg_loss = EmbLoss() # parameters initialization self.apply(xavier_normal_initialization) def _get_kg_embedding(self, h, r, pos_t, neg_t): h_e = self.entity_embedding(h).unsqueeze(1) pos_t_e = self.entity_embedding(pos_t).unsqueeze(1) neg_t_e = self.entity_embedding(neg_t).unsqueeze(1) r_e = self.relation_embedding(r) r_trans_w = self.trans_w(r).view(r.size(0), self.embedding_size, self.kg_embedding_size) h_e = torch.bmm(h_e, r_trans_w).squeeze(1) pos_t_e = torch.bmm(pos_t_e, r_trans_w).squeeze(1) neg_t_e = torch.bmm(neg_t_e, r_trans_w).squeeze(1) r_e = F.normalize(r_e, p=2, dim=1) h_e = F.normalize(h_e, p=2, dim=1) pos_t_e = F.normalize(pos_t_e, p=2, dim=1) neg_t_e = F.normalize(neg_t_e, p=2, dim=1) return h_e, r_e, pos_t_e, neg_t_e, r_trans_w
[docs] def forward(self, user, item): u_e = self.user_embedding(user) i_e = self.item_embedding(item) + self.entity_embedding(item) score = torch.mul(u_e, i_e).sum(dim=1) return score
def _get_rec_loss(self, user_e, pos_e, neg_e): pos_score = torch.mul(user_e, pos_e).sum(dim=1) neg_score = torch.mul(user_e, neg_e).sum(dim=1) rec_loss = self.rec_loss(pos_score, neg_score) return rec_loss def _get_kg_loss(self, h_e, r_e, pos_e, neg_e): pos_tail_score = ((h_e + r_e - pos_e) ** 2).sum(dim=1) neg_tail_score = ((h_e + r_e - neg_e) ** 2).sum(dim=1) kg_loss = self.kg_loss(neg_tail_score, pos_tail_score) return kg_loss
[docs] def calculate_loss(self, interaction): user = interaction[self.USER_ID] pos_item = interaction[self.ITEM_ID] neg_item = interaction[self.NEG_ITEM_ID] h = interaction[self.HEAD_ENTITY_ID] r = interaction[self.RELATION_ID] pos_t = interaction[self.TAIL_ENTITY_ID] neg_t = interaction[self.NEG_TAIL_ENTITY_ID] user_e = self.user_embedding(user) pos_item_e = self.item_embedding(pos_item) neg_item_e = self.item_embedding(neg_item) pos_item_kg_e = self.entity_embedding(pos_item) neg_item_kg_e = self.entity_embedding(neg_item) pos_item_final_e = pos_item_e + pos_item_kg_e neg_item_final_e = neg_item_e + neg_item_kg_e rec_loss = self._get_rec_loss(user_e, pos_item_final_e, neg_item_final_e) h_e, r_e, pos_t_e, neg_t_e, r_trans_w = self._get_kg_embedding(h, r, pos_t, neg_t) kg_loss = self._get_kg_loss(h_e, r_e, pos_t_e, neg_t_e) reg_loss = self.reg_weights[0] * self.reg_loss(user_e, pos_item_final_e, neg_item_final_e) + \ self.reg_weights[1] * self.reg_loss(h_e, r_e, pos_t_e, neg_t_e) return rec_loss, kg_loss, reg_loss
[docs] def predict(self, interaction): user = interaction[self.USER_ID] item = interaction[self.ITEM_ID] return self.forward(user, item)
[docs] def full_sort_predict(self, interaction): user = interaction[self.USER_ID] user_e = self.user_embedding(user) all_item_e = self.item_embedding.weight + self.entity_embedding.weight[:self.n_items] score = torch.matmul(user_e, all_item_e.transpose(0, 1)) return score.view(-1)