scdataset.scDataset#
- class scdataset.scDataset(data_collection, strategy: SamplingStrategy, batch_size: int, fetch_factor: int = 16, drop_last: bool = False, fetch_transform: Callable | None = None, batch_transform: Callable | None = None, fetch_callback: Callable | None = None, batch_callback: Callable | None = None)[source]
Bases:
IterableDatasetIterable PyTorch Dataset for on-disk data collections with flexible sampling strategies.
This dataset implementation provides efficient iteration over large on-disk data collections using configurable sampling strategies. It supports various data transformations, custom fetch/batch callbacks, and automatic handling of multiprocessing workers.
- Parameters:
data_collection (object) – The data collection to sample from (e.g., AnnCollection, HuggingFace Dataset, numpy array, etc.). Must support indexing (
__getitem__) and length (__len__) operations.strategy (SamplingStrategy) – Strategy for sampling indices from the data collection. Determines the order and selection of samples.
batch_size (int) – Number of samples per minibatch. Must be positive.
fetch_factor (int, default=1) – Multiplier for fetch size relative to batch size. Higher values may improve I/O efficiency by fetching more data at once.
drop_last (bool, default=False) – Whether to drop the last incomplete batch if it contains fewer than
batch_sizesamples.fetch_transform (Callable, optional) – Function to transform data after fetching but before batching. Applied to the entire fetch (multiple batches worth of data).
batch_transform (Callable, optional) – Function to transform each individual batch before yielding.
fetch_callback (Callable, optional) – Custom function to fetch data given indices. Should accept
(data_collection, indices)and return the fetched data. If None, uses default indexing (data_collection[indices]).batch_callback (Callable, optional) – Custom function to extract batch data from fetched data. Should accept
(fetched_data, batch_indices)and return the batch. If None, uses default indexing (fetched_data[batch_indices]).
- Variables:
collection (object) – The underlying data collection.
strategy (SamplingStrategy) – The sampling strategy being used.
batch_size (int) – Size of each batch.
fetch_factor (int) – Fetch size multiplier.
drop_last (bool) – Whether incomplete batches are dropped.
fetch_size (int) – Total number of samples fetched at once (batch_size * fetch_factor).
sort_before_fetch (bool) – Always True. Indices are sorted before fetching for optimal I/O.
- Raises:
ValueError – If batch_size or fetch_factor is not positive.
TypeError – If data_collection doesn’t support required operations or strategy is not a SamplingStrategy instance.
Examples
>>> from scdataset import scDataset >>> from scdataset.strategy import Streaming >>> import numpy as np
>>> # Simple streaming dataset >>> data = np.random.randn(1000, 50) # 1000 samples, 50 features >>> strategy = Streaming() >>> dataset = scDataset(data, strategy, batch_size=32) >>> len(dataset) # Number of batches 32
>>> # With custom transforms >>> def normalize_batch(batch): ... return (batch - batch.mean()) / batch.std() >>> dataset = scDataset( ... data, strategy, batch_size=32, ... batch_transform=normalize_batch ... )
>>> # Iterate through batches >>> for batch in dataset: ... print(batch.shape) # (32, 50) for most batches ... break
See also
scdataset.strategy.SamplingStrategyBase class for sampling strategies
scdataset.strategy.StreamingSequential sampling without shuffling
scdataset.strategy.BlockShufflingBlock-based shuffling
scdataset.strategy.BlockWeightedSamplingWeighted sampling with blocks
scdataset.strategy.ClassBalancedSamplingAutomatic class balancing
Notes
The dataset automatically handles PyTorch’s multiprocessing by distributing fetch ranges among workers. Each worker gets a different subset of the data to avoid duplication.
Data is fetched in chunks of size
batch_size * fetch_factorand then divided into batches. This can improve I/O efficiency, especially for datasets where accessing non-contiguous indices is expensive.Methods
__init__(data_collection, strategy, batch_size)Initialize the scDataset.
- __init__(data_collection, strategy: SamplingStrategy, batch_size: int, fetch_factor: int = 16, drop_last: bool = False, fetch_transform: Callable | None = None, batch_transform: Callable | None = None, fetch_callback: Callable | None = None, batch_callback: Callable | None = None)[source]
Initialize the scDataset.
- Parameters:
data_collection (object) – Data collection supporting indexing and len().
strategy (SamplingStrategy) – Sampling strategy instance.
batch_size (int) – Positive integer for batch size.
fetch_factor (int, default=16) – Positive integer for fetch size multiplier.
drop_last (bool, default=False) – Whether to drop incomplete batches.
fetch_transform (Callable, optional) – Transform applied to fetched data.
batch_transform (Callable, optional) – Transform applied to each batch.
fetch_callback (Callable, optional) – Custom fetch function.
batch_callback (Callable, optional) – Custom batch extraction function.
- __len__() int[source]
Return the number of batches in the dataset.
Calculates the number of batches that will be yielded by the iterator based on the sampling strategy’s effective length and the batch size.
- Returns:
Number of batches in the dataset.
- Return type:
Examples
>>> from scdataset.strategy import Streaming >>> dataset = scDataset(range(100), Streaming(), batch_size=10) >>> len(dataset) 10
>>> # With drop_last=True >>> dataset = scDataset(range(105), Streaming(), batch_size=10, drop_last=True) >>> len(dataset) # 105 // 10 = 10 (drops 5 samples) 10
>>> # With drop_last=False (default) >>> dataset = scDataset(range(105), Streaming(), batch_size=10, drop_last=False) >>> len(dataset) # ceil(105 / 10) = 11 11
Notes
When
drop_last=True, only complete batches are counted. Whendrop_last=False, the last incomplete batch is included in the count.
- __iter__()[source]
Yield batches of data according to the sampling strategy.
Creates an iterator that yields batches of data by:
Getting indices from the sampling strategy
Dividing indices into fetch ranges
Distributing fetch ranges among workers (if multiprocessing)
Fetching data in chunks and applying fetch transforms
Dividing fetched data into batches and applying batch transforms
Yielding transformed batches
- Yields:
object – Batches of data after applying all transforms. The exact type depends on the data collection and any applied transforms.
Examples
>>> from scdataset.strategy import Streaming >>> import numpy as np >>> data = np.random.randn(100, 10) >>> dataset = scDataset(data, Streaming(), batch_size=5) >>> for i, batch in enumerate(dataset): ... print(f"Batch {i}: shape {batch.shape}") ... if i >= 2: # Just show first few batches ... break Batch 0: shape (5, 10) Batch 1: shape (5, 10) Batch 2: shape (5, 10)
Notes
The fetch-then-batch approach can improve I/O efficiency by:
Sorting indices before fetching for better disk access patterns
Fetching multiple batches worth of data at once
Reducing the number of data access operations
Shuffling behavior is controlled by the sampling strategy’s
_shuffle_before_yieldattribute.