Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions crates/rmcp/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,11 @@ name = "test_progress_subscriber"
required-features = ["server", "client", "macros"]
path = "tests/test_progress_subscriber.rs"

[[test]]
name = "test_request_timeout_progress"
required-features = ["server", "client", "macros"]
path = "tests/test_request_timeout_progress.rs"

[[test]]
name = "test_elicitation"
required-features = ["elicitation", "client", "server"]
Expand Down
250 changes: 229 additions & 21 deletions crates/rmcp/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,12 @@ pub(crate) type MaybeBoxFuture<'a, T> = BoxFuture<'a, T>;
#[cfg(feature = "local")]
pub(crate) type MaybeBoxFuture<'a, T> = LocalBoxFuture<'a, T>;

#[cfg(feature = "server")]
use crate::model::ClientNotification;
#[cfg(feature = "server")]
use crate::model::ServerJsonRpcMessage;
#[cfg(feature = "client")]
use crate::model::ServerNotification;
use crate::{
error::ErrorData as McpError,
model::{
Expand Down Expand Up @@ -299,7 +303,37 @@ impl ProgressTokenProvider for AtomicU32Provider {
}
}

#[doc(hidden)]
pub trait ProgressNotificationToken {
fn progress_token(&self) -> Option<&ProgressToken>;
}

#[cfg(feature = "server")]
impl ProgressNotificationToken for ClientNotification {
fn progress_token(&self) -> Option<&ProgressToken> {
match self {
ClientNotification::ProgressNotification(notification) => {
Some(&notification.params.progress_token)
}
_ => None,
}
}
}

#[cfg(feature = "client")]
impl ProgressNotificationToken for ServerNotification {
fn progress_token(&self) -> Option<&ProgressToken> {
match self {
ServerNotification::ProgressNotification(notification) => {
Some(&notification.params.progress_token)
}
_ => None,
}
}
}

type Responder<T> = tokio::sync::oneshot::Sender<T>;
type ProgressTimeoutWatchers = Arc<tokio::sync::RwLock<HashMap<ProgressToken, mpsc::Sender<()>>>>;

/// A handle to a remote request
///
Expand All @@ -314,40 +348,126 @@ pub struct RequestHandle<R: ServiceRole> {
pub peer: Peer<R>,
pub id: RequestId,
pub progress_token: ProgressToken,
progress_timeout_watchers: ProgressTimeoutWatchers,
progress_reset_rx: Option<mpsc::Receiver<()>>,
}

impl<R: ServiceRole> RequestHandle<R> {
pub const REQUEST_TIMEOUT_REASON: &str = "request timeout";
pub async fn await_response(self) -> Result<R::PeerResp, ServiceError> {
if let Some(timeout) = self.options.timeout {
let timeout_result = tokio::time::timeout(timeout, async move {
self.rx.await.map_err(|_e| ServiceError::TransportClosed)?
})
.await;
match timeout_result {
Ok(response) => response,
pub const REQUEST_MAX_TOTAL_TIMEOUT_REASON: &str = "maximum total timeout exceeded";

pub async fn await_response(mut self) -> Result<R::PeerResp, ServiceError> {
let timeout = self.options.timeout;
let max_total_timeout = self.options.max_total_timeout;
let reset_timeout_on_progress = self.options.reset_timeout_on_progress;

let has_progress_reset_rx = self.progress_reset_rx.is_some();
let progress_token = self.progress_token.clone();

let result = match (timeout, max_total_timeout, reset_timeout_on_progress) {
(Some(timeout), None, false) => match tokio::time::timeout(timeout, &mut self.rx).await
{
Ok(response) => response.map_err(|_e| ServiceError::TransportClosed)?,
Err(_) => {
let error = Err(ServiceError::Timeout { timeout });
// cancel this request
let notification = CancelledNotification {
params: CancelledNotificationParam {
request_id: self.id,
reason: Some(Self::REQUEST_TIMEOUT_REASON.to_owned()),
},
method: crate::model::CancelledNotificationMethod,
extensions: Default::default(),
};
let _ = self.peer.send_notification(notification.into()).await;
self.send_timeout_cancel_notification(Self::REQUEST_TIMEOUT_REASON)
.await;
error
}
},
(None, None, _) => (&mut self.rx)
.await
.map_err(|_e| ServiceError::TransportClosed)?,
_ => {
self.await_response_with_progress_timeout(
timeout,
max_total_timeout,
reset_timeout_on_progress,
)
.await
}
};

Self::cleanup_progress_timeout_watcher(
&self.peer.progress_timeout_watchers,
&progress_token,
has_progress_reset_rx,
)
Comment thread
ContextVM-org marked this conversation as resolved.
.await;
result
}

async fn send_timeout_cancel_notification(&self, reason: &str) {
let notification = CancelledNotification {
params: CancelledNotificationParam {
request_id: self.id.clone(),
reason: Some(reason.to_owned()),
},
method: crate::model::CancelledNotificationMethod,
extensions: Default::default(),
};
let _ = self.peer.send_notification(notification.into()).await;
}

async fn await_response_with_progress_timeout(
&mut self,
timeout: Option<Duration>,
max_total_timeout: Option<Duration>,
reset_timeout_on_progress: bool,
) -> Result<R::PeerResp, ServiceError> {
let mut idle_sleep = timeout.map(tokio::time::sleep).map(Box::pin);
let mut max_total_sleep = max_total_timeout.map(tokio::time::sleep).map(Box::pin);

loop {
tokio::select! {
biased;

response = &mut self.rx => {
return response.map_err(|_e| ServiceError::TransportClosed)?;
}
_ = async {
if let Some(sleep) = idle_sleep.as_mut() {
sleep.as_mut().await;
}
}, if idle_sleep.is_some() => {
let timeout = timeout.expect("idle timeout exists when idle sleep exists");
self.send_timeout_cancel_notification(Self::REQUEST_TIMEOUT_REASON).await;
return Err(ServiceError::Timeout { timeout });
}
_ = async {
if let Some(sleep) = max_total_sleep.as_mut() {
sleep.as_mut().await;
}
}, if max_total_sleep.is_some() => {
let timeout = max_total_timeout.expect("max total timeout exists when max total sleep exists");
self.send_timeout_cancel_notification(Self::REQUEST_MAX_TOTAL_TIMEOUT_REASON).await;
return Err(ServiceError::Timeout { timeout });
}
progress = async {
match self.progress_reset_rx.as_mut() {
Some(rx) => rx.recv().await,
None => None,
}
}, if reset_timeout_on_progress && timeout.is_some() && self.progress_reset_rx.is_some() => {
if progress.is_some() {
if let (Some(timeout), Some(sleep)) = (timeout, idle_sleep.as_mut()) {
sleep.as_mut().reset(tokio::time::Instant::now() + timeout);
}
}
}
}
} else {
self.rx.await.map_err(|_e| ServiceError::TransportClosed)?
}
}

/// Cancel this request
pub async fn cancel(self, reason: Option<String>) -> Result<(), ServiceError> {
Self::cleanup_progress_timeout_watcher(
&self.progress_timeout_watchers,
&self.progress_token,
self.progress_reset_rx.is_some(),
)
.await;
let notification = CancelledNotification {
params: CancelledNotificationParam {
request_id: self.id,
Expand All @@ -359,6 +479,19 @@ impl<R: ServiceRole> RequestHandle<R> {
self.peer.send_notification(notification.into()).await?;
Ok(())
}

async fn cleanup_progress_timeout_watcher(
progress_timeout_watchers: &ProgressTimeoutWatchers,
progress_token: &ProgressToken,
has_progress_reset_rx: bool,
) {
if has_progress_reset_rx {
progress_timeout_watchers
.write()
.await
.remove(progress_token);
}
}
}

#[derive(Debug)]
Expand All @@ -384,6 +517,7 @@ pub struct Peer<R: ServiceRole> {
tx: mpsc::Sender<PeerSinkMessage<R>>,
request_id_provider: Arc<dyn RequestIdProvider>,
progress_token_provider: Arc<dyn ProgressTokenProvider>,
progress_timeout_watchers: ProgressTimeoutWatchers,
info: Arc<std::sync::RwLock<Option<Arc<R::PeerInfo>>>>,
}

Expand All @@ -403,12 +537,33 @@ type ProxyOutbound<R> = mpsc::Receiver<PeerSinkMessage<R>>;
pub struct PeerRequestOptions {
pub timeout: Option<Duration>,
pub meta: Option<Meta>,
/// Reset the request timeout when a matching progress notification is received.
pub reset_timeout_on_progress: bool,
/// Maximum total time to wait for the request, regardless of progress notifications.
pub max_total_timeout: Option<Duration>,
}

impl PeerRequestOptions {
pub fn no_options() -> Self {
Self::default()
}

pub fn with_timeout(timeout: Duration) -> Self {
Self {
timeout: Some(timeout),
..Self::default()
}
}

pub fn reset_timeout_on_progress(mut self) -> Self {
self.reset_timeout_on_progress = true;
self
}

pub fn with_max_total_timeout(mut self, timeout: Duration) -> Self {
self.max_total_timeout = Some(timeout);
self
}
}

impl<R: ServiceRole> Peer<R> {
Expand All @@ -423,6 +578,7 @@ impl<R: ServiceRole> Peer<R> {
tx,
request_id_provider,
progress_token_provider: Arc::new(AtomicU32ProgressTokenProvider::default()),
progress_timeout_watchers: Default::default(),
info: Arc::new(std::sync::RwLock::new(peer_info.map(Arc::new))),
},
rx,
Expand Down Expand Up @@ -468,22 +624,68 @@ impl<R: ServiceRole> Peer<R> {
request.get_meta_mut().extend(meta);
}
let (responder, receiver) = tokio::sync::oneshot::channel();
self.tx
let progress_reset_rx = if options.reset_timeout_on_progress && options.timeout.is_some() {
let (sender, receiver) = mpsc::channel(1);
self.progress_timeout_watchers
.write()
.await
.insert(progress_token.clone(), sender);
Some(receiver)
} else {
None
};
if self
.tx
.send(PeerSinkMessage::Request {
request,
id: id.clone(),
responder,
})
.await
.map_err(|_m| ServiceError::TransportClosed)?;
.is_err()
{
if progress_reset_rx.is_some() {
self.progress_timeout_watchers
.write()
.await
.remove(&progress_token);
}
return Err(ServiceError::TransportClosed);
}
Ok(RequestHandle {
id,
rx: receiver,
progress_token,
options,
peer: self.clone(),
progress_timeout_watchers: self.progress_timeout_watchers.clone(),
progress_reset_rx,
})
}

async fn notify_progress_timeout_watcher(&self, progress_token: &ProgressToken) {
let sender = self
.progress_timeout_watchers
.read()
.await
.get(progress_token)
.cloned();
if let Some(sender) = sender {
match sender.try_send(()) {
Ok(()) => {}
Err(mpsc::error::TrySendError::Full(_)) => {
tracing::trace!(?progress_token, "progress timeout watcher channel is full");
}
Err(mpsc::error::TrySendError::Closed(_)) => {
self.progress_timeout_watchers
.write()
.await
.remove(progress_token);
}
}
}
}

/// Snapshot of the peer's handshake info.
pub fn peer_info(&self) -> Option<Arc<R::PeerInfo>> {
self.info.read().expect("peer info lock poisoned").clone()
Expand Down Expand Up @@ -700,6 +902,7 @@ pub fn serve_directly<R, S, T, E, A>(
) -> RunningService<R, S>
where
R: ServiceRole,
R::PeerNot: ProgressNotificationToken,
S: Service<R>,
T: IntoTransport<R, E, A>,
E: std::error::Error + Send + Sync + 'static,
Expand All @@ -716,6 +919,7 @@ pub fn serve_directly_with_ct<R, S, T, E, A>(
) -> RunningService<R, S>
where
R: ServiceRole,
R::PeerNot: ProgressNotificationToken,
S: Service<R>,
T: IntoTransport<R, E, A>,
E: std::error::Error + Send + Sync + 'static,
Expand Down Expand Up @@ -756,6 +960,7 @@ fn serve_inner<R, S, T>(
) -> RunningService<R, S>
where
R: ServiceRole,
R::PeerNot: ProgressNotificationToken,
S: Service<R>,
T: Transport<R> + 'static,
{
Expand Down Expand Up @@ -1002,6 +1207,9 @@ where
}
Err(notification) => notification,
};
if let Some(progress_token) = notification.progress_token() {
peer.notify_progress_timeout_watcher(progress_token).await;
}
{
let service = shared_service.clone();
let mut extensions = Extensions::new();
Expand Down
Loading
Loading