trainer#

Module Contents#

ClientTrainer

An abstract class representing a client trainer.

SerialClientTrainer

Base class. Simulate multiple clients in sequence in a single process.

class ClientTrainer(model: torch.nn.Module, cuda: bool, device: str = None)#

Bases: fedlab.core.model_maintainer.ModelMaintainer

An abstract class representing a client trainer.

In FedLab, we define the backend of client trainer show manage its local model. It should have a function to update its model called local_process().

If you use our framework to define the activities of client, please make sure that your self-defined class should subclass it. All subclasses should overwrite local_process() and property uplink_package.

Parameters:
  • model (torch.nn.Module) – PyTorch model.

  • cuda (bool) – Use GPUs or not.

  • device (str, optional) – Assign model/data to the given GPUs. E.g., ‘device:0’ or ‘device:0,1’. Defaults to None.

Return a tensor list for uploading to server.

This attribute will be called by client manager. Customize it for new algorithms.

abstract setup_dataset()#

Set up local dataset self.dataset for clients.

abstract setup_optim()#

Set up variables for optimization algorithms.

abstract classmethod local_process(payload: List[torch.Tensor])#

Manager of the upper layer will call this function with accepted payload

In synchronous mode, return True to end current FL round.

abstract train()#

Override this method to define the training procedure. This function should manipulate self._model.

abstract validate()#

Validate quality of local model.

abstract evaluate()#

Evaluate quality of local model.

class SerialClientTrainer(model: torch.nn.Module, num_clients: int, cuda: bool, device: str = None, personal: bool = False)#

Bases: fedlab.core.model_maintainer.SerialModelMaintainer

Base class. Simulate multiple clients in sequence in a single process.

Parameters:
  • model (torch.nn.Module) – Model used in this federation.

  • num_clients (int) – Number of clients in current trainer.

  • cuda (bool) – Use GPUs or not. Default: False.

  • device (str, optional) – Assign model/data to the given GPUs. E.g., ‘device:0’ or ‘device:0,1’. Defaults to None.

  • personal (bool, optional) – If Ture is passed, SerialModelMaintainer will generate the copy of local parameters list and maintain them respectively. These paremeters are indexed by [0, num-1]. Defaults to False.

Return a tensor list for uploading to server.

This attribute will be called by client manager. Customize it for new algorithms.

abstract setup_dataset()#

Override this function to set up local dataset for clients

abstract setup_optim()#
abstract classmethod local_process(id_list: list, payload: List[torch.Tensor])#

Define the local main process.

abstract train()#

Override this method to define the algorithm of training your model. This function should manipulate self._model

abstract evaluate()#

Evaluate quality of local model.

abstract validate()#

Validate quality of local model.