recbole.trainer.trainer¶
-
class
recbole.trainer.trainer.
AbstractTrainer
(config, model)[source]¶ Bases:
object
Trainer Class is used to manage the training and evaluation processes of recommender system models. AbstractTrainer is an abstract class in which the fit() and evaluate() method should be implemented according to different training and evaluation strategies.
-
class
recbole.trainer.trainer.
KGATTrainer
(config, model)[source]¶ Bases:
recbole.trainer.trainer.Trainer
KGATTrainer is designed for KGAT, which is a knowledge-aware recommendation method.
-
class
recbole.trainer.trainer.
KGTrainer
(config, model)[source]¶ Bases:
recbole.trainer.trainer.Trainer
KGTrainer is designed for Knowledge-aware recommendation methods. Some of these models need to train the recommendation related task and knowledge related task alternately.
-
class
recbole.trainer.trainer.
MKRTrainer
(config, model)[source]¶ Bases:
recbole.trainer.trainer.Trainer
MKRTrainer is designed for MKR, which is a knowledge-aware recommendation method.
-
class
recbole.trainer.trainer.
S3RecTrainer
(config, model)[source]¶ Bases:
recbole.trainer.trainer.Trainer
S3RecTrainer is designed for S3Rec, which is a self-supervised learning based sequential recommenders. It includes two training stages: pre-training ang fine-tuning.
-
fit
(train_data, valid_data=None, verbose=True, saved=True)[source]¶ Train the model based on the train data and the valid data.
- Parameters
train_data (DataLoader) – the train data
valid_data (DataLoader, optional) – the valid data, default: None. If it’s None, the early_stopping is invalid.
verbose (bool, optional) – whether to write training and evaluation information to logger, default: True
saved (bool, optional) – whether to save the model parameters, default: True
- Returns
best valid score and best valid result. If valid_data is None, it returns (-1, None)
- Return type
(float, dict)
-
-
class
recbole.trainer.trainer.
TraditionalTrainer
(config, model)[source]¶ Bases:
recbole.trainer.trainer.Trainer
TraditionalTrainer is designed for Traditional model(Pop,ItemKNN), which set the epoch to 1 whatever the config.
-
class
recbole.trainer.trainer.
Trainer
(config, model)[source]¶ Bases:
recbole.trainer.trainer.AbstractTrainer
The basic Trainer for basic training and evaluation strategies in recommender systems. This class defines common functions for training and evaluation processes of most recommender system models, including fit(), evaluate(), resume_checkpoint() and some other features helpful for model training and evaluation.
Generally speaking, this class can serve most recommender system models, If the training process of the model is to simply optimize a single loss without involving any complex training strategies, such as adversarial learning, pre-training and so on.
Initializing the Trainer needs two parameters: config and model. config records the parameters information for controlling training and evaluation, such as learning_rate, epochs, eval_step and so on. More information can be found in [placeholder]. model is the instantiated object of a Model Class.
-
evaluate
(eval_data, load_best_model=True, model_file=None)[source]¶ Evaluate the model based on the eval data.
- Parameters
eval_data (DataLoader) – the eval data
load_best_model (bool, optional) – whether load the best model in the training process, default: True. It should be set True, if users want to test the model after training.
model_file (str, optional) – the saved model file, default: None. If users want to test the previously trained model file, they can set this parameter.
- Returns
eval result, key is the eval metric and value in the corresponding metric value
- Return type
dict
-
fit
(train_data, valid_data=None, verbose=True, saved=True)[source]¶ Train the model based on the train data and the valid data.
- Parameters
train_data (DataLoader) – the train data
valid_data (DataLoader, optional) – the valid data, default: None. If it’s None, the early_stopping is invalid.
verbose (bool, optional) – whether to write training and evaluation information to logger, default: True
saved (bool, optional) – whether to save the model parameters, default: True
- Returns
best valid score and best valid result. If valid_data is None, it returns (-1, None)
- Return type
(float, dict)
-