feat: set up basic sessions

This commit is contained in:
Andrew Rioux
2025-01-28 03:10:43 -05:00
parent bee66a8d6c
commit bf879bb081
23 changed files with 862 additions and 106 deletions

View File

@@ -2,8 +2,9 @@ use leptos::prelude::*;
use leptos_meta::{provide_meta_context, MetaTags, Stylesheet, Title};
use leptos_router::{
components::{A, Route, Router, Routes},
StaticSegment,
path
};
use serde::{Serialize, Deserialize};
#[server]
pub async fn test_retrieve() -> Result<u64, ServerFnError> {
@@ -20,6 +21,58 @@ pub async fn test_retrieve() -> Result<u64, ServerFnError> {
Ok(since_the_epoch)
}
#[derive(Clone, Serialize, Deserialize)]
pub struct User {
user_id: i64,
user_name: String,
}
#[server]
async fn me() -> Result<Option<User>, ServerFnError> {
let session: crate::db::user::AuthSession = leptos_axum::extract().await?;
Ok(session.user.map(|user| User {
user_id: user.user_id,
user_name: user.user_name
}))
}
#[server]
async fn login(username: String, password: String, next: Option<String>) -> Result<(), ServerFnError> {
use leptos::server_fn::error::NoCustomError;
let mut session: crate::db::user::AuthSession = leptos_axum::extract().await?;
let user = match session.authenticate((username, password).clone()).await {
Ok(Some(user)) => user,
Ok(None) => return Err(ServerFnError::<NoCustomError>::ServerError("Invalid credentials".to_string())),
Err(e) => return Err(server_fn::server_fn_error!(e).into())
};
if let Err(e) = session.login(&user).await {
return Err(server_fn::server_fn_error!(e).into());
}
if let Some(target) = next {
leptos_axum::redirect(&target);
}
Ok(())
}
#[server]
async fn logout() -> Result<(), ServerFnError> {
let mut session: crate::db::user::AuthSession = leptos_axum::extract().await?;
match session.logout().await {
Ok(_) => {
leptos_axum::redirect("/login");
Ok(())
}
Err(e) => Err(server_fn::server_fn_error!(e).into())
}
}
pub fn shell(options: LeptosOptions) -> impl IntoView {
view! {
<!DOCTYPE html>
@@ -42,34 +95,71 @@ pub fn shell(options: LeptosOptions) -> impl IntoView {
pub fn App() -> impl IntoView {
provide_meta_context();
let user = Resource::new(|| (), |_| async { me().await });
view! {
<Stylesheet id="leptos" href="/pkg/sparse-server.css"/>
// sets the document title
<Title text="Welcome to Leptos"/>
<Title text="Sparse Control"/>
// content for this welcome page
<Router>
<nav>
<h1>"Sparse control"</h1>
<A href="/">"Home"</A>
<Suspense fallback=|| ()>
<A href="/beacons">"Beacon management"</A>
<A href="/users">"Users"</A>
{move || user
.get()
.map(|err| err.ok())
.flatten()
.flatten()
.map(|_| view! {
<a
href="#"
on:click=move |_| {
leptos::task::spawn_local(async move {
let _ = logout().await;
user.refetch();
});
}
>
"Log out"
</a>
})}
{move || user
.get()
.map(|err| err.ok())
.flatten()
.flatten()
.is_none()
.then(|| view! {
<A href="/login">"Log in"</A>
})}
</Suspense>
</nav>
<aside class="beacons">
</aside>
<Routes fallback=|| "Page not found.".into_view()>
<Route path=StaticSegment("") view=HomePage/>
<Route path=StaticSegment("/users") view=crate::users::UserView/>
<Route path=path!("users") view=crate::users::UserView />
<Route path=path!("login") view=move || view! { <LoginPage /> } />
<Route path=path!("") view=HomePage/>
</Routes>
</Router>
}
}
#[component]
fn LoginPage() -> impl IntoView {
}
/// Renders the home page of your application.
#[component]
fn HomePage() -> impl IntoView {
use leptos_use::{UseWebSocketReturn, use_websocket};
// Creates a reactive value to update the button
let count = RwSignal::new(0);
@@ -91,23 +181,31 @@ fn HomePage() -> impl IntoView {
let pending = request_time.pending();
let text_input = RwSignal::new("".to_owned());
#[cfg_attr(feature = "ssr", allow(unused_variables))]
let (messages, set_messages) = signal(Vec::<String>::new());
cfg_if::cfg_if! {
if #[cfg(feature = "hydrate")] {
use leptos_use::{UseWebSocketReturn, use_websocket};
let UseWebSocketReturn { send, message, .. } = use_websocket::<String, String, codee::string::FromToStringCodec>("/ws");
let UseWebSocketReturn { send, message, .. } = use_websocket::<String, String, codee::string::FromToStringCodec>("/ws");
Effect::new(move |_| {
message.with(move |message| {
if let Some(m) = message {
leptos::logging::log!("got update: {}", m);
set_messages.update(|messages: &mut Vec<_>| messages.push(format!("msg: {}", m)));
}
})
});
Effect::new(move |_| {
message.with(move |message| {
if let Some(m) = message {
leptos::logging::log!("got update: {}", m);
set_messages.update(|messages: &mut Vec<_>| messages.push(format!("msg: {}", m)));
}
})
});
let send_message = move |_| {
send(&text_input.get());
text_input.set("".to_string());
};
let send_message = move |_| {
send(&text_input.get());
text_input.set("".to_string());
};
} else {
let send_message = move |_| {};
}
}
view! {
<main class="main">

View File

@@ -1,4 +1,5 @@
use std::path::PathBuf;
use std::{net::SocketAddrV4, path::PathBuf};
use structopt::StructOpt;
pub mod user;
@@ -25,7 +26,15 @@ pub struct Options {
#[structopt()]
pub enum Command {
/// Run the web and API server
Serve {},
Serve {
/// Address to bind to for the management interface
#[structopt(default_value = "127.0.0.1:3000")]
management_address: SocketAddrV4,
/// Public address to bind to for the beacons to call back to
#[structopt(default_value = "127.0.0.1:5000")]
bind_address: SocketAddrV4,
},
/// Extract the public key and print it to standard out
ExtractPubKey {},

View File

@@ -1,7 +1,7 @@
use std::process::ExitCode;
use futures_util::StreamExt;
use sqlx::{Database, query, sqlite::SqlitePool};
use sqlx::{query, sqlite::SqlitePool};
use crate::cli::UserCommand as UC;

View File

@@ -1,15 +1,2 @@
#[cfg(feature = "ssr")]
pub mod user;
pub struct User {
pub user_id: i16,
pub user_name: String,
pub password_salt: String,
pub password_hash: String,
}
pub struct Sessions {
pub session_id: String,
pub user_id: i16,
pub expires: chrono::DateTime<chrono::offset::Local>,
}

View File

@@ -1,29 +1,77 @@
use pbkdf2::{pbkdf2_hmac_array, password_hash::{rand_core::OsRng, SaltString}};
use sha2::Sha256;
#[derive(Clone)]
pub struct User {
pub user_id: i64,
pub user_name: String,
password_hash: String,
pub last_active: Option<i64>
}
use async_trait::async_trait;
use pbkdf2::{Pbkdf2, password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, rand_core::OsRng, SaltString}};
use axum_login::{AuthUser, AuthnBackend, UserId};
use sqlx::SqlitePool;
use crate::error::Error;
const PASSWORD_ITERATIONS: u32 = 100_000;
impl std::fmt::Debug for User {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("User")
.field("user_id", &self.user_id)
.field("user_name", &self.user_name)
.field("password_hash", &"[redacted]")
.finish()
}
}
impl AuthUser for User {
type Id = i64;
fn id(&self) -> Self::Id {
self.user_id
}
fn session_auth_hash(&self) -> &[u8] {
self.password_hash.as_bytes()
}
}
async fn hash_password(pass: &[u8]) -> Result<String, Error> {
Ok(tokio::task::spawn_blocking({
let pass = pass.to_owned();
let salt = SaltString::generate(&mut OsRng);
move || Pbkdf2.hash_password(
&*pass,
&salt,
).map(|hash| hash.to_string())
}).await??)
}
async fn verify_password(pass: &str, hash: &str) -> Result<bool, Error> {
Ok(tokio::task::spawn_blocking({
let pass = pass.to_owned();
let hash = hash.to_owned();
move ||
PasswordHash::new(&*hash)
.map(|parsed| Pbkdf2.verify_password(
&pass.as_bytes(),
&parsed
).is_ok())
}).await??)
}
pub async fn reset_password<'a, E>(pool: E, id: i16, password: String) -> Result<(), crate::error::Error>
where
E: sqlx::SqliteExecutor<'a>
{
let salt = SaltString::generate(&mut OsRng);
let key = pbkdf2_hmac_array::<Sha256, 20>(
password.as_bytes(),
salt.as_str().as_bytes(),
PASSWORD_ITERATIONS
);
let salt_string = hex::encode(salt.as_str().as_bytes());
let password_string = hex::encode(&key[..]);
let password_string = hash_password(
password.as_bytes()
).await?;
sqlx::query!(
"UPDATE users SET password_hash = ?, password_salt = ? WHERE user_id = ?",
"UPDATE users SET password_hash = ? WHERE user_id = ?",
password_string,
salt_string,
id
)
.execute(pool)
@@ -52,7 +100,7 @@ where
tracing::info!("Creating new user {}", name);
let new_id = sqlx::query!(
r#"INSERT INTO users (user_name, password_salt, password_hash) VALUES (?, "", "")"#,
r#"INSERT INTO users (user_name, password_hash) VALUES (?, "")"#,
name
)
.execute(&mut *tx)
@@ -65,3 +113,68 @@ where
Ok(())
}
#[derive(Clone)]
pub struct Backend(SqlitePool);
impl Backend {
pub fn new(db: SqlitePool) -> Self {
Self(db)
}
}
#[async_trait]
impl AuthnBackend for Backend {
type User = User;
type Credentials = (String, String);
type Error = Error;
async fn authenticate(
&self,
creds: Self::Credentials
) -> Result<Option<Self::User>, Self::Error> {
let user: Option<Self::User> = sqlx::query_as!(
User,
"SELECT * FROM users WHERE user_name = ?",
creds.0
)
.fetch_optional(&self.0)
.await?;
let Some(user) = user else { return Ok(None); };
let good_hash = verify_password(
&user.password_hash,
&creds.1
).await?;
if good_hash {
let now = chrono::Utc::now().timestamp();
sqlx::query!(
"UPDATE users SET last_active = ?",
now
)
.execute(&self.0)
.await?;
Ok(Some(user))
} else {
Ok(None)
}
}
async fn get_user(&self, user_id: &UserId<Self>) -> Result<Option<Self::User>, Self::Error> {
let user: Option<Self::User> = sqlx::query_as!(
User,
"SELECT * FROM users WHERE user_id = ?",
user_id
)
.fetch_optional(&self.0)
.await?;
Ok(user)
}
}
pub type AuthSession = axum_login::AuthSession<Backend>;

View File

@@ -1,13 +1,21 @@
#[derive(Debug)]
pub enum Error {
Generic(String),
UserCreate(String),
#[cfg(feature = "ssr")]
Sqlx(sqlx::Error),
#[cfg(feature = "ssr")]
TokioJoin(tokio::task::JoinError),
#[cfg(feature = "ssr")]
Pbkdf2(pbkdf2::password_hash::errors::Error),
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Error::Generic(err) => {
write!(f, "generic error: {err}")
}
Error::UserCreate(err) => {
write!(f, "user create error: {err}")
}
@@ -15,6 +23,14 @@ impl std::fmt::Display for Error {
Error::Sqlx(err) => {
write!(f, "sqlx error: {err:?}")
}
#[cfg(feature = "ssr")]
Error::TokioJoin(err) => {
write!(f, "tokio join error: {err:?}")
}
#[cfg(feature = "ssr")]
Error::Pbkdf2(err) => {
write!(f, "password hash error: {err:?}")
}
}
}
}
@@ -24,14 +40,38 @@ impl std::error::Error for Error {
match self {
#[cfg(feature = "ssr")]
Error::Sqlx(err) => Some(err),
#[cfg(feature = "ssr")]
Error::TokioJoin(err) => Some(err),
_ => None,
}
}
}
impl std::str::FromStr for Error {
type Err = Self;
fn from_str(err: &str) -> Result<Self, Self::Err> {
Ok(Self::Generic(err.to_string()))
}
}
#[cfg(feature = "ssr")]
impl From<sqlx::Error> for Error {
fn from(err: sqlx::Error) -> Self {
Self::Sqlx(err)
}
}
#[cfg(feature = "ssr")]
impl From<tokio::task::JoinError> for Error {
fn from(err: tokio::task::JoinError) -> Self {
Self::TokioJoin(err)
}
}
#[cfg(feature = "ssr")]
impl From<pbkdf2::password_hash::errors::Error> for Error {
fn from(err: pbkdf2::password_hash::errors::Error) -> Self {
Self::Pbkdf2(err)
}
}

View File

@@ -1,10 +1,15 @@
#[cfg(feature = "ssr")]
pub(crate) mod beacons {
#[allow(dead_code)]
pub const LINUX_BEACON: &'static [u8] = include_bytes!(std::env!("SPARSE_BEACON_LINUX"));
#[allow(dead_code)]
pub const FREEBSD_BEACON: &'static [u8] = include_bytes!(std::env!("SPARSE_BEACON_FREEBSD"));
#[allow(dead_code)]
pub const WINDOWS_BEACON: &'static [u8] = include_bytes!(std::env!("SPARSE_BEACON_WINDOWS"));
#[allow(dead_code)]
pub const LINUX_INSTALLER: &'static [u8] = include_bytes!(std::env!("SPARSE_INSTALLER_LINUX"));
#[allow(dead_code)]
pub const FREEBSD_INSTALLER: &'static [u8] = include_bytes!(std::env!("SPARSE_INSTALLER_FREEBSD"));
}
@@ -78,9 +83,9 @@ async fn main() -> anyhow::Result<std::process::ExitCode> {
tracing::info!("Done running database migrations!");
match options.command.clone() {
Some(cli::Command::Serve { }) => {
Some(cli::Command::Serve { management_address, bind_address }) => {
tracing::info!("Performing requested action, acting as web server");
webserver::serve_web(options, pool).await
webserver::serve_web(management_address, bind_address, pool).await
}
Some(cli::Command::ExtractPubKey { }) => {
Ok(ExitCode::SUCCESS)
@@ -89,8 +94,13 @@ async fn main() -> anyhow::Result<std::process::ExitCode> {
cli::user::handle_user_command(command, pool).await
}
None => {
use std::net::{Ipv4Addr, SocketAddrV4};
tracing::info!("Performing default action of acting as web server");
webserver::serve_web(options, pool).await
let default_management_ip = SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 3000);
let default_beacon_ip = SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 5000);
webserver::serve_web(default_management_ip, default_beacon_ip, pool).await
}
}
}

View File

@@ -70,6 +70,7 @@ pub fn RenderUser(refresh_user_list: Action<(), ()>, user: PubUser) -> impl Into
let UseIntervalReturn { counter, .. } = use_interval(1000);
let (time_ago, set_time_ago) = signal(user.last_active.map(|active| format_delta(Utc::now() - active)));
#[cfg(feature = "hydrate")]
Effect::watch(
move || counter.get(),
move |_, _, _| {
@@ -194,7 +195,7 @@ async fn list_users() -> Result<Vec<PubUser>, ServerFnError> {
let users = sqlx::query_as!(
DbUser,
"SELECT user_id, user_name, (SELECT MAX(expires) FROM sessions s WHERE s.user_id = u.user_id) as last_active FROM users u"
"SELECT user_id, user_name, last_active FROM users"
)
.fetch(&pool)
.map(|user| user.map(|u| PubUser {

View File

@@ -1,12 +1,14 @@
use std::process::ExitCode;
use std::{net::SocketAddrV4, process::ExitCode};
use sqlx::sqlite::SqlitePool;
use axum::Router;
use leptos::logging::log;
use leptos::prelude::*;
use leptos_axum::{generate_route_list, LeptosRoutes};
use sparse_server::app::*;
use tokio::{signal, task::AbortHandle};
use tower_sessions::{Expiry, SessionManagerLayer, session_store::ExpiredDeletion};
use tower_sessions_sqlx_store::SqliteStore;
use sparse_server::app::*;
pub async fn websocket(ws: axum::extract::ws::WebSocketUpgrade) -> axum::response::Response {
tracing::info!("Handling websocket request to /ws");
@@ -14,7 +16,6 @@ pub async fn websocket(ws: axum::extract::ws::WebSocketUpgrade) -> axum::respons
}
async fn handle_websocket(mut socket: axum::extract::ws::WebSocket) {
use futures_util::StreamExt;
use tracing::info;
let mut count = 0;
@@ -43,13 +44,34 @@ async fn handle_websocket(mut socket: axum::extract::ws::WebSocket) {
}
}
pub async fn serve_web(options: crate::cli::Options, db: SqlitePool) -> anyhow::Result<ExitCode> {
pub async fn serve_web(management_address: SocketAddrV4, _bind_address: SocketAddrV4, db: SqlitePool) -> anyhow::Result<ExitCode> {
let conf = get_configuration(None).unwrap();
let addr = conf.leptos_options.site_addr;
let leptos_options = conf.leptos_options;
// Generate the list of routes in your Leptos App
let routes = generate_route_list(App);
let session_store = SqliteStore::new(db.clone());
session_store.migrate().await?;
let deletion_task = tokio::task::spawn(
session_store
.clone()
.continuously_delete_expired(tokio::time::Duration::from_secs(60))
);
let session_layer = SessionManagerLayer::new(session_store)
.with_secure(false)
.with_expiry(Expiry::OnInactivity(time::Duration::minutes(20)));
let backend = crate::db::user::Backend::new(db.clone());
let auth_layer = axum_login::AuthManagerLayerBuilder::new(backend, session_layer).build();
let compression_layer = tower_http::compression::CompressionLayer::new()
.gzip(true)
.deflate(true)
.br(true)
.zstd(true);
let app = Router::new()
.route("/ws", axum::routing::any(websocket))
.leptos_routes_with_context(
@@ -61,13 +83,50 @@ pub async fn serve_web(options: crate::cli::Options, db: SqlitePool) -> anyhow::
move || shell(leptos_options.clone())
})
.fallback(leptos_axum::file_and_error_handler(shell))
.with_state(leptos_options);
.with_state(leptos_options)
.layer(auth_layer)
.layer(compression_layer);
// run our app with hyper
// `axum::Server` is a re-export of `hyper::Server`
tracing::info!("listening on http://{}", &addr);
let listener = tokio::net::TcpListener::bind(&addr).await?;
axum::serve(listener, app.into_make_service()).await?;
let management_listener = tokio::net::TcpListener::bind(&management_address).await?;
tracing::info!("management interface listening on http://{}", &management_address);
axum::serve(management_listener, app.into_make_service())
.with_graceful_shutdown(shutdown_signal(deletion_task.abort_handle()))
.await?;
deletion_task.await??;
Ok(ExitCode::SUCCESS)
}
async fn shutdown_signal(deletion_task_abort_handle: AbortHandle) {
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {
tracing::info!("Received Ctrl-C");
deletion_task_abort_handle.abort()
},
_ = terminate => {
tracing::info!("Received terminate command");
deletion_task_abort_handle.abort()
},
}
}