Source code for niacin.augment.randaugment

#!/usr/bin/env python3
# -*- encoding: utf-8 -*-

Implementation of RandAugment algorithm

from functools import partial
import typing as t

from numpy.random import default_rng

[docs]class RandAugment: """Implements RandAugment algorithm (randaugment_) as an iterator. RandAugment selects ``n`` functions at random from a sequence, and initializes them with magnitude ``m``. To use it, initialize this class with a list of transformation functions. Every time it is iterated over (e.g. in a for loop) it yields a random subset of those transformation functions. The original paper claims that m is on a scale from 0-10, but its listed experiments regularly use magnitudes in the 20s and 30s. We take this to mean that "10" was a typo, and the scale was meant to extend to 100. The defaults for m and n have been set to 1 and 10, respectively, for safety. Depending on your model size and task, you may achieve more accurate results with n ∈ {2, 3} and an m in the range [10, 20). By default, the transforms in a single sample will be returned in random order. If your transforms must occur in a logical sequence (e.g. swapping synonyms before removing random characters), set shuffle to False. Args: transforms: sequence of transformation functions m: magnitude of transformation, on a scale of 0-100 n: number of transforms to select shuffle: return transforms in random order seed: seed to use for the random number generator .. _randaugment: """ # niacin functions use floats for magnitude _p: float = 0.0 def __init__( self, transforms: t.List[t.Callable], m: int = 10, n: int = 1, shuffle: bool = True, seed: int = None ): self._transforms = list(transforms) self.n = n self.m = m self._shuffle = shuffle self._rng = default_rng(seed=seed) def __len__(self): return len(self._transforms) def __iter__(self): choices = [ partial(fun, p=self._p) for fun in self._rng.choice( self._transforms, size=self._n, replace=False, shuffle=self._shuffle ) ] return iter(choices) @property def n(self) -> int: """Size of sample - should be less than total available transforms """ return self._n @n.setter def n(self, value: int): n_funs = len(self._transforms) if value > n_funs: msg = f"Sample size n={value} must be <= number of transforms={n_funs}" raise ValueError(msg) self._n = value @property def m(self) -> int: """Magnitude of transformation - should be a number between 0 and 100 """ return int(self._p * 100) @m.setter def m(self, value: int): self._p = min(max(value / 100, 0.0), 1.0)