Skip to content

Commit 9674c56

Browse files
authored
refactor: fix group_by compiler with dtype convertions (#2350)
This change can resolve the `bigframes.ml.metrics.roc_auc_score` doctests failures in #2248. Fixes internal issue 417774347 🦕
1 parent 7efdda8 commit 9674c56

File tree

2 files changed

+60
-9
lines changed

2 files changed

+60
-9
lines changed

bigframes/core/compile/sqlglot/aggregations/windows.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919

2020
from bigframes.core import utils, window_spec
2121
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
22+
import bigframes.core.expression as ex
2223
import bigframes.core.ordering as ordering_spec
24+
import bigframes.dtypes as dtypes
2325

2426

2527
def apply_window_if_present(
@@ -52,10 +54,7 @@ def apply_window_if_present(
5254
order = sge.Order(expressions=order_by)
5355

5456
group_by = (
55-
[
56-
scalar_compiler.scalar_op_compiler.compile_expression(key)
57-
for key in window.grouping_keys
58-
]
57+
[_compile_group_by_key(key) for key in window.grouping_keys]
5958
if window.grouping_keys
6059
else None
6160
)
@@ -164,3 +163,18 @@ def _get_window_bounds(
164163

165164
side = "PRECEDING" if value < 0 else "FOLLOWING"
166165
return sge.convert(abs(value)), side
166+
167+
168+
def _compile_group_by_key(key: ex.Expression) -> sge.Expression:
169+
expr = scalar_compiler.scalar_op_compiler.compile_expression(key)
170+
# The group_by keys has been rewritten by bind_schema_to_node
171+
assert isinstance(key, ex.ResolvedDerefOp)
172+
173+
# Some types need to be converted to another type to enable groupby
174+
if key.dtype == dtypes.FLOAT_DTYPE:
175+
expr = sge.Cast(this=expr, to="STRING")
176+
elif key.dtype == dtypes.GEO_DTYPE:
177+
expr = sge.Cast(this=expr, to="BYTES")
178+
elif key.dtype == dtypes.JSON_DTYPE:
179+
expr = sge.func("TO_JSON_STRING", expr)
180+
return expr

tests/unit/core/compile/sqlglot/aggregations/test_windows.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@
1818
import pytest
1919
import sqlglot.expressions as sge
2020

21+
from bigframes import dtypes
2122
from bigframes.core import window_spec
2223
from bigframes.core.compile.sqlglot.aggregations.windows import (
2324
apply_window_if_present,
2425
get_window_order_by,
2526
)
2627
import bigframes.core.expression as ex
28+
import bigframes.core.identifiers as ids
2729
import bigframes.core.ordering as ordering
2830

2931

@@ -82,16 +84,37 @@ def test_apply_window_if_present_row_bounded_no_ordering_raises(self):
8284
),
8385
)
8486

85-
def test_apply_window_if_present_unbounded_grouping_no_ordering(self):
87+
def test_apply_window_if_present_grouping_no_ordering(self):
8688
result = apply_window_if_present(
8789
sge.Var(this="value"),
8890
window_spec.WindowSpec(
89-
grouping_keys=(ex.deref("col1"),),
91+
grouping_keys=(
92+
ex.ResolvedDerefOp(
93+
ids.ColumnId("col1"),
94+
dtype=dtypes.STRING_DTYPE,
95+
is_nullable=True,
96+
),
97+
ex.ResolvedDerefOp(
98+
ids.ColumnId("col2"),
99+
dtype=dtypes.FLOAT_DTYPE,
100+
is_nullable=True,
101+
),
102+
ex.ResolvedDerefOp(
103+
ids.ColumnId("col3"),
104+
dtype=dtypes.JSON_DTYPE,
105+
is_nullable=True,
106+
),
107+
ex.ResolvedDerefOp(
108+
ids.ColumnId("col4"),
109+
dtype=dtypes.GEO_DTYPE,
110+
is_nullable=True,
111+
),
112+
),
90113
),
91114
)
92115
self.assertEqual(
93116
result.sql(dialect="bigquery"),
94-
"value OVER (PARTITION BY `col1`)",
117+
"value OVER (PARTITION BY `col1`, CAST(`col2` AS STRING), TO_JSON_STRING(`col3`), CAST(`col4` AS BYTES))",
95118
)
96119

97120
def test_apply_window_if_present_range_bounded(self):
@@ -126,8 +149,22 @@ def test_apply_window_if_present_all_params(self):
126149
result = apply_window_if_present(
127150
sge.Var(this="value"),
128151
window_spec.WindowSpec(
129-
grouping_keys=(ex.deref("col1"),),
130-
ordering=(ordering.OrderingExpression(ex.deref("col2")),),
152+
grouping_keys=(
153+
ex.ResolvedDerefOp(
154+
ids.ColumnId("col1"),
155+
dtype=dtypes.STRING_DTYPE,
156+
is_nullable=True,
157+
),
158+
),
159+
ordering=(
160+
ordering.OrderingExpression(
161+
ex.ResolvedDerefOp(
162+
ids.ColumnId("col2"),
163+
dtype=dtypes.STRING_DTYPE,
164+
is_nullable=True,
165+
)
166+
),
167+
),
131168
bounds=window_spec.RowsWindowBounds(start=-1, end=0),
132169
),
133170
)

0 commit comments

Comments
 (0)