diff --git a/httpcore/_async/connection.py b/httpcore/_async/connection.py index b42581df..2f8a0a8b 100644 --- a/httpcore/_async/connection.py +++ b/httpcore/_async/connection.py @@ -48,6 +48,7 @@ def __init__( uds: str | None = None, network_backend: AsyncNetworkBackend | None = None, socket_options: typing.Iterable[SOCKET_OPTION] | None = None, + h2_ping_interval: float | None = None, ) -> None: self._origin = origin self._ssl_context = ssl_context @@ -57,6 +58,7 @@ def __init__( self._retries = retries self._local_address = local_address self._uds = uds + self._h2_ping_interval = h2_ping_interval self._network_backend: AsyncNetworkBackend = ( AutoBackend() if network_backend is None else network_backend @@ -89,6 +91,7 @@ async def handle_async_request(self, request: Request) -> Response: origin=self._origin, stream=stream, keepalive_expiry=self._keepalive_expiry, + h2_ping_interval=self._h2_ping_interval, ) else: self._connection = AsyncHTTP11Connection( diff --git a/httpcore/_async/connection_pool.py b/httpcore/_async/connection_pool.py index 5ef74e64..c1d10604 100644 --- a/httpcore/_async/connection_pool.py +++ b/httpcore/_async/connection_pool.py @@ -59,6 +59,7 @@ def __init__( uds: str | None = None, network_backend: AsyncNetworkBackend | None = None, socket_options: typing.Iterable[SOCKET_OPTION] | None = None, + h2_ping_interval: float | None = None, ) -> None: """ A connection pool for making HTTP requests. @@ -88,6 +89,10 @@ def __init__( network_backend: A backend instance to use for handling network I/O. socket_options: Socket options that have to be included in the TCP socket when the connection was established. + h2_ping_interval: Interval in seconds between HTTP/2 PING frames + sent to keep connections alive. Set to ``None`` to disable. + Falls back to the ``HTTPCORE_H2_PING_INTERVAL`` environment + variable if not specified. """ self._ssl_context = ssl_context self._proxy = proxy @@ -114,6 +119,7 @@ def __init__( AutoBackend() if network_backend is None else network_backend ) self._socket_options = socket_options + self._h2_ping_interval = h2_ping_interval # The mutable state on a connection pool is the queue of incoming requests, # and the set of connections that are servicing those requests. @@ -176,6 +182,7 @@ def create_connection(self, origin: Origin) -> AsyncConnectionInterface: uds=self._uds, network_backend=self._network_backend, socket_options=self._socket_options, + h2_ping_interval=self._h2_ping_interval, ) @property diff --git a/httpcore/_async/http2.py b/httpcore/_async/http2.py index dbd0beeb..7adb0e77 100644 --- a/httpcore/_async/http2.py +++ b/httpcore/_async/http2.py @@ -2,6 +2,8 @@ import enum import logging +import os +import threading import time import types import typing @@ -48,6 +50,7 @@ def __init__( origin: Origin, stream: AsyncNetworkStream, keepalive_expiry: float | None = None, + h2_ping_interval: float | None = None, ): self._origin = origin self._network_stream = stream @@ -64,6 +67,15 @@ def __init__( self._used_all_stream_ids = False self._connection_error = False + if h2_ping_interval is not None: + self._h2_ping_interval: float | None = h2_ping_interval + else: + env_val = os.environ.get("HTTPCORE_H2_PING_INTERVAL") + self._h2_ping_interval = float(env_val) if env_val else None + self._ping_thread: threading.Thread | None = None + self._ping_stop = threading.Event() + self._ping_write_lock = threading.Lock() + # Mapping from stream ID to response stream events. self._events: dict[ int, @@ -217,6 +229,48 @@ async def _send_connection_init(self, request: Request) -> None: self._h2_state.increment_flow_control_window(2**24) await self._write_outgoing_data(request) + if self._h2_ping_interval is not None: + self._start_ping_keepalive() + + def _start_ping_keepalive(self) -> None: + self._ping_stop.clear() + self._ping_thread = threading.Thread( + target=self._ping_keepalive_loop, daemon=True + ) + self._ping_thread.start() + logger.debug( + "HTTP/2 PING keepalive started (interval=%.0fs)", self._h2_ping_interval + ) + + def _ping_keepalive_loop(self) -> None: + """Background thread that sends periodic PING frames via the raw socket.""" + assert self._h2_ping_interval is not None + + raw_sock = self._network_stream.get_extra_info("socket") + if raw_sock is None: + raw_sock = self._network_stream.get_extra_info("ssl_object") + if raw_sock is None: # pragma: nocover + logger.debug("HTTP/2 PING keepalive: unable to obtain raw socket, stopping") + return + + while not self._ping_stop.wait(self._h2_ping_interval): + try: + if self.is_closed(): # pragma: nocover + break + with self._ping_write_lock: + if self.is_closed(): # pragma: nocover + break + opaque = int(time.monotonic_ns() & 0xFFFFFFFFFFFFFFFF).to_bytes( + 8, "big" + ) + self._h2_state.ping(opaque) + data_to_send = self._h2_state.data_to_send() + if data_to_send: + raw_sock.sendall(data_to_send) + logger.debug("HTTP/2 PING sent") + except Exception: # pragma: nocover + break + # Sending the request... async def _send_request_headers(self, request: Request, stream_id: int) -> None: @@ -424,6 +478,10 @@ async def _response_closed(self, stream_id: int) -> None: async def aclose(self) -> None: # Note that this method unilaterally closes the connection, and does # not have any kind of locking in place around it. + self._ping_stop.set() + if self._ping_thread is not None: + self._ping_thread.join(timeout=2) + self._ping_thread = None self._h2_state.close_connection() self._state = HTTPConnectionState.CLOSED await self._network_stream.aclose() diff --git a/httpcore/_sync/connection.py b/httpcore/_sync/connection.py index 363f8be8..45c884d6 100644 --- a/httpcore/_sync/connection.py +++ b/httpcore/_sync/connection.py @@ -48,6 +48,7 @@ def __init__( uds: str | None = None, network_backend: NetworkBackend | None = None, socket_options: typing.Iterable[SOCKET_OPTION] | None = None, + h2_ping_interval: float | None = None, ) -> None: self._origin = origin self._ssl_context = ssl_context @@ -57,6 +58,7 @@ def __init__( self._retries = retries self._local_address = local_address self._uds = uds + self._h2_ping_interval = h2_ping_interval self._network_backend: NetworkBackend = ( SyncBackend() if network_backend is None else network_backend @@ -89,6 +91,7 @@ def handle_request(self, request: Request) -> Response: origin=self._origin, stream=stream, keepalive_expiry=self._keepalive_expiry, + h2_ping_interval=self._h2_ping_interval, ) else: self._connection = HTTP11Connection( diff --git a/httpcore/_sync/connection_pool.py b/httpcore/_sync/connection_pool.py index 4b26f9c6..3ec4509d 100644 --- a/httpcore/_sync/connection_pool.py +++ b/httpcore/_sync/connection_pool.py @@ -59,6 +59,7 @@ def __init__( uds: str | None = None, network_backend: NetworkBackend | None = None, socket_options: typing.Iterable[SOCKET_OPTION] | None = None, + h2_ping_interval: float | None = None, ) -> None: """ A connection pool for making HTTP requests. @@ -88,6 +89,10 @@ def __init__( network_backend: A backend instance to use for handling network I/O. socket_options: Socket options that have to be included in the TCP socket when the connection was established. + h2_ping_interval: Interval in seconds between HTTP/2 PING frames + sent to keep connections alive. Set to ``None`` to disable. + Falls back to the ``HTTPCORE_H2_PING_INTERVAL`` environment + variable if not specified. """ self._ssl_context = ssl_context self._proxy = proxy @@ -114,6 +119,7 @@ def __init__( SyncBackend() if network_backend is None else network_backend ) self._socket_options = socket_options + self._h2_ping_interval = h2_ping_interval # The mutable state on a connection pool is the queue of incoming requests, # and the set of connections that are servicing those requests. @@ -176,6 +182,7 @@ def create_connection(self, origin: Origin) -> ConnectionInterface: uds=self._uds, network_backend=self._network_backend, socket_options=self._socket_options, + h2_ping_interval=self._h2_ping_interval, ) @property diff --git a/httpcore/_sync/http2.py b/httpcore/_sync/http2.py index ddcc1890..3e8f40d0 100644 --- a/httpcore/_sync/http2.py +++ b/httpcore/_sync/http2.py @@ -2,6 +2,8 @@ import enum import logging +import os +import threading import time import types import typing @@ -48,6 +50,7 @@ def __init__( origin: Origin, stream: NetworkStream, keepalive_expiry: float | None = None, + h2_ping_interval: float | None = None, ): self._origin = origin self._network_stream = stream @@ -64,6 +67,15 @@ def __init__( self._used_all_stream_ids = False self._connection_error = False + if h2_ping_interval is not None: + self._h2_ping_interval: float | None = h2_ping_interval + else: + env_val = os.environ.get("HTTPCORE_H2_PING_INTERVAL") + self._h2_ping_interval = float(env_val) if env_val else None + self._ping_thread: threading.Thread | None = None + self._ping_stop = threading.Event() + self._ping_write_lock = threading.Lock() + # Mapping from stream ID to response stream events. self._events: dict[ int, @@ -217,6 +229,48 @@ def _send_connection_init(self, request: Request) -> None: self._h2_state.increment_flow_control_window(2**24) self._write_outgoing_data(request) + if self._h2_ping_interval is not None: + self._start_ping_keepalive() + + def _start_ping_keepalive(self) -> None: + self._ping_stop.clear() + self._ping_thread = threading.Thread( + target=self._ping_keepalive_loop, daemon=True + ) + self._ping_thread.start() + logger.debug( + "HTTP/2 PING keepalive started (interval=%.0fs)", self._h2_ping_interval + ) + + def _ping_keepalive_loop(self) -> None: + """Background thread that sends periodic PING frames via the raw socket.""" + assert self._h2_ping_interval is not None + + raw_sock = self._network_stream.get_extra_info("socket") + if raw_sock is None: + raw_sock = self._network_stream.get_extra_info("ssl_object") + if raw_sock is None: # pragma: nocover + logger.debug("HTTP/2 PING keepalive: unable to obtain raw socket, stopping") + return + + while not self._ping_stop.wait(self._h2_ping_interval): + try: + if self.is_closed(): # pragma: nocover + break + with self._ping_write_lock: + if self.is_closed(): # pragma: nocover + break + opaque = int(time.monotonic_ns() & 0xFFFFFFFFFFFFFFFF).to_bytes( + 8, "big" + ) + self._h2_state.ping(opaque) + data_to_send = self._h2_state.data_to_send() + if data_to_send: + raw_sock.sendall(data_to_send) + logger.debug("HTTP/2 PING sent") + except Exception: # pragma: nocover + break + # Sending the request... def _send_request_headers(self, request: Request, stream_id: int) -> None: @@ -424,6 +478,10 @@ def _response_closed(self, stream_id: int) -> None: def close(self) -> None: # Note that this method unilaterally closes the connection, and does # not have any kind of locking in place around it. + self._ping_stop.set() + if self._ping_thread is not None: + self._ping_thread.join(timeout=2) + self._ping_thread = None self._h2_state.close_connection() self._state = HTTPConnectionState.CLOSED self._network_stream.close() diff --git a/tests/_async/test_http2.py b/tests/_async/test_http2.py index b4ec6648..7ba93991 100644 --- a/tests/_async/test_http2.py +++ b/tests/_async/test_http2.py @@ -1,3 +1,8 @@ +import os +import time +import typing +from unittest.mock import patch + import hpack import hyperframe.frame import pytest @@ -380,3 +385,199 @@ async def test_http2_remote_max_streams_update(): conn._h2_state.local_settings.max_concurrent_streams, ) i += 1 + + +@pytest.mark.anyio +async def test_http2_ping_keepalive_thread_lifecycle(): + """ + When h2_ping_interval is set, a background PING thread should be started + after the connection is initialized and stopped when the connection is closed. + """ + origin = httpcore.Origin(b"https", b"example.com", 443) + stream = httpcore.AsyncMockStream( + [ + hyperframe.frame.SettingsFrame().serialize(), + hyperframe.frame.HeadersFrame( + stream_id=1, + data=hpack.Encoder().encode( + [ + (b":status", b"200"), + (b"content-type", b"plain/text"), + ] + ), + flags=["END_HEADERS"], + ).serialize(), + hyperframe.frame.DataFrame( + stream_id=1, data=b"Hello, world!", flags=["END_STREAM"] + ).serialize(), + ] + ) + conn = httpcore.AsyncHTTP2Connection( + origin=origin, stream=stream, h2_ping_interval=10.0 + ) + assert conn._h2_ping_interval == 10.0 + + response = await conn.request("GET", "https://example.com/") + assert response.status == 200 + assert response.content == b"Hello, world!" + + assert conn._ping_thread is not None + assert conn._ping_thread.is_alive() + + await conn.aclose() + + assert conn._ping_thread is None or not conn._ping_thread.is_alive() + + +@pytest.mark.anyio +async def test_http2_no_ping_keepalive_by_default(): + """ + When h2_ping_interval is not set and the env var is absent, no PING thread + should be started. + """ + origin = httpcore.Origin(b"https", b"example.com", 443) + stream = httpcore.AsyncMockStream( + [ + hyperframe.frame.SettingsFrame().serialize(), + hyperframe.frame.HeadersFrame( + stream_id=1, + data=hpack.Encoder().encode( + [ + (b":status", b"200"), + (b"content-type", b"plain/text"), + ] + ), + flags=["END_HEADERS"], + ).serialize(), + hyperframe.frame.DataFrame( + stream_id=1, data=b"Hello, world!", flags=["END_STREAM"] + ).serialize(), + ] + ) + with patch.dict(os.environ, {}, clear=False): + os.environ.pop("HTTPCORE_H2_PING_INTERVAL", None) + async with httpcore.AsyncHTTP2Connection(origin=origin, stream=stream) as conn: + response = await conn.request("GET", "https://example.com/") + assert response.status == 200 + assert conn._h2_ping_interval is None + assert conn._ping_thread is None + + +@pytest.mark.anyio +async def test_http2_ping_keepalive_env_var(): + """ + The HTTPCORE_H2_PING_INTERVAL environment variable should enable PING + keepalive when the constructor argument is not provided. + """ + origin = httpcore.Origin(b"https", b"example.com", 443) + stream = httpcore.AsyncMockStream( + [ + hyperframe.frame.SettingsFrame().serialize(), + hyperframe.frame.HeadersFrame( + stream_id=1, + data=hpack.Encoder().encode( + [ + (b":status", b"200"), + (b"content-type", b"plain/text"), + ] + ), + flags=["END_HEADERS"], + ).serialize(), + hyperframe.frame.DataFrame( + stream_id=1, data=b"Hello, world!", flags=["END_STREAM"] + ).serialize(), + ] + ) + with patch.dict(os.environ, {"HTTPCORE_H2_PING_INTERVAL": "30"}): + conn = httpcore.AsyncHTTP2Connection(origin=origin, stream=stream) + assert conn._h2_ping_interval == 30.0 + + response = await conn.request("GET", "https://example.com/") + assert response.status == 200 + + assert conn._ping_thread is not None + + await conn.aclose() + + +@pytest.mark.anyio +async def test_http2_ping_keepalive_constructor_overrides_env(): + """ + An explicit h2_ping_interval constructor argument should take precedence + over the HTTPCORE_H2_PING_INTERVAL environment variable. + """ + origin = httpcore.Origin(b"https", b"example.com", 443) + stream = httpcore.AsyncMockStream([]) + + with patch.dict(os.environ, {"HTTPCORE_H2_PING_INTERVAL": "30"}): + conn = httpcore.AsyncHTTP2Connection( + origin=origin, stream=stream, h2_ping_interval=45.0 + ) + assert conn._h2_ping_interval == 45.0 + await conn.aclose() + + +@pytest.mark.anyio +async def test_http2_ping_keepalive_sends_ping_frames(): + """ + Verify that the PING keepalive loop actually generates PING frames + on the h2 state machine. + """ + written_data: typing.List[bytes] = [] + + class RecordingMockStream(httpcore.AsyncMockStream): + async def write( + self, buffer: bytes, timeout: typing.Optional[float] = None + ) -> None: + written_data.append(buffer) + + def get_extra_info(self, info: str) -> typing.Any: + if info == "socket": + return RecordingSocket() + return super().get_extra_info(info) # pragma: nocover + + class RecordingSocket: + """Fake socket that records sendall calls for the async PING thread.""" + + def sendall(self, data: bytes) -> None: + written_data.append(data) + + def selected_alpn_protocol(self) -> str: # pragma: nocover + return "h2" + + origin = httpcore.Origin(b"https", b"example.com", 443) + stream = RecordingMockStream( + [ + hyperframe.frame.SettingsFrame().serialize(), + hyperframe.frame.HeadersFrame( + stream_id=1, + data=hpack.Encoder().encode( + [ + (b":status", b"200"), + (b"content-type", b"plain/text"), + ] + ), + flags=["END_HEADERS"], + ).serialize(), + hyperframe.frame.DataFrame( + stream_id=1, data=b"Hello, world!", flags=["END_STREAM"] + ).serialize(), + ] + ) + conn = httpcore.AsyncHTTP2Connection( + origin=origin, + stream=stream, + h2_ping_interval=0.1, + ) + response = await conn.request("GET", "https://example.com/") + assert response.status == 200 + + # Wait for at least one PING to be sent + time.sleep(0.3) + + await conn.aclose() + + # Look for PING frames (type 0x06) in the written data + ping_frame_type = b"\x06" + ping_found = any(ping_frame_type in data for data in written_data) + assert ping_found, "Expected at least one PING frame to be written" diff --git a/tests/_sync/test_http2.py b/tests/_sync/test_http2.py index 695359bd..485d521c 100644 --- a/tests/_sync/test_http2.py +++ b/tests/_sync/test_http2.py @@ -1,3 +1,8 @@ +import os +import time +import typing +from unittest.mock import patch + import hpack import hyperframe.frame import pytest @@ -380,3 +385,199 @@ def test_http2_remote_max_streams_update(): conn._h2_state.local_settings.max_concurrent_streams, ) i += 1 + + + +def test_http2_ping_keepalive_thread_lifecycle(): + """ + When h2_ping_interval is set, a background PING thread should be started + after the connection is initialized and stopped when the connection is closed. + """ + origin = httpcore.Origin(b"https", b"example.com", 443) + stream = httpcore.MockStream( + [ + hyperframe.frame.SettingsFrame().serialize(), + hyperframe.frame.HeadersFrame( + stream_id=1, + data=hpack.Encoder().encode( + [ + (b":status", b"200"), + (b"content-type", b"plain/text"), + ] + ), + flags=["END_HEADERS"], + ).serialize(), + hyperframe.frame.DataFrame( + stream_id=1, data=b"Hello, world!", flags=["END_STREAM"] + ).serialize(), + ] + ) + conn = httpcore.HTTP2Connection( + origin=origin, stream=stream, h2_ping_interval=10.0 + ) + assert conn._h2_ping_interval == 10.0 + + response = conn.request("GET", "https://example.com/") + assert response.status == 200 + assert response.content == b"Hello, world!" + + assert conn._ping_thread is not None + assert conn._ping_thread.is_alive() + + conn.close() + + assert conn._ping_thread is None or not conn._ping_thread.is_alive() + + + +def test_http2_no_ping_keepalive_by_default(): + """ + When h2_ping_interval is not set and the env var is absent, no PING thread + should be started. + """ + origin = httpcore.Origin(b"https", b"example.com", 443) + stream = httpcore.MockStream( + [ + hyperframe.frame.SettingsFrame().serialize(), + hyperframe.frame.HeadersFrame( + stream_id=1, + data=hpack.Encoder().encode( + [ + (b":status", b"200"), + (b"content-type", b"plain/text"), + ] + ), + flags=["END_HEADERS"], + ).serialize(), + hyperframe.frame.DataFrame( + stream_id=1, data=b"Hello, world!", flags=["END_STREAM"] + ).serialize(), + ] + ) + with patch.dict(os.environ, {}, clear=False): + os.environ.pop("HTTPCORE_H2_PING_INTERVAL", None) + with httpcore.HTTP2Connection(origin=origin, stream=stream) as conn: + response = conn.request("GET", "https://example.com/") + assert response.status == 200 + assert conn._h2_ping_interval is None + assert conn._ping_thread is None + + + +def test_http2_ping_keepalive_env_var(): + """ + The HTTPCORE_H2_PING_INTERVAL environment variable should enable PING + keepalive when the constructor argument is not provided. + """ + origin = httpcore.Origin(b"https", b"example.com", 443) + stream = httpcore.MockStream( + [ + hyperframe.frame.SettingsFrame().serialize(), + hyperframe.frame.HeadersFrame( + stream_id=1, + data=hpack.Encoder().encode( + [ + (b":status", b"200"), + (b"content-type", b"plain/text"), + ] + ), + flags=["END_HEADERS"], + ).serialize(), + hyperframe.frame.DataFrame( + stream_id=1, data=b"Hello, world!", flags=["END_STREAM"] + ).serialize(), + ] + ) + with patch.dict(os.environ, {"HTTPCORE_H2_PING_INTERVAL": "30"}): + conn = httpcore.HTTP2Connection(origin=origin, stream=stream) + assert conn._h2_ping_interval == 30.0 + + response = conn.request("GET", "https://example.com/") + assert response.status == 200 + + assert conn._ping_thread is not None + + conn.close() + + + +def test_http2_ping_keepalive_constructor_overrides_env(): + """ + An explicit h2_ping_interval constructor argument should take precedence + over the HTTPCORE_H2_PING_INTERVAL environment variable. + """ + origin = httpcore.Origin(b"https", b"example.com", 443) + stream = httpcore.MockStream([]) + + with patch.dict(os.environ, {"HTTPCORE_H2_PING_INTERVAL": "30"}): + conn = httpcore.HTTP2Connection( + origin=origin, stream=stream, h2_ping_interval=45.0 + ) + assert conn._h2_ping_interval == 45.0 + conn.close() + + + +def test_http2_ping_keepalive_sends_ping_frames(): + """ + Verify that the PING keepalive loop actually generates PING frames + on the h2 state machine. + """ + written_data: typing.List[bytes] = [] + + class RecordingMockStream(httpcore.MockStream): + def write( + self, buffer: bytes, timeout: typing.Optional[float] = None + ) -> None: + written_data.append(buffer) + + def get_extra_info(self, info: str) -> typing.Any: + if info == "socket": + return RecordingSocket() + return super().get_extra_info(info) # pragma: nocover + + class RecordingSocket: + """Fake socket that records sendall calls for the async PING thread.""" + + def sendall(self, data: bytes) -> None: + written_data.append(data) + + def selected_alpn_protocol(self) -> str: # pragma: nocover + return "h2" + + origin = httpcore.Origin(b"https", b"example.com", 443) + stream = RecordingMockStream( + [ + hyperframe.frame.SettingsFrame().serialize(), + hyperframe.frame.HeadersFrame( + stream_id=1, + data=hpack.Encoder().encode( + [ + (b":status", b"200"), + (b"content-type", b"plain/text"), + ] + ), + flags=["END_HEADERS"], + ).serialize(), + hyperframe.frame.DataFrame( + stream_id=1, data=b"Hello, world!", flags=["END_STREAM"] + ).serialize(), + ] + ) + conn = httpcore.HTTP2Connection( + origin=origin, + stream=stream, + h2_ping_interval=0.1, + ) + response = conn.request("GET", "https://example.com/") + assert response.status == 200 + + # Wait for at least one PING to be sent + time.sleep(0.3) + + conn.close() + + # Look for PING frames (type 0x06) in the written data + ping_frame_type = b"\x06" + ping_found = any(ping_frame_type in data for data in written_data) + assert ping_found, "Expected at least one PING frame to be written"