diff --git a/src/main.rs b/src/main.rs index decac78..51d8753 100644 --- a/src/main.rs +++ b/src/main.rs @@ -550,7 +550,21 @@ async fn process_loop(state: Arc>, client: Client) { match invoke_claude(&room_id, &room_name, &chat_msgs, seen_idx, &model).await { Ok(Some(response)) => { - if let Some(target_room) = client.get_room(&response.room) { + let target_room = match &response.target { + ResponseTarget::Room(rid) => client.get_room(rid), + ResponseTarget::Dm(user) => match find_or_create_dm(&client, user).await { + Ok(r) => Some(r), + Err(e) => { + tracing::error!(user = %user, "failed to get/create DM: {e}"); + None + } + }, + }; + let target_label = match &response.target { + ResponseTarget::Room(rid) => rid.to_string(), + ResponseTarget::Dm(user) => format!("dm:{user}"), + }; + if let Some(target_room) = target_room { let content = RoomMessageEventContent::text_plain(&response.body); match target_room.send(content).await { Ok(_) => { @@ -560,7 +574,7 @@ async fn process_loop(state: Arc>, client: Client) { state.last_shown.insert(room_id.clone(), eid); } tracing::info!( - room = %response.room, + target = %target_label, "sent response ({} budget remaining)", state.rate_budget ); @@ -570,7 +584,7 @@ async fn process_loop(state: Arc>, client: Client) { Err(e) => tracing::error!("failed to send: {e}"), } } else { - tracing::warn!(room = %response.room, "target room not found"); + tracing::warn!(target = %target_label, "target not available"); } } Ok(None) => { @@ -590,6 +604,22 @@ async fn process_loop(state: Arc>, client: Client) { } } +/// Find an existing DM room with the given user, or create one. +async fn find_or_create_dm(client: &Client, user_id: &UserId) -> anyhow::Result { + for room in client.joined_rooms() { + if room.is_direct().await.unwrap_or(false) + && room + .direct_targets() + .iter() + .any(|t| t.as_str() == user_id.as_str()) + { + return Ok(room); + } + } + tracing::info!(user = %user_id, "creating new DM room"); + Ok(client.create_dm(user_id).await?) +} + async fn send_read_receipt(room: &Room, event_id: Option) { let Some(eid) = event_id else { return; @@ -685,8 +715,13 @@ async fn fetch_message( )) } +enum ResponseTarget { + Room(OwnedRoomId), + Dm(OwnedUserId), +} + struct ClaudeResponse { - room: OwnedRoomId, + target: ResponseTarget, body: String, } @@ -804,19 +839,31 @@ fn parse_response(raw: &str, default_room: &OwnedRoomId) -> Option().ok()); + + let target = if let Some(user) = dm { + ResponseTarget::Dm(user) + } else { + let room = frontmatter + .lines() + .find(|l| l.starts_with("room:")) + .and_then(|l| l.strip_prefix("room:")) + .and_then(|r| r.trim().parse().ok()) + .unwrap_or_else(|| default_room.clone()); + ResponseTarget::Room(room) + }; if body.is_empty() { return None; } return Some(ClaudeResponse { - room, + target, body: body.to_owned(), }); } @@ -827,7 +874,7 @@ fn parse_response(raw: &str, default_room: &OwnedRoomId) -> Option assert_eq!(r.as_str(), expected), + ResponseTarget::Dm(_) => panic!("expected room target, got dm"), + } + } + #[test] fn parse_frontmatter_response() { let raw = "---\nroom: !other:server\n---\nhello world"; let resp = parse_response(raw, &test_room()).unwrap(); - assert_eq!(resp.room.as_str(), "!other:server"); + assert_room(&resp, "!other:server"); assert_eq!(resp.body, "hello world"); } @@ -858,7 +912,7 @@ mod tests { fn parse_plain_response() { let raw = "just a message"; let resp = parse_response(raw, &test_room()).unwrap(); - assert_eq!(resp.room, test_room()); + assert_room(&resp, "!test:example.com"); assert_eq!(resp.body, "just a message"); } @@ -872,6 +926,27 @@ mod tests { fn parse_default_room() { let raw = "---\n---\nhello"; let resp = parse_response(raw, &test_room()).unwrap(); - assert_eq!(resp.room, test_room()); + assert_room(&resp, "!test:example.com"); + } + + #[test] + fn parse_dm_response() { + let raw = "---\ndm: @alice:example.com\n---\nhi alice"; + let resp = parse_response(raw, &test_room()).unwrap(); + match &resp.target { + ResponseTarget::Dm(u) => assert_eq!(u.as_str(), "@alice:example.com"), + ResponseTarget::Room(_) => panic!("expected dm target"), + } + assert_eq!(resp.body, "hi alice"); + } + + #[test] + fn parse_dm_takes_precedence_over_room() { + let raw = "---\nroom: !other:server\ndm: @bob:example.com\n---\nhello"; + let resp = parse_response(raw, &test_room()).unwrap(); + match &resp.target { + ResponseTarget::Dm(u) => assert_eq!(u.as_str(), "@bob:example.com"), + ResponseTarget::Room(_) => panic!("expected dm target"), + } } }