Skip to content

Commit da9b136

Browse files
fix: Improve escaping of literals and identifiers (#682)
1 parent 96243f2 commit da9b136

File tree

7 files changed

+178
-92
lines changed

7 files changed

+178
-92
lines changed

bigframes/core/blocks.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -2096,7 +2096,7 @@ def _get_rows_as_json_values(self) -> Block:
20962096
)
20972097

20982098
column_names.append(serialized_column_name)
2099-
column_names_csv = sql.csv(column_names, quoted=True)
2099+
column_names_csv = sql.csv(map(sql.simple_literal, column_names))
21002100

21012101
# index columns count
21022102
index_columns_count = len(self.index_columns)
@@ -2108,22 +2108,22 @@ def _get_rows_as_json_values(self) -> Block:
21082108

21092109
# types of the columns to serialize for the row
21102110
column_types = list(self.index.dtypes) + list(self.dtypes)
2111-
column_types_csv = sql.csv([str(typ) for typ in column_types], quoted=True)
2111+
column_types_csv = sql.csv(
2112+
[sql.simple_literal(str(typ)) for typ in column_types]
2113+
)
21122114

21132115
# row dtype to use for deserializing the row as pandas series
21142116
pandas_row_dtype = bigframes.dtypes.lcd_type(*column_types)
21152117
if pandas_row_dtype is None:
21162118
pandas_row_dtype = "object"
2117-
pandas_row_dtype = sql.quote(str(pandas_row_dtype))
2119+
pandas_row_dtype = sql.simple_literal(str(pandas_row_dtype))
21182120

21192121
# create a json column representing row through SQL manipulation
21202122
row_json_column_name = guid.generate_guid()
21212123
select_columns = (
21222124
[ordering_column_name] + list(self.index_columns) + [row_json_column_name]
21232125
)
2124-
select_columns_csv = sql.csv(
2125-
[sql.column_reference(col) for col in select_columns]
2126-
)
2126+
select_columns_csv = sql.csv([sql.identifier(col) for col in select_columns])
21272127
json_sql = f"""\
21282128
With T0 AS (
21292129
{textwrap.indent(expr_sql, " ")}
@@ -2136,7 +2136,7 @@ def _get_rows_as_json_values(self) -> Block:
21362136
"values", [{column_references_csv}],
21372137
"indexlength", {index_columns_count},
21382138
"dtype", {pandas_row_dtype}
2139-
) AS {row_json_column_name} FROM T0
2139+
) AS {sql.identifier(row_json_column_name)} FROM T0
21402140
)
21412141
SELECT {select_columns_csv} FROM T1
21422142
"""

bigframes/core/compile/compiled.py

+5-23
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,8 @@
1616
import abc
1717
import functools
1818
import itertools
19-
import textwrap
2019
import typing
21-
from typing import Collection, Iterable, Literal, Optional, Sequence
20+
from typing import Collection, Literal, Optional, Sequence
2221

2322
import bigframes_vendored.ibis.expr.operations as vendored_ibis_ops
2423
import ibis
@@ -40,6 +39,7 @@
4039
OrderingExpression,
4140
)
4241
import bigframes.core.schema as schemata
42+
import bigframes.core.sql
4343
from bigframes.core.window_spec import RangeWindowBounds, RowsWindowBounds, WindowSpec
4444
import bigframes.dtypes
4545
import bigframes.operations.aggregations as agg_ops
@@ -821,15 +821,13 @@ def to_sql(
821821
)
822822
)
823823
output_columns = [
824-
col_id_overrides.get(col) if (col in col_id_overrides) else col
825-
for col in baked_ir.column_ids
824+
col_id_overrides.get(col, col) for col in baked_ir.column_ids
826825
]
827-
selection = ", ".join(map(lambda col_id: f"`{col_id}`", output_columns))
826+
sql = bigframes.core.sql.select_from(output_columns, sql)
828827

