trainer#
Module Contents#
An abstract class representing a client backend trainer. |
|
Client backend handler, this class provides data process method to upper layer. |
- class ClientTrainer(model, cuda)#
Bases:
fedlab.core.model_maintainer.ModelMaintainer
An abstract class representing a client backend trainer.
In our framework, we define the backend of client trainer show manage its local model. It should have a function to update its model called
train()
.If you use our framework to define the activities of client, please make sure that your self-defined class should subclass it. All subclasses should overwrite
train()
.- Parameters
model (torch.nn.Module) – PyTorch model.
cuda (bool) – Use GPUs or not.
- abstract train(self)#
Override this method to define the algorithm of training your model. This function should manipulate
self._model
- abstract evaluate(self)#
Evaluate quality of local model.
- class ClientSGDTrainer(model, data_loader, epochs, optimizer, criterion, cuda=False, logger=Logger())#
Bases:
ClientTrainer
Client backend handler, this class provides data process method to upper layer.
- Parameters
model (torch.nn.Module) – PyTorch model.
data_loader (torch.utils.data.DataLoader) –
torch.utils.data.DataLoader
for this client.epochs (int) – the number of local epoch.
optimizer (torch.optim.Optimizer, optional) – optimizer for this client’s model.
criterion (torch.nn.Loss, optional) – loss function used in local training process.
cuda (bool, optional) – use GPUs or not. Default:
False
.logger (Logger, optional) – :object of
Logger
.
- train(self, model_parameters) None #
Client trains its local model on local dataset.
- Parameters
model_parameters (torch.Tensor) – Serialized model parameters.