Customize Trainers¶
Here, we present how to develop a new Trainer, and apply it into RecBole. For a new model, if the training method is complex, and existing trainer can not be used for training and evaluation, then we need to develop a new trainer.
The function used to train the model is fit(), it will call _train_epoch() to train the model.
The function used to evaluate the model is evaluate(), it will call _valid_epoch() to evaluate the model.
If the developed model need more complex training method,
then one can inherent the Trainer,
and revise fit() or _train_epoch().
If the developed model need more complex evaluation method,
then one can inherent the Trainer,
and revise evaluate() or _valid_epoch().
Example¶
Here we present a simple Trainer example, which is used for alternative optimization.
We revise the _train_epoch() method.
To begin with, we need to create a new class for
NewTrainer based on Trainer.
from recbole.trainer import Trainer
class NewTrainer(Trainer):
def __init__(self, config, model):
super(NewTrainer, self).__init__(config, model)
Then we revise _train_epoch().
Here, the losses are alternatively optimized after each epoch,
and the losses are computed by calculate_loss1() and calculate_loss2()
def _train_epoch(self, train_data, epoch_idx):
self.model.train()
total_loss = 0.
if epoch_idx % 2 == 0:
for batch_idx, interaction in enumerate(train_data):
interaction = interaction.to(self.device)
self.optimizer.zero_grad()
loss = self.model.calculate_loss1(interaction)
self._check_nan(loss)
loss.backward()
self.optimizer.step()
total_loss += loss.item()
else:
for batch_idx, interaction in enumerate(train_data):
interaction = interaction.to(self.device)
self.optimizer.zero_grad()
loss = self.model.calculate_loss2(interaction)
self._check_nan(loss)
loss.backward()
self.optimizer.step()
total_loss += loss.item()
return total_loss
Complete Code¶
from recbole.trainer import Trainer
class NewTrainer(Trainer):
def __init__(self, config, model):
super(NewTrainer, self).__init__(config, model)
def _train_epoch(self, train_data, epoch_idx):
self.model.train()
total_loss = 0.
if epoch_idx % 2 == 0:
for batch_idx, interaction in enumerate(train_data):
interaction = interaction.to(self.device)
self.optimizer.zero_grad()
loss = self.model.calculate_loss1(interaction)
self._check_nan(loss)
loss.backward()
self.optimizer.step()
total_loss += loss.item()
else:
for batch_idx, interaction in enumerate(train_data):
interaction = interaction.to(self.device)
self.optimizer.zero_grad()
loss = self.model.calculate_loss2(interaction)
self._check_nan(loss)
loss.backward()
self.optimizer.step()
total_loss += loss.item()
return total_loss