diff --git a/snow-scanner/Cargo.toml b/snow-scanner/Cargo.toml index e778891..fca07a6 100644 --- a/snow-scanner/Cargo.toml +++ b/snow-scanner/Cargo.toml @@ -38,9 +38,10 @@ members = [ # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -rocket = { git = "https://github.com/rwf2/Rocket/", rev = "3bf9ef02d6e803fe9f753777f5a829dda6d2453d"} -# contrib/db_pools +rocket = { git = "https://github.com/rwf2/Rocket/", rev = "3bf9ef02d6e803fe9f753777f5a829dda6d2453d"} +rocket_ws = { git = "https://github.com/rwf2/Rocket/", rev = "3bf9ef02d6e803fe9f753777f5a829dda6d2453d"} rocket_db_pools = { git = "https://github.com/rwf2/Rocket/", rev = "3bf9ef02d6e803fe9f753777f5a829dda6d2453d", default-features = false, features = ["diesel_mysql"] } + # mariadb-dev on Alpine # "mysqlclient-src" "mysql_backend" diesel = { version = "^2", default-features = false, features = ["mysql", "chrono", "uuid"] } diff --git a/snow-scanner/src/main.rs b/snow-scanner/src/main.rs index ea8ac84..13ff24e 100644 --- a/snow-scanner/src/main.rs +++ b/snow-scanner/src/main.rs @@ -3,7 +3,7 @@ use chrono::{NaiveDateTime, Utc}; #[macro_use] extern crate rocket; -use rocket::{fairing::AdHoc, Build, Rocket, State}; +use rocket::{fairing::AdHoc, futures::SinkExt, trace::error, Build, Rocket, State}; use rocket_db_pools::{ rocket::{ figment::{ @@ -20,15 +20,17 @@ use rocket_db_pools::{ use rocket_db_pools::diesel::mysql::{Mysql, MysqlValue}; use rocket_db_pools::diesel::serialize::IsNull; use rocket_db_pools::diesel::sql_types::Text; -use rocket_db_pools::diesel::{deserialize, serialize}; use rocket_db_pools::diesel::MysqlPool; +use rocket_db_pools::diesel::{deserialize, serialize}; use rocket_db_pools::Database; +use rocket_ws::WebSocket; +use server::Worker; use worker::detection::{detect_scanner, get_dns_client, Scanners}; -use std::{io::Write, net::SocketAddr}; use std::path::PathBuf; use std::{env, fmt}; +use std::{io::Write, net::SocketAddr}; use uuid::Uuid; use serde::{Deserialize, Deserializer, Serialize}; @@ -460,19 +462,56 @@ async fn pong() -> PlainText { PlainText("pong".to_string()) } +/* + let mut ws_server_handles = Server { + clients: HashMap::new(), + new_scanners: HashMap::new(), + }; + info!("Worker server is listening on: {worker_server_address}"); + loop { + match ws_server.process(&mut ws_server_handles, 0.5) { + Ok(_) => {} + Err(err) => error!("Processing error: {err}"), + } + ws_server_handles.cleanup(&ws_server); + ws_server_handles.commit(conn); +}*/ + #[get("/ws")] -pub async fn ws() -> PlainText { - info!("establish_ws_connection"); - PlainText("ok".to_string()) - // Ok(HttpResponse::Unauthorized().json(e)) - /* - match result { - Ok(response) => Ok(response.into()), - Err(e) => { - error!("ws connection error: {:?}", e); - Err(e) - }, - }*/ +pub async fn ws(ws: WebSocket) -> rocket_ws::Channel<'static> { + use rocket::futures::StreamExt; + use rocket::tokio; + use rocket_ws as ws; + use std::time::Duration; + + ws.channel(move |mut stream: ws::stream::DuplexStream| { + Box::pin(async move { + let mut interval = tokio::time::interval(Duration::from_secs(10)); + + tokio::spawn(async move { + let mut worker = Worker::initial(&mut stream); + loop { + tokio::select! { + _ = interval.tick() => { + // Send message every 10 seconds + //let reading = get_latest_readings().await.unwrap(); + //let _ = stream.send(ws::Message::Text(json!(reading).to_string())).await; + // info!("Sent message"); + } + Ok(false) = worker.poll() => { + // Continue the loop + } + else => { + break; + } + } + } + }); + + tokio::signal::ctrl_c().await.unwrap(); + Ok(()) + }) + }) } struct AppConfigs { @@ -484,7 +523,8 @@ async fn report_counts(rocket: Rocket) -> Rocket { let conn = SnowDb::fetch(&rocket) .expect("database is attached") - .get().await + .get() + .await .unwrap_or_else(|e| { span_error!("failed to connect to MySQL database" => error!("{e}")); panic!("aborting launch"); @@ -499,9 +539,12 @@ async fn report_counts(rocket: Rocket) -> Rocket { #[launch] fn rocket() -> _ { let server_address: SocketAddr = if let Ok(env) = env::var("SERVER_ADDRESS") { - env.parse().expect("The ENV SERVER_ADDRESS should be a valid socket address (address:port)") + env.parse() + .expect("The ENV SERVER_ADDRESS should be a valid socket address (address:port)") } else { - "127.0.0.1:8000".parse().expect("The default address should be valid") + "127.0.0.1:8000" + .parse() + .expect("The default address should be valid") }; let static_data_dir: String = match env::var("STATIC_DATA_DIR") { @@ -544,32 +587,4 @@ fn rocket() -> _ { ws, ], ) - /* - - match server { - Ok(server) => { - match ws2::listen(worker_server_address.as_str()) { - Ok(mut ws_server) => { - std::thread::spawn(move || { - let pool = get_connection(db_url.as_str()); - // note that obtaining a connection from the pool is also potentially blocking - let conn = &mut pool.get().unwrap(); - let mut ws_server_handles = Server { - clients: HashMap::new(), - new_scanners: HashMap::new(), - }; - info!("Worker server is listening on: {worker_server_address}"); - loop { - match ws_server.process(&mut ws_server_handles, 0.5) { - Ok(_) => {} - Err(err) => error!("Processing error: {err}"), - } - ws_server_handles.cleanup(&ws_server); - ws_server_handles.commit(conn); - } - }); - } - Err(err) => error!("Unable to listen on {worker_server_address}: {err}"), - }; - }*/ } diff --git a/snow-scanner/src/server.rs b/snow-scanner/src/server.rs index 6ab4c53..d05b4ac 100644 --- a/snow-scanner/src/server.rs +++ b/snow-scanner/src/server.rs @@ -1,14 +1,21 @@ +use cidr::IpCidr; use hickory_resolver::Name; +use rocket::futures::{stream::Next, SinkExt, StreamExt}; +use rocket_ws::Message; use std::{collections::HashMap, net::IpAddr, str::FromStr}; -use crate::{worker::detection::detect_scanner_from_name, DbConn, Scanner}; +use crate::worker::{ + detection::detect_scanner_from_name, + modules::{Network, WorkerMessages}, +}; +use crate::{DbConn, Scanner}; -pub struct Server { - pub clients: HashMap, +pub struct Server<'a> { + pub clients: HashMap>, pub new_scanners: HashMap, } -impl Server { +impl<'a> Server<'a> { pub async fn commit(&mut self, conn: &mut DbConn) -> &Server { for (name, query_address) in self.new_scanners.clone() { let scanner_name = Name::from_str(name.as_str()).unwrap(); @@ -44,18 +51,81 @@ impl Server { } } -#[derive(Debug, Clone)] -pub struct Worker { - pub authenticated: bool, - pub login: Option, +pub struct Worker<'a> { + authenticated: bool, + login: Option, + stream: &'a mut rocket_ws::stream::DuplexStream, } -impl Worker { - pub fn initial() -> Worker { +impl<'a> Worker<'a> { + pub fn initial(stream: &mut rocket_ws::stream::DuplexStream) -> Worker { info!("New worker"); Worker { authenticated: false, login: None, + stream, + } + } + + pub async fn send(&mut self, msg: Message) -> Result<(), rocket_ws::result::Error> { + self.stream.send(msg).await + } + + pub fn next(&mut self) -> Next<'_, rocket_ws::stream::DuplexStream> { + self.stream.next() + } + + pub async fn poll(&mut self) -> Result { + let message = self.next(); + + match message.await { + Some(Ok(message)) => { + match message { + rocket_ws::Message::Text(_) => match self.on_message(&message).await { + Ok(_) => {} + Err(err) => error!("Processing error: {err}"), + }, + rocket_ws::Message::Binary(data) => { + // Handle Binary message + info!("Received Binary message: {:?}", data); + } + rocket_ws::Message::Close(close_frame) => { + // Handle Close message + info!("Received Close message: {:?}", close_frame); + let close_frame = rocket_ws::frame::CloseFrame { + code: rocket_ws::frame::CloseCode::Normal, + reason: "Client disconected".to_string().into(), + }; + let _ = self.stream.close(Some(close_frame)).await; + return Ok(true); + } + rocket_ws::Message::Ping(ping_data) => { + match self.send(rocket_ws::Message::Pong(ping_data)).await { + Ok(_) => {} + Err(err) => error!("Processing error: {err}"), + } + } + rocket_ws::Message::Pong(pong_data) => { + // Handle Pong message + info!("Received Pong message: {:?}", pong_data); + } + _ => { + info!("Received other message: {:?}", message); + } + }; + Ok(false) + } + Some(Err(err)) => { + info!("Connection closed"); + let close_frame = rocket_ws::frame::CloseFrame { + code: rocket_ws::frame::CloseCode::Normal, + reason: "Client disconected".to_string().into(), + }; + let _ = self.stream.close(Some(close_frame)).await; + // The connection is closed by the client + Ok(true) + } + None => Ok(false), } } @@ -77,63 +147,43 @@ impl Worker { self.authenticated = true; self } -} -/* -impl ws2::Handler for Server { - fn on_open(&mut self, ws: &WebSocket) -> Pod { - info!("New client: {ws}"); - let worker = Worker::initial(); - // Add the client - self.clients.insert(ws.id(), worker); - Ok(()) - } - - fn on_close(&mut self, ws: &WebSocket) -> Pod { - info!("Client /quit: {ws}"); - // Drop the client - self.clients.remove(&ws.id()); - Ok(()) - } - - fn on_message(&mut self, ws: &WebSocket, msg: String) -> Pod { - let client = self.clients.get_mut(&ws.id()); - if client.is_none() { - // Impossible, close in case - return ws.close(); - } - let worker: &mut Worker = client.unwrap(); - - info!("on message: {msg}, {ws}"); + pub async fn on_message(&mut self, msg: &Message) -> Result<(), String> { + info!("on message: {msg}"); let mut worker_reply: Option = None; - let worker_request: WorkerMessages = msg.clone().into(); + let worker_request: WorkerMessages = match msg.clone().try_into() { + Ok(data) => data, + Err(err) => return Err(err), + }; let result = match worker_request { WorkerMessages::AuthenticateRequest { login } => { - if !worker.is_authenticated() { - worker.authenticate(login); + if !self.is_authenticated() { + self.authenticate(login); return Ok(()); } else { - error!("Already authenticated: {ws}"); + error!("Already authenticated"); return Ok(()); } } WorkerMessages::ScannerFoundResponse { name, address } => { info!("Detected {name} for {address}"); - self.new_scanners.insert(name, address); + //self.new_scanners.insert(name, address); Ok(()) } WorkerMessages::GetWorkRequest {} => { + let net = IpCidr::from_str("52.189.78.0/24").unwrap(); worker_reply = Some(WorkerMessages::DoWorkRequest { - neworks: vec![Network(IpCidr::from_str("52.189.78.0/24")?)], + neworks: vec![Network(net)], }); Ok(()) } WorkerMessages::DoWorkRequest { .. } | WorkerMessages::Invalid { .. } => { - error!("Unable to understand: {msg}, {ws}"); + error!("Unable to understand: {msg}"); // Unable to understand, close the connection - return ws.close(); + //return ws.close(); + Err("Unable to understand: {msg}}") } /*msg => { error!("No implemented: {:#?}", msg); Ok(()) @@ -143,14 +193,14 @@ impl ws2::Handler for Server { // it has a request to send if let Some(worker_reply) = worker_reply { let msg_string: String = worker_reply.to_string(); - match ws.send(msg_string) { + match self.send(rocket_ws::Message::Text(msg_string)).await { Ok(_) => match worker_reply { WorkerMessages::DoWorkRequest { .. } => {} msg => error!("No implemented: {:#?}", msg), }, - Err(err) => error!("Error sending reply to {ws}: {err}"), + Err(err) => error!("Error sending reply: {err}"), } } - result + Ok(result?) } -}*/ +} diff --git a/snow-scanner/src/worker/Cargo.toml b/snow-scanner/src/worker/Cargo.toml index 5436f6a..bb69f2d 100644 --- a/snow-scanner/src/worker/Cargo.toml +++ b/snow-scanner/src/worker/Cargo.toml @@ -14,6 +14,7 @@ path = "worker.rs" [dependencies] tungstenite = { version = "0.24.0", default-features = true, features = ["native-tls"] } +rocket_ws = { git = "https://github.com/rwf2/Rocket/", rev = "3bf9ef02d6e803fe9f753777f5a829dda6d2453d", default-features = true} log2 = "0.1.11" diesel = { version = "2", default-features = false, features = [] } dns-ptr-resolver = {git = "https://github.com/wdes/dns-ptr-resolver.git"} diff --git a/snow-scanner/src/worker/modules.rs b/snow-scanner/src/worker/modules.rs index 1cd0d26..3a257af 100644 --- a/snow-scanner/src/worker/modules.rs +++ b/snow-scanner/src/worker/modules.rs @@ -1,6 +1,7 @@ use std::{net::IpAddr, str::FromStr}; use cidr::IpCidr; +use rocket_ws::Message as RocketMessage; use serde::{Deserialize, Deserializer, Serialize, Serializer}; #[derive(Debug, Clone, PartialEq)] @@ -66,6 +67,20 @@ impl Into for String { } } +impl TryInto for RocketMessage { + type Error = String; + + fn try_into(self) -> Result { + match self { + RocketMessage::Text(data) => { + let data: WorkerMessages = data.into(); + Ok(data) + } + _ => Err("Only text is supported".to_string()), + } + } +} + #[cfg(test)] mod tests { use cidr::IpCidr; diff --git a/snow-scanner/src/worker/worker.rs b/snow-scanner/src/worker/worker.rs index 7f99685..0ac4589 100644 --- a/snow-scanner/src/worker/worker.rs +++ b/snow-scanner/src/worker/worker.rs @@ -190,24 +190,6 @@ impl Worker { } } -/*impl ws2::Handler for Worker { - fn on_open(&mut self, ws: &WebSocket) -> Pod { - info!("Connected to: {ws}, starting to work"); - Ok(()) - } - - fn on_close(&mut self, ws: &WebSocket) -> Pod { - info!("End of the work day: {ws}"); - Ok(()) - } - - fn on_message(&mut self, ws: &WebSocket, msg: String) -> Pod { - let server_request: WorkerMessages = msg.clone().into(); - self.receive_request(ws, server_request); - Ok(()) - } -}*/ - fn main() -> () { let _log2 = log2::stdout() .module(true)