From 9dbf4d8e13efbc09b9dd0371974054338f27d5f9 Mon Sep 17 00:00:00 2001 From: Henry Lee Date: Fri, 1 May 2026 20:52:53 +0800 Subject: [PATCH 1/2] fix: avoid closing stdio server streams --- src/mcp/server/stdio.py | 37 +++++++++++++++++++++++++++++-------- tests/server/test_stdio.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 8 deletions(-) diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index 5c1459dff..8b31cb845 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -17,9 +17,10 @@ async def run_server(): ``` """ +import os import sys from contextlib import asynccontextmanager -from io import TextIOWrapper +from io import TextIOWrapper, UnsupportedOperation import anyio import anyio.lowlevel @@ -29,6 +30,17 @@ async def run_server(): from mcp.shared.message import SessionMessage +def _wrap_standard_stream(stream, mode: str, *, errors: str | None = None) -> tuple[anyio.AsyncFile[str], bool]: + """Wrap a standard stream without taking ownership of the original handle.""" + try: + fd = os.dup(stream.fileno()) + except (AttributeError, OSError, UnsupportedOperation): + return anyio.wrap_file(TextIOWrapper(stream.buffer, encoding="utf-8", errors=errors)), False + + binary = os.fdopen(fd, mode, closefd=True) + return anyio.wrap_file(TextIOWrapper(binary, encoding="utf-8", errors=errors)), True + + @asynccontextmanager async def stdio_server(stdin: anyio.AsyncFile[str] | None = None, stdout: anyio.AsyncFile[str] | None = None): """Server transport for stdio: this communicates with an MCP client by reading @@ -37,11 +49,14 @@ async def stdio_server(stdin: anyio.AsyncFile[str] | None = None, stdout: anyio. # Purposely not using context managers for these, as we don't want to close # standard process handles. Encoding of stdin/stdout as text streams on # python is platform-dependent (Windows is particularly problematic), so we - # re-wrap the underlying binary stream to ensure UTF-8. + # re-wrap duplicate file descriptors to ensure UTF-8 without taking + # ownership of the original standard streams. + close_stdin = False + close_stdout = False if not stdin: - stdin = anyio.wrap_file(TextIOWrapper(sys.stdin.buffer, encoding="utf-8", errors="replace")) + stdin, close_stdin = _wrap_standard_stream(sys.stdin, "rb", errors="replace") if not stdout: - stdout = anyio.wrap_file(TextIOWrapper(sys.stdout.buffer, encoding="utf-8")) + stdout, close_stdout = _wrap_standard_stream(sys.stdout, "wb") read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0) write_stream, write_stream_reader = create_context_streams[SessionMessage](0) @@ -71,7 +86,13 @@ async def stdout_writer(): except anyio.ClosedResourceError: # pragma: no cover await anyio.lowlevel.checkpoint() - async with anyio.create_task_group() as tg: - tg.start_soon(stdin_reader) - tg.start_soon(stdout_writer) - yield read_stream, write_stream + try: + async with anyio.create_task_group() as tg: + tg.start_soon(stdin_reader) + tg.start_soon(stdout_writer) + yield read_stream, write_stream + finally: + if close_stdin: + await stdin.aclose() + if close_stdout: + await stdout.aclose() diff --git a/tests/server/test_stdio.py b/tests/server/test_stdio.py index 677a99356..63dace778 100644 --- a/tests/server/test_stdio.py +++ b/tests/server/test_stdio.py @@ -1,3 +1,4 @@ +import gc import io import sys from io import TextIOWrapper @@ -92,3 +93,30 @@ async def test_stdio_server_invalid_utf8(monkeypatch: pytest.MonkeyPatch): second = await read_stream.receive() assert isinstance(second, SessionMessage) assert second.message == valid + + +@pytest.mark.anyio +async def test_stdio_server_does_not_close_standard_streams(monkeypatch: pytest.MonkeyPatch, tmp_path): + """Default stdio wrapping must not close the process stdin/stdout handles.""" + message = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") + stdin_path = tmp_path / "stdin.jsonl" + stdout_path = tmp_path / "stdout.jsonl" + stdin_path.write_text(message.model_dump_json(by_alias=True, exclude_none=True) + "\n", encoding="utf-8") + + with stdin_path.open("r", encoding="utf-8") as fake_stdin, stdout_path.open("w+", encoding="utf-8") as fake_stdout: + monkeypatch.setattr(sys, "stdin", fake_stdin) + monkeypatch.setattr(sys, "stdout", fake_stdout) + + async with stdio_server() as (read_stream, write_stream): + await write_stream.aclose() + async with read_stream: + received = await read_stream.receive() + assert isinstance(received, SessionMessage) + assert received.message == message + + gc.collect() + + assert not fake_stdin.closed + assert not fake_stdout.closed + fake_stdout.write("still open\n") + fake_stdout.flush() From df31fecda4d4b055eb061a802096e7c9228812ff Mon Sep 17 00:00:00 2001 From: Henry Lee Date: Fri, 1 May 2026 21:06:30 +0800 Subject: [PATCH 2/2] test: satisfy stdio pyright checks --- src/mcp/server/stdio.py | 16 +++++++++++++++- tests/server/test_stdio.py | 3 ++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index 8b31cb845..0ce395caa 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -19,8 +19,10 @@ async def run_server(): import os import sys +from collections.abc import Callable from contextlib import asynccontextmanager from io import TextIOWrapper, UnsupportedOperation +from typing import BinaryIO, Literal, Protocol import anyio import anyio.lowlevel @@ -30,7 +32,19 @@ async def run_server(): from mcp.shared.message import SessionMessage -def _wrap_standard_stream(stream, mode: str, *, errors: str | None = None) -> tuple[anyio.AsyncFile[str], bool]: +class _TextStreamWithBuffer(Protocol): + @property + def buffer(self) -> BinaryIO: ... + + fileno: Callable[[], int] + + +def _wrap_standard_stream( + stream: _TextStreamWithBuffer, + mode: Literal["rb", "wb"], + *, + errors: str | None = None, +) -> tuple[anyio.AsyncFile[str], bool]: """Wrap a standard stream without taking ownership of the original handle.""" try: fd = os.dup(stream.fileno()) diff --git a/tests/server/test_stdio.py b/tests/server/test_stdio.py index 63dace778..d4ce860ec 100644 --- a/tests/server/test_stdio.py +++ b/tests/server/test_stdio.py @@ -2,6 +2,7 @@ import io import sys from io import TextIOWrapper +from pathlib import Path import anyio import pytest @@ -96,7 +97,7 @@ async def test_stdio_server_invalid_utf8(monkeypatch: pytest.MonkeyPatch): @pytest.mark.anyio -async def test_stdio_server_does_not_close_standard_streams(monkeypatch: pytest.MonkeyPatch, tmp_path): +async def test_stdio_server_does_not_close_standard_streams(monkeypatch: pytest.MonkeyPatch, tmp_path: Path): """Default stdio wrapping must not close the process stdin/stdout handles.""" message = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") stdin_path = tmp_path / "stdin.jsonl"