handler#
Module Contents#
An abstract class representing handler of parameter server. |
|
Synchronous Parameter Server Handler. |
|
Asynchronous Parameter Server Handler |
- class ParameterServerBackendHandler(model, cuda=False)#
Bases:
fedlab.core.model_maintainer.ModelMaintainer
An abstract class representing handler of parameter server.
Please make sure that your self-defined server handler class subclasses this class
Example
Read source code of
SyncParameterServerHandler
andAsyncParameterServerHandler
.- abstract _update_model(self, model_parameters_list)#
Override this function for global model aggregation strategy.
- Parameters
model_parameters_list (list[torch.Tensor]) – A list of serialized model parameters collected from different clients.
- class SyncParameterServerHandler(model, global_round=5, cuda=False, sample_ratio=1.0, logger=Logger())#
Bases:
ParameterServerBackendHandler
Synchronous Parameter Server Handler.
Backend of synchronous parameter server: this class is responsible for backend computing in synchronous server.
Synchronous parameter server will wait for every client to finish local training process before the next FL round.
Details in paper: http://proceedings.mlr.press/v54/mcmahan17a.html
- Parameters
model (torch.nn.Module) – Model used in this federation.
global_round (int) – stop condition. Shut down FL system when global round is reached.
cuda (bool) – Use GPUs or not. Default:
False
sample_ratio (float) –
sample_ratio * client_num
is the number of clients to join in every FL round. Default:1.0
.logger (Logger, optional) – object of
Logger
.
- stop_condition(self) bool #
NetworkManager
keeps monitoring the return of this method, and it will stop all related processes and threads whenTrue
returned.
- sample_clients(self)#
Return a list of client rank indices selected randomly. The client ID is from
1
toself.client_num_in_total + 1
.
- add_model(self, sender_rank, model_parameters)#
Deal with incoming model parameters from one client.
Note
Return True when self._update_model is called.
- Parameters
sender_rank (int) – Rank of sender client in
torch.distributed
group.model_parameters (torch.Tensor) – Serialized model parameters from one client.
- _update_model(self, model_parameters_list)#
Update global model with collected parameters from clients.
Note
Server handler will call this method when its
client_buffer_cache
is full. User can overwrite the strategy of aggregation to apply onmodel_parameters_list
, and useSerializationTool.deserialize_model()
to load serialized parameters after aggregation intoself._model
.- Parameters
model_parameters_list (list[torch.Tensor]) – A list of parameters.aq
- property client_num_per_round(self)#
- class AsyncParameterServerHandler(model, alpha=0.5, total_time=5, strategy='constant', cuda=False, logger=Logger())#
Bases:
ParameterServerBackendHandler
Asynchronous Parameter Server Handler
Update global model immediately after receiving a ParameterUpdate message Paper: https://arxiv.org/abs/1903.03934
- Parameters
model (torch.nn.Module) – Global model in server
alpha (float) – weight used in async aggregation.
total_time (int) – stop condition. Shut down FL system when total_time is reached.
strategy (str) – adaptive strategy.
constant
,hinge
andpolynomial
is optional. Default:constant
.cuda (bool) – Use GPUs or not.
logger (Logger, optional) – object of
Logger
.
- property server_time(self)#
- stop_condition(self) bool #
NetworkManager
keeps monitoring the return of this method, and it will stop all related processes and threads whenTrue
returned.
- _update_model(self, client_model_parameters, model_time)#
“update global model from client_model_queue
- _adapt_alpha(self, receive_model_time)#
update the alpha according to staleness