diff --git a/CHANGELOG.md b/CHANGELOG.md index f35c4827..247626cb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,7 @@ # Unreleased +- Non-protocol subclasses of `Protocol` now ignore the + `__init__` method inherited from protocol base classes. - Fix setting of `__required_keys__` and `__optional_keys__` when inheriting keys with the same name. - Fix incorrect behaviour on Python 3.9 and Python 3.10 that meant that diff --git a/src/test_typing_extensions.py b/src/test_typing_extensions.py index c7025321..287e522b 100644 --- a/src/test_typing_extensions.py +++ b/src/test_typing_extensions.py @@ -20,6 +20,7 @@ import typing import warnings from collections import defaultdict +from dataclasses import dataclass from functools import lru_cache from pathlib import Path from unittest import TestCase, main, skipIf, skipUnless @@ -3698,12 +3699,59 @@ def __init__(self, x: T) -> None: def test_init_called(self): T = TypeVar('T') + class P(Protocol[T]): pass + class C(P[T]): def __init__(self): self.test = 'OK' + self.assertEqual(C[int]().test, 'OK') + class B: + def __init__(self): + self.test = 'OK' + + class D1(B, P[T]): + pass + + self.assertEqual(D1[int]().test, 'OK') + + class D2(P[T], B): + pass + + self.assertEqual(D2[int]().test, 'OK') + + def test_super_call_init(self): + class P(Protocol): + x: int + + class Foo(P): + def __init__(self): + super().__init__() + + Foo() # Previously triggered RecursionError + + def test_inherit_from_protocol(self): + # Dataclasses inheriting from protocol should preserve their own `__init__`. + # See bpo-45081. + + class P(Protocol): + a: int + + @dataclass + class C(P): + a: int + + self.assertEqual(C(5).a, 5) + + @dataclass + class D(P): + def __init__(self, a): + self.a = a * 2 + + self.assertEqual(D(5).a, 10) + def test_protocols_bad_subscripts(self): T = TypeVar('T') S = TypeVar('S') diff --git a/src/typing_extensions.py b/src/typing_extensions.py index 05d4522c..c4976556 100644 --- a/src/typing_extensions.py +++ b/src/typing_extensions.py @@ -715,10 +715,6 @@ def _allow_reckless_class_checks(depth=2): """ return _caller(depth) in {'abc', 'functools', None} - def _no_init(self, *args, **kwargs): - if type(self)._is_protocol: - raise TypeError('Protocols cannot be instantiated') - def _type_check_issubclass_arg_1(arg): """Raise TypeError if `arg` is not an instance of `type` in `issubclass(arg, )`. @@ -882,7 +878,7 @@ def __init_subclass__(cls, *args, **kwargs): # Prohibit instantiation for protocol classes if cls._is_protocol and cls.__init__ is Protocol.__init__: - cls.__init__ = _no_init + cls.__init__ = typing._no_init_or_replace_init # Breakpoint: https://github.com/python/cpython/pull/113401