basic_client#

Module Contents#

SGDClientTrainer

Client backend handler, this class provides data process method to upper layer.

SGDSerialClientTrainer

Train multiple clients in a single process.

class SGDClientTrainer(model: torch.nn.Module, cuda: bool = False, device: str = None, logger: fedlab.utils.Logger = None)#

Bases: fedlab.core.client.trainer.ClientTrainer

Client backend handler, this class provides data process method to upper layer.

Parameters:
  • model (torch.nn.Module) – PyTorch model.

  • cuda (bool, optional) – use GPUs or not. Default: False.

  • device (str, optional) – Assign model/data to the given GPUs. E.g., ‘device:0’ or ‘device:0,1’. Defaults to None.

  • logger (Logger, optional) – :object of Logger.

Return a tensor list for uploading to server.

This attribute will be called by client manager. Customize it for new algorithms.

setup_dataset(dataset)#

Set up local dataset self.dataset for clients.

setup_optim(epochs, batch_size, lr)#

Set up local optimization configuration.

Parameters:
  • epochs (int) – Local epochs.

  • batch_size (int) – Local batch size.

  • lr (float) – Learning rate.

local_process(payload, id)#

Manager of the upper layer will call this function with accepted payload

In synchronous mode, return True to end current FL round.

train(model_parameters, train_loader) None#

Client trains its local model on local dataset.

Parameters:

model_parameters (torch.Tensor) – Serialized model parameters.

class SGDSerialClientTrainer(model, num_clients, cuda=False, device=None, logger=None, personal=False)#

Bases: fedlab.core.client.trainer.SerialClientTrainer

Train multiple clients in a single process.

Customize _get_dataloader() or _train_alone() for specific algorithm design in clients.

Parameters:
  • model (torch.nn.Module) – Model used in this federation.

  • num_clients (int) – Number of clients in current trainer.

  • cuda (bool) – Use GPUs or not. Default: False.

  • device (str, optional) – Assign model/data to the given GPUs. E.g., ‘device:0’ or ‘device:0,1’. Defaults to None.

  • logger (Logger, optional) – Object of Logger.

  • personal (bool, optional) – If Ture is passed, SerialModelMaintainer will generate the copy of local parameters list and maintain them respectively. These paremeters are indexed by [0, num-1]. Defaults to False.

Return a tensor list for uploading to server.

This attribute will be called by client manager. Customize it for new algorithms.

setup_dataset(dataset)#

Override this function to set up local dataset for clients

setup_optim(epochs, batch_size, lr)#

Set up local optimization configuration.

Parameters:
  • epochs (int) – Local epochs.

  • batch_size (int) – Local batch size.

  • lr (float) – Learning rate.

local_process(payload, id_list)#

Define the local main process.

train(model_parameters, train_loader)#

Single round of local training for one client.

Note

Overwrite this method to customize the PyTorch training pipeline.

Parameters: