"""2D CNN model which wraps Torchvision's ResNet and VGG models."""
from typing import Dict, Any, Optional
from torch import Tensor, set_grad_enabled
from torch.nn import Module, Sequential, Flatten, Conv2d
from torch_tools.models._argument_processing import process_num_feats
from torch_tools.models._torchvision_encoder_backbones_2d import get_backbone
from torch_tools.models._adaptive_pools_2d import get_adaptive_pool
from torch_tools.models._fc_net import FCNet
# pylint: disable=too-many-arguments, too-many-positional-arguments
[docs]
class ConvNet2d(Module):
"""CNN model which wraps Torchvision's ResNet, VGG and Mobilenet_v3 models.
The model contains:
— An encoder, taken from Torchvision's ResNet/VGG models.
— An adaptive pooling layer.
— A fully-connected classification/regression head.
Parameters
----------
out_feats : int
The number of output features the model should produce (for example,
the number of classes).
in_channels : int
Number of input channels the model should take. Warning: if you don't
use three input channels, the first conv layer is overwritten, which
renders freezing the encoder pointless.
encoder_style : str, optional
The encoder option to use. The encoders are loaded from torchvision's
models. Options include all of torchvision's VGG, ResNET and MOBILENET
v3 options (i.e. ``"vgg11"``, ``"vgg11_bn"``, ``"resnet18"``,
``mobilenet_v3_small`` etc.).
pretrained : bool, optional
Determines whether the encoder is initialised with Torchvision's
pretrained weights. If ``True``, the model will load Torchvision's most
up-to-date image-net-trained weights.
pool_option : str, optional
The type of adaptive pooling layer to use. Choose from ``"avg"``,
``"max"`` or ``"avg-max-concat"`` (the latter simply concatenates the
former two). See ``torch_tools.models._adaptive_pools_2d`` for more
info.
fc_net_kwargs : Dict[str, Any], optional
Keyword arguments for
``torch_tools.models.fc_net.FCNet`` which serves as the
classification/regression part of the model.
Examples
--------
>>> from torch_tools import ConvNet2d
>>> model = ConvNet2d(out_feats=512,
in_channels=3,
encoder_style="vgg11_bn",
pretrained=True,
pool_style="avg-max-concat",
fc_net_kwargs={"hidden_sizes": (1024, 1024), "hidden_dropout": 0.25})
Another potentially useful feature is the ability to *freeze* the encoder,
and take advantage of the available pretrained weights by doing transfer
learning.
>>> from torch import rand
>>> from torch_tools import ConvNet2d
>>> model = ConvNet2d(out_feats=10, pretrained=True)
>>> # Batch of 10 fake three-channel images of 256x256 pixels
>>> mini_batch = rand(10, 3, 256, 256)
>>> # With the encoder frozen
>>> preds = model(mini_batch, frozen_encoder=True)
>>> # Without the encoder frozen (default behaviour)
>>> preds = model(mini_batch, frozen_encoder=False)
Notes
-----
— Even if you load pretrained weights, but *don't* freeze the encoder, you
will likely end up finding better performance than you would by randomly
initialising the model—even if it doesn't make sense. Welcome to deep
learning.
— If you change the number of input channels, don't bother freezing the
encoder—the first convolutional layer is overloaded and randomly
initialised.
— See ``torch_tools.models._conv_net_2d.ConvNet2d`` for more info.
"""
def __init__(
self,
out_feats: int,
in_channels: int = 3,
encoder_style: str = "resnet34",
pretrained=True,
pool_style: str = "avg-max-concat",
fc_net_kwargs: Optional[Dict[str, Any]] = None,
):
"""Build `ConvNet2d`."""
super().__init__()
self.backbone, num_feats, pool_size = get_backbone(
encoder_style,
pretrained=pretrained,
)
self._replace_first_conv_if_necessary(process_num_feats(in_channels))
self.pool = Sequential(
get_adaptive_pool(pool_style, pool_size),
Flatten(),
)
if fc_net_kwargs is not None:
_forbidden_args_in_dn_kwargs(fc_net_kwargs)
self._dn_args.update(fc_net_kwargs)
self.dense_layers = FCNet(
2 * num_feats if pool_style == "avg-max-concat" else num_feats,
process_num_feats(out_feats),
**self._dn_args,
)
_dn_args: Dict[str, Any]
_dn_args = {
"hidden_sizes": None,
"input_bnorm": False,
"hidden_bnorm": False,
"input_dropout": 0.0,
"hidden_dropout": 0.0,
"negative_slope": 0.2,
}
def _replace_first_conv_if_necessary(self, in_channels: int):
"""Replace the first conv layer if input channels don't match.
Parameters
----------
in_channels : int
The number of input channels requested by the user.
"""
for _, module in self.backbone.named_children():
if isinstance(module, Conv2d):
config = _conv_config(module)
if config["in_channels"] != in_channels:
config["in_channels"] = in_channels
setattr(self.backbone, _, Conv2d(**config)) # type:ignore
break
[docs]
def forward(self, batch: Tensor, frozen_encoder: bool = False) -> Tensor:
"""Pass `batch` through the model.
Parameters
----------
batch : Tensor
A mini-batch of inputs with shape (N, C, H, W), where N is the
batch-size, C the number of channels and (H, W) the input
size.
frozen_encoder : bool, optional
If ``True``, the gradients are disabled in the encoder. If
``False``, the gradients are enabled in the encoder.
Returns
-------
Tensor
The result of passing ``batch`` through the model.
"""
with set_grad_enabled(not frozen_encoder):
encoder_out = self.backbone(batch)
pool_out = self.pool(encoder_out)
return self.dense_layers(pool_out)
[docs]
def get_features(self, batch: Tensor) -> Tensor:
"""Return the features produced by the encoder and pool.
Parameters
----------
batch : Tensor
A mini-batch of image-like inputs.
Returns
-------
Tensor
The encoded features for the items in ``batch``.
"""
encoder_out = self.backbone(batch)
return self.pool(encoder_out)
def _conv_config(conv: Conv2d) -> Dict[str, Any]:
"""Return a dictionary with the `conv`'s instantiation arguments.
Parameters
----------
conv : Conv2d
Two-dimensional convolutional layer.
"""
return {
"in_channels": conv.in_channels,
"out_channels": conv.out_channels,
"kernel_size": conv.kernel_size,
"stride": conv.stride,
"padding": conv.padding,
"dilation": conv.dilation,
"groups": conv.groups,
"bias": not conv.bias is None,
"padding_mode": conv.padding_mode,
}
def _forbidden_args_in_dn_kwargs(user_dn_kwargs: Dict[str, Any]):
"""Check there are no forbidden arguments in ``user_dn_kwargs``.
Parameters
----------
user_dn_kwargs : Dict[str, Any]
The dense net kwargs supplied by the user.
Raises
------
RuntimeError
If ``in_feats`` is in ``self._dn_args``, the user has tried to
set the number of input features to the fully connected final
layer, which is forbidden.
RuntimeError
If ``out_feats`` is in ``self._dn_args``, the user has tried to
set the number of input features to the fully connected final
layer, which is forbidden.
"""
if "in_feats" in user_dn_kwargs:
msg = "Do not supply 'in_feats' in 'fc_net_kwargs'. This "
msg += "quantity is determined by the choice of encoder."
raise RuntimeError(msg)
if "out_feats" in user_dn_kwargs:
msg = "Do not supply 'out_feats' in 'fc_net_kwargs'. Instead "
msg += "set this quanitiy using the 'out_feats' argument of "
msg += "'ConvNet2d' at instantiation."
raise RuntimeError(msg)