Source code for scdataset.multiindexable

"""
Multi-indexable data structure for synchronized indexing.

This module provides the :class:`MultiIndexable` class for grouping multiple
indexable objects that should be indexed together using the same indices.
This is particularly useful for multi-modal data or when working with
features and labels that need to stay synchronized.

.. autosummary::
   :toctree: generated/

   MultiIndexable
"""

from typing import Any, Dict, List, Optional, Sequence, Union

import numpy as np


[docs] class MultiIndexable: """ Container for multiple indexable objects that should be indexed together. This class allows you to group multiple indexable objects (arrays, lists, etc.) and index them synchronously. It's particularly useful for scenarios like: - Multi-modal single-cell data (gene expression + protein data) - Features and labels (X, y) that need to stay aligned - Multiple data modalities that share the same sample dimension The class supports both positional and named access to the contained indexables, and ensures all indexables have the same length along the first dimension. Additionally, it supports storing unstructured metadata that is not indexed but remains accessible after indexing operations. This is useful for keeping metadata like gene names, dataset info, or other non-sample-aligned data. Parameters ---------- *indexables : indexable objects or dict Variable number of indexable objects that should be indexed together, OR a single dictionary where keys become names and values are indexables. All indexables must have the same length in the first dimension. names : list of str, optional Names for the indexables when using positional arguments. Must have the same length as the number of indexables. Cannot be used with dictionary input. unstructured : dict, optional Dictionary of non-indexable metadata. This data is preserved unchanged when the MultiIndexable is indexed/subsetted. Useful for storing metadata like gene names, dataset descriptions, or configuration. **named_indexables : dict, optional Named indexable objects passed as keyword arguments. Cannot be used together with positional indexables. Attributes ---------- names : list of str or None Names of the indexables if provided, None otherwise. count : int Number of indexables contained in this object. unstructured : dict Dictionary of non-indexable metadata (empty dict if none provided). Raises ------ ValueError If indexables have different lengths along the first dimension, or if the number of names doesn't match the number of indexables. TypeError If both positional and keyword indexables are provided, or if unstructured is not a dictionary. Examples -------- Create with positional arguments: >>> import numpy as np >>> x = np.random.randn(100, 50) >>> y = np.random.randint(0, 3, 100) >>> multi = MultiIndexable(x, y, names=['features', 'labels']) >>> len(multi) 100 >>> multi.count 2 Create with dictionary as positional argument: >>> data_dict = { ... 'genes': np.random.randn(100, 2000), ... 'proteins': np.random.randn(100, 100) ... } >>> multi = MultiIndexable(data_dict) >>> subset = multi[10:20] # Get samples 10-19 from both modalities >>> subset['genes'].shape (10, 2000) Create with keyword arguments: >>> multi = MultiIndexable( ... genes=np.random.randn(100, 2000), ... proteins=np.random.randn(100, 100) ... ) >>> multi.names ['genes', 'proteins'] Create with unstructured metadata: >>> gene_names = ['Gene_' + str(i) for i in range(2000)] >>> multi = MultiIndexable( ... X=np.random.randn(100, 2000), ... unstructured={'gene_names': gene_names, 'dataset_name': 'MyDataset'} ... ) >>> multi.unstructured['gene_names'][:3] ['Gene_0', 'Gene_1', 'Gene_2'] >>> subset = multi[10:20] # Unstructured data is preserved >>> subset.unstructured['dataset_name'] 'MyDataset' Access by name or position: >>> multi = MultiIndexable(x, y, names=['x', 'y']) >>> same_x1 = multi[0] # Access by position >>> same_x2 = multi['x'] # Access by name >>> np.array_equal(same_x1, same_x2) True Use with scDataset: >>> from scdataset import scDataset, Streaming >>> dataset = scDataset(multi, Streaming(), batch_size=32) >>> for batch in dataset: ... genes, proteins = batch[0], batch[1] # or batch['genes'], batch['proteins'] ... break See Also -------- scdataset.scDataset : Main dataset class that can use MultiIndexable objects """
[docs] def __init__( self, *indexables, names: Optional[List[str]] = None, unstructured: Optional[Dict[str, Any]] = None, **named_indexables, ): """ Initialize MultiIndexable with indexable objects. Can be initialized in four ways: 1. Positional: MultiIndexable(x, y, z) 2. Positional with names: MultiIndexable(x, y, names=['x', 'y']) 3. Dictionary as positional: MultiIndexable({'x': x_data, 'y': y_data}) 4. Named keywords: MultiIndexable(x=x_data, y=y_data) All variants support the optional ``unstructured`` parameter for non-indexable metadata. """ # Handle different initialization patterns if indexables and named_indexables: raise TypeError("Cannot provide both positional and named indexables") # Validate and store unstructured data if unstructured is not None: if not isinstance(unstructured, dict): raise TypeError( f"unstructured must be a dictionary, got {type(unstructured).__name__}" ) self._unstructured = unstructured.copy() else: self._unstructured = {} # Check for single dictionary as positional argument if ( len(indexables) == 1 and isinstance(indexables[0], dict) and not named_indexables and names is None ): # Dictionary passed as positional argument data_dict = indexables[0] self._names = list(data_dict.keys()) self._indexables = list(data_dict.values()) self._mapping = data_dict.copy() elif named_indexables: # Dictionary-style initialization self._names = list(named_indexables.keys()) self._indexables = list(named_indexables.values()) self._mapping = named_indexables.copy() elif indexables: # Positional initialization self._indexables = list(indexables) if names is not None: if len(names) != len(indexables): raise ValueError( f"Length of names ({len(names)}) must match number of " f"indexables ({len(indexables)})" ) self._names = list(names) self._mapping = dict(zip(self._names, self._indexables)) else: self._names = None self._mapping = None else: raise TypeError("Must provide at least one indexable object") # Validate that all indexables have the same length if not self._indexables: raise ValueError("No indexables provided") try: first_len = len(self._indexables[0]) except TypeError as err: raise TypeError("All indexables must support len() operation") from err for i, indexable in enumerate(self._indexables[1:], start=1): try: curr_len = len(indexable) except TypeError as err: raise TypeError( f"Indexable at position {i} does not support len() operation" ) from err if curr_len != first_len: name_info = f" ('{self._names[i]}')" if self._names else "" raise ValueError( f"All indexables must have the same length. " f"First indexable has length {first_len}, but indexable {i}{name_info} " f"has length {curr_len}" )
@property def names(self) -> Optional[List[str]]: """Names of the indexables, if provided.""" return self._names.copy() if self._names else None @property def count(self) -> int: """Number of indexables contained in this object.""" return len(self._indexables) @property def unstructured(self) -> Dict[str, Any]: """ Dictionary of non-indexable metadata. This data is preserved unchanged when the MultiIndexable is indexed or subsetted. Returns the internal dictionary directly for efficiency; modify with care if you need to preserve the original. Returns ------- dict Dictionary containing unstructured metadata. Examples -------- >>> multi = MultiIndexable( ... X=np.zeros((10, 5)), ... unstructured={'gene_names': ['A', 'B', 'C', 'D', 'E']} ... ) >>> multi.unstructured['gene_names'] ['A', 'B', 'C', 'D', 'E'] """ return self._unstructured @property def unstructured_keys(self) -> List[str]: """ List of keys in the unstructured metadata dictionary. Returns ------- list of str Keys present in the unstructured dictionary. Examples -------- >>> multi = MultiIndexable( ... X=np.zeros((10, 5)), ... unstructured={'gene_names': ['A', 'B'], 'dataset': 'test'} ... ) >>> multi.unstructured_keys ['gene_names', 'dataset'] """ return list(self._unstructured.keys())
[docs] def __getitem__(self, key: Union[int, str, slice, Sequence[int], np.ndarray]): """ Index the MultiIndexable object. Parameters ---------- key : int, str, slice, or array-like - int: Return the indexable at that position - str: Return the indexable with that name (if names provided) - slice/array: Return new MultiIndexable with subsets at those sample indices Returns ------- object or MultiIndexable - Single indexable if key is int or str - New MultiIndexable with subsets if key represents sample indices Notes ----- When subsetting with slices or arrays, the unstructured metadata is preserved unchanged in the resulting MultiIndexable. """ if isinstance(key, int): # Return the indexable at position key if key < 0: key = len(self._indexables) + key if not 0 <= key < len(self._indexables): raise IndexError( f"Index {key} out of range for {len(self._indexables)} indexables" ) return self._indexables[key] elif isinstance(key, str): # Return the named indexable if self._mapping is None: raise KeyError( f"No named indexables available. Available indices: 0-{len(self._indexables)-1}" ) if key not in self._mapping: raise KeyError( f"Key '{key}' not found. Available keys: {list(self._mapping.keys())}" ) return self._mapping[key] else: # Sample indices - return new MultiIndexable with subsets try: subset_indexables = [indexable[key] for indexable in self._indexables] except (IndexError, TypeError) as e: raise IndexError(f"Invalid indices for sample selection: {e}") from e # Preserve names and unstructured data if any if self._mapping: return MultiIndexable( **dict(zip(self._names, subset_indexables)), unstructured=self._unstructured if self._unstructured else None, ) else: return MultiIndexable( *subset_indexables, unstructured=self._unstructured if self._unstructured else None, )
[docs] def __len__(self) -> int: """Return the number of samples (length of first dimension).""" return len(self._indexables[0]) if self._indexables else 0
[docs] def __repr__(self) -> str: """Return string representation of the MultiIndexable.""" n_samples = len(self) if self._names: indexable_info = f"names={self._names}" else: indexable_info = f"count={self.count}" # Add unstructured info if present if self._unstructured: unstructured_info = f", unstructured_keys={list(self._unstructured.keys())}" else: unstructured_info = "" return ( f"MultiIndexable({indexable_info}, samples={n_samples}{unstructured_info})" )
[docs] def __iter__(self): """Iterate over indexables.""" return iter(self._indexables)
[docs] def items(self): """ Iterate over (name, indexable) pairs. Yields ------ tuple (name, indexable) pairs if names are available, (index, indexable) pairs otherwise. """ if self._names: for name, indexable in zip(self._names, self._indexables): yield name, indexable else: for i, indexable in enumerate(self._indexables): yield i, indexable
[docs] def keys(self): """ Get the names or indices of indexables. Returns ------- list List of names if available, list of indices otherwise. """ return self._names.copy() if self._names else list(range(len(self._indexables)))
[docs] def values(self): """ Get the indexable objects. Returns ------- list List of indexable objects. """ return self._indexables.copy()