Customize Trainers¶
Here, we present how to develop a new Trainer, and apply it into RecBole. For a new model, if the training method is complex, and existing trainer can not be used for training and evaluation, then we need to develop a new trainer.
The function used to train the model is fit()
, it will call _train_epoch()
to train the model.
The function used to evaluate the model is evaluate()
, it will call _valid_epoch()
to evaluate the model.
If the developed model need more complex training method,
then one can inherent the Trainer
,
and revise fit()
or _train_epoch()
.
If the developed model need more complex evaluation method,
then one can inherent the Trainer
,
and revise evaluate()
or _valid_epoch()
.
Example¶
Here we present a simple Trainer example, which is used for alternative optimization.
We revise the _train_epoch()
method.
To begin with, we need to create a new class for
NewTrainer
based on Trainer
.
from recbole.trainer import Trainer
class NewTrainer(Trainer):
def __init__(self, config, model):
super(NewTrainer, self).__init__(config, model)
Then we revise _train_epoch()
.
Here, the losses are alternatively optimized after each epoch,
and the losses are computed by calculate_loss1()
and calculate_loss2()
def _train_epoch(self, train_data, epoch_idx):
self.model.train()
total_loss = 0.
if epoch_idx % 2 == 0:
for batch_idx, interaction in enumerate(train_data):
interaction = interaction.to(self.device)
self.optimizer.zero_grad()
loss = self.model.calculate_loss1(interaction)
self._check_nan(loss)
loss.backward()
self.optimizer.step()
total_loss += loss.item()
else:
for batch_idx, interaction in enumerate(train_data):
interaction = interaction.to(self.device)
self.optimizer.zero_grad()
loss = self.model.calculate_loss2(interaction)
self._check_nan(loss)
loss.backward()
self.optimizer.step()
total_loss += loss.item()
return total_loss
Complete Code¶
from recbole.trainer import Trainer
class NewTrainer(Trainer):
def __init__(self, config, model):
super(NewTrainer, self).__init__(config, model)
def _train_epoch(self, train_data, epoch_idx):
self.model.train()
total_loss = 0.
if epoch_idx % 2 == 0:
for batch_idx, interaction in enumerate(train_data):
interaction = interaction.to(self.device)
self.optimizer.zero_grad()
loss = self.model.calculate_loss1(interaction)
self._check_nan(loss)
loss.backward()
self.optimizer.step()
total_loss += loss.item()
else:
for batch_idx, interaction in enumerate(train_data):
interaction = interaction.to(self.device)
self.optimizer.zero_grad()
loss = self.model.calculate_loss2(interaction)
self._check_nan(loss)
loss.backward()
self.optimizer.step()
total_loss += loss.item()
return total_loss