refactor: simplified route query code

fighting the borrow checker
This commit is contained in:
Andrew Rioux 2023-05-01 10:52:23 -04:00
parent c16bf366b7
commit cfdf8f7e86
4 changed files with 162 additions and 30 deletions

View File

@ -29,22 +29,53 @@ async fn main() -> anyhow::Result<()> {
let socket = netlink::Socket::new()?; let socket = netlink::Socket::new()?;
let routes = socket.get_routes()?; let routes = socket.get_routes()?;
let mut routes_inner = routes.iter().collect::<Vec<_>>();
let neighs = socket.get_neigh()?; let neighs = socket.get_neigh()?;
let links = socket.get_links()?; let links = socket.get_links()?;
let addrs = socket.get_addrs()?;
let srcip = route::get_srcip_for_dstip(&routes, target) routes_inner.sort_by(|r1, r2| {
r2.dst().map(|a| a.cidrlen())
.partial_cmp(&r1.dst().map(|a| a.cidrlen()))
.unwrap_or(std::cmp::Ordering::Equal)
});
println!("-=- Addrs -=-");
for addr in addrs.iter() {
println!("addr: {:?}, {}", addr.local(), addr.ifindex());
}
println!("-=- Routes -=-");
for route in routes_inner.iter() {
println!("route: {:?}", route.dst());
for hop in route.hop_iter() {
println!("\thop: {:?}, {}", hop.gateway(), hop.ifindex());
}
}
println!("-=- Neighs -=-");
for neigh in neighs.iter() {
println!("neigh: {:?}, {:?}, {}", neigh.dst(), neigh.lladdr(), neigh.ifindex());
}
println!("-=- Links -=-");
for link in links.iter() {
println!("link {:?}: {:?}, {}", link.name(), link.addr(), link.ifindex());
}
let (srcip, srcmac, dstmac) = route::get_macs_and_src_for_ip(&addrs, &routes, &neighs, &links, target)
.ok_or(anyhow!("unable to find a route to the IP"))?;
/*let srcip = route::get_srcip_for_dstip(&routes, target)
.ok_or(anyhow!("Unable to find a route to the IP"))?; .ok_or(anyhow!("Unable to find a route to the IP"))?;
let (target_link, dst_mac) = route::get_neigh_for_addr(&routes, &neighs, &links, &srcip.into()) let (srcip, target_link, dst_mac) = route::get_neigh_for_addr(&routes, &neighs, &links, &srcip.into())
.ok_or(anyhow!("Unable to find local interface to use"))?; .ok_or(anyhow!("Unable to find local interface to use"))?;*/
let src_mac = target_link.addr().hw_address(); dbg!(srcmac);
dbg!(dstmac);
dbg!(srcip);
dbg!(target);
( ( srcmac, dstmac, srcip )
TryInto::<[u8;6]>::try_into(src_mac).unwrap(),
TryInto::<[u8;6]>::try_into(dst_mac).unwrap(),
srcip
)
}; };
let mut interfaces = pcap_sys::PcapDevIterator::new()?; let mut interfaces = pcap_sys::PcapDevIterator::new()?;

View File

