Source code for recbole.model.context_aware_recommender.deepfm
# -*- coding: utf-8 -*-# @Time : 2020/7/8# @Author : Shanlei Mu# @Email : slmu@ruc.edu.cn# @File : deepfm.py# UPDATE:# @Time : 2020/8/14# @Author : Zihan Lin# @Email : linzihan.super@foxmain.comr"""DeepFM################################################Reference: Huifeng Guo et al. "DeepFM: A Factorization-Machine based Neural Network for CTR Prediction." in IJCAI 2017."""importtorch.nnasnnfromtorch.nn.initimportxavier_normal_,constant_fromrecbole.model.abstract_recommenderimportContextRecommenderfromrecbole.model.layersimportBaseFactorizationMachine,MLPLayers
[docs]classDeepFM(ContextRecommender):"""DeepFM is a DNN enhanced FM which both use a DNN and a FM to calculate feature interaction. Also DeepFM can be seen as a combination of FNN and FM. """def__init__(self,config,dataset):super(DeepFM,self).__init__(config,dataset)# load parameters infoself.mlp_hidden_size=config['mlp_hidden_size']self.dropout_prob=config['dropout_prob']# define layers and lossself.fm=BaseFactorizationMachine(reduce_sum=True)size_list=[self.embedding_size*self.num_feature_field]+self.mlp_hidden_sizeself.mlp_layers=MLPLayers(size_list,self.dropout_prob)self.deep_predict_layer=nn.Linear(self.mlp_hidden_size[-1],1)# Linear product to the final scoreself.sigmoid=nn.Sigmoid()self.loss=nn.BCELoss()# parameters initializationself.apply(self._init_weights)def_init_weights(self,module):ifisinstance(module,nn.Embedding):xavier_normal_(module.weight.data)elifisinstance(module,nn.Linear):xavier_normal_(module.weight.data)ifmodule.biasisnotNone:constant_(module.bias.data,0)