feat: added tcp

sorry Judah
This commit is contained in:
Andrew Rioux 2025-02-12 17:49:31 -05:00
parent e388b2eefa
commit f9ff9f266a
Signed by: andrew.rioux
GPG Key ID: 9B8BAC47C17ABB94
37 changed files with 1939 additions and 902 deletions

230
Cargo.lock generated
View File

@ -599,6 +599,16 @@ version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "186dce98367766de751c42c4f03970fc60fc012296e706ccbb9d5df9b6c1e271" checksum = "186dce98367766de751c42c4f03970fc60fc012296e706ccbb9d5df9b6c1e271"
[[package]]
name = "colored"
version = "2.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "117725a109d387c937a1533ce01b450cbde6b88abceea8473c4d7a85853cda3c"
dependencies = [
"lazy_static",
"windows-sys 0.59.0",
]
[[package]] [[package]]
name = "concurrent-queue" name = "concurrent-queue"
version = "2.5.0" version = "2.5.0"
@ -683,6 +693,16 @@ dependencies = [
"version_check", "version_check",
] ]
[[package]]
name = "core-foundation"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b55271e5c8c478ad3f38ad24ef34923091e0548492a266d19b3c0b4d82574c63"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]] [[package]]
name = "core-foundation-sys" name = "core-foundation-sys"
version = "0.8.7" version = "0.8.7"
@ -825,6 +845,38 @@ dependencies = [
"syn 2.0.96", "syn 2.0.96",
] ]
[[package]]
name = "defmt"
version = "0.3.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "86f6162c53f659f65d00619fe31f14556a6e9f8752ccc4a41bd177ffcf3d6130"
dependencies = [
"bitflags 1.3.2",
"defmt-macros",
]
[[package]]
name = "defmt-macros"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d135dd939bad62d7490b0002602d35b358dce5fd9233a709d3c1ef467d4bde6"
dependencies = [
"defmt-parser",
"proc-macro-error2",
"proc-macro2",
"quote",
"syn 2.0.96",
]
[[package]]
name = "defmt-parser"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3983b127f13995e68c1e29071e5d115cd96f215ccb5e6812e3728cd6f92653b3"
dependencies = [
"thiserror 2.0.11",
]
[[package]] [[package]]
name = "der" name = "der"
version = "0.7.9" version = "0.7.9"
@ -1254,6 +1306,15 @@ dependencies = [
"tracing", "tracing",
] ]
[[package]]
name = "hash32"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "47d60b12902ba28e2730cd37e95b8c9223af2808df9e902d4df49588d1470606"
dependencies = [
"byteorder",
]
[[package]] [[package]]
name = "hashbrown" name = "hashbrown"
version = "0.14.5" version = "0.14.5"
@ -1280,6 +1341,16 @@ dependencies = [
"hashbrown 0.15.2", "hashbrown 0.15.2",
] ]
[[package]]
name = "heapless"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bfb9eb618601c89945a70e254898da93b13be0388091d42117462b265bb3fad"
dependencies = [
"hash32",
"stable_deref_trait",
]
[[package]] [[package]]
name = "heck" name = "heck"
version = "0.3.3" version = "0.3.3"
@ -1438,6 +1509,25 @@ dependencies = [
"pin-project-lite", "pin-project-lite",
"smallvec", "smallvec",
"tokio", "tokio",
"want",
]
[[package]]
name = "hyper-rustls"
version = "0.27.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d191583f3da1305256f22463b9bb0471acad48a4e534a5218b9963e9c1f59b2"
dependencies = [
"futures-util",
"http",
"hyper",
"hyper-util",
"rustls",
"rustls-native-certs",
"rustls-pki-types",
"tokio",
"tokio-rustls",
"tower-service",
] ]
[[package]] [[package]]
@ -1447,13 +1537,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df2dcfbe0677734ab2f3ffa7fa7bfd4706bfdc1ef393f2ee30184aed67e631b4" checksum = "df2dcfbe0677734ab2f3ffa7fa7bfd4706bfdc1ef393f2ee30184aed67e631b4"
dependencies = [ dependencies = [
"bytes", "bytes",
"futures-channel",
"futures-util", "futures-util",
"http", "http",
"http-body", "http-body",
"hyper", "hyper",
"pin-project-lite", "pin-project-lite",
"socket2",
"tokio", "tokio",
"tower-service", "tower-service",
"tracing",
] ]
[[package]] [[package]]
@ -2013,6 +2106,12 @@ version = "0.4.25"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f" checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f"
[[package]]
name = "managed"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ca88d725a0a943b096803bd34e73a4437208b6077654cc4ecb2947a5f91618d"
[[package]] [[package]]
name = "manyhow" name = "manyhow"
version = "0.11.4" version = "0.11.4"
@ -2233,6 +2332,15 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "num_threads"
version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c7398b9c8b70908f6371f47ed36737907c87c52af34c268fed0bf0ceb92ead9"
dependencies = [
"libc",
]
[[package]] [[package]]
name = "object" name = "object"
version = "0.36.7" version = "0.36.7"
@ -2267,6 +2375,12 @@ version = "1.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775"
[[package]]
name = "openssl-probe"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e"
[[package]] [[package]]
name = "or_poisoned" name = "or_poisoned"
version = "0.1.0" version = "0.1.0"
@ -2357,6 +2471,7 @@ dependencies = [
"packets", "packets",
"tokio", "tokio",
"tokio-stream", "tokio-stream",
"windows",
] ]
[[package]] [[package]]
@ -2386,18 +2501,18 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e"
[[package]] [[package]]
name = "pin-project" name = "pin-project"
version = "1.1.8" version = "1.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e2ec53ad785f4d35dac0adea7f7dc6f1bb277ad84a680c7afefeae05d1f5916" checksum = "dfe2e71e1471fe07709406bf725f710b02927c9c54b2b5b2ec0e8087d97c327d"
dependencies = [ dependencies = [
"pin-project-internal", "pin-project-internal",
] ]
[[package]] [[package]]
name = "pin-project-internal" name = "pin-project-internal"
version = "1.1.8" version = "1.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d56a66c0c55993aa927429d0f8a0abfd74f084e4d9c192cffed01e418d83eefb" checksum = "f6e859e6e5bd50440ab63c47e3ebabc90f26251f7c73c3d3e837b74a1cc3fa67"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@ -2875,12 +2990,25 @@ dependencies = [
"aws-lc-rs", "aws-lc-rs",
"log", "log",
"once_cell", "once_cell",
"ring",
"rustls-pki-types", "rustls-pki-types",
"rustls-webpki", "rustls-webpki",
"subtle", "subtle",
"zeroize", "zeroize",
] ]
[[package]]
name = "rustls-native-certs"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7fcff2dd52b58a8d98a70243663a0d234c4e2b79235637849d15913394a247d3"
dependencies = [
"openssl-probe",
"rustls-pki-types",
"schannel",
"security-framework",
]
[[package]] [[package]]
name = "rustls-pemfile" name = "rustls-pemfile"
version = "2.2.0" version = "2.2.0"
@ -2929,12 +3057,44 @@ dependencies = [
"winapi-util", "winapi-util",
] ]
[[package]]
name = "schannel"
version = "0.1.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f29ebaa345f945cec9fbbc532eb307f0fdad8161f281b6369539c8d84876b3d"
dependencies = [
"windows-sys 0.59.0",
]
[[package]] [[package]]
name = "scopeguard" name = "scopeguard"
version = "1.2.0" version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "security-framework"
version = "3.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "271720403f46ca04f7ba6f55d438f8bd878d6b8ca0a1046e8228c4145bcbb316"
dependencies = [
"bitflags 2.8.0",
"core-foundation",
"core-foundation-sys",
"libc",
"security-framework-sys",
]
[[package]]
name = "security-framework-sys"
version = "2.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49db231d56a190491cb4aeda9527f1ad45345af50b0851622a7adb8c03b01c32"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]] [[package]]
name = "send_wrapper" name = "send_wrapper"
version = "0.6.0" version = "0.6.0"
@ -3134,6 +3294,18 @@ dependencies = [
"rand_core 0.6.4", "rand_core 0.6.4",
] ]
[[package]]
name = "simple_logger"
version = "5.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8c5dfa5e08767553704aa0ffd9d9794d527103c736aba9854773851fd7497eb"
dependencies = [
"colored",
"log",
"time",
"windows-sys 0.48.0",
]
[[package]] [[package]]
name = "slab" name = "slab"
version = "0.4.9" version = "0.4.9"
@ -3161,6 +3333,21 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "smoltcp"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dad095989c1533c1c266d9b1e8d70a1329dd3723c3edac6d03bbd67e7bf6f4bb"
dependencies = [
"bitflags 1.3.2",
"byteorder",
"cfg-if",
"defmt",
"heapless",
"log",
"managed",
]
[[package]] [[package]]
name = "socket2" name = "socket2"
version = "0.5.8" version = "0.5.8"
@ -3182,7 +3369,23 @@ dependencies = [
name = "sparse-beacon" name = "sparse-beacon"
version = "0.7.0" version = "0.7.0"
dependencies = [ dependencies = [
"async-trait",
"futures",
"hyper",
"hyper-rustls",
"hyper-util",
"nl-sys",
"packets",
"pcap-sys", "pcap-sys",
"pin-project",
"rand 0.9.0",
"simple_logger",
"smoltcp",
"sparse-actions",
"thiserror 2.0.11",
"tokio",
"tower-service",
"tracing",
] ]
[[package]] [[package]]
@ -3281,6 +3484,8 @@ dependencies = [
name = "sparse-windows-beacon" name = "sparse-windows-beacon"
version = "2.0.0" version = "2.0.0"
dependencies = [ dependencies = [
"anyhow",
"pcap-sys",
"windows", "windows",
"winreg", "winreg",
] ]
@ -3752,7 +3957,9 @@ checksum = "35e7868883861bd0e56d9ac6efcaaca0d6d5d82a2a7ec8209ff492c07cf37b21"
dependencies = [ dependencies = [
"deranged", "deranged",
"itoa", "itoa",
"libc",
"num-conv", "num-conv",
"num_threads",
"powerfmt", "powerfmt",
"serde", "serde",
"time-core", "time-core",
@ -4078,6 +4285,12 @@ dependencies = [
"tracing-serde", "tracing-serde",
] ]
[[package]]
name = "try-lock"
version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
[[package]] [[package]]
name = "tungstenite" name = "tungstenite"
version = "0.24.0" version = "0.24.0"
@ -4257,6 +4470,15 @@ dependencies = [
"winapi-util", "winapi-util",
] ]
[[package]]
name = "want"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e"
dependencies = [
"try-lock",
]
[[package]] [[package]]
name = "wasi" name = "wasi"
version = "0.11.0+wasi-snapshot-preview1" version = "0.11.0+wasi-snapshot-preview1"

View File

@ -92,6 +92,7 @@ extern "C" {
pub fn rtnl_link_get_name(link: *mut rtnl_link) -> *const c_char; pub fn rtnl_link_get_name(link: *mut rtnl_link) -> *const c_char;
pub fn rtnl_link_get_ifindex(link: *mut rtnl_link) -> c_int; pub fn rtnl_link_get_ifindex(link: *mut rtnl_link) -> c_int;
pub fn rtnl_link_get_type(link: *mut rtnl_link) -> *const c_char; pub fn rtnl_link_get_type(link: *mut rtnl_link) -> *const c_char;
pub fn rtnl_link_get_mtu(link: *mut rtnl_link) -> c_uint;
pub fn rtnl_route_alloc_cache( pub fn rtnl_route_alloc_cache(
sock: *mut nl_sock, sock: *mut nl_sock,

View File

@ -91,6 +91,11 @@ impl Link {
} }
} }
/// Returns the MTU of the link
pub fn mtu(&self) -> u32 {
unsafe { rtnl_link_get_mtu(self.link) }
}
/// Determines the type of link. Ethernet devices are "veth or eth" /// Determines the type of link. Ethernet devices are "veth or eth"
pub fn ltype(&self) -> Option<String> { pub fn ltype(&self) -> Option<String> {
unsafe { unsafe {

View File

@ -78,6 +78,10 @@ impl<'a> EthernetPkt<'a> {
data: self.data.to_vec(), data: self.data.to_vec(),
} }
} }
pub fn raw(&self) -> &[u8] {
self.data
}
} }
pub enum Layer3Pkt<'a> { pub enum Layer3Pkt<'a> {
@ -425,6 +429,10 @@ impl EthernetPacket {
pub fn pkt(&'_ self) -> EthernetPkt<'_> { pub fn pkt(&'_ self) -> EthernetPkt<'_> {
EthernetPkt { data: &self.data } EthernetPkt { data: &self.data }
} }
pub fn from_raw(data: Vec<u8>) -> EthernetPacket {
Self { data }
}
} }
#[derive(Clone)] #[derive(Clone)]

View File

@ -31,4 +31,8 @@ tokio = { version = "1.21.2", features = [
"rt-multi-thread", "rt-multi-thread",
] } ] }
tokio-stream = "0.1.14" tokio-stream = "0.1.14"
packets = { path = "../packets" } packets = { path = "../packets" }
[target.'cfg(windows)'.dependencies]
windows = { version = "0.59.0", features = ["Win32_System_Threading"] }

View File

@ -31,6 +31,7 @@ pub enum Error {
InvalidPcapFd, InvalidPcapFd,
Io(std::io::Error), Io(std::io::Error),
Libc(Errno), Libc(Errno),
IncorrectDeviceState(crate::State, crate::State),
} }
pub type Result<T> = std::result::Result<T, Error>; pub type Result<T> = std::result::Result<T, Error>;
@ -61,6 +62,10 @@ impl Display for Error {
Error::InvalidPcapFd => write!(f, "internal pcap file descriptor error"), Error::InvalidPcapFd => write!(f, "internal pcap file descriptor error"),
Error::Io(io) => write!(f, "std::io error ({io})"), Error::Io(io) => write!(f, "std::io error ({io})"),
Error::Libc(err) => write!(f, "libc error ({err})"), Error::Libc(err) => write!(f, "libc error ({err})"),
Error::IncorrectDeviceState(des, cur) => write!(
f,
"device in incorrect state (desired: {des:?}; current: {cur:?})"
),
} }
} }
} }

View File

@ -181,3 +181,8 @@ extern "C" {
pkt_data: *mut *mut c_char, pkt_data: *mut *mut c_char,
) -> c_int; ) -> c_int;
} }
#[cfg(target_os = "windows")]
extern "C" {
pub fn pcap_getevent(p: *mut PcapDev) -> windows::Win32::Foundation::HANDLE;
}

View File

