Skip to content

BUG: Prevent abuse of kwargs in stat functions #12318

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/source/whatsnew/v0.18.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,8 @@ Other API Changes

- As part of the new API for :ref:`window functions <whatsnew_0180.enhancements.moments>` and :ref:`resampling <whatsnew_0180.breaking.resample>`, aggregation functions have been clarified, raising more informative error messages on invalid aggregations. (:issue:`9052`). A full set of examples are presented in :ref:`groupby <groupby.aggregation>`.

- Statistical functions for ``NDFrame`` objects will now raise if non-numpy-compatible arguments are passed in for ``**kwargs`` (:issue:`12301`)

.. _whatsnew_0180.deprecations:

Deprecations
Expand Down
20 changes: 20 additions & 0 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5207,12 +5207,29 @@ def _doc_parms(cls):
%(outname)s : %(name1)s\n"""


def _validate_kwargs(fname, kwargs, *compat_args):
"""
Checks whether parameters passed to the
**kwargs argument in a 'stat' function 'fname'
are valid parameters as specified in *compat_args

"""
list(map(kwargs.__delitem__, filter(
kwargs.__contains__, compat_args)))
if kwargs:
bad_arg = list(kwargs)[0] # first 'key' element
raise TypeError(("{fname}() got an unexpected "
"keyword argument '{arg}'".
format(fname=fname, arg=bad_arg)))


def _make_stat_function(name, name1, name2, axis_descr, desc, f):
@Substitution(outname=name, desc=desc, name1=name1, name2=name2,
axis_descr=axis_descr)
@Appender(_num_doc)
def stat_func(self, axis=None, skipna=None, level=None, numeric_only=None,
**kwargs):
_validate_kwargs(name, kwargs, 'out', 'dtype')
if skipna is None:
skipna = True
if axis is None:
Expand All @@ -5233,6 +5250,7 @@ def _make_stat_function_ddof(name, name1, name2, axis_descr, desc, f):
@Appender(_num_ddof_doc)
def stat_func(self, axis=None, skipna=None, level=None, ddof=1,
numeric_only=None, **kwargs):
_validate_kwargs(name, kwargs, 'out', 'dtype')
if skipna is None:
skipna = True
if axis is None:
Expand All @@ -5254,6 +5272,7 @@ def _make_cum_function(name, name1, name2, axis_descr, desc, accum_func,
@Appender("Return cumulative {0} over requested axis.".format(name) +
_cnum_doc)
def func(self, axis=None, dtype=None, out=None, skipna=True, **kwargs):
_validate_kwargs(name, kwargs, 'out', 'dtype')
if axis is None:
axis = self._stat_axis_number
else:
Expand Down Expand Up @@ -5288,6 +5307,7 @@ def _make_logical_function(name, name1, name2, axis_descr, desc, f):
@Appender(_bool_doc)
def logical_func(self, axis=None, bool_only=None, skipna=None, level=None,
**kwargs):
_validate_kwargs(name, kwargs, 'out', 'dtype')
if skipna is None:
skipna = True
if axis is None:
Expand Down
20 changes: 17 additions & 3 deletions pandas/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@

from pandas.compat import range, zip
from pandas import compat
from pandas.util.testing import (assert_series_equal,
from pandas.util.testing import (assertRaisesRegexp,
assert_series_equal,
assert_frame_equal,
assert_panel_equal,
assert_panel4d_equal,
assert_almost_equal,
assert_equal)

import pandas.util.testing as tm


Expand Down Expand Up @@ -483,8 +485,6 @@ def test_split_compat(self):
self.assertTrue(len(np.array_split(o, 2)) == 2)

def test_unexpected_keyword(self): # GH8597
from pandas.util.testing import assertRaisesRegexp

df = DataFrame(np.random.randn(5, 2), columns=['jim', 'joe'])
ca = pd.Categorical([0, 0, 2, 2, 3, np.nan])
ts = df['joe'].copy()
Expand All @@ -502,6 +502,20 @@ def test_unexpected_keyword(self): # GH8597
with assertRaisesRegexp(TypeError, 'unexpected keyword'):
ts.fillna(0, in_place=True)

# See gh-12301
def test_stat_unexpected_keyword(self):
obj = self._construct(5)
starwars = 'Star Wars'

with assertRaisesRegexp(TypeError, 'unexpected keyword'):
obj.max(epic=starwars) # stat_function
with assertRaisesRegexp(TypeError, 'unexpected keyword'):
obj.var(epic=starwars) # stat_function_ddof
with assertRaisesRegexp(TypeError, 'unexpected keyword'):
obj.sum(epic=starwars) # cum_function
with assertRaisesRegexp(TypeError, 'unexpected keyword'):
obj.any(epic=starwars) # logical_function


class TestSeries(tm.TestCase, Generic):
_typ = Series
Expand Down