Source code for scdataset.experimental.auto_config

"""
Automatic configuration utilities for scDataset.

This module provides functions to automatically suggest optimal parameters
for scDataset based on system resources and data characteristics.

.. autosummary::
   :toctree: generated/

   suggest_parameters
   estimate_sample_size
"""

import os
import sys
import warnings
from typing import Any, Callable, Dict, Optional, Set

import numpy as np


def _deep_sizeof(obj: Any, seen: Optional[Set[int]] = None) -> int:
    """
    Recursively estimate memory size of an object in bytes.

    Handles numpy arrays, scipy sparse matrices, torch tensors,
    pandas DataFrames/Series, dicts, lists, tuples, strings,
    AnnData objects, and nested structures.

    Parameters
    ----------
    obj : Any
        The object to estimate size for.
    seen : set, optional
        Set of object IDs already counted (to avoid double-counting).

    Returns
    -------
    int
        Estimated size in bytes.
    """
    if seen is None:
        seen = set()

    obj_id = id(obj)
    if obj_id in seen:
        return 0  # Avoid counting same object twice
    seen.add(obj_id)

    # Pandas DataFrame - check before numpy arrays since DataFrames have nbytes too
    if hasattr(obj, "memory_usage") and hasattr(obj, "columns"):
        try:
            return int(obj.memory_usage(deep=True).sum())
        except (TypeError, AttributeError):
            # Fall through to other handlers if memory_usage fails
            pass  # noqa: B110

    # Pandas Series - check before numpy arrays since Series have nbytes too
    if (
        hasattr(obj, "memory_usage")
        and hasattr(obj, "index")
        and not hasattr(obj, "columns")
    ):
        try:
            return int(obj.memory_usage(deep=True))
        except (TypeError, AttributeError):
            # Fall through to other handlers if memory_usage fails
            pass  # noqa: B110

    # NumPy arrays - check dtype attribute to distinguish from other objects with nbytes
    if hasattr(obj, "nbytes") and hasattr(obj, "dtype") and hasattr(obj, "shape"):
        return int(obj.nbytes)

    # PyTorch tensors
    if hasattr(obj, "element_size") and hasattr(obj, "numel"):
        return int(obj.element_size() * obj.numel())

    # Scipy sparse matrices (CSR, CSC, COO, etc.)
    if hasattr(obj, "data") and hasattr(obj, "indices") and hasattr(obj, "format"):
        size = obj.data.nbytes + obj.indices.nbytes
        if hasattr(obj, "indptr"):
            size += obj.indptr.nbytes
        return int(size)

    # AnnCollection objects - count obs and internal structures
    # Check for AnnCollection by looking for adatas attribute and n_obs
    if hasattr(obj, "adatas") and hasattr(obj, "n_obs") and hasattr(obj, "n_vars"):
        size = 0
        # obs DataFrame (concatenated cell metadata) - this gets copied per worker
        if hasattr(obj, "obs") and obj.obs is not None:
            size += _deep_sizeof(obj.obs, seen)
        # var DataFrame (shared gene metadata)
        if hasattr(obj, "var") and obj.var is not None:
            size += _deep_sizeof(obj.var, seen)
        # obs_names and var_names (index objects)
        if hasattr(obj, "obs_names") and obj.obs_names is not None:
            size += _deep_sizeof(obj.obs_names, seen)
        if hasattr(obj, "var_names") and obj.var_names is not None:
            size += _deep_sizeof(obj.var_names, seen)
        # Internal adatas list
        for adata in obj.adatas:
            size += _deep_sizeof(adata, seen)
        return int(size)

    # AnnData objects - special handling for accurate estimation
    if hasattr(obj, "X") and hasattr(obj, "obs") and hasattr(obj, "var_names"):
        size = 0
        # X matrix (main data) - for backed mode, this doesn't load data
        if obj.X is not None:
            size += _deep_sizeof(obj.X, seen)
        # obs DataFrame (cell metadata for this sample)
        if hasattr(obj, "obs") and obj.obs is not None:
            size += _deep_sizeof(obj.obs, seen)
        # var DataFrame (gene metadata)
        if hasattr(obj, "var") and obj.var is not None:
            size += _deep_sizeof(obj.var, seen)
        # obsm matrices (e.g., embeddings)
        if hasattr(obj, "obsm") and obj.obsm is not None:
            for key in obj.obsm.keys():
                size += _deep_sizeof(obj.obsm[key], seen)
        # layers (alternative matrices like raw counts)
        if hasattr(obj, "layers") and obj.layers is not None:
            for key in obj.layers.keys():
                size += _deep_sizeof(obj.layers[key], seen)
        return int(size)

    # MultiIndexable from scdataset - has _indexables list and unstructured property
    if hasattr(obj, "_indexables") and hasattr(obj, "unstructured"):
        size = 0
        for indexable in obj._indexables:
            size += _deep_sizeof(indexable, seen)
        # unstructured is typically shared, count once
        if obj.unstructured:
            size += _deep_sizeof(obj.unstructured, seen)
        return int(size)

    # Dictionaries - recursive
    if isinstance(obj, dict):
        size = sys.getsizeof(obj)  # Dict overhead
        for k, v in obj.items():
            size += _deep_sizeof(k, seen)
            size += _deep_sizeof(v, seen)
        return int(size)

    # Lists and tuples - recursive
    if isinstance(obj, (list, tuple)):
        size = sys.getsizeof(obj)  # Container overhead
        for item in obj:
            size += _deep_sizeof(item, seen)
        return int(size)

    # Strings - UTF-8 encoded size is more accurate than sys.getsizeof
    if isinstance(obj, str):
        return len(obj.encode("utf-8"))

    # Bytes
    if isinstance(obj, bytes):
        return len(obj)

    # Default: sys.getsizeof for primitive types and unknown objects
    return sys.getsizeof(obj)


