# _*_ coding: utf-8 _*_
# @Time : 2023/1/20
# @Author : Wanli Yang
# @Email : 2013774@mail.nankai.edu.cn
r"""
KD_DAGFM
################################################
Reference:
Zhen Tian et al. "Directed Acyclic Graph Factorization Machines for CTR Prediction via Knowledge Distillation."
in WSDM 2023.
Reference code:
https://github.com/chenyuwuxin/DAGFM
"""
import torch
from torch import nn
from torch.nn.init import xavier_normal_
from copy import deepcopy
from recbole.model.init import xavier_normal_initialization
from recbole.model.abstract_recommender import ContextRecommender
[docs]class KD_DAGFM(ContextRecommender):
r"""KD_DAGFM is a context-based recommendation model. The model is based on directed acyclic graph and knowledge
distillation. It can learn arbitrary feature interactions from the complex teacher networks and achieve
approximately lossless model performance. It can also greatly reduce the computational resource costs.
"""
def __init__(self, config, dataset):
super(KD_DAGFM, self).__init__(config, dataset)
# load parameters info
self.phase = config["phase"]
self.alpha = config["alpha"]
self.beta = config["beta"]
# add element to config for the initialization of teacher&student network
config["feature_num"] = self.num_feature_field
# initialize teacher&student network
self.student_network = DAGFM(config)
self.teacher_network = eval(f"{config['teacher']}")(
self.get_teacher_config(config)
)
# initialize loss function
self.loss_fn = nn.BCELoss()
# get warm up parameters
if self.phase != "teacher_training":
if "warm_up" not in config:
raise ValueError("Must have warm up!")
else:
save_info = torch.load(config["warm_up"])
self.load_state_dict(save_info["state_dict"])
else:
self.apply(xavier_normal_initialization)
# get config of teacher network from config
[docs] def get_teacher_config(self, config):
teacher_cfg = deepcopy(config)
for key in config.final_config_dict:
if key.startswith("t_"):
teacher_cfg[key[2:]] = config[key]
return teacher_cfg
[docs] def FeatureInteraction(self, feature):
if self.phase == "teacher_training":
return self.teacher_network.FeatureInteraction(feature)
elif self.phase == "distillation" or self.phase == "finetuning":
return self.student_network.FeatureInteraction(feature)
else:
return ValueError("Phase invalid!")
[docs] def forward(self, interaction):
dagfm_all_embeddings = self.concat_embed_input_fields(
interaction
) # [batch_size, num_field, embed_dim]
if self.phase == "teacher_training" or self.phase == "finetuning":
return self.FeatureInteraction(dagfm_all_embeddings)
elif self.phase == "distillation":
dagfm_all_embeddings = dagfm_all_embeddings.data
if self.training:
self.t_pred = self.teacher_network(dagfm_all_embeddings)
return self.FeatureInteraction(dagfm_all_embeddings)
else:
raise ValueError("Phase invalid!")
[docs] def calculate_loss(self, interaction):
if self.phase == "teacher_training" or self.phase == "finetuning":
prediction = self.forward(interaction)
loss = self.loss_fn(
prediction.squeeze(-1),
interaction[self.LABEL].squeeze(-1).to(self.device),
)
elif self.phase == "distillation":
self.teacher_network.eval()
s_pred = self.forward(interaction)
ctr_loss = self.loss_fn(
s_pred.squeeze(-1), interaction[self.LABEL].squeeze(-1).to(self.device)
)
kd_loss = torch.mean(
(self.teacher_network.logits.data - self.student_network.logits) ** 2
)
loss = self.alpha * ctr_loss + self.beta * kd_loss
else:
raise ValueError("Phase invalid!")
return loss
[docs] def predict(self, interaction):
return self.forward(interaction)
[docs]class DAGFM(nn.Module):
def __init__(self, config):
super(DAGFM, self).__init__()
if torch.cuda.is_available():
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
# load parameters info
self.type = config["type"]
self.depth = config["depth"]
field_num = config["feature_num"]
embedding_size = config["embedding_size"]
# initialize parameters according to the type
if self.type == "inner":
self.p = nn.ParameterList(
[
nn.Parameter(torch.randn(field_num, field_num, embedding_size))
for _ in range(self.depth)
]
)
for _ in range(self.depth):
xavier_normal_(self.p[_], gain=1.414)
elif self.type == "outer":
self.p = nn.ParameterList(
[
nn.Parameter(torch.randn(field_num, field_num, embedding_size))
for _ in range(self.depth)
]
)
self.q = nn.ParameterList(
[
nn.Parameter(torch.randn(field_num, field_num, embedding_size))
for _ in range(self.depth)
]
)
for _ in range(self.depth):
xavier_normal_(self.p[_], gain=1.414)
xavier_normal_(self.q[_], gain=1.414)
self.adj_matrix = torch.zeros(field_num, field_num, embedding_size).to(
self.device
)
for i in range(field_num):
for j in range(i, field_num):
self.adj_matrix[i, j, :] += 1
self.connect_layer = nn.Parameter(torch.eye(field_num).float())
self.linear = nn.Linear(field_num * (self.depth + 1), 1)
[docs] def FeatureInteraction(self, feature):
init_state = self.connect_layer @ feature
h0, ht = init_state, init_state
state = [torch.sum(init_state, dim=-1)]
for i in range(self.depth):
if self.type == "inner":
aggr = torch.einsum("bfd,fsd->bsd", ht, self.p[i] * self.adj_matrix)
ht = h0 * aggr
elif self.type == "outer":
term = torch.einsum("bfd,fsd->bfs", ht, self.p[i] * self.adj_matrix)
aggr = torch.einsum("bfs,fsd->bsd", term, self.q[i])
ht = h0 * aggr
state.append(torch.sum(ht, dim=-1))
state = torch.cat(state, dim=-1)
self.logits = self.linear(state)
self.outputs = torch.sigmoid(self.logits)
return self.outputs
# teacher network CrossNet
[docs]class CrossNet(nn.Module):
def __init__(self, config):
super(CrossNet, self).__init__()
# load parameters info
self.depth = config["depth"]
self.embedding_size = config["embedding_size"]
self.feature_num = config["feature_num"]
self.in_feature_num = self.feature_num * self.embedding_size
self.cross_layer_w = nn.ParameterList(
nn.Parameter(torch.randn(self.in_feature_num, self.in_feature_num))
for _ in range(self.depth)
)
self.bias = nn.ParameterList(
nn.Parameter(torch.zeros(self.in_feature_num, 1)) for _ in range(self.depth)
)
self.linear = nn.Linear(self.in_feature_num, 1)
nn.init.normal_(self.linear.weight)
[docs] def FeatureInteraction(self, x_0):
x_0 = x_0.reshape(x_0.shape[0], -1)
x_0 = x_0.unsqueeze(dim=2)
x_l = x_0 # (batch_size, in_feature_num, 1)
for i in range(self.depth):
xl_w = torch.matmul(self.cross_layer_w[i], x_l)
xl_w = xl_w + self.bias[i]
xl_dot = torch.mul(x_0, xl_w)
x_l = xl_dot + x_l
x_l = x_l.squeeze(dim=2)
self.logits = self.linear(x_l)
self.outputs = torch.sigmoid(self.logits)
return self.outputs
[docs] def forward(self, feature):
return self.FeatureInteraction(feature)
[docs]class CINComp(nn.Module):
def __init__(self, indim, outdim, config):
super(CINComp, self).__init__()
basedim = config["feature_num"]
self.conv = nn.Conv1d(indim * basedim, outdim, 1)
[docs] def forward(self, feature, base):
return self.conv(
(feature[:, :, None, :] * base[:, None, :, :]).reshape(
feature.shape[0], feature.shape[1] * base.shape[1], -1
)
)
# teacher network CIN
[docs]class CIN(nn.Module):
def __init__(self, config):
super().__init__()
self.cinlist = [config["feature_num"]] + config["cin"]
self.cin = nn.ModuleList(
[
CINComp(self.cinlist[i], self.cinlist[i + 1], config)
for i in range(0, len(self.cinlist) - 1)
]
)
self.linear = nn.Parameter(torch.zeros(sum(self.cinlist) - self.cinlist[0], 1))
nn.init.normal_(self.linear, mean=0, std=0.01)
self.backbone = ["cin", "linear"]
self.loss_fn = nn.BCELoss()
if torch.cuda.is_available():
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
[docs] def FeatureInteraction(self, feature):
base = feature
x = feature
p = []
for comp in self.cin:
x = comp(x, base)
p.append(torch.sum(x, dim=-1))
p = torch.cat(p, dim=-1)
self.logits = p @ self.linear
self.outputs = torch.sigmoid(self.logits)
return self.outputs
[docs] def forward(self, feature):
return self.FeatureInteraction(feature)