KGAT¶
- Reference:
Xiang Wang et al. “KGAT: Knowledge Graph Attention Network for Recommendation.” in SIGKDD 2019.
- Reference code:
https://github.com/xiangwang1223/knowledge_graph_attention_network
- class recbole.model.knowledge_aware_recommender.kgat.Aggregator(input_dim, output_dim, dropout, aggregator_type)[source]¶
Bases:
torch.nn.modules.module.Module
GNN Aggregator layer
- forward(norm_matrix, ego_embeddings)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool¶
- class recbole.model.knowledge_aware_recommender.kgat.KGAT(config, dataset)[source]¶
Bases:
recbole.model.abstract_recommender.KnowledgeRecommender
KGAT is a knowledge-based recommendation model. It combines knowledge graph and the user-item interaction graph to a new graph called collaborative knowledge graph (CKG). This model learns the representations of users and items by exploiting the structure of CKG. It adopts a GNN-based architecture and define the attention on the CKG.
- calculate_kg_loss(interaction)[source]¶
Calculate the training loss for a batch data of KG.
- Parameters
interaction (Interaction) – Interaction class of the batch.
- Returns
Training loss, shape: []
- Return type
torch.Tensor
- calculate_loss(interaction)[source]¶
Calculate the training loss for a batch data.
- Parameters
interaction (Interaction) – Interaction class of the batch.
- Returns
Training loss, shape: []
- Return type
torch.Tensor
- forward()[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- full_sort_predict(interaction)[source]¶
full sort prediction function. Given users, calculate the scores between users and all candidate items.
- Parameters
interaction (Interaction) – Interaction class of the batch.
- Returns
Predicted scores for given users and all candidate items, shape: [n_batch_users * n_candidate_items]
- Return type
torch.Tensor
- generate_transE_score(hs, ts, r)[source]¶
Calculating scores for triples in KG.
- Parameters
hs (torch.Tensor) – head entities
ts (torch.Tensor) – tail entities
r (int) – the relation id between hs and ts
- Returns
the scores of (hs, r, ts)
- Return type
torch.Tensor
- init_graph()[source]¶
Get the initial attention matrix through the collaborative knowledge graph
- Returns
Sparse tensor of the attention matrix
- Return type
torch.sparse.FloatTensor
- input_type = 2¶
- predict(interaction)[source]¶
Predict the scores between users and items.
- Parameters
interaction (Interaction) – Interaction class of the batch.
- Returns
Predicted scores for given users and items, shape: [batch_size]
- Return type
torch.Tensor
- training: bool¶