use std::{ collections::HashMap, sync::{Arc, RwLock}, }; use axum::routing::{get, post, Router}; use sqlx::SqlitePool; use tokio::task::JoinHandle; pub mod error; pub struct BeaconListenerHandle { join_handle: JoinHandle<()> } impl BeaconListenerHandle { pub fn is_finished(&self) -> bool { self.join_handle.is_finished() } pub fn abort(&self) { self.join_handle.abort() } } #[derive(Clone, Default)] pub struct BeaconListenerMap(Arc>>); impl std::ops::Deref for BeaconListenerMap { type Target = Arc>>; fn deref(&self) -> &Self::Target { &self.0 } } pub async fn start_all_listeners(beacon_listener_map: BeaconListenerMap, db: SqlitePool) -> Result<(), crate::error::Error> { let listener_ids = sqlx::query!("SELECT listener_id FROM beacon_listener") .fetch_all(&db) .await?; tracing::info!("Starting {} listener(s)...", listener_ids.len()); for listener in listener_ids { start_listener(beacon_listener_map.clone(), listener.listener_id, db.clone()).await?; } Ok(()) } #[derive(Clone)] struct ListenerState { db: SqlitePool } struct Listener { listener_id: i64, port: i64, public_ip: String, domain_name: String, certificate: Vec, privkey: Vec } pub async fn start_listener(beacon_listener_map: BeaconListenerMap, listener_id: i64, db: SqlitePool) -> Result<(), crate::error::Error> { { let Ok(blm_handle) = beacon_listener_map.read() else { return Err(crate::error::Error::Generic("Could not acquire write lock on beacon listener map".to_string())); }; if blm_handle.get(&listener_id).is_some() { return Err(crate::error::Error::Generic("Beacon listener already started".to_string())); } } let listener = sqlx::query_as!(Listener, "SELECT * FROM beacon_listener WHERE listener_id = ?", listener_id) .fetch_one(&db) .await?; let app: Router<()> = Router::new() .route("/register_beacon", post(|| async { tracing::info!("Beacon attempting to register"); })) .route("/test", get(|| async { tracing::info!("Hello"); "hi there" })) .with_state(ListenerState { db }); let hidden_app = Router::new().nest("/hidden_sparse", app); let keypair = match rustls::pki_types::PrivateKeyDer::try_from(listener.privkey.clone()) { Ok(pk) => pk, Err(e) => { return Err(crate::error::Error::Generic(format!("Could not parse private key: {e}"))); } }; let cert = rustls::pki_types::CertificateDer::from(listener.certificate.clone()); let mut tls_config = rustls::ServerConfig::builder() .with_no_client_auth() .with_single_cert(vec![cert], keypair)?; tls_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; let addr = std::net::SocketAddr::from(([0, 0, 0, 0], listener.port as u16)); tracing::debug!("Starting listener {}, {}, on port {}", listener_id, listener.domain_name, listener.port); let join_handle = tokio::task::spawn(async move { let res = axum_server::tls_rustls::bind_rustls( addr, axum_server::tls_rustls::RustlsConfig::from_config( Arc::new(tls_config) ) ) .serve(hidden_app.into_make_service()) .await; if let Err(e) = res { tracing::error!("error running sparse listener: {e:?}"); } }); let Ok(mut blm_handle) = beacon_listener_map.write() else { return Err(crate::error::Error::Generic("Could not acquire write lock on beacon listener map".to_string())); }; blm_handle.insert(listener_id, BeaconListenerHandle { join_handle }); Ok(()) }