diff --git a/src/wsproto/frame_protocol.py b/src/wsproto/frame_protocol.py index 288276d..4b44d07 100644 --- a/src/wsproto/frame_protocol.py +++ b/src/wsproto/frame_protocol.py @@ -13,6 +13,8 @@ from enum import IntEnum from typing import TYPE_CHECKING, NamedTuple +from .utilities import LocalProtocolError + if TYPE_CHECKING: from collections.abc import Generator @@ -588,13 +590,14 @@ def received_frames(self) -> Generator[Frame, None, None]: def close(self, code: int | None = None, reason: str | None = None) -> bytearray: payload = bytearray() - if code is CloseReason.NO_STATUS_RCVD: + if code == CloseReason.NO_STATUS_RCVD: code = None if code is None and reason: msg = "cannot specify a reason without a code" raise TypeError(msg) if code in LOCAL_ONLY_CLOSE_REASONS: - code = CloseReason.NORMAL_CLOSURE + msg = f"cannot send a close frame with local-only code {code}" + raise LocalProtocolError(msg) if code is not None: payload += bytearray(struct.pack("!H", code)) if reason is not None: diff --git a/tests/test_connection.py b/tests/test_connection.py index 92921dc..30a1893 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -90,6 +90,12 @@ def test_close_whilst_closing() -> None: client.send(CloseConnection(code=CloseReason.NORMAL_CLOSURE)) +def test_local_only_close_reason_rejected() -> None: + client = Connection(CLIENT) + with pytest.raises(LocalProtocolError): + client.send(CloseConnection(code=1006)) + + def test_send_after_close() -> None: client = Connection(CLIENT) client.send(CloseConnection(code=CloseReason.NORMAL_CLOSURE)) diff --git a/tests/test_frame_protocol.py b/tests/test_frame_protocol.py index 5443348..2c42481 100644 --- a/tests/test_frame_protocol.py +++ b/tests/test_frame_protocol.py @@ -10,6 +10,7 @@ from wsproto import extensions as wpext from wsproto import frame_protocol as fp +from wsproto.utilities import LocalProtocolError class TestBuffer: @@ -1047,10 +1048,14 @@ def test_no_status_rcvd_close_reason(self) -> None: data = proto.close(code=fp.CloseReason.NO_STATUS_RCVD) assert data == b"\x88\x00" - def test_local_only_close_reason(self) -> None: + @pytest.mark.parametrize( + "code", + [fp.CloseReason.ABNORMAL_CLOSURE, fp.CloseReason.TLS_HANDSHAKE_FAILED], + ) + def test_local_only_close_reason(self, code: fp.CloseReason) -> None: proto = fp.FrameProtocol(client=False, extensions=[]) - data = proto.close(code=fp.CloseReason.ABNORMAL_CLOSURE) - assert data == b"\x88\x02\x03\xe8" + with pytest.raises(LocalProtocolError): + proto.close(code=code) def test_ping_without_payload(self) -> None: proto = fp.FrameProtocol(client=False, extensions=[])