Source code for recbole.model.context_aware_recommender.autoint

# -*- coding: utf-8 -*-
# @Time   : 2020/09/01
# @Author : Shuqing Bian
# @Email  : shuqingbian@gmail.com
# @File   : autoint.py

r"""
AutoInt
################################################
Reference:
    Weiping Song et al. "AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks"
    in CIKM 2018.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import xavier_normal_, constant_

from recbole.model.abstract_recommender import ContextRecommender
from recbole.model.layers import MLPLayers


[docs]class AutoInt(ContextRecommender): """AutoInt is a novel CTR prediction model based on self-attention mechanism, which can automatically learn high-order feature interactions in an explicit fashion. """ def __init__(self, config, dataset): super(AutoInt, self).__init__(config, dataset) # load parameters info self.attention_size = config["attention_size"] self.dropout_probs = config["dropout_probs"] self.n_layers = config["n_layers"] self.num_heads = config["num_heads"] self.mlp_hidden_size = config["mlp_hidden_size"] self.has_residual = config["has_residual"] # define layers and loss self.att_embedding = nn.Linear(self.embedding_size, self.attention_size) self.embed_output_dim = self.num_feature_field * self.embedding_size self.atten_output_dim = self.num_feature_field * self.attention_size size_list = [self.embed_output_dim] + self.mlp_hidden_size self.mlp_layers = MLPLayers(size_list, dropout=self.dropout_probs[1]) # multi-head self-attention network self.self_attns = nn.ModuleList( [ nn.MultiheadAttention( self.attention_size, self.num_heads, dropout=self.dropout_probs[0] ) for _ in range(self.n_layers) ] ) self.attn_fc = torch.nn.Linear(self.atten_output_dim, 1) self.deep_predict_layer = nn.Linear(self.mlp_hidden_size[-1], 1) if self.has_residual: self.v_res_embedding = torch.nn.Linear( self.embedding_size, self.attention_size ) self.dropout_layer = nn.Dropout(p=self.dropout_probs[2]) self.sigmoid = nn.Sigmoid() self.loss = nn.BCEWithLogitsLoss() # parameters initialization self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Embedding): xavier_normal_(module.weight.data) elif isinstance(module, nn.Linear): xavier_normal_(module.weight.data) if module.bias is not None: constant_(module.bias.data, 0)
[docs] def autoint_layer(self, infeature): """Get the attention-based feature interaction score Args: infeature (torch.FloatTensor): input feature embedding tensor. shape of[batch_size,field_size,embed_dim]. Returns: torch.FloatTensor: Result of score. shape of [batch_size,1] . """ att_infeature = self.att_embedding(infeature) cross_term = att_infeature.transpose(0, 1) for self_attn in self.self_attns: cross_term, _ = self_attn(cross_term, cross_term, cross_term) cross_term = cross_term.transpose(0, 1) # Residual connection if self.has_residual: v_res = self.v_res_embedding(infeature) cross_term += v_res # Interacting layer cross_term = F.relu(cross_term).contiguous().view(-1, self.atten_output_dim) batch_size = infeature.shape[0] att_output = self.attn_fc(cross_term) + self.deep_predict_layer( self.mlp_layers(infeature.view(batch_size, -1)) ) return att_output
[docs] def forward(self, interaction): autoint_all_embeddings = self.concat_embed_input_fields( interaction ) # [batch_size, num_field, embed_dim] output = self.first_order_linear(interaction) + self.autoint_layer( autoint_all_embeddings ) return output.squeeze(1)
[docs] def calculate_loss(self, interaction): label = interaction[self.LABEL] output = self.forward(interaction) return self.loss(output, label)
[docs] def predict(self, interaction): return self.sigmoid(self.forward(interaction))