# -*- coding: utf-8 -*-
# @Time : 2020/6/27
# @Author : Shanlei Mu
# @Email : slmu@ruc.edu.cn
# UPDATE:
# @Time : 2020/8/22,
# @Author : Zihan Lin
# @Email : linzihan.super@foxmain.com
r"""
NeuMF
################################################
Reference:
Xiangnan He et al. "Neural Collaborative Filtering." in WWW 2017.
"""
import torch
import torch.nn as nn
from torch.nn.init import normal_
from recbole.model.abstract_recommender import GeneralRecommender
from recbole.model.layers import MLPLayers
from recbole.utils import InputType
[docs]class NeuMF(GeneralRecommender):
r"""NeuMF is an neural network enhanced matrix factorization model.
It replace the dot product to mlp for a more precise user-item interaction.
Note:
Our implementation only contains a rough pretraining function.
"""
input_type = InputType.POINTWISE
def __init__(self, config, dataset):
super(NeuMF, self).__init__(config, dataset)
# load dataset info
self.LABEL = config["LABEL_FIELD"]
# load parameters info
self.mf_embedding_size = config["mf_embedding_size"]
self.mlp_embedding_size = config["mlp_embedding_size"]
self.mlp_hidden_size = config["mlp_hidden_size"]
self.dropout_prob = config["dropout_prob"]
self.mf_train = config["mf_train"]
self.mlp_train = config["mlp_train"]
self.use_pretrain = config["use_pretrain"]
self.mf_pretrain_path = config["mf_pretrain_path"]
self.mlp_pretrain_path = config["mlp_pretrain_path"]
# define layers and loss
self.user_mf_embedding = nn.Embedding(self.n_users, self.mf_embedding_size)
self.item_mf_embedding = nn.Embedding(self.n_items, self.mf_embedding_size)
self.user_mlp_embedding = nn.Embedding(self.n_users, self.mlp_embedding_size)
self.item_mlp_embedding = nn.Embedding(self.n_items, self.mlp_embedding_size)
self.mlp_layers = MLPLayers(
[2 * self.mlp_embedding_size] + self.mlp_hidden_size, self.dropout_prob
)
self.mlp_layers.logger = None # remove logger to use torch.save()
if self.mf_train and self.mlp_train:
self.predict_layer = nn.Linear(
self.mf_embedding_size + self.mlp_hidden_size[-1], 1
)
elif self.mf_train:
self.predict_layer = nn.Linear(self.mf_embedding_size, 1)
elif self.mlp_train:
self.predict_layer = nn.Linear(self.mlp_hidden_size[-1], 1)
self.sigmoid = nn.Sigmoid()
self.loss = nn.BCEWithLogitsLoss()
# parameters initialization
if self.use_pretrain:
self.load_pretrain()
else:
self.apply(self._init_weights)
[docs] def load_pretrain(self):
r"""A simple implementation of loading pretrained parameters."""
mf = torch.load(self.mf_pretrain_path, map_location="cpu")
mlp = torch.load(self.mlp_pretrain_path, map_location="cpu")
mf = mf if "state_dict" not in mf else mf["state_dict"]
mlp = mlp if "state_dict" not in mlp else mlp["state_dict"]
self.user_mf_embedding.weight.data.copy_(mf["user_mf_embedding.weight"])
self.item_mf_embedding.weight.data.copy_(mf["item_mf_embedding.weight"])
self.user_mlp_embedding.weight.data.copy_(mlp["user_mlp_embedding.weight"])
self.item_mlp_embedding.weight.data.copy_(mlp["item_mlp_embedding.weight"])
mlp_layers = list(self.mlp_layers.state_dict().keys())
index = 0
for layer in self.mlp_layers.mlp_layers:
if isinstance(layer, nn.Linear):
weight_key = "mlp_layers." + mlp_layers[index]
bias_key = "mlp_layers." + mlp_layers[index + 1]
assert (
layer.weight.shape == mlp[weight_key].shape
), f"mlp layer parameter shape mismatch"
assert (
layer.bias.shape == mlp[bias_key].shape
), f"mlp layer parameter shape mismatch"
layer.weight.data.copy_(mlp[weight_key])
layer.bias.data.copy_(mlp[bias_key])
index += 2
predict_weight = torch.cat(
[mf["predict_layer.weight"], mlp["predict_layer.weight"]], dim=1
)
predict_bias = mf["predict_layer.bias"] + mlp["predict_layer.bias"]
self.predict_layer.weight.data.copy_(predict_weight)
self.predict_layer.bias.data.copy_(0.5 * predict_bias)
def _init_weights(self, module):
if isinstance(module, nn.Embedding):
normal_(module.weight.data, mean=0.0, std=0.01)
[docs] def forward(self, user, item):
user_mf_e = self.user_mf_embedding(user)
item_mf_e = self.item_mf_embedding(item)
user_mlp_e = self.user_mlp_embedding(user)
item_mlp_e = self.item_mlp_embedding(item)
if self.mf_train:
mf_output = torch.mul(user_mf_e, item_mf_e) # [batch_size, embedding_size]
if self.mlp_train:
mlp_output = self.mlp_layers(
torch.cat((user_mlp_e, item_mlp_e), -1)
) # [batch_size, layers[-1]]
if self.mf_train and self.mlp_train:
output = self.predict_layer(torch.cat((mf_output, mlp_output), -1))
elif self.mf_train:
output = self.predict_layer(mf_output)
elif self.mlp_train:
output = self.predict_layer(mlp_output)
else:
raise RuntimeError(
"mf_train and mlp_train can not be False at the same time"
)
return output.squeeze(-1)
[docs] def calculate_loss(self, interaction):
user = interaction[self.USER_ID]
item = interaction[self.ITEM_ID]
label = interaction[self.LABEL]
output = self.forward(user, item)
return self.loss(output, label)
[docs] def predict(self, interaction):
user = interaction[self.USER_ID]
item = interaction[self.ITEM_ID]
predict = self.sigmoid(self.forward(user, item))
return predict
[docs] def dump_parameters(self):
r"""A simple implementation of dumping model parameters for pretrain."""
if self.mf_train and not self.mlp_train:
save_path = self.mf_pretrain_path
torch.save(self, save_path)
elif self.mlp_train and not self.mf_train:
save_path = self.mlp_pretrain_path
torch.save(self, save_path)