@ -16,13 +16,12 @@
use std::{ use std::{
ffi::{CStr, CString}, ffi::{CStr, CString},
ptr, slice, ptr, slice,
time::Duration,
}; };
pub mod error; pub mod error;
mod ffi; mod ffi;
pub use packets; pub use packets;
#[cfg(target_os = "linux")]
pub mod stream;
pub mod consts { pub mod consts {
pub use super::ffi::{ pub use super::ffi::{
@ -102,37 +101,33 @@ impl std::iter::Iterator for PcapDevIterator {
} }
} }
pub trait State {} #[derive(PartialEq, Eq, Debug, Clone, Copy)]
pub trait Activated: State {} pub enum State {
pub trait NotListening: Activated {} Disabled,
pub trait Listening: Activated {} Activated,
pub trait Disabled: State {} Listening,
}
pub enum DevActivated {} impl State {
impl State for DevActivated {} fn is_activated(&self) -> bool {
impl Activated for DevActivated {} match self {
impl NotListening for DevActivated {} Self::Disabled => false,
Self::Activated | Self::Listening => true,
pub enum DevDisabled {} }
impl State for DevDisabled {} }
impl Disabled for DevDisabled {} }
pub enum DevListening {}
impl State for DevListening {}
impl Activated for DevListening {}
impl Listening for DevListening {}
pub struct BpfProgram {} pub struct BpfProgram {}
pub struct Interface<T: State> { pub struct Interface {
dev_name: CString, dev_name: CString,
dev: *mut ffi::PcapDev, dev: *mut ffi::PcapDev,
marker: std::marker::PhantomData<T>,
absorbed: bool, absorbed: bool,
nonblocking: bool, nonblocking: bool,
state: State,
} }
impl<T: State> Drop for Interface<T> { impl Drop for Interface {
fn drop(&mut self) { fn drop(&mut self) {
if !self.absorbed { if !self.absorbed {
unsafe { ffi::pcap_close(self.dev) }; unsafe { ffi::pcap_close(self.dev) };
@ -140,11 +135,21 @@ impl<T: State> Drop for Interface<T> {
} }
} }
unsafe impl<T: State> Send for Interface<T> {} unsafe impl Send for Interface {}
unsafe impl<T: State> Sync for Interface<T> {} unsafe impl Sync for Interface {}
impl<T: State> Interface<T> { struct ListenHandler<'a, F>
pub fn new(name: &str) -> error::Result<Interface<DevDisabled>> { where
F: FnMut(&Interface, packets::EthernetPkt) -> error::Result<bool>,
{
packet_handler: F,
break_on_fail: bool,
fail_error: Option<error::Error>,
interface: &'a Interface,
}
impl Interface {
pub fn new(name: &str) -> error::Result<Interface> {
let mut errbuf = [0i8; ffi::PCAP_ERRBUF_SIZE]; let mut errbuf = [0i8; ffi::PCAP_ERRBUF_SIZE];
let dev_name = CString::new(name)?; let dev_name = CString::new(name)?;
@ -154,12 +159,12 @@ impl<T: State> Interface<T> {
Err(&errbuf)?; Err(&errbuf)?;
} }
Ok(Interface::<DevDisabled> { Ok(Interface {
dev_name, dev_name,
dev, dev,
marker: std::marker::PhantomData,
absorbed: false, absorbed: false,
nonblocking: false, nonblocking: false,
state: State::Disabled,
}) })
} }
@ -200,10 +205,15 @@ impl<T: State> Interface<T> {
pub fn name(&self) -> &str { pub fn name(&self) -> &str {
std::str::from_utf8(self.dev_name.as_bytes()).unwrap() std::str::from_utf8(self.dev_name.as_bytes()).unwrap()
} }
pub fn set_promisc(&mut self, promisc: bool) -> error::Result<()> {
if self.state != State::Disabled {
return Err(error::Error::IncorrectDeviceState(
State::Disabled,
self.state,
));
} }
impl<T: Disabled> Interface<T> {
pub fn set_promisc(&mut self, promisc: bool) -> error::Result<()> {
if unsafe { ffi::pcap_set_promisc(self.dev, i32::from(promisc)) } != 0 { if unsafe { ffi::pcap_set_promisc(self.dev, i32::from(promisc)) } != 0 {
Err(unsafe { ffi::pcap_geterr(self.dev) })?; Err(unsafe { ffi::pcap_geterr(self.dev) })?;
} }
@ -212,6 +222,13 @@ impl<T: Disabled> Interface<T> {
} }
pub fn set_buffer_size(&mut self, bufsize: i32) -> error::Result<()> { pub fn set_buffer_size(&mut self, bufsize: i32) -> error::Result<()> {
if self.state != State::Disabled {
return Err(error::Error::IncorrectDeviceState(
State::Disabled,
self.state,
));
}
if unsafe { ffi::pcap_set_buffer_size(self.dev, bufsize) } != 0 { if unsafe { ffi::pcap_set_buffer_size(self.dev, bufsize) } != 0 {
Err(unsafe { ffi::pcap_geterr(self.dev) })?; Err(unsafe { ffi::pcap_geterr(self.dev) })?;
} }
@ -220,6 +237,13 @@ impl<T: Disabled> Interface<T> {
} }
pub fn set_timeout(&mut self, timeout: i32) -> error::Result<()> { pub fn set_timeout(&mut self, timeout: i32) -> error::Result<()> {
if self.state != State::Disabled {
return Err(error::Error::IncorrectDeviceState(
State::Disabled,
self.state,
));
}
if unsafe { ffi::pcap_set_timeout(self.dev, timeout) } != 0 { if unsafe { ffi::pcap_set_timeout(self.dev, timeout) } != 0 {
Err(unsafe { ffi::pcap_geterr(self.dev) })?; Err(unsafe { ffi::pcap_geterr(self.dev) })?;
} }
@ -227,26 +251,33 @@ impl<T: Disabled> Interface<T> {
Ok(()) Ok(())
} }
pub fn activate(mut self) -> error::Result<Interface<DevActivated>> { pub fn activate(&mut self) -> error::Result<()> {
if self.state != State::Disabled {
return Err(error::Error::IncorrectDeviceState(
State::Disabled,
self.state,
));
}
if unsafe { ffi::pcap_activate(self.dev) } != 0 { if unsafe { ffi::pcap_activate(self.dev) } != 0 {
Err(unsafe { ffi::pcap_geterr(self.dev) })?; Err(unsafe { ffi::pcap_geterr(self.dev) })?;
} }
self.absorbed = true; self.absorbed = true;
self.state = State::Activated;
Ok(Interface::<DevActivated> { Ok(())
dev_name: self.dev_name.clone(),
dev: self.dev,
marker: std::marker::PhantomData,
absorbed: false,
nonblocking: self.nonblocking,
})
}
} }
impl<T: Activated> Interface<T> { pub fn datalink(&self) -> error::Result<i32> {
pub fn datalink(&self) -> i32 { if !self.state.is_activated() {
unsafe { ffi::pcap_datalink(self.dev) } return Err(error::Error::IncorrectDeviceState(
State::Activated,
self.state,
));
}
Ok(unsafe { ffi::pcap_datalink(self.dev) })
} }
pub fn set_filter( pub fn set_filter(
@ -255,6 +286,13 @@ impl<T: Activated> Interface<T> {
optimize: bool, optimize: bool,
mask: Option<u32>, mask: Option<u32>,
) -> error::Result<Box<ffi::BpfProgram>> { ) -> error::Result<Box<ffi::BpfProgram>> {
if !self.state.is_activated() {
return Err(error::Error::IncorrectDeviceState(
State::Activated,
self.state,
));
}
let mut bpf = ffi::BpfProgram { let mut bpf = ffi::BpfProgram {
bf_len: 0, bf_len: 0,
bpf_insn: ptr::null(), bpf_insn: ptr::null(),
@ -290,6 +328,13 @@ impl<T: Activated> Interface<T> {
} }
pub fn sendpacket(&self, packet: packets::EthernetPkt) -> error::Result<()> { pub fn sendpacket(&self, packet: packets::EthernetPkt) -> error::Result<()> {
if !self.state.is_activated() {
return Err(error::Error::IncorrectDeviceState(
State::Activated,
self.state,
));
}
if unsafe { if unsafe {
ffi::pcap_sendpacket( ffi::pcap_sendpacket(
self.dev, self.dev,
@ -304,12 +349,37 @@ impl<T: Activated> Interface<T> {
Ok(()) Ok(())
} }
pub fn next_packet(&mut self) -> error::Result<packets::EthernetPacket> { pub fn next_packet(&self) -> error::Result<packets::EthernetPacket> {
if !self.state.is_activated() {
return Err(error::Error::IncorrectDeviceState(
State::Activated,
self.state,
));
}
let mut header: *mut ffi::PktHeader = ptr::null_mut(); let mut header: *mut ffi::PktHeader = ptr::null_mut();
let mut data: *mut libc::c_char = ptr::null_mut(); let mut data: *mut libc::c_char = ptr::null_mut();
if unsafe { ffi::pcap_next_ex(self.dev, &mut header as *mut _, &mut data as *mut _) < 1 } { let res =
return unsafe { Err(ffi::pcap_geterr(self.dev))? }; unsafe { ffi::pcap_next_ex(self.dev, &mut header as *mut _, &mut data as *mut _) };
match res {
1 => {} // no problems
0 => {
// timeout
return Err(std::io::Error::new(
std::io::ErrorKind::WouldBlock,
"pcap timeout",
))
.map_err(Into::into);
}
-1 => {
// actual error
return unsafe { Err(ffi::pcap_geterr(self.dev)).map_err(Into::into) };
}
_ => {
panic!("Unrecognized value returned from pcap_next_ex");
}
} }
let rdata = unsafe { slice::from_raw_parts(data as *mut u8, (*header).caplen as usize) }; let rdata = unsafe { slice::from_raw_parts(data as *mut u8, (*header).caplen as usize) };
@ -319,34 +389,29 @@ impl<T: Activated> Interface<T> {
Ok(packets::EthernetPkt { data: rdata }.to_owned()) Ok(packets::EthernetPkt { data: rdata }.to_owned())
} }
}
struct ListenHandler<'a, F>
where
F: FnMut(&Interface<DevListening>, packets::EthernetPkt) -> error::Result<bool>,
{
packet_handler: F,
break_on_fail: bool,
fail_error: Option<error::Error>,
interface: &'a Interface<DevListening>,
}
impl<T: NotListening> Interface<T> {
pub fn listen<F>( pub fn listen<F>(
&self, &self,
packet_handler: F, packet_handler: F,
break_on_fail: bool, break_on_fail: bool,
packet_count: i32, packet_count: i32,
) -> (Option<error::Error>, i32) ) -> error::Result<(Option<error::Error>, i32)>
where where
F: FnMut(&Interface<DevListening>, packets::EthernetPkt) -> error::Result<bool>, F: FnMut(&Interface, packets::EthernetPkt) -> error::Result<bool>,
{ {
if self.state == State::Listening {
return Err(error::Error::IncorrectDeviceState(
State::Activated,
self.state,
));
}
unsafe extern "C" fn cback<F>( unsafe extern "C" fn cback<F>(
user: *mut libc::c_void, user: *mut libc::c_void,
header: *const ffi::PktHeader, header: *const ffi::PktHeader,
data: *const u8, data: *const u8,
) where ) where
F: FnMut(&Interface<DevListening>, packets::EthernetPkt) -> error::Result<bool>, F: FnMut(&Interface, packets::EthernetPkt) -> error::Result<bool>,
{ {
let info = &mut *(user as *mut ListenHandler<F>); let info = &mut *(user as *mut ListenHandler<F>);
@ -380,9 +445,9 @@ impl<T: NotListening> Interface<T> {
let interface = Interface { let interface = Interface {
dev_name: self.dev_name.clone(), dev_name: self.dev_name.clone(),
dev: self.dev, dev: self.dev,
marker: std::marker::PhantomData,
absorbed: true, absorbed: true,
nonblocking: self.nonblocking, nonblocking: self.nonblocking,
state: State::Listening,
}; };
let mut info = ListenHandler::<F> { let mut info = ListenHandler::<F> {
@ -401,20 +466,102 @@ impl<T: NotListening> Interface<T> {
) )
}; };
(info.fail_error, count) Ok((info.fail_error, count))
} }
#[cfg(target_os = "linux")] #[cfg(target_os = "windows")]
pub fn stream(mut self) -> error::Result<stream::InterfaceStream<DevActivated>> { pub fn get_wait_ready_callback(&self) -> WaitHandle {
self.set_non_blocking(true)?; let handle = unsafe { ffi::pcap_getevent(self.dev) };
WaitHandle(handle)
}
Ok(stream::InterfaceStream { #[cfg(not(target_os = "windows"))]
inner: tokio::io::unix::AsyncFd::with_interest( pub fn get_wait_ready_callback(&self) -> WaitHandle {
stream::InternalInterfaceStream::<DevActivated>::new(unsafe { let fd = unsafe { ffi::pcap_get_selectable_fd(self.dev) };
std::mem::transmute(self) WaitHandle(fd)
})?, }
tokio::io::Interest::READABLE,
)?, pub fn wait_ready(&self, timeout: Option<Duration>) -> error::Result<()> {
}) self.get_wait_ready_callback().wait(timeout)
}
}
#[cfg(windows)]
#[derive(Clone)]
pub struct WaitHandle(windows::Win32::Foundation::HANDLE);
#[cfg(unix)]
#[derive(Clone)]
pub struct WaitHandle(libc::c_int);
unsafe impl Send for WaitHandle {}
unsafe impl Sync for WaitHandle {}
impl WaitHandle {
#[cfg(windows)]
pub fn wait(&self, timeout: Option<Duration>) -> error::Result<()> {
use windows::Win32::System::Threading::{WaitForSingleObject, INFINITE};
let timeout = timeout
.map(|t| (t.as_millis() & 0xFFFFFFFF) as u32)
.unwrap_or(50);
unsafe {
if WaitForSingleObject(self.0, timeout).0 != 0 {
Err(std::io::Error::last_os_error()).map_err(Into::into)
} else {
Ok(())
}
}
}
#[cfg(unix)]
pub fn wait(&self, timeout: Option<Duration>) -> error::Result<()> {
unsafe {
use std::mem::MaybeUninit;
let mut readfds = {
let mut readfds = MaybeUninit::<libc::fd_set>::uninit();
libc::FD_ZERO(readfds.as_mut_ptr());
libc::FD_SET(self.0, readfds.as_mut_ptr());
readfds.assume_init()
};
let mut writefds = {
let mut writefds = MaybeUninit::<libc::fd_set>::uninit();
libc::FD_ZERO(writefds.as_mut_ptr());
libc::FD_SET(self.0, writefds.as_mut_ptr());
writefds.assume_init()
};
let mut exceptfds = {
let mut exceptfds = MaybeUninit::<libc::fd_set>::uninit();
libc::FD_ZERO(exceptfds.as_mut_ptr());
exceptfds.assume_init()
};
let mut c_timeout = libc::timeval {
tv_sec: 0,
tv_usec: 50_000,
};
if let Some(t) = timeout {
c_timeout.tv_sec = t.as_secs() as libc::time_t;
c_timeout.tv_usec = (t.as_micros() % 1_000_000) as libc::suseconds_t;
}
let res = libc::select(
1,
&mut readfds,
&mut writefds,
&mut exceptfds,
&mut c_timeout as *mut _,
);
if res == -1 {
Err(std::io::Error::last_os_error()).map_err(Into::into)
} else {
Ok(())
}
}
} }
} }

View File

@ -1,343 +0,0 @@
// Copyright (C) 2023 Andrew Rioux
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as
// published by the Free Software Foundation, either version 3 of the
// License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
use std::{
collections::HashMap,
os::fd::{AsRawFd, RawFd},
pin::Pin,
task::{self, Poll},
};
use futures::{ready, StreamExt};
use tokio::io::unix::AsyncFd;
use tokio_stream::StreamMap;
use super::{
error, ffi, packets, Activated, DevActivated, DevDisabled, Disabled, Interface, NotListening,
PcapDevIterator, State,
};
pub(crate) struct InternalInterfaceStream<T: Activated> {
interface: Interface<T>,
fd: RawFd,
}
impl<T: Activated> InternalInterfaceStream<T> {
pub(crate) fn new(interface: Interface<T>) -> error::Result<InternalInterfaceStream<T>> {
let fd = unsafe { ffi::pcap_get_selectable_fd(interface.dev) };
if fd == -1 {
return Err(error::Error::InvalidPcapFd);
}
Ok(Self { interface, fd })
}
}
impl<T: Activated> AsRawFd for InternalInterfaceStream<T> {
fn as_raw_fd(&self) -> RawFd {
self.fd
}
}
pub struct InterfaceStream<T: Activated> {
pub(crate) inner: AsyncFd<InternalInterfaceStream<T>>,
}
impl<T: Activated> InterfaceStream<T> {
pub fn sendpacket(&mut self, packet: packets::EthernetPkt) -> error::Result<()> {
self.inner.get_mut().interface.sendpacket(packet)
}
pub fn set_filter(
&mut self,
filter: &str,
optimize: bool,
mask: Option<u32>,
) -> error::Result<Box<ffi::BpfProgram>> {
self.inner
.get_mut()
.interface
.set_filter(filter, optimize, mask)
}
}
impl<T: Activated> Unpin for InterfaceStream<T> {}
impl<T: Activated> futures::Stream for InterfaceStream<T> {
type Item = error::Result<packets::EthernetPacket>;
fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
let stream = Pin::into_inner(self);
loop {
let mut guard = ready!(stream.inner.poll_read_ready_mut(cx))?;
match guard.try_io(|inner| match inner.get_mut().interface.next_packet() {
Ok(p) => Ok(Ok(p)),
Err(e) => Ok(Err(e)),
}) {
Ok(result) => {
return Poll::Ready(Some(result?));
}
Err(_would_block) => continue,
}
}
}
}
pub fn new_aggregate_interface_filtered<F>(
crash: bool,
mut f: F,
) -> error::Result<AggregateInterface<DevDisabled>>
where
F: FnMut(&str) -> bool,
{
let interfaces = if crash {
PcapDevIterator::new()?
.filter(|s| (f)(s))
.map(|if_name| {
let new_name = if_name.clone();
Interface::<DevDisabled>::new(&if_name)
.map(|interface| (if_name, interface))
.map_err(|e| e.add_ifname(&new_name))
})
.collect::<error::Result<HashMap<_, _>>>()?
} else {
PcapDevIterator::new()?
.filter(|s| (f)(s))
.filter_map(|if_name| {
let new_name = if_name.clone();
Interface::<DevDisabled>::new(&if_name)
.map(|interface| (if_name, interface))
.ok()
.or_else(|| {
println!("{} failed to create device", new_name);
None
})
})
.collect::<HashMap<_, _>>()
};
Ok(AggregateInterface { interfaces, crash })
}
pub fn new_aggregate_interface(crash: bool) -> error::Result<AggregateInterface<DevDisabled>> {
new_aggregate_interface_filtered(crash, |_| true)
}
pub struct AggregateInterface<T: State> {
interfaces: HashMap<String, Interface<T>>,
crash: bool,
}
impl<T: State> AggregateInterface<T> {
pub fn set_non_blocking(&mut self, nonblocking: bool) -> error::Result<()> {
for (n, i) in self.interfaces.iter_mut() {
i.set_non_blocking(nonblocking)
.map_err(|e| e.add_ifname(n))?;
}
Ok(())
}
pub fn lookupnets(&self) -> error::Result<HashMap<&str, (u32, u32)>> {
self.interfaces
.iter()
.map(|(name, interface)| {
interface
.lookupnet()
.map(|net| (&**name, net))
.map_err(|e| e.add_ifname(&name))
})
.collect::<error::Result<_>>()
}
pub fn get_ifnames(&self) -> Vec<&str> {
self.interfaces.keys().map(|n| &**n).collect::<_>()
}
}
impl<T: Disabled> AggregateInterface<T> {
pub fn set_promisc(&mut self, promisc: bool) -> error::Result<()> {
for (n, i) in self.interfaces.iter_mut() {
i.set_promisc(promisc).map_err(|e| e.add_ifname(n))?;
}
Ok(())
}
pub fn set_buffer_size(&mut self, bufsize: i32) -> error::Result<()> {
for (n, i) in self.interfaces.iter_mut() {
i.set_buffer_size(bufsize).map_err(|e| e.add_ifname(n))?;
}
Ok(())
}
pub fn set_timeout(&mut self, timeout: i32) -> error::Result<()> {
for (n, i) in self.interfaces.iter_mut() {
i.set_timeout(timeout).map_err(|e| e.add_ifname(n))?;
}
Ok(())
}
pub fn activate(self) -> error::Result<AggregateInterface<DevActivated>> {
Ok(AggregateInterface {
interfaces: if self.crash {
self.interfaces
.into_iter()
.map(|(name, interface)| {
let new_name = name.clone();
interface
.activate()
.map(|interface| (name, interface))
.map_err(|e| e.add_ifname(&new_name))
})
.collect::<error::Result<_>>()?
} else {
self.interfaces
.into_iter()
.filter_map(|(name, interface)| {
let name_clone = name.clone();
interface
.activate()
.map(|interface| (name, interface))
.ok()
.or_else(|| {
println!("{} failed to activate", name_clone);
None
})
})
.collect::<_>()
},
crash: self.crash,
})
}
}
impl<T: Activated> AggregateInterface<T> {
pub fn datalinks(&self) -> HashMap<&str, i32> {
self.interfaces
.iter()
.map(|(name, interface)| (&**name, interface.datalink()))
.collect::<_>()
}
pub fn prune<F>(&mut self, mut f: F)
where
F: FnMut(&str, &mut Interface<T>) -> bool,
{
let to_prune = self
.interfaces
.iter_mut()
.filter_map(|(k, v)| if (f)(k, v) { Some(k.clone()) } else { None })
.collect::<Vec<_>>();
for name in to_prune {
self.interfaces.remove(&name);
}
}
pub fn set_filter(
&mut self,
filter: &str,
optimize: bool,
mask: Option<u32>,
) -> error::Result<HashMap<&str, Box<ffi::BpfProgram>>> {
if self.crash {
self.interfaces
.iter_mut()
.map(|(name, interface)| {
interface
.set_filter(filter, optimize, mask)
.map(|bpf| (&**name, bpf))
.map_err(|e| e.add_ifname(&name))
})
.collect::<error::Result<_>>()
} else {
Ok(self
.interfaces
.iter_mut()
.filter_map(|(name, interface)| {
let name_clone = name.clone();
interface
.set_filter(filter, optimize, mask)
.map(|bpf| (&**name, bpf))
.ok()
.or_else(|| {
println!("{} failed to set filter", name_clone);
None
})
})
.collect::<_>())
}
}
pub fn sendpacket(&self, ifname: &str, packet: packets::EthernetPkt) -> error::Result<()> {
if let Some(interface) = self.interfaces.get(ifname) {
interface
.sendpacket(packet)
.map_err(|e| e.add_ifname(ifname))?;
}
Ok(())
}
}
impl<T: NotListening> AggregateInterface<T> {
pub fn stream(self) -> error::Result<AggregateInterfaceStream<DevActivated>> {
Ok(AggregateInterfaceStream {
streams: self
.interfaces
.into_iter()
.map(|(ifname, interface)| {
let new_name = ifname.clone();
interface
.stream()
.map(|stream| (ifname, stream))
.map_err(|e| e.add_ifname(&new_name))
})
.collect::<error::Result<_>>()?,
})
}
}
pub struct AggregateInterfaceStream<T: Activated> {
streams: StreamMap<String, InterfaceStream<T>>,
}
impl<T: Activated> AggregateInterfaceStream<T> {
pub fn get_ifnames(&self) -> Vec<&str> {
self.streams.keys().map(|n| &**n).collect::<_>()
}
pub fn sendpacket(&mut self, ifname: &str, packet: packets::EthernetPkt) -> error::Result<()> {
if let Some(interface) = self.streams.values_mut().find(|interface| {
interface.inner.get_ref().interface.dev_name.as_bytes() == ifname.as_bytes()
}) {
interface.sendpacket(packet)?;
}
Ok(())
}
}
impl<T: Activated> futures::Stream for AggregateInterfaceStream<T> {
type Item = (String, error::Result<packets::EthernetPacket>);
fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
self.streams.poll_next_unpin(cx)
}
}

View File

@ -5,4 +5,21 @@ edition = "2021"
publish = false publish = false
[dependencies] [dependencies]
pcap-sys = { path = "../pcap-sys" } hyper = { version = "1.6.0", features = ["client", "http1", "http2"] }
smoltcp = { version = "0.12.0", default-features = false, features = ["async", "log", "medium-ethernet", "proto-ipv4", "proto-ipv4-fragmentation", "socket-raw", "socket-tcp", "std"] }
thiserror = "2.0.11"
tokio = { version = "1.43.0", features = ["fs", "io-std", "io-util", "net", "process", "rt", "sync", "tokio-macros"] }
async-trait = "0.1.86"
tracing = "0.1.41"
rand = "0.9.0"
pin-project = "1.1.9"
hyper-util = { version = "0.1.10", features = ["client", "client-legacy", "http1", "http2", "service", "tokio"] }
hyper-rustls = { version = "0.27.5", default-features = false, features = ["http1", "http2", "native-tokio", "ring", "tls12"] }
tower-service = "0.3.3"
futures = "0.3.31"
simple_logger = "5.0.0"
pcap-sys = { version = "0.1.0", path = "../pcap-sys" }
sparse-actions = { version = "2.0.0", path = "../sparse-actions" }
packets = { version = "0.1.0", path = "../packets" }
nl-sys = { version = "0.1.0", path = "../nl-sys" }

View File

@ -0,0 +1,29 @@
use std::net::Ipv4Addr;
use crate::error;
#[derive(Debug)]
pub struct BeaconRoute {
pub network: (Ipv4Addr, u8),
pub gateway: (Ipv4Addr, u8),
pub interface_index: usize,
}
pub struct BeaconNetworkingInfo {
pub routes: Vec<BeaconRoute>,
pub interfaces: Vec<BeaconInterface>,
}
#[derive(Debug)]
pub struct BeaconInterface {
pub name: Vec<u8>,
pub mtu: u16,
pub mac_addr: [u8; 6],
}
#[async_trait::async_trait]
pub trait BeaconAdapter {
fn interface_name_from_interface(interface: &BeaconInterface) -> Vec<u8>;
fn networking_info(&self) -> Result<BeaconNetworkingInfo, error::BeaconError>;
}

View File

@ -0,0 +1,337 @@
use std::{
future::Future,
net::Ipv4Addr,
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll},
};
use futures::ready;
use smoltcp::{
iface::{Config, Interface, SocketHandle, SocketSet},
socket::tcp::{RecvError, SendError, Socket, SocketBuffer, State},
time::Instant,
wire::{EthernetAddress, IpCidr, Ipv4Address},
};
use tokio::{
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
task::{spawn, spawn_blocking, JoinHandle},
};
use sparse_actions::payload_types::Parameters;
use crate::{adapter, error};
pub struct NetInterfaceHandle {
net: Arc<Mutex<(SocketSet<'static>, crate::socket::RawSocket, Interface)>>,
tcp_handle: SocketHandle,
background_process: JoinHandle<()>,
}
impl Drop for NetInterfaceHandle {
fn drop(&mut self) {
self.background_process.abort();
}
}
impl AsyncRead for NetInterfaceHandle {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
tbuf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<Result<(), std::io::Error>> {
let this = self.get_mut();
let Ok(mut inner) = this.net.lock() else {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"mutex for tcp connection is poisoned",
)));
};
let (ref mut s_guard, _, _) = *inner;
let socket = s_guard.get_mut::<Socket>(this.tcp_handle);
let has_data = socket.can_recv();
while socket.can_recv() {
let buf = match socket.recv(|buf| (buf.len(), buf.to_vec())) {
Ok(v) => v,
Err(RecvError::InvalidState) => {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::NetworkDown,
"received InvalidState from smoltcp",
)));
}
Err(RecvError::Finished) => {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"tried reading from finished connection",
)));
}
};
tbuf.put_slice(&buf);
}
socket.register_recv_waker(cx.waker());
if has_data {
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
}
impl AsyncWrite for NetInterfaceHandle {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
src: &[u8],
) -> Poll<std::io::Result<usize>> {
let this = self.get_mut();
let Ok(mut inner) = this.net.lock() else {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"mutex for tcp connection is poisoned",
)));
};
let (ref mut s_guard, _, _) = *inner;
let socket = s_guard.get_mut::<Socket>(this.tcp_handle);
socket.register_send_waker(cx.waker());
if socket.can_send() {
let to_send = socket.send_capacity().min(src.len());
match socket.send_slice(&src[..to_send]) {
Ok(s) => Poll::Ready(Ok(s)),
Err(SendError::InvalidState) => {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::NetworkDown,
"received InvalidState from smoltcp",
)))
}
}
} else {
Poll::Pending
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
let this = self.get_mut();
let Ok(mut inner) = this.net.lock() else {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"mutex for tcp connection is poisoned",
)));
};
let (ref mut s_guard, _, _) = *inner;
let socket = s_guard.get_mut::<Socket>(this.tcp_handle);
socket.close();
Poll::Ready(Ok(()))
}
}
pub async fn setup_network<T>(
adapter: &T,
parameters: &Parameters,
) -> Result<NetInterfaceHandle, error::BeaconError>
where
T: adapter::BeaconAdapter + Clone + Send + 'static,
{
let net_info = tokio::task::spawn_blocking({
let adapter = adapter.clone();
move || adapter.networking_info()
})
.await??;
let (interface, gateway_ip, mac_address, source_ip, netmask) = match unsafe {
parameters.source_ip.custom_networking.mode
} {
0 => {
// custom networking
let interface_name = unsafe {
&parameters.source_ip.custom_networking.interface
[..parameters.source_ip.custom_networking.interface_len as usize]
};
let interface = if interface_name.is_empty() {
let Some(default_route) = net_info.routes.iter().find(|r| r.network.1 == 0) else {
return Err(error::BeaconError::NoDefaultRoute);
};
&net_info.interfaces[default_route.interface_index]
} else {
net_info
.interfaces
.iter()
.find(|intf| intf.name == interface_name)
.ok_or(error::BeaconError::NoDefaultRoute)?
};
unsafe {
(
interface,
Ipv4Addr::new(
parameters.source_ip.custom_networking.gateway.a,
parameters.source_ip.custom_networking.gateway.b,
parameters.source_ip.custom_networking.gateway.c,
parameters.source_ip.custom_networking.gateway.d,
),
parameters.source_ip.custom_networking.source_mac.clone(),
Ipv4Addr::new(
parameters.source_ip.custom_networking.source_ip.a,
parameters.source_ip.custom_networking.source_ip.b,
parameters.source_ip.custom_networking.source_ip.c,
parameters.source_ip.custom_networking.source_ip.d,
),
parameters.source_ip.custom_networking.netmask as u8,
)
}
}
1 => {
// host networking
let Some(default_route) = net_info.routes.iter().find(|r| r.network.1 == 0) else {
return Err(error::BeaconError::NoDefaultRoute);
};
let default_route_if = &net_info.interfaces[default_route.interface_index];
(
default_route_if,
default_route.gateway.0,
default_route_if.mac_addr.clone(),
unsafe {
Ipv4Addr::new(
parameters.source_ip.use_host_networking.source_ip.a,
parameters.source_ip.use_host_networking.source_ip.b,
parameters.source_ip.use_host_networking.source_ip.c,
parameters.source_ip.use_host_networking.source_ip.d,
)
},
default_route.gateway.1,
)
}
_ => panic!("Corrupted parameters present!"),
};
let go_promisc = mac_address != [0, 0, 0, 0, 0, 0];
let mac_address = Some(mac_address)
.filter(|smac| smac != &[0, 0, 0, 0, 0, 0])
.unwrap_or(interface.mac_addr);
let local_port = 49152 + rand::random::<u16>() % 16384;
let mut device = crate::socket::RawSocket::new::<T>(interface, go_promisc, local_port)?;
let mut config = Config::new(EthernetAddress(mac_address).into());
config.random_seed = rand::random();
let mut iface = Interface::new(config, &mut device, Instant::now());
iface.update_ip_addrs(|addrs| {
addrs
.push(IpCidr::new(source_ip.into(), netmask))
.expect("could not add new IP address");
});
iface
.routes_mut()
.add_default_ipv4_route(gateway_ip.into())
.expect("did not expect route table to be full");
let tcp_rx_buffer = SocketBuffer::new(vec![0; 8192]);
let tcp_tx_buffer = SocketBuffer::new(vec![0; 8192]);
let tcp_socket = Socket::new(tcp_rx_buffer, tcp_tx_buffer);
let mut sockets = SocketSet::new(vec![]);
let tcp_handle = sockets.add(tcp_socket);
let mut active = false;
let ready_wait = device.get_ready_wait_callback();
let destination = (
Ipv4Address::new(
parameters.destination_ip.a,
parameters.destination_ip.b,
parameters.destination_ip.c,
parameters.destination_ip.d,
),
8080, //parameters.destination_port,
);
while !active {
let timestamp = Instant::now();
iface.poll(timestamp, &mut device, &mut sockets);
let cx = iface.context();
let socket = sockets.get_mut::<Socket>(tcp_handle);
if !socket.is_active() {
socket.connect(cx, destination, local_port)?;
}
active = socket.is_active() && socket.state() == State::Established;
ready_wait.wait(iface.poll_delay(timestamp, &sockets).map(Into::into))?;
}
let net = Arc::new(Mutex::new((sockets, device, iface)));
let background_process = spawn({
let net = Arc::clone(&net);
async move {
loop {
let delay = {
let Ok(mut guard) = net.lock() else {
continue;
};
let (ref mut s_guard, ref mut d_guard, ref mut i_guard) = *guard;
let timestamp = Instant::now();
i_guard.poll(timestamp, d_guard, s_guard);
i_guard.poll_delay(timestamp, s_guard)
};
let _ = ready_wait.wait(delay.map(Into::into));
}
}
});
Ok(NetInterfaceHandle {
net,
tcp_handle,
background_process,
})
}
pub async fn perform_callback<T>(
adapter: &T,
parameters: &Parameters,
) -> Result<(), error::BeaconError>
where
T: adapter::BeaconAdapter + Clone + Send + 'static,
{
println!("Attempting net connection...");
let mut net_handle = setup_network(adapter, parameters).await?;
println!("Got connection!");
let mut buffer = vec![0u8; 4096];
net_handle.write(&*b"Hello there\n").await?;
while let Ok(v) = net_handle.read(&mut buffer).await {
println!("Received {v} bytes: {:?}", &buffer[..v]);
net_handle.write(&buffer[..v]).await?;
}
println!("Finishing connection");
Ok(())
}

View File

@ -0,0 +1,19 @@
use thiserror::Error;
#[derive(Error, Debug)]
pub enum BeaconError {
#[error("io error")]
Io(#[from] std::io::Error),
#[error("pcap error")]
Pcap(#[from] pcap_sys::error::Error),
#[error("utf8 decoding error")]
Utf8(#[from] std::str::Utf8Error),
#[error("task join error")]
Join(#[from] tokio::task::JoinError),
#[error("could not find default route")]
NoDefaultRoute,
#[error("connection error")]
Connect(#[from] smoltcp::socket::tcp::ConnectError),
#[error("netlink error")]
Nl(#[from] nl_sys::error::Error),
}

View File

@ -1 +1,16 @@
pub fn run_beacon_step() {} use sparse_actions::payload_types::Parameters;
pub mod adapter;
mod callback;
pub mod error;
mod socket;
pub use error::BeaconError;
pub async fn run_beacon_step<A>(host_adapter: A, params: Parameters) -> Result<(), BeaconError>
where
A: adapter::BeaconAdapter + Clone + Send + 'static,
{
callback::perform_callback(&host_adapter, &params).await?;
Ok(())
}

113
sparse-beacon/src/main.rs Normal file
View File

@ -0,0 +1,113 @@
use std::{io::SeekFrom, net::Ipv4Addr};
use tokio::io::{AsyncReadExt, AsyncSeekExt};
use nl_sys::netlink;
use sparse_actions::payload_types::{Parameters, XOR_KEY};
use sparse_beacon::{
adapter::{BeaconAdapter, BeaconInterface, BeaconNetworkingInfo, BeaconRoute},
error,
};
#[derive(Clone)]
struct LinuxAdapter;
#[async_trait::async_trait]
impl BeaconAdapter for LinuxAdapter {
fn interface_name_from_interface(interface: &BeaconInterface) -> Vec<u8> {
interface.name.clone()
}
fn networking_info(&self) -> Result<BeaconNetworkingInfo, error::BeaconError> {
let nlsock = netlink::Socket::new()?;
let routes = nlsock.get_routes()?;
let links = nlsock.get_links()?;
let links_vec = links.iter().collect::<Vec<_>>();
Ok(BeaconNetworkingInfo {
routes: routes
.iter()
.filter_map(|r| {
let dst = r.dst()?;
let dst4: Ipv4Addr = (&dst).try_into().ok()?;
let next_hop = r.nexthop(0)?;
let gateway = next_hop.gateway()?;
let gateway4: Ipv4Addr = (&gateway).try_into().ok()?;
let gateway_int = u32::from(gateway4);
let src_cidr = routes.iter().find_map(|r| {
let dst = r.dst()?;
let dst4: Ipv4Addr = (&dst).try_into().ok()?;
if dst.cidrlen() == 0 {
return None;
}
let mask = (0xFFFFFFFFu32.overflowing_shr(32 - dst.cidrlen()))
.0
.overflowing_shl(32 - dst.cidrlen())
.0;
if (mask & u32::from(dst4)) == (mask & gateway_int) {
Some(dst.cidrlen())
} else {
None
}
})?;
Some(BeaconRoute {
network: (dst4, dst.cidrlen() as u8),
gateway: (gateway4, src_cidr as u8),
interface_index: links_vec
.iter()
.position(|l| l.ifindex() == next_hop.ifindex())?,
})
})
.collect(),
interfaces: links
.iter()
.filter_map(|l| {
let mac_addr = l.addr().hw_address();
Some(BeaconInterface {
name: l.name().as_bytes().to_owned(),
mtu: (l.mtu() & 0xFFFF) as u16,
mac_addr: mac_addr.try_into().ok()?,
})
})
.collect(),
})
}
}
#[tokio::main]
async fn main() -> Result<(), sparse_beacon::BeaconError> {
let installer = std::env::args()
.skip(1)
.next()
.expect("Could not get a reference to a sparse installer");
let mut installer_file = tokio::fs::OpenOptions::new()
.read(true)
.open(installer)
.await?;
let parameters_size = std::mem::size_of::<Parameters>() as i64;
installer_file.seek(SeekFrom::End(-parameters_size)).await?;
let mut parameters_buffer = Vec::with_capacity(parameters_size as usize);
installer_file.read_to_end(&mut parameters_buffer).await?;
for b in parameters_buffer.iter_mut() {
*b = *b ^ (XOR_KEY as u8);
}
let parameters: Parameters =
unsafe { std::mem::transmute(*(parameters_buffer.as_ptr() as *const Parameters)) };
sparse_beacon::run_beacon_step(LinuxAdapter, parameters).await?;
Ok(())
}

132
sparse-beacon/src/socket.rs Normal file
View File

@ -0,0 +1,132 @@
use smoltcp::phy::{self, Device, DeviceCapabilities, Medium};
use pcap_sys::Interface;
use crate::{adapter, error};
struct SocketInner {
lower: Interface,
}
pub struct RawSocket {
inner: SocketInner,
mtu: usize,
}
impl RawSocket {
pub fn new<T: adapter::BeaconAdapter>(
a_interface: &adapter::BeaconInterface,
promisc: bool,
port: u16,
) -> Result<Self, error::BeaconError> {
let name_raw = T::interface_name_from_interface(&a_interface);
let name = std::str::from_utf8(&name_raw)?;
let mut lower = Interface::new(name)?;
let mtu = a_interface.mtu as usize + if cfg!(unix) { 14 } else { 0 };
dbg!(promisc);
lower.set_promisc(promisc)?;
lower.set_buffer_size(mtu as i32)?;
lower.set_non_blocking(true)?;
lower.set_buffer_size(8192)?;
lower.set_timeout(10)?;
lower.activate()?;
lower.set_filter(&format!("arp or (inbound and tcp port {port})"), true, None)?;
Ok(Self {
inner: SocketInner { lower },
mtu,
})
}
pub fn get_ready_wait_callback(&self) -> pcap_sys::WaitHandle {
self.inner.lower.get_wait_ready_callback()
}
}
impl Device for RawSocket {
type RxToken<'a>
= RxToken
where
Self: 'a;
type TxToken<'a>
= TxToken<'a>
where
Self: 'a;
fn capabilities(&self) -> DeviceCapabilities {
let mut caps = DeviceCapabilities::default();
caps.max_transmission_unit = self.mtu;
caps.medium = Medium::Ethernet;
caps
}
fn receive(
&mut self,
_timestamp: smoltcp::time::Instant,
) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> {
match self.inner.lower.next_packet() {
Ok(p) => {
let rx = RxToken {
buffer: p.pkt().raw().to_vec(),
};
let tx = TxToken { inner: &self.inner };
Some((rx, tx))
}
Err(pcap_sys::error::Error::Io(e)) if e.kind() == std::io::ErrorKind::WouldBlock => {
None
}
Err(e) => {
panic!("{}", e);
}
}
}
fn transmit(&mut self, _timestamp: smoltcp::time::Instant) -> Option<Self::TxToken<'_>> {
Some(TxToken { inner: &self.inner })
}
}
pub struct TxToken<'a> {
inner: &'a SocketInner,
}
impl phy::TxToken for TxToken<'_> {
fn consume<R, F>(self, len: usize, f: F) -> R
where
F: FnOnce(&mut [u8]) -> R,
{
let mut buffer = vec![0; len];
let result = f(&mut buffer);
let packet = packets::EthernetPacket::from_raw(buffer);
match self.inner.lower.sendpacket(packet.pkt()) {
Ok(_) => {}
Err(pcap_sys::error::Error::Io(e)) if e.kind() == std::io::ErrorKind::WouldBlock => {
println!("Failed to send due to non blocking mode");
}
Err(err) => panic!("{}", err),
}
drop(packet);
result
}
}
pub struct RxToken {
buffer: Vec<u8>,
}
impl phy::RxToken for RxToken {
fn consume<R, F>(self, f: F) -> R
where
F: FnOnce(&[u8]) -> R,
{
f(&self.buffer[..])
}
}

View File

@ -3,14 +3,14 @@ use std::{
sync::{Arc, RwLock}, sync::{Arc, RwLock},
}; };
use axum::routing::{get, post, Router}; use axum::routing::{Router, get, post};
use sqlx::SqlitePool; use sqlx::SqlitePool;
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
pub mod error; pub mod error;
pub struct BeaconListenerHandle { pub struct BeaconListenerHandle {
join_handle: JoinHandle<()> join_handle: JoinHandle<()>,
} }
impl BeaconListenerHandle { impl BeaconListenerHandle {
@ -34,7 +34,10 @@ impl std::ops::Deref for BeaconListenerMap {
} }
} }
pub async fn start_all_listeners(beacon_listener_map: BeaconListenerMap, db: SqlitePool) -> Result<(), crate::error::Error> { pub async fn start_all_listeners(
beacon_listener_map: BeaconListenerMap,
db: SqlitePool,
) -> Result<(), crate::error::Error> {
let listener_ids = sqlx::query!("SELECT listener_id FROM beacon_listener") let listener_ids = sqlx::query!("SELECT listener_id FROM beacon_listener")
.fetch_all(&db) .fetch_all(&db)
.await?; .await?;
@ -42,7 +45,12 @@ pub async fn start_all_listeners(beacon_listener_map: BeaconListenerMap, db: Sql
tracing::info!("Starting {} listener(s)...", listener_ids.len()); tracing::info!("Starting {} listener(s)...", listener_ids.len());
for listener in listener_ids { for listener in listener_ids {
start_listener(beacon_listener_map.clone(), listener.listener_id, db.clone()).await?; start_listener(
beacon_listener_map.clone(),
listener.listener_id,
db.clone(),
)
.await?;
} }
Ok(()) Ok(())
@ -50,7 +58,7 @@ pub async fn start_all_listeners(beacon_listener_map: BeaconListenerMap, db: Sql
#[derive(Clone)] #[derive(Clone)]
struct ListenerState { struct ListenerState {
db: SqlitePool db: SqlitePool,
} }
struct Listener { struct Listener {
@ -59,41 +67,59 @@ struct Listener {
public_ip: String, public_ip: String,
domain_name: String, domain_name: String,
certificate: Vec<u8>, certificate: Vec<u8>,
privkey: Vec<u8> privkey: Vec<u8>,
} }
pub async fn start_listener(beacon_listener_map: BeaconListenerMap, listener_id: i64, db: SqlitePool) -> Result<(), crate::error::Error> { pub async fn start_listener(
beacon_listener_map: BeaconListenerMap,
listener_id: i64,
db: SqlitePool,
) -> Result<(), crate::error::Error> {
{ {
let Ok(blm_handle) = beacon_listener_map.read() else { let Ok(blm_handle) = beacon_listener_map.read() else {
return Err(crate::error::Error::Generic("Could not acquire write lock on beacon listener map".to_string())); return Err(crate::error::Error::Generic(
"Could not acquire write lock on beacon listener map".to_string(),
));
}; };
if blm_handle.get(&listener_id).is_some() { if blm_handle.get(&listener_id).is_some() {
return Err(crate::error::Error::Generic("Beacon listener already started".to_string())); return Err(crate::error::Error::Generic(
"Beacon listener already started".to_string(),
));
} }
} }
let listener = sqlx::query_as!(Listener, "SELECT * FROM beacon_listener WHERE listener_id = ?", listener_id) let listener = sqlx::query_as!(
Listener,
"SELECT * FROM beacon_listener WHERE listener_id = ?",
listener_id
)
.fetch_one(&db) .fetch_one(&db)
.await?; .await?;
let app: Router<()> = Router::new() let app: Router<()> = Router::new()
.route("/register_beacon", post(|| async { .route(
"/register_beacon",
post(|| async {
tracing::info!("Beacon attempting to register"); tracing::info!("Beacon attempting to register");
})) }),
.route("/test", get(|| async { )
.route(
"/test",
get(|| async {
tracing::info!("Hello"); tracing::info!("Hello");
"hi there" "hi there"
})) }),
.with_state(ListenerState { )
db .with_state(ListenerState { db });
});
let hidden_app = Router::new().nest("/hidden_sparse", app); let hidden_app = Router::new().nest("/hidden_sparse", app);
let keypair = match rustls::pki_types::PrivateKeyDer::try_from(listener.privkey.clone()) { let keypair = match rustls::pki_types::PrivateKeyDer::try_from(listener.privkey.clone()) {
Ok(pk) => pk, Ok(pk) => pk,
Err(e) => { Err(e) => {
return Err(crate::error::Error::Generic(format!("Could not parse private key: {e}"))); return Err(crate::error::Error::Generic(format!(
"Could not parse private key: {e}"
)));
} }
}; };
let cert = rustls::pki_types::CertificateDer::from(listener.certificate.clone()); let cert = rustls::pki_types::CertificateDer::from(listener.certificate.clone());
@ -105,14 +131,17 @@ pub async fn start_listener(beacon_listener_map: BeaconListenerMap, listener_id:
let addr = std::net::SocketAddr::from(([0, 0, 0, 0], listener.port as u16)); let addr = std::net::SocketAddr::from(([0, 0, 0, 0], listener.port as u16));
tracing::debug!("Starting listener {}, {}, on port {}", listener_id, listener.domain_name, listener.port); tracing::debug!(
"Starting listener {}, {}, on port {}",
listener_id,
listener.domain_name,
listener.port
);
let join_handle = tokio::task::spawn(async move { let join_handle = tokio::task::spawn(async move {
let res = axum_server::tls_rustls::bind_rustls( let res = axum_server::tls_rustls::bind_rustls(
addr, addr,
axum_server::tls_rustls::RustlsConfig::from_config( axum_server::tls_rustls::RustlsConfig::from_config(Arc::new(tls_config)),
Arc::new(tls_config)
)
) )
.serve(hidden_app.into_make_service()) .serve(hidden_app.into_make_service())
.await; .await;
@ -123,12 +152,12 @@ pub async fn start_listener(beacon_listener_map: BeaconListenerMap, listener_id:
}); });
let Ok(mut blm_handle) = beacon_listener_map.write() else { let Ok(mut blm_handle) = beacon_listener_map.write() else {
return Err(crate::error::Error::Generic("Could not acquire write lock on beacon listener map".to_string())); return Err(crate::error::Error::Generic(
"Could not acquire write lock on beacon listener map".to_string(),
));
}; };
blm_handle.insert(listener_id, BeaconListenerHandle { blm_handle.insert(listener_id, BeaconListenerHandle { join_handle });
join_handle
});
Ok(()) Ok(())
} }

View File

@ -0,0 +1,2 @@
-- Add migration script here
ALTER TABLE beacon_template ADD COLUMN source_interface blob DEFAULT '';

View File

@ -1,9 +1,9 @@
use leptos::{either::Either, prelude::*}; use leptos::{either::Either, prelude::*};
use leptos_meta::{provide_meta_context, MetaTags, Stylesheet, Title}; use leptos_meta::{provide_meta_context, MetaTags, Stylesheet, Title};
use leptos_router::{ use leptos_router::{
components::{A, ParentRoute, Route, Router, Routes}, components::{ParentRoute, Route, Router, Routes, A},
hooks::use_query_map, hooks::use_query_map,
path path,
}; };
use crate::users::User; use crate::users::User;
@ -29,7 +29,7 @@ pub async fn me() -> Result<Option<User>, ServerFnError> {
Ok(user.map(|user| User { Ok(user.map(|user| User {
user_id: user.user_id, user_id: user.user_id,
user_name: user.user_name user_name: user.user_name,
})) }))
} }
@ -73,10 +73,7 @@ pub fn App() -> impl IntoView {
let (user_res, set_user_res) = signal(None::<User>); let (user_res, set_user_res) = signal(None::<User>);
let user = Resource::new( let user = Resource::new(move || login.version().get(), |_| async { me().await });
move || login.version().get(),
|_| async { me().await }
);
#[cfg(feature = "hydrate")] #[cfg(feature = "hydrate")]
Effect::new(move || { Effect::new(move || {
@ -143,7 +140,12 @@ pub fn App() -> impl IntoView {
#[component] #[component]
fn LoginPage(login: ServerAction<Login>) -> impl IntoView { fn LoginPage(login: ServerAction<Login>) -> impl IntoView {
let next = move || use_query_map().read().get("next").unwrap_or("/".to_string()); let next = move || {
use_query_map()
.read()
.get("next")
.unwrap_or("/".to_string())
};
view! { view! {
<main class="login"> <main class="login">

View File

@ -36,7 +36,7 @@ pub struct BeaconResources {
listeners: Resource<Result<Vec<listeners::PubListener>, ServerFnError>>, listeners: Resource<Result<Vec<listeners::PubListener>, ServerFnError>>,
categories: Resource<Result<Vec<categories::Category>, ServerFnError>>, categories: Resource<Result<Vec<categories::Category>, ServerFnError>>,
configs: Resource<Result<Vec<configs::BeaconConfig>, ServerFnError>>, configs: Resource<Result<Vec<configs::BeaconConfig>, ServerFnError>>,
templates: Resource<Result<Vec<templates::BeaconTemplate>, ServerFnError>> templates: Resource<Result<Vec<templates::BeaconTemplate>, ServerFnError>>,
} }
pub fn provide_beacon_resources() { pub fn provide_beacon_resources() {
@ -56,40 +56,48 @@ pub fn provide_beacon_resources() {
let remove_template = ServerAction::<templates::RemoveTemplate>::new(); let remove_template = ServerAction::<templates::RemoveTemplate>::new();
let listeners = Resource::new( let listeners = Resource::new(
move || ( move || {
(
user.get(), user.get(),
add_listener.version().get(), add_listener.version().get(),
remove_listener.version().get(), remove_listener.version().get(),
), )
|_| async { listeners::get_listeners().await } },
|_| async { listeners::get_listeners().await },
); );
let categories = Resource::new( let categories = Resource::new(
move || ( move || {
(
user.get(), user.get(),
add_category.version().get(), add_category.version().get(),
remove_category.version().get(), remove_category.version().get(),
rename_category.version().get(), rename_category.version().get(),
), )
|_| async { categories::get_categories().await } },
|_| async { categories::get_categories().await },
); );
let configs = Resource::new( let configs = Resource::new(
move || ( move || {
(
user.get(), user.get(),
add_beacon_config.version().get(), add_beacon_config.version().get(),
remove_beacon_config.version().get(), remove_beacon_config.version().get(),
), )
|_| async { configs::get_beacon_configs().await } },
|_| async { configs::get_beacon_configs().await },
); );
let templates = Resource::new( let templates = Resource::new(
move || ( move || {
(
user.get(), user.get(),
add_template.version().get(), add_template.version().get(),
remove_template.version().get() remove_template.version().get(),
), )
|_| async { templates::get_templates().await } },
|_| async { templates::get_templates().await },
); );
provide_context(BeaconResources { provide_context(BeaconResources {
@ -106,7 +114,7 @@ pub fn provide_beacon_resources() {
listeners, listeners,
categories, categories,
configs, configs,
templates templates,
}); });
} }
@ -142,7 +150,7 @@ enum SortMethod {
Listener, Listener,
Config, Config,
Category, Category,
Template Template,
} }
impl std::str::FromStr for SortMethod { impl std::str::FromStr for SortMethod {
@ -154,7 +162,7 @@ impl std::str::FromStr for SortMethod {
"Config" => Ok(Self::Config), "Config" => Ok(Self::Config),
"Category" => Ok(Self::Category), "Category" => Ok(Self::Category),
"Template" => Ok(Self::Template), "Template" => Ok(Self::Template),
&_ => Err(()) &_ => Err(()),
} }
} }
} }
@ -167,7 +175,8 @@ impl std::string::ToString for SortMethod {
SM::Config => "Config", SM::Config => "Config",
SM::Category => "Category", SM::Category => "Category",
SM::Template => "Template", SM::Template => "Template",
}.to_string() }
.to_string()
} }
} }

View File

@ -1,18 +1,14 @@
use leptos::{either::Either, prelude::*}; use leptos::{either::Either, prelude::*};
use serde::{Serialize, Deserialize}; use serde::{Deserialize, Serialize};
#[cfg(feature = "ssr")] #[cfg(feature = "ssr")]
use { use {crate::db::user, leptos::server_fn::error::NoCustomError, sqlx::SqlitePool};
sqlx::SqlitePool,
leptos::server_fn::error::NoCustomError,
crate::db::user,
};
use super::BeaconResources; use super::BeaconResources;
#[derive(Clone, Serialize, Deserialize)] #[derive(Clone, Serialize, Deserialize)]
pub struct Category { pub struct Category {
pub category_id: i64, pub category_id: i64,
pub category_name: String pub category_name: String,
} }
#[server] #[server]
@ -20,28 +16,26 @@ pub async fn get_categories() -> Result<Vec<Category>, ServerFnError> {
let user = user::get_auth_session().await?; let user = user::get_auth_session().await?;
if user.is_none() { if user.is_none() {
return Err(ServerFnError::<NoCustomError>::ServerError("You are not signed in!".to_owned())); return Err(ServerFnError::<NoCustomError>::ServerError(
"You are not signed in!".to_owned(),
));
} }
let db = expect_context::<SqlitePool>(); let db = expect_context::<SqlitePool>();
Ok( Ok(sqlx::query_as!(Category, "SELECT * FROM beacon_category")
sqlx::query_as!(
Category,
"SELECT * FROM beacon_category"
)
.fetch_all(&db) .fetch_all(&db)
.await? .await?)
)
} }
#[server] #[server]
pub async fn add_category(name: String) -> Result<(), ServerFnError> { pub async fn add_category(name: String) -> Result<(), ServerFnError> {
let user = user::get_auth_session().await?; let user = user::get_auth_session().await?;
if user.is_none() { if user.is_none() {
return Err(ServerFnError::<NoCustomError>::ServerError("You are not signed in!".to_owned())); return Err(ServerFnError::<NoCustomError>::ServerError(
"You are not signed in!".to_owned(),
));
} }
let db = expect_context::<SqlitePool>(); let db = expect_context::<SqlitePool>();
@ -61,15 +55,14 @@ pub async fn remove_category(id: i64) -> Result<(), ServerFnError> {
let user = user::get_auth_session().await?; let user = user::get_auth_session().await?;
if user.is_none() { if user.is_none() {
return Err(ServerFnError::<NoCustomError>::ServerError("You are not signed in!".to_owned())); return Err(ServerFnError::<NoCustomError>::ServerError(
"You are not signed in!".to_owned(),
));
} }
let db = expect_context::<SqlitePool>(); let db = expect_context::<SqlitePool>();
sqlx::query!( sqlx::query!("DELETE FROM beacon_category WHERE category_id = ?", id)
"DELETE FROM beacon_category WHERE category_id = ?",
id
)
.execute(&db) .execute(&db)
.await?; .await?;
@ -81,7 +74,9 @@ pub async fn rename_category(id: i64, name: String) -> Result<(), ServerFnError>
let user = user::get_auth_session().await?; let user = user::get_auth_session().await?;
if user.is_none() { if user.is_none() {
return Err(ServerFnError::<NoCustomError>::ServerError("You are not signed in!".to_owned())); return Err(ServerFnError::<NoCustomError>::ServerError(
"You are not signed in!".to_owned(),
));
} }
let db = expect_context::<SqlitePool>(); let db = expect_context::<SqlitePool>();
@ -99,7 +94,11 @@ pub async fn rename_category(id: i64, name: String) -> Result<(), ServerFnError>
#[component] #[component]
pub fn CategoriesView() -> impl IntoView { pub fn CategoriesView() -> impl IntoView {
let BeaconResources { add_category, categories, .. } = expect_context(); let BeaconResources {
add_category,
categories,
..
} = expect_context();
view! { view! {
<div class="categories"> <div class="categories">
@ -148,7 +147,11 @@ pub fn CategoriesView() -> impl IntoView {
#[component] #[component]
fn DisplayCategories(categories: Vec<Category>) -> impl IntoView { fn DisplayCategories(categories: Vec<Category>) -> impl IntoView {
let BeaconResources { remove_category, rename_category, .. } = expect_context(); let BeaconResources {
remove_category,
rename_category,
..
} = expect_context();
let (target_rename_id, set_target_rename_id) = signal(0); let (target_rename_id, set_target_rename_id) = signal(0);
let target_rename_name = RwSignal::new("".to_owned()); let target_rename_name = RwSignal::new("".to_owned());
@ -157,7 +160,8 @@ fn DisplayCategories(categories: Vec<Category>) -> impl IntoView {
let categories_view = categories let categories_view = categories
.iter() .iter()
.map(|category| view! { .map(|category| {
view! {
<li> <li>
{category.category_id} {category.category_id}
": " ": "
@ -191,21 +195,24 @@ fn DisplayCategories(categories: Vec<Category>) -> impl IntoView {
"delete" "delete"
</button> </button>
</li> </li>
}
}) })
.collect_view(); .collect_view();
Effect::watch( Effect::watch(
move || ( move || {
(
rename_category.version().get(), rename_category.version().get(),
rename_category.value().get(), rename_category.value().get(),
dialog_ref.get() dialog_ref.get(),
), )
},
move |(_, res, dialog_ref), _, _| { move |(_, res, dialog_ref), _, _| {
if let (Some(Ok(())), Some(dialog)) = (res, dialog_ref) { if let (Some(Ok(())), Some(dialog)) = (res, dialog_ref) {
let _ = dialog.close(); let _ = dialog.close();
} }
}, },
false false,
); );
view! { view! {

View File

@ -2,12 +2,10 @@ use leptos::{either::Either, prelude::*};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
#[cfg(feature = "ssr")] #[cfg(feature = "ssr")]
use { use {
std::str::FromStr, crate::db::user,
sqlx::{sqlite::SqliteRow, FromRow, Row, SqlitePool},
leptos::server_fn::error::NoCustomError, leptos::server_fn::error::NoCustomError,
sqlx::{sqlite::SqliteRow, FromRow, Row, SqlitePool},
crate::db::user std::str::FromStr,
}; };
use super::BeaconResources; use super::BeaconResources;
@ -32,7 +30,7 @@ impl FromRow<'_, SqliteRow> for BeaconConfigTypes {
)), )),
"cron" => Ok(Self::CronSchedule( "cron" => Ok(Self::CronSchedule(
row.try_get("cron_schedule")?, row.try_get("cron_schedule")?,
row.try_get("cron_mode")? row.try_get("cron_mode")?,
)), )),
type_name => Err(sqlx::Error::TypeNotFound { type_name => Err(sqlx::Error::TypeNotFound {
type_name: type_name.to_string(), type_name: type_name.to_string(),
@ -55,7 +53,9 @@ pub async fn get_beacon_configs() -> Result<Vec<BeaconConfig>, ServerFnError> {
let user = user::get_auth_session().await?; let user = user::get_auth_session().await?;
if user.is_none() { if user.is_none() {
return Err(ServerFnError::<NoCustomError>::ServerError("You are not signed in!".to_owned())); return Err(ServerFnError::<NoCustomError>::ServerError(
"You are not signed in!".to_owned(),
));
} }
let db = expect_context::<SqlitePool>(); let db = expect_context::<SqlitePool>();
@ -78,7 +78,9 @@ pub async fn add_beacon_config(
let user = user::get_auth_session().await?; let user = user::get_auth_session().await?;
if user.is_none() { if user.is_none() {
return Err(ServerFnError::<NoCustomError>::ServerError("You are not signed in!".to_owned())); return Err(ServerFnError::<NoCustomError>::ServerError(
"You are not signed in!".to_owned(),
));
} }
let db = expect_context::<SqlitePool>(); let db = expect_context::<SqlitePool>();
@ -93,10 +95,12 @@ pub async fn add_beacon_config(
.await?; .await?;
Ok(()) Ok(())
}, }
"regular" => { "regular" => {
if regular_interval < 1 { if regular_interval < 1 {
return Err(ServerFnError::<NoCustomError>::ServerError("Invalid interval provided".to_owned())) return Err(ServerFnError::<NoCustomError>::ServerError(
"Invalid interval provided".to_owned(),
));
} }
sqlx::query!( sqlx::query!(
@ -108,10 +112,12 @@ pub async fn add_beacon_config(
.await?; .await?;
Ok(()) Ok(())
}, }
"random" => { "random" => {
if random_min_time < 1 || random_max_time < random_min_time { if random_min_time < 1 || random_max_time < random_min_time {
return Err(ServerFnError::<NoCustomError>::ServerError("Invalid random interval provided".to_owned())) return Err(ServerFnError::<NoCustomError>::ServerError(
"Invalid random interval provided".to_owned(),
));
} }
sqlx::query!( sqlx::query!(
@ -124,19 +130,21 @@ pub async fn add_beacon_config(
.await?; .await?;
Ok(()) Ok(())
}, }
"cron" => { "cron" => {
if let Err(e) = cron::Schedule::from_str(&cron_schedule) { if let Err(e) = cron::Schedule::from_str(&cron_schedule) {
return Err(ServerFnError::<NoCustomError>::ServerError(format!( return Err(ServerFnError::<NoCustomError>::ServerError(format!(
"Could not parse cron expression: {}", "Could not parse cron expression: {}",
e e
))) )));
} }
match &*cron_mode { match &*cron_mode {
"local" | "utc" => {}, "local" | "utc" => {}
_ => { _ => {
return Err(ServerFnError::<NoCustomError>::ServerError("Unrecognized timezone specifier for cron".to_string())) return Err(ServerFnError::<NoCustomError>::ServerError(
"Unrecognized timezone specifier for cron".to_string(),
))
} }
} }
@ -150,10 +158,10 @@ pub async fn add_beacon_config(
.await?; .await?;
Ok(()) Ok(())
},
_ => {
Err(ServerFnError::<NoCustomError>::ServerError("Invalid mode supplied".to_owned()))
} }
_ => Err(ServerFnError::<NoCustomError>::ServerError(
"Invalid mode supplied".to_owned(),
)),
} }
} }
@ -162,15 +170,14 @@ pub async fn remove_beacon_config(id: i64) -> Result<(), ServerFnError> {
let user = user::get_auth_session().await?; let user = user::get_auth_session().await?;
if user.is_none() { if user.is_none() {
return Err(ServerFnError::<NoCustomError>::ServerError("You are not signed in!".to_owned())); return Err(ServerFnError::<NoCustomError>::ServerError(
"You are not signed in!".to_owned(),
));
} }
let db = expect_context::<SqlitePool>(); let db = expect_context::<SqlitePool>();
sqlx::query!( sqlx::query!("DELETE FROM beacon_config WHERE config_id = ?", id)
"DELETE FROM beacon_config WHERE config_id = ?",
id
)
.execute(&db) .execute(&db)
.await?; .await?;
@ -179,7 +186,11 @@ pub async fn remove_beacon_config(id: i64) -> Result<(), ServerFnError> {
#[component] #[component]
pub fn ConfigsView() -> impl IntoView { pub fn ConfigsView() -> impl IntoView {
let BeaconResources { add_beacon_config, configs, .. } = expect_context(); let BeaconResources {
add_beacon_config,
configs,
..
} = expect_context();
view! { view! {
<div class="config"> <div class="config">
@ -260,11 +271,15 @@ pub fn ConfigsView() -> impl IntoView {
#[component] #[component]
fn DisplayConfigs(configs: Vec<BeaconConfig>) -> impl IntoView { fn DisplayConfigs(configs: Vec<BeaconConfig>) -> impl IntoView {
let BeaconResources { remove_beacon_config, .. } = expect_context(); let BeaconResources {
remove_beacon_config,
..
} = expect_context();
let configs_view = configs let configs_view = configs
.iter() .iter()
.map(|config| view! { .map(|config| {
view! {
<li> <li>
{config.config_id} {config.config_id}
": " ": "
@ -296,6 +311,7 @@ fn DisplayConfigs(configs: Vec<BeaconConfig>) -> impl IntoView {
"delete" "delete"
</button> </button>
</li> </li>
}
}) })
.collect_view(); .collect_view();

View File

@ -2,17 +2,14 @@
use std::net::Ipv4Addr; use std::net::Ipv4Addr;
use leptos::{either::Either, prelude::*}; use leptos::{either::Either, prelude::*};
use serde::{Serialize, Deserialize}; use serde::{Deserialize, Serialize};
#[cfg(feature = "ssr")] #[cfg(feature = "ssr")]
use { use {
sqlx::SqlitePool, crate::db::user,
leptos::server_fn::error::NoCustomError, leptos::server_fn::error::NoCustomError,
rcgen::{generate_simple_self_signed, CertifiedKey}, rcgen::{generate_simple_self_signed, CertifiedKey},
sparse_handler::BeaconListenerMap, sparse_handler::BeaconListenerMap,
sqlx::SqlitePool,
crate::db::user,
}; };
use super::BeaconResources; use super::BeaconResources;
@ -22,7 +19,7 @@ struct DbListener {
listener_id: i64, listener_id: i64,
port: i64, port: i64,
public_ip: String, public_ip: String,
domain_name: String domain_name: String,
} }
#[derive(Clone, Serialize, Deserialize)] #[derive(Clone, Serialize, Deserialize)]
@ -31,7 +28,7 @@ pub struct PubListener {
pub port: i64, pub port: i64,
pub public_ip: String, pub public_ip: String,
pub domain_name: String, pub domain_name: String,
pub active: bool pub active: bool,
} }
#[server] #[server]
@ -39,7 +36,9 @@ pub async fn get_listeners() -> Result<Vec<PubListener>, ServerFnError> {
let user = user::get_auth_session().await?; let user = user::get_auth_session().await?;
if user.is_none() { if user.is_none() {
return Err(ServerFnError::<NoCustomError>::ServerError("You are not signed in!".to_owned())); return Err(ServerFnError::<NoCustomError>::ServerError(
"You are not signed in!".to_owned(),
));
} }
let db = expect_context::<SqlitePool>(); let db = expect_context::<SqlitePool>();
@ -66,13 +65,16 @@ pub async fn get_listeners() -> Result<Vec<PubListener>, ServerFnError> {
active: beacon_handles_handle active: beacon_handles_handle
.get(&b.listener_id) .get(&b.listener_id)
.map(|h| !h.is_finished()) .map(|h| !h.is_finished())
.unwrap_or(false) .unwrap_or(false),
}) })
.collect()) .collect())
} }
#[cfg(feature = "ssr")] #[cfg(feature = "ssr")]
pub fn generate_cert_from_keypair(kp: &rcgen::KeyPair, names: Vec<String>) -> Result<rcgen::Certificate, rcgen::Error> { pub fn generate_cert_from_keypair(
kp: &rcgen::KeyPair,
names: Vec<String>,
) -> Result<rcgen::Certificate, rcgen::Error> {
use rcgen::CertificateParams; use rcgen::CertificateParams;
let mut params = CertificateParams::new(names)?; let mut params = CertificateParams::new(names)?;
@ -83,25 +85,33 @@ pub fn generate_cert_from_keypair(kp: &rcgen::KeyPair, names: Vec<String>) -> Re
} }
#[server] #[server]
pub async fn add_listener(public_ip: String, port: i16, domain_name: String) -> Result<(), ServerFnError> { pub async fn add_listener(
public_ip: String,
port: i16,
domain_name: String,
) -> Result<(), ServerFnError> {
let user = user::get_auth_session().await?; let user = user::get_auth_session().await?;
if user.is_none() { if user.is_none() {
return Err(ServerFnError::<NoCustomError>::ServerError("You are not signed in!".to_owned())); return Err(ServerFnError::<NoCustomError>::ServerError(
"You are not signed in!".to_owned(),
));
} }
if public_ip.parse::<Ipv4Addr>().is_err() { if public_ip.parse::<Ipv4Addr>().is_err() {
return Err(ServerFnError::<NoCustomError>::ServerError("Unable to parse public IP address".to_owned())); return Err(ServerFnError::<NoCustomError>::ServerError(
"Unable to parse public IP address".to_owned(),
));
} }
let subject_alt_names = vec![public_ip.to_string(), domain_name.clone()]; let subject_alt_names = vec![public_ip.to_string(), domain_name.clone()];
let (key_pair, cert) = tokio::task::spawn_blocking(|| { let (key_pair, cert) = tokio::task::spawn_blocking(|| {
rcgen::KeyPair::generate_for(&rcgen::PKCS_ECDSA_P256_SHA256) rcgen::KeyPair::generate_for(&rcgen::PKCS_ECDSA_P256_SHA256).and_then(|keypair| {
.and_then(|keypair| generate_cert_from_keypair(&keypair, subject_alt_names).map(|cert| (keypair, cert))
generate_cert_from_keypair(&keypair, subject_alt_names).map(|cert| (keypair, cert))) })
}).await??; })
.await??;
let db = expect_context::<SqlitePool>(); let db = expect_context::<SqlitePool>();
@ -128,20 +138,26 @@ pub async fn remove_listener(listener_id: i64) -> Result<(), ServerFnError> {
let user = user::get_auth_session().await?; let user = user::get_auth_session().await?;
if user.is_none() { if user.is_none() {
return Err(ServerFnError::<NoCustomError>::ServerError("You are not signed in!".to_owned())); return Err(ServerFnError::<NoCustomError>::ServerError(
"You are not signed in!".to_owned(),
));
} }
{ {
let blm = expect_context::<BeaconListenerMap>(); let blm = expect_context::<BeaconListenerMap>();
let Ok(mut blm_handle) = blm.write() else { let Ok(mut blm_handle) = blm.write() else {
return Err(ServerFnError::<NoCustomError>::ServerError("Failed to get write handle for beacon listener map".to_owned())); return Err(ServerFnError::<NoCustomError>::ServerError(
"Failed to get write handle for beacon listener map".to_owned(),
));
}; };
if let Some(bl) = blm_handle.get_mut(&listener_id) { if let Some(bl) = blm_handle.get_mut(&listener_id) {
bl.abort(); bl.abort();
} else { } else {
return Err(ServerFnError::<NoCustomError>::ServerError("Failed to get write handle for beacon listener map".to_owned())); return Err(ServerFnError::<NoCustomError>::ServerError(
"Failed to get write handle for beacon listener map".to_owned(),
));
} }
blm_handle.remove(&listener_id); blm_handle.remove(&listener_id);
@ -164,21 +180,23 @@ pub async fn start_listener(listener_id: i64) -> Result<(), ServerFnError> {
let user = user::get_auth_session().await?; let user = user::get_auth_session().await?;
if user.is_none() { if user.is_none() {
return Err(ServerFnError::<NoCustomError>::ServerError("You are not signed in!".to_owned())); return Err(ServerFnError::<NoCustomError>::ServerError(
"You are not signed in!".to_owned(),
));
} }
sparse_handler::start_listener( sparse_handler::start_listener(expect_context(), listener_id, expect_context()).await?;
expect_context(),
listener_id,
expect_context()
).await?;
Ok(()) Ok(())
} }
#[component] #[component]
pub fn ListenersView() -> impl IntoView { pub fn ListenersView() -> impl IntoView {
let super::BeaconResources { add_listener, listeners, .. } = expect_context::<super::BeaconResources>(); let super::BeaconResources {
add_listener,
listeners,
..
} = expect_context::<super::BeaconResources>();
view! { view! {
<div class="listeners"> <div class="listeners">
@ -230,7 +248,11 @@ pub fn ListenersView() -> impl IntoView {
#[component] #[component]
fn DisplayListeners(listeners: Vec<PubListener>) -> impl IntoView { fn DisplayListeners(listeners: Vec<PubListener>) -> impl IntoView {
let BeaconResources { listeners: listener_resource, remove_listener, .. } = expect_context::<BeaconResources>(); let BeaconResources {
listeners: listener_resource,
remove_listener,
..
} = expect_context::<BeaconResources>();
let (error_msg, set_error_msg) = signal(None); let (error_msg, set_error_msg) = signal(None);
let start_listener_action = Action::new(move |&id: &i64| async move { let start_listener_action = Action::new(move |&id: &i64| async move {
@ -246,7 +268,8 @@ fn DisplayListeners(listeners: Vec<PubListener>) -> impl IntoView {
let listeners_view = listeners let listeners_view = listeners
.iter() .iter()
.map(|listener| view! { .map(|listener| {
view! {
<li> <li>
{listener.listener_id} {listener.listener_id}
": " ": "
@ -286,6 +309,7 @@ fn DisplayListeners(listeners: Vec<PubListener>) -> impl IntoView {
"delete" "delete"
</button> </button>
</li> </li>
}
}) })
.collect_view(); .collect_view();

View File

@ -1,13 +1,11 @@
use leptos::{either::Either, prelude::*}; use leptos::{either::Either, prelude::*};
use serde::{Serialize, Deserialize}; use serde::{Deserialize, Serialize};
#[cfg(feature = "ssr")] #[cfg(feature = "ssr")]
use { use {
std::net::Ipv4Addr, crate::db::user,
sqlx::{sqlite::SqliteRow, FromRow, Row, SqlitePool},
leptos::server_fn::error::NoCustomError, leptos::server_fn::error::NoCustomError,
sqlx::{sqlite::SqliteRow, FromRow, Row, SqlitePool},
crate::db::user std::net::Ipv4Addr,
}; };
use crate::beacons::BeaconResources; use crate::beacons::BeaconResources;
@ -15,7 +13,7 @@ use crate::beacons::BeaconResources;
#[derive(Clone, Serialize, Deserialize)] #[derive(Clone, Serialize, Deserialize)]
pub enum BeaconSourceMode { pub enum BeaconSourceMode {
Host, Host,
Custom(i64, String) Custom(i64, String),
} }
#[cfg(feature = "ssr")] #[cfg(feature = "ssr")]
@ -25,7 +23,7 @@ impl FromRow<'_, SqliteRow> for BeaconSourceMode {
"host" => Ok(Self::Host), "host" => Ok(Self::Host),
"custom" => Ok(Self::Custom( "custom" => Ok(Self::Custom(
row.try_get("source_netmask")?, row.try_get("source_netmask")?,
row.try_get("source_gateway")? row.try_get("source_gateway")?,
)), )),
type_name => Err(sqlx::Error::TypeNotFound { type_name => Err(sqlx::Error::TypeNotFound {
type_name: type_name.to_string(), type_name: type_name.to_string(),
@ -48,7 +46,7 @@ pub struct BeaconTemplate {
config_id: i64, config_id: i64,
listener_id: i64, listener_id: i64,
default_category: Option<i64> default_category: Option<i64>,
} }
cfg_if::cfg_if! { cfg_if::cfg_if! {
@ -72,7 +70,8 @@ pub async fn add_template(
source_mac: String, source_mac: String,
source_mode: String, source_mode: String,
source_netmask: i64, source_netmask: i64,
source_gateway: String source_gateway: String,
source_interface: String,
) -> Result<(), ServerFnError> { ) -> Result<(), ServerFnError> {
let user = user::get_auth_session().await?; let user = user::get_auth_session().await?;
@ -97,7 +96,11 @@ pub async fn add_template(
} }
let mac_parts = source_mac.split(":").collect::<Vec<_>>(); let mac_parts = source_mac.split(":").collect::<Vec<_>>();
if mac_parts.len() != 6 || mac_parts.iter().any(|p| p.len() != 2 || u8::from_str_radix(p, 16).is_err()) { if mac_parts.len() != 6
|| mac_parts
.iter()
.any(|p| p.len() != 2 || u8::from_str_radix(p, 16).is_err())
{
srverr!("Source MAC address is formatted incorrectly"); srverr!("Source MAC address is formatted incorrectly");
} }
@ -119,7 +122,7 @@ pub async fn add_template(
srverr!("Could not parse private key: {e}"); srverr!("Could not parse private key: {e}");
} }
}, },
&rcgen::PKCS_ECDSA_P256_SHA256 &rcgen::PKCS_ECDSA_P256_SHA256,
)?; )?;
let ca_params = CertificateParams::from_ca_cert_der(&(*listener.certificate).into())?; let ca_params = CertificateParams::from_ca_cert_der(&(*listener.certificate).into())?;
let ca_cert = ca_params.self_signed(&keypair)?; let ca_cert = ca_params.self_signed(&keypair)?;
@ -131,6 +134,8 @@ pub async fn add_template(
let client_key_der = client_key.serialize_der(); let client_key_der = client_key.serialize_der();
let client_cert_der = client_cert.der().to_vec(); let client_cert_der = client_cert.der().to_vec();
let interface = Some(source_interface).filter(|s| !s.is_empty());
match &*source_mode { match &*source_mode {
"host" => { "host" => {
let source_mac = Some(source_mac).filter(|mac| mac != "00:00:00:00:00:00"); let source_mac = Some(source_mac).filter(|mac| mac != "00:00:00:00:00:00");
@ -138,9 +143,11 @@ pub async fn add_template(
sqlx::query!( sqlx::query!(
r"INSERT INTO beacon_template r"INSERT INTO beacon_template
(template_name, operating_system, config_id, listener_id, source_ip, source_mac, source_mode, default_category, client_key, client_cert) (template_name, operating_system, config_id, listener_id, source_ip,
source_mac, source_mode, default_category, client_key, client_cert,
source_interface)
VALUES VALUES
(?, ?, ?, ?, ?, ?, 'host', ?, ?, ?)", (?, ?, ?, ?, ?, ?, 'host', ?, ?, ?, ?)",
template_name, template_name,
operating_system, operating_system,
config_id, config_id,
@ -149,22 +156,25 @@ pub async fn add_template(
source_mac, source_mac,
default_category, default_category,
client_key_der, client_key_der,
client_cert_der client_cert_der,
interface
) )
.execute(&db) .execute(&db)
.await?; .await?;
Ok(()) Ok(())
}, }
"custom" => { "custom" => {
let source_mac = Some(source_mac).filter(|mac| mac != "00:00:00:00:00:00"); let source_mac = Some(source_mac).filter(|mac| mac != "00:00:00:00:00:00");
let default_category = Some(default_category).filter(|dc| *dc != 0); let default_category = Some(default_category).filter(|dc| *dc != 0);
sqlx::query!( sqlx::query!(
r"INSERT INTO beacon_template r"INSERT INTO beacon_template
(template_name, operating_system, config_id, listener_id, source_ip, source_mac, source_mode, source_netmask, source_gateway, default_category, client_key, client_cert) (template_name, operating_system, config_id, listener_id, source_ip,
source_mac, source_mode, source_netmask, source_gateway, default_category,
client_key, client_cert, source_interface)
VALUES VALUES
(?, ?, ?, ?, ?, ?, 'host', ?, ?, ?, ?, ?)", (?, ?, ?, ?, ?, ?, 'host', ?, ?, ?, ?, ?, ?)",
template_name, template_name,
operating_system, operating_system,
config_id, config_id,
@ -175,13 +185,14 @@ pub async fn add_template(
source_gateway, source_gateway,
default_category, default_category,
client_key_der, client_key_der,
client_cert_der client_cert_der,
interface
) )
.execute(&db) .execute(&db)
.await?; .await?;
Ok(()) Ok(())
}, }
_other => { _other => {
srverr!("Invalid type of source mode provided"); srverr!("Invalid type of source mode provided");
} }
@ -198,7 +209,10 @@ pub async fn remove_template(template_id: i64) -> Result<(), ServerFnError> {
let db = expect_context::<SqlitePool>(); let db = expect_context::<SqlitePool>();
sqlx::query!("DELETE FROM beacon_template WHERE template_id = ?", template_id) sqlx::query!(
"DELETE FROM beacon_template WHERE template_id = ?",
template_id
)
.execute(&db) .execute(&db)
.await?; .await?;
@ -210,7 +224,9 @@ pub async fn get_templates() -> Result<Vec<BeaconTemplate>, ServerFnError> {
let user = user::get_auth_session().await?; let user = user::get_auth_session().await?;
if user.is_none() { if user.is_none() {
return Err(ServerFnError::<NoCustomError>::ServerError("You are not signed in!".to_owned())); return Err(ServerFnError::<NoCustomError>::ServerError(
"You are not signed in!".to_owned(),
));
} }
let db = expect_context::<SqlitePool>(); let db = expect_context::<SqlitePool>();
@ -222,7 +238,13 @@ pub async fn get_templates() -> Result<Vec<BeaconTemplate>, ServerFnError> {
#[component] #[component]
pub fn TemplatesView() -> impl IntoView { pub fn TemplatesView() -> impl IntoView {
let BeaconResources { configs, listeners, categories, templates, .. } = expect_context(); let BeaconResources {
configs,
listeners,
categories,
templates,
..
} = expect_context();
view! { view! {
<div class="templates"> <div class="templates">
@ -379,6 +401,8 @@ pub fn AddTemplateForm(
<input class="mode-custom" type="number" name="source_netmask" value="24"/> <input class="mode-custom" type="number" name="source_netmask" value="24"/>
<label class="mode-custom">"Network gateway"</label> <label class="mode-custom">"Network gateway"</label>
<input class="mode-custom" name="source_gateway"/> <input class="mode-custom" name="source_gateway"/>
<label class="mode-custom">"Network interface name"</label>
<input class="mode-custom" name="source_interface"/>
<div></div> <div></div>
<input type="submit" value="Submit" /> <input type="submit" value="Submit" />
</fieldset> </fieldset>
@ -391,9 +415,11 @@ pub fn DisplayTemplates(
configs: Vec<super::configs::BeaconConfig>, configs: Vec<super::configs::BeaconConfig>,
listeners: Vec<super::listeners::PubListener>, listeners: Vec<super::listeners::PubListener>,
categories: Vec<super::categories::Category>, categories: Vec<super::categories::Category>,
templates: Vec<BeaconTemplate> templates: Vec<BeaconTemplate>,
) -> impl IntoView { ) -> impl IntoView {
let BeaconResources { remove_template, .. } = expect_context(); let BeaconResources {
remove_template, ..
} = expect_context();
let templates_view = templates let templates_view = templates
.iter() .iter()

View File

@ -9,7 +9,7 @@ pub async fn handle_user_command(user_command: UC, db: SqlitePool) -> anyhow::Re
match user_command { match user_command {
UC::List {} => list_users(db).await, UC::List {} => list_users(db).await,
UC::Create { user_name } => create_user(db, user_name).await, UC::Create { user_name } => create_user(db, user_name).await,
UC::ResetPassword { user_id } => reset_password(&db, user_id).await UC::ResetPassword { user_id } => reset_password(&db, user_id).await,
} }
} }
@ -51,7 +51,7 @@ async fn create_user(db: SqlitePool, name: String) -> anyhow::Result<ExitCode> {
async fn reset_password<'a, E>(db: E, id: i16) -> anyhow::Result<ExitCode> async fn reset_password<'a, E>(db: E, id: i16) -> anyhow::Result<ExitCode>
where where
E: sqlx::SqliteExecutor<'a> E: sqlx::SqliteExecutor<'a>,
{ {
let password = get_password()?; let password = get_password()?;

View File

@ -1,7 +1,13 @@
use leptos::prelude::ServerFnError;
use leptos::{prelude::expect_context, server_fn::error::NoCustomError}; use leptos::{prelude::expect_context, server_fn::error::NoCustomError};
use leptos_axum::{extract, ResponseOptions}; use leptos_axum::{extract, ResponseOptions};
use leptos::prelude::ServerFnError; use pbkdf2::{
use pbkdf2::{Pbkdf2, password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, rand_core::{OsRng, RngCore}, SaltString}}; password_hash::{
rand_core::{OsRng, RngCore},
PasswordHash, PasswordHasher, PasswordVerifier, SaltString,
},
Pbkdf2,
};
use sqlx::SqlitePool; use sqlx::SqlitePool;
use crate::error::Error; use crate::error::Error;
@ -11,7 +17,7 @@ pub struct User {
pub user_id: i64, pub user_id: i64,
pub user_name: String, pub user_name: String,
password_hash: String, password_hash: String,
pub last_active: Option<i64> pub last_active: Option<i64>,
} }
impl std::fmt::Debug for User { impl std::fmt::Debug for User {
@ -29,12 +35,13 @@ async fn hash_password(pass: &[u8]) -> Result<String, Error> {
let pass = pass.to_owned(); let pass = pass.to_owned();
let salt = SaltString::generate(&mut OsRng); let salt = SaltString::generate(&mut OsRng);
move || move || {
Pbkdf2.hash_password( Pbkdf2
&*pass, .hash_password(&*pass, &salt)
&salt, .map(|hash| hash.serialize().as_str().to_string())
).map(|hash| hash.serialize().as_str().to_string()) }
}).await??) })
.await??)
} }
async fn verify_password(pass: &str, hash: &str) -> Result<bool, Error> { async fn verify_password(pass: &str, hash: &str) -> Result<bool, Error> {
@ -42,22 +49,23 @@ async fn verify_password(pass: &str, hash: &str) -> Result<bool, Error> {
let pass = pass.to_owned(); let pass = pass.to_owned();
let hash = hash.to_owned(); let hash = hash.to_owned();
move || move || {
PasswordHash::new(&*hash) PasswordHash::new(&*hash)
.map(|parsed| Pbkdf2.verify_password( .map(|parsed| Pbkdf2.verify_password(&pass.as_bytes(), &parsed).is_ok())
&pass.as_bytes(), }
&parsed })
).is_ok()) .await??)
}).await??)
} }
pub async fn reset_password<'a, E>(pool: E, id: i16, password: String) -> Result<(), crate::error::Error> pub async fn reset_password<'a, E>(
pool: E,
id: i16,
password: String,
) -> Result<(), crate::error::Error>
where where
E: sqlx::SqliteExecutor<'a> E: sqlx::SqliteExecutor<'a>,
{ {
let password_string = hash_password( let password_string = hash_password(password.as_bytes()).await?;
password.as_bytes()
).await?;
sqlx::query!( sqlx::query!(
"UPDATE users SET password_hash = ? WHERE user_id = ?", "UPDATE users SET password_hash = ? WHERE user_id = ?",
@ -70,16 +78,18 @@ where
Ok(()) Ok(())
} }
pub async fn create_user<'a, E>(acq: E, name: String, password: String) -> Result<(), crate::error::Error> pub async fn create_user<'a, E>(
acq: E,
name: String,
password: String,
) -> Result<(), crate::error::Error>
where where
E: sqlx::Acquire<'a, Database = sqlx::Sqlite> E: sqlx::Acquire<'a, Database = sqlx::Sqlite>,
{ {
let mut tx = acq.begin().await?; let mut tx = acq.begin().await?;
let previous_user_check = sqlx::query_scalar!( let previous_user_check =
"SELECT COUNT(*) FROM users WHERE user_name = ?", sqlx::query_scalar!("SELECT COUNT(*) FROM users WHERE user_name = ?", name)
name
)
.fetch_one(&mut *tx) .fetch_one(&mut *tx)
.await?; .await?;
@ -108,37 +118,30 @@ const SESSION_ID_KEY: &'static str = "session_id";
const SESSION_AGE: i64 = 30 * 60; const SESSION_AGE: i64 = 30 * 60;
pub async fn create_auth_session(username: String, password: String) -> Result<(), ServerFnError> { pub async fn create_auth_session(username: String, password: String) -> Result<(), ServerFnError> {
use axum_extra::extract::cookie::{Cookie, SameSite};
use axum::http::{header, HeaderValue}; use axum::http::{header, HeaderValue};
use axum_extra::extract::cookie::{Cookie, SameSite};
let db = expect_context::<SqlitePool>(); let db = expect_context::<SqlitePool>();
let resp = expect_context::<ResponseOptions>(); let resp = expect_context::<ResponseOptions>();
let user: Option<User> = sqlx::query_as!( let user: Option<User> =
User, sqlx::query_as!(User, "SELECT * FROM users WHERE user_name = ?", username)
"SELECT * FROM users WHERE user_name = ?",
username
)
.fetch_optional(&db) .fetch_optional(&db)
.await?; .await?;
let Some(user) = user else { let Some(user) = user else {
return Err(ServerFnError::<NoCustomError>::ServerError("Invalid credentials".to_string())); return Err(ServerFnError::<NoCustomError>::ServerError(
"Invalid credentials".to_string(),
));
}; };
let good_hash = verify_password( let good_hash = verify_password(&password, &user.password_hash).await?;
&password,
&user.password_hash
).await?;
if good_hash { if good_hash {
let now = chrono::Utc::now().timestamp(); let now = chrono::Utc::now().timestamp();
let expires = now + SESSION_AGE; let expires = now + SESSION_AGE;
sqlx::query!( sqlx::query!("UPDATE users SET last_active = ?", now)
"UPDATE users SET last_active = ?",
now
)
.execute(&db) .execute(&db)
.await?; .await?;
@ -146,7 +149,8 @@ pub async fn create_auth_session(username: String, password: String) -> Result<(
let mut key = [0u8; 32]; let mut key = [0u8; 32];
OsRng.fill_bytes(&mut key); OsRng.fill_bytes(&mut key);
hex::encode(&key[..]) hex::encode(&key[..])
}).await?; })
.await?;
sqlx::query!( sqlx::query!(
"INSERT INTO sessions (session_id, user_id, expires) VALUES (?, ?, ?)", "INSERT INTO sessions (session_id, user_id, expires) VALUES (?, ?, ?)",
@ -168,7 +172,9 @@ pub async fn create_auth_session(username: String, password: String) -> Result<(
Ok(()) Ok(())
} else { } else {
Err(ServerFnError::<NoCustomError>::ServerError("Invalid credentials".to_string())) Err(ServerFnError::<NoCustomError>::ServerError(
"Invalid credentials".to_string(),
))
} }
} }
@ -184,10 +190,7 @@ pub async fn destroy_auth_session() -> Result<(), ServerFnError> {
let session_id = cookie.value(); let session_id = cookie.value();
sqlx::query!( sqlx::query!("DELETE FROM sessions WHERE session_id = ?", session_id)
"DELETE FROM sessions WHERE session_id = ?",
session_id
)
.execute(&db) .execute(&db)
.await?; .await?;
@ -241,10 +244,7 @@ pub async fn get_auth_session() -> Result<Option<User>, ServerFnError> {
.await?; .await?;
} }
sqlx::query!( sqlx::query!("DELETE FROM sessions WHERE expires < ?", now)
"DELETE FROM sessions WHERE expires < ?",
now
)
.execute(&db) .execute(&db)
.await?; .await?;

View File

@ -1,36 +1,37 @@
#[cfg(feature = "ssr")]
mod cli;
#[cfg(feature = "ssr")]
mod webserver;
#[cfg(feature = "ssr")] #[cfg(feature = "ssr")]
mod beacons; mod beacons;
#[cfg(feature = "ssr")] #[cfg(feature = "ssr")]
pub mod users; mod cli;
pub mod error;
pub mod db; pub mod db;
pub mod error;
#[cfg(feature = "ssr")]
pub mod users;
#[cfg(feature = "ssr")]
mod webserver;
#[cfg(feature = "ssr")] #[cfg(feature = "ssr")]
#[tokio::main] #[tokio::main]
async fn main() -> anyhow::Result<std::process::ExitCode> { async fn main() -> anyhow::Result<std::process::ExitCode> {
use std::{path::PathBuf, process::ExitCode, str::FromStr}; use std::{path::PathBuf, process::ExitCode, str::FromStr};
use sqlx::sqlite::{SqliteConnectOptions, SqlitePool};
use structopt::StructOpt; use structopt::StructOpt;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use sqlx::sqlite::{SqlitePool, SqliteConnectOptions};
tracing_subscriber::registry() tracing_subscriber::registry()
.with( .with(
tracing_subscriber::EnvFilter::try_from_default_env() tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| {
.unwrap_or_else(|_| format!("{}=debug,sparse_handler=debug", env!("CARGO_CRATE_NAME")).into()), format!("{}=debug,sparse_handler=debug", env!("CARGO_CRATE_NAME")).into()
}),
) )
.with(tracing_subscriber::fmt::layer()) .with(tracing_subscriber::fmt::layer())
.init(); .init();
let options = cli::Options::from_args(); let options = cli::Options::from_args();
let db_location = options.db_location.clone() let db_location = options
.db_location
.clone()
.or(std::env::var("DATABASE_URL") .or(std::env::var("DATABASE_URL")
.map(|p| p.replace("sqlite://", "")) .map(|p| p.replace("sqlite://", ""))
.map(PathBuf::from) .map(PathBuf::from)
@ -45,7 +46,7 @@ async fn main() -> anyhow::Result<std::process::ExitCode> {
if !options.init_ok { if !options.init_ok {
tracing::error!("Database doesn't exist, and initialization not allowed!"); tracing::error!("Database doesn't exist, and initialization not allowed!");
tracing::error!("{:?}", e); tracing::error!("{:?}", e);
return Ok(ExitCode::FAILURE) return Ok(ExitCode::FAILURE);
} }
tracing::info!("Database doesn't exist, readying initialization"); tracing::info!("Database doesn't exist, readying initialization");
@ -53,14 +54,13 @@ async fn main() -> anyhow::Result<std::process::ExitCode> {
let pool = SqlitePool::connect_with( let pool = SqlitePool::connect_with(
SqliteConnectOptions::from_str(&format!("sqlite://{}", db_location.to_string_lossy()))? SqliteConnectOptions::from_str(&format!("sqlite://{}", db_location.to_string_lossy()))?
.create_if_missing(options.init_ok) .create_if_missing(options.init_ok),
).await?; )
.await?;
tracing::info!("Running database migrations..."); tracing::info!("Running database migrations...");
sqlx::migrate!() sqlx::migrate!().run(&pool).await?;
.run(&pool)
.await?;
tracing::info!("Done running database migrations!"); tracing::info!("Done running database migrations!");
@ -69,12 +69,8 @@ async fn main() -> anyhow::Result<std::process::ExitCode> {
tracing::info!("Performing requested action, acting as web server"); tracing::info!("Performing requested action, acting as web server");
webserver::serve_web(management_address, pool).await webserver::serve_web(management_address, pool).await
} }
Some(cli::Command::ExtractPubKey { }) => { Some(cli::Command::ExtractPubKey {}) => Ok(ExitCode::SUCCESS),
Ok(ExitCode::SUCCESS) Some(cli::Command::User { command }) => cli::user::handle_user_command(command, pool).await,
}
Some(cli::Command::User { command }) => {
cli::user::handle_user_command(command, pool).await
}
None => { None => {
use std::net::{Ipv4Addr, SocketAddrV4}; use std::net::{Ipv4Addr, SocketAddrV4};

View File

@ -1,21 +1,25 @@
use chrono::{DateTime, offset::Utc}; use chrono::{offset::Utc, DateTime};
use leptos::prelude::*; use leptos::prelude::*;
use serde::{Serialize, Deserialize}; use serde::{Deserialize, Serialize};
#[cfg(feature = "ssr")] #[cfg(feature = "ssr")]
use { use {crate::db::user, leptos::server_fn::error::NoCustomError, sqlx::SqlitePool};
sqlx::SqlitePool,
leptos::server_fn::error::NoCustomError,
crate::db::user
};
pub fn format_delta(time: chrono::TimeDelta) -> String { pub fn format_delta(time: chrono::TimeDelta) -> String {
let seconds = time.num_seconds(); let seconds = time.num_seconds();
match seconds { match seconds {
0..=59 => format!("{} second{} ago", seconds, if seconds == 1 {""} else {"s"}), 0..=59 => format!(
"{} second{} ago",
seconds,
if seconds == 1 { "" } else { "s" }
),
60..=3599 => { 60..=3599 => {
let minutes = seconds / 60; let minutes = seconds / 60;
format!("{} minute{} ago", minutes, if minutes == 1 {""} else {"s"}) format!(
"{} minute{} ago",
minutes,
if minutes == 1 { "" } else { "s" }
)
} }
3600..=86399 => { 3600..=86399 => {
let hours = seconds / 3600; let hours = seconds / 3600;
@ -44,14 +48,14 @@ impl std::cmp::PartialEq for User {
pub struct DbUser { pub struct DbUser {
user_id: i64, user_id: i64,
user_name: String, user_name: String,
last_active: Option<i64> last_active: Option<i64>,
} }
#[derive(Clone, Serialize, Deserialize)] #[derive(Clone, Serialize, Deserialize)]
pub struct PubUser { pub struct PubUser {
user_id: i64, user_id: i64,
user_name: String, user_name: String,
last_active: Option<DateTime<Utc>> last_active: Option<DateTime<Utc>>,
} }
#[server] #[server]
@ -59,15 +63,14 @@ async fn delete_user(user_id: i64) -> Result<(), ServerFnError> {
let user = user::get_auth_session().await?; let user = user::get_auth_session().await?;
if user.is_none() { if user.is_none() {
return Err(ServerFnError::<NoCustomError>::ServerError("You are not signed in!".to_owned())); return Err(ServerFnError::<NoCustomError>::ServerError(
"You are not signed in!".to_owned(),
));
} }
let pool = expect_context::<SqlitePool>(); let pool = expect_context::<SqlitePool>();
sqlx::query!( sqlx::query!("DELETE FROM users WHERE user_id = ?", user_id)
"DELETE FROM users WHERE user_id = ?",
user_id
)
.execute(&pool) .execute(&pool)
.await?; .await?;
@ -79,7 +82,9 @@ async fn reset_password(user_id: i64, password: String) -> Result<(), ServerFnEr
let user = user::get_auth_session().await?; let user = user::get_auth_session().await?;
if user.is_none() { if user.is_none() {
return Err(ServerFnError::<NoCustomError>::ServerError("You are not signed in!".to_owned())); return Err(ServerFnError::<NoCustomError>::ServerError(
"You are not signed in!".to_owned(),
));
} }
let pool = expect_context::<SqlitePool>(); let pool = expect_context::<SqlitePool>();
@ -96,15 +101,21 @@ pub fn RenderUser(refresh_user_list: Action<(), ()>, user: PubUser) -> impl Into
#[cfg_attr(feature = "ssr", allow(unused_variables))] #[cfg_attr(feature = "ssr", allow(unused_variables))]
let UseIntervalReturn { counter, .. } = use_interval(1000); let UseIntervalReturn { counter, .. } = use_interval(1000);
#[cfg_attr(feature = "ssr", allow(unused_variables))] #[cfg_attr(feature = "ssr", allow(unused_variables))]
let (time_ago, set_time_ago) = signal(user.last_active.map(|active| format_delta(Utc::now() - active))); let (time_ago, set_time_ago) = signal(
user.last_active
.map(|active| format_delta(Utc::now() - active)),
);
#[cfg(feature = "hydrate")] #[cfg(feature = "hydrate")]
Effect::watch( Effect::watch(
move || counter.get(), move || counter.get(),
move |_, _, _| { move |_, _, _| {
set_time_ago(user.last_active.map(|active| format_delta(Utc::now() - active))); set_time_ago(
user.last_active
.map(|active| format_delta(Utc::now() - active)),
);
}, },
false false,
); );
let dialog_ref = NodeRef::<leptos::html::Dialog>::new(); let dialog_ref = NodeRef::<leptos::html::Dialog>::new();
@ -220,23 +231,27 @@ async fn list_users() -> Result<Vec<PubUser>, ServerFnError> {
let user = user::get_auth_session().await?; let user = user::get_auth_session().await?;
if user.is_none() { if user.is_none() {
return Err(ServerFnError::<NoCustomError>::ServerError("You are not signed in!".to_owned())); return Err(ServerFnError::<NoCustomError>::ServerError(
"You are not signed in!".to_owned(),
));
} }
use futures::stream::StreamExt; use futures::stream::StreamExt;
let pool = expect_context::<SqlitePool>(); let pool = expect_context::<SqlitePool>();
let users = sqlx::query_as!( let users = sqlx::query_as!(DbUser, "SELECT user_id, user_name, last_active FROM users")
DbUser,
"SELECT user_id, user_name, last_active FROM users"
)
.fetch(&pool) .fetch(&pool)
.map(|user| user.map(|u| PubUser { .map(|user| {
user.map(|u| PubUser {
user_id: u.user_id, user_id: u.user_id,
user_name: u.user_name, user_name: u.user_name,
last_active: u.last_active.map(|ts| DateTime::from_timestamp(ts, 0)).flatten() last_active: u
})) .last_active
.map(|ts| DateTime::from_timestamp(ts, 0))
.flatten(),
})
})
.collect::<Vec<Result<_, _>>>() .collect::<Vec<Result<_, _>>>()
.await; .await;
@ -250,7 +265,9 @@ async fn add_user(name: String, password: String) -> Result<(), ServerFnError> {
let user = user::get_auth_session().await?; let user = user::get_auth_session().await?;
if user.is_none() { if user.is_none() {
return Err(ServerFnError::<NoCustomError>::ServerError("You are not signed in!".to_owned())); return Err(ServerFnError::<NoCustomError>::ServerError(
"You are not signed in!".to_owned(),
));
} }
let pool = expect_context::<SqlitePool>(); let pool = expect_context::<SqlitePool>();

View File

@ -1,9 +1,14 @@
use std::{net::SocketAddrV4, process::ExitCode}; use std::{net::SocketAddrV4, process::ExitCode};
use sqlx::sqlite::SqlitePool; use axum::{
use axum::{extract::{FromRef, Path, State}, response::IntoResponse, Router, routing::get}; extract::{FromRef, Path, State},
response::IntoResponse,
routing::get,
Router,
};
use leptos::prelude::*; use leptos::prelude::*;
use leptos_axum::{generate_route_list, LeptosRoutes}; use leptos_axum::{generate_route_list, LeptosRoutes};
use sqlx::sqlite::SqlitePool;
use tokio::signal; use tokio::signal;
use sparse_server::app::*; use sparse_server::app::*;
@ -11,8 +16,10 @@ use sparse_server::app::*;
#[cfg(not(debug_assertions))] #[cfg(not(debug_assertions))]
pub(crate) mod beacon_binaries { pub(crate) mod beacon_binaries {
pub const LINUX_INSTALLER: &'static [u8] = include_bytes!(std::env!("SPARSE_INSTALLER_LINUX")); pub const LINUX_INSTALLER: &'static [u8] = include_bytes!(std::env!("SPARSE_INSTALLER_LINUX"));
pub const FREEBSD_INSTALLER: &'static [u8] = include_bytes!(std::env!("SPARSE_INSTALLER_FREEBSD")); pub const FREEBSD_INSTALLER: &'static [u8] =
pub const WINDOWS_INSTALLER: &'static [u8] = include_bytes!(std::env!("SPARSE_INSTALLER_WINDOWS")); include_bytes!(std::env!("SPARSE_INSTALLER_FREEBSD"));
pub const WINDOWS_INSTALLER: &'static [u8] =
include_bytes!(std::env!("SPARSE_INSTALLER_WINDOWS"));
} }
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
@ -21,7 +28,11 @@ pub async fn get_installer(btype: &str) -> Result<Vec<u8>, crate::error::Error>
"linux" => "target/x86_64-unknown-linux-musl/debug/sparse-unix-installer", "linux" => "target/x86_64-unknown-linux-musl/debug/sparse-unix-installer",
"freebsd" => "target/x86_64-unknown-freebsd/debug/sparse-unix-installer", "freebsd" => "target/x86_64-unknown-freebsd/debug/sparse-unix-installer",
"windows" => "target/x86_64-pc-windows-gnu/debug/sparse-unix-installer", "windows" => "target/x86_64-pc-windows-gnu/debug/sparse-unix-installer",
other => return Err(crate::error::Error::Generic(format!("unknown beacon type: {other}"))), other => {
return Err(crate::error::Error::Generic(format!(
"unknown beacon type: {other}"
)))
}
}; };
Ok(tokio::fs::read(path).await?) Ok(tokio::fs::read(path).await?)
@ -32,20 +43,22 @@ pub async fn get_installer(btype: &str) -> Result<Vec<u8>, crate::error::Error>
"linux" => Ok(beacon_binaries::LINUX_INSTALLER.to_owned()), "linux" => Ok(beacon_binaries::LINUX_INSTALLER.to_owned()),
"windows" => Ok(beacon_binaries::WINDOWS_INSTALLER.to_owned()), "windows" => Ok(beacon_binaries::WINDOWS_INSTALLER.to_owned()),
"freebsd" => Ok(beacon_binaries::FREEBSD_INSTALLER.to_owned()), "freebsd" => Ok(beacon_binaries::FREEBSD_INSTALLER.to_owned()),
other => Err(crate::error::Error::Generic(format!("unknown beacon type: {other}"))) other => Err(crate::error::Error::Generic(format!(
"unknown beacon type: {other}"
))),
} }
} }
#[derive(FromRef, Clone, Debug)] #[derive(FromRef, Clone, Debug)]
pub struct AppState { pub struct AppState {
db: SqlitePool, db: SqlitePool,
leptos_options: leptos::config::LeptosOptions leptos_options: leptos::config::LeptosOptions,
} }
#[axum::debug_handler] #[axum::debug_handler]
pub async fn download_beacon_installer( pub async fn download_beacon_installer(
Path(template_id): Path<i64>, Path(template_id): Path<i64>,
State(db): State<AppState> State(db): State<AppState>,
) -> Result<impl IntoResponse, crate::error::Error> { ) -> Result<impl IntoResponse, crate::error::Error> {
use rand::{rngs::OsRng, TryRngCore}; use rand::{rngs::OsRng, TryRngCore};
use sparse_actions::payload_types::{Parameters_t, XOR_KEY}; use sparse_actions::payload_types::{Parameters_t, XOR_KEY};
@ -53,11 +66,13 @@ pub async fn download_beacon_installer(
let mut parameters_buffer = vec![0u8; std::mem::size_of::<Parameters_t>()]; let mut parameters_buffer = vec![0u8; std::mem::size_of::<Parameters_t>()];
let _ = OsRng.try_fill_bytes(&mut parameters_buffer); let _ = OsRng.try_fill_bytes(&mut parameters_buffer);
let parameters: &mut Parameters_t = unsafe { std::mem::transmute(parameters_buffer.as_mut_ptr()) }; let parameters: &mut Parameters_t =
unsafe { std::mem::transmute(parameters_buffer.as_mut_ptr()) };
let template = sqlx::query!( let template = sqlx::query!(
r"SELECT operating_system, source_ip, source_mac, source_mode, source_netmask, r"SELECT operating_system, source_ip, source_mac, source_mode, source_netmask,
source_gateway, port, public_ip, domain_name, certificate, client_cert, client_key source_gateway, port, public_ip, domain_name, certificate, client_cert, client_key,
source_interface
FROM beacon_template JOIN beacon_listener" FROM beacon_template JOIN beacon_listener"
) )
.fetch_one(&db.db) .fetch_one(&db.db)
@ -79,19 +94,29 @@ pub async fn download_beacon_installer(
.map(|by| u8::from_str_radix(by, 16)) .map(|by| u8::from_str_radix(by, 16))
.collect::<Result<Vec<u8>, _>>() .collect::<Result<Vec<u8>, _>>()
.map_err(|_| crate::error::Error::Generic("Could not parse source MAC address".to_string())) .map_err(|_| crate::error::Error::Generic("Could not parse source MAC address".to_string()))
.and_then( .and_then(|bytes| {
|bytes| bytes.try_into().map_err(|_| crate::error::Error::Generic("Could not parse source MAC address".to_string())) bytes.try_into().map_err(|_| {
)?; crate::error::Error::Generic("Could not parse source MAC address".to_string())
})
})?;
let src_octets = src_ip.octets(); let src_octets = src_ip.octets();
match (template.source_mode.as_deref(), template.source_netmask, template.source_gateway) { match (
template.source_mode.as_deref(),
template.source_netmask,
template.source_gateway,
) {
(Some("custom"), Some(nm), Some(ip)) => unsafe { (Some("custom"), Some(nm), Some(ip)) => unsafe {
let gateway = ip.parse::<std::net::Ipv4Addr>()?; let gateway = ip.parse::<std::net::Ipv4Addr>()?;
let gw_octets = gateway.octets(); let gw_octets = gateway.octets();
parameters.source_ip.custom_networking.mode = 0; parameters.source_ip.custom_networking.mode = 0;
parameters.source_ip.custom_networking.source_mac.copy_from_slice(&src_mac[..]);
parameters.source_ip.custom_networking.netmask = nm as u16; parameters.source_ip.custom_networking.netmask = nm as u16;
parameters
.source_ip
.custom_networking
.source_mac
.copy_from_slice(&src_mac[..]);
parameters.source_ip.custom_networking.source_ip.a = src_octets[0]; parameters.source_ip.custom_networking.source_ip.a = src_octets[0];
parameters.source_ip.custom_networking.source_ip.b = src_octets[1]; parameters.source_ip.custom_networking.source_ip.b = src_octets[1];
parameters.source_ip.custom_networking.source_ip.c = src_octets[2]; parameters.source_ip.custom_networking.source_ip.c = src_octets[2];
@ -100,17 +125,31 @@ pub async fn download_beacon_installer(
parameters.source_ip.custom_networking.gateway.b = gw_octets[1]; parameters.source_ip.custom_networking.gateway.b = gw_octets[1];
parameters.source_ip.custom_networking.gateway.c = gw_octets[2]; parameters.source_ip.custom_networking.gateway.c = gw_octets[2];
parameters.source_ip.custom_networking.gateway.d = gw_octets[3]; parameters.source_ip.custom_networking.gateway.d = gw_octets[3];
if let Some(intf) = &template.source_interface {
parameters.source_ip.custom_networking.interface[..intf.len()]
.copy_from_slice(&intf[..]);
parameters.source_ip.custom_networking.interface_len = intf.len() as u8;
} else {
parameters.source_ip.custom_networking.interface_len = 0;
} }
},
(Some("host"), _, _) => unsafe { (Some("host"), _, _) => unsafe {
parameters.source_ip.use_host_networking.mode = 1; parameters.source_ip.use_host_networking.mode = 1;
parameters.source_ip.use_host_networking.source_mac.copy_from_slice(&src_mac[..]); parameters
.source_ip
.use_host_networking
.source_mac
.copy_from_slice(&src_mac[..]);
parameters.source_ip.use_host_networking.source_ip.a = src_octets[0]; parameters.source_ip.use_host_networking.source_ip.a = src_octets[0];
parameters.source_ip.use_host_networking.source_ip.b = src_octets[1]; parameters.source_ip.use_host_networking.source_ip.b = src_octets[1];
parameters.source_ip.use_host_networking.source_ip.c = src_octets[2]; parameters.source_ip.use_host_networking.source_ip.c = src_octets[2];
parameters.source_ip.use_host_networking.source_ip.d = src_octets[3]; parameters.source_ip.use_host_networking.source_ip.d = src_octets[3];
} },
_ => { _ => {
return Err(crate::error::Error::Generic("Could not parse host networking configuration".to_string())); return Err(crate::error::Error::Generic(
"Could not parse host networking configuration".to_string(),
));
} }
} }
@ -144,10 +183,7 @@ pub async fn download_beacon_installer(
Ok(( Ok((
[ [
( (header::CONTENT_TYPE, "application/octet-stream".to_string()),
header::CONTENT_TYPE,
"application/octet-stream".to_string()
),
( (
header::CONTENT_DISPOSITION, header::CONTENT_DISPOSITION,
format!( format!(
@ -157,17 +193,17 @@ pub async fn download_beacon_installer(
} else { } else {
"" ""
} }
) ),
) ),
], ],
[ [&installer_bytes[..], &parameters_bytes[..]].concat(),
&installer_bytes[..],
&parameters_bytes[..]
].concat()
)) ))
} }
pub async fn serve_web(management_address: SocketAddrV4, db: SqlitePool) -> anyhow::Result<ExitCode> { pub async fn serve_web(
management_address: SocketAddrV4,
db: SqlitePool,
) -> anyhow::Result<ExitCode> {
let conf = get_configuration(None).unwrap(); let conf = get_configuration(None).unwrap();
let leptos_options = conf.leptos_options; let leptos_options = conf.leptos_options;
let routes = generate_route_list(App); let routes = generate_route_list(App);
@ -183,7 +219,7 @@ pub async fn serve_web(management_address: SocketAddrV4, db: SqlitePool) -> anyh
let state = AppState { let state = AppState {
leptos_options: leptos_options.clone(), leptos_options: leptos_options.clone(),
db: db.clone() db: db.clone(),
}; };
let app = Router::new() let app = Router::new()
@ -198,20 +234,26 @@ pub async fn serve_web(management_address: SocketAddrV4, db: SqlitePool) -> anyh
{ {
let leptos_options = leptos_options.clone(); let leptos_options = leptos_options.clone();
move || shell(leptos_options.clone()) move || shell(leptos_options.clone())
} },
) )
.fallback(leptos_axum::file_and_error_handler::<leptos::config::LeptosOptions, _>(shell)) .fallback(leptos_axum::file_and_error_handler::<
leptos::config::LeptosOptions,
_,
>(shell))
.with_state(state) .with_state(state)
.layer( .layer(
tower::ServiceBuilder::new() tower::ServiceBuilder::new()
.layer(tower_http::trace::TraceLayer::new_for_http()) .layer(tower_http::trace::TraceLayer::new_for_http())
.layer(compression_layer) .layer(compression_layer),
); );
// run our app with hyper // run our app with hyper
// `axum::Server` is a re-export of `hyper::Server` // `axum::Server` is a re-export of `hyper::Server`
let management_listener = tokio::net::TcpListener::bind(&management_address).await?; let management_listener = tokio::net::TcpListener::bind(&management_address).await?;
tracing::info!("management interface listening on http://{}", &management_address); tracing::info!(
"management interface listening on http://{}",
&management_address
);
axum::serve(management_listener, app.into_make_service()) axum::serve(management_listener, app.into_make_service())
.with_graceful_shutdown(shutdown_signal()) .with_graceful_shutdown(shutdown_signal())

View File

@ -94,11 +94,11 @@ where
let access_time = libc::timespec { let access_time = libc::timespec {
tv_sec: metadata.atime(), tv_sec: metadata.atime(),
tv_nsec: metadata.atime_nsec() tv_nsec: metadata.atime_nsec(),
}; };
let modify_time = libc::timespec { let modify_time = libc::timespec {
tv_sec: metadata.mtime(), tv_sec: metadata.mtime(),
tv_nsec: metadata.mtime_nsec() tv_nsec: metadata.mtime_nsec(),
}; };
unsafe { unsafe {
@ -320,7 +320,6 @@ where
cap_set_fd(binary.as_raw_fd(), current_caps); cap_set_fd(binary.as_raw_fd(), current_caps);
cap_free(current_caps); cap_free(current_caps);
} }
} }
Ok(()) Ok(())

View File

@ -1,10 +1,10 @@
use std::{ use std::{
fs::OpenOptions, fs::OpenOptions,
io::{prelude::*, Error, SeekFrom}, io::{Error, SeekFrom, prelude::*},
path::PathBuf, path::PathBuf,
}; };
use rand::{rngs::OsRng, TryRngCore}; use rand::{TryRngCore, rngs::OsRng};
use structopt::StructOpt; use structopt::StructOpt;
use sparse_actions::payload_types::{Parameters, XOR_KEY}; use sparse_actions::payload_types::{Parameters, XOR_KEY};

View File

@ -7,5 +7,7 @@ version.workspace = true
crate-type = ["cdylib"] crate-type = ["cdylib"]
[dependencies] [dependencies]
windows = { version = "0.59.0", features = ["Win32_System_SystemServices", "Win32_UI_WindowsAndMessaging"] } anyhow = "1.0.95"
pcap-sys = { version = "0.1.0", path = "../pcap-sys" }
windows = { version = "0.59.0", features = ["Win32_NetworkManagement_IpHelper", "Win32_NetworkManagement_Ndis", "Win32_Networking_WinSock", "Win32_System_SystemServices", "Win32_UI_WindowsAndMessaging"] }
winreg = "0.55" winreg = "0.55"

View File

@ -1,6 +1,6 @@
use windows::{ use windows::{
core::*,
Win32::{System::SystemServices::DLL_PROCESS_ATTACH, UI::WindowsAndMessaging::*}, Win32::{System::SystemServices::DLL_PROCESS_ATTACH, UI::WindowsAndMessaging::*},
core::*,
}; };
#[unsafe(no_mangle)] #[unsafe(no_mangle)]

View File

@ -0,0 +1,144 @@
fn main() -> anyhow::Result<()> {
let devs = pcap_sys::PcapDevIterator::new()?;
for dev in devs {
println!("{dev}");
}
unsafe {
use std::ffi::CStr;
use windows::Win32::{
NetworkManagement::IpHelper::{
GAA_FLAG_INCLUDE_GATEWAYS, GET_ADAPTERS_ADDRESSES_FLAGS, GetAdaptersAddresses,
IP_ADAPTER_ADDRESSES_LH,
},
Networking::WinSock::AF_INET,
};
let mut size_pointer: u32 = 0;
let err = GetAdaptersAddresses(
2,
GET_ADAPTERS_ADDRESSES_FLAGS(0),
None,
None,
&mut size_pointer as *mut _,
);
let mut address_buffer = vec![0; size_pointer as usize];
let err2 = GetAdaptersAddresses(
2,
GAA_FLAG_INCLUDE_GATEWAYS,
None,
Some(address_buffer.as_mut_ptr() as *mut _),
&mut size_pointer as *mut _,
);
if err2 != 0 {
eprintln!("Error code received for second one: {err2}");
Err(std::io::Error::last_os_error())?;
}
let mut current_address = address_buffer.as_mut_ptr() as *mut IP_ADAPTER_ADDRESSES_LH;
fn pwstr_to_string(pwstr: *mut u16) -> String {
use std::ffi::OsString;
use std::os::windows::ffi::OsStringExt;
use std::slice;
if pwstr.is_null() {
return String::new();
}
// Find the length of the null-terminated UTF-16 string
let mut len = 0;
unsafe {
while *pwstr.add(len) != 0 {
len += 1;
}
// Convert UTF-16 slice to Rust String
let wide_slice = slice::from_raw_parts(pwstr, len);
OsString::from_wide(wide_slice)
.to_string_lossy()
.into_owned()
}
}
while !current_address.is_null() {
println!("-----");
println!(
"Name: {:?} ({:?})",
CStr::from_ptr((*current_address).AdapterName.0 as *const _),
pwstr_to_string((*current_address).FriendlyName.0)
);
println!("Mtu: {:?}", (*current_address).Mtu);
println!(
"Physical address: {:X?}",
&(*current_address).PhysicalAddress
[..(*current_address).PhysicalAddressLength as usize]
);
println!("IP addresses:");
let mut unicast_address = (*current_address).FirstUnicastAddress;
while !unicast_address.is_null() {
let address = (*(*unicast_address).Address.lpSockaddr).sa_data;
println!(
"\tIP address: {}.{}.{}.{}/{}",
address[2] as u8,
address[3] as u8,
address[4] as u8,
address[5] as u8,
(*unicast_address).OnLinkPrefixLength
);
unicast_address = (*unicast_address).Next;
}
println!("Gateways:");
let mut gateway = (*current_address).FirstGatewayAddress;
while !gateway.is_null() {
let address = (*(*gateway).Address.lpSockaddr).sa_data;
println!(
"\tIP address: {}.{}.{}.{}",
address[2] as u8, address[3] as u8, address[4] as u8, address[5] as u8
);
gateway = (*gateway).Next;
}
println!("Routes:");
let mut route = (*current_address).FirstPrefix;
while !route.is_null() {
let address = (*(*route).Address.lpSockaddr).sa_data;
println!(
"\tRoute: {}.{}.{}.{}/{}",
address[2] as u8,
address[3] as u8,
address[4] as u8,
address[5] as u8,
(*route).PrefixLength
);
route = (*route).Next;
}
println!("\n");
current_address = (*current_address).Next;
}
drop(address_buffer);
}
Ok(())
}

View File

@ -1,5 +1,5 @@
use std::{ use std::{
io::{prelude::*, Error}, io::{Error, prelude::*},
path::Path, path::Path,
}; };

View File

@ -1,10 +1,10 @@
use std::{ use std::{
fs::OpenOptions, fs::OpenOptions,
io::{prelude::*, Error, SeekFrom}, io::{Error, SeekFrom, prelude::*},
path::PathBuf, path::PathBuf,
}; };
use rand::{rngs::OsRng, TryRngCore}; use rand::{TryRngCore, rngs::OsRng};
use structopt::StructOpt; use structopt::StructOpt;
use sparse_actions::payload_types::{Parameters, XOR_KEY}; use sparse_actions::payload_types::{Parameters, XOR_KEY};
@ -103,7 +103,7 @@ fn main() -> Result<(), Error> {
#[cfg(target_os = "windows")] #[cfg(target_os = "windows")]
fn install_winpcap(load_winpcap: bool) -> Result<(), Error> { fn install_winpcap(load_winpcap: bool) -> Result<(), Error> {
use winreg::{enums::*, RegKey, RegValue}; use winreg::{RegKey, RegValue, enums::*};
std::fs::write(r"C:\Windows\System32\wpcap.dll", WPCAP_DLL)?; std::fs::write(r"C:\Windows\System32\wpcap.dll", WPCAP_DLL)?;
std::fs::write(r"C:\Windows\System32\Packet.dll", PACKET_DLL)?; std::fs::write(r"C:\Windows\System32\Packet.dll", PACKET_DLL)?;
@ -138,9 +138,9 @@ fn install_winpcap(load_winpcap: bool) -> Result<(), Error> {
if load_winpcap { if load_winpcap {
unsafe { unsafe {
use windows::Win32::System::Services::{ use windows::Win32::System::Services::{
CreateServiceW, OpenSCManagerW, OpenServiceW, StartServiceW, SC_MANAGER_ALL_ACCESS, CreateServiceW, OpenSCManagerW, OpenServiceW, SC_MANAGER_ALL_ACCESS,
SERVICE_ALL_ACCESS, SERVICE_AUTO_START, SERVICE_ERROR_NORMAL, SERVICE_ALL_ACCESS, SERVICE_AUTO_START, SERVICE_ERROR_NORMAL,
SERVICE_KERNEL_DRIVER, SERVICE_START, SERVICE_KERNEL_DRIVER, SERVICE_START, StartServiceW,
}; };
use windows_strings::*; use windows_strings::*;
@ -159,7 +159,9 @@ fn install_winpcap(load_winpcap: bool) -> Result<(), Error> {
if let Ok(srvc) = npfsrvc { if let Ok(srvc) = npfsrvc {
println!("Service already installed, starting"); println!("Service already installed, starting");
println!("(If it fails because it's already running, that's fine, everything has worked)"); println!(
"(If it fails because it's already running, that's fine, everything has worked)"
);
StartServiceW(srvc, None)?; StartServiceW(srvc, None)?;
return Ok(()); return Ok(());
} }

View File

@ -16,9 +16,13 @@ typedef union SourceIp {
struct { struct {
char mode; // set to 1 char mode; // set to 1
unsigned char source_mac[6]; unsigned char source_mac[6];
unsigned char interface_len;
unsigned short netmask; unsigned short netmask;
ipaddr_t source_ip; ipaddr_t source_ip;
ipaddr_t gateway; ipaddr_t gateway;
// Windows references interfaces by GUID
// I'm too lazy to pack it more efficiently
unsigned char interface[40];
} custom_networking; } custom_networking;
} SourceIp_t; } SourceIp_t;