Skip to content

Commit 71b4053

Browse files
authored
feat: support dataframe where method (#1166)
* feat: support dataframe where method * fix the cond_id * fix to set df.column.name and df.column.names
1 parent b2816a5 commit 71b4053

File tree

3 files changed

+257
-0
lines changed

3 files changed

+257
-0
lines changed

bigframes/dataframe.py

+57
Original file line numberDiff line numberDiff line change
@@ -2241,6 +2241,63 @@ def itertuples(
22412241
for item in df.itertuples(index=index, name=name):
22422242
yield item
22432243

2244+
def where(self, cond, other=None):
2245+
if isinstance(other, bigframes.series.Series):
2246+
raise ValueError("Seires is not a supported replacement type!")
2247+
2248+
if self.columns.nlevels > 1 or self.index.nlevels > 1:
2249+
raise NotImplementedError(
2250+
"The dataframe.where() method does not support multi-index and/or multi-column."
2251+
)
2252+
2253+
aligned_block, (_, _) = self._block.join(cond._block, how="left")
2254+
# No left join is needed when 'other' is None or constant.
2255+
if isinstance(other, bigframes.dataframe.DataFrame):
2256+
aligned_block, (_, _) = aligned_block.join(other._block, how="left")
2257+
self_len = len(self._block.value_columns)
2258+
cond_len = len(cond._block.value_columns)
2259+
2260+
ids = aligned_block.value_columns[:self_len]
2261+
labels = aligned_block.column_labels[:self_len]
2262+
self_col = {x: ex.deref(y) for x, y in zip(labels, ids)}
2263+
2264+
if isinstance(cond, bigframes.series.Series) and cond.name in self_col:
2265+
# This is when 'cond' is a valid series.
2266+
y = aligned_block.value_columns[self_len]
2267+
cond_col = {x: ex.deref(y) for x in self_col.keys()}
2268+
else:
2269+
# This is when 'cond' is a dataframe.
2270+
ids = aligned_block.value_columns[self_len : self_len + cond_len]
2271+
labels = aligned_block.column_labels[self_len : self_len + cond_len]
2272+
cond_col = {x: ex.deref(y) for x, y in zip(labels, ids)}
2273+
2274+
if isinstance(other, DataFrame):
2275+
other_len = len(self._block.value_columns)
2276+
ids = aligned_block.value_columns[-other_len:]
2277+
labels = aligned_block.column_labels[-other_len:]
2278+
other_col = {x: ex.deref(y) for x, y in zip(labels, ids)}
2279+
else:
2280+
# This is when 'other' is None or constant.
2281+
labels = aligned_block.column_labels[:self_len]
2282+
other_col = {x: ex.const(other) for x in labels} # type: ignore
2283+
2284+
result_series = {}
2285+
for x, self_id in self_col.items():
2286+
cond_id = cond_col[x] if x in cond_col else ex.const(False)
2287+
other_id = other_col[x] if x in other_col else ex.const(None)
2288+
result_block, result_id = aligned_block.project_expr(
2289+
ops.where_op.as_expr(self_id, cond_id, other_id)
2290+
)
2291+
series = bigframes.series.Series(
2292+
result_block.select_column(result_id).with_column_labels([x])
2293+
)
2294+
result_series[x] = series
2295+
2296+
result = DataFrame(result_series)
2297+
result.columns.name = self.columns.name
2298+
result.columns.names = self.columns.names
2299+
return result
2300+
22442301
def dropna(
22452302
self,
22462303
*,

tests/system/small/test_dataframe.py

+108
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,114 @@ def test_insert(scalars_dfs, loc, column, value, allow_duplicates):
322322
pd.testing.assert_frame_equal(bf_df.to_pandas(), pd_df, check_dtype=False)
323323

324324

325+
def test_where_series_cond(scalars_df_index, scalars_pandas_df_index):
326+
# Condition is dataframe, other is None (as default).
327+
cond_bf = scalars_df_index["int64_col"] > 0
328+
cond_pd = scalars_pandas_df_index["int64_col"] > 0
329+
bf_result = scalars_df_index.where(cond_bf).to_pandas()
330+
pd_result = scalars_pandas_df_index.where(cond_pd)
331+
pandas.testing.assert_frame_equal(bf_result, pd_result)
332+
333+
334+
def test_where_series_multi_index(scalars_df_index, scalars_pandas_df_index):
335+
# Test when a dataframe has multi-index or multi-columns.
336+
columns = ["int64_col", "float64_col"]
337+
dataframe_bf = scalars_df_index[columns]
338+
339+
dataframe_bf.columns = pd.MultiIndex.from_tuples(
340+
[("str1", 1), ("str2", 2)], names=["STR", "INT"]
341+
)
342+
cond_bf = dataframe_bf["str1"] > 0
343+
344+
with pytest.raises(NotImplementedError) as context:
345+
dataframe_bf.where(cond_bf).to_pandas()
346+
assert (
347+
str(context.value)
348+
== "The dataframe.where() method does not support multi-index and/or multi-column."
349+
)
350+
351+
352+
def test_where_series_cond_const_other(scalars_df_index, scalars_pandas_df_index):
353+
# Condition is a series, other is a constant.
354+
columns = ["int64_col", "float64_col"]
355+
dataframe_bf = scalars_df_index[columns]
356+
dataframe_pd = scalars_pandas_df_index[columns]
357+
dataframe_bf.columns.name = "test_name"
358+
dataframe_pd.columns.name = "test_name"
359+
360+
cond_bf = dataframe_bf["int64_col"] > 0
361+
cond_pd = dataframe_pd["int64_col"] > 0
362+
other = 0
363+
364+
bf_result = dataframe_bf.where(cond_bf, other).to_pandas()
365+
pd_result = dataframe_pd.where(cond_pd, other)
366+
pandas.testing.assert_frame_equal(bf_result, pd_result)
367+
368+
369+
def test_where_series_cond_dataframe_other(scalars_df_index, scalars_pandas_df_index):
370+
# Condition is a series, other is a dataframe.
371+
columns = ["int64_col", "float64_col"]
372+
dataframe_bf = scalars_df_index[columns]
373+
dataframe_pd = scalars_pandas_df_index[columns]
374+
375+
cond_bf = dataframe_bf["int64_col"] > 0
376+
cond_pd = dataframe_pd["int64_col"] > 0
377+
other_bf = -dataframe_bf
378+
other_pd = -dataframe_pd
379+
380+
bf_result = dataframe_bf.where(cond_bf, other_bf).to_pandas()
381+
pd_result = dataframe_pd.where(cond_pd, other_pd)
382+
pandas.testing.assert_frame_equal(bf_result, pd_result)
383+
384+
385+
def test_where_dataframe_cond(scalars_df_index, scalars_pandas_df_index):
386+
# Condition is a dataframe, other is None.
387+
columns = ["int64_col", "float64_col"]
388+
dataframe_bf = scalars_df_index[columns]
389+
dataframe_pd = scalars_pandas_df_index[columns]
390+
391+
cond_bf = dataframe_bf > 0
392+
cond_pd = dataframe_pd > 0
393+
394+
bf_result = dataframe_bf.where(cond_bf, None).to_pandas()
395+
pd_result = dataframe_pd.where(cond_pd, None)
396+
pandas.testing.assert_frame_equal(bf_result, pd_result)
397+
398+
399+
def test_where_dataframe_cond_const_other(scalars_df_index, scalars_pandas_df_index):
400+
# Condition is a dataframe, other is a constant.
401+
columns = ["int64_col", "float64_col"]
402+
dataframe_bf = scalars_df_index[columns]
403+
dataframe_pd = scalars_pandas_df_index[columns]
404+
405+
cond_bf = dataframe_bf > 0
406+
cond_pd = dataframe_pd > 0
407+
other_bf = 10
408+
other_pd = 10
409+
410+
bf_result = dataframe_bf.where(cond_bf, other_bf).to_pandas()
411+
pd_result = dataframe_pd.where(cond_pd, other_pd)
412+
pandas.testing.assert_frame_equal(bf_result, pd_result)
413+
414+
415+
def test_where_dataframe_cond_dataframe_other(
416+
scalars_df_index, scalars_pandas_df_index
417+
):
418+
# Condition is a dataframe, other is a dataframe.
419+
columns = ["int64_col", "float64_col"]
420+
dataframe_bf = scalars_df_index[columns]
421+
dataframe_pd = scalars_pandas_df_index[columns]
422+
423+
cond_bf = dataframe_bf > 0
424+
cond_pd = dataframe_pd > 0
425+
other_bf = dataframe_bf * 2
426+
other_pd = dataframe_pd * 2
427+
428+
bf_result = dataframe_bf.where(cond_bf, other_bf).to_pandas()
429+
pd_result = dataframe_pd.where(cond_pd, other_pd)
430+
pandas.testing.assert_frame_equal(bf_result, pd_result)
431+
432+
325433
def test_drop_column(scalars_dfs):
326434
scalars_df, scalars_pandas_df = scalars_dfs
327435
col_name = "int64_col"

third_party/bigframes_vendored/pandas/core/frame.py

+92
Original file line numberDiff line numberDiff line change
@@ -1956,6 +1956,98 @@ def items(self):
19561956
"""
19571957
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
19581958

1959+
def where(self, cond, other):
1960+
"""Replace values where the condition is False.
1961+
1962+
**Examples:**
1963+
1964+
>>> import bigframes.pandas as bpd
1965+
>>> bpd.options.display.progress_bar = None
1966+
1967+
>>> df = bpd.DataFrame({'a': [20, 10, 0], 'b': [0, 10, 20]})
1968+
>>> df
1969+
a b
1970+
0 20 0
1971+
1 10 10
1972+
2 0 20
1973+
<BLANKLINE>
1974+
[3 rows x 2 columns]
1975+
1976+
You can filter the values in the dataframe based on a condition. The
1977+
values matching the condition would be kept, and not matching would be
1978+
replaced. The default replacement value is ``NA``. For example, when the
1979+
condition is a dataframe:
1980+
1981+
>>> df.where(df > 0)
1982+
a b
1983+
0 20 <NA>
1984+
1 10 10
1985+
2 <NA> 20
1986+
<BLANKLINE>
1987+
[3 rows x 2 columns]
1988+
1989+
You can specify a custom replacement value for non-matching values.
1990+
1991+
>>> df.where(df > 0, -1)
1992+
a b
1993+
0 20 -1
1994+
1 10 10
1995+
2 -1 20
1996+
<BLANKLINE>
1997+
[3 rows x 2 columns]
1998+
1999+
Besides dataframe, the condition can be a series too. For example:
2000+
2001+
>>> df.where(df['a'] > 10, -1)
2002+
a b
2003+
0 20 0
2004+
1 -1 -1
2005+
2 -1 -1
2006+
<BLANKLINE>
2007+
[3 rows x 2 columns]
2008+
2009+
As for the replacement, it can be a dataframe too. For example:
2010+
2011+
>>> df.where(df > 10, -df)
2012+
a b
2013+
0 20 0
2014+
1 -10 -10
2015+
2 0 20
2016+
<BLANKLINE>
2017+
[3 rows x 2 columns]
2018+
2019+
>>> df.where(df['a'] > 10, -df)
2020+
a b
2021+
0 20 0
2022+
1 -10 -10
2023+
2 0 -20
2024+
<BLANKLINE>
2025+
[3 rows x 2 columns]
2026+
2027+
Please note, replacement doesn't support Series for now. In pandas, when
2028+
specifying a Series as replacement, the axis value should be specified
2029+
at the same time, which is not supported in bigframes DataFrame.
2030+
2031+
Args:
2032+
cond (bool Series/DataFrame, array-like, or callable):
2033+
Where cond is True, keep the original value. Where False, replace
2034+
with corresponding value from other. If cond is callable, it is
2035+
computed on the Series/DataFrame and returns boolean
2036+
Series/DataFrame or array. The callable must not change input
2037+
Series/DataFrame (though pandas doesn’t check it).
2038+
other (scalar, DataFrame, or callable):
2039+
Entries where cond is False are replaced with corresponding value
2040+
from other. If other is callable, it is computed on the
2041+
DataFrame and returns scalar or DataFrame. The callable must not
2042+
change input DataFrame (though pandas doesn’t check it). If not
2043+
specified, entries will be filled with the corresponding NULL
2044+
value (np.nan for numpy dtypes, pd.NA for extension dtypes).
2045+
2046+
Returns:
2047+
DataFrame: DataFrame after the replacement.
2048+
"""
2049+
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
2050+
19592051
# ----------------------------------------------------------------------
19602052
# Sorting
19612053

0 commit comments

Comments
 (0)