"""Generated: validation needed.
Description:
Expression submodule implementation used by protein-stage coordinator.
"""
from __future__ import annotations
from collections.abc import Sequence
from typing import Protocol
import pandas as pd
from VmaxBuilder.config.dataclasses import APIConfig
from VmaxBuilder.config.validation import ConfigurationError
from VmaxBuilder.core.protocols import Scaffold
from VmaxBuilder.database_retrieval.identifier_translation import (
IdentifierTranslationResult,
IdentifierTranslationService,
)
from VmaxBuilder.protein.input_resolution import resolve_dataframe_input
[docs]
class IdentifierTranslationServiceProtocol(Protocol):
"""Generated: validation needed.
Description:
Protocol for expression-level identifier translation and transcript mapping.
"""
[docs]
def translate_identifiers(
self,
identifiers: Sequence[str],
*,
source_id_type: str,
target_id_type: str,
species: str | None,
provider: str,
max_workers: int,
batch_size: int,
) -> IdentifierTranslationResult:
"""Generated: validation needed.
Description:
Translate identifier collection across namespaces.
Args:
identifiers (list[str]): Source identifiers.
source_id_type (str): Source identifier namespace.
target_id_type (str): Target identifier namespace.
species (str | None): Optional species hint.
provider (str): Translation provider key.
max_workers (int): Maximum worker threads.
batch_size (int): Identifier batch size.
Returns:
IdentifierTranslationResult: Mapping output.
"""
[docs]
def build_transcript_gene_dataframe(
self,
transcript_ids: Sequence[str],
*,
transcript_id_type: str,
target_gene_id_type: str,
species: str | None,
provider: str,
max_workers: int,
batch_size: int,
) -> pd.DataFrame:
"""Generated: validation needed.
Description:
Build transcript-to-gene mapping dataframe.
Args:
transcript_ids (list[str]): Transcript identifiers.
transcript_id_type (str): Transcript namespace.
target_gene_id_type (str): Target gene namespace.
species (str | None): Optional species hint.
provider (str): Translation provider key.
max_workers (int): Maximum worker threads.
batch_size (int): Identifier batch size.
Returns:
pd.DataFrame: Transcript mapping table.
"""
[docs]
class DefaultExpressionImplementation:
"""Generated: validation needed.
Description:
Resolve and preprocess expression input for downstream protein abundance assembly.
Args:
translation_service (IdentifierTranslationServiceProtocol | None):
Optional identifier translation service override.
"""
def __init__(
self,
translation_service: IdentifierTranslationServiceProtocol | None = None,
) -> None:
self._translation_service = translation_service or IdentifierTranslationService()
[docs]
def resolve_expression_frame(
self,
scaffold: Scaffold,
config: APIConfig,
) -> pd.DataFrame | None:
"""Generated: validation needed.
Description:
Resolve expression dataframe from configured scaffold/config sources.
Args:
scaffold (Scaffold): Shared pipeline scaffold.
config (APIConfig): Root API configuration.
Returns:
pd.DataFrame | None: Expression dataframe when available.
"""
return resolve_dataframe_input(scaffold, config, input_key="expression")
[docs]
@staticmethod
def _build_id_type_name(provider: str | None, level: str) -> str | None:
"""Generated: validation needed.
Description:
Build full identifier type name from provider and level.
Args:
provider (str | None): Identifier provider.
level (str): Gene or transcript granularity.
Returns:
str | None: Full identifier type name, or None if provider is None.
"""
if provider is None:
return None
level_lower = level.lower()
if provider == "ensembl":
return f"ensembl_{level_lower}_id"
return provider
[docs]
def prepare_expression_frame(
self,
scaffold: Scaffold,
expression_df: pd.DataFrame,
config: APIConfig,
) -> pd.DataFrame:
"""Generated: validation needed.
Description:
Apply placeholder transcript-to-gene conversion when run
target requests gene level.
Args:
scaffold (Scaffold): Shared pipeline scaffold.
expression_df (pd.DataFrame): Expression input table.
config (APIConfig): Root API configuration.
Returns:
pd.DataFrame: Possibly converted expression table.
Raises:
ConfigurationError: If unsupported transcript aggregation policy is configured.
Modifies:
scaffold["artifacts"] and scaffold["diagnostics"] with translation metadata.
"""
source_level = config.expression.level.lower()
target_level = config.run_target_transcript_gene_level.lower()
source_id_type = self._build_id_type_name(config.expression.id_type, source_level)
target_id_type = self._build_id_type_name(config.model.id_type, config.model.level)
expression_index = [str(index_value) for index_value in expression_df.index]
diagnostics_payload = scaffold.setdefault("diagnostics", {}).setdefault(
"expression_preparation", {}
)
if source_level == "transcript":
# Transcript level requires valid id_types for mapping
assert (
source_id_type is not None
), "Expression id_type must be set for transcript-level conversion"
assert (
target_id_type is not None
), "Model id_type must be set for transcript-level conversion"
transcript_map_df = self._translation_service.build_transcript_gene_dataframe(
expression_index,
transcript_id_type=source_id_type,
target_gene_id_type=target_id_type,
species=config.expression.id_translation_species,
provider=config.expression.id_translation_provider,
max_workers=config.expression.id_translation_max_workers,
batch_size=config.expression.id_translation_batch_size,
)
scaffold.setdefault("artifacts", {})["transcript_gene_map"] = transcript_map_df
diagnostics_payload["transcript_gene_map_rows"] = int(len(transcript_map_df))
if target_level == "gene":
return self._aggregate_transcripts_to_genes(
expression_df,
transcript_map_df,
aggregation_policy=config.expression.transcript_aggregation_policy,
protein_coding_only=config.transcript_processing.protein_coding_only,
protein_coding_aggregation_policy=(
config.transcript_processing.protein_coding_aggregation_policy
),
diagnostics_payload=diagnostics_payload,
)
return expression_df
if not source_id_type or not target_id_type:
diagnostics_payload["id_translation"] = "skipped_missing_id_type"
return expression_df
if source_id_type == target_id_type:
diagnostics_payload["id_translation"] = "skipped_matching_id_type"
return expression_df
translation_result = self._translation_service.translate_identifiers(
expression_index,
source_id_type=source_id_type,
target_id_type=target_id_type,
species=config.expression.id_translation_species,
provider=config.expression.id_translation_provider,
max_workers=config.expression.id_translation_max_workers,
batch_size=config.expression.id_translation_batch_size,
)
diagnostics_payload["id_translation"] = {
"source_id_type": source_id_type,
"target_id_type": target_id_type,
"mapped_identifiers": len(translation_result.mapped_identifiers),
"unresolved_identifiers": translation_result.unresolved_identifiers,
}
return self._apply_identifier_mapping(
expression_df,
identifier_mapping=translation_result.mapped_identifiers,
)
[docs]
@staticmethod
def _aggregate_transcripts_to_genes(
expression_df: pd.DataFrame,
transcript_gene_map_df: pd.DataFrame,
*,
aggregation_policy: str,
protein_coding_only: bool,
protein_coding_aggregation_policy: str,
diagnostics_payload: dict[str, object],
) -> pd.DataFrame:
"""Generated: validation needed.
Description:
Aggregate transcript expression rows to genes and keep unresolved transcripts.
Args:
expression_df (pd.DataFrame): Transcript-level expression table.
transcript_gene_map_df (pd.DataFrame): Transcript-to-gene mapping dataframe.
aggregation_policy (str): Configured aggregation policy.
protein_coding_only (bool): Whether to keep only protein-coding transcripts.
protein_coding_aggregation_policy (str): Aggregation policy for
protein-coding transcript rows.
diagnostics_payload (dict[str, object]): Mutable diagnostics payload.
Returns:
pd.DataFrame: Gene-level table with unresolved transcripts retained.
Raises:
ConfigurationError: If unsupported aggregation policy is configured.
"""
supported_policies = {"sum", "mean"}
if aggregation_policy not in supported_policies:
raise ConfigurationError(
"Unsupported transcript aggregation policy: "
f"{aggregation_policy!r}. Supported values: ['sum', 'mean']."
)
if protein_coding_aggregation_policy not in supported_policies:
raise ConfigurationError(
"Unsupported protein-coding transcript aggregation policy: "
f"{protein_coding_aggregation_policy!r}. Supported values: ['sum', 'mean']."
)
if transcript_gene_map_df.empty:
diagnostics_payload["transcript_unresolved_count"] = int(len(expression_df.index))
return expression_df
aggregation_policy_to_use = aggregation_policy
if protein_coding_only and "is_protein_coding" in transcript_gene_map_df.columns:
transcript_gene_map_df = transcript_gene_map_df[
transcript_gene_map_df["is_protein_coding"]
]
aggregation_policy_to_use = protein_coding_aggregation_policy
diagnostics_payload["protein_coding_only_filter"] = True
diagnostics_payload["protein_coding_map_rows"] = int(len(transcript_gene_map_df))
transcript_to_gene = {
str(row["transcript_id"]): str(row["gene_id"])
for _, row in transcript_gene_map_df[["transcript_id", "gene_id"]]
.dropna()
.iterrows()
}
normalised_index = [str(index_value) for index_value in expression_df.index]
mapped_rows = [index_value in transcript_to_gene for index_value in normalised_index]
mapped_expression_df = expression_df.loc[mapped_rows].copy()
mapped_expression_df.index = [
transcript_to_gene[str(index_value)] for index_value in mapped_expression_df.index
]
if aggregation_policy_to_use == "mean":
aggregated_expression_df = mapped_expression_df.groupby(level=0).mean()
else:
aggregated_expression_df = mapped_expression_df.groupby(level=0).sum()
unresolved_expression_df = expression_df.loc[
[not row_is_mapped for row_is_mapped in mapped_rows]
].copy()
diagnostics_payload["transcript_unresolved_count"] = int(
len(unresolved_expression_df.index)
)
if unresolved_expression_df.empty:
return pd.DataFrame(aggregated_expression_df)
combined_expression_df = pd.concat(
[aggregated_expression_df, unresolved_expression_df]
)
if aggregation_policy_to_use == "mean":
return combined_expression_df.groupby(level=0).mean()
return combined_expression_df.groupby(level=0).sum()
[docs]
@staticmethod
def _apply_identifier_mapping(
expression_df: pd.DataFrame,
*,
identifier_mapping: dict[str, str],
) -> pd.DataFrame:
"""Generated: validation needed.
Description:
Apply partial identifier mapping and aggregate rows when mappings collide.
Args:
expression_df (pd.DataFrame): Input expression table.
identifier_mapping (dict[str, str]):
Source identifier to target identifier mapping.
Returns:
pd.DataFrame: Table indexed by mapped identifiers where available.
"""
if not identifier_mapping:
return expression_df
mapped_df = expression_df.copy()
mapped_df.index = [
identifier_mapping.get(str(index_value), str(index_value))
for index_value in mapped_df.index
]
return mapped_df.groupby(level=0).sum()