Source code for scdataset.transforms

"""
Transform functions for scDataset.

This module provides utility transform functions for common data formats
like AnnData/AnnCollection, HuggingFace datasets, and BioNeMo datasets.
These transforms can be used as ``fetch_callback``, ``fetch_transform``,
``batch_callback``, or ``batch_transform`` arguments in scDataset.

.. autosummary::
   :toctree: generated/

   adata_to_mindex
   hf_tahoe_to_tensor
   bionemo_to_tensor
"""

from typing import List, Optional, Sequence, Union

import numpy as np
import torch

__all__ = [
    "adata_to_mindex",
    "hf_tahoe_to_tensor",
    "bionemo_to_tensor",
]


[docs] def adata_to_mindex(batch, columns: Optional[List[str]] = None): """ Transform AnnData/AnnCollection batch to MultiIndexable with optional obs columns. This transform converts a batch from an AnnCollection (or backed AnnData) into a MultiIndexable object containing the expression matrix and optionally selected observation columns. The MultiIndexable can then be indexed in subsequent batch operations. Parameters ---------- batch : AnnData-like Batch from AnnCollection or backed AnnData. Must have: - ``.to_memory()`` method (for AnnCollection/backed AnnData) - ``.X`` attribute (expression matrix) - ``.obs`` attribute (observation metadata) columns : list of str, optional List of observation column names to include in the output. If None, only the X matrix is included. Returns ------- MultiIndexable A MultiIndexable object with: - ``'X'``: Dense expression matrix as numpy array - Additional keys for each column in ``columns`` (as numpy arrays) Examples -------- >>> # Basic usage - just X matrix >>> from scdataset import scDataset, BlockShuffling >>> from scdataset.transforms import adata_to_mindex >>> >>> dataset = scDataset( ... ann_collection, ... BlockShuffling(), ... batch_size=64, ... fetch_transform=adata_to_mindex ... ) >>> # With observation columns using functools.partial >>> from functools import partial >>> fetch_fn = partial(adata_to_mindex, columns=['cell_type', 'batch']) >>> dataset = scDataset( ... ann_collection, ... BlockShuffling(), ... batch_size=64, ... fetch_transform=fetch_fn ... ) >>> for batch in dataset: ... X = batch['X'] ... cell_types = batch['cell_type'] ... break Notes ----- This transform calls ``.to_memory()`` to materialize the AnnData object, which is necessary when working with backed or lazy AnnCollection objects. Sparse matrices are automatically converted to dense numpy arrays for compatibility with standard indexing operations. See Also -------- MultiIndexable : Container for synchronized multi-modal data """ # Import here to avoid circular imports from .multiindexable import MultiIndexable # Import scipy.sparse locally to avoid hard dependency try: import scipy.sparse as sp except ImportError: sp = None # Materialize the AnnData batch in memory # Handle different AnnData-like types: # - AnnData: use to_memory() if backed # - AnnCollectionView: use to_adata() if hasattr(batch, "to_adata"): # AnnCollectionView from AnnCollection indexing batch = batch.to_adata() elif hasattr(batch, "to_memory"): # Backed AnnData batch = batch.to_memory() # else: already in-memory AnnData, use as-is X = batch.X # Densify if X is a sparse matrix if sp is not None and sp.issparse(X): X = X.toarray() obs = batch.obs # Create dict with X and all obs columns as numpy arrays data_dict = {"X": X} if columns is not None: for col in columns: data_dict[col] = obs[col].values multi = MultiIndexable(data_dict) return multi
[docs] def hf_tahoe_to_tensor(batch, num_genes: int = 62713): """ Transform HuggingFace Tahoe-100M sparse gene expression data to dense tensors. This transform converts sparse gene expression data stored in HuggingFace format (with separate 'genes' and 'expressions' arrays) into dense PyTorch tensors suitable for model training. Parameters ---------- batch : dict or list Batch of data from HuggingFace dataset. Can be: - dict with 'genes' and 'expressions' keys (list of arrays) - list of dicts, each with 'genes' and 'expressions' keys num_genes : int, default=62713 Total number of genes (dimension of output tensor). Default is the Tahoe-100M gene count. Returns ------- torch.Tensor Dense tensor of shape (batch_size, num_genes) with gene expression values. Examples -------- >>> # With scDataset >>> from scdataset import scDataset, BlockShuffling >>> from scdataset.transforms import hf_tahoe_to_tensor >>> >>> dataset = scDataset( ... hf_dataset, ... BlockShuffling(), ... batch_size=64, ... fetch_transform=hf_tahoe_to_tensor ... ) Notes ----- This transform is specifically designed for datasets like Tahoe-100M that store sparse gene expression data in HuggingFace Datasets format, where each sample has variable-length arrays of gene indices and their expression values. The transform efficiently converts the sparse representation to dense tensors using numpy operations before converting to PyTorch, which is faster than building sparse PyTorch tensors directly. """ if isinstance(batch, dict): # Extract numpy arrays from batch batch_genes = batch["genes"] # List of numpy arrays batch_expr = batch["expressions"] # List of numpy arrays elif isinstance(batch, list): # Extract numpy arrays from batch batch_genes = [item["genes"] for item in batch] batch_expr = [item["expressions"] for item in batch] else: raise ValueError("Batch must be a dictionary or a list of dictionaries.") batch_size = len(batch_genes) # Generate batch indices using numpy lengths = [len(arr) for arr in batch_genes] batch_indices_np = np.concatenate( [np.full(length, i, dtype=np.int64) for i, length in enumerate(lengths)] ) # Concatenate all genes and expressions in numpy first gene_indices_np = np.concatenate(batch_genes) values_np = np.concatenate(batch_expr) # Single conversion to tensors batch_indices = torch.from_numpy(batch_indices_np) gene_indices = torch.from_numpy(gene_indices_np) values = torch.from_numpy(values_np).float() # Create combined indices tensor indices = torch.stack([batch_indices, gene_indices], dim=0) # Create dense tensor in one assignment dense_tensor = torch.zeros(batch_size, num_genes, dtype=values.dtype) dense_tensor[indices[0], indices[1]] = values return dense_tensor
[docs] def bionemo_to_tensor( data_collection, idx: Union[int, slice, Sequence[int], np.ndarray, torch.Tensor] ) -> torch.Tensor: """ Fetch callback for BioNeMo SingleCellMemMapDataset. This callback provides custom indexing logic for BioNeMo's SingleCellMemMapDataset, which returns sparse matrices that need to be collated and densified for use with scDataset. Use this as a ``fetch_callback`` in scDataset. Parameters ---------- data_collection : SingleCellMemMapDataset The BioNeMo dataset to fetch from. idx : int, slice, sequence, or tensor Indices to fetch. Can be: - int: Single index - slice: Slice object - list/ndarray/tensor: Batch of indices Returns ------- torch.Tensor Dense tensor of shape (batch_size, num_genes) with expression values. Examples -------- >>> from scdataset import scDataset, BlockShuffling >>> from scdataset.transforms import bionemo_to_tensor >>> from bionemo.scdl.io.single_cell_memmap_dataset import SingleCellMemMapDataset >>> >>> bionemo_data = SingleCellMemMapDataset(data_path='/path/to/data') >>> dataset = scDataset( ... bionemo_data, ... BlockShuffling(), ... batch_size=64, ... fetch_callback=bionemo_to_tensor ... ) Notes ----- This callback requires the bionemo-scdl package to be installed. The collate function handles the sparse matrix format used by BioNeMo. Raises ------ ImportError If bionemo-scdl is not installed. """ # Import bionemo collate function locally to avoid hard dependency try: from bionemo.scdl.util.torch_dataloader_utils import collate_sparse_matrix_batch except ImportError as e: raise ImportError( "bionemo_to_tensor requires the bionemo-scdl package. " "Install it with: pip install bionemo-scdl" ) from e if isinstance(idx, int): # Single index return collate_sparse_matrix_batch( [data_collection.__getitem__(idx)] ).to_dense() elif isinstance(idx, slice): # Slice: convert to a list of indices indices = list(range(*idx.indices(len(data_collection)))) batch_tensors = [data_collection.__getitem__(i) for i in indices] return collate_sparse_matrix_batch(batch_tensors).to_dense() elif isinstance(idx, (list, np.ndarray, torch.Tensor)): # Batch indexing if isinstance(idx, torch.Tensor): idx = idx.tolist() batch_tensors = [data_collection.__getitem__(int(i)) for i in idx] return collate_sparse_matrix_batch(batch_tensors).to_dense() else: raise TypeError(f"Unsupported index type: {type(idx)}")