scdataset.strategy.SamplingStrategy

scdataset.strategy.SamplingStrategy#

class scdataset.strategy.SamplingStrategy(indices: _Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | None = None)[source]

Bases: object

Abstract base class for sampling strategies.

This class defines the interface that all sampling strategies must implement. Sampling strategies determine how indices are generated from a data collection for training or inference.

Variables:
  • _shuffle_before_yield (bool or None) – Whether to shuffle indices before yielding batches. Set by subclasses.

  • _indices (numpy.ndarray or None) – Stored subset of indices if provided. Always sorted for optimal I/O.

Notes

All subclasses must implement the get_indices() method to define their specific sampling behavior.

scDataset relies on sorted indices for efficient sequential I/O access patterns. When indices are provided to any strategy, they are automatically sorted to ensure optimal performance.

Examples

>>> # Custom sampling strategy
>>> class CustomStrategy(SamplingStrategy):
...     def get_indices(self, data_collection, seed=None, rng=None):
...         return np.arange(len(data_collection))

Methods

__init__([indices])

Initialize the sampling strategy.

get_indices(data_collection[, seed, rng])

Generate indices for sampling from the data collection.

__init__(indices: _Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | None = None)[source]

Initialize the sampling strategy.

Parameters:

indices (array-like, optional) – Subset of indices to use for sampling. If provided, they will be automatically sorted to ensure optimal I/O performance.

get_indices(data_collection, seed: int | None = None, rng: Generator | None = None) ndarray[tuple[int, ...], dtype[int64]][source]

Generate indices for sampling from the data collection.

This is an abstract method that must be implemented by all subclasses.

Parameters:
  • data_collection (object) – The data collection to sample from.

  • seed (int, optional) – Random seed for reproducible sampling. Ignored if rng is provided.

  • rng (numpy.random.Generator, optional) – Random number generator to use. If provided, seed is ignored.

Returns:

Array of indices to sample from the data collection.

Return type:

numpy.ndarray

Raises:

NotImplementedError – Always raised as this is an abstract method.