diff --git a/Cargo.lock b/Cargo.lock index da0504a..911bfb3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -606,6 +606,7 @@ dependencies = [ "anyhow", "futures-util", "matrix-sdk", + "rmcp", "serde", "serde_json", "tokio", @@ -764,6 +765,12 @@ version = "2.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "117240f60069e65410b3ae1bb213295bd828f707b5bec6596a1afc8793ce0cbc" +[[package]] +name = "dyn-clone" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" + [[package]] name = "ed25519" version = "2.2.3" @@ -973,6 +980,21 @@ dependencies = [ "new_debug_unreachable", ] +[[package]] +name = "futures" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b147ee9d1f6d097cef9ce628cd2ee62288d963e16fb287bd9286455b241382d" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.32" @@ -980,6 +1002,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07bbe89c50d7a535e539b8c17bc0b49bdb77747034daa8087407d655f3f7cc1d" dependencies = [ "futures-core", + "futures-sink", ] [[package]] @@ -2405,6 +2428,12 @@ dependencies = [ "windows-link", ] +[[package]] +name = "pastey" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c5a797f0e07bdf071d15742978fc3128ec6c22891c31a3a931513263904c982a" + [[package]] name = "pbkdf2" version = "0.12.2" @@ -2737,6 +2766,26 @@ dependencies = [ "bitflags", ] +[[package]] +name = "ref-cast" +version = "1.0.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f354300ae66f76f1c85c5f84693f0ce81d747e2c3f21a45fef496d89c960bf7d" +dependencies = [ + "ref-cast-impl", +] + +[[package]] +name = "ref-cast-impl" +version = "1.0.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7186006dcb21920990093f30e3dea63b7d6e977bf1256be20c3563a5db070da" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "regex" version = "1.12.3" @@ -2821,6 +2870,41 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "rmcp" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67d69668de0b0ccd9cc435f700f3b39a7861863cf37a15e1f304ea78688a4826" +dependencies = [ + "async-trait", + "base64", + "chrono", + "futures", + "pastey", + "pin-project-lite", + "rmcp-macros", + "schemars", + "serde", + "serde_json", + "thiserror 2.0.18", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "rmcp-macros" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48fdc01c81097b0aed18633e676e269fefa3a78ec1df56b4fe597c1241b92025" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "serde_json", + "syn", +] + [[package]] name = "rmp" version = "0.8.15" @@ -3122,6 +3206,32 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "schemars" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2b42f36aa1cd011945615b92222f6bf73c599a102a300334cd7f8dbeec726cc" +dependencies = [ + "chrono", + "dyn-clone", + "ref-cast", + "schemars_derive", + "serde", + "serde_json", +] + +[[package]] +name = "schemars_derive" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d115b50f4aaeea07e79c1912f645c7513d81715d0420f8bc77a18c6260b307f" +dependencies = [ + "proc-macro2", + "quote", + "serde_derive_internals", + "syn", +] + [[package]] name = "scopeguard" version = "1.2.0" @@ -3208,6 +3318,17 @@ dependencies = [ "syn", ] +[[package]] +name = "serde_derive_internals" +version = "0.29.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18d26a20a969b9e3fdf2fc2d9f21eda6c40e2de84c9408bb5d3b05d499aae711" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "serde_html_form" version = "0.2.8" diff --git a/Cargo.toml b/Cargo.toml index 5f31216..54318c0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } anyhow = "1" futures-util = "0.3" +rmcp = { version = "1", features = ["server", "macros", "transport-io"] } [[bin]] name = "damocles-daemon" @@ -29,6 +30,10 @@ path = "src/bin/send.rs" name = "verify" path = "src/bin/verify.rs" +[[bin]] +name = "damocles-mcp" +path = "src/bin/mcp.rs" + [[bin]] name = "bootstrap-cross-signing" path = "src/bin/bootstrap_cross_signing.rs" diff --git a/src/bin/mcp.rs b/src/bin/mcp.rs new file mode 100644 index 0000000..ac18b69 --- /dev/null +++ b/src/bin/mcp.rs @@ -0,0 +1,259 @@ +//! MCP stdio server bridging Claude's tool calls to the daemon's Unix socket. +//! +//! Launched by claude CLI via `--mcp-config`. Reads `DAMOCLES_SOCKET` and +//! `DAMOCLES_SOURCE_ROOM` from environment (set by the daemon in the config). + +use std::io::{BufRead, Write as _}; +use std::os::unix::net::UnixStream; + +use anyhow::{Context, Result}; +use rmcp::{ + ErrorData as McpError, ServiceExt, + handler::server::wrapper::Parameters, + model::{CallToolResult, Content}, + schemars::{self, JsonSchema}, + tool, tool_router, + transport::stdio, +}; +use serde::Deserialize; +use tokio::sync::Mutex; + +mod protocol_inline { + //! Inline copy of the daemon protocol types. The MCP binary is a separate + //! entrypoint but lives in the same crate, so it can't `use crate::protocol`. + //! We duplicate the minimal types here. + use serde::{Deserialize, Serialize}; + + #[derive(Debug, Serialize, Deserialize)] + #[serde(tag = "method")] + pub enum DaemonRequest { + #[serde(rename = "send_message")] + SendMessage { room_id: String, body: String }, + + #[serde(rename = "send_dm")] + SendDm { user_id: String, body: String }, + + #[serde(rename = "send_reaction")] + SendReaction { + room_id: String, + event_id: String, + key: String, + }, + + #[serde(rename = "list_rooms")] + ListRooms {}, + + #[serde(rename = "list_room_members")] + ListRoomMembers { room_id: String }, + } + + #[derive(Debug, Serialize, Deserialize)] + pub struct DaemonResponse { + pub success: bool, + pub data: Option, + pub error: Option, + } +} + +use protocol_inline::{DaemonRequest, DaemonResponse}; + +// --------------------------------------------------------------------------- +// Tool parameter types (schemars derives JSON Schema for claude) +// --------------------------------------------------------------------------- + +#[derive(Debug, Deserialize, JsonSchema)] +struct SendMessageParams { + /// The message text to send. + body: String, + /// Target room ID (e.g. !abc:server). Defaults to the room that triggered + /// this invocation if omitted. + #[serde(default)] + room_id: Option, +} + +#[derive(Debug, Deserialize, JsonSchema)] +struct SendDmParams { + /// The Matrix user ID to DM (e.g. @alice:server). + user_id: String, + /// The message text to send. + body: String, +} + +#[derive(Debug, Deserialize, JsonSchema)] +struct SendReactionParams { + /// The event ID to react to. Can be the shortened form shown in the + /// timeline (e.g. $abc123de...). + event_id: String, + /// The reaction emoji (e.g. fire, eyes, heart). + key: String, +} + +#[derive(Debug, Deserialize, JsonSchema)] +struct ListRoomMembersParams { + /// The room ID to list members for. + room_id: String, +} + +// --------------------------------------------------------------------------- +// MCP server struct +// --------------------------------------------------------------------------- + +struct MatrixBridge { + socket: Mutex, + source_room: String, +} + +impl MatrixBridge { + fn new(socket: UnixStream, source_room: String) -> Self { + Self { + socket: Mutex::new(socket), + source_room, + } + } + + /// Send a request to the daemon and read the response. + async fn call(&self, request: &DaemonRequest) -> Result { + let mut socket = self.socket.lock().await; + let mut json = serde_json::to_string(request).map_err(|e| { + McpError::internal_error(format!("serialize request: {e}"), None) + })?; + json.push('\n'); + socket.write_all(json.as_bytes()).map_err(|e| { + McpError::internal_error(format!("socket write: {e}"), None) + })?; + socket.flush().map_err(|e| { + McpError::internal_error(format!("socket flush: {e}"), None) + })?; + + let mut reader = std::io::BufReader::new(&*socket); + let mut line = String::new(); + reader.read_line(&mut line).map_err(|e| { + McpError::internal_error(format!("socket read: {e}"), None) + })?; + + serde_json::from_str::(&line).map_err(|e| { + McpError::internal_error(format!("parse response: {e}"), None) + }) + } + + fn response_to_result(resp: DaemonResponse) -> Result { + if resp.success { + let text = match resp.data { + Some(serde_json::Value::String(s)) => s, + Some(v) => serde_json::to_string_pretty(&v).unwrap_or_default(), + None => "ok".to_owned(), + }; + Ok(CallToolResult::success(vec![Content::text(text)])) + } else { + let msg = resp.error.unwrap_or_else(|| "unknown error".to_owned()); + Ok(CallToolResult::error(vec![Content::text(msg)])) + } + } +} + +// --------------------------------------------------------------------------- +// Tool definitions +// --------------------------------------------------------------------------- + +#[tool_router(server_handler)] +impl MatrixBridge { + #[tool(description = "Send a message to a Matrix room. Defaults to the room that triggered this invocation.")] + async fn send_message( + &self, + Parameters(params): Parameters, + ) -> Result { + let room_id = params + .room_id + .unwrap_or_else(|| self.source_room.clone()); + let resp = self + .call(&DaemonRequest::SendMessage { + room_id, + body: params.body, + }) + .await?; + Self::response_to_result(resp) + } + + #[tool(description = "Send a direct message to a Matrix user. Creates the DM room if needed.")] + async fn send_dm( + &self, + Parameters(params): Parameters, + ) -> Result { + let resp = self + .call(&DaemonRequest::SendDm { + user_id: params.user_id, + body: params.body, + }) + .await?; + Self::response_to_result(resp) + } + + #[tool(description = "React to a message with an emoji. Use the event ID shown in the timeline.")] + async fn send_reaction( + &self, + Parameters(params): Parameters, + ) -> Result { + let resp = self + .call(&DaemonRequest::SendReaction { + room_id: self.source_room.clone(), + event_id: params.event_id, + key: params.key, + }) + .await?; + Self::response_to_result(resp) + } + + #[tool(description = "List all Matrix rooms the bot has joined.")] + async fn list_rooms(&self) -> Result { + let resp = self.call(&DaemonRequest::ListRooms {}).await?; + Self::response_to_result(resp) + } + + #[tool(description = "List members of a Matrix room.")] + async fn list_room_members( + &self, + Parameters(params): Parameters, + ) -> Result { + let resp = self + .call(&DaemonRequest::ListRoomMembers { + room_id: params.room_id, + }) + .await?; + Self::response_to_result(resp) + } +} + +// --------------------------------------------------------------------------- +// Main +// --------------------------------------------------------------------------- + +#[tokio::main] +async fn main() -> Result<()> { + // MCP servers MUST log to stderr (stdout is the MCP transport) + tracing_subscriber::fmt() + .with_env_filter( + tracing_subscriber::EnvFilter::from_default_env() + .add_directive(tracing::Level::DEBUG.into()), + ) + .with_writer(std::io::stderr) + .with_ansi(false) + .init(); + + let socket_path = + std::env::var("DAMOCLES_SOCKET").context("DAMOCLES_SOCKET env var not set")?; + let source_room = + std::env::var("DAMOCLES_SOURCE_ROOM").context("DAMOCLES_SOURCE_ROOM env var not set")?; + + tracing::info!(%socket_path, %source_room, "damocles-mcp starting"); + + let socket = UnixStream::connect(&socket_path) + .with_context(|| format!("failed to connect to daemon socket at {socket_path}"))?; + + let bridge = MatrixBridge::new(socket, source_room); + let service = bridge.serve(stdio()).await.inspect_err(|e| { + tracing::error!("mcp serve error: {:?}", e); + })?; + + service.waiting().await?; + Ok(()) +} diff --git a/src/claude.rs b/src/claude.rs index 61c5d67..4499ab0 100644 --- a/src/claude.rs +++ b/src/claude.rs @@ -1,13 +1,20 @@ use std::collections::HashMap; use std::fmt::Write as _; +use std::path::Path; use anyhow::{Context, bail}; use matrix_sdk::ruma::{OwnedEventId, OwnedRoomId, OwnedUserId}; +use serde_json::json; use crate::paths; use crate::timeline::render_timeline_item; -use crate::types::{ClaudeDoc, ResponseTarget, TimelineItem}; +use crate::types::TimelineItem; +/// Invoke claude with MCP tools for Matrix interaction. +/// +/// Instead of parsing `=== type` output, the shard calls MCP tools +/// (send_message, send_reaction, etc.) which the daemon handles via the Unix +/// socket. Any text claude prints to stdout is logged as internal thought. pub async fn invoke_claude( source_room: &OwnedRoomId, room_name: &str, @@ -15,10 +22,88 @@ pub async fn invoke_claude( seen_idx: usize, model: &str, read_markers: &HashMap>, -) -> anyhow::Result> { + socket_path: &Path, +) -> anyhow::Result<()> { let identity_dir = paths::identity_dir(); let identity_str = identity_dir.to_string_lossy(); + let prompt = build_prompt(source_room, room_name, timeline, seen_idx, read_markers); + + // Write MCP config pointing to our bridge binary + daemon socket + let mcp_config = build_mcp_config(socket_path, source_room)?; + let mcp_config_path = paths::state_dir().join("mcp.json"); + tokio::fs::write(&mcp_config_path, &mcp_config).await?; + + let new_msg_count = timeline[seen_idx..] + .iter() + .filter(|t| matches!(t, TimelineItem::Message { .. })) + .count(); + let new_react_count = timeline.len().saturating_sub(seen_idx) - new_msg_count; + tracing::info!( + "invoking claude: {} new ({} msg + {} react), {} seen", + timeline.len().saturating_sub(seen_idx), + new_msg_count, + new_react_count, + seen_idx + ); + tracing::trace!("full prompt:\n{prompt}"); + + use tokio::process::Command; + let mcp_config_str = mcp_config_path.to_string_lossy(); + let mut cmd = Command::new("claude"); + cmd.args([ + "--print", + "--model", + model, + "--add-dir", + &identity_str, + "--allowedTools", + "Read,Edit,Write,Glob,Grep,mcp__matrix__send_message,mcp__matrix__send_dm,mcp__matrix__send_reaction,mcp__matrix__list_rooms,mcp__matrix__list_room_members", + "--mcp-config", + &mcp_config_str, + "-p", + &prompt, + ]); + cmd.current_dir(&identity_dir); + cmd.stdin(std::process::Stdio::null()); + let output = cmd.output().await.context("failed to run claude")?; + + let stderr = String::from_utf8_lossy(&output.stderr); + let stdout = String::from_utf8_lossy(&output.stdout); + + if !output.status.success() { + bail!( + "claude exited with {}:\nstdout: {}\nstderr: {}", + output.status, + stdout, + stderr + ); + } + + if !stderr.is_empty() { + tracing::warn!("claude stderr: {stderr}"); + } + + // With MCP, stdout is just the shard's internal monologue - log it + let text = stdout.trim(); + if !text.is_empty() { + tracing::info!( + "claude thought: {}", + text.chars().take(200).collect::() + ); + tracing::trace!("full claude output: {text}"); + } + + Ok(()) +} + +fn build_prompt( + source_room: &OwnedRoomId, + room_name: &str, + timeline: &[TimelineItem], + seen_idx: usize, + read_markers: &HashMap>, +) -> String { let mut prompt = String::new(); writeln!(prompt, "[room_id: {source_room}]").unwrap(); writeln!(prompt, "[room_name: {room_name}]").unwrap(); @@ -28,7 +113,7 @@ pub async fn invoke_claude( ) .unwrap(); - // Collect unique non-self participants (message senders + reactors) + // Collect unique non-self participants let mut senders: Vec<&OwnedUserId> = timeline .iter() .filter(|t| !t.is_self()) @@ -66,354 +151,28 @@ pub async fn invoke_claude( } } - let new_msg_count = new - .iter() - .filter(|t| matches!(t, TimelineItem::Message { .. })) - .count(); - let new_react_count = new.len() - new_msg_count; - tracing::info!( - "invoking claude: {} new ({} msg + {} react), {} seen", - new.len(), - new_msg_count, - new_react_count, - old.len() - ); - tracing::trace!("full prompt:\n{prompt}"); - - use tokio::process::Command; - let mut cmd = Command::new("claude"); - cmd.args([ - "--print", - "--model", - model, - "--add-dir", - &identity_str, - "--allowedTools", - "Read Edit Write Glob Grep", - "-p", - &prompt, - ]); - cmd.current_dir(&identity_dir); - cmd.stdin(std::process::Stdio::null()); - let output = cmd.output().await.context("failed to run claude")?; - - let stderr = String::from_utf8_lossy(&output.stderr); - let stdout = String::from_utf8_lossy(&output.stdout); - - if !output.status.success() { - bail!( - "claude exited with {}:\nstdout: {}\nstderr: {}", - output.status, - stdout, - stderr - ); - } - - if !stderr.is_empty() { - tracing::warn!("claude stderr: {stderr}"); - } - - let raw = String::from_utf8_lossy(&output.stdout).to_string(); - Ok(parse_response(&raw, source_room)) + prompt } -/// Parse Claude's stdout into a list of documents. -/// -/// Format: each doc starts with a line `=== [arg]`. Body is everything -/// until the next `===` line or EOF. Types: -/// - `=== thought` -> `ClaudeDoc::Thought` (logged, not sent) -/// - `=== room []` -> `ClaudeDoc::Message` to that room (or source room if no arg) -/// - `=== dm ` -> `ClaudeDoc::Message` as DM -/// - `=== skip` -> `ClaudeDoc::Skip` (no-op) -/// -/// Anything before the first `===` line is treated as a preamble thought. -/// Bare text with no `===` is treated as a single message to default_room. -pub fn parse_response(raw: &str, default_room: &OwnedRoomId) -> Vec { - let trimmed = raw.trim(); - if trimmed.is_empty() { - return Vec::new(); - } +/// Build the MCP config JSON that tells claude how to launch damocles-mcp. +fn build_mcp_config(socket_path: &Path, source_room: &OwnedRoomId) -> anyhow::Result { + let mcp_bin = std::env::current_exe()? + .parent() + .context("no parent dir for current exe")? + .join("damocles-mcp"); - let mut docs = Vec::new(); - let mut current_header: Option = None; - let mut current_body = String::new(); - let mut preamble = String::new(); - - for line in trimmed.lines() { - if let Some(header) = line.strip_prefix("===") { - if let Some(h) = current_header.take() { - if let Some(doc) = build_doc(&h, current_body.trim(), default_room) { - docs.push(doc); + let config = json!({ + "mcpServers": { + "matrix": { + "command": mcp_bin.to_string_lossy(), + "args": [], + "env": { + "DAMOCLES_SOCKET": socket_path.to_string_lossy(), + "DAMOCLES_SOURCE_ROOM": source_room.as_str() } - current_body.clear(); - } else { - let p = preamble.trim(); - if !p.is_empty() { - docs.push(ClaudeDoc::Thought(p.to_owned())); - } - preamble.clear(); } - current_header = Some(header.trim().to_owned()); - } else if current_header.is_some() { - current_body.push_str(line); - current_body.push('\n'); - } else { - preamble.push_str(line); - preamble.push('\n'); } - } + }); - if let Some(h) = current_header { - if let Some(doc) = build_doc(&h, current_body.trim(), default_room) { - docs.push(doc); - } - } else { - let p = preamble.trim(); - if !p.is_empty() { - docs.push(ClaudeDoc::Message { - target: ResponseTarget::Room(default_room.clone()), - body: p.to_owned(), - }); - } - } - - docs -} - -fn build_doc(header: &str, body: &str, default_room: &OwnedRoomId) -> Option { - let mut parts = header.splitn(2, char::is_whitespace); - let kind = parts.next().unwrap_or("").trim(); - let arg = parts.next().unwrap_or("").trim(); - - match kind { - "skip" => Some(ClaudeDoc::Skip), - "thought" => { - if body.is_empty() { - None - } else { - Some(ClaudeDoc::Thought(body.to_owned())) - } - } - "room" => { - if body.is_empty() { - return None; - } - let target = if arg.is_empty() { - ResponseTarget::Room(default_room.clone()) - } else { - match arg.parse::() { - Ok(rid) => ResponseTarget::Room(rid), - Err(_) => return None, - } - }; - Some(ClaudeDoc::Message { - target, - body: body.to_owned(), - }) - } - "dm" => { - if body.is_empty() { - return None; - } - match arg.parse::() { - Ok(uid) => Some(ClaudeDoc::Message { - target: ResponseTarget::Dm(uid), - body: body.to_owned(), - }), - Err(_) => None, - } - } - "react" => { - let mut header_parts = arg.splitn(2, char::is_whitespace); - let eid_arg = header_parts.next().unwrap_or("").trim(); - let key_in_header = header_parts.next().unwrap_or("").trim(); - if eid_arg.is_empty() { - return None; - } - let key = if !key_in_header.is_empty() { - key_in_header.to_owned() - } else if !body.is_empty() { - body.to_owned() - } else { - return None; - }; - Some(ClaudeDoc::Reaction { - target_id_arg: eid_arg.to_owned(), - key, - }) - } - _ => { - if body.is_empty() { - None - } else { - Some(ClaudeDoc::Thought(format!( - "[unknown header '{header}'] {body}" - ))) - } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - fn test_room() -> OwnedRoomId { - "!test:example.com".parse().unwrap() - } - - fn first_message(docs: &[ClaudeDoc]) -> (&ResponseTarget, &str) { - for d in docs { - if let ClaudeDoc::Message { target, body } = d { - return (target, body.as_str()); - } - } - panic!("no message doc found"); - } - - fn assert_room(target: &ResponseTarget, expected: &str) { - match target { - ResponseTarget::Room(r) => assert_eq!(r.as_str(), expected), - ResponseTarget::Dm(_) => panic!("expected room target, got dm"), - } - } - - #[test] - fn parse_room_with_arg() { - let raw = "=== room !other:server\nhello world"; - let docs = parse_response(raw, &test_room()); - let (target, body) = first_message(&docs); - assert_room(target, "!other:server"); - assert_eq!(body, "hello world"); - } - - #[test] - fn parse_room_no_arg_uses_default() { - let raw = "=== room\nhi"; - let docs = parse_response(raw, &test_room()); - let (target, _) = first_message(&docs); - assert_room(target, "!test:example.com"); - } - - #[test] - fn parse_skip() { - let raw = "=== skip"; - let docs = parse_response(raw, &test_room()); - assert_eq!(docs.len(), 1); - assert!(matches!(docs[0], ClaudeDoc::Skip)); - } - - #[test] - fn parse_plain_text_no_header() { - let raw = "just a message"; - let docs = parse_response(raw, &test_room()); - let (target, body) = first_message(&docs); - assert_room(target, "!test:example.com"); - assert_eq!(body, "just a message"); - } - - #[test] - fn parse_empty() { - assert!(parse_response("", &test_room()).is_empty()); - assert!(parse_response(" \n ", &test_room()).is_empty()); - } - - #[test] - fn parse_dm() { - let raw = "=== dm @alice:example.com\nhi alice"; - let docs = parse_response(raw, &test_room()); - let (target, body) = first_message(&docs); - match target { - ResponseTarget::Dm(u) => assert_eq!(u.as_str(), "@alice:example.com"), - ResponseTarget::Room(_) => panic!("expected dm target"), - } - assert_eq!(body, "hi alice"); - } - - #[test] - fn parse_thought() { - let raw = "=== thought\nthinking about whether to reply..."; - let docs = parse_response(raw, &test_room()); - assert_eq!(docs.len(), 1); - match &docs[0] { - ClaudeDoc::Thought(s) => assert_eq!(s, "thinking about whether to reply..."), - _ => panic!("expected thought"), - } - } - - #[test] - fn parse_multi_doc() { - let raw = "\ -=== thought -let me check notes - -=== room !x:y -hi - -=== dm @u:s -private - -=== skip -"; - let docs = parse_response(raw, &test_room()); - assert_eq!(docs.len(), 4); - assert!(matches!(docs[0], ClaudeDoc::Thought(_))); - assert!(matches!( - docs[1], - ClaudeDoc::Message { - target: ResponseTarget::Room(_), - .. - } - )); - assert!(matches!( - docs[2], - ClaudeDoc::Message { - target: ResponseTarget::Dm(_), - .. - } - )); - assert!(matches!(docs[3], ClaudeDoc::Skip)); - } - - #[test] - fn parse_preamble_becomes_thought() { - let raw = "preamble line\n=== room !x:y\nhello"; - let docs = parse_response(raw, &test_room()); - assert_eq!(docs.len(), 2); - assert!(matches!(docs[0], ClaudeDoc::Thought(_))); - assert!(matches!(docs[1], ClaudeDoc::Message { .. })); - } - - #[test] - fn parse_react_with_key_in_header() { - let raw = "=== react $abc12345… šŸ‘€"; - let docs = parse_response(raw, &test_room()); - assert_eq!(docs.len(), 1); - match &docs[0] { - ClaudeDoc::Reaction { target_id_arg, key } => { - assert_eq!(target_id_arg, "$abc12345…"); - assert_eq!(key, "šŸ‘€"); - } - _ => panic!("expected reaction"), - } - } - - #[test] - fn parse_react_with_key_in_body() { - let raw = "=== react $abc12345…\nšŸ”„"; - let docs = parse_response(raw, &test_room()); - assert_eq!(docs.len(), 1); - match &docs[0] { - ClaudeDoc::Reaction { key, .. } => assert_eq!(key, "šŸ”„"), - _ => panic!("expected reaction"), - } - } - - #[test] - fn parse_unknown_header_becomes_thought() { - let raw = "=== mystery foo\nbody"; - let docs = parse_response(raw, &test_room()); - assert_eq!(docs.len(), 1); - assert!(matches!(docs[0], ClaudeDoc::Thought(_))); - } + serde_json::to_string_pretty(&config).context("serialize mcp config") } diff --git a/src/main.rs b/src/main.rs index d5712f9..90dfa8c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,11 +1,14 @@ mod claude; mod handlers; mod paths; +mod protocol; mod session; +mod socket; mod timeline; mod types; use std::collections::HashSet; +use std::path::PathBuf; use std::sync::Arc; use anyhow::{Context, bail}; @@ -15,10 +18,7 @@ use matrix_sdk::{ ruma::{ OwnedEventId, OwnedRoomId, api::client::filter::FilterDefinition, - events::{ - reaction::ReactionEventContent, receipt::ReceiptThread, relation::Annotation, - room::message::RoomMessageEventContent, - }, + events::receipt::ReceiptThread, }, }; use tokio::fs; @@ -26,8 +26,7 @@ use tokio::sync::Mutex; use tracing_subscriber::EnvFilter; use types::{ - ClaudeDoc, DEFAULT_MAX_HISTORY, DEFAULT_MODEL, DEFAULT_RATE_LIMIT_PER_MIN, DaemonState, - ResponseTarget, TimelineItem, + DEFAULT_MAX_HISTORY, DEFAULT_MODEL, DEFAULT_RATE_LIMIT_PER_MIN, DaemonState, TimelineItem, }; #[tokio::main] @@ -90,10 +89,20 @@ async fn main() -> anyhow::Result<()> { max_history, })); + // Start MCP socket listener for tool calls from the shard + let socket_path = paths::state_dir().join("daemon.sock"); + let socket_client = client.clone(); + let socket_path_clone = socket_path.clone(); + tokio::spawn(async move { + if let Err(e) = socket::start_listener(&socket_path_clone, socket_client).await { + tracing::error!("mcp socket listener failed: {e}"); + } + }); + let processor_state = state.clone(); let processor_client = client.clone(); tokio::spawn(async move { - process_loop(processor_state, processor_client).await; + process_loop(processor_state, processor_client, &socket_path).await; }); sync(client, sync_token, &session_file, state).await @@ -157,7 +166,7 @@ async fn sync( bail!("sync loop exited unexpectedly") } -async fn process_loop(state: Arc>, client: Client) { +async fn process_loop(state: Arc>, client: Client, socket_path: &PathBuf) { loop { tokio::time::sleep(std::time::Duration::from_secs(1)).await; @@ -185,7 +194,7 @@ async fn process_loop(state: Arc>, client: Client) { continue; }; - if let Err(e) = process_room(&state, &client, &room_id, &room).await { + if let Err(e) = process_room(&state, &client, &room_id, &room, socket_path).await { tracing::error!(room = %room_id, "failed to process room: {e}"); } } @@ -196,6 +205,7 @@ async fn process_room( client: &Client, room_id: &OwnedRoomId, room: &Room, + socket_path: &PathBuf, ) -> anyhow::Result<()> { // Snapshot last_shown for this room so we can mark seen vs new. let in_memory = { @@ -302,76 +312,29 @@ async fn process_room( tracing::debug!(room = %room_id, "failed to send typing start: {e}"); } - let invoke_result = - claude::invoke_claude(room_id, &room_name, &tl, seen_idx, &model, &read_markers).await; + let invoke_result = claude::invoke_claude( + room_id, + &room_name, + &tl, + seen_idx, + &model, + &read_markers, + socket_path, + ) + .await; if let Err(e) = room.typing_notice(false).await { tracing::debug!(room = %room_id, "failed to send typing stop: {e}"); } - let docs = invoke_result?; - - for doc in docs { - match doc { - ClaudeDoc::Skip => { - tracing::debug!(room = %room_id, "claude doc: skip"); - } - ClaudeDoc::Thought(body) => { - tracing::info!(room = %room_id, thought = %body.chars().take(120).collect::(), "claude doc: thought"); - tracing::trace!("full thought: {body}"); - } - ClaudeDoc::Message { target, body } => { - let target_room = match &target { - ResponseTarget::Room(rid) => client.get_room(rid), - ResponseTarget::Dm(user) => { - match handlers::find_or_create_dm(client, user).await { - Ok(r) => Some(r), - Err(e) => { - tracing::error!(user = %user, "failed to get/create DM: {e}"); - None - } - } - } - }; - let target_label = match &target { - ResponseTarget::Room(rid) => rid.to_string(), - ResponseTarget::Dm(user) => format!("dm:{user}"), - }; - if let Some(target_room) = target_room { - let content = RoomMessageEventContent::text_plain(&body); - match target_room.send(content).await { - Ok(_) => { - let mut state = state.lock().await; - state.rate_budget = state.rate_budget.saturating_sub(1); - tracing::info!( - target = %target_label, - "sent response ({} budget remaining)", - state.rate_budget - ); - } - Err(e) => tracing::error!("failed to send: {e}"), - } - } else { - tracing::warn!(target = %target_label, "target not available"); - } - } - ClaudeDoc::Reaction { target_id_arg, key } => { - let Some(full_eid) = timeline::resolve_event_id(&tl, &target_id_arg) else { - tracing::warn!(arg = %target_id_arg, "react: target event id not found in timeline"); - continue; - }; - let content = - ReactionEventContent::new(Annotation::new(full_eid.clone(), key.clone())); - match room.send(content).await { - Ok(_) => tracing::info!(target = %full_eid, %key, "sent reaction"), - Err(e) => tracing::error!("failed to send reaction: {e}"), - } - } - } + if let Err(e) = invoke_result { + tracing::error!(room = %room_id, "claude invocation failed: {e}"); } + // Decrement rate budget per invocation (not per message - MCP handles sends) { let mut state = state.lock().await; + state.rate_budget = state.rate_budget.saturating_sub(1); if let Some(eid) = new_last_event_id.clone() { state.last_shown.insert(room_id.clone(), eid); } diff --git a/src/protocol.rs b/src/protocol.rs new file mode 100644 index 0000000..fb67d4e --- /dev/null +++ b/src/protocol.rs @@ -0,0 +1,53 @@ +use serde::{Deserialize, Serialize}; + +/// Request from MCP server to daemon over Unix socket (ndjson). +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "method")] +pub enum DaemonRequest { + #[serde(rename = "send_message")] + SendMessage { room_id: String, body: String }, + + #[serde(rename = "send_dm")] + SendDm { user_id: String, body: String }, + + #[serde(rename = "send_reaction")] + SendReaction { + room_id: String, + event_id: String, + key: String, + }, + + #[serde(rename = "list_rooms")] + ListRooms {}, + + #[serde(rename = "list_room_members")] + ListRoomMembers { room_id: String }, +} + +/// Response from daemon to MCP server. +#[derive(Debug, Serialize, Deserialize)] +pub struct DaemonResponse { + pub success: bool, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub data: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +impl DaemonResponse { + pub fn ok(data: impl Serialize) -> Self { + Self { + success: true, + data: Some(serde_json::to_value(data).unwrap_or_default()), + error: None, + } + } + + pub fn err(msg: impl Into) -> Self { + Self { + success: false, + data: None, + error: Some(msg.into()), + } + } +} diff --git a/src/socket.rs b/src/socket.rs new file mode 100644 index 0000000..436c653 --- /dev/null +++ b/src/socket.rs @@ -0,0 +1,185 @@ +use std::path::Path; + +use matrix_sdk::{ + Client, + ruma::{ + OwnedRoomId, OwnedUserId, + events::{ + reaction::ReactionEventContent, + relation::Annotation, + room::message::RoomMessageEventContent, + }, + }, +}; +use serde_json::json; +use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; +use tokio::net::{UnixListener, UnixStream}; + +use crate::handlers; +use crate::protocol::{DaemonRequest, DaemonResponse}; +use crate::timeline; + +pub async fn start_listener(socket_path: &Path, client: Client) -> anyhow::Result<()> { + let _ = tokio::fs::remove_file(socket_path).await; + let listener = UnixListener::bind(socket_path)?; + tracing::info!(path = %socket_path.display(), "mcp socket listener started"); + + loop { + let (stream, _) = listener.accept().await?; + let client = client.clone(); + tokio::spawn(async move { + if let Err(e) = handle_connection(stream, client).await { + tracing::warn!("mcp socket connection error: {e}"); + } + }); + } +} + +async fn handle_connection(stream: UnixStream, client: Client) -> anyhow::Result<()> { + let (reader, mut writer) = stream.into_split(); + let mut lines = BufReader::new(reader).lines(); + + while let Some(line) = lines.next_line().await? { + let response = match serde_json::from_str::(&line) { + Ok(request) => { + tracing::debug!(?request, "mcp socket request"); + handle_request(request, &client).await + } + Err(e) => DaemonResponse::err(format!("invalid request: {e}")), + }; + tracing::debug!(?response, "mcp socket response"); + let mut json = serde_json::to_string(&response)?; + json.push('\n'); + writer.write_all(json.as_bytes()).await?; + writer.flush().await?; + } + + Ok(()) +} + +async fn handle_request(request: DaemonRequest, client: &Client) -> DaemonResponse { + match request { + DaemonRequest::SendMessage { room_id, body } => send_message(client, &room_id, &body).await, + DaemonRequest::SendDm { user_id, body } => send_dm(client, &user_id, &body).await, + DaemonRequest::SendReaction { + room_id, + event_id, + key, + } => send_reaction(client, &room_id, &event_id, &key).await, + DaemonRequest::ListRooms {} => list_rooms(client).await, + DaemonRequest::ListRoomMembers { room_id } => list_room_members(client, &room_id).await, + } +} + +async fn send_message(client: &Client, room_id: &str, body: &str) -> DaemonResponse { + let rid = match room_id.parse::() { + Ok(r) => r, + Err(e) => return DaemonResponse::err(format!("invalid room_id: {e}")), + }; + let Some(room) = client.get_room(&rid) else { + return DaemonResponse::err(format!("room {rid} not found")); + }; + let content = RoomMessageEventContent::text_plain(body); + match room.send(content).await { + Ok(_) => { + tracing::info!(room = %rid, "mcp: sent message"); + DaemonResponse::ok(format!("sent to {rid}")) + } + Err(e) => DaemonResponse::err(format!("send failed: {e}")), + } +} + +async fn send_dm(client: &Client, user_id: &str, body: &str) -> DaemonResponse { + let uid = match user_id.parse::() { + Ok(u) => u, + Err(e) => return DaemonResponse::err(format!("invalid user_id: {e}")), + }; + let room = match handlers::find_or_create_dm(client, &uid).await { + Ok(r) => r, + Err(e) => return DaemonResponse::err(format!("failed to get/create DM: {e}")), + }; + let content = RoomMessageEventContent::text_plain(body); + match room.send(content).await { + Ok(_) => { + tracing::info!(user = %uid, "mcp: sent DM"); + DaemonResponse::ok(format!("DM sent to {uid}")) + } + Err(e) => DaemonResponse::err(format!("send DM failed: {e}")), + } +} + +async fn send_reaction( + client: &Client, + room_id: &str, + event_id: &str, + key: &str, +) -> DaemonResponse { + let rid = match room_id.parse::() { + Ok(r) => r, + Err(e) => return DaemonResponse::err(format!("invalid room_id: {e}")), + }; + let Some(room) = client.get_room(&rid) else { + return DaemonResponse::err(format!("room {rid} not found")); + }; + let own_user = match client.user_id() { + Some(u) => u.to_owned(), + None => return DaemonResponse::err("not logged in".to_owned()), + }; + + // Load timeline to resolve possibly-shortened event id + let tl = match timeline::load_timeline(&room, 50, &own_user).await { + Ok(t) => t, + Err(e) => return DaemonResponse::err(format!("failed to load timeline: {e}")), + }; + let Some(full_eid) = timeline::resolve_event_id(&tl, event_id) else { + return DaemonResponse::err(format!("event {event_id} not found in timeline")); + }; + + let content = ReactionEventContent::new(Annotation::new(full_eid.clone(), key.to_owned())); + match room.send(content).await { + Ok(_) => { + tracing::info!(target = %full_eid, %key, "mcp: sent reaction"); + DaemonResponse::ok(format!("reacted {key} to {full_eid}")) + } + Err(e) => DaemonResponse::err(format!("send reaction failed: {e}")), + } +} + +async fn list_rooms(client: &Client) -> DaemonResponse { + let mut rooms = Vec::new(); + for room in client.joined_rooms() { + let name = room + .display_name() + .await + .map_or_else(|_| room.room_id().to_string(), |n| n.to_string()); + rooms.push(json!({ + "room_id": room.room_id().as_str(), + "name": name, + })); + } + DaemonResponse::ok(rooms) +} + +async fn list_room_members(client: &Client, room_id: &str) -> DaemonResponse { + let rid = match room_id.parse::() { + Ok(r) => r, + Err(e) => return DaemonResponse::err(format!("invalid room_id: {e}")), + }; + let Some(room) = client.get_room(&rid) else { + return DaemonResponse::err(format!("room {rid} not found")); + }; + let members = match room.members(matrix_sdk::RoomMemberships::JOIN).await { + Ok(m) => m, + Err(e) => return DaemonResponse::err(format!("failed to list members: {e}")), + }; + let list: Vec<_> = members + .iter() + .map(|m| { + json!({ + "user_id": m.user_id().as_str(), + "display_name": m.display_name().unwrap_or_default(), + }) + }) + .collect(); + DaemonResponse::ok(list) +} diff --git a/src/types.rs b/src/types.rs index 599e1ea..d2cef7d 100644 --- a/src/types.rs +++ b/src/types.rs @@ -89,24 +89,3 @@ pub struct DaemonState { pub max_history: usize, } -pub enum ResponseTarget { - Room(OwnedRoomId), - Dm(OwnedUserId), -} - -/// One document within Claude's multi-doc output. Each doc has its own -/// frontmatter; the daemon routes based on which fields are present. -pub enum ClaudeDoc { - /// A chat message to send. - Message { - target: ResponseTarget, - body: String, - }, - /// A reaction to a message. `target_id_arg` is the event id (possibly - /// shortened) the agent saw in the prompt; daemon expands by prefix match. - Reaction { target_id_arg: String, key: String }, - /// Agent's internal monologue. Not sent to chat. Logged to tracing. - Thought(String), - /// Explicit "do nothing for this slot". Useful as a placeholder. - Skip, -}