scdataset.scDataset#
- class scdataset.scDataset(data_collection, strategy: SamplingStrategy, batch_size: int, fetch_factor: int = 16, drop_last: bool = False, fetch_callback: Callable | None = None, fetch_transform: Callable | None = None, batch_callback: Callable | None = None, batch_transform: Callable | None = None, rank: int | None = None, world_size: int | None = None, seed: int | None = None)[source]
Bases:
IterableDatasetIterable PyTorch Dataset for on-disk data collections with flexible sampling strategies.
This dataset implementation provides efficient iteration over large on-disk data collections using configurable sampling strategies. It supports various data transformations, custom fetch/batch callbacks, and automatic handling of multiprocessing workers and distributed training (DDP).
- Parameters:
data_collection (object) – The data collection to sample from (e.g., AnnCollection, HuggingFace Dataset, numpy array, etc.). Must support indexing (
__getitem__) and length (__len__) operations.strategy (SamplingStrategy) – Strategy for sampling indices from the data collection. Determines the order and selection of samples.
batch_size (int) – Number of samples per minibatch. Must be positive.
fetch_factor (int, default=16) – Multiplier for fetch size relative to batch size. Higher values may improve I/O efficiency by fetching more data at once.
drop_last (bool, default=False) – Whether to drop the last incomplete batch if it contains fewer than
batch_sizesamples.fetch_callback (Callable, optional) – Custom function to fetch data given indices. Should accept
(data_collection, indices)and return the fetched data. If None, uses default indexing (data_collection[indices]).fetch_transform (Callable, optional) – Function to transform data after fetching but before batching. Applied to the entire fetch (multiple batches worth of data).
batch_callback (Callable, optional) – Custom function to extract batch data from fetched data. Should accept
(fetched_data, batch_indices)and return the batch. If None, uses default indexing (fetched_data[batch_indices]).batch_transform (Callable, optional) – Function to transform each individual batch before yielding.
rank (int, optional) – Process rank for distributed training (DDP). If None, auto-detects from torch.distributed if initialized. Defaults to 0 for non-distributed.
world_size (int, optional) – Total number of processes for distributed training (DDP). If None, auto-detects from torch.distributed if initialized. Defaults to 1.
seed (int or None, optional) – Base seed for reproducible shuffling. Combined with auto-incrementing epoch counter to produce different shuffling each epoch while ensuring reproducibility across runs with the same seed. If None, a random seed is generated and shared across all DDP ranks to ensure consistent shuffling while providing variety between runs.
- Variables:
collection (object) – The underlying data collection.
strategy (SamplingStrategy) – The sampling strategy being used.
batch_size (int) – Size of each batch.
fetch_factor (int) – Fetch size multiplier.
drop_last (bool) – Whether incomplete batches are dropped.
fetch_size (int) – Total number of samples fetched at once (batch_size * fetch_factor).
sort_before_fetch (bool) – Always True. Indices are sorted before fetching for optimal I/O.
rank (int) – Process rank for distributed training.
world_size (int) – Total number of distributed processes.
- Raises:
ValueError – If batch_size or fetch_factor is not positive.
TypeError – If data_collection doesn’t support required operations or strategy is not a SamplingStrategy instance.
Examples
>>> from scdataset import scDataset >>> from scdataset.strategy import Streaming >>> import numpy as np
>>> # Simple streaming dataset >>> data = np.random.randn(1000, 50) # 1000 samples, 50 features >>> strategy = Streaming() >>> dataset = scDataset(data, strategy, batch_size=32) >>> len(dataset) # Number of batches 32
>>> # With custom transforms >>> def normalize_batch(batch): ... return (batch - batch.mean()) / batch.std() >>> dataset = scDataset( ... data, strategy, batch_size=32, ... batch_transform=normalize_batch ... )
>>> # Iterate through batches >>> for batch in dataset: ... print(batch.shape) ... break (32, 50)
>>> # Distributed Data Parallel (DDP) usage >>> # In DDP training script: >>> # import torch.distributed as dist >>> # dist.init_process_group(...) >>> # dataset = scDataset(data, strategy, batch_size=32) # Auto-detects DDP >>> # Or manually specify: >>> # dataset = scDataset(data, strategy, batch_size=32, rank=0, world_size=4)
See also
scdataset.strategy.SamplingStrategyBase class for sampling strategies
scdataset.strategy.StreamingSequential sampling without shuffling
scdataset.strategy.BlockShufflingBlock-based shuffling
scdataset.strategy.BlockWeightedSamplingWeighted sampling with blocks
scdataset.strategy.ClassBalancedSamplingAutomatic class balancing
Notes
The dataset automatically handles PyTorch’s multiprocessing by distributing fetch ranges among workers. Each worker gets a different subset of the data to avoid duplication.
Data is fetched in chunks of size
batch_size * fetch_factorand then divided into batches. This can improve I/O efficiency, especially for datasets where accessing non-contiguous indices is expensive.DDP Support: When using Distributed Data Parallel, fetches are distributed across ranks in round-robin fashion for better load balancing. Each rank processes every
world_size-th fetch, ensuring no data duplication. Combined with PyTorch DataLoader’snum_workers, this provides two levels of parallelism: across DDP ranks and across DataLoader workers within each rank.Methods
__init__(data_collection, strategy, batch_size)Initialize the scDataset.
- __init__(data_collection, strategy: SamplingStrategy, batch_size: int, fetch_factor: int = 16, drop_last: bool = False, fetch_callback: Callable | None = None, fetch_transform: Callable | None = None, batch_callback: Callable | None = None, batch_transform: Callable | None = None, rank: int | None = None, world_size: int | None = None, seed: int | None = None)[source]
Initialize the scDataset.
- Parameters:
data_collection (object) – Data collection supporting indexing and len().
strategy (SamplingStrategy) – Sampling strategy instance.
batch_size (int) – Positive integer for batch size.
fetch_factor (int, default=16) – Positive integer for fetch size multiplier.
drop_last (bool, default=False) – Whether to drop incomplete batches.
fetch_callback (Callable, optional) – Custom fetch function.
fetch_transform (Callable, optional) – Transform applied to fetched data.
batch_callback (Callable, optional) – Custom batch extraction function.
batch_transform (Callable, optional) – Transform applied to each batch.
rank (int, optional) – Process rank for DDP. Auto-detected if None.
world_size (int, optional) – Number of DDP processes. Auto-detected if None.
seed (int or None, optional) – Base seed for reproducible shuffling. If None, generates a random seed shared across DDP ranks for consistent shuffling. Combined with auto-incrementing epoch counter for different shuffling each epoch.
- __len__() int[source]
Return the number of batches in the dataset for this rank.
Calculates the number of batches that will be yielded by the iterator based on the sampling strategy’s effective length, batch size, and the number of DDP ranks (if using distributed training).
- Returns:
Number of batches in the dataset for this rank.
- Return type:
Examples
>>> from scdataset.strategy import Streaming >>> dataset = scDataset(range(100), Streaming(), batch_size=10) >>> len(dataset) 10
>>> # With drop_last=True >>> dataset = scDataset(range(105), Streaming(), batch_size=10, drop_last=True) >>> len(dataset) # 105 // 10 = 10 (drops 5 samples) 10
>>> # With drop_last=False (default) >>> dataset = scDataset(range(105), Streaming(), batch_size=10, drop_last=False) >>> len(dataset) # ceil(105 / 10) = 11 11
Notes
When
drop_last=True, only complete batches are counted. Whendrop_last=False, the last incomplete batch is included in the count.When using DDP (
world_size > 1), the returned length is the number of batches this specific rank will process, which is approximatelytotal_batches / world_size.
- __iter__()[source]
Yield batches of data according to the sampling strategy.
Creates an iterator that yields batches of data by:
Getting indices from the sampling strategy (same across all DDP ranks)
Dividing indices into fetch ranges
Distributing fetch ranges among DDP ranks (round-robin)
Further distributing among DataLoader workers (if multiprocessing)
Fetching data in chunks and applying fetch transforms
Dividing fetched data into batches and applying batch transforms
Yielding transformed batches
- Yields:
object – Batches of data after applying all transforms. The exact type depends on the data collection and any applied transforms.
Examples
>>> from scdataset.strategy import Streaming >>> import numpy as np >>> data = np.random.randn(100, 10) >>> dataset = scDataset(data, Streaming(), batch_size=5) >>> for i, batch in enumerate(dataset): ... print(f"Batch {i}: shape {batch.shape}") ... if i >= 2: # Just show first few batches ... break Batch 0: shape (5, 10) Batch 1: shape (5, 10) Batch 2: shape (5, 10)
Notes
The fetch-then-batch approach can improve I/O efficiency by:
Sorting indices before fetching for better disk access patterns
Fetching multiple batches worth of data at once
Reducing the number of data access operations
Shuffling behavior is controlled by the sampling strategy’s
_shuffle_before_yieldattribute.DDP Distribution: When using multiple ranks (
world_size > 1), fetches are distributed in round-robin fashion. Rank 0 gets fetches 0, world_size, 2*world_size, etc. Rank 1 gets fetches 1, world_size+1, 2*world_size+1, etc. This ensures even load distribution and that all data is processed exactly once across all ranks.Auto-incrementing epoch: The epoch counter automatically increments each time the dataset is iterated. This ensures different shuffling each epoch without requiring manual
set_epoch()calls.