[docs]defrun(model,dataset,config_file_list=None,config_dict=None,saved=True,nproc=1,world_size=-1,ip="localhost",port="5678",group_offset=0,):ifnproc==1andworld_size<=0:res=run_recbole(model=model,dataset=dataset,config_file_list=config_file_list,config_dict=config_dict,saved=saved,)else:ifworld_size==-1:world_size=nprocimporttorch.multiprocessingasmp# Refer to https://discuss.pytorch.org/t/problems-with-torch-multiprocess-spawn-and-simplequeue/69674/2# https://discuss.pytorch.org/t/return-from-mp-spawn/94302/2queue=mp.get_context("spawn").SimpleQueue()config_dict=config_dictor{}config_dict.update({"world_size":world_size,"ip":ip,"port":port,"nproc":nproc,"offset":group_offset,})kwargs={"config_dict":config_dict,"queue":queue,}mp.spawn(run_recboles,args=(model,dataset,config_file_list,kwargs),nprocs=nproc,join=True,)# Normally, there should be only one item in the queueres=Noneifqueue.empty()elsequeue.get()returnres
[docs]defrun_recbole(model=None,dataset=None,config_file_list=None,config_dict=None,saved=True,queue=None,):r"""A fast running api, which includes the complete process of training and testing a model on a specified dataset Args: model (str, optional): Model name. Defaults to ``None``. dataset (str, optional): Dataset name. Defaults to ``None``. config_file_list (list, optional): Config files used to modify experiment parameters. Defaults to ``None``. config_dict (dict, optional): Parameters dictionary used to modify experiment parameters. Defaults to ``None``. saved (bool, optional): Whether to save the model. Defaults to ``True``. queue (torch.multiprocessing.Queue, optional): The queue used to pass the result to the main process. Defaults to ``None``. """# configurations initializationconfig=Config(model=model,dataset=dataset,config_file_list=config_file_list,config_dict=config_dict,)init_seed(config["seed"],config["reproducibility"])# logger initializationinit_logger(config)logger=getLogger()logger.info(sys.argv)logger.info(config)# dataset filteringdataset=create_dataset(config)logger.info(dataset)# dataset splittingtrain_data,valid_data,test_data=data_preparation(config,dataset)# model loading and initializationinit_seed(config["seed"]+config["local_rank"],config["reproducibility"])model=get_model(config["model"])(config,train_data._dataset).to(config["device"])logger.info(model)transform=construct_transform(config)flops=get_flops(model,dataset,config["device"],logger,transform)logger.info(set_color("FLOPs","blue")+f": {flops}")# trainer loading and initializationtrainer=get_trainer(config["MODEL_TYPE"],config["model"])(config,model)# model trainingbest_valid_score,best_valid_result=trainer.fit(train_data,valid_data,saved=saved,show_progress=config["show_progress"])# model evaluationtest_result=trainer.evaluate(test_data,load_best_model=saved,show_progress=config["show_progress"])environment_tb=get_environment(config)logger.info("The running environment of this training is as follows:\n"+environment_tb.draw())logger.info(set_color("best valid ","yellow")+f": {best_valid_result}")logger.info(set_color("test result","yellow")+f": {test_result}")result={"best_valid_score":best_valid_score,"valid_score_bigger":config["valid_metric_bigger"],"best_valid_result":best_valid_result,"test_result":test_result,}ifnotconfig["single_spec"]:dist.destroy_process_group()ifconfig["local_rank"]==0andqueueisnotNone:queue.put(result)# for multiprocessing, e.g., mp.spawnreturnresult# for the single process
[docs]defrun_recboles(rank,*args):kwargs=args[-1]ifnotisinstance(kwargs,MutableMapping):raiseValueError(f"The last argument of run_recboles should be a dict, but got {type(kwargs)}")kwargs["config_dict"]=kwargs.get("config_dict",{})kwargs["config_dict"]["local_rank"]=rankrun_recbole(*args[:3],**kwargs,)
[docs]defobjective_function(config_dict=None,config_file_list=None,saved=True):r"""The default objective_function used in HyperTuning Args: config_dict (dict, optional): Parameters dictionary used to modify experiment parameters. Defaults to ``None``. config_file_list (list, optional): Config files used to modify experiment parameters. Defaults to ``None``. saved (bool, optional): Whether to save the model. Defaults to ``True``. """config=Config(config_dict=config_dict,config_file_list=config_file_list)init_seed(config["seed"],config["reproducibility"])logger=getLogger()forhdlrinlogger.handlers[:]:# remove all old handlerslogger.removeHandler(hdlr)init_logger(config)logging.basicConfig(level=logging.ERROR)dataset=create_dataset(config)train_data,valid_data,test_data=data_preparation(config,dataset)init_seed(config["seed"],config["reproducibility"])model_name=config["model"]model=get_model(model_name)(config,train_data._dataset).to(config["device"])trainer=get_trainer(config["MODEL_TYPE"],config["model"])(config,model)best_valid_score,best_valid_result=trainer.fit(train_data,valid_data,verbose=False,saved=saved)test_result=trainer.evaluate(test_data,load_best_model=saved)tune.report(**test_result)return{"model":model_name,"best_valid_score":best_valid_score,"valid_score_bigger":config["valid_metric_bigger"],"best_valid_result":best_valid_result,"test_result":test_result,}
[docs]defload_data_and_model(model_file):r"""Load filtered dataset, split dataloaders and saved model. Args: model_file (str): The path of saved model file. Returns: tuple: - config (Config): An instance object of Config, which record parameter information in :attr:`model_file`. - model (AbstractRecommender): The model load from :attr:`model_file`. - dataset (Dataset): The filtered dataset. - train_data (AbstractDataLoader): The dataloader for training. - valid_data (AbstractDataLoader): The dataloader for validation. - test_data (AbstractDataLoader): The dataloader for testing. """importtorchcheckpoint=torch.load(model_file)config=checkpoint["config"]init_seed(config["seed"],config["reproducibility"])init_logger(config)logger=getLogger()logger.info(config)dataset=create_dataset(config)logger.info(dataset)train_data,valid_data,test_data=data_preparation(config,dataset)init_seed(config["seed"],config["reproducibility"])model=get_model(config["model"])(config,train_data._dataset).to(config["device"])model.load_state_dict(checkpoint["state_dict"])model.load_other_parameter(checkpoint.get("other_parameter"))returnconfig,model,dataset,train_data,valid_data,test_data