@@ -208,6 +208,12 @@ def explicitly_ordered(self) -> bool:
208208 """
209209 ...
210210
211+ @functools .cached_property
212+ def height (self ) -> int :
213+ if len (self .child_nodes ) == 0 :
214+ return 0
215+ return max (child .height for child in self .child_nodes ) + 1
216+
211217 @functools .cached_property
212218 def total_variables (self ) -> int :
213219 return self .variables_introduced + sum (
@@ -284,6 +290,34 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
284290 return self .transform_children (lambda x : x .prune (used_cols ))
285291
286292
293+ class AdditiveNode :
294+ """Definition of additive - if you drop added_fields, you end up with the descendent.
295+
296+ .. code-block:: text
297+
298+ AdditiveNode (fields: a, b, c; added_fields: c)
299+ |
300+ | additive_base
301+ V
302+ BigFrameNode (fields: a, b)
303+
304+ """
305+
306+ @property
307+ @abc .abstractmethod
308+ def added_fields (self ) -> Tuple [Field , ...]:
309+ ...
310+
311+ @property
312+ @abc .abstractmethod
313+ def additive_base (self ) -> BigFrameNode :
314+ ...
315+
316+ @abc .abstractmethod
317+ def replace_additive_base (self , BigFrameNode ):
318+ ...
319+
320+
287321@dataclasses .dataclass (frozen = True , eq = False )
288322class UnaryNode (BigFrameNode ):
289323 child : BigFrameNode
@@ -381,6 +415,106 @@ def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]):
381415 return self
382416
383417
418+ @dataclasses .dataclass (frozen = True , eq = False )
419+ class InNode (BigFrameNode , AdditiveNode ):
420+ """
421+ Special Join Type that only returns rows from the left side, as well as adding a bool column indicating whether a match exists on the right side.
422+
423+ Modelled separately from join node, as this operation preserves row identity.
424+ """
425+
426+ left_child : BigFrameNode
427+ right_child : BigFrameNode
428+ left_col : ex .DerefOp
429+ right_col : ex .DerefOp
430+ indicator_col : bfet_ids .ColumnId
431+
432+ def _validate (self ):
433+ assert not (
434+ set (self .left_child .ids ) & set (self .right_child .ids )
435+ ), "Join ids collide"
436+
437+ @property
438+ def row_preserving (self ) -> bool :
439+ return False
440+
441+ @property
442+ def non_local (self ) -> bool :
443+ return True
444+
445+ @property
446+ def child_nodes (self ) -> typing .Sequence [BigFrameNode ]:
447+ return (self .left_child , self .right_child )
448+
449+ @property
450+ def order_ambiguous (self ) -> bool :
451+ return False
452+
453+ @property
454+ def explicitly_ordered (self ) -> bool :
455+ # Preserves left ordering always
456+ return True
457+
458+ @property
459+ def added_fields (self ) -> Tuple [Field , ...]:
460+ return (Field (self .indicator_col , bigframes .dtypes .BOOL_DTYPE ),)
461+
462+ @property
463+ def fields (self ) -> Iterable [Field ]:
464+ return itertools .chain (
465+ self .left_child .fields ,
466+ self .added_fields ,
467+ )
468+
469+ @functools .cached_property
470+ def variables_introduced (self ) -> int :
471+ """Defines the number of variables generated by the current node. Used to estimate query planning complexity."""
472+ return 1
473+
474+ @property
475+ def joins (self ) -> bool :
476+ return True
477+
478+ @property
479+ def row_count (self ) -> Optional [int ]:
480+ return self .left_child .row_count
481+
482+ @property
483+ def node_defined_ids (self ) -> Tuple [bfet_ids .ColumnId , ...]:
484+ return (self .indicator_col ,)
485+
486+ @property
487+ def additive_base (self ) -> BigFrameNode :
488+ return self .left_child
489+
490+ def replace_additive_base (self , node : BigFrameNode ):
491+ return dataclasses .replace (self , left_child = node )
492+
493+ def transform_children (
494+ self , t : Callable [[BigFrameNode ], BigFrameNode ]
495+ ) -> BigFrameNode :
496+ transformed = dataclasses .replace (
497+ self , left_child = t (self .left_child ), right_child = t (self .right_child )
498+ )
499+ if self == transformed :
500+ # reusing existing object speeds up eq, and saves a small amount of memory
501+ return self
502+ return transformed
503+
504+ def prune (self , used_cols : COLUMN_SET ) -> BigFrameNode :
505+ return self
506+
507+ def remap_vars (
508+ self , mappings : Mapping [bfet_ids .ColumnId , bfet_ids .ColumnId ]
509+ ) -> BigFrameNode :
510+ return dataclasses .replace (
511+ self , indicator_col = mappings .get (self .indicator_col , self .indicator_col )
512+ )
513+
514+ def remap_refs (self , mappings : Mapping [bfet_ids .ColumnId , bfet_ids .ColumnId ]):
515+ return dataclasses .replace (self , left_col = self .left_col .remap_column_refs (mappings , allow_partial_bindings = True ), right_col = self .right_col .remap_column_refs (mappings , allow_partial_bindings = True )) # type: ignore
516+
517+
384518@dataclasses .dataclass (frozen = True , eq = False )
385519class JoinNode (BigFrameNode ):
386520 left_child : BigFrameNode
@@ -926,7 +1060,7 @@ class CachedTableNode(ReadTableNode):
9261060
9271061# Unary nodes
9281062@dataclasses .dataclass (frozen = True , eq = False )
929- class PromoteOffsetsNode (UnaryNode ):
1063+ class PromoteOffsetsNode (UnaryNode , AdditiveNode ):
9301064 col_id : bigframes .core .identifiers .ColumnId
9311065
9321066 @property
@@ -959,6 +1093,13 @@ def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]:
9591093 def added_fields (self ) -> Tuple [Field , ...]:
9601094 return (Field (self .col_id , bigframes .dtypes .INT_DTYPE ),)
9611095
1096+ @property
1097+ def additive_base (self ) -> BigFrameNode :
1098+ return self .child
1099+
1100+ def replace_additive_base (self , node : BigFrameNode ):
1101+ return dataclasses .replace (self , child = node )
1102+
9621103 def prune (self , used_cols : COLUMN_SET ) -> BigFrameNode :
9631104 if self .col_id not in used_cols :
9641105 return self .child .prune (used_cols )
@@ -1171,7 +1312,7 @@ def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]):
11711312
11721313
11731314@dataclasses .dataclass (frozen = True , eq = False )
1174- class ProjectionNode (UnaryNode ):
1315+ class ProjectionNode (UnaryNode , AdditiveNode ):
11751316 """Assigns new variables (without modifying existing ones)"""
11761317
11771318 assignments : typing .Tuple [
@@ -1212,6 +1353,13 @@ def row_count(self) -> Optional[int]:
12121353 def node_defined_ids (self ) -> Tuple [bfet_ids .ColumnId , ...]:
12131354 return tuple (id for _ , id in self .assignments )
12141355
1356+ @property
1357+ def additive_base (self ) -> BigFrameNode :
1358+ return self .child
1359+
1360+ def replace_additive_base (self , node : BigFrameNode ):
1361+ return dataclasses .replace (self , child = node )
1362+
12151363 def prune (self , used_cols : COLUMN_SET ) -> BigFrameNode :
12161364 pruned_assignments = tuple (i for i in self .assignments if i [1 ] in used_cols )
12171365 if len (pruned_assignments ) == 0 :
@@ -1378,7 +1526,7 @@ def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]):
13781526
13791527
13801528@dataclasses .dataclass (frozen = True , eq = False )
1381- class WindowOpNode (UnaryNode ):
1529+ class WindowOpNode (UnaryNode , AdditiveNode ):
13821530 expression : ex .Aggregation
13831531 window_spec : window .WindowSpec
13841532 output_name : bigframes .core .identifiers .ColumnId
@@ -1438,6 +1586,13 @@ def inherits_order(self) -> bool:
14381586 ) and self .expression .op .implicitly_inherits_order
14391587 return op_inherits_order or self .window_spec .row_bounded
14401588
1589+ @property
1590+ def additive_base (self ) -> BigFrameNode :
1591+ return self .child
1592+
1593+ def replace_additive_base (self , node : BigFrameNode ):
1594+ return dataclasses .replace (self , child = node )
1595+
14411596 def prune (self , used_cols : COLUMN_SET ) -> BigFrameNode :
14421597 if self .output_name not in used_cols :
14431598 return self .child .prune (used_cols )
0 commit comments