Skip to content

Commit 6eb19a7

Browse files
authored
feat: Series.str.split (#675)
* feat: Series.str.split * add more tests * format fix
1 parent 2fd1b81 commit 6eb19a7

File tree

6 files changed

+116
-3
lines changed

6 files changed

+116
-3
lines changed

bigframes/core/compile/scalar_op_compiler.py

+5
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,11 @@ def endswith_op_impl(x: ibis_types.Value, op: ops.EndsWithOp):
588588
return any_match if any_match is not None else ibis_types.literal(False)
589589

590590

591+
@scalar_op_compiler.register_unary_op(ops.StringSplitOp, pass_op=True)
592+
def stringsplit_op_impl(x: ibis_types.Value, op: ops.StringSplitOp):
593+
return typing.cast(ibis_types.StringValue, x).split(op.pat)
594+
595+
591596
@scalar_op_compiler.register_unary_op(ops.ZfillOp, pass_op=True)
592597
def zfill_op_impl(x: ibis_types.Value, op: ops.ZfillOp):
593598
str_value = typing.cast(ibis_types.StringValue, x)

bigframes/dtypes.py

+6
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,12 @@ def bigframes_dtype_to_ibis_dtype(
405405
return BIGFRAMES_TO_IBIS[bigframes_dtype]
406406

407407

408+
def bigframes_dtype_to_arrow_dtype(
409+
bigframes_dtype: Union[DtypeString, Dtype, np.dtype[Any]]
410+
) -> pa.DataType:
411+
return ibis_dtype_to_arrow_dtype(bigframes_dtype_to_ibis_dtype(bigframes_dtype))
412+
413+
408414
def literal_to_ibis_scalar(
409415
literal, force_dtype: typing.Optional[Dtype] = None, validate: bool = True
410416
):

bigframes/operations/__init__.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,19 @@ def output_type(self, *input_types):
386386
return op_typing.STRING_PREDICATE.output_type(input_types[0])
387387

388388

389+
@dataclasses.dataclass(frozen=True)
390+
class StringSplitOp(UnaryOp):
391+
name: typing.ClassVar[str] = "str_split"
392+
pat: typing.Sequence[str]
393+
394+
def output_type(self, *input_types):
395+
input_type = input_types[0]
396+
if not isinstance(input_type, pd.StringDtype):
397+
raise TypeError("field accessor input must be a string type")
398+
arrow_type = dtypes.bigframes_dtype_to_arrow_dtype(input_type)
399+
return pd.ArrowDtype(pa.list_(arrow_type))
400+
401+
389402
@dataclasses.dataclass(frozen=True)
390403
class EndsWithOp(UnaryOp):
391404
name: typing.ClassVar[str] = "str_endswith"
@@ -463,9 +476,7 @@ def output_type(self, *input_types):
463476
raise TypeError("field accessor input must be a struct type")
464477

465478
pa_result_type = pa_type[self.name_or_index].type
466-
# TODO: Directly convert from arrow to pandas type
467-
ibis_result_type = dtypes.arrow_dtype_to_ibis_dtype(pa_result_type)
468-
return dtypes.ibis_dtype_to_bigframes_dtype(ibis_result_type)
479+
return dtypes.arrow_dtype_to_bigframes_dtype(pa_result_type)
469480

470481

471482
@dataclasses.dataclass(frozen=True)

bigframes/operations/strings.py

+12
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,18 @@ def endswith(
247247
pat = (pat,)
248248
return self._apply_unary_op(ops.EndsWithOp(pat=pat))
249249

250+
def split(
251+
self,
252+
pat: str = " ",
253+
regex: Union[bool, None] = None,
254+
) -> series.Series:
255+
if regex is True or (regex is None and len(pat) > 1):
256+
raise NotImplementedError(
257+
"Regular expressions aren't currently supported. Please set "
258+
+ f"`regex=False` and try again. {constants.FEEDBACK_LINK}"
259+
)
260+
return self._apply_unary_op(ops.StringSplitOp(pat=pat))
261+
250262
def zfill(self, width: int) -> series.Series:
251263
return self._apply_unary_op(ops.ZfillOp(width=width))
252264

tests/system/small/operations/test_strings.py

+31
Original file line numberDiff line numberDiff line change
@@ -531,3 +531,34 @@ def test_str_rjust(scalars_dfs):
531531
pd_result,
532532
bf_result,
533533
)
534+
535+
536+
@pytest.mark.parametrize(
537+
("pat", "regex"),
538+
[
539+
pytest.param(" ", None, id="one_char"),
540+
pytest.param("ll", False, id="two_chars"),
541+
pytest.param(
542+
" ",
543+
True,
544+
id="one_char_reg",
545+
marks=pytest.mark.xfail(raises=NotImplementedError),
546+
),
547+
pytest.param(
548+
"ll",
549+
None,
550+
id="two_chars_reg",
551+
marks=pytest.mark.xfail(raises=NotImplementedError),
552+
),
553+
],
554+
)
555+
def test_str_split_raise_errors(scalars_dfs, pat, regex):
556+
scalars_df, scalars_pandas_df = scalars_dfs
557+
col_name = "string_col"
558+
bf_result = scalars_df[col_name].str.split(pat=pat, regex=regex).to_pandas()
559+
pd_result = scalars_pandas_df[col_name].str.split(pat=pat, regex=regex)
560+
561+
# TODO(b/336880368): Allow for NULL values for ARRAY columns in BigQuery.
562+
pd_result = pd_result.apply(lambda x: [] if pd.isnull(x) is True else x)
563+
564+
assert_series_equal(pd_result, bf_result, check_dtype=False)

third_party/bigframes_vendored/pandas/core/strings/accessor.py

+48
Original file line numberDiff line numberDiff line change
@@ -940,6 +940,54 @@ def endswith(
940940
"""
941941
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
942942

943+
def split(
944+
self,
945+
pat: str = " ",
946+
regex: typing.Union[bool, None] = None,
947+
):
948+
"""
949+
Split strings around given separator/delimiter.
950+
951+
**Examples:**
952+
953+
>>> import bigframes.pandas as bpd
954+
>>> import numpy as np
955+
>>> bpd.options.display.progress_bar = None
956+
957+
>>> s = bpd.Series(
958+
... [
959+
... "a regular sentence",
960+
... "https://docs.python.org/index.html",
961+
... np.nan
962+
... ]
963+
... )
964+
>>> s.str.split()
965+
0 ['a' 'regular' 'sentence']
966+
1 ['https://docs.python.org/index.html']
967+
2 []
968+
dtype: list<item: string>[pyarrow]
969+
970+
The pat parameter can be used to split by other characters.
971+
972+
>>> s.str.split("//", regex=False)
973+
0 ['a regular sentence']
974+
1 ['https:' 'docs.python.org/index.html']
975+
2 []
976+
dtype: list<item: string>[pyarrow]
977+
978+
Args:
979+
pat (str, default " "):
980+
String to split on. If not specified, split on whitespace.
981+
regex (bool, default None):
982+
Determines if the passed-in pattern is a regular expression. Regular
983+
expressions aren't currently supported. Please set `regex=False` when
984+
`pat` length is not 1.
985+
986+
Returns:
987+
bigframes.series.Series: Type matches caller.
988+
"""
989+
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
990+
943991
def match(self, pat: str, case: bool = True, flags: int = 0):
944992
"""
945993
Determine if each string starts with a match of a regular expression.

0 commit comments

Comments
 (0)