# @Time : 2020/7/21
# @Author : Yupeng Hou
# @Email : houyupeng@ruc.edu.cn
# UPDATE:
# @Time : 2020/10/19, 2020/9/17, 2020/8/31
# @Author : Yupeng Hou, Yushuo Chen, Kaiyuan Li
# @Email : houyupeng@ruc.edu.cn, chenyushuo@ruc.edu.cn, tsotfsk@outlook.com
"""
recbole.data.utils
########################
"""
import copy
import importlib
import os
from recbole.config import EvalSetting
from recbole.data.dataloader import *
from recbole.sampler import KGSampler, Sampler, RepeatableSampler
from recbole.utils import ModelType, ensure_dir
[docs]def create_dataset(config):
"""Create dataset according to :attr:`config['model']` and :attr:`config['MODEL_TYPE']`.
Args:
config (Config): An instance object of Config, used to record parameter information.
Returns:
Dataset: Constructed dataset.
"""
dataset_module = importlib.import_module('recbole.data.dataset')
if hasattr(dataset_module, config['model'] + 'Dataset'):
return getattr(dataset_module, config['model'] + 'Dataset')(config)
else:
model_type = config['MODEL_TYPE']
if model_type == ModelType.SEQUENTIAL:
from .dataset import SequentialDataset
return SequentialDataset(config)
elif model_type == ModelType.KNOWLEDGE:
from .dataset import KnowledgeBasedDataset
return KnowledgeBasedDataset(config)
elif model_type == ModelType.SOCIAL:
from .dataset import SocialDataset
return SocialDataset(config)
elif model_type == ModelType.XGBOOST:
from .dataset import XgboostDataset
return XgboostDataset(config)
else:
from .dataset import Dataset
return Dataset(config)
[docs]def data_preparation(config, dataset, save=False):
"""Split the dataset by :attr:`config['eval_setting']` and call :func:`dataloader_construct` to create
corresponding dataloader.
Args:
config (Config): An instance object of Config, used to record parameter information.
dataset (Dataset): An instance object of Dataset, which contains all interaction records.
save (bool, optional): If ``True``, it will call :func:`save_datasets` to save split dataset.
Defaults to ``False``.
Returns:
tuple:
- train_data (AbstractDataLoader): The dataloader for training.
- valid_data (AbstractDataLoader): The dataloader for validation.
- test_data (AbstractDataLoader): The dataloader for testing.
"""
model_type = config['MODEL_TYPE']
es_str = [_.strip() for _ in config['eval_setting'].split(',')]
es = EvalSetting(config)
es.set_ordering_and_splitting(es_str[0])
built_datasets = dataset.build(es)
train_dataset, valid_dataset, test_dataset = built_datasets
phases = ['train', 'valid', 'test']
sampler = None
if save:
save_datasets(config['checkpoint_dir'], name=phases, dataset=built_datasets)
kwargs = {}
if config['training_neg_sample_num']:
if dataset.label_field in dataset.inter_feat:
raise ValueError(
f'`training_neg_sample_num` should be 0 '
f'if inter_feat have label_field [{dataset.label_field}].'
)
train_distribution = config['training_neg_sample_distribution'] or 'uniform'
es.neg_sample_by(by=config['training_neg_sample_num'], distribution=train_distribution)
if model_type != ModelType.SEQUENTIAL:
sampler = Sampler(phases, built_datasets, es.neg_sample_args['distribution'])
else:
sampler = RepeatableSampler(phases, dataset, es.neg_sample_args['distribution'])
kwargs['sampler'] = sampler.set_phase('train')
kwargs['neg_sample_args'] = copy.deepcopy(es.neg_sample_args)
if model_type == ModelType.KNOWLEDGE:
kg_sampler = KGSampler(dataset, es.neg_sample_args['distribution'])
kwargs['kg_sampler'] = kg_sampler
train_data = dataloader_construct(
name='train',
config=config,
eval_setting=es,
dataset=train_dataset,
dl_format=config['MODEL_INPUT_TYPE'],
batch_size=config['train_batch_size'],
shuffle=True,
**kwargs
)
kwargs = {}
if len(es_str) > 1 and getattr(es, es_str[1], None):
if dataset.label_field in dataset.inter_feat:
raise ValueError(
f'It can not validate with `{es_str[1]}` '
f'when inter_feat have label_field [{dataset.label_field}].'
)
getattr(es, es_str[1])()
if sampler is None:
if model_type != ModelType.SEQUENTIAL:
sampler = Sampler(phases, built_datasets, es.neg_sample_args['distribution'])
else:
sampler = RepeatableSampler(phases, dataset, es.neg_sample_args['distribution'])
sampler.set_distribution(es.neg_sample_args['distribution'])
kwargs['sampler'] = [sampler.set_phase('valid'), sampler.set_phase('test')]
kwargs['neg_sample_args'] = copy.deepcopy(es.neg_sample_args)
valid_data, test_data = dataloader_construct(
name='evaluation',
config=config,
eval_setting=es,
dataset=[valid_dataset, test_dataset],
batch_size=config['eval_batch_size'],
**kwargs
)
return train_data, valid_data, test_data
[docs]def dataloader_construct(
name, config, eval_setting, dataset, dl_format=InputType.POINTWISE, batch_size=1, shuffle=False, **kwargs
):
"""Get a correct dataloader class by calling :func:`get_data_loader` to construct dataloader.
Args:
name (str): The stage of dataloader. It can only take two values: 'train' or 'evaluation'.
config (Config): An instance object of Config, used to record parameter information.
eval_setting (EvalSetting): An instance object of EvalSetting, used to record evaluation settings.
dataset (Dataset or list of Dataset): The split dataset for constructing dataloader.
dl_format (InputType, optional): The input type of dataloader. Defaults to
:obj:`~recbole.utils.enum_type.InputType.POINTWISE`.
batch_size (int, optional): The batch_size of dataloader. Defaults to ``1``.
shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``.
**kwargs: Other input args of dataloader, such as :attr:`sampler`, :attr:`kg_sampler`
and :attr:`neg_sample_args`. The meaning of these args is the same as these args in some dataloaders.
Returns:
AbstractDataLoader or list of AbstractDataLoader: Constructed dataloader in split dataset.
"""
if not isinstance(dataset, list):
dataset = [dataset]
if not isinstance(batch_size, list):
batch_size = [batch_size] * len(dataset)
if len(dataset) != len(batch_size):
raise ValueError(f'Dataset {dataset} and batch_size {batch_size} should have the same length.')
kwargs_list = [{} for _ in range(len(dataset))]
for key, value in kwargs.items():
key = [key] * len(dataset)
if not isinstance(value, list):
value = [value] * len(dataset)
if len(dataset) != len(value):
raise ValueError(f'Dataset {dataset} and {key} {value} should have the same length.')
for kw, k, w in zip(kwargs_list, key, value):
kw[k] = w
model_type = config['MODEL_TYPE']
logger = getLogger()
logger.info(f'Build [{model_type}] DataLoader for [{name}] with format [{dl_format}]')
logger.info(eval_setting)
logger.info(f'batch_size = [{batch_size}], shuffle = [{shuffle}]\n')
dataloader = get_data_loader(name, config, eval_setting)
try:
ret = [
dataloader(config=config, dataset=ds, batch_size=bs, dl_format=dl_format, shuffle=shuffle, **kw)
for ds, bs, kw in zip(dataset, batch_size, kwargs_list)
]
except TypeError:
raise ValueError('training_neg_sample_num should be 0')
if len(ret) == 1:
return ret[0]
else:
return ret
[docs]def save_datasets(save_path, name, dataset):
"""Save split datasets.
Args:
save_path (str): The path of directory for saving.
name (str or list of str): The stage of dataloader. It can only take two values: 'train' or 'evaluation'.
dataset (Dataset or list of Dataset): The split datasets.
"""
if (not isinstance(name, list)) and (not isinstance(dataset, list)):
name = [name]
dataset = [dataset]
if len(name) != len(dataset):
raise ValueError(f'Length of name {name} should equal to length of dataset {dataset}.')
for i, d in enumerate(dataset):
cur_path = os.path.join(save_path, name[i])
ensure_dir(cur_path)
d.save(cur_path)
[docs]def get_data_loader(name, config, eval_setting):
"""Return a dataloader class according to :attr:`config` and :attr:`eval_setting`.
Args:
name (str): The stage of dataloader. It can only take two values: 'train' or 'evaluation'.
config (Config): An instance object of Config, used to record parameter information.
eval_setting (EvalSetting): An instance object of EvalSetting, used to record evaluation settings.
Returns:
type: The dataloader class that meets the requirements in :attr:`config` and :attr:`eval_setting`.
"""
register_table = {
'DIN': _get_DIN_data_loader,
"MultiDAE": _get_AE_data_loader,
"MultiVAE": _get_AE_data_loader,
'MacridVAE': _get_AE_data_loader,
'CDAE': _get_AE_data_loader,
'ENMF': _get_AE_data_loader
}
if config['model'] in register_table:
return register_table[config['model']](name, config, eval_setting)
model_type = config['MODEL_TYPE']
neg_sample_strategy = eval_setting.neg_sample_args['strategy']
if model_type == ModelType.GENERAL or model_type == ModelType.TRADITIONAL:
if neg_sample_strategy == 'none':
return GeneralDataLoader
elif neg_sample_strategy == 'by':
return GeneralNegSampleDataLoader
elif neg_sample_strategy == 'full':
return GeneralFullDataLoader
elif model_type == ModelType.CONTEXT:
if neg_sample_strategy == 'none':
return ContextDataLoader
elif neg_sample_strategy == 'by':
return ContextNegSampleDataLoader
elif neg_sample_strategy == 'full':
return ContextFullDataLoader
elif model_type == ModelType.SEQUENTIAL:
if neg_sample_strategy == 'none':
return SequentialDataLoader
elif neg_sample_strategy == 'by':
return SequentialNegSampleDataLoader
elif neg_sample_strategy == 'full':
return SequentialFullDataLoader
elif model_type == ModelType.XGBOOST:
if neg_sample_strategy == 'none':
return XgboostDataLoader
elif neg_sample_strategy == 'by':
return XgboostNegSampleDataLoader
elif neg_sample_strategy == 'full':
return XgboostFullDataLoader
elif model_type == ModelType.KNOWLEDGE:
if neg_sample_strategy == 'by':
if name == 'train':
return KnowledgeBasedDataLoader
else:
return GeneralNegSampleDataLoader
elif neg_sample_strategy == 'full':
return GeneralFullDataLoader
elif neg_sample_strategy == 'none':
# return GeneralDataLoader
# TODO 训练也可以为none? 看general的逻辑似乎是都可以为None
raise NotImplementedError(
'The use of external negative sampling for knowledge model has not been implemented'
)
else:
raise NotImplementedError(f'Model_type [{model_type}] has not been implemented.')
def _get_DIN_data_loader(name, config, eval_setting):
"""Customized function for DIN to get correct dataloader class.
Args:
name (str): The stage of dataloader. It can only take two values: 'train' or 'evaluation'.
config (Config): An instance object of Config, used to record parameter information.
eval_setting (EvalSetting): An instance object of EvalSetting, used to record evaluation settings.
Returns:
type: The dataloader class that meets the requirements in :attr:`config` and :attr:`eval_setting`.
"""
neg_sample_strategy = eval_setting.neg_sample_args['strategy']
if neg_sample_strategy == 'none':
return SequentialDataLoader
elif neg_sample_strategy == 'by':
return SequentialNegSampleDataLoader
elif neg_sample_strategy == 'full':
return SequentialFullDataLoader
def _get_AE_data_loader(name, config, eval_setting):
"""Customized function for Multi-DAE and Multi-VAE to get correct dataloader class.
Args:
name (str): The stage of dataloader. It can only take two values: 'train' or 'evaluation'.
config (Config): An instance object of Config, used to record parameter information.
eval_setting (EvalSetting): An instance object of EvalSetting, used to record evaluation settings.
Returns:
type: The dataloader class that meets the requirements in :attr:`config` and :attr:`eval_setting`.
"""
neg_sample_strategy = eval_setting.neg_sample_args['strategy']
if name == "train":
return UserDataLoader
else:
if neg_sample_strategy == 'none':
return GeneralDataLoader
elif neg_sample_strategy == 'by':
return GeneralNegSampleDataLoader
elif neg_sample_strategy == 'full':
return GeneralFullDataLoader
[docs]class DLFriendlyAPI(object):
"""A Decorator class, which helps copying :class:`Dataset` methods to :class:`DataLoader`.
These methods are called *DataLoader Friendly APIs*.
E.g. if ``train_data`` is an object of :class:`DataLoader`,
and :meth:`~recbole.data.dataset.dataset.Dataset.num` is a method of :class:`~recbole.data.dataset.dataset.Dataset`,
Cause it has been decorated, :meth:`~recbole.data.dataset.dataset.Dataset.num` can be called directly by
``train_data``.
See the example of :meth:`set` for details.
Attributes:
dataloader_apis (set): Register table that saves all the method names of DataLoader Friendly APIs.
"""
def __init__(self):
self.dataloader_apis = set()
def __iter__(self):
return self.dataloader_apis
[docs] def set(self):
"""
Example:
.. code:: python
from recbole.data.utils import dlapi
@dlapi.set()
def dataset_meth():
...
"""
def decorator(f):
self.dataloader_apis.add(f.__name__)
return f
return decorator
dlapi = DLFriendlyAPI()