Move to Rocket !

This commit is contained in:
2024-10-03 10:34:56 +02:00
parent 04aea8558f
commit 25df2642e9
5 changed files with 293 additions and 276 deletions

View File

@ -1,28 +1,38 @@
use actix_files::NamedFile;
use actix_web::error::ErrorInternalServerError;
use actix_web::http::header::ContentType;
use actix_web::{web, App, HttpRequest, HttpResponse, HttpServer};
use log2::*;
use chrono::{NaiveDateTime, Utc};
use diesel::deserialize::{self};
use diesel::mysql::{Mysql, MysqlValue};
use diesel::sql_types::Text;
use diesel::r2d2::ConnectionManager;
use diesel::r2d2::Pool;
#[macro_use]
extern crate rocket;
use rocket::{fairing::AdHoc, Build, Rocket, State};
use rocket_db_pools::{
rocket::{
figment::{
util::map,
value::{Map, Value},
},
form::Form,
fs::NamedFile,
launch, Responder,
},
Connection,
};
use rocket_db_pools::diesel::mysql::{Mysql, MysqlValue};
use rocket_db_pools::diesel::serialize::IsNull;
use rocket_db_pools::diesel::sql_types::Text;
use rocket_db_pools::diesel::{deserialize, serialize};
use rocket_db_pools::diesel::MysqlPool;
use rocket_db_pools::Database;
use worker::detection::{detect_scanner, get_dns_client, Scanners};
use std::collections::HashMap;
use std::io::Write;
use std::{io::Write, net::SocketAddr};
use std::path::PathBuf;
use std::{env, fmt};
use uuid::Uuid;
use serde::{Deserialize, Deserializer, Serialize};
use diesel::serialize::IsNull;
use diesel::{serialize, MysqlConnection};
use dns_ptr_resolver::{get_ptr, ResolvedResult};
pub mod models;
@ -31,10 +41,12 @@ pub mod server;
pub mod worker;
use crate::models::*;
use crate::server::Server;
/// Short-hand for the database pool type to use throughout the app.
type DbPool = Pool<ConnectionManager<MysqlConnection>>;
#[derive(Database)]
#[database("snow_scanner_db")]
pub struct SnowDb(MysqlPool);
type DbConn = Connection<SnowDb>;
trait IsStatic {
fn is_static(self: &Self) -> bool;
@ -103,16 +115,29 @@ impl serialize::ToSql<Text, Mysql> for Scanners {
impl deserialize::FromSql<Text, Mysql> for Scanners {
fn from_sql(bytes: MysqlValue) -> deserialize::Result<Self> {
let value = <String as deserialize::FromSql<Text, Mysql>>::from_sql(bytes)?;
match &value as &str {
"stretchoid" => Ok(Scanners::Stretchoid),
"binaryedge" => Ok(Scanners::Binaryedge),
"internet-measurement.com" => Ok(Scanners::InternetMeasurement),
_ => Err("Unrecognized enum variant".into()),
let value = &value as &str;
let value: Result<Scanners, String> = value.try_into();
match value {
Ok(d) => Ok(d),
Err(err) => Err(err.into()),
}
}
}
async fn handle_ip(pool: web::Data<DbPool>, ip: String) -> Result<Scanner, Option<ResolvedResult>> {
impl TryInto<Scanners> for &str {
type Error = String;
fn try_into(self) -> Result<Scanners, Self::Error> {
match self {
"stretchoid" => Ok(Scanners::Stretchoid),
"binaryedge" => Ok(Scanners::Binaryedge),
"internet-measurement.com" => Ok(Scanners::InternetMeasurement),
value => Err(format!("Invalid value: {value}")),
}
}
}
async fn handle_ip(mut conn: DbConn, ip: String) -> Result<Scanner, Option<ResolvedResult>> {
let query_address = ip.parse().expect("To parse");
let ptr_result: Result<ResolvedResult, ()> = std::thread::spawn(move || {
@ -135,17 +160,11 @@ async fn handle_ip(pool: web::Data<DbPool>, ip: String) -> Result<Scanner, Optio
match detect_scanner(&result) {
Ok(Some(scanner_type)) => {
// use web::block to offload blocking Diesel queries without blocking server thread
web::block(move || {
// note that obtaining a connection from the pool is also potentially blocking
let conn = &mut pool.get().unwrap();
match Scanner::find_or_new(query_address, scanner_type, result.result, conn) {
Ok(scanner) => Ok(scanner),
Err(_) => Err(None),
}
})
.await
.unwrap()
match Scanner::find_or_new(query_address, scanner_type, result.result, &mut conn).await
{
Ok(scanner) => Ok(scanner),
Err(_) => Err(None),
}
}
Ok(None) => Err(None),
@ -176,55 +195,63 @@ static FORM: &str = r#"
</html>
"#;
#[derive(Serialize, Deserialize)]
pub struct ScanParams {
username: String,
ips: String,
#[derive(FromForm, Serialize, Deserialize)]
pub struct ScanParams<'r> {
username: &'r str,
ips: &'r str,
}
async fn handle_scan(pool: web::Data<DbPool>, params: web::Form<ScanParams>) -> HttpResponse {
if params.username.len() < 4 {
return plain_contents("Invalid username".to_string());
#[derive(Responder)]
enum MultiReply {
#[response(status = 500, content_type = "text")]
Error(ServerError),
#[response(status = 422)]
FormError(PlainText),
#[response(status = 404)]
NotFound(String),
#[response(status = 200)]
Content(HtmlContents),
#[response(status = 200)]
FileContents(NamedFile),
}
#[post("/scan", data = "<form>")]
async fn handle_scan(mut db: DbConn, form: Form<ScanParams<'_>>) -> MultiReply {
if form.username.len() < 4 {
return MultiReply::FormError(PlainText("Invalid username".to_string()));
}
let task_group_id: Uuid = Uuid::now_v7();
// use web::block to offload blocking Diesel queries without blocking server thread
let _ = web::block(move || {
// note that obtaining a connection from the pool is also potentially blocking
let conn = &mut pool.get().unwrap();
for ip in params.ips.lines() {
let scan_task = ScanTask {
task_group_id: task_group_id.to_string(),
cidr: ip.to_string(),
created_by_username: params.username.clone(),
created_at: Utc::now().naive_utc(),
updated_at: None,
started_at: None,
still_processing_at: None,
ended_at: None,
};
match scan_task.save(conn) {
Ok(_) => error!("Added {}", ip.to_string()),
Err(err) => error!("Not added: {:?}", err),
}
for ip in form.ips.lines() {
let scan_task = ScanTask {
task_group_id: task_group_id.to_string(),
cidr: ip.to_string(),
created_by_username: form.username.to_string(),
created_at: Utc::now().naive_utc(),
updated_at: None,
started_at: None,
still_processing_at: None,
ended_at: None,
};
match scan_task.save(&mut db).await {
Ok(_) => error!("Added {}", ip.to_string()),
Err(err) => error!("Not added: {:?}", err),
}
})
.await
// map diesel query errors to a 500 error response
.map_err(|err| ErrorInternalServerError(err));
}
html_contents(format!("New task added: {} !", task_group_id))
MultiReply::Content(HtmlContents(format!("New task added: {} !", task_group_id)))
}
#[derive(Serialize, Deserialize)]
#[derive(FromForm, Serialize, Deserialize)]
pub struct ReportParams {
ip: String,
}
async fn handle_report(pool: web::Data<DbPool>, params: web::Form<ReportParams>) -> HttpResponse {
match handle_ip(pool, params.ip.clone()).await {
Ok(scanner) => html_contents(match scanner.scanner_name {
#[post("/report", data = "<form>")]
async fn handle_report(db: DbConn, form: Form<ReportParams>) -> HtmlContents {
match handle_ip(db, form.ip.clone()).await {
Ok(scanner) => HtmlContents(match scanner.scanner_name {
Scanners::Binaryedge => match scanner.last_checked_at {
Some(date) => format!(
"Reported a binaryedge ninja! <b>{}</b> known as {} since {date}.",
@ -252,9 +279,9 @@ async fn handle_report(pool: web::Data<DbPool>, params: web::Form<ReportParams>)
_ => format!("Not supported"),
}),
Err(ptr_result) => html_contents(format!(
Err(ptr_result) => HtmlContents(format!(
"The IP <b>{}</a> resolved as {:?} did not match known scanners patterns.",
params.ip,
form.ip,
match ptr_result {
Some(res) => res.result,
None => None,
@ -267,87 +294,86 @@ struct SecurePath {
pub data: String,
}
impl<'de> Deserialize<'de> for SecurePath {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let s = <String>::deserialize(deserializer)?;
impl TryInto<SecurePath> for &str {
type Error = String;
fn try_into(self) -> Result<SecurePath, Self::Error> {
// A-Z a-z 0-9
// . - _
if s.chars()
if self
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_')
{
return Ok(SecurePath { data: s });
return Ok(SecurePath {
data: self.to_string(),
});
}
Err(serde::de::Error::custom(format!(
"Invalid value: {}",
s.to_string()
)))
Err(format!("Invalid value: {}", self.to_string()))
}
}
#[get("/collections/<vendor_name>/<file_name>")]
async fn handle_get_collection(
path: web::Path<(SecurePath, SecurePath)>,
req: HttpRequest,
static_data_dir: actix_web::web::Data<String>,
) -> actix_web::Result<HttpResponse> {
let (vendor_name, file_name) = path.into_inner();
vendor_name: &str,
file_name: &str,
app_configs: &State<AppConfigs>,
) -> MultiReply {
let vendor_name: Result<SecurePath, String> = vendor_name.try_into();
let vendor_name = match vendor_name {
Ok(secure_path) => secure_path.data,
Err(err) => return MultiReply::FormError(PlainText(err.to_string())),
};
let file_name: Result<SecurePath, String> = file_name.try_into();
let file_name = match file_name {
Ok(secure_path) => secure_path.data,
Err(err) => return MultiReply::FormError(PlainText(err.to_string())),
};
let mut path: PathBuf = PathBuf::new();
let static_data_dir: String = static_data_dir.into_inner().to_string();
let static_data_dir: String = app_configs.static_data_dir.clone();
path.push(static_data_dir);
path.push("collections");
path.push(vendor_name.data);
path.push(file_name.data);
match NamedFile::open(path) {
Ok(file) => Ok(file.into_response(&req)),
Err(err) => Ok(HttpResponse::NotFound()
.content_type(ContentType::plaintext())
.body(format!("File not found: {}.\n", err))),
path.push(vendor_name);
path.push(file_name);
match NamedFile::open(path).await {
Ok(file) => MultiReply::FileContents(file),
Err(err) => MultiReply::NotFound(err.to_string()),
}
}
#[get("/scanners/<scanner_name>")]
async fn handle_list_scanners(
pool: web::Data<DbPool>,
path: web::Path<Scanners>,
req: HttpRequest,
static_data_dir: actix_web::web::Data<String>,
) -> actix_web::Result<HttpResponse> {
let scanner_name = path.into_inner();
let static_data_dir: String = static_data_dir.into_inner().to_string();
mut db: DbConn,
scanner_name: &str,
app_configs: &State<AppConfigs>,
) -> MultiReply {
let static_data_dir: String = app_configs.static_data_dir.clone();
let scanner_name: Result<Scanners, String> = scanner_name.try_into();
let scanner_name = match scanner_name {
Ok(scanner_name) => scanner_name,
Err(err) => return MultiReply::FormError(PlainText(err.to_string())),
};
if scanner_name.is_static() {
let mut path: PathBuf = PathBuf::new();
path.push(static_data_dir);
path.push("scanners");
path.push(scanner_name.to_string());
return match NamedFile::open(path) {
Ok(file) => Ok(file.into_response(&req)),
Err(err) => Ok(HttpResponse::NotFound()
.content_type(ContentType::plaintext())
.body(format!("File not found: {}.\n", err))),
return match NamedFile::open(path).await {
Ok(file) => MultiReply::FileContents(file),
Err(err) => MultiReply::NotFound(err.to_string()),
};
}
// use web::block to offload blocking Diesel queries without blocking server thread
let scanners_list = web::block(move || {
// note that obtaining a connection from the pool is also potentially blocking
let conn = &mut pool.get().unwrap();
match Scanner::list_names(scanner_name, conn) {
Ok(data) => Ok(data),
Err(err) => Err(err),
}
})
.await
// map diesel query errors to a 500 error response
.map_err(|err| ErrorInternalServerError(err))
.unwrap();
let scanners_list = match Scanner::list_names(scanner_name, &mut db).await {
Ok(data) => Ok(data),
Err(err) => Err(err),
};
if let Ok(scanners) = scanners_list {
Ok(html_contents(scanners.join("\n")))
MultiReply::Content(HtmlContents(scanners.join("\n")))
} else {
Ok(server_error("Unable to list scanners".to_string()))
MultiReply::Error(ServerError("Unable to list scanners".to_string()))
}
}
@ -376,23 +402,16 @@ static SCAN_TASKS_FOOT: &str = r#"
</html>
"#;
async fn handle_list_scan_tasks(pool: web::Data<DbPool>) -> HttpResponse {
#[get("/scan/tasks")]
async fn handle_list_scan_tasks(mut db: Connection<SnowDb>) -> MultiReply {
let mut html_data: Vec<String> = vec![SCAN_TASKS_HEAD.to_string()];
// use web::block to offload blocking Diesel queries without blocking server thread
let scan_tasks_list = web::block(move || {
// note that obtaining a connection from the pool is also potentially blocking
let conn = &mut pool.get().unwrap();
match ScanTask::list(conn) {
Ok(data) => Ok(data),
Err(err) => Err(err),
}
})
.await
// map diesel query errors to a 500 error response
.map_err(|err| ErrorInternalServerError(err));
let scan_tasks_list = match ScanTask::list(&mut db).await {
Ok(data) => Ok(data),
Err(err) => Err(err),
};
if let Ok(scan_tasks) = scan_tasks_list.unwrap() {
if let Ok(scan_tasks) = scan_tasks_list {
for row in scan_tasks {
let cidr: String = row.cidr;
let started_at: Option<NaiveDateTime> = row.started_at;
@ -413,69 +432,81 @@ async fn handle_list_scan_tasks(pool: web::Data<DbPool>) -> HttpResponse {
html_data.push(SCAN_TASKS_FOOT.to_string());
html_contents(html_data.join("\n"))
MultiReply::Content(HtmlContents(html_data.join("\n")))
} else {
return server_error("Unable to list scan tasks".to_string());
return MultiReply::Error(ServerError("Unable to list scan tasks".to_string()));
}
}
fn get_connection(database_url: &str) -> DbPool {
let manager = ConnectionManager::<MysqlConnection>::new(database_url);
// Refer to the `r2d2` documentation for more methods to use
// when building a connection pool
Pool::builder()
.max_size(5)
.test_on_check_out(true)
.build(manager)
.expect("Could not build connection pool")
#[derive(Responder)]
#[response(status = 200, content_type = "text")]
pub struct PlainText(String);
#[derive(Responder)]
#[response(status = 200, content_type = "html")]
pub struct HtmlContents(String);
#[derive(Responder)]
#[response(status = 500, content_type = "html")]
pub struct ServerError(String);
#[get("/")]
async fn index() -> HtmlContents {
HtmlContents(FORM.to_string())
}
fn plain_contents(data: String) -> HttpResponse {
HttpResponse::Ok()
.content_type(ContentType::plaintext())
.body(data)
#[get("/ping")]
async fn pong() -> PlainText {
PlainText("pong".to_string())
}
fn html_contents(data: String) -> HttpResponse {
HttpResponse::Ok()
.content_type(ContentType::html())
.body(data)
#[get("/ws")]
pub async fn ws() -> PlainText {
info!("establish_ws_connection");
PlainText("ok".to_string())
// Ok(HttpResponse::Unauthorized().json(e))
/*
match result {
Ok(response) => Ok(response.into()),
Err(e) => {
error!("ws connection error: {:?}", e);
Err(e)
},
}*/
}
fn server_error(data: String) -> HttpResponse {
HttpResponse::InternalServerError()
.content_type(ContentType::html())
.body(data)
struct AppConfigs {
static_data_dir: String,
}
async fn index() -> HttpResponse {
html_contents(FORM.to_string())
async fn report_counts(rocket: Rocket<Build>) -> Rocket<Build> {
use rocket_db_pools::diesel::AsyncConnectionWrapper;
let conn = SnowDb::fetch(&rocket)
.expect("database is attached")
.get().await
.unwrap_or_else(|e| {
span_error!("failed to connect to MySQL database" => error!("{e}"));
panic!("aborting launch");
});
let _: AsyncConnectionWrapper<_> = conn.into();
info!("Connected to the DB");
rocket
}
async fn pong() -> HttpResponse {
plain_contents("pong".to_string())
}
#[actix_web::main]
async fn main() -> std::io::Result<()> {
let _log2 = log2::stdout()
.module(false)
.level(match env::var("RUST_LOG") {
Ok(level) => level,
Err(_) => "debug".to_string(),
})
.start();
let server_address: String = if let Ok(env) = env::var("SERVER_ADDRESS") {
env
#[launch]
fn rocket() -> _ {
let server_address: SocketAddr = if let Ok(env) = env::var("SERVER_ADDRESS") {
env.parse().expect("The ENV SERVER_ADDRESS should be a valid socket address (address:port)")
} else {
"127.0.0.1:8000".to_string()
"127.0.0.1:8000".parse().expect("The default address should be valid")
};
let worker_server_address: String = if let Ok(env) = env::var("WORKER_SERVER_ADDRESS") {
env
} else {
"127.0.0.1:8800".to_string()
let static_data_dir: String = match env::var("STATIC_DATA_DIR") {
Ok(val) => val,
Err(_) => "../data/".to_string(),
};
let db_url: String = if let Ok(env) = env::var("DB_URL") {
@ -485,40 +516,36 @@ async fn main() -> std::io::Result<()> {
"mysql://localhost".to_string()
};
let pool = get_connection(db_url.as_str());
// note that obtaining a connection from the pool is also potentially blocking
let conn = &mut pool.get().unwrap();
let names = Scanner::list_names(Scanners::Stretchoid, conn);
match names {
Ok(names) => info!("Found {} Stretchoid scanners", names.len()),
Err(err) => error!("Unable to get names: {}", err),
let db: Map<_, Value> = map! {
"url" => db_url.into(),
"pool_size" => 10.into(),
"timeout" => 5.into(),
};
let server = HttpServer::new(move || {
let static_data_dir: String = match env::var("STATIC_DATA_DIR") {
Ok(val) => val,
Err(_) => "../data/".to_string(),
};
let config_figment = rocket::Config::figment()
.merge(("address", server_address.ip().to_string()))
.merge(("port", server_address.port()))
.merge(("databases", map!["snow_scanner_db" => db]));
rocket::custom(config_figment)
.attach(SnowDb::init())
.attach(AdHoc::on_ignite("Report counts", report_counts))
.manage(AppConfigs { static_data_dir })
.mount(
"/",
routes![
index,
pong,
handle_report,
handle_scan,
handle_list_scan_tasks,
handle_list_scanners,
handle_get_collection,
ws,
],
)
/*
App::new()
.app_data(web::Data::new(pool.clone()))
.app_data(actix_web::web::Data::new(static_data_dir))
.route("/", web::get().to(index))
.route("/ping", web::get().to(pong))
.route("/report", web::post().to(handle_report))
.route("/scan", web::post().to(handle_scan))
.route("/scan/tasks", web::get().to(handle_list_scan_tasks))
.route(
"/scanners/{scanner_name}",
web::get().to(handle_list_scanners),
)
.route(
"/collections/{vendor_name}/{file_name}",
web::get().to(handle_get_collection),
)
})
.bind(&server_address);
match server {
Ok(server) => {
match ws2::listen(worker_server_address.as_str()) {
@ -544,13 +571,5 @@ async fn main() -> std::io::Result<()> {
}
Err(err) => error!("Unable to listen on {worker_server_address}: {err}"),
};
info!("Now listening on {}", server_address);
server.run().await
}
Err(err) => {
error!("Could not bind the server to {}", server_address);
Err(err)
}
}
}*/
}