# @Time : 2020/6/26
# @Author : Shanlei Mu
# @Email : slmu@ruc.edu.cn
# UPDATE:
# @Time : 2020/8/7
# @Author : Shanlei Mu
# @Email : slmu@ruc.edu.cn
"""
recbole.model.loss
#######################
Common Loss in recommender system
"""
import torch
import torch.nn as nn
[docs]class BPRLoss(nn.Module):
""" BPRLoss, based on Bayesian Personalized Ranking
Args:
- gamma(float): Small value to avoid division by zero
Shape:
- Pos_score: (N)
- Neg_score: (N), same shape as the Pos_score
- Output: scalar.
Examples::
>>> loss = BPRLoss()
>>> pos_score = torch.randn(3, requires_grad=True)
>>> neg_score = torch.randn(3, requires_grad=True)
>>> output = loss(pos_score, neg_score)
>>> output.backward()
"""
def __init__(self, gamma=1e-10):
super(BPRLoss, self).__init__()
self.gamma = gamma
[docs] def forward(self, pos_score, neg_score):
loss = - torch.log(self.gamma + torch.sigmoid(pos_score - neg_score)).mean()
return loss
[docs]class RegLoss(nn.Module):
""" RegLoss, L2 regularization on model parameters
"""
def __init__(self):
super(RegLoss, self).__init__()
[docs] def forward(self, parameters):
reg_loss = None
for W in parameters:
if reg_loss is None:
reg_loss = W.norm(2)
else:
reg_loss = reg_loss + W.norm(2)
return reg_loss
[docs]class EmbLoss(nn.Module):
""" EmbLoss, regularization on embeddings
"""
def __init__(self, norm=2):
super(EmbLoss, self).__init__()
self.norm = norm
[docs] def forward(self, *embeddings):
emb_loss = torch.zeros(1).to(embeddings[-1].device)
for embedding in embeddings:
emb_loss += torch.norm(embedding, p=self.norm)
emb_loss /= embeddings[-1].shape[0]
return emb_loss
[docs]class EmbMarginLoss(nn.Module):
""" EmbMarginLoss, regularization on embeddings
"""
def __init__(self, power=2):
super(EmbMarginLoss, self).__init__()
self.power = power
[docs] def forward(self, *embeddings):
dev = embeddings[-1].device
cache_one = torch.tensor(1.0).to(dev)
cache_zero = torch.tensor(0.0).to(dev)
emb_loss = torch.tensor(0.).to(dev)
for embedding in embeddings:
norm_e = torch.sum(embedding ** self.power, dim=1, keepdim=True)
emb_loss += torch.sum(torch.max(norm_e - cache_one, cache_zero))
return emb_loss