Customize DataLoaders¶
Here, we present how to develop a new DataLoader, and apply it into our tool. If we have a new model, and there is special requirement for loading the data, then we need to design a new DataLoader.
Abstract DataLoader¶
In this project, there are two abstract dataloaders:
AbstractDataLoader
,
NegSampleDataLoader
.
In general, the new dataloader should inherit from the above two abstract classes. If one only needs to modify existing DataLoader, you can also inherit from it. The documentation of dataloader: recbole.data.dataloader
AbstractDataLoader¶
AbstractDataLoader
is the most basic abstract class,
which includes three important attributes:
pr
,
batch_size
and
step
.
The pr
represents the pointer of this dataloader.
The batch_size
represents the upper bound of the number of interactions in one single batch.
And the step
represents the increment of pr
for each batch.
And AbstractDataLoader
includes four functions to be implemented:
_init_batch_size_and_step()
,
pr_end()
,
_shuffle()
and _next_batch_data()
.
_init_batch_size_and_step()
is used to
initialize batch_size
and
step
.
pr_end()
is the max
pr
plus 1.
_shuffle()
is leveraged to permute the dataset,
which will be invoked by __iter__()
if the parameter shuffle
is True.
_next_batch_data()
is used to
load the next batch data, and return the Interaction
format,
which will be invoked in __next__()
.
NegSampleDataLoader¶
NegSampleDataLoader
inherents from
AbstractDataLoader
, which is used for negative sampling.
It has four additional functions upon its parent class:
_set_neg_sample_args()
,
_neg_sampling()
,
_neg_sample_by_pair_wise_sampling()
,
and _neg_sample_by_point_wise_sampling()
.
These four functions don’t need to be implemented, they are just auxiliary functions to
NegSampleDataLoader
.
In current studies, there have only two sampling strategies,
the first one is pair-wise sampling
, the other is point-wise sampling
.
_neg_sample_by_pair_wise_sampling()
,
and _neg_sample_by_point_wise_sampling()
are implemented according to these two sampling strategies.
_set_neg_sample_args()
is used to
set the negative sampling args like the sampling strategies, sampling functions and so on.
_neg_sampling()
is used for negative sampling,
which will generate negative items and invoke
_neg_sample_by_pair_wise_sampling()
,
or _neg_sample_by_point_wise_sampling()
according to the sampling strategies.
Example¶
Here, we take UserDataLoader
as the example,
this dataloader returns user id, which is leveraged to train the user representations.
Implement __init__()¶
__init__()
can be used to initialize some of the necessary parameters.
Here, we just need to record uid_field
and generate user_list
which contains all user ids.
And because of some training requirements, shuffle
should be set to True
.
def __init__(self, config, dataset, sampler, shuffle=False):
if shuffle is False:
shuffle = True
self.logger.warning('UserDataLoader must shuffle the data.')
self.uid_field = dataset.uid_field
self.user_list = Interaction({self.uid_field: torch.arange(dataset.user_num)})
super().__init__(config, dataset, sampler, shuffle=shuffle)
Implement _init_batch_size_and_step()¶
Because UserDataLoader
don’t need negative sampling,
so the batch_size
and step
can be both set to self.config['train_batch_size']
.
def _init_batch_size_and_step(self):
batch_size = self.config['train_batch_size']
self.step = batch_size
self.set_batch_size(batch_size)
Implement pr_end() and _shuffle()¶
Since this dataloader only returns user id, these function can be implemented readily.
@property
def pr_end(self):
return len(self.user_list)
def _shuffle(self):
self.user_list.shuffle()
Implement _next_batch_data¶
This function only requires to return user id from user_list
,
we just select corresponding slice of user_list
and return this slice.
def _next_batch_data(self):
cur_data = self.user_list[self.pr:self.pr + self.step]
self.pr += self.step
return cur_data
Complete Code¶
class UserDataLoader(AbstractDataLoader):
""":class:`UserDataLoader` will return a batch of data which only contains user-id when it is iterated.
Args:
config (Config): The config of dataloader.
dataset (Dataset): The dataset of dataloader.
sampler (Sampler): The sampler of dataloader.
shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``.
Attributes:
shuffle (bool): Whether the dataloader will be shuffle after a round.
However, in :class:`UserDataLoader`, it's guaranteed to be ``True``.
"""
dl_type = DataLoaderType.ORIGIN
def __init__(self, config, dataset, sampler, shuffle=False):
if shuffle is False:
shuffle = True
self.logger.warning('UserDataLoader must shuffle the data.')
self.uid_field = dataset.uid_field
self.user_list = Interaction({self.uid_field: torch.arange(dataset.user_num)})
super().__init__(config, dataset, sampler, shuffle=shuffle)
def _init_batch_size_and_step(self):
batch_size = self.config['train_batch_size']
self.step = batch_size
self.set_batch_size(batch_size)
@property
def pr_end(self):
return len(self.user_list)
def _shuffle(self):
self.user_list.shuffle()
def _next_batch_data(self):
cur_data = self.user_list[self.pr:self.pr + self.step]
self.pr += self.step
return cur_data
Other more complex Dataloader development can refer to the source code.