recbole.trainer.trainer

class recbole.trainer.trainer.AbstractTrainer(config, model)[source]

Bases: object

Trainer Class is used to manage the training and evaluation processes of recommender system models. AbstractTrainer is an abstract class in which the fit() and evaluate() method should be implemented according to different training and evaluation strategies.

evaluate(eval_data)[source]

Evaluate the model based on the eval data.

fit(train_data)[source]

Train the model based on the train data.

set_reduce_hook()[source]

Call the forward function of ‘distributed_model’ to apply grads reduce hook to each parameter of its module.

sync_grad_loss()[source]

Ensure that each parameter appears to the loss function to make the grads reduce sync in each node.

class recbole.trainer.trainer.DecisionTreeTrainer(config, model)[source]

Bases: recbole.trainer.trainer.AbstractTrainer

DecisionTreeTrainer is designed for DecisionTree model.

evaluate(eval_data, load_best_model=True, model_file=None, show_progress=False)[source]

Evaluate the model based on the eval data.

fit(train_data, valid_data=None, verbose=True, saved=True, show_progress=False)[source]

Train the model based on the train data.

class recbole.trainer.trainer.KGATTrainer(config, model)[source]

Bases: recbole.trainer.trainer.Trainer

KGATTrainer is designed for KGAT, which is a knowledge-aware recommendation method.

class recbole.trainer.trainer.KGTrainer(config, model)[source]

Bases: recbole.trainer.trainer.Trainer

KGTrainer is designed for Knowledge-aware recommendation methods. Some of these models need to train the recommendation related task and knowledge related task alternately.

class recbole.trainer.trainer.LightGBMTrainer(config, model)[source]

Bases: recbole.trainer.trainer.DecisionTreeTrainer

LightGBMTrainer is designed for LightGBM.

evaluate(eval_data, load_best_model=True, model_file=None, show_progress=False)[source]

Evaluate the model based on the eval data.

class recbole.trainer.trainer.MKRTrainer(config, model)[source]

Bases: recbole.trainer.trainer.Trainer

MKRTrainer is designed for MKR, which is a knowledge-aware recommendation method.

class recbole.trainer.trainer.NCLTrainer(config, model)[source]

Bases: recbole.trainer.trainer.Trainer

fit(train_data, valid_data=None, verbose=True, saved=True, show_progress=False, callback_fn=None)[source]

Train the model based on the train data and the valid data.

Parameters
  • train_data (DataLoader) – the train data.

  • valid_data (DataLoader, optional) – the valid data, default: None. If it’s None, the early_stopping is invalid.

  • verbose (bool, optional) – whether to write training and evaluation information to logger, default: True

  • saved (bool, optional) – whether to save the model parameters, default: True

  • show_progress (bool) – Show the progress of training epoch and evaluate epoch. Defaults to False.

  • callback_fn (callable) – Optional callback function executed at end of epoch. Includes (epoch_idx, valid_score) input arguments.

Returns

best valid score and best valid result. If valid_data is None, it returns (-1, None)

Return type

(float, dict)

class recbole.trainer.trainer.PretrainTrainer(config, model)[source]

Bases: recbole.trainer.trainer.Trainer

PretrainTrainer is designed for pre-training. It can be inherited by the trainer which needs pre-training and fine-tuning.

pretrain(train_data, verbose=True, show_progress=False)[source]
save_pretrained_model(epoch, saved_model_file)[source]

Store the model parameters information and training information.

Parameters
  • epoch (int) – the current epoch id

  • saved_model_file (str) – file name for saved pretrained model

class recbole.trainer.trainer.RaCTTrainer(config, model)[source]

Bases: recbole.trainer.trainer.PretrainTrainer

RaCTTrainer is designed for RaCT, which is an actor-critic reinforcement learning based general recommenders. It includes three training stages: actor pre-training, critic pre-training and actor-critic training.

fit(train_data, valid_data=None, verbose=True, saved=True, show_progress=False, callback_fn=None)[source]

Train the model based on the train data and the valid data.

Parameters
  • train_data (DataLoader) – the train data

  • valid_data (DataLoader, optional) – the valid data, default: None. If it’s None, the early_stopping is invalid.

  • verbose (bool, optional) – whether to write training and evaluation information to logger, default: True

  • saved (bool, optional) – whether to save the model parameters, default: True

  • show_progress (bool) – Show the progress of training epoch and evaluate epoch. Defaults to False.

  • callback_fn (callable) – Optional callback function executed at end of epoch. Includes (epoch_idx, valid_score) input arguments.

