256 lines
9.5 KiB
Rust
256 lines
9.5 KiB
Rust
//! Per-agent socket listener. Each socket file's existence on disk
|
|
//! authenticates the caller: connecting to `<.../agents/foo/mcp.sock>` means
|
|
//! you are `foo`.
|
|
|
|
use std::path::{Path, PathBuf};
|
|
use std::sync::Arc;
|
|
|
|
use anyhow::{Context, Result};
|
|
use hive_sh4re::{AgentRequest, AgentResponse, Message};
|
|
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
|
use tokio::net::{UnixListener, UnixStream};
|
|
use tokio::task::JoinHandle;
|
|
|
|
use crate::coordinator::Coordinator;
|
|
|
|
pub struct AgentSocket {
|
|
pub path: PathBuf,
|
|
pub handle: JoinHandle<()>,
|
|
}
|
|
|
|
pub fn start(
|
|
agent: &str,
|
|
socket_path: &Path,
|
|
coord: Arc<Coordinator>,
|
|
) -> Result<AgentSocket> {
|
|
let agent = agent.to_owned();
|
|
if let Some(parent) = socket_path.parent() {
|
|
std::fs::create_dir_all(parent)
|
|
.with_context(|| format!("create agent socket dir {}", parent.display()))?;
|
|
}
|
|
if socket_path.exists() {
|
|
std::fs::remove_file(socket_path).context("remove stale agent socket")?;
|
|
}
|
|
let listener = UnixListener::bind(socket_path)
|
|
.with_context(|| format!("bind agent socket {}", socket_path.display()))?;
|
|
tracing::info!(%agent, socket = %socket_path.display(), "agent socket listening");
|
|
|
|
let path = socket_path.to_path_buf();
|
|
let handle = tokio::spawn(async move {
|
|
loop {
|
|
match listener.accept().await {
|
|
Ok((stream, _)) => {
|
|
let agent = agent.clone();
|
|
let coord = coord.clone();
|
|
tokio::spawn(async move {
|
|
if let Err(e) = serve(stream, agent, coord).await {
|
|
tracing::warn!(error = ?e, "agent connection failed");
|
|
}
|
|
});
|
|
}
|
|
Err(e) => {
|
|
tracing::warn!(error = ?e, "agent listener accept failed; exiting");
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
});
|
|
Ok(AgentSocket { path, handle })
|
|
}
|
|
|
|
async fn serve(stream: UnixStream, agent: String, coord: Arc<Coordinator>) -> Result<()> {
|
|
let (read, mut write) = stream.into_split();
|
|
let mut reader = BufReader::new(read);
|
|
let mut line = String::new();
|
|
loop {
|
|
line.clear();
|
|
let n = reader.read_line(&mut line).await?;
|
|
if n == 0 {
|
|
return Ok(());
|
|
}
|
|
let resp = match serde_json::from_str::<AgentRequest>(line.trim()) {
|
|
Ok(req) => dispatch(&req, &agent, &coord).await,
|
|
Err(e) => AgentResponse::Err {
|
|
message: format!("parse error: {e}"),
|
|
},
|
|
};
|
|
let mut payload = serde_json::to_string(&resp)?;
|
|
payload.push('\n');
|
|
write.write_all(payload.as_bytes()).await?;
|
|
write.flush().await?;
|
|
}
|
|
}
|
|
|
|
/// Max long-poll window the caller can ask for; values above the
|
|
/// cap are clamped. 180s keeps us under typical TCP/proxy idle
|
|
/// limits while still letting agents park their turn until a
|
|
/// message arrives. Omitting `wait_seconds` (or passing `0`) means
|
|
/// "peek, don't wait" — claude can call recv whenever it wants a
|
|
/// cheap "is there anything pending?" check without blocking the
|
|
/// turn for 30 seconds. To actually park, the caller passes a
|
|
/// positive `wait_seconds`.
|
|
const RECV_LONG_POLL_MAX: std::time::Duration = std::time::Duration::from_secs(180);
|
|
|
|
fn recv_timeout(wait_seconds: Option<u64>) -> std::time::Duration {
|
|
match wait_seconds {
|
|
Some(s) => std::time::Duration::from_secs(s).min(RECV_LONG_POLL_MAX),
|
|
None => std::time::Duration::ZERO,
|
|
}
|
|
}
|
|
|
|
async fn dispatch(req: &AgentRequest, agent: &str, coord: &Arc<Coordinator>) -> AgentResponse {
|
|
let broker = &coord.broker;
|
|
match req {
|
|
AgentRequest::Send { to, body } => {
|
|
// Handle broadcast sends (recipient = "*")
|
|
if to == "*" {
|
|
let errors = coord.broadcast_send(agent, body);
|
|
if errors.is_empty() {
|
|
AgentResponse::Ok
|
|
} else {
|
|
AgentResponse::Err {
|
|
message: format!("broadcast failed for agents: {}", errors.join(", ")),
|
|
}
|
|
}
|
|
} else {
|
|
// Normal unicast send
|
|
match broker.send(&Message {
|
|
from: agent.to_owned(),
|
|
to: to.clone(),
|
|
body: body.clone(),
|
|
}) {
|
|
Ok(()) => AgentResponse::Ok,
|
|
Err(e) => AgentResponse::Err {
|
|
message: format!("{e:#}"),
|
|
},
|
|
}
|
|
}
|
|
}
|
|
AgentRequest::Recv { wait_seconds } => match broker
|
|
.recv_blocking(agent, recv_timeout(*wait_seconds))
|
|
.await
|
|
{
|
|
Ok(Some(msg)) => AgentResponse::Message {
|
|
from: msg.from,
|
|
body: msg.body,
|
|
},
|
|
Ok(None) => AgentResponse::Empty,
|
|
Err(e) => AgentResponse::Err {
|
|
message: format!("{e:#}"),
|
|
},
|
|
},
|
|
AgentRequest::Status => match broker.count_pending(agent) {
|
|
Ok(unread) => AgentResponse::Status { unread },
|
|
Err(e) => AgentResponse::Err {
|
|
message: format!("{e:#}"),
|
|
},
|
|
},
|
|
AgentRequest::OperatorMsg { body } => match broker.send(&Message {
|
|
from: hive_sh4re::OPERATOR_RECIPIENT.to_owned(),
|
|
to: agent.to_owned(),
|
|
body: body.clone(),
|
|
}) {
|
|
Ok(()) => AgentResponse::Ok,
|
|
Err(e) => AgentResponse::Err {
|
|
message: format!("{e:#}"),
|
|
},
|
|
},
|
|
AgentRequest::Wake { from, body } => match broker.send(&Message {
|
|
from: from.clone(),
|
|
to: agent.to_owned(),
|
|
body: body.clone(),
|
|
}) {
|
|
Ok(()) => AgentResponse::Ok,
|
|
Err(e) => AgentResponse::Err {
|
|
message: format!("{e:#}"),
|
|
},
|
|
},
|
|
AgentRequest::Recent { limit } => match broker.recent_for(agent, *limit) {
|
|
Ok(rows) => AgentResponse::Recent { rows },
|
|
Err(e) => AgentResponse::Err {
|
|
message: format!("{e:#}"),
|
|
},
|
|
},
|
|
AgentRequest::AskOperator {
|
|
question,
|
|
options,
|
|
multi,
|
|
ttl_seconds,
|
|
} => {
|
|
let deadline_at = ttl_seconds.and_then(|s| {
|
|
let now = std::time::SystemTime::now()
|
|
.duration_since(std::time::UNIX_EPOCH)
|
|
.ok()
|
|
.and_then(|d| i64::try_from(d.as_secs()).ok())
|
|
.unwrap_or(0);
|
|
i64::try_from(s).ok().map(|s| now + s)
|
|
});
|
|
match coord
|
|
.questions
|
|
.submit(agent, question, options, *multi, deadline_at)
|
|
{
|
|
Ok(id) => {
|
|
tracing::info!(%id, %agent, ?deadline_at, "agent question queued");
|
|
if let Some(ttl) = *ttl_seconds {
|
|
crate::manager_server::spawn_question_watchdog(coord, id, ttl);
|
|
}
|
|
AgentResponse::QuestionQueued { id }
|
|
}
|
|
Err(e) => AgentResponse::Err {
|
|
message: format!("{e:#}"),
|
|
},
|
|
}
|
|
}
|
|
AgentRequest::Remind {
|
|
message,
|
|
timing,
|
|
file_path,
|
|
} => {
|
|
use hive_sh4re::ReminderTiming;
|
|
|
|
// Calculate the due_at timestamp, propagating errors instead of silently
|
|
// defaulting to epoch 1970 on overflow/conversion failure.
|
|
let due_at_result: Result<i64> = match timing {
|
|
ReminderTiming::InSeconds { seconds } => {
|
|
let now = std::time::SystemTime::now();
|
|
let future = match now.checked_add(std::time::Duration::from_secs(*seconds)) {
|
|
Some(t) => t,
|
|
None => return AgentResponse::Err {
|
|
message: format!("InSeconds overflow: {seconds}s exceeds system time range"),
|
|
},
|
|
};
|
|
let duration = match future.duration_since(std::time::UNIX_EPOCH) {
|
|
Ok(d) => d,
|
|
Err(e) => return AgentResponse::Err {
|
|
message: format!("system time before UNIX_EPOCH: {e}"),
|
|
},
|
|
};
|
|
match i64::try_from(duration.as_secs()) {
|
|
Ok(ts) => Ok(ts),
|
|
Err(e) => return AgentResponse::Err {
|
|
message: format!("unix timestamp exceeds i64 range: {e}"),
|
|
},
|
|
}
|
|
}
|
|
ReminderTiming::At { unix_timestamp } => Ok(*unix_timestamp),
|
|
};
|
|
|
|
match due_at_result {
|
|
Ok(due_at) => {
|
|
match broker.store_reminder(agent, message, file_path.as_deref(), due_at) {
|
|
Ok(id) => {
|
|
tracing::info!(%id, %agent, %due_at, "reminder scheduled");
|
|
AgentResponse::Ok
|
|
}
|
|
Err(e) => AgentResponse::Err {
|
|
message: format!("failed to store reminder: {e:#}"),
|
|
},
|
|
}
|
|
}
|
|
Err(e) => AgentResponse::Err {
|
|
message: format!("invalid reminder timing: {e:#}"),
|
|
},
|
|
}
|
|
}
|
|
}
|
|
}
|