diff --git a/src/bluesearch/entrypoint/database/add.py b/src/bluesearch/entrypoint/database/add.py index 7055f2941..ce7ea1e96 100644 --- a/src/bluesearch/entrypoint/database/add.py +++ b/src/bluesearch/entrypoint/database/add.py @@ -82,6 +82,7 @@ def run( import sqlalchemy from bluesearch.database.article import Article + from bluesearch.sql import retrieve_existing_article_ids from bluesearch.utils import load_spacy_model if db_type == "sqlite": @@ -116,6 +117,13 @@ def run( if not articles: raise RuntimeWarning(f"No article was loaded from '{parsed_path}'!") + # Keep only articles not already present in the database + existing_uids = retrieve_existing_article_ids(engine) + articles = [article for article in articles if article.uid not in existing_uids] + + if not articles: + raise RuntimeWarning(f"All articles are already saved in '{db_url}'!") + logger.info("Loading spacy model") nlp = load_spacy_model("en_core_sci_lg", disable=["ner"]) diff --git a/src/bluesearch/sql.py b/src/bluesearch/sql.py index 6b32018be..4a759250d 100644 --- a/src/bluesearch/sql.py +++ b/src/bluesearch/sql.py @@ -16,12 +16,14 @@ # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . +from __future__ import annotations import logging from typing import cast import numpy as np import pandas as pd +import sqlalchemy import sqlalchemy.sql as sql @@ -58,6 +60,24 @@ def get_titles(article_ids, engine): return titles +def retrieve_existing_article_ids(engine: sqlalchemy.engine.Engine) -> list[str]: + """Retrieve all articles_ids from a database. + + Parameters + ---------- + engine : sqlalchemy.engine.Engine + SQLAlchemy Engine connected to the database. + + Returns + ------- + article_ids : list[str] + List of existing article_ids + """ + result_proxy = engine.execute("SELECT article_id FROM articles") + article_ids = [result[0] for result in result_proxy.fetchall()] + return article_ids + + def retrieve_article_ids(engine): """Retrieve all articles_id from sentences table. diff --git a/tests/unit/entrypoint/database/test_add.py b/tests/unit/entrypoint/database/test_add.py index 306915250..a62cd5987 100644 --- a/tests/unit/entrypoint/database/test_add.py +++ b/tests/unit/entrypoint/database/test_add.py @@ -95,6 +95,16 @@ def test_sqlite_cord19(engine_sqlite, tmp_path, monkeypatch, model_entities): (n_rows_sentences,) = engine_sqlite.execute(query_sentences).fetchone() assert n_rows_sentences > 0 + # Test adding something that is already in the database + with pytest.raises(RuntimeWarning): + args_and_opts = [ + "add", + engine_sqlite.url.database, + str(parsed_dir), + "--db-type=sqlite", + ] + main(args_and_opts) + # Test adding something that does not exist with pytest.raises(ValueError): args_and_opts = [ diff --git a/tests/unit/test_sql.py b/tests/unit/test_sql.py index 71a5a93f6..45f7eb027 100644 --- a/tests/unit/test_sql.py +++ b/tests/unit/test_sql.py @@ -30,6 +30,7 @@ retrieve_article_ids, retrieve_article_metadata_from_article_id, retrieve_articles, + retrieve_existing_article_ids, retrieve_mining_cache, retrieve_paragraph, retrieve_paragraph_from_sentence_id, @@ -182,6 +183,13 @@ def test_retrieve_article( == len(set(article_id)) * test_parameters["n_sections_per_article"] ) + def test_retrieve_existing_article_ids( + self, fake_sqlalchemy_engine, test_parameters + ): + article_ids = retrieve_existing_article_ids(fake_sqlalchemy_engine) + assert isinstance(article_ids, list) + assert len(article_ids) == test_parameters["n_articles"] + def test_retrieve_articles_ids(self, fake_sqlalchemy_engine, test_parameters): article_ids_dict = retrieve_article_ids(fake_sqlalchemy_engine) assert isinstance(article_ids_dict, dict)