Source code for VmaxBuilder.utils.lookup_cache

"""
Thread-safe, disk-backed key-value cache for external API lookups.

Designed to be shared by:
- Gene sequence lookups (Ensembl, RefSeq)
  → namespaces ``ensembl_sequences`` / ``refseq_sequences``
- SMILES lookups         (future)            → namespace ``smiles_metabolite``
- Any other external API calls that are expensive and should survive process restarts

Cache files live in ``{project_root}/.lookup_cache/`` by default, or in a directory
set by the ``VmaxBuilder_CACHE_DIR`` environment variable.  Each namespace is a separate
JSON file, e.g. ``.lookup_cache/ensembl_sequences.json``.

Thread-safety
-------------
All public methods acquire an internal ``threading.Lock`` so the cache can safely
be used from ``ThreadPoolExecutor`` workers that call :meth:`set` concurrently.

Atomic writes
-------------
Saves go via ``{file}.tmp`` → :meth:`Path.replace` so a crash mid-write never
leaves a corrupt cache file.

Usage example
-------------
::

    from src.VmaxBuilder.utils.lookup_cache import LookupCache, get_default_cache_dir

    cache = LookupCache(get_default_cache_dir(), "ensembl_sequences")
    key   = sequence_cache_key("homo_sapiens", "ENSG00000139618", "canonical_only")

    if key not in cache:
        result = expensive_api_call(...)
        cache.set(key, gene_result_to_dict(result))   # saved to disk immediately

    data = cache.get(key)  # returns the stored dict

"""

from __future__ import annotations

import json
import os
import threading
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any

_DEFAULT_CACHE_SUBDIR = ".lookup_cache"


