Some random PyTorch utility functions
Torch utils
PyTorch utilities.
- torch_tools.torch_utils.disable_biases(model: torch.nn.Module)[source]
Disable all
biasparameters in model.- Parameters:
model (Module) – The model to disable biases in.
- torch_tools.torch_utils.img_batch_dims_power_of_2(batch: torch.Tensor)[source]
Check height and width of
batchare powers of 2.- Parameters:
batch (Tensor) – A mini-batch of image-like inputs.
- Raises:
TypeError – If
batchis not aTensor.RuntimeError – If
batchdoes not have four dimensions.RuntimeError – If the
batch’s images’ heights are not a power of 2.RuntimeError – If the
batch’s images’ heights are not a power of 2.
- torch_tools.torch_utils.patchify_img_batch(img_batch: torch.Tensor, patch_size: int) torch.Tensor[source]
Turn
img_batchinto a collection of patches.Note: gradient flow works through this function.
- Parameters:
img (Tensor) – Convert
img_batchinto a batch of sub-patches. Should have size(N, C, H, W), whereNis the batch size,Cis the number of channels,His the image height andWthe width.patch_size (int) – Size of the square patches to break the images in
img_batchinto.
- Returns:
img_batchas a collection of small patches. The returnedTensorhas size(N * H / patch_size * W / patch_size, C, patch_size, patch_size). For example: using a batch of 10 RGB images of size 16x16, and a patch size of 4, will return aTensorof shape(160, 3, 4, 4).- Return type:
Tensor
- torch_tools.torch_utils.target_from_mask_img(mask_img: torch.Tensor, num_classes: int) torch.Tensor[source]
Convert 1-channel image to a target tensor for semantic segmentation.
- Parameters:
mask_img (Tensor) – An image holding the segmentation mask. Should be on [0, num_classes) with shape
(H, W), whereHis the image height andWthe width.- Returns:
Target Tensor of shape
(num_classes, H, W). Each element,target[:, i, j]is a one-hot-encoded vector.- Return type:
Tensor
- Raises:
TypeError – If
mask_imgis not aTensor.TypeError – If
num_classesis not anint.ValueError – If any of the values in
mask_imgcannot be cast as int.ValueError – If
num_classes < 2.ValueError – If
mask_imghas values less than zero, or greater than/equal tonum_classes.RuntimeError – If
mask_imgis not two-dimensional.
- torch_tools.torch_utils.total_image_variation(img_batch: torch.Tensor, mean_reduce: bool = True) torch.Tensor[source]
Compute the total variation of
img_batch.- Parameters:
img_batch (Tensor) – A mini-batch of image-like inputs.
mean_reduce (bool, optional) – If
True, the returned value is normalised by the number of elements inimg_batch. IfFalse, no division is performed.
- Returns:
The total variation measure.
- Return type:
Tensor
- Raises:
TypeError – If
img_batchis not aTensor.TypeError – If
mean_reduceis not abool.RuntimeError – If
img_batchis not a 4DTensor.
Notes
Heavily inspired by the TensorFlow code: https://github.com/tensorflow/tensorflow/blob/v2.14.0/tensorflow/python/ops/image_ops_impl.py#L3322-L3391