serialization#
Module Contents#
- class SerializationTool#
Bases:
object
- static serialize_model_gradients(model: torch.nn.Module) torch.Tensor #
_summary_
- 参数
model (torch.nn.Module) – _description_
- 返回
_description_
- 返回类型
- static deserialize_model_gradients(model: torch.nn.Module, gradients: torch.Tensor)#
- static serialize_model(model: torch.nn.Module) torch.Tensor #
Unfold model parameters
Unfold every layer of model, concate all of tensors into one. Return a torch.Tensor with shape (size, ).
- 参数
model (torch.nn.Module) – model to serialize.
- static deserialize_model(model: torch.nn.Module, serialized_parameters: torch.Tensor, mode='copy')#
Assigns serialized parameters to model.parameters. This is done by iterating through
model.parameters()
and assigning the relevant params ingrad_update
. NOTE: this function manipulatesmodel.parameters
.- 参数
model (torch.nn.Module) – model to deserialize.
serialized_parameters (torch.Tensor) – serialized model parameters.
mode (str) – deserialize mode. “copy” or “add”.