trainer#
Module Contents#
An abstract class representing a client trainer. |
|
Base class. Simulate multiple clients in sequence in a single process. |
- class ClientTrainer(model: torch.nn.Module, cuda: bool, device: str = None)#
Bases:
fedlab.core.model_maintainer.ModelMaintainer
An abstract class representing a client trainer.
In FedLab, we define the backend of client trainer show manage its local model. It should have a function to update its model called
local_process()
.If you use our framework to define the activities of client, please make sure that your self-defined class should subclass it. All subclasses should overwrite
local_process()
and propertyuplink_package
.- Parameters:
model (torch.nn.Module) – PyTorch model.
cuda (bool) – Use GPUs or not.
device (str, optional) – Assign model/data to the given GPUs. E.g., ‘device:0’ or ‘device:0,1’. Defaults to
None
.
- abstract property uplink_package: List[torch.Tensor]#
Return a tensor list for uploading to server.
This attribute will be called by client manager. Customize it for new algorithms.
- abstract setup_dataset()#
Set up local dataset
self.dataset
for clients.
- abstract setup_optim()#
Set up variables for optimization algorithms.
- abstract classmethod local_process(payload: List[torch.Tensor])#
Manager of the upper layer will call this function with accepted payload
In synchronous mode, return True to end current FL round.
- abstract train()#
Override this method to define the training procedure. This function should manipulate
self._model
.
- abstract validate()#
Validate quality of local model.
- abstract evaluate()#
Evaluate quality of local model.
- class SerialClientTrainer(model: torch.nn.Module, num_clients: int, cuda: bool, device: str = None, personal: bool = False)#
Bases:
fedlab.core.model_maintainer.SerialModelMaintainer
Base class. Simulate multiple clients in sequence in a single process.
- Parameters:
model (torch.nn.Module) – Model used in this federation.
num_clients (int) – Number of clients in current trainer.
cuda (bool) – Use GPUs or not. Default:
False
.device (str, optional) – Assign model/data to the given GPUs. E.g., ‘device:0’ or ‘device:0,1’. Defaults to None.
personal (bool, optional) – If Ture is passed, SerialModelMaintainer will generate the copy of local parameters list and maintain them respectively. These paremeters are indexed by [0, num-1]. Defaults to False.
- abstract property uplink_package: List[List[torch.Tensor]]#
Return a tensor list for uploading to server.
This attribute will be called by client manager. Customize it for new algorithms.
- abstract setup_dataset()#
Override this function to set up local dataset for clients
- abstract setup_optim()#
- abstract classmethod local_process(id_list: list, payload: List[torch.Tensor])#
Define the local main process.
- abstract train()#
Override this method to define the algorithm of training your model. This function should manipulate
self._model
- abstract evaluate()#
Evaluate quality of local model.
- abstract validate()#
Validate quality of local model.