Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 70 additions & 3 deletions mypy/plugins/functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,30 +9,38 @@
import mypy.semanal
from mypy.argmap import map_actuals_to_formals
from mypy.erasetype import erase_typevars
from mypy.expandtype import expand_type
from mypy.infer import infer_type_arguments
from mypy.nodes import (
ARG_POS,
ARG_STAR2,
SYMBOL_FUNCBASE_TYPES,
ArgKind,
Argument,
CallExpr,
Expression,
MemberExpr,
NameExpr,
Var,
)
from mypy.plugins.common import add_method_to_class
from mypy.typeops import get_all_type_vars
from mypy.types import (
ANY_STRATEGY,
AnyType,
BoolTypeQuery,
CallableType,
Instance,
Overloaded,
ParamSpecFlavor,
ParamSpecType,
Type,
TypeOfAny,
TypeVarId,
TypeVarType,
UnboundType,
UnionType,
UnpackType,
get_proper_type,
)

Expand All @@ -41,6 +49,7 @@
_ORDERING_METHODS: Final = {"__lt__", "__le__", "__gt__", "__ge__"}

PARTIAL: Final = "functools.partial"
PLACEHOLDER: Final = "functools.Placeholder"


class _MethodInfo:
Expand Down Expand Up @@ -134,6 +143,22 @@ def _analyze_class(ctx: mypy.plugin.ClassDefContext) -> dict[str, _MethodInfo |
return comparison_methods


def _is_functools_placeholder(expr: Expression) -> bool:
return isinstance(expr, (NameExpr, MemberExpr)) and expr.fullname == PLACEHOLDER


class _HasUnpack(BoolTypeQuery):
def __init__(self) -> None:
super().__init__(ANY_STRATEGY)

def visit_unpack_type(self, t: UnpackType) -> bool:
return True


def _has_unpack(typ: Type) -> bool:
return typ.accept(_HasUnpack())


def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type:
"""Infer a more precise return type for functools.partial"""
if not isinstance(ctx.api, mypy.checker.TypeChecker): # use internals
Expand Down Expand Up @@ -184,6 +209,7 @@ def handle_partial_with_callee(ctx: mypy.plugin.FunctionContext, callee: Type) -
actual_arg_kinds = []
actual_arg_names = []
actual_types = []
placeholder_actuals = []
seen_args = set()
for i, param in enumerate(ctx.args[1:], start=1):
for j, a in enumerate(param):
Expand All @@ -198,6 +224,9 @@ def handle_partial_with_callee(ctx: mypy.plugin.FunctionContext, callee: Type) -
actual_arg_kinds.append(ctx.arg_kinds[i][j])
actual_arg_names.append(ctx.arg_names[i][j])
actual_types.append(ctx.arg_types[i][j])
placeholder_actuals.append(
ctx.arg_kinds[i][j].is_positional() and _is_functools_placeholder(a)
)

formal_to_actual = map_actuals_to_formals(
actual_kinds=actual_arg_kinds,
Expand All @@ -215,8 +244,20 @@ def handle_partial_with_callee(ctx: mypy.plugin.FunctionContext, callee: Type) -
continue
can_infer_ids.update({tv.id for tv in get_all_type_vars(arg_type)})

defaulted_arg_types = list(fn_type.arg_types)
for i, actuals in enumerate(formal_to_actual):
if any(placeholder_actuals[j] for j in actuals):
# functools.Placeholder is a positional sentinel introduced in Python 3.14.
# It occupies the formal slot but does not bind it, so make the validation
# call accept the sentinel while preserving the original type for the
# resulting partial signature below.
defaulted_arg_types[i] = actual_types[
next(j for j in actuals if placeholder_actuals[j])
]

# special_sig="partial" allows omission of args/kwargs typed with ParamSpec
defaulted = fn_type.copy_modified(
arg_types=defaulted_arg_types,
arg_kinds=[
(
ArgKind.ARG_OPT
Expand Down Expand Up @@ -273,10 +314,30 @@ def handle_partial_with_callee(ctx: mypy.plugin.FunctionContext, callee: Type) -
partial_kinds = []
partial_types = []
partial_names = []
inferred_type_vars: dict[TypeVarId, Type] = {}
if any(placeholder_actuals) and len(bound.arg_types) == len(fn_type.arg_types):
for i, actuals in enumerate(formal_to_actual):
if not actuals or any(placeholder_actuals[j] for j in actuals):
continue
if _has_unpack(fn_type.arg_types[i]) or _has_unpack(bound.arg_types[i]):
# TypeVarTuple/Unpack constraints are handled by check_call() above. Calling
# infer_type_arguments() directly on an UnpackType trips the constraint builder's
# internal "unpack should be handled at a higher level" guard.
continue
inferred_args = infer_type_arguments(
fn_type.variables, fn_type.arg_types[i], bound.arg_types[i]
)
for type_var, inferred_arg in zip(fn_type.variables, inferred_args):
if inferred_arg is not None and mypy.checker.is_valid_inferred_type(
inferred_arg, ctx.api.options
):
inferred_type_vars[type_var.id] = inferred_arg
# We need to fully apply any positional arguments (they cannot be respecified)
# However, keyword arguments can be respecified, so just give them a default
for i, actuals in enumerate(formal_to_actual):
if len(bound.arg_types) == len(fn_type.arg_types):
if any(placeholder_actuals[j] for j in actuals):
arg_type = expand_type(fn_type.arg_types[i], inferred_type_vars)
elif len(bound.arg_types) == len(fn_type.arg_types):
arg_type = bound.arg_types[i]
if not mypy.checker.is_valid_inferred_type(arg_type, ctx.api.options):
arg_type = fn_type.arg_types[i] # bit of a hack
Expand All @@ -285,10 +346,16 @@ def handle_partial_with_callee(ctx: mypy.plugin.FunctionContext, callee: Type) -
# true when PEP 646 things are happening. See testFunctoolsPartialTypeVarTuple
arg_type = fn_type.arg_types[i]

if not actuals or fn_type.arg_kinds[i] in (ArgKind.ARG_STAR, ArgKind.ARG_STAR2):
if (
not actuals
or fn_type.arg_kinds[i] in (ArgKind.ARG_STAR, ArgKind.ARG_STAR2)
or any(placeholder_actuals[j] for j in actuals)
):
partial_kinds.append(fn_type.arg_kinds[i])
partial_types.append(arg_type)
partial_names.append(fn_type.arg_names[i])
partial_names.append(
None if any(placeholder_actuals[j] for j in actuals) else fn_type.arg_names[i]
)
else:
assert actuals
if any(actual_arg_kinds[j] in (ArgKind.ARG_POS, ArgKind.ARG_STAR) for j in actuals):
Expand Down
53 changes: 53 additions & 0 deletions test-data/unit/check-functools.test
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,9 @@ def bar(a: int, b: str, c: float) -> None: ...
p(bar, 1, "a", 3.0) # OK
p(bar, 1, "a", 3.0, kwarg="asdf") # OK
p(bar, 1, "a", "b") # E: Argument 1 to "foo" has incompatible type "Callable[[int, str, float], None]"; expected "Callable[[int, str, str], None]"
p2 = functools.partial(foo, bar, 1) # E: Argument 1 to "foo" has incompatible type "Callable[[int, str, float], None]"; expected "Callable[[int], None]"
p2("a", 3.0, kwarg="asdf") # E: Argument 1 to "foo" has incompatible type "str"; expected "int" \
# E: Argument 2 to "foo" has incompatible type "float"; expected "int"
[builtins fixtures/dict.pyi]

[case testFunctoolsPartialUnion]
Expand Down Expand Up @@ -726,3 +729,53 @@ def outer_c(arg: Tc) -> None:
use_int_callable(partial(inner, b="")) # E: Argument 1 to "use_int_callable" has incompatible type "partial[str]"; expected "Callable[[int], int]" \
# N: "partial[str].__call__" has type "def __call__(__self, *args: Any, **kwargs: Any) -> str"
[builtins fixtures/tuple.pyi]

[case testFunctoolsPartialPlaceholder]
import functools
from functools import partial, Placeholder as _
from typing import TypeVar

T = TypeVar("T")


def foo(a: int, b: str, c: bool) -> tuple[int, str, bool]: ...


p = partial(foo, _, "x", _)
reveal_type(p) # N: Revealed type is "functools.partial[tuple[builtins.int, builtins.str, builtins.bool]]"
reveal_type(p(1, True)) # N: Revealed type is "tuple[builtins.int, builtins.str, builtins.bool]"
p("bad", True) # E: Argument 1 to "foo" has incompatible type "str"; expected "int"
p(1, 1) # E: Argument 2 to "foo" has incompatible type "int"; expected "bool"
p(a=1, c=True) # E: Unexpected keyword argument "a" for "foo" \
# E: Unexpected keyword argument "c" for "foo"


def same(a: T, b: T) -> T: ...
def same_list(a: T, b: list[T]) -> T: ...


generic = partial(same, _, 1)
reveal_type(generic) # N: Revealed type is "functools.partial[builtins.int]"
generic(2)
generic("bad") # E: Argument 1 to "same" has incompatible type "str"; expected "int"

nested_generic = partial(same_list, _, [1])
reveal_type(nested_generic) # N: Revealed type is "functools.partial[builtins.int]"
nested_generic(2)
nested_generic("bad") # E: Argument 1 to "same_list" has incompatible type "str"; expected "int"

module_attr = partial(foo, functools.Placeholder, "x", functools.Placeholder)
reveal_type(module_attr(1, True)) # N: Revealed type is "tuple[builtins.int, builtins.str, builtins.bool]"
partial(foo, a=_) # E: Argument "a" to "foo" has incompatible type "_PlaceholderType"; expected "int"
[file functools.pyi]
from typing import Any, Callable, Final, Generic, TypeVar

_T = TypeVar("_T")

class _PlaceholderType: ...
Placeholder: Final[_PlaceholderType]

class partial(Generic[_T]):
def __new__(cls, func: Callable[..., _T], /, *args: Any, **kwargs: Any) -> partial[_T]: ...
def __call__(self, *args: Any, **kwargs: Any) -> _T: ...
[builtins fixtures/tuple.pyi]
Loading