trainer#
Module Contents#
An abstract class representing a client trainer. |
|
Client backend handler, this class provides data process method to upper layer. |
- class ClientTrainer(model, cuda)#
Bases:
fedlab.core.model_maintainer.ModelMaintainer
An abstract class representing a client trainer.
In FedLab, we define the backend of client trainer show manage its local model. It should have a function to update its model called
local_process()
.If you use our framework to define the activities of client, please make sure that your self-defined class should subclass it. All subclasses should overwrite
local_process()
and property uplink_package.- Parameters
model (torch.nn.Module) – PyTorch model.
cuda (bool) – Use GPUs or not.
- property uplink_package(self) List[torch.Tensor] #
Return a tensor list for uploading to server.
This attribute will be called by client manager. Customize it for new algorithms.
- abstract classmethod local_process(self, payload) bool #
Manager of the upper layer will call this function with accepted payload
In synchronous mode, return True to end current FL round.
- abstract train(self)#
Override this method to define the algorithm of training your model. This function should manipulate
self._model
- abstract evaluate(self)#
Evaluate quality of local model.
- class SGDClientTrainer(model, data_loader, epochs, optimizer, criterion, cuda=False, logger=None)#
Bases:
ClientTrainer
Client backend handler, this class provides data process method to upper layer.
- Parameters
model (torch.nn.Module) – PyTorch model.
data_loader (torch.utils.data.DataLoader) –
torch.utils.data.DataLoader
for this client.epochs (int) – the number of local epoch.
optimizer (torch.optim.Optimizer) – optimizer for this client’s model.
criterion (torch.nn.Loss) – loss function used in local training process.
cuda (bool, optional) – use GPUs or not. Default:
False
.logger (Logger, optional) – :object of
Logger
.
- property uplink_package(self)#
Return a tensor list for uploading to server.
This attribute will be called by client manager. Customize it for new algorithms.
- local_process(self, payload)#
Manager of the upper layer will call this function with accepted payload
In synchronous mode, return True to end current FL round.
- train(self, model_parameters) None #
Client trains its local model on local dataset.
- Parameters
model_parameters (torch.Tensor) – Serialized model parameters.