Source code for scdataset.strategy

"""
Sampling strategies for on-disk data collections.

This module provides various sampling strategies for efficiently iterating through
large on-disk datasets. Each strategy defines how indices are generated and
ordered for data loading.

.. autosummary::
   :toctree: generated/

   SamplingStrategy
   Streaming
   BlockShuffling
   BlockWeightedSampling
   ClassBalancedSampling
"""

import numpy as np
import warnings
from typing import Optional, Union
from numpy.typing import NDArray, ArrayLike


__all__ = [
    "SamplingStrategy",
    "Streaming", 
    "BlockShuffling",
    "BlockWeightedSampling",
    "ClassBalancedSampling",
]


[docs] class SamplingStrategy: """ Abstract base class for sampling strategies. This class defines the interface that all sampling strategies must implement. Sampling strategies determine how indices are generated from a data collection for training or inference. Attributes ---------- _shuffle_before_yield : bool or None Whether to shuffle indices before yielding batches. Set by subclasses. Notes ----- All subclasses must implement the :meth:`get_indices` method to define their specific sampling behavior. Examples -------- >>> # Custom sampling strategy >>> class CustomStrategy(SamplingStrategy): ... def get_indices(self, data_collection, seed=None, rng=None): ... return np.arange(len(data_collection)) """
[docs] def __init__(self): """Initialize the sampling strategy.""" self._shuffle_before_yield = None
[docs] def get_indices(self, data_collection, seed: Optional[int] = None, rng: Optional[np.random.Generator] = None) -> NDArray[np.intp]: """ Generate indices for sampling from the data collection. This is an abstract method that must be implemented by all subclasses. Parameters ---------- data_collection : object The data collection to sample from. Must support ``len()`` and indexing. seed : int, optional Random seed for reproducible sampling. Ignored if ``rng`` is provided. rng : numpy.random.Generator, optional Random number generator to use. If provided, ``seed`` is ignored. Returns ------- numpy.ndarray Array of indices to sample from the data collection. Raises ------ NotImplementedError Always raised as this is an abstract method. """ raise NotImplementedError("Subclasses must implement get_indices method")
def _get_rng(self, seed: Optional[int] = None, rng: Optional[np.random.Generator] = None) -> np.random.Generator: """ Get a random number generator from seed or rng parameter. This helper method provides a consistent way to obtain a random number generator across all sampling strategies. Parameters ---------- seed : int, optional Random seed to create a new generator. Ignored if ``rng`` is provided. rng : numpy.random.Generator, optional Existing random number generator to use. Takes precedence over ``seed``. Returns ------- numpy.random.Generator Random number generator instance. Examples -------- >>> strategy = SamplingStrategy() >>> rng = strategy._get_rng(seed=42) >>> isinstance(rng, np.random.Generator) True """ if rng is not None: return rng return np.random.default_rng(seed)
[docs] class Streaming(SamplingStrategy): """ Sequential streaming sampling strategy with optional buffer-level shuffling. This strategy provides indices in sequential order, with optional shuffling at the buffer level (defined by fetch_factor in scDataset). When shuffle=True, batches within each fetch buffer are shuffled, similar to Ray Dataset or WebDataset behavior, while maintaining overall sequential order across buffers. Parameters ---------- indices : array-like, optional Subset of indices to use for sampling. If None, uses all indices from 0 to len(data_collection)-1. shuffle : bool, default=False Whether to shuffle batches within each fetch buffer. When True, enables buffer-level shuffling that maintains sequential order between buffers but randomizes the order of batches within each buffer (defined by fetch_factor * batch_size). Attributes ---------- _shuffle_before_yield : bool Controlled by the shuffle parameter. True if buffer-level shuffling is enabled, False otherwise. _indices : numpy.ndarray or None Stored subset of indices if provided. shuffle : bool Whether buffer-level shuffling is enabled. Examples -------- >>> # Stream through entire dataset without shuffling >>> strategy = Streaming() >>> indices = strategy.get_indices(range(100)) >>> len(indices) 100 >>> # Stream through subset of indices >>> subset_strategy = Streaming(indices=[10, 20, 30]) >>> indices = subset_strategy.get_indices(range(100)) >>> list(indices) [10, 20, 30] >>> # Stream with buffer-level shuffling (like Ray Dataset/WebDataset) >>> shuffle_strategy = Streaming(shuffle=True) >>> # Batches within each fetch buffer will be shuffled, >>> # but buffers themselves maintain sequential order See Also -------- BlockShuffling : For shuffled block-based sampling BlockWeightedSampling : For weighted sampling with shuffling Notes ----- When shuffle=True, this strategy provides behavior similar to: - Ray Dataset's local shuffling within windows - WebDataset's shuffle buffer functionality The key difference from BlockShuffling is that Streaming maintains the overall sequential order of fetch buffers, only shuffling within each buffer, while BlockShuffling shuffles the order of blocks themselves. """
[docs] def __init__(self, indices: Optional[ArrayLike] = None, shuffle: bool = False): """ Initialize streaming strategy. Parameters ---------- indices : array-like, optional Subset of indices to stream through. If None, streams through all available indices. shuffle : bool, default=False Whether to enable buffer-level shuffling. When True, batches within each fetch buffer are shuffled while maintaining sequential order between buffers. """ super().__init__() self.shuffle = shuffle self._shuffle_before_yield = shuffle self._indices = indices
[docs] def get_len(self, data_collection) -> int: """ Get the effective length of the data collection for this strategy. Parameters ---------- data_collection : object The data collection to get length from. Must support ``len()``. Returns ------- int Number of samples that will be yielded by this strategy. Examples -------- >>> strategy = Streaming() >>> strategy.get_len(range(100)) 100 >>> subset_strategy = Streaming(indices=[1, 3, 5]) >>> subset_strategy.get_len(range(100)) 3 """ if self._indices is None: return len(data_collection) return len(self._indices)
[docs] def get_indices(self, data_collection, seed: Optional[int] = None, rng: Optional[np.random.Generator] = None) -> NDArray[np.intp]: """ Get indices for streaming sampling. Returns indices in sequential order. If shuffle=True was set during initialization, the _shuffle_before_yield attribute will cause buffer-level shuffling during iteration. Parameters ---------- data_collection : object The data collection to sample from. Must support ``len()``. seed : int, optional Random seed. Only used if shuffle=True for buffer-level shuffling during iteration, not for index generation which remains sequential. rng : numpy.random.Generator, optional Random number generator. Only used if shuffle=True for buffer-level shuffling during iteration. Returns ------- numpy.ndarray Array of indices in sequential order. Examples -------- >>> strategy = Streaming() >>> indices = strategy.get_indices(range(5)) >>> list(indices) [0, 1, 2, 3, 4] >>> subset_strategy = Streaming(indices=[2, 4, 6]) >>> indices = subset_strategy.get_indices(range(10)) >>> list(indices) [2, 4, 6] >>> # With shuffle=True, indices are still sequential >>> shuffle_strategy = Streaming(shuffle=True) >>> indices = shuffle_strategy.get_indices(range(5)) >>> list(indices) # Still sequential - shuffling happens at buffer level [0, 1, 2, 3, 4] """ if self._indices is None: return np.arange(len(data_collection)) return self._indices
[docs] class BlockShuffling(SamplingStrategy): """ Block-based shuffling sampling strategy. This strategy divides the data into blocks of fixed size and shuffles the order of blocks while maintaining the original order within each block. This provides a balance between randomization and maintaining some locality of data access patterns. Parameters ---------- block_size : int, default=8 Size of each block for shuffling. Larger blocks maintain more locality but provide less randomization. indices : array-like, optional Subset of indices to use for sampling. If None, uses all indices from 0 to len(data_collection)-1. drop_last : bool, default=False Whether to drop the last incomplete block if the total number of indices is not divisible by block_size. Attributes ---------- _shuffle_before_yield : bool Always True for block shuffling strategy. _indices : numpy.ndarray or None Stored subset of indices if provided. block_size : int Size of blocks for shuffling. drop_last : bool Whether to drop incomplete blocks. Notes ----- When ``drop_last=False`` and there's a remainder block smaller than ``block_size``, it's inserted at a random position among the shuffled complete blocks. Examples -------- >>> # Basic block shuffling >>> strategy = BlockShuffling(block_size=3) >>> np.random.seed(42) # For reproducible example >>> indices = strategy.get_indices(range(10), seed=42) >>> len(indices) 10 >>> # Drop incomplete blocks >>> strategy = BlockShuffling(block_size=3, drop_last=True) >>> indices = strategy.get_indices(range(10), seed=42) >>> len(indices) # 10 // 3 * 3 = 9 9 See Also -------- Streaming : For sequential sampling without shuffling BlockWeightedSampling : For weighted block-based sampling """
[docs] def __init__(self, block_size: int = 8, indices: Optional[ArrayLike] = None, drop_last: bool = False): """ Initialize block shuffling strategy. Parameters ---------- block_size : int, default=8 Size of blocks for shuffling. Must be positive. indices : array-like, optional Subset of indices to sample from. drop_last : bool, default=False Whether to drop the last incomplete block. Raises ------ ValueError If block_size is not positive. """ super().__init__() if block_size <= 0: raise ValueError("block_size must be positive") self._shuffle_before_yield = True self._indices = indices self.block_size = block_size self.drop_last = drop_last
[docs] def get_len(self, data_collection) -> int: """ Get the effective length of the data collection for this strategy. Takes into account the drop_last setting when calculating the effective length. Parameters ---------- data_collection : object The data collection to get length from. Must support ``len()``. Returns ------- int Number of samples that will be yielded by this strategy. Examples -------- >>> strategy = BlockShuffling(block_size=3, drop_last=False) >>> strategy.get_len(range(10)) 10 >>> strategy = BlockShuffling(block_size=3, drop_last=True) >>> strategy.get_len(range(10)) # 10 - (10 % 3) = 9 9 """ if self._indices is None: l = len(data_collection) else: l = len(self._indices) if self.drop_last: l -= l % self.block_size return l
[docs] def get_indices(self, data_collection, seed: Optional[int] = None, rng: Optional[np.random.Generator] = None) -> NDArray[np.intp]: """ Generate indices with block-based shuffling. Divides indices into blocks and shuffles the order of complete blocks. Incomplete blocks are either dropped or inserted at random positions depending on the ``drop_last`` setting. Parameters ---------- data_collection : object The data collection to sample from. Must support ``len()``. seed : int, optional Random seed for reproducible shuffling. Ignored if ``rng`` is provided. rng : numpy.random.Generator, optional Random number generator to use for shuffling. If provided, ``seed`` is ignored. Returns ------- numpy.ndarray Array of indices with blocks shuffled. Notes ----- When ``drop_last=True`` and there are remainder indices that don't form a complete block, they are randomly removed from the dataset. When ``drop_last=False``, remainder indices are inserted at a random position among the shuffled complete blocks. Examples -------- >>> strategy = BlockShuffling(block_size=2, drop_last=False) >>> indices = strategy.get_indices(range(5), seed=42) >>> len(indices) 5 >>> strategy = BlockShuffling(block_size=2, drop_last=True) >>> indices = strategy.get_indices(range(5), seed=42) >>> len(indices) # Drops the last incomplete block 4 Raises ------ ValueError If the random number generator cannot sample the required number of indices for removal when drop_last=True. """ if self._indices is None: indices = np.arange(len(data_collection)) else: indices = self._indices rng_obj = self._get_rng(seed, rng) n = len(indices) n_blocks = n // self.block_size n_complete = n_blocks * self.block_size remainder = n - n_complete if self.drop_last and remainder > 0: remove_indices = rng_obj.choice(indices, size=remainder, replace=False) mask = ~np.isin(indices, remove_indices) complete_part = indices[mask] else: complete_part = indices[:n_complete] # Reshape complete part into blocks blocks = complete_part.reshape(n_blocks, self.block_size) blocks = rng_obj.permutation(blocks, axis=0) if self.drop_last or remainder == 0: return blocks.reshape(-1) else: # Insert remainder block at a random block boundary insert_pos = rng_obj.integers(0, n_blocks + 1) before = blocks[:insert_pos].reshape(-1) after = blocks[insert_pos:].reshape(-1) return np.concatenate([before, indices[n_complete:], after])
[docs] class BlockWeightedSampling(SamplingStrategy): """ Weighted sampling with block-based shuffling. This strategy performs weighted sampling from the data collection and then applies block-based shuffling to the sampled indices. It supports both sampling with and without replacement, and can generate a different total number of samples than the original data collection size. Parameters ---------- block_size : int, default=8 Size of blocks for shuffling after weighted sampling. indices : array-like, optional Subset of indices to sample from. If None, uses all indices. weights : array-like, optional Sampling weights for each element in the data collection. Must be non-negative and sum to a positive value. If None, uses uniform sampling. total_size : int, optional Total number of samples to draw. If None, uses the length of indices or data_collection. replace : bool, default=True Whether to sample with replacement. sampling_size : int, optional Size of each sampling round when ``replace=False``. Required when ``replace=False``. Attributes ---------- _shuffle_before_yield : bool Always True for weighted sampling strategy. _indices : numpy.ndarray or None Stored subset of indices if provided. block_size : int Size of blocks for shuffling. weights : numpy.ndarray or None Normalized sampling weights. total_size : int or None Total number of samples to generate. replace : bool Whether sampling is with replacement. sampling_size : int or None Size of each sampling round for replacement=False. Raises ------ ValueError If weights are negative, sum to zero, or don't match data collection length. If sampling_size is not provided when replace=False. Warns ----- UserWarning If sampling_size is provided when replace=True (it will be ignored). Examples -------- >>> # Uniform weighted sampling >>> strategy = BlockWeightedSampling(block_size=2, total_size=6) >>> indices = strategy.get_indices(range(4), seed=42) >>> len(indices) 6 >>> # Custom weights favoring certain indices >>> weights = [0.1, 0.1, 0.4, 0.4] # Favor indices 2 and 3 >>> strategy = BlockWeightedSampling(weights=weights, total_size=8, seed=42) >>> indices = strategy.get_indices(range(4), seed=42) >>> len(indices) 8 >>> # Sampling without replacement >>> strategy = BlockWeightedSampling( ... total_size=10, replace=False, sampling_size=5 ... ) See Also -------- BlockShuffling : For unweighted block-based shuffling ClassBalancedSampling : For automatic class-balanced sampling """
[docs] def __init__( self, block_size: int = 8, indices: Optional[ArrayLike] = None, weights: Optional[ArrayLike] = None, total_size: Optional[int] = None, replace: bool = True, sampling_size: Optional[int] = None ): """ Initialize weighted sampling strategy. Parameters ---------- block_size : int, default=8 Size of blocks for shuffling. Must be positive. indices : array-like, optional Subset of indices to sample from. weights : array-like, optional Sampling weights. Will be normalized automatically. total_size : int, optional Total number of samples to generate. replace : bool, default=True Whether to sample with replacement. sampling_size : int, optional Required when replace=False. Size of each sampling round. Raises ------ ValueError If block_size is not positive, weights are invalid, or sampling_size is missing when replace=False. """ super().__init__() if block_size <= 0: raise ValueError("block_size must be positive") self._shuffle_before_yield = True self._indices = indices self.block_size = block_size if weights is not None: weights = np.asarray(weights) if np.any(weights < 0): raise ValueError("weights must be non-negative") if np.sum(weights) == 0: raise ValueError("weights must sum to a positive value") # Normalize weights weights = weights / np.sum(weights) self.weights = weights self.total_size = total_size self.replace = replace if not replace and sampling_size is None: raise ValueError("sampling_size must be provided when replace=False") if replace and sampling_size is not None: warnings.warn("sampling_size is ignored when replace=True, since it will sample with replacement") self.sampling_size = sampling_size
[docs] def get_len(self, data_collection) -> int: """ Get the effective length of the data collection for this strategy. Returns the total number of samples that will be generated, which may be different from the original data collection size. Parameters ---------- data_collection : object The data collection to get length from. Must support ``len()``. Returns ------- int Number of samples that will be yielded by this strategy. Examples -------- >>> strategy = BlockWeightedSampling(total_size=100) >>> strategy.get_len(range(50)) # Returns total_size 100 >>> strategy = BlockWeightedSampling() # No total_size specified >>> strategy.get_len(range(50)) # Returns collection length 50 """ if self.total_size is not None: l = self.total_size else: # Use the length of indices or data_collection if total_size not specified if self._indices is None: l = len(data_collection) else: l = len(self._indices) return l
[docs] def get_indices(self, data_collection, seed: Optional[int] = None, rng: Optional[np.random.Generator] = None) -> NDArray[np.intp]: """ Generate indices using weighted sampling followed by block shuffling. First performs weighted sampling (with or without replacement) to select indices, then applies block-based shuffling to the selected indices. Parameters ---------- data_collection : object The data collection to sample from. Must support ``len()``. seed : int, optional Random seed for reproducible sampling. Ignored if ``rng`` is provided. rng : numpy.random.Generator, optional Random number generator to use. If provided, ``seed`` is ignored. Returns ------- numpy.ndarray Array of indices after weighted sampling and block shuffling. Raises ------ ValueError If weights don't match the data collection length. Notes ----- When ``replace=False``, sampling is performed in rounds of size ``sampling_size`` until the desired ``total_size`` is reached. This helps avoid memory issues with large datasets. The selected indices are sorted before block shuffling to ensure consistent behavior across different random seeds. Examples -------- >>> # Weighted sampling with replacement >>> weights = [0.25, 0.25, 0.25, 0.25] # Uniform weights >>> strategy = BlockWeightedSampling( ... weights=weights, total_size=8, block_size=2 ... ) >>> indices = strategy.get_indices(range(4), seed=42) >>> len(indices) 8 >>> # Sampling without replacement >>> strategy = BlockWeightedSampling( ... total_size=6, replace=False, sampling_size=3, block_size=2 ... ) >>> indices = strategy.get_indices(range(10), seed=42) >>> len(indices) 6 """ if self.weights is not None: if len(self.weights) != len(data_collection): raise ValueError("weights must have the same length as data_collection") if self._indices is None: _indices = np.arange(len(data_collection)) else: _indices = self._indices rng_obj = self._get_rng(seed, rng) if self.replace: # Sample with replacement if self.total_size is not None: size = self.total_size else: size = len(_indices) indices = rng_obj.choice(_indices, size=size, replace=True, p=self.weights) else: # Sample without replacement until we have total_size sampled = 0 indices_list = [] while sampled < self.total_size: remaining = self.total_size - sampled current_size = min(self.sampling_size, remaining) new_indices = rng_obj.choice(_indices, size=current_size, replace=False, p=self.weights) indices_list.append(new_indices) sampled += len(new_indices) indices = np.concatenate(indices_list) indices.sort() n = len(indices) n_blocks = n // self.block_size n_complete = n_blocks * self.block_size complete_part = indices[:n_complete] remainder_part = indices[n_complete:] # Reshape complete part into blocks blocks = complete_part.reshape(n_blocks, self.block_size) rng_obj = self._get_rng(seed, rng) blocks = rng_obj.permutation(blocks, axis=0) if len(remainder_part) == 0: # Only return shuffled complete blocks return blocks.reshape(-1) else: # Insert remainder block at a random block boundary insert_pos = rng_obj.integers(0, n_blocks + 1) before = blocks[:insert_pos].reshape(-1) after = blocks[insert_pos:].reshape(-1) return np.concatenate([before, remainder_part, after])
[docs] class ClassBalancedSampling(BlockWeightedSampling): """ Class-balanced sampling with automatic weight computation. This strategy extends :class:`BlockWeightedSampling` by automatically computing balanced weights from provided labels, making each class equally likely to be sampled regardless of the original class distribution in the dataset. The weights are computed as the inverse of class frequencies, ensuring that underrepresented classes get higher sampling probability and overrepresented classes get lower sampling probability. Parameters ---------- labels : array-like Class labels for each sample. Can be numpy array, pandas Series, or any array-like object. Must have the same length as the data collection (validated during get_indices). block_size : int, default=8 Size of blocks for block shuffling after sampling. indices : array-like, optional Subset of indices to sample from. If None, uses all indices. total_size : int, optional Total number of samples to draw. If None, uses the length of indices or data_collection. replace : bool, default=True Whether to sample with replacement. sampling_size : int, optional Size of each sampling round when ``replace=False``. Required when ``replace=False``. Attributes ---------- labels : numpy.ndarray Array of class labels for each sample. Raises ------ ValueError If labels array is empty. Examples -------- >>> # Balanced sampling from imbalanced dataset >>> labels = [0, 0, 0, 0, 1, 1, 2] # Imbalanced classes >>> strategy = ClassBalancedSampling(labels, total_size=12) >>> indices = strategy.get_indices(range(7), seed=42) >>> len(indices) 12 >>> # Class weights are automatically computed >>> strategy = ClassBalancedSampling([0, 0, 1]) >>> # Class 0 appears twice, class 1 once >>> # So class 0 gets weight 1/2 = 0.5, class 1 gets weight 1/1 = 1.0 >>> # After normalization: class 0 weight ≈ 0.33, class 1 weight ≈ 0.67 See Also -------- BlockWeightedSampling : For manual weight specification BlockShuffling : For unweighted sampling Notes ----- The computed weights ensure that each class has equal probability of being sampled, not that each class appears equally often in the final sample. The actual class distribution in samples will depend on the random sampling process and may vary between different runs. """
[docs] def __init__( self, labels: ArrayLike, block_size: int = 8, indices: Optional[ArrayLike] = None, total_size: Optional[int] = None, replace: bool = True, sampling_size: Optional[int] = None ): """ Initialize class-balanced sampling strategy. Parameters ---------- labels : array-like Class labels for each sample in the data collection. block_size : int, default=8 Size of blocks for shuffling. Must be positive. indices : array-like, optional Subset of indices to sample from. total_size : int, optional Total number of samples to generate. replace : bool, default=True Whether to sample with replacement. sampling_size : int, optional Required when replace=False. Size of each sampling round. Raises ------ ValueError If labels array is empty or block_size is not positive. """ # Store labels and validate basic properties self.labels = np.asarray(labels) if len(self.labels) == 0: raise ValueError("labels cannot be empty") weights = self._compute_class_weights() super().__init__( block_size=block_size, indices=indices, weights=weights, total_size=total_size, replace=replace, sampling_size=sampling_size )
def _compute_class_weights(self) -> NDArray[np.floating]: """ Compute balanced weights for each sample based on inverse class frequency. Computes weights that are inversely proportional to class frequency, ensuring that all classes have equal sampling probability regardless of their representation in the original dataset. Returns ------- numpy.ndarray Normalized weights for each sample. Samples from less frequent classes receive higher weights. Examples -------- >>> labels = np.array([0, 0, 1, 2, 2, 2]) >>> strategy = ClassBalancedSampling(labels) >>> weights = strategy._compute_class_weights() >>> # Class 0: 2 samples -> weight 1/2 = 0.5 per sample >>> # Class 1: 1 sample -> weight 1/1 = 1.0 per sample >>> # Class 2: 3 samples -> weight 1/3 ≈ 0.33 per sample >>> weights.round(3) array([0.5, 0.5, 1.0, 0.333, 0.333, 0.333]) Notes ----- The weights are not normalized to sum to 1.0, as this normalization is handled by the parent class :class:`BlockWeightedSampling`. """ unique_classes, class_counts = np.unique(self.labels, return_counts=True) # Compute inverse frequency weights class_weights = 1.0 / class_counts # Create sample weights by mapping class weights to each sample weights = np.zeros(len(self.labels)) for cls, weight in zip(unique_classes, class_weights): weights[self.labels == cls] = weight return weights