111 lines
3.8 KiB
Rust
111 lines
3.8 KiB
Rust
//! Per-agent socket listener. Each socket file's existence on disk
|
|
//! authenticates the caller: connecting to `<.../agents/foo/mcp.sock>` means
|
|
//! you are `foo`.
|
|
|
|
use std::path::{Path, PathBuf};
|
|
use std::sync::Arc;
|
|
|
|
use anyhow::{Context, Result};
|
|
use hive_sh4re::{AgentRequest, AgentResponse, Message};
|
|
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
|
use tokio::net::{UnixListener, UnixStream};
|
|
use tokio::task::JoinHandle;
|
|
|
|
use crate::broker::Broker;
|
|
|
|
pub struct AgentSocket {
|
|
pub path: PathBuf,
|
|
pub handle: JoinHandle<()>,
|
|
}
|
|
|
|
pub fn start(agent: &str, socket_path: &Path, broker: Arc<Broker>) -> Result<AgentSocket> {
|
|
let agent = agent.to_owned();
|
|
if let Some(parent) = socket_path.parent() {
|
|
std::fs::create_dir_all(parent)
|
|
.with_context(|| format!("create agent socket dir {}", parent.display()))?;
|
|
}
|
|
if socket_path.exists() {
|
|
std::fs::remove_file(socket_path).context("remove stale agent socket")?;
|
|
}
|
|
let listener = UnixListener::bind(socket_path)
|
|
.with_context(|| format!("bind agent socket {}", socket_path.display()))?;
|
|
tracing::info!(%agent, socket = %socket_path.display(), "agent socket listening");
|
|
|
|
let path = socket_path.to_path_buf();
|
|
let handle = tokio::spawn(async move {
|
|
loop {
|
|
match listener.accept().await {
|
|
Ok((stream, _)) => {
|
|
let agent = agent.clone();
|
|
let broker = broker.clone();
|
|
tokio::spawn(async move {
|
|
if let Err(e) = serve(stream, agent, broker).await {
|
|
tracing::warn!(error = ?e, "agent connection failed");
|
|
}
|
|
});
|
|
}
|
|
Err(e) => {
|
|
tracing::warn!(error = ?e, "agent listener accept failed; exiting");
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
});
|
|
Ok(AgentSocket { path, handle })
|
|
}
|
|
|
|
async fn serve(stream: UnixStream, agent: String, broker: Arc<Broker>) -> Result<()> {
|
|
let (read, mut write) = stream.into_split();
|
|
let mut reader = BufReader::new(read);
|
|
let mut line = String::new();
|
|
loop {
|
|
line.clear();
|
|
let n = reader.read_line(&mut line).await?;
|
|
if n == 0 {
|
|
return Ok(());
|
|
}
|
|
let resp = match serde_json::from_str::<AgentRequest>(line.trim()) {
|
|
Ok(req) => dispatch(&req, &agent, &broker),
|
|
Err(e) => AgentResponse::Err {
|
|
message: format!("parse error: {e}"),
|
|
},
|
|
};
|
|
let mut payload = serde_json::to_string(&resp)?;
|
|
payload.push('\n');
|
|
write.write_all(payload.as_bytes()).await?;
|
|
write.flush().await?;
|
|
}
|
|
}
|
|
|
|
fn dispatch(req: &AgentRequest, agent: &str, broker: &Broker) -> AgentResponse {
|
|
match req {
|
|
AgentRequest::Send { to, body } => {
|
|
match broker.send(&Message {
|
|
from: agent.to_owned(),
|
|
to: to.clone(),
|
|
body: body.clone(),
|
|
}) {
|
|
Ok(()) => AgentResponse::Ok,
|
|
Err(e) => AgentResponse::Err {
|
|
message: format!("{e:#}"),
|
|
},
|
|
}
|
|
}
|
|
AgentRequest::Recv => match broker.recv(agent) {
|
|
Ok(Some(msg)) => AgentResponse::Message {
|
|
from: msg.from,
|
|
body: msg.body,
|
|
},
|
|
Ok(None) => AgentResponse::Empty,
|
|
Err(e) => AgentResponse::Err {
|
|
message: format!("{e:#}"),
|
|
},
|
|
},
|
|
AgentRequest::Status => match broker.count_pending(agent) {
|
|
Ok(unread) => AgentResponse::Status { unread },
|
|
Err(e) => AgentResponse::Err {
|
|
message: format!("{e:#}"),
|
|
},
|
|
},
|
|
}
|
|
}
|