Skip to content

Commit fff3d45

Browse files
milkshakeiiiHenry J Solberg
and
Henry J Solberg
authored
fix: change return type of Series.loc[scalar] (#40)
* bug: change return type of `Series.loc[scalar]` Change-Id: Id60a7da3021972da5c8a28fb8f3620e10643c0ed * add scalar case and update return types * remove unneeded iloc in series getitem test * fix test_series_get_with_default_index * Run query manual for clarity/redundance --------- Co-authored-by: Henry J Solberg <[email protected]>
1 parent 781307e commit fff3d45

File tree

4 files changed

+62
-35
lines changed

4 files changed

+62
-35
lines changed

bigframes/core/indexers.py

+53-28
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from __future__ import annotations
1616

1717
import typing
18-
from typing import Tuple
18+
from typing import Tuple, Union
1919

2020
import ibis
2121
import pandas as pd
@@ -29,20 +29,19 @@
2929
import bigframes.series
3030

3131
if typing.TYPE_CHECKING:
32-
LocSingleKey = typing.Union[bigframes.series.Series, indexes.Index, slice]
32+
LocSingleKey = Union[
33+
bigframes.series.Series, indexes.Index, slice, bigframes.core.scalar.Scalar
34+
]
3335

3436

3537
class LocSeriesIndexer:
3638
def __init__(self, series: bigframes.series.Series):
3739
self._series = series
3840

39-
def __getitem__(self, key) -> bigframes.series.Series:
40-
"""
41-
Only indexing by a boolean bigframes.series.Series or list of index entries is currently supported
42-
"""
43-
return typing.cast(
44-
bigframes.series.Series, _loc_getitem_series_or_dataframe(self._series, key)
45-
)
41+
def __getitem__(
42+
self, key
43+
) -> Union[bigframes.core.scalar.Scalar, bigframes.series.Series]:
44+
return _loc_getitem_series_or_dataframe(self._series, key)
4645

4746
def __setitem__(self, key, value) -> None:
4847
# TODO(swast): support MultiIndex
@@ -84,7 +83,7 @@ def __init__(self, series: bigframes.series.Series):
8483

8584
def __getitem__(
8685
self, key
87-
) -> bigframes.core.scalar.Scalar | bigframes.series.Series:
86+
) -> Union[bigframes.core.scalar.Scalar, bigframes.series.Series]:
8887
"""
8988
Index series using integer offsets. Currently supports index by key type:
9089
@@ -103,13 +102,17 @@ def __init__(self, dataframe: bigframes.dataframe.DataFrame):
103102
self._dataframe = dataframe
104103

105104
@typing.overload
106-
def __getitem__(self, key: LocSingleKey) -> bigframes.dataframe.DataFrame:
105+
def __getitem__(
106+
self, key: LocSingleKey
107+
) -> Union[bigframes.dataframe.DataFrame, pd.Series]:
107108
...
108109

109110
# Technically this is wrong since we can have duplicate column labels, but
110111
# this is expected to be rare.
111112
@typing.overload
112-
def __getitem__(self, key: Tuple[LocSingleKey, str]) -> bigframes.series.Series:
113+
def __getitem__(
114+
self, key: Tuple[LocSingleKey, str]
115+
) -> Union[bigframes.series.Series, bigframes.core.scalar.Scalar]:
113116
...
114117

115118
def __getitem__(self, key):
@@ -173,7 +176,7 @@ class ILocDataFrameIndexer:
173176
def __init__(self, dataframe: bigframes.dataframe.DataFrame):
174177
self._dataframe = dataframe
175178

176-
def __getitem__(self, key) -> bigframes.dataframe.DataFrame | pd.Series:
179+
def __getitem__(self, key) -> Union[bigframes.dataframe.DataFrame, pd.Series]:
177180
"""
178181
Index dataframe using integer offsets. Currently supports index by key type:
179182
@@ -188,21 +191,26 @@ def __getitem__(self, key) -> bigframes.dataframe.DataFrame | pd.Series:
188191
@typing.overload
189192
def _loc_getitem_series_or_dataframe(
190193
series_or_dataframe: bigframes.series.Series, key
191-
) -> bigframes.series.Series:
194+
) -> Union[bigframes.core.scalar.Scalar, bigframes.series.Series]:
192195
...
193196

