Source code for recbole.data.utils

# @Time   : 2020/7/21
# @Author : Yupeng Hou
# @Email  : houyupeng@ruc.edu.cn

# UPDATE:
# @Time   : 2020/10/19, 2020/9/17, 2020/8/31, 2021/2/20, 2021/3/1
# @Author : Yupeng Hou, Yushuo Chen, Kaiyuan Li, Haoran Cheng, Jiawei Guan
# @Email  : houyupeng@ruc.edu.cn, chenyushuo@ruc.edu.cn, tsotfsk@outlook.com, chenghaoran29@foxmail.com, guanjw@ruc.edu.cn

"""
recbole.data.utils
########################
"""

import copy
import importlib
import os
import pickle

from recbole.config import EvalSetting
from recbole.data.dataloader import *
from recbole.sampler import KGSampler, Sampler, RepeatableSampler
from recbole.utils import ModelType, ensure_dir, get_local_time
from recbole.utils.utils import set_color


[docs]def create_dataset(config): """Create dataset according to :attr:`config['model']` and :attr:`config['MODEL_TYPE']`. Args: config (Config): An instance object of Config, used to record parameter information. Returns: Dataset: Constructed dataset. """ dataset_module = importlib.import_module('recbole.data.dataset') if hasattr(dataset_module, config['model'] + 'Dataset'): return getattr(dataset_module, config['model'] + 'Dataset')(config) else: model_type = config['MODEL_TYPE'] if model_type == ModelType.SEQUENTIAL: from .dataset import SequentialDataset return SequentialDataset(config) elif model_type == ModelType.KNOWLEDGE: from .dataset import KnowledgeBasedDataset return KnowledgeBasedDataset(config) elif model_type == ModelType.SOCIAL: from .dataset import SocialDataset return SocialDataset(config) elif model_type == ModelType.DECISIONTREE: from .dataset import DecisionTreeDataset return DecisionTreeDataset(config) else: from .dataset import Dataset return Dataset(config)
[docs]def data_preparation(config, dataset, save=False): """Split the dataset by :attr:`config['eval_setting']` and call :func:`dataloader_construct` to create corresponding dataloader. Args: config (Config): An instance object of Config, used to record parameter information. dataset (Dataset): An instance object of Dataset, which contains all interaction records. save (bool, optional): If ``True``, it will call :func:`save_datasets` to save split dataset. Defaults to ``False``. Returns: tuple: - train_data (AbstractDataLoader): The dataloader for training. - valid_data (AbstractDataLoader): The dataloader for validation. - test_data (AbstractDataLoader): The dataloader for testing. """ model_type = config['MODEL_TYPE'] es = EvalSetting(config) built_datasets = dataset.build(es) train_dataset, valid_dataset, test_dataset = built_datasets phases = ['train', 'valid', 'test'] sampler = None logger = getLogger() train_neg_sample_args = config['train_neg_sample_args'] eval_neg_sample_args = es.neg_sample_args # Training train_kwargs = { 'config': config, 'dataset': train_dataset, 'batch_size': config['train_batch_size'], 'dl_format': config['MODEL_INPUT_TYPE'], 'shuffle': True, } if train_neg_sample_args['strategy'] != 'none': if dataset.label_field in dataset.inter_feat: raise ValueError( f'`training_neg_sample_num` should be 0 ' f'if inter_feat have label_field [{dataset.label_field}].' ) if model_type != ModelType.SEQUENTIAL: sampler = Sampler(phases, built_datasets, train_neg_sample_args['distribution']) else: sampler = RepeatableSampler(phases, dataset, train_neg_sample_args['distribution']) train_kwargs['sampler'] = sampler.set_phase('train') train_kwargs['neg_sample_args'] = train_neg_sample_args if model_type == ModelType.KNOWLEDGE: kg_sampler = KGSampler(dataset, train_neg_sample_args['distribution']) train_kwargs['kg_sampler'] = kg_sampler dataloader = get_data_loader('train', config, train_neg_sample_args) logger.info( set_color('Build', 'pink') + set_color(f' [{dataloader.__name__}]', 'yellow') + ' for ' + set_color('[train]', 'yellow') + ' with format ' + set_color(f'[{train_kwargs["dl_format"]}]', 'yellow') ) if train_neg_sample_args['strategy'] != 'none': logger.info( set_color('[train]', 'pink') + set_color(' Negative Sampling', 'blue') + f': {train_neg_sample_args}' ) else: logger.info(set_color('[train]', 'pink') + set_color(' No Negative Sampling', 'yellow')) logger.info( set_color('[train]', 'pink') + set_color(' batch_size', 'cyan') + ' = ' + set_color(f'[{train_kwargs["batch_size"]}]', 'yellow') + ', ' + set_color('shuffle', 'cyan') + ' = ' + set_color(f'[{train_kwargs["shuffle"]}]\n', 'yellow') ) train_data = dataloader(**train_kwargs) # Evaluation eval_kwargs = { 'config': config, 'batch_size': config['eval_batch_size'], 'dl_format': InputType.POINTWISE, 'shuffle': False, } valid_kwargs = {'dataset': valid_dataset} test_kwargs = {'dataset': test_dataset} if eval_neg_sample_args['strategy'] != 'none': if dataset.label_field in dataset.inter_feat: raise ValueError( f'It can not validate with `{es.es_str[1]}` ' f'when inter_feat have label_field [{dataset.label_field}].' ) if sampler is None: if model_type != ModelType.SEQUENTIAL: sampler = Sampler(phases, built_datasets, eval_neg_sample_args['distribution']) else: sampler = RepeatableSampler(phases, dataset, eval_neg_sample_args['distribution']) else: sampler.set_distribution(eval_neg_sample_args['distribution']) eval_kwargs['neg_sample_args'] = eval_neg_sample_args valid_kwargs['sampler'] = sampler.set_phase('valid') test_kwargs['sampler'] = sampler.set_phase('test') valid_kwargs.update(eval_kwargs) test_kwargs.update(eval_kwargs) dataloader = get_data_loader('evaluation', config, eval_neg_sample_args) logger.info( set_color('Build', 'pink') + set_color(f' [{dataloader.__name__}]', 'yellow') + ' for ' + set_color('[evaluation]', 'yellow') + ' with format ' + set_color(f'[{eval_kwargs["dl_format"]}]', 'yellow') ) logger.info(es) logger.info( set_color('[evaluation]', 'pink') + set_color(' batch_size', 'cyan') + ' = ' + set_color(f'[{eval_kwargs["batch_size"]}]', 'yellow') + ', ' + set_color('shuffle', 'cyan') + ' = ' + set_color(f'[{eval_kwargs["shuffle"]}]\n', 'yellow') ) valid_data = dataloader(**valid_kwargs) test_data = dataloader(**test_kwargs) if save: save_split_dataloaders(config, dataloaders=(train_data, valid_data, test_data)) return train_data, valid_data, test_data
[docs]def save_split_dataloaders(config, dataloaders): """Save split dataloaders. Args: config (Config): An instance object of Config, used to record parameter information. dataloaders (tuple of AbstractDataLoader): The split dataloaders. """ save_path = config['checkpoint_dir'] saved_dataloaders_file = f'{config["dataset"]}-for-{config["model"]}-dataloader.pth' file_path = os.path.join(save_path, saved_dataloaders_file) logger = getLogger() logger.info(set_color('Saved split dataloaders', 'blue') + f': {file_path}') with open(file_path, 'wb') as f: pickle.dump(dataloaders, f)
[docs]def load_split_dataloaders(saved_dataloaders_file): """Load split dataloaders. Args: saved_dataloaders_file (str): The path of split dataloaders. Returns: dataloaders (tuple of AbstractDataLoader): The split dataloaders. """ with open(saved_dataloaders_file, 'rb') as f: dataloaders = pickle.load(f) return dataloaders
[docs]def get_data_loader(name, config, neg_sample_args): """Return a dataloader class according to :attr:`config` and :attr:`eval_setting`. Args: name (str): The stage of dataloader. It can only take two values: 'train' or 'evaluation'. config (Config): An instance object of Config, used to record parameter information. neg_sample_args (dict) : Settings of negative sampling. Returns: type: The dataloader class that meets the requirements in :attr:`config` and :attr:`eval_setting`. """ register_table = { 'DIN': _get_DIN_data_loader, 'DIEN': _get_DIEN_data_loader, "MultiDAE": _get_AE_data_loader, "MultiVAE": _get_AE_data_loader, 'MacridVAE': _get_AE_data_loader, 'CDAE': _get_AE_data_loader, 'ENMF': _get_AE_data_loader, 'RaCT': _get_AE_data_loader, 'RecVAE': _get_AE_data_loader } if config['model'] in register_table: return register_table[config['model']](name, config, neg_sample_args) model_type_table = { ModelType.GENERAL: 'General', ModelType.TRADITIONAL: 'General', ModelType.CONTEXT: 'Context', ModelType.SEQUENTIAL: 'Sequential', ModelType.DECISIONTREE: 'DecisionTree', } neg_sample_strategy_table = { 'none': 'DataLoader', 'by': 'NegSampleDataLoader', 'full': 'FullDataLoader', } model_type = config['MODEL_TYPE'] neg_sample_strategy = neg_sample_args['strategy'] dataloader_module = importlib.import_module('recbole.data.dataloader') if model_type in model_type_table and neg_sample_strategy in neg_sample_strategy_table: dataloader_name = model_type_table[model_type] + neg_sample_strategy_table[neg_sample_strategy] return getattr(dataloader_module, dataloader_name) elif model_type == ModelType.KNOWLEDGE: if neg_sample_strategy == 'by': if name == 'train': return KnowledgeBasedDataLoader else: return GeneralNegSampleDataLoader elif neg_sample_strategy == 'full': return GeneralFullDataLoader elif neg_sample_strategy == 'none': raise NotImplementedError( 'The use of external negative sampling for knowledge model has not been implemented' ) else: raise NotImplementedError(f'Model_type [{model_type}] has not been implemented.')
def _get_DIN_data_loader(name, config, neg_sample_args): """Customized function for DIN to get correct dataloader class. Args: name (str): The stage of dataloader. It can only take two values: 'train' or 'evaluation'. config (Config): An instance object of Config, used to record parameter information. neg_sample_args : Settings of negative sampling. Returns: type: The dataloader class that meets the requirements in :attr:`config` and :attr:`eval_setting`. """ neg_sample_strategy = neg_sample_args['strategy'] if neg_sample_strategy == 'none': return SequentialDataLoader elif neg_sample_strategy == 'by': return SequentialNegSampleDataLoader elif neg_sample_strategy == 'full': return SequentialFullDataLoader def _get_DIEN_data_loader(name, config, neg_sample_args): """Customized function for DIEN to get correct dataloader class. Args: name (str): The stage of dataloader. It can only take two values: 'train' or 'evaluation'. config (Config): An instance object of Config, used to record parameter information. neg_sample_args : Settings of negative sampling. Returns: type: The dataloader class that meets the requirements in :attr:`config` and :attr:`eval_setting`. """ neg_sample_strategy = neg_sample_args['strategy'] if neg_sample_strategy == 'none': return DIENDataLoader elif neg_sample_strategy == 'by': return DIENNegSampleDataLoader elif neg_sample_strategy == 'full': return DIENFullDataLoader def _get_AE_data_loader(name, config, neg_sample_args): """Customized function for Multi-DAE and Multi-VAE to get correct dataloader class. Args: name (str): The stage of dataloader. It can only take two values: 'train' or 'evaluation'. config (Config): An instance object of Config, used to record parameter information. neg_sample_args (dict): Settings of negative sampling. Returns: type: The dataloader class that meets the requirements in :attr:`config` and :attr:`eval_setting`. """ neg_sample_strategy = neg_sample_args['strategy'] if name == "train": return UserDataLoader else: if neg_sample_strategy == 'none': return GeneralDataLoader elif neg_sample_strategy == 'by': return GeneralNegSampleDataLoader elif neg_sample_strategy == 'full': return GeneralFullDataLoader
[docs]class DLFriendlyAPI(object): """A Decorator class, which helps copying :class:`Dataset` methods to :class:`DataLoader`. These methods are called *DataLoader Friendly APIs*. E.g. if ``train_data`` is an object of :class:`DataLoader`, and :meth:`~recbole.data.dataset.dataset.Dataset.num` is a method of :class:`~recbole.data.dataset.dataset.Dataset`, Cause it has been decorated, :meth:`~recbole.data.dataset.dataset.Dataset.num` can be called directly by ``train_data``. See the example of :meth:`set` for details. Attributes: dataloader_apis (set): Register table that saves all the method names of DataLoader Friendly APIs. """ def __init__(self): self.dataloader_apis = set() def __iter__(self): return self.dataloader_apis
[docs] def set(self): """ Example: .. code:: python from recbole.data.utils import dlapi @dlapi.set() def dataset_meth(): ... """ def decorator(f): self.dataloader_apis.add(f.__name__) return f return decorator
dlapi = DLFriendlyAPI()