Skip to content
This repository was archived by the owner on May 7, 2026. It is now read-only.
Merged
Prev Previous commit
Next Next commit
refactor based on pr comments
  • Loading branch information
TrevorBergeron committed Apr 15, 2024
commit dd8f9ff796f5b16495a16451c4fdf65c196a40c4
85 changes: 50 additions & 35 deletions bigframes/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,26 +364,49 @@ def unpivot(
unpivot_columns: Mapping of column id to list of input column ids. Lists of input columns may use None.
passthrough_columns: Columns that will not be unpivoted. Column id will be preserved.
index_col_id (str): The column id to be used for the row labels.
dtype (dtype or list of dtype): Dtype to use for the unpivot columns. If list, must be equal in number to unpivot_columns.

Returns:
ArrayValue: The unpivoted ArrayValue
"""
# 1. construct array_value from row_labels, with offsets
unpivot_offset_id = bigframes.core.guid.generate_guid("unpivot_offsets_")
explode_offsets_id = bigframes.core.guid.generate_guid("unpivot_offsets_")
labels_array = self._create_labels_array(row_labels, index_col_ids)
labels_array = labels_array.promote_offsets(explode_offsets_id)

rows = []
for row_offset in range(len(row_labels)):
row_label = row_labels[row_offset]
row_label = (row_label,) if not isinstance(row_label, tuple) else row_label
row = {index_col_ids[i]: row_label[i] for i in range(len(index_col_ids))}
rows.append(row)
# Unpivot creates N output rows for each input row, labels disambiguate these N rows
joined_array = self._cross_join_w_labels(labels_array, how)

labels_array = ArrayValue.from_pyarrow(
pa.Table.from_pylist(rows), session=self.session
).promote_offsets(unpivot_offset_id)
# Build the output rows as a switch statment that selects between the N input columns
unpivot_exprs = []
for col_id, input_ids in unpivot_columns:
# row explode offset used to choose the input column
cases = tuple(
(
ops.eq_op.as_expr(explode_offsets_id, ex.const(i)),
ex.free_var(id_or_null)
if (id_or_null is not None)
else ex.const(None),
)
for i, id_or_null in enumerate(input_ids)
)
col_expr = ops.switch_op.as_expr(*cases)
unpivot_exprs.append((col_expr, col_id))

index_exprs = ((ex.free_var(id), id) for id in index_col_ids)
# passthrough columns are unchanged, just repeated N times each
passthrough_exprs = ((ex.free_var(id), id) for id in passthrough_columns)
return ArrayValue(
nodes.ProjectionNode(
child=joined_array.node,
assignments=(*index_exprs, *unpivot_exprs, *passthrough_exprs),
)
)

# 2. cross join labels_array with main table
def _cross_join_w_labels(
self, labels_array: ArrayValue, how: typing.Literal["left", "right"]
) -> ArrayValue:
"""
Convert each row in self to N rows, one for each label in labels array.
"""
table_join_side = (
join_def.JoinSide.LEFT if how == "left" else join_def.JoinSide.RIGHT
)
Expand All @@ -403,30 +426,22 @@ def unpivot(
joined_array = self.join(labels_array, join_def=join)
else:
joined_array = labels_array.join(self, join_def=join)
return joined_array

# 3. Build output columns by switching between input columns
unpivot_exprs = []
for col_id, input_ids in unpivot_columns:
cases = tuple(
(
ops.eq_op.as_expr(unpivot_offset_id, ex.const(i)),
ex.free_var(id_or_null)
if (id_or_null is not None)
else ex.const(None),
)
for i, id_or_null in enumerate(input_ids)
)
col_expr = ops.switch_op.as_expr(*cases)
unpivot_exprs.append((col_expr, col_id))
def _create_labels_array(
self,
row_labels: typing.Sequence[typing.Hashable],
col_ids: typing.Sequence[str],
) -> ArrayValue:
"""Create an ArrayValue from a list of label tuples."""
rows = []
for row_offset in range(len(row_labels)):
row_label = row_labels[row_offset]
row_label = (row_label,) if not isinstance(row_label, tuple) else row_label
row = {col_ids[i]: row_label[i] for i in range(len(col_ids))}
rows.append(row)

index_exprs = ((ex.free_var(id), id) for id in index_col_ids)
passthrough_exprs = ((ex.free_var(id), id) for id in passthrough_columns)
return ArrayValue(
nodes.ProjectionNode(
child=joined_array.node,
assignments=(*index_exprs, *unpivot_exprs, *passthrough_exprs),
)
)
return ArrayValue.from_pyarrow(pa.Table.from_pylist(rows), session=self.session)

def join(
self,
Expand Down
72 changes: 18 additions & 54 deletions bigframes/operations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,29 @@ def order_preserving(self) -> bool:
...


# These classes can be used to create simple ops that don't take local parameters
# All is needed is a unique name, and to register an implementation in ibis_mappings.py
@dataclasses.dataclass(frozen=True)
class UnaryOp:
class NaryOp:
@property
def name(self) -> str:
raise NotImplementedError("RowOp abstract base class has no implementation")

def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
raise NotImplementedError("Abstract operation has no output type")

@property
def order_preserving(self) -> bool:
"""Whether the row operation preserves total ordering. Can be pruned from ordering expressions."""
return False


# These classes can be used to create simple ops that don't take local parameters
# All is needed is a unique name, and to register an implementation in ibis_mappings.py
@dataclasses.dataclass(frozen=True)
class UnaryOp(NaryOp):
@property
def arguments(self) -> int:
return 1

def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
raise NotImplementedError("Abstract operation has no output type")

def as_expr(
self, input_id: typing.Union[str, bigframes.core.expression.Expression] = "arg"
) -> bigframes.core.expression.Expression:
Expand All @@ -69,25 +77,13 @@ def as_expr(
self, (_convert_expr_input(input_id),)
)

@property
def order_preserving(self) -> bool:
"""Whether the row operation preserves total ordering. Can be pruned from ordering expressions."""
return False


@dataclasses.dataclass(frozen=True)
class BinaryOp:
@property
def name(self) -> str:
raise NotImplementedError("RowOp abstract base class has no implementation")

class BinaryOp(NaryOp):
@property
def arguments(self) -> int:
return 2

def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
raise NotImplementedError("Abstract operation has no output type")

def as_expr(
self,
left_input: typing.Union[str, bigframes.core.expression.Expression] = "arg1",
Expand All @@ -103,25 +99,13 @@ def as_expr(
),
)

@property
def order_preserving(self) -> bool:
"""Whether the row operation preserves total ordering. Can be pruned from ordering expressions."""
return False


@dataclasses.dataclass(frozen=True)
class TernaryOp:
@property
def name(self) -> str:
raise NotImplementedError("RowOp abstract base class has no implementation")

class TernaryOp(NaryOp):
@property
def arguments(self) -> int:
return 3

def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
raise NotImplementedError("Abstract operation has no output type")

def as_expr(
self,
input1: typing.Union[str, bigframes.core.expression.Expression] = "arg1",
Expand All @@ -139,26 +123,6 @@ def as_expr(
),
)

@property
def order_preserving(self) -> bool:
"""Whether the row operation preserves total ordering. Can be pruned from ordering expressions."""
return False


@dataclasses.dataclass(frozen=True)
class NaryOp:
@property
def name(self) -> str:
raise NotImplementedError("RowOp abstract base class has no implementation")

def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
raise NotImplementedError("Abstract operation has no output type")

@property
def order_preserving(self) -> bool:
"""Whether the row operation preserves total ordering. Can be pruned from ordering expressions."""
return False


def _convert_expr_input(
input: typing.Union[str, bigframes.core.expression.Expression]
Expand Down Expand Up @@ -677,7 +641,7 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
clip_op = ClipOp()


class SwitchOp(NaryOp):
class CaseWhenOp(NaryOp):
name: typing.ClassVar[str] = "switch"

def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
Expand Down Expand Up @@ -713,7 +677,7 @@ def as_expr(
)


switch_op = SwitchOp()
switch_op = CaseWhenOp()


# Just parameterless unary ops for now
Expand Down