"""
Sampling strategies for on-disk data collections.
This module provides various sampling strategies for efficiently iterating through
large on-disk datasets. Each strategy defines how indices are generated and
ordered for data loading.
.. autosummary::
:toctree: generated/
SamplingStrategy
Streaming
BlockShuffling
BlockWeightedSampling
ClassBalancedSampling
"""
import warnings
from typing import Optional
import numpy as np
from numpy.typing import ArrayLike, NDArray
__all__ = [
"SamplingStrategy",
"Streaming",
"BlockShuffling",
"BlockWeightedSampling",
"ClassBalancedSampling",
]
[docs]
class SamplingStrategy:
"""
Abstract base class for sampling strategies.
This class defines the interface that all sampling strategies must implement.
Sampling strategies determine how indices are generated from a data collection
for training or inference.
Attributes
----------
_shuffle_before_yield : bool or None
Whether to shuffle indices before yielding batches. Set by subclasses.
_indices : numpy.ndarray or None
Stored subset of indices if provided. Always sorted for optimal I/O.
Notes
-----
All subclasses must implement the :meth:`get_indices` method to define
their specific sampling behavior.
scDataset relies on sorted indices for efficient sequential I/O access
patterns. When indices are provided to any strategy, they are automatically
sorted to ensure optimal performance.
Examples
--------
>>> # Custom sampling strategy
>>> class CustomStrategy(SamplingStrategy):
... def get_indices(self, data_collection, seed=None, rng=None):
... return np.arange(len(data_collection))
"""
[docs]
def __init__(self, indices: Optional[ArrayLike] = None):
"""
Initialize the sampling strategy.
Parameters
----------
indices : array-like, optional
Subset of indices to use for sampling. If provided, they will be
automatically sorted to ensure optimal I/O performance.
"""
self._shuffle_before_yield = None
self._indices = self._validate_and_sort_indices(indices)
def _validate_and_sort_indices(
self, indices: Optional[ArrayLike]
) -> Optional[np.ndarray]:
"""
Validate and sort indices for optimal I/O performance.
scDataset relies on sorted indices for efficient sequential I/O access.
This method ensures any provided indices are sorted and emits a warning
if reordering was necessary.
Parameters
----------
indices : array-like, optional
Indices to validate and sort.
Returns
-------
numpy.ndarray or None
Sorted indices array, or None if no indices provided.
Warns
-----
UserWarning
If indices were not already sorted and had to be reordered.
"""
if indices is None:
return None
indices = np.asarray(indices)
sorted_indices = np.sort(indices)
if not np.array_equal(indices, sorted_indices):
warnings.warn(
"Provided indices were not sorted. They have been automatically "
"sorted to ensure optimal I/O performance. scDataset relies on "
"sorted indices for efficient data access patterns.",
UserWarning,
stacklevel=2,
)
return sorted_indices
[docs]
def get_indices(
self,
data_collection,
seed: Optional[int] = None,
rng: Optional[np.random.Generator] = None,
) -> NDArray[np.intp]:
"""
Generate indices for sampling from the data collection.
This is an abstract method that must be implemented by all subclasses.
Parameters
----------
data_collection : object
The data collection to sample from.
seed : int, optional
Random seed for reproducible sampling. Ignored if ``rng`` is provided.
rng : numpy.random.Generator, optional
Random number generator to use. If provided, ``seed`` is ignored.
Returns
-------
numpy.ndarray
Array of indices to sample from the data collection.
Raises
------
NotImplementedError
Always raised as this is an abstract method.
"""
raise NotImplementedError("Subclasses must implement get_indices method")
def _get_rng(
self, seed: Optional[int] = None, rng: Optional[np.random.Generator] = None
) -> np.random.Generator:
"""
Get a random number generator from seed or rng parameter.
This helper method provides a consistent way to obtain a random number
generator across all sampling strategies.
Parameters
----------
seed : int, optional
Random seed to create a new generator. Ignored if ``rng`` is provided.
rng : numpy.random.Generator, optional
Existing random number generator to use. Takes precedence over ``seed``.
Returns
-------
numpy.random.Generator
Random number generator instance.
Examples
--------
>>> strategy = SamplingStrategy()
>>> rng = strategy._get_rng(seed=42)
>>> isinstance(rng, np.random.Generator)
True
"""
if rng is not None:
return rng
return np.random.default_rng(seed)
[docs]
class Streaming(SamplingStrategy):
"""
Sequential streaming sampling strategy with optional buffer-level shuffling.
This strategy provides indices in sequential order, with optional shuffling
at the buffer level (defined by fetch_factor in scDataset). When shuffle=True,
batches within each fetch buffer are shuffled, similar to Ray Dataset or
WebDataset behavior, while maintaining overall sequential order across buffers.
Parameters
----------
indices : array-like, optional
Subset of indices to use for sampling. If None, uses all indices
from 0 to len(data_collection)-1.
shuffle : bool, default=False
Whether to shuffle batches within each fetch buffer. When True,
enables buffer-level shuffling that maintains sequential order
between buffers but randomizes the order of batches within each
buffer (defined by fetch_factor * batch_size).
Attributes
----------
_shuffle_before_yield : bool
Controlled by the shuffle parameter. True if buffer-level shuffling
is enabled, False otherwise.
_indices : numpy.ndarray or None
Stored subset of indices if provided.
shuffle : bool
Whether buffer-level shuffling is enabled.
Examples
--------
>>> # Stream through entire dataset without shuffling
>>> strategy = Streaming()
>>> indices = strategy.get_indices(range(100))
>>> len(indices)
100
>>> # Stream through subset of indices
>>> subset_strategy = Streaming(indices=[10, 20, 30])
>>> indices = subset_strategy.get_indices(range(100))
>>> indices.tolist()
[10, 20, 30]
>>> # Stream with buffer-level shuffling (like Ray Dataset/WebDataset)
>>> shuffle_strategy = Streaming(shuffle=True)
>>> # Batches within each fetch buffer will be shuffled,
>>> # but buffers themselves maintain sequential order
See Also
--------
BlockShuffling : For shuffled block-based sampling
BlockWeightedSampling : For weighted sampling with shuffling
Notes
-----
When shuffle=True, this strategy provides behavior similar to:
- Ray Dataset's local shuffling within windows
- WebDataset's shuffle buffer functionality
The key difference from BlockShuffling is that Streaming maintains
the overall sequential order of fetch buffers, only shuffling within
each buffer, while BlockShuffling shuffles the order of blocks themselves.
"""
[docs]
def __init__(self, indices: Optional[ArrayLike] = None, shuffle: bool = False):
"""
Initialize streaming strategy.
Parameters
----------
indices : array-like, optional
Subset of indices to stream through. If None, streams through
all available indices. Indices will be automatically sorted
to ensure optimal I/O performance.
shuffle : bool, default=False
Whether to enable buffer-level shuffling. When True, batches
within each fetch buffer are shuffled while maintaining
sequential order between buffers.
"""
super().__init__(indices=indices)
self.shuffle = shuffle
self._shuffle_before_yield = shuffle
[docs]
def get_len(self, data_collection) -> int:
"""
Get the effective length of the data collection for this strategy.
Parameters
----------
data_collection : object
The data collection to get length from. Must support ``len()``.
Returns
-------
int
Number of samples that will be yielded by this strategy.
Examples
--------
>>> strategy = Streaming()
>>> strategy.get_len(range(100))
100
>>> subset_strategy = Streaming(indices=[1, 3, 5])
>>> subset_strategy.get_len(range(100))
3
"""
if self._indices is None:
return len(data_collection)
return len(self._indices)
[docs]
def get_indices(
self,
data_collection,
seed: Optional[int] = None,
rng: Optional[np.random.Generator] = None,
) -> NDArray[np.intp]:
"""
Get indices for streaming sampling.
Returns indices in sequential order. If shuffle=True was set during
initialization, the _shuffle_before_yield attribute will cause
buffer-level shuffling during iteration.
Parameters
----------
data_collection : object
The data collection to sample from. Must support ``len()``.
seed : int, optional
Random seed. Only used if shuffle=True for buffer-level shuffling
during iteration, not for index generation which remains sequential.
rng : numpy.random.Generator, optional
Random number generator. Only used if shuffle=True for buffer-level
shuffling during iteration.
Returns
-------
numpy.ndarray
Array of indices in sequential order.
Examples
--------
>>> strategy = Streaming()
>>> indices = strategy.get_indices(range(5))
>>> indices.tolist()
[0, 1, 2, 3, 4]
>>> subset_strategy = Streaming(indices=[2, 4, 6])
>>> indices = subset_strategy.get_indices(range(10))
>>> indices.tolist()
[2, 4, 6]
>>> # With shuffle=True, indices are still sequential
>>> shuffle_strategy = Streaming(shuffle=True)
>>> indices = shuffle_strategy.get_indices(range(5))
>>> indices.tolist()
[0, 1, 2, 3, 4]
"""
if self._indices is None:
return np.arange(len(data_collection))
return self._indices
[docs]
class BlockShuffling(SamplingStrategy):
"""
Block-based shuffling sampling strategy.
This strategy divides the data into blocks of fixed size and shuffles
the order of blocks while maintaining the original order within each block.
This provides a balance between randomization and maintaining some locality
of data access patterns.
Parameters
----------
block_size : int, default=8
Size of each block for shuffling. Larger blocks maintain more locality
but provide less randomization.
indices : array-like, optional
Subset of indices to use for sampling. If None, uses all indices
from 0 to len(data_collection)-1.
drop_last : bool, default=False
Whether to drop the last incomplete block if the total number of
indices is not divisible by block_size.
Attributes
----------
_shuffle_before_yield : bool
Always True for block shuffling strategy.
_indices : numpy.ndarray or None
Stored subset of indices if provided.
block_size : int
Size of blocks for shuffling.
drop_last : bool
Whether to drop incomplete blocks.
Notes
-----
When ``drop_last=False`` and there's a remainder block smaller than
``block_size``, it's inserted at a random position among the shuffled
complete blocks.
Examples
--------
>>> # Basic block shuffling
>>> strategy = BlockShuffling(block_size=3)
>>> np.random.seed(42) # For reproducible example
>>> indices = strategy.get_indices(range(10), seed=42)
>>> len(indices)
10
>>> # Drop incomplete blocks
>>> strategy = BlockShuffling(block_size=3, drop_last=True)
>>> indices = strategy.get_indices(range(10), seed=42)
>>> len(indices) # 10 // 3 * 3 = 9
9
See Also
--------
Streaming : For sequential sampling without shuffling
BlockWeightedSampling : For weighted block-based sampling
"""
[docs]
def __init__(
self,
block_size: int = 8,
indices: Optional[ArrayLike] = None,
drop_last: bool = False,
):
"""
Initialize block shuffling strategy.
Parameters
----------
block_size : int, default=8
Size of blocks for shuffling. Must be positive.
indices : array-like, optional
Subset of indices to sample from. Indices will be automatically
sorted to ensure optimal I/O performance.
drop_last : bool, default=False
Whether to drop the last incomplete block.
Raises
------
ValueError
If block_size is not positive.
"""
super().__init__(indices=indices)
if block_size <= 0:
raise ValueError("block_size must be positive")
self._shuffle_before_yield = True
self.block_size = block_size
self.drop_last = drop_last
[docs]
def get_len(self, data_collection) -> int:
"""
Get the effective length of the data collection for this strategy.
Takes into account the drop_last setting when calculating the
effective length.
Parameters
----------
data_collection : object
The data collection to get length from. Must support ``len()``.
Returns
-------
int
Number of samples that will be yielded by this strategy.
Examples
--------
>>> strategy = BlockShuffling(block_size=3, drop_last=False)
>>> strategy.get_len(range(10))
10
>>> strategy = BlockShuffling(block_size=3, drop_last=True)
>>> strategy.get_len(range(10)) # 10 - (10 % 3) = 9
9
"""
if self._indices is None:
length = len(data_collection)
else:
length = len(self._indices)
if self.drop_last:
length -= length % self.block_size
return length
[docs]
def get_indices(
self,
data_collection,
seed: Optional[int] = None,
rng: Optional[np.random.Generator] = None,
) -> NDArray[np.intp]:
"""
Generate indices with block-based shuffling.
Divides indices into blocks and shuffles the order of complete blocks.
Incomplete blocks are either dropped or inserted at random positions
depending on the ``drop_last`` setting.
Parameters
----------
data_collection : object
The data collection to sample from. Must support ``len()``.
seed : int, optional
Random seed for reproducible shuffling. Ignored if ``rng`` is provided.
rng : numpy.random.Generator, optional
Random number generator to use for shuffling. If provided, ``seed`` is ignored.
Returns
-------
numpy.ndarray
Array of indices with blocks shuffled.
Notes
-----
When ``drop_last=True`` and there are remainder indices that don't form
a complete block, they are randomly removed from the dataset.
When ``drop_last=False``, remainder indices are inserted at a random
position among the shuffled complete blocks.
Examples
--------
>>> strategy = BlockShuffling(block_size=2, drop_last=False)
>>> indices = strategy.get_indices(range(5), seed=42)
>>> len(indices)
5
>>> strategy = BlockShuffling(block_size=2, drop_last=True)
>>> indices = strategy.get_indices(range(5), seed=42)
>>> len(indices) # Drops the last incomplete block
4
Raises
------
ValueError
If the random number generator cannot sample the required number
of indices for removal when drop_last=True.
"""
if self._indices is None:
indices = np.arange(len(data_collection))
else:
indices = self._indices
rng_obj = self._get_rng(seed, rng)
n = len(indices)
n_blocks = n // self.block_size
n_complete = n_blocks * self.block_size
remainder = n - n_complete
if self.drop_last and remainder > 0:
remove_indices = rng_obj.choice(indices, size=remainder, replace=False)
mask = ~np.isin(indices, remove_indices)
complete_part = indices[mask]
else:
complete_part = indices[:n_complete]
# Reshape complete part into blocks
blocks = complete_part.reshape(n_blocks, self.block_size)
blocks = rng_obj.permutation(blocks, axis=0)
if self.drop_last or remainder == 0:
return blocks.reshape(-1)
else:
# Insert remainder block at a random block boundary
insert_pos = rng_obj.integers(0, n_blocks + 1)
before = blocks[:insert_pos].reshape(-1)
after = blocks[insert_pos:].reshape(-1)
return np.concatenate([before, indices[n_complete:], after])
[docs]
class BlockWeightedSampling(SamplingStrategy):
"""
Weighted sampling with block-based shuffling.
This strategy performs weighted sampling from the data collection and then
applies block-based shuffling to the sampled indices. It supports both
sampling with and without replacement, and can generate a different total
number of samples than the original data collection size.
Parameters
----------
block_size : int, default=8
Size of blocks for shuffling after weighted sampling.
indices : array-like, optional
Subset of indices to sample from. If None, uses all indices.
weights : array-like, optional
Sampling weights for each element in the data collection.
Must be non-negative and sum to a positive value.
If None, uses uniform sampling.
total_size : int, optional
Total number of samples to draw. If None, uses the length of
indices or data_collection.
replace : bool, default=True
Whether to sample with replacement.
sampling_size : int, optional
Size of each sampling round when ``replace=False``.
Required when ``replace=False``.
Attributes
----------
_shuffle_before_yield : bool
Always True for weighted sampling strategy.
_indices : numpy.ndarray or None
Stored subset of indices if provided.
block_size : int
Size of blocks for shuffling.
weights : numpy.ndarray or None
Normalized sampling weights.
total_size : int or None
Total number of samples to generate.
replace : bool
Whether sampling is with replacement.
sampling_size : int or None
Size of each sampling round for replacement=False.
Raises
------
ValueError
If weights are negative, sum to zero, or don't match data collection length.
If sampling_size is not provided when replace=False.
Warns
-----
UserWarning
If sampling_size is provided when replace=True (it will be ignored).
Examples
--------
>>> # Uniform weighted sampling
>>> strategy = BlockWeightedSampling(block_size=2, total_size=6)
>>> indices = strategy.get_indices(range(4), seed=42)
>>> len(indices)
6
>>> # Custom weights favoring certain indices
>>> weights = [0.1, 0.1, 0.4, 0.4] # Favor indices 2 and 3
>>> strategy = BlockWeightedSampling(weights=weights, total_size=8)
>>> indices = strategy.get_indices(range(4), seed=42)
>>> len(indices)
8
>>> # Sampling without replacement
>>> strategy = BlockWeightedSampling(
... total_size=10, replace=False, sampling_size=5
... )
See Also
--------
BlockShuffling : For unweighted block-based shuffling
ClassBalancedSampling : For automatic class-balanced sampling
"""
[docs]
def __init__(
self,
block_size: int = 8,
indices: Optional[ArrayLike] = None,
weights: Optional[ArrayLike] = None,
total_size: Optional[int] = None,
replace: bool = True,
sampling_size: Optional[int] = None,
):
"""
Initialize weighted sampling strategy.
Parameters
----------
block_size : int, default=8
Size of blocks for shuffling. Must be positive.
indices : array-like, optional
Subset of indices to sample from. Indices will be automatically
sorted to ensure optimal I/O performance.
weights : array-like, optional
Sampling weights. Will be normalized automatically.
total_size : int, optional
Total number of samples to generate.
replace : bool, default=True
Whether to sample with replacement.
sampling_size : int, optional
Required when replace=False. Size of each sampling round.
Raises
------
ValueError
If block_size is not positive, weights are invalid, or
sampling_size is missing when replace=False.
"""
super().__init__(indices=indices)
if block_size <= 0:
raise ValueError("block_size must be positive")
self._shuffle_before_yield = True
self.block_size = block_size
if weights is not None:
weights = np.asarray(weights)
if np.any(weights < 0):
raise ValueError("weights must be non-negative")
if np.sum(weights) == 0:
raise ValueError("weights must sum to a positive value")
# Normalize weights
weights = weights / np.sum(weights)
self.weights = weights
self.total_size = total_size
self.replace = replace
if not replace and sampling_size is None:
raise ValueError("sampling_size must be provided when replace=False")
if replace and sampling_size is not None:
warnings.warn(
"sampling_size is ignored when replace=True, since it will sample with replacement",
stacklevel=2,
)
self.sampling_size = sampling_size
[docs]
def get_len(self, data_collection) -> int:
"""
Get the effective length of the data collection for this strategy.
Returns the total number of samples that will be generated,
which may be different from the original data collection size.
Parameters
----------
data_collection : object
The data collection to get length from. Must support ``len()``.
Returns
-------
int
Number of samples that will be yielded by this strategy.
Examples
--------
>>> strategy = BlockWeightedSampling(total_size=100)
>>> strategy.get_len(range(50)) # Returns total_size
100
>>> strategy = BlockWeightedSampling() # No total_size specified
>>> strategy.get_len(range(50)) # Returns collection length
50
"""
if self.total_size is not None:
length = self.total_size
else:
# Use the length of indices or data_collection if total_size not specified
if self._indices is None:
length = len(data_collection)
else:
length = len(self._indices)
return length
[docs]
def get_indices(
self,
data_collection,
seed: Optional[int] = None,
rng: Optional[np.random.Generator] = None,
) -> NDArray[np.intp]:
"""
Generate indices using weighted sampling followed by block shuffling.
First performs weighted sampling (with or without replacement) to select
indices, then applies block-based shuffling to the selected indices.
Parameters
----------
data_collection : object
The data collection to sample from. Must support ``len()``.
seed : int, optional
Random seed for reproducible sampling. Ignored if ``rng`` is provided.
rng : numpy.random.Generator, optional
Random number generator to use. If provided, ``seed`` is ignored.
Returns
-------
numpy.ndarray
Array of indices after weighted sampling and block shuffling.
Raises
------
ValueError
If weights don't match the data collection length.
Notes
-----
When ``replace=False``, sampling is performed in rounds of size
``sampling_size`` until the desired ``total_size`` is reached.
Examples
--------
>>> # Weighted sampling with replacement
>>> weights = [0.25, 0.25, 0.25, 0.25] # Uniform weights
>>> strategy = BlockWeightedSampling(
... weights=weights, total_size=8, block_size=2
... )
>>> indices = strategy.get_indices(range(4), seed=42)
>>> len(indices)
8
>>> # Sampling without replacement
>>> strategy = BlockWeightedSampling(
... total_size=6, replace=False, sampling_size=3, block_size=2
... )
>>> indices = strategy.get_indices(range(10), seed=42)
>>> len(indices)
6
"""
# Handle weights validation - must match either data_collection or indices
working_weights = None
if self.weights is not None:
if self._indices is not None:
# When both weights and indices are provided, weights should match indices length
if len(self.weights) == len(self._indices):
# Weights already match indices - use directly
working_weights = self.weights / self.weights.sum()
elif len(self.weights) == len(data_collection):
# Full weights provided - extract subset for indices
subset_weights = self.weights[self._indices]
working_weights = subset_weights / subset_weights.sum()
else:
raise ValueError(
f"weights length ({len(self.weights)}) must match either "
f"data_collection length ({len(data_collection)}) or "
f"indices length ({len(self._indices)})"
)
else:
# No indices - weights must match data_collection
if len(self.weights) != len(data_collection):
raise ValueError(
"weights must have the same length as data_collection"
)
working_weights = self.weights / self.weights.sum()
if self._indices is None:
_indices = np.arange(len(data_collection))
else:
_indices = self._indices
rng_obj = self._get_rng(seed, rng)
if self.replace:
# Sample with replacement
if self.total_size is not None:
size = self.total_size
else:
size = len(_indices)
indices = rng_obj.choice(
_indices, size=size, replace=True, p=working_weights
)
else:
# Sample without replacement until we have total_size
sampled = 0
indices_list = []
while sampled < self.total_size:
remaining = self.total_size - sampled
current_size = min(self.sampling_size, remaining)
new_indices = rng_obj.choice(
_indices, size=current_size, replace=False, p=working_weights
)
indices_list.append(new_indices)
sampled += len(new_indices)
indices = np.concatenate(indices_list)
indices.sort()
n = len(indices)
n_blocks = n // self.block_size
n_complete = n_blocks * self.block_size
complete_part = indices[:n_complete]
remainder_part = indices[n_complete:]
# Reshape complete part into blocks
blocks = complete_part.reshape(n_blocks, self.block_size)
rng_obj = self._get_rng(seed, rng)
blocks = rng_obj.permutation(blocks, axis=0)
if len(remainder_part) == 0:
# Only return shuffled complete blocks
return blocks.reshape(-1)
else:
# Insert remainder block at a random block boundary
insert_pos = rng_obj.integers(0, n_blocks + 1)
before = blocks[:insert_pos].reshape(-1)
after = blocks[insert_pos:].reshape(-1)
return np.concatenate([before, remainder_part, after])
[docs]
class ClassBalancedSampling(BlockWeightedSampling):
"""
Class-balanced sampling with automatic weight computation.
This strategy extends :class:`BlockWeightedSampling` by automatically computing
balanced weights from provided labels, making each class equally likely to be
sampled regardless of the original class distribution in the dataset.
The weights are computed as the inverse of class frequencies, ensuring that
underrepresented classes get higher sampling probability and overrepresented
classes get lower sampling probability.
**Dual Behavior for Labels:**
The strategy supports two modes based on the labels array length:
1. **Global class balancing** (labels length = full dataset):
Weights are computed from the full dataset's class distribution. When
sampling from a subset (via ``indices``), samples are weighted according
to their importance in the global distribution, not the subset.
2. **Subset class balancing** (labels length = indices length):
Weights are computed only from the labels of the subset indices.
This balances classes within the subset, ignoring the global distribution.
Parameters
----------
labels : array-like
Class labels for each sample. The length determines the balancing mode:
- If ``len(labels) == len(data_collection)``: Global balancing mode.
Weights computed from full dataset, then applied to subset.
- If ``len(labels) == len(indices)``: Subset balancing mode.
Weights computed only from the subset's labels.
block_size : int, default=8
Size of blocks for block shuffling after sampling.
indices : array-like, optional
Subset of indices to sample from. If None, uses all indices.
total_size : int, optional
Total number of samples to draw. If None, uses the length of indices
or data_collection.
replace : bool, default=True
Whether to sample with replacement.
sampling_size : int, optional
Size of each sampling round when ``replace=False``. Required when
``replace=False``.
Attributes
----------
labels : numpy.ndarray
Array of class labels for each sample.
Raises
------
ValueError
If labels array is empty.
Examples
--------
Global balancing - balance for full dataset distribution:
>>> # Full dataset: 90% class 0, 10% class 1
>>> full_labels = [0]*90 + [1]*10 # 100 samples total
>>> subset_indices = [0, 1, 90, 91, 92, 93, 94, 95, 96, 97] # 2 of class 0, 8 of class 1
>>>
>>> # Global balancing: uses full dataset weights
>>> strategy = ClassBalancedSampling(full_labels, indices=subset_indices, total_size=20)
>>> # Class 1 samples get ~9x higher weight (because 1/10 vs 1/90 in global dist)
>>> # Even though subset is 80% class 1, global weights still favor class 1
Subset balancing - balance within the subset only:
>>> # Only provide labels for the subset indices
>>> subset_labels = [0, 0, 1, 1, 1, 1, 1, 1, 1, 1] # Labels for subset: 20% class 0, 80% class 1
>>> strategy = ClassBalancedSampling(subset_labels, indices=subset_indices, total_size=20)
>>> # Now class 0 samples get 4x higher weight (because 1/2 vs 1/8 in subset dist)
>>> # This balances within the subset, ignoring global distribution
See Also
--------
BlockWeightedSampling : For manual weight specification
BlockShuffling : For unweighted sampling
Notes
-----
The computed weights ensure that each class has equal probability of being
sampled, not that each class appears equally often in the final sample.
The actual class distribution in samples will depend on the random sampling
process and may vary between different runs.
When using global balancing with a subset that has different class proportions
than the full dataset, the output may appear imbalanced relative to the subset.
This is intentional - the weights reflect global importance.
"""
[docs]
def __init__(
self,
labels: ArrayLike,
block_size: int = 8,
indices: Optional[ArrayLike] = None,
total_size: Optional[int] = None,
replace: bool = True,
sampling_size: Optional[int] = None,
):
"""
Initialize class-balanced sampling strategy.
Parameters
----------
labels : array-like
Class labels for samples. The length of labels determines the
balancing mode (see class docstring for details):
- If ``len(labels) == len(indices)``: **subset balancing** mode.
Labels correspond to the subset samples only.
- If ``len(labels) > len(indices)``: **global balancing** mode.
Labels correspond to the full dataset.
block_size : int, default=8
Size of blocks for shuffling. Must be positive.
indices : array-like, optional
Subset of indices to sample from.
total_size : int, optional
Total number of samples to generate.
replace : bool, default=True
Whether to sample with replacement.
sampling_size : int, optional
Required when replace=False. Size of each sampling round.
Raises
------
ValueError
If labels array is empty, block_size is not positive, or labels
length doesn't match indices length (for subset mode) or exceed it
(for global mode).
"""
# Store labels and validate basic properties
self.labels = np.asarray(labels)
if len(self.labels) == 0:
raise ValueError("labels cannot be empty")
# Store indices for validation
self._init_indices = np.asarray(indices) if indices is not None else None
# Validate labels length vs indices length
# - Subset mode: len(labels) == len(indices) - balance within subset
# - Global mode: len(labels) > len(indices) - use global class frequencies
if self._init_indices is not None and len(self.labels) < len(
self._init_indices
):
raise ValueError(
f"labels length ({len(self.labels)}) must be either equal to "
f"indices length ({len(self._init_indices)}) for subset balancing, "
f"or greater than indices length for global balancing"
)
weights = self._compute_class_weights()
super().__init__(
block_size=block_size,
indices=indices,
weights=weights,
total_size=total_size,
replace=replace,
sampling_size=sampling_size,
)
def _compute_class_weights(self) -> NDArray[np.floating]:
"""
Compute balanced weights for each sample based on inverse class frequency.
The weight computation depends on the balancing mode:
- **Subset mode** (``len(labels) == len(indices)``): Weights are computed
from the provided labels directly, which represent the subset's class
distribution. This balances classes relative to the subset.
- **Global mode** (``len(labels) > len(indices)`` or no indices): Weights
are computed from the full label array to preserve global class importance.
Samples from rare global classes get higher weights even if they're
common in the subset.
Returns
-------
numpy.ndarray
Weights array. In subset mode, length equals ``len(labels)`` which
equals ``len(indices)``. In global mode, length equals the full
``len(labels)``, and the parent class handles subsetting.
Notes
-----
This method is called internally during ``__init__`` to compute
weights before they are passed to the parent class.
The weights are not normalized to sum to 1.0, as this normalization
is handled by the parent class :class:`BlockWeightedSampling`.
Example (Subset Mode):
For labels ``[0, 0, 1, 2, 2, 2]`` (subset labels only):
- Class 0: 2 samples -> weight 1/2 = 0.5 per sample
- Class 1: 1 sample -> weight 1/1 = 1.0 per sample
- Class 2: 3 samples -> weight 1/3 ≈ 0.33 per sample
Result: ``array([0.5, 0.5, 1.0, 0.333, 0.333, 0.333])``
Example (Global Mode):
If full dataset labels are ``[0, 0, 0, 0, 1]`` (80% class 0, 20% class 1)
and indices select samples 3, 4 (classes 0, 1):
- Class 0: 4 samples globally -> weight 1/4 = 0.25
- Class 1: 1 sample globally -> weight 1/1 = 1.0
Subset sample 3 (class 0) gets weight 0.25,
Subset sample 4 (class 1) gets weight 1.0.
Class 1 samples are 4x more likely to be sampled despite
being 50% of the subset.
"""
unique_classes, class_counts = np.unique(self.labels, return_counts=True)
# Compute inverse frequency weights
class_weights = 1.0 / class_counts
# Create sample weights by mapping class weights to each sample
weights = np.zeros(len(self.labels))
for cls, weight in zip(unique_classes, class_weights):
weights[self.labels == cls] = weight
return weights