use std::{ collections::HashMap, net::SocketAddr, sync::{Arc, RwLock} }; use rcgen::{CertificateParams, KeyPair}; use rustls::{RootCertStore, server::WebPkiClientVerifier}; use sqlx::SqlitePool; use tokio::{sync::broadcast, task::JoinHandle}; pub mod error; mod router; #[derive(Clone)] pub enum BeaconEvent { NewBeacon(String), Checkin(String) } pub struct BeaconListenerHandle { join_handle: JoinHandle<()>, events_broadcast: broadcast::Sender, } impl BeaconListenerHandle { pub fn is_finished(&self) -> bool { self.join_handle.is_finished() } pub fn abort(&self) { self.join_handle.abort() } pub fn event_subscribe(&self) -> broadcast::Receiver { self.events_broadcast.subscribe() } } #[derive(Clone, Default)] pub struct BeaconListenerMap(Arc>>); impl std::ops::Deref for BeaconListenerMap { type Target = Arc>>; fn deref(&self) -> &Self::Target { &self.0 } } pub async fn start_all_listeners( beacon_listener_map: BeaconListenerMap, db: SqlitePool, beacon_event_broadcast: tokio::sync::broadcast::Sender:: ) -> Result<(), crate::error::Error> { rustls::crypto::ring::default_provider().install_default().expect("could not set up rustls"); let listener_ids = sqlx::query!("SELECT listener_id FROM beacon_listener") .fetch_all(&db) .await?; tracing::info!("Starting {} listener(s)...", listener_ids.len()); for listener in listener_ids { start_listener( beacon_listener_map.clone(), listener.listener_id, db.clone(), beacon_event_broadcast.clone(), ) .await?; } Ok(()) } struct Listener { port: i64, domain_name: String, certificate: Vec, privkey: Vec, } pub async fn start_listener( beacon_listener_map: BeaconListenerMap, listener_id: i64, db: SqlitePool, beacon_event_broadcast: tokio::sync::broadcast::Sender:: ) -> Result<(), crate::error::Error> { { let Ok(blm_handle) = beacon_listener_map.read() else { return Err(crate::error::Error::Generic( "Could not acquire write lock on beacon listener map".to_string(), )); }; if blm_handle.get(&listener_id).is_some() { return Err(crate::error::Error::Generic( "Beacon listener already started".to_string(), )); } } let listener = sqlx::query_as!( Listener, "SELECT port, domain_name, certificate, privkey FROM beacon_listener WHERE listener_id = ?", listener_id ) .fetch_one(&db) .await?; let app = router::get_router(db, beacon_event_broadcast.clone()); let ca_cert = rustls::pki_types::CertificateDer::from(listener.certificate.clone()); let (keypair, cert) = { let ca_keypair = KeyPair::from_der_and_sign_algo( &rustls_pki_types::PrivateKeyDer::try_from(&*listener.privkey) .map_err(|_| rcgen::Error::CouldNotParseCertificate)?, &rcgen::PKCS_ECDSA_P256_SHA256, )?; let ca_params = CertificateParams::from_ca_cert_der(&(*listener.certificate).into()) .map_err(|_| rcgen::Error::CouldNotParseCertificate)?; let ca_cert = ca_params.self_signed(&ca_keypair)?; let server_key = KeyPair::generate()?; let Ok(server_params) = CertificateParams::new(vec![listener.domain_name.clone()]) else { return Err(crate::error::Error::Generic(format!( "Could not generate new server keychain" ))); }; let server_cert = server_params.signed_by(&server_key, &ca_cert, &ca_keypair)?; let keypair = match rustls::pki_types::PrivateKeyDer::try_from(server_key.serialize_der()) { Ok(pk) => pk, Err(e) => { return Err(crate::error::Error::Generic(format!( "Could not parse private key: {e}" ))); } }; let cert = server_cert.into(); (keypair, cert) }; let mut root_store = RootCertStore::empty(); root_store.add(ca_cert)?; let client_verifier = WebPkiClientVerifier::builder(root_store.into()).build()?; let mut tls_config = rustls::ServerConfig::builder() .with_client_cert_verifier(client_verifier) .with_single_cert(vec![cert], keypair)?; tls_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; let addr = std::net::SocketAddr::from(([0, 0, 0, 0], listener.port as u16)); tracing::debug!( "Starting listener {}, {}, on port {}", listener_id, listener.domain_name, listener.port ); let join_handle = tokio::task::spawn(async move { let res = axum_server::tls_rustls::bind_rustls( addr, axum_server::tls_rustls::RustlsConfig::from_config(Arc::new(tls_config)), ) .serve(app.into_make_service_with_connect_info::()) .await; if let Err(e) = res { tracing::error!("error running sparse listener: {e:?}"); } }); let Ok(mut blm_handle) = beacon_listener_map.write() else { return Err(crate::error::Error::Generic( "Could not acquire write lock on beacon listener map".to_string(), )); }; blm_handle.insert(listener_id, BeaconListenerHandle { join_handle, events_broadcast: beacon_event_broadcast }); Ok(()) }