Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 37 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,9 @@ Parameters are substituted into the MongoDB filter during execution, providing p

### INSERT Statements

PyMongoSQL supports inserting documents into MongoDB collections using PartiQL-style object and bag literals.
PyMongoSQL supports inserting documents into MongoDB collections using both PartiQL-style object literals and standard SQL INSERT VALUES syntax.

#### PartiQL-Style Object Literals

**Single Document**

Expand All @@ -239,12 +241,44 @@ cursor.execute(
```python
# Positional parameters using ? placeholders
cursor.execute(
"INSERT INTO Music {'title': ?, 'artist': ?, 'year': ?}",
"INSERT INTO Music {'title': '?', 'artist': '?', 'year': '?'}",
["Song D", "Diana", 2020]
)
```

> **Note**: For parameterized INSERT, use positional parameters (`?`). Named placeholders (`:name`) are supported for SELECT, UPDATE, and DELETE queries.
#### Standard SQL INSERT VALUES

**Single Row with Column List**

```python
cursor.execute(
"INSERT INTO Music (title, artist, year) VALUES ('Song E', 'Eve', 2022)"
)
```

**Multiple Rows**

```python
cursor.execute(
"INSERT INTO Music (title, artist, year) VALUES ('Song F', 'Frank', 2023), ('Song G', 'Grace', 2024)"
)
```

**Parameterized INSERT VALUES**

```python
# Positional parameters (?)
cursor.execute(
"INSERT INTO Music (title, artist, year) VALUES (?, ?, ?)",
["Song H", "Henry", 2025]
)

# Named parameters (:name)
cursor.execute(
"INSERT INTO Music (title, artist) VALUES (:title, :artist)",
{"title": "Song I", "artist": "Iris"}
)
```

### UPDATE Statements

Expand Down
2 changes: 1 addition & 1 deletion pymongosql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
if TYPE_CHECKING:
from .connection import Connection

__version__: str = "0.3.0"
__version__: str = "0.3.1"

# Globals https://www.python.org/dev/peps/pep-0249/#globals
apilevel: str = "2.0"
Expand Down
133 changes: 46 additions & 87 deletions pymongosql/sql/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,13 @@
from typing import Any, Dict, Union

from ..error import SqlSyntaxError
from .builder import BuilderFactory
from .delete_builder import DeleteExecutionPlan
from .delete_handler import DeleteParseResult
from .handler import BaseHandler, HandlerFactory
from .insert_builder import InsertExecutionPlan
from .insert_handler import InsertParseResult
from .partiql.PartiQLLexer import PartiQLLexer
from .partiql.PartiQLParser import PartiQLParser
from .partiql.PartiQLParserVisitor import PartiQLParserVisitor
from .query_builder import QueryExecutionPlan
from .query_handler import QueryParseResult
from .update_builder import UpdateExecutionPlan
from .update_handler import UpdateParseResult

_logger = logging.getLogger(__name__)
Expand All @@ -37,7 +32,7 @@ class MongoSQLParserVisitor(PartiQLParserVisitor):

def __init__(self) -> None:
super().__init__()
self._parse_result = QueryParseResult.for_visitor()
self._query_parse_result = QueryParseResult.for_visitor()
self._insert_parse_result = InsertParseResult.for_visitor()
self._delete_parse_result = DeleteParseResult.for_visitor()
self._update_parse_result = UpdateParseResult.for_visitor()
Expand All @@ -58,86 +53,27 @@ def _initialize_handlers(self) -> Dict[str, BaseHandler]:
}

@property
def parse_result(self) -> QueryParseResult:
"""Get the current parse result"""
return self._parse_result

def parse_to_execution_plan(
self,
) -> Union[QueryExecutionPlan, InsertExecutionPlan, DeleteExecutionPlan, UpdateExecutionPlan]:
"""Convert the parse result to an execution plan using BuilderFactory."""
def parse_result(self) -> Union[QueryParseResult, InsertParseResult, DeleteParseResult, UpdateParseResult]:
"""Get the current parse result based on the current operation"""
if self._current_operation == "insert":
return self._build_insert_plan()
return self._insert_parse_result
elif self._current_operation == "delete":
return self._build_delete_plan()
return self._delete_parse_result
elif self._current_operation == "update":
return self._build_update_plan()
return self._update_parse_result
else:
return self._query_parse_result

