damocles-daemon/src/socket.rs

432 lines
16 KiB
Rust

use std::path::Path;
use matrix_sdk::{
Client,
room::reply::{EnforceThread, Reply},
ruma::{
OwnedRoomId, OwnedUserId,
events::{
reaction::ReactionEventContent,
relation::Annotation,
room::message::RoomMessageEventContent,
},
},
};
use crate::claude::{short_eid, wire_event_from};
use crate::types::{FetchEventResult, MemberInfo, RoomInfo, WireEvent};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::net::{UnixListener, UnixStream};
use crate::handlers;
use crate::protocol::{DaemonRequest, DaemonResponse};
use crate::timeline;
pub async fn start_listener(socket_path: &Path, client: Client) -> anyhow::Result<()> {
let _ = tokio::fs::remove_file(socket_path).await;
let listener = UnixListener::bind(socket_path)?;
tracing::info!(path = %socket_path.display(), "mcp socket listener started");
loop {
let (stream, _) = listener.accept().await?;
let client = client.clone();
tokio::spawn(async move {
if let Err(e) = handle_connection(stream, client).await {
tracing::warn!("mcp socket connection error: {e}");
}
});
}
}
async fn handle_connection(stream: UnixStream, client: Client) -> anyhow::Result<()> {
let (reader, mut writer) = stream.into_split();
let mut lines = BufReader::new(reader).lines();
while let Some(line) = lines.next_line().await? {
let response = match serde_json::from_str::<DaemonRequest>(&line) {
Ok(request) => {
tracing::debug!(?request, "mcp socket request");
handle_request(request, &client).await
}
Err(e) => DaemonResponse::err(format!("invalid request: {e}")),
};
tracing::debug!(?response, "mcp socket response");
let mut json = serde_json::to_string(&response)?;
json.push('\n');
writer.write_all(json.as_bytes()).await?;
writer.flush().await?;
}
Ok(())
}
async fn handle_request(request: DaemonRequest, client: &Client) -> DaemonResponse {
match request {
DaemonRequest::SendMessage { room_id, body } => send_message(client, &room_id, &body).await,
DaemonRequest::SendDm { user_id, body } => send_dm(client, &user_id, &body).await,
DaemonRequest::SendReaction {
room_id,
event_id,
key,
} => send_reaction(client, &room_id, &event_id, &key).await,
DaemonRequest::SendReply {
room_id,
event_id,
body,
} => send_reply(client, &room_id, &event_id, &body).await,
DaemonRequest::ListRooms {} => list_rooms(client).await,
DaemonRequest::ListRoomMembers { room_id } => list_room_members(client, &room_id).await,
DaemonRequest::GetRoomHistory { room_id, limit } => {
get_room_history(client, &room_id, limit).await
}
DaemonRequest::FetchEvent {
room_id,
event_id,
context_before,
} => fetch_event(client, &room_id, &event_id, context_before).await,
}
}
async fn send_message(client: &Client, room_id: &str, body: &str) -> DaemonResponse {
let rid = match room_id.parse::<OwnedRoomId>() {
Ok(r) => r,
Err(e) => return DaemonResponse::err(format!("invalid room_id: {e}")),
};
let Some(room) = client.get_room(&rid) else {
return DaemonResponse::err(format!("room {rid} not found"));
};
let content = RoomMessageEventContent::text_markdown(body);
match room.send(content).await {
Ok(_) => {
tracing::info!(room = %rid, "mcp: sent message");
DaemonResponse::ok(format!("sent to {rid}"))
}
Err(e) => DaemonResponse::err(format!("send failed: {e}")),
}
}
async fn send_dm(client: &Client, user_id: &str, body: &str) -> DaemonResponse {
let uid = match user_id.parse::<OwnedUserId>() {
Ok(u) => u,
Err(e) => return DaemonResponse::err(format!("invalid user_id: {e}")),
};
let room = match handlers::find_or_create_dm(client, &uid).await {
Ok(r) => r,
Err(e) => return DaemonResponse::err(format!("failed to get/create DM: {e}")),
};
let content = RoomMessageEventContent::text_markdown(body);
match room.send(content).await {
Ok(_) => {
tracing::info!(user = %uid, "mcp: sent DM");
DaemonResponse::ok(format!("DM sent to {uid}"))
}
Err(e) => DaemonResponse::err(format!("send DM failed: {e}")),
}
}
async fn send_reaction(
client: &Client,
room_id: &str,
event_id: &str,
key: &str,
) -> DaemonResponse {
let rid = match room_id.parse::<OwnedRoomId>() {
Ok(r) => r,
Err(e) => return DaemonResponse::err(format!("invalid room_id: {e}")),
};
let Some(room) = client.get_room(&rid) else {
return DaemonResponse::err(format!("room {rid} not found"));
};
let own_user = match client.user_id() {
Some(u) => u.to_owned(),
None => return DaemonResponse::err("not logged in".to_owned()),
};
// Load timeline to resolve possibly-shortened event id
let tl = match timeline::load_timeline(&room, 50, &own_user).await {
Ok(t) => t,
Err(e) => return DaemonResponse::err(format!("failed to load timeline: {e}")),
};
let Some(full_eid) = timeline::resolve_event_id(&tl, event_id) else {
return DaemonResponse::err(format!("event {event_id} not found in timeline"));
};
let content = ReactionEventContent::new(Annotation::new(full_eid.clone(), key.to_owned()));
match room.send(content).await {
Ok(_) => {
tracing::info!(target = %full_eid, %key, "mcp: sent reaction");
DaemonResponse::ok(format!("reacted {key} to {full_eid}"))
}
Err(e) => DaemonResponse::err(format!("send reaction failed: {e}")),
}
}
async fn list_rooms(client: &Client) -> DaemonResponse {
let mut rooms = Vec::new();
for room in client.joined_rooms() {
let name = room
.display_name()
.await
.map_or_else(|_| room.room_id().to_string(), |n| n.to_string());
rooms.push(RoomInfo {
room_id: room.room_id().as_str().to_owned(),
name,
});
}
DaemonResponse::ok(rooms)
}
async fn list_room_members(client: &Client, room_id: &str) -> DaemonResponse {
let rid = match room_id.parse::<OwnedRoomId>() {
Ok(r) => r,
Err(e) => return DaemonResponse::err(format!("invalid room_id: {e}")),
};
let Some(room) = client.get_room(&rid) else {
return DaemonResponse::err(format!("room {rid} not found"));
};
let members = match room.members(matrix_sdk::RoomMemberships::JOIN).await {
Ok(m) => m,
Err(e) => return DaemonResponse::err(format!("failed to list members: {e}")),
};
let list: Vec<MemberInfo> = members
.iter()
.map(|m| MemberInfo {
user_id: m.user_id().as_str().to_owned(),
display_name: m.display_name().unwrap_or_default().to_owned(),
})
.collect();
DaemonResponse::ok(list)
}
async fn send_reply(client: &Client, room_id: &str, event_id: &str, body: &str) -> DaemonResponse {
let rid = match room_id.parse::<OwnedRoomId>() {
Ok(r) => r,
Err(e) => return DaemonResponse::err(format!("invalid room_id: {e}")),
};
let Some(room) = client.get_room(&rid) else {
return DaemonResponse::err(format!("room {rid} not found"));
};
let own_user = match client.user_id() {
Some(u) => u.to_owned(),
None => return DaemonResponse::err("not logged in".to_owned()),
};
// Resolve possibly-shortened event id against recent timeline
let tl = match timeline::load_timeline(&room, 50, &own_user).await {
Ok(t) => t,
Err(e) => return DaemonResponse::err(format!("failed to load timeline: {e}")),
};
let Some(full_eid) = timeline::resolve_event_id(&tl, event_id) else {
return DaemonResponse::err(format!("event {event_id} not found in timeline"));
};
let content = RoomMessageEventContent::text_markdown(body).into();
let reply = Reply {
event_id: full_eid.clone(),
enforce_thread: EnforceThread::MaybeThreaded,
};
let reply_content = match room.make_reply_event(content, reply).await {
Ok(c) => c,
Err(e) => return DaemonResponse::err(format!("make_reply_event failed: {e}")),
};
match room.send(reply_content).await {
Ok(_) => {
tracing::info!(target = %full_eid, "mcp: sent reply");
DaemonResponse::ok(format!("replied to {full_eid}"))
}
Err(e) => DaemonResponse::err(format!("send reply failed: {e}")),
}
}
async fn get_room_history(
client: &Client,
room_id: &str,
limit: Option<usize>,
) -> DaemonResponse {
let rid = match room_id.parse::<OwnedRoomId>() {
Ok(r) => r,
Err(e) => return DaemonResponse::err(format!("invalid room_id: {e}")),
};
let Some(room) = client.get_room(&rid) else {
return DaemonResponse::err(format!("room {rid} not found"));
};
let own_user = match client.user_id() {
Some(u) => u.to_owned(),
None => return DaemonResponse::err("not logged in".to_owned()),
};
let limit = limit.unwrap_or(20).min(100);
// Backfill via /messages if cache is short
if let Ok((cache, _)) = room.event_cache().await {
let mut tl = match timeline::load_timeline(&room, limit, &own_user).await {
Ok(t) => t,
Err(e) => return DaemonResponse::err(format!("failed to load timeline: {e}")),
};
let mut tries = 0;
while tl.len() < limit && tries < 5 {
tries += 1;
match cache.pagination().run_backwards_once((limit - tl.len()) as u16).await {
Ok(outcome) => {
if outcome.reached_start {
break;
}
}
Err(e) => {
tracing::warn!("backfill failed: {e}");
break;
}
}
tl = match timeline::load_timeline(&room, limit, &own_user).await {
Ok(t) => t,
Err(e) => return DaemonResponse::err(format!("reload after backfill failed: {e}")),
};
}
let read_markers = timeline::compute_read_markers(&room, &tl, &own_user).await;
let items: Vec<WireEvent> = tl
.iter()
.map(|i| wire_event_from(i, &read_markers))
.collect();
DaemonResponse::ok(items)
} else {
DaemonResponse::err("event cache not available".to_owned())
}
}
/// Fetch a specific event by ID via the homeserver `/context` endpoint.
/// Returns the event plus `context_before` events before it. Includes one
/// extra event as `earlier_handle` so the shard can page further backward
/// by calling fetch_event again with that id.
async fn fetch_event(
client: &Client,
room_id: &str,
event_id: &str,
context_before: Option<u32>,
) -> DaemonResponse {
use matrix_sdk::ruma::events::AnySyncTimelineEvent;
use matrix_sdk::ruma::events::room::message::MessageType;
let rid = match room_id.parse::<OwnedRoomId>() {
Ok(r) => r,
Err(e) => return DaemonResponse::err(format!("invalid room_id: {e}")),
};
let Some(room) = client.get_room(&rid) else {
return DaemonResponse::err(format!("room {rid} not found"));
};
let own_user = match client.user_id() {
Some(u) => u.to_owned(),
None => return DaemonResponse::err("not logged in".to_owned()),
};
// Resolve possibly-shortened event id against recent timeline first;
// fall back to parsing as a full id.
let resolve_tl = timeline::load_timeline(&room, 50, &own_user)
.await
.unwrap_or_default();
let full_eid = match timeline::resolve_event_id(&resolve_tl, event_id) {
Some(eid) => eid,
None => match event_id.parse::<matrix_sdk::ruma::OwnedEventId>() {
Ok(eid) => eid,
Err(e) => {
return DaemonResponse::err(format!("invalid or unknown event_id: {e}"));
}
},
};
// Request one extra event so we can split it off as the paging handle
let context_before = context_before.unwrap_or(0).min(50);
let request_size = context_before + 1;
let response = match room
.event_with_context(
&full_eid,
false,
matrix_sdk::ruma::UInt::from(request_size),
None,
)
.await
{
Ok(r) => r,
Err(e) => return DaemonResponse::err(format!("event_with_context failed: {e}")),
};
let render = |raw: &matrix_sdk::deserialized_responses::TimelineEvent| -> Option<WireEvent> {
let deserialized = raw.raw().deserialize().ok()?;
let AnySyncTimelineEvent::MessageLike(msg) = deserialized else {
return None;
};
match msg {
matrix_sdk::ruma::events::AnySyncMessageLikeEvent::RoomMessage(
matrix_sdk::ruma::events::SyncMessageLikeEvent::Original(orig),
) => {
let MessageType::Text(text) = &orig.content.msgtype else {
return None;
};
let ms: u64 = orig.origin_server_ts.0.into();
let ts = (ms / 1000) as i64;
let in_reply_to = match &orig.content.relates_to {
Some(matrix_sdk::ruma::events::room::message::Relation::Reply { in_reply_to }) => {
Some(in_reply_to.event_id.as_str().to_owned())
}
_ => None,
};
Some(WireEvent::Message {
event_id: orig.event_id.as_str().to_owned(),
event_id_short: short_eid(orig.event_id.as_str()),
sender: orig.sender.as_str().to_owned(),
is_self: orig.sender == own_user,
ts,
ts_human: format!("{} UTC", crate::timeline::format_ts(ts)),
body: text.body.clone(),
in_reply_to,
read_by: Vec::new(),
edit_history: Vec::new(),
})
}
matrix_sdk::ruma::events::AnySyncMessageLikeEvent::Reaction(
matrix_sdk::ruma::events::SyncMessageLikeEvent::Original(orig),
) => {
let ms: u64 = orig.origin_server_ts.0.into();
let ts = (ms / 1000) as i64;
Some(WireEvent::Reaction {
sender: orig.sender.as_str().to_owned(),
is_self: orig.sender == own_user,
ts,
ts_human: format!("{} UTC", crate::timeline::format_ts(ts)),
target_event_id: orig.content.relates_to.event_id.as_str().to_owned(),
target_event_id_short: short_eid(orig.content.relates_to.event_id.as_str()),
key: orig.content.relates_to.key.clone(),
})
}
_ => None,
}
};
// events_before is newest-first per matrix /context spec - reverse for chronological
let mut before: Vec<WireEvent> = response.events_before.iter().filter_map(render).collect();
before.reverse();
// The "earlier handle" is the oldest event we got, used by the shard
// to page further back via another fetch_event call.
let earlier_handle = if context_before > 0 && before.len() > context_before as usize {
before.first().and_then(|e| match e {
WireEvent::Message { event_id, .. } => Some(event_id.clone()),
WireEvent::Reaction { target_event_id, .. } => Some(target_event_id.clone()),
WireEvent::Notice { .. } => None,
})
} else {
None
};
let context_events: Vec<WireEvent> = if earlier_handle.is_some() {
before.into_iter().skip(1).collect()
} else {
before
};
let target = response.event.as_ref().and_then(render);
DaemonResponse::ok(FetchEventResult {
event: target,
context_before: context_events,
earlier_handle,
})
}