From 496bb5484a79f41ff02aa99df9a9bc93eb6acdc0 Mon Sep 17 00:00:00 2001 From: Damocles Date: Fri, 1 May 2026 13:29:52 +0200 Subject: [PATCH] token bucket rate limiting with burst + synthetic notice on delay --- src/claude.rs | 20 ++++++++++-- src/handlers.rs | 8 ++--- src/main.rs | 84 ++++++++++++++++++++++++++++++++++++++----------- src/socket.rs | 7 +++-- src/types.rs | 19 +++++++++-- 5 files changed, 108 insertions(+), 30 deletions(-) diff --git a/src/claude.rs b/src/claude.rs index 271fddf..06ea397 100644 --- a/src/claude.rs +++ b/src/claude.rs @@ -29,7 +29,9 @@ pub struct MatrixTurn { /// Build a matrix_turn envelope for one room. If `include_history` is false, /// the `previously_seen` array is empty (shard already has that context from -/// earlier turns in this session). +/// earlier turns in this session). If `delay_notice_seconds` is `Some(n)`, a +/// synthetic Notice event is prepended to `new_events` informing the shard +/// that this turn was held by rate limiting for that many seconds. pub fn build_turn( source_room: &OwnedRoomId, room_name: &str, @@ -37,6 +39,7 @@ pub fn build_turn( seen_idx: usize, read_markers: &HashMap>, include_history: bool, + delay_notice_seconds: Option, ) -> MatrixTurn { let mut senders: Vec<&OwnedUserId> = timeline .iter() @@ -56,7 +59,7 @@ pub fn build_turn( } else { Vec::new() }; - let new_events: Vec = timeline[seen..] + let mut new_events: Vec = timeline[seen..] .iter() .map(|i| wire_event_from(i, read_markers)) .collect(); @@ -66,6 +69,19 @@ pub fn build_turn( .map(|d| d.as_secs() as i64) .unwrap_or(0); + if let Some(secs) = delay_notice_seconds { + new_events.insert( + 0, + WireEvent::Notice { + text: format!( + "rate_limit: events were held for {secs}s before reaching you. context may be slightly stale; respond accordingly." + ), + ts: now, + ts_human: format!("{} UTC", format_ts(now)), + }, + ); + } + MatrixTurn { kind: "matrix_turn", now, diff --git a/src/handlers.rs b/src/handlers.rs index 7af6e4c..91b3035 100644 --- a/src/handlers.rs +++ b/src/handlers.rs @@ -59,8 +59,8 @@ pub async fn on_room_message( if !is_self { let mut state = state.lock().await; - if !state.pending_rooms.contains(&room_id) { - state.pending_rooms.push(room_id); + if !state.pending_rooms.iter().any(|(r, _)| r == &room_id) { + state.pending_rooms.push((room_id, std::time::Instant::now())); } } } @@ -86,8 +86,8 @@ pub async fn on_reaction( "reaction" ); - if !is_self && !state.pending_rooms.contains(&room_id) { - state.pending_rooms.push(room_id); + if !is_self && !state.pending_rooms.iter().any(|(r, _)| r == &room_id) { + state.pending_rooms.push((room_id, std::time::Instant::now())); } } diff --git a/src/main.rs b/src/main.rs index 1b45bb2..1ce41ae 100644 --- a/src/main.rs +++ b/src/main.rs @@ -62,6 +62,9 @@ async fn main() -> anyhow::Result<()> { let session_max_events = config .session_max_events .unwrap_or(types::DEFAULT_SESSION_MAX_EVENTS); + let rate_burst_capacity = config + .rate_burst_capacity + .unwrap_or(types::DEFAULT_RATE_BURST_CAPACITY); let (client, sync_token) = if session_file.exists() { session::restore_session(&session_file).await? @@ -89,15 +92,20 @@ async fn main() -> anyhow::Result<()> { own_user_id, last_shown: std::collections::HashMap::new(), pending_rooms: Vec::new(), - rate_budget: rate_limit_per_min, + // Start with a full bucket so the first event after startup is fast + rate_budget: rate_burst_capacity as f64, rate_limit_per_min, - last_rate_reset: std::time::Instant::now(), + rate_burst_capacity, + last_rate_check: std::time::Instant::now(), model, max_history, session_idle_minutes, session_max_events, })); + // Notify dispatcher when new events arrive (instant wake-up) + let dispatch_notify = Arc::new(tokio::sync::Notify::new()); + // Start MCP socket listener for tool calls from the shard let socket_path = paths::state_dir().join("daemon.sock"); let socket_client = client.clone(); @@ -110,11 +118,12 @@ async fn main() -> anyhow::Result<()> { let processor_state = state.clone(); let processor_client = client.clone(); + let processor_notify = dispatch_notify.clone(); tokio::spawn(async move { - process_loop(processor_state, processor_client, socket_path).await; + process_loop(processor_state, processor_client, socket_path, processor_notify).await; }); - sync(client, sync_token, &session_file, state).await + sync(client, sync_token, &session_file, state, dispatch_notify).await } async fn sync( @@ -122,6 +131,7 @@ async fn sync( initial_sync_token: Option, session_file: &std::path::Path, state: Arc>, + notify: Arc, ) -> anyhow::Result<()> { let has_token = initial_sync_token.is_some(); if has_token { @@ -141,7 +151,7 @@ async fn sync( // received while we were down trigger the queue. On first start we skip // this to avoid backlogging every historical message. if has_token { - register_event_handlers(&client, state.clone()); + register_event_handlers(&client, state.clone(), notify.clone()); } loop { @@ -158,7 +168,7 @@ async fn sync( } if !has_token { - register_event_handlers(&client, state.clone()); + register_event_handlers(&client, state.clone(), notify.clone()); } tracing::info!("synced, listening for messages"); @@ -168,24 +178,34 @@ async fn sync( bail!("sync loop exited unexpectedly") } -fn register_event_handlers(client: &Client, state: Arc>) { +fn register_event_handlers( + client: &Client, + state: Arc>, + notify: Arc, +) { let msg_state = state.clone(); + let msg_notify = notify.clone(); client.add_event_handler( move |event: matrix_sdk::ruma::events::room::message::OriginalSyncRoomMessageEvent, room: Room| { let state = msg_state.clone(); + let notify = msg_notify.clone(); async move { handlers::on_room_message(event, room, state).await; + notify.notify_one(); } }, ); let react_state = state.clone(); + let react_notify = notify.clone(); client.add_event_handler( move |event: matrix_sdk::ruma::events::reaction::OriginalSyncReactionEvent, room: Room| { let state = react_state.clone(); + let notify = react_notify.clone(); async move { handlers::on_reaction(event, room, state).await; + notify.notify_one(); } }, ); @@ -195,7 +215,16 @@ fn register_event_handlers(client: &Client, state: Arc>) { /// The dispatcher loop: owns one long-running ShardSession across rooms, /// drains pending_rooms, runs turns, manages refresh. -async fn process_loop(state: Arc>, client: Client, socket_path: PathBuf) { +/// +/// Uses a token bucket on the input side: bucket fills at `rate_per_min`, caps +/// at `rate_burst_capacity`. Events queue in `pending_rooms` until budget +/// covers one. Output is never throttled. +async fn process_loop( + state: Arc>, + client: Client, + socket_path: PathBuf, + notify: Arc, +) { let mcp_config_path = match claude::write_mcp_config(&socket_path).await { Ok(p) => p, Err(e) => { @@ -207,15 +236,21 @@ async fn process_loop(state: Arc>, client: Client, socket_pat let mut session: Option = None; loop { - tokio::time::sleep(std::time::Duration::from_secs(1)).await; + // Wait for an event signal OR a tick (tick lets us reap idle session). + tokio::select! { + _ = notify.notified() => {} + _ = tokio::time::sleep(std::time::Duration::from_secs(2)) => {} + } - let (room_id, model, idle_minutes, max_events) = { + let (popped, model, idle_minutes, max_events) = { let mut s = state.lock().await; - if s.last_rate_reset.elapsed() >= std::time::Duration::from_secs(60) { - s.rate_budget = s.rate_limit_per_min; - s.last_rate_reset = std::time::Instant::now(); - } - if s.rate_budget == 0 { + // Refill bucket based on elapsed time since last check. + let elapsed = s.last_rate_check.elapsed().as_secs_f64(); + let new_tokens = elapsed * (s.rate_limit_per_min as f64) / 60.0; + s.rate_budget = (s.rate_budget + new_tokens).min(s.rate_burst_capacity as f64); + s.last_rate_check = std::time::Instant::now(); + if s.rate_budget < 1.0 { + tracing::debug!(budget = s.rate_budget, "bucket empty, holding"); continue; } ( @@ -227,7 +262,7 @@ async fn process_loop(state: Arc>, client: Client, socket_pat }; // No work? Check if existing session has aged out and reap it. - let Some(room_id) = room_id else { + let Some((room_id, queued_at)) = popped else { if let Some(sess) = &mut session { if sess .should_refresh( @@ -279,10 +314,21 @@ async fn process_loop(state: Arc>, client: Client, socket_pat } } + // Compute delay since this room first entered the queue. If + // significant (>30s), surface it to the shard via a synthetic notice. + let delay = queued_at.elapsed(); + let delay_notice = if delay.as_secs() >= 30 { + Some(delay.as_secs()) + } else { + None + }; + // Process the room. If the turn fails, drop the session and let next // iteration respawn. let sess = session.as_mut().unwrap(); - if let Err(e) = process_room(&state, &client, &room_id, &room, sess).await { + if let Err(e) = + process_room(&state, &client, &room_id, &room, sess, delay_notice).await + { tracing::error!(room = %room_id, "turn failed, dropping session: {e}"); if let Some(s) = session.take() { s.shutdown().await; @@ -297,6 +343,7 @@ async fn process_room( room_id: &OwnedRoomId, room: &Room, session: &mut shard::ShardSession, + delay_notice_seconds: Option, ) -> anyhow::Result<()> { // Snapshot last_shown for this room so we can mark seen vs new. let in_memory = { @@ -404,6 +451,7 @@ async fn process_room( seen_idx, &read_markers, include_history, + delay_notice_seconds, ); let turn_text = claude::turn_to_text(&turn); @@ -422,7 +470,7 @@ async fn process_room( { let mut state = state.lock().await; - state.rate_budget = state.rate_budget.saturating_sub(1); + state.rate_budget = (state.rate_budget - 1.0).max(0.0); if let Some(eid) = new_last_event_id.clone() { state.last_shown.insert(room_id.clone(), eid); } diff --git a/src/socket.rs b/src/socket.rs index b1ac783..32efaae 100644 --- a/src/socket.rs +++ b/src/socket.rs @@ -407,9 +407,10 @@ async fn fetch_event( // The "earlier handle" is the oldest event we got, used by the shard // to page further back via another fetch_event call. let earlier_handle = if context_before > 0 && before.len() > context_before as usize { - before.first().map(|e| match e { - WireEvent::Message { event_id, .. } => event_id.clone(), - WireEvent::Reaction { target_event_id, .. } => target_event_id.clone(), + before.first().and_then(|e| match e { + WireEvent::Message { event_id, .. } => Some(event_id.clone()), + WireEvent::Reaction { target_event_id, .. } => Some(target_event_id.clone()), + WireEvent::Notice { .. } => None, }) } else { None diff --git a/src/types.rs b/src/types.rs index 47a550d..af75430 100644 --- a/src/types.rs +++ b/src/types.rs @@ -32,6 +32,13 @@ pub enum WireEvent { target_event_id_short: String, key: String, }, + /// Synthetic event from the daemon (not a Matrix event). Currently used + /// to tell the shard "you were rate-limited; events held for X seconds." + Notice { + text: String, + ts: i64, + ts_human: String, + }, } #[derive(Debug, Serialize)] @@ -56,6 +63,7 @@ pub struct FetchEventResult { pub const DEFAULT_MODEL: &str = "claude-sonnet-4-6"; pub const DEFAULT_MAX_HISTORY: usize = 20; pub const DEFAULT_RATE_LIMIT_PER_MIN: u32 = 1; +pub const DEFAULT_RATE_BURST_CAPACITY: u32 = 3; pub const DEFAULT_SESSION_IDLE_MINUTES: u64 = 10; pub const DEFAULT_SESSION_MAX_EVENTS: u32 = 100; @@ -65,6 +73,7 @@ pub struct Config { pub username: String, pub password: String, pub rate_limit_per_min: Option, + pub rate_burst_capacity: Option, pub model: Option, pub max_history: Option, pub session_idle_minutes: Option, @@ -132,10 +141,14 @@ pub struct DaemonState { /// Per-room: the latest event_id that's been "shown" to Claude. Events /// after this are "new" on the next invocation. Cleared on daemon restart. pub last_shown: HashMap, - pub pending_rooms: Vec, - pub rate_budget: u32, + /// Rooms with unprocessed events. The Instant is when the room first + /// entered the queue (or last became empty, then refilled). Used to + /// surface rate-limit delays to the shard via a synthetic notice event. + pub pending_rooms: Vec<(OwnedRoomId, std::time::Instant)>, + pub rate_budget: f64, pub rate_limit_per_min: u32, - pub last_rate_reset: std::time::Instant, + pub rate_burst_capacity: u32, + pub last_rate_check: std::time::Instant, pub model: String, pub max_history: usize, pub session_idle_minutes: u64,