Federated Optimization#
Standard FL Optimization contains two parts: 1. local train in client; 2. global aggregation in server. Local train and aggregation procedure are customizable in FedLab. You need to define ClientTrainer and ServerHandler.
Since ClientTrainer and ServerHandler are required to manipulate PyTorch Model. They are both inherited from ModelMaintainer.
class ModelMaintainer(object):
"""Maintain PyTorch model.
Provide necessary attributes and operation methods. More features with local or global model
will be implemented here.
Args:
model (torch.nn.Module): PyTorch model.
cuda (bool): Use GPUs or not.
device (str, optional): Assign model/data to the given GPUs. E.g., 'device:0' or 'device:0,1'. Defaults to None. If device is None and cuda is True, FedLab will set the gpu with the largest memory as default.
"""
def __init__(self,
model: torch.nn.Module,
cuda: bool,
device: str = None) -> None:
self.cuda = cuda
if cuda:
# dynamic gpu acquire.
if device is None:
self.device = get_best_gpu()
else:
self.device = device
self._model = deepcopy(model).cuda(self.device)
else:
self._model = deepcopy(model).cpu()
def set_model(self, parameters: torch.Tensor):
"""Assign parameters to self._model."""
SerializationTool.deserialize_model(self._model, parameters)
@property
def model(self) -> torch.nn.Module:
"""Return :class:`torch.nn.module`."""
return self._model
@property
def model_parameters(self) -> torch.Tensor:
"""Return serialized model parameters."""
return SerializationTool.serialize_model(self._model)
@property
def model_gradients(self) -> torch.Tensor:
"""Return serialized model gradients."""
return SerializationTool.serialize_model_gradients(self._model)
@property
def shape_list(self) -> List[torch.Tensor]:
"""Return shape of model parameters.
Currently, this attributes used in tensor compression.
"""
shape_list = [param.shape for param in self._model.parameters()]
return shape_list
Client local training#
The basic class of ClientTrainer is shown below, we encourage users define local training process following our code pattern:
class ClientTrainer(ModelMaintainer):
"""An abstract class representing a client trainer.
In FedLab, we define the backend of client trainer show manage its local model.
It should have a function to update its model called :meth:`local_process`.
If you use our framework to define the activities of client, please make sure that your self-defined class
should subclass it. All subclasses should overwrite :meth:`local_process` and property ``uplink_package``.
Args:
model (torch.nn.Module): PyTorch model.
cuda (bool): Use GPUs or not.
device (str, optional): Assign model/data to the given GPUs. E.g., 'device:0' or 'device:0,1'. Defaults to ``None``.
"""
def __init__(self,
model: torch.nn.Module,
cuda: bool,
device: str = None) -> None:
super().__init__(model, cuda, device)
self.client_num = 1 # default is 1.
self.dataset = FedDataset() # or Dataset
self.type = ORDINARY_TRAINER
def setup_dataset(self):
"""Set up local dataset ``self.dataset`` for clients."""
raise NotImplementedError()
def setup_optim(self):
"""Set up variables for optimization algorithms."""
raise NotImplementedError()
@property
@abstractmethod
def uplink_package(self) -> List[torch.Tensor]:
"""Return a tensor list for uploading to server.
This attribute will be called by client manager.
Customize it for new algorithms.
"""
raise NotImplementedError()
@abstractclassmethod
def local_process(self, payload: List[torch.Tensor]):
"""Manager of the upper layer will call this function with accepted payload
In synchronous mode, return True to end current FL round.
"""
raise NotImplementedError()
def train(self):
"""Override this method to define the training procedure. This function should manipulate :attr:`self._model`."""
raise NotImplementedError()
def validate(self):
"""Validate quality of local model."""
raise NotImplementedError()
def evaluate(self):
"""Evaluate quality of local model."""
raise NotImplementedError()
Overwrite
ClientTrainer.local_process()to define local procedure. Typically, you need to implement standard training pipeline of PyTorch.Attributes
modelandmodel_parametersis is associated withself._model. Please make sure the functionlocal_process()will manipulateself._model.
A standard implementation of this part is in :class:`SGDClientTrainer`.
Server global aggregation#
Calculation tasks related with PyTorch should be define in ServerHandler part. In FedLab, our basic class of Handler is defined in ServerHandler.
class ServerHandler(ModelMaintainer):
"""An abstract class representing handler of parameter server.
Please make sure that your self-defined server handler class subclasses this class
Example:
Read source code of :class:`SyncServerHandler` and :class:`AsyncServerHandler`.
Args:
model (torch.nn.Module): PyTorch model.
cuda (bool): Use GPUs or not.
device (str, optional): Assign model/data to the given GPUs. E.g., 'device:0' or 'device:0,1'. Defaults to None. If device is None and cuda is True, FedLab will set the gpu with the largest memory as default.
"""
def __init__(self,
model: torch.nn.Module,
cuda: bool,
device: str = None) -> None:
super().__init__(model, cuda, device)
@property
@abstractmethod
def downlink_package(self) -> List[torch.Tensor]:
"""Property for manager layer. Server manager will call this property when activates clients."""
raise NotImplementedError()
@property
@abstractmethod
def if_stop(self) -> bool:
""":class:`NetworkManager` keeps monitoring this attribute, and it will stop all related processes and threads when ``True`` returned."""
return False
@abstractmethod
def setup_optim(self):
"""Override this function to load your optimization hyperparameters."""
raise NotImplementedError()
@abstractmethod
def global_update(self, buffer):
raise NotImplementedError()
@abstractmethod
def load(self, payload):
"""Override this function to define how to update global model (aggregation or optimization)."""
raise NotImplementedError()
@abstractmethod
def evaluate(self):
"""Override this function to define the evaluation of global model."""
raise NotImplementedError()
User can define server aggregation strategy by finish following functions:
You can overwrite
_update_global_model()to customize global procedure._update_global_model()is required to manipulate global model parameters (self._model).Summarised FL aggregation strategies are implemented in
fedlab.utils.aggregator.
A standard implementation of this part is in SyncParameterServerHandler.