Distributed & Parallel Training#
Added in version 0.3.0.
scDataset provides native support for distributed and parallel training with
PyTorch. All three parallel/distributed modalities work automatically and are
handled internally by scDataset:
DataLoader multiprocessing:
num_workersfor parallel data loadingDataParallel (DP): Simple multi-GPU on single node
DistributedDataParallel (DDP): Multi-GPU, multi-node training
No special configuration is required, scDataset auto-detects the training
environment and partitions data appropriately.
How DDP Works#
When running in a distributed environment, scDataset automatically partitions
data across workers using round-robin assignment. Each worker (rank) processes
only its assigned portion of the data, ensuring no overlap.
Key features:
Automatic rank detection: Works with
torch.distributedenvironment variablesRound-robin partitioning: Even distribution of data across all ranks
Per-epoch shuffling: Automatic epoch counter ensures different data ordering each epoch
Deterministic splits: Same data partition for a given rank/world_size/seed
Basic DDP Setup#
import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from scdataset import scDataset, BlockShuffling
def setup_distributed():
"""Initialize distributed training."""
dist.init_process_group(backend="nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
return rank, world_size
def main():
rank, world_size = setup_distributed()
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
# Load your data source
adata = load_adata() # Your data loading function
# Create scDataset - it automatically detects rank and world_size
dataset = scDataset(
adata,
BlockShuffling(block_size=64),
batch_size=128,
fetch_factor=64,
fetch_callback=my_fetch_fn
)
# Create DataLoader (no DistributedSampler needed!)
loader = DataLoader(
dataset,
batch_size=None, # Batching handled by scDataset
num_workers=4,
prefetch_factor=65 # fetch_factor + 1
)
# Standard DDP model setup
model = YourModel().to(local_rank)
model = DDP(model, device_ids=[local_rank])
for epoch in range(num_epochs):
for batch in loader:
batch = batch.to(local_rank)
# Training code here
pass
dist.destroy_process_group()
if __name__ == "__main__":
main()
Manual Rank Configuration#
You can also manually specify rank and world_size without relying on environment variables:
# Explicit rank configuration
dataset = scDataset(
adata,
BlockShuffling(block_size=64),
batch_size=128,
fetch_factor=64,
fetch_callback=my_fetch_fn,
rank=2, # This worker's rank (0-indexed)
world_size=4 # Total number of workers
)
Data Partitioning#
scDataset uses round-robin partitioning to distribute data across workers:
Rank 0: Indices 0, 4, 8, 12, …
Rank 1: Indices 1, 5, 9, 13, …
Rank 2: Indices 2, 6, 10, 14, …
Rank 3: Indices 3, 7, 11, 15, …
This ensures even distribution and allows each worker to process its portion independently without communication during data loading.
# Example: Understanding partitioning
# With 1000 samples and world_size=4:
# - Each rank processes ~250 samples
# - Rank 0: samples at positions 0, 4, 8, ... (250 samples)
# - Rank 1: samples at positions 1, 5, 9, ... (250 samples)
# etc.
Automatic Epoch Handling#
scDataset automatically increments an internal epoch counter each time the
dataset is iterated. This means different shuffling happens automatically each epoch.
Note
Unlike PyTorch’s DistributedSampler, you do not need to call set_epoch()
before each epoch. scDataset handles this internally by combining the base seed
with an auto-incrementing epoch counter.
# Different shuffling each epoch automatically - no set_epoch() needed!
for epoch in range(100):
for batch in loader:
train_step(batch)
# Epoch counter auto-increments when iteration completes
For reproducibility across runs, you can set a base seed in the constructor:
dataset = scDataset(adata, strategy, batch_size=128, seed=42)
# Same seed = same shuffling sequence across runs
Launching Distributed Training#
Use torchrun to launch distributed training:
Single Node, Multiple GPUs:
torchrun --nproc_per_node=4 train.py
Multiple Nodes:
# On node 0 (master):
torchrun --nproc_per_node=4 --nnodes=2 --node_rank=0 \
--master_addr=<master_ip> --master_port=29500 train.py
# On node 1:
torchrun --nproc_per_node=4 --nnodes=2 --node_rank=1 \
--master_addr=<master_ip> --master_port=29500 train.py
Complete Training Example#
Here’s a complete example with all components:
import os
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
import anndata as ad
from scdataset import scDataset, BlockShuffling
from scdataset.transforms import adata_to_mindex
def train():
# Initialize distributed
dist.init_process_group(backend="nccl")
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
# Load data (each rank loads independently)
adata = ad.read_h5ad("large_dataset.h5ad", backed='r')
# Create dataset - DDP handled automatically
dataset = scDataset(
adata,
BlockShuffling(block_size=32),
fetch_factor=32,
batch_size=512,
fetch_callback=lambda d, idx: adata_to_mindex(d[idx])
)
loader = DataLoader(dataset, batch_size=None, num_workers=4, prefetch_factor=33)
# Model setup
model = nn.Sequential(
nn.Linear(adata.n_vars, 256),
nn.ReLU(),
nn.Linear(256, 128)
).to(local_rank)
model = DDP(model, device_ids=[local_rank])
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()
# Training loop - shuffling changes automatically each epoch!
for epoch in range(10):
for batch_idx, batch in enumerate(loader):
batch = batch.to(local_rank)
output = model(batch)
loss = criterion(output, batch[:, :128]) # Reconstruction
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch_idx % 100 == 0 and local_rank == 0:
print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}")
dist.destroy_process_group()
if __name__ == "__main__":
train()
Weighted Sampling with DDP#
Added in version 0.3.0.
One of scDataset’s unique advantages is native support for weighted sampling
in distributed training scenarios. This is a capability that PyTorch does not
provide out of the box.
The Problem with PyTorch’s Approach
PyTorch provides WeightedRandomSampler for handling class imbalance and
DistributedSampler for distributed training, but these two samplers are
incompatible with each other. This has been a long-standing issue in the PyTorch ecosystem.
While a DistributedWeightedRandomSampler PR
exists, it remains unmerged as of January 2026.
scDataset’s Solution
With scDataset, weighted sampling and distributed training are decoupled by design.
The sampling strategy operates on the data collection level, and DDP partitioning
is applied automatically on top of that. This means you can use any strategy, including
BlockWeightedSampling and ClassBalancedSampling, without any special configuration:
import torch.distributed as dist
from scdataset import scDataset, BlockWeightedSampling
dist.init_process_group(backend="nccl")
# Weighted sampling just works in DDP - no special wrappers needed!
weights = compute_sample_weights(adata.obs['cell_type']) # Your weights
dataset = scDataset(
adata,
BlockWeightedSampling(
weights=weights,
total_size=10000,
block_size=64
),
batch_size=128,
fetch_factor=64,
fetch_callback=my_fetch_fn
)
# Each GPU gets a different portion of the weighted-sampled data
loader = DataLoader(dataset, batch_size=None, num_workers=4, prefetch_factor=65)
Why This Matters
To the best of our knowledge, scDataset is among the first libraries to provide
efficient on-disk weighted sampling that works natively with DDP. This is possible
because:
Strategy and distribution are decoupled: The sampling strategy generates indices, and DDP partitioning is applied afterward
Deterministic seeding: All ranks use the same base seed plus epoch offset, ensuring coordinated but different data across GPUs
On-disk efficiency: By fetching only the required indices,
scDatasetavoids loading the entire dataset into memory, making large-scale weighted sampling practical even for billion-cell datasets
DDP with Any Strategy#
scDataset supports any sampling strategy in DDP mode. The distributed
partitioning layer is orthogonal to the strategy layer:
Strategy |
Description |
|---|---|
|
Simple sequential access, partitioned across GPUs |
|
Block-shuffled access, each GPU gets different blocks |
|
Weighted sampling with DDP partitioning |
|
Class-balanced with automatic DDP support |
# All strategies work the same way in DDP
# Simple streaming
dataset1 = scDataset(adata, Streaming(), ...)
# Block shuffling
dataset2 = scDataset(adata, BlockShuffling(block_size=64), ...)
# Weighted sampling
dataset3 = scDataset(adata, BlockWeightedSampling(weights=w), ...)
# Class-balanced
dataset4 = scDataset(adata, ClassBalancedSampling(labels=l), ...)
# All automatically partition data across GPUs when run with torchrun
Best Practices#
No DistributedSampler needed:
scDatasethandles partitioning internallyUse ``batch_size=None`` in DataLoader: Batching is handled by
scDatasetBacked mode for large files: Use
ad.read_h5ad(path, backed='r')to avoid loading entire datasets into memory on each rankSame data across ranks: Ensure all ranks can access the same data files
Logging on rank 0 only: Print/log only from rank 0 to avoid duplicate output
Synchronize when needed: Use
dist.barrier()for synchronization pointsSet seed for reproducibility: Use the
seedparameter (e.g.,seed=42) if you need identical shuffling sequences across different runsWeighted sampling works out of the box: Use
BlockWeightedSamplingorClassBalancedSamplingin DDP without any special configuration
DataLoader Multiprocessing (num_workers)#
scDataset works seamlessly with PyTorch DataLoader’s num_workers parameter
for multiprocessing-based data loading. This enables parallel data fetching and
preprocessing while the GPU trains on the current batch.
Basic Usage
from torch.utils.data import DataLoader
from scdataset import scDataset, BlockShuffling
dataset = scDataset(
data,
BlockShuffling(block_size=32),
fetch_factor=64,
batch_size=128
)
# Use num_workers for parallel data loading
loader = DataLoader(
dataset,
batch_size=None, # IMPORTANT: batching handled by scDataset
num_workers=4, # Use 4 worker processes
prefetch_factor=65 # Prefetch 65 batches per worker
)
for batch in loader:
# Training code here
pass
How It Works
When num_workers > 0, PyTorch spawns worker processes that load data in parallel:
Each worker independently iterates over
scDatasetWorkers use PyTorch’s automatic data splitting via
get_worker_info()scDatasetdetects worker count and ID, partitioning data among workersBatches are prefetched and queued for the main process
Combining with DDP
Both DDP partitioning and DataLoader multiprocessing work together:
# In a DDP training script launched with torchrun
dist.init_process_group(backend="nccl")
dataset = scDataset(
adata,
BlockShuffling(block_size=64),
batch_size=128
# rank and world_size auto-detected from torch.distributed
)
# Each DDP rank uses multiple workers
loader = DataLoader(
dataset,
batch_size=None,
num_workers=4 # 4 workers PER GPU
)
# With 4 GPUs and 4 workers each, you have 16 worker processes total
Choosing num_workers
Start with
num_workers=2-4and increase if GPU utilization is lowMore workers = more memory usage (each loads data separately)
For backed HDF5 files, each worker opens its own file handle
Monitor CPU and memory usage to find the optimal value
Important Notes
Always use
batch_size=Nonein DataLoader (batching is handled by scDataset)For backed AnnData, each worker opens its own file handle
The
seedparameter ensures reproducibility across workers
DataParallel (DP) Support#
For simpler single-node multi-GPU setups, you can use PyTorch’s DataParallel
instead of DistributedDataParallel. This is easier to set up but less efficient
than DDP for large-scale training.
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from scdataset import scDataset, BlockShuffling
# DataParallel doesn't need distributed setup
# scDataset works normally without rank/world_size
dataset = scDataset(
data,
BlockShuffling(block_size=64),
batch_size=128,
fetch_factor=64
)
loader = DataLoader(dataset, batch_size=None, num_workers=4, prefetch_factor=65)
# Wrap model with DataParallel
model = nn.Sequential(
nn.Linear(2000, 256),
nn.ReLU(),
nn.Linear(256, 128)
)
if torch.cuda.device_count() > 1:
print(f"Using {torch.cuda.device_count()} GPUs with DataParallel")
model = nn.DataParallel(model)
model = model.cuda()
for batch in loader:
batch = batch.cuda() # DataParallel handles GPU distribution
output = model(batch)
# Training code here
When to Use DataParallel vs DDP
DataParallel (DP) |
DistributedDataParallel (DDP) |
|
|---|---|---|
Setup |
Simple, no |
Requires distributed initialization |
Efficiency |
Slower due to GIL and data copying |
Faster, recommended for production |
Multi-node |
No |
Yes |
Data handling |
Single DataLoader, GPU scatter |
Per-GPU DataLoader, no scatter |
scDataset usage |
Normal (no rank/world_size) |
Auto-detected rank/world_size |
Recommendation: Use DDP for serious training workloads. DataParallel is acceptable for quick experiments or when simplicity is more important than maximum throughput.
Further Reading#
PyTorch Issue #23430 - Feature request for DistributedWeightedRandomSampler
PyTorch PR #150182 - Proposed DistributedWeightedRandomSampler (not merged)