829-
sql = textwrap.dedent(f"SELECT {selection}\n" "FROM (\n" f"{sql}\n" ")\n")
830828
# Single row frames may not have any ordering columns
831829
if len(baked_ir._ordering.all_ordering_columns) > 0:
832-
order_by_clause = baked_ir._ordering_clause(
830+
order_by_clause = bigframes.core.sql.ordering_clause(
833831
baked_ir._ordering.all_ordering_columns
834832
)
835833
sql += f"{order_by_clause}\n"
@@ -843,22 +841,6 @@ def to_sql(
843841
)
844842
return typing.cast(str, sql)
845843

846-
def _ordering_clause(self, ordering: Iterable[OrderingExpression]) -> str:
847-
parts = []
848-
for col_ref in ordering:
849-
asc_desc = "ASC" if col_ref.direction.is_ascending else "DESC"
850-
null_clause = "NULLS LAST" if col_ref.na_last else "NULLS FIRST"
851-
ordering_expr = col_ref.scalar_expression
852-
# We don't know how to compile scalar expressions in isolation
853-
if ordering_expr.is_const:
854-
# Probably shouldn't have constants in ordering definition, but best to ignore if somehow they end up here.
855-
continue
856-
if not isinstance(ordering_expr, ex.UnboundVariableExpression):
857-
raise ValueError("Expected direct column reference.")
858-
part = f"`{ordering_expr.id}` {asc_desc} {null_clause}"
859-
parts.append(part)
860-
return f"ORDER BY {' ,'.join(parts)}"
861-
862844
def _to_ibis_expr(
863845
self,
864846
*,

bigframes/core/sql.py

+129-26
Original file line numberDiff line numberDiff line change
@@ -11,49 +11,152 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from __future__ import annotations
1415

1516
"""
1617
Utility functions for SQL construction.
1718
"""
1819

19-
from typing import Iterable
20+
import datetime
21+
import math
22+
import textwrap
23+
from typing import Iterable, TYPE_CHECKING
2024

25+
# Literals and identifiers matching this pattern can be unquoted
26+
unquoted = r"^[A-Za-z_][A-Za-z_0-9]*$"
2127

22-
def quote(value: str):
23-
"""Return quoted input string."""
2428

25-
# Let's use repr which also escapes any special characters
26-
#
27-
# >>> for val in [
28-
# ... "123",
29-
# ... "str with no special chars",
30-
# ... "str with special chars.,'\"/\\"
31-
# ... ]:
32-
# ... print(f"{val} -> {repr(val)}")
33-
# ...
34-
# 123 -> '123'
35-
# str with no special chars -> 'str with no special chars'
36-
# str with special chars.,'"/\ -> 'str with special chars.,\'"/\\'
29+
if TYPE_CHECKING:
30+
import google.cloud.bigquery as bigquery
3731

38-
return repr(value)
32+
import bigframes.core.ordering
3933

4034

41-
def column_reference(column_name: str):
35+
### Writing SQL Values (literals, column references, table references, etc.)
36+
def simple_literal(value: str | int | bool | float):
37+
"""Return quoted input string."""
38+
# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#literals
39+
if isinstance(value, str):
40+
# Single quoting seems to work nicer with ibis than double quoting
41+
return f"'{escape_special_characters(value)}'"
42+
elif isinstance(value, (bool, int)):
43+
return str(value)
44+
elif isinstance(value, float):
45+
# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#floating_point_literals
46+
if math.isnan(value):
47+
return 'CAST("nan" as FLOAT)'
48+
if value == math.inf:
49+
return 'CAST("+inf" as FLOAT)'
50+
if value == -math.inf:
51+
return 'CAST("-inf" as FLOAT)'
52+
return str(value)
53+
else:
54+
raise ValueError(f"Cannot produce literal for {value}")
55+
56+
57+
def multi_literal(*values: str):
58+
literal_strings = [simple_literal(i) for i in values]
59+
return "(" + ", ".join(literal_strings) + ")"
60+
61+
62+
def identifier(id: str) -> str:
4263
"""Return a string representing column reference in a SQL."""
64+
# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#identifiers
65+
# Just always escape, otherwise need to check against every reserved sql keyword
66+
return f"`{escape_special_characters(id)}`"
67+
68+
69+
def escape_special_characters(value: str):
70+
"""Escapes all special charactesrs"""
71+
# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#string_and_bytes_literals
72+
trans_table = str.maketrans(
73+
{
74+
"\a": r"\a",
75+
"\b": r"\b",
76+
"\f": r"\f",
77+
"\n": r"\n",
78+
"\r": r"\r",
79+
"\t": r"\t",
80+
"\v": r"\v",
81+
"\\": r"\\",
82+
"?": r"\?",
83+
'"': r"\"",
84+
"'": r"\'",
85+
"`": r"\`",
86+
}
87+
)
88+
return value.translate(trans_table)
89+
90+
91+
def cast_as_string(column_name: str) -> str:
92+
"""Return a string representing string casting of a column."""
4393

44-
return f"`{column_name}`"
94+
return f"CAST({identifier(column_name)} AS STRING)"
4595

4696

47-
def cast_as_string(column_name: str):
48-
"""Return a string representing string casting of a column."""
97+
def csv(values: Iterable[str]) -> str:
98+
"""Return a string of comma separated values."""
99+
return ", ".join(values)
49100

50-
return f"CAST({column_reference(column_name)} AS STRING)"
51101

102+
def table_reference(table_ref: bigquery.TableReference) -> str:
103+
return f"`{escape_special_characters(table_ref.project)}`.`{escape_special_characters(table_ref.dataset_id)}`.`{escape_special_characters(table_ref.table_id)}`"
52104

53-
def csv(values: Iterable[str], quoted=False):
54-
"""Return a string of comma separated values."""
55105

56-
if quoted:
57-
values = [quote(val) for val in values]
106+
def infix_op(opname: str, left_arg: str, right_arg: str):
107+
# Maybe should add parentheses??
108+
return f"{left_arg} {opname} {right_arg}"
58109

59-
return ", ".join(values)
110+
111+
### Writing SELECT expressions
112+
def select_from(columns: Iterable[str], subquery: str, distinct: bool = False):
113+
selection = ", ".join(map(identifier, columns))
114+
distinct_clause = "DISTINCT " if distinct else ""
115+
116+
return textwrap.dedent(
117+
f"SELECT {distinct_clause}{selection}\nFROM (\n" f"{subquery}\n" ")\n"
118+
)
119+
120+
121+
def select_table(table_ref: bigquery.TableReference):
122+
return textwrap.dedent(f"SELECT * FROM {table_reference(table_ref)}")
123+
124+
125+
def is_distinct_sql(columns: Iterable[str], table_sql: str) -> str:
126+
is_unique_sql = f"""WITH full_table AS (
127+
{select_from(columns, table_sql)}
128+
),
129+
distinct_table AS (
130+
{select_from(columns, table_sql, distinct=True)}
131+
)
132+
133+
SELECT (SELECT COUNT(*) FROM full_table) AS `total_count`,
134+
(SELECT COUNT(*) FROM distinct_table) AS `distinct_count`
135+
"""
136+
return is_unique_sql
137+
138+
139+
def ordering_clause(
140+
ordering: Iterable[bigframes.core.ordering.OrderingExpression],
141+
) -> str:
142+
import bigframes.core.expression
143+
144+
parts = []
145+
for col_ref in ordering:
146+
asc_desc = "ASC" if col_ref.direction.is_ascending else "DESC"
147+
null_clause = "NULLS LAST" if col_ref.na_last else "NULLS FIRST"
148+
ordering_expr = col_ref.scalar_expression
149+
# We don't know how to compile scalar expressions in isolation
150+
if ordering_expr.is_const:
151+
# Probably shouldn't have constants in ordering definition, but best to ignore if somehow they end up here.
152+
continue
153+
assert isinstance(
154+
ordering_expr, bigframes.core.expression.UnboundVariableExpression
155+
)
156+
part = f"`{ordering_expr.id}` {asc_desc} {null_clause}"
157+
parts.append(part)
158+
return f"ORDER BY {' ,'.join(parts)}"
159+
160+
161+
def snapshot_clause(time_travel_timestamp: datetime.datetime):
162+
return f"FOR SYSTEM_TIME AS OF TIMESTAMP({repr(time_travel_timestamp.isoformat())})"

bigframes/session/_io/bigquery/__init__.py

+25-13
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
import bigframes
3333
from bigframes.core import log_adapter
34+
import bigframes.core.sql
3435
import bigframes.formatting_helpers as formatting_helpers
3536

3637
IO_ORDERING_ID = "bqdf_row_nums"
@@ -353,7 +354,7 @@ def to_query(
353354
else:
354355
select_clause = "SELECT *"
355356

356-
where_clause = ""
357+
filter_string = ""
357358
if filters:
358359
valid_operators: Mapping[third_party_pandas_gbq.FilterOps, str] = {
359360
"in": "IN",
@@ -373,12 +374,11 @@ def to_query(
373374
):
374375
filters = typing.cast(third_party_pandas_gbq.FiltersType, [filters])
375376

376-
or_expressions = []
377377
for group in filters:
378378
if not isinstance(group, Iterable):
379379
group = [group]
380380

381-
and_expressions = []
381+
and_expression = ""
382382
for filter_item in group:
383383
if not isinstance(filter_item, tuple) or (len(filter_item) != 3):
384384
raise ValueError(
@@ -397,17 +397,29 @@ def to_query(
397397

398398
operator_str = valid_operators[operator]
399399

400+
column_ref = bigframes.core.sql.identifier(column)
400401
if operator_str in ["IN", "NOT IN"]:
401-
value_list = ", ".join([repr(v) for v in value])
402-
expression = f"`{column}` {operator_str} ({value_list})"
402+
value_literal = bigframes.core.sql.multi_literal(*value)
403403
else:
404-
expression = f"`{column}` {operator_str} {repr(value)}"
405-
and_expressions.append(expression)
406-
407-
or_expressions.append(" AND ".join(and_expressions))
404+
value_literal = bigframes.core.sql.simple_literal(value)
405+
expression = bigframes.core.sql.infix_op(
406+
operator_str, column_ref, value_literal
407+
)
408+
if and_expression:
409+
and_expression = bigframes.core.sql.infix_op(
410+
"AND", and_expression, expression
411+
)
412+
else:
413+
and_expression = expression
408414

409-
if or_expressions:
410-
where_clause = " WHERE " + " OR ".join(or_expressions)
415+
if filter_string:
416+
filter_string = bigframes.core.sql.infix_op(
417+
"OR", filter_string, and_expression
418+
)
419+
else:
420+
filter_string = and_expression
411421

412-
full_query = f"{select_clause} FROM {sub_query} AS sub{where_clause}"
413-
return full_query
422+
if filter_string:
423+
return f"{select_clause} FROM {sub_query} AS sub WHERE {filter_string}"
424+
else:
425+
return f"{select_clause} FROM {sub_query} AS sub"

0 commit comments

Comments
 (0)