Skip to content

Commit 2110cd4

Browse files
author
Peng Ren
committed
Added basic insert support
1 parent f664c9f commit 2110cd4

File tree

12 files changed

+658
-72
lines changed

12 files changed

+658
-72
lines changed

pymongosql/cursor.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def __init__(self, connection: "Connection", mode: str = "standard", **kwargs) -
2929
self._kwargs = kwargs
3030
self._result_set: Optional[ResultSet] = None
3131
self._result_set_class = ResultSet
32-
self._current_execution_plan: Optional[QueryExecutionPlan] = None
32+
self._current_execution_plan: Optional[Any] = None
3333
self._is_closed = False
3434

3535
@property
@@ -103,12 +103,32 @@ def execute(self: _T, operation: str, parameters: Optional[Any] = None) -> _T:
103103
self._current_execution_plan = strategy.execution_plan
104104

105105
# Create result set from command result
106-
self._result_set = self._result_set_class(
107-
command_result=result,
108-
execution_plan=self._current_execution_plan,
109-
database=self.connection.database,
110-
**self._kwargs,
111-
)
106+
# For SELECT/QUERY operations, use the execution plan directly
107+
if isinstance(self._current_execution_plan, QueryExecutionPlan):
108+
execution_plan_for_rs = self._current_execution_plan
109+
self._result_set = self._result_set_class(
110+
command_result=result,
111+
execution_plan=execution_plan_for_rs,
112+
database=self.connection.database,
113+
**self._kwargs,
114+
)
115+
else:
116+
# For INSERT and other non-query operations, create a minimal synthetic result
117+
# since INSERT commands don't return a cursor structure
118+
stub_plan = QueryExecutionPlan(collection=self._current_execution_plan.collection)
119+
self._result_set = self._result_set_class(
120+
command_result={
121+
"cursor": {
122+
"id": 0,
123+
"firstBatch": [],
124+
}
125+
},
126+
execution_plan=stub_plan,
127+
database=self.connection.database,
128+
**self._kwargs,
129+
)
130+
# Store the actual insert result for reference
131+
self._result_set._insert_result = result
112132

113133
return self
114134

pymongosql/executor.py

Lines changed: 86 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from pymongo.errors import PyMongoError
88

99
from .error import DatabaseError, OperationalError, ProgrammingError, SqlSyntaxError
10+
from .helper import SQLHelper
11+
from .sql.insert_builder import InsertExecutionPlan
1012
from .sql.parser import SQLParser
1113
from .sql.query_builder import QueryExecutionPlan
1214

@@ -30,7 +32,7 @@ class ExecutionStrategy(ABC):
3032

3133
@property
3234
@abstractmethod
33-
def execution_plan(self) -> QueryExecutionPlan:
35+
def execution_plan(self) -> Union[QueryExecutionPlan, InsertExecutionPlan]:
3436
"""Name of the execution plan"""
3537
pass
3638

@@ -60,7 +62,7 @@ def supports(self, context: ExecutionContext) -> bool:
6062
pass
6163

6264

63-
class StandardExecution(ExecutionStrategy):
65+
class StandardQueryExecution(ExecutionStrategy):
6466
"""Standard execution strategy for simple SELECT queries without subqueries"""
6567

6668
@property
@@ -70,7 +72,8 @@ def execution_plan(self) -> QueryExecutionPlan:
7072

7173
def supports(self, context: ExecutionContext) -> bool:
7274
"""Support simple queries without subqueries"""
73-
return "standard" in context.execution_mode.lower()
75+
normalized = context.query.lstrip().upper()
76+
return "standard" in context.execution_mode.lower() and normalized.startswith("SELECT")
7477

7578
def _parse_sql(self, sql: str) -> QueryExecutionPlan:
7679
"""Parse SQL statement and return QueryExecutionPlan"""
@@ -91,29 +94,7 @@ def _parse_sql(self, sql: str) -> QueryExecutionPlan:
9194

9295
def _replace_placeholders(self, obj: Any, parameters: Sequence[Any]) -> Any:
9396
"""Recursively replace ? placeholders with parameter values in filter/projection dicts"""
94-
param_index = [0] # Use list to allow modification in nested function
95-
96-
def replace_recursive(value: Any) -> Any:
97-
if isinstance(value, str):
98-
# Replace ? with the next parameter value
99-
if value == "?":
100-
if param_index[0] < len(parameters):
101-
result = parameters[param_index[0]]
102-
param_index[0] += 1
103-
return result
104-
else:
105-
raise ProgrammingError(
106-
f"Not enough parameters provided: expected at least {param_index[0] + 1}"
107-
)
108-
return value
109-
elif isinstance(value, dict):
110-
return {k: replace_recursive(v) for k, v in value.items()}
111-
elif isinstance(value, list):
112-
return [replace_recursive(item) for item in value]
113-
else:
114-
return value
115-
116-
return replace_recursive(obj)
97+
return SQLHelper.replace_placeholders_generic(obj, parameters, "qmark")
11798