Returns

best valid score and best valid result. If valid_data is None, it returns (-1, None)

Return type

(float, dict)

class recbole.trainer.trainer.RecVAETrainer(config, model)[source]

Bases: recbole.trainer.trainer.Trainer

RecVAETrainer is designed for RecVAE, which is a general recommender.

class recbole.trainer.trainer.S3RecTrainer(config, model)[source]

Bases: recbole.trainer.trainer.PretrainTrainer

S3RecTrainer is designed for S3Rec, which is a self-supervised learning based sequential recommenders. It includes two training stages: pre-training ang fine-tuning.

fit(train_data, valid_data=None, verbose=True, saved=True, show_progress=False, callback_fn=None)[source]

Train the model based on the train data and the valid data.

Parameters
  • train_data (DataLoader) – the train data

  • valid_data (DataLoader, optional) – the valid data, default: None. If it’s None, the early_stopping is invalid.

  • verbose (bool, optional) – whether to write training and evaluation information to logger, default: True

  • saved (bool, optional) – whether to save the model parameters, default: True

  • show_progress (bool) – Show the progress of training epoch and evaluate epoch. Defaults to False.

  • callback_fn (callable) – Optional callback function executed at end of epoch. Includes (epoch_idx, valid_score) input arguments.

Returns

best valid score and best valid result. If valid_data is None, it returns (-1, None)

Return type

(float, dict)

class recbole.trainer.trainer.TraditionalTrainer(config, model)[source]

Bases: recbole.trainer.trainer.Trainer

TraditionalTrainer is designed for Traditional model(Pop,ItemKNN), which set the epoch to 1 whatever the config.

class recbole.trainer.trainer.Trainer(config, model)[source]

Bases: recbole.trainer.trainer.AbstractTrainer

The basic Trainer for basic training and evaluation strategies in recommender systems. This class defines common functions for training and evaluation processes of most recommender system models, including fit(), evaluate(), resume_checkpoint() and some other features helpful for model training and evaluation.

Generally speaking, this class can serve most recommender system models, If the training process of the model is to simply optimize a single loss without involving any complex training strategies, such as adversarial learning, pre-training and so on.

Initializing the Trainer needs two parameters: config and model. config records the parameters information for controlling training and evaluation, such as learning_rate, epochs, eval_step and so on. model is the instantiated object of a Model Class.

evaluate(eval_data, load_best_model=True, model_file=None, show_progress=False)[source]

Evaluate the model based on the eval data.

Parameters
  • eval_data (DataLoader) – the eval data

  • load_best_model (bool, optional) – whether load the best model in the training process, default: True. It should be set True, if users want to test the model after training.

  • model_file (str, optional) – the saved model file, default: None. If users want to test the previously trained model file, they can set this parameter.

  • show_progress (bool) – Show the progress of evaluate epoch. Defaults to False.

Returns

eval result, key is the eval metric and value in the corresponding metric value.

Return type

collections.OrderedDict

fit(train_data, valid_data=None, verbose=True, saved=True, show_progress=False, callback_fn=None)[source]

Train the model based on the train data and the valid data.

Parameters
  • train_data (DataLoader) – the train data

  • valid_data (DataLoader, optional) – the valid data, default: None. If it’s None, the early_stopping is invalid.

  • verbose (bool, optional) – whether to write training and evaluation information to logger, default: True

  • saved (bool, optional) – whether to save the model parameters, default: True

  • show_progress (bool) – Show the progress of training epoch and evaluate epoch. Defaults to False.

  • callback_fn (callable) – Optional callback function executed at end of epoch. Includes (epoch_idx, valid_score) input arguments.

Returns

best valid score and best valid result. If valid_data is None, it returns (-1, None)

Return type

(float, dict)

resume_checkpoint(resume_file)[source]

Load the model parameters information and training information.

Parameters

resume_file (file) – the checkpoint file

class recbole.trainer.trainer.XGBoostTrainer(config, model)[source]

Bases: recbole.trainer.trainer.DecisionTreeTrainer

XGBoostTrainer is designed for XGBOOST.

evaluate(eval_data, load_best_model=True, model_file=None, show_progress=False)[source]

Evaluate the model based on the eval data.