scdataset.strategy.SamplingStrategy

scdataset.strategy.SamplingStrategy#

class scdataset.strategy.SamplingStrategy[source]

Bases: object

Abstract base class for sampling strategies.

This class defines the interface that all sampling strategies must implement. Sampling strategies determine how indices are generated from a data collection for training or inference.

Variables:

_shuffle_before_yield (bool or None) – Whether to shuffle indices before yielding batches. Set by subclasses.

Notes

All subclasses must implement the get_indices() method to define their specific sampling behavior.

Examples

>>> # Custom sampling strategy
>>> class CustomStrategy(SamplingStrategy):
...     def get_indices(self, data_collection, seed=None, rng=None):
...         return np.arange(len(data_collection))

Methods

__init__()

Initialize the sampling strategy.

get_indices(data_collection[, seed, rng])

Generate indices for sampling from the data collection.

__init__()[source]

Initialize the sampling strategy.

get_indices(data_collection, seed: int | None = None, rng: Generator | None = None) ndarray[tuple[int, ...], dtype[int64]][source]

Generate indices for sampling from the data collection.

This is an abstract method that must be implemented by all subclasses.

Parameters:
  • data_collection (object) – The data collection to sample from. Must support len() and indexing.

  • seed (int, optional) – Random seed for reproducible sampling. Ignored if rng is provided.

  • rng (numpy.random.Generator, optional) – Random number generator to use. If provided, seed is ignored.

Returns:

Array of indices to sample from the data collection.

Return type:

numpy.ndarray

Raises:

NotImplementedError – Always raised as this is an abstract method.