194197

195198
@typing.overload
196199
def _loc_getitem_series_or_dataframe(
197200
series_or_dataframe: bigframes.dataframe.DataFrame, key
198-
) -> bigframes.dataframe.DataFrame:
201+
) -> Union[bigframes.dataframe.DataFrame, pd.Series]:
199202
...
200203

201204

202205
def _loc_getitem_series_or_dataframe(
203-
series_or_dataframe: bigframes.dataframe.DataFrame | bigframes.series.Series,
206+
series_or_dataframe: Union[bigframes.dataframe.DataFrame, bigframes.series.Series],
204207
key: LocSingleKey,
205-
) -> bigframes.dataframe.DataFrame | bigframes.series.Series:
208+
) -> Union[
209+
bigframes.dataframe.DataFrame,
210+
bigframes.series.Series,
211+
pd.Series,
212+
bigframes.core.scalar.Scalar,
213+
]:
206214
if isinstance(key, bigframes.series.Series) and key.dtype == "boolean":
207215
return series_or_dataframe[key]
208216
elif isinstance(key, bigframes.series.Series):
@@ -222,7 +230,7 @@ def _loc_getitem_series_or_dataframe(
222230
# TODO(henryjsolberg): support MultiIndex
223231
if len(key) == 0: # type: ignore
224232
return typing.cast(
225-
typing.Union[bigframes.dataframe.DataFrame, bigframes.series.Series],
233+
Union[bigframes.dataframe.DataFrame, bigframes.series.Series],
226234
series_or_dataframe.iloc[0:0],
227235
)
228236

@@ -258,11 +266,22 @@ def _loc_getitem_series_or_dataframe(
258266
)
259267
keys_df = keys_df.set_index(index_name, drop=True)
260268
keys_df.index.name = None
261-
return _perform_loc_list_join(series_or_dataframe, keys_df)
269+
result = _perform_loc_list_join(series_or_dataframe, keys_df)
270+
pandas_result = result.to_pandas()
271+
# although loc[scalar_key] returns multiple results when scalar_key
272+
# is not unique, we download the results here and return the computed
273+
# individual result (as a scalar or pandas series) when the key is unique,
274+
# since we expect unique index keys to be more common. loc[[scalar_key]]
275+
# can be used to retrieve one-item DataFrames or Series.
276+
if len(pandas_result) == 1:
277+
return pandas_result.iloc[0]
278+
# when the key is not unique, we return a bigframes data type
279+
# as usual for methods that return dataframes/series
280+
return result
262281
else:
263282
raise TypeError(
264-
"Invalid argument type. loc currently only supports indexing with a "
265-
"boolean bigframes Series, a list of index entries or a single index entry. "
283+
"Invalid argument type. Expected bigframes.Series, bigframes.Index, "
284+
"list, : (empty slice), or scalar. "
266285
f"{constants.FEEDBACK_LINK}"
267286
)
268287

@@ -284,9 +303,9 @@ def _perform_loc_list_join(
284303

285304

286305
def _perform_loc_list_join(
287-
series_or_dataframe: bigframes.dataframe.DataFrame | bigframes.series.Series,
306+
series_or_dataframe: Union[bigframes.dataframe.DataFrame, bigframes.series.Series],
288307
keys_df: bigframes.dataframe.DataFrame,
289-
) -> bigframes.series.Series | bigframes.dataframe.DataFrame:
308+
) -> Union[bigframes.series.Series, bigframes.dataframe.DataFrame]:
290309
# right join based on the old index so that the matching rows from the user's
291310
# original dataframe will be duplicated and reordered appropriately
292311
original_index_names = series_or_dataframe.index.names
@@ -309,20 +328,26 @@ def _perform_loc_list_join(
309328
@typing.overload
310329
def _iloc_getitem_series_or_dataframe(
311330
series_or_dataframe: bigframes.series.Series, key
312-
) -> bigframes.series.Series | bigframes.core.scalar.Scalar:
331+
) -> Union[bigframes.series.Series, bigframes.core.scalar.Scalar]:
313332
...
314333

315334

316335
@typing.overload
317336
def _iloc_getitem_series_or_dataframe(
318337
series_or_dataframe: bigframes.dataframe.DataFrame, key
319-
) -> bigframes.dataframe.DataFrame | pd.Series:
338+
) -> Union[bigframes.dataframe.DataFrame, pd.Series]:
320339
...
321340

