Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
5cdb753
Reject ParamSpec-typed callables calls with insufficient arguments
sterliakov Jun 4, 2024
3b2297f
Reuse params preprocessing logic for generic functions
sterliakov Jun 4, 2024
a32ad3f
Only perform deep expansion on overloads when ParamSpec is present
sterliakov Jun 4, 2024
63995e3
Tidy up code a bit
sterliakov Jun 10, 2024
be2c49b
Merge remote-tracking branch 'upstream/master' into bugfix/st-paramsp…
sterliakov Aug 21, 2024
0a658a7
Always pick ParamSpec-containing overloads as plausible candidates
sterliakov Aug 21, 2024
63f9438
Remove parameters thaat are no longer used
sterliakov Aug 21, 2024
786fb55
Merge branch 'master' into bugfix/st-paramspec-missing-args
sterliakov Sep 12, 2024
1903402
Undo style change to make it easier to review, feel free to add in se…
hauntsaninja Sep 24, 2024
512a722
Support ParamSpec + functools.partial
sterliakov Jun 10, 2024
558a35f
Merge remote-tracking branch 'upstream/master' into bugfix/st-paramsp…
sterliakov Sep 25, 2024
01390c0
Fix lost error in test
sterliakov Sep 25, 2024
b23efe8
Merge remote-tracking branch 'upstream/master' into bugfix/st-paramsp…
sterliakov Oct 17, 2024
bf60ec1
Remove param_spec_args_bound - new version produces worse error messa…
sterliakov Oct 17, 2024
a660f1f
Add test scenario
sterliakov Oct 17, 2024
5d46d85
Fix typing, use `immutable` for storing binding information
sterliakov Oct 17, 2024
bbaf9de
Merge remote-tracking branch 'upstream/master' into bugfix/st-paramsp…
sterliakov Oct 24, 2024
4d55122
Fix duplicated testcase
sterliakov Oct 24, 2024
b077063
Merge remote-tracking branch 'upstream/master' into bugfix/st-paramsp…
sterliakov Oct 25, 2024
aa1391c
Replace naive check by arg kinds with more robust type-aware check
sterliakov Oct 25, 2024
31a7492
Deduplicate isinstance() checks
sterliakov Oct 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Reuse params preprocessing logic for generic functions
  • Loading branch information