return self._build_query_plan()

def _build_query_plan(self) -> QueryExecutionPlan:
"""Build a query execution plan from SELECT parsing."""
builder = BuilderFactory.create_query_builder().collection(self._parse_result.collection)

builder.filter(self._parse_result.filter_conditions).project(self._parse_result.projection).column_aliases(
self._parse_result.column_aliases
).sort(self._parse_result.sort_fields).limit(self._parse_result.limit_value).skip(
self._parse_result.offset_value
)

return builder.build()

def _build_insert_plan(self) -> InsertExecutionPlan:
"""Build an INSERT execution plan from INSERT parsing."""
if self._insert_parse_result.has_errors:
raise SqlSyntaxError(self._insert_parse_result.error_message or "INSERT parsing failed")

builder = BuilderFactory.create_insert_builder().collection(self._insert_parse_result.collection)

documents = self._insert_parse_result.insert_documents or []
builder.insert_documents(documents)

if self._insert_parse_result.parameter_style:
builder.parameter_style(self._insert_parse_result.parameter_style)

if self._insert_parse_result.parameter_count > 0:
builder.parameter_count(self._insert_parse_result.parameter_count)

return builder.build()

def _build_delete_plan(self) -> DeleteExecutionPlan:
"""Build a DELETE execution plan from DELETE parsing."""
_logger.debug(
f"Building DELETE plan with collection: {self._delete_parse_result.collection}, "
f"filters: {self._delete_parse_result.filter_conditions}"
)
builder = BuilderFactory.create_delete_builder().collection(self._delete_parse_result.collection)

if self._delete_parse_result.filter_conditions:
builder.filter_conditions(self._delete_parse_result.filter_conditions)

return builder.build()

def _build_update_plan(self) -> UpdateExecutionPlan:
"""Build an UPDATE execution plan from UPDATE parsing."""
_logger.debug(
f"Building UPDATE plan with collection: {self._update_parse_result.collection}, "
f"update_fields: {self._update_parse_result.update_fields}, "
f"filters: {self._update_parse_result.filter_conditions}"
)
builder = BuilderFactory.create_update_builder().collection(self._update_parse_result.collection)

if self._update_parse_result.update_fields:
builder.update_fields(self._update_parse_result.update_fields)

if self._update_parse_result.filter_conditions:
builder.filter_conditions(self._update_parse_result.filter_conditions)

return builder.build()
@property
def current_operation(self) -> str:
"""Get the current operation type (select, insert, delete, or update)"""
return self._current_operation

def visitRoot(self, ctx: PartiQLParser.RootContext) -> Any:
"""Visit root node and process child nodes"""
_logger.debug("Starting to parse SQL query")
# Reset to default SELECT operation at the start of each query
self._current_operation = "select"
try:
result = self.visitChildren(ctx)
return result
Expand All @@ -149,7 +85,7 @@ def visitSelectAll(self, ctx: PartiQLParser.SelectAllContext) -> Any:
"""Handle SELECT * statements"""
_logger.debug("Processing SELECT ALL statement")
# SELECT * means no projection filter (return all fields)
self._parse_result.projection = {}
self._query_parse_result.projection = {}
return self.visitChildren(ctx)

def visitSelectItems(self, ctx: PartiQLParser.SelectItemsContext) -> Any:
Expand All @@ -158,7 +94,7 @@ def visitSelectItems(self, ctx: PartiQLParser.SelectItemsContext) -> Any:
try:
handler = self._handlers["select"]
if handler:
result = handler.handle_visitor(ctx, self._parse_result)
result = handler.handle_visitor(ctx, self._query_parse_result)
return result
return self.visitChildren(ctx)
except Exception as e:
Expand All @@ -171,7 +107,7 @@ def visitFromClause(self, ctx: PartiQLParser.FromClauseContext) -> Any:
try:
handler = self._handlers["from"]
if handler:
result = handler.handle_visitor(ctx, self._parse_result)
result = handler.handle_visitor(ctx, self._query_parse_result)
_logger.debug(f"Extracted collection: {result}")
return result
return self.visitChildren(ctx)
Expand All @@ -185,7 +121,7 @@ def visitWhereClauseSelect(self, ctx: PartiQLParser.WhereClauseSelectContext) ->
try:
handler = self._handlers["where"]
if handler:
result = handler.handle_visitor(ctx, self._parse_result)
result = handler.handle_visitor(ctx, self._query_parse_result)
_logger.debug(f"Extracted filter conditions: {result}")
return result
return self.visitChildren(ctx)
Expand All @@ -197,20 +133,43 @@ def visitInsertStatement(self, ctx: PartiQLParser.InsertStatementContext) -> Any
"""Handle INSERT statements via the insert handler."""
_logger.debug("Processing INSERT statement")
self._current_operation = "insert"
# Reset insert parse result for this statement
self._insert_parse_result = InsertParseResult.for_visitor()
handler = self._handlers.get("insert")
if handler:
return handler.handle_visitor(ctx, self._insert_parse_result)
handler.handle_visitor(ctx, self._insert_parse_result)
# Continue visiting children to process columnList and values
self.visitChildren(ctx)
return self._insert_parse_result
return self.visitChildren(ctx)

