diff --git a/snow-scanner/src/event_bus.rs b/snow-scanner/src/event_bus.rs new file mode 100644 index 0000000..e68c399 --- /dev/null +++ b/snow-scanner/src/event_bus.rs @@ -0,0 +1,93 @@ +use rocket::futures::channel::mpsc as rocket_mpsc; +use rocket::futures::StreamExt; +use rocket::tokio; + +/// Handles all the raw events being streamed from balancers and parses and filters them into only the events we care about. +pub struct EventBus { + events_rx: rocket_mpsc::Receiver, + events_tx: rocket_mpsc::Sender, + bus_tx: tokio::sync::broadcast::Sender, +} + +impl EventBus { + pub fn new() -> Self { + let (events_tx, events_rx) = rocket_mpsc::channel(100); + let (bus_tx, _) = tokio::sync::broadcast::channel(100); + Self { + events_rx, + events_tx, + bus_tx, + } + } + + pub async fn run(&mut self) { + info!("EventBus started"); + loop { + tokio::select! { + Some(event) = self.events_rx.next() => { + info!("EventBus received: {event}"); + self.handle_event(event); + } + else => { + warn!("EventBus stopped"); + break; + } + } + } + } + + fn handle_event(&self, event: rocket_ws::Message) { + info!("Received event: {}", event); + if self.bus_tx.receiver_count() == 0 { + return; + } + match self.bus_tx.send(event) { + Ok(count) => { + info!("Event sent to {count} subscribers"); + } + Err(err) => { + error!("Error sending event to subscribers: {}", err); + } + } + } + + pub fn subscriber(&self) -> EventBusSubscriber { + EventBusSubscriber::new(self.bus_tx.clone()) + } + + pub fn writer(&self) -> EventBusWriter { + EventBusWriter::new(self.events_tx.clone()) + } +} + +pub type EventBusEvent = rocket_ws::Message; + +/// Enables subscriptions to the event bus +pub struct EventBusSubscriber { + bus_tx: tokio::sync::broadcast::Sender, +} + +/// Enables subscriptions to the event bus +pub struct EventBusWriter { + bus_tx: rocket_mpsc::Sender, +} + +impl EventBusWriter { + pub fn new(bus_tx: rocket_mpsc::Sender) -> Self { + Self { bus_tx } + } + + pub fn write(&self) -> rocket_mpsc::Sender { + self.bus_tx.clone() + } +} + +impl EventBusSubscriber { + pub fn new(bus_tx: tokio::sync::broadcast::Sender) -> Self { + Self { bus_tx } + } + + pub fn subscribe(&self) -> tokio::sync::broadcast::Receiver { + self.bus_tx.subscribe() + } +} diff --git a/snow-scanner/src/main.rs b/snow-scanner/src/main.rs index 8ad5d14..38507e1 100644 --- a/snow-scanner/src/main.rs +++ b/snow-scanner/src/main.rs @@ -3,8 +3,8 @@ use chrono::{NaiveDateTime, Utc}; #[macro_use] extern crate rocket; -use rocket::futures::channel::mpsc::channel; -use rocket::{fairing::AdHoc, trace::error, Rocket, State}; +use event_bus::{EventBusSubscriber, EventBusWriter}; +use rocket::{fairing::AdHoc, futures::SinkExt, trace::error, Rocket, State}; use rocket_db_pools::{ rocket::{ figment::{ @@ -13,7 +13,7 @@ use rocket_db_pools::{ }, form::Form, fs::NamedFile, - launch, Responder, + Responder, }, Connection, }; @@ -25,9 +25,8 @@ use rocket_db_pools::diesel::MysqlPool; use rocket_db_pools::diesel::{deserialize, serialize}; use rocket_db_pools::Database; -use ::ws::Message; use rocket_ws::WebSocket; -use server::{SharedData, Worker}; +use server::Server; use worker::detection::{detect_scanner, get_dns_client, Scanners}; use std::path::PathBuf; @@ -39,6 +38,7 @@ use serde::{Deserialize, Deserializer, Serialize}; use dns_ptr_resolver::{get_ptr, ResolvedResult}; +pub mod event_bus; pub mod models; pub mod schema; pub mod server; @@ -460,55 +460,27 @@ async fn index() -> HtmlContents { } #[get("/ping")] -async fn pong() -> PlainText { +async fn pong(event_bus_writer: &State) -> PlainText { + let mut bus_tx = event_bus_writer.write(); + let _ = bus_tx + .send(rocket_ws::Message::Text("Groumpf!".to_string())) + .await; PlainText("pong".to_string()) } #[get("/ws")] -pub async fn ws(ws: WebSocket, shared: &State) -> rocket_ws::Channel<'static> { - use crate::rocket::futures::StreamExt; - use rocket::tokio; - use rocket_ws as ws; - use std::time::Duration; +pub async fn ws( + ws: WebSocket, + event_bus: &State, + event_bus_writer: &State, +) -> rocket_ws::Channel<'static> { + use rocket::futures::channel::mpsc as rocket_mpsc; - let (tx, mut rx) = channel::(1); - - SharedData::add_worker(tx, &shared.workers); - - SharedData::send_to_all(&shared.workers, "I am new !"); - - let channel = ws.channel(move |mut stream: ws::stream::DuplexStream| { - Box::pin(async move { - let mut interval = tokio::time::interval(Duration::from_secs(60)); - - tokio::spawn(async move { - let mut worker = Worker::initial(&mut stream); - loop { - tokio::select! { - _ = interval.tick() => { - // Send message every X seconds - if let Ok(true) = worker.tick().await { - break; - } - } - Some(message) = rx.next() => { - println!("Received message from other client: {:?}", message); - let _ = worker.send(message).await; - }, - Ok(false) = worker.poll() => { - // Continue the loop - } - else => { - break; - } - } - } - }); - - tokio::signal::ctrl_c().await.unwrap(); - Ok(()) - }) - }); + let (_, ws_receiver) = rocket_mpsc::channel::(1); + let bus_rx = event_bus.subscribe(); + let bus_tx = event_bus_writer.write(); + let channel: rocket_ws::Channel = + ws.channel(|stream| Server::handle(stream, bus_rx, bus_tx, ws_receiver)); channel } @@ -535,8 +507,8 @@ async fn report_counts<'a>(rocket: Rocket) -> Rocket _ { +#[rocket::main] +async fn main() -> Result<(), rocket::Error> { let server_address: SocketAddr = if let Ok(env) = env::var("SERVER_ADDRESS") { env.parse() .expect("The ENV SERVER_ADDRESS should be a valid socket address (address:port)") @@ -569,18 +541,30 @@ fn rocket() -> _ { .merge(("port", server_address.port())) .merge(("databases", map!["snow_scanner_db" => db])); - rocket::custom(config_figment) + let mut event_bus = event_bus::EventBus::new(); + let event_subscriber = event_bus.subscriber(); + let event_writer = event_bus.writer(); + + let _ = rocket::custom(config_figment) .attach(SnowDb::init()) .attach(AdHoc::on_ignite("Report counts", report_counts)) .attach(AdHoc::on_shutdown("Close Websockets", |r| { Box::pin(async move { - if let Some(clients) = r.state::() { - SharedData::shutdown_to_all(&clients.workers); + if let Some(writer) = r.state::() { + Server::shutdown_to_all(writer); } }) })) - .manage(SharedData::init()) + .attach(AdHoc::on_liftoff("Run websocket client manager", |_| { + Box::pin(async move { + rocket::tokio::spawn(async move { + event_bus.run().await; + }); + }) + })) .manage(AppConfigs { static_data_dir }) + .manage(event_subscriber) + .manage(event_writer) .mount( "/", routes![ @@ -594,4 +578,7 @@ fn rocket() -> _ { ws, ], ) + .launch() + .await; + Ok(()) } diff --git a/snow-scanner/src/server.rs b/snow-scanner/src/server.rs index 62f0e7c..edad8c8 100644 --- a/snow-scanner/src/server.rs +++ b/snow-scanner/src/server.rs @@ -1,30 +1,105 @@ use cidr::IpCidr; -use hickory_resolver::Name; use rocket::futures::{stream::Next, SinkExt, StreamExt}; use rocket_ws::{frame::CloseFrame, Message}; -use std::{collections::HashMap, net::IpAddr, ops::Deref, str::FromStr, sync::Mutex}; +use std::{pin::Pin, str::FromStr}; -use crate::worker::{ - detection::detect_scanner_from_name, - modules::{Network, WorkerMessages}, +use crate::{ + event_bus::{EventBusEvent, EventBusWriter}, + worker::modules::{Network, WorkerMessages}, }; -use crate::{DbConn, Scanner}; -use rocket::futures::channel::mpsc::Sender; +use rocket::futures::channel::mpsc as rocket_mpsc; -pub type WorkersList = Vec>; +pub struct WsChat {} -pub struct SharedData { - pub workers: Mutex, -} +impl WsChat { + pub async fn work( + mut stream: rocket_ws::stream::DuplexStream, + mut bus_rx: rocket::tokio::sync::broadcast::Receiver, + mut bus_tx: rocket_mpsc::Sender, + mut ws_receiver: rocket_mpsc::Receiver, + ) { + use crate::rocket::futures::StreamExt; + use rocket::tokio; -impl SharedData { - pub fn init() -> SharedData { - SharedData { - workers: Mutex::new(vec![]), + let _ = bus_tx + .send(rocket_ws::Message::Text("I am new !".to_string())) + .await; + //SharedData::send_to_all(&workers, "I am new !"); + let mut worker = Worker::initial(&mut stream); + let mut interval = rocket::tokio::time::interval(std::time::Duration::from_secs(60)); + loop { + tokio::select! { + _ = interval.tick() => { + // Send message every X seconds + if let Ok(true) = worker.tick().await { + break; + } + } + result = bus_rx.recv() => { + let message = match result { + Ok(message) => message, + Err(err) => { + error!("Bus error: {err}"); + continue; + } + }; + if let Err(err) = worker.send(message).await { + error!("Error sending event to Event bus WebSocket: {}", err); + break; + } + } + Some(message) = ws_receiver.next() => { + info!("Received message from other client: {:?}", message); + let _ = worker.send(message).await; + }, + Ok(false) = worker.poll() => { + // Continue the loop + } + else => { + break; + } + } } } +} - pub fn add_worker(tx: Sender, workers: &Mutex) -> () { +pub struct Server {} + +type HandleBox = Pin< + Box> + std::marker::Send>, +>; + +impl Server { + pub fn handle( + stream: rocket_ws::stream::DuplexStream, + bus_rx: rocket::tokio::sync::broadcast::Receiver, + bus_tx: rocket_mpsc::Sender, + ws_receiver: rocket_mpsc::Receiver, + ) -> HandleBox { + use rocket::tokio; + + //SharedData::add_worker(tx.clone(), &shared.workers); + //move |mut stream: ws::stream::DuplexStream| { + Box::pin(async move { + let work_fn = WsChat::work( + stream, + bus_rx, + bus_tx, + ws_receiver, + //workers + ); + tokio::spawn(work_fn); + + tokio::signal::ctrl_c().await.unwrap(); + Ok(()) + }) + } + + pub fn new() -> Server { + Server {} + } + + /*pub fn add_worker(tx: rocket_mpsc::Sender, workers: &Mutex) -> () { let workers_lock = workers.try_lock(); if let Ok(mut workers) = workers_lock { workers.push(tx); @@ -33,33 +108,24 @@ impl SharedData { } else { error!("Unable to lock workers"); } + }*/ + + pub fn shutdown_to_all(server: &EventBusWriter) -> () { + let res = server.write().try_send(Message::Close(Some(CloseFrame { + code: rocket_ws::frame::CloseCode::Away, + reason: "Server stop".into(), + }))); + match res { + Ok(_) => { + info!("Worker did receive stop signal."); + } + Err(err) => { + error!("Send error: {err}"); + } + }; } - pub fn shutdown_to_all(workers: &Mutex) -> () { - let workers_lock = workers.try_lock(); - if let Ok(ref workers) = workers_lock { - workers.iter().for_each(|tx| { - let res = tx.clone().try_send(Message::Close(Some(CloseFrame { - code: rocket_ws::frame::CloseCode::Away, - reason: "Server stop".into(), - }))); - match res { - Ok(_) => { - info!("Worker did receive stop signal."); - } - Err(err) => { - error!("Send error: {err}"); - } - }; - }); - info!("Currently {} workers online.", workers.len()); - std::mem::drop(workers_lock); - } else { - error!("Unable to lock workers"); - } - } - - pub fn send_to_all(workers: &Mutex, message: &str) -> () { + /*pub fn send_to_all(workers: &Mutex, message: &str) -> () { let workers_lock = workers.try_lock(); if let Ok(ref workers) = workers_lock { workers.iter().for_each(|tx| { @@ -78,48 +144,7 @@ impl SharedData { } else { error!("Unable to lock workers"); } - } -} - -pub struct Server { - pub clients: HashMap, - pub new_scanners: HashMap, -} - -impl Server { - pub async fn commit(&mut self, conn: &mut DbConn) -> &Server { - for (name, query_address) in self.new_scanners.clone() { - let scanner_name = Name::from_str(name.as_str()).unwrap(); - - match detect_scanner_from_name(&scanner_name) { - Ok(Some(scanner_type)) => { - match Scanner::find_or_new( - query_address, - scanner_type, - Some(scanner_name), - conn, - ) - .await - { - Ok(scanner) => { - // Got saved - self.new_scanners.remove(&name); - info!( - "Saved {scanner_type}: {name} for {query_address}: {:?}", - scanner.ip_ptr - ); - } - Err(err) => { - error!("Unable to find or new {:?}", err); - } - }; - } - Ok(None) => {} - Err(_) => {} - } - } - self - } + }*/ } pub struct Worker<'a> { @@ -196,7 +221,7 @@ impl<'a> Worker<'a> { }; Ok(false) } - Some(Err(err)) => { + Some(Err(_)) => { info!("Connection closed"); let close_frame = rocket_ws::frame::CloseFrame { code: rocket_ws::frame::CloseCode::Normal,