diff --git a/src/claude.rs b/src/claude.rs index 6e09cc8..271fddf 100644 --- a/src/claude.rs +++ b/src/claude.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; use std::path::Path; -use anyhow::{Context, bail}; +use anyhow::Context; use matrix_sdk::ruma::{OwnedEventId, OwnedRoomId, OwnedUserId}; use serde::Serialize; @@ -9,98 +9,8 @@ use crate::paths; use crate::timeline::format_ts; use crate::types::{TimelineItem, WireEvent}; -/// Invoke claude with MCP tools. The shard receives a JSON `matrix_turn` -/// describing the room and new events, and calls MCP tools (which carry an -/// explicit room_id) for any actions. Claude's stdout is logged as thought. -pub async fn invoke_claude( - source_room: &OwnedRoomId, - room_name: &str, - timeline: &[TimelineItem], - seen_idx: usize, - model: &str, - read_markers: &HashMap>, - socket_path: &Path, -) -> anyhow::Result<()> { - let identity_dir = paths::identity_dir(); - let identity_str = identity_dir.to_string_lossy(); - - let turn = build_turn(source_room, room_name, timeline, seen_idx, read_markers); - let prompt = format!( - "{TURN_PREAMBLE}\n\n```json\n{}\n```\n", - serde_json::to_string_pretty(&turn).unwrap() - ); - - let mcp_config = build_mcp_config(socket_path)?; - let mcp_config_path = paths::state_dir().join("mcp.json"); - tokio::fs::write(&mcp_config_path, &mcp_config).await?; - - let new_msg_count = timeline[seen_idx..] - .iter() - .filter(|t| matches!(t, TimelineItem::Message { .. })) - .count(); - let new_react_count = timeline.len().saturating_sub(seen_idx) - new_msg_count; - tracing::info!( - "invoking claude: {} new ({} msg + {} react), {} seen", - timeline.len().saturating_sub(seen_idx), - new_msg_count, - new_react_count, - seen_idx - ); - tracing::trace!("full prompt:\n{prompt}"); - - use tokio::process::Command; - let mcp_config_str = mcp_config_path.to_string_lossy(); - let mut cmd = Command::new("claude"); - cmd.args([ - "--print", - "--model", - model, - "--add-dir", - &identity_str, - "--allowedTools", - "Read,Edit,Write,Glob,Grep,mcp__matrix__send_message,mcp__matrix__send_dm,mcp__matrix__send_reaction,mcp__matrix__send_reply,mcp__matrix__list_rooms,mcp__matrix__list_room_members,mcp__matrix__get_room_history,mcp__matrix__fetch_event", - "--mcp-config", - &mcp_config_str, - "-p", - &prompt, - ]); - cmd.current_dir(&identity_dir); - cmd.stdin(std::process::Stdio::null()); - let output = cmd.output().await.context("failed to run claude")?; - - let stderr = String::from_utf8_lossy(&output.stderr); - let stdout = String::from_utf8_lossy(&output.stdout); - - if !output.status.success() { - bail!( - "claude exited with {}:\nstdout: {}\nstderr: {}", - output.status, - stdout, - stderr - ); - } - - if !stderr.is_empty() { - tracing::warn!("claude stderr: {stderr}"); - } - - let text = stdout.trim(); - if !text.is_empty() { - tracing::info!( - "claude thought: {}", - text.chars().take(200).collect::() - ); - tracing::trace!("full claude output: {text}"); - } - - Ok(()) -} - -const TURN_PREAMBLE: &str = "New matrix events for you. JSON envelope follows. \ -The room_id and other fields are explicit - use them in your tool calls."; - #[derive(Debug, Serialize)] -struct MatrixTurn { +pub struct MatrixTurn { #[serde(rename = "type")] kind: &'static str, /// Current wall clock at turn-build time (unix seconds). @@ -111,16 +21,22 @@ struct MatrixTurn { room_name: String, room_notes_path: String, people_in_room: Vec, + /// Older context. Empty when the shard already saw this room within the + /// current session (state lives in the conversation). previously_seen: Vec, new_events: Vec, } -fn build_turn( +/// Build a matrix_turn envelope for one room. If `include_history` is false, +/// the `previously_seen` array is empty (shard already has that context from +/// earlier turns in this session). +pub fn build_turn( source_room: &OwnedRoomId, room_name: &str, timeline: &[TimelineItem], seen_idx: usize, read_markers: &HashMap>, + include_history: bool, ) -> MatrixTurn { let mut senders: Vec<&OwnedUserId> = timeline .iter() @@ -132,10 +48,14 @@ fn build_turn( let people_in_room: Vec = senders.iter().map(|s| s.to_string()).collect(); let seen = seen_idx.min(timeline.len()); - let previously_seen: Vec = timeline[..seen] - .iter() - .map(|i| wire_event_from(i, read_markers)) - .collect(); + let previously_seen: Vec = if include_history { + timeline[..seen] + .iter() + .map(|i| wire_event_from(i, read_markers)) + .collect() + } else { + Vec::new() + }; let new_events: Vec = timeline[seen..] .iter() .map(|i| wire_event_from(i, read_markers)) @@ -159,6 +79,15 @@ fn build_turn( } } +/// Wrap a turn for stream-json delivery: the body is just the JSON envelope +/// inside a ```json code fence so claude recognizes it as structured data. +pub fn turn_to_text(turn: &MatrixTurn) -> String { + format!( + "```json\n{}\n```", + serde_json::to_string_pretty(turn).unwrap() + ) +} + pub fn wire_event_from( item: &TimelineItem, read_markers: &HashMap>, @@ -232,8 +161,8 @@ struct McpServer { env: std::collections::BTreeMap, } -/// Build the MCP config JSON that tells claude how to launch damocles-mcp. -fn build_mcp_config(socket_path: &Path) -> anyhow::Result { +/// Write the MCP config JSON to state/mcp.json and return the path. +pub async fn write_mcp_config(socket_path: &Path) -> anyhow::Result { let mcp_bin = std::env::current_exe()? .parent() .context("no parent dir for current exe")? @@ -256,5 +185,10 @@ fn build_mcp_config(socket_path: &Path) -> anyhow::Result { ); let config = McpConfig { mcp_servers }; - serde_json::to_string_pretty(&config).context("serialize mcp config") + let json = serde_json::to_string_pretty(&config).context("serialize mcp config")?; + let path = paths::state_dir().join("mcp.json"); + tokio::fs::write(&path, &json).await?; + Ok(path) } + +pub const ALLOWED_TOOLS: &str = "Read,Edit,Write,Glob,Grep,mcp__matrix__send_message,mcp__matrix__send_dm,mcp__matrix__send_reaction,mcp__matrix__send_reply,mcp__matrix__list_rooms,mcp__matrix__list_room_members,mcp__matrix__get_room_history,mcp__matrix__fetch_event"; diff --git a/src/main.rs b/src/main.rs index 5f80a1a..1b45bb2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,6 +3,7 @@ mod handlers; mod paths; mod protocol; mod session; +mod shard; mod socket; mod timeline; mod types; @@ -55,6 +56,12 @@ async fn main() -> anyhow::Result<()> { .clone() .unwrap_or_else(|| DEFAULT_MODEL.to_owned()); let max_history = config.max_history.unwrap_or(DEFAULT_MAX_HISTORY); + let session_idle_minutes = config + .session_idle_minutes + .unwrap_or(types::DEFAULT_SESSION_IDLE_MINUTES); + let session_max_events = config + .session_max_events + .unwrap_or(types::DEFAULT_SESSION_MAX_EVENTS); let (client, sync_token) = if session_file.exists() { session::restore_session(&session_file).await? @@ -87,6 +94,8 @@ async fn main() -> anyhow::Result<()> { last_rate_reset: std::time::Instant::now(), model, max_history, + session_idle_minutes, + session_max_events, })); // Start MCP socket listener for tool calls from the shard @@ -102,7 +111,7 @@ async fn main() -> anyhow::Result<()> { let processor_state = state.clone(); let processor_client = client.clone(); tokio::spawn(async move { - process_loop(processor_state, processor_client, &socket_path).await; + process_loop(processor_state, processor_client, socket_path).await; }); sync(client, sync_token, &session_file, state).await @@ -184,26 +193,54 @@ fn register_event_handlers(client: &Client, state: Arc>) { client.add_event_handler(handlers::on_stripped_state_member); } -async fn process_loop(state: Arc>, client: Client, socket_path: &PathBuf) { +/// The dispatcher loop: owns one long-running ShardSession across rooms, +/// drains pending_rooms, runs turns, manages refresh. +async fn process_loop(state: Arc>, client: Client, socket_path: PathBuf) { + let mcp_config_path = match claude::write_mcp_config(&socket_path).await { + Ok(p) => p, + Err(e) => { + tracing::error!("failed to write mcp config: {e}"); + return; + } + }; + + let mut session: Option = None; + loop { tokio::time::sleep(std::time::Duration::from_secs(1)).await; - let room_id = { - let mut state = state.lock().await; - - if state.last_rate_reset.elapsed() >= std::time::Duration::from_secs(60) { - state.rate_budget = state.rate_limit_per_min; - state.last_rate_reset = std::time::Instant::now(); + let (room_id, model, idle_minutes, max_events) = { + let mut s = state.lock().await; + if s.last_rate_reset.elapsed() >= std::time::Duration::from_secs(60) { + s.rate_budget = s.rate_limit_per_min; + s.last_rate_reset = std::time::Instant::now(); } - - if state.rate_budget == 0 { + if s.rate_budget == 0 { continue; } - - state.pending_rooms.pop() + ( + s.pending_rooms.pop(), + s.model.clone(), + s.session_idle_minutes, + s.session_max_events, + ) }; + // No work? Check if existing session has aged out and reap it. let Some(room_id) = room_id else { + if let Some(sess) = &mut session { + if sess + .should_refresh( + std::time::Duration::from_secs(idle_minutes * 60), + max_events, + ) + .is_some() + { + if let Some(s) = session.take() { + s.shutdown().await; + } + } + } continue; }; @@ -212,8 +249,44 @@ async fn process_loop(state: Arc>, client: Client, socket_pat continue; }; - if let Err(e) = process_room(&state, &client, &room_id, &room, socket_path).await { - tracing::error!(room = %room_id, "failed to process room: {e}"); + // Refresh check before we use the session + if let Some(sess) = &mut session { + if let Some(reason) = sess.should_refresh( + std::time::Duration::from_secs(idle_minutes * 60), + max_events, + ) { + tracing::info!("shard refresh: {reason}"); + if let Some(s) = session.take() { + s.shutdown().await; + } + } + } + + // Spawn fresh if needed + if session.is_none() { + match shard::ShardSession::spawn(shard::SpawnConfig { + model: &model, + mcp_config_path: &mcp_config_path, + allowed_tools: claude::ALLOWED_TOOLS, + }) + .await + { + Ok(s) => session = Some(s), + Err(e) => { + tracing::error!("failed to spawn shard: {e}"); + continue; + } + } + } + + // Process the room. If the turn fails, drop the session and let next + // iteration respawn. + let sess = session.as_mut().unwrap(); + if let Err(e) = process_room(&state, &client, &room_id, &room, sess).await { + tracing::error!(room = %room_id, "turn failed, dropping session: {e}"); + if let Some(s) = session.take() { + s.shutdown().await; + } } } } @@ -223,7 +296,7 @@ async fn process_room( client: &Client, room_id: &OwnedRoomId, room: &Room, - socket_path: &PathBuf, + session: &mut shard::ShardSession, ) -> anyhow::Result<()> { // Snapshot last_shown for this room so we can mark seen vs new. let in_memory = { @@ -260,13 +333,9 @@ async fn process_room( .await .map_or_else(|_| room_id.to_string(), |n| n.to_string()); - let (own_user, model, max_history) = { + let (own_user, max_history) = { let state = state.lock().await; - ( - state.own_user_id.clone(), - state.model.clone(), - state.max_history, - ) + (state.own_user_id.clone(), state.max_history) }; let mut tl = timeline::load_timeline(room, max_history, &own_user).await?; @@ -326,30 +395,31 @@ async fn process_room( let read_markers = timeline::compute_read_markers(room, &tl, &own_user).await; - if let Err(e) = room.typing_notice(true).await { - tracing::debug!(room = %room_id, "failed to send typing start: {e}"); - } - - let invoke_result = claude::invoke_claude( + // First time this room appears in this shard session? Include history. + let include_history = !session.rooms_seen.contains(room_id); + let turn = claude::build_turn( room_id, &room_name, &tl, seen_idx, - &model, &read_markers, - socket_path, - ) - .await; + include_history, + ); + let turn_text = claude::turn_to_text(&turn); + + if let Err(e) = room.typing_notice(true).await { + tracing::debug!(room = %room_id, "failed to send typing start: {e}"); + } + + let result = session.run_turn(&turn_text).await; if let Err(e) = room.typing_notice(false).await { tracing::debug!(room = %room_id, "failed to send typing stop: {e}"); } - if let Err(e) = invoke_result { - tracing::error!(room = %room_id, "claude invocation failed: {e}"); - } + result?; + session.rooms_seen.insert(room_id.clone()); - // Decrement rate budget per invocation (not per message - MCP handles sends) { let mut state = state.lock().await; state.rate_budget = state.rate_budget.saturating_sub(1); diff --git a/src/shard.rs b/src/shard.rs new file mode 100644 index 0000000..76ccfe2 --- /dev/null +++ b/src/shard.rs @@ -0,0 +1,307 @@ +//! Long-running claude process for the shard. Spawned once, fed turns via +//! stdin (stream-json), emits events via stdout. Survives across many turns +//! within configured limits, then respawns fresh. + +use std::collections::HashSet; +use std::path::{Path, PathBuf}; +use std::process::Stdio; +use std::time::{Duration, Instant}; + +use anyhow::{Context, bail}; +use matrix_sdk::ruma::OwnedRoomId; +use serde::{Deserialize, Serialize}; +use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; +use tokio::process::{Child, ChildStdin, Command}; +use tokio::sync::mpsc; + +use crate::paths; + +/// One stream-json event from claude's stdout. Only fields we care about. +#[derive(Debug, Deserialize)] +#[serde(tag = "type")] +pub enum StreamEvent { + #[serde(rename = "system")] + System { + #[serde(default)] + subtype: String, + #[serde(default)] + session_id: Option, + }, + #[serde(rename = "assistant")] + Assistant { + #[serde(default)] + message: serde_json::Value, + }, + #[serde(rename = "result")] + Result { + #[serde(default)] + subtype: String, + #[serde(default)] + is_error: bool, + #[serde(default)] + result: String, + #[serde(default)] + stop_reason: Option, + }, + #[serde(rename = "rate_limit_event")] + RateLimit {}, + #[serde(other)] + Other, +} + +#[derive(Debug, Serialize)] +struct UserMessage<'a> { + #[serde(rename = "type")] + kind: &'static str, + message: UserMessageBody<'a>, +} + +#[derive(Debug, Serialize)] +struct UserMessageBody<'a> { + role: &'static str, + content: Vec>, +} + +#[derive(Debug, Serialize)] +struct UserContent<'a> { + #[serde(rename = "type")] + kind: &'static str, + text: &'a str, +} + +/// A live claude shard process. Owns stdin and an mpsc receiver for stdout +/// events. When it exits or is dropped, the child is cleaned up. +pub struct ShardSession { + child: Child, + stdin: ChildStdin, + events: mpsc::Receiver, + /// When the session was spawned. + started: Instant, + /// Last time a turn finished. + last_used: Instant, + /// Number of turns processed. + turn_count: u32, + /// Rooms we've sent at least one turn for in this session. Used to decide + /// whether to include `previously_seen` context in a turn. + pub rooms_seen: HashSet, + /// Mtimes of identity/CHANGELOG files at session start - for refresh. + mtime_snapshot: Vec<(PathBuf, std::time::SystemTime)>, +} + +pub struct SpawnConfig<'a> { + pub model: &'a str, + pub mcp_config_path: &'a Path, + pub allowed_tools: &'a str, +} + +impl ShardSession { + pub async fn spawn(cfg: SpawnConfig<'_>) -> anyhow::Result { + let identity_dir = paths::identity_dir(); + + let mut cmd = Command::new("claude"); + cmd.args([ + "--print", + "--input-format", + "stream-json", + "--output-format", + "stream-json", + "--model", + cfg.model, + "--add-dir", + &identity_dir.to_string_lossy(), + "--allowedTools", + cfg.allowed_tools, + "--mcp-config", + &cfg.mcp_config_path.to_string_lossy(), + ]); + cmd.current_dir(&identity_dir); + cmd.stdin(Stdio::piped()); + cmd.stdout(Stdio::piped()); + cmd.stderr(Stdio::piped()); + cmd.kill_on_drop(true); + + let mut child = cmd.spawn().context("spawn claude")?; + let stdin = child.stdin.take().context("claude stdin missing")?; + let stdout = child.stdout.take().context("claude stdout missing")?; + let stderr = child.stderr.take().context("claude stderr missing")?; + + let (tx, rx) = mpsc::channel(256); + + // stdout reader: parse stream-json line by line, push events + let tx_out = tx.clone(); + tokio::spawn(async move { + let mut lines = BufReader::new(stdout).lines(); + while let Ok(Some(line)) = lines.next_line().await { + if line.trim().is_empty() { + continue; + } + let ev: StreamEvent = match serde_json::from_str(&line) { + Ok(v) => v, + Err(e) => { + tracing::debug!(?e, line = %line, "shard: failed to parse stream-json line"); + continue; + } + }; + if tx_out.send(ev).await.is_err() { + break; + } + } + tracing::info!("shard: stdout stream closed"); + }); + + // stderr drainer: log claude's stderr at warn level + tokio::spawn(async move { + let mut lines = BufReader::new(stderr).lines(); + while let Ok(Some(line)) = lines.next_line().await { + tracing::warn!("claude stderr: {line}"); + } + }); + + let mtime_snapshot = snapshot_identity_mtimes(); + tracing::info!("shard: spawned"); + + Ok(Self { + child, + stdin, + events: rx, + started: Instant::now(), + last_used: Instant::now(), + turn_count: 0, + rooms_seen: HashSet::new(), + mtime_snapshot, + }) + } + + /// Send one user-turn JSON to claude and wait for the next `result`. + /// Returns the result text (claude's final assistant output). + pub async fn run_turn(&mut self, turn_body: &str) -> anyhow::Result { + let msg = UserMessage { + kind: "user", + message: UserMessageBody { + role: "user", + content: vec![UserContent { + kind: "text", + text: turn_body, + }], + }, + }; + let mut json = serde_json::to_string(&msg)?; + json.push('\n'); + self.stdin.write_all(json.as_bytes()).await?; + self.stdin.flush().await?; + + // Drain events until we hit a `result` (turn end). + loop { + let ev = self + .events + .recv() + .await + .context("shard: stdout closed before result")?; + match ev { + StreamEvent::System { .. } | StreamEvent::RateLimit {} | StreamEvent::Other => {} + StreamEvent::Assistant { message } => { + log_assistant_text(&message); + } + StreamEvent::Result { + is_error, + result, + stop_reason, + .. + } => { + self.last_used = Instant::now(); + self.turn_count += 1; + if is_error { + bail!("turn ended with is_error=true: {result}"); + } + tracing::info!( + turn = self.turn_count, + stop_reason = ?stop_reason, + "shard: turn complete" + ); + return Ok(result); + } + } + } + } + + pub fn should_refresh(&mut self, idle: Duration, max_turns: u32) -> Option<&'static str> { + if self.last_used.elapsed() > idle { + return Some("idle gap exceeded"); + } + if self.turn_count >= max_turns { + return Some("max turn count reached"); + } + if mtimes_changed(&self.mtime_snapshot) { + return Some("identity files changed"); + } + if self.child.try_wait().ok().flatten().is_some() { + return Some("child process exited"); + } + None + } + + pub async fn shutdown(mut self) { + // Close stdin → claude exits gracefully on EOF + drop(self.stdin); + // Bounded wait for clean exit + let _ = tokio::time::timeout(Duration::from_secs(3), self.child.wait()).await; + let _ = self.child.kill().await; + tracing::info!( + turns = self.turn_count, + uptime = ?self.started.elapsed(), + "shard: session ended" + ); + } +} + +fn log_assistant_text(message: &serde_json::Value) { + let Some(content) = message.get("content").and_then(|c| c.as_array()) else { + return; + }; + for item in content { + let Some(kind) = item.get("type").and_then(|t| t.as_str()) else { + continue; + }; + match kind { + "text" => { + if let Some(text) = item.get("text").and_then(|t| t.as_str()) { + let preview: String = text.chars().take(200).collect(); + tracing::info!("shard text: {preview}"); + tracing::trace!("shard full text: {text}"); + } + } + "thinking" => { + if let Some(text) = item.get("thinking").and_then(|t| t.as_str()) { + tracing::trace!("shard thinking: {text}"); + } + } + _ => {} + } + } +} + +fn snapshot_identity_mtimes() -> Vec<(PathBuf, std::time::SystemTime)> { + let id = paths::identity_dir(); + let state = paths::state_dir(); + let candidates = [ + id.join("CLAUDE.md"), + id.join("SYSTEM.md"), + id.join("notes.md"), + state.join("CHANGELOG.md"), + ]; + candidates + .into_iter() + .filter_map(|p| std::fs::metadata(&p).and_then(|m| m.modified()).ok().map(|t| (p, t))) + .collect() +} + +fn mtimes_changed(snapshot: &[(PathBuf, std::time::SystemTime)]) -> bool { + for (path, prev) in snapshot { + if let Ok(now) = std::fs::metadata(path).and_then(|m| m.modified()) { + if &now != prev { + return true; + } + } + } + false +} diff --git a/src/types.rs b/src/types.rs index c662177..47a550d 100644 --- a/src/types.rs +++ b/src/types.rs @@ -56,6 +56,8 @@ pub struct FetchEventResult { pub const DEFAULT_MODEL: &str = "claude-sonnet-4-6"; pub const DEFAULT_MAX_HISTORY: usize = 20; pub const DEFAULT_RATE_LIMIT_PER_MIN: u32 = 1; +pub const DEFAULT_SESSION_IDLE_MINUTES: u64 = 10; +pub const DEFAULT_SESSION_MAX_EVENTS: u32 = 100; #[derive(Debug, Deserialize)] pub struct Config { @@ -65,6 +67,8 @@ pub struct Config { pub rate_limit_per_min: Option, pub model: Option, pub max_history: Option, + pub session_idle_minutes: Option, + pub session_max_events: Option, } #[derive(Debug, Serialize, Deserialize)] @@ -134,5 +138,7 @@ pub struct DaemonState { pub last_rate_reset: std::time::Instant, pub model: String, pub max_history: usize, + pub session_idle_minutes: u64, + pub session_max_events: u32, }