328 lines
10 KiB
Rust
328 lines
10 KiB
Rust
use std::{
|
|
net::Ipv4Addr,
|
|
pin::Pin,
|
|
sync::{Arc, Mutex},
|
|
task::{Context, Poll},
|
|
};
|
|
|
|
use hyper_util::client::legacy::connect;
|
|
use smoltcp::{
|
|
iface::{Config, Interface, SocketHandle, SocketSet},
|
|
socket::tcp::{RecvError, SendError, Socket, SocketBuffer, State},
|
|
time::Instant,
|
|
wire::{EthernetAddress, IpCidr, Ipv4Address},
|
|
};
|
|
use tokio::{
|
|
io::{AsyncRead, AsyncWrite},
|
|
sync::broadcast,
|
|
task::{spawn_blocking, JoinHandle},
|
|
};
|
|
|
|
use sparse_actions::payload_types::Parameters;
|
|
|
|
use crate::{adapter, error};
|
|
|
|
pub struct NetInterfaceHandle {
|
|
net: Arc<Mutex<(SocketSet<'static>, crate::socket::RawSocket, Interface)>>,
|
|
tcp_handle: SocketHandle,
|
|
|
|
close_background: broadcast::Sender<()>,
|
|
background_process: JoinHandle<()>,
|
|
}
|
|
|
|
impl Drop for NetInterfaceHandle {
|
|
fn drop(&mut self) {
|
|
let _ = self.close_background.send(());
|
|
self.background_process.abort();
|
|
}
|
|
}
|
|
|
|
impl AsyncRead for NetInterfaceHandle {
|
|
fn poll_read(
|
|
self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
tbuf: &mut tokio::io::ReadBuf<'_>,
|
|
) -> Poll<Result<(), std::io::Error>> {
|
|
let this = self.get_mut();
|
|
let Ok(mut inner) = this.net.lock() else {
|
|
return Poll::Ready(Err(std::io::Error::new(
|
|
std::io::ErrorKind::BrokenPipe,
|
|
"mutex for tcp connection is poisoned",
|
|
)));
|
|
};
|
|
let (ref mut s_guard, _, _) = *inner;
|
|
|
|
let socket = s_guard.get_mut::<Socket>(this.tcp_handle);
|
|
|
|
let has_data = socket.can_recv();
|
|
while socket.can_recv() {
|
|
let buf = match socket.recv(|buf| (buf.len(), buf.to_vec())) {
|
|
Ok(v) => v,
|
|
Err(RecvError::InvalidState) => {
|
|
return Poll::Ready(Err(std::io::Error::new(
|
|
std::io::ErrorKind::NetworkDown,
|
|
"received InvalidState from smoltcp",
|
|
)));
|
|
}
|
|
Err(RecvError::Finished) => {
|
|
return Poll::Ready(Err(std::io::Error::new(
|
|
std::io::ErrorKind::BrokenPipe,
|
|
"tried reading from finished connection",
|
|
)));
|
|
}
|
|
};
|
|
tbuf.put_slice(&buf);
|
|
}
|
|
|
|
socket.register_recv_waker(cx.waker());
|
|
|
|
if has_data {
|
|
Poll::Ready(Ok(()))
|
|
} else {
|
|
Poll::Pending
|
|
}
|
|
}
|
|
}
|
|
|
|
impl AsyncWrite for NetInterfaceHandle {
|
|
fn poll_write(
|
|
self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
src: &[u8],
|
|
) -> Poll<std::io::Result<usize>> {
|
|
let this = self.get_mut();
|
|
let Ok(mut inner) = this.net.lock() else {
|
|
return Poll::Ready(Err(std::io::Error::new(
|
|
std::io::ErrorKind::BrokenPipe,
|
|
"mutex for tcp connection is poisoned",
|
|
)));
|
|
};
|
|
let (ref mut s_guard, _, _) = *inner;
|
|
|
|
let socket = s_guard.get_mut::<Socket>(this.tcp_handle);
|
|
|
|
socket.register_send_waker(cx.waker());
|
|
|
|
if socket.can_send() {
|
|
let to_send = socket.send_capacity().min(src.len());
|
|
match socket.send_slice(&src[..to_send]) {
|
|
Ok(s) => Poll::Ready(Ok(s)),
|
|
Err(SendError::InvalidState) => {
|
|
return Poll::Ready(Err(std::io::Error::new(
|
|
std::io::ErrorKind::NetworkDown,
|
|
"received InvalidState from smoltcp",
|
|
)))
|
|
}
|
|
}
|
|
} else {
|
|
Poll::Pending
|
|
}
|
|
}
|
|
|
|
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
|
Poll::Ready(Ok(()))
|
|
}
|
|
|
|
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
|
let this = self.get_mut();
|
|
let Ok(mut inner) = this.net.lock() else {
|
|
return Poll::Ready(Err(std::io::Error::new(
|
|
std::io::ErrorKind::BrokenPipe,
|
|
"mutex for tcp connection is poisoned",
|
|
)));
|
|
};
|
|
let (ref mut s_guard, _, _) = *inner;
|
|
|
|
let socket = s_guard.get_mut::<Socket>(this.tcp_handle);
|
|
socket.close();
|
|
|
|
Poll::Ready(Ok(()))
|
|
}
|
|
}
|
|
|
|
impl connect::Connection for NetInterfaceHandle {
|
|
fn connected(&self) -> connect::Connected {
|
|
connect::Connected::new()
|
|
}
|
|
}
|
|
|
|
pub async fn setup_network<T>(
|
|
adapter: T,
|
|
parameters: Parameters,
|
|
) -> Result<NetInterfaceHandle, error::BeaconError<T::Error>>
|
|
where
|
|
T: adapter::BeaconAdapter + Clone + Send + 'static,
|
|
{
|
|
let net_info = tokio::task::spawn_blocking({
|
|
let adapter = adapter.clone();
|
|
move || adapter.networking_info()
|
|
})
|
|
.await??;
|
|
|
|
let (interface, gateway_ip, mac_address, source_ip, netmask) = match unsafe {
|
|
parameters.source_ip.custom_networking.mode
|
|
} {
|
|
0 => {
|
|
// custom networking
|
|
let interface_name = unsafe {
|
|
¶meters.source_ip.custom_networking.interface
|
|
[..parameters.source_ip.custom_networking.interface_len as usize]
|
|
};
|
|
|
|
let interface = if interface_name.is_empty() {
|
|
let Some(default_route) = net_info.routes.iter().find(|r| r.network.1 == 0) else {
|
|
return Err(error::BeaconError::NoDefaultRoute);
|
|
};
|
|
|
|
&net_info.interfaces[default_route.interface_index]
|
|
} else {
|
|
net_info
|
|
.interfaces
|
|
.iter()
|
|
.find(|intf| intf.name == interface_name)
|
|
.ok_or(error::BeaconError::NoDefaultRoute)?
|
|
};
|
|
|
|
unsafe {
|
|
(
|
|
interface,
|
|
Ipv4Addr::new(
|
|
parameters.source_ip.custom_networking.gateway.a,
|
|
parameters.source_ip.custom_networking.gateway.b,
|
|
parameters.source_ip.custom_networking.gateway.c,
|
|
parameters.source_ip.custom_networking.gateway.d,
|
|
),
|
|
parameters.source_ip.custom_networking.source_mac.clone(),
|
|
Ipv4Addr::new(
|
|
parameters.source_ip.custom_networking.source_ip.a,
|
|
parameters.source_ip.custom_networking.source_ip.b,
|
|
parameters.source_ip.custom_networking.source_ip.c,
|
|
parameters.source_ip.custom_networking.source_ip.d,
|
|
),
|
|
parameters.source_ip.custom_networking.netmask as u8,
|
|
)
|
|
}
|
|
}
|
|
1 => {
|
|
// host networking
|
|
let Some(default_route) = net_info.routes.iter().find(|r| r.network.1 == 0) else {
|
|
return Err(error::BeaconError::NoDefaultRoute);
|
|
};
|
|
|
|
let default_route_if = &net_info.interfaces[default_route.interface_index];
|
|
|
|
(
|
|
default_route_if,
|
|
default_route.gateway.0,
|
|
default_route_if.mac_addr.clone(),
|
|
unsafe {
|
|
Ipv4Addr::new(
|
|
parameters.source_ip.use_host_networking.source_ip.a,
|
|
parameters.source_ip.use_host_networking.source_ip.b,
|
|
parameters.source_ip.use_host_networking.source_ip.c,
|
|
parameters.source_ip.use_host_networking.source_ip.d,
|
|
)
|
|
},
|
|
default_route.gateway.1,
|
|
)
|
|
}
|
|
_ => panic!("Corrupted parameters present!"),
|
|
};
|
|
|
|
let go_promisc = mac_address != [0, 0, 0, 0, 0, 0];
|
|
let mac_address = Some(mac_address)
|
|
.filter(|smac| smac != &[0, 0, 0, 0, 0, 0])
|
|
.unwrap_or(interface.mac_addr);
|
|
|
|
let local_port = 49152 + rand::random::<u16>() % 16384;
|
|
let mut device = crate::socket::RawSocket::new::<T>(interface, go_promisc, local_port)?;
|
|
|
|
let mut config = Config::new(EthernetAddress(mac_address).into());
|
|
config.random_seed = rand::random();
|
|
|
|
let mut iface = Interface::new(config, &mut device, Instant::now());
|
|
iface.update_ip_addrs(|addrs| {
|
|
addrs
|
|
.push(IpCidr::new(source_ip.into(), netmask))
|
|
.expect("could not add new IP address");
|
|
});
|
|
iface
|
|
.routes_mut()
|
|
.add_default_ipv4_route(gateway_ip.into())
|
|
.expect("did not expect route table to be full");
|
|
|
|
let tcp_rx_buffer = SocketBuffer::new(vec![0; 8192]);
|
|
let tcp_tx_buffer = SocketBuffer::new(vec![0; 8192]);
|
|
let tcp_socket = Socket::new(tcp_rx_buffer, tcp_tx_buffer);
|
|
|
|
let mut sockets = SocketSet::new(vec![]);
|
|
let tcp_handle = sockets.add(tcp_socket);
|
|
|
|
let mut active = false;
|
|
let ready_wait = device.get_ready_wait_callback();
|
|
|
|
let (close_background, mut close_background_recv) = broadcast::channel(1);
|
|
|
|
let destination = (
|
|
Ipv4Address::new(
|
|
parameters.destination_ip.a,
|
|
parameters.destination_ip.b,
|
|
parameters.destination_ip.c,
|
|
parameters.destination_ip.d,
|
|
),
|
|
parameters.destination_port,
|
|
);
|
|
|
|
while !active {
|
|
let timestamp = Instant::now();
|
|
iface.poll(timestamp, &mut device, &mut sockets);
|
|
|
|
let cx = iface.context();
|
|
|
|
let socket = sockets.get_mut::<Socket>(tcp_handle);
|
|
if !socket.is_active() {
|
|
socket.connect(cx, destination, local_port)?;
|
|
}
|
|
active = socket.is_active() && socket.state() == State::Established;
|
|
|
|
ready_wait.wait(iface.poll_delay(timestamp, &sockets).map(Into::into))?;
|
|
}
|
|
if cfg!(debug_assertions) {
|
|
println!("Connected!");
|
|
}
|
|
|
|
let net = Arc::new(Mutex::new((sockets, device, iface)));
|
|
|
|
let background_process = spawn_blocking({
|
|
let net = Arc::clone(&net);
|
|
|
|
move || loop {
|
|
if close_background_recv.try_recv().is_ok() {
|
|
break;
|
|
}
|
|
|
|
let delay = {
|
|
let Ok(mut guard) = net.lock() else {
|
|
continue;
|
|
};
|
|
let (ref mut s_guard, ref mut d_guard, ref mut i_guard) = *guard;
|
|
|
|
let timestamp = Instant::now();
|
|
i_guard.poll(timestamp, d_guard, s_guard);
|
|
|
|
i_guard.poll_delay(timestamp, s_guard)
|
|
};
|
|
|
|
let _ = ready_wait.wait(delay.map(Into::into));
|
|
}
|
|
});
|
|
|
|
Ok(NetInterfaceHandle {
|
|
net,
|
|
tcp_handle,
|
|
|
|
close_background,
|
|
background_process,
|
|
})
|
|
}
|