[docs]classAbstractTrainer(object):r"""Trainer Class is used to manage the training and evaluation processes of recommender system models. AbstractTrainer is an abstract class in which the fit() and evaluate() method should be implemented according to different training and evaluation strategies. """def__init__(self,config,model):self.config=configself.model=modelifnotconfig["single_spec"]:self.model=torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)self.distributed_model=DistributedDataParallel(self.model,device_ids=[config["local_rank"]])
[docs]deffit(self,train_data):r"""Train the model based on the train data."""raiseNotImplementedError("Method [next] should be implemented.")
[docs]defevaluate(self,eval_data):r"""Evaluate the model based on the eval data."""raiseNotImplementedError("Method [next] should be implemented.")
[docs]defset_reduce_hook(self):r"""Call the forward function of 'distributed_model' to apply grads reduce hook to each parameter of its module. """t=self.model.forwardself.model.forward=lambdax:xself.distributed_model(torch.LongTensor([0]).to(self.device))self.model.forward=t
[docs]defsync_grad_loss(self):r"""Ensure that each parameter appears to the loss function to make the grads reduce sync in each node. """sync_loss=0forparamsinself.model.parameters():sync_loss+=torch.sum(params)*0returnsync_loss
[docs]classTrainer(AbstractTrainer):r"""The basic Trainer for basic training and evaluation strategies in recommender systems. This class defines common functions for training and evaluation processes of most recommender system models, including fit(), evaluate(), resume_checkpoint() and some other features helpful for model training and evaluation. Generally speaking, this class can serve most recommender system models, If the training process of the model is to simply optimize a single loss without involving any complex training strategies, such as adversarial learning, pre-training and so on. Initializing the Trainer needs two parameters: `config` and `model`. `config` records the parameters information for controlling training and evaluation, such as `learning_rate`, `epochs`, `eval_step` and so on. `model` is the instantiated object of a Model Class. """def__init__(self,config,model):super(Trainer,self).__init__(config,model)self.logger=getLogger()self.tensorboard=get_tensorboard(self.logger)self.wandblogger=WandbLogger(config)self.learner=config["learner"]self.learning_rate=config["learning_rate"]self.epochs=config["epochs"]self.eval_step=min(config["eval_step"],self.epochs)self.stopping_step=config["stopping_step"]self.clip_grad_norm=config["clip_grad_norm"]self.valid_metric=config["valid_metric"].lower()self.valid_metric_bigger=config["valid_metric_bigger"]self.test_batch_size=config["eval_batch_size"]self.gpu_available=torch.cuda.is_available()andconfig["use_gpu"]self.device=config["device"]self.checkpoint_dir=config["checkpoint_dir"]self.enable_amp=config["enable_amp"]self.enable_scaler=torch.cuda.is_available()andconfig["enable_scaler"]ensure_dir(self.checkpoint_dir)saved_model_file="{}-{}.pth".format(self.config["model"],get_local_time())self.saved_model_file=os.path.join(self.checkpoint_dir,saved_model_file)self.weight_decay=config["weight_decay"]self.start_epoch=0self.cur_step=0self.best_valid_score=-np.infifself.valid_metric_biggerelsenp.infself.best_valid_result=Noneself.train_loss_dict=dict()self.optimizer=self._build_optimizer()self.eval_type=config["eval_type"]self.eval_collector=Collector(config)self.evaluator=Evaluator(config)self.item_tensor=Noneself.tot_item_num=Nonedef_build_optimizer(self,**kwargs):r"""Init the Optimizer Args: params (torch.nn.Parameter, optional): The parameters to be optimized. Defaults to ``self.model.parameters()``. learner (str, optional): The name of used optimizer. Defaults to ``self.learner``. learning_rate (float, optional): Learning rate. Defaults to ``self.learning_rate``. weight_decay (float, optional): The L2 regularization weight. Defaults to ``self.weight_decay``. Returns: torch.optim: the optimizer """params=kwargs.pop("params",self.model.parameters())learner=kwargs.pop("learner",self.learner)learning_rate=kwargs.pop("learning_rate",self.learning_rate)weight_decay=kwargs.pop("weight_decay",self.weight_decay)if(self.config["reg_weight"]andweight_decayandweight_decay*self.config["reg_weight"]>0):self.logger.warning("The parameters [weight_decay] and [reg_weight] are specified simultaneously, ""which may lead to double regularization.")iflearner.lower()=="adam":optimizer=optim.Adam(params,lr=learning_rate,weight_decay=weight_decay)eliflearner.lower()=="adamw":optimizer=optim.AdamW(params,lr=learning_rate,weight_decay=weight_decay)eliflearner.lower()=="sgd":optimizer=optim.SGD(params,lr=learning_rate,weight_decay=weight_decay)eliflearner.lower()=="adagrad":optimizer=optim.Adagrad(params,lr=learning_rate,weight_decay=weight_decay)eliflearner.lower()=="rmsprop":optimizer=optim.RMSprop(params,lr=learning_rate,weight_decay=weight_decay)eliflearner.lower()=="sparse_adam":optimizer=optim.SparseAdam(params,lr=learning_rate)ifweight_decay>0:self.logger.warning("Sparse Adam cannot argument received argument [{weight_decay}]")else:self.logger.warning("Received unrecognized optimizer, set default Adam optimizer")optimizer=optim.Adam(params,lr=learning_rate)returnoptimizerdef_train_epoch(self,train_data,epoch_idx,loss_func=None,show_progress=False):r"""Train the model in an epoch Args: train_data (DataLoader): The train data. epoch_idx (int): The current epoch id. loss_func (function): The loss function of :attr:`model`. If it is ``None``, the loss function will be :attr:`self.model.calculate_loss`. Defaults to ``None``. show_progress (bool): Show the progress of training epoch. Defaults to ``False``. Returns: float/tuple: The sum of loss returned by all batches in this epoch. If the loss in each batch contains multiple parts and the model return these multiple parts loss instead of the sum of loss, it will return a tuple which includes the sum of loss in each part. """self.model.train()loss_func=loss_funcorself.model.calculate_losstotal_loss=Noneiter_data=(tqdm(train_data,total=len(train_data),ncols=100,desc=set_color(f"Train {epoch_idx:>5}","pink"),)ifshow_progresselsetrain_data)ifnotself.config["single_spec"]andtrain_data.shuffle:train_data.sampler.set_epoch(epoch_idx)scaler=amp.GradScaler(enabled=self.enable_scaler)forbatch_idx,interactioninenumerate(iter_data):interaction=interaction.to(self.device)self.optimizer.zero_grad()sync_loss=0ifnotself.config["single_spec"]:self.set_reduce_hook()sync_loss=self.sync_grad_loss()withtorch.autocast(device_type=self.device.type,enabled=self.enable_amp):losses=loss_func(interaction)ifisinstance(losses,tuple):loss=sum(losses)loss_tuple=tuple(per_loss.item()forper_lossinlosses)total_loss=(loss_tupleiftotal_lossisNoneelsetuple(map(sum,zip(total_loss,loss_tuple))))else:loss=lossestotal_loss=(losses.item()iftotal_lossisNoneelsetotal_loss+losses.item())self._check_nan(loss)scaler.scale(loss+sync_loss).backward()ifself.clip_grad_norm:clip_grad_norm_(self.model.parameters(),**self.clip_grad_norm)scaler.step(self.optimizer)scaler.update()ifself.gpu_availableandshow_progress:iter_data.set_postfix_str(set_color("GPU RAM: "+get_gpu_usage(self.device),"yellow"))returntotal_lossdef_valid_epoch(self,valid_data,show_progress=False):r"""Valid the model with valid data Args: valid_data (DataLoader): the valid data. show_progress (bool): Show the progress of evaluate epoch. Defaults to ``False``. Returns: float: valid score dict: valid result """valid_result=self.evaluate(valid_data,load_best_model=False,show_progress=show_progress)valid_score=calculate_valid_score(valid_result,self.valid_metric)returnvalid_score,valid_resultdef_save_checkpoint(self,epoch,verbose=True,**kwargs):r"""Store the model parameters information and training information. Args: epoch (int): the current epoch id """ifnotself.config["single_spec"]andself.config["local_rank"]!=0:returnsaved_model_file=kwargs.pop("saved_model_file",self.saved_model_file)state={"config":self.config,"epoch":epoch,"cur_step":self.cur_step,"best_valid_score":self.best_valid_score,"state_dict":self.model.state_dict(),"other_parameter":self.model.other_parameter(),"optimizer":self.optimizer.state_dict(),}torch.save(state,saved_model_file,pickle_protocol=4)ifverbose:self.logger.info(set_color("Saving current","blue")+f": {saved_model_file}")
[docs]defresume_checkpoint(self,resume_file):r"""Load the model parameters information and training information. Args: resume_file (file): the checkpoint file """resume_file=str(resume_file)self.saved_model_file=resume_filecheckpoint=torch.load(resume_file,map_location=self.device)self.start_epoch=checkpoint["epoch"]+1self.cur_step=checkpoint["cur_step"]self.best_valid_score=checkpoint["best_valid_score"]# load architecture params from checkpointifcheckpoint["config"]["model"].lower()!=self.config["model"].lower():self.logger.warning("Architecture configuration given in config file is different from that of checkpoint. ""This may yield an exception while state_dict is being loaded.")self.model.load_state_dict(checkpoint["state_dict"])self.model.load_other_parameter(checkpoint.get("other_parameter"))# load optimizer state from checkpoint only when optimizer type is not changedself.optimizer.load_state_dict(checkpoint["optimizer"])message_output="Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch)self.logger.info(message_output)
def_check_nan(self,loss):iftorch.isnan(loss):raiseValueError("Training loss is nan")def_generate_train_loss_output(self,epoch_idx,s_time,e_time,losses):des=self.config["loss_decimal_place"]or4train_loss_output=(set_color("epoch %d training","green")+" ["+set_color("time","blue")+": %.2fs, ")%(epoch_idx,e_time-s_time)ifisinstance(losses,tuple):des=set_color("train_loss%d","blue")+": %."+str(des)+"f"train_loss_output+=", ".join(des%(idx+1,loss)foridx,lossinenumerate(losses))else:des="%."+str(des)+"f"train_loss_output+=set_color("train loss","blue")+": "+des%lossesreturntrain_loss_output+"]"def_add_train_loss_to_tensorboard(self,epoch_idx,losses,tag="Loss/Train"):ifisinstance(losses,tuple):foridx,lossinenumerate(losses):self.tensorboard.add_scalar(tag+str(idx),loss,epoch_idx)else:self.tensorboard.add_scalar(tag,losses,epoch_idx)def_add_hparam_to_tensorboard(self,best_valid_result):# base hparamhparam_dict={"learner":self.config["learner"],"learning_rate":self.config["learning_rate"],"train_batch_size":self.config["train_batch_size"],}# unrecorded parameterunrecorded_parameter={parameterforparametersinself.config.parameters.values()forparameterinparameters}.union({"model","dataset","config_files","device"})# other model-specific hparamhparam_dict.update({para:valforpara,valinself.config.final_config_dict.items()ifparanotinunrecorded_parameter})forkinhparam_dict:ifhparam_dict[k]isnotNoneandnotisinstance(hparam_dict[k],(bool,str,float,int)):hparam_dict[k]=str(hparam_dict[k])self.tensorboard.add_hparams(hparam_dict,{"hparam/best_valid_result":best_valid_result})
[docs]deffit(self,train_data,valid_data=None,verbose=True,saved=True,show_progress=False,callback_fn=None,):r"""Train the model based on the train data and the valid data. Args: train_data (DataLoader): the train data valid_data (DataLoader, optional): the valid data, default: None. If it's None, the early_stopping is invalid. verbose (bool, optional): whether to write training and evaluation information to logger, default: True saved (bool, optional): whether to save the model parameters, default: True show_progress (bool): Show the progress of training epoch and evaluate epoch. Defaults to ``False``. callback_fn (callable): Optional callback function executed at end of epoch. Includes (epoch_idx, valid_score) input arguments. Returns: (float, dict): best valid score and best valid result. If valid_data is None, it returns (-1, None) """ifsavedandself.start_epoch>=self.epochs:self._save_checkpoint(-1,verbose=verbose)self.eval_collector.data_collect(train_data)ifself.config["train_neg_sample_args"].get("dynamic",False):train_data.get_model(self.model)valid_step=0forepoch_idxinrange(self.start_epoch,self.epochs):# traintraining_start_time=time()train_loss=self._train_epoch(train_data,epoch_idx,show_progress=show_progress)self.train_loss_dict[epoch_idx]=(sum(train_loss)ifisinstance(train_loss,tuple)elsetrain_loss)training_end_time=time()train_loss_output=self._generate_train_loss_output(epoch_idx,training_start_time,training_end_time,train_loss)ifverbose:self.logger.info(train_loss_output)self._add_train_loss_to_tensorboard(epoch_idx,train_loss)self.wandblogger.log_metrics({"epoch":epoch_idx,"train_loss":train_loss,"train_step":epoch_idx},head="train",)# evalifself.eval_step<=0ornotvalid_data:ifsaved:self._save_checkpoint(epoch_idx,verbose=verbose)continueif(epoch_idx+1)%self.eval_step==0:valid_start_time=time()valid_score,valid_result=self._valid_epoch(valid_data,show_progress=show_progress)(self.best_valid_score,self.cur_step,stop_flag,update_flag,)=early_stopping(valid_score,self.best_valid_score,self.cur_step,max_step=self.stopping_step,bigger=self.valid_metric_bigger,)valid_end_time=time()valid_score_output=(set_color("epoch %d evaluating","green")+" ["+set_color("time","blue")+": %.2fs, "+set_color("valid_score","blue")+": %f]")%(epoch_idx,valid_end_time-valid_start_time,valid_score)valid_result_output=(set_color("valid result","blue")+": \n"+dict2str(valid_result))ifverbose:self.logger.info(valid_score_output)self.logger.info(valid_result_output)self.tensorboard.add_scalar("Vaild_score",valid_score,epoch_idx)self.wandblogger.log_metrics({**valid_result,"valid_step":valid_step},head="valid")ifupdate_flag:ifsaved:self._save_checkpoint(epoch_idx,verbose=verbose)self.best_valid_result=valid_resultifcallback_fn:callback_fn(epoch_idx,valid_score)ifstop_flag:stop_output="Finished training, best eval result in epoch %d"%(epoch_idx-self.cur_step*self.eval_step)ifverbose:self.logger.info(stop_output)breakvalid_step+=1self._add_hparam_to_tensorboard(self.best_valid_score)returnself.best_valid_score,self.best_valid_result
def_full_sort_batch_eval(self,batched_data):interaction,history_index,positive_u,positive_i=batched_datatry:# Note: interaction without item idsscores=self.model.full_sort_predict(interaction.to(self.device))exceptNotImplementedError:inter_len=len(interaction)new_inter=interaction.to(self.device).repeat_interleave(self.tot_item_num)batch_size=len(new_inter)new_inter.update(self.item_tensor.repeat(inter_len))ifbatch_size<=self.test_batch_size:scores=self.model.predict(new_inter)else:scores=self._spilt_predict(new_inter,batch_size)scores=scores.view(-1,self.tot_item_num)scores[:,0]=-np.infifhistory_indexisnotNone:scores[history_index]=-np.infreturninteraction,scores,positive_u,positive_idef_neg_sample_batch_eval(self,batched_data):interaction,row_idx,positive_u,positive_i=batched_databatch_size=interaction.lengthifbatch_size<=self.test_batch_size:origin_scores=self.model.predict(interaction.to(self.device))else:origin_scores=self._spilt_predict(interaction,batch_size)ifself.config["eval_type"]==EvaluatorType.VALUE:returninteraction,origin_scores,positive_u,positive_ielifself.config["eval_type"]==EvaluatorType.RANKING:col_idx=interaction[self.config["ITEM_ID_FIELD"]]batch_user_num=positive_u[-1]+1scores=torch.full((batch_user_num,self.tot_item_num),-np.inf,device=self.device)scores[row_idx,col_idx]=origin_scoresreturninteraction,scores,positive_u,positive_i
[docs]@torch.no_grad()defevaluate(self,eval_data,load_best_model=True,model_file=None,show_progress=False):r"""Evaluate the model based on the eval data. Args: eval_data (DataLoader): the eval data load_best_model (bool, optional): whether load the best model in the training process, default: True. It should be set True, if users want to test the model after training. model_file (str, optional): the saved model file, default: None. If users want to test the previously trained model file, they can set this parameter. show_progress (bool): Show the progress of evaluate epoch. Defaults to ``False``. Returns: collections.OrderedDict: eval result, key is the eval metric and value in the corresponding metric value. """ifnoteval_data:returnifload_best_model:checkpoint_file=model_fileorself.saved_model_filecheckpoint=torch.load(checkpoint_file,map_location=self.device)self.model.load_state_dict(checkpoint["state_dict"])self.model.load_other_parameter(checkpoint.get("other_parameter"))message_output="Loading model structure and parameters from {}".format(checkpoint_file)self.logger.info(message_output)self.model.eval()ifisinstance(eval_data,FullSortEvalDataLoader):eval_func=self._full_sort_batch_evalifself.item_tensorisNone:self.item_tensor=eval_data._dataset.get_item_feature().to(self.device)else:eval_func=self._neg_sample_batch_evalifself.config["eval_type"]==EvaluatorType.RANKING:self.tot_item_num=eval_data._dataset.item_numiter_data=(tqdm(eval_data,total=len(eval_data),ncols=100,desc=set_color(f"Evaluate ","pink"),)ifshow_progresselseeval_data)num_sample=0forbatch_idx,batched_datainenumerate(iter_data):num_sample+=len(batched_data)interaction,scores,positive_u,positive_i=eval_func(batched_data)ifself.gpu_availableandshow_progress:iter_data.set_postfix_str(set_color("GPU RAM: "+get_gpu_usage(self.device),"yellow"))self.eval_collector.eval_batch_collect(scores,interaction,positive_u,positive_i)self.eval_collector.model_collect(self.model)struct=self.eval_collector.get_data_struct()result=self.evaluator.evaluate(struct)ifnotself.config["single_spec"]:result=self._map_reduce(result,num_sample)self.wandblogger.log_eval_metrics(result,head="eval")returnresult
[docs]classKGTrainer(Trainer):r"""KGTrainer is designed for Knowledge-aware recommendation methods. Some of these models need to train the recommendation related task and knowledge related task alternately. """def__init__(self,config,model):super(KGTrainer,self).__init__(config,model)self.train_rec_step=config["train_rec_step"]self.train_kg_step=config["train_kg_step"]def_train_epoch(self,train_data,epoch_idx,loss_func=None,show_progress=False):ifself.train_rec_stepisNoneorself.train_kg_stepisNone:interaction_state=KGDataLoaderState.RSKGelif(epoch_idx%(self.train_rec_step+self.train_kg_step)<self.train_rec_step):interaction_state=KGDataLoaderState.RSelse:interaction_state=KGDataLoaderState.KGifnotself.config["single_spec"]:train_data.knowledge_shuffle(epoch_idx)train_data.set_mode(interaction_state)ifinteraction_statein[KGDataLoaderState.RSKG,KGDataLoaderState.RS]:returnsuper()._train_epoch(train_data,epoch_idx,show_progress=show_progress)elifinteraction_statein[KGDataLoaderState.KG]:returnsuper()._train_epoch(train_data,epoch_idx,loss_func=self.model.calculate_kg_loss,show_progress=show_progress,)returnNone
[docs]classKGATTrainer(Trainer):r"""KGATTrainer is designed for KGAT, which is a knowledge-aware recommendation method."""def__init__(self,config,model):super(KGATTrainer,self).__init__(config,model)def_train_epoch(self,train_data,epoch_idx,loss_func=None,show_progress=False):# train rsifnotself.config["single_spec"]:train_data.knowledge_shuffle(epoch_idx)train_data.set_mode(KGDataLoaderState.RS)rs_total_loss=super()._train_epoch(train_data,epoch_idx,show_progress=show_progress)# train kgtrain_data.set_mode(KGDataLoaderState.KG)kg_total_loss=super()._train_epoch(train_data,epoch_idx,loss_func=self.model.calculate_kg_loss,show_progress=show_progress,)# update Aself.model.eval()withtorch.no_grad():self.model.update_attentive_A()returnrs_total_loss,kg_total_loss
[docs]classPretrainTrainer(Trainer):r"""PretrainTrainer is designed for pre-training. It can be inherited by the trainer which needs pre-training and fine-tuning. """def__init__(self,config,model):super(PretrainTrainer,self).__init__(config,model)self.pretrain_epochs=self.config["pretrain_epochs"]self.save_step=self.config["save_step"]
[docs]defsave_pretrained_model(self,epoch,saved_model_file):r"""Store the model parameters information and training information. Args: epoch (int): the current epoch id saved_model_file (str): file name for saved pretrained model """state={"config":self.config,"epoch":epoch,"state_dict":self.model.state_dict(),"optimizer":self.optimizer.state_dict(),"other_parameter":self.model.other_parameter(),}torch.save(state,saved_model_file)self.saved_model_file=saved_model_file
[docs]classS3RecTrainer(PretrainTrainer):r"""S3RecTrainer is designed for S3Rec, which is a self-supervised learning based sequential recommenders. It includes two training stages: pre-training ang fine-tuning. """def__init__(self,config,model):super(S3RecTrainer,self).__init__(config,model)
[docs]deffit(self,train_data,valid_data=None,verbose=True,saved=True,show_progress=False,callback_fn=None,):ifself.model.train_stage=="pretrain":returnself.pretrain(train_data,verbose,show_progress)elifself.model.train_stage=="finetune":returnsuper().fit(train_data,valid_data,verbose,saved,show_progress,callback_fn)else:raiseValueError("Please make sure that the 'train_stage' is 'pretrain' or 'finetune'!")
[docs]classMKRTrainer(Trainer):r"""MKRTrainer is designed for MKR, which is a knowledge-aware recommendation method."""def__init__(self,config,model):super(MKRTrainer,self).__init__(config,model)self.kge_interval=config["kge_interval"]def_train_epoch(self,train_data,epoch_idx,loss_func=None,show_progress=False):rs_total_loss,kg_total_loss=0.0,0.0# train rsself.logger.info("Train RS")train_data.set_mode(KGDataLoaderState.RS)rs_total_loss=super()._train_epoch(train_data,epoch_idx,loss_func=self.model.calculate_rs_loss,show_progress=show_progress,)# train kgifepoch_idx%self.kge_interval==0:self.logger.info("Train KG")train_data.set_mode(KGDataLoaderState.KG)kg_total_loss=super()._train_epoch(train_data,epoch_idx,loss_func=self.model.calculate_kg_loss,show_progress=show_progress,)returnrs_total_loss,kg_total_loss
[docs]classTraditionalTrainer(Trainer):r"""TraditionalTrainer is designed for Traditional model(Pop,ItemKNN), which set the epoch to 1 whatever the config."""def__init__(self,config,model):super(TraditionalTrainer,self).__init__(config,model)self.epochs=1# Set the epoch to 1 when running memory based model
[docs]classDecisionTreeTrainer(AbstractTrainer):"""DecisionTreeTrainer is designed for DecisionTree model."""def__init__(self,config,model):super(DecisionTreeTrainer,self).__init__(config,model)self.logger=getLogger()self.tensorboard=get_tensorboard(self.logger)self.label_field=config["LABEL_FIELD"]self.convert_token_to_onehot=self.config["convert_token_to_onehot"]# evaluatorself.eval_type=config["eval_type"]self.epochs=config["epochs"]self.eval_step=min(config["eval_step"],self.epochs)self.valid_metric=config["valid_metric"].lower()self.eval_collector=Collector(config)self.evaluator=Evaluator(config)# model savedself.checkpoint_dir=config["checkpoint_dir"]ensure_dir(self.checkpoint_dir)temp_file="{}-{}-temp.pth".format(self.config["model"],get_local_time())self.temp_file=os.path.join(self.checkpoint_dir,temp_file)temp_best_file="{}-{}-temp-best.pth".format(self.config["model"],get_local_time())self.temp_best_file=os.path.join(self.checkpoint_dir,temp_best_file)saved_model_file="{}-{}.pth".format(self.config["model"],get_local_time())self.saved_model_file=os.path.join(self.checkpoint_dir,saved_model_file)self.stopping_step=config["stopping_step"]self.valid_metric_bigger=config["valid_metric_bigger"]self.cur_step=0self.best_valid_score=-np.infifself.valid_metric_biggerelsenp.infself.best_valid_result=Nonedef_interaction_to_sparse(self,dataloader):r"""Convert data format from interaction to sparse or numpy Args: dataloader (DecisionTreeDataLoader): DecisionTreeDataLoader dataloader. Returns: cur_data (sparse or numpy): data. interaction_np[self.label_field] (numpy): label. """interaction=dataloader._dataset[:]interaction_np=interaction.numpy()cur_data=np.array([])columns=[]forkey,valueininteraction_np.items():value=np.resize(value,(value.shape[0],1))ifkey!=self.label_field:columns.append(key)ifcur_data.shape[0]==0:cur_data=valueelse:cur_data=np.hstack((cur_data,value))ifself.convert_token_to_onehot:fromscipyimportsparsefromscipy.sparseimportdok_matrixconvert_col_list=dataloader._dataset.convert_col_listhash_count=dataloader._dataset.hash_countnew_col=cur_data.shape[1]-len(convert_col_list)forkey,valuesinhash_count.items():new_col=new_col+valuesonehot_data=dok_matrix((cur_data.shape[0],new_col))cur_j=0new_j=0forkeyincolumns:ifkeyinconvert_col_list:foriinrange(cur_data.shape[0]):onehot_data[i,int(new_j+cur_data[i,cur_j])]=1new_j=new_j+hash_count[key]-1else:foriinrange(cur_data.shape[0]):onehot_data[i,new_j]=cur_data[i,cur_j]cur_j=cur_j+1new_j=new_j+1cur_data=sparse.csc_matrix(onehot_data)returncur_data,interaction_np[self.label_field]def_interaction_to_lib_datatype(self,dataloader):passdef_valid_epoch(self,valid_data):r""" Args: valid_data (DecisionTreeDataLoader): DecisionTreeDataLoader, which is the same with GeneralDataLoader. """valid_result=self.evaluate(valid_data,load_best_model=False)valid_score=calculate_valid_score(valid_result,self.valid_metric)returnvalid_score,valid_resultdef_save_checkpoint(self,epoch):r"""Store the model parameters information and training information. Args: epoch (int): the current epoch id """state={"config":self.config,"epoch":epoch,"cur_step":self.cur_step,"best_valid_score":self.best_valid_score,"state_dict":self.temp_best_file,"other_parameter":None,}torch.save(state,self.saved_model_file)
[docs]deffit(self,train_data,valid_data=None,verbose=True,saved=True,show_progress=False):forepoch_idxinrange(self.epochs):self._train_at_once(train_data,valid_data)if(epoch_idx+1)%self.eval_step==0:# evaluatevalid_start_time=time()valid_score,valid_result=self._valid_epoch(valid_data)(self.best_valid_score,self.cur_step,stop_flag,update_flag,)=early_stopping(valid_score,self.best_valid_score,self.cur_step,max_step=self.stopping_step,bigger=self.valid_metric_bigger,)valid_end_time=time()valid_score_output=(set_color("epoch %d evaluating","green")+" ["+set_color("time","blue")+": %.2fs, "+set_color("valid_score","blue")+": %f]")%(epoch_idx,valid_end_time-valid_start_time,valid_score)valid_result_output=(set_color("valid result","blue")+": \n"+dict2str(valid_result))ifverbose:self.logger.info(valid_score_output)self.logger.info(valid_result_output)self.tensorboard.add_scalar("Vaild_score",valid_score,epoch_idx)ifupdate_flag:ifsaved:self.model.save_model(self.temp_best_file)self._save_checkpoint(epoch_idx)self.best_valid_result=valid_resultifstop_flag:stop_output="Finished training, best eval result in epoch %d"%(epoch_idx-self.cur_step*self.eval_step)ifself.temp_file:os.remove(self.temp_file)ifverbose:self.logger.info(stop_output)breakreturnself.best_valid_score,self.best_valid_result
[docs]classXGBoostTrainer(DecisionTreeTrainer):"""XGBoostTrainer is designed for XGBOOST."""def__init__(self,config,model):super(XGBoostTrainer,self).__init__(config,model)self.xgb=__import__("xgboost")self.boost_model=config["xgb_model"]self.silent=config["xgb_silent"]self.nthread=config["xgb_nthread"]# train paramsself.params=config["xgb_params"]self.num_boost_round=config["xgb_num_boost_round"]self.evals=()self.early_stopping_rounds=config["xgb_early_stopping_rounds"]self.evals_result={}self.verbose_eval=config["xgb_verbose_eval"]self.callbacks=Noneself.deval=Noneself.eval_pred=self.eval_true=Nonedef_interaction_to_lib_datatype(self,dataloader):r"""Convert data format from interaction to DMatrix Args: dataloader (DecisionTreeDataLoader): xgboost dataloader. Returns: DMatrix: Data in the form of 'DMatrix'. """data,label=self._interaction_to_sparse(dataloader)returnself.xgb.DMatrix(data=data,label=label,silent=self.silent,nthread=self.nthread)def_train_at_once(self,train_data,valid_data):r""" Args: train_data (DecisionTreeDataLoader): DecisionTreeDataLoader, which is the same with GeneralDataLoader. valid_data (DecisionTreeDataLoader): DecisionTreeDataLoader, which is the same with GeneralDataLoader. """self.dtrain=self._interaction_to_lib_datatype(train_data)self.dvalid=self._interaction_to_lib_datatype(valid_data)self.evals=[(self.dtrain,"train"),(self.dvalid,"valid")]self.model=self.xgb.train(self.params,self.dtrain,self.num_boost_round,self.evals,early_stopping_rounds=self.early_stopping_rounds,evals_result=self.evals_result,verbose_eval=self.verbose_eval,xgb_model=self.boost_model,callbacks=self.callbacks,)self.model.save_model(self.temp_file)self.boost_model=self.temp_file
[docs]classLightGBMTrainer(DecisionTreeTrainer):"""LightGBMTrainer is designed for LightGBM."""def__init__(self,config,model):super(LightGBMTrainer,self).__init__(config,model)self.lgb=__import__("lightgbm")# train paramsself.params=config["lgb_params"]self.num_boost_round=config["lgb_num_boost_round"]self.evals=()self.deval_data=self.deval_label=Noneself.eval_pred=self.eval_true=Nonedef_interaction_to_lib_datatype(self,dataloader):r"""Convert data format from interaction to Dataset Args: dataloader (DecisionTreeDataLoader): xgboost dataloader. Returns: dataset(lgb.Dataset): Data in the form of 'lgb.Dataset'. """data,label=self._interaction_to_sparse(dataloader)returnself.lgb.Dataset(data=data,label=label)def_train_at_once(self,train_data,valid_data):r""" Args: train_data (DecisionTreeDataLoader): DecisionTreeDataLoader, which is the same with GeneralDataLoader. valid_data (DecisionTreeDataLoader): DecisionTreeDataLoader, which is the same with GeneralDataLoader. """self.dtrain=self._interaction_to_lib_datatype(train_data)self.dvalid=self._interaction_to_lib_datatype(valid_data)self.evals=[self.dtrain,self.dvalid]self.model=self.lgb.train(self.params,self.dtrain,self.num_boost_round,self.evals)self.model.save_model(self.temp_file)self.boost_model=self.temp_file
[docs]classRaCTTrainer(PretrainTrainer):r"""RaCTTrainer is designed for RaCT, which is an actor-critic reinforcement learning based general recommenders. It includes three training stages: actor pre-training, critic pre-training and actor-critic training. """def__init__(self,config,model):super(RaCTTrainer,self).__init__(config,model)
[docs]deffit(self,train_data,valid_data=None,verbose=True,saved=True,show_progress=False,callback_fn=None,):ifself.model.train_stage=="actor_pretrain":returnself.pretrain(train_data,verbose,show_progress)elifself.model.train_stage=="critic_pretrain":returnself.pretrain(train_data,verbose,show_progress)elifself.model.train_stage=="finetune":returnsuper().fit(train_data,valid_data,verbose,saved,show_progress,callback_fn)else:raiseValueError("Please make sure that the 'train_stage' is ""'actor_pretrain', 'critic_pretrain' or 'finetune'!")
[docs]classRecVAETrainer(Trainer):r"""RecVAETrainer is designed for RecVAE, which is a general recommender."""def__init__(self,config,model):super(RecVAETrainer,self).__init__(config,model)self.n_enc_epochs=config["n_enc_epochs"]self.n_dec_epochs=config["n_dec_epochs"]self.optimizer_encoder=self._build_optimizer(params=self.model.encoder.parameters())self.optimizer_decoder=self._build_optimizer(params=self.model.decoder.parameters())def_train_epoch(self,train_data,epoch_idx,loss_func=None,show_progress=False):self.optimizer=self.optimizer_encoderencoder_loss_func=lambdadata:self.model.calculate_loss(data,encoder_flag=True)forepochinrange(self.n_enc_epochs):super()._train_epoch(train_data,epoch_idx,loss_func=encoder_loss_func,show_progress=show_progress,)self.model.update_prior()loss=0.0self.optimizer=self.optimizer_decoderdecoder_loss_func=lambdadata:self.model.calculate_loss(data,encoder_flag=False)forepochinrange(self.n_dec_epochs):loss+=super()._train_epoch(train_data,epoch_idx,loss_func=decoder_loss_func,show_progress=show_progress,)returnloss
[docs]deffit(self,train_data,valid_data=None,verbose=True,saved=True,show_progress=False,callback_fn=None,):r"""Train the model based on the train data and the valid data. Args: train_data (DataLoader): the train data. valid_data (DataLoader, optional): the valid data, default: None. If it's None, the early_stopping is invalid. verbose (bool, optional): whether to write training and evaluation information to logger, default: True saved (bool, optional): whether to save the model parameters, default: True show_progress (bool): Show the progress of training epoch and evaluate epoch. Defaults to ``False``. callback_fn (callable): Optional callback function executed at end of epoch. Includes (epoch_idx, valid_score) input arguments. Returns: (float, dict): best valid score and best valid result. If valid_data is None, it returns (-1, None) """ifsavedandself.start_epoch>=self.epochs:self._save_checkpoint(-1)self.eval_collector.data_collect(train_data)forepoch_idxinrange(self.start_epoch,self.epochs):# only differences from the original trainerifepoch_idx%self.num_m_step==0:self.logger.info("Running E-step ! ")self.model.e_step()# traintraining_start_time=time()train_loss=self._train_epoch(train_data,epoch_idx,show_progress=show_progress)self.train_loss_dict[epoch_idx]=(sum(train_loss)ifisinstance(train_loss,tuple)elsetrain_loss)training_end_time=time()train_loss_output=self._generate_train_loss_output(epoch_idx,training_start_time,training_end_time,train_loss)ifverbose:self.logger.info(train_loss_output)self._add_train_loss_to_tensorboard(epoch_idx,train_loss)# evalifself.eval_step<=0ornotvalid_data:ifsaved:self._save_checkpoint(epoch_idx)update_output=(set_color("Saving current","blue")+": %s"%self.saved_model_file)ifverbose:self.logger.info(update_output)continueif(epoch_idx+1)%self.eval_step==0:valid_start_time=time()valid_score,valid_result=self._valid_epoch(valid_data,show_progress=show_progress)(self.best_valid_score,self.cur_step,stop_flag,update_flag,)=early_stopping(valid_score,self.best_valid_score,self.cur_step,max_step=self.stopping_step,bigger=self.valid_metric_bigger,)valid_end_time=time()valid_score_output=(set_color("epoch %d evaluating","green")+" ["+set_color("time","blue")+": %.2fs, "+set_color("valid_score","blue")+": %f]")%(epoch_idx,valid_end_time-valid_start_time,valid_score)valid_result_output=(set_color("valid result","blue")+": \n"+dict2str(valid_result))ifverbose:self.logger.info(valid_score_output)self.logger.info(valid_result_output)self.tensorboard.add_scalar("Vaild_score",valid_score,epoch_idx)ifupdate_flag:ifsaved:self._save_checkpoint(epoch_idx)update_output=(set_color("Saving current best","blue")+": %s"%self.saved_model_file)ifverbose:self.logger.info(update_output)self.best_valid_result=valid_resultifcallback_fn:callback_fn(epoch_idx,valid_score)ifstop_flag:stop_output="Finished training, best eval result in epoch %d"%(epoch_idx-self.cur_step*self.eval_step)ifverbose:self.logger.info(stop_output)breakself._add_hparam_to_tensorboard(self.best_valid_score)returnself.best_valid_score,self.best_valid_result
def_train_epoch(self,train_data,epoch_idx,loss_func=None,show_progress=False):r"""Train the model in an epoch Args: train_data (DataLoader): The train data. epoch_idx (int): The current epoch id. loss_func (function): The loss function of :attr:`model`. If it is ``None``, the loss function will be :attr:`self.model.calculate_loss`. Defaults to ``None``. show_progress (bool): Show the progress of training epoch. Defaults to ``False``. Returns: float/tuple: The sum of loss returned by all batches in this epoch. If the loss in each batch contains multiple parts and the model return these multiple parts loss instead of the sum of loss, it will return a tuple which includes the sum of loss in each part. """self.model.train()loss_func=loss_funcorself.model.calculate_losstotal_loss=Noneiter_data=(tqdm(train_data,total=len(train_data),ncols=100,desc=set_color(f"Train {epoch_idx:>5}","pink"),)ifshow_progresselsetrain_data)scaler=amp.GradScaler(enabled=self.enable_scaler)ifnotself.config["single_spec"]andtrain_data.shuffle:train_data.sampler.set_epoch(epoch_idx)forbatch_idx,interactioninenumerate(iter_data):interaction=interaction.to(self.device)self.optimizer.zero_grad()sync_loss=0ifnotself.config["single_spec"]:self.set_reduce_hook()sync_loss=self.sync_grad_loss()withamp.autocast(enabled=self.enable_amp):losses=loss_func(interaction)ifisinstance(losses,tuple):ifepoch_idx<self.config["warm_up_step"]:losses=losses[:-1]loss=sum(losses)loss_tuple=tuple(per_loss.item()forper_lossinlosses)total_loss=(loss_tupleiftotal_lossisNoneelsetuple(map(sum,zip(total_loss,loss_tuple))))else:loss=lossestotal_loss=(losses.item()iftotal_lossisNoneelsetotal_loss+losses.item())self._check_nan(loss)scaler.scale(loss+sync_loss).backward()ifself.clip_grad_norm:clip_grad_norm_(self.model.parameters(),**self.clip_grad_norm)scaler.step(self.optimizer)scaler.update()ifself.gpu_availableandshow_progress:iter_data.set_postfix_str(set_color("GPU RAM: "+get_gpu_usage(self.device),"yellow"))returntotal_loss