Federated Optimization#
Standard FL Optimization contains two parts: 1. local train in client; 2. global aggregation in server. Local train and aggregation procedure are customizable in FedLab. You need to define ClientTrainer
and ParameterServerBackendHandler
.
Since ClientTrainer
and ParameterServerBackendHandler
are required to manipulate PyTorch Model. They are both inherited from ModelMaintainer
.
class ModelMaintainer(object):
"""Maintain PyTorch model.
Provide necessary attributes and operation methods.
Args:
model (torch.Module): PyTorch model.
cuda (bool): use GPUs or not.
"""
def __init__(self, model, cuda) -> None:
self.cuda = cuda
if cuda:
# dynamic gpu acquire.
self.gpu = get_best_gpu()
self._model = model.cuda(self.gpu)
else:
self._model = model.cpu()
@property
def model(self):
"""Return torch.nn.module"""
return self._model
@property
def model_parameters(self):
"""Return serialized model parameters."""
return SerializationTool.serialize_model(self._model)
@property
def shape_list(self):
"""Return shape of parameters"""
shape_list = [param.shape for param in self._model.parameters()]
return shape_list
Client local training#
The basic class of ClientTrainer is shown below, we encourage users define local training process following our code pattern:
class ClientTrainer(ModelMaintainer):
"""An abstract class representing a client backend trainer.
In our framework, we define the backend of client trainer show manage its local model.
It should have a function to update its model called :meth:`local_process`.
If you use our framework to define the activities of client, please make sure that your self-defined class
should subclass it. All subclasses should overwrite :meth:`local_process`.
Args:
model (torch.nn.Module): PyTorch model.
cuda (bool): Use GPUs or not.
"""
def __init__(self, model, cuda):
super().__init__(model, cuda)
self.client_num = 1 # default is 1.
self.type = ORDINARY_TRAINER
@property
def uplink_package(self):
"""Return a tensor list for uploading to server.
This attribute will be called by client manager.
Customize it for new algorithms.
"""
return [self.model_parameters]
def local_process(self, payload):
"""Manager of the upper layer will call this function with accepted payload"""
raise NotImplementedError()
def train(self):
"""Override this method to define the algorithm of training your model. This function should manipulate :attr:`self._model`"""
raise NotImplementedError()
def evaluate(self):
"""Evaluate quality of local model."""
raise NotImplementedError()
Overwrite
ClientTrainer.local_process()
to define local procedure. Typically, you need to implement standard training pipeline of PyTorch.Attributes
model
andmodel_parameters
is is associated withself._model
. Please make sure the functionlocal_process()
will manipulateself._model
.
A standard implementation of this part is in :class:`SGDClientTrainer`.
Server global aggregation#
Calculation tasks related with PyTorch should be define in ServerHandler part. In FedLab, our basic class of Handler is defined in ParameterServerBackendHandler
.
class ParameterServerBackendHandler(ModelMaintainer):
"""An abstract class representing handler of parameter server.
Please make sure that your self-defined server handler class subclasses this class
Example:
Read source code of :class:`SyncParameterServerHandler` and :class:`AsyncParameterServerHandler`.
"""
def __init__(self, model, cuda=False):
super().__init__(model, cuda)
@property
def downlink_package(self):
"""Property for manager layer. Server manager will call this property when activates clients."""
return [self.model_parameters]
@property
def if_stop(self):
""":class:`NetworkManager` keeps monitoring this attribute, and it will stop all related processes and threads when ``True`` returned."""
return False
def _update_global_model(self, *args, **kwargs):
"""Override this function for iterating global model (aggregation or optimization)."""
raise NotImplementedError()
User can define server aggregation strategy by finish following functions:
You can overwrite
_update_global_model()
to customize global procedure._update_global_model()
is required to manipulate global model parameters (self._model).Summarised FL aggregation strategies are implemented in
fedlab.utils.aggregator
.
A standard implementation of this part is in SyncParameterServerHandler.