token bucket rate limiting with burst + synthetic notice on delay
This commit is contained in:
parent
d4b8aa731b
commit
496bb5484a
5 changed files with 108 additions and 30 deletions
|
|
@ -29,7 +29,9 @@ pub struct MatrixTurn {
|
|||
|
||||
/// Build a matrix_turn envelope for one room. If `include_history` is false,
|
||||
/// the `previously_seen` array is empty (shard already has that context from
|
||||
/// earlier turns in this session).
|
||||
/// earlier turns in this session). If `delay_notice_seconds` is `Some(n)`, a
|
||||
/// synthetic Notice event is prepended to `new_events` informing the shard
|
||||
/// that this turn was held by rate limiting for that many seconds.
|
||||
pub fn build_turn(
|
||||
source_room: &OwnedRoomId,
|
||||
room_name: &str,
|
||||
|
|
@ -37,6 +39,7 @@ pub fn build_turn(
|
|||
seen_idx: usize,
|
||||
read_markers: &HashMap<OwnedEventId, Vec<OwnedUserId>>,
|
||||
include_history: bool,
|
||||
delay_notice_seconds: Option<u64>,
|
||||
) -> MatrixTurn {
|
||||
let mut senders: Vec<&OwnedUserId> = timeline
|
||||
.iter()
|
||||
|
|
@ -56,7 +59,7 @@ pub fn build_turn(
|
|||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
let new_events: Vec<WireEvent> = timeline[seen..]
|
||||
let mut new_events: Vec<WireEvent> = timeline[seen..]
|
||||
.iter()
|
||||
.map(|i| wire_event_from(i, read_markers))
|
||||
.collect();
|
||||
|
|
@ -66,6 +69,19 @@ pub fn build_turn(
|
|||
.map(|d| d.as_secs() as i64)
|
||||
.unwrap_or(0);
|
||||
|
||||
if let Some(secs) = delay_notice_seconds {
|
||||
new_events.insert(
|
||||
0,
|
||||
WireEvent::Notice {
|
||||
text: format!(
|
||||
"rate_limit: events were held for {secs}s before reaching you. context may be slightly stale; respond accordingly."
|
||||
),
|
||||
ts: now,
|
||||
ts_human: format!("{} UTC", format_ts(now)),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
MatrixTurn {
|
||||
kind: "matrix_turn",
|
||||
now,
|
||||
|
|
|
|||
|
|
@ -59,8 +59,8 @@ pub async fn on_room_message(
|
|||
|
||||
if !is_self {
|
||||
let mut state = state.lock().await;
|
||||
if !state.pending_rooms.contains(&room_id) {
|
||||
state.pending_rooms.push(room_id);
|
||||
if !state.pending_rooms.iter().any(|(r, _)| r == &room_id) {
|
||||
state.pending_rooms.push((room_id, std::time::Instant::now()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -86,8 +86,8 @@ pub async fn on_reaction(
|
|||
"reaction"
|
||||
);
|
||||
|
||||
if !is_self && !state.pending_rooms.contains(&room_id) {
|
||||
state.pending_rooms.push(room_id);
|
||||
if !is_self && !state.pending_rooms.iter().any(|(r, _)| r == &room_id) {
|
||||
state.pending_rooms.push((room_id, std::time::Instant::now()));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
84
src/main.rs
84
src/main.rs
|
|
@ -62,6 +62,9 @@ async fn main() -> anyhow::Result<()> {
|
|||
let session_max_events = config
|
||||
.session_max_events
|
||||
.unwrap_or(types::DEFAULT_SESSION_MAX_EVENTS);
|
||||
let rate_burst_capacity = config
|
||||
.rate_burst_capacity
|
||||
.unwrap_or(types::DEFAULT_RATE_BURST_CAPACITY);
|
||||
|
||||
let (client, sync_token) = if session_file.exists() {
|
||||
session::restore_session(&session_file).await?
|
||||
|
|
@ -89,15 +92,20 @@ async fn main() -> anyhow::Result<()> {
|
|||
own_user_id,
|
||||
last_shown: std::collections::HashMap::new(),
|
||||
pending_rooms: Vec::new(),
|
||||
rate_budget: rate_limit_per_min,
|
||||
// Start with a full bucket so the first event after startup is fast
|
||||
rate_budget: rate_burst_capacity as f64,
|
||||
rate_limit_per_min,
|
||||
last_rate_reset: std::time::Instant::now(),
|
||||
rate_burst_capacity,
|
||||
last_rate_check: std::time::Instant::now(),
|
||||
model,
|
||||
max_history,
|
||||
session_idle_minutes,
|
||||
session_max_events,
|
||||
}));
|
||||
|
||||
// Notify dispatcher when new events arrive (instant wake-up)
|
||||
let dispatch_notify = Arc::new(tokio::sync::Notify::new());
|
||||
|
||||
// Start MCP socket listener for tool calls from the shard
|
||||
let socket_path = paths::state_dir().join("daemon.sock");
|
||||
let socket_client = client.clone();
|
||||
|
|
@ -110,11 +118,12 @@ async fn main() -> anyhow::Result<()> {
|
|||
|
||||
let processor_state = state.clone();
|
||||
let processor_client = client.clone();
|
||||
let processor_notify = dispatch_notify.clone();
|
||||
tokio::spawn(async move {
|
||||
process_loop(processor_state, processor_client, socket_path).await;
|
||||
process_loop(processor_state, processor_client, socket_path, processor_notify).await;
|
||||
});
|
||||
|
||||
sync(client, sync_token, &session_file, state).await
|
||||
sync(client, sync_token, &session_file, state, dispatch_notify).await
|
||||
}
|
||||
|
||||
async fn sync(
|
||||
|
|
@ -122,6 +131,7 @@ async fn sync(
|
|||
initial_sync_token: Option<String>,
|
||||
session_file: &std::path::Path,
|
||||
state: Arc<Mutex<DaemonState>>,
|
||||
notify: Arc<tokio::sync::Notify>,
|
||||
) -> anyhow::Result<()> {
|
||||
let has_token = initial_sync_token.is_some();
|
||||
if has_token {
|
||||
|
|
@ -141,7 +151,7 @@ async fn sync(
|
|||
// received while we were down trigger the queue. On first start we skip
|
||||
// this to avoid backlogging every historical message.
|
||||
if has_token {
|
||||
register_event_handlers(&client, state.clone());
|
||||
register_event_handlers(&client, state.clone(), notify.clone());
|
||||
}
|
||||
|
||||
loop {
|
||||
|
|
@ -158,7 +168,7 @@ async fn sync(
|
|||
}
|
||||
|
||||
if !has_token {
|
||||
register_event_handlers(&client, state.clone());
|
||||
register_event_handlers(&client, state.clone(), notify.clone());
|
||||
}
|
||||
|
||||
tracing::info!("synced, listening for messages");
|
||||
|
|
@ -168,24 +178,34 @@ async fn sync(
|
|||
bail!("sync loop exited unexpectedly")
|
||||
}
|
||||
|
||||
fn register_event_handlers(client: &Client, state: Arc<Mutex<DaemonState>>) {
|
||||
fn register_event_handlers(
|
||||
client: &Client,
|
||||
state: Arc<Mutex<DaemonState>>,
|
||||
notify: Arc<tokio::sync::Notify>,
|
||||
) {
|
||||
let msg_state = state.clone();
|
||||
let msg_notify = notify.clone();
|
||||
client.add_event_handler(
|
||||
move |event: matrix_sdk::ruma::events::room::message::OriginalSyncRoomMessageEvent,
|
||||
room: Room| {
|
||||
let state = msg_state.clone();
|
||||
let notify = msg_notify.clone();
|
||||
async move {
|
||||
handlers::on_room_message(event, room, state).await;
|
||||
notify.notify_one();
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
let react_state = state.clone();
|
||||
let react_notify = notify.clone();
|
||||
client.add_event_handler(
|
||||
move |event: matrix_sdk::ruma::events::reaction::OriginalSyncReactionEvent, room: Room| {
|
||||
let state = react_state.clone();
|
||||
let notify = react_notify.clone();
|
||||
async move {
|
||||
handlers::on_reaction(event, room, state).await;
|
||||
notify.notify_one();
|
||||
}
|
||||
},
|
||||
);
|
||||
|
|
@ -195,7 +215,16 @@ fn register_event_handlers(client: &Client, state: Arc<Mutex<DaemonState>>) {
|
|||
|
||||
/// The dispatcher loop: owns one long-running ShardSession across rooms,
|
||||
/// drains pending_rooms, runs turns, manages refresh.
|
||||
async fn process_loop(state: Arc<Mutex<DaemonState>>, client: Client, socket_path: PathBuf) {
|
||||
///
|
||||
/// Uses a token bucket on the input side: bucket fills at `rate_per_min`, caps
|
||||
/// at `rate_burst_capacity`. Events queue in `pending_rooms` until budget
|
||||
/// covers one. Output is never throttled.
|
||||
async fn process_loop(
|
||||
state: Arc<Mutex<DaemonState>>,
|
||||
client: Client,
|
||||
socket_path: PathBuf,
|
||||
notify: Arc<tokio::sync::Notify>,
|
||||
) {
|
||||
let mcp_config_path = match claude::write_mcp_config(&socket_path).await {
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
|
|
@ -207,15 +236,21 @@ async fn process_loop(state: Arc<Mutex<DaemonState>>, client: Client, socket_pat
|
|||
let mut session: Option<shard::ShardSession> = None;
|
||||
|
||||
loop {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
// Wait for an event signal OR a tick (tick lets us reap idle session).
|
||||
tokio::select! {
|
||||
_ = notify.notified() => {}
|
||||
_ = tokio::time::sleep(std::time::Duration::from_secs(2)) => {}
|
||||
}
|
||||
|
||||
let (room_id, model, idle_minutes, max_events) = {
|
||||
let (popped, model, idle_minutes, max_events) = {
|
||||
let mut s = state.lock().await;
|
||||
if s.last_rate_reset.elapsed() >= std::time::Duration::from_secs(60) {
|
||||
s.rate_budget = s.rate_limit_per_min;
|
||||
s.last_rate_reset = std::time::Instant::now();
|
||||
}
|
||||
if s.rate_budget == 0 {
|
||||
// Refill bucket based on elapsed time since last check.
|
||||
let elapsed = s.last_rate_check.elapsed().as_secs_f64();
|
||||
let new_tokens = elapsed * (s.rate_limit_per_min as f64) / 60.0;
|
||||
s.rate_budget = (s.rate_budget + new_tokens).min(s.rate_burst_capacity as f64);
|
||||
s.last_rate_check = std::time::Instant::now();
|
||||
if s.rate_budget < 1.0 {
|
||||
tracing::debug!(budget = s.rate_budget, "bucket empty, holding");
|
||||
continue;
|
||||
}
|
||||
(
|
||||
|
|
@ -227,7 +262,7 @@ async fn process_loop(state: Arc<Mutex<DaemonState>>, client: Client, socket_pat
|
|||
};
|
||||
|
||||
// No work? Check if existing session has aged out and reap it.
|
||||
let Some(room_id) = room_id else {
|
||||
let Some((room_id, queued_at)) = popped else {
|
||||
if let Some(sess) = &mut session {
|
||||
if sess
|
||||
.should_refresh(
|
||||
|
|
@ -279,10 +314,21 @@ async fn process_loop(state: Arc<Mutex<DaemonState>>, client: Client, socket_pat
|
|||
}
|
||||
}
|
||||
|
||||
// Compute delay since this room first entered the queue. If
|
||||
// significant (>30s), surface it to the shard via a synthetic notice.
|
||||
let delay = queued_at.elapsed();
|
||||
let delay_notice = if delay.as_secs() >= 30 {
|
||||
Some(delay.as_secs())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Process the room. If the turn fails, drop the session and let next
|
||||
// iteration respawn.
|
||||
let sess = session.as_mut().unwrap();
|
||||
if let Err(e) = process_room(&state, &client, &room_id, &room, sess).await {
|
||||
if let Err(e) =
|
||||
process_room(&state, &client, &room_id, &room, sess, delay_notice).await
|
||||
{
|
||||
tracing::error!(room = %room_id, "turn failed, dropping session: {e}");
|
||||
if let Some(s) = session.take() {
|
||||
s.shutdown().await;
|
||||
|
|
@ -297,6 +343,7 @@ async fn process_room(
|
|||
room_id: &OwnedRoomId,
|
||||
room: &Room,
|
||||
session: &mut shard::ShardSession,
|
||||
delay_notice_seconds: Option<u64>,
|
||||
) -> anyhow::Result<()> {
|
||||
// Snapshot last_shown for this room so we can mark seen vs new.
|
||||
let in_memory = {
|
||||
|
|
@ -404,6 +451,7 @@ async fn process_room(
|
|||
seen_idx,
|
||||
&read_markers,
|
||||
include_history,
|
||||
delay_notice_seconds,
|
||||
);
|
||||
let turn_text = claude::turn_to_text(&turn);
|
||||
|
||||
|
|
@ -422,7 +470,7 @@ async fn process_room(
|
|||
|
||||
{
|
||||
let mut state = state.lock().await;
|
||||
state.rate_budget = state.rate_budget.saturating_sub(1);
|
||||
state.rate_budget = (state.rate_budget - 1.0).max(0.0);
|
||||
if let Some(eid) = new_last_event_id.clone() {
|
||||
state.last_shown.insert(room_id.clone(), eid);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -407,9 +407,10 @@ async fn fetch_event(
|
|||
// The "earlier handle" is the oldest event we got, used by the shard
|
||||
// to page further back via another fetch_event call.
|
||||
let earlier_handle = if context_before > 0 && before.len() > context_before as usize {
|
||||
before.first().map(|e| match e {
|
||||
WireEvent::Message { event_id, .. } => event_id.clone(),
|
||||
WireEvent::Reaction { target_event_id, .. } => target_event_id.clone(),
|
||||
before.first().and_then(|e| match e {
|
||||
WireEvent::Message { event_id, .. } => Some(event_id.clone()),
|
||||
WireEvent::Reaction { target_event_id, .. } => Some(target_event_id.clone()),
|
||||
WireEvent::Notice { .. } => None,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
|
|
|
|||
19
src/types.rs
19
src/types.rs
|
|
@ -32,6 +32,13 @@ pub enum WireEvent {
|
|||
target_event_id_short: String,
|
||||
key: String,
|
||||
},
|
||||
/// Synthetic event from the daemon (not a Matrix event). Currently used
|
||||
/// to tell the shard "you were rate-limited; events held for X seconds."
|
||||
Notice {
|
||||
text: String,
|
||||
ts: i64,
|
||||
ts_human: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
|
|
@ -56,6 +63,7 @@ pub struct FetchEventResult {
|
|||
pub const DEFAULT_MODEL: &str = "claude-sonnet-4-6";
|
||||
pub const DEFAULT_MAX_HISTORY: usize = 20;
|
||||
pub const DEFAULT_RATE_LIMIT_PER_MIN: u32 = 1;
|
||||
pub const DEFAULT_RATE_BURST_CAPACITY: u32 = 3;
|
||||
pub const DEFAULT_SESSION_IDLE_MINUTES: u64 = 10;
|
||||
pub const DEFAULT_SESSION_MAX_EVENTS: u32 = 100;
|
||||
|
||||
|
|
@ -65,6 +73,7 @@ pub struct Config {
|
|||
pub username: String,
|
||||
pub password: String,
|
||||
pub rate_limit_per_min: Option<u32>,
|
||||
pub rate_burst_capacity: Option<u32>,
|
||||
pub model: Option<String>,
|
||||
pub max_history: Option<usize>,
|
||||
pub session_idle_minutes: Option<u64>,
|
||||
|
|
@ -132,10 +141,14 @@ pub struct DaemonState {
|
|||
/// Per-room: the latest event_id that's been "shown" to Claude. Events
|
||||
/// after this are "new" on the next invocation. Cleared on daemon restart.
|
||||
pub last_shown: HashMap<OwnedRoomId, OwnedEventId>,
|
||||
pub pending_rooms: Vec<OwnedRoomId>,
|
||||
pub rate_budget: u32,
|
||||
/// Rooms with unprocessed events. The Instant is when the room first
|
||||
/// entered the queue (or last became empty, then refilled). Used to
|
||||
/// surface rate-limit delays to the shard via a synthetic notice event.
|
||||
pub pending_rooms: Vec<(OwnedRoomId, std::time::Instant)>,
|
||||
pub rate_budget: f64,
|
||||
pub rate_limit_per_min: u32,
|
||||
pub last_rate_reset: std::time::Instant,
|
||||
pub rate_burst_capacity: u32,
|
||||
pub last_rate_check: std::time::Instant,
|
||||
pub model: String,
|
||||
pub max_history: usize,
|
||||
pub session_idle_minutes: u64,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue