2025-02-02 02:37:53 -05:00

135 lines
3.8 KiB
Rust

use std::{
collections::HashMap,
sync::{Arc, RwLock},
};
use axum::routing::{get, post, Router};
use sqlx::SqlitePool;
use tokio::task::JoinHandle;
pub mod error;
pub struct BeaconListenerHandle {
join_handle: JoinHandle<()>
}
impl BeaconListenerHandle {
pub fn is_finished(&self) -> bool {
self.join_handle.is_finished()
}
pub fn abort(&self) {
self.join_handle.abort()
}
}
#[derive(Clone, Default)]
pub struct BeaconListenerMap(Arc<RwLock<HashMap<i64, BeaconListenerHandle>>>);
impl std::ops::Deref for BeaconListenerMap {
type Target = Arc<RwLock<HashMap<i64, BeaconListenerHandle>>>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
pub async fn start_all_listeners(beacon_listener_map: BeaconListenerMap, db: SqlitePool) -> Result<(), crate::error::Error> {
let listener_ids = sqlx::query!("SELECT listener_id FROM beacon_listener")
.fetch_all(&db)
.await?;
tracing::info!("Starting {} listener(s)...", listener_ids.len());
for listener in listener_ids {
start_listener(beacon_listener_map.clone(), listener.listener_id, db.clone()).await?;
}
Ok(())
}
#[derive(Clone)]
struct ListenerState {
db: SqlitePool
}
struct Listener {
listener_id: i64,
port: i64,
public_ip: String,
domain_name: String,
certificate: Vec<u8>,
privkey: Vec<u8>
}
pub async fn start_listener(beacon_listener_map: BeaconListenerMap, listener_id: i64, db: SqlitePool) -> Result<(), crate::error::Error> {
{
let Ok(blm_handle) = beacon_listener_map.read() else {
return Err(crate::error::Error::Generic("Could not acquire write lock on beacon listener map".to_string()));
};
if blm_handle.get(&listener_id).is_some() {
return Err(crate::error::Error::Generic("Beacon listener already started".to_string()));
}
}
let listener = sqlx::query_as!(Listener, "SELECT * FROM beacon_listener WHERE listener_id = ?", listener_id)
.fetch_one(&db)
.await?;
let app: Router<()> = Router::new()
.route("/register_beacon", post(|| async {
tracing::info!("Beacon attempting to register");
}))
.route("/test", get(|| async {
tracing::info!("Hello");
"hi there"
}))
.with_state(ListenerState {
db
});
let hidden_app = Router::new().nest("/hidden_sparse", app);
let keypair = match rustls::pki_types::PrivateKeyDer::try_from(listener.privkey.clone()) {
Ok(pk) => pk,
Err(e) => {
return Err(crate::error::Error::Generic(format!("Could not parse private key: {e}")));
}
};
let cert = rustls::pki_types::CertificateDer::from(listener.certificate.clone());
let mut tls_config = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(vec![cert], keypair)?;
tls_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
let addr = std::net::SocketAddr::from(([0, 0, 0, 0], listener.port as u16));
tracing::debug!("Starting listener {}, {}, on port {}", listener_id, listener.domain_name, listener.port);
let join_handle = tokio::task::spawn(async move {
let res = axum_server::tls_rustls::bind_rustls(
addr,
axum_server::tls_rustls::RustlsConfig::from_config(
Arc::new(tls_config)
)
)
.serve(hidden_app.into_make_service())
.await;
if let Err(e) = res {
tracing::error!("error running sparse listener: {e:?}");
}
});
let Ok(mut blm_handle) = beacon_listener_map.write() else {
return Err(crate::error::Error::Generic("Could not acquire write lock on beacon listener map".to_string()));
};
blm_handle.insert(listener_id, BeaconListenerHandle {
join_handle
});
Ok(())
}