2025-02-22 16:04:15 -05:00

189 lines
5.5 KiB
Rust

use std::{
collections::HashMap, net::SocketAddr, sync::{Arc, RwLock}
};
use rcgen::{CertificateParams, KeyPair};
use rustls::{RootCertStore, server::WebPkiClientVerifier};
use sqlx::SqlitePool;
use tokio::{sync::broadcast, task::JoinHandle};
pub mod error;
mod router;
#[derive(Clone)]
pub enum BeaconEvent {
NewBeacon(String),
Checkin(String)
}
pub struct BeaconListenerHandle {
join_handle: JoinHandle<()>,
events_broadcast: broadcast::Sender<BeaconEvent>,
}
impl BeaconListenerHandle {
pub fn is_finished(&self) -> bool {
self.join_handle.is_finished()
}
pub fn abort(&self) {
self.join_handle.abort()
}
pub fn event_subscribe(&self) -> broadcast::Receiver<BeaconEvent> {
self.events_broadcast.subscribe()
}
}
#[derive(Clone, Default)]
pub struct BeaconListenerMap(Arc<RwLock<HashMap<i64, BeaconListenerHandle>>>);
impl std::ops::Deref for BeaconListenerMap {
type Target = Arc<RwLock<HashMap<i64, BeaconListenerHandle>>>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
pub async fn start_all_listeners(
beacon_listener_map: BeaconListenerMap,
db: SqlitePool,
beacon_event_broadcast: tokio::sync::broadcast::Sender::<BeaconEvent>
) -> Result<(), crate::error::Error> {
rustls::crypto::ring::default_provider().install_default().expect("could not set up rustls");
let listener_ids = sqlx::query!("SELECT listener_id FROM beacon_listener")
.fetch_all(&db)
.await?;
tracing::info!("Starting {} listener(s)...", listener_ids.len());
for listener in listener_ids {
start_listener(
beacon_listener_map.clone(),
listener.listener_id,
db.clone(),
beacon_event_broadcast.clone(),
)
.await?;
}
Ok(())
}
struct Listener {
port: i64,
domain_name: String,
certificate: Vec<u8>,
privkey: Vec<u8>,
}
pub async fn start_listener(
beacon_listener_map: BeaconListenerMap,
listener_id: i64,
db: SqlitePool,
beacon_event_broadcast: tokio::sync::broadcast::Sender::<BeaconEvent>
) -> Result<(), crate::error::Error> {
{
let Ok(blm_handle) = beacon_listener_map.read() else {
return Err(crate::error::Error::Generic(
"Could not acquire write lock on beacon listener map".to_string(),
));
};
if blm_handle.get(&listener_id).is_some() {
return Err(crate::error::Error::Generic(
"Beacon listener already started".to_string(),
));
}
}
let listener = sqlx::query_as!(
Listener,
"SELECT port, domain_name, certificate, privkey FROM beacon_listener WHERE listener_id = ?",
listener_id
)
.fetch_one(&db)
.await?;
let sender = broadcast::Sender::new(128);
let app = router::get_router(db, sender.clone());
let ca_cert = rustls::pki_types::CertificateDer::from(listener.certificate.clone());
let (keypair, cert) = {
let ca_keypair = KeyPair::from_der_and_sign_algo(
&rustls_pki_types::PrivateKeyDer::try_from(&*listener.privkey)
.map_err(|_| rcgen::Error::CouldNotParseCertificate)?,
&rcgen::PKCS_ECDSA_P256_SHA256,
)?;
let ca_params = CertificateParams::from_ca_cert_der(&(*listener.certificate).into())
.map_err(|_| rcgen::Error::CouldNotParseCertificate)?;
let ca_cert = ca_params.self_signed(&ca_keypair)?;
let server_key = KeyPair::generate()?;
let Ok(server_params) = CertificateParams::new(vec![listener.domain_name.clone()]) else {
return Err(crate::error::Error::Generic(format!(
"Could not generate new server keychain"
)));
};
let server_cert = server_params.signed_by(&server_key, &ca_cert, &ca_keypair)?;
let keypair = match rustls::pki_types::PrivateKeyDer::try_from(server_key.serialize_der()) {
Ok(pk) => pk,
Err(e) => {
return Err(crate::error::Error::Generic(format!(
"Could not parse private key: {e}"
)));
}
};
let cert = server_cert.into();
(keypair, cert)
};
let mut root_store = RootCertStore::empty();
root_store.add(ca_cert)?;
let client_verifier = WebPkiClientVerifier::builder(root_store.into()).build()?;
let mut tls_config = rustls::ServerConfig::builder()
.with_client_cert_verifier(client_verifier)
.with_single_cert(vec![cert], keypair)?;
tls_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
let addr = std::net::SocketAddr::from(([0, 0, 0, 0], listener.port as u16));
tracing::debug!(
"Starting listener {}, {}, on port {}",
listener_id,
listener.domain_name,
listener.port
);
let join_handle = tokio::task::spawn(async move {
let res = axum_server::tls_rustls::bind_rustls(
addr,
axum_server::tls_rustls::RustlsConfig::from_config(Arc::new(tls_config)),
)
.serve(app.into_make_service_with_connect_info::<SocketAddr>())
.await;
if let Err(e) = res {
tracing::error!("error running sparse listener: {e:?}");
}
});
let Ok(mut blm_handle) = beacon_listener_map.write() else {
return Err(crate::error::Error::Generic(
"Could not acquire write lock on beacon listener map".to_string(),
));
};
blm_handle.insert(listener_id, BeaconListenerHandle { join_handle, events_broadcast: sender });
Ok(())
}