Source code for torch_tools.datasets._shapes_dataset

"""Synthetic dataset object."""

from typing import Tuple, Optional, Dict, Callable, List

from torch import from_numpy, Tensor  # pylint: disable=no-name-in-module
from torch.utils.data import Dataset

from torchvision.transforms import Compose  # type: ignore

from numpy import ndarray, array, where, full
from numpy.random import default_rng

from skimage.morphology import star, square, octagon, disk


[docs] class ShapesDataset(Dataset): """Synthetic dataset which produces images withs spots and squares. *Warning*—this dataset object is untested. Parameters ---------- spot_prob : float, optional Probability of including spots in the image. square_prob : float, optional Probability of including sqaures in the image. num_spots : int, optional The number of spots that will be included in the image. num_squares : int, optional The number of squares that will be included in the image. length : int, optional The length of the data set. image_size : int, optional The length of the square images. input_tfms : Compose, optional A composition of transforms to apply to the input. target_tfms : Compose, optional A composition of transforms to apply to the target. seed : int Integer seed for numpy's default rng. Notes ----- The images have white backrounds and the shapes have randomly selected RGB colours on [0, 1)^{3}. To get the indices of each shape, use, for example >>> data_set = ShapesDataset() >>> spot_index = data_set.target_names.index("spot") >>> star_index = data_set.target_names.index("star") To print the classes as a list, use >>> print(data_set.target_names) """ # pylint: disable=too-many-positional-arguments,too-many-arguments def __init__( self, spot_prob: float = 0.5, square_prob: float = 0.5, star_prob: float = 0.5, octagon_prob: float = 0.5, num_shapes: int = 3, length: int = 1000, image_size: int = 256, input_tfms: Optional[Compose] = None, target_tfms: Optional[Compose] = None, seed: int = 666, ): """Build ``ShapesDataset``.""" self._len = length self._num_shapes = num_shapes self._probs = { "square": square_prob, "spot": spot_prob, "star": star_prob, "octagon": octagon_prob, } self._img_size = image_size self._x_tfms = input_tfms self._y_tfms = target_tfms self._rng = default_rng(seed=seed) _shapes: Dict[str, Callable] = { "square": lambda x: square(2 * x), "star": star, "octagon": lambda x: octagon(x, x), "spot": disk, } def __len__(self) -> int: """Return the length of the dataset. Returns ------- int The length of the dataset. """ return self._len def __getitem__(self, idx: int): """Return an image-target pair. Parameters ---------- idx : int The index of the item to be returned. Returns ------- img : Tensor An RGB image of shape (c, H, W). tgt : Tensor Target vector: — If there are no spots or squares, [0.0, 0.0] — If there are spots only, [1.0, 0.0] — If there are squares only, [0.0, 1.0] — If there are both, [1.0, 1.0] """ img, tgt = self._create_image() if self._x_tfms is not None: img = self._x_tfms(img) if self._y_tfms is not None: tgt = self._y_tfms(tgt) return img, tgt def _add_shape(self, image: ndarray, shape: str) -> bool: """Add spots to ``image``. Parameters ---------- image : ndarray RGB image. shape : str Name of the shape to include. Returns ------- include_spots : bool Whether or not the spots were added. """ include_shape = self._rng.random() <= self._probs[shape] if include_shape: for _ in range(self._num_shapes): colour = self._rng.random(size=(1, 3)) radius = self._img_size // 20 shape_arr = self._shapes[shape](radius) # pylint: disable=unbalanced-tuple-unpacking rows, cols = where(shape_arr == 1) left, top = self._rng.integers( 0, self._img_size - len(shape_arr), size=2, ) rows, cols = rows + top, cols + left image[rows, cols] = colour return include_shape def _create_image(self) -> Tuple[Tensor, Tensor]: """Create image. Returns ------- Tensor An RGB image of shape (c, H, W). Tensor Target vector: — If there are no spots or squares, [0.0, 0.0] — If there are spots only, [1.0, 0.0] — If there are squares only, [0.0, 1.0] — If there are both, [1.0, 1.0] """ # image = ones((self._img_size, self._img_size, 3), dtype=float32) image = full( (self._img_size, self._img_size, 3), fill_value=self._rng.random(size=(1, 3)), ) targets = [] for key in self._shapes: targets.append(self._add_shape(image, key)) return ( from_numpy(image).permute(2, 0, 1).float(), from_numpy(array(targets)).float(), ) @property def target_names(self) -> List[str]: """Return a list of target names order by their one-hot indices. Returns ------- List[str] A list of the names of the shapes, ordered by their one-hot indices. """ return list(self._shapes.keys())