diff --git a/snow-scanner/src/main.rs b/snow-scanner/src/main.rs index b633eb2..71e1e53 100644 --- a/snow-scanner/src/main.rs +++ b/snow-scanner/src/main.rs @@ -9,7 +9,7 @@ use rocket::{ fairing::AdHoc, futures::SinkExt, http::Status, - request::{FromRequest, Outcome, Request}, + request::{FromParam, FromRequest, Outcome, Request}, trace::error, Rocket, State, }; @@ -111,6 +111,22 @@ impl IsStatic for Scanners { } } +impl FromParam<'_> for Scanners { + type Error = String; + + fn from_param(param: &'_ str) -> Result { + match param { + "stretchoid" => Ok(Scanners::Stretchoid), + "binaryedge" => Ok(Scanners::Binaryedge), + "stretchoid.txt" => Ok(Scanners::Stretchoid), + "binaryedge.txt" => Ok(Scanners::Binaryedge), + "censys.txt" => Ok(Scanners::Censys), + "internet-measurement.com.txt" => Ok(Scanners::InternetMeasurement), + v => Err(format!("Unknown value: {v}")), + } + } +} + impl<'de> Deserialize<'de> for Scanners { fn deserialize(deserializer: D) -> Result where @@ -364,47 +380,39 @@ struct SecurePath { pub data: String, } -impl TryInto for &str { +impl FromParam<'_> for SecurePath { type Error = String; - fn try_into(self) -> Result { + fn from_param(param: &'_ str) -> Result { // A-Z a-z 0-9 // . - _ - if self + if param .chars() .all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_') { return Ok(SecurePath { - data: self.to_string(), + data: param.to_string(), }); } - Err(format!("Invalid value: {}", self.to_string())) + Err(format!( + "Invalid path value (forbidden chars): {}", + param.to_string() + )) } } #[get("/collections//")] async fn handle_get_collection( - vendor_name: &str, - file_name: &str, + vendor_name: SecurePath, + file_name: SecurePath, app_configs: &State, ) -> MultiReply { - let vendor_name: Result = vendor_name.try_into(); - let vendor_name = match vendor_name { - Ok(secure_path) => secure_path.data, - Err(err) => return MultiReply::FormError(PlainText(err.to_string())), - }; - let file_name: Result = file_name.try_into(); - let file_name = match file_name { - Ok(secure_path) => secure_path.data, - Err(err) => return MultiReply::FormError(PlainText(err.to_string())), - }; - let mut path: PathBuf = PathBuf::new(); let static_data_dir: String = app_configs.static_data_dir.clone(); path.push(static_data_dir); path.push("collections"); - path.push(vendor_name); - path.push(file_name); + path.push(vendor_name.data); + path.push(file_name.data); match NamedFile::open(path).await { Ok(file) => MultiReply::FileContents(file), Err(err) => MultiReply::NotFound(err.to_string()), @@ -414,15 +422,10 @@ async fn handle_get_collection( #[get("/scanners/")] async fn handle_list_scanners( mut db: DbConn, - scanner_name: &str, + scanner_name: Scanners, app_configs: &State, ) -> MultiReply { let static_data_dir: String = app_configs.static_data_dir.clone(); - let scanner_name: Result = scanner_name.try_into(); - let scanner_name = match scanner_name { - Ok(scanner_name) => scanner_name, - Err(err) => return MultiReply::FormError(PlainText(err.to_string())), - }; if scanner_name.is_static() { let mut path: PathBuf = PathBuf::new(); path.push(static_data_dir);