serialization#
Module Contents#
- class SerializationTool#
Bases:
object- static serialize_model_gradients(model: torch.nn.Module) torch.Tensor#
_summary_
- Parameters
model (torch.nn.Module) – _description_
- Returns
_description_
- Return type
- static deserialize_model_gradients(model: torch.nn.Module, gradients: torch.Tensor)#
- static serialize_model(model: torch.nn.Module) torch.Tensor#
Unfold model parameters
Unfold every layer of model, concate all of tensors into one. Return a torch.Tensor with shape (size, ).
- Parameters
model (torch.nn.Module) – model to serialize.
- static deserialize_model(model: torch.nn.Module, serialized_parameters: torch.Tensor, mode='copy')#
Assigns serialized parameters to model.parameters. This is done by iterating through
model.parameters()and assigning the relevant params ingrad_update. NOTE: this function manipulatesmodel.parameters.- Parameters
model (torch.nn.Module) – model to deserialize.
serialized_parameters (torch.Tensor) – serialized model parameters.
mode (str) – deserialize mode. “copy” or “add”.