hyperhive/hive-c0re/src/agent_server.rs

256 lines
9.5 KiB
Rust

//! Per-agent socket listener. Each socket file's existence on disk
//! authenticates the caller: connecting to `<.../agents/foo/mcp.sock>` means
//! you are `foo`.
use std::path::{Path, PathBuf};
use std::sync::Arc;
use anyhow::{Context, Result};
use hive_sh4re::{AgentRequest, AgentResponse, Message};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::net::{UnixListener, UnixStream};
use tokio::task::JoinHandle;
use crate::coordinator::Coordinator;
pub struct AgentSocket {
pub path: PathBuf,
pub handle: JoinHandle<()>,
}
pub fn start(
agent: &str,
socket_path: &Path,
coord: Arc<Coordinator>,
) -> Result<AgentSocket> {
let agent = agent.to_owned();
if let Some(parent) = socket_path.parent() {
std::fs::create_dir_all(parent)
.with_context(|| format!("create agent socket dir {}", parent.display()))?;
}
if socket_path.exists() {
std::fs::remove_file(socket_path).context("remove stale agent socket")?;
}
let listener = UnixListener::bind(socket_path)
.with_context(|| format!("bind agent socket {}", socket_path.display()))?;
tracing::info!(%agent, socket = %socket_path.display(), "agent socket listening");
let path = socket_path.to_path_buf();
let handle = tokio::spawn(async move {
loop {
match listener.accept().await {
Ok((stream, _)) => {
let agent = agent.clone();
let coord = coord.clone();
tokio::spawn(async move {
if let Err(e) = serve(stream, agent, coord).await {
tracing::warn!(error = ?e, "agent connection failed");
}
});
}
Err(e) => {
tracing::warn!(error = ?e, "agent listener accept failed; exiting");
return;
}
}
}
});
Ok(AgentSocket { path, handle })
}
async fn serve(stream: UnixStream, agent: String, coord: Arc<Coordinator>) -> Result<()> {
let (read, mut write) = stream.into_split();
let mut reader = BufReader::new(read);
let mut line = String::new();
loop {
line.clear();
let n = reader.read_line(&mut line).await?;
if n == 0 {
return Ok(());
}
let resp = match serde_json::from_str::<AgentRequest>(line.trim()) {
Ok(req) => dispatch(&req, &agent, &coord).await,
Err(e) => AgentResponse::Err {
message: format!("parse error: {e}"),
},
};
let mut payload = serde_json::to_string(&resp)?;
payload.push('\n');
write.write_all(payload.as_bytes()).await?;
write.flush().await?;
}
}
/// Max long-poll window the caller can ask for; values above the
/// cap are clamped. 180s keeps us under typical TCP/proxy idle
/// limits while still letting agents park their turn until a
/// message arrives. Omitting `wait_seconds` (or passing `0`) means
/// "peek, don't wait" — claude can call recv whenever it wants a
/// cheap "is there anything pending?" check without blocking the
/// turn for 30 seconds. To actually park, the caller passes a
/// positive `wait_seconds`.
const RECV_LONG_POLL_MAX: std::time::Duration = std::time::Duration::from_secs(180);
fn recv_timeout(wait_seconds: Option<u64>) -> std::time::Duration {
match wait_seconds {
Some(s) => std::time::Duration::from_secs(s).min(RECV_LONG_POLL_MAX),
None => std::time::Duration::ZERO,
}
}
async fn dispatch(req: &AgentRequest, agent: &str, coord: &Arc<Coordinator>) -> AgentResponse {
let broker = &coord.broker;
match req {
AgentRequest::Send { to, body } => {
// Handle broadcast sends (recipient = "*")
if to == "*" {
let errors = coord.broadcast_send(agent, body);
if errors.is_empty() {
AgentResponse::Ok
} else {
AgentResponse::Err {
message: format!("broadcast failed for agents: {}", errors.join(", ")),
}
}
} else {
// Normal unicast send
match broker.send(&Message {
from: agent.to_owned(),
to: to.clone(),
body: body.clone(),
}) {
Ok(()) => AgentResponse::Ok,
Err(e) => AgentResponse::Err {
message: format!("{e:#}"),
},
}
}
}
AgentRequest::Recv { wait_seconds } => match broker
.recv_blocking(agent, recv_timeout(*wait_seconds))
.await
{
Ok(Some(msg)) => AgentResponse::Message {
from: msg.from,
body: msg.body,
},
Ok(None) => AgentResponse::Empty,
Err(e) => AgentResponse::Err {
message: format!("{e:#}"),
},
},
AgentRequest::Status => match broker.count_pending(agent) {
Ok(unread) => AgentResponse::Status { unread },
Err(e) => AgentResponse::Err {
message: format!("{e:#}"),
},
},
AgentRequest::OperatorMsg { body } => match broker.send(&Message {
from: hive_sh4re::OPERATOR_RECIPIENT.to_owned(),
to: agent.to_owned(),
body: body.clone(),
}) {
Ok(()) => AgentResponse::Ok,
Err(e) => AgentResponse::Err {
message: format!("{e:#}"),
},
},
AgentRequest::Wake { from, body } => match broker.send(&Message {
from: from.clone(),
to: agent.to_owned(),
body: body.clone(),
}) {
Ok(()) => AgentResponse::Ok,
Err(e) => AgentResponse::Err {
message: format!("{e:#}"),
},
},
AgentRequest::Recent { limit } => match broker.recent_for(agent, *limit) {
Ok(rows) => AgentResponse::Recent { rows },
Err(e) => AgentResponse::Err {
message: format!("{e:#}"),
},
},
AgentRequest::AskOperator {
question,
options,
multi,
ttl_seconds,
} => {
let deadline_at = ttl_seconds.and_then(|s| {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.ok()
.and_then(|d| i64::try_from(d.as_secs()).ok())
.unwrap_or(0);
i64::try_from(s).ok().map(|s| now + s)
});
match coord
.questions
.submit(agent, question, options, *multi, deadline_at)
{
Ok(id) => {
tracing::info!(%id, %agent, ?deadline_at, "agent question queued");
if let Some(ttl) = *ttl_seconds {
crate::manager_server::spawn_question_watchdog(coord, id, ttl);
}
AgentResponse::QuestionQueued { id }
}
Err(e) => AgentResponse::Err {
message: format!("{e:#}"),
},
}
}
AgentRequest::Remind {
message,
timing,
file_path,
} => {
use hive_sh4re::ReminderTiming;
// Calculate the due_at timestamp, propagating errors instead of silently
// defaulting to epoch 1970 on overflow/conversion failure.
let due_at_result: Result<i64> = match timing {
ReminderTiming::InSeconds { seconds } => {
let now = std::time::SystemTime::now();
let future = match now.checked_add(std::time::Duration::from_secs(*seconds)) {
Some(t) => t,
None => return AgentResponse::Err {
message: format!("InSeconds overflow: {seconds}s exceeds system time range"),
},
};
let duration = match future.duration_since(std::time::UNIX_EPOCH) {
Ok(d) => d,
Err(e) => return AgentResponse::Err {
message: format!("system time before UNIX_EPOCH: {e}"),
},
};
match i64::try_from(duration.as_secs()) {
Ok(ts) => Ok(ts),
Err(e) => return AgentResponse::Err {
message: format!("unix timestamp exceeds i64 range: {e}"),
},
}
}
ReminderTiming::At { unix_timestamp } => Ok(*unix_timestamp),
};
match due_at_result {
Ok(due_at) => {
match broker.store_reminder(agent, message, file_path.as_deref(), due_at) {
Ok(id) => {
tracing::info!(%id, %agent, %due_at, "reminder scheduled");
AgentResponse::Ok
}
Err(e) => AgentResponse::Err {
message: format!("failed to store reminder: {e:#}"),
},
}
}
Err(e) => AgentResponse::Err {
message: format!("invalid reminder timing: {e:#}"),
},
}
}
}
}