@@ -17,6 +17,7 @@ def benchmark(
1717 lora : bool = False ,
1818 ckpt_save_all : bool = False ,
1919 online_wandb : bool = False ,
20+ recomp_embs : bool = False ,
2021 ** kwargs ,
2122):
2223 """
@@ -36,6 +37,7 @@ def benchmark(
3637 lora (bool): Whether to use LoRA (Low-Rank Adaptation) for model adaptation. Default is False.
3738 ckpt_save_all (bool): Whether to save all checkpoints during training. Default is False which means that only the best is saved.
3839 online_wandb (bool): Whether to use online mode for Weights & Biases (wandb) logging. Default is False which means offline mode.
40+ recomp_embs (bool): Whether to recompute embeddings if already saved.
3941 """
4042 from hydra import compose , initialize
4143 from omegaconf import OmegaConf
@@ -45,6 +47,7 @@ def benchmark(
4547 wandb_mode = "online" if online_wandb else "offline"
4648 adaptation_type = "lora" if lora else "frozen"
4749 ckpt_saving = "save_ckpts_all_epochs" if ckpt_save_all else "save_best_ckpt_only"
50+ embedding_recomputing = "recomp_embs" if recomp_embs else "no_recomp_embs"
4851 model_name = model if isinstance (model , str ) else None
4952
5053 if model_name and model_name .startswith ("custom:" ):
@@ -60,6 +63,7 @@ def benchmark(
6063 adaptation_type ,
6164 loading_mode ,
6265 wandb_mode ,
66+ embedding_recomputing ,
6367 ** kwargs ,
6468 )
6569
@@ -253,7 +257,23 @@ def run_benchmark(cfg: DictConfig, model_cls: Callable = None) -> None:
253257 f"No pre-computed embeddings found for the (dataset, model) pair "
254258 f"({ dataset_name } , { model_name } ). Computing them."
255259 )
260+ comp_embs = True
261+ else :
262+ emb_info_str = (
263+ f"Pre-computed embeddings already found for the (dataset, model) pair "
264+ f"({ dataset_name } , { model_name } )."
265+ )
266+
267+ if cfg .embedding_recomputing .recompute_embeddings :
268+ emb_info_str += " Re-computing them as explictly requested."
269+ comp_embs = True
270+ else :
271+ emb_info_str += " Not re-computing them."
272+ comp_embs = False
273+
274+ logging .info (emb_info_str )
256275
276+ if comp_embs :
257277 pre_computing_patch_embeddings (
258278 cfg ,
259279 embeddings_folder ,
@@ -269,12 +289,6 @@ def run_benchmark(cfg: DictConfig, model_cls: Callable = None) -> None:
269289 model_cls ,
270290 )
271291
272- else :
273- logging .info (
274- f"Pre-computed embeddings already found for the (dataset, model) pair "
275- f"({ dataset_name } , { model_name } ). Not re-computing them."
276- )
277-
278292 if task_type in [
279293 "alignment_scoring" ,
280294 "embedding_space_visualization" ,
0 commit comments