Skip to content

Commit 0fae2e0

Browse files
feat: (Series | DataFrame).plot.bar (#1152)
* feat: (Series | DataFrame).plot.bar * add warning message * fix mypy * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md --------- Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
1 parent de923d0 commit 0fae2e0

File tree

5 files changed

+121
-10
lines changed

5 files changed

+121
-10
lines changed

bigframes/operations/_matplotlib/__init__.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,11 @@
2020
PLOT_TYPES = typing.Union[type[core.SamplingPlot], type[hist.HistPlot]]
2121

2222
PLOT_CLASSES: dict[str, PLOT_TYPES] = {
23-
"hist": hist.HistPlot,
24-
"line": core.LinePlot,
2523
"area": core.AreaPlot,
24+
"bar": core.BarPlot,
25+
"line": core.LinePlot,
2626
"scatter": core.ScatterPlot,
27+
"hist": hist.HistPlot,
2728
}
2829

2930

bigframes/operations/_matplotlib/core.py

+37-7
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import abc
1616
import typing
17+
import warnings
1718

1819
import bigframes_vendored.constants as constants
1920
import pandas as pd
@@ -46,10 +47,15 @@ def result(self):
4647

4748

4849
class SamplingPlot(MPLPlot):
49-
@abc.abstractproperty
50+
@property
51+
@abc.abstractmethod
5052
def _kind(self):
5153
pass
5254

55+
@property
56+
def _sampling_warning_msg(self) -> typing.Optional[str]:
57+
return None
58+
5359
def __init__(self, data, **kwargs) -> None:
5460
self.kwargs = kwargs
5561
self.data = data
@@ -61,6 +67,15 @@ def generate(self) -> None:
6167
def _compute_sample_data(self, data):
6268
# TODO: Cache the sampling data in the PlotAccessor.
6369
sampling_n = self.kwargs.pop("sampling_n", DEFAULT_SAMPLING_N)
70+
if self._sampling_warning_msg is not None:
71+
total_n = data.shape[0]
72+
if sampling_n < total_n:
73+
warnings.warn(
74+
self._sampling_warning_msg.format(
75+
sampling_n=sampling_n, total_n=total_n
76+
)
77+
)
78+
6479
sampling_random_state = self.kwargs.pop(
6580
"sampling_random_state", DEFAULT_SAMPLING_STATE
6681
)
@@ -74,18 +89,33 @@ def _compute_plot_data(self):
7489
return self._compute_sample_data(self.data)
7590

7691

77-
class LinePlot(SamplingPlot):
78-
@property
79-
def _kind(self) -> typing.Literal["line"]:
80-
return "line"
81-
82-
8392
class AreaPlot(SamplingPlot):
8493
@property
8594
def _kind(self) -> typing.Literal["area"]:
8695
return "area"
8796

8897

98+
class BarPlot(SamplingPlot):
99+
@property
100+
def _kind(self) -> typing.Literal["bar"]:
101+
return "bar"
102+
103+
@property
104+
def _sampling_warning_msg(self) -> typing.Optional[str]:
105+
return (
106+
"To optimize plotting performance, your data has been downsampled to {sampling_n} "
107+
"rows from the original {total_n} rows. This may result in some data points "
108+
"not being displayed. For a more comprehensive view, consider pre-processing "
109+
"your data by aggregating it or selecting the top categories."
110+
)
111+
112+
113+
class LinePlot(SamplingPlot):
114+
@property
115+
def _kind(self) -> typing.Literal["line"]:
116+
return "line"
117+
118+
89119
class ScatterPlot(SamplingPlot):
90120
@property
91121
def _kind(self) -> typing.Literal["scatter"]:

bigframes/operations/plotting.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
class PlotAccessor(vendordt.PlotAccessor):
2424
__doc__ = vendordt.PlotAccessor.__doc__
2525

26-
_common_kinds = ("line", "area", "hist")
26+
_common_kinds = ("line", "area", "hist", "bar")
2727
_dataframe_kinds = ("scatter",)
2828
_all_kinds = _common_kinds + _dataframe_kinds
2929

@@ -72,6 +72,14 @@ def area(
7272
):
7373
return self(kind="area", x=x, y=y, stacked=stacked, **kwargs)
7474

75+
def bar(
76+
self,
77+
x: typing.Optional[typing.Hashable] = None,
78+
y: typing.Optional[typing.Hashable] = None,
79+
**kwargs,
80+
):
81+
return self(kind="bar", x=x, y=y, **kwargs)
82+
7583
def scatter(
7684
self,
7785
x: typing.Optional[typing.Hashable] = None,

tests/system/small/operations/test_plotting.py

+12
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,18 @@ def test_area(scalars_dfs):
195195
tm.assert_almost_equal(line.get_data()[1], pd_line.get_data()[1])
196196

197197

198+
def test_bar(scalars_dfs):
199+
scalars_df, scalars_pandas_df = scalars_dfs
200+
col_names = ["int64_col", "float64_col", "int64_too"]
201+
ax = scalars_df[col_names].plot.bar()
202+
pd_ax = scalars_pandas_df[col_names].plot.bar()
203+
tm.assert_almost_equal(ax.get_xticks(), pd_ax.get_xticks())
204+
tm.assert_almost_equal(ax.get_yticks(), pd_ax.get_yticks())
205+
for line, pd_line in zip(ax.lines, pd_ax.lines):
206+
# Compare y coordinates between the lines
207+
tm.assert_almost_equal(line.get_data()[1], pd_line.get_data()[1])
208+
209+
198210
def test_scatter(scalars_dfs):
199211
scalars_df, scalars_pandas_df = scalars_dfs
200212
col_names = ["int64_col", "float64_col", "int64_too", "bool_col"]

third_party/bigframes_vendored/pandas/plotting/_core.py

+60
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,66 @@ def area(
215215
"""
216216
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
217217

218+
def bar(
219+
self,
220+
x: typing.Optional[typing.Hashable] = None,
221+
y: typing.Optional[typing.Hashable] = None,
222+
**kwargs,
223+
):
224+
"""
225+
Draw a vertical bar plot.
226+
227+
This function calls `pandas.plot` to generate a plot with a random sample
228+
of items. For consistent results, the random sampling is reproducible.
229+
Use the `sampling_random_state` parameter to modify the sampling seed.
230+
231+
**Examples:**
232+
233+
Basic plot.
234+
235+
>>> import bigframes.pandas as bpd
236+
>>> bpd.options.display.progress_bar = None
237+
>>> df = bpd.DataFrame({'lab':['A', 'B', 'C'], 'val':[10, 30, 20]})
238+
>>> ax = df.plot.bar(x='lab', y='val', rot=0)
239+
240+
Plot a whole dataframe to a bar plot. Each column is assigned a distinct color,
241+
and each row is nested in a group along the horizontal axis.
242+
243+
>>> speed = [0.1, 17.5, 40, 48, 52, 69, 88]
244+
>>> lifespan = [2, 8, 70, 1.5, 25, 12, 28]
245+
>>> index = ['snail', 'pig', 'elephant',
246+
... 'rabbit', 'giraffe', 'coyote', 'horse']
247+
>>> df = bpd.DataFrame({'speed': speed, 'lifespan': lifespan}, index=index)
248+
>>> ax = df.plot.bar(rot=0)
249+
250+
Plot stacked bar charts for the DataFrame.
251+
252+
>>> ax = df.plot.bar(stacked=True)
253+
254+
If you don’t like the default colours, you can specify how you’d like each column
255+
to be colored.
256+
257+
>>> axes = df.plot.bar(
258+
... rot=0, subplots=True, color={"speed": "red", "lifespan": "green"}
259+
... )
260+
261+
Args:
262+
x (label or position, optional):
263+
Allows plotting of one column versus another. If not specified, the index
264+
of the DataFrame is used.
265+
y (label or position, optional):
266+
Allows plotting of one column versus another. If not specified, all numerical
267+
columns are used.
268+
**kwargs:
269+
Additional keyword arguments are documented in
270+
:meth:`DataFrame.plot`.
271+
272+
Returns:
273+
matplotlib.axes.Axes or numpy.ndarray:
274+
Area plot, or array of area plots if subplots is True.
275+
"""
276+
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
277+
218278
def scatter(
219279
self,
220280
x: typing.Optional[typing.Hashable] = None,

0 commit comments

Comments
 (0)