diff --git a/CHANGELOG.md b/CHANGELOG.md index a8dc485..a0ec1de 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,26 @@ All notable changes to this project will be documented in this file. +## [1.0.0.a6] - 2025-03-12 + +### Bug Fixes + +- Better handle midline model. \ + This means disabling the evolution over midline extension. Also, since the new + version of `lymph-model`, the `midext_prob` parameter is not epected to be the + first one anymore when passed to `set_params()`. +- Pass only ipsilateral diagnosis to unilateral model. +- Pass diagnose & involvement correctly to models as dict. + +### Testing + +- Ensure unilateral model receives correct diagnosis. +- Test that diagnosis is used correctly in posteriors. + +### Build + +- Bump lydata & lymph-model dependency. + ## [1.0.0.a5] - 2025-02-05 ### Bug Fixes diff --git a/lyscripts/compute/posteriors.py b/lyscripts/compute/posteriors.py index 885700c..ee14a9f 100644 --- a/lyscripts/compute/posteriors.py +++ b/lyscripts/compute/posteriors.py @@ -55,6 +55,11 @@ def compute_posteriors( posteriors = [] kwargs = {"midext": midext} if isinstance(model, models.Midline) else {} + if isinstance(model, models.Unilateral | models.HPVUnilateral): + diagnosis = diagnosis.ipsi + else: + diagnosis = diagnosis.model_dump() + for prior in progress.track( sequence=priors, description=progress_desc, diff --git a/lyscripts/compute/risks.py b/lyscripts/compute/risks.py index 84a2842..d1a66a5 100644 --- a/lyscripts/compute/risks.py +++ b/lyscripts/compute/risks.py @@ -7,6 +7,7 @@ import numpy as np from loguru import logger +from lymph import models from pydantic import Field from rich import progress @@ -47,6 +48,11 @@ def compute_risks( model = add_modalities(model, modality_configs) risks = [] + if isinstance(model, models.Unilateral | models.HPVUnilateral): + involvement = involvement.ipsi + else: + involvement = involvement.model_dump() + for posterior in progress.track( sequence=posteriors, description=progress_desc, diff --git a/lyscripts/configs.py b/lyscripts/configs.py index 0f1e79a..d0d392b 100644 --- a/lyscripts/configs.py +++ b/lyscripts/configs.py @@ -247,12 +247,17 @@ def model_post_init(self, __context): ) if "Midline" in self.class_: self.class_ = "Midline" + warnings.warn( + "Model may not be recreated as expected due to extra parameter " + "`midext_prob`. Make sure to manually handle edge cases.", + stacklevel=2, + ) return super().model_post_init(__context) def translate(self) -> tuple[ModelConfig, dict[int | str, DistributionConfig]]: """Translate the deprecated model config to the new format.""" old_kwargs = self.kwargs.copy() - new_kwargs = {} + new_kwargs = {"use_midext_evo": False} if "Midline" in self.class_ else {} if (tumor_spread := old_kwargs.pop("base_symmetric")) is not None: new_kwargs["is_symmetric"] = new_kwargs.get("is_symmetric", {}) diff --git a/pyproject.toml b/pyproject.toml index bb89967..d5ec5d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,13 +44,13 @@ dependencies = [ "rich", "rich-argparse", "pyyaml", - "lymph-model >= 1.3.2", + "lymph-model >= 1.3.3", "deprecated", "joblib", "pydantic", "pydantic-settings >= 2.7.0", "numpydantic", - "lydata >= 0.2.0", + "lydata >= 0.2.5", "loguru", ] dynamic = ["version"] diff --git a/tests/predict/posteriors_test.py b/tests/predict/posteriors_test.py new file mode 100644 index 0000000..5511e0c --- /dev/null +++ b/tests/predict/posteriors_test.py @@ -0,0 +1,135 @@ +"""Test utilities of the predict submodule.""" + +import numpy as np +import pytest +from lydata.utils import ModalityConfig + +from lyscripts.compute.posteriors import compute_posteriors +from lyscripts.compute.priors import compute_priors +from lyscripts.compute.utils import complete_pattern +from lyscripts.configs import ( + DiagnosisConfig, + DistributionConfig, + GraphConfig, + ModelConfig, + add_distributions, + construct_model, +) + +RNG = np.random.default_rng(42) + + +@pytest.fixture(params=["Unilateral", "Bilateral"]) +def model_config(request) -> ModelConfig: + """Create unilateral model config.""" + return ModelConfig(class_name=request.param) + + +@pytest.fixture +def graph_config() -> GraphConfig: + """Create simple graph.""" + return GraphConfig( + tumor={"T": ["I", "II", "III"]}, + lnl={"I": ["II"], "II": ["III"], "III": []}, + ) + + +@pytest.fixture +def dist_configs() -> dict[str, DistributionConfig]: + """Provide early and late distributions.""" + return { + "early": DistributionConfig(kind="frozen", func="binomial"), + "late": DistributionConfig(kind="parametric", func="binomial"), + } + + +@pytest.fixture +def modality_config() -> ModalityConfig: + """Create modality config.""" + return ModalityConfig(spec=0.9, sens=0.8) + + +@pytest.fixture +def diagnosis_config() -> DiagnosisConfig: + """Create a simple diagnosis config.""" + return DiagnosisConfig( + ipsi={"D": {"I": True, "II": True, "III": False}}, + contra={"D": {"I": False, "II": True, "III": False}}, + ) + + +@pytest.fixture +def samples( + model_config: ModelConfig, + graph_config: GraphConfig, + dist_configs: dict[str, DistributionConfig], +) -> np.ndarray: + """Generate some samples.""" + model = construct_model(model_config, graph_config) + model = add_distributions(model, dist_configs) + return RNG.uniform(size=(100, model.get_num_dims())) + + +@pytest.fixture +def priors( + model_config: ModelConfig, + graph_config: GraphConfig, + dist_configs: dict[str, DistributionConfig], + samples: np.ndarray, +) -> np.ndarray: + """Provide some priors.""" + return compute_priors( + model_config=model_config, + graph_config=graph_config, + dist_configs=dist_configs, + samples=samples, + t_stages=["late"], + t_stages_dist=[1.0], + ) + + +def test_compute_posterior( + model_config: ModelConfig, + graph_config: GraphConfig, + dist_configs: dict[str, DistributionConfig], + modality_config: ModalityConfig, + diagnosis_config: DiagnosisConfig, + priors: np.ndarray, +) -> None: + """Ensure that the diagnosis is correctly treated.""" + posteriors = compute_posteriors( + model_config=model_config, + graph_config=graph_config, + dist_configs=dist_configs, + modality_configs={"D": modality_config}, + priors=priors, + diagnosis=diagnosis_config, + ) + + assert np.all(posteriors >= 0), "Negative probabilities in posterior." + assert np.all(posteriors <= 1), "Probabilities above 1 in posterior." + + +def test_clean_pattern(): + """Test outdated utility function.""" + empty_pattern = {} + one_pos_pattern = {"ipsi": {"II": True}} + nums_pattern = {"ipsi": {"I": 1}, "contra": {"III": 0}} + lnls = ["I", "II", "III"] + + empty_cleaned = complete_pattern(empty_pattern, lnls) + one_pos_cleaned = complete_pattern(one_pos_pattern, lnls) + nums_cleaned = complete_pattern(nums_pattern, lnls) + + assert empty_cleaned == { + "ipsi": {"I": None, "II": None, "III": None}, + "contra": {"I": None, "II": None, "III": None}, + }, "Empty pattern does not get filled correctly." + assert one_pos_cleaned == { + "ipsi": {"I": None, "II": True, "III": None}, + "contra": {"I": None, "II": None, "III": None}, + }, "Pattern with one positive LNL not cleaned properly." + assert nums_cleaned == { + "ipsi": {"I": True, "II": None, "III": None}, + "contra": {"I": None, "II": None, "III": False}, + }, "Number pattern cleaned wrongly." diff --git a/tests/predict/predict_utils_test.py b/tests/predict/predict_utils_test.py deleted file mode 100644 index 06fb7c7..0000000 --- a/tests/predict/predict_utils_test.py +++ /dev/null @@ -1,30 +0,0 @@ -"""Test utilities of the predict submodule.""" - -from lyscripts.compute.utils import complete_pattern - - -def test_clean_pattern(): - """Test the utility function that cleans the involvement patterns from the - `params.yaml` file - """ - empty_pattern = {} - one_pos_pattern = {"ipsi": {"II": True}} - nums_pattern = {"ipsi": {"I": 1}, "contra": {"III": 0}} - lnls = ["I", "II", "III"] - - empty_cleaned = complete_pattern(empty_pattern, lnls) - one_pos_cleaned = complete_pattern(one_pos_pattern, lnls) - nums_cleaned = complete_pattern(nums_pattern, lnls) - - assert empty_cleaned == { - "ipsi": {"I": None, "II": None, "III": None}, - "contra": {"I": None, "II": None, "III": None}, - }, "Empty pattern does not get filled correctly." - assert one_pos_cleaned == { - "ipsi": {"I": None, "II": True, "III": None}, - "contra": {"I": None, "II": None, "III": None}, - }, "Pattern with one positive LNL not cleaned properly." - assert nums_cleaned == { - "ipsi": {"I": True, "II": None, "III": None}, - "contra": {"I": None, "II": None, "III": False}, - }, "Number pattern cleaned wrongly." diff --git a/tests/utils_test.py b/tests/utils_test.py index d137d6f..ebeaad1 100644 --- a/tests/utils_test.py +++ b/tests/utils_test.py @@ -83,7 +83,8 @@ def test_translate_deprecated_model_config( trans_model_config, trans_dist_configs = old_model_config.translate() - assert exp_model_config.model_dump( - exclude="kwargs" - ) == trans_model_config.model_dump(exclude="kwargs") + assert ( # noqa + exp_model_config.model_dump(exclude="kwargs") + == trans_model_config.model_dump(exclude="kwargs") + ) assert exp_dist_configs == trans_dist_configs