Source code for recbole.model.general_recommender.ldiffrec

# -*- coding: utf-8 -*-
# @Time   : 2023/10/6
# @Author : Enze Liu
# @Email  : enzeeliu@foxmail.com

r"""
DiffRec
################################################
Reference:
    Wenjie Wang et al. "Diffusion Recommender Model." in SIGIR 2023.

Reference code:
    https://github.com/YiyanXu/DiffRec
"""

import os
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
from recbole.model.init import xavier_normal_initialization
from recbole.model.layers import MLPLayers
from recbole.model.general_recommender.diffrec import (
    DiffRec,
    DNN,
    ModelMeanType,
    mean_flat,
)


[docs]class AutoEncoder(nn.Module): r""" Guassian Diffusion for large-scale recommendation. """ def __init__( self, item_emb, n_cate, in_dims, out_dims, device, act_func, reparam=True, dropout=0.1, ): super(AutoEncoder, self).__init__() self.item_emb = item_emb self.n_cate = n_cate self.in_dims = in_dims self.out_dims = out_dims self.act_func = act_func self.n_item = len(item_emb) self.reparam = reparam self.dropout = nn.Dropout(dropout) if n_cate == 1: # no clustering in_dims_temp = ( [self.n_item + 1] + self.in_dims[:-1] + [self.in_dims[-1] * 2] ) out_dims_temp = [self.in_dims[-1]] + self.out_dims + [self.n_item + 1] self.encoder = MLPLayers(in_dims_temp, activation=self.act_func) self.decoder = MLPLayers( out_dims_temp, activation=self.act_func, last_activation=False ) else: from kmeans_pytorch import kmeans self.cluster_ids, _ = kmeans( X=item_emb, num_clusters=n_cate, distance="euclidean", device=device ) # cluster_ids(labels): [0, 1, 2, 2, 1, 0, 0, ...] category_idx = [] for i in range(n_cate): idx = np.argwhere(self.cluster_ids.numpy() == i).flatten().tolist() category_idx.append(torch.tensor(idx, dtype=int) + 1) self.category_idx = category_idx # [cate1: [iid1, iid2, ...], cate2: [iid3, iid4, ...], cate3: [iid5, iid6, ...]] self.category_map = torch.cat(tuple(category_idx), dim=-1) # map self.category_len = [ len(self.category_idx[i]) for i in range(n_cate) ] # item num in each category print("category length: ", self.category_len) assert sum(self.category_len) == self.n_item ##### Build the Encoder and Decoder ##### encoders = [] decode_dim = [] for i in range(n_cate): if i == n_cate - 1: latent_dims = list(self.in_dims - np.array(decode_dim).sum(axis=0)) else: latent_dims = [ int(self.category_len[i] / self.n_item * self.in_dims[j]) for j in range(len(self.in_dims)) ] latent_dims = [ latent_dims[j] if latent_dims[j] != 0 else 1 for j in range(len(self.in_dims)) ] in_dims_temp = ( [self.category_len[i]] + latent_dims[:-1] + [latent_dims[-1] * 2] ) encoders.append(MLPLayers(in_dims_temp, activation=self.act_func)) decode_dim.append(latent_dims) self.encoder = nn.ModuleList(encoders) print("Latent dims of each category: ", decode_dim) self.decode_dim = [decode_dim[i][::-1] for i in range(len(decode_dim))] if len(out_dims) == 0: # one-layer decoder: [encoder_dim_sum, n_item] out_dim = self.in_dims[-1] self.decoder = MLPLayers([out_dim, self.n_item], activation=None) else: # multi-layer decoder: [encoder_dim, hidden_size, cate_num] # decoder_modules = [[] for _ in range(n_cate)] decoders = [] for i in range(n_cate): out_dims_temp = self.decode_dim[i] + [self.category_len[i]] decoders.append( MLPLayers( out_dims_temp, activation=self.act_func, last_activation=False, ) ) self.decoder = nn.ModuleList(decoders) self.apply(xavier_normal_initialization)
[docs] def Encode(self, batch): batch = self.dropout(batch) if self.n_cate == 1: hidden = self.encoder(batch) mu = hidden[:, : self.in_dims[-1]] logvar = hidden[:, self.in_dims[-1] :] if self.training and self.reparam: latent = self.reparamterization(mu, logvar) else: latent = mu kl_divergence = -0.5 * torch.mean( torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1) ) return batch, latent, kl_divergence else: batch_cate = [] for i in range(self.n_cate): batch_cate.append(batch[:, self.category_idx[i]]) # [batch_size, n_items] -> [[batch_size, n1_items], [batch_size, n2_items], [batch_size, n3_items]] latent_mu = [] latent_logvar = [] for i in range(self.n_cate): hidden = self.encoder[i](batch_cate[i]) latent_mu.append(hidden[:, : self.decode_dim[i][0]]) latent_logvar.append(hidden[:, self.decode_dim[i][0] :]) # latent: [[batch_size, latent_size1], [batch_size, latent_size2], [batch_size, latent_size3]] mu = torch.cat(tuple(latent_mu), dim=-1) logvar = torch.cat(tuple(latent_logvar), dim=-1) if self.training and self.reparam: latent = self.reparamterization(mu, logvar) else: latent = mu kl_divergence = -0.5 * torch.mean( torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1) ) return torch.cat(tuple(batch_cate), dim=-1), latent, kl_divergence
[docs] def reparamterization(self, mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return eps.mul(std).add_(mu)
[docs] def Decode(self, batch): if len(self.out_dims) == 0 or self.n_cate == 1: # one-layer decoder return self.decoder(batch) else: batch_cate = [] start = 0 for i in range(self.n_cate): end = start + self.decode_dim[i][0] batch_cate.append(batch[:, start:end]) start = end pred_cate = [] for i in range(self.n_cate): pred_cate.append(self.decoder[i](batch_cate[i])) pred = torch.cat(tuple(pred_cate), dim=-1) return pred
[docs]class LDiffRec(DiffRec): r""" L-DiffRec clusters items into groups, compresses the interaction vector over each group into a low-dimensional latent vector via a group-specific VAE, and conducts the forward and reverse diffusion processes in the latent space. """ def __init__(self, config, dataset): super(LDiffRec, self).__init__(config, dataset) self.n_cate = config["n_cate"] self.reparam = config["reparam"] self.ae_act_func = config["ae_act_func"] self.in_dims = config["in_dims"] self.out_dims = config["out_dims"] # control loss in training self.update_count = 0 self.update_count_vae = 0 self.lamda = config["lamda"] self.anneal_cap = config["anneal_cap"] self.anneal_steps = config["anneal_steps"] self.vae_anneal_cap = config["vae_anneal_cap"] self.vae_anneal_steps = config["vae_anneal_steps"] out_dims = self.out_dims in_dims = self.in_dims[::-1] emb_path = os.path.join(dataset.dataset_path, f"item_emb.npy") if self.n_cate > 1: if not os.path.exists(emb_path): self.logger.exception( "The item embedding file must be given when n_cate>1." ) item_emb = torch.from_numpy(np.load(emb_path, allow_pickle=True)) else: item_emb = torch.zeros((self.n_items - 1, 64)) self.autoencoder = AutoEncoder( item_emb, self.n_cate, in_dims, out_dims, self.device, self.ae_act_func, self.reparam, ).to(self.device) self.latent_size = in_dims[-1] dims = [self.latent_size] + config["dims_dnn"] + [self.latent_size] self.mlp = DNN( dims=dims, emb_size=self.emb_size, time_type="cat", norm=self.norm, act_func=self.mlp_act_func, ).to(self.device)
[docs] def calculate_loss(self, interaction): user = interaction[self.USER_ID] batch = self.get_rating_matrix(user) batch_cate, batch_latent, vae_kl = self.autoencoder.Encode(batch) # calculate loss in diffusion batch_size, device = batch_latent.size(0), batch_latent.device ts, pt = self.sample_timesteps(batch_size, device, "importance") noise = torch.randn_like(batch_latent) if self.noise_scale != 0.0: x_t = self.q_sample(batch_latent, ts, noise) else: x_t = batch_latent model_output = self.mlp(x_t, ts) target = { ModelMeanType.START_X: batch_latent, ModelMeanType.EPSILON: noise, }[self.mean_type] assert model_output.shape == target.shape == batch_latent.shape mse = mean_flat((target - model_output) ** 2) reloss = self.reweight_loss( batch_latent, x_t, mse, ts, target, model_output, device ) if self.mean_type == ModelMeanType.START_X: batch_latent_recon = model_output else: batch_latent_recon = self._predict_xstart_from_eps(x_t, ts, model_output) self.update_Lt_history(ts, reloss) diff_loss = (reloss / pt).mean() batch_recon = self.autoencoder.Decode(batch_latent_recon) if self.anneal_steps > 0: lamda = max( (1.0 - self.update_count / self.anneal_steps) * self.lamda, self.anneal_cap, ) else: lamda = max(self.lamda, self.anneal_cap) if self.vae_anneal_steps > 0: anneal = min( self.vae_anneal_cap, 1.0 * self.update_count_vae / self.vae_anneal_steps ) else: anneal = self.vae_anneal_cap self.update_count_vae += 1 self.update_count += 1 vae_loss = compute_loss(batch_recon, batch_cate) + anneal * vae_kl loss = lamda * diff_loss + vae_loss return loss
[docs] def full_sort_predict(self, interaction): user = interaction[self.USER_ID] batch = self.get_rating_matrix(user) _, batch_latent, _ = self.autoencoder.Encode(batch) batch_latent_recon = super(LDiffRec, self).p_sample(batch_latent) prediction = self.autoencoder.Decode( batch_latent_recon ) # [batch_size, n1_items + n2_items + n3_items] if self.n_cate > 1: transform = torch.zeros((prediction.shape[0], prediction.shape[1] + 1)).to( prediction.device ) transform[:, self.autoencoder.category_map] = prediction else: transform = prediction return transform
[docs] def predict(self, interaction): item = interaction[self.ITEM_ID] x_t = self.full_sort_predict(interaction) scores = x_t[:, item] return scores
[docs]def compute_loss(recon_x, x): return -torch.mean( torch.sum(F.log_softmax(recon_x, 1) * x, -1) ) # multinomial log likelihood in MultVAE