From 854d457caf8a32cd25fdab73d703873e1770fe84 Mon Sep 17 00:00:00 2001 From: William Desportes Date: Sat, 22 Jun 2024 20:58:35 +0200 Subject: [PATCH] Move to a mutex --- snow-scanner/src/main.rs | 51 +++++++++++++++++++++++++++------------- 1 file changed, 35 insertions(+), 16 deletions(-) diff --git a/snow-scanner/src/main.rs b/snow-scanner/src/main.rs index 69494b5..4a427e7 100644 --- a/snow-scanner/src/main.rs +++ b/snow-scanner/src/main.rs @@ -9,6 +9,7 @@ use rusqlite::{named_params, Connection, OpenFlags, Result, ToSql}; use sha2::Sha256; use std::fmt; use std::str::FromStr; +use std::sync::Mutex; use hickory_client::client::SyncClient; use hickory_client::rr::Name; @@ -73,8 +74,12 @@ fn save_scanner(conn: &Connection, scanner: &Scanner) -> Result<(), ()> { ":last_checked_at": &scanner.last_checked_at }, ) { - Ok(_) => Ok(()), - Err(_) => Err(()), + Ok(_) => { + Ok(()) + }, + Err(_) => { + Err(()) + }, } } @@ -118,11 +123,14 @@ static FORM: &str = r#" "#; -fn handle_report( - conn: &Connection, - client: SyncClient, - request: &Request, -) -> Response { +fn handle_scan(conn: &Mutex, request: &Request) -> Response { + let data = try_or_400!(post_input!(request, { + ips: String, + })); + rouille::Response::html(data.ips.split('\n').collect::>().join("
")) +} + +fn handle_report(conn: &Mutex, request: &Request) -> Response { let data = try_or_400!(post_input!(request, { ip: String, })); @@ -132,6 +140,7 @@ fn handle_report( println!("Received data: {:?}", data); let query_address = data.ip.parse().expect("To parse"); + let client = get_dns_client(); let ptr_result = get_ptr(query_address, client).unwrap(); match detect_scanner(&ptr_result) { @@ -146,7 +155,8 @@ fn handle_report( last_seen_at: Utc::now().to_string(), last_checked_at: Utc::now().to_string(), }; - save_scanner(&conn, &scanner).unwrap(); + let db = conn.lock().unwrap(); + save_scanner(&db, &scanner).unwrap(); rouille::Response::html(match scanner_name { Scanners::Binaryedge => format!( "Reported an escaped ninja! {} {:?}.", @@ -168,8 +178,9 @@ fn handle_report( } } -fn handle_list_scanners(conn: &Connection, scanner_name: String) -> Response { - let mut stmt = conn.prepare("SELECT ip FROM scanners WHERE scanner_name = :scanner_name ORDER BY ip_type, created_at").unwrap(); +fn handle_list_scanners(conn: &Mutex, scanner_name: String) -> Response { + let db = conn.lock().unwrap(); + let mut stmt = db.prepare("SELECT ip FROM scanners WHERE scanner_name = :scanner_name ORDER BY ip_type, created_at").unwrap(); let mut rows = stmt .query(named_params! { ":scanner_name": scanner_name }) .unwrap(); @@ -214,16 +225,23 @@ fn get_connection() -> Connection { conn } +fn get_dns_client() -> SyncClient { + let server = "1.1.1.1:53".parse().expect("To parse"); + let dns_conn = + TcpClientConnection::with_timeout(server, std::time::Duration::new(5, 0)).unwrap(); + SyncClient::new(dns_conn) +} + fn main() -> Result<()> { println!("Now listening on localhost:8000"); - let server = "1.1.1.1:53".parse().expect("To parse"); - let conn = TcpClientConnection::with_timeout(server, std::time::Duration::new(5, 0)).unwrap(); + let conn = Mutex::new(get_connection()); + conn.lock() + .unwrap() + .execute("SELECT 0 WHERE 0;", named_params! {}) + .expect("Failed to initialize database"); rouille::start_server("localhost:8000", move |request| { - let client = SyncClient::new(conn); - let conn = get_connection(); - router!(request, (GET) (/) => { rouille::Response::html(FORM) @@ -233,7 +251,8 @@ fn main() -> Result<()> { rouille::Response::text("pong") }, - (POST) (/report) => {handle_report(&conn, client, &request)}, + (POST) (/report) => {handle_report(&conn, &request)}, + (POST) (/scan) => {handle_scan(&conn, &request)}, (POST) (/register) => { let data = try_or_400!(post_input!(request, {