fedprox#

Module Contents#

FedProxServerHandler

FedProx server handler.

FedProxClientTrainer

Federated client with local SGD with proximal term solver.

FedProxSerialClientTrainer

Train multiple clients in a single process.

class FedProxServerHandler(model: torch.nn.Module, global_round: int, num_clients: int = 0, sample_ratio: float = 1, cuda: bool = False, device: str = None, sampler: fedlab.contrib.client_sampler.base_sampler.FedSampler = None, logger: fedlab.utils.Logger = None)#

Bases: fedlab.contrib.algorithm.basic_server.SyncServerHandler

FedProx server handler.

class FedProxClientTrainer(model: torch.nn.Module, cuda: bool = False, device: str = None, logger: fedlab.utils.Logger = None)#

Bases: fedlab.contrib.algorithm.basic_client.SGDClientTrainer

Federated client with local SGD with proximal term solver.

setup_optim(epochs, batch_size, lr, mu)#

Set up local optimization configuration.

Parameters:
  • epochs (int) – Local epochs.

  • batch_size (int) – Local batch size.

  • lr (float) – Learning rate.

local_process(payload, id)#

Manager of the upper layer will call this function with accepted payload

In synchronous mode, return True to end current FL round.

train(model_parameters, train_loader, mu) None#

Client trains its local model on local dataset.

Parameters:

model_parameters (torch.Tensor) – Serialized model parameters.

class FedProxSerialClientTrainer(model, num_clients, cuda=False, device=None, logger=None, personal=False)#

Bases: fedlab.contrib.algorithm.basic_client.SGDSerialClientTrainer

Train multiple clients in a single process.

Customize _get_dataloader() or _train_alone() for specific algorithm design in clients.

Parameters:
  • model (torch.nn.Module) – Model used in this federation.

  • num_clients (int) – Number of clients in current trainer.

  • cuda (bool) – Use GPUs or not. Default: False.

  • device (str, optional) – Assign model/data to the given GPUs. E.g., ‘device:0’ or ‘device:0,1’. Defaults to None.

  • logger (Logger, optional) – Object of Logger.

  • personal (bool, optional) – If Ture is passed, SerialModelMaintainer will generate the copy of local parameters list and maintain them respectively. These paremeters are indexed by [0, num-1]. Defaults to False.

setup_optim(epochs, batch_size, lr, mu)#

Set up local optimization configuration.

Parameters:
  • epochs (int) – Local epochs.

  • batch_size (int) – Local batch size.

  • lr (float) – Learning rate.

local_process(payload, id_list)#

Define the local main process.

train(model_parameters, train_loader, mu) None#

Client trains its local model on local dataset.

Parameters: