Source code for recbole.data.dataloader.sequential_dataloader

# @Time   : 2020/7/7
# @Author : Yupeng Hou
# @Email  : houyupeng@ruc.edu.cn

# UPDATE
# @Time   : 2020/10/6, 2020/9/17
# @Author : Yupeng Hou, Yushuo Chen
# @email  : houyupeng@ruc.edu.cn, chenyushuo@ruc.edu.cn

"""
recbole.data.dataloader.sequential_dataloader
################################################
"""

import numpy as np
import torch

from recbole.data.dataloader.abstract_dataloader import AbstractDataLoader
from recbole.data.dataloader.neg_sample_mixin import NegSampleByMixin, NegSampleMixin
from recbole.data.interaction import Interaction, cat_interactions
from recbole.utils import DataLoaderType, FeatureSource, FeatureType, InputType


[docs]class SequentialDataLoader(AbstractDataLoader): """:class:`SequentialDataLoader` is used for sequential model. It will do data augmentation for the origin data. And its returned data contains the following: - user id - history items list - history items' interaction time list - item to be predicted - the interaction time of item to be predicted - history list length - other interaction information of item to be predicted Args: config (Config): The config of dataloader. dataset (Dataset): The dataset of dataloader. batch_size (int, optional): The batch_size of dataloader. Defaults to ``1``. dl_format (InputType, optional): The input type of dataloader. Defaults to :obj:`~recbole.utils.enum_type.InputType.POINTWISE`. shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``. """ dl_type = DataLoaderType.ORIGIN def __init__(self, config, dataset, batch_size=1, dl_format=InputType.POINTWISE, shuffle=False): self.uid_field = dataset.uid_field self.iid_field = dataset.iid_field self.time_field = dataset.time_field self.max_item_list_len = config['MAX_ITEM_LIST_LENGTH'] list_suffix = config['LIST_SUFFIX'] for field in dataset.inter_feat: if field != self.uid_field: list_field = field + list_suffix setattr(self, f'{field}_list_field', list_field) ftype = dataset.field2type[field] if ftype in [FeatureType.TOKEN, FeatureType.TOKEN_SEQ]: list_ftype = FeatureType.TOKEN_SEQ else: list_ftype = FeatureType.FLOAT_SEQ if ftype in [FeatureType.TOKEN_SEQ, FeatureType.FLOAT_SEQ]: list_len = (self.max_item_list_len, dataset.field2seqlen[field]) else: list_len = self.max_item_list_len dataset.set_field_property(list_field, list_ftype, FeatureSource.INTERACTION, list_len) self.item_list_length_field = config['ITEM_LIST_LENGTH_FIELD'] dataset.set_field_property(self.item_list_length_field, FeatureType.TOKEN, FeatureSource.INTERACTION, 1) self.uid_list = dataset.uid_list self.item_list_index = dataset.item_list_index self.target_index = dataset.target_index self.item_list_length = dataset.item_list_length self.pre_processed_data = None super().__init__(config, dataset, batch_size=batch_size, dl_format=dl_format, shuffle=shuffle)
[docs] def data_preprocess(self): """Do data augmentation before training/evaluation. """ self.pre_processed_data = self.augmentation(self.item_list_index, self.target_index, self.item_list_length)
@property def pr_end(self): return len(self.uid_list) def _shuffle(self): if self.real_time: new_index = torch.randperm(self.pr_end) self.uid_list = self.uid_list[new_index] self.item_list_index = self.item_list_index[new_index] self.target_index = self.target_index[new_index] self.item_list_length = self.item_list_length[new_index] else: self.pre_processed_data.shuffle() def _next_batch_data(self): cur_data = self._get_processed_data(slice(self.pr, self.pr + self.step)) self.pr += self.step return cur_data def _get_processed_data(self, index): if self.real_time: cur_data = self.augmentation( self.item_list_index[index], self.target_index[index], self.item_list_length[index] ) else: cur_data = self.pre_processed_data[index] return cur_data
[docs] def augmentation(self, item_list_index, target_index, item_list_length): """Data augmentation. Args: item_list_index (numpy.ndarray): the index of history items list in interaction. target_index (numpy.ndarray): the index of items to be predicted in interaction. item_list_length (numpy.ndarray): history list length. Returns: dict: the augmented data. """ new_length = len(item_list_index) new_data = self.dataset.inter_feat[target_index] new_dict = { self.item_list_length_field: torch.tensor(item_list_length), } for field in self.dataset.inter_feat: if field != self.uid_field: list_field = getattr(self, f'{field}_list_field') list_len = self.dataset.field2seqlen[list_field] shape = (new_length, list_len) if isinstance(list_len, int) else (new_length,) + list_len list_ftype = self.dataset.field2type[list_field] dtype = torch.int64 if list_ftype in [FeatureType.TOKEN, FeatureType.TOKEN_SEQ] else torch.float64 new_dict[list_field] = torch.zeros(shape, dtype=dtype) value = self.dataset.inter_feat[field] for i, (index, length) in enumerate(zip(item_list_index, item_list_length)): new_dict[list_field][i][:length] = value[index] new_data.update(Interaction(new_dict)) return new_data
[docs]class SequentialNegSampleDataLoader(NegSampleByMixin, SequentialDataLoader): """:class:`SequentialNegSampleDataLoader` is sequential-dataloader with negative sampling. Like :class:`~recbole.data.dataloader.general_dataloader.GeneralNegSampleDataLoader`, for the result of every batch, we permit that every positive interaction and its negative interaction must be in the same batch. Beside this, when it is in the evaluation stage, and evaluator is topk-like function, we also permit that all the interactions corresponding to each user are in the same batch and positive interactions are before negative interactions. Args: config (Config): The config of dataloader. dataset (Dataset): The dataset of dataloader. sampler (Sampler): The sampler of dataloader. neg_sample_args (dict): The neg_sample_args of dataloader. batch_size (int, optional): The batch_size of dataloader. Defaults to ``1``. dl_format (InputType, optional): The input type of dataloader. Defaults to :obj:`~recbole.utils.enum_type.InputType.POINTWISE`. shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``. """ def __init__( self, config, dataset, sampler, neg_sample_args, batch_size=1, dl_format=InputType.POINTWISE, shuffle=False ): super().__init__( config, dataset, sampler, neg_sample_args, batch_size=batch_size, dl_format=dl_format, shuffle=shuffle ) def _batch_size_adaptation(self): batch_num = max(self.batch_size // self.times, 1) new_batch_size = batch_num * self.times self.step = batch_num self.upgrade_batch_size(new_batch_size) def _next_batch_data(self): cur_data = self._get_processed_data(slice(self.pr, self.pr + self.step)) cur_data = self._neg_sampling(cur_data) self.pr += self.step if self.user_inter_in_one_batch: cur_data_len = len(cur_data[self.uid_field]) pos_len_list = np.ones(cur_data_len // self.times, dtype=np.int64) user_len_list = pos_len_list * self.times cur_data.set_additional_info(list(pos_len_list), list(user_len_list)) return cur_data def _neg_sampling(self, data): if self.user_inter_in_one_batch: data_len = len(data[self.uid_field]) data_list = [] for i in range(data_len): uids = data[self.uid_field][i:i + 1] neg_iids = self.sampler.sample_by_user_ids(uids, self.neg_sample_by) cur_data = data[i:i + 1] data_list.append(self.sampling_func(cur_data, neg_iids)) return cat_interactions(data_list) else: uids = data[self.uid_field] neg_iids = self.sampler.sample_by_user_ids(uids, self.neg_sample_by) return self.sampling_func(data, neg_iids) def _neg_sample_by_pair_wise_sampling(self, data, neg_iids): new_data = data.repeat(self.times) new_data.update(Interaction({self.neg_item_id: neg_iids})) return new_data def _neg_sample_by_point_wise_sampling(self, data, neg_iids): pos_inter_num = len(data) new_data = data.repeat(self.times) new_data[self.iid_field][pos_inter_num:] = neg_iids labels = torch.zeros(pos_inter_num * self.times) labels[:pos_inter_num] = 1.0 new_data.update(Interaction({self.label_field: labels})) return new_data
[docs] def get_pos_len_list(self): """ Returns: numpy.ndarray: Number of positive item for each user in a training/evaluating epoch. """ return np.ones(self.pr_end, dtype=np.int64)
[docs] def get_user_len_list(self): """ Returns: numpy.ndarray: Number of all item for each user in a training/evaluating epoch. """ return np.full(self.pr_end, self.times)
[docs]class SequentialFullDataLoader(NegSampleMixin, SequentialDataLoader): """:class:`SequentialFullDataLoader` is a sequential-dataloader with full sort. In order to speed up calculation, this dataloader would only return then user part of interactions, positive items and used items. It would not return negative items. Args: config (Config): The config of dataloader. dataset (Dataset): The dataset of dataloader. sampler (Sampler): The sampler of dataloader. neg_sample_args (dict): The neg_sample_args of dataloader. batch_size (int, optional): The batch_size of dataloader. Defaults to ``1``. dl_format (InputType, optional): The input type of dataloader. Defaults to :obj:`~recbole.utils.enum_type.InputType.POINTWISE`. shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``. """ dl_type = DataLoaderType.FULL def __init__( self, config, dataset, sampler, neg_sample_args, batch_size=1, dl_format=InputType.POINTWISE, shuffle=False ): super().__init__( config, dataset, sampler, neg_sample_args, batch_size=batch_size, dl_format=dl_format, shuffle=shuffle ) def _batch_size_adaptation(self): pass def _neg_sampling(self, inter_feat): pass def _shuffle(self): self.logger.warnning('SequentialFullDataLoader can\'t shuffle') def _next_batch_data(self): interaction = super()._next_batch_data() inter_num = len(interaction) pos_len_list = np.ones(inter_num, dtype=np.int64) user_len_list = np.full(inter_num, self.item_num) interaction.set_additional_info(pos_len_list, user_len_list) scores_row = torch.arange(inter_num).repeat(2) padding_idx = torch.zeros(inter_num, dtype=torch.int64) positive_idx = interaction[self.iid_field] scores_col_after = torch.cat((padding_idx, positive_idx)) scores_col_before = torch.cat((positive_idx, padding_idx)) return interaction, None, scores_row, scores_col_after, scores_col_before
[docs] def get_pos_len_list(self): """ Returns: numpy.ndarray or list: Number of positive item for each user in a training/evaluating epoch. """ return np.ones(self.pr_end, dtype=np.int64)
[docs] def get_user_len_list(self): """ Returns: numpy.ndarray: Number of all item for each user in a training/evaluating epoch. """ return np.full(self.pr_end, self.item_num)