Some random PyTorch utility functions
Torch utils
PyTorch utilities.
- torch_tools.torch_utils.disable_biases(model: torch.nn.Module)[source]
Disable all
bias
parameters in model.- Parameters:
model (Module) – The model to disable biases in.
- torch_tools.torch_utils.img_batch_dims_power_of_2(batch: torch.Tensor)[source]
Check height and width of
batch
are powers of 2.- Parameters:
batch (Tensor) – A mini-batch of image-like inputs.
- Raises:
TypeError – If
batch
is not aTensor
.RuntimeError – If
batch
does not have four dimensions.RuntimeError – If the
batch
’s images’ heights are not a power of 2.RuntimeError – If the
batch
’s images’ heights are not a power of 2.
- torch_tools.torch_utils.patchify_img_batch(img_batch: torch.Tensor, patch_size: int) torch.Tensor [source]
Turn
img_batch
into a collection of patches.Note: gradient flow works through this function.
- Parameters:
img (Tensor) – Convert
img_batch
into a batch of sub-patches. Should have size(N, C, H, W)
, whereN
is the batch size,C
is the number of channels,H
is the image height andW
the width.patch_size (int) – Size of the square patches to break the images in
img_batch
into.
- Returns:
img_batch
as a collection of small patches. The returnedTensor
has size(N * H / patch_size * W / patch_size, C, patch_size, patch_size)
. For example: using a batch of 10 RGB images of size 16x16, and a patch size of 4, will return aTensor
of shape(160, 3, 4, 4)
.- Return type:
Tensor
- torch_tools.torch_utils.target_from_mask_img(mask_img: torch.Tensor, num_classes: int) torch.Tensor [source]
Convert 1-channel image to a target tensor for semantic segmentation.
- Parameters:
mask_img (Tensor) – An image holding the segmentation mask. Should be on [0, num_classes) with shape
(H, W)
, whereH
is the image height andW
the width.- Returns:
Target Tensor of shape
(num_classes, H, W)
. Each element,target[:, i, j]
is a one-hot-encoded vector.- Return type:
Tensor
- Raises:
TypeError – If
mask_img
is not aTensor
.TypeError – If
num_classes
is not anint
.ValueError – If any of the values in
mask_img
cannot be cast as int.ValueError – If
num_classes < 2
.ValueError – If
mask_img
has values less than zero, or greater than/equal tonum_classes
.RuntimeError – If
mask_img
is not two-dimensional.
- torch_tools.torch_utils.total_image_variation(img_batch: torch.Tensor, mean_reduce: bool = True) torch.Tensor [source]
Compute the total variation of
img_batch
.- Parameters:
img_batch (Tensor) – A mini-batch of image-like inputs.
mean_reduce (bool, optional) – If
True
, the returned value is normalised by the number of elements inimg_batch
. IfFalse
, no division is performed.
- Returns:
The total variation measure.
- Return type:
Tensor
- Raises:
TypeError – If
img_batch
is not aTensor
.TypeError – If
mean_reduce
is not abool
.RuntimeError – If
img_batch
is not a 4DTensor
.
Notes
Heavily inspired by the TensorFlow code: https://github.com/tensorflow/tensorflow/blob/v2.14.0/tensorflow/python/ops/image_ops_impl.py#L3322-L3391