Source code for recbole.data.utils

# @Time   : 2020/7/21
# @Author : Yupeng Hou
# @Email  : houyupeng@ruc.edu.cn

# UPDATE:
# @Time   : 2020/10/19, 2020/9/17, 2020/8/31
# @Author : Yupeng Hou, Yushuo Chen, Kaiyuan Li
# @Email  : houyupeng@ruc.edu.cn, chenyushuo@ruc.edu.cn, tsotfsk@outlook.com

"""
recbole.data.utils
########################
"""

import copy
import os
import importlib

from recbole.config import EvalSetting
from recbole.sampler import KGSampler, Sampler, RepeatableSampler
from recbole.utils import ModelType
from recbole.data.dataloader import *


[docs]def create_dataset(config): """Create dataset according to :attr:`config['model']` and :attr:`config['MODEL_TYPE']`. Args: config (Config): An instance object of Config, used to record parameter information. Returns: Dataset: Constructed dataset. """ try: return getattr(importlib.import_module('recbole.data.dataset'), config['model'] + 'Dataset')(config) except AttributeError: model_type = config['MODEL_TYPE'] if model_type == ModelType.SEQUENTIAL: from .dataset import SequentialDataset return SequentialDataset(config) elif model_type == ModelType.KNOWLEDGE: from .dataset import KnowledgeBasedDataset return KnowledgeBasedDataset(config) elif model_type == ModelType.SOCIAL: from .dataset import SocialDataset return SocialDataset(config) else: from .dataset import Dataset return Dataset(config)
[docs]def data_preparation(config, dataset, save=False): """Split the dataset by :attr:`config['eval_setting']` and call :func:`dataloader_construct` to create corresponding dataloader. Args: config (Config): An instance object of Config, used to record parameter information. dataset (Dataset): An instance object of Dataset, which contains all interaction records. save (bool, optional): If ``True``, it will call :func:`save_datasets` to save split dataset. Defaults to ``False``. Returns: tuple: - train_data (AbstractDataLoader): The dataloader for training. - valid_data (AbstractDataLoader): The dataloader for validation. - test_data (AbstractDataLoader): The dataloader for testing. """ model_type = config['MODEL_TYPE'] es_str = [_.strip() for _ in config['eval_setting'].split(',')] es = EvalSetting(config) kwargs = {} if 'RS' in es_str[0]: kwargs['ratios'] = config['split_ratio'] if kwargs['ratios'] is None: raise ValueError('`ratios` should be set if `RS` is set') if 'LS' in es_str[0]: kwargs['leave_one_num'] = config['leave_one_num'] if kwargs['leave_one_num'] is None: raise ValueError('`leave_one_num` should be set if `LS` is set') kwargs['group_by_user'] = config['group_by_user'] getattr(es, es_str[0])(**kwargs) if es.split_args['strategy'] != 'loo' and model_type == ModelType.SEQUENTIAL: raise ValueError('Sequential models require "loo" split strategy.') builded_datasets = dataset.build(es) train_dataset, valid_dataset, test_dataset = builded_datasets phases = ['train', 'valid', 'test'] if save: save_datasets(config['checkpoint_dir'], name=phases, dataset=builded_datasets) kwargs = {} if config['training_neg_sample_num']: es.neg_sample_by(config['training_neg_sample_num']) if model_type != ModelType.SEQUENTIAL: sampler = Sampler(phases, builded_datasets, es.neg_sample_args['distribution']) else: sampler = RepeatableSampler(phases, dataset, es.neg_sample_args['distribution']) kwargs['sampler'] = sampler.set_phase('train') kwargs['neg_sample_args'] = copy.deepcopy(es.neg_sample_args) if model_type == ModelType.KNOWLEDGE: kg_sampler = KGSampler(dataset, es.neg_sample_args['distribution']) kwargs['kg_sampler'] = kg_sampler train_data = dataloader_construct( name='train', config=config, eval_setting=es, dataset=train_dataset, dl_format=config['MODEL_INPUT_TYPE'], batch_size=config['train_batch_size'], shuffle=True, **kwargs ) kwargs = {} if len(es_str) > 1 and getattr(es, es_str[1], None): getattr(es, es_str[1])() if 'sampler' not in locals(): sampler = Sampler(phases, builded_datasets, es.neg_sample_args['distribution']) kwargs['sampler'] = [sampler.set_phase('valid'), sampler.set_phase('test')] kwargs['neg_sample_args'] = copy.deepcopy(es.neg_sample_args) valid_data, test_data = dataloader_construct( name='evaluation', config=config, eval_setting=es, dataset=[valid_dataset, test_dataset], batch_size=config['eval_batch_size'], **kwargs ) return train_data, valid_data, test_data
[docs]def dataloader_construct(name, config, eval_setting, dataset, dl_format=InputType.POINTWISE, batch_size=1, shuffle=False, **kwargs): """Get a correct dataloader class by calling :func:`get_data_loader` to construct dataloader. Args: name (str): The stage of dataloader. It can only take two values: 'train' or 'evaluation'. config (Config): An instance object of Config, used to record parameter information. eval_setting (EvalSetting): An instance object of EvalSetting, used to record evaluation settings. dataset (Dataset or list of Dataset): The split dataset for constructing dataloader. dl_format (InputType, optional): The input type of dataloader. Defaults to :obj:`~recbole.utils.enum_type.InputType.POINTWISE`. batch_size (int, optional): The batch_size of dataloader. Defaults to ``1``. shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``. **kwargs: Other input args of dataloader, such as :attr:`sampler`, :attr:`kg_sampler` and :attr:`neg_sample_args`. The meaning of these args is the same as these args in some dataloaders. Returns: AbstractDataLoader or list of AbstractDataLoader: Constructed dataloader in split dataset. """ if not isinstance(dataset, list): dataset = [dataset] if not isinstance(batch_size, list): batch_size = [batch_size] * len(dataset) if len(dataset) != len(batch_size): raise ValueError('dataset {} and batch_size {} should have the same length'.format(dataset, batch_size)) kwargs_list = [{} for i in range(len(dataset))] for key, value in kwargs.items(): key = [key] * len(dataset) if not isinstance(value, list): value = [value] * len(dataset) if len(dataset) != len(value): raise ValueError('dataset {} and {} {} should have the same length'.format(dataset, key, value)) for kw, k, w in zip(kwargs_list, key, value): kw[k] = w model_type = config['MODEL_TYPE'] logger = getLogger() logger.info('Build [{}] DataLoader for [{}] with format [{}]'.format(model_type, name, dl_format)) logger.info(eval_setting) logger.info('batch_size = [{}], shuffle = [{}]\n'.format(batch_size, shuffle)) DataLoader = get_data_loader(name, config, eval_setting) try: ret = [ DataLoader( config=config, dataset=ds, batch_size=bs, dl_format=dl_format, shuffle=shuffle, **kw ) for ds, bs, kw in zip(dataset, batch_size, kwargs_list) ] except TypeError: raise ValueError('training_neg_sample_num should be 0') if len(ret) == 1: return ret[0] else: return ret
[docs]def save_datasets(save_path, name, dataset): """Save split datasets. Args: save_path (str): The path of directory for saving. name (str or list of str): The stage of dataloader. It can only take two values: 'train' or 'evaluation'. dataset (Dataset or list of Dataset): The split datasets. """ if (not isinstance(name, list)) and (not isinstance(dataset, list)): name = [name] dataset = [dataset] if len(name) != len(dataset): raise ValueError('len of name {} should equal to len of dataset'.format(name, dataset)) for i, d in enumerate(dataset): cur_path = os.path.join(save_path, name[i]) if not os.path.isdir(cur_path): os.makedirs(cur_path) d.save(cur_path)
[docs]def get_data_loader(name, config, eval_setting): """Return a dataloader class according to :attr:`config` and :attr:`eval_setting`. Args: name (str): The stage of dataloader. It can only take two values: 'train' or 'evaluation'. config (Config): An instance object of Config, used to record parameter information. eval_setting (EvalSetting): An instance object of EvalSetting, used to record evaluation settings. Returns: type: The dataloader class that meets the requirements in :attr:`config` and :attr:`eval_setting`. """ register_table = { 'DIN': _get_DIN_data_loader } if config['model'] in register_table: return register_table[config['model']](name, config, eval_setting) model_type = config['MODEL_TYPE'] neg_sample_strategy = eval_setting.neg_sample_args['strategy'] if model_type == ModelType.GENERAL or model_type == ModelType.TRADITIONAL: if neg_sample_strategy == 'none': return GeneralDataLoader elif neg_sample_strategy == 'by': return GeneralNegSampleDataLoader elif neg_sample_strategy == 'full': return GeneralFullDataLoader elif model_type == ModelType.CONTEXT: if neg_sample_strategy == 'none': return ContextDataLoader elif neg_sample_strategy == 'by': return ContextNegSampleDataLoader elif neg_sample_strategy == 'full': raise NotImplementedError('context model\'s full_sort has not been implemented') elif model_type == ModelType.SEQUENTIAL: if neg_sample_strategy == 'none': return SequentialDataLoader elif neg_sample_strategy == 'by': return SequentialNegSampleDataLoader elif neg_sample_strategy == 'full': return SequentialFullDataLoader elif model_type == ModelType.KNOWLEDGE: if neg_sample_strategy == 'by': if name == 'train': return KnowledgeBasedDataLoader else: return GeneralNegSampleDataLoader elif neg_sample_strategy == 'full': return GeneralFullDataLoader elif neg_sample_strategy == 'none': # return GeneralDataLoader # TODO 训练也可以为none? 看general的逻辑似乎是都可以为None raise NotImplementedError('The use of external negative sampling for knowledge model ' 'has not been implemented') else: raise NotImplementedError('model_type [{}] has not been implemented'.format(model_type))
def _get_DIN_data_loader(name, config, eval_setting): """Customized function for DIN to get correct dataloader class. Args: name (str): The stage of dataloader. It can only take two values: 'train' or 'evaluation'. config (Config): An instance object of Config, used to record parameter information. eval_setting (EvalSetting): An instance object of EvalSetting, used to record evaluation settings. Returns: type: The dataloader class that meets the requirements in :attr:`config` and :attr:`eval_setting`. """ neg_sample_strategy = eval_setting.neg_sample_args['strategy'] if neg_sample_strategy == 'none': return SequentialDataLoader elif neg_sample_strategy == 'by': return SequentialNegSampleDataLoader elif neg_sample_strategy == 'full': return SequentialFullDataLoader
[docs]class DLFriendlyAPI(object): """A Decorator class, which helps copying :class:`Dataset` methods to :class:`DataLoader`. These methods are called *DataLoader Friendly APIs*. E.g. if ``train_data`` is an object of :class:`DataLoader`, and :meth:`~recbole.data.dataset.dataset.Dataset.num` is a method of :class:`~recbole.data.dataset.dataset.Dataset`, Cause it has been decorated, :meth:`~recbole.data.dataset.dataset.Dataset.num` can be called directly by ``train_data``. See the example of :meth:`set` for details. Attributes: dataloader_apis (set): Register table that saves all the method names of DataLoader Friendly APIs. """ def __init__(self): self.dataloader_apis = set() def __iter__(self): return self.dataloader_apis
[docs] def set(self): """ Example: .. code:: python from recbole.data.utils import dlapi @dlapi.set() def dataset_meth(): ... """ def decorator(f): self.dataloader_apis.add(f.__name__) return f return decorator
dlapi = DLFriendlyAPI()