KGAT

Reference:

Xiang Wang et al. “KGAT: Knowledge Graph Attention Network for Recommendation.” in SIGKDD 2019.

Reference code:

https://github.com/xiangwang1223/knowledge_graph_attention_network

class recbole.model.knowledge_aware_recommender.kgat.Aggregator(input_dim, output_dim, dropout, aggregator_type)[source]

Bases: torch.nn.modules.module.Module

GNN Aggregator layer

forward(norm_matrix, ego_embeddings)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class recbole.model.knowledge_aware_recommender.kgat.KGAT(config, dataset)[source]

Bases: recbole.model.abstract_recommender.KnowledgeRecommender

KGAT is a knowledge-based recommendation model. It combines knowledge graph and the user-item interaction graph to a new graph called collaborative knowledge graph (CKG). This model learns the representations of users and items by exploiting the structure of CKG. It adopts a GNN-based architecture and define the attention on the CKG.

calculate_kg_loss(interaction)[source]

Calculate the training loss for a batch data of KG.

Parameters

interaction (Interaction) – Interaction class of the batch.

Returns

Training loss, shape: []

Return type

torch.Tensor

calculate_loss(interaction)[source]

Calculate the training loss for a batch data.

Parameters

interaction (Interaction) – Interaction class of the batch.

Returns

Training loss, shape: []

Return type

torch.Tensor

forward()[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

full_sort_predict(interaction)[source]

full sort prediction function. Given users, calculate the scores between users and all candidate items.

Parameters

interaction (Interaction) – Interaction class of the batch.

Returns

Predicted scores for given users and all candidate items, shape: [n_batch_users * n_candidate_items]

Return type

torch.Tensor

generate_transE_score(hs, ts, r)[source]

Calculating scores for triples in KG.

Parameters
  • hs (torch.Tensor) – head entities

  • ts (torch.Tensor) – tail entities

  • r (int) – the relation id between hs and ts

Returns

the scores of (hs, r, ts)

Return type

torch.Tensor

init_graph()[source]

Get the initial attention matrix through the collaborative knowledge graph

Returns

Sparse tensor of the attention matrix

Return type

torch.sparse.FloatTensor

input_type = 2
predict(interaction)[source]

Predict the scores between users and items.

Parameters

interaction (Interaction) – Interaction class of the batch.

Returns

Predicted scores for given users and items, shape: [batch_size]

Return type

torch.Tensor

update_attentive_A()[source]

Update the attention matrix using the updated embedding matrix