Move to a mutex

This commit is contained in:
2024-06-22 20:58:35 +02:00
parent dfdc1db11a
commit 854d457caf

View File

@ -9,6 +9,7 @@ use rusqlite::{named_params, Connection, OpenFlags, Result, ToSql};
use sha2::Sha256; use sha2::Sha256;
use std::fmt; use std::fmt;
use std::str::FromStr; use std::str::FromStr;
use std::sync::Mutex;
use hickory_client::client::SyncClient; use hickory_client::client::SyncClient;
use hickory_client::rr::Name; use hickory_client::rr::Name;
@ -73,8 +74,12 @@ fn save_scanner(conn: &Connection, scanner: &Scanner) -> Result<(), ()> {
":last_checked_at": &scanner.last_checked_at ":last_checked_at": &scanner.last_checked_at
}, },
) { ) {
Ok(_) => Ok(()), Ok(_) => {
Err(_) => Err(()), Ok(())
},
Err(_) => {
Err(())
},
} }
} }
@ -118,11 +123,14 @@ static FORM: &str = r#"
</html> </html>
"#; "#;
fn handle_report( fn handle_scan(conn: &Mutex<Connection>, request: &Request) -> Response {
conn: &Connection, let data = try_or_400!(post_input!(request, {
client: SyncClient<TcpClientConnection>, ips: String,
request: &Request, }));
) -> Response { rouille::Response::html(data.ips.split('\n').collect::<Vec<&str>>().join("<br>"))
}
fn handle_report(conn: &Mutex<Connection>, request: &Request) -> Response {
let data = try_or_400!(post_input!(request, { let data = try_or_400!(post_input!(request, {
ip: String, ip: String,
})); }));
@ -132,6 +140,7 @@ fn handle_report(
println!("Received data: {:?}", data); println!("Received data: {:?}", data);
let query_address = data.ip.parse().expect("To parse"); let query_address = data.ip.parse().expect("To parse");
let client = get_dns_client();
let ptr_result = get_ptr(query_address, client).unwrap(); let ptr_result = get_ptr(query_address, client).unwrap();
match detect_scanner(&ptr_result) { match detect_scanner(&ptr_result) {
@ -146,7 +155,8 @@ fn handle_report(
last_seen_at: Utc::now().to_string(), last_seen_at: Utc::now().to_string(),
last_checked_at: Utc::now().to_string(), last_checked_at: Utc::now().to_string(),
}; };
save_scanner(&conn, &scanner).unwrap(); let db = conn.lock().unwrap();
save_scanner(&db, &scanner).unwrap();
rouille::Response::html(match scanner_name { rouille::Response::html(match scanner_name {
Scanners::Binaryedge => format!( Scanners::Binaryedge => format!(
"Reported an escaped ninja! <b>{}</a> {:?}.", "Reported an escaped ninja! <b>{}</a> {:?}.",
@ -168,8 +178,9 @@ fn handle_report(
} }
} }
fn handle_list_scanners(conn: &Connection, scanner_name: String) -> Response { fn handle_list_scanners(conn: &Mutex<Connection>, scanner_name: String) -> Response {
let mut stmt = conn.prepare("SELECT ip FROM scanners WHERE scanner_name = :scanner_name ORDER BY ip_type, created_at").unwrap(); let db = conn.lock().unwrap();
let mut stmt = db.prepare("SELECT ip FROM scanners WHERE scanner_name = :scanner_name ORDER BY ip_type, created_at").unwrap();
let mut rows = stmt let mut rows = stmt
.query(named_params! { ":scanner_name": scanner_name }) .query(named_params! { ":scanner_name": scanner_name })
.unwrap(); .unwrap();
@ -214,16 +225,23 @@ fn get_connection() -> Connection {
conn conn
} }
fn get_dns_client() -> SyncClient<TcpClientConnection> {
let server = "1.1.1.1:53".parse().expect("To parse");
let dns_conn =
TcpClientConnection::with_timeout(server, std::time::Duration::new(5, 0)).unwrap();
SyncClient::new(dns_conn)
}
fn main() -> Result<()> { fn main() -> Result<()> {
println!("Now listening on localhost:8000"); println!("Now listening on localhost:8000");
let server = "1.1.1.1:53".parse().expect("To parse"); let conn = Mutex::new(get_connection());
let conn = TcpClientConnection::with_timeout(server, std::time::Duration::new(5, 0)).unwrap(); conn.lock()
.unwrap()
.execute("SELECT 0 WHERE 0;", named_params! {})
.expect("Failed to initialize database");
rouille::start_server("localhost:8000", move |request| { rouille::start_server("localhost:8000", move |request| {
let client = SyncClient::new(conn);
let conn = get_connection();
router!(request, router!(request,
(GET) (/) => { (GET) (/) => {
rouille::Response::html(FORM) rouille::Response::html(FORM)
@ -233,7 +251,8 @@ fn main() -> Result<()> {
rouille::Response::text("pong") rouille::Response::text("pong")
}, },
(POST) (/report) => {handle_report(&conn, client, &request)}, (POST) (/report) => {handle_report(&conn, &request)},
(POST) (/scan) => {handle_scan(&conn, &request)},
(POST) (/register) => { (POST) (/register) => {
let data = try_or_400!(post_input!(request, { let data = try_or_400!(post_input!(request, {