sterliakov committed Jun 4, 2024
commit 3b2297f1b370549695dae895045282831b901a64
74 changes: 47 additions & 27 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1716,33 +1716,9 @@ def check_callable_call(
callee = callee.copy_modified(ret_type=fresh_ret_type)

if callee.is_generic():
need_refresh = any(
isinstance(v, (ParamSpecType, TypeVarTupleType)) for v in callee.variables
callee, formal_to_actual = self.adjust_generic_callable_params_mapping(
callee, args, arg_kinds, arg_names, formal_to_actual, context
)
callee = freshen_function_type_vars(callee)
callee = self.infer_function_type_arguments_using_context(callee, context)
if need_refresh:
# Argument kinds etc. may have changed due to
# ParamSpec or TypeVarTuple variables being replaced with an arbitrary
# number of arguments; recalculate actual-to-formal map
formal_to_actual = map_actuals_to_formals(
arg_kinds,
arg_names,
callee.arg_kinds,
callee.arg_names,
lambda i: self.accept(args[i]),
)
callee = self.infer_function_type_arguments(
callee, args, arg_kinds, arg_names, formal_to_actual, need_refresh, context
)
if need_refresh:
formal_to_actual = map_actuals_to_formals(
arg_kinds,
arg_names,
callee.arg_kinds,
callee.arg_names,
lambda i: self.accept(args[i]),
)

param_spec = callee.param_spec()
if (
Expand Down Expand Up @@ -2633,7 +2609,7 @@ def check_overload_call(
arg_types = self.infer_arg_types_in_empty_context(args)
# Step 1: Filter call targets to remove ones where the argument counts don't match
plausible_targets = self.plausible_overload_call_targets(
arg_types, arg_kinds, arg_names, callee
args, arg_types, arg_kinds, arg_names, callee, context
)

# Step 2: If the arguments contain a union, we try performing union math first,
Expand Down Expand Up @@ -2751,12 +2727,52 @@ def check_overload_call(
self.chk.fail(message_registry.TOO_MANY_UNION_COMBINATIONS, context)
return result

def adjust_generic_callable_params_mapping(
self,
callee: CallableType,
args: list[Expression],
arg_kinds: list[ArgKind],
arg_names: Sequence[str | None] | None,
formal_to_actual: list[list[int]],
context: Context,
) -> tuple[CallableType, list[list[int]]]:
need_refresh = any(
isinstance(v, (ParamSpecType, TypeVarTupleType)) for v in callee.variables
)
callee = freshen_function_type_vars(callee)
callee = self.infer_function_type_arguments_using_context(callee, context)
if need_refresh:
# Argument kinds etc. may have changed due to
# ParamSpec or TypeVarTuple variables being replaced with an arbitrary
# number of arguments; recalculate actual-to-formal map
formal_to_actual = map_actuals_to_formals(
arg_kinds,
arg_names,
callee.arg_kinds,
callee.arg_names,
lambda i: self.accept(args[i]),
)
callee = self.infer_function_type_arguments(
callee, args, arg_kinds, arg_names, formal_to_actual, need_refresh, context
)
if need_refresh:
formal_to_actual = map_actuals_to_formals(
arg_kinds,
arg_names,
callee.arg_kinds,
callee.arg_names,
lambda i: self.accept(args[i]),
)
return callee, formal_to_actual

def plausible_overload_call_targets(
self,
args: list[Expression],
arg_types: list[Type],
arg_kinds: list[ArgKind],
arg_names: Sequence[str | None] | None,
overload: Overloaded,
context: Context,
) -> list[CallableType]:
"""Returns all overload call targets that having matching argument counts.

Expand Down Expand Up @@ -2790,6 +2806,10 @@ def has_shape(typ: Type) -> bool:
formal_to_actual = map_actuals_to_formals(
arg_kinds, arg_names, typ.arg_kinds, typ.arg_names, lambda i: arg_types[i]
)
if typ.is_generic():
typ, formal_to_actual = self.adjust_generic_callable_params_mapping(
typ, args, arg_kinds, arg_names, formal_to_actual, context
)

with self.msg.filter_errors():
if self.check_argument_count(
Expand Down
64 changes: 62 additions & 2 deletions test-data/unit/check-parameter-specification.test
Original file line number Diff line number Diff line change
Expand Up @@ -2211,12 +2211,22 @@ from typing import Callable

_P = ParamSpec("_P")

def run(predicate: Callable[_P, str], *args: _P.args, **kwargs: _P.kwargs) -> None:
def run(predicate: Callable[_P, None], *args: _P.args, **kwargs: _P.kwargs) -> None: # N: "run" defined here
predicate() # E: Too few arguments
predicate(*args) # E: Too few arguments
predicate(**kwargs) # E: Too few arguments
predicate(*args, **kwargs)

def fn() -> None: ...
def fn_args(x: int) -> None: ...
def fn_posonly(x: int, /) -> None: ...

run(fn)
run(fn_args, 1)
run(fn_args, x=1)
run(fn_posonly, 1)
run(fn_posonly, x=1) # E: Unexpected keyword argument "x" for "run"

[builtins fixtures/paramspec.pyi]

[case testRunParamSpecConcatenateInsufficientArgs]
Expand All @@ -2225,7 +2235,7 @@ from typing import Callable

_P = ParamSpec("_P")

def run(predicate: Callable[Concatenate[int, _P], str], *args: _P.args, **kwargs: _P.kwargs) -> None:
def run(predicate: Callable[Concatenate[int, _P], None], *args: _P.args, **kwargs: _P.kwargs) -> None: # N: "run" defined here
predicate() # E: Too few arguments
predicate(1) # E: Too few arguments
predicate(1, *args) # E: Too few arguments
Expand All @@ -2234,6 +2244,22 @@ def run(predicate: Callable[Concatenate[int, _P], str], *args: _P.args, **kwargs
predicate(*args, **kwargs) # E: Argument 1 has incompatible type "*_P.args"; expected "int"
predicate(1, *args, **kwargs)

def fn() -> None: ...
def fn_args(x: int, y: str) -> None: ...
def fn_posonly(x: int, /) -> None: ...
def fn_posonly_args(x: int, /, y: str) -> None: ...

run(fn) # E: Argument 1 to "run" has incompatible type "Callable[[], None]"; expected "Callable[[int], None]"
run(fn_args, 1, 'a') # E: Too many arguments for "run" \
# E: Argument 2 to "run" has incompatible type "int"; expected "str"
run(fn_args, y='a')
run(fn_args, 'a')
run(fn_posonly)
run(fn_posonly, x=1) # E: Unexpected keyword argument "x" for "run"
run(fn_posonly_args) # E: Missing positional argument "y" in call to "run"
run(fn_posonly_args, 'a')
run(fn_posonly_args, y='a')

[builtins fixtures/paramspec.pyi]

[case testRunParamSpecConcatenateInsufficientArgsInDecorator]
Expand All @@ -2255,3 +2281,37 @@ def decorator(fn: Callable[Concatenate[str, P], None]) -> Callable[P, None]:
def foo(s: str, s2: str) -> None: ...

[builtins fixtures/paramspec.pyi]

[case testRunParamSpecOverload]
from typing_extensions import ParamSpec, Concatenate
from typing import Callable, overload, NoReturn, TypeVar, Union

P = ParamSpec("P")
T = TypeVar("T")

@overload
def capture(
sync_fn: Callable[P, NoReturn],
*args: P.args,
**kwargs: P.kwargs,
) -> int: ...
@overload
def capture(
sync_fn: Callable[P, T],
*args: P.args,
**kwargs: P.kwargs,
) -> Union[T, int]: ...
def capture(
sync_fn: Callable[P, T],
*args: P.args,
**kwargs: P.kwargs,
) -> Union[T, int]:
return sync_fn(*args, **kwargs)

def fn() -> str: return ''
def err() -> NoReturn: ...

reveal_type(capture(fn)) # N: Revealed type is "Union[builtins.str, builtins.int]"
reveal_type(capture(err)) # N: Revealed type is "builtins.int"

[builtins fixtures/paramspec.pyi]