Source code for recbole.model.general_recommender.diffrec

# -*- coding: utf-8 -*-
# @Time   : 2023/10/6
# @Author : Enze Liu
# @Email  : enzeeliu@foxmail.com

r"""
DiffRec
################################################
Reference:
    Wenjie Wang et al. "Diffusion Recommender Model." in SIGIR 2023.

Reference code:
    https://github.com/YiyanXu/DiffRec
"""

import enum
import math
import copy
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
from recbole.model.init import xavier_normal_initialization
from recbole.utils.enum_type import InputType
from recbole.model.abstract_recommender import AutoEncoderMixin, GeneralRecommender
from recbole.model.layers import MLPLayers
import typing


[docs]class ModelMeanType(enum.Enum): START_X = enum.auto() # the model predicts x_0 EPSILON = enum.auto() # the model predicts epsilon
[docs]class DNN(nn.Module): r""" A deep neural network for the reverse diffusion preocess. """ def __init__( self, dims: typing.List, emb_size: int, time_type="cat", act_func="tanh", norm=False, dropout=0.5, ): super(DNN, self).__init__() self.dims = dims self.time_type = time_type self.time_emb_dim = emb_size self.norm = norm self.emb_layer = nn.Linear(self.time_emb_dim, self.time_emb_dim) if self.time_type == "cat": # Concatenate timestep embedding with input self.dims[0] += self.time_emb_dim else: raise ValueError( "Unimplemented timestep embedding type %s" % self.time_type ) self.mlp_layers = MLPLayers( layers=self.dims, dropout=0, activation=act_func, last_activation=False ) self.drop = nn.Dropout(dropout) self.apply(xavier_normal_initialization)
[docs] def forward(self, x, timesteps): time_emb = timestep_embedding(timesteps, self.time_emb_dim).to(x.device) emb = self.emb_layer(time_emb) if self.norm: x = F.normalize(x) x = self.drop(x) h = torch.cat([x, emb], dim=-1) h = self.mlp_layers(h) return h
[docs]class DiffRec(GeneralRecommender, AutoEncoderMixin): r""" DiffRec is a generative recommender model which infers users' interaction probabilities in a denoising manner. Note that DiffRec simultaneously ranks all items for each user. We implement the the DiffRec model with only user dataloader. """ input_type = InputType.LISTWISE def __init__(self, config, dataset): super(DiffRec, self).__init__(config, dataset) if config["mean_type"] == "x0": self.mean_type = ModelMeanType.START_X elif config["mean_type"] == "eps": self.mean_type = ModelMeanType.EPSILON else: raise ValueError("Unimplemented mean type %s" % config["mean_type"]) self.time_aware = config["time-aware"] self.w_max = config["w_max"] self.w_min = config["w_min"] self.build_histroy_items(dataset) self.noise_schedule = config["noise_schedule"] self.noise_scale = config["noise_scale"] self.noise_min = config["noise_min"] self.noise_max = config["noise_max"] self.steps = config["steps"] self.beta_fixed = config["beta_fixed"] self.emb_size = config["embedding_size"] self.norm = config["norm"] # True or False self.reweight = config["reweight"] # reweight the loss for different timesteps self.sampling_noise = config[ "sampling_noise" ] # whether sample noise during predict self.sampling_steps = config["sampling_steps"] self.mlp_act_func = config["mlp_act_func"] assert self.sampling_steps <= self.steps, "Too much steps in inference." self.history_num_per_term = config["history_num_per_term"] self.Lt_history = torch.zeros( self.steps, self.history_num_per_term, dtype=torch.float64 ).to(self.device) self.Lt_count = torch.zeros(self.steps, dtype=int).to(self.device) dims = [self.n_items] + config["dims_dnn"] + [self.n_items] self.mlp = DNN( dims=dims, emb_size=self.emb_size, time_type="cat", norm=self.norm, act_func=self.mlp_act_func, ).to(self.device) if self.noise_scale != 0.0: self.betas = torch.tensor(self.get_betas(), dtype=torch.float64).to( self.device ) if self.beta_fixed: self.betas[ 0 ] = 0.00001 # Deep Unsupervised Learning using Noneequilibrium Thermodynamics 2.4.1 # The variance \beta_1 of the first step is fixed to a small constant to prevent overfitting. assert len(self.betas.shape) == 1, "betas must be 1-D" assert ( len(self.betas) == self.steps ), "num of betas must equal to diffusion steps" assert (self.betas > 0).all() and ( self.betas <= 1 ).all(), "betas out of range" self.calculate_for_diffusion()
[docs] def build_histroy_items(self, dataset): r""" Add time-aware reweighting to the original user-item interaction matrix when config['time-aware'] is True. """ if not self.time_aware: super().build_histroy_items(dataset) else: inter_feat = copy.deepcopy(dataset.inter_feat) inter_feat.sort(dataset.time_field) user_ids, item_ids = ( inter_feat[dataset.uid_field].numpy(), inter_feat[dataset.iid_field].numpy(), ) w_max = self.w_max w_min = self.w_min values = np.zeros(len(inter_feat)) row_num = dataset.user_num row_ids, col_ids = user_ids, item_ids for uid in range(1, row_num + 1): uindex = np.argwhere(user_ids == uid).flatten() int_num = len(uindex) weight = np.linspace(w_min, w_max, int_num) values[uindex] = weight history_len = np.zeros(row_num, dtype=np.int64) for row_id in row_ids: history_len[row_id] += 1 max_inter_num = np.max(history_len) col_num = max_inter_num history_matrix = np.zeros((row_num, col_num), dtype=np.int64) history_value = np.zeros((row_num, col_num)) history_len[:] = 0 for row_id, value, col_id in zip(row_ids, values, col_ids): if history_len[row_id] >= col_num: continue history_matrix[row_id, history_len[row_id]] = col_id history_value[row_id, history_len[row_id]] = value history_len[row_id] += 1 self.history_item_id = torch.LongTensor(history_matrix) self.history_item_value = torch.FloatTensor(history_value) self.history_item_id = self.history_item_id.to(self.device) self.history_item_value = self.history_item_value.to(self.device)
[docs] def get_betas(self): r""" Given the schedule name, create the betas for the diffusion process. """ if self.noise_schedule == "linear" or self.noise_schedule == "linear-var": start = self.noise_scale * self.noise_min end = self.noise_scale * self.noise_max if self.noise_schedule == "linear": return np.linspace(start, end, self.steps, dtype=np.float64) else: return betas_from_linear_variance( self.steps, np.linspace(start, end, self.steps, dtype=np.float64) ) elif self.noise_schedule == "cosine": return betas_for_alpha_bar( self.steps, lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 ) # Deep Unsupervised Learning using Noneequilibrium Thermodynamics 2.4.1 elif self.noise_schedule == "binomial": ts = np.arange(self.steps) betas = [1 / (self.steps - t + 1) for t in ts] return betas else: raise NotImplementedError(f"unknown beta schedule: {self.noise_schedule}!")
[docs] def calculate_for_diffusion(self): r""" Calculate the coefficients for the diffusion process. """ alphas = 1.0 - self.betas # [alpha_{1}, ..., alpha_{1}*...*alpha_{T}] shape (steps,) self.alphas_cumprod = torch.cumprod(alphas, axis=0).to(self.device) # alpha_{t-1} self.alphas_cumprod_prev = torch.cat( [torch.tensor([1.0]).to(self.device), self.alphas_cumprod[:-1]] ).to(self.device) # alpha_{t+1} self.alphas_cumprod_next = torch.cat( [self.alphas_cumprod[1:], torch.tensor([0.0]).to(self.device)] ).to(self.device) assert self.alphas_cumprod_prev.shape == (self.steps,) self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod) self.log_one_minus_alphas_cumprod = torch.log(1.0 - self.alphas_cumprod) self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod) self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod - 1) self.posterior_variance = ( self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) ) self.posterior_log_variance_clipped = torch.log( torch.cat( [self.posterior_variance[1].unsqueeze(0), self.posterior_variance[1:]] ) ) # Eq.10 coef for x_theta self.posterior_mean_coef1 = ( self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) ) # Eq.10 coef for x_t self.posterior_mean_coef2 = ( (1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - self.alphas_cumprod) )
[docs] def p_sample(self, x_start): r""" Generate users' interaction probabilities in a denoising manner. Args: x_start (torch.FloatTensor): the input tensor that contains user's history interaction matrix, for DiffRec shape: [batch_size, n_items] for LDiffRec shape: [batch_size, hidden_size] Returns: torch.FloatTensor: the interaction probabilities, for DiffRec shape: [batch_size, n_items] for LDiffRec shape: [batch_size, hidden_size] """ steps = self.sampling_steps if steps == 0: x_t = x_start else: t = torch.tensor([steps - 1] * x_start.shape[0]).to(x_start.device) x_t = self.q_sample(x_start, t) indices = list(range(self.steps))[::-1] if self.noise_scale == 0.0: for i in indices: t = torch.tensor([i] * x_t.shape[0]).to(x_start.device) x_t = self.mlp(x_t, t) return x_t for i in indices: t = torch.tensor([i] * x_t.shape[0]).to(x_start.device) out = self.p_mean_variance(x_t, t) if self.sampling_noise: noise = torch.randn_like(x_t) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x_t.shape) - 1))) ) # no noise when t == 0 x_t = ( out["mean"] + nonzero_mask * torch.exp(0.5 * out["log_variance"]) * noise ) else: x_t = out["mean"] return x_t
[docs] def full_sort_predict(self, interaction): user = interaction[self.USER_ID] x_start = self.get_rating_matrix(user) scores = self.p_sample(x_start) return scores
[docs] def predict(self, interaction): item = interaction[self.ITEM_ID] x_t = self.full_sort_predict(interaction) scores = x_t[:, item] return scores
[docs] def calculate_loss(self, interaction): user = interaction[self.USER_ID] x_start = self.get_rating_matrix(user) batch_size, device = x_start.size(0), x_start.device ts, pt = self.sample_timesteps(batch_size, device, "importance") noise = torch.randn_like(x_start) if self.noise_scale != 0.0: x_t = self.q_sample(x_start, ts, noise) else: x_t = x_start model_output = self.mlp(x_t, ts) target = { ModelMeanType.START_X: x_start, ModelMeanType.EPSILON: noise, }[self.mean_type] assert model_output.shape == target.shape == x_start.shape mse = mean_flat((target - model_output) ** 2) reloss = self.reweight_loss(x_start, x_t, mse, ts, target, model_output, device) self.update_Lt_history(ts, reloss) # importance sampling reloss /= pt mean_loss = reloss.mean() return mean_loss
[docs] def reweight_loss(self, x_start, x_t, mse, ts, target, model_output, device): if self.reweight: if self.mean_type == ModelMeanType.START_X: # Eq.11 weight = self.SNR(ts - 1) - self.SNR(ts) # Eq.12 weight = torch.where((ts == 0), 1.0, weight) loss = mse elif self.mean_type == ModelMeanType.EPSILON: weight = (1 - self.alphas_cumprod[ts]) / ( (1 - self.alphas_cumprod_prev[ts]) ** 2 * (1 - self.betas[ts]) ) weight = torch.where((ts == 0), 1.0, weight) likelihood = mean_flat( (x_start - self._predict_xstart_from_eps(x_t, ts, model_output)) ** 2 / 2.0 ) loss = torch.where((ts == 0), likelihood, mse) else: weight = torch.tensor([1.0] * len(target)).to(device) loss = mse reloss = weight * loss return reloss
[docs] def update_Lt_history(self, ts, reloss): # update Lt_history & Lt_count for t, loss in zip(ts, reloss): if self.Lt_count[t] == self.history_num_per_term: Lt_history_old = self.Lt_history.clone() self.Lt_history[t, :-1] = Lt_history_old[t, 1:] self.Lt_history[t, -1] = loss.detach() else: try: self.Lt_history[t, self.Lt_count[t]] = loss.detach() self.Lt_count[t] += 1 except: print(t) print(self.Lt_count[t]) print(loss) raise ValueError
[docs] def sample_timesteps( self, batch_size, device, method="uniform", uniform_prob=0.001 ): if method == "importance": # importance sampling if not (self.Lt_count == self.history_num_per_term).all(): return self.sample_timesteps(batch_size, device, method="uniform") Lt_sqrt = torch.sqrt(torch.mean(self.Lt_history**2, axis=-1)) pt_all = Lt_sqrt / torch.sum(Lt_sqrt) pt_all *= 1 - uniform_prob pt_all += uniform_prob / len(pt_all) # ensure the least prob > uniform_prob assert pt_all.sum(-1) - 1.0 < 1e-5 t = torch.multinomial(pt_all, num_samples=batch_size, replacement=True) pt = pt_all.gather(dim=0, index=t) * len(pt_all) return t, pt elif method == "uniform": # uniform sampling t = torch.randint(0, self.steps, (batch_size,), device=device).long() pt = torch.ones_like(t).float() return t, pt else: raise ValueError
[docs] def q_sample(self, x_start, t, noise=None): if noise is None: noise = torch.randn_like(x_start) assert noise.shape == x_start.shape return ( self._extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + self._extract_into_tensor( self.sqrt_one_minus_alphas_cumprod, t, x_start.shape ) * noise )
[docs] def q_posterior_mean_variance(self, x_start, x_t, t): r""" Compute the mean and variance of the diffusion posterior: q(x_{t-1} | x_t, x_0) """ assert x_start.shape == x_t.shape posterior_mean = ( self._extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + self._extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t ) posterior_variance = self._extract_into_tensor( self.posterior_variance, t, x_t.shape ) posterior_log_variance_clipped = self._extract_into_tensor( self.posterior_log_variance_clipped, t, x_t.shape ) assert ( posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == x_start.shape[0] ) return posterior_mean, posterior_variance, posterior_log_variance_clipped
[docs] def p_mean_variance(self, x, t): r""" Apply the model to get p(x_{t-1} | x_t), as well as a prediction of the initial x, x_0. """ B, C = x.shape[:2] assert t.shape == (B,) model_output = self.mlp(x, t) model_variance = self.posterior_variance model_log_variance = self.posterior_log_variance_clipped model_variance = self._extract_into_tensor(model_variance, t, x.shape) model_log_variance = self._extract_into_tensor(model_log_variance, t, x.shape) if self.mean_type == ModelMeanType.START_X: pred_xstart = model_output elif self.mean_type == ModelMeanType.EPSILON: pred_xstart = self._predict_xstart_from_eps(x, t, eps=model_output) else: raise NotImplementedError(self.mean_type) model_mean, _, _ = self.q_posterior_mean_variance( x_start=pred_xstart, x_t=x, t=t ) assert ( model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape ) return { "mean": model_mean, "variance": model_variance, "log_variance": model_log_variance, "pred_xstart": pred_xstart, }
def _predict_xstart_from_eps(self, x_t, t, eps): assert x_t.shape == eps.shape return ( self._extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - self._extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps )
[docs] def SNR(self, t): r""" Compute the signal-to-noise ratio for a single timestep. """ self.alphas_cumprod = self.alphas_cumprod.to(t.device) return self.alphas_cumprod[t] / (1 - self.alphas_cumprod[t])
def _extract_into_tensor(self, arr, timesteps, broadcast_shape): r""" Extract values from a 1-D torch tensor for a batch of indices. Args: arr (torch.Tensor): the 1-D torch tensor. timesteps (torch.Tensor): a tensor of indices into the array to extract. broadcast_shape (torch.Size): a larger shape of K dimensions with the batch dimension equal to the length of timesteps. Returns: torch.Tensor: a tensor of shape [batch_size, 1, ...] where the shape has K dims. """ # res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float() arr = arr.to(timesteps.device) res = arr[timesteps].float() while len(res.shape) < len(broadcast_shape): res = res[..., None] return res.expand(broadcast_shape)
[docs]def betas_from_linear_variance(steps, variance, max_beta=0.999): alpha_bar = 1 - variance betas = [] betas.append(1 - alpha_bar[0]) for i in range(1, steps): betas.append(min(1 - alpha_bar[i] / alpha_bar[i - 1], max_beta)) return np.array(betas)
[docs]def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): r""" Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Args: num_diffusion_timesteps (int): the number of betas to produce. alpha_bar (Callable): a lambda that takes an argument t from 0 to 1 and produces the cumulative product of (1-beta) up to that part of the diffusion process. max_beta (int): the maximum beta to use; use values lower than 1 to prevent singularities. Returns: np.ndarray: a 1-D array of beta values. """ betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) return np.array(betas)
[docs]def normal_kl(mean1, logvar1, mean2, logvar2): r""" Compute the KL divergence between two gaussians. Shapes are automatically broadcasted, so batches can be compared to scalars, among other use cases. """ tensor = None for obj in (mean1, logvar1, mean2, logvar2): if isinstance(obj, torch.Tensor): tensor = obj break assert tensor is not None, "at least one argument must be a Tensor" # Force variances to be Tensors. Broadcasting helps convert scalars to # Tensors, but it does not work for torch.exp(). logvar1, logvar2 = [ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2) ] return 0.5 * ( -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) )
[docs]def mean_flat(tensor): r""" Take the mean over all non-batch dimensions. """ return tensor.mean(dim=list(range(1, len(tensor.shape))))
[docs]def timestep_embedding(timesteps, dim, max_period=10000): r""" Create sinusoidal timestep embeddings. :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. (N,) :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an [N x dim] Tensor of positional embeddings. """ half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half ).to( timesteps.device ) # shape (dim//2,) args = timesteps[:, None].float() * freqs[None] # (N, dim//2) embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) # (N, (dim//2)*2) if dim % 2: # zero pad in the last dimension to ensure shape (N, dim) embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) return embedding