feat: added download and upload commands

redid actions to better support different clients
This commit is contained in:
Andrew Rioux
2025-03-03 20:03:04 -05:00
parent e0bd5c3b06
commit 7f9ea12b6a
21 changed files with 401 additions and 88 deletions

View File

@@ -21,3 +21,6 @@ axum-msgpack = "0.4.0"
chrono = { version = "0.4.39", features = ["serde"] }
sparse-actions = { path = "../sparse-actions" }
http-body-util = "0.1.2"
tokio-util = { version = "0.7.13", features = ["io"] }
uuid = { version = "1.15.1", features = ["v4"] }

View File

@@ -1,5 +1,5 @@
use std::{
collections::HashMap, net::SocketAddr, sync::{Arc, RwLock}
collections::HashMap, net::SocketAddr, sync::{Arc, RwLock}, path::PathBuf
};
use rcgen::{CertificateParams, KeyPair};
@@ -54,7 +54,8 @@ impl std::ops::Deref for BeaconListenerMap {
pub async fn start_all_listeners(
beacon_listener_map: BeaconListenerMap,
db: SqlitePool,
beacon_event_broadcast: tokio::sync::broadcast::Sender::<BeaconEvent>
beacon_event_broadcast: tokio::sync::broadcast::Sender::<BeaconEvent>,
file_store: PathBuf,
) -> Result<(), crate::error::Error> {
rustls::crypto::ring::default_provider().install_default().expect("could not set up rustls");
@@ -70,6 +71,7 @@ pub async fn start_all_listeners(
listener.listener_id,
db.clone(),
beacon_event_broadcast.clone(),
file_store.clone()
)
.await?;
}
@@ -88,7 +90,8 @@ pub async fn start_listener(
beacon_listener_map: BeaconListenerMap,
listener_id: i64,
db: SqlitePool,
beacon_event_broadcast: tokio::sync::broadcast::Sender::<BeaconEvent>
beacon_event_broadcast: tokio::sync::broadcast::Sender::<BeaconEvent>,
file_store: PathBuf
) -> Result<(), crate::error::Error> {
{
let Ok(blm_handle) = beacon_listener_map.read() else {
@@ -111,7 +114,7 @@ pub async fn start_listener(
.fetch_one(&db)
.await?;
let app = router::get_router(db, beacon_event_broadcast.clone());
let app = router::get_router(db, beacon_event_broadcast.clone(), file_store);
let ca_cert = rustls::pki_types::CertificateDer::from(listener.certificate.clone());

View File

@@ -1,9 +1,15 @@
use std::net::SocketAddr;
use std::{net::SocketAddr, path::PathBuf};
use axum::{extract::{State, ConnectInfo, Path}, routing::post, Router};
use axum::{
extract::{State, ConnectInfo, Path, Request},
routing::{get, post},
Router
};
use axum_msgpack::MsgPack;
use sqlx::SqlitePool;
use tokio::sync::broadcast;
use tokio::{io::AsyncWriteExt, sync::broadcast};
use tokio_util::io::ReaderStream;
use tokio_stream::StreamExt;
use sparse_actions::messages;
@@ -13,6 +19,7 @@ use crate::{BeaconEvent, error};
pub struct ListenerState {
db: SqlitePool,
event_publisher: broadcast::Sender<BeaconEvent>,
file_store: PathBuf
}
pub async fn handle_checkin(
@@ -210,17 +217,52 @@ pub async fn handle_command_result(
Ok(())
}
pub fn get_router(db: SqlitePool, event_publisher: broadcast::Sender<BeaconEvent>) -> Router<()> {
pub async fn download_file(
State(state): State<ListenerState>,
Path(file_id): Path<String>
) -> Result<axum::body::Body, error::Error> {
let mut file_path = state.file_store.clone();
file_path.push(file_id);
let file = tokio::fs::File::open(file_path).await?;
let stream = ReaderStream::new(file);
Ok(axum::body::Body::from_stream(stream))
}
pub async fn upload_file(
State(state): State<ListenerState>,
request: Request
) -> Result<MsgPack<sparse_actions::actions::FileId>, error::Error> {
let file_id = uuid::Uuid::new_v4();
let mut target_file_path = state.file_store.clone();
target_file_path.push(file_id.to_string());
let mut target_file = tokio::fs::OpenOptions::new()
.write(true)
.create(true)
.open(target_file_path)
.await?;
let mut body = request.into_body().into_data_stream();
while let Some(Ok(chunk)) = body.next().await {
target_file.write_all(&chunk).await?;
}
Ok(MsgPack(sparse_actions::actions::FileId(file_id)))
}
pub fn get_router(db: SqlitePool, event_publisher: broadcast::Sender<BeaconEvent>, file_store: PathBuf) -> Router<()> {
Router::new()
.route(
"/checkin",
post(handle_checkin),
)
.route("/files/download/:fileid", post(|| async {}))
.route("/files/upload", post(|| async {}))
.route("/files/download/:fileid", get(download_file))
.route("/files/upload", post(upload_file))
.route(
"/finish/:beaconid/:commandid",
post(handle_command_result),
)
.with_state(ListenerState { db, event_publisher })
.with_state(ListenerState { db, event_publisher, file_store })
}