Skip to content

[BUG] quantile method of BaseDistribution fails for multi-index inputs #678

@marrov

Description

@marrov

Describe the bug

The quantile() method in skpro.distributions.base.BaseDistribution fails when the distribution has a MultiIndex. The error occurs because the _Indexer.__getitem__ method incorrectly handles the special case for MultiIndex row keys when both row and column indices are provided as pandas Index objects (e.g., dist.loc[x.index, x.columns]).

To Reproduce

import pandas as pd
import numpy as np
from skpro.distributions import Normal

# Create a distribution with MultiIndex
index = pd.MultiIndex.from_tuples([
    ('h0_0', 'h1_0', '2000-03-31'),
    ('h0_0', 'h1_0', '2000-04-01'),
    ('h0_1', 'h1_0', '2000-03-31'),
    ('h0_1', 'h1_0', '2000-04-01'),
], names=['h0', 'h1', 'time'])

columns = pd.Index(['c0'])

# Create distribution with parameters
mu = pd.DataFrame(np.random.randn(4, 1), index=index, columns=columns)
sigma = pd.DataFrame(np.abs(np.random.randn(4, 1)), index=index, columns=columns)

dist = Normal(mu=mu, sigma=sigma, index=index, columns=columns)

# This should work but raises KeyError
quantiles = dist.quantile([0.1, 0.5, 0.9])

Expected behavior

The quantile() method should successfully compute quantiles for distributions with MultiIndex, returning a DataFrame with the same MultiIndex as the distribution and columns representing the requested quantile levels.

Environment

  • Operating System: macOS
  • Python version: 3.12
  • skpro version: (installed from package)
  • pandas version: 2.x

Additional context

The bug is caused by two issues in _base.py:

  1. Line ~1839-1842 in _Indexer.__getitem__: The special case handling for MultiIndex incorrectly treats (pd.MultiIndex, pd.Index) tuples as a single row key instead of separate row and column indices.

  2. Line ~248-310 in _get_indexer_like_pandas: Missing early check for when keys is already a pd.Index (including pd.MultiIndex), causing it to fall through to logic that attempts to use get_locs() on full tuple keys.

The fix could involve:

  1. Adding a check in _Indexer.__getitem__ to only apply the MultiIndex special case when key elements are not Index-like objects
  2. Adding an early return in _get_indexer_like_pandas when keys is already a pd.Index to use get_indexer_for() directly

In my fork, the following versions of BaseDistribution._get_indexer_like_pandas and _Indexer.__getitem__ fix this issue:

    def _get_indexer_like_pandas(self, index, keys):
        """Return indexer for keys in index.

        A unified helper that mimics pandas' get_indexer_for but supports:

        - scalar key (e.g., "a", ("a", 1))
        - tuple key (partial or full)
        - list of keys (partial or full)
        - works for both Index and MultiIndex

        Returns
        -------
        np.ndarray of positions (integers)
        """
        # regular index, not multiindex
        if not isinstance(index, pd.MultiIndex):
            return index.get_indexer_for(keys)

        # if isinstance(index, pd.MultiIndex):

        # If keys is already a pd.Index (including MultiIndex), use get_indexer_for
        # This handles the case where keys is passed as x.index from a DataFrame
        if isinstance(keys, pd.Index):
            return index.get_indexer_for(keys)

        if is_scalar_notnone(keys) or isinstance(keys, tuple):
            keys = [keys]

        n_levels = index.nlevels

        # Check if keys are full keys (tuples with same length as number of levels)
        # If so, use get_indexer_for which handles full tuple keys correctly
        def _is_full_key(key):
            return isinstance(key, tuple) and len(key) == n_levels

        # If all keys are full keys (not slices), use get_indexer_for directly
        # This handles MultiIndex keys correctly
        if all(_is_full_key(k) for k in keys):
            return index.get_indexer_for(keys)

        # Use get_locs for each key (partial key or slice)
        ilocs = []
        for key in keys:
            if isinstance(key, slice):
                ilocs.append(index.slice_indexer(key.start, key.stop, key.step))
            elif _is_full_key(key):
                # For full tuple keys, use get_loc to get the position
                iloc = index.get_loc(key)
                if isinstance(iloc, slice):
                    iloc = np.arange(len(index))[iloc]
                elif isinstance(iloc, int):
                    iloc = np.array([iloc])
                ilocs.append(iloc)
            else:
                # For partial keys, use get_locs
                if not isinstance(key, tuple):
                    key = (key,)
                iloc = index.get_locs(key)
                if isinstance(iloc, slice):
                    iloc = np.arange(len(index))[iloc]
                ilocs.append(iloc)
        return np.concatenate(ilocs) if ilocs else np.array([], dtype=int)
    def __getitem__(self, key):
        """Getitem dunder, for use in my_distr.loc[index] and my_distr.iloc[index]."""

        def is_noneslice(obj):
            res = isinstance(obj, slice)
            res = res and obj.start is None and obj.stop is None and obj.step is None
            return res

        ref = self.ref
        indexer = getattr(ref, self.method)

        # handle special case of multiindex in loc with single tuple key
        # This handles cases like dist.loc[('a', 'b')] where ('a', 'b') is a
        # single row label in a MultiIndex. We should NOT trigger this when
        # the key contains Index objects (e.g., dist.loc[x.index, x.columns]).
        def _is_index_like(obj):
            return isinstance(obj, (pd.Index, pd.MultiIndex, list, np.ndarray, slice))

        if isinstance(key, tuple) and not any(isinstance(k, tuple) for k in key):
            # Only treat as single row key if no element is Index-like
            if not any(_is_index_like(k) for k in key):
                if isinstance(ref.index, pd.MultiIndex) and self.method == "_loc":
                    if type(ref).__name__ != "Empirical":
                        return indexer(rowidx=key, colidx=None)

        # general case
        if isinstance(key, tuple):
            if not len(key) == 2:
                raise ValueError(
                    "there should be one or two keys when calling .loc, "
                    "e.g., mydist[key], or mydist[key1, key2]"
                )
            rows = key[0]
            cols = key[1]
            if is_noneslice(rows) and is_noneslice(cols):
                return ref
            elif is_noneslice(cols):
                return indexer(rowidx=rows, colidx=None)
            elif is_noneslice(rows):
                return indexer(rowidx=None, colidx=cols)
            else:
                return indexer(rowidx=rows, colidx=cols)
        else:
            return indexer(rowidx=key, colidx=None)

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions