diff --git a/src/main.rs b/src/main.rs index 23fd631..decac78 100644 --- a/src/main.rs +++ b/src/main.rs @@ -35,9 +35,11 @@ struct Config { password: String, rate_limit_per_min: Option, model: Option, + max_history: Option, } const DEFAULT_MODEL: &str = "claude-sonnet-4-6"; +const DEFAULT_MAX_HISTORY: usize = 20; #[derive(Debug, Serialize, Deserialize)] struct PersistedSession { @@ -67,9 +69,9 @@ struct DaemonState { rate_limit_per_min: u32, last_rate_reset: std::time::Instant, model: String, + max_history: usize, } -const MAX_HISTORY: usize = 20; const DEFAULT_RATE_LIMIT_PER_MIN: u32 = 1; #[tokio::main] @@ -97,6 +99,7 @@ async fn main() -> anyhow::Result<()> { .model .clone() .unwrap_or_else(|| DEFAULT_MODEL.to_owned()); + let max_history = config.max_history.unwrap_or(DEFAULT_MAX_HISTORY); let (client, sync_token) = if session_file.exists() { restore_session(&session_file).await? @@ -126,6 +129,7 @@ async fn main() -> anyhow::Result<()> { rate_limit_per_min, last_rate_reset: std::time::Instant::now(), model, + max_history, })); let processor_state = state.clone(); @@ -480,8 +484,17 @@ async fn process_loop(state: Arc>, client: Client) { .await .map_or_else(|_| room_id.to_string(), |n| n.to_string()); + let (own_user, model, max_history) = { + let state = state.lock().await; + ( + state.own_user_id.clone(), + state.model.clone(), + state.max_history, + ) + }; + // Load recent history from matrix-sdk's persistent event cache - let history = match load_recent_messages(&room, MAX_HISTORY).await { + let mut history = match load_recent_messages(&room, max_history).await { Ok(h) => h, Err(e) => { tracing::error!(room = %room_id, "failed to load history: {e}"); @@ -489,14 +502,37 @@ async fn process_loop(state: Arc>, client: Client) { } }; - let (own_user, model) = { - let state = state.lock().await; - (state.own_user_id.clone(), state.model.clone()) - }; + // For any new messages that reply to events outside the window, fetch + // the replied-to event from cache and prepend it as extra context. + let seen_idx_initial = prev_last_shown + .as_ref() + .and_then(|id| history.iter().position(|(eid, _, _, _, _)| eid == id)) + .map_or(0, |pos| pos + 1); + let in_window: std::collections::HashSet = history + .iter() + .map(|(eid, _, _, _, _)| eid.clone()) + .collect(); + let mut reply_targets: Vec = Vec::new(); + for (_, _, _, _, in_reply_to) in history.iter().skip(seen_idx_initial) { + if let Some(target) = in_reply_to { + if !in_window.contains(target) && !reply_targets.contains(target) { + reply_targets.push(target.clone()); + } + } + } + if !reply_targets.is_empty() { + if let Ok((cache, _h)) = room.event_cache().await { + for target in &reply_targets { + if let Some(found) = fetch_message(&cache, target).await { + history.insert(0, found); + } + } + } + } let chat_msgs: Vec = history .iter() - .map(|(_, sender, body, ts)| ChatMessage { + .map(|(_, sender, body, ts, _)| ChatMessage { sender: sender.clone(), body: body.clone(), is_self: sender == &own_user, @@ -507,10 +543,10 @@ async fn process_loop(state: Arc>, client: Client) { // Determine seen split: everything before (and including) prev_last_shown is "seen" let seen_idx = prev_last_shown .as_ref() - .and_then(|id| history.iter().position(|(eid, _, _, _)| eid == id)) + .and_then(|id| history.iter().position(|(eid, _, _, _, _)| eid == id)) .map_or(0, |pos| pos + 1); - let new_last_event_id = history.last().map(|(eid, _, _, _)| eid.clone()); + let new_last_event_id = history.last().map(|(eid, _, _, _, _)| eid.clone()); match invoke_claude(&room_id, &room_name, &chat_msgs, seen_idx, &model).await { Ok(Some(response)) => { @@ -568,16 +604,17 @@ async fn send_read_receipt(room: &Room, event_id: Option) { /// Load the last N text messages from the room's persistent event cache. /// Returns oldest-first list of (event_id, sender, body, ts_secs). +/// Returns oldest-first list of (event_id, sender, body, ts_secs, in_reply_to). async fn load_recent_messages( room: &Room, limit: usize, -) -> anyhow::Result> { +) -> anyhow::Result)>> { use matrix_sdk::ruma::events::AnySyncTimelineEvent; let (cache, _handles) = room.event_cache().await?; let events = cache.events().await; - let mut out: Vec<(OwnedEventId, OwnedUserId, String, i64)> = Vec::new(); + let mut out: Vec<(OwnedEventId, OwnedUserId, String, i64, Option)> = Vec::new(); for ev in events.iter().rev() { if out.len() >= limit { break; @@ -594,11 +631,18 @@ async fn load_recent_messages( if let MessageType::Text(text) = &orig.content.msgtype { let ts_ms: u64 = orig.origin_server_ts.0.into(); let ts_secs: i64 = i64::try_from(ts_ms).unwrap_or(0) / 1000; + let in_reply_to = match &orig.content.relates_to { + Some(matrix_sdk::ruma::events::room::message::Relation::Reply { + in_reply_to, + }) => Some(in_reply_to.event_id.clone()), + _ => None, + }; out.push(( orig.event_id.clone(), orig.sender.clone(), text.body.clone(), ts_secs, + in_reply_to, )); } } @@ -609,6 +653,38 @@ async fn load_recent_messages( Ok(out) } +/// Fetch a single text message by event_id from the room's event cache. +async fn fetch_message( + cache: &matrix_sdk::event_cache::RoomEventCache, + event_id: &matrix_sdk::ruma::EventId, +) -> Option<(OwnedEventId, OwnedUserId, String, i64, Option)> { + use matrix_sdk::ruma::events::AnySyncTimelineEvent; + + let ev = cache.find_event(event_id).await?; + let deserialized = ev.raw().deserialize().ok()?; + let AnySyncTimelineEvent::MessageLike(msg) = deserialized else { + return None; + }; + let matrix_sdk::ruma::events::AnySyncMessageLikeEvent::RoomMessage( + matrix_sdk::ruma::events::SyncMessageLikeEvent::Original(orig), + ) = msg + else { + return None; + }; + let MessageType::Text(text) = &orig.content.msgtype else { + return None; + }; + let ts_ms: u64 = orig.origin_server_ts.0.into(); + let ts_secs: i64 = i64::try_from(ts_ms).unwrap_or(0) / 1000; + Some(( + orig.event_id.clone(), + orig.sender.clone(), + text.body.clone(), + ts_secs, + None, + )) +} + struct ClaudeResponse { room: OwnedRoomId, body: String,