recbole.trainer.trainer

class recbole.trainer.trainer.AbstractTrainer(config, model)[source]

Bases: object

Trainer Class is used to manage the training and evaluation processes of recommender system models. AbstractTrainer is an abstract class in which the fit() and evaluate() method should be implemented according to different training and evaluation strategies.

evaluate(eval_data)[source]

Evaluate the model based on the eval data.

fit(train_data)[source]

Train the model based on the train data.

class recbole.trainer.trainer.KGATTrainer(config, model)[source]

Bases: recbole.trainer.trainer.Trainer

KGATTrainer is designed for KGAT, which is a knowledge-aware recommendation method.

class recbole.trainer.trainer.KGTrainer(config, model)[source]

Bases: recbole.trainer.trainer.Trainer

KGTrainer is designed for Knowledge-aware recommendation methods. Some of these models need to train the recommendation related task and knowledge related task alternately.

class recbole.trainer.trainer.MKRTrainer(config, model)[source]

Bases: recbole.trainer.trainer.Trainer

MKRTrainer is designed for MKR, which is a knowledge-aware recommendation method.

class recbole.trainer.trainer.S3RecTrainer(config, model)[source]

Bases: recbole.trainer.trainer.Trainer

S3RecTrainer is designed for S3Rec, which is a self-supervised learning based sequential recommenders. It includes two training stages: pre-training ang fine-tuning.

fit(train_data, valid_data=None, verbose=True, saved=True, show_progress=False, callback_fn=None)[source]

Train the model based on the train data and the valid data.

Parameters
  • train_data (DataLoader) – the train data

  • valid_data (DataLoader, optional) – the valid data, default: None. If it’s None, the early_stopping is invalid.

  • verbose (bool, optional) – whether to write training and evaluation information to logger, default: True

  • saved (bool, optional) – whether to save the model parameters, default: True

  • show_progress (bool) – Show the progress of training epoch and evaluate epoch. Defaults to False.

  • callback_fn (callable) – Optional callback function executed at end of epoch. Includes (epoch_idx, valid_score) input arguments.

Returns

best valid score and best valid result. If valid_data is None, it returns (-1, None)

Return type

(float, dict)

pretrain(train_data, verbose=True, show_progress=False)[source]
save_pretrained_model(epoch, saved_model_file)[source]

Store the model parameters information and training information.

Parameters
  • epoch (int) – the current epoch id

  • saved_model_file (str) – file name for saved pretrained model

class recbole.trainer.trainer.TraditionalTrainer(config, model)[source]

Bases: recbole.trainer.trainer.Trainer

TraditionalTrainer is designed for Traditional model(Pop,ItemKNN), which set the epoch to 1 whatever the config.

class recbole.trainer.trainer.Trainer(config, model)[source]

Bases: recbole.trainer.trainer.AbstractTrainer

The basic Trainer for basic training and evaluation strategies in recommender systems. This class defines common functions for training and evaluation processes of most recommender system models, including fit(), evaluate(), resume_checkpoint() and some other features helpful for model training and evaluation.

Generally speaking, this class can serve most recommender system models, If the training process of the model is to simply optimize a single loss without involving any complex training strategies, such as adversarial learning, pre-training and so on.

Initializing the Trainer needs two parameters: config and model. config records the parameters information for controlling training and evaluation, such as learning_rate, epochs, eval_step and so on. model is the instantiated object of a Model Class.

evaluate(eval_data, load_best_model=True, model_file=None, show_progress=False)[source]

Evaluate the model based on the eval data.

Parameters
  • eval_data (DataLoader) – the eval data

  • load_best_model (bool, optional) – whether load the best model in the training process, default: True. It should be set True, if users want to test the model after training.

  • model_file (str, optional) – the saved model file, default: None. If users want to test the previously trained model file, they can set this parameter.

  • show_progress (bool) – Show the progress of evaluate epoch. Defaults to False.

Returns

eval result, key is the eval metric and value in the corresponding metric value.

Return type

dict

fit(train_data, valid_data=None, verbose=True, saved=True, show_progress=False, callback_fn=None)[source]

Train the model based on the train data and the valid data.

Parameters
  • train_data (DataLoader) – the train data

  • valid_data (DataLoader, optional) – the valid data, default: None. If it’s None, the early_stopping is invalid.

  • verbose (bool, optional) – whether to write training and evaluation information to logger, default: True

  • saved (bool, optional) – whether to save the model parameters, default: True

  • show_progress (bool) – Show the progress of training epoch and evaluate epoch. Defaults to False.

  • callback_fn (callable) – Optional callback function executed at end of epoch. Includes (epoch_idx, valid_score) input arguments.

Returns

best valid score and best valid result. If valid_data is None, it returns (-1, None)

Return type

(float, dict)

plot_train_loss(show=True, save_path=None)[source]

Plot the train loss in each epoch

Parameters
  • show (bool, optional) – Whether to show this figure, default: True

  • save_path (str, optional) – The data path to save the figure, default: None. If it’s None, it will not be saved.

resume_checkpoint(resume_file)[source]

Load the model parameters information and training information.

Parameters

resume_file (file) – the checkpoint file

class recbole.trainer.trainer.xgboostTrainer(config, model)[source]

Bases: recbole.trainer.trainer.AbstractTrainer

xgboostTrainer is designed for XGBOOST.

evaluate(eval_data, load_best_model=True, model_file=None, show_progress=False)[source]

Evaluate the model based on the eval data.

fit(train_data, valid_data=None, verbose=True, saved=True, show_progress=False)[source]

Train the model based on the train data.