# -*- coding: utf-8 -*-
# @Time : 2020/10/04
# @Author : Xinyan Fan
# @Email : xinyan.fan@ruc.edu.cn
# @File : ffm.py
# UPDATE:
# @Time : 2022/7/16
# @Author : Zhen Tian
# @Email : chenyuwuxinn@gmail.com
r"""
FFM
#####################################################
Reference:
Yuchin Juan et al. "Field-aware Factorization Machines for CTR Prediction" in RecSys 2016.
Reference code:
https://github.com/rixwew/pytorch-fm
"""
import numpy as np
import torch
import torch.nn as nn
from torch.nn.init import xavier_normal_, constant_
from recbole.model.abstract_recommender import ContextRecommender
[docs]class FFM(ContextRecommender):
r"""FFM is a context-based recommendation model. It aims to model the different feature interactions
between different fields. Each feature has several latent vectors :math:`v_{i,F(j)}`,
which depend on the field of other features, and one of them is used to do the inner product.
The model defines as follows:
.. math::
y = w_0 + \sum_{i=1}^{m}x_{i}w_{i} + \sum_{i=1}^{m}\sum_{j=i+1}^{m}x_{i}x_{j}<v_{i,F(j)}, v_{j,F(i)}>
"""
def __init__(self, config, dataset):
super(FFM, self).__init__(config, dataset)
# load parameters info
self.fields = config["fields"] # a dict; key: field_id; value: feature_list
self.sigmoid = nn.Sigmoid()
self.feature2id = {}
self.feature2field = {}
self.feature_names = (
self.token_field_names,
self.float_field_names,
self.token_seq_field_names,
self.float_seq_field_names,
)
self.feature_dims = (
self.token_field_dims,
self.float_field_dims,
self.token_seq_field_dims,
self.float_seq_field_dims,
)
self._get_feature2field()
self.num_fields = len(set(self.feature2field.values())) # the number of fields
self.ffm = FieldAwareFactorizationMachine(
self.feature_names,
self.feature_dims,
self.feature2id,
self.feature2field,
self.num_fields,
self.embedding_size,
self.device,
)
self.loss = nn.BCEWithLogitsLoss()
# parameters initialization
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Embedding):
xavier_normal_(module.weight.data)
elif isinstance(module, nn.Linear):
xavier_normal_(module.weight.data)
if module.bias is not None:
constant_(module.bias.data, 0)
def _get_feature2field(self):
r"""Create a mapping between features and fields."""
fea_id = 0
for names in self.feature_names:
if names is not None:
for name in names:
self.feature2id[name] = fea_id
fea_id += 1
if self.fields is None:
field_id = 0
for key, value in self.feature2id.items():
self.feature2field[self.feature2id[key]] = field_id
field_id += 1
else:
for key, value in self.fields.items():
for v in value:
try:
self.feature2field[self.feature2id[v]] = key
except:
pass
[docs] def forward(self, interaction):
ffm_input = self.get_ffm_input(interaction)
ffm_output = torch.sum(
torch.sum(self.ffm(ffm_input), dim=1), dim=1, keepdim=True
)
output = self.first_order_linear(interaction) + ffm_output
return output.squeeze(-1)
[docs] def calculate_loss(self, interaction):
label = interaction[self.LABEL]
output = self.forward(interaction)
return self.loss(output, label)
[docs] def predict(self, interaction):
return self.sigmoid(self.forward(interaction))
[docs]class FieldAwareFactorizationMachine(nn.Module):
r"""This is Field-Aware Factorization Machine Module for FFM."""
def __init__(
self,
feature_names,
feature_dims,
feature2id,
feature2field,
num_fields,
embed_dim,
device,
):
super(FieldAwareFactorizationMachine, self).__init__()
self.token_feature_names = feature_names[0]
self.float_feature_names = feature_names[1]
self.token_seq_feature_names = feature_names[2]
self.float_seq_feature_names = feature_names[3]
self.token_feature_dims = feature_dims[0]
self.float_feature_dims = feature_dims[1]
self.token_seq_feature_dims = feature_dims[2]
self.float_seq_feature_dims = feature_dims[3]
self.feature2id = feature2id
self.feature2field = feature2field
self.num_features = (
len(self.token_feature_names)
+ len(self.float_feature_names)
+ len(self.token_seq_feature_names)
+ len(self.float_seq_feature_names)
)
self.num_fields = num_fields
self.embed_dim = embed_dim
self.device = device
# init token field-aware embeddings if there is token type of features.
if len(self.token_feature_names) > 0:
self.num_token_features = len(self.token_feature_names)
self.token_embeddings = torch.nn.ModuleList(
[
nn.Embedding(sum(self.token_feature_dims), self.embed_dim)
for _ in range(self.num_fields)
]
)
self.token_offsets = np.array(
(0, *np.cumsum(self.token_feature_dims)[:-1]), dtype=np.long
)
for embedding in self.token_embeddings:
nn.init.xavier_uniform_(embedding.weight.data)
# init float field-aware embeddings if there is float type of features.
if len(self.float_feature_names) > 0:
self.num_float_features = len(self.float_feature_names)
self.float_offsets = np.array(
(0, *np.cumsum(self.float_feature_dims)[:-1]), dtype=np.long
)
self.float_embeddings = torch.nn.ModuleList(
[
nn.Embedding(sum(self.float_feature_dims), self.embed_dim)
for _ in range(self.num_fields)
]
)
for embedding in self.float_embeddings:
nn.init.xavier_uniform_(embedding.weight.data)
# init token_seq field-aware embeddings if there is token_seq type of features.
if len(self.token_seq_feature_names) > 0:
self.num_token_seq_features = len(self.token_seq_feature_names)
self.token_seq_embeddings = torch.nn.ModuleList()
self.token_seq_embedding = torch.nn.ModuleList()
for i in range(self.num_fields):
for token_seq_feature_dim in self.token_seq_feature_dims:
self.token_seq_embedding.append(
nn.Embedding(token_seq_feature_dim, self.embed_dim)
)
for embedding in self.token_seq_embedding:
nn.init.xavier_uniform_(embedding.weight.data)
self.token_seq_embeddings.append(self.token_seq_embedding)
if len(self.float_seq_feature_names) > 0:
self.num_float_seq_features = len(self.float_seq_feature_names)
self.float_seq_embeddings = torch.nn.ModuleList()
self.float_seq_embedding = torch.nn.ModuleList()
for i in range(self.num_fields):
for float_seq_feature_dim in self.float_seq_feature_dims:
self.float_seq_embedding.append(
nn.Embedding(float_seq_feature_dim, self.embed_dim)
)
for embedding in self.float_seq_embedding:
nn.init.xavier_uniform_(embedding.weight.data)
self.float_seq_embeddings.append(self.float_seq_embedding)
[docs] def forward(self, input_x):
r"""Model the different interaction strengths of different field pairs.
Args:
input_x (a tuple): (token_ffm_input, float_ffm_input, token_seq_ffm_input)
token_ffm_input (torch.cuda.FloatTensor): [batch_size, num_token_features] or None
float_ffm_input (torch.cuda.FloatTensor): [batch_size, num_float_features] or None
token_seq_ffm_input (list): length is num_token_seq_features or 0
Returns:
torch.cuda.FloatTensor: The results of all features' field-aware interactions.
shape: [batch_size, num_fields, emb_dim]
"""
token_ffm_input, float_ffm_input, token_seq_ffm_input, float_seq_ffm_input = (
input_x[0],
input_x[1],
input_x[2],
input_x[3],
)
token_input_x_emb = self._emb_token_ffm_input(token_ffm_input)
float_input_x_emb = self._emb_float_ffm_input(float_ffm_input)
token_seq_input_x_emb = self._emb_token_seq_ffm_input(token_seq_ffm_input)
float_seq_input_x_emb = self._emb_float_seq_ffm_input(float_seq_ffm_input)
input_x_emb = self._get_input_x_emb(
token_input_x_emb,
float_input_x_emb,
token_seq_input_x_emb,
float_seq_input_x_emb,
)
output = list()
for i in range(self.num_features - 1):
for j in range(i + 1, self.num_features):
output.append(
input_x_emb[self.feature2field[j]][:, i]
* input_x_emb[self.feature2field[i]][:, j]
)
output = torch.stack(output, dim=1) # [batch_size, num_fields, emb_dim]
return output
def _get_input_x_emb(
self,
token_input_x_emb,
float_input_x_emb,
token_seq_input_x_emb,
float_seq_input_x_emb,
):
# merge different types of field-aware embeddings
input_x_emb = [] # [num_fields: [batch_size, num_fields, emb_dim]]
zip_args = []
if len(self.token_feature_names) > 0:
zip_args.append(token_input_x_emb)
if len(self.float_feature_names) > 0:
zip_args.append(float_input_x_emb)
if len(self.token_seq_feature_names) > 0:
zip_args.append(token_seq_input_x_emb)
if len(self.float_seq_feature_names) > 0:
zip_args.append(float_seq_input_x_emb)
for tensors in zip(*zip_args):
input_x_emb.append(torch.cat(tensors, dim=1))
return input_x_emb
def _emb_token_ffm_input(self, token_ffm_input):
# get token field-aware embeddings
token_input_x_emb = []
if len(self.token_feature_names) > 0:
token_input_x = token_ffm_input + token_ffm_input.new_tensor(
self.token_offsets
).unsqueeze(0)
token_input_x_emb = [
self.token_embeddings[i](token_input_x) for i in range(self.num_fields)
] # [num_fields: [batch_size, num_token_features, emb_dim]]
return token_input_x_emb
def _emb_float_ffm_input(self, float_ffm_input):
# get float field-aware embeddings
float_input_x_emb = []
if len(self.float_feature_names) > 0:
base, index = torch.split(float_ffm_input, [1, 1], dim=-1)
index = index.squeeze(-1).long()
index = index + index.new_tensor(self.float_offsets).unsqueeze(0)
float_input_x_emb = [
self.float_embeddings[i](index) * base for i in range(self.num_fields)
] # [num_fields: [batch_size, num_float_features, emb_dim]]
return float_input_x_emb
def _emb_token_seq_ffm_input(self, token_seq_ffm_input):
# get token_seq field-aware embeddings
token_seq_input_x_emb = []
if len(self.token_seq_feature_names) > 0:
for i in range(self.num_fields):
token_seq_result = []
for j, token_seq in enumerate(token_seq_ffm_input):
embedding_table = self.token_seq_embeddings[i][j]
mask = token_seq != 0 # [batch_size, seq_len]
mask = mask.float()
value_cnt = torch.sum(mask, dim=1, keepdim=True) # [batch_size, 1]
token_seq_embedding = embedding_table(
token_seq
) # [batch_size, seq_len, embed_dim]
mask = mask.unsqueeze(2).expand_as(
token_seq_embedding
) # [batch_size, seq_len, embed_dim]
# mean
masked_token_seq_embedding = token_seq_embedding * mask.float()
result = torch.sum(
masked_token_seq_embedding, dim=1
) # [batch_size, embed_dim]
eps = torch.FloatTensor([1e-8]).to(self.device)
result = torch.div(
result, value_cnt + eps
) # [batch_size, embed_dim]
result = result.unsqueeze(1) # [batch_size, 1, embed_dim]
token_seq_result.append(result)
token_seq_input_x_emb.append(
torch.cat(token_seq_result, dim=1)
) # [num_fields: batch_size, num_token_seq_features, embed_dim]
return token_seq_input_x_emb
def _emb_float_seq_ffm_input(self, float_seq_ffm_input):
# get float_seq field-aware embeddings
float_seq_input_x_emb = []
if len(self.float_seq_feature_names) > 0:
for i in range(self.num_fields):
float_seq_result = []
for j, float_seq in enumerate(float_seq_ffm_input):
embedding_table = self.float_seq_embeddings[i][j]
base, index = torch.split(float_seq, [1, 1], dim=-1)
index = index.squeeze(-1)
mask = index != 0 # [batch_size, seq_len]
mask = mask.float()
value_cnt = torch.sum(mask, dim=1, keepdim=True) # [batch_size, 1]
float_seq_embedding = base * embedding_table(
index.long()
) # [batch_size, seq_len, embed_dim]
mask = mask.unsqueeze(2).expand_as(
float_seq_embedding
) # [batch_size, seq_len, embed_dim]
# mean
masked_float_seq_embedding = float_seq_embedding * mask.float()
result = torch.sum(
masked_float_seq_embedding, dim=1
) # [batch_size, embed_dim]
eps = torch.FloatTensor([1e-8]).to(self.device)
result = torch.div(
result, value_cnt + eps
) # [batch_size, embed_dim]
result = result.unsqueeze(1) # [batch_size, 1, embed_dim]
float_seq_result.append(result)
float_seq_input_x_emb.append(
torch.cat(float_seq_result, dim=1)
) # [num_fields: batch_size, num_token_seq_features, embed_dim]
return float_seq_input_x_emb