scdataset.strategy.SamplingStrategy#
- class scdataset.strategy.SamplingStrategy(indices: _Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | None = None)[source]
Bases:
objectAbstract base class for sampling strategies.
This class defines the interface that all sampling strategies must implement. Sampling strategies determine how indices are generated from a data collection for training or inference.
- Variables:
_shuffle_before_yield (bool or None) – Whether to shuffle indices before yielding batches. Set by subclasses.
_indices (numpy.ndarray or None) – Stored subset of indices if provided. Always sorted for optimal I/O.
Notes
All subclasses must implement the
get_indices()method to define their specific sampling behavior.scDataset relies on sorted indices for efficient sequential I/O access patterns. When indices are provided to any strategy, they are automatically sorted to ensure optimal performance.
Examples
>>> # Custom sampling strategy >>> class CustomStrategy(SamplingStrategy): ... def get_indices(self, data_collection, seed=None, rng=None): ... return np.arange(len(data_collection))
Methods
__init__([indices])Initialize the sampling strategy.
get_indices(data_collection[, seed, rng])Generate indices for sampling from the data collection.
- __init__(indices: _Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | None = None)[source]
Initialize the sampling strategy.
- Parameters:
indices (array-like, optional) – Subset of indices to use for sampling. If provided, they will be automatically sorted to ensure optimal I/O performance.
- get_indices(data_collection, seed: int | None = None, rng: Generator | None = None) ndarray[tuple[int, ...], dtype[int64]][source]
Generate indices for sampling from the data collection.
This is an abstract method that must be implemented by all subclasses.
- Parameters:
data_collection (object) – The data collection to sample from.
seed (int, optional) – Random seed for reproducible sampling. Ignored if
rngis provided.rng (numpy.random.Generator, optional) – Random number generator to use. If provided,
seedis ignored.
- Returns:
Array of indices to sample from the data collection.
- Return type:
- Raises:
NotImplementedError – Always raised as this is an abstract method.