def visitInsertStatementLegacy(self, ctx: PartiQLParser.InsertStatementLegacyContext) -> Any:
"""Handle legacy INSERT statements."""
_logger.debug("Processing INSERT legacy statement")
self._current_operation = "insert"
# Reset insert parse result for this statement
self._insert_parse_result = InsertParseResult.for_visitor()
handler = self._handlers.get("insert")
if handler:
return handler.handle_visitor(ctx, self._insert_parse_result)
return self.visitChildren(ctx)

def visitColumnList(self, ctx: PartiQLParser.ColumnListContext) -> Any:
"""Handle column list in INSERT statements."""
if self._current_operation == "insert":
handler = self._handlers.get("insert")
if handler:
return handler.handle_column_list(ctx, self._insert_parse_result)
return self.visitChildren(ctx)

def visitValues(self, ctx: PartiQLParser.ValuesContext) -> Any:
"""Handle VALUES clause in INSERT statements."""
if self._current_operation == "insert":
handler = self._handlers.get("insert")
if handler:
return handler.handle_values(ctx, self._insert_parse_result)
return self.visitChildren(ctx)

def visitFromClauseSimpleExplicit(self, ctx: PartiQLParser.FromClauseSimpleExplicitContext) -> Any:
"""Handle FROM clause (explicit form) in DELETE statements."""
if self._current_operation == "delete":
Expand Down Expand Up @@ -247,7 +206,7 @@ def visitWhereClause(self, ctx: PartiQLParser.WhereClauseContext) -> Any:
# For other operations, use the where handler
handler = self._handlers["where"]
if handler:
result = handler.handle_visitor(ctx, self._parse_result)
result = handler.handle_visitor(ctx, self._query_parse_result)
_logger.debug(f"Extracted filter conditions: {result}")
return result
return {}
Expand Down Expand Up @@ -284,7 +243,7 @@ def visitOrderByClause(self, ctx: PartiQLParser.OrderByClauseContext) -> Any:
# Convert to the expected format: List[Dict[str, int]]
sort_specs.append({field_name: direction})

self._parse_result.sort_fields = sort_specs
self._query_parse_result.sort_fields = sort_specs
_logger.debug(f"Extracted sort specifications: {sort_specs}")
return self.visitChildren(ctx)
except Exception as e:
Expand All @@ -299,7 +258,7 @@ def visitLimitClause(self, ctx: PartiQLParser.LimitClauseContext) -> Any:
limit_text = ctx.exprSelect().getText()
try:
limit_value = int(limit_text)
self._parse_result.limit_value = limit_value
self._query_parse_result.limit_value = limit_value
_logger.debug(f"Extracted limit value: {limit_value}")
except ValueError as e:
_logger.warning(f"Invalid LIMIT value '{limit_text}': {e}")
Expand All @@ -316,7 +275,7 @@ def visitOffsetByClause(self, ctx: PartiQLParser.OffsetByClauseContext) -> Any:
offset_text = ctx.exprSelect().getText()
try:
offset_value = int(offset_text)
self._parse_result.offset_value = offset_value
self._query_parse_result.offset_value = offset_value
_logger.debug(f"Extracted offset value: {offset_value}")
except ValueError as e:
_logger.warning(f"Invalid OFFSET value '{offset_text}': {e}")
Expand Down
Loading