diff --git a/snow-scanner/Cargo.toml b/snow-scanner/Cargo.toml index ce1dbf8..b219dd1 100644 --- a/snow-scanner/Cargo.toml +++ b/snow-scanner/Cargo.toml @@ -30,13 +30,15 @@ path = "src/main.rs" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -rouille = "3.6.2" +actix-web = "4" +actix-files = "0.6.6" hmac = "0.12.1" sha2 = "0.10.8" hex = "0.4.3" -diesel = { version = "2.2.0", default-features = false, features = ["mysql", "chrono", "uuid"] } +diesel = { version = "2.2.0", default-features = false, features = ["mysql", "chrono", "uuid", "r2d2"] } dns-ptr-resolver = "1.2.0" hickory-client = { version = "0.24.1", default-features = false } chrono = "0.4.38" uuid = { version = "1.10.0", default-features = false, features = ["v7", "serde", "std"] } cidr = "0.2.2" +serde = "1.0.210" diff --git a/snow-scanner/src/main.rs b/snow-scanner/src/main.rs index b760079..11663ee 100644 --- a/snow-scanner/src/main.rs +++ b/snow-scanner/src/main.rs @@ -1,25 +1,30 @@ -#![feature(trivial_bounds)] -#[macro_use] -extern crate rouille; +use actix_files::NamedFile; +use actix_web::error::ErrorInternalServerError; +use actix_web::http::header::ContentType; +use actix_web::{web, App, HttpRequest, HttpResponse, HttpServer}; use chrono::{NaiveDateTime, Utc}; use diesel::deserialize::{self, FromSqlRow}; use diesel::mysql::{Mysql, MysqlValue}; use diesel::sql_types::Text; -use hmac::{Hmac, Mac}; -use rouille::{Request, Response, ResponseBody}; -use sha2::Sha256; + +use diesel::r2d2::ConnectionManager; +use diesel::r2d2::Pool; + use std::io::Write; +use std::path::PathBuf; use std::str::FromStr; -use std::{env, fmt, thread}; +use std::{env, fmt}; use uuid::Uuid; +use serde::{Deserialize, Deserializer, Serialize}; + use hickory_client::client::SyncClient; use hickory_client::rr::Name; use hickory_client::tcp::TcpClientConnection; use diesel::serialize::IsNull; -use diesel::{serialize, Connection, MysqlConnection}; +use diesel::{serialize, MysqlConnection}; use dns_ptr_resolver::{get_ptr, ResolvedResult}; pub mod models; @@ -27,8 +32,11 @@ pub mod schema; use crate::models::*; +/// Short-hand for the database pool type to use throughout the app. +type DbPool = Pool>; + // Create alias for HMAC-SHA256 -type HmacSha256 = Hmac; +// type HmacSha256 = Hmac; #[derive(Debug, Clone, Copy, FromSqlRow)] pub enum Scanners { @@ -52,20 +60,24 @@ impl IsStatic for Scanners { } } -#[derive(Debug, PartialEq, Eq)] -struct ParseScannerError; - -impl FromStr for Scanners { - type Err = ParseScannerError; - fn from_str(input: &str) -> Result { - match input { +impl<'de> Deserialize<'de> for Scanners { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let s = >::deserialize(deserializer)?; + let k: &str = s[0].as_str(); + match k { "stretchoid" => Ok(Scanners::Stretchoid), "binaryedge" => Ok(Scanners::Binaryedge), "stretchoid.txt" => Ok(Scanners::Stretchoid), "binaryedge.txt" => Ok(Scanners::Binaryedge), "censys.txt" => Ok(Scanners::Censys), "internet-measurement.com.txt" => Ok(Scanners::InternetMeasurement), - _ => Err(ParseScannerError {}), + v => Err(serde::de::Error::custom(format!( + "Unknown value: {}", + v.to_string() + ))), } } } @@ -128,7 +140,7 @@ fn detect_scanner(ptr_result: &ResolvedResult) -> Result { } } -fn handle_ip(conn: &mut MysqlConnection, ip: String) -> Result> { +async fn handle_ip(pool: web::Data, ip: String) -> Result> { let query_address = ip.parse().expect("To parse"); let client = get_dns_client(); @@ -155,10 +167,18 @@ fn handle_ip(conn: &mut MysqlConnection, ip: String) -> Result Ok(scanner), - Err(_) => Err(None), - } + // use web::block to offload blocking Diesel queries without blocking server thread + web::block(move || { + // note that obtaining a connection from the pool is also potentially blocking + let conn = &mut pool.get().unwrap(); + + match scanner.save(conn) { + Ok(scanner) => Ok(scanner), + Err(_) => Err(None), + } + }) + .await + .unwrap() } Err(_) => Err(Some(ptr_result)), @@ -188,50 +208,55 @@ static FORM: &str = r#" "#; -fn handle_scan(conn: &mut MysqlConnection, request: &Request) -> Response { - let data = try_or_400!(post_input!(request, { - username: String, - ips: String, - })); +#[derive(Serialize, Deserialize)] +pub struct ScanParams { + username: String, + ips: String, +} - if data.username.len() < 4 { - return Response { - status_code: 422, - headers: vec![("Content-Type".into(), "text/plain; charset=utf-8".into())], - data: ResponseBody::from_string("Invalid username"), - upgrade: None, - }; +async fn handle_scan(pool: web::Data, params: web::Form) -> HttpResponse { + if params.username.len() < 4 { + return plain_contents("Invalid username".to_string()); } let task_group_id: Uuid = Uuid::now_v7(); - for ip in data.ips.lines() { - let scan_task = ScanTask { - task_group_id: task_group_id, - cidr: ip.to_string(), - created_by_username: data.username.clone(), - created_at: Utc::now().naive_utc(), - updated_at: None, - started_at: None, - still_processing_at: None, - ended_at: None, - }; - match scan_task.save(conn) { - Ok(_) => println!("Added {}", ip.to_string()), - Err(err) => eprintln!("Not added: {:?}", err), + // use web::block to offload blocking Diesel queries without blocking server thread + let _ = web::block(move || { + // note that obtaining a connection from the pool is also potentially blocking + let conn = &mut pool.get().unwrap(); + for ip in params.ips.lines() { + let scan_task = ScanTask { + task_group_id: task_group_id.to_string(), + cidr: ip.to_string(), + created_by_username: params.username.clone(), + created_at: Utc::now().naive_utc(), + updated_at: None, + started_at: None, + still_processing_at: None, + ended_at: None, + }; + match scan_task.save(conn) { + Ok(_) => println!("Added {}", ip.to_string()), + Err(err) => eprintln!("Not added: {:?}", err), + } } - } + }) + .await + // map diesel query errors to a 500 error response + .map_err(|err| ErrorInternalServerError(err)); - rouille::Response::html(format!("New task added: {} !", task_group_id)) + html_contents(format!("New task added: {} !", task_group_id)) } -fn handle_report(conn: &mut MysqlConnection, request: &Request) -> Response { - let data = try_or_400!(post_input!(request, { - ip: String, - })); +#[derive(Serialize, Deserialize)] +pub struct ReportParams { + ip: String, +} - match handle_ip(conn, data.ip.clone()) { - Ok(scanner) => rouille::Response::html(match scanner.scanner_name { +async fn handle_report(pool: web::Data, params: web::Form) -> HttpResponse { + match handle_ip(pool, params.ip.clone()).await { + Ok(scanner) => html_contents(match scanner.scanner_name { Scanners::Binaryedge => format!( "Reported an escaped ninja! {} known as {:?}.", scanner.ip, scanner.ip_ptr @@ -243,9 +268,9 @@ fn handle_report(conn: &mut MysqlConnection, request: &Request) -> Response { _ => format!("Not supported"), }), - Err(ptr_result) => rouille::Response::html(format!( + Err(ptr_result) => html_contents(format!( "The IP {} resolved as {:?} did not match known scanners patterns.", - data.ip, + params.ip, match ptr_result { Some(res) => res.result, None => None, @@ -254,63 +279,65 @@ fn handle_report(conn: &mut MysqlConnection, request: &Request) -> Response { } } -fn handle_get_collection(request: &Request, static_data_dir: &str) -> Response { - // The `match_assets` function tries to find a file whose name corresponds to the URL - // of the request. The second parameter (`"."`) tells where the files to look for are - // located. - // In order to avoid potential security threats, `match_assets` will never return any - // file outside of this directory even if the URL is for example `/../../foo.txt`. - let response = rouille::match_assets(&request, static_data_dir); +async fn handle_get_collection( + path: web::Path<(String, String)>, + req: HttpRequest, + static_data_dir: actix_web::web::Data, +) -> actix_web::Result { + let (vendor_name, file_name) = path.into_inner(); - if response.is_success() { - return response; + let mut path: PathBuf = PathBuf::new(); + let static_data_dir: String = static_data_dir.into_inner().to_string(); + path.push(static_data_dir); + path.push(vendor_name.to_string()); + path.push(file_name.to_string()); + match NamedFile::open(path) { + Ok(file) => Ok(file.into_response(&req)), + Err(err) => Ok(HttpResponse::NotFound() + .content_type(ContentType::plaintext()) + .body(format!("File not found: {}.\n", err))), } - return Response { - status_code: 404, - headers: vec![("Content-Type".into(), "text/plain; charset=utf-8".into())], - data: ResponseBody::from_string("File not found.\n"), - upgrade: None, - }; } -fn handle_list_scanners( - conn: &mut MysqlConnection, - static_data_dir: &str, - scanner_name: Scanners, - request: &Request, -) -> Response { +async fn handle_list_scanners( + pool: web::Data, + path: web::Path, + req: HttpRequest, + static_data_dir: actix_web::web::Data, +) -> actix_web::Result { + let scanner_name = path.into_inner(); + let static_data_dir: String = static_data_dir.into_inner().to_string(); if scanner_name.is_static() { - // The `match_assets` function tries to find a file whose name corresponds to the URL - // of the request. The second parameter (`"."`) tells where the files to look for are - // located. - // In order to avoid potential security threats, `match_assets` will never return any - // file outside of this directory even if the URL is for example `/../../foo.txt`. - let response = rouille::match_assets(&request, static_data_dir); + let mut path: PathBuf = PathBuf::new(); + path.push(static_data_dir); + path.push(scanner_name.to_string()); - if response.is_success() { - return response; - } - return Response { - status_code: 404, - headers: vec![("Content-Type".into(), "text/plain; charset=utf-8".into())], - data: ResponseBody::from_string("File not found.\n"), - upgrade: None, + return match NamedFile::open(path) { + Ok(file) => Ok(file.into_response(&req)), + Err(err) => Ok(HttpResponse::NotFound() + .content_type(ContentType::plaintext()) + .body(format!("File not found: {}.\n", err))), }; } - if let Ok(scanners) = Scanner::list_names(scanner_name, conn) { - Response { - status_code: 200, - headers: vec![("Content-Type".into(), "text/plain; charset=utf-8".into())], - data: ResponseBody::from_string(scanners.join("\n")), - upgrade: None, + + // use web::block to offload blocking Diesel queries without blocking server thread + let scanners_list = web::block(move || { + // note that obtaining a connection from the pool is also potentially blocking + let conn = &mut pool.get().unwrap(); + match Scanner::list_names(scanner_name, conn) { + Ok(data) => Ok(data), + Err(err) => Err(err), } + }) + .await + // map diesel query errors to a 500 error response + .map_err(|err| ErrorInternalServerError(err)) + .unwrap(); + + if let Ok(scanners) = scanners_list { + Ok(html_contents(scanners.join("\n"))) } else { - Response { - status_code: 500, - headers: vec![("Content-Type".into(), "text/plain; charset=utf-8".into())], - data: ResponseBody::from_string("Unable to list scanners"), - upgrade: None, - } + Ok(server_error("Unable to list scanners".to_string())) } } @@ -339,10 +366,23 @@ static SCAN_TASKS_FOOT: &str = r#" "#; -fn handle_list_scan_tasks(conn: &mut MysqlConnection) -> Response { +async fn handle_list_scan_tasks(pool: web::Data) -> HttpResponse { let mut html_data: Vec = vec![SCAN_TASKS_HEAD.to_string()]; - if let Ok(scan_tasks) = ScanTask::list(conn) { + // use web::block to offload blocking Diesel queries without blocking server thread + let scan_tasks_list = web::block(move || { + // note that obtaining a connection from the pool is also potentially blocking + let conn = &mut pool.get().unwrap(); + match ScanTask::list(conn) { + Ok(data) => Ok(data), + Err(err) => Err(err), + } + }) + .await + // map diesel query errors to a 500 error response + .map_err(|err| ErrorInternalServerError(err)); + + if let Ok(scan_tasks) = scan_tasks_list.unwrap() { for row in scan_tasks { let cidr: String = row.cidr; let started_at: Option = row.started_at; @@ -350,38 +390,34 @@ fn handle_list_scan_tasks(conn: &mut MysqlConnection) -> Response { let ended_at: Option = row.ended_at; html_data.push(format!( " - - {cidr} - {:#?} - {:#?} - {:#?} - - ", + + {cidr} + {:#?} + {:#?} + {:#?} + + ", started_at, still_processing_at, ended_at )); } html_data.push(SCAN_TASKS_FOOT.to_string()); - Response { - status_code: 200, - headers: vec![("Content-Type".into(), "text/html; charset=utf-8".into())], - data: ResponseBody::from_string(html_data.join("\n")), - upgrade: None, - } + html_contents(html_data.join("\n")) } else { - Response { - status_code: 500, - headers: vec![("Content-Type".into(), "text/plain; charset=utf-8".into())], - data: ResponseBody::from_string("Unable to list scan tasks"), - upgrade: None, - } + return server_error("Unable to list scan tasks".to_string()); } } -fn get_connection(database_url: &str) -> MysqlConnection { - MysqlConnection::establish(&database_url) - .unwrap_or_else(|_| panic!("Error connecting to {}", database_url)) +fn get_connection(database_url: &str) -> DbPool { + let manager = ConnectionManager::::new(database_url); + // Refer to the `r2d2` documentation for more methods to use + // when building a connection pool + Pool::builder() + .max_size(30) + .test_on_check_out(true) + .build(manager) + .expect("Could not build connection pool") } fn get_dns_client() -> SyncClient { @@ -391,28 +427,122 @@ fn get_dns_client() -> SyncClient { SyncClient::new(dns_conn) } -fn main() -> Result<(), ()> { +fn plain_contents(data: String) -> HttpResponse { + HttpResponse::Ok() + .content_type(ContentType::plaintext()) + .body(data) +} + +fn html_contents(data: String) -> HttpResponse { + HttpResponse::Ok() + .content_type(ContentType::html()) + .body(data) +} + +fn server_error(data: String) -> HttpResponse { + HttpResponse::InternalServerError() + .content_type(ContentType::html()) + .body(data) +} + +async fn index() -> HttpResponse { + html_contents(FORM.to_string()) +} + +async fn pong() -> HttpResponse { + plain_contents("pong".to_string()) +} + +#[actix_web::main] +async fn main() -> std::io::Result<()> { let server_address: String = if let Ok(env) = env::var("SERVER_ADDRESS") { env } else { "localhost:8000".to_string() }; - println!("Now listening on {}", server_address); - let db_url: String = if let Ok(env) = env::var("DB_URL") { env } else { - "./snow-scanner.sqlite".to_string() + eprintln!("Missing ENV: DB_URL"); + "mysql://localhost".to_string() }; - let static_data_dir: String = match env::var("STATIC_DATA_DIR") { - Ok(val) => val, - Err(_) => "../data/".to_string(), - }; + let server = HttpServer::new(move || { + let static_data_dir: String = match env::var("STATIC_DATA_DIR") { + Ok(val) => val, + Err(_) => "../data/".to_string(), + }; + + let pool = get_connection(db_url.as_str()); + App::new() + .app_data(web::Data::new(pool.clone())) + .app_data(actix_web::web::Data::new(static_data_dir)) + .route("/", web::get().to(index)) + .route("/ping", web::get().to(pong)) + .route("/report", web::post().to(handle_report)) + .route("/scan", web::post().to(handle_scan)) + .route("/scan/tasks", web::get().to(handle_list_scan_tasks)) + .route( + "/scanners/{scanner_name}", + web::get().to(handle_list_scanners), + ) + .route( + "/collections/{vendor_name}/{file_name}", + web::get().to(handle_get_collection), + ) + }) + .bind(&server_address); + match server { + Ok(server) => { + println!("Now listening on {}", server_address); + server.run().await + } + Err(err) => { + eprintln!("Could not bind the server to {}", server_address); + Err(err) + } + } +} +/* +(POST) (/register) => { + let data = try_or_400!(post_input!(request, { + email: String, + })); + + // We just print what was received on stdout. Of course in a real application + // you probably want to process the data, eg. store it in a database. + println!("Received data: {:?}", data); + + + let mut mac = HmacSha256::new_from_slice(b"my secret and secure key") + .expect("HMAC can take key of any size"); + mac.update(data.email.as_bytes()); + + // `result` has type `CtOutput` which is a thin wrapper around array of + // bytes for providing constant time equality check + let result = mac.finalize(); + // To get underlying array use `into_bytes`, but be careful, since + // incorrect use of the code value may permit timing attacks which defeats + // the security provided by the `CtOutput` + let code_bytes = result.into_bytes(); + rouille::Response::html(format!("Success! {}.", hex::encode(code_bytes))) +}, + +(GET) (/{api_key: String}/scanners/{scanner_name: String}) => { + let mut mac = HmacSha256::new_from_slice(b"my secret and secure key") + .expect("HMAC can take key of any size"); + + mac.update(b"williamdes@wdes.fr"); + + println!("{}", api_key); + let hex_key = hex::decode(&api_key).unwrap(); + // `verify_slice` will return `Ok(())` if code is correct, `Err(MacError)` otherwise + mac.verify_slice(&hex_key).unwrap(); + + rouille::Response::empty_404() +}, - let conn = &mut get_connection(db_url.as_str()); - /* thread::spawn(move || { let conn = &mut get_connection(db_url.as_str()); // Reset scan tasks @@ -470,69 +600,3 @@ fn main() -> Result<(), ()> { thread::sleep(two_hundred_millis); } });*/ - - rouille::start_server(server_address, move |request| { - router!(request, - (GET) (/) => { - rouille::Response::html(FORM) - }, - - (GET) (/ping) => { - rouille::Response::text("pong") - }, - - (POST) (/report) => {handle_report(conn, &request)}, - (POST) (/scan) => {handle_scan(conn, &request)}, - (GET) (/scan/tasks) => { - handle_list_scan_tasks(conn) - }, - - (POST) (/register) => { - let data = try_or_400!(post_input!(request, { - email: String, - })); - - // We just print what was received on stdout. Of course in a real application - // you probably want to process the data, eg. store it in a database. - println!("Received data: {:?}", data); - - - let mut mac = HmacSha256::new_from_slice(b"my secret and secure key") - .expect("HMAC can take key of any size"); - mac.update(data.email.as_bytes()); - - // `result` has type `CtOutput` which is a thin wrapper around array of - // bytes for providing constant time equality check - let result = mac.finalize(); - // To get underlying array use `into_bytes`, but be careful, since - // incorrect use of the code value may permit timing attacks which defeats - // the security provided by the `CtOutput` - let code_bytes = result.into_bytes(); - rouille::Response::html(format!("Success! {}.", hex::encode(code_bytes))) - }, - - (GET) (/scanners/{scanner_name: Scanners}) => { - handle_list_scanners(conn, &static_data_dir, scanner_name, &request) - }, - (GET) (/collections/{vendor_name: String}/{file_name: String}) => { - handle_get_collection(&request, &static_data_dir) - }, - (GET) (/{api_key: String}/scanners/{scanner_name: String}) => { - let mut mac = HmacSha256::new_from_slice(b"my secret and secure key") - .expect("HMAC can take key of any size"); - - mac.update(b"williamdes@wdes.fr"); - - println!("{}", api_key); - let hex_key = hex::decode(&api_key).unwrap(); - // `verify_slice` will return `Ok(())` if code is correct, `Err(MacError)` otherwise - mac.verify_slice(&hex_key).unwrap(); - - rouille::Response::empty_404() - }, - // The code block is called if none of the other blocks matches the request. - // We return an empty response with a 404 status code. - _ => rouille::Response::empty_404() - ) - }); -} diff --git a/snow-scanner/src/models.rs b/snow-scanner/src/models.rs index 8f3b2c0..3e911b9 100644 --- a/snow-scanner/src/models.rs +++ b/snow-scanner/src/models.rs @@ -1,12 +1,8 @@ -use std::str::FromStr; - use crate::Scanners; use chrono::NaiveDateTime; -use diesel::deserialize::FromSqlRow; use diesel::dsl::insert_into; use diesel::prelude::*; use diesel::result::Error as DieselError; -use uuid::Uuid; use crate::schema::scan_tasks::dsl::scan_tasks; use crate::schema::scanners::dsl::scanners; @@ -88,7 +84,7 @@ impl NewScanner { #[diesel(table_name = crate::schema::scan_tasks)] #[diesel(check_for_backend(diesel::mysql::Mysql))] pub struct ScanTask { - pub task_group_id: uuid::Uuid, + pub task_group_id: String, pub cidr: String, pub created_by_username: String, pub created_at: NaiveDateTime,