scdataset.scDataset

scdataset.scDataset#

class scdataset.scDataset(data_collection, strategy: SamplingStrategy, batch_size: int, fetch_factor: int = 16, drop_last: bool = False, fetch_transform: Callable | None = None, batch_transform: Callable | None = None, fetch_callback: Callable | None = None, batch_callback: Callable | None = None)[source]

Bases: IterableDataset

Iterable PyTorch Dataset for on-disk data collections with flexible sampling strategies.

This dataset implementation provides efficient iteration over large on-disk data collections using configurable sampling strategies. It supports various data transformations, custom fetch/batch callbacks, and automatic handling of multiprocessing workers.

Parameters:
  • data_collection (object) – The data collection to sample from (e.g., AnnCollection, HuggingFace Dataset, numpy array, etc.). Must support indexing (__getitem__) and length (__len__) operations.

  • strategy (SamplingStrategy) – Strategy for sampling indices from the data collection. Determines the order and selection of samples.

  • batch_size (int) – Number of samples per minibatch. Must be positive.

  • fetch_factor (int, default=1) – Multiplier for fetch size relative to batch size. Higher values may improve I/O efficiency by fetching more data at once.

  • drop_last (bool, default=False) – Whether to drop the last incomplete batch if it contains fewer than batch_size samples.

  • fetch_transform (Callable, optional) – Function to transform data after fetching but before batching. Applied to the entire fetch (multiple batches worth of data).

  • batch_transform (Callable, optional) – Function to transform each individual batch before yielding.

  • fetch_callback (Callable, optional) – Custom function to fetch data given indices. Should accept (data_collection, indices) and return the fetched data. If None, uses default indexing (data_collection[indices]).

  • batch_callback (Callable, optional) – Custom function to extract batch data from fetched data. Should accept (fetched_data, batch_indices) and return the batch. If None, uses default indexing (fetched_data[batch_indices]).

Variables:
  • collection (object) – The underlying data collection.

  • strategy (SamplingStrategy) – The sampling strategy being used.

  • batch_size (int) – Size of each batch.

  • fetch_factor (int) – Fetch size multiplier.

  • drop_last (bool) – Whether incomplete batches are dropped.

  • fetch_size (int) – Total number of samples fetched at once (batch_size * fetch_factor).

  • sort_before_fetch (bool) – Always True. Indices are sorted before fetching for optimal I/O.

Raises:
  • ValueError – If batch_size or fetch_factor is not positive.

  • TypeError – If data_collection doesn’t support required operations or strategy is not a SamplingStrategy instance.

Examples

>>> from scdataset import scDataset
>>> from scdataset.strategy import Streaming
>>> import numpy as np
>>> # Simple streaming dataset
>>> data = np.random.randn(1000, 50)  # 1000 samples, 50 features
>>> strategy = Streaming()
>>> dataset = scDataset(data, strategy, batch_size=32)
>>> len(dataset)  # Number of batches
32
>>> # With custom transforms
>>> def normalize_batch(batch):
...     return (batch - batch.mean()) / batch.std()
>>> dataset = scDataset(
...     data, strategy, batch_size=32,
...     batch_transform=normalize_batch
... )
>>> # Iterate through batches
>>> for batch in dataset:
...     print(batch.shape)  # (32, 50) for most batches
...     break

See also

scdataset.strategy.SamplingStrategy

Base class for sampling strategies

scdataset.strategy.Streaming

Sequential sampling without shuffling

scdataset.strategy.BlockShuffling

Block-based shuffling

scdataset.strategy.BlockWeightedSampling

Weighted sampling with blocks

scdataset.strategy.ClassBalancedSampling

Automatic class balancing

Notes

The dataset automatically handles PyTorch’s multiprocessing by distributing fetch ranges among workers. Each worker gets a different subset of the data to avoid duplication.

Data is fetched in chunks of size batch_size * fetch_factor and then divided into batches. This can improve I/O efficiency, especially for datasets where accessing non-contiguous indices is expensive.

Methods

__init__(data_collection, strategy, batch_size)

Initialize the scDataset.

__init__(data_collection, strategy: SamplingStrategy, batch_size: int, fetch_factor: int = 16, drop_last: bool = False, fetch_transform: Callable | None = None, batch_transform: Callable | None = None, fetch_callback: Callable | None = None, batch_callback: Callable | None = None)[source]

Initialize the scDataset.

Parameters:
  • data_collection (object) – Data collection supporting indexing and len().

  • strategy (SamplingStrategy) – Sampling strategy instance.

  • batch_size (int) – Positive integer for batch size.

  • fetch_factor (int, default=16) – Positive integer for fetch size multiplier.

  • drop_last (bool, default=False) – Whether to drop incomplete batches.

  • fetch_transform (Callable, optional) – Transform applied to fetched data.

  • batch_transform (Callable, optional) – Transform applied to each batch.

  • fetch_callback (Callable, optional) – Custom fetch function.

  • batch_callback (Callable, optional) – Custom batch extraction function.

__len__() int[source]

Return the number of batches in the dataset.

Calculates the number of batches that will be yielded by the iterator based on the sampling strategy’s effective length and the batch size.

Returns:

Number of batches in the dataset.

Return type:

int

Examples

>>> from scdataset.strategy import Streaming
>>> dataset = scDataset(range(100), Streaming(), batch_size=10)
>>> len(dataset)
10
>>> # With drop_last=True
>>> dataset = scDataset(range(105), Streaming(), batch_size=10, drop_last=True)
>>> len(dataset)  # 105 // 10 = 10 (drops 5 samples)
10
>>> # With drop_last=False (default)
>>> dataset = scDataset(range(105), Streaming(), batch_size=10, drop_last=False)
>>> len(dataset)  # ceil(105 / 10) = 11
11

Notes

When drop_last=True, only complete batches are counted. When drop_last=False, the last incomplete batch is included in the count.

__iter__()[source]

Yield batches of data according to the sampling strategy.

Creates an iterator that yields batches of data by:

  1. Getting indices from the sampling strategy

  2. Dividing indices into fetch ranges

  3. Distributing fetch ranges among workers (if multiprocessing)

  4. Fetching data in chunks and applying fetch transforms

  5. Dividing fetched data into batches and applying batch transforms

  6. Yielding transformed batches

Yields:

object – Batches of data after applying all transforms. The exact type depends on the data collection and any applied transforms.

Examples

>>> from scdataset.strategy import Streaming
>>> import numpy as np
>>> data = np.random.randn(100, 10)
>>> dataset = scDataset(data, Streaming(), batch_size=5)
>>> for i, batch in enumerate(dataset):
...     print(f"Batch {i}: shape {batch.shape}")
...     if i >= 2:  # Just show first few batches
...         break
Batch 0: shape (5, 10)
Batch 1: shape (5, 10)
Batch 2: shape (5, 10)

Notes

The fetch-then-batch approach can improve I/O efficiency by:

  • Sorting indices before fetching for better disk access patterns

  • Fetching multiple batches worth of data at once

  • Reducing the number of data access operations

Shuffling behavior is controlled by the sampling strategy’s _shuffle_before_yield attribute.