[docs] def estimate_sample_size( data_collection, n_samples: int = 16, fetch_callback: Optional[Callable] = None, fetch_transform: Optional[Callable] = None, batch_callback: Optional[Callable] = None, batch_transform: Optional[Callable] = None, ) -> int: """ Estimate the memory size of a single sample from the data collection. This function samples a few elements from the data collection and estimates the average memory size per sample in bytes using recursive deep size estimation. If transforms/callbacks are provided, they are applied to simulate the actual memory usage during training. Parameters ---------- data_collection : object Data collection to estimate sample size from. Must support indexing. n_samples : int, default=16 Number of samples to average over for estimation. fetch_callback : Callable, optional Custom fetch function. If provided, called as ``fetch_callback(data_collection, [index])`` to get the sample. This should match the ``fetch_callback`` parameter used with scDataset. fetch_transform : Callable, optional Transform to apply after fetching data. This should match the ``fetch_transform`` parameter used with scDataset. Applied to the fetched sample before size estimation. batch_callback : Callable, optional Custom batch extraction function. If provided, called as ``batch_callback(fetched_data, [0])`` to extract a single sample. This should match the ``batch_callback`` parameter used with scDataset. batch_transform : Callable, optional Transform to apply to batches. This should match the ``batch_transform`` parameter used with scDataset. Applied after fetch_transform. Returns ------- int Estimated size per sample in bytes. Examples -------- >>> import numpy as np >>> data = np.random.randn(1000, 2000) # 1000 samples, 2000 features >>> size = estimate_sample_size(data) >>> print(f"Estimated sample size: {size} bytes") Estimated sample size: 16000 bytes For AnnData with fetch transform (not runnable without data): .. code-block:: python from scdataset import adata_to_mindex size = estimate_sample_size(adata_collection, fetch_transform=adata_to_mindex) Notes ----- The estimation uses recursive deep size estimation that correctly handles: - NumPy arrays (using ``nbytes``) - Scipy sparse matrices (CSR, CSC, COO - counting data, indices, indptr) - PyTorch tensors (using ``element_size * numel``) - Pandas DataFrames/Series (using ``memory_usage(deep=True)``) - AnnData objects (counting X, obs, obsm, and layers) - Dictionaries and lists (recursively counting all elements) - Strings (UTF-8 encoded byte length) Shared objects (same ``id()``) are only counted once to avoid double-counting. When using backed AnnData (``backed='r'``), it's important to provide the ``fetch_transform`` parameter to get accurate memory estimates, as backed data remains on disk until transformed. """ n_samples = min(n_samples, len(data_collection)) sizes = [] # Check if data_collection is a MultiIndexable (uses list indexing for samples) is_multiindexable = hasattr(data_collection, "_indexables") and hasattr( data_collection, "unstructured" ) for i in range(n_samples): # Step 1: Fetch sample (using callback or default indexing) if fetch_callback is not None: sample = fetch_callback(data_collection, [i]) elif is_multiindexable: # MultiIndexable uses list indexing for samples, not integer indexing sample = data_collection[[i]] else: sample = data_collection[i] # Step 2: Apply fetch transform if provided if fetch_transform is not None: sample = fetch_transform(sample) # Step 3: Extract single item via batch_callback if provided # (This simulates what happens when extracting a batch of 1) if batch_callback is not None: sample = batch_callback(sample, [0]) # Step 4: Apply batch transform if provided if batch_transform is not None: sample = batch_transform(sample) sizes.append(_deep_sizeof(sample)) return int(np.mean(sizes)) if sizes else 0
[docs] def suggest_parameters( data_collection, batch_size: int, target_ram_fraction: float = 0.20, max_workers: int = 16, min_workers: int = 1, verbose: bool = True, fetch_callback: Optional[Callable] = None, fetch_transform: Optional[Callable] = None, batch_callback: Optional[Callable] = None, batch_transform: Optional[Callable] = None, ) -> Dict[str, Any]: r""" Suggest optimal parameters for scDataset based on system resources. This function analyzes the data collection and available system resources to suggest optimal values for ``num_workers``, ``fetch_factor``, and ``block_size`` parameters. Parameters ---------- data_collection : object The data collection to be used with scDataset. batch_size : int The batch size you plan to use. target_ram_fraction : float, default=0.20 Maximum fraction of available RAM to use for data loading. Default is 20% which leaves room for model and other processes. max_workers : int, default=16 Maximum number of workers to suggest. More than 16 workers typically has diminishing returns. min_workers : int, default=1 Minimum number of workers to suggest. verbose : bool, default=True If True, print detailed suggestions and explanations. fetch_callback : Callable, optional Custom fetch function. Pass the same function you will use with scDataset for accurate memory estimation. fetch_transform : Callable, optional Transform to apply after fetching data. Pass the same function you will use with scDataset for accurate memory estimation. batch_callback : Callable, optional Custom batch extraction function. Pass the same function you will use with scDataset for accurate memory estimation. batch_transform : Callable, optional Transform to apply to batches. Pass the same function you will use with scDataset for accurate memory estimation. Returns ------- dict Dictionary containing suggested parameters: - ``num_workers``: Suggested number of DataLoader workers - ``fetch_factor``: Suggested fetch factor for scDataset - ``block_size_conservative``: Block size for more randomness (fetch_factor // 2) - ``block_size_balanced``: Block size balancing randomness and throughput - ``block_size_aggressive``: Block size for maximum throughput (fetch_factor * 2) - ``prefetch_factor``: Suggested prefetch_factor for DataLoader - ``estimated_memory_per_fetch_mb``: Estimated memory per fetch operation in MB - ``system_info``: Dictionary with system information used for calculation Examples -------- >>> import numpy as np >>> from scdataset import scDataset, BlockShuffling >>> from scdataset.experimental import suggest_parameters >>> from torch.utils.data import DataLoader >>> >>> data = np.random.randn(10000, 200) >>> params = suggest_parameters(data, batch_size=64, verbose=False) >>> >>> # Use suggested parameters >>> strategy = BlockShuffling(block_size=params['block_size_balanced']) >>> dataset = scDataset( ... data, strategy, ... batch_size=64, ... fetch_factor=params['fetch_factor'] ... ) >>> loader = DataLoader( ... dataset, batch_size=None, ... num_workers=min(params['num_workers'], 2), # Limit for example ... prefetch_factor=params['prefetch_factor'] ... ) Notes ----- **Worker selection logic:** The number of workers is set to ``min(available_cores // 2, max_workers)``. Using half the cores leaves resources for the main process and model training. **Fetch factor selection logic:** The fetch factor is chosen such that the total data loaded by all workers does not exceed ``target_ram_fraction`` of available RAM. The calculation accounts for prefetching (prefetch_factor = fetch_factor + 1), which effectively doubles memory usage since both the current and prefetched data are in memory simultaneously: .. math:: 2 \\times batch\\_size \\times fetch\\_factor \\times num\\_workers \\times sample\\_size < target\\_ram\\_fraction \\times RAM The factor of 2 accounts for the prefetch buffer in the DataLoader. **Block size recommendations:** - ``block_size_conservative`` (fetch_factor // 2): More randomness, slightly lower throughput. Good for training where randomization is important. - ``block_size_balanced`` (fetch_factor): Balanced randomness and throughput. - ``block_size_aggressive`` (fetch_factor * 2): Higher throughput, less randomness. Block sizes smaller than ``fetch_factor // 2`` or larger than ``fetch_factor * 2`` have diminishing returns. Raises ------ ImportError If psutil is not installed (optional dependency). Warns ----- UserWarning If psutil is not available, uses conservative defaults. """ result = {} system_info = {} # Try to get system information try: import psutil available_ram = psutil.virtual_memory().available total_ram = psutil.virtual_memory().total cpu_count = os.cpu_count() or 4 system_info["available_ram_gb"] = available_ram / (1024**3) system_info["total_ram_gb"] = total_ram / (1024**3) system_info["cpu_count"] = cpu_count has_psutil = True except ImportError: warnings.warn( "psutil not installed. Using conservative defaults. " "Install psutil for better parameter suggestions: pip install psutil", stacklevel=2, ) # Conservative defaults available_ram = 8 * 1024**3 # Assume 8GB available total_ram = 16 * 1024**3 # Assume 16GB total cpu_count = 4 system_info["available_ram_gb"] = "unknown (psutil not installed)" system_info["total_ram_gb"] = "unknown (psutil not installed)" system_info["cpu_count"] = cpu_count has_psutil = False # Calculate num_workers num_workers = min(max(cpu_count // 2, min_workers), max_workers) result["num_workers"] = num_workers # Estimate sample size (applying transforms/callbacks for accurate estimation) sample_size = estimate_sample_size( data_collection, fetch_transform=fetch_transform, batch_transform=batch_transform, fetch_callback=fetch_callback, batch_callback=batch_callback, ) system_info["estimated_sample_size_bytes"] = sample_size # Calculate maximum fetch_factor based on RAM constraint # Formula: 2 * batch_size * fetch_factor * num_workers * sample_size < target_ram_fraction * available_ram # The factor of 2 accounts for prefetch_factor = fetch_factor + 1 (prefetch buffer doubles memory) target_ram = target_ram_fraction * available_ram if sample_size > 0 and batch_size > 0 and num_workers > 0: # Account for prefetch doubling memory (factor of 2) max_fetch_factor = int( target_ram / (2 * batch_size * num_workers * sample_size) ) # Clamp to reasonable range fetch_factor = max(1, min(max_fetch_factor, 256)) else: fetch_factor = 8 # Default fallback result["fetch_factor"] = fetch_factor # Calculate block sizes result["block_size_conservative"] = max(1, fetch_factor // 2) result["block_size_balanced"] = max(1, fetch_factor) result["block_size_aggressive"] = max(1, fetch_factor * 2) # Prefetch factor should be fetch_factor + 1 for optimal performance result["prefetch_factor"] = fetch_factor + 1 # Calculate estimated memory usage (includes prefetch buffer - hence * 2) memory_per_fetch = batch_size * fetch_factor * sample_size memory_total = memory_per_fetch * num_workers * 2 # * 2 for prefetch buffer result["estimated_memory_per_fetch_mb"] = memory_per_fetch / (1024**2) result["estimated_total_memory_mb"] = memory_total / (1024**2) result["system_info"] = system_info if verbose: print("=" * 60) print("scDataset Parameter Suggestions") print("=" * 60) print() print("System Information:") if has_psutil: print(f" Available RAM: {system_info['available_ram_gb']:.1f} GB") print(f" Total RAM: {system_info['total_ram_gb']:.1f} GB") else: print(" RAM info: Not available (install psutil)") print(f" CPU cores: {system_info['cpu_count']}") print( f" Estimated sample size: {sample_size:,} bytes ({sample_size/1024:.1f} KB)" ) print() print("Suggested Parameters:") print(f" num_workers: {num_workers}") print(f" fetch_factor: {fetch_factor}") print(f" prefetch_factor: {result['prefetch_factor']}") print() print("Block Size Options (choose based on your needs):") print(f" block_size_conservative: {result['block_size_conservative']}") print(" └─ More randomness, good for training") print(f" block_size_balanced: {result['block_size_balanced']}") print(" └─ Balanced randomness and throughput (recommended)") print(f" block_size_aggressive: {result['block_size_aggressive']}") print(" └─ Maximum throughput, less randomness") print() print("Memory Estimates (includes prefetch buffer):") print(f" Per fetch: {result['estimated_memory_per_fetch_mb']:.1f} MB") print( f" Total (all workers + prefetch): {result['estimated_total_memory_mb']:.1f} MB" ) print(f" Target RAM usage: {target_ram_fraction*100:.0f}% of available") print() print("Tips:") print(" • block_size = fetch_factor is optimal (recommended)") print(" • block_size < fetch_factor/2: diminishing returns on randomness") print(" • block_size > fetch_factor*2: diminishing returns on throughput") print(" • Increase fetch_factor if I/O is the bottleneck") print(" • Decrease num_workers if memory is constrained") print("=" * 60) return result