"""
Iterable PyTorch Dataset for on-disk data collections.
This module provides the main :class:`scDataset` class for creating efficient
iterable datasets from on-disk data collections with flexible sampling
strategies and customizable data transformation pipelines.
.. autosummary::
:toctree: generated/
scDataset
"""
from typing import Callable, Optional
import numpy as np
from torch.utils.data import IterableDataset, get_worker_info
from .strategy import SamplingStrategy
[docs]
class scDataset(IterableDataset):
"""
Iterable PyTorch Dataset for on-disk data collections with flexible sampling strategies.
This dataset implementation provides efficient iteration over large on-disk
data collections using configurable sampling strategies. It supports various
data transformations, custom fetch/batch callbacks, and automatic handling
of multiprocessing workers and distributed training (DDP).
Parameters
----------
data_collection : object
The data collection to sample from (e.g., AnnCollection, HuggingFace Dataset,
numpy array, etc.). Must support indexing (``__getitem__``) and length
(``__len__``) operations.
strategy : SamplingStrategy
Strategy for sampling indices from the data collection. Determines the
order and selection of samples.
batch_size : int
Number of samples per minibatch. Must be positive.
fetch_factor : int, default=16
Multiplier for fetch size relative to batch size. Higher values may
improve I/O efficiency by fetching more data at once.
drop_last : bool, default=False
Whether to drop the last incomplete batch if it contains fewer than
``batch_size`` samples.
fetch_callback : Callable, optional
Custom function to fetch data given indices. Should accept
``(data_collection, indices)`` and return the fetched data.
If None, uses default indexing (``data_collection[indices]``).
fetch_transform : Callable, optional
Function to transform data after fetching but before batching.
Applied to the entire fetch (multiple batches worth of data).
batch_callback : Callable, optional
Custom function to extract batch data from fetched data.
Should accept ``(fetched_data, batch_indices)`` and return the batch.
If None, uses default indexing (``fetched_data[batch_indices]``).
batch_transform : Callable, optional
Function to transform each individual batch before yielding.
rank : int, optional
Process rank for distributed training (DDP). If None, auto-detects
from torch.distributed if initialized. Defaults to 0 for non-distributed.
world_size : int, optional
Total number of processes for distributed training (DDP). If None,
auto-detects from torch.distributed if initialized. Defaults to 1.
seed : int or None, optional
Base seed for reproducible shuffling. Combined with auto-incrementing
epoch counter to produce different shuffling each epoch while ensuring
reproducibility across runs with the same seed.
If None, a random seed is generated and shared across all DDP ranks
to ensure consistent shuffling while providing variety between runs.
Attributes
----------
collection : object
The underlying data collection.
strategy : SamplingStrategy
The sampling strategy being used.
batch_size : int
Size of each batch.
fetch_factor : int
Fetch size multiplier.
drop_last : bool
Whether incomplete batches are dropped.
fetch_size : int
Total number of samples fetched at once (batch_size * fetch_factor).
sort_before_fetch : bool
Always True. Indices are sorted before fetching for optimal I/O.
rank : int
Process rank for distributed training.
world_size : int
Total number of distributed processes.
Raises
------
ValueError
If batch_size or fetch_factor is not positive.
TypeError
If data_collection doesn't support required operations or strategy
is not a SamplingStrategy instance.
Examples
--------
>>> from scdataset import scDataset
>>> from scdataset.strategy import Streaming
>>> import numpy as np
>>> # Simple streaming dataset
>>> data = np.random.randn(1000, 50) # 1000 samples, 50 features
>>> strategy = Streaming()
>>> dataset = scDataset(data, strategy, batch_size=32)
>>> len(dataset) # Number of batches
32
>>> # With custom transforms
>>> def normalize_batch(batch):
... return (batch - batch.mean()) / batch.std()
>>> dataset = scDataset(
... data, strategy, batch_size=32,
... batch_transform=normalize_batch
... )
>>> # Iterate through batches
>>> for batch in dataset: # doctest: +ELLIPSIS
... print(batch.shape)
... break
(32, 50)
>>> # Distributed Data Parallel (DDP) usage
>>> # In DDP training script:
>>> # import torch.distributed as dist
>>> # dist.init_process_group(...)
>>> # dataset = scDataset(data, strategy, batch_size=32) # Auto-detects DDP
>>> # Or manually specify:
>>> # dataset = scDataset(data, strategy, batch_size=32, rank=0, world_size=4)
See Also
--------
scdataset.strategy.SamplingStrategy : Base class for sampling strategies
scdataset.strategy.Streaming : Sequential sampling without shuffling
scdataset.strategy.BlockShuffling : Block-based shuffling
scdataset.strategy.BlockWeightedSampling : Weighted sampling with blocks
scdataset.strategy.ClassBalancedSampling : Automatic class balancing
Notes
-----
The dataset automatically handles PyTorch's multiprocessing by distributing
fetch ranges among workers. Each worker gets a different subset of the data
to avoid duplication.
Data is fetched in chunks of size ``batch_size * fetch_factor`` and then
divided into batches. This can improve I/O efficiency, especially for
datasets where accessing non-contiguous indices is expensive.
**DDP Support**: When using Distributed Data Parallel, fetches are distributed
across ranks in round-robin fashion for better load balancing. Each rank
processes every ``world_size``-th fetch, ensuring no data duplication.
Combined with PyTorch DataLoader's ``num_workers``, this provides two levels
of parallelism: across DDP ranks and across DataLoader workers within each rank.
"""
[docs]
def __init__(
self,
data_collection,
strategy: SamplingStrategy,
batch_size: int,
fetch_factor: int = 16,
drop_last: bool = False,
fetch_callback: Optional[Callable] = None,
fetch_transform: Optional[Callable] = None,
batch_callback: Optional[Callable] = None,
batch_transform: Optional[Callable] = None,
rank: Optional[int] = None,
world_size: Optional[int] = None,
seed: Optional[int] = None,
):
"""
Initialize the scDataset.
Parameters
----------
data_collection : object
Data collection supporting indexing and len().
strategy : SamplingStrategy
Sampling strategy instance.
batch_size : int
Positive integer for batch size.
fetch_factor : int, default=16
Positive integer for fetch size multiplier.
drop_last : bool, default=False
Whether to drop incomplete batches.
fetch_callback : Callable, optional
Custom fetch function.
fetch_transform : Callable, optional
Transform applied to fetched data.
batch_callback : Callable, optional
Custom batch extraction function.
batch_transform : Callable, optional
Transform applied to each batch.
rank : int, optional
Process rank for DDP. Auto-detected if None.
world_size : int, optional
Number of DDP processes. Auto-detected if None.
seed : int or None, optional
Base seed for reproducible shuffling. If None, generates a random
seed shared across DDP ranks for consistent shuffling. Combined
with auto-incrementing epoch counter for different shuffling each epoch.
"""
# Input validation
if batch_size <= 0:
raise ValueError("batch_size must be positive")
if fetch_factor <= 0:
raise ValueError("fetch_factor must be positive")
if not isinstance(strategy, SamplingStrategy):
raise TypeError("strategy must be an instance of SamplingStrategy")
self.collection = data_collection
self.strategy = strategy
self.batch_size = batch_size
self.fetch_factor = fetch_factor
self.drop_last = drop_last
self.fetch_size = self.batch_size * self.fetch_factor
self.sort_before_fetch = True # Always sort before fetch as per new design
# Store callback functions
self.fetch_callback = fetch_callback
self.fetch_transform = fetch_transform
self.batch_callback = batch_callback
self.batch_transform = batch_transform
# DDP support with auto-detection
self.rank, self.world_size = self._detect_ddp(rank, world_size)
# Epoch counter for reproducible shuffling - auto-increments each iteration
self._epoch = 0
# Base seed for deterministic shuffling sequences
# If None, generate random seed shared across DDP ranks
self._base_seed = self._init_seed(seed)
def _init_seed(self, seed: Optional[int]) -> int:
"""
Initialize or generate the base seed for shuffling.
If a seed is provided, use it directly. If None, generate a random
seed and broadcast it from rank 0 to all other ranks to ensure
consistent shuffling across DDP processes.
Parameters
----------
seed : int or None
Explicit seed, or None to generate a random shared seed.
Returns
-------
int
The seed to use for shuffling.
"""
if seed is not None:
return seed
# Generate random seed - will be shared across ranks
import torch
# Single process: just generate random seed
if self.world_size == 1:
return int(torch.randint(0, 2**31, (1,)).item())
# Multi-process DDP: broadcast seed from rank 0
import torch.distributed as dist
if not (dist.is_available() and dist.is_initialized()):
raise RuntimeError(
f"world_size={self.world_size} but torch.distributed is not initialized. "
"Please call torch.distributed.init_process_group() before creating the dataset."
)
# Rank 0 generates seed and broadcasts to all ranks
# Use device-appropriate tensor for the backend (NCCL needs CUDA tensors)
backend = dist.get_backend()
if backend == "nccl" and torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}")
seed_tensor = torch.zeros(1, dtype=torch.int64, device=device)
if self.rank == 0:
seed_tensor[0] = torch.randint(0, 2**31, (1,), device=device).item()
else:
# gloo and other backends work with CPU tensors
seed_tensor = torch.zeros(1, dtype=torch.int64)
if self.rank == 0:
seed_tensor[0] = torch.randint(0, 2**31, (1,)).item()
dist.broadcast(seed_tensor, src=0)
return int(seed_tensor.item())
def _detect_ddp(self, rank: Optional[int], world_size: Optional[int]) -> tuple:
"""
Detect or validate DDP settings.
Auto-detects from torch.distributed if available and initialized,
otherwise defaults to single-process settings.
Parameters
----------
rank : int or None
Explicit rank, or None for auto-detection.
world_size : int or None
Explicit world_size, or None for auto-detection.
Returns
-------
tuple
(rank, world_size) tuple.
"""
try:
import torch.distributed as dist
if dist.is_available() and dist.is_initialized():
detected_rank = dist.get_rank()
detected_world_size = dist.get_world_size()
else:
detected_rank = 0
detected_world_size = 1
except ImportError:
detected_rank = 0
detected_world_size = 1
final_rank = rank if rank is not None else detected_rank
final_world_size = world_size if world_size is not None else detected_world_size
return final_rank, final_world_size
[docs]
def __len__(self) -> int:
"""
Return the number of batches in the dataset for this rank.
Calculates the number of batches that will be yielded by the iterator
based on the sampling strategy's effective length, batch size, and
the number of DDP ranks (if using distributed training).
Returns
-------
int
Number of batches in the dataset for this rank.
Examples
--------
>>> from scdataset.strategy import Streaming
>>> dataset = scDataset(range(100), Streaming(), batch_size=10)
>>> len(dataset)
10
>>> # With drop_last=True
>>> dataset = scDataset(range(105), Streaming(), batch_size=10, drop_last=True)
>>> len(dataset) # 105 // 10 = 10 (drops 5 samples)
10
>>> # With drop_last=False (default)
>>> dataset = scDataset(range(105), Streaming(), batch_size=10, drop_last=False)
>>> len(dataset) # ceil(105 / 10) = 11
11
Notes
-----
When ``drop_last=True``, only complete batches are counted.
When ``drop_last=False``, the last incomplete batch is included in the count.
When using DDP (``world_size > 1``), the returned length is the number
of batches this specific rank will process, which is approximately
``total_batches / world_size``.
"""
# Get the total number of samples from the sampling strategy
n = self.strategy.get_len(self.collection)
# Calculate total fetches and per-rank fetches
fetch_size = self.fetch_size
num_fetches = (n + fetch_size - 1) // fetch_size
# Round-robin distribution: this rank gets fetches at positions
# rank, rank + world_size, rank + 2*world_size, ...
rank_fetch_ids = list(range(self.rank, num_fetches, self.world_size))
# Check if this rank gets any fetches
if len(rank_fetch_ids) == 0:
return 0
# Calculate exact number of samples for this rank
per_rank_samples = 0
for fetch_id in rank_fetch_ids:
fetch_start = fetch_id * fetch_size
fetch_end = min((fetch_id + 1) * fetch_size, n)
per_rank_samples += fetch_end - fetch_start
# Calculate batches from samples, accounting for drop_last
if self.drop_last:
# Each fetch may have leftover samples that don't form a complete batch
# We need to count batches per fetch, not total samples
num_batches = 0
for fetch_id in rank_fetch_ids:
fetch_start = fetch_id * fetch_size
fetch_end = min((fetch_id + 1) * fetch_size, n)
fetch_samples = fetch_end - fetch_start
num_batches += fetch_samples // self.batch_size
return num_batches
else:
# Each fetch yields ceil(fetch_samples / batch_size) batches
num_batches = 0
for fetch_id in rank_fetch_ids:
fetch_start = fetch_id * fetch_size
fetch_end = min((fetch_id + 1) * fetch_size, n)
fetch_samples = fetch_end - fetch_start
num_batches += (fetch_samples + self.batch_size - 1) // self.batch_size
return num_batches
[docs]
def __iter__(self):
"""
Yield batches of data according to the sampling strategy.
Creates an iterator that yields batches of data by:
1. Getting indices from the sampling strategy (same across all DDP ranks)
2. Dividing indices into fetch ranges
3. Distributing fetch ranges among DDP ranks (round-robin)
4. Further distributing among DataLoader workers (if multiprocessing)
5. Fetching data in chunks and applying fetch transforms
6. Dividing fetched data into batches and applying batch transforms
7. Yielding transformed batches
Yields
------
object
Batches of data after applying all transforms. The exact type
depends on the data collection and any applied transforms.
Examples
--------
>>> from scdataset.strategy import Streaming
>>> import numpy as np
>>> data = np.random.randn(100, 10)
>>> dataset = scDataset(data, Streaming(), batch_size=5)
>>> for i, batch in enumerate(dataset):
... print(f"Batch {i}: shape {batch.shape}")
... if i >= 2: # Just show first few batches
... break
Batch 0: shape (5, 10)
Batch 1: shape (5, 10)
Batch 2: shape (5, 10)
Notes
-----
The fetch-then-batch approach can improve I/O efficiency by:
- Sorting indices before fetching for better disk access patterns
- Fetching multiple batches worth of data at once
- Reducing the number of data access operations
Shuffling behavior is controlled by the sampling strategy's
``_shuffle_before_yield`` attribute.
**DDP Distribution**: When using multiple ranks (``world_size > 1``),
fetches are distributed in round-robin fashion. Rank 0 gets fetches
0, world_size, 2*world_size, etc. Rank 1 gets fetches 1, world_size+1,
2*world_size+1, etc. This ensures even load distribution and that
all data is processed exactly once across all ranks.
**Auto-incrementing epoch**: The epoch counter automatically increments
each time the dataset is iterated. This ensures different shuffling
each epoch without requiring manual ``set_epoch()`` calls.
"""
worker_info = get_worker_info()
# Generate seed for sampling strategy - combine base_seed with epoch
# All ranks use the same seed for consistent global ordering
# epoch * 1000 provides sufficient separation between epochs
current_seed = self._base_seed + self._epoch * 1000
# Auto-increment epoch for next iteration (different shuffling each epoch)
self._epoch += 1
if worker_info is None:
rng = np.random.default_rng(current_seed)
else:
# All workers use the same seed for consistent global ordering
# (they partition work, not randomness)
rng = np.random.default_rng(current_seed)
# Get indices from sampling strategy - same ordering across all ranks
indices = self.strategy.get_indices(self.collection, seed=current_seed)
# Calculate fetch ranges
n = len(indices)
fetch_size = self.fetch_size
num_fetches = (n + fetch_size - 1) // fetch_size
# DDP: Distribute fetches among ranks in round-robin fashion
# This rank gets fetches: rank, rank + world_size, rank + 2*world_size, ...
rank_fetch_ids = list(range(self.rank, num_fetches, self.world_size))
# Build fetch ranges for this rank only
fetch_ranges = [
(i * fetch_size, min((i + 1) * fetch_size, n)) for i in rank_fetch_ids
]
# Handle DataLoader multiprocessing by distributing fetch ranges among workers
if worker_info is not None and len(fetch_ranges) > 0:
num_rank_fetches = len(fetch_ranges)
per_worker = num_rank_fetches // worker_info.num_workers
remainder = num_rank_fetches % worker_info.num_workers
if worker_info.id < remainder:
start = worker_info.id * (per_worker + 1)
end = start + per_worker + 1
else:
start = worker_info.id * per_worker + remainder
end = start + per_worker
fetch_ranges = fetch_ranges[start:end]
# Process each fetch range
for fetch_start, fetch_end in fetch_ranges:
fetch_indices = indices[fetch_start:fetch_end]
if self.sort_before_fetch:
fetch_indices = np.sort(fetch_indices)
# Use custom fetch callback if provided, otherwise use default indexing
if self.fetch_callback is not None:
data = self.fetch_callback(self.collection, fetch_indices)
else:
data = self.collection[fetch_indices]
# Call fetch transform if provided
if self.fetch_transform is not None:
data = self.fetch_transform(data)
if self.strategy._shuffle_before_yield:
shuffle_indices = rng.permutation(len(fetch_indices))
else:
shuffle_indices = np.arange(len(fetch_indices))
# Yield batches
batch_start = 0
while batch_start < len(fetch_indices):
batch_end = min(batch_start + self.batch_size, len(fetch_indices))
# Handle drop_last
if self.drop_last and (batch_end - batch_start) < self.batch_size:
break
# Get batch indices
batch_indices = shuffle_indices[batch_start:batch_end]
# Use custom batch callback if provided, otherwise use default indexing
if self.batch_callback is not None:
batch_data = self.batch_callback(data, batch_indices)
else:
batch_data = data[batch_indices]
# Call batch transform if provided
if self.batch_transform is not None:
batch_data = self.batch_transform(batch_data)
yield batch_data
batch_start = batch_end