Federated Optimization#
Standard FL Optimization contains two parts: 1. local train in client; 2. global aggregation in server. Local train and aggregation procedure are customizable in FedLab. You need to define ClientTrainer
and ParameterServerBackendHandler
.
Since ClientTrainer
and ParameterServerBackendHandler
are required to manipulate PyTorch Model. They are both inherited from ModelMaintainer
.
class ModelMaintainer(object):
"""Maintain PyTorch model.
Provide necessary attributes and operation methods.
Args:
model (torch.Module): PyTorch model.
cuda (bool): use GPUs or not.
"""
def __init__(self, model, cuda) -> None:
self.cuda = cuda
if cuda:
# dynamic gpu acquire.
self.gpu = get_best_gpu()
self._model = model.cuda(self.gpu)
else:
self._model = model.cpu()
@property
def model(self):
"""Return torch.nn.module"""
return self._model
@property
def model_parameters(self):
"""Return serialized model parameters."""
return SerializationTool.serialize_model(self._model)
@property
def shape_list(self):
"""Return shape of parameters"""
shape_list = [param.shape for param in self._model.parameters()]
return shape_list
Client local training#
The basic class of ClientTrainer is shown below, we encourage users define local training process following our code pattern:
class ClientTrainer(ModelMaintainer):
"""An abstract class representing a client backend handler.
In our framework, we define the backend of client handler show manage its local model.
It should have a function to update its model called :meth:`train`.
If you use our framework to define the activities of client, please make sure that your self-defined class
should subclass it. All subclasses should overwrite :meth:`train`.
Args:
model (torch.nn.Module): Model used in this federation
cuda (bool): Use GPUs or not
"""
def __init__(self, model, cuda):
super().__init__(model, cuda)
self.client_num = 1 # default is 1.
self.type = ORDINARY_TRAINER
@abstractmethod
def train(self):
"""Override this method to define the algorithm of training your model. This function should manipulate :attr:`self._model`"""
raise NotImplementedError()
Overwrite
ClientTrainer.train()
to define local train procedure. Typically, you need to implement standard training pipeline of PyTorch.Attributes
model
andmodel_parameters
is is associated withself._model
. Please make sure the functiontrain()
will manipulateself._model
.
A standard implementation of this part is in :class:`ClientSGDTrainer`.
Server global aggregation#
Calculation tasks related with PyTorch should be define in ServerHandler part. In FedLab, our basic class of Handler is defined in ParameterServerBackendHandler
.
class ParameterServerBackendHandler(ModelMaintainer):
"""An abstract class representing handler of parameter server.
Please make sure that your self-defined server handler class subclasses this class
Example:
Read source code of :class:`SyncParameterServerHandler` and :class:`AsyncParameterServerHandler`.
"""
def __init__(self, model, cuda=False) -> None:
super().__init__(model, cuda)
@abstractmethod
def _update_model(self, model_parameters_list) -> torch.Tensor:
"""Override this function to update global model
Args:
model_parameters_list (list[torch.Tensor]): A list of serialized model parameters collected from different clients.
"""
raise NotImplementedError()
User can define server aggregation strategy by finish following functions:
You can overwrite
_update_model(model_parameters_list)
to customize aggregation procedure. Typically, you can define aggregation functions as FedLab._update_model(model_parameters_list)
is required to manipulate global model parameters (self._model).implemented in
fedlab.utils.aggregator
which used in FedLab standard implementations.
A standard implementation of this part is in SyncParameterServerHandler.