# @Time : 2020/7/21
# @Author : Yupeng Hou
# @Email : houyupeng@ruc.edu.cn
# UPDATE:
# @Time : 2021/7/9, 2020/9/17, 2020/8/31, 2021/2/20, 2021/3/1
# @Author : Yupeng Hou, Yushuo Chen, Kaiyuan Li, Haoran Cheng, Jiawei Guan
# @Email : houyupeng@ruc.edu.cn, chenyushuo@ruc.edu.cn, tsotfsk@outlook.com, chenghaoran29@foxmail.com, guanjw@ruc.edu.cn
"""
recbole.data.utils
########################
"""
import copy
import importlib
import os
import pickle
from recbole.data.dataloader import *
from recbole.sampler import KGSampler, Sampler, RepeatableSampler
from recbole.utils import ModelType, ensure_dir, get_local_time, set_color
[docs]def create_dataset(config):
"""Create dataset according to :attr:`config['model']` and :attr:`config['MODEL_TYPE']`.
Args:
config (Config): An instance object of Config, used to record parameter information.
Returns:
Dataset: Constructed dataset.
"""
dataset_module = importlib.import_module('recbole.data.dataset')
if hasattr(dataset_module, config['model'] + 'Dataset'):
return getattr(dataset_module, config['model'] + 'Dataset')(config)
else:
model_type = config['MODEL_TYPE']
if model_type == ModelType.SEQUENTIAL:
from .dataset import SequentialDataset
return SequentialDataset(config)
elif model_type == ModelType.KNOWLEDGE:
from .dataset import KnowledgeBasedDataset
return KnowledgeBasedDataset(config)
elif model_type == ModelType.DECISIONTREE:
from .dataset import DecisionTreeDataset
return DecisionTreeDataset(config)
else:
from .dataset import Dataset
return Dataset(config)
[docs]def save_split_dataloaders(config, dataloaders):
"""Save split dataloaders.
Args:
config (Config): An instance object of Config, used to record parameter information.
dataloaders (tuple of AbstractDataLoader): The split dataloaders.
"""
save_path = config['checkpoint_dir']
saved_dataloaders_file = f'{config["dataset"]}-for-{config["model"]}-dataloader.pth'
file_path = os.path.join(save_path, saved_dataloaders_file)
logger = getLogger()
logger.info(set_color('Saved split dataloaders', 'blue') + f': {file_path}')
with open(file_path, 'wb') as f:
pickle.dump(dataloaders, f)
[docs]def load_split_dataloaders(saved_dataloaders_file):
"""Load split dataloaders.
Args:
saved_dataloaders_file (str): The path of split dataloaders.
Returns:
dataloaders (tuple of AbstractDataLoader): The split dataloaders.
"""
with open(saved_dataloaders_file, 'rb') as f:
dataloaders = pickle.load(f)
return dataloaders
[docs]def data_preparation(config, dataset, save=False):
"""Split the dataset by :attr:`config['eval_args']` and create training, validation and test dataloader.
Args:
config (Config): An instance object of Config, used to record parameter information.
dataset (Dataset): An instance object of Dataset, which contains all interaction records.
save (bool, optional): If ``True``, it will call :func:`save_datasets` to save split dataset.
Defaults to ``False``.
Returns:
tuple:
- train_data (AbstractDataLoader): The dataloader for training.
- valid_data (AbstractDataLoader): The dataloader for validation.
- test_data (AbstractDataLoader): The dataloader for testing.
"""
model_type = config['MODEL_TYPE']
built_datasets = dataset.build()
logger = getLogger()
train_dataset, valid_dataset, test_dataset = built_datasets
train_sampler, valid_sampler, test_sampler = create_samplers(config, dataset, built_datasets)
if model_type != ModelType.KNOWLEDGE:
train_data = get_dataloader(config, 'train')(config, train_dataset, train_sampler, shuffle=True)
else:
kg_sampler = KGSampler(dataset, config['train_neg_sample_args']['distribution'])
train_data = get_dataloader(config, 'train')(config, train_dataset, train_sampler, kg_sampler, shuffle=True)
valid_data = get_dataloader(config, 'evaluation')(config, valid_dataset, valid_sampler, shuffle=False)
test_data = get_dataloader(config, 'evaluation')(config, test_dataset, test_sampler, shuffle=False)
logger.info(
set_color('[Training]: ', 'pink') + set_color('train_batch_size', 'cyan') + ' = ' +
set_color(f'[{config["train_batch_size"]}]', 'yellow') + set_color(' negative sampling', 'cyan') + ': ' +
set_color(f'[{config["neg_sampling"]}]', 'yellow')
)
logger.info(
set_color('[Evaluation]: ', 'pink') + set_color('eval_batch_size', 'cyan') + ' = ' +
set_color(f'[{config["eval_batch_size"]}]', 'yellow') + set_color(' eval_args', 'cyan') + ': ' +
set_color(f'[{config["eval_args"]}]', 'yellow')
)
if save:
save_split_dataloaders(config, dataloaders=(train_data, valid_data, test_data))
return train_data, valid_data, test_data
[docs]def get_dataloader(config, phase):
"""Return a dataloader class according to :attr:`config` and :attr:`phase`.
Args:
config (Config): An instance object of Config, used to record parameter information.
phase (str): The stage of dataloader. It can only take two values: 'train' or 'evaluation'.
Returns:
type: The dataloader class that meets the requirements in :attr:`config` and :attr:`phase`.
"""
register_table = {
"MultiDAE": _get_AE_dataloader,
"MultiVAE": _get_AE_dataloader,
'MacridVAE': _get_AE_dataloader,
'CDAE': _get_AE_dataloader,
'ENMF': _get_AE_dataloader,
'RaCT': _get_AE_dataloader,
'RecVAE': _get_AE_dataloader,
}
if config['model'] in register_table:
return register_table[config['model']](config, phase)
model_type = config['MODEL_TYPE']
if phase == 'train':
if model_type != ModelType.KNOWLEDGE:
return TrainDataLoader
else:
return KnowledgeBasedDataLoader
else:
eval_strategy = config['eval_neg_sample_args']['strategy']
if eval_strategy in {'none', 'by'}:
return NegSampleEvalDataLoader
elif eval_strategy == 'full':
return FullSortEvalDataLoader
def _get_AE_dataloader(config, phase):
"""Customized function for VAE models to get correct dataloader class.
Args:
config (Config): An instance object of Config, used to record parameter information.
phase (str): The stage of dataloader. It can only take two values: 'train' or 'evaluation'.
Returns:
type: The dataloader class that meets the requirements in :attr:`config` and :attr:`phase`.
"""
if phase == 'train':
return UserDataLoader
else:
eval_strategy = config['eval_neg_sample_args']['strategy']
if eval_strategy in {'none', 'by'}:
return NegSampleEvalDataLoader
elif eval_strategy == 'full':
return FullSortEvalDataLoader
[docs]def create_samplers(config, dataset, built_datasets):
"""Create sampler for training, validation and testing.
Args:
config (Config): An instance object of Config, used to record parameter information.
dataset (Dataset): An instance object of Dataset, which contains all interaction records.
built_datasets (list of Dataset): A list of split Dataset, which contains dataset for
training, validation and testing.
Returns:
tuple:
- train_sampler (AbstractSampler): The sampler for training.
- valid_sampler (AbstractSampler): The sampler for validation.
- test_sampler (AbstractSampler): The sampler for testing.
"""
phases = ['train', 'valid', 'test']
train_neg_sample_args = config['train_neg_sample_args']
eval_neg_sample_args = config['eval_neg_sample_args']
sampler = None
train_sampler, valid_sampler, test_sampler = None, None, None
if train_neg_sample_args['strategy'] != 'none':
if not config['repeatable']:
sampler = Sampler(phases, built_datasets, train_neg_sample_args['distribution'])
else:
sampler = RepeatableSampler(phases, dataset, train_neg_sample_args['distribution'])
train_sampler = sampler.set_phase('train')
if eval_neg_sample_args['strategy'] != 'none':
if sampler is None:
if not config['repeatable']:
sampler = Sampler(phases, built_datasets, eval_neg_sample_args['distribution'])
else:
sampler = RepeatableSampler(phases, dataset, eval_neg_sample_args['distribution'])
else:
sampler.set_distribution(eval_neg_sample_args['distribution'])
valid_sampler = sampler.set_phase('valid')
test_sampler = sampler.set_phase('test')
return train_sampler, valid_sampler, test_sampler