trainer#

Module Contents#

SerialTrainer

Base class. Train multiple clients in sequence with a single process.

SubsetSerialTrainer

Train multiple clients in a single process.

class SerialTrainer(model, client_num, aggregator=None, cuda=False, logger=Logger())#

Bases: fedlab.core.client.trainer.ClientTrainer

Base class. Train multiple clients in sequence with a single process.

Parameters
  • model (torch.nn.Module) – Model used in this federation.

  • client_num (int) – Number of clients in current trainer.

  • aggregator (Aggregators, callable, optional) – Function to perform aggregation on a list of serialized model parameters.

  • cuda (bool) – Use GPUs or not. Default: False.

  • logger (Logger, optional) – object of Logger.

abstract _train_alone(self, model_parameters, train_loader)#

Train local model with model_parameters on train_loader.

Parameters
abstract _get_dataloader(self, client_id)#

Get DataLoader for client_id.

train(self, model_parameters, id_list, aggregate=False)#

Train local model with different dataset according to client id in id_list.

Parameters
  • model_parameters (torch.Tensor) – Serialized model parameters.

  • id_list (list[int]) – Client id in this training serial.

  • aggregate (bool) – Whether to perform partial aggregation on this group of clients’ local models at the end of each local training round.

Note

Normally, aggregation is performed by server, while we provide aggregate option here to perform partial aggregation on current client group. This partial aggregation can reduce the aggregation workload of server.

Returns

Serialized model parameters / list of model parameters.

class SubsetSerialTrainer(model, dataset, data_slices, aggregator=None, logger=Logger(), cuda=False, args=None)#

Bases: SerialTrainer

Train multiple clients in a single process.

Customize _get_dataloader() or _train_alone() for specific algorithm design in clients.

Parameters
  • model (torch.nn.Module) – Model used in this federation.

  • dataset (torch.utils.data.Dataset) – Local dataset for this group of clients.

  • data_slices (list[list]) – subset of indices of dataset.

  • aggregator (Aggregators, callable, optional) – Function to perform aggregation on a list of model parameters.

  • logger (Logger, optional) – object of Logger.

  • cuda (bool) – Use GPUs or not. Default: False.

  • args (dict, optional) – Uncertain variables.

Note

len(data_slices) == client_num, that is, each sub-index of dataset corresponds to a client’s local dataset one-by-one.

_get_dataloader(self, client_id)#

Return a training dataloader used in train() for client with id

Parameters

client_id (int) – client_id of client to generate dataloader

Note

client_id here is not equal to client_id in global FL setting. It is the index of client in current SerialTrainer.

Returns

DataLoader for specific client’s sub-dataset

_train_alone(self, model_parameters, train_loader)#

Single round of local training for one client.

Note

Overwrite this method to customize the PyTorch training pipeline.

Parameters