Source code for recbole.model.general_recommender.recvae

# -*- coding: utf-8 -*-
# @Time   : 2021/2/28
# @Author : Lanling Xu
# @Email  : xulanling_sherry@163.com

r"""
RecVAE
################################################
Reference:
    Shenbin, Ilya, et al. "RecVAE: A new variational autoencoder for Top-N recommendations with implicit feedback." In WSDM 2020.

Reference code:
    https://github.com/ilya-shenbin/RecVAE
"""

import numpy as np
from copy import deepcopy

import torch
import torch.nn as nn
import torch.nn.functional as F

from recbole.model.abstract_recommender import GeneralRecommender, AutoEncoderMixin
from recbole.model.init import xavier_normal_initialization
from recbole.utils import InputType


[docs]def swish(x): r"""Swish activation function: .. math:: \text{Swish}(x) = \frac{x}{1 + \exp(-x)} """ return x.mul(torch.sigmoid(x))
[docs]def log_norm_pdf(x, mu, logvar): return -0.5 * (logvar + np.log(2 * np.pi) + (x - mu).pow(2) / logvar.exp())
[docs]class CompositePrior(nn.Module): def __init__(self, hidden_dim, latent_dim, input_dim, mixture_weights): super(CompositePrior, self).__init__() self.mixture_weights = mixture_weights self.mu_prior = nn.Parameter(torch.Tensor(1, latent_dim), requires_grad=False) self.mu_prior.data.fill_(0) self.logvar_prior = nn.Parameter( torch.Tensor(1, latent_dim), requires_grad=False ) self.logvar_prior.data.fill_(0) self.logvar_uniform_prior = nn.Parameter( torch.Tensor(1, latent_dim), requires_grad=False ) self.logvar_uniform_prior.data.fill_(10) self.encoder_old = Encoder(hidden_dim, latent_dim, input_dim) self.encoder_old.requires_grad_(False)
[docs] def forward(self, x, z): post_mu, post_logvar = self.encoder_old(x, 0) stnd_prior = log_norm_pdf(z, self.mu_prior, self.logvar_prior) post_prior = log_norm_pdf(z, post_mu, post_logvar) unif_prior = log_norm_pdf(z, self.mu_prior, self.logvar_uniform_prior) gaussians = [stnd_prior, post_prior, unif_prior] gaussians = [g.add(np.log(w)) for g, w in zip(gaussians, self.mixture_weights)] density_per_gaussian = torch.stack(gaussians, dim=-1) return torch.logsumexp(density_per_gaussian, dim=-1)
[docs]class Encoder(nn.Module): def __init__(self, hidden_dim, latent_dim, input_dim, eps=1e-1): super(Encoder, self).__init__() self.fc1 = nn.Linear(input_dim, hidden_dim) self.ln1 = nn.LayerNorm(hidden_dim, eps=eps) self.fc2 = nn.Linear(hidden_dim, hidden_dim) self.ln2 = nn.LayerNorm(hidden_dim, eps=eps) self.fc3 = nn.Linear(hidden_dim, hidden_dim) self.ln3 = nn.LayerNorm(hidden_dim, eps=eps) self.fc4 = nn.Linear(hidden_dim, hidden_dim) self.ln4 = nn.LayerNorm(hidden_dim, eps=eps) self.fc5 = nn.Linear(hidden_dim, hidden_dim) self.ln5 = nn.LayerNorm(hidden_dim, eps=eps) self.fc_mu = nn.Linear(hidden_dim, latent_dim) self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
[docs] def forward(self, x, dropout_prob): x = F.normalize(x) x = F.dropout(x, dropout_prob, training=self.training) h1 = self.ln1(swish(self.fc1(x))) h2 = self.ln2(swish(self.fc2(h1) + h1)) h3 = self.ln3(swish(self.fc3(h2) + h1 + h2)) h4 = self.ln4(swish(self.fc4(h3) + h1 + h2 + h3)) h5 = self.ln5(swish(self.fc5(h4) + h1 + h2 + h3 + h4)) return self.fc_mu(h5), self.fc_logvar(h5)
[docs]class RecVAE(GeneralRecommender, AutoEncoderMixin): r"""Collaborative Denoising Auto-Encoder (RecVAE) is a recommendation model for top-N recommendation with implicit feedback. We implement the model following the original author """ input_type = InputType.PAIRWISE def __init__(self, config, dataset): super(RecVAE, self).__init__(config, dataset) self.hidden_dim = config["hidden_dimension"] self.latent_dim = config["latent_dimension"] self.dropout_prob = config["dropout_prob"] self.beta = config["beta"] self.mixture_weights = config["mixture_weights"] self.gamma = config["gamma"] self.build_histroy_items(dataset) self.encoder = Encoder(self.hidden_dim, self.latent_dim, self.n_items) self.prior = CompositePrior( self.hidden_dim, self.latent_dim, self.n_items, self.mixture_weights ) self.decoder = nn.Linear(self.latent_dim, self.n_items) # parameters initialization self.apply(xavier_normal_initialization)
[docs] def reparameterize(self, mu, logvar): if self.training: std = torch.exp(0.5 * logvar) epsilon = torch.zeros_like(std).normal_(mean=0, std=0.01) return mu + epsilon * std else: return mu
[docs] def forward(self, rating_matrix, dropout_prob): mu, logvar = self.encoder(rating_matrix, dropout_prob=dropout_prob) z = self.reparameterize(mu, logvar) x_pred = self.decoder(z) return x_pred, mu, logvar, z
[docs] def calculate_loss(self, interaction, encoder_flag): user = interaction[self.USER_ID] rating_matrix = self.get_rating_matrix(user) if encoder_flag: dropout_prob = self.dropout_prob else: dropout_prob = 0 x_pred, mu, logvar, z = self.forward(rating_matrix, dropout_prob) if self.gamma: norm = rating_matrix.sum(dim=-1) kl_weight = self.gamma * norm else: kl_weight = self.beta mll = (F.log_softmax(x_pred, dim=-1) * rating_matrix).sum(dim=-1).mean() kld = ( (log_norm_pdf(z, mu, logvar) - self.prior(rating_matrix, z)) .sum(dim=-1) .mul(kl_weight) .mean() ) negative_elbo = -(mll - kld) return negative_elbo
[docs] def predict(self, interaction): user = interaction[self.USER_ID] item = interaction[self.ITEM_ID] rating_matrix = self.get_rating_matrix(user) scores, _, _, _ = self.forward(rating_matrix, self.dropout_prob) return scores[[torch.arange(len(item)).to(self.device), item]]
[docs] def full_sort_predict(self, interaction): user = interaction[self.USER_ID] rating_matrix = self.get_rating_matrix(user) scores, _, _, _ = self.forward(rating_matrix, self.dropout_prob) return scores.view(-1)
[docs] def update_prior(self): self.prior.encoder_old.load_state_dict(deepcopy(self.encoder.state_dict()))