trainer#
Module Contents#
Base class. Train multiple clients in sequence with a single process. |
|
Train multiple clients in a single process. |
- class SerialTrainer(model, client_num, aggregator=None, cuda=False, logger=Logger())#
Bases:
fedlab.core.client.trainer.ClientTrainer
Base class. Train multiple clients in sequence with a single process.
- Parameters
model (torch.nn.Module) – Model used in this federation.
client_num (int) – Number of clients in current trainer.
aggregator (Aggregators, callable, optional) – Function to perform aggregation on a list of serialized model parameters.
cuda (bool) – Use GPUs or not. Default:
False
.logger (Logger, optional) – object of
Logger
.
- abstract _train_alone(self, model_parameters, train_loader)#
Train local model with
model_parameters
ontrain_loader
.- Parameters
model_parameters (torch.Tensor) – Serialized model parameters of one model.
train_loader (torch.utils.data.DataLoader) –
torch.utils.data.DataLoader
for this client.
- abstract _get_dataloader(self, client_id)#
Get
DataLoader
forclient_id
.
- train(self, model_parameters, id_list, aggregate=False)#
Train local model with different dataset according to client id in
id_list
.- Parameters
model_parameters (torch.Tensor) – Serialized model parameters.
aggregate (bool) – Whether to perform partial aggregation on this group of clients’ local models at the end of each local training round.
Note
Normally, aggregation is performed by server, while we provide
aggregate
option here to perform partial aggregation on current client group. This partial aggregation can reduce the aggregation workload of server.- Returns
Serialized model parameters / list of model parameters.
- class SubsetSerialTrainer(model, dataset, data_slices, aggregator=None, logger=Logger(), cuda=False, args=None)#
Bases:
SerialTrainer
Train multiple clients in a single process.
Customize
_get_dataloader()
or_train_alone()
for specific algorithm design in clients.- Parameters
model (torch.nn.Module) – Model used in this federation.
dataset (torch.utils.data.Dataset) – Local dataset for this group of clients.
aggregator (Aggregators, callable, optional) – Function to perform aggregation on a list of model parameters.
logger (Logger, optional) – object of
Logger
.cuda (bool) – Use GPUs or not. Default:
False
.args (dict, optional) – Uncertain variables.
Note
len(data_slices) == client_num
, that is, each sub-index ofdataset
corresponds to a client’s local dataset one-by-one.- _get_dataloader(self, client_id)#
Return a training dataloader used in
train()
for client withid
- Parameters
client_id (int) –
client_id
of client to generate dataloader
Note
client_id
here is not equal toclient_id
in global FL setting. It is the index of client in currentSerialTrainer
.- Returns
DataLoader
for specific client’s sub-dataset
- _train_alone(self, model_parameters, train_loader)#
Single round of local training for one client.
Note
Overwrite this method to customize the PyTorch training pipeline.
- Parameters
model_parameters (torch.Tensor) – serialized model parameters.
train_loader (torch.utils.data.DataLoader) –
torch.utils.data.DataLoader
for this client.