added cranky and removed all .unwrap() usages

This commit is contained in:
Eugene Pankov 2022-07-23 21:31:35 +02:00
parent 3c78ce4888
commit 47518c6253
No known key found for this signature in database
GPG key ID: 5896FCBBDD1CF4F4
29 changed files with 140 additions and 129 deletions

8
Cranky.toml Normal file
View file

@ -0,0 +1,8 @@
deny = [
"unsafe_code",
"clippy::unwrap_used",
"clippy::expect_used",
"clippy::panic",
# "clippy::indexing_slicing",
# "clippy::integer_arithmetic",
]

2
clippy.toml Normal file
View file

@ -0,0 +1,2 @@
avoid-breaking-exported-api = false
allow-unwrap-in-tests = true

View file

@ -10,7 +10,7 @@ fix *ARGS:
for p in {{projects}}; do cargo fix -p $p {{ARGS}}; done for p in {{projects}}; do cargo fix -p $p {{ARGS}}; done
clippy *ARGS: clippy *ARGS:
for p in {{projects}}; do cargo clippy -p $p {{ARGS}}; done for p in {{projects}}; do cargo cranky -p $p {{ARGS}}; done
yarn *ARGS: yarn *ARGS:
cd warpgate-web && yarn {{ARGS}} cd warpgate-web && yarn {{ARGS}}

View file

@ -1,2 +1,2 @@
[toolchain] [toolchain]
channel = "nightly-2022-03-14" channel = "nightly-2022-07-22"

View file

