Skip to content

Commit 290f95d

Browse files
authored
feat: support list of numerics in pandas.cut (#580)
An internal user encountered this missing overload
1 parent 8fc26c4 commit 290f95d

File tree

4 files changed

+93
-7
lines changed

4 files changed

+93
-7
lines changed

bigframes/core/reshape/__init__.py

+23-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import annotations
1515

1616
import typing
17-
from typing import Iterable, Literal, Optional, Tuple, Union
17+
from typing import Iterable, Literal, Optional, Union
1818

1919
import pandas as pd
2020

@@ -113,7 +113,7 @@ def cut(
113113
bins: Union[
114114
int,
115115
pd.IntervalIndex,
116-
Iterable[Tuple[Union[int, float], Union[int, float]]],
116+
Iterable,
117117
],
118118
*,
119119
labels: Optional[bool] = None,
@@ -125,9 +125,29 @@ def cut(
125125
if isinstance(bins, pd.IntervalIndex):
126126
as_index: pd.IntervalIndex = bins
127127
bins = tuple((bin.left.item(), bin.right.item()) for bin in bins)
128-
else:
128+
elif len(list(bins)) == 0:
129+
raise ValueError("`bins` iterable should have at least one item")
130+
elif isinstance(list(bins)[0], tuple):
129131
as_index = pd.IntervalIndex.from_tuples(list(bins))
130132
bins = tuple(bins)
133+
elif pd.api.types.is_number(list(bins)[0]):
134+
bins_list = list(bins)
135+
if len(bins_list) < 2:
136+
raise ValueError(
137+
"`bins` iterable of numeric breaks should have"
138+
" at least two items"
139+
)
140+
as_index = pd.IntervalIndex.from_breaks(bins_list)
141+
single_type = all([isinstance(n, type(bins_list[0])) for n in bins_list])
142+
numeric_type = type(bins_list[0]) if single_type else float
143+
bins = tuple(
144+
[
145+
(numeric_type(bins_list[i]), numeric_type(bins_list[i + 1]))
146+
for i in range(len(bins_list) - 1)
147+
]
148+
)
149+
else:
150+
raise ValueError("`bins` iterable should contain tuples or numerics")
131151

132152
if as_index.is_overlapping:
133153
raise ValueError("Overlapping IntervalIndex is not accepted.")

bigframes/operations/aggregations.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import abc
1818
import dataclasses
1919
import typing
20-
from typing import ClassVar, Hashable, Optional, Tuple
20+
from typing import ClassVar, Iterable, Optional
2121

2222
import pandas as pd
2323
import pyarrow as pa
@@ -213,7 +213,7 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
213213
@dataclasses.dataclass(frozen=True)
214214
class CutOp(UnaryWindowOp):
215215
# TODO: Unintuitive, refactor into multiple ops?
216-
bins: typing.Union[int, Tuple[Tuple[Hashable, Hashable], ...]]
216+
bins: typing.Union[int, Iterable]
217217
labels: Optional[bool]
218218

219219
@property
@@ -232,7 +232,7 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
232232
interval_dtype = (
233233
pa.float64()
234234
if isinstance(self.bins, int)
235-
else dtypes.infer_literal_arrow_type(self.bins[0][0])
235+
else dtypes.infer_literal_arrow_type(list(self.bins)[0][0])
236236
)
237237
pa_type = pa.struct(
238238
[

tests/system/small/test_pandas.py

+52
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,58 @@ def test_cut_default_labels(scalars_dfs):
424424
)
425425

426426

427+
@pytest.mark.parametrize(
428+
("breaks",),
429+
[
430+
([0, 5, 10, 15, 20, 100, 1000],), # ints
431+
([0.5, 10.5, 15.5, 20.5, 100.5, 1000.5],), # floats
432+
([0, 5, 10.5, 15.5, 20, 100, 1000.5],), # mixed
433+
],
434+
)
435+
def test_cut_numeric_breaks(scalars_dfs, breaks):
436+
scalars_df, scalars_pandas_df = scalars_dfs
437+
438+
pd_result = pd.cut(scalars_pandas_df["float64_col"], breaks)
439+
bf_result = bpd.cut(scalars_df["float64_col"], breaks).to_pandas()
440+
441+
# Convert to match data format
442+
pd_result_converted = pd.Series(
443+
[
444+
{"left_exclusive": interval.left, "right_inclusive": interval.right}
445+
if pd.notna(val)
446+
else pd.NA
447+
for val, interval in zip(
448+
pd_result, pd_result.cat.categories[pd_result.cat.codes]
449+
)
450+
],
451+
name=pd_result.name,
452+
)
453+
454+
pd.testing.assert_series_equal(
455+
bf_result, pd_result_converted, check_index=False, check_dtype=False
456+
)
457+
458+
459+
@pytest.mark.parametrize(
460+
("bins",),
461+
[
462+
(-1,), # negative integer bins argument
463+
([],), # empty iterable of bins
464+
(["notabreak"],), # iterable of wrong type
465+
([1],), # numeric breaks with only one numeric
466+
# this is supported by pandas but not by
467+
# the bigquery operation and a bigframes workaround
468+
# is not yet available. Should return column
469+
# of structs with all NaN values.
470+
],
471+
)
472+
def test_cut_errors(scalars_dfs, bins):
473+
scalars_df, _ = scalars_dfs
474+
475+
with pytest.raises(ValueError):
476+
bpd.cut(scalars_df["float64_col"], bins)
477+
478+
427479
@pytest.mark.parametrize(
428480
("bins",),
429481
[

third_party/bigframes_vendored/pandas/core/reshape/tile.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,20 @@ def cut(
7676
3 {'left_exclusive': 5, 'right_inclusive': 20}
7777
dtype: struct<left_exclusive: int64, right_inclusive: int64>[pyarrow]
7878
79+
Cut with an iterable of ints:
80+
81+
>>> bins_ints = [0, 1, 5, 20]
82+
>>> bpd.cut(s, bins=bins_ints)
83+
0 <NA>
84+
1 {'left_exclusive': 0, 'right_inclusive': 1}
85+
2 {'left_exclusive': 1, 'right_inclusive': 5}
86+
3 {'left_exclusive': 5, 'right_inclusive': 20}
87+
dtype: struct<left_exclusive: int64, right_inclusive: int64>[pyarrow]
88+
7989
Args:
8090
x (Series):
8191
The input Series to be binned. Must be 1-dimensional.
82-
bins (int, pd.IntervalIndex, Iterable[Tuple[Union[int, float], Union[int, float]]]):
92+
bins (int, pd.IntervalIndex, Iterable):
8393
The criteria to bin by.
8494
8595
int: Defines the number of equal-width bins in the range of `x`. The
@@ -88,6 +98,10 @@ def cut(
8898
8999
pd.IntervalIndex or Iterable of tuples: Defines the exact bins to be used.
90100
It's important to ensure that these bins are non-overlapping.
101+
102+
Iterable of numerics: Defines the exact bins by using the interval
103+
between each item and its following item. The items must be monotonically
104+
increasing.
91105
labels (None):
92106
Specifies the labels for the returned bins. Must be the same length as
93107
the resulting bins. If False, returns only integer indicators of the

0 commit comments

Comments
 (0)