use chrono::{NaiveDateTime, Utc}; #[macro_use] extern crate rocket; use rocket::futures::channel::mpsc::channel; use rocket::{fairing::AdHoc, trace::error, Rocket, State}; use rocket_db_pools::{ rocket::{ figment::{ util::map, value::{Map, Value}, }, form::Form, fs::NamedFile, launch, Responder, }, Connection, }; use rocket_db_pools::diesel::mysql::{Mysql, MysqlValue}; use rocket_db_pools::diesel::serialize::IsNull; use rocket_db_pools::diesel::sql_types::Text; use rocket_db_pools::diesel::MysqlPool; use rocket_db_pools::diesel::{deserialize, serialize}; use rocket_db_pools::Database; use ::ws::Message; use rocket_ws::WebSocket; use server::{SharedData, Worker}; use worker::detection::{detect_scanner, get_dns_client, Scanners}; use std::path::PathBuf; use std::{env, fmt}; use std::{io::Write, net::SocketAddr}; use uuid::Uuid; use serde::{Deserialize, Deserializer, Serialize}; use dns_ptr_resolver::{get_ptr, ResolvedResult}; pub mod models; pub mod schema; pub mod server; pub mod worker; use crate::models::*; #[derive(Database)] #[database("snow_scanner_db")] pub struct SnowDb(MysqlPool); type DbConn = Connection; trait IsStatic { fn is_static(self: &Self) -> bool; } impl IsStatic for Scanners { fn is_static(self: &Self) -> bool { match self { Scanners::Censys => true, Scanners::InternetMeasurement => true, _ => false, } } } impl<'de> Deserialize<'de> for Scanners { fn deserialize(deserializer: D) -> Result where D: Deserializer<'de>, { let s = >::deserialize(deserializer)?; let k: &str = s[0].as_str(); match k { "stretchoid" => Ok(Scanners::Stretchoid), "binaryedge" => Ok(Scanners::Binaryedge), "stretchoid.txt" => Ok(Scanners::Stretchoid), "binaryedge.txt" => Ok(Scanners::Binaryedge), "censys.txt" => Ok(Scanners::Censys), "internet-measurement.com.txt" => Ok(Scanners::InternetMeasurement), v => Err(serde::de::Error::custom(format!( "Unknown value: {}", v.to_string() ))), } } } impl fmt::Display for Scanners { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, "{}", match self { Self::Stretchoid => "stretchoid", Self::Binaryedge => "binaryedge", Self::Censys => "censys", Self::InternetMeasurement => "internet-measurement.com", } ) } } impl serialize::ToSql for Scanners { fn to_sql(&self, out: &mut serialize::Output) -> serialize::Result { match *self { Self::Stretchoid => out.write_all(b"stretchoid")?, Self::Binaryedge => out.write_all(b"binaryedge")?, Self::Censys => out.write_all(b"censys")?, Self::InternetMeasurement => out.write_all(b"internet-measurement.com")?, }; Ok(IsNull::No) } } impl deserialize::FromSql for Scanners { fn from_sql(bytes: MysqlValue) -> deserialize::Result { let value = >::from_sql(bytes)?; let value = &value as &str; let value: Result = value.try_into(); match value { Ok(d) => Ok(d), Err(err) => Err(err.into()), } } } impl TryInto for &str { type Error = String; fn try_into(self) -> Result { match self { "stretchoid" => Ok(Scanners::Stretchoid), "binaryedge" => Ok(Scanners::Binaryedge), "internet-measurement.com" => Ok(Scanners::InternetMeasurement), value => Err(format!("Invalid value: {value}")), } } } async fn handle_ip(mut conn: DbConn, ip: String) -> Result> { let query_address = ip.parse().expect("To parse"); let ptr_result: Result = std::thread::spawn(move || { let client = get_dns_client(); let ptr_result: ResolvedResult = if let Ok(res) = get_ptr(query_address, client) { res } else { return Err(()); }; Ok(ptr_result) }) .join() .unwrap(); if ptr_result.is_err() { return Err(None); } let result = ptr_result.unwrap(); match detect_scanner(&result) { Ok(Some(scanner_type)) => { match Scanner::find_or_new(query_address, scanner_type, result.result, &mut conn).await { Ok(scanner) => Ok(scanner), Err(_) => Err(None), } } Ok(None) => Err(None), Err(_) => Err(Some(result)), } } static FORM: &str = r#" Wdes - snow scanner

