Data Transforms and Callbacks#
scDataset provides a flexible data transformation pipeline through four hook points:
fetch_callback, fetch_transform, batch_callback, and batch_transform.
Understanding when and how to use each one is key to efficient data loading.
Built-in Transform Functions#
scDataset includes pre-built transform functions for common use cases:
from scdataset.transforms import adata_to_mindex, hf_tahoe_to_tensor, bionemo_to_tensor
- adata_to_mindex
Transforms an AnnData batch into a
MultiIndexableobject. Handles sparse matrices, backed data materialization, and optional observation columns.- hf_tahoe_to_tensor
Converts HuggingFace sparse gene expression data to dense tensors or
MultiIndexableobjects.- bionemo_to_tensor
Fetch callback for BioNeMo’s
SingleCellMemMapDataset. Handles the sparse matrix format used by BioNeMo and returns dense tensors. Requiresbionemo-scdl.
See the examples below for detailed usage.
Overview#
The data loading pipeline in scDataset follows this flow:
┌─────────────────────────────────────────────────────────────────┐
│ scDataset Pipeline │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 1. Strategy generates indices │
│ ↓ │
│ 2. fetch_callback(collection, indices) → raw_data │
│ [or default: collection[indices]] │
│ ↓ │
│ 3. fetch_transform(raw_data) → transformed_data │
│ [Applied to ENTIRE fetch: batch_size × fetch_factor] │
│ ↓ │
│ 4. batch_callback(transformed_data, batch_indices) → batch │
│ [or default: transformed_data[batch_indices]] │
│ ↓ │
│ 5. batch_transform(batch) → final_batch │
│ [Applied to EACH batch: batch_size samples] │
│ ↓ │
│ 6. yield final_batch │
│ │
└─────────────────────────────────────────────────────────────────┘
The Four Hook Points#
fetch_callback#
Purpose: Custom function to fetch data from the collection using indices.
Signature: (data_collection, indices) -> fetched_data
When to use:
When your data collection doesn’t support standard indexing
When you need special handling for batch vs single indexing
When working with custom data formats (e.g., BioNeMo sparse matrices)
Default behavior: data_collection[indices]
Example - Custom fetch for a database:
def fetch_from_database(db_connection, indices):
"""Fetch rows from a database by indices."""
query = f"SELECT * FROM data WHERE id IN ({','.join(map(str, indices))})"
return db_connection.execute(query).fetchall()
dataset = scDataset(
db_connection,
Streaming(),
batch_size=64,
fetch_callback=fetch_from_database
)
Example - BioNeMo sparse matrices:
from bionemo.scdl.util.torch_dataloader_utils import collate_sparse_matrix_batch
def bionemo_to_tensor(data_collection, idx):
"""Handle BioNeMo's sparse matrix format."""
if isinstance(idx, int):
return collate_sparse_matrix_batch([data_collection[idx]]).to_dense()
else:
batch = [data_collection[int(i)] for i in idx]
return collate_sparse_matrix_batch(batch).to_dense()
dataset = scDataset(
bionemo_data,
BlockShuffling(),
batch_size=64,
fetch_callback=bionemo_to_tensor
)
fetch_transform#
Purpose: Transform data after fetching, before splitting into batches.
Signature: (fetched_data) -> transformed_data
When to use:
Converting data formats (e.g., AnnData to numpy)
Materializing lazy/backed data into memory
Operations that are more efficient on larger chunks
Creating MultiIndexable objects for downstream indexing
Applied to: Entire fetch (batch_size × fetch_factor samples)
Built-in: adata_to_mindex#
Use adata_to_mindex() for AnnData and AnnCollection:
from scdataset import scDataset, BlockShuffling
from scdataset.transforms import adata_to_mindex
from functools import partial
# Basic usage - returns MultiIndexable with 'X' key
dataset = scDataset(
ann_collection,
BlockShuffling(),
batch_size=64,
fetch_transform=adata_to_mindex
)
# With observation columns
fetch_fn = partial(adata_to_mindex, columns=['cell_type', 'batch'])
dataset = scDataset(
ann_collection,
BlockShuffling(),
batch_size=64,
fetch_transform=fetch_fn
)
Built-in: hf_tahoe_to_tensor#
Use hf_tahoe_to_tensor() for Tahoe-100M HuggingFace sparse dataset:
from scdataset import scDataset, Streaming
from scdataset.transforms import hf_tahoe_to_tensor
from functools import partial
# Returns dense tensor (default)
dataset = scDataset(
hf_dataset,
Streaming(),
batch_size=64,
fetch_transform=hf_tahoe_to_tensor
)
# With custom gene count
fetch_fn = partial(hf_tahoe_to_tensor, num_genes=62713)
# With dict output format for multi-modal data
fetch_fn = partial(
hf_tahoe_to_tensor,
output_format='dict',
obs_columns=['cell_type', 'batch']
)
Built-in: bionemo_to_tensor#
Use bionemo_to_tensor() for BioNeMo’s SingleCellMemMapDataset:
from scdataset import scDataset, BlockShuffling
from scdataset.transforms import bionemo_to_tensor
from bionemo.scdl.io.single_cell_memmap_dataset import SingleCellMemMapDataset
# Load BioNeMo dataset
bionemo_data = SingleCellMemMapDataset(data_path='/path/to/data')
# Use bionemo_to_tensor as fetch_callback (not fetch_transform)
dataset = scDataset(
bionemo_data,
BlockShuffling(block_size=32),
batch_size=64,
fetch_factor=32,
fetch_callback=bionemo_to_tensor # Handles sparse matrix collation
)
Note
bionemo_to_tensor is a fetch_callback, not a fetch_transform, because it
needs access to both the data collection and indices to handle BioNeMo’s sparse
matrix format correctly. Requires pip install bionemo-scdl.
Custom Transform Example#
For custom data formats, write your own transform:
import scipy.sparse as sp
from scdataset import MultiIndexable
def custom_fetch_transform(batch, columns=None):
"""Transform custom batch to MultiIndexable."""
batch = batch.to_memory() # Materialize if backed
X = batch.X
if sp.issparse(X):
X = X.toarray()
data_dict = {'X': X}
if columns is not None:
for col in columns:
data_dict[col] = batch.obs[col].values
return MultiIndexable(data_dict)
batch_callback#
Purpose: Custom function to extract a batch from transformed data.
Signature: (transformed_data, batch_indices) -> batch
When to use:
When transformed data doesn’t support standard indexing
When you need special slicing logic
Rarely needed if fetch_transform produces indexable output
Default behavior: transformed_data[batch_indices]
Example - Custom batch extraction:
def custom_batch_callback(data, indices):
"""Extract batch with custom logic."""
# Maybe data is a custom container
return {
'features': data.get_features(indices),
'labels': data.get_labels(indices)
}
dataset = scDataset(
custom_data,
Streaming(),
batch_size=64,
batch_callback=custom_batch_callback
)
batch_transform#
Purpose: Transform each individual batch before yielding.
Signature: (batch) -> transformed_batch
When to use:
Normalization per batch
Data augmentation
Converting to model-ready format
Adding noise or other batch-level operations
Applied to: Each individual batch (batch_size samples)
Example - Normalization and augmentation:
import torch
def batch_transform(batch):
"""Normalize and augment batch."""
X, y = batch['X'], batch['labels']
# Convert to tensor if needed
if not isinstance(X, torch.Tensor):
X = torch.from_numpy(X).float()
# Normalize per sample
X = (X - X.mean(dim=1, keepdim=True)) / (X.std(dim=1, keepdim=True) + 1e-8)
# Add small noise for regularization (training only)
if training:
X = X + torch.randn_like(X) * 0.01
return X, torch.from_numpy(y).long()
dataset = scDataset(
data,
BlockShuffling(),
batch_size=64,
batch_transform=batch_transform
)
Example - Log transformation for gene expression:
import numpy as np
def log_transform(batch):
"""Apply log1p transformation to gene expression."""
return np.log1p(batch)
dataset = scDataset(
gene_expression_data,
Streaming(),
batch_size=64,
batch_transform=log_transform
)
Best Practices#
Use fetch_transform for heavy operations: Operations like densifying sparse matrices or loading data from disk are more efficient when applied to larger chunks (entire fetch) rather than individual batches.
Use batch_transform for sample-wise operations: Normalization, augmentation, and format conversion that operate per-sample belong in batch_transform.
Return indexable objects from fetch_transform: If you use fetch_transform, make sure it returns something that can be indexed (numpy array, tensor, MultiIndexable, etc.) unless you also provide a custom batch_callback.
Use MultiIndexable for multi-modal data: When your transform creates multiple outputs (X, y, metadata), wrap them in MultiIndexable for synchronized indexing.
Profile your transforms: Use Python’s
cProfileorline_profilerto ensure transforms aren’t bottlenecks.
Common Use Cases#
Use Case |
fetch_transform |
batch_transform |
|---|---|---|
Densify sparse matrix |
✓ |
|
Load backed AnnData |
✓ |
|
Per-sample normalize |
✓ |
✓ |
Data augmentation |
✓ |
✓ |
Convert to tensor |
✓ |
✓ |
Add labels from obs |
✓ |
|
Log transformation |
✓ |
✓ |
See Also#
Examples - More complete examples
Quick Start Guide - Getting started guide
API Reference - API reference