Weight initialisation functions

Weight initialisation

Weight initialisation functions.

torch_tools.weight_init.normal_init(model: Module, attr_name: str = 'weight', mean: float = 0.0, std: float = 0.02)[source]

Initialise model’s weights by sampling from a normal distribution.

The weights and biases are initialised.

Parameters:
  • model (Module) – The Module to be initialised.

  • attr_name (str) – The name of the attriubute in model to be initialised with normally distributed data.

  • mean (float, optional) – The mean of the normal distribution the weights are sampled from.

  • std (float, optional) – The standard deviation of the normal distribution the weights are sampled from.

Raises:
  • TypeError – If model is not an instance of torch.nn.Module.

  • TypeError – If attr_name is not a str.

  • TypeError – If mean is not a float.

  • TypeError – If std is not a float.