diff --git a/snow-scanner/src/main.rs b/snow-scanner/src/main.rs index fcaf75c..bf44feb 100644 --- a/snow-scanner/src/main.rs +++ b/snow-scanner/src/main.rs @@ -7,6 +7,7 @@ use cidr::IpCidr; use event_bus::{EventBusSubscriber, EventBusWriter, EventBusWriterEvent}; use rocket::{ fairing::AdHoc, + form::FromFormField, futures::SinkExt, http::Status, request::{FromParam, FromRequest, Outcome, Request}, @@ -41,6 +42,7 @@ use worker::detection::{detect_scanner, get_dns_client, validate_ip, Scanners}; use std::{ env, fmt, + net::IpAddr, ops::{Deref, DerefMut}, }; use std::{io::Write, net::SocketAddr}; @@ -111,6 +113,30 @@ impl IsStatic for Scanners { } } +#[derive(serde::Deserialize, Clone)] +struct SafeIpAddr { + pub addr: IpAddr, +} + +impl FromFormField<'_> for SafeIpAddr { + fn from_value(field: rocket::form::ValueField<'_>) -> rocket::form::Result<'_, Self> { + let ip = field.value; + let query_address = IpAddr::from_str(ip); + match query_address { + Ok(ip) => { + if !validate_ip(ip) { + return Err(rocket::form::Error::validation(format!( + "Invalid IP address: {ip}" + )) + .into()); + } + Ok(SafeIpAddr { addr: ip }) + } + Err(err) => Err(rocket::form::Error::validation(format!("Invalid IP: {err}")).into()), + } + } +} + impl FromParam<'_> for Scanners { type Error = String; @@ -163,7 +189,7 @@ impl fmt::Display for Scanners { Self::Binaryedge => "binaryedge", Self::Censys => "censys", Self::InternetMeasurement => "internet-measurement.com", - Self::Shadowserver => "shadowserver.txt", + Self::Shadowserver => "shadowserver", } ) } @@ -176,7 +202,7 @@ impl serialize::ToSql for Scanners { Self::Binaryedge => out.write_all(b"binaryedge")?, Self::Censys => out.write_all(b"censys")?, Self::InternetMeasurement => out.write_all(b"internet-measurement.com")?, - Self::Shadowserver => out.write_all(b"shadowserver.txt")?, + Self::Shadowserver => out.write_all(b"shadowserver")?, }; Ok(IsNull::No) @@ -203,14 +229,15 @@ impl TryInto for &str { "stretchoid" => Ok(Scanners::Stretchoid), "binaryedge" => Ok(Scanners::Binaryedge), "internet-measurement.com" => Ok(Scanners::InternetMeasurement), + "shadowserver" => Ok(Scanners::Shadowserver), value => Err(format!("Invalid value: {value}")), } } } -async fn handle_ip(mut conn: DbConn, ip: String) -> Result> { - let query_address = ip.parse().expect("To parse"); - +async fn handle_ip( + query_address: IpAddr, +) -> Result<(IpAddr, Option, ResolvedResult), ()> { let ptr_result: Result = std::thread::spawn(move || { let client = get_dns_client(); let ptr_result: ResolvedResult = if let Ok(res) = get_ptr(query_address, client) { @@ -223,27 +250,19 @@ async fn handle_ip(mut conn: DbConn, ip: String) -> Result { - if !validate_ip(query_address) { - error!("Invalid IP address: {ip}"); - return Err(None); + match ptr_result { + Ok(result) => match detect_scanner(&result) { + Ok(Some(scanner_type)) => { + if !validate_ip(query_address) { + error!("Invalid IP address: {query_address}"); + return Err(()); + } + Ok((query_address, Some(scanner_type), result)) } - match Scanner::find_or_new(query_address, scanner_type, result.result, &mut conn).await - { - Ok(scanner) => Ok(scanner), - Err(_) => Err(None), - } - } - Ok(None) => Err(None), - - Err(_) => Err(Some(result)), + Ok(None) => Ok((query_address, None, result)), + Err(err) => Err(err), + }, + Err(err) => Err(err), } } @@ -282,11 +301,15 @@ enum MultiReply { Error(ServerError), #[response(status = 422)] FormError(PlainText), + #[response(status = 422)] + HtmlFormError(HtmlContents), #[response(status = 404)] NotFound(String), #[response(status = 200)] Content(HtmlContents), #[response(status = 200)] + TextContent(PlainText), + #[response(status = 200)] FileContents(NamedFile), } @@ -345,50 +368,89 @@ async fn handle_scan( MultiReply::Content(HtmlContents(format!("New task added: {} !", task_group_id))) } -#[derive(FromForm, Serialize, Deserialize)] +#[derive(FromForm, Deserialize)] pub struct ReportParams { - ip: String, + ip: SafeIpAddr, +} + +fn reply_contents_for_scanner_found(scanner: Scanner) -> HtmlContents { + HtmlContents(match scanner.scanner_name { + Scanners::Binaryedge => match scanner.last_checked_at { + Some(date) => format!( + "Reported a binaryedge ninja! {} known as {} since {date}.", + scanner.ip, + scanner.ip_ptr.unwrap_or("".to_string()) + ), + None => format!( + "Reported a binaryedge ninja! {} known as {}.", + scanner.ip, + scanner.ip_ptr.unwrap_or("".to_string()) + ), + }, + Scanners::Stretchoid => match scanner.last_checked_at { + Some(date) => format!( + "Reported a stretchoid agent! {} known as {} since {date}.", + scanner.ip, + scanner.ip_ptr.unwrap_or("".to_string()) + ), + None => format!( + "Reported a stretchoid agent! {} known as {}.", + scanner.ip, + scanner.ip_ptr.unwrap_or("".to_string()) + ), + }, + Scanners::Shadowserver => match scanner.last_checked_at { + Some(date) => format!( + "Reported a cloudy shadowserver ! {} known as {} since {date}.", + scanner.ip, + scanner.ip_ptr.unwrap_or("".to_string()) + ), + None => format!( + "Reported a cloudy shadowserver ! {} known as {}.", + scanner.ip, + scanner.ip_ptr.unwrap_or("".to_string()) + ), + }, + _ => format!("Not supported"), + }) } #[post("/report", data = "
")] -async fn handle_report(db: DbConn, form: Form) -> HtmlContents { - match handle_ip(db, form.ip.clone()).await { - Ok(scanner) => HtmlContents(match scanner.scanner_name { - Scanners::Binaryedge => match scanner.last_checked_at { - Some(date) => format!( - "Reported a binaryedge ninja! {} known as {} since {date}.", - scanner.ip, - scanner.ip_ptr.unwrap_or("".to_string()) - ), - None => format!( - "Reported a binaryedge ninja! {} known as {}.", - scanner.ip, - scanner.ip_ptr.unwrap_or("".to_string()) - ), +async fn handle_report(mut db: DbConn, form: Form) -> MultiReply { + match handle_ip(form.ip.addr).await { + Ok((query_address, scanner_type, result)) => match scanner_type { + Some(scanner_type) => match Scanner::find_or_new( + query_address, + scanner_type, + result.result.clone(), + &mut db, + ) + .await + { + Ok(scanner) => MultiReply::Content(reply_contents_for_scanner_found(scanner)), + Err(err) => MultiReply::Error(ServerError(format!( + "The IP {} resolved as {} could not be saved, server error: {err}.", + form.ip.addr, + match result.result { + Some(res) => res.to_string(), + None => "No value".to_string(), + } + ))), }, - Scanners::Stretchoid => match scanner.last_checked_at { - Some(date) => format!( - "Reported a stretchoid agent! {} known as {} since {date}.", - scanner.ip, - scanner.ip_ptr.unwrap_or("".to_string()) - ), - None => format!( - "Reported a stretchoid agent! {} known as {}.", - scanner.ip, - scanner.ip_ptr.unwrap_or("".to_string()) - ), - }, - _ => format!("Not supported"), - }), + None => MultiReply::HtmlFormError(HtmlContents(format!( + "The IP {} resolved as {:?} did not match known scanners patterns.", + form.ip.addr, + match result.result { + Some(res) => res.to_string(), + None => "No value".to_string(), + } + ))), + }, - Err(ptr_result) => HtmlContents(format!( - "The IP {} resolved as {:?} did not match known scanners patterns.", - form.ip, - match ptr_result { - Some(res) => res.result, - None => None, - } - )), + Err(_) => MultiReply::Error(ServerError(format!( + "The IP {} did encounter en error at resolve time.", + form.ip.addr + ))), } } @@ -447,7 +509,9 @@ async fn handle_list_scanners( path.push(static_data_dir); path.push("scanners"); path.push(match scanner_name { - Scanners::Stretchoid | Scanners::Binaryedge | Scanners::Shadowserver => panic!("This should not happen"), + Scanners::Stretchoid | Scanners::Binaryedge | Scanners::Shadowserver => { + panic!("This should not happen") + } Scanners::Censys => "censys.txt".to_string(), Scanners::InternetMeasurement => "internet-measurement.com.txt".to_string(), }); @@ -464,7 +528,7 @@ async fn handle_list_scanners( }; if let Ok(scanners) = scanners_list { - MultiReply::Content(HtmlContents(scanners.join("\n"))) + MultiReply::TextContent(PlainText(scanners.join("\n"))) } else { MultiReply::Error(ServerError("Unable to list scanners".to_string())) } @@ -679,3 +743,35 @@ async fn main() -> Result<(), rocket::Error> { .await; Ok(()) } + +#[cfg(test)] +mod test { + use super::*; + use hickory_resolver::{ + config::{NameServerConfigGroup, ResolverConfig, ResolverOpts}, + Name, Resolver, + }; + use std::time::Duration; + + #[test] + fn test_get_ptr() { + let server = NameServerConfigGroup::google(); + let config = ResolverConfig::from_parts(None, vec![], server); + let mut options = ResolverOpts::default(); + options.timeout = Duration::from_secs(5); + options.attempts = 1; // One try + + let resolver = Resolver::new(config, options).unwrap(); + + let query_address = "8.8.8.8".parse().expect("To parse"); + + assert_eq!( + get_ptr(query_address, resolver).unwrap(), + ResolvedResult { + query: Name::from_str_relaxed("8.8.8.8.in-addr.arpa.").unwrap(), + result: Some(Name::from_str_relaxed("dns.google.").unwrap()), + error: None, + } + ); + } +} diff --git a/snow-scanner/src/models.rs b/snow-scanner/src/models.rs index 6fe3bc3..ee1024a 100644 --- a/snow-scanner/src/models.rs +++ b/snow-scanner/src/models.rs @@ -28,12 +28,12 @@ impl Scanner { scanner_name: Scanners, ptr: Option, conn: &mut DbConn, - ) -> Result { + ) -> Result { let ip_type = if query_address.is_ipv6() { 6 } else { 4 }; let scanner_row_result = Scanner::find(query_address.to_string(), ip_type, conn).await; let scanner_row = match scanner_row_result { Ok(scanner_row) => scanner_row, - Err(_) => return Err(()), + Err(err) => return Err(err), }; let scanner = if let Some(mut scanner) = scanner_row { @@ -58,7 +58,7 @@ impl Scanner { }; match scanner.save(conn).await { Ok(scanner) => Ok(scanner), - Err(_) => Err(()), + Err(err) => Err(err), } } diff --git a/snow-scanner/src/worker/detection.rs b/snow-scanner/src/worker/detection.rs index 3092543..eceb533 100644 --- a/snow-scanner/src/worker/detection.rs +++ b/snow-scanner/src/worker/detection.rs @@ -10,7 +10,7 @@ use hickory_resolver::{Name, Resolver}; use crate::worker::ip_addr::is_global_hardcoded; -#[derive(Debug, Clone, Copy, FromSqlRow)] +#[derive(Debug, Clone, Copy, FromSqlRow, PartialEq)] pub enum Scanners { Stretchoid, Binaryedge, @@ -77,3 +77,34 @@ pub fn detect_scanner_from_name(name: &Name) -> Result, ()> { &_ => Ok(None), } } + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_detect_scanner_from_name() { + let ptr = Name::from_str("scan-47e.shadowserver.org.").unwrap(); + + assert_eq!( + detect_scanner_from_name(&ptr).unwrap(), + Some(Scanners::Shadowserver) + ); + } + + #[test] + fn test_detect_scanner() { + let cname_ptr = Name::from_str("111.0-24.197.62.64.in-addr.arpa.").unwrap(); + let ptr = Name::from_str("scan-47e.shadowserver.org.").unwrap(); + + assert_eq!( + detect_scanner(&ResolvedResult { + query: cname_ptr, + result: Some(ptr), + error: None + }) + .unwrap(), + Some(Scanners::Shadowserver) + ); + } +}