322341

323342
def _iloc_getitem_series_or_dataframe(
324-
series_or_dataframe: bigframes.dataframe.DataFrame | bigframes.series.Series, key
325-
) -> bigframes.dataframe.DataFrame | bigframes.series.Series | bigframes.core.scalar.Scalar | pd.Series:
343+
series_or_dataframe: Union[bigframes.dataframe.DataFrame, bigframes.series.Series],
344+
key,
345+
) -> Union[
346+
bigframes.dataframe.DataFrame,
347+
bigframes.series.Series,
348+
bigframes.core.scalar.Scalar,
349+
pd.Series,
350+
]:
326351
if isinstance(key, int):
327352
internal_slice_result = series_or_dataframe._slice(key, key + 1, 1)
328353
result_pd_df = internal_slice_result.to_pandas()
@@ -334,7 +359,7 @@ def _iloc_getitem_series_or_dataframe(
334359
elif pd.api.types.is_list_like(key):
335360
if len(key) == 0:
336361
return typing.cast(
337-
typing.Union[bigframes.dataframe.DataFrame, bigframes.series.Series],
362+
Union[bigframes.dataframe.DataFrame, bigframes.series.Series],
338363
series_or_dataframe.iloc[0:0],
339364
)
340365
df = series_or_dataframe

bigframes/ml/model_selection.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
https://scikit-learn.org/stable/modules/classes.html#module-sklearn.model_selection."""
1818

1919

20+
import typing
2021
from typing import List, Union
2122

2223
from bigframes.ml import utils
@@ -79,9 +80,10 @@ def train_test_split(
7980
train_index = split_dfs[0].index
8081
test_index = split_dfs[1].index
8182

82-
split_dfs += [
83-
df.loc[index] for df in dfs[1:] for index in (train_index, test_index)
84-
]
83+
split_dfs += typing.cast(
84+
List[bpd.DataFrame],
85+
[df.loc[index] for df in dfs[1:] for index in (train_index, test_index)],
86+
)
8587

8688
# convert back to Series.
8789
results: List[Union[bpd.DataFrame, bpd.Series]] = []

tests/system/small/test_dataframe.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2081,7 +2081,7 @@ def test_loc_single_index_no_duplicate(scalars_df_index, scalars_pandas_df_index
20812081
bf_result = scalars_df_index.loc[index]
20822082
pd_result = scalars_pandas_df_index.loc[index]
20832083
pd.testing.assert_series_equal(
2084-
bf_result.to_pandas().iloc[0, :],
2084+
bf_result,
20852085
pd_result,
20862086
)
20872087

tests/system/small/test_series.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def test_series_get_with_default_index(scalars_dfs):
118118
scalars_df, scalars_pandas_df = scalars_dfs
119119
bf_result = scalars_df[col_name].get(key)
120120
pd_result = scalars_pandas_df[col_name].get(key)
121-
assert bf_result.to_pandas().iloc[0] == pd_result
121+
assert bf_result == pd_result
122122

123123

124124
@pytest.mark.parametrize(
@@ -157,7 +157,7 @@ def test_series___getitem___with_default_index(scalars_dfs):
157157
scalars_df, scalars_pandas_df = scalars_dfs
158158
bf_result = scalars_df[col_name][key]
159159
pd_result = scalars_pandas_df[col_name][key]
160-
assert bf_result.to_pandas().iloc[0] == pd_result
160+
assert bf_result == pd_result
161161

162162

163163
@pytest.mark.parametrize(
@@ -2652,7 +2652,7 @@ def test_loc_single_index_no_duplicate(scalars_df_index, scalars_pandas_df_index
26522652
index = -2345
26532653
bf_result = scalars_df_index.date_col.loc[index]
26542654
pd_result = scalars_pandas_df_index.date_col.loc[index]
2655-
assert bf_result.to_pandas().iloc[0] == pd_result
2655+
assert bf_result == pd_result
26562656

26572657

26582658
def test_series_bool_interpretation_error(scalars_df_index):

0 commit comments

Comments
 (0)