@ -3,6 +3,7 @@ mod api;
use poem_openapi::OpenApiService; use poem_openapi::OpenApiService;
use regex::Regex; use regex::Regex;
#[allow(clippy::unwrap_used)]
pub fn main() { pub fn main() {
let api_service = let api_service =
OpenApiService::new(api::get(), "Warpgate Web Admin", env!("CARGO_PKG_VERSION")) OpenApiService::new(api::get(), "Warpgate Web Admin", env!("CARGO_PKG_VERSION"))

View file

@ -47,11 +47,13 @@ fn _default_database_url() -> Secret<String> {
#[inline] #[inline]
fn _default_http_listen() -> ListenEndpoint { fn _default_http_listen() -> ListenEndpoint {
#[allow(clippy::unwrap_used)]
ListenEndpoint("0.0.0.0:8888".to_socket_addrs().unwrap().next().unwrap()) ListenEndpoint("0.0.0.0:8888".to_socket_addrs().unwrap().next().unwrap())
} }
#[inline] #[inline]
fn _default_mysql_listen() -> ListenEndpoint { fn _default_mysql_listen() -> ListenEndpoint {
#[allow(clippy::unwrap_used)]
ListenEndpoint("0.0.0.0:33306".to_socket_addrs().unwrap().next().unwrap()) ListenEndpoint("0.0.0.0:33306".to_socket_addrs().unwrap().next().unwrap())
} }
@ -77,7 +79,7 @@ pub struct TargetSSHOptions {
pub auth: SSHTargetAuth, pub auth: SSHTargetAuth,
} }
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] #[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
#[serde(untagged)] #[serde(untagged)]
pub enum SSHTargetAuth { pub enum SSHTargetAuth {
#[serde(rename = "password")] #[serde(rename = "password")]
@ -182,7 +184,7 @@ pub enum TargetOptions {
WebAdmin(TargetWebAdminOptions), WebAdmin(TargetWebAdminOptions),
} }
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] #[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
#[serde(tag = "type")] #[serde(tag = "type")]
pub enum UserAuthCredential { pub enum UserAuthCredential {
#[serde(rename = "password")] #[serde(rename = "password")]
@ -221,6 +223,7 @@ pub struct Role {
} }
fn _default_ssh_listen() -> ListenEndpoint { fn _default_ssh_listen() -> ListenEndpoint {
#[allow(clippy::unwrap_used)]
ListenEndpoint("0.0.0.0:2222".to_socket_addrs().unwrap().next().unwrap()) ListenEndpoint("0.0.0.0:2222".to_socket_addrs().unwrap().next().unwrap())
} }

View file

@ -232,15 +232,13 @@ impl ConfigProvider for FileConfigProvider {
.roles .roles
.iter() .iter()
.map(|x| config.store.roles.iter().find(|y| &y.name == x)) .map(|x| config.store.roles.iter().find(|y| &y.name == x))
.filter(|x| x.is_some()) .filter_map(|x| x.to_owned())
.map(|x| x.unwrap().to_owned())
.collect::<HashSet<_>>(); .collect::<HashSet<_>>();
let target_roles = target let target_roles = target
.allow_roles .allow_roles
.iter() .iter()
.map(|x| config.store.roles.iter().find(|y| &y.name == x)) .map(|x| config.store.roles.iter().find(|y| &y.name == x))
.filter(|x| x.is_some()) .filter_map(|x| x.to_owned())
.map(|x| x.unwrap().to_owned())
.collect::<HashSet<_>>(); .collect::<HashSet<_>>();
let intersect = user_roles.intersection(&target_roles).count() > 0; let intersect = user_roles.intersection(&target_roles).count() > 0;

View file

@ -11,6 +11,8 @@ use crate::Secret;
pub fn hash_password(password: &str) -> String { pub fn hash_password(password: &str) -> String {
let salt = SaltString::generate(&mut OsRng); let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default(); let argon2 = Argon2::default();
// Only panics for invalid hash parameters
#[allow(clippy::unwrap_used)]
argon2 argon2
.hash_password(password.as_bytes(), &salt) .hash_password(password.as_bytes(), &salt)
.unwrap() .unwrap()

View file

@ -32,6 +32,7 @@ fn get_totp(key: &OtpSecretKey, label: Option<&str>) -> TOTP<OtpExposedSecretKey
} }
pub fn verify_totp(code: &str, key: &OtpSecretKey) -> bool { pub fn verify_totp(code: &str, key: &OtpSecretKey) -> bool {
#[allow(clippy::unwrap_used)]
let time = SystemTime::now() let time = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH) .duration_since(SystemTime::UNIX_EPOCH)
.unwrap() .unwrap()

View file

@ -32,7 +32,11 @@ where
pub fn install_database_logger(database: Arc<Mutex<DatabaseConnection>>) { pub fn install_database_logger(database: Arc<Mutex<DatabaseConnection>>) {
tokio::spawn(async move { tokio::spawn(async move {
let mut receiver = LOG_SENDER.get().unwrap().subscribe(); #[allow(clippy::expect_used)]
let mut receiver = LOG_SENDER
.get()
.expect("Log sender not ready yet")
.subscribe();
loop { loop {
match receiver.recv().await { match receiver.recv().await {
Err(_) => break, Err(_) => break,

View file

@ -49,9 +49,12 @@ where
return return
}; };
let buffer = BytesMut::from( let Ok(serialized) = serde_json::to_vec(&values) else {
&serde_json::to_vec(&values).expect("Cannot serialize log entry, this is a bug")[..], eprintln!("Failed to serialize log entry {values:?}");
); continue
};
let buffer = BytesMut::from(&serialized[..]);
if let Err(error) = socket.send_to(buffer.as_ref(), socket_address).await { if let Err(error) = socket.send_to(buffer.as_ref(), socket_address).await {
error!(%error, is_socket_logging_error=true, "Failed to forward log entry"); error!(%error, is_socket_logging_error=true, "Failed to forward log entry");
} }

View file

@ -8,7 +8,7 @@ use tracing_core::Field;
pub type SerializedRecordValuesInner = HashMap<&'static str, String>; pub type SerializedRecordValuesInner = HashMap<&'static str, String>;
#[derive(Serialize)] #[derive(Serialize, Debug)]
pub struct SerializedRecordValues(SerializedRecordValuesInner); pub struct SerializedRecordValues(SerializedRecordValuesInner);
impl SerializedRecordValues { impl SerializedRecordValues {

View file

@ -30,6 +30,9 @@ pub enum Error {
#[error("Disabled")] #[error("Disabled")]
Disabled, Disabled,
#[error("Invalid recording path")]
InvalidPath,
} }
pub type Result<T> = std::result::Result<T, Error>; pub type Result<T> = std::result::Result<T, Error>;
@ -71,7 +74,7 @@ impl SessionRecordings {
} }
let path = self.path_for(id, &name); let path = self.path_for(id, &name);
tokio::fs::create_dir_all(&path.parent().unwrap()).await?; tokio::fs::create_dir_all(&path.parent().ok_or(Error::InvalidPath)?).await?;
info!(%name, path=?path, "Recording session {}", id); info!(%name, path=?path, "Recording session {}", id);
let model = { let model = {

View file

@ -30,7 +30,7 @@ fn test_catch() {
let mut caught = false; let mut caught = false;
try_block!({ try_block!({
let _: u32 = "asdf".parse()?; let _: u32 = "asdf".parse()?;
panic!(); assert!(false)
} catch (e: anyhow::Error) { } catch (e: anyhow::Error) {
assert_eq!(e.to_string(), "asdf".parse::<i32>().unwrap_err().to_string()); assert_eq!(e.to_string(), "asdf".parse::<i32>().unwrap_err().to_string());
caught = true; caught = true;
@ -43,6 +43,6 @@ fn test_success() {
try_block!({ try_block!({
let _: u32 = "123".parse()?; let _: u32 = "123".parse()?;
} catch (_e: anyhow::Error) { } catch (_e: anyhow::Error) {
panic!(); assert!(false)
}); });
} }

View file

@ -13,7 +13,7 @@ use crate::helpers::rng::get_crypto_rng;
pub type SessionId = Uuid; pub type SessionId = Uuid;
pub type ProtocolName = &'static str; pub type ProtocolName = &'static str;
#[derive(PartialEq, Clone)] #[derive(PartialEq, Eq, Clone)]
pub struct Secret<T>(T); pub struct Secret<T>(T);
impl Secret<String> { impl Secret<String> {
@ -94,9 +94,7 @@ impl<'de> Deserialize<'de> for ListenEndpoint {
})? })?
.next() .next()
.ok_or_else(|| { .ok_or_else(|| {
return serde::de::Error::custom(format!( serde::de::Error::custom(format!("failed to resolve {v} into a TCP endpoint"))
"failed to resolve {v} into a TCP endpoint"
));
})?; })?;
Ok(Self(v)) Ok(Self(v))
} }

View file

@ -151,39 +151,6 @@ pub trait DatabaseError: 'static + Send + Sync + StdError {
} }
impl dyn DatabaseError { impl dyn DatabaseError {
/// Downcast a reference to this generic database error to a specific
/// database error type.
///
/// # Panics
///
/// Panics if the database error type is not `E`. This is a deliberate contrast from
/// `Error::downcast_ref` which returns `Option<&E>`. In normal usage, you should know the
/// specific error type. In other cases, use `try_downcast_ref`.
pub fn downcast_ref<E: DatabaseError>(&self) -> &E {
self.try_downcast_ref().unwrap_or_else(|| {
panic!(
"downcast to wrong DatabaseError type; original error: {}",
self
)
})
}
/// Downcast this generic database error to a specific database error type.
///
/// # Panics
///
/// Panics if the database error type is not `E`. This is a deliberate contrast from
/// `Error::downcast` which returns `Option<E>`. In normal usage, you should know the
/// specific error type. In other cases, use `try_downcast`.
pub fn downcast<E: DatabaseError>(self: Box<Self>) -> Box<E> {
self.try_downcast().unwrap_or_else(|e| {
panic!(
"downcast to wrong DatabaseError type; original error: {}",
e
)
})
}
/// Downcast a reference to this generic database error to a specific /// Downcast a reference to this generic database error to a specific
/// database error type. /// database error type.
#[inline] #[inline]
@ -195,6 +162,7 @@ impl dyn DatabaseError {
#[inline] #[inline]
pub fn try_downcast<E: DatabaseError>(self: Box<Self>) -> StdResult<Box<E>, Box<Self>> { pub fn try_downcast<E: DatabaseError>(self: Box<Self>) -> StdResult<Box<E>, Box<Self>> {
if self.as_error().is::<E>() { if self.as_error().is::<E>() {
#[allow(clippy::unwrap_used)]
Ok(self.into_error().downcast().unwrap()) Ok(self.into_error().downcast().unwrap())
} else { } else {
Err(self) Err(self)

View file

@ -41,7 +41,7 @@ impl BufExt for Bytes {
fn get_str_nul(&mut self) -> Result<String, Error> { fn get_str_nul(&mut self) -> Result<String, Error> {
self.get_bytes_nul().and_then(|bytes| { self.get_bytes_nul().and_then(|bytes| {
from_utf8(&*bytes) from_utf8(&bytes)
.map(ToOwned::to_owned) .map(ToOwned::to_owned)
.map_err(|err| err_protocol!("{}", err)) .map_err(|err| err_protocol!("{}", err))
}) })

View file

@ -4,6 +4,7 @@ use crate::err_protocol;
use crate::error::Error; use crate::error::Error;
#[derive(Debug, Copy, Clone, PartialEq, Eq)] #[derive(Debug, Copy, Clone, PartialEq, Eq)]
#[allow(clippy::enum_variant_names)]
pub enum AuthPlugin { pub enum AuthPlugin {
MySqlClearPassword, MySqlClearPassword,
MySqlNativePassword, MySqlNativePassword,

View file

@ -31,7 +31,7 @@ impl Decode<'_, Capabilities> for ErrPacket {
if capabilities.contains(Capabilities::PROTOCOL_41) { if capabilities.contains(Capabilities::PROTOCOL_41) {
// If the next byte is '#' then we have a SQL STATE // If the next byte is '#' then we have a SQL STATE
if buf.get(0) == Some(&0x23) { if buf.first() == Some(&0x23) {
buf.advance(1); buf.advance(1);
sql_state = Some(buf.get_str(5)?); sql_state = Some(buf.get_str(5)?);
} }

View file

@ -63,7 +63,7 @@ bitflags! {
// https://dev.mysql.com/doc/internals/en/com-query-response.html#column-type // https://dev.mysql.com/doc/internals/en/com-query-response.html#column-type
#[derive(Debug, Copy, Clone, PartialEq)] #[derive(Debug, Copy, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] #[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
#[repr(u8)] #[repr(u8)]
pub enum ColumnType { pub enum ColumnType {

View file

@ -6,6 +6,7 @@ mod session;
mod session_handle; mod session_handle;
use poem_openapi::OpenApiService; use poem_openapi::OpenApiService;
#[allow(clippy::unwrap_used)]
pub fn main() { pub fn main() {
let api_service = OpenApiService::new( let api_service = OpenApiService::new(
api::get(), api::get(),

View file

@ -1,4 +1,8 @@
use anyhow::Result; use std::borrow::Cow;
use std::collections::HashSet;
use std::str::FromStr;
use anyhow::{Context, Result};
use cookie::Cookie; use cookie::Cookie;
use delegate::delegate; use delegate::delegate;
use futures::{SinkExt, StreamExt}; use futures::{SinkExt, StreamExt};
@ -7,9 +11,6 @@ use http::uri::{Authority, Scheme};
use http::Uri; use http::Uri;
use poem::web::websocket::{CloseCode, Message, WebSocket}; use poem::web::websocket::{CloseCode, Message, WebSocket};
use poem::{Body, IntoResponse, Request, Response}; use poem::{Body, IntoResponse, Request, Response};
use std::borrow::Cow;
use std::collections::HashSet;
use std::str::FromStr;
use tokio_tungstenite::{connect_async_with_config, tungstenite}; use tokio_tungstenite::{connect_async_with_config, tungstenite};
use tracing::*; use tracing::*;
use warpgate_common::{try_block, TargetHTTPOptions}; use warpgate_common::{try_block, TargetHTTPOptions};
@ -75,33 +76,40 @@ lazy_static::lazy_static! {
}; };
} }
fn construct_uri(req: &Request, options: &TargetHTTPOptions, websocket: bool) -> Uri { fn construct_uri(req: &Request, options: &TargetHTTPOptions, websocket: bool) -> Result<Uri> {
let target_uri = Uri::try_from(options.url.clone()).unwrap(); let target_uri = Uri::try_from(options.url.clone())?;
let source_uri = req.uri().clone(); let source_uri = req.uri().clone();
let authority = target_uri.authority().unwrap().to_string(); let authority = target_uri
let authority = authority.split("@").last().unwrap(); .authority()
let authority: Authority = authority.try_into().unwrap(); .context("No authority in the URL")?
.to_string();
let authority = authority.split("@").last().context("Authority is empty")?;
let authority: Authority = authority.try_into()?;
let mut uri = http::uri::Builder::new() let mut uri = http::uri::Builder::new()
.authority(authority) .authority(authority)
.path_and_query(source_uri.path_and_query().unwrap().clone()); .path_and_query(
source_uri
.path_and_query()
.context("No path in the URL")?
.clone(),
);
uri = uri.scheme(target_uri.scheme().unwrap().clone()); let scheme = target_uri.scheme().context("No scheme in the URL")?;
uri = uri.scheme(scheme.clone());
if websocket { if websocket {
uri = uri.scheme( uri = uri.scheme(
Scheme::from_str( Scheme::from_str(if scheme == &Scheme::from_str("http").unwrap() {
if target_uri.scheme().unwrap() == &Scheme::from_str("http").unwrap() { "ws"
"ws" } else {
} else { "wss"
"wss" })
},
)
.unwrap(), .unwrap(),
); );
} }
uri.build().unwrap() Ok(uri.build()?)
} }
fn copy_client_response<R: SomeResponse>( fn copy_client_response<R: SomeResponse>(
@ -131,20 +139,23 @@ fn rewrite_request<B: SomeRequestBuilder>(mut req: B, options: &TargetHTTPOption
} }
fn rewrite_response(resp: &mut Response, options: &TargetHTTPOptions) -> Result<()> { fn rewrite_response(resp: &mut Response, options: &TargetHTTPOptions) -> Result<()> {
let target_uri = Uri::try_from(options.url.clone()).unwrap(); let target_uri = Uri::try_from(options.url.clone())?;
let headers = resp.headers_mut(); let headers = resp.headers_mut();
if let Some(value) = headers.get_mut(http::header::LOCATION) { if let Some(value) = headers.get_mut(http::header::LOCATION) {
let redirect_uri = Uri::try_from(value.as_bytes()).unwrap(); let redirect_uri = Uri::try_from(value.as_bytes())?;
if redirect_uri.authority() == target_uri.authority() { if redirect_uri.authority() == target_uri.authority() {
let old_value = value.clone(); let old_value = value.clone();
*value = Uri::builder() *value = Uri::builder()
.path_and_query(redirect_uri.path_and_query().unwrap().clone()) .path_and_query(
.build() redirect_uri
.unwrap() .path_and_query()
.context("No path in URL")?
.clone(),
)
.build()?
.to_string() .to_string()
.parse() .parse()?;
.unwrap();
debug!("Rewrote a redirect from {:?} to {:?}", old_value, value); debug!("Rewrote a redirect from {:?} to {:?}", old_value, value);
} }
} }
@ -174,7 +185,8 @@ fn copy_server_request<B: SomeRequestBuilder>(req: &Request, mut target: B) -> B
req.headers() req.headers()
.get_all(k) .get_all(k)
.iter() .iter()
.map(|v| v.to_str().unwrap().to_string()) .map(|v| v.to_str().map(|x| x.to_string()))
.filter_map(|x| x.ok())
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join("; "), .join("; "),
); );
@ -187,7 +199,7 @@ pub async fn proxy_normal_request(
body: Body, body: Body,
options: &TargetHTTPOptions, options: &TargetHTTPOptions,
) -> poem::Result<Response> { ) -> poem::Result<Response> {
let uri = construct_uri(req, &options, false).to_string(); let uri = construct_uri(req, &options, false)?.to_string();
tracing::debug!("URI: {:?}", uri); tracing::debug!("URI: {:?}", uri);
@ -195,15 +207,18 @@ pub async fn proxy_normal_request(
.redirect(reqwest::redirect::Policy::none()) .redirect(reqwest::redirect::Policy::none())
.connection_verbose(true) .connection_verbose(true)
.build() .build()
.unwrap(); .context("Could not build request")?;
let mut client_request = client.request(req.method().into(), uri.clone()); let mut client_request = client.request(req.method().into(), uri.clone());
client_request = copy_server_request(&req, client_request); client_request = copy_server_request(&req, client_request);
client_request = rewrite_request(client_request, options)?; client_request = rewrite_request(client_request, options)?;
client_request = client_request.body(reqwest::Body::wrap_stream(body.into_bytes_stream())); client_request = client_request.body(reqwest::Body::wrap_stream(body.into_bytes_stream()));
let client_request = client_request.build().unwrap(); let client_request = client_request.build().context("Could not build request")?;
let client_response = client.execute(client_request).await.unwrap(); let client_response = client
.execute(client_request)
.await
.context("Could not execute request")?;
let status = client_response.status().clone(); let status = client_response.status().clone();
let mut response: Response = "".into(); let mut response: Response = "".into();
@ -275,7 +290,7 @@ pub async fn proxy_websocket_request(
ws: WebSocket, ws: WebSocket,
options: &TargetHTTPOptions, options: &TargetHTTPOptions,
) -> poem::Result<impl IntoResponse> { ) -> poem::Result<impl IntoResponse> {
let uri = construct_uri(req, &options, true); let uri = construct_uri(req, &options, true)?;
proxy_ws_inner(req, ws, uri.clone(), options) proxy_ws_inner(req, ws, uri.clone(), options)
.await .await
.map_err(|error| { .map_err(|error| {

View file

@ -1,17 +1,19 @@
use crate::common::{PROTOCOL_NAME, SESSION_MAX_AGE}; use std::collections::{BTreeMap, HashMap};
use crate::session_handle::{ use std::sync::{Arc, Weak};
HttpSessionHandle, SessionHandleCommand, WarpgateServerHandleFromRequest, use std::time::{Duration, Instant};
};
use poem::session::{Session, SessionStorage}; use poem::session::{Session, SessionStorage};
use poem::web::{Data, RemoteAddr}; use poem::web::{Data, RemoteAddr};
use poem::{FromRequest, Request}; use poem::{FromRequest, Request};
use serde_json::Value; use serde_json::Value;
use std::collections::{BTreeMap, HashMap};
use std::sync::{Arc, Weak};
use std::time::{Duration, Instant};
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tracing::*; use tracing::*;
use warpgate_common::{Services, SessionId, WarpgateServerHandle, SessionStateInit}; use warpgate_common::{Services, SessionId, SessionStateInit, WarpgateServerHandle};
use crate::common::{PROTOCOL_NAME, SESSION_MAX_AGE};
use crate::session_handle::{
HttpSessionHandle, SessionHandleCommand, WarpgateServerHandleFromRequest,
};
#[derive(Clone)] #[derive(Clone)]
pub struct SharedSessionStorage(pub Arc<Mutex<Box<dyn SessionStorage>>>); pub struct SharedSessionStorage(pub Arc<Mutex<Box<dyn SessionStorage>>>);
@ -124,7 +126,9 @@ impl SessionMiddleware {
session.set(SESSION_ID_SESSION_KEY, id); session.set(SESSION_ID_SESSION_KEY, id);
let this = self.this.upgrade().unwrap(); let Some(this) = self.this.upgrade() else {
return Err(anyhow::anyhow!("Invalid session state").into())
};
tokio::spawn({ tokio::spawn({
let session_storage = (*session_storage).clone(); let session_storage = (*session_storage).clone();
let poem_session_id: Option<String> = session.get(POEM_SESSION_ID_SESSION_KEY); let poem_session_id: Option<String> = session.get(POEM_SESSION_ID_SESSION_KEY);

View file

@ -125,9 +125,9 @@ impl MySqlClient {
let Some(response) = stream.recv().await? else { let Some(response) = stream.recv().await? else {
return Err(MySqlError::Eof) return Err(MySqlError::Eof)
}; };
if response.get(0) == Some(&0) || response.get(0) == Some(&0xfe) { if response.first() == Some(&0) || response.first() == Some(&0xfe) {
debug!("Authorized"); debug!("Authorized");
} else if response.get(0) == Some(&0xff) { } else if response.first() == Some(&0xff) {
let error = ErrPacket::decode_with(response, options.capabilities)?; let error = ErrPacket::decode_with(response, options.capabilities)?;
return Err(MySqlError::ProtocolError(format!( return Err(MySqlError::ProtocolError(format!(
"handshake failed: {:?}", "handshake failed: {:?}",
@ -136,7 +136,7 @@ impl MySqlClient {
} else { } else {
return Err(MySqlError::ProtocolError(format!( return Err(MySqlError::ProtocolError(format!(
"unknown response type {:?}", "unknown response type {:?}",
response.get(0) response.first()
))); )));
} }

View file

@ -139,7 +139,7 @@ impl MySqlSession {
return Err(MySqlError::Eof); return Err(MySqlError::Eof);
}; };
let password = Secret::new(response.clone().get_str_nul()?); let password = Secret::new(response.clone().get_str_nul()?);
return self.run_authorization(resp, password).await; self.run_authorization(resp, password).await
} }
async fn send_error(&mut self, code: u16, message: &str) -> Result<(), MySqlError> { async fn send_error(&mut self, code: u16, message: &str) -> Result<(), MySqlError> {
@ -209,11 +209,9 @@ impl MySqlSession {
); );
return fail(&mut self).await; return fail(&mut self).await;
} }
return self.run_authorized(handshake, username, target_name).await; self.run_authorized(handshake, username, target_name).await
}
AuthResult::Rejected | AuthResult::OtpNeeded => {
return fail(&mut self).await;
} }
AuthResult::Rejected | AuthResult::OtpNeeded => fail(&mut self).await,
} }
} }
AuthSelector::Ticket { secret } => { AuthSelector::Ticket { secret } => {
@ -231,11 +229,10 @@ impl MySqlSession {
.await .await
.map_err(MySqlError::other)?; .map_err(MySqlError::other)?;
return self self.run_authorized(handshake, ticket.username, ticket.target)
.run_authorized(handshake, ticket.username, ticket.target) .await
.await;
} }
_ => return fail(&mut self).await, _ => fail(&mut self).await,
} }
} }
} }
@ -297,10 +294,9 @@ impl MySqlSession {
} }
let span = self.make_logging_span(); let span = self.make_logging_span();
return self self.run_authorized_inner(handshake, mysql_options)
.run_authorized_inner(handshake, mysql_options)
.instrument(span) .instrument(span)
.await; .await
} }
async fn run_authorized_inner( async fn run_authorized_inner(
@ -341,7 +337,7 @@ impl MySqlSession {
}; };
trace!(?payload, "server got packet"); trace!(?payload, "server got packet");
let com = payload.get(0); let com = payload.first();
// COM_QUERY // COM_QUERY
if com == Some(&0x03) { if com == Some(&0x03) {
@ -359,7 +355,7 @@ impl MySqlSession {
trace!(?response, "client got packet"); trace!(?response, "client got packet");
self.stream.push(&&response[..], ())?; self.stream.push(&&response[..], ())?;
self.stream.flush().await?; self.stream.flush().await?;
if let Some(com) = response.get(0) { if let Some(com) = response.first() {
if com == &0xfe { if com == &0xfe {
if self.capabilities.contains(Capabilities::DEPRECATE_EOF) { if self.capabilities.contains(Capabilities::DEPRECATE_EOF) {
break; break;
@ -415,7 +411,7 @@ impl MySqlSession {
trace!(?response, "client got packet"); trace!(?response, "client got packet");
self.stream.push(&&response[..], ())?; self.stream.push(&&response[..], ())?;
self.stream.flush().await?; self.stream.flush().await?;
if let Some(com) = response.get(0) { if let Some(com) = response.first() {
if com == &0 || com == &0xff || com == &0xfe { if com == &0 || com == &0xff || com == &0xfe {
break; break;
} }

View file

@ -1,7 +1,7 @@
use std::fs::{create_dir_all, File}; use std::fs::{create_dir_all, File};
use std::path::PathBuf; use std::path::PathBuf;
use anyhow::Result; use anyhow::{Context, Result};
use russh_keys::key::{KeyPair, SignatureHash}; use russh_keys::key::{KeyPair, SignatureHash};
use russh_keys::{encode_pkcs8_pem, load_secret_key}; use russh_keys::{encode_pkcs8_pem, load_secret_key};
use tracing::*; use tracing::*;
@ -22,7 +22,7 @@ pub fn generate_host_keys(config: &WarpgateConfig) -> Result<()> {
let key_path = path.join("host-ed25519"); let key_path = path.join("host-ed25519");
if !key_path.exists() { if !key_path.exists() {
info!("Generating Ed25519 host key"); info!("Generating Ed25519 host key");
let key = KeyPair::generate_ed25519().unwrap(); let key = KeyPair::generate_ed25519().context("Failed to generate Ed25519 host key")?;
let f = File::create(key_path)?; let f = File::create(key_path)?;
encode_pkcs8_pem(&key, f)?; encode_pkcs8_pem(&key, f)?;
} }
@ -30,7 +30,8 @@ pub fn generate_host_keys(config: &WarpgateConfig) -> Result<()> {
let key_path = path.join("host-rsa"); let key_path = path.join("host-rsa");
if !key_path.exists() { if !key_path.exists() {
info!("Generating RSA host key"); info!("Generating RSA host key");
let key = KeyPair::generate_rsa(4096, SignatureHash::SHA2_512).unwrap(); let key = KeyPair::generate_rsa(4096, SignatureHash::SHA2_512)
.context("Failed to generate RSA key")?;
let f = File::create(key_path)?; let f = File::create(key_path)?;
encode_pkcs8_pem(&key, f)?; encode_pkcs8_pem(&key, f)?;
} }
@ -59,7 +60,7 @@ pub fn generate_client_keys(config: &WarpgateConfig) -> Result<()> {
let key_path = path.join("client-ed25519"); let key_path = path.join("client-ed25519");
if !key_path.exists() { if !key_path.exists() {
info!("Generating Ed25519 client key"); info!("Generating Ed25519 client key");
let key = KeyPair::generate_ed25519().unwrap(); let key = KeyPair::generate_ed25519().context("Failed to generate Ed25519 client key")?;
let f = File::create(key_path)?; let f = File::create(key_path)?;
encode_pkcs8_pem(&key, f)?; encode_pkcs8_pem(&key, f)?;
} }
@ -67,7 +68,8 @@ pub fn generate_client_keys(config: &WarpgateConfig) -> Result<()> {
let key_path = path.join("client-rsa"); let key_path = path.join("client-rsa");
if !key_path.exists() { if !key_path.exists() {
info!("Generating RSA client key"); info!("Generating RSA client key");
let key = KeyPair::generate_rsa(4096, SignatureHash::SHA2_512).unwrap(); let key = KeyPair::generate_rsa(4096, SignatureHash::SHA2_512)
.context("Failed to generate RSA client key")?;
let f = File::create(key_path)?; let f = File::create(key_path)?;
encode_pkcs8_pem(&key, f)?; encode_pkcs8_pem(&key, f)?;
} }

View file

@ -270,7 +270,7 @@ impl ServerSession {
pub async fn maybe_connect_remote(&mut self) -> Result<()> { pub async fn maybe_connect_remote(&mut self) -> Result<()> {
match self.target.clone() { match self.target.clone() {
TargetSelection::None => { TargetSelection::None => {
panic!("Target not set"); anyhow::bail!("Invalid session state (target not set)")
} }
TargetSelection::NotFound(name) => { TargetSelection::NotFound(name) => {
self.emit_service_message(&format!("Selected target not found: {name}")) self.emit_service_message(&format!("Selected target not found: {name}"))
@ -585,6 +585,7 @@ impl ServerSession {
.traffic_recorder_for(&params.host_to_connect, params.port_to_connect) .traffic_recorder_for(&params.host_to_connect, params.port_to_connect)
.await; .await;
if let Some(recorder) = recorder { if let Some(recorder) = recorder {
#[allow(clippy::unwrap_used)]
let mut recorder = recorder.connection(TrafficConnectionParams { let mut recorder = recorder.connection(TrafficConnectionParams {
dst_addr: Ipv4Addr::from_str("2.2.2.2").unwrap(), dst_addr: Ipv4Addr::from_str("2.2.2.2").unwrap(),
dst_port: params.port_to_connect as u16, dst_port: params.port_to_connect as u16,
@ -630,7 +631,7 @@ impl ServerSession {
let _ = self let _ = self
.session_handle .session_handle
.as_mut() .as_mut()
.unwrap() .context("Invalid session state")?
.channel_success(server_channel_id.0) .channel_success(server_channel_id.0)
.await; .await;
self.pty_channels.push(channel_id); self.pty_channels.push(channel_id);
@ -790,7 +791,7 @@ impl ServerSession {
let _ = self let _ = self
.session_handle .session_handle
.as_mut() .as_mut()
.unwrap() .context("Invalid session state")?
.channel_success(server_channel_id.0) .channel_success(server_channel_id.0)
.await; .await;
let _ = self.maybe_connect_remote().await; let _ = self.maybe_connect_remote().await;
@ -1099,8 +1100,7 @@ impl ServerSession {
let channels = all_channels let channels = all_channels
.into_iter() .into_iter()
.map(|x| self.map_channel_reverse(&x)) .map(|x| self.map_channel_reverse(&x))
.filter(|x| x.is_ok()) .filter_map(|x| x.ok())
.map(|x| x.unwrap())
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let _ = self let _ = self

View file

@ -29,7 +29,7 @@ pub fn load_config(path: &Path, secure: bool) -> Result<WarpgateConfig> {
let config = WarpgateConfig { let config = WarpgateConfig {
store, store,
paths_relative_to: path.parent().unwrap().to_path_buf(), paths_relative_to: path.parent().context("FS root reached")?.to_path_buf(),
}; };
info!( info!(
@ -83,7 +83,7 @@ pub async fn watch_config<P: AsRef<Path>>(
) -> Result<()> { ) -> Result<()> {
let (tx, mut rx) = mpsc::channel(1); let (tx, mut rx) = mpsc::channel(1);
let mut watcher = RecommendedWatcher::new(move |res| { let mut watcher = RecommendedWatcher::new(move |res| {
tx.blocking_send(res).unwrap(); let _ = tx.blocking_send(res);
})?; })?;
watcher.configure(notify::Config::PreciseEvents(true))?; watcher.configure(notify::Config::PreciseEvents(true))?;
watcher.watch(path.as_ref(), RecursiveMode::NonRecursive)?; watcher.watch(path.as_ref(), RecursiveMode::NonRecursive)?;

View file

@ -14,8 +14,7 @@ pub async fn init_logging(config: Option<&WarpgateConfig>) {
std::env::set_var("RUST_LOG", "warpgate=info") std::env::set_var("RUST_LOG", "warpgate=info")
} }
let offset = UtcOffset::current_local_offset() let offset = UtcOffset::current_local_offset().unwrap_or(UtcOffset::UTC);
.unwrap_or_else(|_| UtcOffset::from_whole_seconds(0).unwrap());
let env_filter = Arc::new(EnvFilter::from_default_env()); let env_filter = Arc::new(EnvFilter::from_default_env());
let enable_colors = console::user_attended(); let enable_colors = console::user_attended();
@ -40,6 +39,7 @@ pub async fn init_logging(config: Option<&WarpgateConfig>) {
.with_ansi(enable_colors) .with_ansi(enable_colors)
.with_timer(OffsetTime::new( .with_timer(OffsetTime::new(
offset, offset,
#[allow(clippy::unwrap_used)]
format_description::parse("[day].[month].[year] [hour]:[minute]:[second]") format_description::parse("[day].[month].[year] [hour]:[minute]:[second]")
.unwrap(), .unwrap(),
)) ))
@ -56,6 +56,7 @@ pub async fn init_logging(config: Option<&WarpgateConfig>) {
.with_target(false) .with_target(false)
.with_timer(OffsetTime::new( .with_timer(OffsetTime::new(
offset, offset,
#[allow(clippy::unwrap_used)]
format_description::parse("[hour]:[minute]:[second]").unwrap(), format_description::parse("[hour]:[minute]:[second]").unwrap(),
)) ))
.with_filter(dynamic_filter_fn(move |m, c| { .with_filter(dynamic_filter_fn(move |m, c| {