recbole.trainer.trainer¶
- class recbole.trainer.trainer.AbstractTrainer(config, model)[source]¶
Bases:
object
Trainer Class is used to manage the training and evaluation processes of recommender system models. AbstractTrainer is an abstract class in which the fit() and evaluate() method should be implemented according to different training and evaluation strategies.
- class recbole.trainer.trainer.DecisionTreeTrainer(config, model)[source]¶
Bases:
recbole.trainer.trainer.AbstractTrainer
DecisionTreeTrainer is designed for DecisionTree model.
- class recbole.trainer.trainer.KGATTrainer(config, model)[source]¶
Bases:
recbole.trainer.trainer.Trainer
KGATTrainer is designed for KGAT, which is a knowledge-aware recommendation method.
- class recbole.trainer.trainer.KGTrainer(config, model)[source]¶
Bases:
recbole.trainer.trainer.Trainer
KGTrainer is designed for Knowledge-aware recommendation methods. Some of these models need to train the recommendation related task and knowledge related task alternately.
- class recbole.trainer.trainer.LightGBMTrainer(config, model)[source]¶
Bases:
recbole.trainer.trainer.DecisionTreeTrainer
LightGBMTrainer is designed for LightGBM.
- class recbole.trainer.trainer.MKRTrainer(config, model)[source]¶
Bases:
recbole.trainer.trainer.Trainer
MKRTrainer is designed for MKR, which is a knowledge-aware recommendation method.
- class recbole.trainer.trainer.NCLTrainer(config, model)[source]¶
Bases:
recbole.trainer.trainer.Trainer
- fit(train_data, valid_data=None, verbose=True, saved=True, show_progress=False, callback_fn=None)[source]¶
Train the model based on the train data and the valid data.
- Parameters
train_data (DataLoader) – the train data.
valid_data (DataLoader, optional) – the valid data, default: None. If it’s None, the early_stopping is invalid.
verbose (bool, optional) – whether to write training and evaluation information to logger, default: True
saved (bool, optional) – whether to save the model parameters, default: True
show_progress (bool) – Show the progress of training epoch and evaluate epoch. Defaults to
False
.callback_fn (callable) – Optional callback function executed at end of epoch. Includes (epoch_idx, valid_score) input arguments.
- Returns
best valid score and best valid result. If valid_data is None, it returns (-1, None)
- Return type
(float, dict)
- class recbole.trainer.trainer.PretrainTrainer(config, model)[source]¶
Bases:
recbole.trainer.trainer.Trainer
PretrainTrainer is designed for pre-training. It can be inherited by the trainer which needs pre-training and fine-tuning.
- class recbole.trainer.trainer.RaCTTrainer(config, model)[source]¶
Bases:
recbole.trainer.trainer.PretrainTrainer
RaCTTrainer is designed for RaCT, which is an actor-critic reinforcement learning based general recommenders. It includes three training stages: actor pre-training, critic pre-training and actor-critic training.
- fit(train_data, valid_data=None, verbose=True, saved=True, show_progress=False, callback_fn=None)[source]¶
Train the model based on the train data and the valid data.
- Parameters
train_data (DataLoader) – the train data
valid_data (DataLoader, optional) – the valid data, default: None. If it’s None, the early_stopping is invalid.
verbose (bool, optional) – whether to write training and evaluation information to logger, default: True
saved (bool, optional) – whether to save the model parameters, default: True
show_progress (bool) – Show the progress of training epoch and evaluate epoch. Defaults to
False
.callback_fn (callable) – Optional callback function executed at end of epoch. Includes (epoch_idx, valid_score) input arguments.
- Returns
best valid score and best valid result. If valid_data is None, it returns (-1, None)
- Return type
(float, dict)
- class recbole.trainer.trainer.RecVAETrainer(config, model)[source]¶
Bases:
recbole.trainer.trainer.Trainer
RecVAETrainer is designed for RecVAE, which is a general recommender.
- class recbole.trainer.trainer.S3RecTrainer(config, model)[source]¶
Bases:
recbole.trainer.trainer.PretrainTrainer
S3RecTrainer is designed for S3Rec, which is a self-supervised learning based sequential recommenders. It includes two training stages: pre-training ang fine-tuning.
- fit(train_data, valid_data=None, verbose=True, saved=True, show_progress=False, callback_fn=None)[source]¶
Train the model based on the train data and the valid data.
- Parameters
train_data (DataLoader) – the train data
valid_data (DataLoader, optional) – the valid data, default: None. If it’s None, the early_stopping is invalid.
verbose (bool, optional) – whether to write training and evaluation information to logger, default: True
saved (bool, optional) – whether to save the model parameters, default: True
show_progress (bool) – Show the progress of training epoch and evaluate epoch. Defaults to
False
.callback_fn (callable) – Optional callback function executed at end of epoch. Includes (epoch_idx, valid_score) input arguments.
- Returns
best valid score and best valid result. If valid_data is None, it returns (-1, None)
- Return type
(float, dict)
- class recbole.trainer.trainer.TraditionalTrainer(config, model)[source]¶
Bases:
recbole.trainer.trainer.Trainer
TraditionalTrainer is designed for Traditional model(Pop,ItemKNN), which set the epoch to 1 whatever the config.
- class recbole.trainer.trainer.Trainer(config, model)[source]¶
Bases:
recbole.trainer.trainer.AbstractTrainer
The basic Trainer for basic training and evaluation strategies in recommender systems. This class defines common functions for training and evaluation processes of most recommender system models, including fit(), evaluate(), resume_checkpoint() and some other features helpful for model training and evaluation.
Generally speaking, this class can serve most recommender system models, If the training process of the model is to simply optimize a single loss without involving any complex training strategies, such as adversarial learning, pre-training and so on.
Initializing the Trainer needs two parameters: config and model. config records the parameters information for controlling training and evaluation, such as learning_rate, epochs, eval_step and so on. model is the instantiated object of a Model Class.
- evaluate(eval_data, load_best_model=True, model_file=None, show_progress=False)[source]¶
Evaluate the model based on the eval data.
- Parameters
eval_data (DataLoader) – the eval data
load_best_model (bool, optional) – whether load the best model in the training process, default: True. It should be set True, if users want to test the model after training.
model_file (str, optional) – the saved model file, default: None. If users want to test the previously trained model file, they can set this parameter.
show_progress (bool) – Show the progress of evaluate epoch. Defaults to
False
.
- Returns
eval result, key is the eval metric and value in the corresponding metric value.
- Return type
collections.OrderedDict
- fit(train_data, valid_data=None, verbose=True, saved=True, show_progress=False, callback_fn=None)[source]¶
Train the model based on the train data and the valid data.
- Parameters
train_data (DataLoader) – the train data
valid_data (DataLoader, optional) – the valid data, default: None. If it’s None, the early_stopping is invalid.
verbose (bool, optional) – whether to write training and evaluation information to logger, default: True
saved (bool, optional) – whether to save the model parameters, default: True
show_progress (bool) – Show the progress of training epoch and evaluate epoch. Defaults to
False
.callback_fn (callable) – Optional callback function executed at end of epoch. Includes (epoch_idx, valid_score) input arguments.
- Returns
best valid score and best valid result. If valid_data is None, it returns (-1, None)
- Return type
(float, dict)
- class recbole.trainer.trainer.XGBoostTrainer(config, model)[source]¶
Bases:
recbole.trainer.trainer.DecisionTreeTrainer
XGBoostTrainer is designed for XGBOOST.