From cd23ec1b804d2bf4c77ee52493c4bf6908de3c72 Mon Sep 17 00:00:00 2001 From: Andrew Rioux Date: Wed, 6 Sep 2023 19:44:13 -0400 Subject: [PATCH] feat: continuing work on downloading files --- .../src/commands/connect/commands/download.rs | 12 +-- .../src/commands/connect/shell.rs | 9 +- sparse-05/sparse-05-common/src/lib.rs | 7 +- sparse-05/sparse-05-server/src/connection.rs | 35 ++++++- .../src/connection/download_file.rs | 98 +++++++++++++++++++ 5 files changed, 146 insertions(+), 15 deletions(-) diff --git a/sparse-05/sparse-05-client/src/commands/connect/commands/download.rs b/sparse-05/sparse-05-client/src/commands/connect/commands/download.rs index c1d34e5..b1ba1c2 100644 --- a/sparse-05/sparse-05-client/src/commands/connect/commands/download.rs +++ b/sparse-05/sparse-05-client/src/commands/connect/commands/download.rs @@ -16,13 +16,13 @@ pub async fn download_file( remote_file: PathBuf, local_path: PathBuf, ) -> anyhow::Result<()> { - let mut file = fs::OpenOptions::new().read(true).open(&local_path).await?; - let file_size = file.metadata().await?.len(); + let mut file = fs::OpenOptions::new() + .write(true) + .create(true) + .open(&local_path) + .await?; - let command = Command::StartUploadFile( - remote_path, - (file_size / FILE_TRANSFER_PACKET_SIZE as u64) + 1, - ); + let command = Command::StartDownloadFile(remote_file); conn.send_command(command).await?; let id = loop { diff --git a/sparse-05/sparse-05-client/src/commands/connect/shell.rs b/sparse-05/sparse-05-client/src/commands/connect/shell.rs index 2d4db8c..24e1baa 100644 --- a/sparse-05/sparse-05-client/src/commands/connect/shell.rs +++ b/sparse-05/sparse-05-client/src/commands/connect/shell.rs @@ -215,9 +215,12 @@ pub(super) async fn shell( }), _, ) => { - if let Err(e) = - commands::download::download(Arc::clone(&connection), remote_file, local_path) - .await + if let Err(e) = commands::download::download_file( + Arc::clone(&connection), + remote_file, + local_path, + ) + .await { eprintln!("{e:?}") } diff --git a/sparse-05/sparse-05-common/src/lib.rs b/sparse-05/sparse-05-common/src/lib.rs index 6fcdee4..39cfd68 100644 --- a/sparse-05/sparse-05-common/src/lib.rs +++ b/sparse-05/sparse-05-common/src/lib.rs @@ -74,13 +74,14 @@ pub mod messages { OpenedTTY(u64), ClosedTTY(u64), - SendTTYData(u64, Vec), + SendTTYData(u64, Vec), UploadFileID(u64), UploadFileStatus(u64, Vec), - StartDownloadFile(u64), - DownloadFileSegment(u64, u64, Vec), + StartDownloadFile(u64, u64), + GetDownloadFileStatus(u64, Vec), + DownloadFileSegment(u64, u64, Vec), } #[derive(Serialize_repr, Deserialize_repr, Debug, Clone, Copy)] diff --git a/sparse-05/sparse-05-server/src/connection.rs b/sparse-05/sparse-05-server/src/connection.rs index 733ddd5..78ad0dd 100644 --- a/sparse-05/sparse-05-server/src/connection.rs +++ b/sparse-05/sparse-05-server/src/connection.rs @@ -217,6 +217,7 @@ fn authenticate( } mod command; +mod download_file; mod upload_file; fn handle_full_connection( @@ -232,7 +233,7 @@ where let commands = Arc::new(Mutex::new(HashMap::new())); let uploaded_files = Arc::new(Mutex::new(HashMap::new())); - /*let mut downloaded_files = HashMap::new();*/ + let download_files = Arc::new(Mutex::new(HashMap::new())); std::thread::scope(|s| -> anyhow::Result<()> { loop { @@ -327,8 +328,36 @@ where } } - Command::StartDownloadFile(_) => {} - Command::DownloadFileStatus(_, _) => {} + Command::StartDownloadFile(path) => { + let download_files_clone = download_files.clone(); + let Ok(mut lock) = download_files.lock() else { + continue; + }; + + let handler = match download_file::start_file_download( + &s, + path, + conninfo.clone(), + download_files_clone, + ) { + Ok(handler) => handler, + Err(e) => { + eprintln!("error starting file upload: {e:?}"); + continue; + } + }; + + lock.insert(handler.id, handler); + } + Command::DownloadFileStatus(id, needed) => { + let Ok(lock) = download_files.lock() else { + continue; + }; + + if let Some(handler) = lock.get(&id) { + let _ = handler.download_status.send(needed); + } + } Command::Disconnect => { break; diff --git a/sparse-05/sparse-05-server/src/connection/download_file.rs b/sparse-05/sparse-05-server/src/connection/download_file.rs index e69de29..6451fcf 100644 --- a/sparse-05/sparse-05-server/src/connection/download_file.rs +++ b/sparse-05/sparse-05-server/src/connection/download_file.rs @@ -0,0 +1,98 @@ +use std::{ + collections::HashMap, + fs::{self, OpenOptions}, + io::Write, + path::PathBuf, + sync::{ + atomic::{AtomicU64, Ordering}, + mpsc::{channel, Sender}, + Arc, Mutex, + }, + thread::Scope, + time::Duration, +}; + +use sparse_05_common::messages::{Response, FILE_BUFFER_BUFFER_SIZE, FILE_TRANSFER_PACKET_SIZE}; + +use super::ConnectionInformation; + +static CURRENT_FILE_DOWNLOAD_ID: AtomicU64 = AtomicU64::new(0); + +pub(super) struct DownloadFileHandler { + pub id: u64, + pub download_status: Sender>, +} + +pub(super) fn start_file_download<'a, 'b: 'a>( + s: &'a Scope<'a, 'b>, + file_path: PathBuf, + conninfo: ConnectionInformation, + download_file_map: Arc>>, +) -> anyhow::Result { + let (download_status, receive_download_status) = channel(); + + let id = CURRENT_FILE_DOWNLOAD_ID.fetch_add(1, Ordering::Relaxed); + let id_2 = id; + + s.spawn(move || -> anyhow::Result<()> { + let mut file = OpenOptions::new().read(true).open(&file_path)?; + let file_size = file.metadata()?.len(); + + conninfo.send(conninfo.encrypt_and_sign_resp(Response::StartDownloadFile( + id, + (file_size / FILE_TRANSFER_PACKET_SIZE as u64) + 1, + ))?)?; + + let mut current_packet_count = 0; + + while current_packet_count < packet_count { + let mut buffers: Vec>> = vec![None; FILE_BUFFER_BUFFER_SIZE]; + + loop { + let Ok((i, buffer)) = data_receiver.recv_timeout(Duration::from_millis(250)) else { + let up_to = receive_request_status.recv()?; + + let needed = buffers[..up_to as usize] + .iter() + .enumerate() + .flat_map(|(i, b)| match b { + Some(..) => None, + None => Some(i as u64), + }) + .collect::>(); + + let is_empty = needed.is_empty(); + + conninfo.send( + conninfo.encrypt_and_sign_resp(Response::UploadFileStatus(id, needed))?, + )?; + + if is_empty { + current_packet_count += up_to; + break; + } else { + continue; + } + }; + buffers[i as usize] = Some(buffer); + } + + for buffer in buffers { + let Some(buffer) = buffer else { break }; + target_file.write(&buffer)?; + } + } + + let Ok(mut lock) = download_file_map.lock() else { + return Ok(()); + }; + lock.remove(&id); + + Ok(()) + }); + + Ok(DownloadFileHandler { + id: id_2, + download_status, + }) +}