serial_trainer#

Module Contents#

SerialTrainer

Base class. Simulate multiple clients in sequence in a single process.

SubsetSerialTrainer

Train multiple clients in a single process.

class SerialTrainer(model, client_num, cuda=False, logger=None)#

Bases: fedlab.core.client.trainer.ClientTrainer

Base class. Simulate multiple clients in sequence in a single process.

Parameters
  • model (torch.nn.Module) – Model used in this federation.

  • client_num (int) – Number of clients in current trainer.

  • cuda (bool) – Use GPUs or not. Default: False.

  • logger (Logger, optional) – Object of Logger.

Return a tensor list for uploading to server.

This attribute will be called by client manager. Customize it for new algorithms.

abstract _train_alone(self, model_parameters, train_loader)#

Train local model with model_parameters on train_loader.

Parameters
abstract _get_dataloader(self, client_id)#

Get DataLoader for client_id.

local_process(self, id_list, payload)#

Train local model with different dataset according to client id in id_list.

Parameters
  • id_list (list[int]) – Client id in this training serial.

  • payload (list[torch.Tensor]) – communication payload from server.

class SubsetSerialTrainer(model, dataset, data_slices, logger=None, cuda=False, args={'epochs': 5, 'batch_size': 100, 'lr': 0.1})#

Bases: SerialTrainer

Train multiple clients in a single process.

Customize _get_dataloader() or _train_alone() for specific algorithm design in clients.

Parameters
  • model (torch.nn.Module) – Model used in this federation.

  • dataset (torch.utils.data.Dataset) – Local dataset for this group of clients.

  • data_slices (list[list]) – Subset of indices of dataset.

  • logger (Logger, optional) – Object of Logger.

  • cuda (bool) – Use GPUs or not. Default: False.

  • args (dict) – Uncertain variables. Default: {"epochs": 5, "batch_size": 100, "lr": 0.1}

Note

len(data_slices) == client_num, that is, each sub-index of dataset corresponds to a client’s local dataset one-by-one.

_get_dataloader(self, client_id)#

Return a training dataloader used in train() for client with id

Parameters

client_id (int) – client_id of client to generate dataloader

Note

client_id here is not equal to client_id in global FL setting. It is the index of client in current SerialTrainer.

Returns

DataLoader for specific client’s sub-dataset

_train_alone(self, model_parameters, train_loader)#

Single round of local training for one client.

Note

Overwrite this method to customize the PyTorch training pipeline.

Parameters