# @Time : 2020/7/21
# @Author : Yupeng Hou
# @Email : houyupeng@ruc.edu.cn
# UPDATE:
# @Time : 2021/7/9, 2020/9/17, 2020/8/31, 2021/2/20, 2021/3/1
# @Author : Yupeng Hou, Yushuo Chen, Kaiyuan Li, Haoran Cheng, Jiawei Guan
# @Email : houyupeng@ruc.edu.cn, chenyushuo@ruc.edu.cn, tsotfsk@outlook.com, chenghaoran29@foxmail.com, guanjw@ruc.edu.cn
"""
recbole.data.utils
########################
"""
import copy
import importlib
import os
import pickle
from recbole.data.dataloader import *
from recbole.sampler import KGSampler, Sampler, RepeatableSampler
from recbole.utils import ModelType, ensure_dir, get_local_time, set_color
from recbole.utils.argument_list import dataset_arguments
[docs]def create_dataset(config):
"""Create dataset according to :attr:`config['model']` and :attr:`config['MODEL_TYPE']`.
If :attr:`config['dataset_save_path']` file exists and
its :attr:`config` of dataset is equal to current :attr:`config` of dataset.
It will return the saved dataset in :attr:`config['dataset_save_path']`.
Args:
config (Config): An instance object of Config, used to record parameter information.
Returns:
Dataset: Constructed dataset.
"""
dataset_module = importlib.import_module('recbole.data.dataset')
if hasattr(dataset_module, config['model'] + 'Dataset'):
dataset_class = getattr(dataset_module, config['model'] + 'Dataset')
else:
model_type = config['MODEL_TYPE']
type2class = {
ModelType.GENERAL: 'Dataset',
ModelType.SEQUENTIAL: 'SequentialDataset',
ModelType.CONTEXT: 'Dataset',
ModelType.KNOWLEDGE: 'KnowledgeBasedDataset',
ModelType.TRADITIONAL: 'Dataset',
ModelType.DECISIONTREE: 'Dataset',
}
dataset_class = getattr(dataset_module, type2class[model_type])
default_file = os.path.join(config['checkpoint_dir'], f'{config["dataset"]}-{dataset_class.__name__}.pth')
file = config['dataset_save_path'] or default_file
if os.path.exists(file):
with open(file, 'rb') as f:
dataset = pickle.load(f)
dataset_args_unchanged = True
for arg in dataset_arguments + ['seed', 'repeatable']:
if config[arg] != dataset.config[arg]:
dataset_args_unchanged = False
break
if dataset_args_unchanged:
logger = getLogger()
logger.info(set_color('Load filtered dataset from', 'pink') + f': [{file}]')
return dataset
dataset = dataset_class(config)
if config['save_dataset']:
dataset.save()
return dataset
[docs]def save_split_dataloaders(config, dataloaders):
"""Save split dataloaders.
Args:
config (Config): An instance object of Config, used to record parameter information.
dataloaders (tuple of AbstractDataLoader): The split dataloaders.
"""
ensure_dir(config['checkpoint_dir'])
save_path = config['checkpoint_dir']
saved_dataloaders_file = f'{config["dataset"]}-for-{config["model"]}-dataloader.pth'
file_path = os.path.join(save_path, saved_dataloaders_file)
logger = getLogger()
logger.info(set_color('Saving split dataloaders into', 'pink') + f': [{file_path}]')
with open(file_path, 'wb') as f:
pickle.dump(dataloaders, f)
[docs]def load_split_dataloaders(config):
"""Load split dataloaders if saved dataloaders exist and
their :attr:`config` of dataset are the same as current :attr:`config` of dataset.
Args:
config (Config): An instance object of Config, used to record parameter information.
Returns:
dataloaders (tuple of AbstractDataLoader or None): The split dataloaders.
"""
default_file = os.path.join(config['checkpoint_dir'], f'{config["dataset"]}-for-{config["model"]}-dataloader.pth')
dataloaders_save_path = config['dataloaders_save_path'] or default_file
if not os.path.exists(dataloaders_save_path):
return None
with open(dataloaders_save_path, 'rb') as f:
train_data, valid_data, test_data = pickle.load(f)
for arg in dataset_arguments + ['seed', 'repeatable', 'eval_args']:
if config[arg] != train_data.config[arg]:
return None
train_data.update_config(config)
valid_data.update_config(config)
test_data.update_config(config)
logger = getLogger()
logger.info(set_color('Load split dataloaders from', 'pink') + f': [{dataloaders_save_path}]')
return train_data, valid_data, test_data
[docs]def data_preparation(config, dataset):
"""Split the dataset by :attr:`config['eval_args']` and create training, validation and test dataloader.
Note:
If we can load split dataloaders by :meth:`load_split_dataloaders`, we will not create new split dataloaders.
Args:
config (Config): An instance object of Config, used to record parameter information.
dataset (Dataset): An instance object of Dataset, which contains all interaction records.
Returns:
tuple:
- train_data (AbstractDataLoader): The dataloader for training.
- valid_data (AbstractDataLoader): The dataloader for validation.
- test_data (AbstractDataLoader): The dataloader for testing.
"""
dataloaders = load_split_dataloaders(config)
if dataloaders is not None:
train_data, valid_data, test_data = dataloaders
else:
model_type = config['MODEL_TYPE']
built_datasets = dataset.build()
train_dataset, valid_dataset, test_dataset = built_datasets
train_sampler, valid_sampler, test_sampler = create_samplers(config, dataset, built_datasets)
if model_type != ModelType.KNOWLEDGE:
train_data = get_dataloader(config, 'train')(config, train_dataset, train_sampler, shuffle=True)
else:
kg_sampler = KGSampler(dataset, config['train_neg_sample_args']['distribution'])
train_data = get_dataloader(config, 'train')(config, train_dataset, train_sampler, kg_sampler, shuffle=True)
valid_data = get_dataloader(config, 'evaluation')(config, valid_dataset, valid_sampler, shuffle=False)
test_data = get_dataloader(config, 'evaluation')(config, test_dataset, test_sampler, shuffle=False)
if config['save_dataloaders']:
save_split_dataloaders(config, dataloaders=(train_data, valid_data, test_data))
logger = getLogger()
logger.info(
set_color('[Training]: ', 'pink') + set_color('train_batch_size', 'cyan') + ' = ' +
set_color(f'[{config["train_batch_size"]}]', 'yellow') + set_color(' negative sampling', 'cyan') + ': ' +
set_color(f'[{config["neg_sampling"]}]', 'yellow')
)
logger.info(
set_color('[Evaluation]: ', 'pink') + set_color('eval_batch_size', 'cyan') + ' = ' +
set_color(f'[{config["eval_batch_size"]}]', 'yellow') + set_color(' eval_args', 'cyan') + ': ' +
set_color(f'[{config["eval_args"]}]', 'yellow')
)
return train_data, valid_data, test_data
[docs]def get_dataloader(config, phase):
"""Return a dataloader class according to :attr:`config` and :attr:`phase`.
Args:
config (Config): An instance object of Config, used to record parameter information.
phase (str): The stage of dataloader. It can only take two values: 'train' or 'evaluation'.
Returns:
type: The dataloader class that meets the requirements in :attr:`config` and :attr:`phase`.
"""
register_table = {
"MultiDAE": _get_AE_dataloader,
"MultiVAE": _get_AE_dataloader,
'MacridVAE': _get_AE_dataloader,
'CDAE': _get_AE_dataloader,
'ENMF': _get_AE_dataloader,
'RaCT': _get_AE_dataloader,
'RecVAE': _get_AE_dataloader,
}
if config['model'] in register_table:
return register_table[config['model']](config, phase)
model_type = config['MODEL_TYPE']
if phase == 'train':
if model_type != ModelType.KNOWLEDGE:
return TrainDataLoader
else:
return KnowledgeBasedDataLoader
else:
eval_strategy = config['eval_neg_sample_args']['strategy']
if eval_strategy in {'none', 'by'}:
return NegSampleEvalDataLoader
elif eval_strategy == 'full':
return FullSortEvalDataLoader
def _get_AE_dataloader(config, phase):
"""Customized function for VAE models to get correct dataloader class.
Args:
config (Config): An instance object of Config, used to record parameter information.
phase (str): The stage of dataloader. It can only take two values: 'train' or 'evaluation'.
Returns:
type: The dataloader class that meets the requirements in :attr:`config` and :attr:`phase`.
"""
if phase == 'train':
return UserDataLoader
else:
eval_strategy = config['eval_neg_sample_args']['strategy']
if eval_strategy in {'none', 'by'}:
return NegSampleEvalDataLoader
elif eval_strategy == 'full':
return FullSortEvalDataLoader
[docs]def create_samplers(config, dataset, built_datasets):
"""Create sampler for training, validation and testing.
Args:
config (Config): An instance object of Config, used to record parameter information.
dataset (Dataset): An instance object of Dataset, which contains all interaction records.
built_datasets (list of Dataset): A list of split Dataset, which contains dataset for
training, validation and testing.
Returns:
tuple:
- train_sampler (AbstractSampler): The sampler for training.
- valid_sampler (AbstractSampler): The sampler for validation.
- test_sampler (AbstractSampler): The sampler for testing.
"""
phases = ['train', 'valid', 'test']
train_neg_sample_args = config['train_neg_sample_args']
eval_neg_sample_args = config['eval_neg_sample_args']
sampler = None
train_sampler, valid_sampler, test_sampler = None, None, None
if train_neg_sample_args['strategy'] != 'none':
if not config['repeatable']:
sampler = Sampler(phases, built_datasets, train_neg_sample_args['distribution'])
else:
sampler = RepeatableSampler(phases, dataset, train_neg_sample_args['distribution'])
train_sampler = sampler.set_phase('train')
if eval_neg_sample_args['strategy'] != 'none':
if sampler is None:
if not config['repeatable']:
sampler = Sampler(phases, built_datasets, eval_neg_sample_args['distribution'])
else:
sampler = RepeatableSampler(phases, dataset, eval_neg_sample_args['distribution'])
else:
sampler.set_distribution(eval_neg_sample_args['distribution'])
valid_sampler = sampler.set_phase('valid')
test_sampler = sampler.set_phase('test')
return train_sampler, valid_sampler, test_sampler