From 1d1648d11ee4cc3d227450e3b45e618ba50e9f86 Mon Sep 17 00:00:00 2001 From: Emilie Delattre Date: Mon, 7 Mar 2022 11:25:25 +0100 Subject: [PATCH 1/4] Do not add existing uids --- src/bluesearch/entrypoint/database/add.py | 5 +++++ src/bluesearch/sql.py | 20 ++++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/src/bluesearch/entrypoint/database/add.py b/src/bluesearch/entrypoint/database/add.py index 7055f2941..23c57c948 100644 --- a/src/bluesearch/entrypoint/database/add.py +++ b/src/bluesearch/entrypoint/database/add.py @@ -82,6 +82,7 @@ def run( import sqlalchemy from bluesearch.database.article import Article + from bluesearch.sql import retrieve_existing_article_ids from bluesearch.utils import load_spacy_model if db_type == "sqlite": @@ -116,6 +117,10 @@ def run( if not articles: raise RuntimeWarning(f"No article was loaded from '{parsed_path}'!") + # Filter articles already present in the database + existing_uids = retrieve_existing_article_ids(engine) + articles = [article for article in articles if article.uid not in existing_uids] + logger.info("Loading spacy model") nlp = load_spacy_model("en_core_sci_lg", disable=["ner"]) diff --git a/src/bluesearch/sql.py b/src/bluesearch/sql.py index 6b32018be..4a759250d 100644 --- a/src/bluesearch/sql.py +++ b/src/bluesearch/sql.py @@ -16,12 +16,14 @@ # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . +from __future__ import annotations import logging from typing import cast import numpy as np import pandas as pd +import sqlalchemy import sqlalchemy.sql as sql @@ -58,6 +60,24 @@ def get_titles(article_ids, engine): return titles +def retrieve_existing_article_ids(engine: sqlalchemy.engine.Engine) -> list[str]: + """Retrieve all articles_ids from a database. + + Parameters + ---------- + engine : sqlalchemy.engine.Engine + SQLAlchemy Engine connected to the database. + + Returns + ------- + article_ids : list[str] + List of existing article_ids + """ + result_proxy = engine.execute("SELECT article_id FROM articles") + article_ids = [result[0] for result in result_proxy.fetchall()] + return article_ids + + def retrieve_article_ids(engine): """Retrieve all articles_id from sentences table. From c501614c98dc21100691e009bcf0757bda18f8ea Mon Sep 17 00:00:00 2001 From: Emilie Delattre Date: Mon, 7 Mar 2022 12:00:53 +0100 Subject: [PATCH 2/4] Add a test for sql new function --- tests/unit/test_sql.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/unit/test_sql.py b/tests/unit/test_sql.py index 71a5a93f6..200db1d4b 100644 --- a/tests/unit/test_sql.py +++ b/tests/unit/test_sql.py @@ -30,6 +30,7 @@ retrieve_article_ids, retrieve_article_metadata_from_article_id, retrieve_articles, + retrieve_existing_article_ids, retrieve_mining_cache, retrieve_paragraph, retrieve_paragraph_from_sentence_id, @@ -182,6 +183,11 @@ def test_retrieve_article( == len(set(article_id)) * test_parameters["n_sections_per_article"] ) + def test_retrieve_existing_article_ids(self, fake_sqlalchemy_engine, test_parameters): + article_ids = retrieve_existing_article_ids(fake_sqlalchemy_engine) + assert isinstance(article_ids, list) + assert len(article_ids) == test_parameters["n_articles"] + def test_retrieve_articles_ids(self, fake_sqlalchemy_engine, test_parameters): article_ids_dict = retrieve_article_ids(fake_sqlalchemy_engine) assert isinstance(article_ids_dict, dict) From 45858411869e0ff3415f8bb7abacfb3133daa10e Mon Sep 17 00:00:00 2001 From: Emilie Delattre Date: Mon, 7 Mar 2022 17:01:01 +0100 Subject: [PATCH 3/4] Fix linting --- tests/unit/test_sql.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_sql.py b/tests/unit/test_sql.py index 200db1d4b..45f7eb027 100644 --- a/tests/unit/test_sql.py +++ b/tests/unit/test_sql.py @@ -183,7 +183,9 @@ def test_retrieve_article( == len(set(article_id)) * test_parameters["n_sections_per_article"] ) - def test_retrieve_existing_article_ids(self, fake_sqlalchemy_engine, test_parameters): + def test_retrieve_existing_article_ids( + self, fake_sqlalchemy_engine, test_parameters + ): article_ids = retrieve_existing_article_ids(fake_sqlalchemy_engine) assert isinstance(article_ids, list) assert len(article_ids) == test_parameters["n_articles"] From 86a1b16b67bfeaf07b8eda76957fd2ae09078b85 Mon Sep 17 00:00:00 2001 From: Emilie Delattre Date: Mon, 7 Mar 2022 17:27:37 +0100 Subject: [PATCH 4/4] Raise RuntimeWarning if no new articles --- src/bluesearch/entrypoint/database/add.py | 5 ++++- tests/unit/entrypoint/database/test_add.py | 10 ++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/bluesearch/entrypoint/database/add.py b/src/bluesearch/entrypoint/database/add.py index 23c57c948..ce7ea1e96 100644 --- a/src/bluesearch/entrypoint/database/add.py +++ b/src/bluesearch/entrypoint/database/add.py @@ -117,10 +117,13 @@ def run( if not articles: raise RuntimeWarning(f"No article was loaded from '{parsed_path}'!") - # Filter articles already present in the database + # Keep only articles not already present in the database existing_uids = retrieve_existing_article_ids(engine) articles = [article for article in articles if article.uid not in existing_uids] + if not articles: + raise RuntimeWarning(f"All articles are already saved in '{db_url}'!") + logger.info("Loading spacy model") nlp = load_spacy_model("en_core_sci_lg", disable=["ner"]) diff --git a/tests/unit/entrypoint/database/test_add.py b/tests/unit/entrypoint/database/test_add.py index 306915250..a62cd5987 100644 --- a/tests/unit/entrypoint/database/test_add.py +++ b/tests/unit/entrypoint/database/test_add.py @@ -95,6 +95,16 @@ def test_sqlite_cord19(engine_sqlite, tmp_path, monkeypatch, model_entities): (n_rows_sentences,) = engine_sqlite.execute(query_sentences).fetchone() assert n_rows_sentences > 0 + # Test adding something that is already in the database + with pytest.raises(RuntimeWarning): + args_and_opts = [ + "add", + engine_sqlite.url.database, + str(parsed_dir), + "--db-type=sqlite", + ] + main(args_and_opts) + # Test adding something that does not exist with pytest.raises(ValueError): args_and_opts = [