client#
Package Contents#
Base class for ClientManager. |
|
Active communication |
|
Passive communication |
|
Special client manager for |
|
Base class. Train multiple clients in sequence with a single process. |
|
Train multiple clients in a single process. |
- ORDINARY_TRAINER = 0#
- SERIAL_TRAINER = 1#
- class ClientManager(network, trainer)#
Bases:
fedlab.core.network_manager.NetworkManager
Base class for ClientManager.
ClientManager
defines client activation for different communication stages.- Parameters
network (DistNetwork) – Network configuration.
trainer (ClientTrainer) – Subclass of
ClientTrainer
. Providestrain()
andmodel
. Define local client training procedure.
- setup(self)#
Initialization stage.
ClientManager
reports number of clients simulated by current client process.
- class ClientActiveManager(network, trainer, logger=Logger())#
Bases:
ClientManager
Active communication
NetworkManager
for client in asynchronous FL pattern.- Parameters
network (DistNetwork) – network configuration.
trainer (ClientTrainer) – Subclass of
ClientTrainer
. Providestrain()
andmodel
. Define local client training procedure.logger (Logger, optional) – object of
Logger
.
- main_loop(self)#
Actions to perform on receiving new message, including local training
client requests data from server (ACTIVE)
after receiving data, client will train local model
client will synchronize with server actively
- synchronize(self)#
Synchronize local model with server
- class ClientPassiveManager(network, trainer, logger=Logger())#
Bases:
ClientManager
Passive communication
NetworkManager
for client in synchronous FL pattern.- Parameters
network (DistNetwork) – network configuration.
trainer (ClientTrainer) – Subclass of
ClientTrainer
. Providestrain()
andmodel
. Define local client training procedure.logger (Logger) – object of
Logger
.
- main_loop(self)#
Actions to perform when receiving a new message, including local training.
- Main procedure of each client:
client waits for data from server (PASSIVELY).
after receiving data, client start local model training procedure.
client synchronizes with server actively.
- synchronize(self)#
Synchronize local model with server
- class ScaleClientPassiveManager(network, trainer, logger=Logger())#
Bases:
fedlab.core.client.manager.ClientPassiveManager
Special client manager for
SerialTrainer
.We modify the communication agreements to create a mapping from client id to process rank. Thus,
ScaleClientPassiveManager
is able to simulate multiple clients in sequence.- Parameters
network (DistNetwork) – Distributed network to use.
trainer (ClientTrainer) – Subclass of
ClientTrainer
, providingtrain()
andmodel
. For more client simulation with single process, you are supposed to useSerialTrainer
here.logger (Logger) – object of
Logger
.
- main_loop(self)#
Actions to perform when receiving new message, including local training.
- synchronize(self)#
Synchronize local model with server actively
- class SerialTrainer(model, client_num, aggregator=None, cuda=False, logger=Logger())#
Bases:
fedlab.core.client.trainer.ClientTrainer
Base class. Train multiple clients in sequence with a single process.
- Parameters
model (torch.nn.Module) – Model used in this federation.
client_num (int) – Number of clients in current trainer.
aggregator (Aggregators, callable, optional) – Function to perform aggregation on a list of serialized model parameters.
cuda (bool) – Use GPUs or not. Default:
False
.logger (Logger, optional) – object of
Logger
.
- abstract _train_alone(self, model_parameters, train_loader)#
Train local model with
model_parameters
ontrain_loader
.- Parameters
model_parameters (torch.Tensor) – Serialized model parameters of one model.
train_loader (torch.utils.data.DataLoader) –
torch.utils.data.DataLoader
for this client.
- abstract _get_dataloader(self, client_id)#
Get
DataLoader
forclient_id
.
- train(self, model_parameters, id_list, aggregate=False)#
Train local model with different dataset according to client id in
id_list
.- Parameters
model_parameters (torch.Tensor) – Serialized model parameters.
aggregate (bool) – Whether to perform partial aggregation on this group of clients’ local models at the end of each local training round.
Note
Normally, aggregation is performed by server, while we provide
aggregate
option here to perform partial aggregation on current client group. This partial aggregation can reduce the aggregation workload of server.- Returns
Serialized model parameters / list of model parameters.
- class SubsetSerialTrainer(model, dataset, data_slices, aggregator=None, logger=Logger(), cuda=False, args=None)#
Bases:
SerialTrainer
Train multiple clients in a single process.
Customize
_get_dataloader()
or_train_alone()
for specific algorithm design in clients.- Parameters
model (torch.nn.Module) – Model used in this federation.
dataset (torch.utils.data.Dataset) – Local dataset for this group of clients.
aggregator (Aggregators, callable, optional) – Function to perform aggregation on a list of model parameters.
logger (Logger, optional) – object of
Logger
.cuda (bool) – Use GPUs or not. Default:
False
.args (dict, optional) – Uncertain variables.
Note
len(data_slices) == client_num
, that is, each sub-index ofdataset
corresponds to a client’s local dataset one-by-one.- _get_dataloader(self, client_id)#
Return a training dataloader used in
train()
for client withid
- Parameters
client_id (int) –
client_id
of client to generate dataloader
Note
client_id
here is not equal toclient_id
in global FL setting. It is the index of client in currentSerialTrainer
.- Returns
DataLoader
for specific client’s sub-dataset
- _train_alone(self, model_parameters, train_loader)#
Single round of local training for one client.
Note
Overwrite this method to customize the PyTorch training pipeline.
- Parameters
model_parameters (torch.Tensor) – serialized model parameters.
train_loader (torch.utils.data.DataLoader) –
torch.utils.data.DataLoader
for this client.