diff --git a/snow-scanner/Cargo.toml b/snow-scanner/Cargo.toml index 1a03ead..e778891 100644 --- a/snow-scanner/Cargo.toml +++ b/snow-scanner/Cargo.toml @@ -38,17 +38,19 @@ members = [ # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -ws2 = "0.2.5" -log2 = "0.1.11" -actix-web = "4" -actix-files = "0.6.6" +rocket = { git = "https://github.com/rwf2/Rocket/", rev = "3bf9ef02d6e803fe9f753777f5a829dda6d2453d"} +# contrib/db_pools +rocket_db_pools = { git = "https://github.com/rwf2/Rocket/", rev = "3bf9ef02d6e803fe9f753777f5a829dda6d2453d", default-features = false, features = ["diesel_mysql"] } # mariadb-dev on Alpine # "mysqlclient-src" "mysql_backend" -diesel = { version = "2.2.0", default-features = false, features = ["mysql", "chrono", "uuid", "r2d2"] } +diesel = { version = "^2", default-features = false, features = ["mysql", "chrono", "uuid"] } + +ws = { package = "rocket_ws", version = "0.1.1" } + dns-ptr-resolver = {git = "https://github.com/wdes/dns-ptr-resolver.git"} hickory-resolver = { version = "0.24.1", default-features = false, features = ["tokio-runtime", "dns-over-h3", "dns-over-https", "dns-over-quic"]} chrono = "0.4.38" uuid = { version = "1.10.0", default-features = false, features = ["v7", "serde", "std"] } -cidr = "0.2.2" -serde = "1.0.210" +cidr = "0.3.0" +serde = { version = "1.0.210", features = ["derive"] } serde_json = "1.0.128" diff --git a/snow-scanner/src/main.rs b/snow-scanner/src/main.rs index e7c6d54..ea8ac84 100644 --- a/snow-scanner/src/main.rs +++ b/snow-scanner/src/main.rs @@ -1,28 +1,38 @@ -use actix_files::NamedFile; -use actix_web::error::ErrorInternalServerError; -use actix_web::http::header::ContentType; -use actix_web::{web, App, HttpRequest, HttpResponse, HttpServer}; -use log2::*; - use chrono::{NaiveDateTime, Utc}; -use diesel::deserialize::{self}; -use diesel::mysql::{Mysql, MysqlValue}; -use diesel::sql_types::Text; -use diesel::r2d2::ConnectionManager; -use diesel::r2d2::Pool; +#[macro_use] +extern crate rocket; + +use rocket::{fairing::AdHoc, Build, Rocket, State}; +use rocket_db_pools::{ + rocket::{ + figment::{ + util::map, + value::{Map, Value}, + }, + form::Form, + fs::NamedFile, + launch, Responder, + }, + Connection, +}; + +use rocket_db_pools::diesel::mysql::{Mysql, MysqlValue}; +use rocket_db_pools::diesel::serialize::IsNull; +use rocket_db_pools::diesel::sql_types::Text; +use rocket_db_pools::diesel::{deserialize, serialize}; +use rocket_db_pools::diesel::MysqlPool; +use rocket_db_pools::Database; + use worker::detection::{detect_scanner, get_dns_client, Scanners}; -use std::collections::HashMap; -use std::io::Write; +use std::{io::Write, net::SocketAddr}; use std::path::PathBuf; use std::{env, fmt}; use uuid::Uuid; use serde::{Deserialize, Deserializer, Serialize}; -use diesel::serialize::IsNull; -use diesel::{serialize, MysqlConnection}; use dns_ptr_resolver::{get_ptr, ResolvedResult}; pub mod models; @@ -31,10 +41,12 @@ pub mod server; pub mod worker; use crate::models::*; -use crate::server::Server; -/// Short-hand for the database pool type to use throughout the app. -type DbPool = Pool>; +#[derive(Database)] +#[database("snow_scanner_db")] +pub struct SnowDb(MysqlPool); + +type DbConn = Connection; trait IsStatic { fn is_static(self: &Self) -> bool; @@ -103,16 +115,29 @@ impl serialize::ToSql for Scanners { impl deserialize::FromSql for Scanners { fn from_sql(bytes: MysqlValue) -> deserialize::Result { let value = >::from_sql(bytes)?; - match &value as &str { - "stretchoid" => Ok(Scanners::Stretchoid), - "binaryedge" => Ok(Scanners::Binaryedge), - "internet-measurement.com" => Ok(Scanners::InternetMeasurement), - _ => Err("Unrecognized enum variant".into()), + let value = &value as &str; + let value: Result = value.try_into(); + match value { + Ok(d) => Ok(d), + Err(err) => Err(err.into()), } } } -async fn handle_ip(pool: web::Data, ip: String) -> Result> { +impl TryInto for &str { + type Error = String; + + fn try_into(self) -> Result { + match self { + "stretchoid" => Ok(Scanners::Stretchoid), + "binaryedge" => Ok(Scanners::Binaryedge), + "internet-measurement.com" => Ok(Scanners::InternetMeasurement), + value => Err(format!("Invalid value: {value}")), + } + } +} + +async fn handle_ip(mut conn: DbConn, ip: String) -> Result> { let query_address = ip.parse().expect("To parse"); let ptr_result: Result = std::thread::spawn(move || { @@ -135,17 +160,11 @@ async fn handle_ip(pool: web::Data, ip: String) -> Result { - // use web::block to offload blocking Diesel queries without blocking server thread - web::block(move || { - // note that obtaining a connection from the pool is also potentially blocking - let conn = &mut pool.get().unwrap(); - match Scanner::find_or_new(query_address, scanner_type, result.result, conn) { - Ok(scanner) => Ok(scanner), - Err(_) => Err(None), - } - }) - .await - .unwrap() + match Scanner::find_or_new(query_address, scanner_type, result.result, &mut conn).await + { + Ok(scanner) => Ok(scanner), + Err(_) => Err(None), + } } Ok(None) => Err(None), @@ -176,55 +195,63 @@ static FORM: &str = r#" "#; -#[derive(Serialize, Deserialize)] -pub struct ScanParams { - username: String, - ips: String, +#[derive(FromForm, Serialize, Deserialize)] +pub struct ScanParams<'r> { + username: &'r str, + ips: &'r str, } -async fn handle_scan(pool: web::Data, params: web::Form) -> HttpResponse { - if params.username.len() < 4 { - return plain_contents("Invalid username".to_string()); +#[derive(Responder)] +enum MultiReply { + #[response(status = 500, content_type = "text")] + Error(ServerError), + #[response(status = 422)] + FormError(PlainText), + #[response(status = 404)] + NotFound(String), + #[response(status = 200)] + Content(HtmlContents), + #[response(status = 200)] + FileContents(NamedFile), +} + +#[post("/scan", data = "
")] +async fn handle_scan(mut db: DbConn, form: Form>) -> MultiReply { + if form.username.len() < 4 { + return MultiReply::FormError(PlainText("Invalid username".to_string())); } let task_group_id: Uuid = Uuid::now_v7(); - // use web::block to offload blocking Diesel queries without blocking server thread - let _ = web::block(move || { - // note that obtaining a connection from the pool is also potentially blocking - let conn = &mut pool.get().unwrap(); - for ip in params.ips.lines() { - let scan_task = ScanTask { - task_group_id: task_group_id.to_string(), - cidr: ip.to_string(), - created_by_username: params.username.clone(), - created_at: Utc::now().naive_utc(), - updated_at: None, - started_at: None, - still_processing_at: None, - ended_at: None, - }; - match scan_task.save(conn) { - Ok(_) => error!("Added {}", ip.to_string()), - Err(err) => error!("Not added: {:?}", err), - } + for ip in form.ips.lines() { + let scan_task = ScanTask { + task_group_id: task_group_id.to_string(), + cidr: ip.to_string(), + created_by_username: form.username.to_string(), + created_at: Utc::now().naive_utc(), + updated_at: None, + started_at: None, + still_processing_at: None, + ended_at: None, + }; + match scan_task.save(&mut db).await { + Ok(_) => error!("Added {}", ip.to_string()), + Err(err) => error!("Not added: {:?}", err), } - }) - .await - // map diesel query errors to a 500 error response - .map_err(|err| ErrorInternalServerError(err)); + } - html_contents(format!("New task added: {} !", task_group_id)) + MultiReply::Content(HtmlContents(format!("New task added: {} !", task_group_id))) } -#[derive(Serialize, Deserialize)] +#[derive(FromForm, Serialize, Deserialize)] pub struct ReportParams { ip: String, } -async fn handle_report(pool: web::Data, params: web::Form) -> HttpResponse { - match handle_ip(pool, params.ip.clone()).await { - Ok(scanner) => html_contents(match scanner.scanner_name { +#[post("/report", data = "")] +async fn handle_report(db: DbConn, form: Form) -> HtmlContents { + match handle_ip(db, form.ip.clone()).await { + Ok(scanner) => HtmlContents(match scanner.scanner_name { Scanners::Binaryedge => match scanner.last_checked_at { Some(date) => format!( "Reported a binaryedge ninja! {} known as {} since {date}.", @@ -252,9 +279,9 @@ async fn handle_report(pool: web::Data, params: web::Form) _ => format!("Not supported"), }), - Err(ptr_result) => html_contents(format!( + Err(ptr_result) => HtmlContents(format!( "The IP {} resolved as {:?} did not match known scanners patterns.", - params.ip, + form.ip, match ptr_result { Some(res) => res.result, None => None, @@ -267,87 +294,86 @@ struct SecurePath { pub data: String, } -impl<'de> Deserialize<'de> for SecurePath { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - let s = ::deserialize(deserializer)?; +impl TryInto for &str { + type Error = String; + + fn try_into(self) -> Result { // A-Z a-z 0-9 // . - _ - if s.chars() + if self + .chars() .all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_') { - return Ok(SecurePath { data: s }); + return Ok(SecurePath { + data: self.to_string(), + }); } - Err(serde::de::Error::custom(format!( - "Invalid value: {}", - s.to_string() - ))) + Err(format!("Invalid value: {}", self.to_string())) } } +#[get("/collections//")] async fn handle_get_collection( - path: web::Path<(SecurePath, SecurePath)>, - req: HttpRequest, - static_data_dir: actix_web::web::Data, -) -> actix_web::Result { - let (vendor_name, file_name) = path.into_inner(); + vendor_name: &str, + file_name: &str, + app_configs: &State, +) -> MultiReply { + let vendor_name: Result = vendor_name.try_into(); + let vendor_name = match vendor_name { + Ok(secure_path) => secure_path.data, + Err(err) => return MultiReply::FormError(PlainText(err.to_string())), + }; + let file_name: Result = file_name.try_into(); + let file_name = match file_name { + Ok(secure_path) => secure_path.data, + Err(err) => return MultiReply::FormError(PlainText(err.to_string())), + }; let mut path: PathBuf = PathBuf::new(); - let static_data_dir: String = static_data_dir.into_inner().to_string(); + let static_data_dir: String = app_configs.static_data_dir.clone(); path.push(static_data_dir); path.push("collections"); - path.push(vendor_name.data); - path.push(file_name.data); - match NamedFile::open(path) { - Ok(file) => Ok(file.into_response(&req)), - Err(err) => Ok(HttpResponse::NotFound() - .content_type(ContentType::plaintext()) - .body(format!("File not found: {}.\n", err))), + path.push(vendor_name); + path.push(file_name); + match NamedFile::open(path).await { + Ok(file) => MultiReply::FileContents(file), + Err(err) => MultiReply::NotFound(err.to_string()), } } +#[get("/scanners/")] async fn handle_list_scanners( - pool: web::Data, - path: web::Path, - req: HttpRequest, - static_data_dir: actix_web::web::Data, -) -> actix_web::Result { - let scanner_name = path.into_inner(); - let static_data_dir: String = static_data_dir.into_inner().to_string(); + mut db: DbConn, + scanner_name: &str, + app_configs: &State, +) -> MultiReply { + let static_data_dir: String = app_configs.static_data_dir.clone(); + let scanner_name: Result = scanner_name.try_into(); + let scanner_name = match scanner_name { + Ok(scanner_name) => scanner_name, + Err(err) => return MultiReply::FormError(PlainText(err.to_string())), + }; if scanner_name.is_static() { let mut path: PathBuf = PathBuf::new(); path.push(static_data_dir); path.push("scanners"); path.push(scanner_name.to_string()); - return match NamedFile::open(path) { - Ok(file) => Ok(file.into_response(&req)), - Err(err) => Ok(HttpResponse::NotFound() - .content_type(ContentType::plaintext()) - .body(format!("File not found: {}.\n", err))), + return match NamedFile::open(path).await { + Ok(file) => MultiReply::FileContents(file), + Err(err) => MultiReply::NotFound(err.to_string()), }; } - // use web::block to offload blocking Diesel queries without blocking server thread - let scanners_list = web::block(move || { - // note that obtaining a connection from the pool is also potentially blocking - let conn = &mut pool.get().unwrap(); - match Scanner::list_names(scanner_name, conn) { - Ok(data) => Ok(data), - Err(err) => Err(err), - } - }) - .await - // map diesel query errors to a 500 error response - .map_err(|err| ErrorInternalServerError(err)) - .unwrap(); + let scanners_list = match Scanner::list_names(scanner_name, &mut db).await { + Ok(data) => Ok(data), + Err(err) => Err(err), + }; if let Ok(scanners) = scanners_list { - Ok(html_contents(scanners.join("\n"))) + MultiReply::Content(HtmlContents(scanners.join("\n"))) } else { - Ok(server_error("Unable to list scanners".to_string())) + MultiReply::Error(ServerError("Unable to list scanners".to_string())) } } @@ -376,23 +402,16 @@ static SCAN_TASKS_FOOT: &str = r#" "#; -async fn handle_list_scan_tasks(pool: web::Data) -> HttpResponse { +#[get("/scan/tasks")] +async fn handle_list_scan_tasks(mut db: Connection) -> MultiReply { let mut html_data: Vec = vec![SCAN_TASKS_HEAD.to_string()]; - // use web::block to offload blocking Diesel queries without blocking server thread - let scan_tasks_list = web::block(move || { - // note that obtaining a connection from the pool is also potentially blocking - let conn = &mut pool.get().unwrap(); - match ScanTask::list(conn) { - Ok(data) => Ok(data), - Err(err) => Err(err), - } - }) - .await - // map diesel query errors to a 500 error response - .map_err(|err| ErrorInternalServerError(err)); + let scan_tasks_list = match ScanTask::list(&mut db).await { + Ok(data) => Ok(data), + Err(err) => Err(err), + }; - if let Ok(scan_tasks) = scan_tasks_list.unwrap() { + if let Ok(scan_tasks) = scan_tasks_list { for row in scan_tasks { let cidr: String = row.cidr; let started_at: Option = row.started_at; @@ -413,69 +432,81 @@ async fn handle_list_scan_tasks(pool: web::Data) -> HttpResponse { html_data.push(SCAN_TASKS_FOOT.to_string()); - html_contents(html_data.join("\n")) + MultiReply::Content(HtmlContents(html_data.join("\n"))) } else { - return server_error("Unable to list scan tasks".to_string()); + return MultiReply::Error(ServerError("Unable to list scan tasks".to_string())); } } -fn get_connection(database_url: &str) -> DbPool { - let manager = ConnectionManager::::new(database_url); - // Refer to the `r2d2` documentation for more methods to use - // when building a connection pool - Pool::builder() - .max_size(5) - .test_on_check_out(true) - .build(manager) - .expect("Could not build connection pool") +#[derive(Responder)] +#[response(status = 200, content_type = "text")] +pub struct PlainText(String); + +#[derive(Responder)] +#[response(status = 200, content_type = "html")] +pub struct HtmlContents(String); + +#[derive(Responder)] +#[response(status = 500, content_type = "html")] +pub struct ServerError(String); + +#[get("/")] +async fn index() -> HtmlContents { + HtmlContents(FORM.to_string()) } -fn plain_contents(data: String) -> HttpResponse { - HttpResponse::Ok() - .content_type(ContentType::plaintext()) - .body(data) +#[get("/ping")] +async fn pong() -> PlainText { + PlainText("pong".to_string()) } -fn html_contents(data: String) -> HttpResponse { - HttpResponse::Ok() - .content_type(ContentType::html()) - .body(data) +#[get("/ws")] +pub async fn ws() -> PlainText { + info!("establish_ws_connection"); + PlainText("ok".to_string()) + // Ok(HttpResponse::Unauthorized().json(e)) + /* + match result { + Ok(response) => Ok(response.into()), + Err(e) => { + error!("ws connection error: {:?}", e); + Err(e) + }, + }*/ } -fn server_error(data: String) -> HttpResponse { - HttpResponse::InternalServerError() - .content_type(ContentType::html()) - .body(data) +struct AppConfigs { + static_data_dir: String, } -async fn index() -> HttpResponse { - html_contents(FORM.to_string()) +async fn report_counts(rocket: Rocket) -> Rocket { + use rocket_db_pools::diesel::AsyncConnectionWrapper; + + let conn = SnowDb::fetch(&rocket) + .expect("database is attached") + .get().await + .unwrap_or_else(|e| { + span_error!("failed to connect to MySQL database" => error!("{e}")); + panic!("aborting launch"); + }); + + let _: AsyncConnectionWrapper<_> = conn.into(); + info!("Connected to the DB"); + + rocket } -async fn pong() -> HttpResponse { - plain_contents("pong".to_string()) -} - -#[actix_web::main] -async fn main() -> std::io::Result<()> { - let _log2 = log2::stdout() - .module(false) - .level(match env::var("RUST_LOG") { - Ok(level) => level, - Err(_) => "debug".to_string(), - }) - .start(); - - let server_address: String = if let Ok(env) = env::var("SERVER_ADDRESS") { - env +#[launch] +fn rocket() -> _ { + let server_address: SocketAddr = if let Ok(env) = env::var("SERVER_ADDRESS") { + env.parse().expect("The ENV SERVER_ADDRESS should be a valid socket address (address:port)") } else { - "127.0.0.1:8000".to_string() + "127.0.0.1:8000".parse().expect("The default address should be valid") }; - let worker_server_address: String = if let Ok(env) = env::var("WORKER_SERVER_ADDRESS") { - env - } else { - "127.0.0.1:8800".to_string() + let static_data_dir: String = match env::var("STATIC_DATA_DIR") { + Ok(val) => val, + Err(_) => "../data/".to_string(), }; let db_url: String = if let Ok(env) = env::var("DB_URL") { @@ -485,40 +516,36 @@ async fn main() -> std::io::Result<()> { "mysql://localhost".to_string() }; - let pool = get_connection(db_url.as_str()); - - // note that obtaining a connection from the pool is also potentially blocking - let conn = &mut pool.get().unwrap(); - let names = Scanner::list_names(Scanners::Stretchoid, conn); - match names { - Ok(names) => info!("Found {} Stretchoid scanners", names.len()), - Err(err) => error!("Unable to get names: {}", err), + let db: Map<_, Value> = map! { + "url" => db_url.into(), + "pool_size" => 10.into(), + "timeout" => 5.into(), }; - let server = HttpServer::new(move || { - let static_data_dir: String = match env::var("STATIC_DATA_DIR") { - Ok(val) => val, - Err(_) => "../data/".to_string(), - }; + let config_figment = rocket::Config::figment() + .merge(("address", server_address.ip().to_string())) + .merge(("port", server_address.port())) + .merge(("databases", map!["snow_scanner_db" => db])); + + rocket::custom(config_figment) + .attach(SnowDb::init()) + .attach(AdHoc::on_ignite("Report counts", report_counts)) + .manage(AppConfigs { static_data_dir }) + .mount( + "/", + routes![ + index, + pong, + handle_report, + handle_scan, + handle_list_scan_tasks, + handle_list_scanners, + handle_get_collection, + ws, + ], + ) + /* - App::new() - .app_data(web::Data::new(pool.clone())) - .app_data(actix_web::web::Data::new(static_data_dir)) - .route("/", web::get().to(index)) - .route("/ping", web::get().to(pong)) - .route("/report", web::post().to(handle_report)) - .route("/scan", web::post().to(handle_scan)) - .route("/scan/tasks", web::get().to(handle_list_scan_tasks)) - .route( - "/scanners/{scanner_name}", - web::get().to(handle_list_scanners), - ) - .route( - "/collections/{vendor_name}/{file_name}", - web::get().to(handle_get_collection), - ) - }) - .bind(&server_address); match server { Ok(server) => { match ws2::listen(worker_server_address.as_str()) { @@ -544,13 +571,5 @@ async fn main() -> std::io::Result<()> { } Err(err) => error!("Unable to listen on {worker_server_address}: {err}"), }; - - info!("Now listening on {}", server_address); - server.run().await - } - Err(err) => { - error!("Could not bind the server to {}", server_address); - Err(err) - } - } + }*/ } diff --git a/snow-scanner/src/models.rs b/snow-scanner/src/models.rs index b06545e..6fe3bc3 100644 --- a/snow-scanner/src/models.rs +++ b/snow-scanner/src/models.rs @@ -1,11 +1,9 @@ use std::net::IpAddr; -use crate::Scanners; +use crate::{DbConn, Scanners}; use chrono::{NaiveDateTime, Utc}; -use diesel::dsl::insert_into; -use diesel::prelude::*; -use diesel::result::Error as DieselError; use hickory_resolver::Name; +use rocket_db_pools::diesel::{dsl::insert_into, prelude::*, result::Error as DieselError}; use crate::schema::scan_tasks::dsl::scan_tasks; use crate::schema::scanners::dsl::scanners; @@ -25,14 +23,14 @@ pub struct Scanner { } impl Scanner { - pub fn find_or_new( + pub async fn find_or_new( query_address: IpAddr, scanner_name: Scanners, ptr: Option, - conn: &mut MysqlConnection, + conn: &mut DbConn, ) -> Result { let ip_type = if query_address.is_ipv6() { 6 } else { 4 }; - let scanner_row_result = Scanner::find(query_address.to_string(), ip_type, conn); + let scanner_row_result = Scanner::find(query_address.to_string(), ip_type, conn).await; let scanner_row = match scanner_row_result { Ok(scanner_row) => scanner_row, Err(_) => return Err(()), @@ -58,31 +56,31 @@ impl Scanner { last_checked_at: None, } }; - match scanner.save(conn) { + match scanner.save(conn).await { Ok(scanner) => Ok(scanner), Err(_) => Err(()), } } - pub fn find( + pub async fn find( ip_address: String, ip_type: u8, - conn: &mut MysqlConnection, + conn: &mut DbConn, ) -> Result, DieselError> { use crate::schema::scanners; - scanners .select(Scanner::as_select()) .filter(scanners::ip.eq(ip_address)) .filter(scanners::ip_type.eq(ip_type)) .order((scanners::ip_type.desc(), scanners::created_at.desc())) .first(conn) + .await .optional() } - pub fn list_names( + pub async fn list_names( scanner_name: Scanners, - conn: &mut MysqlConnection, + conn: &mut DbConn, ) -> Result, DieselError> { use crate::schema::scanners; use crate::schema::scanners::ip; @@ -92,16 +90,20 @@ impl Scanner { .filter(scanners::scanner_name.eq(scanner_name.to_string())) .order((scanners::ip_type.desc(), scanners::created_at.desc())) .load::(conn) + .await } - pub fn save(self: Scanner, conn: &mut MysqlConnection) -> Result { - let new_scanner = NewScanner::from_scanner(&self); - match insert_into(scanners) + pub async fn save(self: Scanner, conn: &mut DbConn) -> Result { + use crate::schema::scanners; + + let new_scanner: NewScanner = NewScanner::from_scanner(&self).await; + match insert_into(scanners::table) .values(&new_scanner) .on_conflict(diesel::dsl::DuplicatedKeys) .do_update() .set(&new_scanner) .execute(conn) + .await { Ok(_) => Ok(self), Err(err) => Err(err), @@ -124,7 +126,7 @@ pub struct NewScanner { } impl NewScanner { - pub fn from_scanner<'x>(scanner: &Scanner) -> NewScanner { + pub async fn from_scanner<'x>(scanner: &Scanner) -> NewScanner { NewScanner { ip: scanner.ip.to_string(), ip_type: scanner.ip_type, @@ -165,21 +167,22 @@ pub struct ScanTaskitem { } impl ScanTask { - pub fn list_not_started(conn: &mut MysqlConnection) -> Result, DieselError> { + pub async fn list_not_started(mut conn: DbConn) -> Result, DieselError> { use crate::schema::scan_tasks; let res = scan_tasks .select(ScanTaskitem::as_select()) .filter(scan_tasks::started_at.is_null()) .order((scan_tasks::created_at.asc(),)) - .load::(conn); + .load::(&mut conn) + .await; match res { Ok(rows) => Ok(rows), Err(err) => Err(err), } } - pub fn list(conn: &mut MysqlConnection) -> Result, DieselError> { + pub async fn list(conn: &mut DbConn) -> Result, DieselError> { use crate::schema::scan_tasks; let res = scan_tasks @@ -188,21 +191,26 @@ impl ScanTask { scan_tasks::created_at.desc(), scan_tasks::task_group_id.asc(), )) - .load::(conn); + .load::(conn) + .await; match res { Ok(rows) => Ok(rows), Err(err) => Err(err), } } - pub fn save(self: &ScanTask, conn: &mut MysqlConnection) -> Result<(), DieselError> { - let new_scan_task = NewScanTask::from_scan_task(self); - match insert_into(scan_tasks) + pub async fn save(self: &ScanTask, conn: &mut DbConn) -> Result<(), DieselError> { + use crate::schema::scan_tasks; + + let new_scan_task: NewScanTask = NewScanTask::from_scan_task(self).await; + + match insert_into(scan_tasks::table) .values(&new_scan_task) .on_conflict(diesel::dsl::DuplicatedKeys) .do_update() .set(&new_scan_task) .execute(conn) + .await { Ok(_) => Ok(()), Err(err) => Err(err), @@ -225,7 +233,7 @@ pub struct NewScanTask { } impl NewScanTask { - pub fn from_scan_task<'x>(scan_task: &ScanTask) -> NewScanTask { + pub async fn from_scan_task<'x>(scan_task: &ScanTask) -> NewScanTask { NewScanTask { task_group_id: scan_task.task_group_id.to_string(), cidr: scan_task.cidr.to_owned(), diff --git a/snow-scanner/src/server.rs b/snow-scanner/src/server.rs index 8560760..6ab4c53 100644 --- a/snow-scanner/src/server.rs +++ b/snow-scanner/src/server.rs @@ -1,17 +1,7 @@ -use cidr::IpCidr; -use diesel::MysqlConnection; use hickory_resolver::Name; -use log2::*; use std::{collections::HashMap, net::IpAddr, str::FromStr}; -use ws2::{Pod, WebSocket}; -use crate::{ - worker::{ - detection::detect_scanner_from_name, - modules::{Network, WorkerMessages}, - }, - DbPool, Scanner, -}; +use crate::{worker::detection::detect_scanner_from_name, DbConn, Scanner}; pub struct Server { pub clients: HashMap, @@ -19,12 +9,7 @@ pub struct Server { } impl Server { - pub fn cleanup(&self, _: &ws2::Server) -> &Server { - // TODO: implement check not logged in - &self - } - - pub fn commit(&mut self, conn: &mut MysqlConnection) -> &Server { + pub async fn commit(&mut self, conn: &mut DbConn) -> &Server { for (name, query_address) in self.new_scanners.clone() { let scanner_name = Name::from_str(name.as_str()).unwrap(); @@ -35,7 +20,9 @@ impl Server { scanner_type, Some(scanner_name), conn, - ) { + ) + .await + { Ok(scanner) => { // Got saved self.new_scanners.remove(&name); @@ -92,6 +79,7 @@ impl Worker { } } +/* impl ws2::Handler for Server { fn on_open(&mut self, ws: &WebSocket) -> Pod { info!("New client: {ws}"); @@ -165,4 +153,4 @@ impl ws2::Handler for Server { } result } -} +}*/ diff --git a/snow-scanner/src/worker/Cargo.toml b/snow-scanner/src/worker/Cargo.toml index 9e346d8..5436f6a 100644 --- a/snow-scanner/src/worker/Cargo.toml +++ b/snow-scanner/src/worker/Cargo.toml @@ -15,7 +15,7 @@ path = "worker.rs" [dependencies] tungstenite = { version = "0.24.0", default-features = true, features = ["native-tls"] } log2 = "0.1.11" -diesel = { version = "2.2.0", default-features = false, features = [] } +diesel = { version = "2", default-features = false, features = [] } dns-ptr-resolver = {git = "https://github.com/wdes/dns-ptr-resolver.git"} hickory-resolver = { version = "0.24.1", default-features = false, features = ["tokio-runtime", "dns-over-h3", "dns-over-https", "dns-over-quic"]} chrono = "0.4.38"