@ -17,7 +17,7 @@ use std::{ptr, marker::PhantomData};
use libc::{AF_UNSPEC, AF_INET}; use libc::{AF_UNSPEC, AF_INET};
use crate::{nl_ffi::*, error, route::{Link, Neigh, Route}}; use crate::{nl_ffi::*, error, route::{Link, Neigh, Route, RtAddr}};
/// A netlink socket used to communicate with the kernel /// A netlink socket used to communicate with the kernel
pub struct Socket { pub struct Socket {
@ -89,6 +89,23 @@ impl Socket {
}) })
} }
} }
pub fn get_addrs(&self) -> error::Result<Cache<RtAddr>> {
unsafe {
let mut addr_cache = ptr::null_mut::<nl_cache>();
let ret = rtnl_addr_alloc_cache(self.sock, &mut addr_cache as *mut _);
if ret < 0 {
return Err(error::Error::new(ret));
}
Ok(Cache {
cache: addr_cache,
dt: PhantomData
})
}
}
} }
impl Drop for Socket { impl Drop for Socket {

View File

@ -31,6 +31,7 @@ nl_obj!(nl_cache);
nl_obj!(nl_addr); nl_obj!(nl_addr);
nl_obj!(nl_object); nl_obj!(nl_object);
nl_obj!(nl_list_head); nl_obj!(nl_list_head);
nl_obj!(rtnl_addr);
nl_obj!(rtnl_link); nl_obj!(rtnl_link);
nl_obj!(rtnl_neigh); nl_obj!(rtnl_neigh);
nl_obj!(rtnl_route); nl_obj!(rtnl_route);
@ -62,6 +63,11 @@ extern "C" {
pub fn nl_cache_get_next(obj: *mut nl_object) -> *mut nl_object; pub fn nl_cache_get_next(obj: *mut nl_object) -> *mut nl_object;
pub fn nl_cache_destroy_and_free(obj: *mut nl_cache) -> c_void; pub fn nl_cache_destroy_and_free(obj: *mut nl_cache) -> c_void;
pub fn rtnl_addr_alloc_cache(sock: *mut nl_sock, result: *mut *mut nl_cache) -> c_int;
pub fn rtnl_addr_get_ifindex(addr: *mut rtnl_addr) -> c_int;
pub fn rtnl_addr_get_family(addr: *mut rtnl_addr) -> c_int;
pub fn rtnl_addr_get_local(addr: *mut rtnl_addr) -> *mut nl_addr;
pub fn rtnl_neigh_alloc_cache(sock: *mut nl_sock, result: *mut *mut nl_cache) -> c_int; pub fn rtnl_neigh_alloc_cache(sock: *mut nl_sock, result: *mut *mut nl_cache) -> c_int;
pub fn rtnl_neigh_get(cache: *mut nl_cache, ifindex: c_int, dst: *mut nl_addr) -> *mut rtnl_neigh; pub fn rtnl_neigh_get(cache: *mut nl_cache, ifindex: c_int, dst: *mut nl_addr) -> *mut rtnl_neigh;
pub fn rtnl_neigh_get_dst(neigh: *mut rtnl_neigh) -> *mut nl_addr; pub fn rtnl_neigh_get_dst(neigh: *mut rtnl_neigh) -> *mut nl_addr;

View File

@ -21,6 +21,39 @@ use crate::{error, netlink::{Cache, self}};
use super::nl_ffi::*; use super::nl_ffi::*;
/// Represents an address assigned to a link
pub struct RtAddr {
addr: *mut rtnl_addr
}
impl RtAddr {
pub fn local(&self) -> Option<Addr> {
unsafe {
let addr = rtnl_addr_get_local(self.addr);
if addr.is_null() {
return None;
}
Some(Addr { addr })
}
}
pub fn ifindex(&self) -> i32 {
unsafe { rtnl_addr_get_ifindex(self.addr) }
}
pub fn family(&self) -> i32 {
unsafe { rtnl_addr_get_family(self.addr) }
}
}
impl From<*mut nl_object> for RtAddr {
fn from(value: *mut nl_object) -> Self {
RtAddr { addr: value as *mut _ }
}
}
/// Represents a network link, which can represent a network device /// Represents a network link, which can represent a network device
pub struct Link { pub struct Link {
@ -45,7 +78,7 @@ impl Link {
} }
} }
/// Determines the type of link. Ethernet devices are "veth" /// Determines the type of link. Ethernet devices are "veth or eth"
pub fn ltype(&self) -> Option<String> { pub fn ltype(&self) -> Option<String> {
unsafe { unsafe {
let ltype = rtnl_link_get_type(self.link); let ltype = rtnl_link_get_type(self.link);
@ -96,32 +129,85 @@ impl From<*mut nl_object> for Link {
} }
} }
pub fn get_macs_and_src_for_ip(addrs: &Cache<RtAddr>, routes: &Cache<Route>, neighs: &Cache<Neigh>, links: &Cache<Link>, addr: Ipv4Addr) -> Option<(Ipv4Addr, [u8; 6], [u8; 6])> {
let mut sorted_routes = routes.iter().collect::<Vec<_>>();
sorted_routes.sort_by(|r1, r2| {
r2.dst().map(|a| a.cidrlen())
.partial_cmp(&r1.dst().map(|a| a.cidrlen()))
.unwrap_or(std::cmp::Ordering::Equal)
});
let ip_int = u32::from(addr);
let route = sorted_routes
.iter()
.find(|route| {
let Some(dst) = route.dst() else { return false };
let mask = if dst.cidrlen() != 0 {
(0xFFFFFFFFu32.overflowing_shr(32 - dst.cidrlen())).0.overflowing_shl(32 - dst.cidrlen()).0
} else {
0
};
let Ok(dst_addr): Result<Ipv4Addr, _> = (&dst).try_into() else { return false };
let dst_addr: u32 = dst_addr.into();
(mask & dst_addr) == (mask & ip_int)
})?;
let link_ind = route
.hop_iter()
.next()?
.ifindex();
let link = netlink::get_link_by_index(links, link_ind)?;
let neigh = neighs
.iter()
.find(|n| n.ifindex() == link.ifindex())?;
let srcip = addrs
.iter()
.find(|a| a.ifindex() == link.ifindex())?;
Some((
(&srcip.local()?).try_into().ok()?,
link.addr().hw_address().try_into().ok()?,
neigh.lladdr().hw_address().try_into().ok()?
))
}
/// Gets the neighbor record for the source IP specified, or get the default address /// Gets the neighbor record for the source IP specified, or get the default address
pub fn get_neigh_for_addr(routes: &Cache<Route>, neighs: &Cache<Neigh>, links: &Cache<Link>, addr: &Addr) -> Option<(Link, [u8; 6])> { pub fn get_neigh_for_addr(routes: &Cache<Route>, neighs: &Cache<Neigh>, links: &Cache<Link>, addr: &Addr) -> Option<(Ipv4Addr, Link, [u8; 6])> {
for link in links.iter() { for link in links.iter() {
let Some(neigh) = link.get_neigh(&neighs, addr) else { continue; }; let Some(neigh) = link.get_neigh(&neighs, addr) else { continue; };
return Some((link, neigh)); return Some((addr.try_into().ok()?, link, neigh));
} }
// No good neighbors were found above, try to use the default address // No good neighbors were found above, try to use the default address
println!("here");
if let Some(def_neigh) = get_default_route(routes) { if let Some(def_neigh) = get_default_route(routes) {
println!("Found default route, trying to get link for it"); println!("Found default route, trying to get link for it");
if let Some((link, neigh)) = neighs if let Some((laddr, link, neigh)) = neighs
.iter() .iter()
.filter_map(|n| { .filter_map(|n| {
let Some(link) = netlink::get_link_by_index(links, n.ifindex()) else { let Some(link) = netlink::get_link_by_index(links, n.ifindex()) else {
return None; return None;
}; };
if Some(n.ifindex()) != def_neigh.hop_iter().next().map(|h| h.ifindex()) { let Some(first_hop) = def_neigh.hop_iter().next() else {
return None
};
if n.ifindex() != first_hop.ifindex() {
return None; return None;
} }
Some((link, n.lladdr())) Some(((&first_hop.gateway()?).try_into().ok()?, link, n.lladdr()))
}) })
.next() { .next() {
return Some((link, neigh.hw_address().try_into().ok()?)) return Some((laddr, link, neigh.hw_address().try_into().ok()?))
} }
} }
@ -250,10 +336,10 @@ impl From<Ipv4Addr> for Addr {
} }
} }
impl TryFrom<Addr> for Ipv4Addr { impl TryFrom<&Addr> for Ipv4Addr {
type Error = error::Error; type Error = error::Error;
fn try_from(value: Addr) -> Result<Self, Self::Error> { fn try_from(value: &Addr) -> Result<Self, Self::Error> {
if value.len() != 4 { if value.len() != 4 {
return Err(error::Error::new(15 /* NL_AF_MISMATCH */)); return Err(error::Error::new(15 /* NL_AF_MISMATCH */));
} }
@ -295,13 +381,6 @@ impl Route {
} }
} }
/// Determines which interface is affected by this route
pub fn ifindex(&self) -> c_int {
unsafe {
rtnl_route_get_iif(self.route)
}
}
/// Returns the amount of hops are in this route /// Returns the amount of hops are in this route
pub fn nexthop_len(&self) -> c_int { pub fn nexthop_len(&self) -> c_int {
unsafe { unsafe {
@ -410,7 +489,7 @@ pub fn get_srcip_for_dstip(routes: &Cache<Route>, ip: Ipv4Addr) -> Option<Ipv4Ad
0 0
}; };
let Ok(dst_addr): Result<Ipv4Addr, _> = dst.try_into() else { return false }; let Ok(dst_addr): Result<Ipv4Addr, _> = (&dst).try_into() else { return false };
let dst_addr: u32 = dst_addr.into(); let dst_addr: u32 = dst_addr.into();
(mask & dst_addr) == (mask & ip_int) (mask & dst_addr) == (mask & ip_int)
@ -422,7 +501,6 @@ pub fn get_srcip_for_dstip(routes: &Cache<Route>, ip: Ipv4Addr) -> Option<Ipv4Ad
.and_then(|hop| hop.gateway()) .and_then(|hop| hop.gateway())
.or(route.dst()) .or(route.dst())
}) })
.filter_map(|gateway| gateway.try_into().ok()) .filter_map(|gateway| (&gateway).try_into().ok())
.next() .next()
} }