scdataset.strategy.BlockWeightedSampling#
- class scdataset.strategy.BlockWeightedSampling(block_size: int = 8, indices: _Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | None = None, weights: _Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | None = None, total_size: int | None = None, replace: bool = True, sampling_size: int | None = None)[source]
Bases:
SamplingStrategyWeighted sampling with block-based shuffling.
This strategy performs weighted sampling from the data collection and then applies block-based shuffling to the sampled indices. It supports both sampling with and without replacement, and can generate a different total number of samples than the original data collection size.
- Parameters:
block_size (int, default=8) – Size of blocks for shuffling after weighted sampling.
indices (array-like, optional) – Subset of indices to sample from. If None, uses all indices.
weights (array-like, optional) – Sampling weights for each element in the data collection. Must be non-negative and sum to a positive value. If None, uses uniform sampling.
total_size (int, optional) – Total number of samples to draw. If None, uses the length of indices or data_collection.
replace (bool, default=True) – Whether to sample with replacement.
sampling_size (int, optional) – Size of each sampling round when
replace=False. Required whenreplace=False.
- Variables:
_shuffle_before_yield (bool) – Always True for weighted sampling strategy.
_indices (numpy.ndarray or None) – Stored subset of indices if provided.
block_size (int) – Size of blocks for shuffling.
weights (numpy.ndarray or None) – Normalized sampling weights.
total_size (int or None) – Total number of samples to generate.
replace (bool) – Whether sampling is with replacement.
sampling_size (int or None) – Size of each sampling round for replacement=False.
- Raises:
ValueError – If weights are negative, sum to zero, or don’t match data collection length. If sampling_size is not provided when replace=False.
- Warns:
UserWarning – If sampling_size is provided when replace=True (it will be ignored).
Examples
>>> # Uniform weighted sampling >>> strategy = BlockWeightedSampling(block_size=2, total_size=6) >>> indices = strategy.get_indices(range(4), seed=42) >>> len(indices) 6
>>> # Custom weights favoring certain indices >>> weights = [0.1, 0.1, 0.4, 0.4] # Favor indices 2 and 3 >>> strategy = BlockWeightedSampling(weights=weights, total_size=8, seed=42) >>> indices = strategy.get_indices(range(4), seed=42) >>> len(indices) 8
>>> # Sampling without replacement >>> strategy = BlockWeightedSampling( ... total_size=10, replace=False, sampling_size=5 ... )
See also
BlockShufflingFor unweighted block-based shuffling
ClassBalancedSamplingFor automatic class-balanced sampling
Methods
__init__([block_size, indices, weights, ...])Initialize weighted sampling strategy.
get_indices(data_collection[, seed, rng])Generate indices using weighted sampling followed by block shuffling.
get_len(data_collection)Get the effective length of the data collection for this strategy.
- __init__(block_size: int = 8, indices: _Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | None = None, weights: _Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | None = None, total_size: int | None = None, replace: bool = True, sampling_size: int | None = None)[source]
Initialize weighted sampling strategy.
- Parameters:
block_size (int, default=8) – Size of blocks for shuffling. Must be positive.
indices (array-like, optional) – Subset of indices to sample from.
weights (array-like, optional) – Sampling weights. Will be normalized automatically.
total_size (int, optional) – Total number of samples to generate.
replace (bool, default=True) – Whether to sample with replacement.
sampling_size (int, optional) – Required when replace=False. Size of each sampling round.
- Raises:
ValueError – If block_size is not positive, weights are invalid, or sampling_size is missing when replace=False.
- get_len(data_collection) int[source]
Get the effective length of the data collection for this strategy.
Returns the total number of samples that will be generated, which may be different from the original data collection size.
- Parameters:
data_collection (object) – The data collection to get length from. Must support
len().- Returns:
Number of samples that will be yielded by this strategy.
- Return type:
Examples
>>> strategy = BlockWeightedSampling(total_size=100) >>> strategy.get_len(range(50)) # Returns total_size 100
>>> strategy = BlockWeightedSampling() # No total_size specified >>> strategy.get_len(range(50)) # Returns collection length 50
- get_indices(data_collection, seed: int | None = None, rng: Generator | None = None) ndarray[tuple[int, ...], dtype[int64]][source]
Generate indices using weighted sampling followed by block shuffling.
First performs weighted sampling (with or without replacement) to select indices, then applies block-based shuffling to the selected indices.
- Parameters:
data_collection (object) – The data collection to sample from. Must support
len().seed (int, optional) – Random seed for reproducible sampling. Ignored if
rngis provided.rng (numpy.random.Generator, optional) – Random number generator to use. If provided,
seedis ignored.
- Returns:
Array of indices after weighted sampling and block shuffling.
- Return type:
- Raises:
ValueError – If weights don’t match the data collection length.
Notes
When
replace=False, sampling is performed in rounds of sizesampling_sizeuntil the desiredtotal_sizeis reached. This helps avoid memory issues with large datasets.The selected indices are sorted before block shuffling to ensure consistent behavior across different random seeds.
Examples
>>> # Weighted sampling with replacement >>> weights = [0.25, 0.25, 0.25, 0.25] # Uniform weights >>> strategy = BlockWeightedSampling( ... weights=weights, total_size=8, block_size=2 ... ) >>> indices = strategy.get_indices(range(4), seed=42) >>> len(indices) 8
>>> # Sampling without replacement >>> strategy = BlockWeightedSampling( ... total_size=6, replace=False, sampling_size=3, block_size=2 ... ) >>> indices = strategy.get_indices(range(10), seed=42) >>> len(indices) 6