scdataset.experimental.estimate_sample_size#
- scdataset.experimental.estimate_sample_size(data_collection, n_samples: int = 16, fetch_callback: Callable | None = None, fetch_transform: Callable | None = None, batch_callback: Callable | None = None, batch_transform: Callable | None = None) int[source]#
Estimate the memory size of a single sample from the data collection.
This function samples a few elements from the data collection and estimates the average memory size per sample in bytes using recursive deep size estimation. If transforms/callbacks are provided, they are applied to simulate the actual memory usage during training.
- Parameters:
data_collection (object) – Data collection to estimate sample size from. Must support indexing.
n_samples (int, default=16) – Number of samples to average over for estimation.
fetch_callback (Callable, optional) – Custom fetch function. If provided, called as
fetch_callback(data_collection, [index])to get the sample. This should match thefetch_callbackparameter used with scDataset.fetch_transform (Callable, optional) – Transform to apply after fetching data. This should match the
fetch_transformparameter used with scDataset. Applied to the fetched sample before size estimation.batch_callback (Callable, optional) – Custom batch extraction function. If provided, called as
batch_callback(fetched_data, [0])to extract a single sample. This should match thebatch_callbackparameter used with scDataset.batch_transform (Callable, optional) – Transform to apply to batches. This should match the
batch_transformparameter used with scDataset. Applied after fetch_transform.
- Returns:
Estimated size per sample in bytes.
- Return type:
Examples
>>> import numpy as np >>> data = np.random.randn(1000, 2000) # 1000 samples, 2000 features >>> size = estimate_sample_size(data) >>> print(f"Estimated sample size: {size} bytes") Estimated sample size: 16000 bytes
For AnnData with fetch transform (not runnable without data):
from scdataset import adata_to_mindex size = estimate_sample_size(adata_collection, fetch_transform=adata_to_mindex)
Notes
The estimation uses recursive deep size estimation that correctly handles:
NumPy arrays (using
nbytes)Scipy sparse matrices (CSR, CSC, COO - counting data, indices, indptr)
PyTorch tensors (using
element_size * numel)Pandas DataFrames/Series (using
memory_usage(deep=True))AnnData objects (counting X, obs, obsm, and layers)
Dictionaries and lists (recursively counting all elements)
Strings (UTF-8 encoded byte length)
Shared objects (same
id()) are only counted once to avoid double-counting.When using backed AnnData (
backed='r'), it’s important to provide thefetch_transformparameter to get accurate memory estimates, as backed data remains on disk until transformed.