From abbf3f411771eeb40b4f7195bfbe5b1336798908 Mon Sep 17 00:00:00 2001 From: Peng Ren Date: Tue, 30 Dec 2025 20:52:30 -0500 Subject: [PATCH] Implemented executemany() --- .coverage | Bin 0 -> 53248 bytes HOW_IT_WORKS.md | 5 - README.md | 4 +- pymongosql/__init__.py | 2 +- pymongosql/cursor.py | 50 +++- tests/test_cursor_delete.py | 21 ++ tests/test_cursor_insert.py | 25 +- tests/test_cursor_update.py | 24 ++ tests/test_delete_builder.py | 145 +++++++++++ tests/test_delete_handler.py | 270 +++++++++++++++++++ tests/test_insert_builder.py | 268 +++++++++++++++++++ tests/test_insert_handler.py | 359 +++++++++++++++++++++++++ tests/test_query_builder.py | 432 +++++++++++++++++++++++++++++++ tests/test_query_handler.py | 264 +++++++++++++++++++ tests/test_sqlalchemy_dialect.py | 2 +- tests/test_sqlalchemy_dml.py | 2 +- tests/test_sqlalchemy_query.py | 2 +- tests/test_update_builder.py | 179 +++++++++++++ tests/test_update_handler.py | 358 +++++++++++++++++++++++++ 19 files changed, 2390 insertions(+), 22 deletions(-) create mode 100644 .coverage create mode 100644 tests/test_delete_builder.py create mode 100644 tests/test_delete_handler.py create mode 100644 tests/test_insert_builder.py create mode 100644 tests/test_insert_handler.py create mode 100644 tests/test_query_builder.py create mode 100644 tests/test_query_handler.py create mode 100644 tests/test_update_builder.py create mode 100644 tests/test_update_handler.py diff --git a/.coverage b/.coverage new file mode 100644 index 0000000000000000000000000000000000000000..2e867e9a535be20df2a8f52f087e2a776e16c3e9 GIT binary patch literal 53248 zcmeI4d2kd}8o;~f>bW~1$K-?-g%EN;h=CPW373knT)7|&bS5*MWMGcaGXrEbzzE8g zTU#Px?c%x;i&Adb#bb+`UAC4Ll;|2jkQJ>GbV&ndQ z?)Tnz{NDH8>o=1(Wz#1+{Q_6#72P&JS4<^PG)}Z>x0)e$i$Z#PGXm+>8lh$|rE;W5!ONz=iK|gY+C61qIv~PV_F~8U>N7bGii1 z>2WyiHowzT&-ojKXz74YaFB@(?!iG3)JPOpa( z77O-(UvS*f!F&QOdH}q5Z$MTJai}I-GomYwPOS1BYDS7jD%XK=9&Z$DtRhhjl_i@! z0e3ARpBHd9lE#4=Va0xLcQw1K>-R>S{n2f6hqjFQ+od_9$VrQaYXv;c1p%S6@$fvR zC-76NPjp>^rt@|;p+oY!j~miUx&{g9MJ^ZTt&_S8sm+BO2ckK?&kTveu!*Fj zao$Klu=FR5HexK;z}33EwOwOPiWr{&9R$XLXtBD0$4xDlF=HmQ8L;`B01nr0i&^u$?VRZNIAYJa-t0~D#ntMLI=Y=f{aV#rS9ScxfAOG zw@d^--uyz*+l^|VJ^+h zQQj>KNNe#$2c*6q;AOR$R1dsp{Kmter2_OL^fCP{L68KpXFe;SU zZ1g`oh9=O;%=F0r0E{RsDlQzV2d@3P#I<0<2EV}HZH z4D|*^|C;_`-8o&eE=9XnThD};Rm>pGe$6EH*Xln(RXmUY5+J==3wT}1N18%|N_gPB{#tRDtm$wnNZCk6n;(T8tY+kocYxflu&v zHa=`67n^uTE$DGuKo8%;did}zbmCTO54zJ9XjL2?9=+Lm*?y+yqcC zMm`1M4iS^jvpwQK1MEe1g&i>Wgaa}8yh&aMoF15k{CrHoOd}{Lk*@_}l6ix?raJ`q zWfvjN^m08YD3H%|n~(J2Dji6bX>cHOcB2i3Y!d6dR$d;Fri=yqOa>Il*pxBJyhdKr zBTX4oFiQ;zM#$HcabsAS3Pj5^gER;(n1P`!E0iE~h&=Yjjq5WMAU0Pfv7#u=6&196 zrp93<-~VT;2UF~3ZG-x<`enA#95Qd#?lVs{t}v}JT{7Nlc-vHLXfU4EKVs18&g(a5 z4rnH;2kVw-DP1a)%RJY+AN$ceB!C2v07dy!K_@NGn11)wZEXFYU9Jk|$)vM${hu{c z6)cuZX#Dy=Q&a`(WfFXMLY?dX45uns(Q_)}=E?MWRdBjon(ngxPxGjPd{66%t^X}O z=~mqOKUL01j;{YxWQ^o4>;L4Qq!Ra>nADSc;)1tNjVd@!uAsWp`ah9Z1;@y!;MVnj zLeDr5Tja7m;y_HkS!M@rTmPG8se&akwm{5uqs*q?w*EI%sDcGDnjT*N>t!4qx~%_o zGV{1?{jZg|DPzlXMrPA*TmNgSRlyN5Hf7uxrj~IA>0JM-DpkQDGT9q9t}Dw_!Cbk- zM%Mod=p{pR@;raWVa4D7!?g_(Kmter2_OL^fCP{L5P5lI%)sEk)hbH*pJv%Y&|=ZO*DUF zK5RZ0+Kdkyi=ZkVNB{{S0VIF~ zkN^@u0!RP}{JaPZ%2v_385Zl>kZUFFpYI;6Fz%j_$K7b8?rS}D^`2!X5)St7PrYYP zq?59sw>A-9mh3{Q3KfW#!=(uJw0I3YKP{X z@9V#Qqw=_-^uCq#ODeBd4dM6grKrEpJ@q*|`r7)p}nzVt!AI7m}1&rceft$4Iw3fwAX&xqev zhOT~acyrUV;OFYlj>fvDt}NYkeE-$r#AH$@(YmXB)2A15x10!UQ}#_FYRqTEW35|{ z<<6fqeP58OTy!yGh&!r0s>2G$={!ifKoE?MP zLzm~3R8*Z*rYAzN`*mk-?tFaVx_?%0SzS15+a-V1CUM}N+yuxfH$8Fg=IpmFiy{8N zl9sdQo>1;PIjr^jUmicbbqRB6`_j<$$4|D;v9M5dsFv6_r@3Qy`=#CO+mBZU0@?G; zglwg8?>|C^PgE;synB89Y?k`-l)PCl&l;}izxDanwz=zCcJ2N+bNRJ5cI<0-Xf)NZ zeDKEBUPy5byZs6E$JJ-Qi$J|pj zAD!HP;F|kTO_~X;7-l#lF5a?Z>a!I$4*qc9#?rs;*mAa?*$6p(HP&4LebwR3f$S^; zT#sa|)C0+%o@#4rTk|(&8Q)$#fA*deb7w2^^^jYw?>LZubji_mKU80x*3SGeX6bt$ zt6DZ}$Z2_yrK+o%Precw7g}**(?c2^lm<5owUEwDK2zBK?t?cczq)1Z@xibEK8f05 z&0Ch-l3V^_Zpt=I+x=CmH@)^w@WuV_K9j~k`C%$mm8OQGti_|6veqAb>%XCrP|^eq zfu>jPy(wBoU88C~oHr+9GqYy=3-7e%PdYGn`(vjr4p_N=|C$$?U)j-dZFMRC*pySs z>qk689=P(V^awv0Da5*ijE zwQd^AP%UkB%eajdTPicpJ!PIfQAyPI$?W)M=S`aWFg=GqJi2VabhT>y@}ub;m8Ff! zGX>P}ZEud9zjfZ1+m7wkt~zNLyyLSMpj#-Y7HZ%!U1{kxPf$^lNGB^m=in6U+o$&k z_JOspFXD^q&X<^Il}^{UTzs1Q&4k9F^!xwW>mTW0W%aW3$Oesmtkdi*>Qj(n}B^eeeNl%p$HbqJjlcmI*Bqa%b zq(qk}C8h)^(XvuvG)sxWBqbW7lrRP<(d(r|t&@^ut(2%3DN$;qM4^_FER~dGDy1Yv zp@I(rr0@U3pOE#+c0#9+01`j~NB{{S0VIF~kN^@u0!RP}AORpi_W$GhA435NAOR$R z1dsp{Kmter2_OL^fCP|0? _T: def executemany( self, operation: str, - seq_of_parameters: List[Optional[Dict[str, Any]]], + seq_of_parameters: List[Optional[Any]], ) -> None: """Execute a SQL statement multiple times with different parameters - Note: This is not yet fully implemented for MongoDB operations + This method executes the operation once for each parameter set in + seq_of_parameters. It's particularly useful for bulk INSERT, UPDATE, + or DELETE operations. + + Args: + operation: SQL statement to execute + seq_of_parameters: Sequence of parameter sets. Each element should be + a sequence (list/tuple) for positional parameters with ? placeholders, + or a dict for named parameters with :name placeholders. + + Returns: + None (executemany does not produce a result set) + + Note: The rowcount property will reflect the total number of rows affected + across all executions. """ self._check_closed() - # For now, just execute once and ignore parameters - _logger.warning("executemany not fully implemented, executing once without parameters") - self.execute(operation) + if not seq_of_parameters: + return + + total_rowcount = 0 + + try: + # Execute the operation for each parameter set + for params in seq_of_parameters: + self.execute(operation, params) + # Accumulate rowcount from each execution + if self.rowcount > 0: + total_rowcount += self.rowcount + + # Update the final result set with accumulated rowcount + if self._result_set: + self._result_set._rowcount = total_rowcount + + except (SqlSyntaxError, DatabaseError, OperationalError, ProgrammingError): + # Re-raise known errors + raise + except Exception as e: + _logger.error(f"Unexpected error during executemany: {e}") + raise DatabaseError(f"executemany failed: {e}") def execute_transaction(self) -> None: - """Execute transaction (MongoDB has limited transaction support)""" + """Execute transaction - not yet implemented""" self._check_closed() - # MongoDB transactions are complex and require specific setup - # For now, this is a placeholder - raise NotImplementedError("Transaction support not yet implemented") + raise NotImplementedError("Transaction using this function not yet implemented") def flush(self) -> None: """Flush any pending operations""" diff --git a/tests/test_cursor_delete.py b/tests/test_cursor_delete.py index d955218..045ea76 100644 --- a/tests/test_cursor_delete.py +++ b/tests/test_cursor_delete.py @@ -180,3 +180,24 @@ def test_delete_followed_by_insert(self, conn): assert len(list(db[self.TEST_COLLECTION].find())) == 1 doc = list(db[self.TEST_COLLECTION].find())[0] assert doc["title"] == "New Song" + + def test_delete_executemany_with_parameters(self, conn): + """Test executemany for bulk delete operations with parameters.""" + cursor = conn.cursor() + sql = f"DELETE FROM {self.TEST_COLLECTION} WHERE artist = '?'" + + # Delete multiple artists using executemany + params = [["Alice"], ["Charlie"], ["Eve"]] + + cursor.executemany(sql, params) + + # Verify specified artists were deleted + db = conn.database + remaining = list(db[self.TEST_COLLECTION].find()) + assert len(remaining) == 2 # Only Bob and Diana remain + + remaining_artists = {doc["artist"] for doc in remaining} + assert remaining_artists == {"Bob", "Diana"} + assert "Alice" not in remaining_artists + assert "Charlie" not in remaining_artists + assert "Eve" not in remaining_artists diff --git a/tests/test_cursor_insert.py b/tests/test_cursor_insert.py index 145fcc6..1082e04 100644 --- a/tests/test_cursor_insert.py +++ b/tests/test_cursor_insert.py @@ -1,6 +1,4 @@ # -*- coding: utf-8 -*- -"""Test suite for INSERT statement execution via Cursor.""" - import pytest from pymongosql.error import ProgrammingError, SqlSyntaxError @@ -212,3 +210,26 @@ def test_insert_followed_by_select(self, conn): col_names = [desc[0] for desc in cursor.result_set.description] assert "name" in col_names assert "score" in col_names + + def test_insert_executemany_with_parameters(self, conn): + """Test executemany for bulk insert operations with parameters.""" + sql = f"INSERT INTO {self.TEST_COLLECTION} {{'name': '?', 'age': '?', 'instrument': '?'}}" + cursor = conn.cursor() + + # Multiple parameter sets for bulk insert + params = [["Frank", 28, "Guitar"], ["Grace", 32, "Piano"], ["Henry", 27, "Drums"], ["Iris", 30, "Violin"]] + + cursor.executemany(sql, params) + + # Verify all documents were inserted + db = conn.database + docs = list(db[self.TEST_COLLECTION].find()) + assert len(docs) == 4 + + names = {doc["name"] for doc in docs} + assert names == {"Frank", "Grace", "Henry", "Iris"} + + # Verify specific document + frank = db[self.TEST_COLLECTION].find_one({"name": "Frank"}) + assert frank["age"] == 28 + assert frank["instrument"] == "Guitar" diff --git a/tests/test_cursor_update.py b/tests/test_cursor_update.py index b49e728..bb94461 100644 --- a/tests/test_cursor_update.py +++ b/tests/test_cursor_update.py @@ -210,3 +210,27 @@ def test_update_set_null(self, conn): book_b = db[self.TEST_COLLECTION].find_one({"title": "Book B"}) assert book_b is not None assert book_b["stock"] is None + + def test_update_executemany_with_parameters(self, conn): + """Test executemany for bulk update operations with parameters.""" + cursor = conn.cursor() + sql = f"UPDATE {self.TEST_COLLECTION} SET price = '?' WHERE title = '?'" + + # Update prices for multiple books using executemany + params = [[25.99, "Book A"], [35.99, "Book B"], [45.99, "Book D"]] + + cursor.executemany(sql, params) + + # Verify all specified books were updated + db = conn.database + book_a = db[self.TEST_COLLECTION].find_one({"title": "Book A"}) + book_b = db[self.TEST_COLLECTION].find_one({"title": "Book B"}) + book_d = db[self.TEST_COLLECTION].find_one({"title": "Book D"}) + + assert book_a["price"] == 25.99 + assert book_b["price"] == 35.99 + assert book_d["price"] == 45.99 + + # Verify other books remain unchanged + book_c = db[self.TEST_COLLECTION].find_one({"title": "Book C"}) + assert book_c["price"] == 19.99 # Original price unchanged diff --git a/tests/test_delete_builder.py b/tests/test_delete_builder.py new file mode 100644 index 0000000..eded94e --- /dev/null +++ b/tests/test_delete_builder.py @@ -0,0 +1,145 @@ +# -*- coding: utf-8 -*- +import pytest + +from pymongosql.sql.delete_builder import DeleteExecutionPlan, MongoDeleteBuilder + + +class TestDeleteExecutionPlan: + """Test DeleteExecutionPlan dataclass.""" + + def test_to_dict(self): + """Test to_dict conversion.""" + plan = DeleteExecutionPlan(collection="users", filter_conditions={"age": {"$lt": 18}}) + + result = plan.to_dict() + assert result["collection"] == "users" + assert result["filter"] == {"age": {"$lt": 18}} + + def test_to_dict_empty_filter(self): + """Test to_dict with empty filter.""" + plan = DeleteExecutionPlan(collection="logs", filter_conditions={}) + + result = plan.to_dict() + assert result["collection"] == "logs" + assert result["filter"] == {} + + def test_validate_success(self): + """Test validate returns True for valid plan.""" + plan = DeleteExecutionPlan(collection="products", filter_conditions={"status": "inactive"}) + + assert plan.validate() is True + + def test_validate_empty_filter_allowed(self): + """Test validate allows empty filter (delete all).""" + plan = DeleteExecutionPlan(collection="temp_data", filter_conditions={}) + + assert plan.validate() is True + + def test_copy(self): + """Test copy creates independent copy.""" + original = DeleteExecutionPlan(collection="orders", filter_conditions={"status": "cancelled", "year": 2020}) + + copy = original.copy() + + # Verify all fields copied + assert copy.collection == original.collection + assert copy.filter_conditions == original.filter_conditions + + # Verify it's independent + copy.collection = "new_collection" + copy.filter_conditions["new_field"] = "value" + assert original.collection == "orders" + assert "new_field" not in original.filter_conditions + + def test_copy_with_empty_filter(self): + """Test copy handles empty filter dict.""" + original = DeleteExecutionPlan(collection="test", filter_conditions={}) + + copy = original.copy() + assert copy.filter_conditions == {} + + # Verify it's independent + copy.filter_conditions["new"] = "value" + assert original.filter_conditions == {} + + def test_parameter_style_default(self): + """Test default parameter style is qmark.""" + plan = DeleteExecutionPlan(collection="test") + assert plan.parameter_style == "qmark" + + +class TestMongoDeleteBuilder: + """Test MongoDeleteBuilder class.""" + + def test_collection(self): + """Test setting collection name.""" + builder = MongoDeleteBuilder() + result = builder.collection("users") + + assert builder._plan.collection == "users" + assert result is builder # Fluent interface + + def test_filter_conditions(self): + """Test setting filter conditions.""" + builder = MongoDeleteBuilder() + builder.filter_conditions({"status": "deleted", "age": {"$gt": 100}}) + + assert builder._plan.filter_conditions == {"status": "deleted", "age": {"$gt": 100}} + + def test_filter_conditions_empty(self): + """Test filter_conditions with empty dict doesn't update.""" + builder = MongoDeleteBuilder() + builder.filter_conditions({}) + + assert builder._plan.filter_conditions == {} + + def test_filter_conditions_none(self): + """Test filter_conditions with None doesn't update.""" + builder = MongoDeleteBuilder() + builder._plan.filter_conditions = {"existing": "filter"} + builder.filter_conditions(None) + + # Should preserve existing + assert builder._plan.filter_conditions == {"existing": "filter"} + + def test_build_success(self): + """Test build returns execution plan when valid.""" + builder = MongoDeleteBuilder() + builder.collection("products").filter_conditions({"price": {"$lt": 10}}) + + plan = builder.build() + + assert isinstance(plan, DeleteExecutionPlan) + assert plan.collection == "products" + assert plan.filter_conditions == {"price": {"$lt": 10}} + + def test_build_success_empty_filter(self): + """Test build succeeds with empty filter (delete all).""" + builder = MongoDeleteBuilder() + builder.collection("temp_logs") + + plan = builder.build() + + assert isinstance(plan, DeleteExecutionPlan) + assert plan.collection == "temp_logs" + assert plan.filter_conditions == {} + + def test_build_validation_failure(self): + """Test build raises ValueError when validation fails.""" + builder = MongoDeleteBuilder() + # Don't set collection + + with pytest.raises(ValueError) as exc_info: + builder.build() + + assert "invalid delete plan" in str(exc_info.value).lower() + + def test_fluent_interface_chaining(self): + """Test all methods return self for chaining.""" + builder = MongoDeleteBuilder() + + result = builder.collection("orders").filter_conditions({"status": "expired"}) + + assert result is builder + assert builder._plan.collection == "orders" + assert builder._plan.filter_conditions == {"status": "expired"} diff --git a/tests/test_delete_handler.py b/tests/test_delete_handler.py new file mode 100644 index 0000000..9ce1cbc --- /dev/null +++ b/tests/test_delete_handler.py @@ -0,0 +1,270 @@ +# -*- coding: utf-8 -*- +from pymongosql.sql.delete_handler import DeleteHandler, DeleteParseResult +from pymongosql.sql.partiql.PartiQLParser import PartiQLParser + + +class TestDeleteParseResult: + """Test DeleteParseResult dataclass.""" + + def test_for_visitor_factory(self): + """Test factory method creates fresh instance.""" + result = DeleteParseResult.for_visitor() + assert result.collection is None + assert result.filter_conditions == {} + assert result.has_errors is False + assert result.error_message is None + + def test_validate_missing_collection(self): + """Test validation fails when collection is missing.""" + result = DeleteParseResult() + is_valid = result.validate() + + assert is_valid is False + assert result.has_errors is True + assert result.error_message == "Collection name is required" + + def test_validate_with_collection(self): + """Test validation passes when collection is set.""" + result = DeleteParseResult(collection="test_collection") + is_valid = result.validate() + + assert is_valid is True + assert result.has_errors is False + + def test_to_dict(self): + """Test to_dict conversion.""" + result = DeleteParseResult( + collection="users", filter_conditions={"age": {"$gt": 25}}, has_errors=False, error_message=None + ) + + result_dict = result.to_dict() + assert result_dict["collection"] == "users" + assert result_dict["filter_conditions"] == {"age": {"$gt": 25}} + assert result_dict["has_errors"] is False + assert result_dict["error_message"] is None + + def test_repr(self): + """Test string representation.""" + result = DeleteParseResult(collection="products", filter_conditions={"price": {"$lt": 100}}, has_errors=False) + + repr_str = repr(result) + assert "DeleteParseResult" in repr_str + assert "collection=products" in repr_str + assert "has_errors=False" in repr_str + + +class TestDeleteHandler: + """Test DeleteHandler class.""" + + def test_can_handle_delete_context(self): + """Test can_handle returns True for DELETE context.""" + handler = DeleteHandler() + + # Mock context with DELETE attribute + class MockDeleteContext: + def DELETE(self): + return True + + ctx = MockDeleteContext() + assert handler.can_handle(ctx) is True + + def test_can_handle_delete_command_context(self): + """Test can_handle returns True for DeleteCommandContext.""" + handler = DeleteHandler() + + # Mock DeleteCommandContext + class MockDeleteCommandContext(PartiQLParser.DeleteCommandContext): + def __init__(self): + pass # Skip parent init + + ctx = MockDeleteCommandContext() + assert handler.can_handle(ctx) is True + + def test_can_handle_non_delete_context(self): + """Test can_handle returns False for non-DELETE context.""" + handler = DeleteHandler() + + class MockOtherContext: + pass + + ctx = MockOtherContext() + assert handler.can_handle(ctx) is False + + def test_handle_visitor_success(self): + """Test handle_visitor resets parse result.""" + handler = DeleteHandler() + parse_result = DeleteParseResult( + collection="old_collection", filter_conditions={"old": "value"}, has_errors=True, error_message="old error" + ) + + class MockContext: + pass + + ctx = MockContext() + result = handler.handle_visitor(ctx, parse_result) + + # Verify reset + assert result.collection is None + assert result.filter_conditions == {} + assert result.has_errors is False + assert result.error_message is None + + def test_handle_visitor_with_exception(self): + """Test handle_visitor handles exceptions.""" + handler = DeleteHandler() + parse_result = DeleteParseResult() + + # Force an exception by passing None + result = handler.handle_visitor(None, parse_result) + + # Should handle error gracefully + assert isinstance(result, DeleteParseResult) + + def test_handle_from_clause_explicit_success(self): + """Test handle_from_clause_explicit extracts collection name.""" + handler = DeleteHandler() + parse_result = DeleteParseResult() + + # Mock context with pathSimple + class MockPathSimple: + def getText(self): + return "test_collection" + + class MockFromClauseContext: + def pathSimple(self): + return MockPathSimple() + + ctx = MockFromClauseContext() + collection = handler.handle_from_clause_explicit(ctx, parse_result) + + assert collection == "test_collection" + assert parse_result.collection == "test_collection" + assert parse_result.has_errors is False + + def test_handle_from_clause_explicit_no_path(self): + """Test handle_from_clause_explicit when pathSimple returns None.""" + handler = DeleteHandler() + parse_result = DeleteParseResult() + + class MockFromClauseContext: + def pathSimple(self): + return None + + ctx = MockFromClauseContext() + collection = handler.handle_from_clause_explicit(ctx, parse_result) + + assert collection is None + + def test_handle_from_clause_explicit_with_error(self): + """Test handle_from_clause_explicit handles exceptions.""" + handler = DeleteHandler() + parse_result = DeleteParseResult() + + class MockPathSimple: + def getText(self): + raise ValueError("Test error") + + class MockFromClauseContext: + def pathSimple(self): + return MockPathSimple() + + ctx = MockFromClauseContext() + collection = handler.handle_from_clause_explicit(ctx, parse_result) + + assert collection is None + assert parse_result.has_errors is True + assert "Test error" in parse_result.error_message + + def test_handle_from_clause_implicit_success(self): + """Test handle_from_clause_implicit extracts collection name.""" + handler = DeleteHandler() + parse_result = DeleteParseResult() + + class MockPathSimple: + def getText(self): + return "implicit_collection" + + class MockFromClauseContext: + def pathSimple(self): + return MockPathSimple() + + ctx = MockFromClauseContext() + collection = handler.handle_from_clause_implicit(ctx, parse_result) + + assert collection == "implicit_collection" + assert parse_result.collection == "implicit_collection" + assert parse_result.has_errors is False + + def test_handle_from_clause_implicit_no_path(self): + """Test handle_from_clause_implicit when pathSimple returns None.""" + handler = DeleteHandler() + parse_result = DeleteParseResult() + + class MockFromClauseContext: + def pathSimple(self): + return None + + ctx = MockFromClauseContext() + collection = handler.handle_from_clause_implicit(ctx, parse_result) + + assert collection is None + + def test_handle_from_clause_implicit_with_error(self): + """Test handle_from_clause_implicit handles exceptions.""" + handler = DeleteHandler() + parse_result = DeleteParseResult() + + class MockPathSimple: + def getText(self): + raise RuntimeError("Implicit error") + + class MockFromClauseContext: + def pathSimple(self): + return MockPathSimple() + + ctx = MockFromClauseContext() + collection = handler.handle_from_clause_implicit(ctx, parse_result) + + assert collection is None + assert parse_result.has_errors is True + assert "Implicit error" in parse_result.error_message + + def test_handle_where_clause_no_expression(self): + """Test handle_where_clause when no expression is present.""" + handler = DeleteHandler() + parse_result = DeleteParseResult() + + class MockWhereContext: + arg = None + + def expr(self): + return None + + ctx = MockWhereContext() + result = handler.handle_where_clause(ctx, parse_result) + + # Should return empty dict (delete all) + assert result == {} + assert parse_result.filter_conditions == {} + + def test_handle_where_clause_with_error(self): + """Test handle_where_clause handles exceptions.""" + handler = DeleteHandler() + parse_result = DeleteParseResult() + + class MockExpr: + def getText(self): + raise Exception("WHERE error") + + class MockWhereContext: + arg = None + + def expr(self): + return MockExpr() + + ctx = MockWhereContext() + result = handler.handle_where_clause(ctx, parse_result) + + assert result == {} + assert parse_result.has_errors is True + assert "WHERE error" in parse_result.error_message diff --git a/tests/test_insert_builder.py b/tests/test_insert_builder.py new file mode 100644 index 0000000..220aed2 --- /dev/null +++ b/tests/test_insert_builder.py @@ -0,0 +1,268 @@ +# -*- coding: utf-8 -*- +import pytest + +from pymongosql.sql.insert_builder import InsertExecutionPlan, MongoInsertBuilder + + +class TestInsertExecutionPlan: + """Test InsertExecutionPlan dataclass.""" + + def test_to_dict(self): + """Test to_dict conversion.""" + plan = InsertExecutionPlan( + collection="users", insert_documents=[{"name": "Alice"}, {"name": "Bob"}], parameter_count=2 + ) + + result = plan.to_dict() + assert result["collection"] == "users" + assert result["documents"] == [{"name": "Alice"}, {"name": "Bob"}] + assert result["parameter_count"] == 2 + + def test_validate_success(self): + """Test validate returns True for valid plan.""" + plan = InsertExecutionPlan(collection="products", insert_documents=[{"name": "Product A", "price": 99.99}]) + + assert plan.validate() is True + + def test_validate_no_documents(self): + """Test validate fails when no documents.""" + plan = InsertExecutionPlan(collection="products", insert_documents=[]) + + assert plan.validate() is False + + def test_copy(self): + """Test copy creates independent copy.""" + original = InsertExecutionPlan( + collection="orders", + insert_documents=[{"id": 1, "total": 100}, {"id": 2, "total": 200}], + parameter_style="qmark", + parameter_count=4, + ) + + copy = original.copy() + + # Verify all fields copied + assert copy.collection == original.collection + assert copy.insert_documents == original.insert_documents + assert copy.parameter_style == original.parameter_style + assert copy.parameter_count == original.parameter_count + + # Verify it's independent + copy.collection = "new_collection" + copy.insert_documents[0]["new_field"] = "value" + assert original.collection == "orders" + assert "new_field" not in original.insert_documents[0] + + +class TestMongoInsertBuilder: + """Test MongoInsertBuilder class.""" + + def test_collection_valid(self): + """Test setting collection name.""" + builder = MongoInsertBuilder() + result = builder.collection("users") + + assert builder._execution_plan.collection == "users" + assert result is builder # Fluent interface + + def test_collection_empty_string(self): + """Test collection with empty string adds error.""" + builder = MongoInsertBuilder() + builder.collection("") + + errors = builder.get_errors() + assert len(errors) > 0 + assert "cannot be empty" in errors[0].lower() + + def test_collection_whitespace_only(self): + """Test collection with whitespace only adds error.""" + builder = MongoInsertBuilder() + builder.collection(" ") + + errors = builder.get_errors() + assert len(errors) > 0 + + def test_collection_strips_whitespace(self): + """Test collection strips whitespace.""" + builder = MongoInsertBuilder() + builder.collection(" users ") + + assert builder._execution_plan.collection == "users" + + def test_insert_documents_valid_list(self): + """Test insert_documents with valid list.""" + builder = MongoInsertBuilder() + docs = [{"name": "Alice"}, {"name": "Bob"}] + builder.insert_documents(docs) + + assert builder._execution_plan.insert_documents == docs + + def test_insert_documents_invalid_type(self): + """Test insert_documents with non-list adds error.""" + builder = MongoInsertBuilder() + builder.insert_documents("not a list") + + errors = builder.get_errors() + assert len(errors) > 0 + assert "must be a list" in errors[0].lower() + + def test_insert_documents_empty_list(self): + """Test insert_documents with empty list adds error.""" + builder = MongoInsertBuilder() + builder.insert_documents([]) + + errors = builder.get_errors() + assert len(errors) > 0 + assert "at least one document" in errors[0].lower() + + def test_parameter_style_qmark(self): + """Test parameter_style with qmark.""" + builder = MongoInsertBuilder() + builder.parameter_style("qmark") + + assert builder._execution_plan.parameter_style == "qmark" + + def test_parameter_style_named(self): + """Test parameter_style with named.""" + builder = MongoInsertBuilder() + builder.parameter_style("named") + + assert builder._execution_plan.parameter_style == "named" + + def test_parameter_style_invalid(self): + """Test parameter_style with invalid value adds error.""" + builder = MongoInsertBuilder() + builder.parameter_style("invalid") + + errors = builder.get_errors() + assert len(errors) > 0 + assert "invalid parameter style" in errors[0].lower() + + def test_parameter_style_none(self): + """Test parameter_style with None is allowed.""" + builder = MongoInsertBuilder() + builder.parameter_style(None) + + assert builder._execution_plan.parameter_style is None + assert len(builder.get_errors()) == 0 + + def test_parameter_count_valid(self): + """Test parameter_count with valid value.""" + builder = MongoInsertBuilder() + builder.parameter_count(5) + + assert builder._execution_plan.parameter_count == 5 + + def test_parameter_count_zero(self): + """Test parameter_count with zero is allowed.""" + builder = MongoInsertBuilder() + builder.parameter_count(0) + + assert builder._execution_plan.parameter_count == 0 + + def test_parameter_count_negative(self): + """Test parameter_count with negative value adds error.""" + builder = MongoInsertBuilder() + builder.parameter_count(-1) + + errors = builder.get_errors() + assert len(errors) > 0 + assert "non-negative" in errors[0].lower() + + def test_parameter_count_non_integer(self): + """Test parameter_count with non-integer adds error.""" + builder = MongoInsertBuilder() + builder.parameter_count(5.5) + + errors = builder.get_errors() + assert len(errors) > 0 + + def test_validate_success(self): + """Test validate returns True when valid.""" + builder = MongoInsertBuilder() + builder.collection("users").insert_documents([{"name": "Alice"}]) + + assert builder.validate() is True + + def test_validate_missing_collection(self): + """Test validate returns False when collection missing.""" + builder = MongoInsertBuilder() + builder.insert_documents([{"name": "Alice"}]) + + assert builder.validate() is False + errors = builder.get_errors() + assert "collection name is required" in errors[0].lower() + + def test_validate_missing_documents(self): + """Test validate returns False when documents missing.""" + builder = MongoInsertBuilder() + builder.collection("users") + + assert builder.validate() is False + errors = builder.get_errors() + assert "at least one document" in errors[0].lower() + + def test_build_success(self): + """Test build returns execution plan when valid.""" + builder = MongoInsertBuilder() + builder.collection("products").insert_documents([{"name": "Product A"}]) + + plan = builder.build() + + assert isinstance(plan, InsertExecutionPlan) + assert plan.collection == "products" + assert plan.insert_documents == [{"name": "Product A"}] + + def test_build_validation_failure(self): + """Test build raises ValueError when validation fails.""" + builder = MongoInsertBuilder() + # Don't set collection or documents + + with pytest.raises(ValueError) as exc_info: + builder.build() + + assert "validation failed" in str(exc_info.value).lower() + + def test_reset(self): + """Test reset clears builder state.""" + builder = MongoInsertBuilder() + builder.collection("users").insert_documents([{"name": "Alice"}]).parameter_count(3) + + # Add an error + builder.collection("") + + # Reset + builder.reset() + + assert builder._execution_plan.collection is None + assert builder._execution_plan.insert_documents == [] + assert builder._execution_plan.parameter_count == 0 + assert len(builder.get_errors()) == 0 + + def test_str_representation(self): + """Test __str__ method.""" + builder = MongoInsertBuilder() + builder.collection("products").insert_documents([{"a": 1}, {"b": 2}]) + + str_repr = str(builder) + + assert "MongoInsertBuilder" in str_repr + assert "collection=products" in str_repr + assert "documents=2" in str_repr + + def test_fluent_interface_chaining(self): + """Test all methods return self for chaining.""" + builder = MongoInsertBuilder() + + result = ( + builder.collection("orders") + .insert_documents([{"id": 1}, {"id": 2}]) + .parameter_style("qmark") + .parameter_count(4) + ) + + assert result is builder + assert builder._execution_plan.collection == "orders" + assert len(builder._execution_plan.insert_documents) == 2 + assert builder._execution_plan.parameter_style == "qmark" + assert builder._execution_plan.parameter_count == 4 diff --git a/tests/test_insert_handler.py b/tests/test_insert_handler.py new file mode 100644 index 0000000..2ba5a42 --- /dev/null +++ b/tests/test_insert_handler.py @@ -0,0 +1,359 @@ +# -*- coding: utf-8 -*- +import pytest + +from pymongosql.sql.insert_handler import InsertHandler, InsertParseResult + + +class TestInsertParseResult: + """Test InsertParseResult dataclass.""" + + def test_for_visitor_factory(self): + """Test factory method creates fresh instance.""" + result = InsertParseResult.for_visitor() + assert result.collection is None + assert result.insert_columns is None + assert result.insert_values is None + assert result.insert_documents is None + assert result.insert_type is None + assert result.parameter_style is None + assert result.parameter_count == 0 + assert result.has_errors is False + assert result.error_message is None + + +class TestInsertHandler: + """Test InsertHandler class.""" + + def test_can_handle_insert_context(self): + """Test can_handle returns True for INSERT context.""" + handler = InsertHandler() + + class MockInsertContext: + def INSERT(self): + return True + + ctx = MockInsertContext() + assert handler.can_handle(ctx) is True + + def test_can_handle_non_insert_context(self): + """Test can_handle returns False for non-INSERT context.""" + handler = InsertHandler() + + class MockOtherContext: + pass + + ctx = MockOtherContext() + assert handler.can_handle(ctx) is False + + def test_extract_collection_with_symbol_primitive(self): + """Test _extract_collection with symbolPrimitive.""" + handler = InsertHandler() + + class MockSymbol: + def getText(self): + return "test_collection" + + class MockContext: + def symbolPrimitive(self): + return MockSymbol() + + ctx = MockContext() + collection = handler._extract_collection(ctx) + assert collection == "test_collection" + + def test_extract_collection_with_path_simple(self): + """Test _extract_collection with pathSimple (legacy).""" + handler = InsertHandler() + + class MockPath: + def getText(self): + return "legacy_collection" + + class MockContext: + def symbolPrimitive(self): + return None + + def pathSimple(self): + return MockPath() + + ctx = MockContext() + collection = handler._extract_collection(ctx) + assert collection == "legacy_collection" + + def test_extract_collection_missing(self): + """Test _extract_collection raises when collection missing.""" + handler = InsertHandler() + + class MockContext: + def symbolPrimitive(self): + return None + + def pathSimple(self): + return None + + ctx = MockContext() + with pytest.raises(ValueError) as exc_info: + handler._extract_collection(ctx) + + assert "missing collection name" in str(exc_info.value).lower() + + def test_parse_expression_value_null(self): + """Test _parse_expression_value with NULL.""" + handler = InsertHandler() + + class MockExpr: + def getText(self): + return "NULL" + + value = handler._parse_expression_value(MockExpr()) + assert value is None + + def test_parse_expression_value_boolean_true(self): + """Test _parse_expression_value with TRUE.""" + handler = InsertHandler() + + class MockExpr: + def getText(self): + return "TRUE" + + value = handler._parse_expression_value(MockExpr()) + assert value is True + + def test_parse_expression_value_boolean_false(self): + """Test _parse_expression_value with FALSE.""" + handler = InsertHandler() + + class MockExpr: + def getText(self): + return "FALSE" + + value = handler._parse_expression_value(MockExpr()) + assert value is False + + def test_parse_expression_value_string_single_quote(self): + """Test _parse_expression_value with single-quoted string.""" + handler = InsertHandler() + + class MockExpr: + def getText(self): + return "'hello'" + + value = handler._parse_expression_value(MockExpr()) + assert value == "hello" + + def test_parse_expression_value_string_double_quote(self): + """Test _parse_expression_value with double-quoted string.""" + handler = InsertHandler() + + class MockExpr: + def getText(self): + return '"world"' + + value = handler._parse_expression_value(MockExpr()) + assert value == "world" + + def test_parse_expression_value_integer(self): + """Test _parse_expression_value with integer.""" + handler = InsertHandler() + + class MockExpr: + def getText(self): + return "42" + + value = handler._parse_expression_value(MockExpr()) + assert value == 42 + + def test_parse_expression_value_float(self): + """Test _parse_expression_value with float.""" + handler = InsertHandler() + + class MockExpr: + def getText(self): + return "3.14" + + value = handler._parse_expression_value(MockExpr()) + assert value == 3.14 + + def test_parse_expression_value_qmark(self): + """Test _parse_expression_value with ? parameter.""" + handler = InsertHandler() + + class MockExpr: + def getText(self): + return "?" + + value = handler._parse_expression_value(MockExpr()) + assert value == "?" + + def test_parse_expression_value_named_param(self): + """Test _parse_expression_value with :name parameter.""" + handler = InsertHandler() + + class MockExpr: + def getText(self): + return ":name" + + value = handler._parse_expression_value(MockExpr()) + assert value == ":name" + + def test_parse_expression_value_none(self): + """Test _parse_expression_value with None.""" + handler = InsertHandler() + value = handler._parse_expression_value(None) + assert value is None + + def test_parse_expression_value_complex(self): + """Test _parse_expression_value with complex expression.""" + handler = InsertHandler() + + class MockExpr: + def getText(self): + return "COMPLEX_EXPR" + + value = handler._parse_expression_value(MockExpr()) + # Complex expressions are returned as-is + assert value == "COMPLEX_EXPR" + + def test_convert_rows_to_documents_with_columns(self): + """Test _convert_rows_to_documents with explicit columns.""" + handler = InsertHandler() + columns = ["name", "age", "city"] + rows = [["Alice", 25, "NYC"], ["Bob", 30, "LA"]] + + docs = handler._convert_rows_to_documents(columns, rows) + + assert len(docs) == 2 + assert docs[0] == {"name": "Alice", "age": 25, "city": "NYC"} + assert docs[1] == {"name": "Bob", "age": 30, "city": "LA"} + + def test_convert_rows_to_documents_without_columns(self): + """Test _convert_rows_to_documents without explicit columns.""" + handler = InsertHandler() + rows = [["value1", "value2"], ["value3", "value4"]] + + docs = handler._convert_rows_to_documents(None, rows) + + assert len(docs) == 2 + assert docs[0] == {"col0": "value1", "col1": "value2"} + assert docs[1] == {"col0": "value3", "col1": "value4"} + + def test_convert_rows_to_documents_column_count_mismatch(self): + """Test _convert_rows_to_documents with column count mismatch.""" + handler = InsertHandler() + columns = ["name", "age"] + rows = [["Alice", 25, "Extra"]] # Too many values + + with pytest.raises(ValueError) as exc_info: + handler._convert_rows_to_documents(columns, rows) + + assert "column count" in str(exc_info.value).lower() + assert "value count" in str(exc_info.value).lower() + + def test_normalize_literals(self): + """Test _normalize_literals replaces PartiQL booleans/null.""" + handler = InsertHandler() + + # Test null variations + assert "None" in handler._normalize_literals("null") + assert "None" in handler._normalize_literals("NULL") + + # Test boolean variations + assert "True" in handler._normalize_literals("true") + assert "True" in handler._normalize_literals("TRUE") + assert "False" in handler._normalize_literals("false") + assert "False" in handler._normalize_literals("FALSE") + + def test_parse_literal_dict_valid(self): + """Test _parse_literal_dict with valid dict.""" + handler = InsertHandler() + text = "{'name': 'Alice', 'age': 30}" + + doc = handler._parse_literal_dict(text) + assert doc == {"name": "Alice", "age": 30} + + def test_parse_literal_dict_invalid(self): + """Test _parse_literal_dict with invalid dict.""" + handler = InsertHandler() + text = "not a dict" + + with pytest.raises(ValueError) as exc_info: + handler._parse_literal_dict(text) + + assert "failed to parse" in str(exc_info.value).lower() + + def test_parse_literal_dict_non_dict_value(self): + """Test _parse_literal_dict when value is not a dict.""" + handler = InsertHandler() + text = "['list', 'not', 'dict']" + + with pytest.raises(ValueError) as exc_info: + handler._parse_literal_dict(text) + + assert "must be an object" in str(exc_info.value).lower() + + def test_parse_literal_list_valid(self): + """Test _parse_literal_list with valid list of dicts.""" + handler = InsertHandler() + text = "[{'name': 'Alice'}, {'name': 'Bob'}]" + + docs = handler._parse_literal_list(text) + assert len(docs) == 2 + assert docs[0] == {"name": "Alice"} + assert docs[1] == {"name": "Bob"} + + def test_parse_literal_list_invalid(self): + """Test _parse_literal_list with invalid syntax.""" + handler = InsertHandler() + text = "not valid" + + with pytest.raises(ValueError) as exc_info: + handler._parse_literal_list(text) + + assert "failed to parse" in str(exc_info.value).lower() + + def test_parse_literal_list_non_dict_items(self): + """Test _parse_literal_list when items are not dicts.""" + handler = InsertHandler() + text = "['string1', 'string2']" + + with pytest.raises(ValueError) as exc_info: + handler._parse_literal_list(text) + + assert "must contain objects" in str(exc_info.value).lower() + + def test_detect_parameter_style_qmark(self): + """Test _detect_parameter_style with qmark parameters.""" + handler = InsertHandler() + docs = [{"name": "?", "age": "?"}] + + style, count = handler._detect_parameter_style(docs) + assert style == "qmark" + assert count == 2 + + def test_detect_parameter_style_named(self): + """Test _detect_parameter_style with named parameters.""" + handler = InsertHandler() + docs = [{"name": ":name", "age": ":age"}] + + style, count = handler._detect_parameter_style(docs) + assert style == "named" + assert count == 2 + + def test_detect_parameter_style_none(self): + """Test _detect_parameter_style with no parameters.""" + handler = InsertHandler() + docs = [{"name": "Alice", "age": 30}] + + style, count = handler._detect_parameter_style(docs) + assert style is None + assert count == 0 + + def test_detect_parameter_style_mixed_error(self): + """Test _detect_parameter_style raises on mixed styles.""" + handler = InsertHandler() + docs = [{"name": "?", "age": ":age"}] # Mixed qmark and named + + with pytest.raises(ValueError) as exc_info: + handler._detect_parameter_style(docs) + + assert "mixed parameter styles" in str(exc_info.value).lower() diff --git a/tests/test_query_builder.py b/tests/test_query_builder.py new file mode 100644 index 0000000..01c377d --- /dev/null +++ b/tests/test_query_builder.py @@ -0,0 +1,432 @@ +# -*- coding: utf-8 -*- +import pytest + +from pymongosql.sql.query_builder import MongoQueryBuilder, QueryExecutionPlan + + +class TestQueryExecutionPlan: + """Test QueryExecutionPlan dataclass.""" + + def test_to_dict(self): + """Test to_dict conversion.""" + plan = QueryExecutionPlan( + collection="users", + filter_stage={"age": {"$gt": 18}}, + projection_stage={"name": 1, "email": 1}, + sort_stage=[{"name": 1}], + limit_stage=10, + skip_stage=5, + ) + + result = plan.to_dict() + assert result["collection"] == "users" + assert result["filter"] == {"age": {"$gt": 18}} + assert result["projection"] == {"name": 1, "email": 1} + assert result["sort"] == [{"name": 1}] + assert result["limit"] == 10 + assert result["skip"] == 5 + + def test_validate_success(self): + """Test validate returns True for valid plan.""" + plan = QueryExecutionPlan(collection="products", limit_stage=100, skip_stage=0) + + assert plan.validate() is True + + def test_validate_negative_limit(self): + """Test validate fails for negative limit.""" + plan = QueryExecutionPlan(collection="products", limit_stage=-1) + + assert plan.validate() is False + + def test_validate_invalid_limit_type(self): + """Test validate fails for non-integer limit.""" + plan = QueryExecutionPlan(collection="products", limit_stage="10") # String instead of int + + assert plan.validate() is False + + def test_validate_negative_skip(self): + """Test validate fails for negative skip.""" + plan = QueryExecutionPlan(collection="products", skip_stage=-5) + + assert plan.validate() is False + + def test_validate_invalid_skip_type(self): + """Test validate fails for non-integer skip.""" + plan = QueryExecutionPlan(collection="products", skip_stage=5.5) # Float instead of int + + assert plan.validate() is False + + def test_copy(self): + """Test copy creates independent copy.""" + original = QueryExecutionPlan( + collection="orders", + filter_stage={"status": "active"}, + projection_stage={"total": 1}, + column_aliases={"total": "amount"}, + sort_stage=[{"date": -1}], + limit_stage=50, + skip_stage=10, + ) + + copy = original.copy() + + # Verify all fields copied + assert copy.collection == original.collection + assert copy.filter_stage == original.filter_stage + assert copy.projection_stage == original.projection_stage + assert copy.column_aliases == original.column_aliases + assert copy.sort_stage == original.sort_stage + assert copy.limit_stage == original.limit_stage + assert copy.skip_stage == original.skip_stage + + # Verify it's independent (modify copy doesn't affect original) + copy.collection = "new_collection" + copy.filter_stage["new_key"] = "new_value" + assert original.collection == "orders" + assert "new_key" not in original.filter_stage + + +class TestMongoQueryBuilder: + """Test MongoQueryBuilder class.""" + + def test_collection_valid(self): + """Test setting collection name.""" + builder = MongoQueryBuilder() + result = builder.collection("users") + + assert builder._execution_plan.collection == "users" + assert result is builder # Fluent interface + + def test_collection_empty_string(self): + """Test collection with empty string adds error.""" + builder = MongoQueryBuilder() + builder.collection("") + + errors = builder.get_errors() + assert len(errors) > 0 + assert "cannot be empty" in errors[0].lower() + + def test_collection_whitespace_only(self): + """Test collection with whitespace only adds error.""" + builder = MongoQueryBuilder() + builder.collection(" ") + + errors = builder.get_errors() + assert len(errors) > 0 + + def test_filter_valid_dict(self): + """Test filter with valid dictionary.""" + builder = MongoQueryBuilder() + builder.filter({"age": {"$gt": 25}}) + + assert builder._execution_plan.filter_stage == {"age": {"$gt": 25}} + + def test_filter_invalid_type(self): + """Test filter with non-dict adds error.""" + builder = MongoQueryBuilder() + builder.filter("not a dict") + + errors = builder.get_errors() + assert len(errors) > 0 + assert "must be a dictionary" in errors[0].lower() + + def test_filter_multiple_calls(self): + """Test multiple filter calls update conditions.""" + builder = MongoQueryBuilder() + builder.filter({"age": {"$gt": 25}}) + builder.filter({"status": "active"}) + + assert builder._execution_plan.filter_stage["age"] == {"$gt": 25} + assert builder._execution_plan.filter_stage["status"] == "active" + + def test_project_with_list(self): + """Test project with list of field names.""" + builder = MongoQueryBuilder() + builder.project(["name", "email", "age"]) + + expected = {"name": 1, "email": 1, "age": 1} + assert builder._execution_plan.projection_stage == expected + + def test_project_with_dict(self): + """Test project with dictionary.""" + builder = MongoQueryBuilder() + builder.project({"name": 1, "email": 1, "_id": 0}) + + assert builder._execution_plan.projection_stage == {"name": 1, "email": 1, "_id": 0} + + def test_project_with_invalid_type(self): + """Test project with invalid type adds error.""" + builder = MongoQueryBuilder() + builder.project("name, email") + + errors = builder.get_errors() + assert len(errors) > 0 + assert "must be a list" in errors[0].lower() or "dictionary" in errors[0].lower() + + def test_sort_valid_specs(self): + """Test sort with valid specifications.""" + builder = MongoQueryBuilder() + builder.sort([{"name": 1}, {"age": -1}]) + + assert builder._execution_plan.sort_stage == [{"name": 1}, {"age": -1}] + + def test_sort_invalid_type(self): + """Test sort with non-list adds error.""" + builder = MongoQueryBuilder() + builder.sort({"name": 1}) # Dict instead of list + + errors = builder.get_errors() + assert len(errors) > 0 + assert "must be a list" in errors[0].lower() + + def test_sort_invalid_spec_multiple_keys(self): + """Test sort with multi-key dict adds error.""" + builder = MongoQueryBuilder() + builder.sort([{"name": 1, "age": -1}]) # Two keys in one dict + + errors = builder.get_errors() + assert len(errors) > 0 + assert "single-key dict" in errors[0].lower() + + def test_sort_invalid_direction(self): + """Test sort with invalid direction adds error.""" + builder = MongoQueryBuilder() + builder.sort([{"name": 2}]) # Direction must be 1 or -1 + + errors = builder.get_errors() + assert len(errors) > 0 + assert "must be 1 or -1" in errors[0].lower() + + def test_sort_empty_field_name(self): + """Test sort with empty field name adds error.""" + builder = MongoQueryBuilder() + builder.sort([{"": 1}]) + + errors = builder.get_errors() + assert len(errors) > 0 + assert "non-empty string" in errors[0].lower() + + def test_sort_non_string_field(self): + """Test sort with non-string field adds error.""" + builder = MongoQueryBuilder() + builder.sort([{123: 1}]) + + errors = builder.get_errors() + assert len(errors) > 0 + + def test_limit_valid(self): + """Test limit with valid value.""" + builder = MongoQueryBuilder() + builder.limit(100) + + assert builder._execution_plan.limit_stage == 100 + + def test_limit_negative(self): + """Test limit with negative value adds error.""" + builder = MongoQueryBuilder() + builder.limit(-10) + + errors = builder.get_errors() + assert len(errors) > 0 + assert "non-negative" in errors[0].lower() + + def test_limit_non_integer(self): + """Test limit with non-integer adds error.""" + builder = MongoQueryBuilder() + builder.limit(10.5) + + errors = builder.get_errors() + assert len(errors) > 0 + + def test_skip_valid(self): + """Test skip with valid value.""" + builder = MongoQueryBuilder() + builder.skip(50) + + assert builder._execution_plan.skip_stage == 50 + + def test_skip_negative(self): + """Test skip with negative value adds error.""" + builder = MongoQueryBuilder() + builder.skip(-5) + + errors = builder.get_errors() + assert len(errors) > 0 + assert "non-negative" in errors[0].lower() + + def test_skip_non_integer(self): + """Test skip with non-integer adds error.""" + builder = MongoQueryBuilder() + builder.skip("10") + + errors = builder.get_errors() + assert len(errors) > 0 + + def test_column_aliases_valid(self): + """Test column_aliases with valid dict.""" + builder = MongoQueryBuilder() + builder.column_aliases({"user_name": "name", "user_email": "email"}) + + assert builder._execution_plan.column_aliases == {"user_name": "name", "user_email": "email"} + + def test_column_aliases_invalid_type(self): + """Test column_aliases with non-dict adds error.""" + builder = MongoQueryBuilder() + builder.column_aliases(["name", "email"]) + + errors = builder.get_errors() + assert len(errors) > 0 + assert "must be a dictionary" in errors[0].lower() + + def test_where_equality(self): + """Test where with equality operator.""" + builder = MongoQueryBuilder() + builder.where("status", "=", "active") + + assert builder._execution_plan.filter_stage == {"status": {"$eq": "active"}} + + def test_where_greater_than(self): + """Test where with greater than operator.""" + builder = MongoQueryBuilder() + builder.where("age", ">", 18) + + assert builder._execution_plan.filter_stage == {"age": {"$gt": 18}} + + def test_where_less_than_or_equal(self): + """Test where with less than or equal operator.""" + builder = MongoQueryBuilder() + builder.where("price", "<=", 100.0) + + assert builder._execution_plan.filter_stage == {"price": {"$lte": 100.0}} + + def test_where_not_equal(self): + """Test where with not equal operator.""" + builder = MongoQueryBuilder() + builder.where("status", "!=", "deleted") + + assert builder._execution_plan.filter_stage == {"status": {"$ne": "deleted"}} + + def test_where_unsupported_operator(self): + """Test where with unsupported operator adds error.""" + builder = MongoQueryBuilder() + builder.where("field", "INVALID", "value") + + errors = builder.get_errors() + assert len(errors) > 0 + assert "unsupported operator" in errors[0].lower() + + def test_where_in(self): + """Test where_in method.""" + builder = MongoQueryBuilder() + builder.where_in("category", ["books", "music", "movies"]) + + assert builder._execution_plan.filter_stage == {"category": {"$in": ["books", "music", "movies"]}} + + def test_where_between(self): + """Test where_between method.""" + builder = MongoQueryBuilder() + builder.where_between("age", 18, 65) + + assert builder._execution_plan.filter_stage == {"age": {"$gte": 18, "$lte": 65}} + + def test_where_like(self): + """Test where_like method converts SQL pattern to regex.""" + builder = MongoQueryBuilder() + builder.where_like("name", "John%") + + filter_stage = builder._execution_plan.filter_stage + assert "name" in filter_stage + assert "$regex" in filter_stage["name"] + assert filter_stage["name"]["$regex"] == "John.*" + assert filter_stage["name"]["$options"] == "i" + + def test_where_like_with_underscore(self): + """Test where_like converts underscore to dot.""" + builder = MongoQueryBuilder() + builder.where_like("code", "A_C") + + filter_stage = builder._execution_plan.filter_stage + assert filter_stage["code"]["$regex"] == "A.C" + + def test_validate_success(self): + """Test validate returns True when collection is set.""" + builder = MongoQueryBuilder() + builder.collection("users") + + assert builder.validate() is True + + def test_validate_missing_collection(self): + """Test validate returns False when collection is missing.""" + builder = MongoQueryBuilder() + + assert builder.validate() is False + errors = builder.get_errors() + assert len(errors) > 0 + assert "collection name is required" in errors[0].lower() + + def test_build_success(self): + """Test build returns execution plan when valid.""" + builder = MongoQueryBuilder() + builder.collection("users").filter({"age": {"$gt": 18}}).limit(10) + + plan = builder.build() + + assert isinstance(plan, QueryExecutionPlan) + assert plan.collection == "users" + assert plan.filter_stage == {"age": {"$gt": 18}} + assert plan.limit_stage == 10 + + def test_build_validation_failure(self): + """Test build raises ValueError when validation fails.""" + builder = MongoQueryBuilder() + # Don't set collection + + with pytest.raises(ValueError) as exc_info: + builder.build() + + assert "validation failed" in str(exc_info.value).lower() + + def test_reset(self): + """Test reset clears builder state.""" + builder = MongoQueryBuilder() + builder.collection("users").filter({"age": {"$gt": 18}}).limit(10) + + # Add an error + builder.collection("") + + # Reset + builder.reset() + + assert builder._execution_plan.collection is None + assert builder._execution_plan.filter_stage == {} + assert builder._execution_plan.limit_stage is None + assert len(builder.get_errors()) == 0 + + def test_str_representation(self): + """Test __str__ method.""" + builder = MongoQueryBuilder() + builder.collection("products").filter({"price": {"$lt": 100}}).project(["name", "price"]) + + str_repr = str(builder) + + assert "MongoQueryBuilder" in str_repr + assert "collection=products" in str_repr + + def test_fluent_interface_chaining(self): + """Test all methods return self for chaining.""" + builder = MongoQueryBuilder() + + result = ( + builder.collection("orders") + .filter({"status": "pending"}) + .project(["id", "total"]) + .sort([{"date": -1}]) + .limit(100) + .skip(50) + .column_aliases({"id": "order_id"}) + ) + + assert result is builder + assert builder._execution_plan.collection == "orders" + assert builder._execution_plan.limit_stage == 100 + assert builder._execution_plan.skip_stage == 50 diff --git a/tests/test_query_handler.py b/tests/test_query_handler.py new file mode 100644 index 0000000..f301f12 --- /dev/null +++ b/tests/test_query_handler.py @@ -0,0 +1,264 @@ +# -*- coding: utf-8 -*- +from pymongosql.sql.query_handler import FromHandler, QueryParseResult, SelectHandler, WhereHandler + + +class TestQueryParseResult: + """Test QueryParseResult dataclass.""" + + def test_for_visitor_factory(self): + """Test factory method creates fresh instance.""" + result = QueryParseResult.for_visitor() + assert result.filter_conditions == {} + assert result.has_errors is False + assert result.error_message is None + assert result.collection is None + assert result.projection == {} + assert result.column_aliases == {} + assert result.sort_fields == [] + assert result.limit_value is None + assert result.offset_value is None + + def test_merge_expression_with_filters(self): + """Test merge_expression merges filter conditions.""" + result1 = QueryParseResult(filter_conditions={"age": {"$gt": 18}}) + result2 = QueryParseResult(filter_conditions={"status": "active"}) + + result1.merge_expression(result2) + + # Should combine with $and + assert "$and" in result1.filter_conditions + assert result1.filter_conditions["$and"] == [{"age": {"$gt": 18}}, {"status": "active"}] + + def test_merge_expression_no_existing_filter(self): + """Test merge_expression when no existing filter.""" + result1 = QueryParseResult() + result2 = QueryParseResult(filter_conditions={"status": "active"}) + + result1.merge_expression(result2) + + assert result1.filter_conditions == {"status": "active"} + + def test_merge_expression_with_errors(self): + """Test merge_expression propagates errors.""" + result1 = QueryParseResult() + result2 = QueryParseResult(has_errors=True, error_message="Test error") + + result1.merge_expression(result2) + + assert result1.has_errors is True + assert result1.error_message == "Test error" + + def test_mongo_filter_property_getter(self): + """Test mongo_filter property (backward compatibility).""" + result = QueryParseResult(filter_conditions={"age": 25}) + assert result.mongo_filter == {"age": 25} + + def test_mongo_filter_property_setter(self): + """Test mongo_filter property setter (backward compatibility).""" + result = QueryParseResult() + result.mongo_filter = {"age": 30} + assert result.filter_conditions == {"age": 30} + + +class TestSelectHandler: + """Test SelectHandler class.""" + + def test_can_handle_projection_items(self): + """Test can_handle returns True for projectionItems context.""" + handler = SelectHandler() + + class MockContext: + def projectionItems(self): + return True + + assert handler.can_handle(MockContext()) is True + + def test_can_handle_no_projection_items(self): + """Test can_handle returns False when no projectionItems.""" + handler = SelectHandler() + + class MockContext: + pass + + assert handler.can_handle(MockContext()) is False + + def test_extract_field_and_alias_simple_field(self): + """Test _extract_field_and_alias with simple field.""" + handler = SelectHandler() + + class MockChild: + def getText(self): + return "field_name" + + class MockItem: + children = [MockChild()] + + field_name, alias = handler._extract_field_and_alias(MockItem()) + assert field_name == "field_name" + assert alias is None + + def test_extract_field_and_alias_with_as_keyword(self): + """Test _extract_field_and_alias with AS keyword.""" + handler = SelectHandler() + + class MockField: + def getText(self): + return "field_name" + + class MockAS: + def getText(self): + return "AS" + + class MockAlias: + def getText(self): + return "field_alias" + + class MockItem: + children = [MockField(), MockAS(), MockAlias()] + + field_name, alias = handler._extract_field_and_alias(MockItem()) + assert field_name == "field_name" + assert alias == "field_alias" + + def test_extract_field_and_alias_without_as_keyword(self): + """Test _extract_field_and_alias without AS keyword.""" + handler = SelectHandler() + + class MockField: + def getText(self): + return "field_name" + + class MockAlias: + def getText(self): + return "alias_name" + + class MockItem: + children = [MockField(), MockAlias()] + + field_name, alias = handler._extract_field_and_alias(MockItem()) + assert field_name == "field_name" + assert alias == "alias_name" + + def test_extract_field_and_alias_no_children(self): + """Test _extract_field_and_alias when no children.""" + handler = SelectHandler() + + class MockItem: + def __str__(self): + return "simple_item" + + field_name, alias = handler._extract_field_and_alias(MockItem()) + assert alias is None + + +class TestFromHandler: + """Test FromHandler class.""" + + def test_can_handle_table_reference(self): + """Test can_handle returns True for tableReference context.""" + handler = FromHandler() + + class MockContext: + def tableReference(self): + return True + + assert handler.can_handle(MockContext()) is True + + def test_can_handle_no_table_reference(self): + """Test can_handle returns False when no tableReference.""" + handler = FromHandler() + + class MockContext: + pass + + assert handler.can_handle(MockContext()) is False + + def test_handle_visitor_extracts_collection(self): + """Test handle_visitor extracts collection name.""" + handler = FromHandler() + parse_result = QueryParseResult() + + class MockTableRef: + def getText(self): + return "test_collection" + + class MockContext: + def tableReference(self): + return MockTableRef() + + ctx = MockContext() + collection = handler.handle_visitor(ctx, parse_result) + + assert collection == "test_collection" + assert parse_result.collection == "test_collection" + + def test_handle_visitor_no_table_reference(self): + """Test handle_visitor when no tableReference.""" + handler = FromHandler() + parse_result = QueryParseResult() + + class MockContext: + def tableReference(self): + return None + + ctx = MockContext() + result = handler.handle_visitor(ctx, parse_result) + + assert result is None + + +class TestWhereHandler: + """Test WhereHandler class.""" + + def test_can_handle_expr_select(self): + """Test can_handle returns True for exprSelect context.""" + handler = WhereHandler() + + class MockContext: + def exprSelect(self): + return True + + assert handler.can_handle(MockContext()) is True + + def test_can_handle_no_expr_select(self): + """Test can_handle returns False when no exprSelect.""" + handler = WhereHandler() + + class MockContext: + pass + + assert handler.can_handle(MockContext()) is False + + def test_handle_visitor_no_expression(self): + """Test handle_visitor when no exprSelect.""" + handler = WhereHandler() + parse_result = QueryParseResult() + + class MockContext: + def exprSelect(self): + return None + + ctx = MockContext() + result = handler.handle_visitor(ctx, parse_result) + + assert result == {} + + def test_handle_visitor_with_exception_fallback(self): + """Test handle_visitor falls back to text search on exception.""" + handler = WhereHandler() + parse_result = QueryParseResult() + + class MockExpr: + def getText(self): + return "field = value" + + class MockContext: + def exprSelect(self): + return MockExpr() + + ctx = MockContext() + result = handler.handle_visitor(ctx, parse_result) + + # Should fallback to text search when expression handler fails + # The actual behavior depends on expression handler implementation + assert isinstance(result, dict) diff --git a/tests/test_sqlalchemy_dialect.py b/tests/test_sqlalchemy_dialect.py index 52d84b5..91a0530 100644 --- a/tests/test_sqlalchemy_dialect.py +++ b/tests/test_sqlalchemy_dialect.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python3 +# -*- coding: utf-8 -*- import unittest from typing import Callable from unittest.mock import Mock, patch diff --git a/tests/test_sqlalchemy_dml.py b/tests/test_sqlalchemy_dml.py index 00b1951..2e1bf8e 100644 --- a/tests/test_sqlalchemy_dml.py +++ b/tests/test_sqlalchemy_dml.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python3 +# -*- coding: utf-8 -*- import pytest from tests.conftest import HAS_SQLALCHEMY, Base diff --git a/tests/test_sqlalchemy_query.py b/tests/test_sqlalchemy_query.py index 091ac89..e4758ad 100644 --- a/tests/test_sqlalchemy_query.py +++ b/tests/test_sqlalchemy_query.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python3 +# -*- coding: utf-8 -*- import pytest from tests.conftest import HAS_SQLALCHEMY, Base diff --git a/tests/test_update_builder.py b/tests/test_update_builder.py new file mode 100644 index 0000000..32dfe03 --- /dev/null +++ b/tests/test_update_builder.py @@ -0,0 +1,179 @@ +# -*- coding: utf-8 -*- +import pytest + +from pymongosql.sql.update_builder import MongoUpdateBuilder, UpdateExecutionPlan + + +class TestUpdateExecutionPlan: + """Test UpdateExecutionPlan dataclass.""" + + def test_to_dict(self): + """Test to_dict conversion.""" + plan = UpdateExecutionPlan( + collection="users", update_fields={"name": "John", "age": 30}, filter_conditions={"id": 123} + ) + + result = plan.to_dict() + assert result["collection"] == "users" + assert result["filter"] == {"id": 123} + assert result["update"] == {"$set": {"name": "John", "age": 30}} + + def test_validate_success(self): + """Test validate returns True for valid plan.""" + plan = UpdateExecutionPlan(collection="products", update_fields={"price": 99.99}) + + assert plan.validate() is True + + def test_validate_no_update_fields(self): + """Test validate fails when no update fields.""" + plan = UpdateExecutionPlan(collection="products", update_fields={}) + + assert plan.validate() is False + + def test_validate_empty_filter_allowed(self): + """Test validate allows empty filter (update all).""" + plan = UpdateExecutionPlan(collection="products", update_fields={"status": "active"}, filter_conditions={}) + + assert plan.validate() is True + + def test_copy(self): + """Test copy creates independent copy.""" + original = UpdateExecutionPlan( + collection="orders", update_fields={"status": "shipped", "total": 100}, filter_conditions={"id": 456} + ) + + copy = original.copy() + + # Verify all fields copied + assert copy.collection == original.collection + assert copy.update_fields == original.update_fields + assert copy.filter_conditions == original.filter_conditions + + # Verify it's independent + copy.collection = "new_collection" + copy.update_fields["new_field"] = "value" + assert original.collection == "orders" + assert "new_field" not in original.update_fields + + def test_copy_with_empty_fields(self): + """Test copy handles empty dicts.""" + original = UpdateExecutionPlan(collection="test", update_fields={"field": "value"}) + + copy = original.copy() + assert copy.filter_conditions == {} + + def test_get_mongo_update_doc(self): + """Test get_mongo_update_doc returns $set document.""" + plan = UpdateExecutionPlan(collection="users", update_fields={"email": "user@example.com", "verified": True}) + + update_doc = plan.get_mongo_update_doc() + assert update_doc == {"$set": {"email": "user@example.com", "verified": True}} + + def test_parameter_style_default(self): + """Test default parameter style is qmark.""" + plan = UpdateExecutionPlan(collection="test", update_fields={"a": "b"}) + assert plan.parameter_style == "qmark" + + +class TestMongoUpdateBuilder: + """Test MongoUpdateBuilder class.""" + + def test_collection(self): + """Test setting collection name.""" + builder = MongoUpdateBuilder() + result = builder.collection("users") + + assert builder._plan.collection == "users" + assert result is builder # Fluent interface + + def test_update_fields(self): + """Test setting update fields.""" + builder = MongoUpdateBuilder() + builder.update_fields({"name": "Alice", "age": 25}) + + assert builder._plan.update_fields == {"name": "Alice", "age": 25} + + def test_update_fields_empty_dict(self): + """Test update_fields with empty dict doesn't update.""" + builder = MongoUpdateBuilder() + builder.update_fields({}) + + assert builder._plan.update_fields == {} + + def test_update_fields_none(self): + """Test update_fields with None doesn't update.""" + builder = MongoUpdateBuilder() + builder._plan.update_fields = {"existing": "field"} + builder.update_fields(None) + + # Should preserve existing + assert builder._plan.update_fields == {"existing": "field"} + + def test_filter_conditions(self): + """Test setting filter conditions.""" + builder = MongoUpdateBuilder() + builder.filter_conditions({"status": "active", "age": {"$gt": 18}}) + + assert builder._plan.filter_conditions == {"status": "active", "age": {"$gt": 18}} + + def test_filter_conditions_empty(self): + """Test filter_conditions with empty dict doesn't update.""" + builder = MongoUpdateBuilder() + builder.filter_conditions({}) + + assert builder._plan.filter_conditions == {} + + def test_filter_conditions_none(self): + """Test filter_conditions with None doesn't update.""" + builder = MongoUpdateBuilder() + builder._plan.filter_conditions = {"existing": "filter"} + builder.filter_conditions(None) + + # Should preserve existing + assert builder._plan.filter_conditions == {"existing": "filter"} + + def test_parameter_style(self): + """Test setting parameter style.""" + builder = MongoUpdateBuilder() + builder.parameter_style("named") + + assert builder._plan.parameter_style == "named" + + def test_build_success(self): + """Test build returns execution plan when valid.""" + builder = MongoUpdateBuilder() + builder.collection("products").update_fields({"price": 49.99}) + + plan = builder.build() + + assert isinstance(plan, UpdateExecutionPlan) + assert plan.collection == "products" + assert plan.update_fields == {"price": 49.99} + + def test_build_validation_failure(self): + """Test build raises ValueError when validation fails.""" + builder = MongoUpdateBuilder() + builder.collection("products") + # Don't set update_fields + + with pytest.raises(ValueError) as exc_info: + builder.build() + + assert "invalid update plan" in str(exc_info.value).lower() + + def test_fluent_interface_chaining(self): + """Test all methods return self for chaining.""" + builder = MongoUpdateBuilder() + + result = ( + builder.collection("orders") + .update_fields({"status": "shipped"}) + .filter_conditions({"id": 123}) + .parameter_style("qmark") + ) + + assert result is builder + assert builder._plan.collection == "orders" + assert builder._plan.update_fields == {"status": "shipped"} + assert builder._plan.filter_conditions == {"id": 123} + assert builder._plan.parameter_style == "qmark" diff --git a/tests/test_update_handler.py b/tests/test_update_handler.py new file mode 100644 index 0000000..dc6dcab --- /dev/null +++ b/tests/test_update_handler.py @@ -0,0 +1,358 @@ +# -*- coding: utf-8 -*- +from pymongosql.sql.update_handler import UpdateHandler, UpdateParseResult + + +class TestUpdateParseResult: + """Test UpdateParseResult dataclass.""" + + def test_for_visitor_factory(self): + """Test factory method creates fresh instance.""" + result = UpdateParseResult.for_visitor() + assert result.collection is None + assert result.update_fields == {} + assert result.filter_conditions == {} + assert result.has_errors is False + assert result.error_message is None + + def test_validate_missing_collection(self): + """Test validation fails when collection is missing.""" + result = UpdateParseResult(update_fields={"name": "value"}) + is_valid = result.validate() + + assert is_valid is False + assert result.has_errors is True + assert result.error_message == "Collection name is required" + + def test_validate_missing_update_fields(self): + """Test validation fails when update fields are missing.""" + result = UpdateParseResult(collection="test_collection") + is_valid = result.validate() + + assert is_valid is False + assert result.has_errors is True + assert result.error_message == "At least one field to update is required" + + def test_validate_success(self): + """Test validation passes when all required fields set.""" + result = UpdateParseResult(collection="users", update_fields={"name": "John"}) + is_valid = result.validate() + + assert is_valid is True + assert result.has_errors is False + + def test_to_dict(self): + """Test to_dict conversion.""" + result = UpdateParseResult( + collection="users", + update_fields={"age": 30, "status": "active"}, + filter_conditions={"id": 123}, + has_errors=False, + error_message=None, + ) + + result_dict = result.to_dict() + assert result_dict["collection"] == "users" + assert result_dict["update_fields"] == {"age": 30, "status": "active"} + assert result_dict["filter_conditions"] == {"id": 123} + assert result_dict["has_errors"] is False + assert result_dict["error_message"] is None + + def test_repr(self): + """Test string representation.""" + result = UpdateParseResult( + collection="products", update_fields={"price": 99.99}, filter_conditions={"sku": "ABC123"}, has_errors=False + ) + + repr_str = repr(result) + assert "UpdateParseResult" in repr_str + assert "collection=products" in repr_str + assert "has_errors=False" in repr_str + + +class TestUpdateHandler: + """Test UpdateHandler class.""" + + def test_can_handle_update_context(self): + """Test can_handle returns True for UPDATE context.""" + handler = UpdateHandler() + + class MockUpdateContext: + def UPDATE(self): + return True + + ctx = MockUpdateContext() + assert handler.can_handle(ctx) is True + + def test_can_handle_non_update_context(self): + """Test can_handle returns False for non-UPDATE context.""" + handler = UpdateHandler() + + class MockOtherContext: + pass + + ctx = MockOtherContext() + assert handler.can_handle(ctx) is False + + def test_handle_visitor_with_table_reference(self): + """Test handle_visitor extracts collection from tableBaseReference.""" + handler = UpdateHandler() + parse_result = UpdateParseResult() + + class MockSource: + def getText(self): + return "test_collection" + + class MockTableRef: + source = MockSource() + + class MockContext: + def tableBaseReference(self): + return MockTableRef() + + ctx = MockContext() + result = handler.handle_visitor(ctx, parse_result) + + assert result.collection == "test_collection" + assert result.has_errors is False + + def test_handle_visitor_without_table_reference(self): + """Test handle_visitor when no tableBaseReference present.""" + handler = UpdateHandler() + parse_result = UpdateParseResult() + + class MockContext: + def tableBaseReference(self): + return None + + ctx = MockContext() + result = handler.handle_visitor(ctx, parse_result) + + # Should not set collection + assert result.collection is None + + def test_handle_visitor_with_error(self): + """Test handle_visitor logs warning on exception but continues.""" + handler = UpdateHandler() + parse_result = UpdateParseResult() + + class MockTableRef: + def __getattribute__(self, name): + raise RuntimeError("Test error") + + class MockContext: + def tableBaseReference(self): + return MockTableRef() + + ctx = MockContext() + result = handler.handle_visitor(ctx, parse_result) + + # The handler logs warning but doesn't set error flag + assert result.collection is None + + def test_extract_collection_from_table_ref_with_fallback(self): + """Test _extract_collection_from_table_ref uses getText fallback.""" + handler = UpdateHandler() + + class MockTableRef: + def getText(self): + return "fallback_collection" + + ctx = MockTableRef() + collection = handler._extract_collection_from_table_ref(ctx) + + assert collection == "fallback_collection" + + def test_extract_collection_with_exception(self): + """Test _extract_collection_from_table_ref handles exceptions.""" + handler = UpdateHandler() + + class MockTableRef: + def getText(self): + raise ValueError("Error extracting") + + ctx = MockTableRef() + collection = handler._extract_collection_from_table_ref(ctx) + + assert collection is None + + def test_handle_set_command_single_assignment(self): + """Test handle_set_command with single assignment.""" + handler = UpdateHandler() + parse_result = UpdateParseResult() + + class MockPathSimple: + def getText(self): + return "field_name" + + class MockExpr: + def getText(self): + return "'field_value'" + + class MockAssignment: + def pathSimple(self): + return MockPathSimple() + + def expr(self): + return MockExpr() + + class MockContext: + def setAssignment(self): + return [MockAssignment()] + + ctx = MockContext() + result = handler.handle_set_command(ctx, parse_result) + + assert result.update_fields == {"field_name": "field_value"} + assert result.has_errors is False + + def test_handle_set_command_multiple_assignments(self): + """Test handle_set_command with multiple assignments.""" + handler = UpdateHandler() + parse_result = UpdateParseResult() + + class MockAssignment1: + def pathSimple(self): + class P: + def getText(self): + return "name" + + return P() + + def expr(self): + class E: + def getText(self): + return "'Alice'" + + return E() + + class MockAssignment2: + def pathSimple(self): + class P: + def getText(self): + return "age" + + return P() + + def expr(self): + class E: + def getText(self): + return "30" + + return E() + + class MockContext: + def setAssignment(self): + return [MockAssignment1(), MockAssignment2()] + + ctx = MockContext() + result = handler.handle_set_command(ctx, parse_result) + + assert result.update_fields == {"name": "Alice", "age": 30} + + def test_handle_set_command_with_error(self): + """Test handle_set_command handles exceptions.""" + handler = UpdateHandler() + parse_result = UpdateParseResult() + + class MockContext: + def setAssignment(self): + raise RuntimeError("SET error") + + ctx = MockContext() + result = handler.handle_set_command(ctx, parse_result) + + assert result.has_errors is True + assert "SET error" in result.error_message + + def test_parse_value_string_single_quote(self): + """Test _parse_value with single-quoted string.""" + handler = UpdateHandler() + value = handler._parse_value("'hello world'") + assert value == "hello world" + + def test_parse_value_string_double_quote(self): + """Test _parse_value with double-quoted string.""" + handler = UpdateHandler() + value = handler._parse_value('"hello world"') + assert value == "hello world" + + def test_parse_value_null(self): + """Test _parse_value with null.""" + handler = UpdateHandler() + assert handler._parse_value("null") is None + assert handler._parse_value("NULL") is None + + def test_parse_value_boolean(self): + """Test _parse_value with booleans.""" + handler = UpdateHandler() + assert handler._parse_value("true") is True + assert handler._parse_value("TRUE") is True + assert handler._parse_value("false") is False + assert handler._parse_value("FALSE") is False + + def test_parse_value_integer(self): + """Test _parse_value with integer.""" + handler = UpdateHandler() + assert handler._parse_value("42") == 42 + assert handler._parse_value("-10") == -10 + + def test_parse_value_float(self): + """Test _parse_value with float.""" + handler = UpdateHandler() + assert handler._parse_value("3.14") == 3.14 + assert handler._parse_value("-2.5") == -2.5 + + def test_parse_value_parameter_qmark(self): + """Test _parse_value with qmark parameter.""" + handler = UpdateHandler() + assert handler._parse_value("?") == "?" + + def test_parse_value_parameter_named(self): + """Test _parse_value with named parameter.""" + handler = UpdateHandler() + assert handler._parse_value(":name") == ":name" + + def test_parse_value_unquoted_string(self): + """Test _parse_value with unquoted string.""" + handler = UpdateHandler() + value = handler._parse_value("unquoted") + assert value == "unquoted" + + def test_handle_where_clause_no_expression(self): + """Test handle_where_clause when no expression present.""" + handler = UpdateHandler() + parse_result = UpdateParseResult() + + class MockWhereContext: + arg = None + + def expr(self): + return None + + ctx = MockWhereContext() + result = handler.handle_where_clause(ctx, parse_result) + + # Should return empty dict (update all) + assert result == {} + assert parse_result.filter_conditions == {} + + def test_handle_where_clause_with_error(self): + """Test handle_where_clause logs error but returns empty dict.""" + handler = UpdateHandler() + parse_result = UpdateParseResult() + + class MockExpr: + def getText(self): + raise Exception("WHERE error") + + class MockWhereContext: + arg = None + + def expr(self): + return MockExpr() + + ctx = MockWhereContext() + result = handler.handle_where_clause(ctx, parse_result) + + # Returns empty dict on error + assert result == {}