diff --git a/pyproject.toml b/pyproject.toml index 969a661..6acaa59 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,8 @@ scikit-learn = "^1.6.1" requests-toolbelt = "^1.0.0" responses = "^0.25.7" aioresponses = "^0.7.8" +huggingface_hub = "^0.24.0" +pyarrow = "^18.0.0" [tool.poetry.group.dev.dependencies] diff --git a/tfbpapi/AbstractHfAPI.py b/tfbpapi/AbstractHfAPI.py new file mode 100644 index 0000000..87d64cb --- /dev/null +++ b/tfbpapi/AbstractHfAPI.py @@ -0,0 +1,263 @@ +import logging +import os +from pathlib import Path +from typing import Any, Iterable, Mapping +import hashlib +import json + +from .ParamsDict import ParamsDict + + +class AbstractHfAPI: + """ + Abstract base class for creating Hugging Face API clients. + """ + + + def __init__( + self, + repo_id: str = "", + repo_type: str | None = "dataset", + revision: str | None = None, + token: str = "", + cache_dir: str | Path | None = None, + local_dir: str | Path | None = None, + **kwargs, + ): + """ + Initialize the HF-backed API client. + + :param repo_id: The repo identifier on HF (e.g., "user/dataset"). + :param repo_type: One of {"model", "dataset", "space"}. Defaults to "dataset". + :param revision: Optional git revision (branch, tag, or commit SHA). + :param token: Authentication token. Defaults to env `HF_TOKEN` or `TOKEN`. + :param cache_dir: HF cache dir; passed to snapshot_download. + :param local_dir: Optional local materialization dir; if supported by the + installed `huggingface_hub`, downloaded files will be placed here. + :param kwargs: Additional keyword arguments for ParamsDict construction. + """ + self.logger = logging.getLogger(self.__class__.__name__) + self._token = token or os.getenv("HF_TOKEN") + self.repo_id = repo_id + self.repo_type = repo_type + self.revision = revision + self.cache_dir = Path(cache_dir) if cache_dir is not None else None + self.local_dir = Path(local_dir) if local_dir is not None else None + self._snapshot_path: Path | None = None + + self.params = ParamsDict( + params=kwargs.pop("params", {}), + valid_keys=kwargs.pop("valid_keys", []), + ) + + self._result_cache: dict[str, dict[str, Any]] = {} + + @property + def token(self) -> str: + return self._token + + @token.setter + def token(self, value: str) -> None: + self._token = value + + def _hash_params(self, params: Mapping[str, Any] | None) -> str: + """Stable hash for query-result caching across identical filters. + + Incorporate repo identifiers to avoid cross-repo collisions. + """ + payload = { + "repo_id": self.repo_id, + "repo_type": self.repo_type, + "revision": self.revision, + "params": params or {}, + } + data = json.dumps(payload, sort_keys=True, default=str).encode("utf-8") + return hashlib.sha1(data).hexdigest() + + def ensure_snapshot( + self, + *, + allow_patterns: Iterable[str] | None = None, + ignore_patterns: Iterable[str] | None = None, + local_files_only: bool = False, + ) -> Path: + """ + Ensure a local snapshot exists using HF cache system and return its path. + """ + try: + # Lazy import to avoid hard dependency during package import + from huggingface_hub import snapshot_download # type: ignore + except Exception as e: # pragma: no cover - optional dependency + raise ImportError( + "huggingface_hub is required to use AbstractHfAPI.ensure_snapshot()" + ) from e + + kwargs: dict[str, Any] = { + "repo_id": self.repo_id, + "repo_type": self.repo_type, + "revision": self.revision, + "cache_dir": str(self.cache_dir) if self.cache_dir is not None else None, + "allow_patterns": list(allow_patterns) if allow_patterns else None, + "ignore_patterns": list(ignore_patterns) if ignore_patterns else None, + "local_files_only": local_files_only, + "token": self.token or None, + } + + if self.local_dir is not None: + kwargs_with_local = dict(kwargs) + kwargs_with_local["local_dir"] = str(self.local_dir) + try: + snapshot = snapshot_download(**kwargs_with_local) # type: ignore[arg-type] + except TypeError: + self.logger.info( + "Installed huggingface_hub does not support local_dir; retrying without it" + ) + snapshot = snapshot_download(**kwargs) # type: ignore[arg-type] + else: + snapshot = snapshot_download(**kwargs) # type: ignore[arg-type] + + self._snapshot_path = Path(snapshot) + return self._snapshot_path + + def fetch_repo_metadata(self, *, filename_candidates: Iterable[str] | None = None) -> dict[str, Any] | None: + """ + Fetch lightweight metadata for a file in the repo using get_hf_file_metadata. + + Tries a list of candidate files (e.g., README.md, dataset_infos.json). Returns + a dict with keys {commit_hash, etag, location, filename} or None if none + of the candidates exist. + """ + try: + from huggingface_hub import get_hf_file_metadata, hf_hub_url # type: ignore + except Exception as e: # pragma: no cover - optional dependency + raise ImportError( + "huggingface_hub is required to use AbstractHfAPI.fetch_repo_metadata()" + ) from e + + candidates: list[str] = list(filename_candidates or [ + "README.md", + "README.MD", + "README.rst", + "README.txt", + "dataset_infos.json", + ]) + + for fname in candidates: + try: + url = hf_hub_url( + repo_id=self.repo_id, + filename=fname, + repo_type=self.repo_type, + revision=self.revision, + ) + meta = get_hf_file_metadata(url=url, token=self.token or None) + return { + "commit_hash": getattr(meta, "commit_hash", None), + "etag": getattr(meta, "etag", None), + "location": getattr(meta, "location", None), + "filename": fname, + } + except Exception as e: # EntryNotFoundError, RepositoryNotFoundError, etc. + self.logger.debug(f"Metadata not found for {fname}: {e}") + continue + return None + + def open_dataset( + self, + snapshot_path: str | Path | None = None, + *, + files: Iterable[str | Path] | None = None, + format: str | None = None, + partitioning: str | None = "hive", + ): + """ + Build a pyarrow.dataset from the local snapshot. + + :param snapshot_path: Path to the HF snapshot (defaults to last ensured). + :param files: Optional iterable of files to include explicitly. + :param format: Optional format hint (e.g., "parquet"). + :param partitioning: Optional partitioning strategy (default: "hive"). + :return: A pyarrow.dataset.Dataset instance. + """ + try: + import pyarrow.dataset as ds # type: ignore + except Exception as e: # pragma: no cover - optional dependency + raise ImportError( + "pyarrow is required to use AbstractHfAPI.open_dataset()" + ) from e + + base_path = Path(snapshot_path) if snapshot_path is not None else self._snapshot_path + if base_path is None: + raise RuntimeError("Snapshot not ensured. Call ensure_snapshot() first.") + + if files is not None: + file_list = [str(Path(f)) for f in files] + dataset = ds.dataset(file_list, format=format or None, partitioning=partitioning) + else: + dataset = ds.dataset(str(base_path), format=format or None, partitioning=partitioning) + return dataset + + def build_query(self, params: Mapping[str, Any]) -> dict[str, Any]: + """ + Translate user params to dataset-specific query plan. + + Expected keys in the returned dict (subclasses should implement): + - file_patterns: Optional[list[str]] of glob patterns relative to snapshot root + - filter: Optional[pyarrow.dataset.Expression] to pushdown into scans + - format: Optional[str] dataset format hint (e.g., "parquet") + """ + raise NotImplementedError( + f"`build_query()` is not implemented for {self.__class__.__name__}" + ) + + def read(self, params: Mapping[str, Any] | None = None, **kwargs) -> dict[str, Any]: + """ + Execute a read using HF snapshot + optional dataset scanning. + + Returns a dict with keys: {"metadata": , "data": }. + """ + params = params or {} + cache_key = self._hash_params(params) + if cache_key in self._result_cache: + self.logger.debug("Returning cached query result") + return self._result_cache[cache_key] + + # Build query plan from params + plan = self.build_query(params) + file_patterns: Iterable[str] | None = plan.get("file_patterns") + filter_expr = plan.get("filter") + fmt_hint: str | None = plan.get("format") or kwargs.get("format") + + # Optionally restrict download via allow_patterns; otherwise download all + snapshot_path = self.ensure_snapshot( + allow_patterns=file_patterns, + ) + + # If patterns were used, build the explicit file list under the snapshot + files: list[str] | None = None + if file_patterns: + files = [] + for patt in file_patterns: + files.extend([str(p) for p in Path(snapshot_path).rglob(patt)]) + + dataset = self.open_dataset(snapshot_path=snapshot_path, files=files, format=fmt_hint) + + # Materialize to a pyarrow.Table; subclasses can choose filter types + try: + table = dataset.to_table(filter=filter_expr) if filter_expr is not None else dataset.to_table() + except Exception as e: + # Add context then re-raise + self.logger.error(f"Failed to scan dataset with provided filter: {e}") + raise + + metadata = self.fetch_repo_metadata() + result = {"metadata": metadata, "data": table} + + # Memoize + self._result_cache[cache_key] = result + return result + + def __call__(self, *args, **kwargs): + # Delegate to read() for ergonomic usage + params = args[0] if args else kwargs.pop("params", None) + return self.read(params=params or {}, **kwargs) \ No newline at end of file diff --git a/tfbpapi/HfHu2007Reimand2010API.py b/tfbpapi/HfHu2007Reimand2010API.py new file mode 100644 index 0000000..c930c4f --- /dev/null +++ b/tfbpapi/HfHu2007Reimand2010API.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +from typing import Any, Iterable, Mapping + +from .AbstractHfAPI import AbstractHfAPI + + +class HfHu2007Reimand2010API(AbstractHfAPI): + """HF-backed API for BrentLab/hu_2007_reimand_2010 dataset. + + Exposes a minimal interface to scan the Parquet file(s) and return an + Arrow table along with lightweight repo metadata via the base class. + + Parameters accepted in params for filtering: + - regulator_locus_tag: Optional[str|Iterable[str]] — equality/IN filter + - regulator_symbol: Optional[str|Iterable[str]] — equality/IN filter + - target_locus_tag: Optional[str|Iterable[str]] — equality/IN filter + - target_symbol: Optional[str|Iterable[str]] — equality/IN filter + - effect_min/effect_max: Optional[float] — numeric range on log2 fold-change + - pval_min/pval_max: Optional[float] — numeric range on p-value + """ + + DEFAULT_REPO_ID = "BrentLab/hu_2007_reimand_2010" + + def __init__( + self, + repo_id: str = DEFAULT_REPO_ID, + repo_type: str | None = "dataset", + revision: str | None = None, + token: str = "", + cache_dir: str | None = None, + local_dir: str | None = None, + **kwargs: Any, + ) -> None: + super().__init__( + repo_id=repo_id, + repo_type=repo_type, + revision=revision, + token=token, + cache_dir=cache_dir, + local_dir=local_dir, + **kwargs, + ) + + def build_query(self, params: Mapping[str, Any]) -> dict[str, Any]: + """Construct a query plan for this dataset. + + The dataset contains a single Parquet file "hu_2007_reimand_2010.parquet". + We support basic equality/IN and numeric range filtering on provided columns. + """ + # Files to fetch + file_patterns: list[str] = ["hu_2007_reimand_2010.parquet"] + + filter_expr = None + filters: list[Any] = [] + + # Attempt to construct pushdown expressions; fall back silently if pyarrow not present + try: + import pyarrow.dataset as ds # type: ignore + except Exception: + ds = None # type: ignore + + def as_iterable(value: Any) -> list[Any]: + if isinstance(value, (list, tuple, set)): + return list(value) + return [value] + + if ds is not None: + # String equality/IN filters + for col in ( + "regulator_locus_tag", + "regulator_symbol", + "target_locus_tag", + "target_symbol", + ): + if col in params and params[col] is not None: + values = as_iterable(params[col]) + if len(values) == 1: + filters.append(ds.field(col) == values[0]) + else: + filters.append(ds.field(col).isin(values)) + + # Numeric ranges + if params.get("effect_min") is not None: + try: + filters.append(ds.field("effect") >= float(params["effect_min"])) + except Exception: + pass + if params.get("effect_max") is not None: + try: + filters.append(ds.field("effect") <= float(params["effect_max"])) + except Exception: + pass + if params.get("pval_min") is not None: + try: + filters.append(ds.field("pval") >= float(params["pval_min"])) + except Exception: + pass + if params.get("pval_max") is not None: + try: + filters.append(ds.field("pval") <= float(params["pval_max"])) + except Exception: + pass + + # Combine filters + valid_filters = [f for f in filters if f is not None] + if len(valid_filters) == 1: + filter_expr = valid_filters[0] + elif len(valid_filters) > 1: + expr = valid_filters[0] + for f in valid_filters[1:]: + expr = expr & f + filter_expr = expr + + return { + "file_patterns": file_patterns, + "filter": filter_expr, + "format": "parquet", + } + + def read_table(self, **kwargs: Any): + """Convenience: return only the Arrow table. + + Equivalent to `self.read(...)["data"]`. + """ + result = self.read(params=kwargs.pop("params", {}), **kwargs) + return result["data"] \ No newline at end of file diff --git a/tfbpapi/__init__.py b/tfbpapi/__init__.py index 5f54700..151fb4f 100644 --- a/tfbpapi/__init__.py +++ b/tfbpapi/__init__.py @@ -14,6 +14,7 @@ from .rank_transforms import shifted_negative_log_ranks, stable_rank, transform from .RankResponseAPI import RankResponseAPI from .RegulatorAPI import RegulatorAPI +from .HfHu2007Reimand2010API import HfHu2007Reimand2010API __all__ = [ "BindingAPI", @@ -34,4 +35,5 @@ "RegulatorAPI", "stable_rank", "shifted_negative_log_ranks", + "HfHu2007Reimand2010API", ] diff --git a/tfbpapi/tests/test_HfHu2007Reimand2010API.py b/tfbpapi/tests/test_HfHu2007Reimand2010API.py new file mode 100644 index 0000000..12b085b --- /dev/null +++ b/tfbpapi/tests/test_HfHu2007Reimand2010API.py @@ -0,0 +1,72 @@ +import os +import pytest + +from tfbpapi import HfHu2007Reimand2010API + + +@pytest.fixture() +def api(tmp_path): + cache_dir = tmp_path / "hf_cache" + local_dir = tmp_path / "hf_local" + cache_dir.mkdir(parents=True, exist_ok=True) + local_dir.mkdir(parents=True, exist_ok=True) + + # Token is optional for public datasets; if user has HF_TOKEN env it will be used. + return HfHu2007Reimand2010API( + cache_dir=str(cache_dir), + local_dir=str(local_dir), + ) + + +def test_retrieval_schema_and_nonempty(api): + result = api.read(params={}) + table = result["data"] + + # Verify expected columns + expected_cols = { + "regulator_locus_tag", + "regulator_symbol", + "target_locus_tag", + "target_symbol", + "effect", + "pval", + } + assert expected_cols.issubset(set(table.schema.names)) + + # Non-empty + assert table.num_rows > 0 + + +def test_filter_by_symbols_and_ranges(api): + # Choose a small set of regulators and thresholds + params = { + "regulator_symbol": ["HAP4", "GCR1", "ACE2", "HAP1"], + "effect_min": 1.0, + "pval_max": 0.05, + } + table = api.read(params=params)["data"] + + # Validate constraints hold in the materialized result + df = table.to_pandas() + assert not df.empty + assert set(df["regulator_symbol"]).issubset(set(params["regulator_symbol"])) + assert (df["effect"] >= params["effect_min"]).all() + assert (df["pval"] <= params["pval_max"]).all() + + +def test_memoization_cache_returns_same_object(api): + params = {"regulator_symbol": ["HAP4"], "effect_min": 0.5} + result1 = api.read(params=params) + result2 = api.read(params=params) + + # The in-memory memoization returns the same dict instance + assert result1 is result2 + + +def test_metadata_best_effort(api): + result = api.read(params={}) + meta = result["metadata"] + # Metadata may vary depending on HF API responses; just assert type/keys best-effort + if meta is not None: + assert isinstance(meta, dict) + assert any(k in meta for k in ("filename", "etag", "commit_hash", "location")) \ No newline at end of file