scdataset.strategy.ClassBalancedSampling#
- class scdataset.strategy.ClassBalancedSampling(labels: _Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes], block_size: int = 8, indices: _Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | None = None, total_size: int | None = None, replace: bool = True, sampling_size: int | None = None)[source]
Bases:
BlockWeightedSamplingClass-balanced sampling with automatic weight computation.
This strategy extends
BlockWeightedSamplingby automatically computing balanced weights from provided labels, making each class equally likely to be sampled regardless of the original class distribution in the dataset.The weights are computed as the inverse of class frequencies, ensuring that underrepresented classes get higher sampling probability and overrepresented classes get lower sampling probability.
- Parameters:
labels (array-like) – Class labels for each sample. Can be numpy array, pandas Series, or any array-like object. Must have the same length as the data collection (validated during get_indices).
block_size (int, default=8) – Size of blocks for block shuffling after sampling.
indices (array-like, optional) – Subset of indices to sample from. If None, uses all indices.
total_size (int, optional) – Total number of samples to draw. If None, uses the length of indices or data_collection.
replace (bool, default=True) – Whether to sample with replacement.
sampling_size (int, optional) – Size of each sampling round when
replace=False. Required whenreplace=False.
- Variables:
labels (numpy.ndarray) – Array of class labels for each sample.
- Raises:
ValueError – If labels array is empty.
Examples
>>> # Balanced sampling from imbalanced dataset >>> labels = [0, 0, 0, 0, 1, 1, 2] # Imbalanced classes >>> strategy = ClassBalancedSampling(labels, total_size=12) >>> indices = strategy.get_indices(range(7), seed=42) >>> len(indices) 12
>>> # Class weights are automatically computed >>> strategy = ClassBalancedSampling([0, 0, 1]) >>> # Class 0 appears twice, class 1 once >>> # So class 0 gets weight 1/2 = 0.5, class 1 gets weight 1/1 = 1.0 >>> # After normalization: class 0 weight ≈ 0.33, class 1 weight ≈ 0.67
See also
BlockWeightedSamplingFor manual weight specification
BlockShufflingFor unweighted sampling
Notes
The computed weights ensure that each class has equal probability of being sampled, not that each class appears equally often in the final sample. The actual class distribution in samples will depend on the random sampling process and may vary between different runs.
Methods
__init__(labels[, block_size, indices, ...])Initialize class-balanced sampling strategy.
get_indices(data_collection[, seed, rng])Generate indices using weighted sampling followed by block shuffling.
get_len(data_collection)Get the effective length of the data collection for this strategy.
- __init__(labels: _Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes], block_size: int = 8, indices: _Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | None = None, total_size: int | None = None, replace: bool = True, sampling_size: int | None = None)[source]
Initialize class-balanced sampling strategy.
- Parameters:
labels (array-like) – Class labels for each sample in the data collection.
block_size (int, default=8) – Size of blocks for shuffling. Must be positive.
indices (array-like, optional) – Subset of indices to sample from.
total_size (int, optional) – Total number of samples to generate.
replace (bool, default=True) – Whether to sample with replacement.
sampling_size (int, optional) – Required when replace=False. Size of each sampling round.
- Raises:
ValueError – If labels array is empty or block_size is not positive.