Customize DataLoaders ====================== Here, we present how to develop a new DataLoader, and apply it into our tool. If we have a new model, and there is special requirement for loading the data, then we need to design a new DataLoader. Abstract DataLoader -------------------------- In this project, there are two abstract dataloaders: :class:`~recbole.data.dataloader.abstract_dataloader.AbstractDataLoader`, :class:`~recbole.data.dataloader.abstract_dataloader.NegSampleDataLoader`. In general, the new dataloader should inherit from the above two abstract classes. If one only needs to modify existing DataLoader, you can also inherit from it. The documentation of dataloader: :doc:`../../recbole/recbole.data.dataloader` AbstractDataLoader ^^^^^^^^^^^^^^^^^^^^^^^^^^ :class:`~recbole.data.dataloader.abstract_dataloader.AbstractDataLoader` is the most basic abstract class, which includes three important attributes: :attr:`~recbole.data.dataloader.abstract_dataloader.AbstractDataLoader.pr`, :attr:`~recbole.data.dataloader.abstract_dataloader.AbstractDataloader.batch_size` and :attr:`~recbole.data.dataloader.abstract_dataloader.AbstractDataloader.step`. The :attr:`~recbole.data.dataloader.abstract_dataloader.AbstractDataLoader.pr` represents the pointer of this dataloader. The :attr:`~recbole.data.dataloader.abstract_dataloader.AbstractDataloader.batch_size` represents the upper bound of the number of interactions in one single batch. And the :attr:`~recbole.data.dataloader.abstract_dataloader.AbstractDataloader.step` represents the increment of :attr:`~recbole.data.dataloader.abstract_dataloader.AbstractDataloader.pr` for each batch. And :class:`~recbole.data.dataloader.abstract_dataloader.AbstractDataLoader` includes four functions to be implemented: :meth:`~recbole.data.dataloader.abstract_dataloader.AbstractDataLoader._init_batch_size_and_step`, :meth:`~recbole.data.dataloader.abstract_dataloader.AbstractDataLoader.pr_end`, :meth:`~recbole.data.dataloader.abstract_dataloader.AbstractDataLoader._shuffle` and :meth:`~recbole.data.dataloader.abstract_dataloader.AbstractDataLoader._next_batch_data`. :meth:`~recbole.data.dataloader.abstract_dataloader.AbstractDataLoader._init_batch_size_and_step` is used to initialize :attr:`~recbole.data.dataloader.abstract_dataloader.AbstractDataloader.batch_size` and :attr:`~recbole.data.dataloader.abstract_dataloader.AbstractDataloader.step`. :meth:`~recbole.data.dataloader.abstract_dataloader.AbstractDataLoader.pr_end` is the max :attr:`~recbole.data.dataloader.abstract_dataloader.AbstractDataLoader.pr` plus 1. :meth:`~recbole.data.dataloader.abstract_dataloader.AbstractDataLoader._shuffle` is leveraged to permute the dataset, which will be invoked by :meth:`~recbole.data.dataloader.abstract_dataloader.AbstractDataLoader.__iter__` if the parameter :attr:`~recbole.data.dataloader.abstract_dataloader.AbstractDataLoader.shuffle` is True. :meth:`~recbole.data.dataloader.abstract_dataloader.AbstractDataLoader._next_batch_data` is used to load the next batch data, and return the :class:`~recbole.data.interaction.Interaction` format, which will be invoked in :meth:`~recbole.data.dataloader.abstract_dataloader.AbstractDataLoader.__next__`. NegSampleDataLoader ^^^^^^^^^^^^^^^^^^^^^^^^^^ :class:`~recbole.data.dataloader.abstract_dataloader.NegSampleDataLoader` inherents from :class:`~recbole.data.dataloader.abstract_dataloader.AbstractDataLoader`, which is used for negative sampling. It has four additional functions upon its parent class: :meth:`~recbole.data.dataloader.abstract_dataloader.NegSampleDataLoader._set_neg_sample_args`, :meth:`~recbole.data.dataloader.abstract_dataloader.NegSampleDataLoader._neg_sampling`, :meth:`~recbole.data.dataloader.abstract_dataloader.NegSampleDataLoader._neg_sample_by_pair_wise_sampling`, and :meth:`~recbole.data.dataloader.abstract_dataloader.NegSampleDataLoader._neg_sample_by_point_wise_sampling`. These four functions don't need to be implemented, they are just auxiliary functions to :class:`~recbole.data.dataloader.abstract_dataloader.NegSampleDataLoader`. In current studies, there have only two sampling strategies, the first one is ``pair-wise sampling``, the other is ``point-wise sampling``. :meth:`~recbole.data.dataloader.abstract_dataloader.NegSampleDataLoader._neg_sample_by_pair_wise_sampling`, and :meth:`~recbole.data.dataloader.abstract_dataloader.NegSampleDataLoader._neg_sample_by_point_wise_sampling` are implemented according to these two sampling strategies. :meth:`~recbole.data.dataloader.abstract_dataloader.NegSampleDataLoader._set_neg_sample_args` is used to set the negative sampling args like the sampling strategies, sampling functions and so on. :meth:`~recbole.data.dataloader.abstract_dataloader.NegSampleDataLoader._neg_sampling` is used for negative sampling, which will generate negative items and invoke :meth:`~recbole.data.dataloader.abstract_dataloader.NegSampleDataLoader._neg_sample_by_pair_wise_sampling`, or :meth:`~recbole.data.dataloader.abstract_dataloader.NegSampleDataLoader._neg_sample_by_point_wise_sampling` according to the sampling strategies. Example -------------------------- Here, we take :class:`~recbole.data.dataloader.user_dataloader.UserDataLoader` as the example, this dataloader returns user id, which is leveraged to train the user representations. Implement __init__() ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ :meth:`__init__` can be used to initialize some of the necessary parameters. Here, we just need to record :attr:`uid_field` and generate :attr:`user_list` which contains all user ids. And because of some training requirements, :attr:`shuffle` should be set to ``True``. .. code:: python def __init__(self, config, dataset, sampler, shuffle=False): if shuffle is False: shuffle = True self.logger.warning('UserDataLoader must shuffle the data.') self.uid_field = dataset.uid_field self.user_list = Interaction({self.uid_field: torch.arange(dataset.user_num)}) super().__init__(config, dataset, sampler, shuffle=shuffle) Implement _init_batch_size_and_step() ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Because :class:`~recbole.data.dataloader.user_dataloader.UserDataLoader` don't need negative sampling, so the :attr:`batch_size` and :attr:`step` can be both set to :attr:`self.config['train_batch_size']`. .. code:: python def _init_batch_size_and_step(self): batch_size = self.config['train_batch_size'] self.step = batch_size self.set_batch_size(batch_size) Implement pr_end() and _shuffle() ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Since this dataloader only returns user id, these function can be implemented readily. .. code:: python @property def pr_end(self): return len(self.user_list) def _shuffle(self): self.user_list.shuffle() Implement _next_batch_data ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ This function only requires to return user id from :attr:`user_list`, we just select corresponding slice of :attr:`user_list` and return this slice. .. code:: python def _next_batch_data(self): cur_data = self.user_list[self.pr:self.pr + self.step] self.pr += self.step return cur_data Complete Code ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. code:: python class UserDataLoader(AbstractDataLoader): """:class:`UserDataLoader` will return a batch of data which only contains user-id when it is iterated. Args: config (Config): The config of dataloader. dataset (Dataset): The dataset of dataloader. sampler (Sampler): The sampler of dataloader. shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``. Attributes: shuffle (bool): Whether the dataloader will be shuffle after a round. However, in :class:`UserDataLoader`, it's guaranteed to be ``True``. """ dl_type = DataLoaderType.ORIGIN def __init__(self, config, dataset, sampler, shuffle=False): if shuffle is False: shuffle = True self.logger.warning('UserDataLoader must shuffle the data.') self.uid_field = dataset.uid_field self.user_list = Interaction({self.uid_field: torch.arange(dataset.user_num)}) super().__init__(config, dataset, sampler, shuffle=shuffle) def _init_batch_size_and_step(self): batch_size = self.config['train_batch_size'] self.step = batch_size self.set_batch_size(batch_size) @property def pr_end(self): return len(self.user_list) def _shuffle(self): self.user_list.shuffle() def _next_batch_data(self): cur_data = self.user_list[self.pr:self.pr + self.step] self.pr += self.step return cur_data Other more complex Dataloader development can refer to the source code.