Customize DataLoaders

Here, we present how to develop a new DataLoader, and apply it into our tool. If we have a new model, and there is special requirement for loading the data, then we need to design a new DataLoader.

Abstract DataLoader

In this project, there are two abstract dataloaders: AbstractDataLoader, NegSampleDataLoader.

In general, the new dataloader should inherit from the above two abstract classes. If one only needs to modify existing DataLoader, you can also inherit from it. The documentation of dataloader: recbole.data.dataloader

AbstractDataLoader

AbstractDataLoader is the most basic abstract class, which includes three important attributes: pr, batch_size and step. The pr represents the pointer of this dataloader. The batch_size represents the upper bound of the number of interactions in one single batch. And the step represents the increment of pr for each batch.

And AbstractDataLoader includes four functions to be implemented: _init_batch_size_and_step(), pr_end(), _shuffle() and _next_batch_data(). _init_batch_size_and_step() is used to initialize batch_size and step. pr_end() is the max pr plus 1. _shuffle() is leveraged to permute the dataset, which will be invoked by __iter__() if the parameter shuffle is True. _next_batch_data() is used to load the next batch data, and return the Interaction format, which will be invoked in __next__().

NegSampleDataLoader

NegSampleDataLoader inherents from AbstractDataLoader, which is used for negative sampling. It has four additional functions upon its parent class: _set_neg_sample_args(), _neg_sampling(), _neg_sample_by_pair_wise_sampling(), and _neg_sample_by_point_wise_sampling(). These four functions don’t need to be implemented, they are just auxiliary functions to NegSampleDataLoader.

In current studies, there have only two sampling strategies, the first one is pair-wise sampling, the other is point-wise sampling. _neg_sample_by_pair_wise_sampling(), and _neg_sample_by_point_wise_sampling() are implemented according to these two sampling strategies.

_set_neg_sample_args() is used to set the negative sampling args like the sampling strategies, sampling functions and so on. _neg_sampling() is used for negative sampling, which will generate negative items and invoke _neg_sample_by_pair_wise_sampling(), or _neg_sample_by_point_wise_sampling() according to the sampling strategies.

Example

Here, we take UserDataLoader as the example, this dataloader returns user id, which is leveraged to train the user representations.

Implement __init__()

__init__() can be used to initialize some of the necessary parameters. Here, we just need to record uid_field and generate user_list which contains all user ids. And because of some training requirements, shuffle should be set to True.

def __init__(self, config, dataset, sampler, shuffle=False):
    if shuffle is False:
        shuffle = True
        self.logger.warning('UserDataLoader must shuffle the data.')

    self.uid_field = dataset.uid_field
    self.user_list = Interaction({self.uid_field: torch.arange(dataset.user_num)})

    super().__init__(config, dataset, sampler, shuffle=shuffle)

Implement _init_batch_size_and_step()

Because UserDataLoader don’t need negative sampling, so the batch_size and step can be both set to self.config['train_batch_size'].

def _init_batch_size_and_step(self):
    batch_size = self.config['train_batch_size']
    self.step = batch_size
    self.set_batch_size(batch_size)

Implement pr_end() and _shuffle()

Since this dataloader only returns user id, these function can be implemented readily.

@property
def pr_end(self):
    return len(self.user_list)

def _shuffle(self):
    self.user_list.shuffle()

Implement _next_batch_data

This function only requires to return user id from user_list, we just select corresponding slice of user_list and return this slice.

def _next_batch_data(self):
    cur_data = self.user_list[self.pr:self.pr + self.step]
    self.pr += self.step
    return cur_data

Complete Code

class UserDataLoader(AbstractDataLoader):
    """:class:`UserDataLoader` will return a batch of data which only contains user-id when it is iterated.

    Args:
        config (Config): The config of dataloader.
        dataset (Dataset): The dataset of dataloader.
        sampler (Sampler): The sampler of dataloader.
        shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``.

    Attributes:
        shuffle (bool): Whether the dataloader will be shuffle after a round.
            However, in :class:`UserDataLoader`, it's guaranteed to be ``True``.
    """

    dl_type = DataLoaderType.ORIGIN

    def __init__(self, config, dataset, sampler, shuffle=False):
        if shuffle is False:
            shuffle = True
            self.logger.warning('UserDataLoader must shuffle the data.')

        self.uid_field = dataset.uid_field
        self.user_list = Interaction({self.uid_field: torch.arange(dataset.user_num)})

        super().__init__(config, dataset, sampler, shuffle=shuffle)

    def _init_batch_size_and_step(self):
        batch_size = self.config['train_batch_size']
        self.step = batch_size
        self.set_batch_size(batch_size)

    @property
    def pr_end(self):
        return len(self.user_list)

    def _shuffle(self):
        self.user_list.shuffle()

    def _next_batch_data(self):
        cur_data = self.user_list[self.pr:self.pr + self.step]
        self.pr += self.step
        return cur_data

Other more complex Dataloader development can refer to the source code.