# @Time : 2020/10/6, 2022/7/18
# @Author : Shanlei Mu, Lei Wang
# @Email : slmu@ruc.edu.cn, zxcptss@gmail.com
# UPDATE:
# @Time : 2022/7/8, 2022/07/10, 2022/07/13, 2023/2/11
# @Author : Zhen Tian, Junjie Zhang, Gaowei Zhang
# @Email : chenyuwuxinn@gmail.com, zjj001128@163.com, zgw15630559577@163.com
"""
recbole.quick_start
########################
"""
import logging
import sys
import torch.distributed as dist
from collections.abc import MutableMapping
from logging import getLogger
from ray import tune
from recbole.config import Config
from recbole.data import (
create_dataset,
data_preparation,
)
from recbole.data.transform import construct_transform
from recbole.utils import (
init_logger,
get_model,
get_trainer,
init_seed,
set_color,
get_flops,
get_environment,
)
[docs]def run(
model,
dataset,
config_file_list=None,
config_dict=None,
saved=True,
nproc=1,
world_size=-1,
ip="localhost",
port="5678",
group_offset=0,
):
if nproc == 1 and world_size <= 0:
res = run_recbole(
model=model,
dataset=dataset,
config_file_list=config_file_list,
config_dict=config_dict,
saved=saved,
)
else:
if world_size == -1:
world_size = nproc
import torch.multiprocessing as mp
# Refer to https://discuss.pytorch.org/t/problems-with-torch-multiprocess-spawn-and-simplequeue/69674/2
# https://discuss.pytorch.org/t/return-from-mp-spawn/94302/2
queue = mp.get_context("spawn").SimpleQueue()
config_dict = config_dict or {}
config_dict.update(
{
"world_size": world_size,
"ip": ip,
"port": port,
"nproc": nproc,
"offset": group_offset,
}
)
kwargs = {
"config_dict": config_dict,
"queue": queue,
}
mp.spawn(
run_recboles,
args=(model, dataset, config_file_list, kwargs),
nprocs=nproc,
join=True,
)
# Normally, there should be only one item in the queue
res = None if queue.empty() else queue.get()
return res
[docs]def run_recbole(
model=None,
dataset=None,
config_file_list=None,
config_dict=None,
saved=True,
queue=None,
):
r"""A fast running api, which includes the complete process of
training and testing a model on a specified dataset
Args:
model (str, optional): Model name. Defaults to ``None``.
dataset (str, optional): Dataset name. Defaults to ``None``.
config_file_list (list, optional): Config files used to modify experiment parameters. Defaults to ``None``.
config_dict (dict, optional): Parameters dictionary used to modify experiment parameters. Defaults to ``None``.
saved (bool, optional): Whether to save the model. Defaults to ``True``.
queue (torch.multiprocessing.Queue, optional): The queue used to pass the result to the main process. Defaults to ``None``.
"""
# configurations initialization
config = Config(
model=model,
dataset=dataset,
config_file_list=config_file_list,
config_dict=config_dict,
)
init_seed(config["seed"], config["reproducibility"])
# logger initialization
init_logger(config)
logger = getLogger()
logger.info(sys.argv)
logger.info(config)
# dataset filtering
dataset = create_dataset(config)
logger.info(dataset)
# dataset splitting
train_data, valid_data, test_data = data_preparation(config, dataset)
# model loading and initialization
init_seed(config["seed"] + config["local_rank"], config["reproducibility"])
model = get_model(config["model"])(config, train_data._dataset).to(config["device"])
logger.info(model)
transform = construct_transform(config)
flops = get_flops(model, dataset, config["device"], logger, transform)
logger.info(set_color("FLOPs", "blue") + f": {flops}")
# trainer loading and initialization
trainer = get_trainer(config["MODEL_TYPE"], config["model"])(config, model)
# model training
best_valid_score, best_valid_result = trainer.fit(
train_data, valid_data, saved=saved, show_progress=config["show_progress"]
)
# model evaluation
test_result = trainer.evaluate(
test_data, load_best_model=saved, show_progress=config["show_progress"]
)
environment_tb = get_environment(config)
logger.info(
"The running environment of this training is as follows:\n"
+ environment_tb.draw()
)
logger.info(set_color("best valid ", "yellow") + f": {best_valid_result}")
logger.info(set_color("test result", "yellow") + f": {test_result}")
result = {
"best_valid_score": best_valid_score,
"valid_score_bigger": config["valid_metric_bigger"],
"best_valid_result": best_valid_result,
"test_result": test_result,
}
if not config["single_spec"]:
dist.destroy_process_group()
if config["local_rank"] == 0 and queue is not None:
queue.put(result) # for multiprocessing, e.g., mp.spawn
return result # for the single process
[docs]def run_recboles(rank, *args):
kwargs = args[-1]
if not isinstance(kwargs, MutableMapping):
raise ValueError(
f"The last argument of run_recboles should be a dict, but got {type(kwargs)}"
)
kwargs["config_dict"] = kwargs.get("config_dict", {})
kwargs["config_dict"]["local_rank"] = rank
run_recbole(
*args[:3],
**kwargs,
)
[docs]def objective_function(config_dict=None, config_file_list=None, saved=True):
r"""The default objective_function used in HyperTuning
Args:
config_dict (dict, optional): Parameters dictionary used to modify experiment parameters. Defaults to ``None``.
config_file_list (list, optional): Config files used to modify experiment parameters. Defaults to ``None``.
saved (bool, optional): Whether to save the model. Defaults to ``True``.
"""
config = Config(config_dict=config_dict, config_file_list=config_file_list)
init_seed(config["seed"], config["reproducibility"])
logger = getLogger()
for hdlr in logger.handlers[:]: # remove all old handlers
logger.removeHandler(hdlr)
init_logger(config)
logging.basicConfig(level=logging.ERROR)
dataset = create_dataset(config)
train_data, valid_data, test_data = data_preparation(config, dataset)
init_seed(config["seed"], config["reproducibility"])
model_name = config["model"]
model = get_model(model_name)(config, train_data._dataset).to(config["device"])
trainer = get_trainer(config["MODEL_TYPE"], config["model"])(config, model)
best_valid_score, best_valid_result = trainer.fit(
train_data, valid_data, verbose=False, saved=saved
)
test_result = trainer.evaluate(test_data, load_best_model=saved)
tune.report(**test_result)
return {
"model": model_name,
"best_valid_score": best_valid_score,
"valid_score_bigger": config["valid_metric_bigger"],
"best_valid_result": best_valid_result,
"test_result": test_result,
}
[docs]def load_data_and_model(model_file):
r"""Load filtered dataset, split dataloaders and saved model.
Args:
model_file (str): The path of saved model file.
Returns:
tuple:
- config (Config): An instance object of Config, which record parameter information in :attr:`model_file`.
- model (AbstractRecommender): The model load from :attr:`model_file`.
- dataset (Dataset): The filtered dataset.
- train_data (AbstractDataLoader): The dataloader for training.
- valid_data (AbstractDataLoader): The dataloader for validation.
- test_data (AbstractDataLoader): The dataloader for testing.
"""
import torch
checkpoint = torch.load(model_file)
config = checkpoint["config"]
init_seed(config["seed"], config["reproducibility"])
init_logger(config)
logger = getLogger()
logger.info(config)
dataset = create_dataset(config)
logger.info(dataset)
train_data, valid_data, test_data = data_preparation(config, dataset)
init_seed(config["seed"], config["reproducibility"])
model = get_model(config["model"])(config, train_data._dataset).to(config["device"])
model.load_state_dict(checkpoint["state_dict"])
model.load_other_parameter(checkpoint.get("other_parameter"))
return config, model, dataset, train_data, valid_data, test_data