configurable rate limit, paths module, verify and bootstrap binaries
This commit is contained in:
parent
888eddf093
commit
0a1246d1f8
7 changed files with 310 additions and 33 deletions
76
src/bin/bootstrap_cross_signing.rs
Normal file
76
src/bin/bootstrap_cross_signing.rs
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
use std::path::Path;
|
||||
|
||||
use anyhow::Context;
|
||||
use matrix_sdk::{
|
||||
Client, authentication::matrix::MatrixSession, config::SyncSettings,
|
||||
encryption::CrossSigningResetAuthType, ruma::api::client::uiaa,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use tokio::fs;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Config {
|
||||
homeserver: String,
|
||||
username: String,
|
||||
password: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct PersistedSession {
|
||||
homeserver: String,
|
||||
user_session: MatrixSession,
|
||||
}
|
||||
|
||||
fn workspace_dir() -> &'static str {
|
||||
if Path::new("/workspace/config.json").exists() {
|
||||
"/workspace"
|
||||
} else {
|
||||
"/persist/damocles-lab"
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
let ws = workspace_dir();
|
||||
let config: Config =
|
||||
serde_json::from_str(&fs::read_to_string(format!("{ws}/config.json")).await?)?;
|
||||
let session: PersistedSession =
|
||||
serde_json::from_str(&fs::read_to_string(format!("{ws}/state/session.json")).await?)?;
|
||||
|
||||
let db_path = format!("{ws}/state/db");
|
||||
|
||||
let client = Client::builder()
|
||||
.homeserver_url(&session.homeserver)
|
||||
.sqlite_store(&db_path, None)
|
||||
.build()
|
||||
.await?;
|
||||
|
||||
client.restore_session(session.user_session).await?;
|
||||
|
||||
// sync once to get current state
|
||||
client.sync_once(SyncSettings::default()).await?;
|
||||
|
||||
eprintln!("bootstrapping cross-signing...");
|
||||
|
||||
if let Some(handle) = client.encryption().reset_cross_signing().await? {
|
||||
match handle.auth_type() {
|
||||
CrossSigningResetAuthType::Uiaa(uiaa_info) => {
|
||||
let user_id = client.user_id().context("not logged in")?.to_owned();
|
||||
let mut password = uiaa::Password::new(user_id.into(), config.password.clone());
|
||||
password.session = uiaa_info.session.clone();
|
||||
handle
|
||||
.auth(Some(uiaa::AuthData::Password(password)))
|
||||
.await?;
|
||||
}
|
||||
CrossSigningResetAuthType::OAuth(oauth) => {
|
||||
eprintln!("approve at: {}", oauth.approval_url);
|
||||
handle.auth(None).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
eprintln!("cross-signing bootstrapped successfully");
|
||||
Ok(())
|
||||
}
|
||||
170
src/bin/verify.rs
Normal file
170
src/bin/verify.rs
Normal file
|
|
@ -0,0 +1,170 @@
|
|||
// Interactive device verification.
|
||||
// 1. run this in the lab
|
||||
// 2. trigger verification from another (verified) device in Element
|
||||
// 3. compare the emojis between Element and this terminal
|
||||
// 4. type "yes" if they match
|
||||
// 5. verification completes
|
||||
|
||||
use std::io::{self, Write};
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use anyhow::Context;
|
||||
use futures_util::stream::StreamExt;
|
||||
use matrix_sdk::{
|
||||
Client,
|
||||
authentication::matrix::MatrixSession,
|
||||
config::SyncSettings,
|
||||
encryption::verification::{
|
||||
Emoji, SasState, SasVerification, Verification, VerificationRequest,
|
||||
VerificationRequestState, format_emojis,
|
||||
},
|
||||
ruma::events::{
|
||||
key::verification::request::ToDeviceKeyVerificationRequestEvent,
|
||||
room::message::{MessageType, OriginalSyncRoomMessageEvent},
|
||||
},
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use tokio::fs;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct PersistedSession {
|
||||
homeserver: String,
|
||||
user_session: MatrixSession,
|
||||
}
|
||||
|
||||
fn workspace_dir() -> PathBuf {
|
||||
if Path::new("/workspace/config.json").exists() {
|
||||
PathBuf::from("/workspace")
|
||||
} else {
|
||||
PathBuf::from("/persist/damocles-lab")
|
||||
}
|
||||
}
|
||||
|
||||
async fn wait_for_confirmation(sas: SasVerification, emoji: [Emoji; 7]) {
|
||||
println!(
|
||||
"\nDo the emojis match in Element?\n{}",
|
||||
format_emojis(emoji)
|
||||
);
|
||||
print!("\nType `yes` to confirm, anything else to cancel: ");
|
||||
io::stdout().flush().expect("flush stdout");
|
||||
|
||||
let mut input = String::new();
|
||||
io::stdin().read_line(&mut input).expect("read input");
|
||||
|
||||
match input.trim().to_lowercase().as_ref() {
|
||||
"yes" | "y" | "ok" => {
|
||||
sas.confirm().await.expect("confirm sas");
|
||||
println!("confirmed - waiting for other side...");
|
||||
}
|
||||
_ => {
|
||||
sas.cancel().await.expect("cancel sas");
|
||||
println!("cancelled");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn sas_handler(sas: SasVerification) {
|
||||
println!(
|
||||
"starting SAS verification with {} {}",
|
||||
sas.other_device().user_id(),
|
||||
sas.other_device().device_id()
|
||||
);
|
||||
sas.accept().await.expect("accept sas");
|
||||
|
||||
let mut stream = sas.changes();
|
||||
while let Some(state) = stream.next().await {
|
||||
match state {
|
||||
SasState::KeysExchanged { emojis, .. } => {
|
||||
let emoji = emojis.expect("emoji-only verifications").emojis;
|
||||
tokio::spawn(wait_for_confirmation(sas.clone(), emoji));
|
||||
}
|
||||
SasState::Done { .. } => {
|
||||
let dev = sas.other_device();
|
||||
println!(
|
||||
"✅ verified {} ({})",
|
||||
dev.device_id(),
|
||||
dev.display_name().unwrap_or("no name")
|
||||
);
|
||||
std::process::exit(0);
|
||||
}
|
||||
SasState::Cancelled(info) => {
|
||||
println!("❌ cancelled: {}", info.reason());
|
||||
std::process::exit(1);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn request_handler(request: VerificationRequest) {
|
||||
println!("verification request from {}", request.other_user_id());
|
||||
request.accept().await.expect("accept request");
|
||||
|
||||
let mut stream = request.changes();
|
||||
while let Some(state) = stream.next().await {
|
||||
match state {
|
||||
VerificationRequestState::Transitioned { verification } => {
|
||||
if let Verification::SasV1(sas) = verification {
|
||||
tokio::spawn(sas_handler(sas));
|
||||
break;
|
||||
}
|
||||
}
|
||||
VerificationRequestState::Done | VerificationRequestState::Cancelled(_) => break,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
let ws = workspace_dir();
|
||||
let session: PersistedSession =
|
||||
serde_json::from_str(&fs::read_to_string(ws.join("state/session.json")).await?)?;
|
||||
let db_path = ws.join("state/db");
|
||||
|
||||
let client = Client::builder()
|
||||
.homeserver_url(&session.homeserver)
|
||||
.sqlite_store(&db_path, None)
|
||||
.build()
|
||||
.await?;
|
||||
client.restore_session(session.user_session).await?;
|
||||
|
||||
let user_id = client.user_id().context("not logged in")?.to_owned();
|
||||
let device_id = client.device_id().context("no device id")?.to_owned();
|
||||
println!("logged in as {user_id} ({device_id})");
|
||||
println!("waiting for verification request from another device...");
|
||||
println!("(in Element: settings -> sessions -> click '{device_id}' -> verify)\n");
|
||||
|
||||
// to-device events (out-of-room verification request)
|
||||
client.add_event_handler(
|
||||
|ev: ToDeviceKeyVerificationRequestEvent, client: Client| async move {
|
||||
if let Some(request) = client
|
||||
.encryption()
|
||||
.get_verification_request(&ev.sender, &ev.content.transaction_id)
|
||||
.await
|
||||
{
|
||||
tokio::spawn(request_handler(request));
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
// in-room verification request
|
||||
client.add_event_handler(
|
||||
|ev: OriginalSyncRoomMessageEvent, client: Client| async move {
|
||||
if let MessageType::VerificationRequest(_) = &ev.content.msgtype {
|
||||
if let Some(request) = client
|
||||
.encryption()
|
||||
.get_verification_request(&ev.sender, &ev.event_id)
|
||||
.await
|
||||
{
|
||||
tokio::spawn(request_handler(request));
|
||||
}
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
client.sync(SyncSettings::new()).await?;
|
||||
Ok(())
|
||||
}
|
||||
70
src/main.rs
70
src/main.rs
|
|
@ -27,6 +27,7 @@ struct Config {
|
|||
homeserver: String,
|
||||
username: String,
|
||||
password: String,
|
||||
rate_limit_per_min: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
|
|
@ -50,11 +51,12 @@ struct DaemonState {
|
|||
room_history: std::collections::HashMap<OwnedRoomId, Vec<ChatMessage>>,
|
||||
pending_rooms: Vec<OwnedRoomId>,
|
||||
rate_budget: u32,
|
||||
rate_limit_per_min: u32,
|
||||
last_rate_reset: std::time::Instant,
|
||||
}
|
||||
|
||||
const MAX_HISTORY: usize = 20;
|
||||
const RATE_LIMIT_PER_MIN: u32 = 2;
|
||||
const DEFAULT_RATE_LIMIT_PER_MIN: u32 = 1;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
|
|
@ -65,28 +67,34 @@ async fn main() -> anyhow::Result<()> {
|
|||
tracing::info!("damocles-daemon starting");
|
||||
|
||||
let state_dir = paths::state_dir();
|
||||
let identity_dir = paths::identity_dir();
|
||||
fs::create_dir_all(&state_dir).await?;
|
||||
fs::create_dir_all(&identity_dir).await?;
|
||||
fs::create_dir_all(paths::identity_dir()).await?;
|
||||
fs::create_dir_all(state_dir.join("rooms")).await?;
|
||||
fs::create_dir_all(state_dir.join("people")).await?;
|
||||
|
||||
let session_file = paths::session_path();
|
||||
let db_path = paths::db_path();
|
||||
|
||||
let config = load_config().await?;
|
||||
let rate_limit_per_min = config
|
||||
.rate_limit_per_min
|
||||
.unwrap_or(DEFAULT_RATE_LIMIT_PER_MIN);
|
||||
|
||||
let (client, sync_token) = if session_file.exists() {
|
||||
restore_session(&session_file).await?
|
||||
} else {
|
||||
let config = load_config().await?;
|
||||
(login(&config, &db_path, &session_file).await?, None)
|
||||
};
|
||||
|
||||
let own_user_id = client.user_id().context("not logged in")?.to_owned();
|
||||
tracing::info!(user = %own_user_id, "ready");
|
||||
tracing::info!(user = %own_user_id, rate_limit = rate_limit_per_min, "ready");
|
||||
|
||||
let state = Arc::new(Mutex::new(DaemonState {
|
||||
own_user_id,
|
||||
room_history: std::collections::HashMap::new(),
|
||||
pending_rooms: Vec::new(),
|
||||
rate_budget: RATE_LIMIT_PER_MIN,
|
||||
rate_budget: rate_limit_per_min,
|
||||
rate_limit_per_min,
|
||||
last_rate_reset: std::time::Instant::now(),
|
||||
}));
|
||||
|
||||
|
|
@ -114,7 +122,7 @@ async fn restore_session(session_file: &Path) -> anyhow::Result<(Client, Option<
|
|||
|
||||
let client = Client::builder()
|
||||
.homeserver_url(&session.homeserver)
|
||||
.sqlite_store(&session.db_path, None)
|
||||
.sqlite_store(paths::db_path(), None)
|
||||
.build()
|
||||
.await?;
|
||||
|
||||
|
|
@ -259,7 +267,7 @@ async fn process_loop(state: Arc<Mutex<DaemonState>>, client: Client) {
|
|||
let mut state = state.lock().await;
|
||||
|
||||
if state.last_rate_reset.elapsed() >= std::time::Duration::from_secs(60) {
|
||||
state.rate_budget = RATE_LIMIT_PER_MIN;
|
||||
state.rate_budget = state.rate_limit_per_min;
|
||||
state.last_rate_reset = std::time::Instant::now();
|
||||
}
|
||||
|
||||
|
|
@ -339,26 +347,38 @@ async fn invoke_claude(
|
|||
writeln!(prompt, "{prefix}{}: {}", msg.sender, msg.body).unwrap();
|
||||
}
|
||||
|
||||
tracing::debug!("invoking claude with {} messages", history.len());
|
||||
tracing::info!("invoking claude with {} messages", history.len());
|
||||
tracing::debug!("prompt: {prompt}");
|
||||
|
||||
let output = tokio::process::Command::new("claude")
|
||||
.args([
|
||||
"--print",
|
||||
"--bare",
|
||||
"--add-dir",
|
||||
&identity_str,
|
||||
"--allowedTools",
|
||||
"Read Edit Write Glob Grep",
|
||||
])
|
||||
.current_dir(&identity_dir)
|
||||
.arg(&prompt)
|
||||
.output()
|
||||
.await
|
||||
.context("failed to run claude")?;
|
||||
use tokio::process::Command;
|
||||
let mut cmd = Command::new("claude");
|
||||
cmd.args([
|
||||
"--print",
|
||||
"--add-dir",
|
||||
&identity_str,
|
||||
"--allowedTools",
|
||||
"Read Edit Write Glob Grep",
|
||||
"-p",
|
||||
&prompt,
|
||||
]);
|
||||
cmd.current_dir(&identity_dir);
|
||||
cmd.stdin(std::process::Stdio::null());
|
||||
let output = cmd.output().await.context("failed to run claude")?;
|
||||
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
bail!("claude exited with {}: {stderr}", output.status);
|
||||
bail!(
|
||||
"claude exited with {}:\nstdout: {}\nstderr: {}",
|
||||
output.status,
|
||||
stdout,
|
||||
stderr
|
||||
);
|
||||
}
|
||||
|
||||
if !stderr.is_empty() {
|
||||
tracing::warn!("claude stderr: {stderr}");
|
||||
}
|
||||
|
||||
let raw = String::from_utf8_lossy(&output.stdout).to_string();
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue