diff --git a/docs/migration.md b/docs/migration.md index 46ec205ee..1c965562d 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -364,6 +364,10 @@ For an in-process `Client(server)` (where `server` is a `Server` or `MCPServer` `Client.send_ping()` is deprecated (ping is removed in 2026-07-28); pin `mode='legacy'` if you need it. +### `call_tool` can return `InputRequiredResult` (opt-in) + +For protocol 2026-07-28, a `tools/call` request may return an `InputRequiredResult` asking the client to supply additional input and retry. By default `call_tool` (on `ClientSession`, `Client`, and `ClientSessionGroup`) still returns `CallToolResult` and raises `RuntimeError` if the server requests input. Pass `allow_input_required=True` to receive the `InputRequiredResult` instead, then retry with `input_responses=` / `request_state=`. + ### `McpError` renamed to `MCPError` The `McpError` exception class has been renamed to `MCPError` for consistent naming with the MCP acronym style used throughout the SDK. diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index 1ab8209b1..362042ba5 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -5,7 +5,7 @@ from collections.abc import Awaitable, Callable, Mapping from contextlib import AsyncExitStack from dataclasses import KW_ONLY, dataclass, field -from typing import Any, Literal, TypeVar +from typing import Any, Literal, TypeVar, overload import anyio from typing_extensions import deprecated @@ -30,6 +30,8 @@ EmptyResult, GetPromptResult, Implementation, + InputRequiredResult, + InputResponses, ListPromptsResult, ListResourcesResult, ListResourceTemplatesResult, @@ -374,6 +376,7 @@ async def unsubscribe_resource(self, uri: str, *, meta: RequestParamsMeta | None """Unsubscribe from resource updates.""" return await self.session.unsubscribe_resource(uri, meta=meta) + @overload async def call_tool( self, name: str, @@ -381,8 +384,38 @@ async def call_tool( read_timeout_seconds: float | None = None, progress_callback: ProgressFnT | None = None, *, + input_responses: InputResponses | None = None, + request_state: str | None = None, meta: RequestParamsMeta | None = None, - ) -> CallToolResult: + allow_input_required: Literal[False] = False, + ) -> CallToolResult: ... + + @overload + async def call_tool( + self, + name: str, + arguments: dict[str, Any] | None = None, + read_timeout_seconds: float | None = None, + progress_callback: ProgressFnT | None = None, + *, + input_responses: InputResponses | None = None, + request_state: str | None = None, + meta: RequestParamsMeta | None = None, + allow_input_required: bool, + ) -> CallToolResult | InputRequiredResult: ... + + async def call_tool( + self, + name: str, + arguments: dict[str, Any] | None = None, + read_timeout_seconds: float | None = None, + progress_callback: ProgressFnT | None = None, + *, + input_responses: InputResponses | None = None, + request_state: str | None = None, + meta: RequestParamsMeta | None = None, + allow_input_required: bool = False, + ) -> CallToolResult | InputRequiredResult: """Call a tool on the server. Args: @@ -390,17 +423,32 @@ async def call_tool( arguments: Arguments to pass to the tool read_timeout_seconds: Timeout for the tool call progress_callback: Callback for progress updates + input_responses: Responses to a prior `InputRequiredResult.input_requests` + request_state: Opaque state echoed from a prior `InputRequiredResult` meta: Additional metadata for the request + allow_input_required: When ``False`` (default), an `InputRequiredResult` + from the server raises `RuntimeError`; when ``True``, it is returned + so the caller can resolve the requests and retry. Returns: - The tool result. + The tool result. When ``allow_input_required=True``, may instead be an + `InputRequiredResult` carrying the server's input requests and opaque + ``request_state`` for the retry. + + Raises: + RuntimeError: If the server returns an `InputRequiredResult` and + ``allow_input_required`` is ``False``. """ + # TODO(L84): stop forwarding allow_input_required; run the MRTR auto-loop driver here (S6). return await self.session.call_tool( name=name, arguments=arguments, read_timeout_seconds=read_timeout_seconds, progress_callback=progress_callback, + input_responses=input_responses, + request_state=request_state, meta=meta, + allow_input_required=allow_input_required, ) async def list_prompts( diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 8ac3e2288..902cec80b 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -4,7 +4,7 @@ from collections.abc import Callable, Mapping from dataclasses import dataclass from types import TracebackType -from typing import Any, Protocol, cast +from typing import Any, Literal, Protocol, cast, overload import anyio import anyio.abc @@ -173,6 +173,10 @@ async def _default_logging_callback( ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData) +_CallToolResultAdapter: TypeAdapter[types.CallToolResult | types.InputRequiredResult] = TypeAdapter( + types.CallToolResult | types.InputRequiredResult +) + class ClientSession: """Client half of an MCP connection, running on a `Dispatcher`. @@ -269,7 +273,7 @@ async def __aexit__( async def send_request( self, request: types.ClientRequest, - result_type: type[ReceiveResultT], + result_type: type[ReceiveResultT] | TypeAdapter[ReceiveResultT], request_read_timeout_seconds: float | None = None, metadata: ClientMessageMetadata | None = None, progress_callback: ProgressFnT | None = None, @@ -308,6 +312,8 @@ async def send_request( _methods.validate_server_result(method, version, raw) except KeyError: pass + if isinstance(result_type, TypeAdapter): + return result_type.validate_python(raw, by_name=False) return result_type.model_validate(raw, by_name=False) async def send_notification(self, notification: types.ClientNotification) -> None: @@ -596,6 +602,34 @@ async def unsubscribe_resource(self, uri: str, *, meta: RequestParamsMeta | None types.EmptyResult, ) + @overload + async def call_tool( + self, + name: str, + arguments: dict[str, Any] | None = None, + read_timeout_seconds: float | None = None, + progress_callback: ProgressFnT | None = None, + *, + input_responses: types.InputResponses | None = None, + request_state: str | None = None, + meta: RequestParamsMeta | None = None, + allow_input_required: Literal[False] = False, + ) -> types.CallToolResult: ... + + @overload + async def call_tool( + self, + name: str, + arguments: dict[str, Any] | None = None, + read_timeout_seconds: float | None = None, + progress_callback: ProgressFnT | None = None, + *, + input_responses: types.InputResponses | None = None, + request_state: str | None = None, + meta: RequestParamsMeta | None = None, + allow_input_required: bool, + ) -> types.CallToolResult | types.InputRequiredResult: ... + async def call_tool( self, name: str, @@ -603,22 +637,48 @@ async def call_tool( read_timeout_seconds: float | None = None, progress_callback: ProgressFnT | None = None, *, + input_responses: types.InputResponses | None = None, + request_state: str | None = None, meta: RequestParamsMeta | None = None, - ) -> types.CallToolResult: - """Send a tools/call request with optional progress callback support.""" + allow_input_required: bool = False, + ) -> types.CallToolResult | types.InputRequiredResult: + """Send a tools/call request with optional progress callback support. + + Args: + input_responses: Responses to a prior `InputRequiredResult.input_requests`. + request_state: Opaque state echoed from a prior `InputRequiredResult`. + allow_input_required: When ``False`` (default), an `InputRequiredResult` + from the server raises `RuntimeError`; when ``True``, it is returned + so the caller can resolve the requests and retry. + + Raises: + RuntimeError: If the server returns an `InputRequiredResult` and + ``allow_input_required`` is ``False``. + """ result = await self.send_request( types.CallToolRequest( - params=types.CallToolRequestParams(name=name, arguments=arguments, _meta=meta), + params=types.CallToolRequestParams( + name=name, + arguments=arguments, + input_responses=input_responses, + request_state=request_state, + _meta=meta, + ), ), - types.CallToolResult, + _CallToolResultAdapter, request_read_timeout_seconds=read_timeout_seconds, progress_callback=progress_callback, ) - if not result.is_error: + if isinstance(result, types.CallToolResult) and not result.is_error: await self._validate_tool_result(name, result) + if isinstance(result, types.InputRequiredResult) and not allow_input_required: + raise RuntimeError( + "Server returned InputRequiredResult; pass allow_input_required=True to receive it " + "and retry call_tool(..., input_responses=..., request_state=result.request_state)." + ) return result async def _validate_tool_result(self, name: str, result: types.CallToolResult) -> None: diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 961021264..211733d6a 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -11,7 +11,7 @@ from collections.abc import Callable from dataclasses import dataclass from types import TracebackType -from typing import Any, TypeAlias +from typing import Any, Literal, TypeAlias, overload import anyio import httpx @@ -190,6 +190,7 @@ def tools(self) -> dict[str, types.Tool]: """Returns the tools as a dictionary of names to tools.""" return self._tools + @overload async def call_tool( self, name: str, @@ -197,9 +198,44 @@ async def call_tool( read_timeout_seconds: float | None = None, progress_callback: ProgressFnT | None = None, *, + input_responses: types.InputResponses | None = None, + request_state: str | None = None, meta: types.RequestParamsMeta | None = None, - ) -> types.CallToolResult: - """Executes a tool given its name and arguments.""" + allow_input_required: Literal[False] = False, + ) -> types.CallToolResult: ... + + @overload + async def call_tool( + self, + name: str, + arguments: dict[str, Any] | None = None, + read_timeout_seconds: float | None = None, + progress_callback: ProgressFnT | None = None, + *, + input_responses: types.InputResponses | None = None, + request_state: str | None = None, + meta: types.RequestParamsMeta | None = None, + allow_input_required: bool, + ) -> types.CallToolResult | types.InputRequiredResult: ... + + async def call_tool( + self, + name: str, + arguments: dict[str, Any] | None = None, + read_timeout_seconds: float | None = None, + progress_callback: ProgressFnT | None = None, + *, + input_responses: types.InputResponses | None = None, + request_state: str | None = None, + meta: types.RequestParamsMeta | None = None, + allow_input_required: bool = False, + ) -> types.CallToolResult | types.InputRequiredResult: + """Executes a tool given its name and arguments. + + Raises: + RuntimeError: If the server returns an `InputRequiredResult` and + ``allow_input_required`` is ``False``. + """ session = self._tool_to_session[name] session_tool_name = self.tools[name].name return await session.call_tool( @@ -207,7 +243,10 @@ async def call_tool( arguments=arguments, read_timeout_seconds=read_timeout_seconds, progress_callback=progress_callback, + input_responses=input_responses, + request_state=request_state, meta=meta, + allow_input_required=allow_input_required, ) async def disconnect_from_server(self, session: mcp.ClientSession) -> None: diff --git a/tests/client/test_session.py b/tests/client/test_session.py index c24a4569c..f46d6e606 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -12,7 +12,9 @@ from mcp import MCPError, types from mcp.client import ClientRequestContext +from mcp.client.client import Client from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession +from mcp.server import Server, ServerRequestContext from mcp.shared.direct_dispatcher import create_direct_dispatcher_pair from mcp.shared.dispatcher import CallOptions, DispatchContext, OnNotify, OnRequest from mcp.shared.message import SessionMessage @@ -1656,3 +1658,59 @@ async def test_discover_reraises_unsupported_version_with_malformed_error_data() await session.discover() assert exc.value.error.code == UNSUPPORTED_PROTOCOL_VERSION assert [m for m, _ in dispatcher.calls] == ["server/discover"] + + +@pytest.mark.anyio +async def test_call_tool_returns_input_required_result_when_server_requests_input() -> None: + # `on_call_tool` is still typed `-> CallToolResult` on this branch (#2967 widens it later); + # `add_request_handler` is `HandlerResult`-typed and accepts `InputRequiredResult` cleanly. + async def handler(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.InputRequiredResult: + return types.InputRequiredResult(request_state="s") + + server = Server("test") + server.add_request_handler("tools/call", types.CallToolRequestParams, handler) + with anyio.fail_after(5): + async with Client(server, mode="2026-07-28") as client: + result = await client.call_tool("ask", allow_input_required=True) + assert isinstance(result, types.InputRequiredResult) + assert result.request_state == "s" + + +@pytest.mark.anyio +async def test_call_tool_threads_input_responses_and_request_state_into_params() -> None: + captured: list[types.CallToolRequestParams] = [] + + async def on_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + captured.append(params) + return CallToolResult(content=[]) + + async def on_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[]) + + server = Server("test", on_call_tool=on_call_tool, on_list_tools=on_list_tools) + with anyio.fail_after(5): + async with Client(server, mode="2026-07-28") as client: + await client.call_tool( + "ask", + input_responses={"k": types.ElicitResult(action="decline")}, + request_state="s", + ) + assert captured[0].input_responses == {"k": types.ElicitResult(action="decline")} + assert captured[0].request_state == "s" + + +@pytest.mark.anyio +async def test_client_call_tool_raises_on_input_required_without_opt_in() -> None: + async def handler(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.InputRequiredResult: + return types.InputRequiredResult(request_state="s") + + server = Server("test") + server.add_request_handler("tools/call", types.CallToolRequestParams, handler) + with anyio.fail_after(5): + async with Client(server, mode="2026-07-28") as client: + with pytest.raises(RuntimeError): + await client.call_tool("t") + result = await client.call_tool("t", allow_input_required=True) + assert isinstance(result, types.InputRequiredResult) diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py index 6a58b39f3..faa4281e3 100644 --- a/tests/client/test_session_group.py +++ b/tests/client/test_session_group.py @@ -82,10 +82,27 @@ def hook(name: str, server_info: types.Implementation) -> str: # pragma: no cov arguments={"name": "value1", "args": {}}, read_timeout_seconds=None, progress_callback=None, + input_responses=None, + request_state=None, meta=None, + allow_input_required=False, ) +@pytest.mark.anyio +async def test_client_session_group_call_tool_forwards_allow_input_required(): + mock_session = mock.AsyncMock() + mcp_session_group = ClientSessionGroup() + mcp_session_group._tools = {"my_tool": types.Tool(name="my_tool", input_schema={})} + mcp_session_group._tool_to_session = {"my_tool": mock_session} + mock_session.call_tool.return_value = types.InputRequiredResult(request_state="s") + + result = await mcp_session_group.call_tool(name="my_tool", arguments={}, allow_input_required=True) + assert isinstance(result, types.InputRequiredResult) + assert result.request_state == "s" + assert mock_session.call_tool.call_args.kwargs["allow_input_required"] is True + + @pytest.mark.anyio async def test_client_session_group_connect_to_server(mock_exit_stack: contextlib.AsyncExitStack): """Test connecting to a server and aggregating components."""