Source code for recbole.model.context_aware_recommender.fignn

# -*- coding: utf-8 -*-
# @Time   : 2022/10/27
# @Author : Yuyan Zhang
# @Email  : 2019308160102@cau.edu.cn
# @File   : fignn.py

r"""
FiGNN
################################################
Reference:
    Li, Zekun, et al.  "Fi-GNN: Modeling Feature Interactions via Graph Neural Networks for CTR Prediction"
    in CIKM 2019.

Reference code:
    - https://github.com/CRIPAC-DIG/GraphCTR
    - https://github.com/xue-pai/FuxiCTR
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import xavier_uniform_, xavier_normal_, constant_
from itertools import product

from recbole.utils import InputType
from recbole.model.abstract_recommender import ContextRecommender


[docs]class GraphLayer(nn.Module): """ The implementations of the GraphLayer part and the Attentional Edge Weights part are adapted from https://github.com/xue-pai/FuxiCTR. """ def __init__(self, num_fields, embedding_size): super(GraphLayer, self).__init__() self.W_in = nn.Parameter( torch.Tensor(num_fields, embedding_size, embedding_size) ) self.W_out = nn.Parameter( torch.Tensor(num_fields, embedding_size, embedding_size) ) xavier_normal_(self.W_in) xavier_normal_(self.W_out) self.bias_p = nn.Parameter(torch.zeros(embedding_size))
[docs] def forward(self, g, h): h_out = torch.matmul(self.W_out, h.unsqueeze(-1)).squeeze(-1) aggr = torch.bmm(g, h_out) a = torch.matmul(self.W_in, aggr.unsqueeze(-1)).squeeze(-1) + self.bias_p return a
[docs]class FiGNN(ContextRecommender): """FiGNN is a CTR prediction model based on GGNN, which can model sophisticated interactions among feature fields on the graph-structured features. """ input_type = InputType.POINTWISE def __init__(self, config, dataset): super(FiGNN, self).__init__(config, dataset) # load parameters info self.attention_size = config["attention_size"] self.n_layers = config["n_layers"] self.num_heads = config["num_heads"] self.hidden_dropout_prob = config["hidden_dropout_prob"] self.attn_dropout_prob = config["attn_dropout_prob"] # define layers and loss self.dropout_layer = nn.Dropout(p=self.hidden_dropout_prob) self.att_embedding = nn.Linear(self.embedding_size, self.attention_size) # multi-head self-attention network self.self_attn = nn.MultiheadAttention( self.attention_size, self.num_heads, dropout=self.attn_dropout_prob, batch_first=True, ) self.v_res_embedding = torch.nn.Linear(self.embedding_size, self.attention_size) # FiGNN self.src_nodes, self.dst_nodes = zip( *list(product(range(self.num_feature_field), repeat=2)) ) self.gnn = nn.ModuleList( [ GraphLayer(self.num_feature_field, self.attention_size) for _ in range(self.n_layers - 1) ] ) self.leaky_relu = nn.LeakyReLU(negative_slope=0.01) self.W_attn = nn.Linear(self.attention_size * 2, 1, bias=False) self.gru_cell = nn.GRUCell(self.attention_size, self.attention_size) # Attentional Scoring Layer self.mlp1 = nn.Linear(self.attention_size, 1, bias=False) self.mlp2 = nn.Linear( self.num_feature_field * self.attention_size, self.num_feature_field, bias=False, ) self.sigmoid = nn.Sigmoid() self.loss = nn.BCEWithLogitsLoss() # parameters initialization self.apply(self._init_weights)
[docs] def fignn_layer(self, in_feature): emb_feature = self.att_embedding(in_feature) emb_feature = self.dropout_layer(emb_feature) # multi-head self-attention network att_feature, _ = self.self_attn( emb_feature, emb_feature, emb_feature ) # [batch_size, num_field, att_dim] # Residual connection v_res = self.v_res_embedding(in_feature) att_feature += v_res att_feature = F.relu(att_feature).contiguous() # init graph src_emb = att_feature[:, self.src_nodes, :] dst_emb = att_feature[:, self.dst_nodes, :] concat_emb = torch.cat([src_emb, dst_emb], dim=-1) alpha = self.leaky_relu(self.W_attn(concat_emb)) alpha = alpha.view(-1, self.num_feature_field, self.num_feature_field) mask = torch.eye(self.num_feature_field).to(self.device) alpha = alpha.masked_fill(mask.bool(), float("-inf")) self.graph = F.softmax(alpha, dim=-1) # message passing if self.n_layers > 1: h = att_feature for i in range(self.n_layers - 1): a = self.gnn[i](self.graph, h) a = a.view(-1, self.attention_size) h = h.view(-1, self.attention_size) h = self.gru_cell(a, h) h = h.view(-1, self.num_feature_field, self.attention_size) h += att_feature else: h = att_feature # Attentional Scoring Layer score = self.mlp1(h).squeeze(-1) weight = self.mlp2(h.flatten(start_dim=1)) logit = (weight * score).sum(dim=1).unsqueeze(-1) return logit
def _init_weights(self, module): if isinstance(module, nn.Embedding): xavier_normal_(module.weight.data) elif isinstance(module, nn.Linear): xavier_normal_(module.weight.data) if module.bias is not None: constant_(module.bias.data, 0) elif isinstance(module, nn.GRU): xavier_uniform_(module.weight_hh_l0) xavier_uniform_(module.weight_ih_l0)
[docs] def forward(self, interaction): fignn_all_embeddings = self.concat_embed_input_fields( interaction ) # [batch_size, num_field, embed_dim] output = self.fignn_layer(fignn_all_embeddings) return output.squeeze(1)
[docs] def calculate_loss(self, interaction): label = interaction[self.LABEL] output = self.forward(interaction) return self.loss(output, label)
[docs] def predict(self, interaction): return self.sigmoid(self.forward(interaction))