"""Generated: validation needed.
Description:
PTR submodule implementation used by protein-stage coordinator.
Provides standardization, deduplication, linear transform, within-sample
imputation, unobserved-gene imputation, and expression×PTR multiplication
with optional sample-type column mapping.
"""
from __future__ import annotations
import logging
from collections.abc import Callable
from typing import Any
import numpy as np
import pandas as pd
from VmaxBuilder.config.dataclasses import APIConfig
from VmaxBuilder.core.protocols import Scaffold
from VmaxBuilder.protein.input_resolution import resolve_dataframe_input
from VmaxBuilder.utils.extra_utils import (
get_transport_reaction_gene_ids,
resolve_gene_or_reaction_group_members,
)
from VmaxBuilder.utils.transformations import transform_dataframe
_logger = logging.getLogger(__name__)
_NA_TOKENS: frozenset[str] = frozenset(
{
"",
"nan",
"na",
"n/a",
"none",
"inf",
"+inf",
"-inf",
"infinity",
"+infinity",
"-infinity",
}
)
def _series_mode(series: pd.Series) -> float:
"""Generated: validation needed.
Description:
Compute one numeric mode for a series; return NaN when unavailable.
Args:
series (pd.Series): Input series.
Returns:
float: First mode value or ``np.nan`` when no mode exists.
"""
mode_values = series.mode(dropna=True)
if mode_values.empty:
return float(np.nan)
return float(mode_values.iloc[0])
def _normalize_sample_label(value: Any) -> str:
"""Generated: validation needed.
Description:
Normalize sample/tissue labels for robust matching across expression,
PTR, and explicit tissue-map configuration.
Args:
value (Any): Raw label value.
Returns:
str: Normalized label (trimmed, lower-case, optional ``_ptr`` suffix removed).
"""
text = str(value).strip().lower()
if text.endswith("_ptr"):
text = text[: -len("_ptr")]
return text
_IMPUTATION_STATISTICS: dict[str, Any] = {
"median": lambda s: s.median(skipna=True),
"mean": lambda s: s.mean(skipna=True),
"mode": _series_mode,
"max": lambda s: s.max(skipna=True),
"min": lambda s: s.min(skipna=True),
}
_PRETRANSFORM_ALIASES: dict[str, str] = {
"none": "linear",
}
[docs]
class DefaultPTRImplementation:
"""Generated: validation needed.
Description:
PTR preprocessing and combination logic for the expression+PTR protein
abundance pathway. Covers standardisation, deduplication, log→linear
conversion, within-sample imputation, and expansion to the full
expression gene index.
"""
_METHODS_REQUIRING_SAMPLE_TYPE_MAP: frozenset[str] = frozenset({"ptr_weighted_median"})
# ------------------------------------------------------------------
# Resolution
# ------------------------------------------------------------------
[docs]
def resolve_ptr_frame(
self,
scaffold: Scaffold,
config: APIConfig,
) -> pd.DataFrame | None:
"""Generated: validation needed.
Description:
Resolve PTR dataframe from configured scaffold/config sources.
Args:
scaffold (Scaffold): Shared pipeline scaffold.
config (APIConfig): Root API configuration.
Returns:
pd.DataFrame | None: PTR dataframe when available.
"""
return resolve_dataframe_input(scaffold, config, input_key="ptr")
# ------------------------------------------------------------------
# Standardization
# ------------------------------------------------------------------
[docs]
@staticmethod
def standardize_ptr_frame(ptr_df: pd.DataFrame) -> pd.DataFrame:
"""Generated: validation needed.
Description:
Standardize missing-value tokens, trim string whitespace, and
coerce all values to numeric. Resets integer-indexed frames to
use the first column as the index. Normalises column names to
lower-case.
Args:
ptr_df (pd.DataFrame): Raw PTR table (genes × samples).
Returns:
pd.DataFrame: Numeric PTR frame with standardized missing values
and lower-case column names.
"""
df = ptr_df.copy()
if df.index.dtype in ("int64", "float64") and df.index.min() in (0, 1):
df = df.set_index(df.columns[0])
df = df.replace({pd.NA: np.nan, np.inf: np.nan, -np.inf: np.nan, None: np.nan})
for col in df.columns:
if pd.api.types.is_object_dtype(df[col]) or pd.api.types.is_string_dtype(df[col]):
df[col] = df[col].map(
lambda v: ( # noqa: B023
np.nan
if isinstance(v, str) and v.strip().lower() in _NA_TOKENS
else v.strip()
if isinstance(v, str)
else v
)
)
df = df.apply(pd.to_numeric, errors="coerce")
df = df.replace({np.nan: pd.NA})
df.columns = df.columns.map(str).str.lower()
return df
# ------------------------------------------------------------------
# Deduplication
# ------------------------------------------------------------------
[docs]
@staticmethod
def remove_ptr_duplicates(ptr_df: pd.DataFrame) -> pd.DataFrame:
"""Generated: validation needed.
Description:
For each duplicated gene row, retain the row with the most
non-missing values. When tied, keep the first occurrence.
Args:
ptr_df (pd.DataFrame): PTR table potentially containing duplicate
gene identifiers in the index.
Returns:
pd.DataFrame: PTR table with unique gene index.
"""
df = ptr_df.copy()
duplicated_genes = df.index[df.index.duplicated(keep=False)].unique()
for gene in duplicated_genes:
gene_rows = df.loc[df.index == gene]
non_nan_counts = gene_rows.notna().sum(axis=1)
max_count = non_nan_counts.max()
best_rows = gene_rows[non_nan_counts == max_count]
drop_idx = gene_rows.index.difference(best_rows.index[:1])
df = df.drop(drop_idx)
if df.index.duplicated().any():
n_before = len(df)
df = df[~df.index.duplicated(keep="first")]
_logger.warning(
"PTR: removed %d residual duplicate rows after targeted deduplication.",
n_before - len(df),
)
return pd.DataFrame(df)
# ------------------------------------------------------------------
# Within-sample imputation
# ------------------------------------------------------------------
[docs]
@staticmethod
def get_weights(
df: pd.DataFrame,
col_stat_function: Callable[[pd.Series], float],
) -> pd.Series:
"""Generated: validation needed.
Description:
Compute per-column weighting ratios for within-sample imputation.
Args:
df (pd.DataFrame): PTR frame in linear space.
col_stat_function (Callable[[pd.Series], float]): Statistic function
for column aggregation and global normalisation.
Returns:
pd.Series: Weight ratio per PTR column.
"""
col_stats = pd.Series({col: float(col_stat_function(df[col])) for col in df.columns})
stat_of_col_stats = float(col_stat_function(col_stats))
ratio = (
col_stats / stat_of_col_stats
if stat_of_col_stats != 0 and not np.isnan(stat_of_col_stats)
else pd.Series(1.0, index=col_stats.index)
)
return ratio
[docs]
@staticmethod
def _validate_within_sample_weighting(
use_weighted: bool,
weighted_statistic: str | None,
) -> None:
"""Generated: validation needed.
Description:
Validate effective within-sample weighting inputs.
Args:
use_weighted (bool): Weighted-imputation toggle.
weighted_statistic (str | None): Column-statistic key for weighted mode.
Raises:
ValueError: When weighted mode lacks a strategy statistic.
"""
if not use_weighted:
return
if weighted_statistic is None:
raise ValueError(
"Weighted imputation requires weighted_statistic to be specified."
)
[docs]
@staticmethod
def _resolve_within_sample_stat_functions(
use_weighted: bool,
weighted_statistic: str | None,
imputation_statistic: str,
) -> tuple[Callable[[pd.Series], float], Callable[[pd.Series], float] | None]:
"""Generated: validation needed.
Description:
Resolve callable statistic functions used by within-sample imputation.
Args:
use_weighted (bool): Weighted-imputation toggle.
weighted_statistic (str | None): Weighted-column statistic key.
imputation_statistic (str): Row-wise statistic key.
Returns:
tuple[Callable[[pd.Series], float], Callable[[pd.Series], float] | None]:
Row-statistic function and optional weighted-column statistic function.
Raises:
ValueError: When requested statistic keys are unsupported.
"""
imputation_statistic_function = _IMPUTATION_STATISTICS.get(imputation_statistic)
if imputation_statistic_function is None:
raise ValueError(
f"Unrecognised PTR partial_missing_imputation_statistic '"
f"{imputation_statistic}'. "
f"Expected one of: {', '.join(_IMPUTATION_STATISTICS)}."
)
weighted_statistic_function = (
_IMPUTATION_STATISTICS.get(weighted_statistic) if weighted_statistic else None
)
if use_weighted and weighted_statistic_function is None:
raise ValueError(
f"Unrecognised PTR weighted imputation statistic "
f"'{weighted_statistic}'. "
f"Expected one of: {', '.join(_IMPUTATION_STATISTICS)}."
)
return imputation_statistic_function, weighted_statistic_function
[docs]
@staticmethod
def impute_within_tissue_ptrs(
ptr_df: pd.DataFrame,
use_weighted: bool = True,
weighted_statistic: str | None = "median",
imputation_statistic: str = "median",
) -> pd.DataFrame:
"""Generated: validation needed.
Description:
Impute missing values for genes observed in at least one sample.
Weighted behaviour is controlled by ``use_weighted``.
Args:
ptr_df (pd.DataFrame): PTR table in linear space (genes × samples).
use_weighted (bool): Apply weighted per-column scaling during
within-sample imputation.
weighted_statistic (str | None): Statistic for weighted column ratio.
imputation_statistic (str): Statistic used for row-wise base fill.
Returns:
pd.DataFrame: PTR table with within-sample missing values filled.
Raises:
ValueError: When weighting configuration or statistic is unrecognised.
"""
DefaultPTRImplementation._validate_within_sample_weighting(
use_weighted,
weighted_statistic,
)
imputation_statistic_function, weighted_statistic_function = (
DefaultPTRImplementation._resolve_within_sample_stat_functions(
use_weighted,
weighted_statistic,
imputation_statistic,
)
)
df = ptr_df.copy().replace({pd.NA: np.nan}).astype(float)
row_stats = df.apply(lambda row: float(imputation_statistic_function(row)), axis=1)
if use_weighted:
assert weighted_statistic_function is not None
ratio = DefaultPTRImplementation.get_weights(df, weighted_statistic_function)
else:
ratio = pd.Series(1.0, index=df.columns)
for col in df.columns:
mask = df[col].isna()
if mask.any():
df.loc[mask, col] = row_stats[mask].astype(float) * float(ratio[col])
df.columns = df.columns.str.lower()
return df
# ------------------------------------------------------------------
# Unobserved-gene imputation
# ------------------------------------------------------------------
[docs]
@staticmethod
def _resolve_unobserved_source_frame(
ptr_df: pd.DataFrame,
strategy: str,
reference_df: pd.DataFrame | None,
) -> pd.DataFrame:
"""Generated: validation needed.
Description:
Resolve source PTR frame used to compute unobserved-gene fill
statistics.
Args:
ptr_df (pd.DataFrame): PTR frame after within-sample imputation.
strategy (str): Unobserved-gene strategy.
reference_df (pd.DataFrame | None): Optional pre-imputation frame.
Returns:
pd.DataFrame: Source frame for per-sample statistics.
Raises:
ValueError: When before-imputation strategy is selected without
a reference frame.
"""
if strategy == "sample_before_imputation":
if reference_df is None:
raise ValueError(
"sample_before_imputation requires reference_df with original PTR values."
)
return reference_df.copy().replace({pd.NA: np.nan}).astype(float)
return ptr_df.copy().replace({pd.NA: np.nan}).astype(float)
[docs]
@staticmethod
def _compute_per_sample_fill_values(
source_df: pd.DataFrame,
statistic: str,
) -> dict[str, float]:
"""Generated: validation needed.
Description:
Compute one fill value per sample column from chosen source frame.
Args:
source_df (pd.DataFrame): Source frame used for statistics.
statistic (str): Aggregation statistic key.
Returns:
dict[str, float]: Per-sample fill values.
Raises:
ValueError: When statistic key is unsupported.
"""
base_statistic_function = _IMPUTATION_STATISTICS.get(statistic)
if base_statistic_function is None:
raise ValueError(
f"Unrecognised unobserved_gene_imputation_statistic '{statistic}'. "
f"Expected one of: {', '.join(_IMPUTATION_STATISTICS)}."
)
per_sample_values = {
col: float(base_statistic_function(source_df[col])) for col in source_df.columns
}
return per_sample_values
[docs]
@staticmethod
def _apply_global_unobserved_fill(
df: pd.DataFrame,
fill_values: dict[str, float],
target_gene_ids: set[str],
) -> pd.DataFrame:
"""Generated: validation needed.
Description:
Fill missing cells using one global per-sample statistic.
Args:
df (pd.DataFrame): Target PTR frame aligned to expression index.
fill_values (dict[str, float]): Per-sample fallback values.
target_gene_ids (set[str]): Gene IDs eligible for unobserved-gene fill.
Returns:
pd.DataFrame: Frame with missing values filled.
"""
if not target_gene_ids:
return df
target_gene_mask = pd.Series(
[str(gene_id) in target_gene_ids for gene_id in df.index],
index=df.index,
dtype=bool,
)
for col in df.columns:
mask = df[col].isna() & target_gene_mask
if mask.any():
df.loc[mask, col] = fill_values.get(col, np.nan)
return df
[docs]
@staticmethod
def _apply_grouped_unobserved_fill(
df: pd.DataFrame,
source_df: pd.DataFrame,
statistic: str,
special_gene_groups: dict[str, list[str]],
fallback_fill_values: dict[str, float],
target_gene_ids: set[str],
trace: dict[str, Any] | None = None,
) -> pd.DataFrame:
"""Generated: validation needed.
Description:
Fill missing cells by special-gene groups with independent per-group
statistics and global fallback.
Args:
df (pd.DataFrame): Target PTR frame aligned to expression index.
source_df (pd.DataFrame): Source frame for statistic calculation.
statistic (str): Aggregation statistic key.
special_gene_groups (dict[str, list[str]]): Group name to gene IDs.
fallback_fill_values (dict[str, float]): Global per-sample fallback
values.
target_gene_ids (set[str]): Gene IDs eligible for unobserved-gene fill.
trace (dict[str, Any] | None): Optional mutable trace dictionary
populated with special-group mapping and assigned imputed values.
Returns:
pd.DataFrame: Frame with grouped missing-value imputation applied.
"""
gene_group_lookup: dict[str, str] = {}
for group_name, group_genes in special_gene_groups.items():
for gene_id in group_genes:
gene_group_lookup.setdefault(gene_id, group_name)
group_fill_values: dict[str, dict[str, float]] = {}
for group_name, group_genes in special_gene_groups.items():
group_genes_in_source = source_df.index.intersection(group_genes)
if len(group_genes_in_source) == 0:
group_fill_values[group_name] = dict(fallback_fill_values)
continue
group_frame = pd.DataFrame(source_df.loc[group_genes_in_source])
group_fill_values[group_name] = (
DefaultPTRImplementation._compute_per_sample_fill_values(
group_frame,
statistic,
)
)
assigned_values: dict[str, dict[str, float]] = {}
if not target_gene_ids:
return df
df = df.copy().reindex(df.index.union(target_gene_ids))
target_gene_mask = pd.Series(
[str(gene_id) in target_gene_ids for gene_id in df.index],
index=df.index,
dtype=bool,
)
for col in df.columns:
mask = df[col].isna() & target_gene_mask
if not bool(mask.any()):
continue
missing_gene_ids = df.index[mask]
for gene_id in missing_gene_ids:
group_name = gene_group_lookup.get(gene_id)
assigned_value = DefaultPTRImplementation._resolve_grouped_fill_value(
column_name=col,
group_name=group_name,
group_fill_values=group_fill_values,
fallback_fill_values=fallback_fill_values,
)
df.at[gene_id, col] = assigned_value
assigned_values.setdefault(str(gene_id), {})[str(col)] = assigned_value
if trace is not None:
trace["special_group_gene_mapping"] = dict(gene_group_lookup)
trace["special_group_fill_values_per_sample"] = {
group_name: dict(fill_values)
for group_name, fill_values in group_fill_values.items()
}
trace["special_group_assigned_values_per_sample"] = assigned_values
return df
[docs]
@staticmethod
def _resolve_grouped_fill_value(
*,
column_name: Any,
group_name: str | None,
group_fill_values: dict[str, dict[str, float]],
fallback_fill_values: dict[str, float],
) -> float:
"""Generated: validation needed.
Description:
Resolve one grouped unobserved-gene fill value with fallback.
Args:
column_name (Any): Sample/column identifier.
group_name (str | None): Optional resolved group name.
group_fill_values (dict[str, dict[str, float]]): Per-group fill values.
fallback_fill_values (dict[str, float]): Global fallback fill values.
Returns:
float: Assigned fill value.
"""
if group_name is None:
return float(fallback_fill_values.get(column_name, np.nan))
return float(
group_fill_values[group_name].get(
column_name,
fallback_fill_values.get(column_name, np.nan),
)
)
[docs]
@staticmethod
def impute_unobserved_genes(
ptr_df: pd.DataFrame,
expression_df: pd.DataFrame,
unobserved_gene_ids: set[str],
strategy: str = "sample_after_imputation",
statistic: str = "median",
reference_df: pd.DataFrame | None = None,
special_gene_groups: dict[str, list[str]] | None = None,
use_special_groups: bool = False,
trace: dict[str, Any] | None = None,
) -> pd.DataFrame:
"""Generated: validation needed.
Description:
Expand PTR to match the full gene index of ``expression_df``.
Genes present in expression but absent from PTR are filled using a
per-sample statistic. ``sample_after_imputation`` computes the
statistic on the incoming (already-imputed) PTR values;
``sample_before_imputation`` behaves identically at call time since
the pre-imputation snapshot must be supplied externally by the
caller if needed.
Args:
ptr_df (pd.DataFrame): PTR table after within-sample imputation.
expression_df (pd.DataFrame): Expression table whose index defines
the target gene universe.
unobserved_gene_ids (set[str]): Gene IDs present in expression but
absent from PTR to be filled.
strategy (str): Imputation strategy for unobserved genes. One of
``sample_after_imputation``, ``sample_before_imputation``.
statistic (str): Per-sample aggregation statistic. One of
``median``, ``mean``, ``mode``, ``max``, ``min``.
reference_df (pd.DataFrame | None): Pre-within-imputation PTR frame
used when ``strategy='sample_before_imputation'``.
special_gene_groups (dict[str, list[str]] | None): Optional
special groups to impute independently.
use_special_groups (bool): Enable special-group independent
imputation behavior.
trace (dict[str, Any] | None): Optional mutable trace dictionary
populated with grouped-imputation diagnostics.
Returns:
pd.DataFrame: PTR table re-indexed to ``expression_df.index`` with
unobserved genes filled.
Raises:
ValueError: When ``strategy``/``statistic`` is unrecognised or when
``sample_before_imputation`` lacks a reference frame.
"""
if strategy not in ("sample_after_imputation", "sample_before_imputation"):
raise ValueError(
f"Unrecognised unobserved_gene_imputation_strategy '{strategy}'. "
"Expected one of: sample_after_imputation, sample_before_imputation."
)
df = ptr_df.copy().replace({pd.NA: np.nan}).astype(float)
source_df = DefaultPTRImplementation._resolve_unobserved_source_frame(
ptr_df,
strategy,
reference_df,
)
fill_values = DefaultPTRImplementation._compute_per_sample_fill_values(
source_df,
statistic,
)
if not use_special_groups or not special_gene_groups:
return DefaultPTRImplementation._apply_global_unobserved_fill(
df,
fill_values,
unobserved_gene_ids,
)
return DefaultPTRImplementation._apply_grouped_unobserved_fill(
df,
source_df,
statistic,
special_gene_groups,
fill_values,
unobserved_gene_ids,
trace=trace,
)
# ------------------------------------------------------------------
# Orchestration
# ------------------------------------------------------------------
[docs]
def prepare_ptr_frame(
self,
ptr_df: pd.DataFrame,
expression_df: pd.DataFrame,
config: APIConfig,
metabolic_genes: list[str] | None = None,
model_artifact: Any | None = None,
) -> pd.DataFrame:
"""Generated: validation needed.
Description:
Full PTR preprocessing pipeline: standardize → deduplicate →
optionally filter to metabolic genes → convert to linear scale →
impute within-sample missing values → expand to expression gene
index.
Args:
ptr_df (pd.DataFrame): Raw PTR input table (genes × tissue-types).
expression_df (pd.DataFrame): Preprocessed expression table used
to define the target gene universe and guide imputation.
config (APIConfig): Root API configuration. PTR options read from
``config.ptr``.
metabolic_genes (list[str] | None): Optional list of gene IDs from
the metabolic model. When provided and
``config.ptr.impute_from_metabolic_genes_only`` is ``True``,
PTR is filtered to this set before imputation.
model_artifact (Any | None): Optional cobra-like model used to
expand shorthand special gene groups such as
``transport_reactions``.
Returns:
pd.DataFrame: Fully preprocessed PTR table aligned to the
expression gene index.
"""
ptr_cfg = config.ptr
ptr_imputation_trace: dict[str, Any] = {}
df = self.standardize_ptr_frame(ptr_df)
_logger.debug("PTR: standardized frame shape %s.", df.shape)
df = self.remove_ptr_duplicates(df)
_logger.debug("PTR: deduplicated frame shape %s.", df.shape)
if ptr_cfg.impute_from_metabolic_genes_only and metabolic_genes is not None:
before = len(df)
df = df.loc[df.index.isin(metabolic_genes)]
_logger.debug(
"PTR: filtered to %d metabolic genes (dropped %d).",
len(df),
before - len(df),
)
df = self.transform_ptr_to_linear(
df,
pretransformed_type=ptr_cfg.pretransformed_type,
)
before_within_imputation_df = df.copy()
df = self.impute_within_tissue_ptrs(
df,
use_weighted=ptr_cfg.partial_missing_use_weighted,
weighted_statistic=ptr_cfg.partial_missing_weighted_statistic,
imputation_statistic=ptr_cfg.partial_missing_imputation_statistic,
)
_logger.debug("PTR: within-sample imputation done.")
unobserved_strategy = ptr_cfg.unobserved_gene_imputation_strategy
special_gene_groups = self.resolve_special_gene_groups(
config,
model_artifact=model_artifact,
expression_gene_ids=set(map(str, expression_df.index)),
)
if (
not special_gene_groups
and ptr_cfg.use_special_groups_for_unobserved_imputation
and model_artifact is not None
):
special_gene_groups = {
"transport_reactions": get_transport_reaction_gene_ids(
model_artifact,
expression_gene_ids=set(map(str, expression_df.index)),
)
}
_logger.debug(
"PTR: auto-populated transport_reactions group (%d genes).",
len(special_gene_groups.get("transport_reactions", [])),
)
unobserved_genes: set[str] = set(map(str, expression_df.index)) - set(
map(str, df.index)
)
if not unobserved_genes:
_logger.debug(
"PTR: no unobserved genes to impute after within-sample imputation."
)
return df.reindex(expression_df.index)
df = self.impute_unobserved_genes(
df.reindex(df.index.union(unobserved_genes), fill_value=np.nan),
expression_df,
unobserved_gene_ids=unobserved_genes,
strategy=unobserved_strategy,
statistic=ptr_cfg.unobserved_gene_imputation_statistic,
reference_df=before_within_imputation_df,
special_gene_groups=special_gene_groups,
use_special_groups=ptr_cfg.use_special_groups_for_unobserved_imputation,
trace=ptr_imputation_trace,
)
_logger.debug("PTR: unobserved-gene imputation done, final shape %s.", df.shape)
self._latest_ptr_preparation_diagnostics = {
"special_gene_groups": special_gene_groups,
"special_group_gene_mapping": ptr_imputation_trace.get(
"special_group_gene_mapping", {}
),
"special_group_fill_values_per_sample": ptr_imputation_trace.get(
"special_group_fill_values_per_sample",
{},
),
"special_group_assigned_values_per_sample": ptr_imputation_trace.get(
"special_group_assigned_values_per_sample",
{},
),
}
return df
[docs]
def get_latest_preparation_diagnostics(self) -> dict[str, Any]:
"""Generated: validation needed.
Description:
Return diagnostics captured during latest PTR preparation call.
Returns:
dict[str, Any]: PTR preparation diagnostics for inter-stage artifact
persistence.
"""
diagnostics = getattr(self, "_latest_ptr_preparation_diagnostics", {})
return dict(diagnostics)
[docs]
@staticmethod
def resolve_special_gene_groups(
config: APIConfig,
model_artifact: Any | None = None,
expression_gene_ids: set[str] | None = None,
) -> dict[str, list[str]]:
"""Generated: validation needed.
Description:
Resolve user-provided special gene groups used by PTR unobserved-gene
imputation. This endpoint enables independent group-wise imputation
(e.g., transport genes or other custom partitions). Group values
may contain gene IDs or reaction IDs; ``transport_reactions`` with
an empty list auto-resolves transport-associated genes from model.
Args:
config (APIConfig): Root API configuration.
model_artifact (Any | None): Optional cobra-like model used for
shorthand and reaction-based group expansion.
expression_gene_ids (set[str] | None): Optional expression-gene
universe used to filter resolved group members.
Returns:
dict[str, list[str]]: Mapping of group name to normalized gene IDs.
Raises:
ValueError: When ``transport_reactions`` shorthand is requested
without a model artifact.
"""
raw_groups = config.ptr.special_gene_groups
if raw_groups is None:
return {}
normalized_groups: dict[str, list[str]] = {}
for group_name, group_genes in raw_groups.items():
normalized_name = str(group_name).strip()
if normalized_name == "":
continue
normalized_entries = [
str(group_entry).strip()
for group_entry in group_genes
if str(group_entry).strip() != ""
]
if normalized_name == "transport_reactions" and not normalized_entries:
if model_artifact is None:
raise ValueError(
"special_gene_groups['transport_reactions'] requires a model "
"artifact when no explicit genes or reactions are supplied."
)
normalized_groups[normalized_name] = get_transport_reaction_gene_ids(
model_artifact,
expression_gene_ids=expression_gene_ids,
)
continue
normalized_groups[normalized_name] = resolve_gene_or_reaction_group_members(
model_artifact,
normalized_entries,
expression_gene_ids=expression_gene_ids,
)
return normalized_groups
[docs]
@classmethod
def requires_sample_type_map(cls, ptr_method: str) -> bool:
"""Generated: validation needed.
Description:
Report whether a PTR implementation strategy requires an explicit
expression→PTR sample/tissue mapping.
Args:
ptr_method (str): PTR strategy key from ``config.protein.ptr_method``.
Returns:
bool: ``True`` when selected PTR method requires
``expression.sample_type_map``.
"""
return ptr_method in cls._METHODS_REQUIRING_SAMPLE_TYPE_MAP
# ------------------------------------------------------------------
# Sample-type map resolution
# ------------------------------------------------------------------
[docs]
@staticmethod
def resolve_sample_type_map(
expression_df: pd.DataFrame,
sample_type_map: dict[str, str] | str | None,
) -> dict[str, str]:
"""Generated: validation needed.
Description:
Build a ``{expression_column: ptr_column}`` mapping from the
user-supplied ``sample_type_map``.
* ``None`` → identity map (each expression column maps to itself).
* ``str`` → every expression column maps to that single PTR column.
* ``dict`` → used directly; expression columns absent from the dict
fall back to an identity mapping.
Labels are normalized for robust matching: lower-case, stripped
whitespace, and ``_ptr`` suffix removed.
Args:
expression_df (pd.DataFrame): Expression table whose columns define
the source keys.
sample_type_map (dict[str, str] | str | None): User-configured
column mapping.
Returns:
dict[str, str]: Mapping of expression column → PTR column.
"""
expr_cols_normalized = {
expr_col: _normalize_sample_label(expr_col) for expr_col in expression_df.columns
}
if sample_type_map is None:
return expr_cols_normalized
if isinstance(sample_type_map, str):
normalized_target = _normalize_sample_label(sample_type_map)
return {col: normalized_target for col in expression_df.columns}
normalized_input_map: dict[str, str] = {
_normalize_sample_label(src_col): _normalize_sample_label(dst_col)
for src_col, dst_col in sample_type_map.items()
}
return {
expr_col: normalized_input_map.get(expr_col_norm, expr_col_norm)
for expr_col, expr_col_norm in expr_cols_normalized.items()
}
# ------------------------------------------------------------------
# Combination
# ------------------------------------------------------------------
[docs]
def combine_expression_with_ptr(
self,
expression_df: pd.DataFrame,
ptr_df: pd.DataFrame,
sample_type_map: dict[str, str] | str | None = None,
) -> pd.DataFrame:
"""Generated: validation needed.
Description:
Multiply expression values by PTR values for each gene, using the
resolved sample-type column mapping to pair expression columns with
PTR columns. Genes absent from PTR retain their expression values.
Args:
expression_df (pd.DataFrame): Preprocessed expression table
(genes × expression-samples).
ptr_df (pd.DataFrame): Preprocessed PTR table
(genes × tissue-types).
sample_type_map (dict[str, str] | str | None): Mapping from
expression column names to PTR column names. ``str`` maps
every expression column to the same PTR column; ``None`` falls
back to direct column intersection.
Returns:
pd.DataFrame: Combined protein abundance table with same shape as
``expression_df``.
"""
col_map = self.resolve_sample_type_map(expression_df, sample_type_map)
protein_df = expression_df.copy()
common_genes = expression_df.index.intersection(ptr_df.index)
ptr_col_lookup: dict[str, str] = {}
for ptr_col in ptr_df.columns:
normalized_ptr_col = _normalize_sample_label(ptr_col)
ptr_col_lookup.setdefault(normalized_ptr_col, ptr_col)
if common_genes.empty:
_logger.warning(
"PTR: no overlapping genes between expression and PTR; "
"returning unmodified expression."
)
return protein_df
for expr_col, ptr_col in col_map.items():
if expr_col not in expression_df.columns:
continue
ptr_col_actual = ptr_col_lookup.get(ptr_col)
if ptr_col_actual is None:
_logger.warning(
"PTR: column '%s' not found in PTR frame; "
"skipping multiplication for expression column '%s'.",
ptr_col,
expr_col,
)
continue
protein_df.loc[common_genes, expr_col] = (
expression_df.loc[common_genes, expr_col]
* ptr_df.loc[common_genes, ptr_col_actual]
)
return protein_df