"#; #[derive(FromForm, Serialize, Deserialize)] pub struct ScanParams<'r> { username: &'r str, ips: &'r str, } #[derive(Responder)] enum MultiReply { #[response(status = 500, content_type = "text")] Error(ServerError), #[response(status = 422)] FormError(PlainText), #[response(status = 404)] NotFound(String), #[response(status = 200)] Content(HtmlContents), #[response(status = 200)] FileContents(NamedFile), } #[post("/scan", data = "
")] async fn handle_scan(mut db: DbConn, form: Form>) -> MultiReply { if form.username.len() < 4 { return MultiReply::FormError(PlainText("Invalid username".to_string())); } let task_group_id: Uuid = Uuid::now_v7(); for ip in form.ips.lines() { let scan_task = ScanTask { task_group_id: task_group_id.to_string(), cidr: ip.to_string(), created_by_username: form.username.to_string(), created_at: Utc::now().naive_utc(), updated_at: None, started_at: None, still_processing_at: None, ended_at: None, }; match scan_task.save(&mut db).await { Ok(_) => error!("Added {}", ip.to_string()), Err(err) => error!("Not added: {:?}", err), } } MultiReply::Content(HtmlContents(format!("New task added: {} !", task_group_id))) } #[derive(FromForm, Serialize, Deserialize)] pub struct ReportParams { ip: String, } #[post("/report", data = "")] async fn handle_report(db: DbConn, form: Form) -> HtmlContents { match handle_ip(db, form.ip.clone()).await { Ok(scanner) => HtmlContents(match scanner.scanner_name { Scanners::Binaryedge => match scanner.last_checked_at { Some(date) => format!( "Reported a binaryedge ninja! {} known as {} since {date}.", scanner.ip, scanner.ip_ptr.unwrap_or("".to_string()) ), None => format!( "Reported a binaryedge ninja! {} known as {}.", scanner.ip, scanner.ip_ptr.unwrap_or("".to_string()) ), }, Scanners::Stretchoid => match scanner.last_checked_at { Some(date) => format!( "Reported a stretchoid agent! {} known as {} since {date}.", scanner.ip, scanner.ip_ptr.unwrap_or("".to_string()) ), None => format!( "Reported a stretchoid agent! {} known as {}.", scanner.ip, scanner.ip_ptr.unwrap_or("".to_string()) ), }, _ => format!("Not supported"), }), Err(ptr_result) => HtmlContents(format!( "The IP {} resolved as {:?} did not match known scanners patterns.", form.ip, match ptr_result { Some(res) => res.result, None => None, } )), } } struct SecurePath { pub data: String, } impl TryInto for &str { type Error = String; fn try_into(self) -> Result { // A-Z a-z 0-9 // . - _ if self .chars() .all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_') { return Ok(SecurePath { data: self.to_string(), }); } Err(format!("Invalid value: {}", self.to_string())) } } #[get("/collections//")] async fn handle_get_collection( vendor_name: &str, file_name: &str, app_configs: &State, ) -> MultiReply { let vendor_name: Result = vendor_name.try_into(); let vendor_name = match vendor_name { Ok(secure_path) => secure_path.data, Err(err) => return MultiReply::FormError(PlainText(err.to_string())), }; let file_name: Result = file_name.try_into(); let file_name = match file_name { Ok(secure_path) => secure_path.data, Err(err) => return MultiReply::FormError(PlainText(err.to_string())), }; let mut path: PathBuf = PathBuf::new(); let static_data_dir: String = app_configs.static_data_dir.clone(); path.push(static_data_dir); path.push("collections"); path.push(vendor_name); path.push(file_name); match NamedFile::open(path).await { Ok(file) => MultiReply::FileContents(file), Err(err) => MultiReply::NotFound(err.to_string()), } } #[get("/scanners/")] async fn handle_list_scanners( mut db: DbConn, scanner_name: &str, app_configs: &State, ) -> MultiReply { let static_data_dir: String = app_configs.static_data_dir.clone(); let scanner_name: Result = scanner_name.try_into(); let scanner_name = match scanner_name { Ok(scanner_name) => scanner_name, Err(err) => return MultiReply::FormError(PlainText(err.to_string())), }; if scanner_name.is_static() { let mut path: PathBuf = PathBuf::new(); path.push(static_data_dir); path.push("scanners"); path.push(scanner_name.to_string()); return match NamedFile::open(path).await { Ok(file) => MultiReply::FileContents(file), Err(err) => MultiReply::NotFound(err.to_string()), }; } let scanners_list = match Scanner::list_names(scanner_name, &mut db).await { Ok(data) => Ok(data), Err(err) => Err(err), }; if let Ok(scanners) = scanners_list { MultiReply::Content(HtmlContents(scanners.join("\n"))) } else { MultiReply::Error(ServerError("Unable to list scanners".to_string())) } } static SCAN_TASKS_HEAD: &str = r#" Wdes - snow scanner | Scan tasks "#; static SCAN_TASKS_FOOT: &str = r#"
CIDR Started at Still processing at Ended at
"#; #[get("/scan/tasks")] async fn handle_list_scan_tasks(mut db: Connection) -> MultiReply { let mut html_data: Vec = vec![SCAN_TASKS_HEAD.to_string()]; let scan_tasks_list = match ScanTask::list(&mut db).await { Ok(data) => Ok(data), Err(err) => Err(err), }; if let Ok(scan_tasks) = scan_tasks_list { for row in scan_tasks { let cidr: String = row.cidr; let started_at: Option = row.started_at; let still_processing_at: Option = row.still_processing_at; let ended_at: Option = row.ended_at; html_data.push(format!( " {cidr} {:#?} {:#?} {:#?} ", started_at, still_processing_at, ended_at )); } html_data.push(SCAN_TASKS_FOOT.to_string()); MultiReply::Content(HtmlContents(html_data.join("\n"))) } else { return MultiReply::Error(ServerError("Unable to list scan tasks".to_string())); } } #[derive(Responder)] #[response(status = 200, content_type = "text")] pub struct PlainText(String); #[derive(Responder)] #[response(status = 200, content_type = "html")] pub struct HtmlContents(String); #[derive(Responder)] #[response(status = 500, content_type = "html")] pub struct ServerError(String); #[get("/")] async fn index() -> HtmlContents { HtmlContents(FORM.to_string()) } #[get("/ping")] async fn pong() -> PlainText { PlainText("pong".to_string()) } #[get("/ws")] pub async fn ws(ws: WebSocket, shared: &State) -> rocket_ws::Channel<'static> { use crate::rocket::futures::StreamExt; use rocket::tokio; use rocket_ws as ws; use std::time::Duration; let (tx, mut rx) = channel::(1); SharedData::add_worker(tx, &shared.workers); SharedData::send_to_all(&shared.workers, "I am new !"); let channel = ws.channel(move |mut stream: ws::stream::DuplexStream| { Box::pin(async move { let mut interval = tokio::time::interval(Duration::from_secs(60)); tokio::spawn(async move { let mut worker = Worker::initial(&mut stream); loop { tokio::select! { _ = interval.tick() => { // Send message every X seconds if let Ok(true) = worker.tick().await { break; } } Some(message) = rx.next() => { println!("Received message from other client: {:?}", message); let _ = worker.send(message).await; }, Ok(false) = worker.poll() => { // Continue the loop } else => { break; } } } }); tokio::signal::ctrl_c().await.unwrap(); Ok(()) }) }); channel } struct AppConfigs { static_data_dir: String, } async fn report_counts<'a>(rocket: Rocket) -> Rocket { use rocket_db_pools::diesel::AsyncConnectionWrapper; let conn = SnowDb::fetch(&rocket) .expect("database is attached") .get() .await .unwrap_or_else(|e| { span_error!("failed to connect to MySQL database" => error!("{e}")); panic!("aborting launch"); }); let _: AsyncConnectionWrapper<_> = conn.into(); info!("Connected to the DB"); rocket } #[launch] fn rocket() -> _ { let server_address: SocketAddr = if let Ok(env) = env::var("SERVER_ADDRESS") { env.parse() .expect("The ENV SERVER_ADDRESS should be a valid socket address (address:port)") } else { "127.0.0.1:8000" .parse() .expect("The default address should be valid") }; let static_data_dir: String = match env::var("STATIC_DATA_DIR") { Ok(val) => val, Err(_) => "../data/".to_string(), }; let db_url: String = if let Ok(env) = env::var("DB_URL") { env } else { error!("Missing ENV: DB_URL"); "mysql://localhost".to_string() }; let db: Map<_, Value> = map! { "url" => db_url.into(), "pool_size" => 10.into(), "timeout" => 5.into(), }; let config_figment = rocket::Config::figment() .merge(("address", server_address.ip().to_string())) .merge(("port", server_address.port())) .merge(("databases", map!["snow_scanner_db" => db])); rocket::custom(config_figment) .attach(SnowDb::init()) .attach(AdHoc::on_ignite("Report counts", report_counts)) .attach(AdHoc::on_shutdown("Close Websockets", |r| { Box::pin(async move { if let Some(clients) = r.state::() { SharedData::shutdown_to_all(&clients.workers); } }) })) .manage(SharedData::init()) .manage(AppConfigs { static_data_dir }) .mount( "/", routes![ index, pong, handle_report, handle_scan, handle_list_scan_tasks, handle_list_scanners, handle_get_collection, ws, ], ) }