Federated Optimization#
Standard FL Optimization contains two parts: 1. local train in client; 2. global aggregation in server. Local train and aggregation procedure are customizable in FedLab. You need to define ClientTrainer and ParameterServerBackendHandler.
Since ClientTrainer and ParameterServerBackendHandler are required to manipulate PyTorch Model. They are both inherited from ModelMaintainer.
class ModelMaintainer(object):
"""Maintain PyTorch model.
Provide necessary attributes and operation methods.
Args:
model (torch.Module): PyTorch model.
cuda (bool): use GPUs or not.
"""
def __init__(self, model, cuda) -> None:
self.cuda = cuda
if cuda:
# dynamic gpu acquire.
self.gpu = get_best_gpu()
self._model = model.cuda(self.gpu)
else:
self._model = model.cpu()
@property
def model(self):
"""Return torch.nn.module"""
return self._model
@property
def model_parameters(self):
"""Return serialized model parameters."""
return SerializationTool.serialize_model(self._model)
@property
def shape_list(self):
"""Return shape of parameters"""
shape_list = [param.shape for param in self._model.parameters()]
return shape_list
Client local training#
The basic class of ClientTrainer is shown below, we encourage users define local training process following our code pattern:
class ClientTrainer(ModelMaintainer):
"""An abstract class representing a client backend trainer.
In our framework, we define the backend of client trainer show manage its local model.
It should have a function to update its model called :meth:`local_process`.
If you use our framework to define the activities of client, please make sure that your self-defined class
should subclass it. All subclasses should overwrite :meth:`local_process`.
Args:
model (torch.nn.Module): PyTorch model.
cuda (bool): Use GPUs or not.
"""
def __init__(self, model, cuda):
super().__init__(model, cuda)
self.client_num = 1 # default is 1.
self.type = ORDINARY_TRAINER
@property
def uplink_package(self):
"""Return a tensor list for uploading to server.
This attribute will be called by client manager.
Customize it for new algorithms.
"""
return [self.model_parameters]
def local_process(self, payload):
"""Manager of the upper layer will call this function with accepted payload"""
raise NotImplementedError()
def train(self):
"""Override this method to define the algorithm of training your model. This function should manipulate :attr:`self._model`"""
raise NotImplementedError()
def evaluate(self):
"""Evaluate quality of local model."""
raise NotImplementedError()
Overwrite
ClientTrainer.local_process()to define local procedure. Typically, you need to implement standard training pipeline of PyTorch.Attributes
modelandmodel_parametersis is associated withself._model. Please make sure the functionlocal_process()will manipulateself._model.
A standard implementation of this part is in :class:`SGDClientTrainer`.
Server global aggregation#
Calculation tasks related with PyTorch should be define in ServerHandler part. In FedLab, our basic class of Handler is defined in ParameterServerBackendHandler.
class ParameterServerBackendHandler(ModelMaintainer):
"""An abstract class representing handler of parameter server.
Please make sure that your self-defined server handler class subclasses this class
Example:
Read source code of :class:`SyncParameterServerHandler` and :class:`AsyncParameterServerHandler`.
"""
def __init__(self, model, cuda=False):
super().__init__(model, cuda)
@property
def downlink_package(self):
"""Property for manager layer. Server manager will call this property when activates clients."""
return [self.model_parameters]
@property
def if_stop(self):
""":class:`NetworkManager` keeps monitoring this attribute, and it will stop all related processes and threads when ``True`` returned."""
return False
def _update_global_model(self, *args, **kwargs):
"""Override this function for iterating global model (aggregation or optimization)."""
raise NotImplementedError()
User can define server aggregation strategy by finish following functions:
You can overwrite
_update_global_model()to customize global procedure._update_global_model()is required to manipulate global model parameters (self._model).Summarised FL aggregation strategies are implemented in
fedlab.utils.aggregator.
A standard implementation of this part is in SyncParameterServerHandler.