11899
def _execute_execution_plan(
119100
self,
@@ -202,10 +183,87 @@ def execute(
202183
return self._execute_execution_plan(self._execution_plan, connection.database, processed_params)
203184

204185

186+
class InsertExecution(ExecutionStrategy):
187+
"""Execution strategy for INSERT statements."""
188+
189+
@property
190+
def execution_plan(self) -> InsertExecutionPlan:
191+
return self._execution_plan
192+
193+
def supports(self, context: ExecutionContext) -> bool:
194+
return context.query.lstrip().upper().startswith("INSERT")
195+
196+
def _parse_sql(self, sql: str) -> InsertExecutionPlan:
197+
try:
198+
parser = SQLParser(sql)
199+
plan = parser.get_execution_plan()
200+
201+
if not isinstance(plan, InsertExecutionPlan):
202+
raise SqlSyntaxError("Expected INSERT execution plan")
203+
204+
if not plan.validate():
205+
raise SqlSyntaxError("Generated insert plan is invalid")
206+
207+
return plan
208+
except SqlSyntaxError:
209+
raise
210+
except Exception as e:
211+
_logger.error(f"SQL parsing failed: {e}")
212+
raise SqlSyntaxError(f"Failed to parse SQL: {e}")
213+
214+
def _replace_placeholders(
215+
self,
216+
documents: Sequence[Dict[str, Any]],
217+
parameters: Optional[Union[Sequence[Any], Dict[str, Any]]],
218+
style: Optional[str],
219+
) -> Sequence[Dict[str, Any]]:
220+
return SQLHelper.replace_placeholders_generic(documents, parameters, style)
221+
222+
def _execute_execution_plan(
223+
self,
224+
execution_plan: InsertExecutionPlan,
225+
db: Any,
226+
parameters: Optional[Union[Sequence[Any], Dict[str, Any]]] = None,
227+
) -> Optional[Dict[str, Any]]:
228+
try:
229+
if not execution_plan.collection:
230+
raise ProgrammingError("No collection specified in insert")
231+
232+
docs = execution_plan.insert_documents or []
233+
docs = self._replace_placeholders(docs, parameters, execution_plan.parameter_style)
234+
235+
command = {"insert": execution_plan.collection, "documents": docs}
236+
237+
_logger.debug(f"Executing MongoDB insert command: {command}")
238+
239+
return db.command(command)
240+
except PyMongoError as e:
241+
_logger.error(f"MongoDB insert failed: {e}")
242+
raise DatabaseError(f"Insert execution failed: {e}")
243+
except (ProgrammingError, DatabaseError, OperationalError):
244+
# Re-raise our own errors without wrapping
245+
raise
246+
except Exception as e:
247+
_logger.error(f"Unexpected error during insert execution: {e}")
248+
raise OperationalError(f"Insert execution error: {e}")
249+
250+
def execute(
251+
self,
252+
context: ExecutionContext,
253+
connection: Any,
254+
parameters: Optional[Union[Sequence[Any], Dict[str, Any]]] = None,
255+
) -> Optional[Dict[str, Any]]:
256+
_logger.debug(f"Using insert execution for query: {context.query[:100]}")
257+
258+
self._execution_plan = self._parse_sql(context.query)
259+
260+
return self._execute_execution_plan(self._execution_plan, connection.database, parameters)
261+
262+
205263
class ExecutionPlanFactory:
206264
"""Factory for creating appropriate execution strategy based on query context"""
207265

208-
_strategies = [StandardExecution()]
266+
_strategies = [InsertExecution(), StandardQueryExecution()]
209267

210268
@classmethod
211269
def get_strategy(cls, context: ExecutionContext) -> ExecutionStrategy:
@@ -216,7 +274,7 @@ def get_strategy(cls, context: ExecutionContext) -> ExecutionStrategy:
216274
return strategy
217275

218276
# Fallback to standard execution
219-
return StandardExecution()
277+
return StandardQueryExecution()
220278

221279
@classmethod
222280
def register_strategy(cls, strategy: ExecutionStrategy) -> None:

pymongosql/helper.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
"""
77

88
import logging
9-
from typing import Optional, Tuple
9+
from typing import Any, Optional, Sequence, Tuple
1010
from urllib.parse import parse_qs, urlparse
1111

12+
from .error import ProgrammingError
13+
1214
_logger = logging.getLogger(__name__)
1315

1416

@@ -95,3 +97,54 @@ def parse_connection_string(connection_string: Optional[str]) -> Tuple[Optional[
9597
except Exception as e:
9698
_logger.error(f"Failed to parse connection string: {e}")
9799
raise ValueError(f"Invalid connection string format: {e}")
100+
101+
102+
class SQLHelper:
103+
"""SQL-related helper utilities."""
104+
105+
@staticmethod
106+
def replace_placeholders_generic(value: Any, parameters: Any, style: Optional[str]) -> Any:
107+
"""Recursively replace placeholders in nested structures for qmark or named styles."""
108+
if style is None or parameters is None:
109+
return value
110+
111+
if style == "qmark":
112+
if not isinstance(parameters, Sequence) or isinstance(parameters, (str, bytes, dict)):
113+
raise ProgrammingError("Positional parameters must be provided as a sequence")
114+
115+
idx = [0]
116+
117+
def replace(val: Any) -> Any:
118+
if isinstance(val, str) and val == "?":
119+
if idx[0] >= len(parameters):
120+
raise ProgrammingError("Not enough parameters provided")
121+
out = parameters[idx[0]]
122+
idx[0] += 1
123+
return out
124+
if isinstance(val, dict):
125+
return {k: replace(v) for k, v in val.items()}
126+
if isinstance(val, list):
127+
return [replace(v) for v in val]
128+
return val
129+
130+
return replace(value)
131+
132+
if style == "named":
133+
if not isinstance(parameters, dict):
134+
raise ProgrammingError("Named parameters must be provided as a mapping")
135+
136+
def replace(val: Any) -> Any:
137+
if isinstance(val, str) and val.startswith(":"):
138+
key = val[1:]
139+
if key not in parameters:
140+
raise ProgrammingError(f"Missing named parameter: {key}")
141+
return parameters[key]
142+
if isinstance(val, dict):
143+
return {k: replace(v) for k, v in val.items()}
144+
if isinstance(val, list):
145+
return [replace(v) for v in val]
146+
return val
147+
148+
return replace(value)
149+
150+
return value

pymongosql/sql/ast.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def _initialize_handlers(self) -> Dict[str, BaseHandler]:
4646
"select": HandlerFactory.get_visitor_handler("select"),
4747
"from": HandlerFactory.get_visitor_handler("from"),
4848
"where": HandlerFactory.get_visitor_handler("where"),
49+
"insert": HandlerFactory.get_visitor_handler("insert"),
4950
}
5051

5152
@property
@@ -74,6 +75,9 @@ def _build_query_plan(self) -> QueryExecutionPlan:
7475

7576
def _build_insert_plan(self) -> InsertExecutionPlan:
7677
"""Build an INSERT execution plan from INSERT parsing."""
78+
if self._insert_parse_result.has_errors:
79+
raise SqlSyntaxError(self._insert_parse_result.error_message or "INSERT parsing failed")
80+
7781
builder = BuilderFactory.create_insert_builder().collection(self._insert_parse_result.collection)
7882

7983
documents = self._insert_parse_result.insert_documents or []
@@ -145,6 +149,24 @@ def visitWhereClauseSelect(self, ctx: PartiQLParser.WhereClauseSelectContext) ->
145149
_logger.warning(f"Error processing WHERE clause: {e}")
146150
return self.visitChildren(ctx)
147151

152+
def visitInsertStatement(self, ctx: PartiQLParser.InsertStatementContext) -> Any:
153+
"""Handle INSERT statements via the insert handler."""
154+
_logger.debug("Processing INSERT statement")
155+
self._current_operation = "insert"
156+
handler = self._handlers.get("insert")
157+
if handler:
158+
return handler.handle_visitor(ctx, self._insert_parse_result)
159+
return self.visitChildren(ctx)
160+
161+
def visitInsertStatementLegacy(self, ctx: PartiQLParser.InsertStatementLegacyContext) -> Any:
162+
"""Handle legacy INSERT statements."""
163+
_logger.debug("Processing INSERT legacy statement")
164+
self._current_operation = "insert"
165+
handler = self._handlers.get("insert")
166+
if handler:
167+
return handler.handle_visitor(ctx, self._insert_parse_result)
168+
return self.visitChildren(ctx)
169+
148170
def visitOrderByClause(self, ctx: PartiQLParser.OrderByClauseContext) -> Any:
149171
"""Handle ORDER BY clause for sorting"""
150172
_logger.debug("Processing ORDER BY clause")

pymongosql/sql/insert_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def insert_documents(self, documents: List[Dict[str, Any]]) -> "MongoInsertBuild
8080

8181
def parameter_style(self, style: Optional[str]) -> "MongoInsertBuilder":
8282
"""Set parameter binding style for tracking."""
83-
if style and style not in ["qmark"]:
83+
if style and style not in ["qmark", "named"]:
8484
self._add_error(f"Invalid parameter style: {style}")
8585
return self
8686

0 commit comments

Comments
 (0)