Source code for recbole.model.context_aware_recommender.nfm

# -*- coding: utf-8 -*-
# @Time   : 2020/7/14
# @Author : Zihan Lin
# @Email  : linzihan.super@foxmail.com
# @File   : nfm.py

r"""
NFM
################################################
Reference:
    He X, Chua T S. "Neural factorization machines for sparse predictive analytics" in SIGIR 2017
"""

import torch.nn as nn
from torch.nn.init import xavier_normal_, constant_

from recbole.model.abstract_recommender import ContextRecommender
from recbole.model.layers import BaseFactorizationMachine, MLPLayers


[docs]class NFM(ContextRecommender): """NFM replace the fm part as a mlp to model the feature interaction.""" def __init__(self, config, dataset): super(NFM, self).__init__(config, dataset) # load parameters info self.mlp_hidden_size = config["mlp_hidden_size"] self.dropout_prob = config["dropout_prob"] # define layers and loss size_list = [self.embedding_size] + self.mlp_hidden_size self.fm = BaseFactorizationMachine(reduce_sum=False) self.bn = nn.BatchNorm1d(num_features=self.embedding_size) self.mlp_layers = MLPLayers( size_list, self.dropout_prob, activation="sigmoid", bn=True ) self.predict_layer = nn.Linear(self.mlp_hidden_size[-1], 1, bias=False) self.sigmoid = nn.Sigmoid() self.loss = nn.BCEWithLogitsLoss() # parameters initialization self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Embedding): xavier_normal_(module.weight.data) elif isinstance(module, nn.Linear): xavier_normal_(module.weight.data) if module.bias is not None: constant_(module.bias.data, 0)
[docs] def forward(self, interaction): nfm_all_embeddings = self.concat_embed_input_fields( interaction ) # [batch_size, num_field, embed_dim] bn_nfm_all_embeddings = self.bn(self.fm(nfm_all_embeddings)) output = self.predict_layer( self.mlp_layers(bn_nfm_all_embeddings) ) + self.first_order_linear(interaction) return output.squeeze(-1)
[docs] def calculate_loss(self, interaction): label = interaction[self.LABEL] output = self.forward(interaction) return self.loss(output, label)
[docs] def predict(self, interaction): return self.sigmoid(self.forward(interaction))