Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
255 changes: 233 additions & 22 deletions bigframes/core/compile/sqlglot/expressions/datetime_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,35 +19,14 @@
from bigframes import dtypes
from bigframes import operations as ops
from bigframes.core.compile.constants import UNIT_TO_US_CONVERSION_FACTORS
from bigframes.core.compile.sqlglot import sqlglot_types
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler

register_unary_op = scalar_compiler.scalar_op_compiler.register_unary_op
register_binary_op = scalar_compiler.scalar_op_compiler.register_binary_op


def _calculate_resample_first(y: TypedExpr, origin: str) -> sge.Expression:
if origin == "epoch":
return sge.convert(0)
elif origin == "start_day":
return sge.func(
"UNIX_MICROS",
sge.Cast(
this=sge.Cast(
this=y.expr, to=sge.DataType(this=sge.DataType.Type.DATE)
),
to=sge.DataType(this=sge.DataType.Type.TIMESTAMPTZ),
),
)
elif origin == "start":
return sge.func(
"UNIX_MICROS",
sge.Cast(this=y.expr, to=sge.DataType(this=sge.DataType.Type.TIMESTAMPTZ)),
)
else:
raise ValueError(f"Origin {origin} not supported")


@register_binary_op(ops.DatetimeToIntegerLabelOp, pass_op=True)
def datetime_to_integer_label_op(
x: TypedExpr, y: TypedExpr, op: ops.DatetimeToIntegerLabelOp
Expand Down Expand Up @@ -317,6 +296,20 @@ def _(expr: TypedExpr, op: ops.FloorDtOp) -> sge.Expression:
return sge.TimestampTrunc(this=expr.expr, unit=sge.Identifier(this=bq_freq))


def _calculate_resample_first(y: TypedExpr, origin: str) -> sge.Expression:
if origin == "epoch":
return sge.convert(0)
elif origin == "start_day":
return sge.func(
"UNIX_MICROS",
sge.Cast(this=sge.Cast(this=y.expr, to="DATE"), to="TIMESTAMP"),
)
elif origin == "start":
return sge.func("UNIX_MICROS", sge.Cast(this=y.expr, to="TIMESTAMP"))
else:
raise ValueError(f"Origin {origin} not supported")


@register_unary_op(ops.hour_op)
def _(expr: TypedExpr) -> sge.Expression:
return sge.Extract(this=sge.Identifier(this="HOUR"), expression=expr.expr)
Expand Down Expand Up @@ -436,3 +429,221 @@ def _(expr: TypedExpr, op: ops.UnixSeconds) -> sge.Expression:
@register_unary_op(ops.year_op)
def _(expr: TypedExpr) -> sge.Expression:
return sge.Extract(this=sge.Identifier(this="YEAR"), expression=expr.expr)


@register_binary_op(ops.IntegerLabelToDatetimeOp, pass_op=True)
def integer_label_to_datetime_op(
x: TypedExpr, y: TypedExpr, op: ops.IntegerLabelToDatetimeOp
) -> sge.Expression:
# Determine if the frequency is fixed by checking if 'op.freq.nanos' is defined.
try:
return _integer_label_to_datetime_op_fixed_frequency(x, y, op)
except ValueError:
return _integer_label_to_datetime_op_non_fixed_frequency(x, y, op)


def _integer_label_to_datetime_op_fixed_frequency(
x: TypedExpr, y: TypedExpr, op: ops.IntegerLabelToDatetimeOp
) -> sge.Expression:
"""
This function handles fixed frequency conversions where the unit can range
from microseconds (us) to days.
"""
us = op.freq.nanos / 1000
first = _calculate_resample_first(y, op.origin) # type: ignore
x_label = sge.Cast(
this=sge.func(
"TIMESTAMP_MICROS",
sge.Cast(
this=sge.Add(
this=sge.Mul(
this=sge.Cast(this=x.expr, to="BIGNUMERIC"),
expression=sge.convert(int(us)),
),
expression=sge.Cast(this=first, to="BIGNUMERIC"),
),
to="INT64",
),
),
to=sqlglot_types.from_bigframes_dtype(y.dtype),
)
return x_label


def _integer_label_to_datetime_op_non_fixed_frequency(
x: TypedExpr, y: TypedExpr, op: ops.IntegerLabelToDatetimeOp
) -> sge.Expression:
"""
This function handles non-fixed frequency conversions for units ranging
from weeks to years.
"""
rule_code = op.freq.rule_code
n = op.freq.n
if rule_code == "W-SUN": # Weekly
us = n * 7 * 24 * 60 * 60 * 1000000
first = sge.func(
"UNIX_MICROS",
sge.Add(
this=sge.TimestampTrunc(
this=sge.Cast(this=y.expr, to="TIMESTAMP"),
unit=sge.Var(this="WEEK(MONDAY)"),
),
expression=sge.Interval(
this=sge.convert(6), unit=sge.Identifier(this="DAY")
),
),
)
x_label = sge.Cast(
this=sge.func(
"TIMESTAMP_MICROS",
sge.Cast(
this=sge.Add(
this=sge.Mul(
this=sge.Cast(this=x.expr, to="BIGNUMERIC"),
expression=sge.convert(us),
),
expression=sge.Cast(this=first, to="BIGNUMERIC"),
),
to="INT64",
),
),
to=sqlglot_types.from_bigframes_dtype(y.dtype),
)
Comment on lines +483 to +511
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Style nit: could we define separate helper functions to handle different frequencies? This would help reduce the size of the function _integer_label_to_datetime_op_non_fixed_frequency and improve readability.

The improved function should look like this

def _integer_label_to_datetime_op_non_fixed_frequency():
    ....
    if rule_code is weekly:
        return _integer_label_to_datetime_op_weekly_freq()
    elif rule_code is monthly:
         ....

elif rule_code in ("ME", "M"): # Monthly
one = sge.convert(1)
twelve = sge.convert(12)
first = sge.Sub( # type: ignore
this=sge.Add(
this=sge.Mul(
this=sge.Extract(this="YEAR", expression=y.expr),
expression=twelve,
),
expression=sge.Extract(this="MONTH", expression=y.expr),
),
expression=one,
)
x_val = sge.Add(
this=sge.Mul(this=x.expr, expression=sge.convert(n)), expression=first
)
year = sge.Cast(
this=sge.Floor(this=sge.func("IEEE_DIVIDE", x_val, twelve)),
to="INT64",
)
month = sge.Add(this=sge.Mod(this=x_val, expression=twelve), expression=one)
next_year = sge.Case(
ifs=[
sge.If(
this=sge.EQ(this=month, expression=twelve),
true=sge.Add(this=year, expression=one),
)
],
default=year,
)
next_month = sge.Case(
ifs=[
sge.If(
this=sge.EQ(this=month, expression=twelve),
true=one,
)
],
default=sge.Add(this=month, expression=one),
)
next_month_date = sge.func(
"TIMESTAMP",
sge.Anonymous(
this="DATETIME",
expressions=[
next_year,
next_month,
one,
sge.convert(0),
sge.convert(0),
sge.convert(0),
],
),
)
x_label = sge.Sub( # type: ignore
this=next_month_date, expression=sge.Interval(this=one, unit="DAY")
)
elif rule_code in ("QE-DEC", "Q-DEC"): # Quarterly
one = sge.convert(1)
three = sge.convert(3)
four = sge.convert(4)
twelve = sge.convert(12)
first = sge.Sub( # type: ignore
this=sge.Add(
this=sge.Mul(
this=sge.Extract(this="YEAR", expression=y.expr),
expression=four,
),
expression=sge.Extract(this="QUARTER", expression=y.expr),
),
expression=one,
)
x_val = sge.Add(
this=sge.Mul(this=x.expr, expression=sge.convert(n)), expression=first
)
year = sge.Cast(
this=sge.Floor(this=sge.func("IEEE_DIVIDE", x_val, four)),
to="INT64",
)
month = sge.Mul( # type: ignore
this=sge.Paren(
this=sge.Add(this=sge.Mod(this=x_val, expression=four), expression=one)
),
expression=three,
)
next_year = sge.Case(
ifs=[
sge.If(
this=sge.EQ(this=month, expression=twelve),
true=sge.Add(this=year, expression=one),
)
],
default=year,
)
next_month = sge.Case(
ifs=[sge.If(this=sge.EQ(this=month, expression=twelve), true=one)],
default=sge.Add(this=month, expression=one),
)
next_month_date = sge.Anonymous(
this="DATETIME",
expressions=[
next_year,
next_month,
one,
sge.convert(0),
sge.convert(0),
sge.convert(0),
],
)
x_label = sge.Sub( # type: ignore
this=next_month_date, expression=sge.Interval(this=one, unit="DAY")
)
elif rule_code in ("YE-DEC", "A-DEC", "Y-DEC"): # Yearly
one = sge.convert(1)
first = sge.Extract(this="YEAR", expression=y.expr)
x_val = sge.Add(
this=sge.Mul(this=x.expr, expression=sge.convert(n)), expression=first
)
next_year = sge.Add(this=x_val, expression=one) # type: ignore
next_month_date = sge.func(
"TIMESTAMP",
sge.Anonymous(
this="DATETIME",
expressions=[
next_year,
one,
one,
sge.convert(0),
sge.convert(0),
sge.convert(0),
],
),
)
x_label = sge.Sub( # type: ignore
this=next_month_date, expression=sge.Interval(this=one, unit="DAY")
)
else:
raise ValueError(rule_code)
return sge.Cast(this=x_label, to=sqlglot_types.from_bigframes_dtype(y.dtype))
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
WITH `bfcte_0` AS (
SELECT
`rowindex`,
`timestamp_col`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
CAST(TIMESTAMP_MICROS(
CAST(CAST(`rowindex` AS BIGNUMERIC) * 86400000000 + CAST(UNIX_MICROS(CAST(`timestamp_col` AS TIMESTAMP)) AS BIGNUMERIC) AS INT64)
) AS TIMESTAMP) AS `bfcol_2`,
CAST(DATETIME(
CASE
WHEN (
MOD(
`rowindex` * 1 + EXTRACT(YEAR FROM `timestamp_col`) * 4 + EXTRACT(QUARTER FROM `timestamp_col`) - 1,
4
) + 1
) * 3 = 12
THEN CAST(FLOOR(
IEEE_DIVIDE(
`rowindex` * 1 + EXTRACT(YEAR FROM `timestamp_col`) * 4 + EXTRACT(QUARTER FROM `timestamp_col`) - 1,
4
)
) AS INT64) + 1
ELSE CAST(FLOOR(
IEEE_DIVIDE(
`rowindex` * 1 + EXTRACT(YEAR FROM `timestamp_col`) * 4 + EXTRACT(QUARTER FROM `timestamp_col`) - 1,
4
)
) AS INT64)
END,
CASE
WHEN (
MOD(
`rowindex` * 1 + EXTRACT(YEAR FROM `timestamp_col`) * 4 + EXTRACT(QUARTER FROM `timestamp_col`) - 1,
4
) + 1
) * 3 = 12
THEN 1
ELSE (
MOD(
`rowindex` * 1 + EXTRACT(YEAR FROM `timestamp_col`) * 4 + EXTRACT(QUARTER FROM `timestamp_col`) - 1,
4
) + 1
) * 3 + 1
END,
1,
0,
0,
0
) - INTERVAL 1 DAY AS TIMESTAMP) AS `bfcol_3`
FROM `bfcte_0`
)
SELECT
`bfcol_2` AS `fixed_freq`,
`bfcol_3` AS `non_fixed_freq`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
WITH `bfcte_0` AS (
SELECT
`rowindex`,
`timestamp_col`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
CAST(TIMESTAMP_MICROS(
CAST(CAST(`rowindex` AS BIGNUMERIC) * 86400000000 + CAST(UNIX_MICROS(CAST(`timestamp_col` AS TIMESTAMP)) AS BIGNUMERIC) AS INT64)
) AS TIMESTAMP) AS `bfcol_2`
FROM `bfcte_0`
)
SELECT
`bfcol_2` AS `fixed_freq`
FROM `bfcte_1`
Loading
Loading