Source code for torch_tools.weight_init

"""Weight initialisation functions."""

from torch.nn import Module, Parameter
from torch import randn_like, no_grad  # pylint: disable=no-name-in-module


[docs] def normal_init( model: Module, attr_name: str = "weight", mean: float = 0.0, std: float = 0.02, ): """Initialise ``model``'s weights by sampling from a normal distribution. The weights *and* biases are initialised. Parameters ---------- model : Module The ``Module`` to be initialised. attr_name : str The name of the attriubute in ``model`` to be initialised with normally distributed data. mean : float, optional The mean of the normal distribution the weights are sampled from. std : float, optional The standard deviation of the normal distribution the weights are sampled from. Raises ------ TypeError If ``model`` is not an instance of ``torch.nn.Module``. TypeError If ``attr_name`` is not a str. TypeError If ``mean`` is not a float. TypeError If ``std`` is not a float. """ if not isinstance(model, Module): msg = "Weight init can only be applied to torch.nn.Module. " msg += f"Got '{type(model)}'." raise TypeError(msg) if not isinstance(attr_name, str): raise TypeError(f"'attr_name' should be str. Got '{type(attr_name)}'.") if not isinstance(mean, float): raise TypeError(f"'mean' should be a float. Got '{type(mean)}'.") if not isinstance(std, float): raise TypeError(f"'std' should be a float. Got '{std}'.") if hasattr(model, attr_name): with no_grad(): attr = getattr(model, attr_name) if attr is None: set_to = None else: set_to = ( Parameter((randn_like(attr) * std) + mean) if attr is not None else None ) setattr(model, attr_name, set_to)