Quickstart: Using scDataset with the Tahoe-100M Dataset#
This tutorial demonstrates how to use scDataset to load the Tahoe-100M single-cell dataset in h5ad format and train a simple linear classifier in PyTorch.
Downloading the Tahoe-100M h5ad Files#
You can download the Tahoe-100M dataset in h5ad format from the Arc Institute Google Cloud Storage or convert it from the HuggingFace version. For this tutorial, download one or more plate*_filt_Vevo_Tahoe100M_WServicesFrom_ParseGigalab.h5ad files and note their paths.
[ ]:
# Install required packages (uncomment if running in a fresh environment)
# %pip install scipy scikit-learn tqdm torch anndata scDataset
# Import libraries
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from scipy import sparse
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import anndata as ad
from anndata.experimental import AnnCollection
from scdataset import scDataset, Streaming, BlockShuffling
1. Load the Tahoe-100M Dataset (h5ad)#
We will use the AnnData library to load one or more Tahoe-100M h5ad files and combine them into an AnnCollection for efficient access.
[ ]:
# Load multiple Tahoe-100M h5ad files and create an AnnCollection
h5ad_FILES_PATH = '/path/to/h5ad_folder' # Update this to your local folder with the h5ad files
h5ad_paths = [f'{h5ad_FILES_PATH}/plate{i}_filt_Vevo_Tahoe100M_WServicesFrom_ParseGigalab.h5ad' for i in range(1, 2)] # Loading only 1 plate for demonstration
adatas = [ad.read_h5ad(path, backed='r') for path in h5ad_paths]
for adata in adatas:
# Keep 'cell_line' column
adata.obs = adata.obs[['cell_line']]
collection = AnnCollection(adatas)
print('Total cells:', collection.n_obs)
print('Total genes:', collection.n_vars)
Total cells: 5481420
Total genes: 62710
2. Cell Line Recognition Task and Stratified Split#
For this tutorial, we will train a linear classifier to recognize the cell line of each cell. We will split the dataset into train and test sets using a stratified split on the cell_line column.
[ ]:
# Get cell_line labels from the AnnCollection
cell_lines = collection.obs['cell_line'].values
indices = np.arange(collection.n_obs)
# Stratified train/test split on cell_line
train_idx, test_idx = train_test_split(
indices, stratify=cell_lines, test_size=0.2, random_state=42)
print(f'Train size: {len(train_idx)}, Test size: {len(test_idx)}')
Train size: 4385136, Test size: 1096284
Custom fetch_transform and batch_transform for AnnCollection#
Before we wrap our data with scDataset, it’s important to know that scDataset v0.2.0 allows you to provide custom functions for handling how data is fetched and transformed. This makes it easy to adapt to different data formats and preprocessing needs, especially for large or complex datasets like Tahoe-100M.
fetch_transform: Ensures that each chunk from the AnnCollection is materialized as an in-memory AnnData object. This is necessary because AnnCollection chunks are lazy by default and do not load the full X matrix until
to_adata()is called.batch_transform: Converts each AnnData batch into a tuple
(X, y)suitable for PyTorch training. It densifies the X matrix if it is sparse and encodes the cell line labels as integer tensors. This makes each batch ready for direct use in a PyTorch model.
[ ]:
# Prepare label encoder for cell lines
cell_line_encoder = LabelEncoder()
cell_line_encoder.fit(collection.obs['cell_line'].values)
# Define fetch_transform and batch_transform for scDataset
def fetch_transform(batch):
# Materialize the AnnData batch (X matrix) in memory
return batch.to_adata()
def batch_transform(batch, cell_line_encoder=cell_line_encoder):
# Convert AnnData batch to (X, y) tensors for training
X = batch.X.astype('float32')
# Densify if X is a sparse matrix
if sparse.issparse(X):
X = X.toarray()
X = torch.from_numpy(X)
y = cell_line_encoder.transform(batch.obs['cell_line'].values)
y = torch.from_numpy(y).long()
return X, y
3. Wrap the AnnCollection with scDataset#
To efficiently train and evaluate models, we use scDataset v0.2.0 with sampling strategies to wrap the AnnCollection and create PyTorch DataLoaders for both train and test splits.
New in v0.2.0: scDataset now uses a strategy-based approach for sampling. Each dataset requires a sampling strategy that defines how data is accessed and in what order.
For training, we’ll use BlockShuffling to randomize data while maintaining some locality for better I/O performance. For evaluation, we’ll use Streaming for deterministic, sequential access.
You can tune batch_size and fetch_factor for your hardware; for best performance, set prefetch_factor = fetch_factor + 1 in the DataLoader. The sampling strategies control how indices are generated and whether shuffling occurs. See the scDataset documentation for more details.
[ ]:
# Set up scDataset for train and test splits
batch_size = 64
fetch_factor = 16
num_workers = 12
# Training split with block shuffling for randomization
train_strategy = BlockShuffling(block_size=8, indices=train_idx)
scdata_train = scDataset(
data_collection=collection,
strategy=train_strategy,
batch_size=batch_size,
fetch_factor=fetch_factor,
fetch_transform=fetch_transform,
batch_transform=batch_transform,
)
train_loader = DataLoader(
scdata_train,
batch_size=None,
num_workers=num_workers,
prefetch_factor=fetch_factor+1,
persistent_workers=True,
pin_memory=True
)
# Test split with streaming for deterministic evaluation
test_strategy = Streaming(indices=test_idx)
scdata_test = scDataset(
data_collection=collection,
strategy=test_strategy,
batch_size=batch_size,
fetch_factor=fetch_factor,
fetch_transform=fetch_transform,
batch_transform=batch_transform,
)
test_loader = DataLoader(
scdata_test,
batch_size=None,
num_workers=num_workers,
prefetch_factor=fetch_factor+1,
persistent_workers=True,
pin_memory=True
)
4. Train and Evaluate a Linear Classifier for Cell Line Recognition#
We will train a simple linear classifier to predict the cell line from gene expression. The model will be trained for one epoch on the training set and evaluated on the test set.
[ ]:
# Model definition
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_genes = collection.n_vars
num_classes = len(cell_line_encoder.classes_)
model = nn.Linear(num_genes, num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
# Training loop (1 epoch)
model.train()
for i, (x, y) in enumerate(tqdm(train_loader, desc='Training')):
x = x.to(device)
y = y.to(device)
logits = model(x)
loss = loss_fn(logits, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i % 1000 == 0:
print(f'Train batch {i}: loss = {loss.item():.4f}')
# Evaluation loop (one pass over test set)
model.eval()
correct = 0
total = 0
with torch.no_grad():
for x, y in tqdm(test_loader, desc='Evaluating'):
x = x.to(device)
y = y.to(device)
logits = model(x)
preds = torch.argmax(logits, dim=1)
correct += (preds == y).sum().item()
total += y.size(0)
print(f'Test accuracy: {100.0 * correct / total:.2f}% ({correct}/{total})')
Training: 0%| | 7/68512 [00:04<9:25:04, 2.02it/s]
Train batch 0: loss = 3.9899
Training: 2%|▏ | 1050/68512 [00:22<19:27, 57.77it/s]
Train batch 1000: loss = 0.3440
Training: 3%|▎ | 2042/68512 [00:40<13:59, 79.20it/s]
Train batch 2000: loss = 0.0654
Training: 4%|▍ | 3070/68512 [00:57<07:00, 155.47it/s]
Train batch 3000: loss = 0.0337
Training: 6%|▌ | 4019/68512 [01:15<08:11, 131.25it/s]
Train batch 4000: loss = 0.1021
Training: 7%|▋ | 5074/68512 [01:36<17:49, 59.33it/s]
Train batch 5000: loss = 0.4971
Training: 9%|▉ | 6078/68512 [01:53<13:07, 79.33it/s]
Train batch 6000: loss = 0.1427
Training: 10%|█ | 7046/68512 [02:11<12:39, 80.96it/s]
Train batch 7000: loss = 0.0129
Training: 12%|█▏ | 8049/68512 [02:28<09:09, 110.06it/s]
Train batch 8000: loss = 0.1937
Training: 13%|█▎ | 9013/68512 [02:46<07:25, 133.53it/s]
Train batch 9000: loss = 0.3482
Training: 15%|█▍ | 10053/68512 [03:06<17:54, 54.43it/s]
Train batch 10000: loss = 0.0561
Training: 16%|█▌ | 11042/68512 [03:24<13:05, 73.19it/s]
Train batch 11000: loss = 0.0025
Training: 18%|█▊ | 12058/68512 [03:42<09:05, 103.56it/s]
Train batch 12000: loss = 0.0513
Training: 19%|█▉ | 13040/68512 [03:59<07:20, 125.81it/s]
Train batch 13000: loss = 0.5847
Training: 20%|██ | 13998/68512 [04:17<08:27, 107.43it/s]
Train batch 14000: loss = 0.1122
Training: 22%|██▏ | 15064/68512 [04:37<14:59, 59.45it/s]
Train batch 15000: loss = 0.4273
Training: 23%|██▎ | 16057/68512 [04:55<11:03, 79.10it/s]
Train batch 16000: loss = 0.2817
Training: 25%|██▍ | 17074/68512 [05:13<06:52, 124.59it/s]
Train batch 17000: loss = 0.1180
Training: 26%|██▋ | 18038/68512 [05:30<06:43, 124.95it/s]
Train batch 18000: loss = 1.0325
Training: 28%|██▊ | 18996/68512 [05:48<06:43, 122.75it/s]
Train batch 19000: loss = 0.0683
Training: 29%|██▉ | 20063/68512 [06:08<13:23, 60.33it/s]
Train batch 20000: loss = 0.0278
Training: 31%|███ | 21055/68512 [06:26<09:51, 80.28it/s]
Train batch 21000: loss = 0.2712
Training: 32%|███▏ | 22042/68512 [06:44<07:30, 103.09it/s]
Train batch 22000: loss = 0.0053
Training: 34%|███▎ | 23025/68512 [07:01<06:25, 118.01it/s]
Train batch 23000: loss = 0.1565
Training: 35%|███▍ | 23977/68512 [07:19<06:59, 106.27it/s]
Train batch 24000: loss = 0.0108
Training: 37%|███▋ | 25058/68512 [07:40<12:04, 59.95it/s]
Train batch 25000: loss = 0.2106
Training: 38%|███▊ | 26054/68512 [07:57<08:41, 81.48it/s]
Train batch 26000: loss = 0.0147
Training: 39%|███▉ | 27042/68512 [08:15<06:30, 106.25it/s]
Train batch 27000: loss = 0.3865
Training: 41%|████ | 28020/68512 [08:32<05:19, 126.71it/s]
Train batch 28000: loss = 0.9068
Training: 42%|████▏ | 29047/68512 [08:53<15:08, 43.45it/s]
Train batch 29000: loss = 0.1300
Training: 44%|████▍ | 30047/68512 [09:11<10:44, 59.64it/s]
Train batch 30000: loss = 0.4086
Training: 45%|████▌ | 31051/68512 [09:28<07:38, 81.75it/s]
Train batch 31000: loss = 0.0032
Training: 47%|████▋ | 32061/68512 [09:46<05:28, 110.96it/s]
Train batch 32000: loss = 0.2427
Training: 48%|████▊ | 33019/68512 [10:03<05:29, 107.56it/s]
Train batch 33000: loss = 0.5098
Training: 50%|████▉ | 34043/68512 [10:24<13:14, 43.40it/s]
Train batch 34000: loss = 1.5931
Training: 51%|█████ | 35070/68512 [10:42<06:53, 80.96it/s]
Train batch 35000: loss = 0.2432
Training: 53%|█████▎ | 36051/68512 [11:00<06:31, 82.90it/s]
Train batch 36000: loss = 0.6859
Training: 54%|█████▍ | 37051/68512 [11:17<04:08, 126.49it/s]
Train batch 37000: loss = 0.5389
Training: 55%|█████▌ | 38017/68512 [11:35<04:00, 126.80it/s]
Train batch 38000: loss = 0.5876
Training: 57%|█████▋ | 39053/68512 [11:55<08:37, 56.91it/s]
Train batch 39000: loss = 0.1008
Training: 58%|█████▊ | 40072/68512 [12:13<05:55, 79.97it/s]
Train batch 40000: loss = 0.0231
Training: 60%|█████▉ | 41054/68512 [12:31<04:22, 104.52it/s]
Train batch 41000: loss = 0.0160
Training: 61%|██████▏ | 42033/68512 [12:48<04:20, 101.71it/s]
Train batch 42000: loss = 0.3354
Training: 63%|██████▎ | 42993/68512 [13:06<03:31, 120.92it/s]
Train batch 43000: loss = 0.0277
Training: 64%|██████▍ | 44048/68512 [13:27<07:16, 56.06it/s]
Train batch 44000: loss = 0.2737
Training: 66%|██████▌ | 45055/68512 [13:44<04:53, 79.80it/s]
Train batch 45000: loss = 0.0011
Training: 67%|██████▋ | 46062/68512 [14:02<03:09, 118.53it/s]
Train batch 46000: loss = 0.2583
Training: 69%|██████▊ | 47017/68512 [14:20<03:30, 102.01it/s]
Train batch 47000: loss = 0.2095
Training: 70%|███████ | 47985/68512 [14:37<03:17, 104.11it/s]
Train batch 48000: loss = 0.5802
Training: 72%|███████▏ | 49064/68512 [14:58<05:15, 61.57it/s]
Train batch 49000: loss = 0.4957
Training: 73%|███████▎ | 50044/68512 [15:15<03:50, 79.98it/s]
Train batch 50000: loss = 0.5093
Training: 74%|███████▍ | 51041/68512 [15:33<02:48, 103.41it/s]
Train batch 51000: loss = 0.9492
Training: 76%|███████▌ | 52030/68512 [15:51<02:13, 123.06it/s]
Train batch 52000: loss = 0.1337
Training: 77%|███████▋ | 53035/68512 [16:12<06:23, 40.35it/s]
Train batch 53000: loss = 0.0673
Training: 79%|███████▉ | 54055/68512 [16:29<03:24, 70.62it/s]
Train batch 54000: loss = 0.4136
Training: 80%|████████ | 55055/68512 [16:47<02:42, 82.91it/s]
Train batch 55000: loss = 0.4310
Training: 82%|████████▏ | 56052/68512 [17:05<02:02, 101.96it/s]
Train batch 56000: loss = 0.0508
Training: 83%|████████▎ | 56999/68512 [17:22<01:47, 106.94it/s]
Train batch 57000: loss = 0.8195
Training: 85%|████████▍ | 58066/68512 [17:43<03:06, 56.01it/s]
Train batch 58000: loss = 0.6672
Training: 86%|████████▌ | 59073/68512 [18:01<01:58, 79.81it/s]
Train batch 59000: loss = 0.4724
Training: 88%|████████▊ | 60044/68512 [18:18<01:43, 81.81it/s]
Train batch 60000: loss = 0.7914
Training: 89%|████████▉ | 61031/68512 [18:36<01:13, 101.63it/s]
Train batch 61000: loss = 0.0464
Training: 90%|█████████ | 61996/68512 [18:53<01:01, 106.63it/s]
Train batch 62000: loss = 0.3595
Training: 92%|█████████▏| 63065/68512 [19:14<01:32, 59.13it/s]
Train batch 63000: loss = 0.2254
Training: 94%|█████████▎| 64059/68512 [19:31<00:55, 80.02it/s]
Train batch 64000: loss = 0.2061
Training: 95%|█████████▍| 65055/68512 [19:49<00:33, 102.15it/s]
Train batch 65000: loss = 0.2344
Training: 96%|█████████▋| 66030/68512 [20:06<00:23, 107.32it/s]
Train batch 66000: loss = 0.0641
Training: 98%|█████████▊| 66985/68512 [20:24<00:14, 107.90it/s]
Train batch 67000: loss = 0.0083
Training: 99%|█████████▉| 68060/68512 [20:45<00:07, 59.74it/s]
Train batch 68000: loss = 0.9251
Training: 100%|██████████| 68512/68512 [20:52<00:00, 54.72it/s]
Evaluating: 100%|██████████| 17130/17130 [05:17<00:00, 53.97it/s]
Test accuracy: 97.11% (1064575/1096284)
5. Summary and Next Steps#
You have now seen how to load the Tahoe-100M dataset, perform a stratified split, and train/test a linear classifier for cell line recognition using scDataset v0.2.0.
Key features demonstrated:
Strategy-based sampling: Used
BlockShufflingfor training randomization andStreamingfor deterministic evaluationCustom transforms: Applied
fetch_transformandbatch_transformto handle AnnCollection data formatEfficient data loading: Configured
fetch_factorand multiprocessing for optimal performance
Try exploring other sampling strategies:
ClassBalancedSamplingfor handling imbalanced cell type distributionsBlockWeightedSamplingfor custom sample weightingMultiIndexablefor handling multi-modal data (gene expression + protein data)
For more advanced usage and complete API documentation, see the scDataset documentation.