configurable max_history; auto-pull replied-to messages into prompt context
This commit is contained in:
parent
31ad42f637
commit
26d0e07199
1 changed files with 87 additions and 11 deletions
98
src/main.rs
98
src/main.rs
|
|
@ -35,9 +35,11 @@ struct Config {
|
||||||
password: String,
|
password: String,
|
||||||
rate_limit_per_min: Option<u32>,
|
rate_limit_per_min: Option<u32>,
|
||||||
model: Option<String>,
|
model: Option<String>,
|
||||||
|
max_history: Option<usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
const DEFAULT_MODEL: &str = "claude-sonnet-4-6";
|
const DEFAULT_MODEL: &str = "claude-sonnet-4-6";
|
||||||
|
const DEFAULT_MAX_HISTORY: usize = 20;
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
struct PersistedSession {
|
struct PersistedSession {
|
||||||
|
|
@ -67,9 +69,9 @@ struct DaemonState {
|
||||||
rate_limit_per_min: u32,
|
rate_limit_per_min: u32,
|
||||||
last_rate_reset: std::time::Instant,
|
last_rate_reset: std::time::Instant,
|
||||||
model: String,
|
model: String,
|
||||||
|
max_history: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
const MAX_HISTORY: usize = 20;
|
|
||||||
const DEFAULT_RATE_LIMIT_PER_MIN: u32 = 1;
|
const DEFAULT_RATE_LIMIT_PER_MIN: u32 = 1;
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
|
|
@ -97,6 +99,7 @@ async fn main() -> anyhow::Result<()> {
|
||||||
.model
|
.model
|
||||||
.clone()
|
.clone()
|
||||||
.unwrap_or_else(|| DEFAULT_MODEL.to_owned());
|
.unwrap_or_else(|| DEFAULT_MODEL.to_owned());
|
||||||
|
let max_history = config.max_history.unwrap_or(DEFAULT_MAX_HISTORY);
|
||||||
|
|
||||||
let (client, sync_token) = if session_file.exists() {
|
let (client, sync_token) = if session_file.exists() {
|
||||||
restore_session(&session_file).await?
|
restore_session(&session_file).await?
|
||||||
|
|
@ -126,6 +129,7 @@ async fn main() -> anyhow::Result<()> {
|
||||||
rate_limit_per_min,
|
rate_limit_per_min,
|
||||||
last_rate_reset: std::time::Instant::now(),
|
last_rate_reset: std::time::Instant::now(),
|
||||||
model,
|
model,
|
||||||
|
max_history,
|
||||||
}));
|
}));
|
||||||
|
|
||||||
let processor_state = state.clone();
|
let processor_state = state.clone();
|
||||||
|
|
@ -480,8 +484,17 @@ async fn process_loop(state: Arc<Mutex<DaemonState>>, client: Client) {
|
||||||
.await
|
.await
|
||||||
.map_or_else(|_| room_id.to_string(), |n| n.to_string());
|
.map_or_else(|_| room_id.to_string(), |n| n.to_string());
|
||||||
|
|
||||||
|
let (own_user, model, max_history) = {
|
||||||
|
let state = state.lock().await;
|
||||||
|
(
|
||||||
|
state.own_user_id.clone(),
|
||||||
|
state.model.clone(),
|
||||||
|
state.max_history,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
// Load recent history from matrix-sdk's persistent event cache
|
// Load recent history from matrix-sdk's persistent event cache
|
||||||
let history = match load_recent_messages(&room, MAX_HISTORY).await {
|
let mut history = match load_recent_messages(&room, max_history).await {
|
||||||
Ok(h) => h,
|
Ok(h) => h,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::error!(room = %room_id, "failed to load history: {e}");
|
tracing::error!(room = %room_id, "failed to load history: {e}");
|
||||||
|
|
@ -489,14 +502,37 @@ async fn process_loop(state: Arc<Mutex<DaemonState>>, client: Client) {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let (own_user, model) = {
|
// For any new messages that reply to events outside the window, fetch
|
||||||
let state = state.lock().await;
|
// the replied-to event from cache and prepend it as extra context.
|
||||||
(state.own_user_id.clone(), state.model.clone())
|
let seen_idx_initial = prev_last_shown
|
||||||
};
|
.as_ref()
|
||||||
|
.and_then(|id| history.iter().position(|(eid, _, _, _, _)| eid == id))
|
||||||
|
.map_or(0, |pos| pos + 1);
|
||||||
|
let in_window: std::collections::HashSet<OwnedEventId> = history
|
||||||
|
.iter()
|
||||||
|
.map(|(eid, _, _, _, _)| eid.clone())
|
||||||
|
.collect();
|
||||||
|
let mut reply_targets: Vec<OwnedEventId> = Vec::new();
|
||||||
|
for (_, _, _, _, in_reply_to) in history.iter().skip(seen_idx_initial) {
|
||||||
|
if let Some(target) = in_reply_to {
|
||||||
|
if !in_window.contains(target) && !reply_targets.contains(target) {
|
||||||
|
reply_targets.push(target.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !reply_targets.is_empty() {
|
||||||
|
if let Ok((cache, _h)) = room.event_cache().await {
|
||||||
|
for target in &reply_targets {
|
||||||
|
if let Some(found) = fetch_message(&cache, target).await {
|
||||||
|
history.insert(0, found);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let chat_msgs: Vec<ChatMessage> = history
|
let chat_msgs: Vec<ChatMessage> = history
|
||||||
.iter()
|
.iter()
|
||||||
.map(|(_, sender, body, ts)| ChatMessage {
|
.map(|(_, sender, body, ts, _)| ChatMessage {
|
||||||
sender: sender.clone(),
|
sender: sender.clone(),
|
||||||
body: body.clone(),
|
body: body.clone(),
|
||||||
is_self: sender == &own_user,
|
is_self: sender == &own_user,
|
||||||
|
|
@ -507,10 +543,10 @@ async fn process_loop(state: Arc<Mutex<DaemonState>>, client: Client) {
|
||||||
// Determine seen split: everything before (and including) prev_last_shown is "seen"
|
// Determine seen split: everything before (and including) prev_last_shown is "seen"
|
||||||
let seen_idx = prev_last_shown
|
let seen_idx = prev_last_shown
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.and_then(|id| history.iter().position(|(eid, _, _, _)| eid == id))
|
.and_then(|id| history.iter().position(|(eid, _, _, _, _)| eid == id))
|
||||||
.map_or(0, |pos| pos + 1);
|
.map_or(0, |pos| pos + 1);
|
||||||
|
|
||||||
let new_last_event_id = history.last().map(|(eid, _, _, _)| eid.clone());
|
let new_last_event_id = history.last().map(|(eid, _, _, _, _)| eid.clone());
|
||||||
|
|
||||||
match invoke_claude(&room_id, &room_name, &chat_msgs, seen_idx, &model).await {
|
match invoke_claude(&room_id, &room_name, &chat_msgs, seen_idx, &model).await {
|
||||||
Ok(Some(response)) => {
|
Ok(Some(response)) => {
|
||||||
|
|
@ -568,16 +604,17 @@ async fn send_read_receipt(room: &Room, event_id: Option<OwnedEventId>) {
|
||||||
|
|
||||||
/// Load the last N text messages from the room's persistent event cache.
|
/// Load the last N text messages from the room's persistent event cache.
|
||||||
/// Returns oldest-first list of (event_id, sender, body, ts_secs).
|
/// Returns oldest-first list of (event_id, sender, body, ts_secs).
|
||||||
|
/// Returns oldest-first list of (event_id, sender, body, ts_secs, in_reply_to).
|
||||||
async fn load_recent_messages(
|
async fn load_recent_messages(
|
||||||
room: &Room,
|
room: &Room,
|
||||||
limit: usize,
|
limit: usize,
|
||||||
) -> anyhow::Result<Vec<(OwnedEventId, OwnedUserId, String, i64)>> {
|
) -> anyhow::Result<Vec<(OwnedEventId, OwnedUserId, String, i64, Option<OwnedEventId>)>> {
|
||||||
use matrix_sdk::ruma::events::AnySyncTimelineEvent;
|
use matrix_sdk::ruma::events::AnySyncTimelineEvent;
|
||||||
|
|
||||||
let (cache, _handles) = room.event_cache().await?;
|
let (cache, _handles) = room.event_cache().await?;
|
||||||
let events = cache.events().await;
|
let events = cache.events().await;
|
||||||
|
|
||||||
let mut out: Vec<(OwnedEventId, OwnedUserId, String, i64)> = Vec::new();
|
let mut out: Vec<(OwnedEventId, OwnedUserId, String, i64, Option<OwnedEventId>)> = Vec::new();
|
||||||
for ev in events.iter().rev() {
|
for ev in events.iter().rev() {
|
||||||
if out.len() >= limit {
|
if out.len() >= limit {
|
||||||
break;
|
break;
|
||||||
|
|
@ -594,11 +631,18 @@ async fn load_recent_messages(
|
||||||
if let MessageType::Text(text) = &orig.content.msgtype {
|
if let MessageType::Text(text) = &orig.content.msgtype {
|
||||||
let ts_ms: u64 = orig.origin_server_ts.0.into();
|
let ts_ms: u64 = orig.origin_server_ts.0.into();
|
||||||
let ts_secs: i64 = i64::try_from(ts_ms).unwrap_or(0) / 1000;
|
let ts_secs: i64 = i64::try_from(ts_ms).unwrap_or(0) / 1000;
|
||||||
|
let in_reply_to = match &orig.content.relates_to {
|
||||||
|
Some(matrix_sdk::ruma::events::room::message::Relation::Reply {
|
||||||
|
in_reply_to,
|
||||||
|
}) => Some(in_reply_to.event_id.clone()),
|
||||||
|
_ => None,
|
||||||
|
};
|
||||||
out.push((
|
out.push((
|
||||||
orig.event_id.clone(),
|
orig.event_id.clone(),
|
||||||
orig.sender.clone(),
|
orig.sender.clone(),
|
||||||
text.body.clone(),
|
text.body.clone(),
|
||||||
ts_secs,
|
ts_secs,
|
||||||
|
in_reply_to,
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -609,6 +653,38 @@ async fn load_recent_messages(
|
||||||
Ok(out)
|
Ok(out)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Fetch a single text message by event_id from the room's event cache.
|
||||||
|
async fn fetch_message(
|
||||||
|
cache: &matrix_sdk::event_cache::RoomEventCache,
|
||||||
|
event_id: &matrix_sdk::ruma::EventId,
|
||||||
|
) -> Option<(OwnedEventId, OwnedUserId, String, i64, Option<OwnedEventId>)> {
|
||||||
|
use matrix_sdk::ruma::events::AnySyncTimelineEvent;
|
||||||
|
|
||||||
|
let ev = cache.find_event(event_id).await?;
|
||||||
|
let deserialized = ev.raw().deserialize().ok()?;
|
||||||
|
let AnySyncTimelineEvent::MessageLike(msg) = deserialized else {
|
||||||
|
return None;
|
||||||
|
};
|
||||||
|
let matrix_sdk::ruma::events::AnySyncMessageLikeEvent::RoomMessage(
|
||||||
|
matrix_sdk::ruma::events::SyncMessageLikeEvent::Original(orig),
|
||||||
|
) = msg
|
||||||
|
else {
|
||||||
|
return None;
|
||||||
|
};
|
||||||
|
let MessageType::Text(text) = &orig.content.msgtype else {
|
||||||
|
return None;
|
||||||
|
};
|
||||||
|
let ts_ms: u64 = orig.origin_server_ts.0.into();
|
||||||
|
let ts_secs: i64 = i64::try_from(ts_ms).unwrap_or(0) / 1000;
|
||||||
|
Some((
|
||||||
|
orig.event_id.clone(),
|
||||||
|
orig.sender.clone(),
|
||||||
|
text.body.clone(),
|
||||||
|
ts_secs,
|
||||||
|
None,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
struct ClaudeResponse {
|
struct ClaudeResponse {
|
||||||
room: OwnedRoomId,
|
room: OwnedRoomId,
|
||||||
body: String,
|
body: String,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue