Skip to content

Commit 44b738d

Browse files
perf: Cache transpose to allow performant retranspose (#635)
1 parent 3ffc1d2 commit 44b738d

File tree

5 files changed

+105
-55
lines changed

5 files changed

+105
-55
lines changed

bigframes/core/blocks.py

+58-22
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import os
2828
import random
2929
import typing
30-
from typing import Iterable, List, Literal, Mapping, Optional, Sequence, Tuple
30+
from typing import Iterable, List, Literal, Mapping, Optional, Sequence, Tuple, Union
3131
import warnings
3232

3333
import google.cloud.bigquery as bigquery
@@ -105,6 +105,8 @@ def __init__(
105105
index_columns: Iterable[str],
106106
column_labels: typing.Union[pd.Index, typing.Iterable[Label]],
107107
index_labels: typing.Union[pd.Index, typing.Iterable[Label], None] = None,
108+
*,
109+
transpose_cache: Optional[Block] = None,
108110
):
109111
"""Construct a block object, will create default index if no index columns specified."""
110112
index_columns = list(index_columns)
@@ -144,6 +146,7 @@ def __init__(
144146
# TODO(kemppeterson) Add a cache for corr to parallel the single-column stats.
145147

146148
self._stats_cache[" ".join(self.index_columns)] = {}
149+
self._transpose_cache: Optional[Block] = transpose_cache
147150

148151
@classmethod
149152
def from_local(cls, data: pd.DataFrame, session: bigframes.Session) -> Block:
@@ -716,6 +719,15 @@ def with_column_labels(
716719
index_labels=self.index.names,
717720
)
718721

722+
def with_transpose_cache(self, transposed: Block):
723+
return Block(
724+
self._expr,
725+
index_columns=self.index_columns,
726+
column_labels=self._column_labels,
727+
index_labels=self.index.names,
728+
transpose_cache=transposed,
729+
)
730+
719731
def with_index_labels(self, value: typing.Sequence[Label]) -> Block:
720732
if len(value) != len(self.index_columns):
721733
raise ValueError(
@@ -804,18 +816,35 @@ def multi_apply_window_op(
804816
def multi_apply_unary_op(
805817
self,
806818
columns: typing.Sequence[str],
807-
op: ops.UnaryOp,
819+
op: Union[ops.UnaryOp, ex.Expression],
808820
) -> Block:
821+
if isinstance(op, ops.UnaryOp):
822+
input_varname = guid.generate_guid()
823+
expr = op.as_expr(input_varname)
824+
else:
825+
input_varnames = op.unbound_variables
826+
assert len(input_varnames) == 1
827+
expr = op
828+
input_varname = input_varnames[0]
829+
809830
block = self
810-
for i, col_id in enumerate(columns):
831+
for col_id in columns:
811832
label = self.col_id_to_label[col_id]
812-
block, result_id = block.apply_unary_op(
813-
col_id,
814-
op,
815-
result_label=label,
833+
block, result_id = block.project_expr(
834+
expr.bind_all_variables({input_varname: ex.free_var(col_id)}),
835+
label=label,
816836
)
817837
block = block.copy_values(result_id, col_id)
818838
block = block.drop_columns([result_id])
839+
# Special case, we can preserve transpose cache for full-frame unary ops
840+
if (self._transpose_cache is not None) and set(self.value_columns) == set(
841+
columns
842+
):
843+
transpose_columns = self._transpose_cache.value_columns
844+
new_transpose_cache = self._transpose_cache.multi_apply_unary_op(
845+
transpose_columns, op
846+
)
847+
block = block.with_transpose_cache(new_transpose_cache)
819848
return block
820849

821850
def apply_window_op(
@@ -922,20 +951,17 @@ def aggregate_all_and_stack(
922951
(ex.UnaryAggregation(operation, ex.free_var(col_id)), col_id)
923952
for col_id in self.value_columns
924953
]
925-
index_col_ids = [
926-
guid.generate_guid() for i in range(self.column_labels.nlevels)
927-
]
928-
result_expr = self.expr.aggregate(aggregations, dropna=dropna).unpivot(
929-
row_labels=self.column_labels.to_list(),
930-
index_col_ids=index_col_ids,
931-
unpivot_columns=tuple([(value_col_id, tuple(self.value_columns))]),
932-
)
954+
index_id = guid.generate_guid()
955+
result_expr = self.expr.aggregate(
956+
aggregations, dropna=dropna
957+
).assign_constant(index_id, None, None)
958+
# Transpose as last operation so that final block has valid transpose cache
933959
return Block(
934960
result_expr,
935-
index_columns=index_col_ids,
936-
column_labels=[None],
937-
index_labels=self.column_labels.names,
938-
)
961+
index_columns=[index_id],
962+
column_labels=self.column_labels,
963+
index_labels=[None],
964+
).transpose(original_row_index=pd.Index([None]))
939965
else: # axis_n == 1
940966
# using offsets as identity to group on.
941967
# TODO: Allow to promote identity/total_order columns instead for better perf
@@ -1575,10 +1601,19 @@ def melt(
15751601
index_columns=[index_id],
15761602
)
15771603

1578-
def transpose(self) -> Block:
1579-
"""Transpose the block. Will fail if dtypes aren't coercible to a common type or too many rows"""
1604+
def transpose(self, *, original_row_index: Optional[pd.Index] = None) -> Block:
1605+
"""Transpose the block. Will fail if dtypes aren't coercible to a common type or too many rows.
1606+
Can provide the original_row_index directly if it is already known, otherwise a query is needed.
1607+
"""
1608+
if self._transpose_cache is not None:
1609+
return self._transpose_cache.with_transpose_cache(self)
1610+
15801611
original_col_index = self.column_labels
1581-
original_row_index = self.index.to_pandas()
1612+
original_row_index = (
1613+
original_row_index
1614+
if original_row_index is not None
1615+
else self.index.to_pandas()
1616+
)
15821617
original_row_count = len(original_row_index)
15831618
if original_row_count > bigframes.constants.MAX_COLUMNS:
15841619
raise NotImplementedError(
@@ -1619,6 +1654,7 @@ def transpose(self) -> Block:
16191654
result.with_column_labels(original_row_index)
16201655
.order_by([ordering.ascending_over(result.index_columns[-1])])
16211656
.drop_levels([result.index_columns[-1]])
1657+
.with_transpose_cache(self)
16221658
)
16231659

16241660
def _create_stack_column(

bigframes/core/compile/compiled.py

+7-10
Original file line numberDiff line numberDiff line change
@@ -823,17 +823,14 @@ def to_sql(
823823
for col in baked_ir.column_ids
824824
]
825825
selection = ", ".join(map(lambda col_id: f"`{col_id}`", output_columns))
826-
order_by_clause = baked_ir._ordering_clause(
827-
baked_ir._ordering.all_ordering_columns
828-
)
829826

830-
sql = textwrap.dedent(
831-
f"SELECT {selection}\n"
832-
"FROM (\n"
833-
f"{sql}\n"
834-
")\n"
835-
f"{order_by_clause}\n"
836-
)
827+
sql = textwrap.dedent(f"SELECT {selection}\n" "FROM (\n" f"{sql}\n" ")\n")
828+
# Single row frames may not have any ordering columns
829+
if len(baked_ir._ordering.all_ordering_columns) > 0:
830+
order_by_clause = baked_ir._ordering_clause(
831+
baked_ir._ordering.all_ordering_columns
832+
)
833+
sql += f"{order_by_clause}\n"
837834
else:
838835
sql = ibis_bigquery.Backend().compile(
839836
self._to_ibis_expr(

bigframes/core/ordering.py

+2
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,8 @@ def _truncate_ordering(
167167
truncated_refs.append(order_part)
168168
if columns_seen.issuperset(must_see):
169169
return tuple(truncated_refs)
170+
if len(must_see) == 0:
171+
return ()
170172
raise ValueError("Ordering did not contain all total_order_cols")
171173

172174
def with_reverse(self):

bigframes/dataframe.py

+24-23
Original file line numberDiff line numberDiff line change
@@ -693,18 +693,19 @@ def _apply_binop(
693693
def _apply_scalar_binop(
694694
self, other: float | int, op: ops.BinaryOp, reverse: bool = False
695695
) -> DataFrame:
696-
block = self._block
697-
for column_id, label in zip(
698-
self._block.value_columns, self._block.column_labels
699-
):
700-
expr = (
701-
op.as_expr(ex.const(other), column_id)
702-
if reverse
703-
else op.as_expr(column_id, ex.const(other))
696+
if reverse:
697+
expr = op.as_expr(
698+
left_input=ex.const(other),
699+
right_input=bigframes.core.guid.generate_guid(),
704700
)
705-
block, _ = block.project_expr(expr, label)
706-
block = block.drop_columns([column_id])
707-
return DataFrame(block)
701+
else:
702+
expr = op.as_expr(
703+
left_input=bigframes.core.guid.generate_guid(),
704+
right_input=ex.const(other),
705+
)
706+
return DataFrame(
707+
self._block.multi_apply_unary_op(self._block.value_columns, expr)
708+
)
708709

709710
def _apply_series_binop_axis_0(
710711
self,
@@ -1974,7 +1975,7 @@ def any(
19741975
else:
19751976
frame = self._drop_non_bool()
19761977
block = frame._block.aggregate_all_and_stack(agg_ops.any_op, axis=axis)
1977-
return bigframes.series.Series(block.select_column("values"))
1978+
return bigframes.series.Series(block)
19781979

19791980
def all(
19801981
self, axis: typing.Union[str, int] = 0, *, bool_only: bool = False
@@ -1984,7 +1985,7 @@ def all(
19841985
else:
19851986
frame = self._drop_non_bool()
19861987
block = frame._block.aggregate_all_and_stack(agg_ops.all_op, axis=axis)
1987-
return bigframes.series.Series(block.select_column("values"))
1988+
return bigframes.series.Series(block)
19881989

19891990
def sum(
19901991
self, axis: typing.Union[str, int] = 0, *, numeric_only: bool = False
@@ -1994,7 +1995,7 @@ def sum(
19941995
else:
19951996
frame = self._drop_non_numeric()
19961997
block = frame._block.aggregate_all_and_stack(agg_ops.sum_op, axis=axis)
1997-
return bigframes.series.Series(block.select_column("values"))
1998+
return bigframes.series.Series(block)
19981999

19992000
def mean(
20002001
self, axis: typing.Union[str, int] = 0, *, numeric_only: bool = False
@@ -2004,7 +2005,7 @@ def mean(
20042005
else:
20052006
frame = self._drop_non_numeric()
20062007
block = frame._block.aggregate_all_and_stack(agg_ops.mean_op, axis=axis)
2007-
return bigframes.series.Series(block.select_column("values"))
2008+
return bigframes.series.Series(block)
20082009

20092010
def median(
20102011
self, *, numeric_only: bool = False, exact: bool = True
@@ -2019,7 +2020,7 @@ def median(
20192020
return result
20202021
else:
20212022
block = frame._block.aggregate_all_and_stack(agg_ops.median_op)
2022-
return bigframes.series.Series(block.select_column("values"))
2023+
return bigframes.series.Series(block)
20232024

20242025
def quantile(
20252026
self, q: Union[float, Sequence[float]] = 0.5, *, numeric_only: bool = False
@@ -2052,7 +2053,7 @@ def std(
20522053
else:
20532054
frame = self._drop_non_numeric()
20542055
block = frame._block.aggregate_all_and_stack(agg_ops.std_op, axis=axis)
2055-
return bigframes.series.Series(block.select_column("values"))
2056+
return bigframes.series.Series(block)
20562057

20572058
def var(
20582059
self, axis: typing.Union[str, int] = 0, *, numeric_only: bool = False
@@ -2062,7 +2063,7 @@ def var(
20622063
else:
20632064
frame = self._drop_non_numeric()
20642065
block = frame._block.aggregate_all_and_stack(agg_ops.var_op, axis=axis)
2065-
return bigframes.series.Series(block.select_column("values"))
2066+
return bigframes.series.Series(block)
20662067

20672068
def min(
20682069
self, axis: typing.Union[str, int] = 0, *, numeric_only: bool = False
@@ -2072,7 +2073,7 @@ def min(
20722073
else:
20732074
frame = self._drop_non_numeric()
20742075
block = frame._block.aggregate_all_and_stack(agg_ops.min_op, axis=axis)
2075-
return bigframes.series.Series(block.select_column("values"))
2076+
return bigframes.series.Series(block)
20762077

20772078
def max(
20782079
self, axis: typing.Union[str, int] = 0, *, numeric_only: bool = False
@@ -2082,7 +2083,7 @@ def max(
20822083
else:
20832084
frame = self._drop_non_numeric()
20842085
block = frame._block.aggregate_all_and_stack(agg_ops.max_op, axis=axis)
2085-
return bigframes.series.Series(block.select_column("values"))
2086+
return bigframes.series.Series(block)
20862087

20872088
def prod(
20882089
self, axis: typing.Union[str, int] = 0, *, numeric_only: bool = False
@@ -2092,7 +2093,7 @@ def prod(
20922093
else:
20932094
frame = self._drop_non_numeric()
20942095
block = frame._block.aggregate_all_and_stack(agg_ops.product_op, axis=axis)
2095-
return bigframes.series.Series(block.select_column("values"))
2096+
return bigframes.series.Series(block)
20962097

20972098
product = prod
20982099
product.__doc__ = inspect.getdoc(vendored_pandas_frame.DataFrame.prod)
@@ -2103,11 +2104,11 @@ def count(self, *, numeric_only: bool = False) -> bigframes.series.Series:
21032104
else:
21042105
frame = self._drop_non_numeric()
21052106
block = frame._block.aggregate_all_and_stack(agg_ops.count_op)
2106-
return bigframes.series.Series(block.select_column("values"))
2107+
return bigframes.series.Series(block)
21072108

21082109
def nunique(self) -> bigframes.series.Series:
21092110
block = self._block.aggregate_all_and_stack(agg_ops.nunique_op)
2110-
return bigframes.series.Series(block.select_column("values"))
2111+
return bigframes.series.Series(block)
21112112

21122113
def agg(
21132114
self, func: str | typing.Sequence[str]

tests/system/small/test_dataframe.py

+14
Original file line numberDiff line numberDiff line change
@@ -2524,6 +2524,20 @@ def test_df_transpose_error():
25242524
dataframe.DataFrame([[1, "hello"], [2, "world"]]).transpose()
25252525

25262526

2527+
def test_df_transpose_repeated_uses_cache():
2528+
bf_df = dataframe.DataFrame([[1, 2.5], [2, 3.5]])
2529+
pd_df = pandas.DataFrame([[1, 2.5], [2, 3.5]])
2530+
# Transposing many times so that operation will fail from complexity if not using cache
2531+
for i in range(10):
2532+
# Cache still works even with simple scalar binop
2533+
bf_df = bf_df.transpose() + i
2534+
pd_df = pd_df.transpose() + i
2535+
2536+
pd.testing.assert_frame_equal(
2537+
pd_df, bf_df.to_pandas(), check_dtype=False, check_index_type=False
2538+
)
2539+
2540+
25272541
@pytest.mark.parametrize(
25282542
("ordered"),
25292543
[

0 commit comments

Comments
 (0)