Some random PyTorch utility functions

Torch utils

PyTorch utilities.

torch_tools.torch_utils.disable_biases(model: torch.nn.Module)[source]

Disable all bias parameters in model.

Parameters:

model (Module) – The model to disable biases in.

torch_tools.torch_utils.img_batch_dims_power_of_2(batch: torch.Tensor)[source]

Check height and width of batch are powers of 2.

Parameters:

batch (Tensor) – A mini-batch of image-like inputs.

Raises:
  • TypeError – If batch is not a Tensor.

  • RuntimeError – If batch does not have four dimensions.

  • RuntimeError – If the batch’s images’ heights are not a power of 2.

  • RuntimeError – If the batch’s images’ heights are not a power of 2.

torch_tools.torch_utils.patchify_img_batch(img_batch: torch.Tensor, patch_size: int) torch.Tensor[source]

Turn img_batch into a collection of patches.

Note: gradient flow works through this function.

Parameters:
  • img (Tensor) – Convert img_batch into a batch of sub-patches. Should have size (N, C, H, W), where N is the batch size, C is the number of channels, H is the image height and W the width.

  • patch_size (int) – Size of the square patches to break the images in img_batch into.

Returns:

img_batch as a collection of small patches. The returned Tensor has size (N * H / patch_size * W / patch_size, C, patch_size, patch_size). For example: using a batch of 10 RGB images of size 16x16, and a patch size of 4, will return a Tensor of shape (160, 3, 4, 4).

Return type:

Tensor

torch_tools.torch_utils.target_from_mask_img(mask_img: torch.Tensor, num_classes: int) torch.Tensor[source]

Convert 1-channel image to a target tensor for semantic segmentation.

Parameters:

mask_img (Tensor) – An image holding the segmentation mask. Should be on [0, num_classes) with shape (H, W), where H is the image height and W the width.

Returns:

Target Tensor of shape (num_classes, H, W). Each element, target[:, i, j] is a one-hot-encoded vector.

Return type:

Tensor

Raises:
  • TypeError – If mask_img is not a Tensor.

  • TypeError – If num_classes is not an int.

  • ValueError – If any of the values in mask_img cannot be cast as int.

  • ValueError – If num_classes < 2.

  • ValueError – If mask_img has values less than zero, or greater than/equal to num_classes.

  • RuntimeError – If mask_img is not two-dimensional.

torch_tools.torch_utils.total_image_variation(img_batch: torch.Tensor, mean_reduce: bool = True) torch.Tensor[source]

Compute the total variation of img_batch.

Parameters:
  • img_batch (Tensor) – A mini-batch of image-like inputs.

  • mean_reduce (bool, optional) – If True, the returned value is normalised by the number of elements in img_batch. If False, no division is performed.

Returns:

The total variation measure.

Return type:

Tensor

Raises:
  • TypeError – If img_batch is not a Tensor.

  • TypeError – If mean_reduce is not a bool.

  • RuntimeError – If img_batch is not a 4D Tensor.

Notes

Heavily inspired by the TensorFlow code: https://github.com/tensorflow/tensorflow/blob/v2.14.0/tensorflow/python/ops/image_ops_impl.py#L3322-L3391