serialization#

Module Contents#

class SerializationTool#

Bases: object

static serialize_model_gradients(model: torch.nn.Module) torch.Tensor#

_summary_

参数

model (torch.nn.Module) – _description_

返回

_description_

返回类型

torch.Tensor

static deserialize_model_gradients(model: torch.nn.Module, gradients: torch.Tensor)#
static serialize_model(model: torch.nn.Module) torch.Tensor#

Unfold model parameters

Unfold every layer of model, concate all of tensors into one. Return a torch.Tensor with shape (size, ).

参数

model (torch.nn.Module) – model to serialize.

static deserialize_model(model: torch.nn.Module, serialized_parameters: torch.Tensor, mode='copy')#

Assigns serialized parameters to model.parameters. This is done by iterating through model.parameters() and assigning the relevant params in grad_update. NOTE: this function manipulates model.parameters.

参数
  • model (torch.nn.Module) – model to deserialize.

  • serialized_parameters (torch.Tensor) – serialized model parameters.

  • mode (str) – deserialize mode. “copy” or “add”.