From 55639066b42382d85292f40110c0bc5aef197dde Mon Sep 17 00:00:00 2001 From: Jonathan Carter Date: Fri, 19 Dec 2025 11:55:25 +0000 Subject: [PATCH 1/2] Reinstante checkpoint_path_prefix Signed-off-by: Jonathan Carter --- docs/source-pytorch/visualize/loggers.rst | 34 ++++++++++++++++++++++ src/lightning/pytorch/CHANGELOG.md | 2 +- src/lightning/pytorch/loggers/mlflow.py | 5 +++- tests/tests_pytorch/loggers/test_mlflow.py | 30 +++++++++++++++++++ 4 files changed, 69 insertions(+), 2 deletions(-) diff --git a/docs/source-pytorch/visualize/loggers.rst b/docs/source-pytorch/visualize/loggers.rst index bdf95ec1b675e..f4fd5b23b2311 100644 --- a/docs/source-pytorch/visualize/loggers.rst +++ b/docs/source-pytorch/visualize/loggers.rst @@ -54,3 +54,37 @@ Track and Visualize Experiments + +.. _mlflow_logger: + +MLflow Logger +------------- + +The MLflow logger in PyTorch Lightning now includes a `checkpoint_path_prefix` parameter. This parameter allows you to prefix the checkpoint artifact's path when logging checkpoints as artifacts. + +Example usage: + +.. code-block:: python + + import lightning as L + from lightning.pytorch.loggers import MLFlowLogger + + mlf_logger = MLFlowLogger( + experiment_name="lightning_logs", + tracking_uri="file:./ml-runs", + checkpoint_path_prefix="my_prefix" + ) + trainer = L.Trainer(logger=mlf_logger) + + # Your LightningModule definition + class LitModel(L.LightningModule): + def training_step(self, batch, batch_idx): + # example + self.logger.experiment.whatever_ml_flow_supports(...) + + def any_lightning_module_function_or_hook(self): + self.logger.experiment.whatever_ml_flow_supports(...) + + # Train your model + model = LitModel() + trainer.fit(model) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 591d258bcdd0b..c94e349962f9e 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -10,7 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- +- Reinstated`checkpoint_path_prefix` parameter in `MLFlowLogger` to control the artifact path prefix for logged checkpoints. ### Changed diff --git a/src/lightning/pytorch/loggers/mlflow.py b/src/lightning/pytorch/loggers/mlflow.py index ff9b2b0d7e542..2db2a237099dd 100644 --- a/src/lightning/pytorch/loggers/mlflow.py +++ b/src/lightning/pytorch/loggers/mlflow.py @@ -98,6 +98,7 @@ def any_lightning_module_function_or_hook(self): which also logs every checkpoint during training. * if ``log_model == False`` (default), no checkpoint is logged. + checkpoint_path_prefix: A string to prefix the checkpoint artifact's path. prefix: A string to put at the beginning of metric keys. artifact_location: The location to store run artifacts. If not provided, the server picks an appropriate default. @@ -121,6 +122,7 @@ def __init__( tags: Optional[dict[str, Any]] = None, save_dir: Optional[str] = "./mlruns", log_model: Literal[True, False, "all"] = False, + checkpoint_path_prefix: str = "", prefix: str = "", artifact_location: Optional[str] = None, run_id: Optional[str] = None, @@ -147,6 +149,7 @@ def __init__( self._artifact_location = artifact_location self._log_batch_kwargs = {} if synchronous is None else {"synchronous": synchronous} self._initialized = False + self._checkpoint_path_prefix = checkpoint_path_prefix from mlflow.tracking import MlflowClient @@ -361,7 +364,7 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> Non aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"] # Artifact path on mlflow - artifact_path = Path(p).stem + artifact_path = Path(self._checkpoint_path_prefix, Path(p).stem).as_posix() # Log the checkpoint self.experiment.log_artifact(self._run_id, p, artifact_path) diff --git a/tests/tests_pytorch/loggers/test_mlflow.py b/tests/tests_pytorch/loggers/test_mlflow.py index c7f9dbe1fe2c6..8118349ea6721 100644 --- a/tests/tests_pytorch/loggers/test_mlflow.py +++ b/tests/tests_pytorch/loggers/test_mlflow.py @@ -427,3 +427,33 @@ def test_set_tracking_uri(mlflow_mock): mlflow_mock.set_tracking_uri.assert_not_called() _ = logger.experiment mlflow_mock.set_tracking_uri.assert_called_with("the_tracking_uri") + + +@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) +def test_mlflow_log_model_with_checkpoint_path_prefix(mlflow_mock, tmp_path): + """Test that the logger creates the folders and files in the right place with a prefix.""" + client = mlflow_mock.tracking.MlflowClient + + # Get model, logger, trainer and train + model = BoringModel() + logger = MLFlowLogger("test", save_dir=str(tmp_path), log_model="all", checkpoint_path_prefix="my_prefix") + logger = mock_mlflow_run_creation(logger, experiment_id="test-id") + + trainer = Trainer( + default_root_dir=tmp_path, + logger=logger, + max_epochs=2, + limit_train_batches=3, + limit_val_batches=3, + ) + trainer.fit(model) + + # Checkpoint log + assert client.return_value.log_artifact.call_count == 2 + # Metadata and aliases log + assert client.return_value.log_artifacts.call_count == 2 + + # Check that the prefix is used in the artifact path + for call in client.return_value.log_artifact.call_args_list: + args, _ = call + assert str(args[2]).startswith("my_prefix") From e1f3d9a8f70c68153464b5f0d03ba58ff5825a60 Mon Sep 17 00:00:00 2001 From: Jonathan Carter <42900403+joncarter1@users.noreply.github.com> Date: Fri, 19 Dec 2025 12:26:08 +0000 Subject: [PATCH 2/2] Update CHANGELOG.md --- src/lightning/pytorch/CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index c94e349962f9e..990c5d21992e4 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -10,7 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- Reinstated`checkpoint_path_prefix` parameter in `MLFlowLogger` to control the artifact path prefix for logged checkpoints. +- Reinstated the `checkpoint_path_prefix` parameter in `MLFlowLogger` ([#21432](https://github.com/Lightning-AI/pytorch-lightning/pull/21432)) ### Changed