trainer#

Module Contents#

ClientTrainer

An abstract class representing a client trainer.

SGDClientTrainer

Client backend handler, this class provides data process method to upper layer.

class ClientTrainer(model, cuda)#

Bases: fedlab.core.model_maintainer.ModelMaintainer

An abstract class representing a client trainer.

In FedLab, we define the backend of client trainer show manage its local model. It should have a function to update its model called local_process().

If you use our framework to define the activities of client, please make sure that your self-defined class should subclass it. All subclasses should overwrite local_process() and property uplink_package.

Parameters

Return a tensor list for uploading to server.

This attribute will be called by client manager. Customize it for new algorithms.

abstract classmethod local_process(self, payload) bool#

Manager of the upper layer will call this function with accepted payload

In synchronous mode, return True to end current FL round.

abstract train(self)#

Override this method to define the algorithm of training your model. This function should manipulate self._model

abstract evaluate(self)#

Evaluate quality of local model.

class SGDClientTrainer(model, data_loader, epochs, optimizer, criterion, cuda=False, logger=None)#

Bases: ClientTrainer

Client backend handler, this class provides data process method to upper layer.

Parameters

Return a tensor list for uploading to server.

This attribute will be called by client manager. Customize it for new algorithms.

local_process(self, payload)#

Manager of the upper layer will call this function with accepted payload

In synchronous mode, return True to end current FL round.

train(self, model_parameters) None#

Client trains its local model on local dataset.

Parameters

model_parameters (torch.Tensor) – Serialized model parameters.