Skip to content

Commit 280e2a1

Browse files
committed
Added '--recomp-embs' flag to recompute embeddings if already saved
1 parent 6ac6ce3 commit 280e2a1

File tree

5 files changed

+31
-6
lines changed

5 files changed

+31
-6
lines changed

src/thunder/benchmark.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def benchmark(
1717
lora: bool = False,
1818
ckpt_save_all: bool = False,
1919
online_wandb: bool = False,
20+
recomp_embs: bool = False,
2021
**kwargs,
2122
):
2223
"""
@@ -36,6 +37,7 @@ def benchmark(
3637
lora (bool): Whether to use LoRA (Low-Rank Adaptation) for model adaptation. Default is False.
3738
ckpt_save_all (bool): Whether to save all checkpoints during training. Default is False which means that only the best is saved.
3839
online_wandb (bool): Whether to use online mode for Weights & Biases (wandb) logging. Default is False which means offline mode.
40+
recomp_embs (bool): Whether to recompute embeddings if already saved.
3941
"""
4042
from hydra import compose, initialize
4143
from omegaconf import OmegaConf
@@ -45,6 +47,7 @@ def benchmark(
4547
wandb_mode = "online" if online_wandb else "offline"
4648
adaptation_type = "lora" if lora else "frozen"
4749
ckpt_saving = "save_ckpts_all_epochs" if ckpt_save_all else "save_best_ckpt_only"
50+
embedding_recomputing = "recomp_embs" if recomp_embs else "no_recomp_embs"
4851
model_name = model if isinstance(model, str) else None
4952

5053
if model_name and model_name.startswith("custom:"):
@@ -60,6 +63,7 @@ def benchmark(
6063
adaptation_type,
6164
loading_mode,
6265
wandb_mode,
66+
embedding_recomputing,
6367
**kwargs,
6468
)
6569

@@ -253,7 +257,23 @@ def run_benchmark(cfg: DictConfig, model_cls: Callable = None) -> None:
253257
f"No pre-computed embeddings found for the (dataset, model) pair "
254258
f"({dataset_name}, {model_name}). Computing them."
255259
)
260+
comp_embs = True
261+
else:
262+
emb_info_str = (
263+
f"Pre-computed embeddings already found for the (dataset, model) pair "
264+
f"({dataset_name}, {model_name})."
265+
)
266+
267+
if cfg.embedding_recomputing.recompute_embeddings:
268+
emb_info_str += " Re-computing them as explictly requested."
269+
comp_embs = True
270+
else:
271+
emb_info_str += " Not re-computing them."
272+
comp_embs = False
273+
274+
logging.info(emb_info_str)
256275

276+
if comp_embs:
257277
pre_computing_patch_embeddings(
258278
cfg,
259279
embeddings_folder,
@@ -269,12 +289,6 @@ def run_benchmark(cfg: DictConfig, model_cls: Callable = None) -> None:
269289
model_cls,
270290
)
271291

272-
else:
273-
logging.info(
274-
f"Pre-computed embeddings already found for the (dataset, model) pair "
275-
f"({dataset_name}, {model_name}). Not re-computing them."
276-
)
277-
278292
if task_type in [
279293
"alignment_scoring",
280294
"embedding_space_visualization",
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
recompute_embeddings: False
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
recompute_embeddings: True

src/thunder/main.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,12 @@ def benchmark(
7171
online_wandb: Annotated[
7272
bool, typer.Option(help="Logging with the online mode of wandb")
7373
] = False,
74+
recomp_embs: Annotated[
75+
bool,
76+
typer.Option(
77+
help="If provided embeddings will be re-computed even if already saved"
78+
),
79+
] = False,
7480
kwargs: Annotated[List[str], typer.Argument(help="Additional arguments")] = None,
7581
):
7682
from . import benchmark
@@ -94,6 +100,7 @@ def benchmark(
94100
lora,
95101
ckpt_save_all,
96102
online_wandb,
103+
recomp_embs,
97104
**kwargs,
98105
)
99106

src/thunder/utils/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def get_config(
1212
adaptation: Optional[str] = None,
1313
data_loading_type: Optional[str] = None,
1414
wandb_mode: Optional[str] = None,
15+
embedding_recomputing: Optional[str] = None,
1516
**kwargs,
1617
) -> DictConfig:
1718
params = {
@@ -22,6 +23,7 @@ def get_config(
2223
"pretrained_model": pretrained_model,
2324
"task": task,
2425
"wandb": wandb_mode,
26+
"embedding_recomputing": embedding_recomputing,
2527
}
2628

2729
overrides = [f"+{k}={v}" for k, v in params.items() if v is not None]

0 commit comments

Comments
 (0)