Source code for torch_tools.models._autoencoder_2d

"""A simple image encoder-decoder model."""

from torch.nn import Module

from torch import Tensor, set_grad_enabled

from torch_tools.models._encoder_2d import Encoder2d
from torch_tools.models._decoder_2d import Decoder2d

from torch_tools.models._argument_processing import (
    process_num_feats,
    process_u_architecture_layers,
    process_str_arg,
    process_negative_slope_arg,
    process_boolean_arg,
    process_2d_kernel_size,
    process_2d_block_style_arg,
    process_dropout_prob,
)

# pylint:disable=too-many-arguments, too-many-positional-arguments


[docs] class AutoEncoder2d(Module): """A simple encoder-decoder pair for image-like inputs. Parameters ---------- in_chans : int The number of input channels. out_chans : int The number of output layers the model should produce. num_layers : int, optional The number of layers in the encoder/decoder. features_start : int, optional The number of features produced by the first conv block. lr_slope : float, optional The negative slope to use in the ``LeakyReLU`` layers. pool_style : str, optional The pool style to use in the downsampling blocks ( ``"avg"`` or ``"max"`` ). bilinear : bool, optional Whether or not to upsample with bilinear interpolation ( ``True`` ) or ``ConvTranspose2d`` ( ``False`` ). kernel_size : int, optional Size of the square convolutional kernel to use on the ``Conv2d`` layers. Must be a positive, odd, int. block_style : str, optional Style of convolutional blocks to use in the encoding and decoding blocks. Use either ``"double_conv"`` or ``"conv_res"``. dropout : float, optional The dropout probability to apply at the output of the convolutional blocks. Notes ----- — Depending on the application, it may be convenient to pretrain this model and then use it for transfer learning—hence the ``frozen_encoder`` and ``frozen_decoder`` arguments in the ``forward`` method. There are no pretrained weights available, however. Examples -------- >>> from torch_tools import AutoEncoder2d >>> model = AutoEncoder2d( in_chans=3, start_features=64, num_blocks=4, pool_style="max", lr_slope=0.123, ) Another (potentially) useful feature (if you want to do transfer learning) if the ability to *freeze*—i.e. fix—the parameters of either the encoder or the decoder: >>> from torch import rand >>> from torch_tools import AutoEncoder2d >>> # Mini-batch of ten, three-channel images of 64 by 64 pixels >>> mini_batch = rand(10, 3, 64, 64) >>> model = AutoEncoder2d(in_chans=3, out_chans=3) >>> # With nothing frozen (default behaviour) >>> pred = model(mini_batch, frozen_encoder=False, frozen_decoder=False) >>> # With the encoder frozen: >>> pred = model(mini_batch, frozen_encoder=True, frozen_decoder=False) >>> # With both the encoder and decoder frozen: >>> pred = model(mini_batch, frozen_encoder=True, frozen_decoder=True) """ def __init__( self, in_chans: int, out_chans: int, num_layers: int = 4, features_start: int = 64, lr_slope: float = 0.1, pool_style: str = "max", bilinear: bool = False, kernel_size: int = 3, block_style: str = "double_conv", dropout: float = 0.25, ): """Build ``EncoderDecoder2d``.""" super().__init__() self.encoder = Encoder2d( process_num_feats(in_chans), process_num_feats(features_start), process_u_architecture_layers(num_layers), process_str_arg(pool_style), process_negative_slope_arg(lr_slope), process_2d_kernel_size(kernel_size), block_style=process_2d_block_style_arg(block_style), dropout=process_dropout_prob(dropout), ) self.decoder = Decoder2d( process_num_feats((2 ** (num_layers - 1)) * features_start), process_num_feats(out_chans), process_u_architecture_layers(num_layers), process_boolean_arg(bilinear), process_negative_slope_arg(lr_slope), process_2d_kernel_size(kernel_size), block_style=process_2d_block_style_arg(block_style), dropout=process_dropout_prob(dropout), )
[docs] def forward( self, batch: Tensor, frozen_encoder: bool = False, frozen_decoder: bool = False, ) -> Tensor: """Pass ``batch`` through the model. Parameters ---------- batch : Tensor A mini-batch of inputs. frozen_encoder : bool, optional Boolean switch controlling whether the encoder's gradients are enabled or disabled (useful for transfer learning). frozen_decoder : bool, optional Boolean switch controlling whether the decoder's gradients are enabled or disabled (useful for transfer learning). Returns ------- Tensor The result of passing ``batch`` through the model. """ with set_grad_enabled(not frozen_encoder): encoded = self.encoder(batch) with set_grad_enabled(not frozen_decoder): decoded = self.decoder(encoded) return decoded