Examples#
This section provides comprehensive examples of using scDataset in various scenarios.
Single-Cell Data Analysis#
Working with AnnData#
import anndata as ad
import scanpy as sc
import numpy as np
from scdataset import scDataset, BlockShuffling
from torch.utils.data import DataLoader
# Load single-cell data
adata = sc.datasets.pbmc3k_processed()
# Create custom fetch function for AnnData
def fetch_anndata(adata, indices):
# Get expression matrix and convert to dense if sparse
data = adata[indices].X
if hasattr(data, 'toarray'):
data = data.toarray()
return data.astype(np.float32)
# Create dataset with block shuffling
dataset = scDataset(
adata,
BlockShuffling(block_size=8),
batch_size=64,
fetch_callback=fetch_anndata
)
# Use with DataLoader
loader = DataLoader(dataset, batch_size=None, num_workers=4)
for batch in loader:
print(f"Processing batch of shape: {batch.shape}")
# Your model training code here
break
Class-Balanced Training#
import pandas as pd
from scdataset import ClassBalancedSampling
# Assume you have cell type annotations
cell_types = adata.obs['cell_type'].values
# Create balanced sampling strategy
strategy = ClassBalancedSampling(
cell_types,
total_size=10000, # Generate 10k balanced samples per epoch
block_size=8
)
dataset = scDataset(adata, strategy, batch_size=32, fetch_callback=fetch_anndata)
loader = DataLoader(dataset, batch_size=None, num_workers=4)
# Training loop with balanced batches
for epoch in range(10):
for batch in loader:
# Each batch will be class-balanced
train_step(batch)
Multi-Modal Data#
For multi-modal single-cell data (e.g., gene expression + protein measurements),
you can use the MultiIndexable class to keep different data modalities
synchronized during indexing:
import numpy as np
from scdataset import scDataset, BlockShuffling, MultiIndexable
from torch.utils.data import DataLoader
# Simulate multi-modal data
n_cells = 1000
gene_data = np.random.randn(n_cells, 2000) # Gene expression
protein_data = np.random.randn(n_cells, 100) # Protein measurements
metadata = np.random.randn(n_cells, 10) # Cell metadata
# Method 1: Using keyword arguments
multimodal_data = MultiIndexable(
genes=gene_data,
proteins=protein_data,
metadata=metadata
)
# Method 2: Using dictionary as positional argument
data_dict = {
'genes': gene_data,
'proteins': protein_data,
'metadata': metadata
}
multimodal_data = MultiIndexable(data_dict)
# Create dataset - all modalities will be indexed together
dataset = scDataset(
multimodal_data,
BlockShuffling(block_size=8),
batch_size=32
)
# Use with DataLoader
loader = DataLoader(dataset, batch_size=None, num_workers=4)
for batch in loader:
genes = batch['genes'] # Shape: (32, 2000)
proteins = batch['proteins'] # Shape: (32, 100)
meta = batch['metadata'] # Shape: (32, 10)
print(f"Genes: {genes.shape}, Proteins: {proteins.shape}, Meta: {meta.shape}")
# All correspond to the same 32 cells
break
Alternative approach with custom fetch function (for AnnData objects):
def fetch_multimodal(adata, indices):
# Fetch both gene expression and protein data
gene_data = adata[indices].X.toarray()
protein_data = adata[indices].obsm['protein'].toarray()
return MultiIndexable(
genes=gene_data,
proteins=protein_data
)
dataset = scDataset(
adata,
BlockShuffling(block_size=8),
batch_size=32,
fetch_callback=fetch_multimodal
)
Large-Scale Training#
Memory-Efficient Data Loading#
from scdataset import BlockWeightedSampling
# For very large datasets, use higher fetch factors
dataset = scDataset(
large_data_collection,
BlockShuffling(block_size=4),
batch_size=64,
fetch_factor=16, # Fetch 16 batches worth of data at once
)
# Configure DataLoader for optimal performance
loader = DataLoader(
dataset,
batch_size=None,
num_workers=12, # Use multiple workers
prefetch_factor=17, # fetch_factor + 1
pin_memory=True, # For GPU training
)
Subset Training and Validation#
from sklearn.model_selection import train_test_split
# Split indices for train/validation
indices = np.arange(len(data))
train_idx, val_idx = train_test_split(indices, test_size=0.2, random_state=42)
# Training dataset
train_dataset = scDataset(
data,
BlockShuffling(indices=train_idx, block_size=8),
batch_size=64
)
# Validation dataset (streaming for deterministic evaluation)
val_dataset = scDataset(
data,
Streaming(indices=val_idx),
batch_size=64
)
# Training loader
train_loader = DataLoader(train_dataset, batch_size=None)
# Validation loader
val_loader = DataLoader(val_dataset, batch_size=None)
# Training loop
for epoch in range(num_epochs):
# Training
for batch in train_loader:
train_step(batch)
# Validation
for batch in val_loader:
val_step(batch)
Custom Data Transformations#
On-the-Fly Normalization#
def log_normalize(batch):
# Apply log1p normalization per batch
return np.log1p(batch)
def standardize_genes(batch):
# Standardize genes (features) across batch
return (batch - batch.mean(axis=0)) / (batch.std(axis=0) + 1e-8)
dataset = scDataset(
data,
BlockShuffling(block_size=8),
batch_size=64,
batch_transform=lambda x: standardize_genes(log_normalize(x))
)
Data Augmentation#
def add_noise(batch, noise_level=0.1):
# Add Gaussian noise for data augmentation
noise = np.random.normal(0, noise_level, batch.shape)
return batch + noise
def dropout_genes(batch, dropout_rate=0.1):
# Randomly set some genes to zero
mask = np.random.random(batch.shape) > dropout_rate
return batch * mask
def augment_batch(batch):
batch = add_noise(batch)
batch = dropout_genes(batch)
return batch.astype(np.float32)
dataset = scDataset(
data,
BlockShuffling(block_size=8),
batch_size=64,
batch_transform=augment_batch
)
Working with HuggingFace Datasets#
Basic Usage#
from datasets import load_dataset
from torch.utils.data import DataLoader
# Load a HuggingFace dataset
hf_dataset = load_dataset("imdb", split="train[:1000]")
# Custom batch callback for HuggingFace datasets
def extract_hf_batch(fetched_data, batch_indices):
# Create dataset with custom batch callback
dataset = scDataset(
)
for batch in DataLoader(dataset, batch_size=None):
Custom Processing for HuggingFace Data#
def extract_hf_batch(fetched_data, batch_indices):
def process_hf_batch(batch_dict):
dataset = scDataset(
)
Working with MultiIndexable#
The MultiIndexable class provides a convenient way to group multiple
indexable objects that should be indexed together. This is particularly useful
for multi-modal data or features and labels.
Basic MultiIndexable Usage#
import numpy as np
from scdataset import MultiIndexable, scDataset, Streaming
from torch.utils.data import DataLoader
# Create sample data
features = np.random.randn(1000, 50) # Features
labels = np.random.randint(0, 3, 1000) # Labels
# Group them together
data = MultiIndexable(features, labels, names=['X', 'y'])
# Or using dictionary syntax
data = MultiIndexable(X=features, y=labels)
# Create dataset
dataset = scDataset(data, Streaming(), batch_size=64)
loader = DataLoader(dataset, batch_size=None)
for batch in loader:
X_batch = batch['X'] # or batch[0]
y_batch = batch['y'] # or batch[1]
print(f"Features: {X_batch.shape}, Labels: {y_batch.shape}")
break
Multi-Modal Single-Cell Data#
# Simulate CITE-seq data (RNA + protein)
n_cells = 5000
rna_data = np.random.randn(n_cells, 2000) # Gene expression
protein_data = np.random.randn(n_cells, 50) # Surface proteins
cell_types = np.random.choice(['T', 'B', 'NK'], n_cells) # Labels
# Group all modalities
cite_seq_data = MultiIndexable(
rna=rna_data,
proteins=protein_data,
cell_types=cell_types
)
# Use with class-balanced sampling
from scdataset import ClassBalancedSampling
strategy = ClassBalancedSampling(cell_types, total_size=2000, block_size=16)
dataset = scDataset(cite_seq_data, strategy, batch_size=32)
for batch in dataset:
rna = batch['rna'] # RNA expression for 32 cells
proteins = batch['proteins'] # Protein expression for same 32 cells
types = batch['cell_types'] # Cell type labels for same 32 cells
# All data is synchronized - same cells across modalities
break
Subsetting and Indexing#
# Create MultiIndexable
data = MultiIndexable(
features=np.random.randn(1000, 100),
labels=np.random.randint(0, 5, 1000),
metadata=np.random.randn(1000, 10)
)
# Access individual indexables
features = data['features'] # or data[0]
labels = data['labels'] # or data[1]
# Subset by sample indices - returns new MultiIndexable
subset = data[100:200] # Samples 100-199 from all modalities
train_data = data[train_indices] # Training subset
# Check properties
print(f"Original length: {len(data)}") # 1000 samples
print(f"Subset length: {len(subset)}") # 100 samples
print(f"Number of modalities: {data.count}") # 3 modalities
print(f"Modality names: {data.names}") # ['features', 'labels', 'metadata']
Integration with PyTorch Lightning#
import pytorch_lightning as pl
from torch.utils.data import DataLoader
class SingleCellDataModule(pl.LightningDataModule):
def __init__(self, data_path, batch_size=64, num_workers=4):
super().__init__()
self.data_path = data_path
self.batch_size = batch_size
self.num_workers = num_workers
def setup(self, stage=None):
# Load your data
self.data = load_data(self.data_path)
# Split indices
indices = np.arange(len(self.data))
train_idx, val_idx = train_test_split(indices, test_size=0.2)
# Create datasets
self.train_dataset = scDataset(
self.data,
BlockShuffling(block_size=8),
batch_size=self.batch_size
)
self.val_dataset = scDataset(
self.data,
Streaming(indices=val_idx),
batch_size=self.batch_size
)
def train_dataloader(self):
return DataLoader(
self.train_dataset,
num_workers=self.num_workers,
prefetch_factor=2
)
def val_dataloader(self):
return DataLoader(
self.val_dataset,
num_workers=self.num_workers,
prefetch_factor=2
)
Advanced Sampling Strategies#
Custom Weighted Sampling#
# Create weights based on cell type frequency (inverse frequency weighting)
cell_types = adata.obs['cell_type']
type_counts = cell_types.value_counts()
weights = 1.0 / type_counts[cell_types].values
weights = weights / weights.sum() # Normalize
strategy = BlockWeightedSampling(
weights=weights,
total_size=5000,
block_size=8,
replace=True
)
dataset = scDataset(adata, strategy, batch_size=32)
Temporal Sampling for Time-Series Data#
# Custom strategy for time-series single-cell data
def create_temporal_indices(timepoints, window_size=5):
indices = []
for i in range(len(timepoints) - window_size + 1):
indices.extend(range(i, i + window_size))
return np.array(indices)
temporal_indices = create_temporal_indices(adata.obs['timepoint'])
dataset = scDataset(
adata,
Streaming(indices=temporal_indices),
batch_size=32
)
Performance Benchmarking#
import time
from contextlib import contextmanager
@contextmanager
def timer():
start = time.time()
yield
end = time.time()
print(f"Time taken: {end - start:.2f} seconds")
# Compare different configurations
configs = [
{'block_size': 32, 'fetch_factor': 1},
{'block_size': 64, 'fetch_factor': 2},
{'block_size': 128, 'fetch_factor': 4},
]
for config in configs:
dataset = scDataset(
large_data,
BlockShuffling(block_size=config['block_size']),
batch_size=64,
fetch_factor=config['fetch_factor']
)
loader = DataLoader(dataset, num_workers=4)
with timer():
for i, batch in enumerate(loader):
if i >= 100: # Test first 100 batches
break
print(f"Config {config}: done")
Tips and Best Practices#
Choose appropriate block sizes: Larger blocks (128-512) work better for sequential data access, smaller blocks (16-64) for more randomness.
Use fetch_factor > 1 for better I/O efficiency, especially with slow storage.
Set prefetch_factor = fetch_factor + 1 in DataLoader for optimal performance.
For validation, use
Streamingstrategy for deterministic results.For large datasets, consider using fewer workers but higher fetch_factor to reduce memory overhead.
Profile your pipeline to find the optimal configuration for your specific data and hardware setup.