Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/migration.md
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,10 @@ For an in-process `Client(server)` (where `server` is a `Server` or `MCPServer`

`Client.send_ping()` is deprecated (ping is removed in 2026-07-28); pin `mode='legacy'` if you need it.

### `call_tool` can return `InputRequiredResult` (opt-in)

For protocol 2026-07-28, a `tools/call` request may return an `InputRequiredResult` asking the client to supply additional input and retry. By default `call_tool` (on `ClientSession`, `Client`, and `ClientSessionGroup`) still returns `CallToolResult` and raises `RuntimeError` if the server requests input. Pass `allow_input_required=True` to receive the `InputRequiredResult` instead, then retry with `input_responses=` / `request_state=`.

### `McpError` renamed to `MCPError`

The `McpError` exception class has been renamed to `MCPError` for consistent naming with the MCP acronym style used throughout the SDK.
Expand Down
54 changes: 51 additions & 3 deletions src/mcp/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections.abc import Awaitable, Callable, Mapping
from contextlib import AsyncExitStack
from dataclasses import KW_ONLY, dataclass, field
from typing import Any, Literal, TypeVar
from typing import Any, Literal, TypeVar, overload

import anyio
from typing_extensions import deprecated
Expand All @@ -30,6 +30,8 @@
EmptyResult,
GetPromptResult,
Implementation,
InputRequiredResult,
InputResponses,
ListPromptsResult,
ListResourcesResult,
ListResourceTemplatesResult,
Expand Down Expand Up @@ -374,33 +376,79 @@ async def unsubscribe_resource(self, uri: str, *, meta: RequestParamsMeta | None
"""Unsubscribe from resource updates."""
return await self.session.unsubscribe_resource(uri, meta=meta)

@overload
async def call_tool(
self,
name: str,
arguments: dict[str, Any] | None = None,
read_timeout_seconds: float | None = None,
progress_callback: ProgressFnT | None = None,
*,
input_responses: InputResponses | None = None,
request_state: str | None = None,
meta: RequestParamsMeta | None = None,
) -> CallToolResult:
allow_input_required: Literal[False] = False,
) -> CallToolResult: ...

@overload
async def call_tool(
self,
name: str,
arguments: dict[str, Any] | None = None,
read_timeout_seconds: float | None = None,
progress_callback: ProgressFnT | None = None,
*,
input_responses: InputResponses | None = None,
request_state: str | None = None,
meta: RequestParamsMeta | None = None,
allow_input_required: bool,
) -> CallToolResult | InputRequiredResult: ...

async def call_tool(
self,
name: str,
arguments: dict[str, Any] | None = None,
read_timeout_seconds: float | None = None,
progress_callback: ProgressFnT | None = None,
*,
input_responses: InputResponses | None = None,
request_state: str | None = None,
meta: RequestParamsMeta | None = None,
allow_input_required: bool = False,
) -> CallToolResult | InputRequiredResult:
"""Call a tool on the server.

Args:
name: The name of the tool to call
arguments: Arguments to pass to the tool
read_timeout_seconds: Timeout for the tool call
progress_callback: Callback for progress updates
input_responses: Responses to a prior `InputRequiredResult.input_requests`
request_state: Opaque state echoed from a prior `InputRequiredResult`
meta: Additional metadata for the request
allow_input_required: When ``False`` (default), an `InputRequiredResult`
from the server raises `RuntimeError`; when ``True``, it is returned
so the caller can resolve the requests and retry.

Returns:
The tool result.
The tool result. When ``allow_input_required=True``, may instead be an
`InputRequiredResult` carrying the server's input requests and opaque
``request_state`` for the retry.

Raises:
RuntimeError: If the server returns an `InputRequiredResult` and
``allow_input_required`` is ``False``.
"""
# TODO(L84): stop forwarding allow_input_required; run the MRTR auto-loop driver here (S6).
return await self.session.call_tool(
name=name,
arguments=arguments,
read_timeout_seconds=read_timeout_seconds,
progress_callback=progress_callback,
input_responses=input_responses,
request_state=request_state,
meta=meta,
allow_input_required=allow_input_required,
)

async def list_prompts(
Expand Down
74 changes: 67 additions & 7 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections.abc import Callable, Mapping
from dataclasses import dataclass
from types import TracebackType
from typing import Any, Protocol, cast
from typing import Any, Literal, Protocol, cast, overload

import anyio
import anyio.abc
Expand Down Expand Up @@ -173,6 +173,10 @@ async def _default_logging_callback(

ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData)

_CallToolResultAdapter: TypeAdapter[types.CallToolResult | types.InputRequiredResult] = TypeAdapter(
types.CallToolResult | types.InputRequiredResult
)


class ClientSession:
"""Client half of an MCP connection, running on a `Dispatcher`.
Expand Down Expand Up @@ -269,7 +273,7 @@ async def __aexit__(
async def send_request(
self,
request: types.ClientRequest,
result_type: type[ReceiveResultT],
result_type: type[ReceiveResultT] | TypeAdapter[ReceiveResultT],
request_read_timeout_seconds: float | None = None,
metadata: ClientMessageMetadata | None = None,
progress_callback: ProgressFnT | None = None,
Expand Down Expand Up @@ -308,6 +312,8 @@ async def send_request(
_methods.validate_server_result(method, version, raw)
except KeyError:
pass
if isinstance(result_type, TypeAdapter):
return result_type.validate_python(raw, by_name=False)
return result_type.model_validate(raw, by_name=False)
Comment thread
claude[bot] marked this conversation as resolved.

async def send_notification(self, notification: types.ClientNotification) -> None:
Expand Down Expand Up @@ -596,29 +602,83 @@ async def unsubscribe_resource(self, uri: str, *, meta: RequestParamsMeta | None
types.EmptyResult,
)

@overload
async def call_tool(
self,
name: str,
arguments: dict[str, Any] | None = None,
read_timeout_seconds: float | None = None,
progress_callback: ProgressFnT | None = None,
*,
input_responses: types.InputResponses | None = None,
request_state: str | None = None,
meta: RequestParamsMeta | None = None,
allow_input_required: Literal[False] = False,
) -> types.CallToolResult: ...

@overload
async def call_tool(
self,
name: str,
arguments: dict[str, Any] | None = None,
read_timeout_seconds: float | None = None,
progress_callback: ProgressFnT | None = None,
*,
input_responses: types.InputResponses | None = None,
request_state: str | None = None,
meta: RequestParamsMeta | None = None,
allow_input_required: bool,
) -> types.CallToolResult | types.InputRequiredResult: ...

async def call_tool(
self,
name: str,
arguments: dict[str, Any] | None = None,
read_timeout_seconds: float | None = None,
progress_callback: ProgressFnT | None = None,
*,
input_responses: types.InputResponses | None = None,
request_state: str | None = None,
meta: RequestParamsMeta | None = None,
) -> types.CallToolResult:
"""Send a tools/call request with optional progress callback support."""
allow_input_required: bool = False,
) -> types.CallToolResult | types.InputRequiredResult:
"""Send a tools/call request with optional progress callback support.

Args:
input_responses: Responses to a prior `InputRequiredResult.input_requests`.
request_state: Opaque state echoed from a prior `InputRequiredResult`.
allow_input_required: When ``False`` (default), an `InputRequiredResult`
from the server raises `RuntimeError`; when ``True``, it is returned
so the caller can resolve the requests and retry.

Raises:
RuntimeError: If the server returns an `InputRequiredResult` and
``allow_input_required`` is ``False``.
"""

result = await self.send_request(
types.CallToolRequest(
params=types.CallToolRequestParams(name=name, arguments=arguments, _meta=meta),
params=types.CallToolRequestParams(
name=name,
arguments=arguments,
input_responses=input_responses,
request_state=request_state,
_meta=meta,
),
),
types.CallToolResult,
_CallToolResultAdapter,
request_read_timeout_seconds=read_timeout_seconds,
progress_callback=progress_callback,
)

if not result.is_error:
if isinstance(result, types.CallToolResult) and not result.is_error:
await self._validate_tool_result(name, result)

if isinstance(result, types.InputRequiredResult) and not allow_input_required:
raise RuntimeError(
"Server returned InputRequiredResult; pass allow_input_required=True to receive it "
"and retry call_tool(..., input_responses=..., request_state=result.request_state)."
)
return result

async def _validate_tool_result(self, name: str, result: types.CallToolResult) -> None:
Expand Down
45 changes: 42 additions & 3 deletions src/mcp/client/session_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from collections.abc import Callable
from dataclasses import dataclass
from types import TracebackType
from typing import Any, TypeAlias
from typing import Any, Literal, TypeAlias, overload

import anyio
import httpx
Expand Down Expand Up @@ -190,24 +190,63 @@ def tools(self) -> dict[str, types.Tool]:
"""Returns the tools as a dictionary of names to tools."""
return self._tools

@overload
async def call_tool(
self,
name: str,
arguments: dict[str, Any] | None = None,
read_timeout_seconds: float | None = None,
progress_callback: ProgressFnT | None = None,
*,
input_responses: types.InputResponses | None = None,
request_state: str | None = None,
meta: types.RequestParamsMeta | None = None,
) -> types.CallToolResult:
"""Executes a tool given its name and arguments."""
allow_input_required: Literal[False] = False,
) -> types.CallToolResult: ...

@overload
async def call_tool(
self,
name: str,
arguments: dict[str, Any] | None = None,
read_timeout_seconds: float | None = None,
progress_callback: ProgressFnT | None = None,
*,
input_responses: types.InputResponses | None = None,
request_state: str | None = None,
meta: types.RequestParamsMeta | None = None,
allow_input_required: bool,
) -> types.CallToolResult | types.InputRequiredResult: ...

async def call_tool(
self,
name: str,
arguments: dict[str, Any] | None = None,
read_timeout_seconds: float | None = None,
progress_callback: ProgressFnT | None = None,
*,
input_responses: types.InputResponses | None = None,
request_state: str | None = None,
meta: types.RequestParamsMeta | None = None,
allow_input_required: bool = False,
) -> types.CallToolResult | types.InputRequiredResult:
"""Executes a tool given its name and arguments.

Raises:
RuntimeError: If the server returns an `InputRequiredResult` and
``allow_input_required`` is ``False``.
"""
session = self._tool_to_session[name]
session_tool_name = self.tools[name].name
return await session.call_tool(
session_tool_name,
arguments=arguments,
read_timeout_seconds=read_timeout_seconds,
progress_callback=progress_callback,
input_responses=input_responses,
request_state=request_state,
meta=meta,
allow_input_required=allow_input_required,
)

async def disconnect_from_server(self, session: mcp.ClientSession) -> None:
Expand Down
58 changes: 58 additions & 0 deletions tests/client/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@

from mcp import MCPError, types
from mcp.client import ClientRequestContext
from mcp.client.client import Client
from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession
from mcp.server import Server, ServerRequestContext
from mcp.shared.direct_dispatcher import create_direct_dispatcher_pair
from mcp.shared.dispatcher import CallOptions, DispatchContext, OnNotify, OnRequest
from mcp.shared.message import SessionMessage
Expand Down Expand Up @@ -1656,3 +1658,59 @@ async def test_discover_reraises_unsupported_version_with_malformed_error_data()
await session.discover()
assert exc.value.error.code == UNSUPPORTED_PROTOCOL_VERSION
assert [m for m, _ in dispatcher.calls] == ["server/discover"]


@pytest.mark.anyio
async def test_call_tool_returns_input_required_result_when_server_requests_input() -> None:
# `on_call_tool` is still typed `-> CallToolResult` on this branch (#2967 widens it later);
# `add_request_handler` is `HandlerResult`-typed and accepts `InputRequiredResult` cleanly.
async def handler(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.InputRequiredResult:
return types.InputRequiredResult(request_state="s")

server = Server("test")
server.add_request_handler("tools/call", types.CallToolRequestParams, handler)
with anyio.fail_after(5):
async with Client(server, mode="2026-07-28") as client:
result = await client.call_tool("ask", allow_input_required=True)
assert isinstance(result, types.InputRequiredResult)
assert result.request_state == "s"


@pytest.mark.anyio
async def test_call_tool_threads_input_responses_and_request_state_into_params() -> None:
captured: list[types.CallToolRequestParams] = []

async def on_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult:
captured.append(params)
return CallToolResult(content=[])

async def on_list_tools(
ctx: ServerRequestContext, params: types.PaginatedRequestParams | None
) -> types.ListToolsResult:
return types.ListToolsResult(tools=[])

server = Server("test", on_call_tool=on_call_tool, on_list_tools=on_list_tools)
with anyio.fail_after(5):
async with Client(server, mode="2026-07-28") as client:
await client.call_tool(
"ask",
input_responses={"k": types.ElicitResult(action="decline")},
request_state="s",
)
assert captured[0].input_responses == {"k": types.ElicitResult(action="decline")}
assert captured[0].request_state == "s"


@pytest.mark.anyio
async def test_client_call_tool_raises_on_input_required_without_opt_in() -> None:
async def handler(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.InputRequiredResult:
return types.InputRequiredResult(request_state="s")

server = Server("test")
server.add_request_handler("tools/call", types.CallToolRequestParams, handler)
with anyio.fail_after(5):
async with Client(server, mode="2026-07-28") as client:
with pytest.raises(RuntimeError):
await client.call_tool("t")
result = await client.call_tool("t", allow_input_required=True)
assert isinstance(result, types.InputRequiredResult)
Loading
Loading