Source code for recbole.data.utils

# @Time   : 2020/7/21
# @Author : Yupeng Hou
# @Email  : houyupeng@ruc.edu.cn

# UPDATE:
# @Time   : 2021/7/9, 2020/9/17, 2020/8/31, 2021/2/20, 2021/3/1, 2022/7/6
# @Author : Yupeng Hou, Yushuo Chen, Kaiyuan Li, Haoran Cheng, Jiawei Guan, Gaowei Zhang
# @Email  : houyupeng@ruc.edu.cn, chenyushuo@ruc.edu.cn, tsotfsk@outlook.com, chenghaoran29@foxmail.com, guanjw@ruc.edu.cn, zgw15630559577@163.com

"""
recbole.data.utils
########################
"""

import copy
import importlib
import os
import pickle
import warnings
from typing import Literal

from recbole.data.dataloader import *
from recbole.sampler import KGSampler, Sampler, RepeatableSampler
from recbole.utils import ModelType, ensure_dir, get_local_time, set_color
from recbole.utils.argument_list import dataset_arguments


[docs]def create_dataset(config): """Create dataset according to :attr:`config['model']` and :attr:`config['MODEL_TYPE']`. If :attr:`config['dataset_save_path']` file exists and its :attr:`config` of dataset is equal to current :attr:`config` of dataset. It will return the saved dataset in :attr:`config['dataset_save_path']`. Args: config (Config): An instance object of Config, used to record parameter information. Returns: Dataset: Constructed dataset. """ dataset_module = importlib.import_module("recbole.data.dataset") if hasattr(dataset_module, config["model"] + "Dataset"): dataset_class = getattr(dataset_module, config["model"] + "Dataset") else: model_type = config["MODEL_TYPE"] type2class = { ModelType.GENERAL: "Dataset", ModelType.SEQUENTIAL: "SequentialDataset", ModelType.CONTEXT: "Dataset", ModelType.KNOWLEDGE: "KnowledgeBasedDataset", ModelType.TRADITIONAL: "Dataset", ModelType.DECISIONTREE: "Dataset", } dataset_class = getattr(dataset_module, type2class[model_type]) default_file = os.path.join( config["checkpoint_dir"], f'{config["dataset"]}-{dataset_class.__name__}.pth' ) file = config["dataset_save_path"] or default_file if os.path.exists(file): with open(file, "rb") as f: dataset = pickle.load(f) dataset_args_unchanged = True for arg in dataset_arguments + ["seed", "repeatable"]: if config[arg] != dataset.config[arg]: dataset_args_unchanged = False break if dataset_args_unchanged: logger = getLogger() logger.info(set_color("Load filtered dataset from", "pink") + f": [{file}]") return dataset dataset = dataset_class(config) if config["save_dataset"]: dataset.save() return dataset
[docs]def save_split_dataloaders(config, dataloaders): """Save split dataloaders. Args: config (Config): An instance object of Config, used to record parameter information. dataloaders (tuple of AbstractDataLoader): The split dataloaders. """ ensure_dir(config["checkpoint_dir"]) save_path = config["checkpoint_dir"] saved_dataloaders_file = f'{config["dataset"]}-for-{config["model"]}-dataloader.pth' file_path = os.path.join(save_path, saved_dataloaders_file) logger = getLogger() logger.info(set_color("Saving split dataloaders into", "pink") + f": [{file_path}]") Serialization_dataloaders = [] for dataloader in dataloaders: generator_state = dataloader.generator.get_state() dataloader.generator = None dataloader.sampler.generator = None Serialization_dataloaders += [(dataloader, generator_state)] with open(file_path, "wb") as f: pickle.dump(Serialization_dataloaders, f)
[docs]def load_split_dataloaders(config): """Load split dataloaders if saved dataloaders exist and their :attr:`config` of dataset are the same as current :attr:`config` of dataset. Args: config (Config): An instance object of Config, used to record parameter information. Returns: dataloaders (tuple of AbstractDataLoader or None): The split dataloaders. """ default_file = os.path.join( config["checkpoint_dir"], f'{config["dataset"]}-for-{config["model"]}-dataloader.pth', ) dataloaders_save_path = config["dataloaders_save_path"] or default_file if not os.path.exists(dataloaders_save_path): return None with open(dataloaders_save_path, "rb") as f: dataloaders = [] for data_loader, generator_state in pickle.load(f): generator = torch.Generator() generator.set_state(generator_state) data_loader.generator = generator data_loader.sampler.generator = generator dataloaders.append(data_loader) train_data, valid_data, test_data = dataloaders for arg in dataset_arguments + ["seed", "repeatable", "eval_args"]: if config[arg] != train_data.config[arg]: return None train_data.update_config(config) valid_data.update_config(config) test_data.update_config(config) logger = getLogger() logger.info( set_color("Load split dataloaders from", "pink") + f": [{dataloaders_save_path}]" ) return train_data, valid_data, test_data
[docs]def data_preparation(config, dataset): """Split the dataset by :attr:`config['[valid|test]_eval_args']` and create training, validation and test dataloader. Note: If we can load split dataloaders by :meth:`load_split_dataloaders`, we will not create new split dataloaders. Args: config (Config): An instance object of Config, used to record parameter information. dataset (Dataset): An instance object of Dataset, which contains all interaction records. Returns: tuple: - train_data (AbstractDataLoader): The dataloader for training. - valid_data (AbstractDataLoader): The dataloader for validation. - test_data (AbstractDataLoader): The dataloader for testing. """ dataloaders = load_split_dataloaders(config) if dataloaders is not None: train_data, valid_data, test_data = dataloaders dataset._change_feat_format() else: model_type = config["MODEL_TYPE"] built_datasets = dataset.build() train_dataset, valid_dataset, test_dataset = built_datasets train_sampler, valid_sampler, test_sampler = create_samplers( config, dataset, built_datasets ) if model_type != ModelType.KNOWLEDGE: train_data = get_dataloader(config, "train")( config, train_dataset, train_sampler, shuffle=config["shuffle"] ) else: kg_sampler = KGSampler( dataset, config["train_neg_sample_args"]["distribution"], config["train_neg_sample_args"]["alpha"], ) train_data = get_dataloader(config, "train")( config, train_dataset, train_sampler, kg_sampler, shuffle=True ) valid_data = get_dataloader(config, "valid")( config, valid_dataset, valid_sampler, shuffle=False ) test_data = get_dataloader(config, "test")( config, test_dataset, test_sampler, shuffle=False ) if config["save_dataloaders"]: save_split_dataloaders( config, dataloaders=(train_data, valid_data, test_data) ) logger = getLogger() logger.info( set_color("[Training]: ", "pink") + set_color("train_batch_size", "cyan") + " = " + set_color(f'[{config["train_batch_size"]}]', "yellow") + set_color(" train_neg_sample_args", "cyan") + ": " + set_color(f'[{config["train_neg_sample_args"]}]', "yellow") ) logger.info( set_color("[Evaluation]: ", "pink") + set_color("eval_batch_size", "cyan") + " = " + set_color(f'[{config["eval_batch_size"]}]', "yellow") + set_color(" eval_args", "cyan") + ": " + set_color(f'[{config["eval_args"]}]', "yellow") ) return train_data, valid_data, test_data
[docs]def get_dataloader(config, phase: Literal["train", "valid", "test", "evaluation"]): """Return a dataloader class according to :attr:`config` and :attr:`phase`. Args: config (Config): An instance object of Config, used to record parameter information. phase (str): The stage of dataloader. It can only take 4 values: 'train', 'valid', 'test' or 'evaluation'. Notes: 'evaluation' has been deprecated, please use 'valid' or 'test' instead. Returns: type: The dataloader class that meets the requirements in :attr:`config` and :attr:`phase`. """ if phase not in ["train", "valid", "test", "evaluation"]: raise ValueError( "`phase` can only be 'train', 'valid', 'test' or 'evaluation'." ) if phase == "evaluation": phase = "test" warnings.warn( "'evaluation' has been deprecated, please use 'valid' or 'test' instead.", DeprecationWarning, ) register_table = { "MultiDAE": _get_AE_dataloader, "MultiVAE": _get_AE_dataloader, "MacridVAE": _get_AE_dataloader, "CDAE": _get_AE_dataloader, "ENMF": _get_AE_dataloader, "RaCT": _get_AE_dataloader, "RecVAE": _get_AE_dataloader, "DiffRec": _get_AE_dataloader, "LDiffRec": _get_AE_dataloader, } if config["model"] in register_table: return register_table[config["model"]](config, phase) model_type = config["MODEL_TYPE"] if phase == "train": if model_type != ModelType.KNOWLEDGE: return TrainDataLoader else: return KnowledgeBasedDataLoader else: eval_mode = config["eval_args"]["mode"][phase] if eval_mode == "full": return FullSortEvalDataLoader else: return NegSampleEvalDataLoader
def _get_AE_dataloader(config, phase: Literal["train", "valid", "test", "evaluation"]): """Customized function for VAE models to get correct dataloader class. Args: config (Config): An instance object of Config, used to record parameter information. phase (str): The stage of dataloader. It can only take 4 values: 'train', 'valid', 'test' or 'evaluation'. Notes: 'evaluation' has been deprecated, please use 'valid' or 'test' instead. Returns: type: The dataloader class that meets the requirements in :attr:`config` and :attr:`phase`. """ if phase not in ["train", "valid", "test", "evaluation"]: raise ValueError( "`phase` can only be 'train', 'valid', 'test' or 'evaluation'." ) if phase == "evaluation": phase = "test" warnings.warn( "'evaluation' has been deprecated, please use 'valid' or 'test' instead.", DeprecationWarning, ) if phase == "train": return UserDataLoader else: eval_mode = config["eval_args"]["mode"][phase] if eval_mode == "full": return FullSortEvalDataLoader else: return NegSampleEvalDataLoader def _create_sampler( dataset, built_datasets, distribution: str, repeatable: bool, alpha: float = 1.0, base_sampler=None, ): phases = ["train", "valid", "test"] sampler = None if distribution != "none": if base_sampler is not None: base_sampler.set_distribution(distribution) return base_sampler if not repeatable: sampler = Sampler( phases, built_datasets, distribution, alpha, ) else: sampler = RepeatableSampler( phases, dataset, distribution, alpha, ) return sampler
[docs]def create_samplers(config, dataset, built_datasets): """Create sampler for training, validation and testing. Args: config (Config): An instance object of Config, used to record parameter information. dataset (Dataset): An instance object of Dataset, which contains all interaction records. built_datasets (list of Dataset): A list of split Dataset, which contains dataset for training, validation and testing. Returns: tuple: - train_sampler (AbstractSampler): The sampler for training. - valid_sampler (AbstractSampler): The sampler for validation. - test_sampler (AbstractSampler): The sampler for testing. """ train_neg_sample_args = config["train_neg_sample_args"] valid_neg_sample_args = config["valid_neg_sample_args"] test_neg_sample_args = config["test_neg_sample_args"] repeatable = config["repeatable"] base_sampler = _create_sampler( dataset, built_datasets, train_neg_sample_args["distribution"], repeatable, train_neg_sample_args["alpha"], ) train_sampler = base_sampler.set_phase("train") if base_sampler else None valid_sampler = _create_sampler( dataset, built_datasets, valid_neg_sample_args["distribution"], repeatable, base_sampler=base_sampler, ) valid_sampler = valid_sampler.set_phase("valid") if valid_sampler else None test_sampler = _create_sampler( dataset, built_datasets, test_neg_sample_args["distribution"], repeatable, base_sampler=base_sampler, ) test_sampler = test_sampler.set_phase("test") if test_sampler else None return train_sampler, valid_sampler, test_sampler