serial_trainer#
Module Contents#
Base class. Simulate multiple clients in sequence in a single process. |
|
Train multiple clients in a single process. |
- class SerialTrainer(model, client_num, cuda=False, logger=None)#
Bases:
fedlab.core.client.trainer.ClientTrainer
Base class. Simulate multiple clients in sequence in a single process.
- Parameters
model (torch.nn.Module) – Model used in this federation.
client_num (int) – Number of clients in current trainer.
cuda (bool) – Use GPUs or not. Default:
False
.logger (Logger, optional) – Object of
Logger
.
- property uplink_package(self)#
Return a tensor list for uploading to server.
This attribute will be called by client manager. Customize it for new algorithms.
- abstract _train_alone(self, model_parameters, train_loader)#
Train local model with
model_parameters
ontrain_loader
.- Parameters
model_parameters (torch.Tensor) – Serialized model parameters of one model.
train_loader (torch.utils.data.DataLoader) –
torch.utils.data.DataLoader
for this client.
- abstract _get_dataloader(self, client_id)#
Get
DataLoader
forclient_id
.
- local_process(self, id_list, payload)#
Train local model with different dataset according to client id in
id_list
.- Parameters
payload (list[torch.Tensor]) – communication payload from server.
- class SubsetSerialTrainer(model, dataset, data_slices, logger=None, cuda=False, args={'epochs': 5, 'batch_size': 100, 'lr': 0.1})#
Bases:
SerialTrainer
Train multiple clients in a single process.
Customize
_get_dataloader()
or_train_alone()
for specific algorithm design in clients.- Parameters
model (torch.nn.Module) – Model used in this federation.
dataset (torch.utils.data.Dataset) – Local dataset for this group of clients.
logger (Logger, optional) – Object of
Logger
.cuda (bool) – Use GPUs or not. Default:
False
.args (dict) – Uncertain variables. Default:
{"epochs": 5, "batch_size": 100, "lr": 0.1}
Note
len(data_slices) == client_num
, that is, each sub-index ofdataset
corresponds to a client’s local dataset one-by-one.- _get_dataloader(self, client_id)#
Return a training dataloader used in
train()
for client withid
- Parameters
client_id (int) –
client_id
of client to generate dataloader
Note
client_id
here is not equal toclient_id
in global FL setting. It is the index of client in currentSerialTrainer
.- Returns
DataLoader
for specific client’s sub-dataset
- _train_alone(self, model_parameters, train_loader)#
Single round of local training for one client.
Note
Overwrite this method to customize the PyTorch training pipeline.
- Parameters
model_parameters (torch.Tensor) – serialized model parameters.
train_loader (torch.utils.data.DataLoader) –
torch.utils.data.DataLoader
for this client.