Source code for scdataset.scdataset

"""
Iterable PyTorch Dataset for on-disk data collections.

This module provides the main :class:`scDataset` class for creating efficient
iterable datasets from on-disk data collections with flexible sampling
strategies and customizable data transformation pipelines.

.. autosummary::
   :toctree: generated/

   scDataset
"""

from typing import Optional, List, Union, Callable
import warnings

import numpy as np
from torch.utils.data import IterableDataset, get_worker_info
from .strategy import SamplingStrategy


[docs] class scDataset(IterableDataset): """ Iterable PyTorch Dataset for on-disk data collections with flexible sampling strategies. This dataset implementation provides efficient iteration over large on-disk data collections using configurable sampling strategies. It supports various data transformations, custom fetch/batch callbacks, and automatic handling of multiprocessing workers. Parameters ---------- data_collection : object The data collection to sample from (e.g., AnnCollection, HuggingFace Dataset, numpy array, etc.). Must support indexing (``__getitem__``) and length (``__len__``) operations. strategy : SamplingStrategy Strategy for sampling indices from the data collection. Determines the order and selection of samples. batch_size : int Number of samples per minibatch. Must be positive. fetch_factor : int, default=1 Multiplier for fetch size relative to batch size. Higher values may improve I/O efficiency by fetching more data at once. drop_last : bool, default=False Whether to drop the last incomplete batch if it contains fewer than ``batch_size`` samples. fetch_transform : Callable, optional Function to transform data after fetching but before batching. Applied to the entire fetch (multiple batches worth of data). batch_transform : Callable, optional Function to transform each individual batch before yielding. fetch_callback : Callable, optional Custom function to fetch data given indices. Should accept ``(data_collection, indices)`` and return the fetched data. If None, uses default indexing (``data_collection[indices]``). batch_callback : Callable, optional Custom function to extract batch data from fetched data. Should accept ``(fetched_data, batch_indices)`` and return the batch. If None, uses default indexing (``fetched_data[batch_indices]``). Attributes ---------- collection : object The underlying data collection. strategy : SamplingStrategy The sampling strategy being used. batch_size : int Size of each batch. fetch_factor : int Fetch size multiplier. drop_last : bool Whether incomplete batches are dropped. fetch_size : int Total number of samples fetched at once (batch_size * fetch_factor). sort_before_fetch : bool Always True. Indices are sorted before fetching for optimal I/O. Raises ------ ValueError If batch_size or fetch_factor is not positive. TypeError If data_collection doesn't support required operations or strategy is not a SamplingStrategy instance. Examples -------- >>> from scdataset import scDataset >>> from scdataset.strategy import Streaming >>> import numpy as np >>> # Simple streaming dataset >>> data = np.random.randn(1000, 50) # 1000 samples, 50 features >>> strategy = Streaming() >>> dataset = scDataset(data, strategy, batch_size=32) >>> len(dataset) # Number of batches 32 >>> # With custom transforms >>> def normalize_batch(batch): ... return (batch - batch.mean()) / batch.std() >>> dataset = scDataset( ... data, strategy, batch_size=32, ... batch_transform=normalize_batch ... ) >>> # Iterate through batches >>> for batch in dataset: ... print(batch.shape) # (32, 50) for most batches ... break See Also -------- scdataset.strategy.SamplingStrategy : Base class for sampling strategies scdataset.strategy.Streaming : Sequential sampling without shuffling scdataset.strategy.BlockShuffling : Block-based shuffling scdataset.strategy.BlockWeightedSampling : Weighted sampling with blocks scdataset.strategy.ClassBalancedSampling : Automatic class balancing Notes ----- The dataset automatically handles PyTorch's multiprocessing by distributing fetch ranges among workers. Each worker gets a different subset of the data to avoid duplication. Data is fetched in chunks of size ``batch_size * fetch_factor`` and then divided into batches. This can improve I/O efficiency, especially for datasets where accessing non-contiguous indices is expensive. """
[docs] def __init__( self, data_collection, strategy: SamplingStrategy, batch_size: int, fetch_factor: int = 16, drop_last: bool = False, fetch_transform: Optional[Callable] = None, batch_transform: Optional[Callable] = None, fetch_callback: Optional[Callable] = None, batch_callback: Optional[Callable] = None ): """ Initialize the scDataset. Parameters ---------- data_collection : object Data collection supporting indexing and len(). strategy : SamplingStrategy Sampling strategy instance. batch_size : int Positive integer for batch size. fetch_factor : int, default=16 Positive integer for fetch size multiplier. drop_last : bool, default=False Whether to drop incomplete batches. fetch_transform : Callable, optional Transform applied to fetched data. batch_transform : Callable, optional Transform applied to each batch. fetch_callback : Callable, optional Custom fetch function. batch_callback : Callable, optional Custom batch extraction function. """ # Input validation if batch_size <= 0: raise ValueError("batch_size must be positive") if fetch_factor <= 0: raise ValueError("fetch_factor must be positive") if not hasattr(data_collection, '__len__') or not hasattr(data_collection, '__getitem__'): raise TypeError("data_collection must support indexing and len()") if not isinstance(strategy, SamplingStrategy): raise TypeError("strategy must be an instance of SamplingStrategy") self.collection = data_collection self.strategy = strategy self.batch_size = batch_size self.fetch_factor = fetch_factor self.drop_last = drop_last self.fetch_size = self.batch_size * self.fetch_factor self.sort_before_fetch = True # Always sort before fetch as per new design # Store callback functions self.fetch_transform = fetch_transform self.batch_transform = batch_transform self.fetch_callback = fetch_callback self.batch_callback = batch_callback
[docs] def __len__(self) -> int: """ Return the number of batches in the dataset. Calculates the number of batches that will be yielded by the iterator based on the sampling strategy's effective length and the batch size. Returns ------- int Number of batches in the dataset. Examples -------- >>> from scdataset.strategy import Streaming >>> dataset = scDataset(range(100), Streaming(), batch_size=10) >>> len(dataset) 10 >>> # With drop_last=True >>> dataset = scDataset(range(105), Streaming(), batch_size=10, drop_last=True) >>> len(dataset) # 105 // 10 = 10 (drops 5 samples) 10 >>> # With drop_last=False (default) >>> dataset = scDataset(range(105), Streaming(), batch_size=10, drop_last=False) >>> len(dataset) # ceil(105 / 10) = 11 11 Notes ----- When ``drop_last=True``, only complete batches are counted. When ``drop_last=False``, the last incomplete batch is included in the count. """ # Get the total number of samples from the sampling strategy n = self.strategy.get_len(self.collection) if self.drop_last: # When dropping the last incomplete batch, calculate based on complete batches only return n // self.batch_size else: # When keeping the last incomplete batch, round up return (n + self.batch_size - 1) // self.batch_size
[docs] def __iter__(self): """ Yield batches of data according to the sampling strategy. Creates an iterator that yields batches of data by: 1. Getting indices from the sampling strategy 2. Dividing indices into fetch ranges 3. Distributing fetch ranges among workers (if multiprocessing) 4. Fetching data in chunks and applying fetch transforms 5. Dividing fetched data into batches and applying batch transforms 6. Yielding transformed batches Yields ------ object Batches of data after applying all transforms. The exact type depends on the data collection and any applied transforms. Examples -------- >>> from scdataset.strategy import Streaming >>> import numpy as np >>> data = np.random.randn(100, 10) >>> dataset = scDataset(data, Streaming(), batch_size=5) >>> for i, batch in enumerate(dataset): ... print(f"Batch {i}: shape {batch.shape}") ... if i >= 2: # Just show first few batches ... break Batch 0: shape (5, 10) Batch 1: shape (5, 10) Batch 2: shape (5, 10) Notes ----- The fetch-then-batch approach can improve I/O efficiency by: - Sorting indices before fetching for better disk access patterns - Fetching multiple batches worth of data at once - Reducing the number of data access operations Shuffling behavior is controlled by the sampling strategy's ``_shuffle_before_yield`` attribute. """ worker_info = get_worker_info() # Generate seed for sampling strategy if worker_info is None: rng = np.random.default_rng() else: rng = np.random.default_rng(worker_info.seed - worker_info.id) # Get indices from sampling strategy indices = self.strategy.get_indices(self.collection, rng=rng) # Calculate fetch ranges n = len(indices) fetch_size = self.fetch_size num_fetches = (n + fetch_size - 1) // fetch_size fetch_ranges = [(i * fetch_size, min((i + 1) * fetch_size, n)) for i in range(num_fetches)] # Handle multiprocessing by distributing fetch ranges among workers if worker_info is not None: per_worker = num_fetches // worker_info.num_workers remainder = num_fetches % worker_info.num_workers if worker_info.id < remainder: start = worker_info.id * (per_worker + 1) end = start + per_worker + 1 else: start = worker_info.id * per_worker + remainder end = start + per_worker fetch_ranges = fetch_ranges[start:end] # Process each fetch range for fetch_start, fetch_end in fetch_ranges: fetch_indices = indices[fetch_start:fetch_end] if self.sort_before_fetch: fetch_indices = np.sort(fetch_indices) # Use custom fetch callback if provided, otherwise use default indexing if self.fetch_callback is not None: data = self.fetch_callback(self.collection, fetch_indices) else: data = self.collection[fetch_indices] # Call fetch transform if provided if self.fetch_transform is not None: data = self.fetch_transform(data) if self.strategy._shuffle_before_yield: shuffle_indices = rng.permutation(len(fetch_indices)) else: shuffle_indices = np.arange(len(fetch_indices)) # Yield batches batch_start = 0 while batch_start < len(fetch_indices): batch_end = min(batch_start + self.batch_size, len(fetch_indices)) # Handle drop_last if self.drop_last and (batch_end - batch_start) < self.batch_size: break # Get batch indices batch_indices = shuffle_indices[batch_start:batch_end] # Use custom batch callback if provided, otherwise use default indexing if self.batch_callback is not None: batch_data = self.batch_callback(data, batch_indices) else: batch_data = data[batch_indices] # Call batch transform if provided if self.batch_transform is not None: batch_data = self.batch_transform(batch_data) yield batch_data batch_start = batch_end