[docs] def get_default_cache_dir() -> Path: """Resolve the default cache directory. Resolution order: 1. ``VmaxBuilder_CACHE_DIR`` environment variable (absolute or relative path). 1. ``VmaxBuilder_CACHE_DIR`` environment variable (legacy compatibility). 2. ``{project_root}/.lookup_cache/`` """ env_override = os.environ.get("VmaxBuilder_CACHE_DIR") or os.environ.get( "VmaxBuilder_CACHE_DIR" ) if env_override: return Path(env_override).resolve() # Lazy import to avoid circular dependency at module load time. from .file_handling import get_project_root # noqa: PLC0415 return get_project_root() / _DEFAULT_CACHE_SUBDIR
# --------------------------------------------------------------------------- # Core cache class # ---------------------------------------------------------------------------
[docs] class LookupCache: """Thread-safe, disk-backed key-value store for a single namespace. Parameters ---------- cache_dir: Directory where the JSON file for this namespace lives. namespace: Logical name of the cache (becomes the JSON filename, e.g. ``"ensembl_sequences"`` → ``cache_dir/ensembl_sequences.json``). autosave: When ``True`` (the default) every :meth:`set` / :meth:`set_many` call flushes the updated data to disk immediately. Set to ``False`` for bulk-loading scenarios where you want to call :meth:`save` once at the end. """ def __init__( self, cache_dir: Path, namespace: str, autosave: bool = True, ) -> None: self._path = (cache_dir / namespace).with_suffix(".json") self._lock = threading.Lock() self._autosave = autosave self._data: dict[str, Any] = self._load_from_disk() # ------------------------------------------------------------------ # Internal I/O # ------------------------------------------------------------------ def _load_from_disk(self) -> dict[str, Any]: if not self._path.exists(): return {} try: with self._path.open("r", encoding="utf-8") as fh: data = json.load(fh) if isinstance(data, dict): return data except (json.JSONDecodeError, OSError): pass return {}
[docs] def _write_to_disk(self) -> None: """Atomic write via a temp file so crashes never corrupt the cache.""" self._path.parent.mkdir(parents=True, exist_ok=True) tmp = self._path.with_suffix(".tmp") try: with tmp.open("w", encoding="utf-8") as fh: json.dump(self._data, fh, indent=2) tmp.replace(self._path) except OSError: if tmp.exists(): tmp.unlink(missing_ok=True) raise
# ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------
[docs] def get(self, key: str) -> Any | None: """Return the cached value for *key*, or ``None`` if absent.""" with self._lock: return self._data.get(key)
[docs] def set(self, key: str, value: Any) -> None: """Store *value* under *key* and write to disk if *autosave* is on.""" with self._lock: self._data[key] = value if self._autosave: self._write_to_disk()
[docs] def set_many(self, items: dict[str, Any]) -> None: """Store multiple key-value pairs in one operation (single disk write).""" with self._lock: self._data.update(items) if self._autosave: self._write_to_disk()
[docs] def save(self) -> None: """Flush current state to disk (useful when *autosave* is ``False``).""" with self._lock: self._write_to_disk()
[docs] def invalidate(self, key: str) -> None: """Remove a single cached entry and persist.""" with self._lock: self._data.pop(key, None) if self._autosave: self._write_to_disk()
[docs] def clear(self) -> None: """Remove all entries and persist an empty cache file.""" with self._lock: self._data.clear() self._write_to_disk()
def __contains__(self, key: str) -> bool: with self._lock: return key in self._data def __len__(self) -> int: with self._lock: return len(self._data)
[docs] def keys(self) -> list[str]: """Return a snapshot of all cache keys.""" with self._lock: return list(self._data.keys())
[docs] def hits_and_misses(self, keys: list[str]) -> tuple[list[str], list[str]]: """Split *keys* into *(cached_keys, missing_keys)*. Use this before making API calls to determine which items can be served from cache and which need fetching:: hits, misses = cache.hits_and_misses(gene_symbols) # only fetch 'misses', reuse 'hits' directly """ with self._lock: hits = [k for k in keys if k in self._data] misses = [k for k in keys if k not in self._data] return hits, misses
@property def path(self) -> Path: """Absolute path to the JSON file backing this cache.""" return self._path def __repr__(self) -> str: return ( f"LookupCache(namespace={self._path.stem!r}, " f"entries={len(self._data)}, path={self._path})" )
[docs] @dataclass(frozen=True) class SequenceRecord: """JSON-friendly record describing one retrieved sequence.""" sequence: str source: str accession: str is_canonical: bool | None = None
[docs] @dataclass(frozen=True) class GeneSequenceResult: """Container for sequence lookup results for one gene symbol.""" gene_symbol: str sequences: list[SequenceRecord] errors: list[Any] = field(default_factory=list)
# --------------------------------------------------------------------------- # Sequence-specific helpers # ---------------------------------------------------------------------------
[docs] def sequence_cache_key(species: str, gene_symbol: str, mode: str) -> str: """Canonical cache key for a :class:`GeneSequenceResult`. Parameters ---------- species: The provider-level species string (e.g. ``"homo_sapiens"``). gene_symbol: Gene identifier / symbol. mode: The :class:`~src.VmaxBuilder.sequence_retrieval.types.SequenceMode` value string (e.g. ``"canonical_only"`` or ``"all_isoforms"``). """ return f"{species}:{gene_symbol}:{mode}"
[docs] def gene_result_to_dict(result: GeneSequenceResult) -> dict[str, Any]: """Generated: validation needed. Description: Serialize sequence result object to JSON-safe dictionary. Args: result (GeneSequenceResult): Sequence result to serialize. Returns: dict[str, Any]: JSON-safe representation of result. """ return { "gene_symbol": result.gene_symbol, "sequences": [ { "sequence": r.sequence, "source": r.source, "accession": r.accession, "is_canonical": r.is_canonical, } for r in result.sequences ], "errors": result.errors, }
[docs] def dict_to_gene_result(data: dict[str, Any]) -> GeneSequenceResult: """Generated: validation needed. Description: Deserialize cached dictionary back to sequence result object. Args: data (dict[str, Any]): Cached representation to deserialize. Returns: GeneSequenceResult: Reconstructed sequence result. """ return GeneSequenceResult( gene_symbol=data["gene_symbol"], sequences=[ SequenceRecord( sequence=r["sequence"], source=r["source"], accession=r["accession"], is_canonical=r.get("is_canonical"), ) for r in data.get("sequences", []) ], errors=data.get("errors", []), )
# --------------------------------------------------------------------------- # SMILES-specific helpers (ready for use once SMILES retrieval is implemented) # ---------------------------------------------------------------------------
[docs] def smiles_cache_key(database: str, metabolite_id: str) -> str: """Canonical cache key for a metabolite SMILES lookup. Parameters ---------- database: Source database name (e.g. ``"chebi"``, ``"hmdb"``). metabolite_id: The metabolite identifier in that database. """ return f"{database}:{metabolite_id}"