DataSet

Main dataset object for torch_tools.

class torch_tools.datasets._dataset.DataSet(*args: Any, **kwargs: Any)[source]

Completely custom and highly flexible dataset.

Parameters:
  • inputs (Sequence[str, Path, Tensor, ndarray]) – Inputs (or x items) for the dataset.

  • targets (Optional[Sequence[str, Path, Tensor, ndarray]]) – Targets (or y items) for the dataset.

  • input_tfms (Optional[Compose]) – A composition of transforms to apply to the inputs as they are selected.

  • target_tfms (Optional[Compose]) – A composition of transforms to apply to the targets as they are selected.

  • both_tfms (Optional[Compose]) – A composition of transforms to apply to both the input and target. Note: these transforms are applied after input_tfms and target_tfms, at which point the inputs and targets should be tensors. Each input–target pair will be concatenated along dim=0, transformed, and sliced apart, in the way one would apply rotations or reflections to images and segmentation masks. The dimensionality matters!

  • mixup (bool) – Should we apply mixup augmentation? See the paper: https://arxiv.org/abs/1710.09412. If True, we apply mixup, and the lambda parameter is sampled from a beta distribution with the parameters alpha=beta=0.4.

Notes

This dataset works for

  • Simple perceptron-style inputs.

  • Computer vision experiments, where the inputs are images.

  • Just about any problem where you need inputs, targets and custom transforms.

  • Doing inference (just set targets=None and it yields inputs only).

ShapesDataset

Synthetic dataset object.

class torch_tools.datasets._shapes_dataset.ShapesDataset(*args: Any, **kwargs: Any)[source]

Synthetic dataset which produces images withs spots and squares.

Warning—this dataset object is untested.

Parameters:
  • spot_prob (float, optional) – Probability of including spots in the image.

  • square_prob (float, optional) – Probability of including sqaures in the image.

  • num_spots (int, optional) – The number of spots that will be included in the image.

  • num_squares (int, optional) – The number of squares that will be included in the image.

  • length (int, optional) – The length of the data set.

  • image_size (int, optional) – The length of the square images.

  • input_tfms (Compose, optional) – A composition of transforms to apply to the input.

  • target_tfms (Compose, optional) – A composition of transforms to apply to the target.

  • seed (int) – Integer seed for numpy’s default rng.

Notes

The images have white backrounds and the shapes have randomly selected RGB colours on [0, 1)^{3}.

To get the indices of each shape, use, for example

>>> data_set = ShapesDataset()
>>> spot_index = data_set.target_names.index("spot")
>>> star_index = data_set.target_names.index("star")

To print the classes as a list, use

>>> print(data_set.target_names)
property target_names: List[str]

Return a list of target names order by their one-hot indices.

Returns:

A list of the names of the shapes, ordered by their one-hot indices.

Return type:

List[str]