From 8ff3bc792429fb2d79b3033166ebb260befaa373 Mon Sep 17 00:00:00 2001 From: Eugene Pankov Date: Tue, 19 Jul 2022 22:03:35 +0200 Subject: [PATCH] wip --- Cargo.lock | 189 +--- Cargo.toml | 1 + justfile | 2 +- warpgate-common/src/config.rs | 59 +- warpgate-database-protocols/Cargo.toml | 24 + warpgate-database-protocols/src/error.rs | 241 +++++ warpgate-database-protocols/src/io/buf.rs | 59 ++ warpgate-database-protocols/src/io/buf_mut.rs | 12 + .../src/io/buf_stream.rs | 166 ++++ warpgate-database-protocols/src/io/decode.rs | 29 + warpgate-database-protocols/src/io/encode.rs | 16 + warpgate-database-protocols/src/io/mod.rs | 12 + .../src/io/write_and_flush.rs | 47 + warpgate-database-protocols/src/lib.rs | 6 + .../src/mysql/collation.rs | 901 ++++++++++++++++++ .../src/mysql/io/buf.rs | 40 + .../src/mysql/io/buf_mut.rs | 126 +++ .../src/mysql/io/mod.rs | 5 + warpgate-database-protocols/src/mysql/mod.rs | 5 + .../src/mysql/protocol/auth.rs | 38 + .../src/mysql/protocol/capabilities.rs | 86 ++ .../src/mysql/protocol/connect/auth_switch.rs | 58 ++ .../src/mysql/protocol/connect/handshake.rs | 233 +++++ .../protocol/connect/handshake_response.rs | 147 +++ .../src/mysql/protocol/connect/mod.rs | 13 + .../src/mysql/protocol/connect/ssl_request.rs | 30 + .../src/mysql/protocol/mod.rs | 11 + .../src/mysql/protocol/packet.rs | 89 ++ .../src/mysql/protocol/response/eof.rs | 36 + .../src/mysql/protocol/response/err.rs | 81 ++ .../src/mysql/protocol/response/mod.rs | 14 + .../src/mysql/protocol/response/ok.rs | 63 ++ .../src/mysql/protocol/response/status.rs | 49 + .../src/mysql/protocol/row.rs | 17 + .../src/mysql/protocol/text/column.rs | 265 ++++++ .../src/mysql/protocol/text/mod.rs | 9 + .../src/mysql/protocol/text/ping.rs | 13 + .../src/mysql/protocol/text/query.rs | 32 + .../src/mysql/protocol/text/quit.rs | 13 + warpgate-protocol-mysql/Cargo.toml | 7 +- warpgate-protocol-mysql/src/client.rs | 56 +- warpgate-protocol-mysql/src/common.rs | 11 - warpgate-protocol-mysql/src/error.rs | 14 +- warpgate-protocol-mysql/src/session.rs | 16 +- warpgate-protocol-mysql/src/stream.rs | 2 +- 45 files changed, 3101 insertions(+), 242 deletions(-) create mode 100644 warpgate-database-protocols/Cargo.toml create mode 100644 warpgate-database-protocols/src/error.rs create mode 100644 warpgate-database-protocols/src/io/buf.rs create mode 100644 warpgate-database-protocols/src/io/buf_mut.rs create mode 100644 warpgate-database-protocols/src/io/buf_stream.rs create mode 100644 warpgate-database-protocols/src/io/decode.rs create mode 100644 warpgate-database-protocols/src/io/encode.rs create mode 100644 warpgate-database-protocols/src/io/mod.rs create mode 100644 warpgate-database-protocols/src/io/write_and_flush.rs create mode 100644 warpgate-database-protocols/src/lib.rs create mode 100644 warpgate-database-protocols/src/mysql/collation.rs create mode 100644 warpgate-database-protocols/src/mysql/io/buf.rs create mode 100644 warpgate-database-protocols/src/mysql/io/buf_mut.rs create mode 100644 warpgate-database-protocols/src/mysql/io/mod.rs create mode 100644 warpgate-database-protocols/src/mysql/mod.rs create mode 100644 warpgate-database-protocols/src/mysql/protocol/auth.rs create mode 100644 warpgate-database-protocols/src/mysql/protocol/capabilities.rs create mode 100644 warpgate-database-protocols/src/mysql/protocol/connect/auth_switch.rs create mode 100644 warpgate-database-protocols/src/mysql/protocol/connect/handshake.rs create mode 100644 warpgate-database-protocols/src/mysql/protocol/connect/handshake_response.rs create mode 100644 warpgate-database-protocols/src/mysql/protocol/connect/mod.rs create mode 100644 warpgate-database-protocols/src/mysql/protocol/connect/ssl_request.rs create mode 100644 warpgate-database-protocols/src/mysql/protocol/mod.rs create mode 100644 warpgate-database-protocols/src/mysql/protocol/packet.rs create mode 100644 warpgate-database-protocols/src/mysql/protocol/response/eof.rs create mode 100644 warpgate-database-protocols/src/mysql/protocol/response/err.rs create mode 100644 warpgate-database-protocols/src/mysql/protocol/response/mod.rs create mode 100644 warpgate-database-protocols/src/mysql/protocol/response/ok.rs create mode 100644 warpgate-database-protocols/src/mysql/protocol/response/status.rs create mode 100644 warpgate-database-protocols/src/mysql/protocol/row.rs create mode 100644 warpgate-database-protocols/src/mysql/protocol/text/column.rs create mode 100644 warpgate-database-protocols/src/mysql/protocol/text/mod.rs create mode 100644 warpgate-database-protocols/src/mysql/protocol/text/ping.rs create mode 100644 warpgate-database-protocols/src/mysql/protocol/text/query.rs create mode 100644 warpgate-database-protocols/src/mysql/protocol/text/quit.rs diff --git a/Cargo.lock b/Cargo.lock index 00538af..9f0c3cf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -760,12 +760,6 @@ dependencies = [ "tracing-subscriber", ] -[[package]] -name = "const-oid" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e4c78c047431fee22c1a7bb92e00ad095a02a983affe4d8a72e2a2c62c1b94f3" - [[package]] name = "constant_time_eq" version = "0.2.1" @@ -875,16 +869,6 @@ dependencies = [ "lazy_static", ] -[[package]] -name = "crypto-bigint" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03c6a1d5fa1de37e071642dfa44ec552ca5b299adb128fab16138e24b548fd21" -dependencies = [ - "generic-array", - "subtle", -] - [[package]] name = "crypto-common" version = "0.1.3" @@ -976,17 +960,6 @@ dependencies = [ "syn", ] -[[package]] -name = "der" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6919815d73839e7ad218de758883aae3a257ba6759ce7a9992501efbb53d705c" -dependencies = [ - "const-oid", - "crypto-bigint", - "pem-rfc7468", -] - [[package]] name = "derive_more" version = "0.99.17" @@ -1828,9 +1801,6 @@ name = "lazy_static" version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" -dependencies = [ - "spin 0.5.2", -] [[package]] name = "lazycell" @@ -1927,12 +1897,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "libm" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33a33a362ce288760ec6a508b94caaec573ae7d3bbbd91b87aa0bad4456839db" - [[package]] name = "libsodium-sys" version = "0.2.7" @@ -2211,23 +2175,6 @@ dependencies = [ "num-traits", ] -[[package]] -name = "num-bigint-dig" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "566d173b2f9406afbc5510a90925d5a2cd80cae4605631f1212303df265de011" -dependencies = [ - "byteorder", - "lazy_static", - "libm", - "num-integer", - "num-iter", - "num-traits", - "rand", - "smallvec", - "zeroize", -] - [[package]] name = "num-integer" version = "0.1.44" @@ -2267,7 +2214,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9a64b1ec5cda2586e284722486d802acf1f7dbdc623e2bfc57e65ca1cd099290" dependencies = [ "autocfg", - "libm", ] [[package]] @@ -2523,15 +2469,6 @@ dependencies = [ "base64", ] -[[package]] -name = "pem-rfc7468" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01de5d978f34aa4b2296576379fcc416034702fd94117c56ffd8a1a767cefb30" -dependencies = [ - "base64ct", -] - [[package]] name = "percent-encoding" version = "2.1.0" @@ -2588,28 +2525,6 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" -[[package]] -name = "pkcs1" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a78f66c04ccc83dd4486fd46c33896f4e17b24a7a3a6400dedc48ed0ddd72320" -dependencies = [ - "der", - "pkcs8", - "zeroize", -] - -[[package]] -name = "pkcs8" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7cabda3fb821068a9a4fab19a683eac3af12edf0f34b94a8be53c4972b8149d0" -dependencies = [ - "der", - "spki", - "zeroize", -] - [[package]] name = "pkg-config" version = "0.3.25" @@ -3064,26 +2979,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "rsa" -version = "0.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4cf22754c49613d2b3b119f0e5d46e34a2c628a937e3024b8762de4e7d8c710b" -dependencies = [ - "byteorder", - "digest 0.10.3", - "num-bigint-dig", - "num-integer", - "num-iter", - "num-traits", - "pkcs1", - "pkcs8", - "rand_core", - "smallvec", - "subtle", - "zeroize", -] - [[package]] name = "russh" version = "0.34.0-beta.5" @@ -3672,16 +3567,6 @@ dependencies = [ "lock_api", ] -[[package]] -name = "spki" -version = "0.5.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44d01ac02a6ccf3e07db148d2be087da624fea0221a16152ed01f0496a6b0a27" -dependencies = [ - "base64ct", - "der", -] - [[package]] name = "sqlformat" version = "0.1.8" @@ -3741,7 +3626,7 @@ dependencies = [ "sha2 0.10.2", "smallvec", "sqlformat", - "sqlx-rt 0.6.0 (registry+https://github.com/rust-lang/crates.io-index)", + "sqlx-rt", "stringprep", "thiserror", "tokio-stream", @@ -3749,52 +3634,6 @@ dependencies = [ "uuid", ] -[[package]] -name = "sqlx-core-guts" -version = "0.6.0" -dependencies = [ - "ahash", - "atoi", - "bitflags", - "byteorder", - "bytes", - "crc", - "crossbeam-queue", - "digest 0.10.3", - "either", - "event-listener", - "futures-channel", - "futures-core", - "futures-intrusive", - "futures-util", - "generic-array", - "hashlink", - "hex", - "indexmap", - "itoa", - "libc", - "log", - "memchr", - "num-bigint", - "once_cell", - "paste", - "percent-encoding", - "rand", - "rsa", - "rustls", - "rustls-pemfile", - "sha-1", - "sha2 0.10.2", - "smallvec", - "sqlformat", - "sqlx-rt 0.6.0", - "stringprep", - "thiserror", - "tokio-stream", - "url", - "webpki-roots", -] - [[package]] name = "sqlx-macros" version = "0.6.0" @@ -3810,20 +3649,11 @@ dependencies = [ "serde_json", "sha2 0.10.2", "sqlx-core", - "sqlx-rt 0.6.0 (registry+https://github.com/rust-lang/crates.io-index)", + "sqlx-rt", "syn", "url", ] -[[package]] -name = "sqlx-rt" -version = "0.6.0" -dependencies = [ - "once_cell", - "tokio", - "tokio-rustls", -] - [[package]] name = "sqlx-rt" version = "0.6.0" @@ -4598,6 +4428,19 @@ dependencies = [ "warpgate-db-migrations", ] +[[package]] +name = "warpgate-database-protocols" +version = "0.3.0" +dependencies = [ + "bitflags", + "bytes", + "futures-core", + "futures-util", + "memchr", + "thiserror", + "tokio", +] + [[package]] name = "warpgate-db-entities" version = "0.3.0" @@ -4664,7 +4507,6 @@ dependencies = [ "rustls", "rustls-pemfile", "sha1", - "sqlx-core-guts", "thiserror", "tokio", "tokio-rustls", @@ -4672,6 +4514,7 @@ dependencies = [ "uuid", "warpgate-admin", "warpgate-common", + "warpgate-database-protocols", "warpgate-db-entities", "webpki", "webpki-roots", diff --git a/Cargo.toml b/Cargo.toml index c0652d5..0732e18 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,6 +5,7 @@ members = [ "warpgate-common", "warpgate-db-migrations", "warpgate-db-entities", + "warpgate-database-protocols", "warpgate-protocol-http", "warpgate-protocol-mysql", "warpgate-protocol-ssh", diff --git a/justfile b/justfile index c233f68..5522ad6 100644 --- a/justfile +++ b/justfile @@ -1,4 +1,4 @@ -projects := "warpgate warpgate-admin warpgate-common warpgate-db-entities warpgate-db-migrations warpgate-protocol-ssh warpgate-protocol-mysql" +projects := "warpgate warpgate-admin warpgate-common warpgate-db-entities warpgate-db-migrations warpgate-database-protocols warpgate-protocol-ssh warpgate-protocol-mysql" run *ARGS: RUST_BACKTRACE=1 RUST_LOG=warpgate cd warpgate && cargo run -- --config ../config.yaml {{ARGS}} diff --git a/warpgate-common/src/config.rs b/warpgate-common/src/config.rs index d526910..40ec456 100644 --- a/warpgate-common/src/config.rs +++ b/warpgate-common/src/config.rs @@ -3,7 +3,7 @@ use std::net::ToSocketAddrs; use std::path::PathBuf; use std::time::Duration; -use poem_openapi::{Object, Union}; +use poem_openapi::{Enum, Object, Union}; use serde::{Deserialize, Serialize}; use crate::helpers::otp::OtpSecretKey; @@ -17,10 +17,14 @@ const fn _default_false() -> bool { false } -const fn _default_port() -> u16 { +const fn _default_ssh_port() -> u16 { 22 } +const fn _default_mysql_port() -> u16 { + 3306 +} + #[inline] fn _default_username() -> String { "root".to_owned() @@ -64,7 +68,7 @@ fn _default_empty_vec() -> Vec { #[derive(Debug, Deserialize, Serialize, Clone, Object)] pub struct TargetSSHOptions { pub host: String, - #[serde(default = "_default_port")] + #[serde(default = "_default_ssh_port")] pub port: u16, #[serde(default = "_default_username")] pub username: String, @@ -97,10 +101,57 @@ pub struct TargetHTTPOptions { pub headers: Option>, } +#[derive(Debug, Deserialize, Serialize, Clone, Enum, PartialEq, Eq)] +pub enum TlsMode { + Disabled, + Preferred, + Required, +} + +impl Default for TlsMode { + fn default() -> Self { + TlsMode::Preferred + } +} + +#[derive(Debug, Deserialize, Serialize, Clone, Object)] +pub struct Tls { + #[serde(default)] + pub mode: TlsMode, + + #[serde(default)] + pub verify: bool, +} + +#[allow(clippy::derivable_impls)] +impl Default for Tls { + fn default() -> Self { + Self { + mode: TlsMode::default(), + verify: false, + } + } +} + #[derive(Debug, Deserialize, Serialize, Clone, Object)] pub struct TargetMySqlOptions { #[serde(default = "_default_empty_string")] - pub uri: String, + pub host: String, + + #[serde(default = "_default_mysql_port")] + pub port: u16, + + #[serde(default = "_default_username")] + pub username: String, + + #[serde(default)] + pub password: Option, + + #[serde(default)] + pub tls: Tls, + + #[serde(default)] + pub verify_tls: bool, } #[derive(Debug, Deserialize, Serialize, Clone, Object, Default)] diff --git a/warpgate-database-protocols/Cargo.toml b/warpgate-database-protocols/Cargo.toml new file mode 100644 index 0000000..927aa30 --- /dev/null +++ b/warpgate-database-protocols/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "warpgate-database-protocols" +version = "0.3.0" +description = "Core of SQLx, the rust SQL toolkit. Just the database protocol parts." +license = "MIT OR Apache-2.0" +edition = "2021" +authors = [ + "Ryan Leckey ", + "Austin Bonander ", + "Chloe Ross ", + "Daniel Akhterov ", +] + +[dependencies] +tokio = { version = "1.19", features = ["io-util"] } +bitflags = { version = "1.3", default-features = false } +bytes = "1.1" +futures-core = { version = "0.3", default-features = false } +futures-util = { version = "0.3", default-features = false, features = [ + "alloc", + "sink", +] } +memchr = { version = "2.4.1", default-features = false } +thiserror = "1.0" diff --git a/warpgate-database-protocols/src/error.rs b/warpgate-database-protocols/src/error.rs new file mode 100644 index 0000000..9a47e4c --- /dev/null +++ b/warpgate-database-protocols/src/error.rs @@ -0,0 +1,241 @@ +//! Types for working with errors produced by SQLx. + +use std::borrow::Cow; +use std::error::Error as StdError; +use std::fmt::Display; +use std::io; +use std::result::Result as StdResult; + +/// A specialized `Result` type for SQLx. +pub type Result = StdResult; + +// Convenience type alias for usage within SQLx. +// Do not make this type public. +pub type BoxDynError = Box; + +/// An unexpected `NULL` was encountered during decoding. +/// +/// Returned from [`Row::get`](crate::row::Row::get) if the value from the database is `NULL`, +/// and you are not decoding into an `Option`. +#[derive(thiserror::Error, Debug)] +#[error("unexpected null; try decoding as an `Option`")] +pub struct UnexpectedNullError; + +/// Represents all the ways a method can fail within SQLx. +#[derive(Debug, thiserror::Error)] +#[non_exhaustive] +pub enum Error { + /// Error occurred while parsing a connection string. + #[error("error with configuration: {0}")] + Configuration(#[source] BoxDynError), + + /// Error returned from the database. + #[error("error returned from database: {0}")] + Database(#[source] Box), + + /// Error communicating with the database backend. + #[error("error communicating with database: {0}")] + Io(#[from] io::Error), + + /// Error occurred while attempting to establish a TLS connection. + #[error("error occurred while attempting to establish a TLS connection: {0}")] + Tls(#[source] BoxDynError), + + /// Unexpected or invalid data encountered while communicating with the database. + /// + /// This should indicate there is a programming error in a SQLx driver or there + /// is something corrupted with the connection to the database itself. + #[error("encountered unexpected or invalid data: {0}")] + Protocol(String), + + /// No rows returned by a query that expected to return at least one row. + #[error("no rows returned by a query that expected to return at least one row")] + RowNotFound, + + /// Type in query doesn't exist. Likely due to typo or missing user type. + #[error("type named {type_name} not found")] + TypeNotFound { type_name: String }, + + /// Column index was out of bounds. + #[error("column index out of bounds: the len is {len}, but the index is {index}")] + ColumnIndexOutOfBounds { index: usize, len: usize }, + + /// No column found for the given name. + #[error("no column found for name: {0}")] + ColumnNotFound(String), + + /// Error occurred while decoding a value from a specific column. + #[error("error occurred while decoding column {index}: {source}")] + ColumnDecode { + index: String, + + #[source] + source: BoxDynError, + }, + + /// Error occurred while decoding a value. + #[error("error occurred while decoding: {0}")] + Decode(#[source] BoxDynError), + + /// A [`Pool::acquire`] timed out due to connections not becoming available or + /// because another task encountered too many errors while trying to open a new connection. + /// + /// [`Pool::acquire`]: crate::pool::Pool::acquire + #[error("pool timed out while waiting for an open connection")] + PoolTimedOut, + + /// [`Pool::close`] was called while we were waiting in [`Pool::acquire`]. + /// + /// [`Pool::acquire`]: crate::pool::Pool::acquire + /// [`Pool::close`]: crate::pool::Pool::close + #[error("attempted to acquire a connection on a closed pool")] + PoolClosed, + + /// A background worker has crashed. + #[error("attempted to communicate with a crashed background worker")] + WorkerCrashed, + + #[cfg(feature = "migrate")] + #[error("{0}")] + Migrate(#[source] Box), +} + +impl StdError for Box {} + +impl Error { + #[allow(dead_code)] + #[inline] + pub(crate) fn protocol(err: impl Display) -> Self { + Error::Protocol(err.to_string()) + } + + #[allow(dead_code)] + #[inline] + pub(crate) fn config(err: impl StdError + Send + Sync + 'static) -> Self { + Error::Configuration(err.into()) + } +} + +/// An error that was returned from the database. +pub trait DatabaseError: 'static + Send + Sync + StdError { + /// The primary, human-readable error message. + fn message(&self) -> &str; + + /// The (SQLSTATE) code for the error. + fn code(&self) -> Option> { + None + } + + #[doc(hidden)] + fn as_error(&self) -> &(dyn StdError + Send + Sync + 'static); + + #[doc(hidden)] + fn as_error_mut(&mut self) -> &mut (dyn StdError + Send + Sync + 'static); + + #[doc(hidden)] + fn into_error(self: Box) -> Box; + + #[doc(hidden)] + fn is_transient_in_connect_phase(&self) -> bool { + false + } + + /// Returns the name of the constraint that triggered the error, if applicable. + /// If the error was caused by a conflict of a unique index, this will be the index name. + /// + /// ### Note + /// Currently only populated by the Postgres driver. + fn constraint(&self) -> Option<&str> { + None + } +} + +impl dyn DatabaseError { + /// Downcast a reference to this generic database error to a specific + /// database error type. + /// + /// # Panics + /// + /// Panics if the database error type is not `E`. This is a deliberate contrast from + /// `Error::downcast_ref` which returns `Option<&E>`. In normal usage, you should know the + /// specific error type. In other cases, use `try_downcast_ref`. + pub fn downcast_ref(&self) -> &E { + self.try_downcast_ref().unwrap_or_else(|| { + panic!( + "downcast to wrong DatabaseError type; original error: {}", + self + ) + }) + } + + /// Downcast this generic database error to a specific database error type. + /// + /// # Panics + /// + /// Panics if the database error type is not `E`. This is a deliberate contrast from + /// `Error::downcast` which returns `Option`. In normal usage, you should know the + /// specific error type. In other cases, use `try_downcast`. + pub fn downcast(self: Box) -> Box { + self.try_downcast().unwrap_or_else(|e| { + panic!( + "downcast to wrong DatabaseError type; original error: {}", + e + ) + }) + } + + /// Downcast a reference to this generic database error to a specific + /// database error type. + #[inline] + pub fn try_downcast_ref(&self) -> Option<&E> { + self.as_error().downcast_ref() + } + + /// Downcast this generic database error to a specific database error type. + #[inline] + pub fn try_downcast(self: Box) -> StdResult, Box> { + if self.as_error().is::() { + Ok(self.into_error().downcast().unwrap()) + } else { + Err(self) + } + } +} + +impl From for Error +where + E: DatabaseError, +{ + #[inline] + fn from(error: E) -> Self { + Error::Database(Box::new(error)) + } +} + +#[cfg(feature = "migrate")] +impl From for Error { + #[inline] + fn from(error: crate::migrate::MigrateError) -> Self { + Error::Migrate(Box::new(error)) + } +} + +#[cfg(feature = "_tls-native-tls")] +impl From for Error { + #[inline] + fn from(error: sqlx_rt::native_tls::Error) -> Self { + Error::Tls(Box::new(error)) + } +} + +// Format an error message as a `Protocol` error +#[macro_export] +macro_rules! err_protocol { + ($expr:expr) => { + $crate::error::Error::Protocol($expr.into()) + }; + + ($fmt:expr, $($arg:tt)*) => { + $crate::error::Error::Protocol(format!($fmt, $($arg)*)) + }; +} diff --git a/warpgate-database-protocols/src/io/buf.rs b/warpgate-database-protocols/src/io/buf.rs new file mode 100644 index 0000000..f73764c --- /dev/null +++ b/warpgate-database-protocols/src/io/buf.rs @@ -0,0 +1,59 @@ +use std::str::from_utf8; + +use bytes::{Buf, Bytes}; +use memchr::memchr; + +use crate::err_protocol; +use crate::error::Error; + +pub trait BufExt: Buf { + // Read a nul-terminated byte sequence + fn get_bytes_nul(&mut self) -> Result; + + // Read a byte sequence of the exact length + fn get_bytes(&mut self, len: usize) -> Bytes; + + // Read a nul-terminated string + fn get_str_nul(&mut self) -> Result; + + // Read a string of the exact length + fn get_str(&mut self, len: usize) -> Result; +} + +impl BufExt for Bytes { + fn get_bytes_nul(&mut self) -> Result { + let nul = + memchr(b'\0', self).ok_or_else(|| err_protocol!("expected NUL in byte sequence"))?; + + let v = self.slice(0..nul); + + self.advance(nul + 1); + + Ok(v) + } + + fn get_bytes(&mut self, len: usize) -> Bytes { + let v = self.slice(..len); + self.advance(len); + + v + } + + fn get_str_nul(&mut self) -> Result { + self.get_bytes_nul().and_then(|bytes| { + from_utf8(&*bytes) + .map(ToOwned::to_owned) + .map_err(|err| err_protocol!("{}", err)) + }) + } + + fn get_str(&mut self, len: usize) -> Result { + let v = from_utf8(&self[..len]) + .map_err(|err| err_protocol!("{}", err)) + .map(ToOwned::to_owned)?; + + self.advance(len); + + Ok(v) + } +} diff --git a/warpgate-database-protocols/src/io/buf_mut.rs b/warpgate-database-protocols/src/io/buf_mut.rs new file mode 100644 index 0000000..565d850 --- /dev/null +++ b/warpgate-database-protocols/src/io/buf_mut.rs @@ -0,0 +1,12 @@ +use bytes::BufMut; + +pub trait BufMutExt: BufMut { + fn put_str_nul(&mut self, s: &str); +} + +impl BufMutExt for Vec { + fn put_str_nul(&mut self, s: &str) { + self.extend(s.as_bytes()); + self.push(0); + } +} diff --git a/warpgate-database-protocols/src/io/buf_stream.rs b/warpgate-database-protocols/src/io/buf_stream.rs new file mode 100644 index 0000000..6711e3b --- /dev/null +++ b/warpgate-database-protocols/src/io/buf_stream.rs @@ -0,0 +1,166 @@ +#![allow(dead_code)] + +use std::io; +use std::io::Cursor; +use std::ops::{Deref, DerefMut}; + +use bytes::BytesMut; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; + +use crate::error::Error; +use crate::io::decode::Decode; +use crate::io::encode::Encode; +use crate::io::write_and_flush::WriteAndFlush; + +pub struct BufStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + pub(crate) stream: S, + + // writes with `write` to the underlying stream are buffered + // this can be flushed with `flush` + pub(crate) wbuf: Vec, + + // we read into the read buffer using 100% safe code + rbuf: BytesMut, +} + +impl BufStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + pub fn new(stream: S) -> Self { + Self { + stream, + wbuf: Vec::with_capacity(512), + rbuf: BytesMut::with_capacity(4096), + } + } + + pub fn write<'en, T>(&mut self, value: T) + where + T: Encode<'en, ()>, + { + self.write_with(value, ()) + } + + pub fn write_with<'en, T, C>(&mut self, value: T, context: C) + where + T: Encode<'en, C>, + { + value.encode_with(&mut self.wbuf, context); + } + + pub fn flush(&mut self) -> WriteAndFlush<'_, S> { + WriteAndFlush { + stream: &mut self.stream, + buf: Cursor::new(&mut self.wbuf), + } + } + + pub async fn read<'de, T>(&mut self, cnt: usize) -> Result + where + T: Decode<'de, ()>, + { + self.read_with(cnt, ()).await + } + + pub async fn read_with<'de, T, C>(&mut self, cnt: usize, context: C) -> Result + where + T: Decode<'de, C>, + { + T::decode_with(self.read_raw(cnt).await?.freeze(), context) + } + + pub async fn read_raw(&mut self, cnt: usize) -> Result { + read_raw_into(&mut self.stream, &mut self.rbuf, cnt).await?; + let buf = self.rbuf.split_to(cnt); + + Ok(buf) + } + + pub async fn read_raw_into(&mut self, buf: &mut BytesMut, cnt: usize) -> Result<(), Error> { + read_raw_into(&mut self.stream, buf, cnt).await + } + + pub fn take(self) -> S { + self.stream + } +} + +impl Deref for BufStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + type Target = S; + + fn deref(&self) -> &Self::Target { + &self.stream + } +} + +impl DerefMut for BufStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.stream + } +} + +// Holds a buffer which has been temporarily extended, so that +// we can read into it. Automatically shrinks the buffer back +// down if the read is cancelled. +struct BufTruncator<'a> { + buf: &'a mut BytesMut, + filled_len: usize, +} + +impl<'a> BufTruncator<'a> { + fn new(buf: &'a mut BytesMut) -> Self { + let filled_len = buf.len(); + Self { buf, filled_len } + } + fn reserve(&mut self, space: usize) { + self.buf.resize(self.filled_len + space, 0); + } + async fn read(&mut self, stream: &mut S) -> Result { + let n = stream.read(&mut self.buf[self.filled_len..]).await?; + self.filled_len += n; + Ok(n) + } + fn is_full(&self) -> bool { + self.filled_len >= self.buf.len() + } +} + +impl Drop for BufTruncator<'_> { + fn drop(&mut self) { + self.buf.truncate(self.filled_len); + } +} + +async fn read_raw_into( + stream: &mut S, + buf: &mut BytesMut, + cnt: usize, +) -> Result<(), Error> { + let mut buf = BufTruncator::new(buf); + buf.reserve(cnt); + + while !buf.is_full() { + let n = buf.read(stream).await?; + + if n == 0 { + // a zero read when we had space in the read buffer + // should be treated as an EOF + + // and an unexpected EOF means the server told us to go away + + return Err(io::Error::from(io::ErrorKind::ConnectionAborted).into()); + } + } + + Ok(()) +} diff --git a/warpgate-database-protocols/src/io/decode.rs b/warpgate-database-protocols/src/io/decode.rs new file mode 100644 index 0000000..2f39712 --- /dev/null +++ b/warpgate-database-protocols/src/io/decode.rs @@ -0,0 +1,29 @@ +use bytes::Bytes; + +use crate::error::Error; + +pub trait Decode<'de, Context = ()> +where + Self: Sized, +{ + fn decode(buf: Bytes) -> Result + where + Self: Decode<'de, ()>, + { + Self::decode_with(buf, ()) + } + + fn decode_with(buf: Bytes, context: Context) -> Result; +} + +impl Decode<'_> for Bytes { + fn decode_with(buf: Bytes, _: ()) -> Result { + Ok(buf) + } +} + +impl Decode<'_> for () { + fn decode_with(_: Bytes, _: ()) -> Result<(), Error> { + Ok(()) + } +} diff --git a/warpgate-database-protocols/src/io/encode.rs b/warpgate-database-protocols/src/io/encode.rs new file mode 100644 index 0000000..a417ef9 --- /dev/null +++ b/warpgate-database-protocols/src/io/encode.rs @@ -0,0 +1,16 @@ +pub trait Encode<'en, Context = ()> { + fn encode(&self, buf: &mut Vec) + where + Self: Encode<'en, ()>, + { + self.encode_with(buf, ()); + } + + fn encode_with(&self, buf: &mut Vec, context: Context); +} + +impl<'en, C> Encode<'en, C> for &'_ [u8] { + fn encode_with(&self, buf: &mut Vec, _: C) { + buf.extend_from_slice(self); + } +} diff --git a/warpgate-database-protocols/src/io/mod.rs b/warpgate-database-protocols/src/io/mod.rs new file mode 100644 index 0000000..f994965 --- /dev/null +++ b/warpgate-database-protocols/src/io/mod.rs @@ -0,0 +1,12 @@ +mod buf; +mod buf_mut; +mod buf_stream; +mod decode; +mod encode; +mod write_and_flush; + +pub use buf::BufExt; +pub use buf_mut::BufMutExt; +pub use buf_stream::BufStream; +pub use decode::Decode; +pub use encode::Encode; diff --git a/warpgate-database-protocols/src/io/write_and_flush.rs b/warpgate-database-protocols/src/io/write_and_flush.rs new file mode 100644 index 0000000..8d37d34 --- /dev/null +++ b/warpgate-database-protocols/src/io/write_and_flush.rs @@ -0,0 +1,47 @@ +use std::io::{BufRead, Cursor}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use futures_core::Future; +use futures_util::ready; +use tokio::io::AsyncWrite; + +use crate::error::Error; + +// Atomic operation that writes the full buffer to the stream, flushes the stream, and then +// clears the buffer (even if either of the two previous operations failed). +pub struct WriteAndFlush<'a, S> { + pub(super) stream: &'a mut S, + pub(super) buf: Cursor<&'a mut Vec>, +} + +impl Future for WriteAndFlush<'_, S> { + type Output = Result<(), Error>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let Self { + ref mut stream, + ref mut buf, + } = *self; + + loop { + let read = buf.fill_buf()?; + + if !read.is_empty() { + let written = ready!(Pin::new(&mut *stream).poll_write(cx, read)?); + buf.consume(written); + } else { + break; + } + } + + Pin::new(stream).poll_flush(cx).map_err(Error::Io) + } +} + +impl<'a, S> Drop for WriteAndFlush<'a, S> { + fn drop(&mut self) { + // clear the buffer regardless of whether the flush succeeded or not + self.buf.get_mut().clear(); + } +} diff --git a/warpgate-database-protocols/src/lib.rs b/warpgate-database-protocols/src/lib.rs new file mode 100644 index 0000000..d75cfe5 --- /dev/null +++ b/warpgate-database-protocols/src/lib.rs @@ -0,0 +1,6 @@ +#![allow(dead_code)] +pub mod io; +pub mod mysql; + +#[macro_use] +pub mod error; diff --git a/warpgate-database-protocols/src/mysql/collation.rs b/warpgate-database-protocols/src/mysql/collation.rs new file mode 100644 index 0000000..0247679 --- /dev/null +++ b/warpgate-database-protocols/src/mysql/collation.rs @@ -0,0 +1,901 @@ +use std::str::FromStr; + +use crate::error::Error; + +#[allow(non_camel_case_types)] +#[derive(Copy, Clone)] +pub(crate) enum CharSet { + armscii8, + ascii, + big5, + binary, + cp1250, + cp1251, + cp1256, + cp1257, + cp850, + cp852, + cp866, + cp932, + dec8, + eucjpms, + euckr, + gb18030, + gb2312, + gbk, + geostd8, + greek, + hebrew, + hp8, + keybcs2, + koi8r, + koi8u, + latin1, + latin2, + latin5, + latin7, + macce, + macroman, + sjis, + swe7, + tis620, + ucs2, + ujis, + utf16, + utf16le, + utf32, + utf8, + utf8mb4, +} + +impl CharSet { + pub(crate) fn as_str(&self) -> &'static str { + match self { + CharSet::armscii8 => "armscii8", + CharSet::ascii => "ascii", + CharSet::big5 => "big5", + CharSet::binary => "binary", + CharSet::cp1250 => "cp1250", + CharSet::cp1251 => "cp1251", + CharSet::cp1256 => "cp1256", + CharSet::cp1257 => "cp1257", + CharSet::cp850 => "cp850", + CharSet::cp852 => "cp852", + CharSet::cp866 => "cp866", + CharSet::cp932 => "cp932", + CharSet::dec8 => "dec8", + CharSet::eucjpms => "eucjpms", + CharSet::euckr => "euckr", + CharSet::gb18030 => "gb18030", + CharSet::gb2312 => "gb2312", + CharSet::gbk => "gbk", + CharSet::geostd8 => "geostd8", + CharSet::greek => "greek", + CharSet::hebrew => "hebrew", + CharSet::hp8 => "hp8", + CharSet::keybcs2 => "keybcs2", + CharSet::koi8r => "koi8r", + CharSet::koi8u => "koi8u", + CharSet::latin1 => "latin1", + CharSet::latin2 => "latin2", + CharSet::latin5 => "latin5", + CharSet::latin7 => "latin7", + CharSet::macce => "macce", + CharSet::macroman => "macroman", + CharSet::sjis => "sjis", + CharSet::swe7 => "swe7", + CharSet::tis620 => "tis620", + CharSet::ucs2 => "ucs2", + CharSet::ujis => "ujis", + CharSet::utf16 => "utf16", + CharSet::utf16le => "utf16le", + CharSet::utf32 => "utf32", + CharSet::utf8 => "utf8", + CharSet::utf8mb4 => "utf8mb4", + } + } + + pub(crate) fn default_collation(&self) -> Collation { + match self { + CharSet::armscii8 => Collation::armscii8_general_ci, + CharSet::ascii => Collation::ascii_general_ci, + CharSet::big5 => Collation::big5_chinese_ci, + CharSet::binary => Collation::binary, + CharSet::cp1250 => Collation::cp1250_general_ci, + CharSet::cp1251 => Collation::cp1251_general_ci, + CharSet::cp1256 => Collation::cp1256_general_ci, + CharSet::cp1257 => Collation::cp1257_general_ci, + CharSet::cp850 => Collation::cp850_general_ci, + CharSet::cp852 => Collation::cp852_general_ci, + CharSet::cp866 => Collation::cp866_general_ci, + CharSet::cp932 => Collation::cp932_japanese_ci, + CharSet::dec8 => Collation::dec8_swedish_ci, + CharSet::eucjpms => Collation::eucjpms_japanese_ci, + CharSet::euckr => Collation::euckr_korean_ci, + CharSet::gb18030 => Collation::gb18030_chinese_ci, + CharSet::gb2312 => Collation::gb2312_chinese_ci, + CharSet::gbk => Collation::gbk_chinese_ci, + CharSet::geostd8 => Collation::geostd8_general_ci, + CharSet::greek => Collation::greek_general_ci, + CharSet::hebrew => Collation::hebrew_general_ci, + CharSet::hp8 => Collation::hp8_english_ci, + CharSet::keybcs2 => Collation::keybcs2_general_ci, + CharSet::koi8r => Collation::koi8r_general_ci, + CharSet::koi8u => Collation::koi8u_general_ci, + CharSet::latin1 => Collation::latin1_swedish_ci, + CharSet::latin2 => Collation::latin2_general_ci, + CharSet::latin5 => Collation::latin5_turkish_ci, + CharSet::latin7 => Collation::latin7_general_ci, + CharSet::macce => Collation::macce_general_ci, + CharSet::macroman => Collation::macroman_general_ci, + CharSet::sjis => Collation::sjis_japanese_ci, + CharSet::swe7 => Collation::swe7_swedish_ci, + CharSet::tis620 => Collation::tis620_thai_ci, + CharSet::ucs2 => Collation::ucs2_general_ci, + CharSet::ujis => Collation::ujis_japanese_ci, + CharSet::utf16 => Collation::utf16_general_ci, + CharSet::utf16le => Collation::utf16le_general_ci, + CharSet::utf32 => Collation::utf32_general_ci, + CharSet::utf8 => Collation::utf8_unicode_ci, + CharSet::utf8mb4 => Collation::utf8mb4_unicode_ci, + } + } +} + +impl FromStr for CharSet { + type Err = Error; + + fn from_str(char_set: &str) -> Result { + Ok(match char_set { + "armscii8" => CharSet::armscii8, + "ascii" => CharSet::ascii, + "big5" => CharSet::big5, + "binary" => CharSet::binary, + "cp1250" => CharSet::cp1250, + "cp1251" => CharSet::cp1251, + "cp1256" => CharSet::cp1256, + "cp1257" => CharSet::cp1257, + "cp850" => CharSet::cp850, + "cp852" => CharSet::cp852, + "cp866" => CharSet::cp866, + "cp932" => CharSet::cp932, + "dec8" => CharSet::dec8, + "eucjpms" => CharSet::eucjpms, + "euckr" => CharSet::euckr, + "gb18030" => CharSet::gb18030, + "gb2312" => CharSet::gb2312, + "gbk" => CharSet::gbk, + "geostd8" => CharSet::geostd8, + "greek" => CharSet::greek, + "hebrew" => CharSet::hebrew, + "hp8" => CharSet::hp8, + "keybcs2" => CharSet::keybcs2, + "koi8r" => CharSet::koi8r, + "koi8u" => CharSet::koi8u, + "latin1" => CharSet::latin1, + "latin2" => CharSet::latin2, + "latin5" => CharSet::latin5, + "latin7" => CharSet::latin7, + "macce" => CharSet::macce, + "macroman" => CharSet::macroman, + "sjis" => CharSet::sjis, + "swe7" => CharSet::swe7, + "tis620" => CharSet::tis620, + "ucs2" => CharSet::ucs2, + "ujis" => CharSet::ujis, + "utf16" => CharSet::utf16, + "utf16le" => CharSet::utf16le, + "utf32" => CharSet::utf32, + "utf8" => CharSet::utf8, + "utf8mb4" => CharSet::utf8mb4, + + _ => { + return Err(Error::Configuration( + format!("unsupported MySQL charset: {}", char_set).into(), + )); + } + }) + } +} + +#[derive(Copy, Clone)] +#[allow(non_camel_case_types)] +#[repr(u8)] +pub(crate) enum Collation { + armscii8_bin = 64, + armscii8_general_ci = 32, + ascii_bin = 65, + ascii_general_ci = 11, + big5_bin = 84, + big5_chinese_ci = 1, + binary = 63, + cp1250_bin = 66, + cp1250_croatian_ci = 44, + cp1250_czech_cs = 34, + cp1250_general_ci = 26, + cp1250_polish_ci = 99, + cp1251_bin = 50, + cp1251_bulgarian_ci = 14, + cp1251_general_ci = 51, + cp1251_general_cs = 52, + cp1251_ukrainian_ci = 23, + cp1256_bin = 67, + cp1256_general_ci = 57, + cp1257_bin = 58, + cp1257_general_ci = 59, + cp1257_lithuanian_ci = 29, + cp850_bin = 80, + cp850_general_ci = 4, + cp852_bin = 81, + cp852_general_ci = 40, + cp866_bin = 68, + cp866_general_ci = 36, + cp932_bin = 96, + cp932_japanese_ci = 95, + dec8_bin = 69, + dec8_swedish_ci = 3, + eucjpms_bin = 98, + eucjpms_japanese_ci = 97, + euckr_bin = 85, + euckr_korean_ci = 19, + gb18030_bin = 249, + gb18030_chinese_ci = 248, + gb18030_unicode_520_ci = 250, + gb2312_bin = 86, + gb2312_chinese_ci = 24, + gbk_bin = 87, + gbk_chinese_ci = 28, + geostd8_bin = 93, + geostd8_general_ci = 92, + greek_bin = 70, + greek_general_ci = 25, + hebrew_bin = 71, + hebrew_general_ci = 16, + hp8_bin = 72, + hp8_english_ci = 6, + keybcs2_bin = 73, + keybcs2_general_ci = 37, + koi8r_bin = 74, + koi8r_general_ci = 7, + koi8u_bin = 75, + koi8u_general_ci = 22, + latin1_bin = 47, + latin1_danish_ci = 15, + latin1_general_ci = 48, + latin1_general_cs = 49, + latin1_german1_ci = 5, + latin1_german2_ci = 31, + latin1_spanish_ci = 94, + latin1_swedish_ci = 8, + latin2_bin = 77, + latin2_croatian_ci = 27, + latin2_czech_cs = 2, + latin2_general_ci = 9, + latin2_hungarian_ci = 21, + latin5_bin = 78, + latin5_turkish_ci = 30, + latin7_bin = 79, + latin7_estonian_cs = 20, + latin7_general_ci = 41, + latin7_general_cs = 42, + macce_bin = 43, + macce_general_ci = 38, + macroman_bin = 53, + macroman_general_ci = 39, + sjis_bin = 88, + sjis_japanese_ci = 13, + swe7_bin = 82, + swe7_swedish_ci = 10, + tis620_bin = 89, + tis620_thai_ci = 18, + ucs2_bin = 90, + ucs2_croatian_ci = 149, + ucs2_czech_ci = 138, + ucs2_danish_ci = 139, + ucs2_esperanto_ci = 145, + ucs2_estonian_ci = 134, + ucs2_general_ci = 35, + ucs2_general_mysql500_ci = 159, + ucs2_german2_ci = 148, + ucs2_hungarian_ci = 146, + ucs2_icelandic_ci = 129, + ucs2_latvian_ci = 130, + ucs2_lithuanian_ci = 140, + ucs2_persian_ci = 144, + ucs2_polish_ci = 133, + ucs2_roman_ci = 143, + ucs2_romanian_ci = 131, + ucs2_sinhala_ci = 147, + ucs2_slovak_ci = 141, + ucs2_slovenian_ci = 132, + ucs2_spanish_ci = 135, + ucs2_spanish2_ci = 142, + ucs2_swedish_ci = 136, + ucs2_turkish_ci = 137, + ucs2_unicode_520_ci = 150, + ucs2_unicode_ci = 128, + ucs2_vietnamese_ci = 151, + ujis_bin = 91, + ujis_japanese_ci = 12, + utf16_bin = 55, + utf16_croatian_ci = 122, + utf16_czech_ci = 111, + utf16_danish_ci = 112, + utf16_esperanto_ci = 118, + utf16_estonian_ci = 107, + utf16_general_ci = 54, + utf16_german2_ci = 121, + utf16_hungarian_ci = 119, + utf16_icelandic_ci = 102, + utf16_latvian_ci = 103, + utf16_lithuanian_ci = 113, + utf16_persian_ci = 117, + utf16_polish_ci = 106, + utf16_roman_ci = 116, + utf16_romanian_ci = 104, + utf16_sinhala_ci = 120, + utf16_slovak_ci = 114, + utf16_slovenian_ci = 105, + utf16_spanish_ci = 108, + utf16_spanish2_ci = 115, + utf16_swedish_ci = 109, + utf16_turkish_ci = 110, + utf16_unicode_520_ci = 123, + utf16_unicode_ci = 101, + utf16_vietnamese_ci = 124, + utf16le_bin = 62, + utf16le_general_ci = 56, + utf32_bin = 61, + utf32_croatian_ci = 181, + utf32_czech_ci = 170, + utf32_danish_ci = 171, + utf32_esperanto_ci = 177, + utf32_estonian_ci = 166, + utf32_general_ci = 60, + utf32_german2_ci = 180, + utf32_hungarian_ci = 178, + utf32_icelandic_ci = 161, + utf32_latvian_ci = 162, + utf32_lithuanian_ci = 172, + utf32_persian_ci = 176, + utf32_polish_ci = 165, + utf32_roman_ci = 175, + utf32_romanian_ci = 163, + utf32_sinhala_ci = 179, + utf32_slovak_ci = 173, + utf32_slovenian_ci = 164, + utf32_spanish_ci = 167, + utf32_spanish2_ci = 174, + utf32_swedish_ci = 168, + utf32_turkish_ci = 169, + utf32_unicode_520_ci = 182, + utf32_unicode_ci = 160, + utf32_vietnamese_ci = 183, + utf8_bin = 83, + utf8_croatian_ci = 213, + utf8_czech_ci = 202, + utf8_danish_ci = 203, + utf8_esperanto_ci = 209, + utf8_estonian_ci = 198, + utf8_general_ci = 33, + utf8_general_mysql500_ci = 223, + utf8_german2_ci = 212, + utf8_hungarian_ci = 210, + utf8_icelandic_ci = 193, + utf8_latvian_ci = 194, + utf8_lithuanian_ci = 204, + utf8_persian_ci = 208, + utf8_polish_ci = 197, + utf8_roman_ci = 207, + utf8_romanian_ci = 195, + utf8_sinhala_ci = 211, + utf8_slovak_ci = 205, + utf8_slovenian_ci = 196, + utf8_spanish_ci = 199, + utf8_spanish2_ci = 206, + utf8_swedish_ci = 200, + utf8_tolower_ci = 76, + utf8_turkish_ci = 201, + utf8_unicode_520_ci = 214, + utf8_unicode_ci = 192, + utf8_vietnamese_ci = 215, + utf8mb4_0900_ai_ci = 255, + utf8mb4_bin = 46, + utf8mb4_croatian_ci = 245, + utf8mb4_czech_ci = 234, + utf8mb4_danish_ci = 235, + utf8mb4_esperanto_ci = 241, + utf8mb4_estonian_ci = 230, + utf8mb4_general_ci = 45, + utf8mb4_german2_ci = 244, + utf8mb4_hungarian_ci = 242, + utf8mb4_icelandic_ci = 225, + utf8mb4_latvian_ci = 226, + utf8mb4_lithuanian_ci = 236, + utf8mb4_persian_ci = 240, + utf8mb4_polish_ci = 229, + utf8mb4_roman_ci = 239, + utf8mb4_romanian_ci = 227, + utf8mb4_sinhala_ci = 243, + utf8mb4_slovak_ci = 237, + utf8mb4_slovenian_ci = 228, + utf8mb4_spanish_ci = 231, + utf8mb4_spanish2_ci = 238, + utf8mb4_swedish_ci = 232, + utf8mb4_turkish_ci = 233, + utf8mb4_unicode_520_ci = 246, + utf8mb4_unicode_ci = 224, + utf8mb4_vietnamese_ci = 247, +} + +impl Collation { + pub(crate) fn as_str(&self) -> &'static str { + match self { + Collation::armscii8_bin => "armscii8_bin", + Collation::armscii8_general_ci => "armscii8_general_ci", + Collation::ascii_bin => "ascii_bin", + Collation::ascii_general_ci => "ascii_general_ci", + Collation::big5_bin => "big5_bin", + Collation::big5_chinese_ci => "big5_chinese_ci", + Collation::binary => "binary", + Collation::cp1250_bin => "cp1250_bin", + Collation::cp1250_croatian_ci => "cp1250_croatian_ci", + Collation::cp1250_czech_cs => "cp1250_czech_cs", + Collation::cp1250_general_ci => "cp1250_general_ci", + Collation::cp1250_polish_ci => "cp1250_polish_ci", + Collation::cp1251_bin => "cp1251_bin", + Collation::cp1251_bulgarian_ci => "cp1251_bulgarian_ci", + Collation::cp1251_general_ci => "cp1251_general_ci", + Collation::cp1251_general_cs => "cp1251_general_cs", + Collation::cp1251_ukrainian_ci => "cp1251_ukrainian_ci", + Collation::cp1256_bin => "cp1256_bin", + Collation::cp1256_general_ci => "cp1256_general_ci", + Collation::cp1257_bin => "cp1257_bin", + Collation::cp1257_general_ci => "cp1257_general_ci", + Collation::cp1257_lithuanian_ci => "cp1257_lithuanian_ci", + Collation::cp850_bin => "cp850_bin", + Collation::cp850_general_ci => "cp850_general_ci", + Collation::cp852_bin => "cp852_bin", + Collation::cp852_general_ci => "cp852_general_ci", + Collation::cp866_bin => "cp866_bin", + Collation::cp866_general_ci => "cp866_general_ci", + Collation::cp932_bin => "cp932_bin", + Collation::cp932_japanese_ci => "cp932_japanese_ci", + Collation::dec8_bin => "dec8_bin", + Collation::dec8_swedish_ci => "dec8_swedish_ci", + Collation::eucjpms_bin => "eucjpms_bin", + Collation::eucjpms_japanese_ci => "eucjpms_japanese_ci", + Collation::euckr_bin => "euckr_bin", + Collation::euckr_korean_ci => "euckr_korean_ci", + Collation::gb18030_bin => "gb18030_bin", + Collation::gb18030_chinese_ci => "gb18030_chinese_ci", + Collation::gb18030_unicode_520_ci => "gb18030_unicode_520_ci", + Collation::gb2312_bin => "gb2312_bin", + Collation::gb2312_chinese_ci => "gb2312_chinese_ci", + Collation::gbk_bin => "gbk_bin", + Collation::gbk_chinese_ci => "gbk_chinese_ci", + Collation::geostd8_bin => "geostd8_bin", + Collation::geostd8_general_ci => "geostd8_general_ci", + Collation::greek_bin => "greek_bin", + Collation::greek_general_ci => "greek_general_ci", + Collation::hebrew_bin => "hebrew_bin", + Collation::hebrew_general_ci => "hebrew_general_ci", + Collation::hp8_bin => "hp8_bin", + Collation::hp8_english_ci => "hp8_english_ci", + Collation::keybcs2_bin => "keybcs2_bin", + Collation::keybcs2_general_ci => "keybcs2_general_ci", + Collation::koi8r_bin => "koi8r_bin", + Collation::koi8r_general_ci => "koi8r_general_ci", + Collation::koi8u_bin => "koi8u_bin", + Collation::koi8u_general_ci => "koi8u_general_ci", + Collation::latin1_bin => "latin1_bin", + Collation::latin1_danish_ci => "latin1_danish_ci", + Collation::latin1_general_ci => "latin1_general_ci", + Collation::latin1_general_cs => "latin1_general_cs", + Collation::latin1_german1_ci => "latin1_german1_ci", + Collation::latin1_german2_ci => "latin1_german2_ci", + Collation::latin1_spanish_ci => "latin1_spanish_ci", + Collation::latin1_swedish_ci => "latin1_swedish_ci", + Collation::latin2_bin => "latin2_bin", + Collation::latin2_croatian_ci => "latin2_croatian_ci", + Collation::latin2_czech_cs => "latin2_czech_cs", + Collation::latin2_general_ci => "latin2_general_ci", + Collation::latin2_hungarian_ci => "latin2_hungarian_ci", + Collation::latin5_bin => "latin5_bin", + Collation::latin5_turkish_ci => "latin5_turkish_ci", + Collation::latin7_bin => "latin7_bin", + Collation::latin7_estonian_cs => "latin7_estonian_cs", + Collation::latin7_general_ci => "latin7_general_ci", + Collation::latin7_general_cs => "latin7_general_cs", + Collation::macce_bin => "macce_bin", + Collation::macce_general_ci => "macce_general_ci", + Collation::macroman_bin => "macroman_bin", + Collation::macroman_general_ci => "macroman_general_ci", + Collation::sjis_bin => "sjis_bin", + Collation::sjis_japanese_ci => "sjis_japanese_ci", + Collation::swe7_bin => "swe7_bin", + Collation::swe7_swedish_ci => "swe7_swedish_ci", + Collation::tis620_bin => "tis620_bin", + Collation::tis620_thai_ci => "tis620_thai_ci", + Collation::ucs2_bin => "ucs2_bin", + Collation::ucs2_croatian_ci => "ucs2_croatian_ci", + Collation::ucs2_czech_ci => "ucs2_czech_ci", + Collation::ucs2_danish_ci => "ucs2_danish_ci", + Collation::ucs2_esperanto_ci => "ucs2_esperanto_ci", + Collation::ucs2_estonian_ci => "ucs2_estonian_ci", + Collation::ucs2_general_ci => "ucs2_general_ci", + Collation::ucs2_general_mysql500_ci => "ucs2_general_mysql500_ci", + Collation::ucs2_german2_ci => "ucs2_german2_ci", + Collation::ucs2_hungarian_ci => "ucs2_hungarian_ci", + Collation::ucs2_icelandic_ci => "ucs2_icelandic_ci", + Collation::ucs2_latvian_ci => "ucs2_latvian_ci", + Collation::ucs2_lithuanian_ci => "ucs2_lithuanian_ci", + Collation::ucs2_persian_ci => "ucs2_persian_ci", + Collation::ucs2_polish_ci => "ucs2_polish_ci", + Collation::ucs2_roman_ci => "ucs2_roman_ci", + Collation::ucs2_romanian_ci => "ucs2_romanian_ci", + Collation::ucs2_sinhala_ci => "ucs2_sinhala_ci", + Collation::ucs2_slovak_ci => "ucs2_slovak_ci", + Collation::ucs2_slovenian_ci => "ucs2_slovenian_ci", + Collation::ucs2_spanish_ci => "ucs2_spanish_ci", + Collation::ucs2_spanish2_ci => "ucs2_spanish2_ci", + Collation::ucs2_swedish_ci => "ucs2_swedish_ci", + Collation::ucs2_turkish_ci => "ucs2_turkish_ci", + Collation::ucs2_unicode_520_ci => "ucs2_unicode_520_ci", + Collation::ucs2_unicode_ci => "ucs2_unicode_ci", + Collation::ucs2_vietnamese_ci => "ucs2_vietnamese_ci", + Collation::ujis_bin => "ujis_bin", + Collation::ujis_japanese_ci => "ujis_japanese_ci", + Collation::utf16_bin => "utf16_bin", + Collation::utf16_croatian_ci => "utf16_croatian_ci", + Collation::utf16_czech_ci => "utf16_czech_ci", + Collation::utf16_danish_ci => "utf16_danish_ci", + Collation::utf16_esperanto_ci => "utf16_esperanto_ci", + Collation::utf16_estonian_ci => "utf16_estonian_ci", + Collation::utf16_general_ci => "utf16_general_ci", + Collation::utf16_german2_ci => "utf16_german2_ci", + Collation::utf16_hungarian_ci => "utf16_hungarian_ci", + Collation::utf16_icelandic_ci => "utf16_icelandic_ci", + Collation::utf16_latvian_ci => "utf16_latvian_ci", + Collation::utf16_lithuanian_ci => "utf16_lithuanian_ci", + Collation::utf16_persian_ci => "utf16_persian_ci", + Collation::utf16_polish_ci => "utf16_polish_ci", + Collation::utf16_roman_ci => "utf16_roman_ci", + Collation::utf16_romanian_ci => "utf16_romanian_ci", + Collation::utf16_sinhala_ci => "utf16_sinhala_ci", + Collation::utf16_slovak_ci => "utf16_slovak_ci", + Collation::utf16_slovenian_ci => "utf16_slovenian_ci", + Collation::utf16_spanish_ci => "utf16_spanish_ci", + Collation::utf16_spanish2_ci => "utf16_spanish2_ci", + Collation::utf16_swedish_ci => "utf16_swedish_ci", + Collation::utf16_turkish_ci => "utf16_turkish_ci", + Collation::utf16_unicode_520_ci => "utf16_unicode_520_ci", + Collation::utf16_unicode_ci => "utf16_unicode_ci", + Collation::utf16_vietnamese_ci => "utf16_vietnamese_ci", + Collation::utf16le_bin => "utf16le_bin", + Collation::utf16le_general_ci => "utf16le_general_ci", + Collation::utf32_bin => "utf32_bin", + Collation::utf32_croatian_ci => "utf32_croatian_ci", + Collation::utf32_czech_ci => "utf32_czech_ci", + Collation::utf32_danish_ci => "utf32_danish_ci", + Collation::utf32_esperanto_ci => "utf32_esperanto_ci", + Collation::utf32_estonian_ci => "utf32_estonian_ci", + Collation::utf32_general_ci => "utf32_general_ci", + Collation::utf32_german2_ci => "utf32_german2_ci", + Collation::utf32_hungarian_ci => "utf32_hungarian_ci", + Collation::utf32_icelandic_ci => "utf32_icelandic_ci", + Collation::utf32_latvian_ci => "utf32_latvian_ci", + Collation::utf32_lithuanian_ci => "utf32_lithuanian_ci", + Collation::utf32_persian_ci => "utf32_persian_ci", + Collation::utf32_polish_ci => "utf32_polish_ci", + Collation::utf32_roman_ci => "utf32_roman_ci", + Collation::utf32_romanian_ci => "utf32_romanian_ci", + Collation::utf32_sinhala_ci => "utf32_sinhala_ci", + Collation::utf32_slovak_ci => "utf32_slovak_ci", + Collation::utf32_slovenian_ci => "utf32_slovenian_ci", + Collation::utf32_spanish_ci => "utf32_spanish_ci", + Collation::utf32_spanish2_ci => "utf32_spanish2_ci", + Collation::utf32_swedish_ci => "utf32_swedish_ci", + Collation::utf32_turkish_ci => "utf32_turkish_ci", + Collation::utf32_unicode_520_ci => "utf32_unicode_520_ci", + Collation::utf32_unicode_ci => "utf32_unicode_ci", + Collation::utf32_vietnamese_ci => "utf32_vietnamese_ci", + Collation::utf8_bin => "utf8_bin", + Collation::utf8_croatian_ci => "utf8_croatian_ci", + Collation::utf8_czech_ci => "utf8_czech_ci", + Collation::utf8_danish_ci => "utf8_danish_ci", + Collation::utf8_esperanto_ci => "utf8_esperanto_ci", + Collation::utf8_estonian_ci => "utf8_estonian_ci", + Collation::utf8_general_ci => "utf8_general_ci", + Collation::utf8_general_mysql500_ci => "utf8_general_mysql500_ci", + Collation::utf8_german2_ci => "utf8_german2_ci", + Collation::utf8_hungarian_ci => "utf8_hungarian_ci", + Collation::utf8_icelandic_ci => "utf8_icelandic_ci", + Collation::utf8_latvian_ci => "utf8_latvian_ci", + Collation::utf8_lithuanian_ci => "utf8_lithuanian_ci", + Collation::utf8_persian_ci => "utf8_persian_ci", + Collation::utf8_polish_ci => "utf8_polish_ci", + Collation::utf8_roman_ci => "utf8_roman_ci", + Collation::utf8_romanian_ci => "utf8_romanian_ci", + Collation::utf8_sinhala_ci => "utf8_sinhala_ci", + Collation::utf8_slovak_ci => "utf8_slovak_ci", + Collation::utf8_slovenian_ci => "utf8_slovenian_ci", + Collation::utf8_spanish_ci => "utf8_spanish_ci", + Collation::utf8_spanish2_ci => "utf8_spanish2_ci", + Collation::utf8_swedish_ci => "utf8_swedish_ci", + Collation::utf8_tolower_ci => "utf8_tolower_ci", + Collation::utf8_turkish_ci => "utf8_turkish_ci", + Collation::utf8_unicode_520_ci => "utf8_unicode_520_ci", + Collation::utf8_unicode_ci => "utf8_unicode_ci", + Collation::utf8_vietnamese_ci => "utf8_vietnamese_ci", + Collation::utf8mb4_0900_ai_ci => "utf8mb4_0900_ai_ci", + Collation::utf8mb4_bin => "utf8mb4_bin", + Collation::utf8mb4_croatian_ci => "utf8mb4_croatian_ci", + Collation::utf8mb4_czech_ci => "utf8mb4_czech_ci", + Collation::utf8mb4_danish_ci => "utf8mb4_danish_ci", + Collation::utf8mb4_esperanto_ci => "utf8mb4_esperanto_ci", + Collation::utf8mb4_estonian_ci => "utf8mb4_estonian_ci", + Collation::utf8mb4_general_ci => "utf8mb4_general_ci", + Collation::utf8mb4_german2_ci => "utf8mb4_german2_ci", + Collation::utf8mb4_hungarian_ci => "utf8mb4_hungarian_ci", + Collation::utf8mb4_icelandic_ci => "utf8mb4_icelandic_ci", + Collation::utf8mb4_latvian_ci => "utf8mb4_latvian_ci", + Collation::utf8mb4_lithuanian_ci => "utf8mb4_lithuanian_ci", + Collation::utf8mb4_persian_ci => "utf8mb4_persian_ci", + Collation::utf8mb4_polish_ci => "utf8mb4_polish_ci", + Collation::utf8mb4_roman_ci => "utf8mb4_roman_ci", + Collation::utf8mb4_romanian_ci => "utf8mb4_romanian_ci", + Collation::utf8mb4_sinhala_ci => "utf8mb4_sinhala_ci", + Collation::utf8mb4_slovak_ci => "utf8mb4_slovak_ci", + Collation::utf8mb4_slovenian_ci => "utf8mb4_slovenian_ci", + Collation::utf8mb4_spanish_ci => "utf8mb4_spanish_ci", + Collation::utf8mb4_spanish2_ci => "utf8mb4_spanish2_ci", + Collation::utf8mb4_swedish_ci => "utf8mb4_swedish_ci", + Collation::utf8mb4_turkish_ci => "utf8mb4_turkish_ci", + Collation::utf8mb4_unicode_520_ci => "utf8mb4_unicode_520_ci", + Collation::utf8mb4_unicode_ci => "utf8mb4_unicode_ci", + Collation::utf8mb4_vietnamese_ci => "utf8mb4_vietnamese_ci", + } + } +} + +// Handshake packet have only 1 byte for collation_id. +// So we can't use collations with ID > 255. +impl FromStr for Collation { + type Err = Error; + + fn from_str(collation: &str) -> Result { + Ok(match collation { + "big5_chinese_ci" => Collation::big5_chinese_ci, + "swe7_swedish_ci" => Collation::swe7_swedish_ci, + "utf16_unicode_ci" => Collation::utf16_unicode_ci, + "utf16_icelandic_ci" => Collation::utf16_icelandic_ci, + "utf16_latvian_ci" => Collation::utf16_latvian_ci, + "utf16_romanian_ci" => Collation::utf16_romanian_ci, + "utf16_slovenian_ci" => Collation::utf16_slovenian_ci, + "utf16_polish_ci" => Collation::utf16_polish_ci, + "utf16_estonian_ci" => Collation::utf16_estonian_ci, + "utf16_spanish_ci" => Collation::utf16_spanish_ci, + "utf16_swedish_ci" => Collation::utf16_swedish_ci, + "ascii_general_ci" => Collation::ascii_general_ci, + "utf16_turkish_ci" => Collation::utf16_turkish_ci, + "utf16_czech_ci" => Collation::utf16_czech_ci, + "utf16_danish_ci" => Collation::utf16_danish_ci, + "utf16_lithuanian_ci" => Collation::utf16_lithuanian_ci, + "utf16_slovak_ci" => Collation::utf16_slovak_ci, + "utf16_spanish2_ci" => Collation::utf16_spanish2_ci, + "utf16_roman_ci" => Collation::utf16_roman_ci, + "utf16_persian_ci" => Collation::utf16_persian_ci, + "utf16_esperanto_ci" => Collation::utf16_esperanto_ci, + "utf16_hungarian_ci" => Collation::utf16_hungarian_ci, + "ujis_japanese_ci" => Collation::ujis_japanese_ci, + "utf16_sinhala_ci" => Collation::utf16_sinhala_ci, + "utf16_german2_ci" => Collation::utf16_german2_ci, + "utf16_croatian_ci" => Collation::utf16_croatian_ci, + "utf16_unicode_520_ci" => Collation::utf16_unicode_520_ci, + "utf16_vietnamese_ci" => Collation::utf16_vietnamese_ci, + "ucs2_unicode_ci" => Collation::ucs2_unicode_ci, + "ucs2_icelandic_ci" => Collation::ucs2_icelandic_ci, + "sjis_japanese_ci" => Collation::sjis_japanese_ci, + "ucs2_latvian_ci" => Collation::ucs2_latvian_ci, + "ucs2_romanian_ci" => Collation::ucs2_romanian_ci, + "ucs2_slovenian_ci" => Collation::ucs2_slovenian_ci, + "ucs2_polish_ci" => Collation::ucs2_polish_ci, + "ucs2_estonian_ci" => Collation::ucs2_estonian_ci, + "ucs2_spanish_ci" => Collation::ucs2_spanish_ci, + "ucs2_swedish_ci" => Collation::ucs2_swedish_ci, + "ucs2_turkish_ci" => Collation::ucs2_turkish_ci, + "ucs2_czech_ci" => Collation::ucs2_czech_ci, + "ucs2_danish_ci" => Collation::ucs2_danish_ci, + "cp1251_bulgarian_ci" => Collation::cp1251_bulgarian_ci, + "ucs2_lithuanian_ci" => Collation::ucs2_lithuanian_ci, + "ucs2_slovak_ci" => Collation::ucs2_slovak_ci, + "ucs2_spanish2_ci" => Collation::ucs2_spanish2_ci, + "ucs2_roman_ci" => Collation::ucs2_roman_ci, + "ucs2_persian_ci" => Collation::ucs2_persian_ci, + "ucs2_esperanto_ci" => Collation::ucs2_esperanto_ci, + "ucs2_hungarian_ci" => Collation::ucs2_hungarian_ci, + "ucs2_sinhala_ci" => Collation::ucs2_sinhala_ci, + "ucs2_german2_ci" => Collation::ucs2_german2_ci, + "ucs2_croatian_ci" => Collation::ucs2_croatian_ci, + "latin1_danish_ci" => Collation::latin1_danish_ci, + "ucs2_unicode_520_ci" => Collation::ucs2_unicode_520_ci, + "ucs2_vietnamese_ci" => Collation::ucs2_vietnamese_ci, + "ucs2_general_mysql500_ci" => Collation::ucs2_general_mysql500_ci, + "hebrew_general_ci" => Collation::hebrew_general_ci, + "utf32_unicode_ci" => Collation::utf32_unicode_ci, + "utf32_icelandic_ci" => Collation::utf32_icelandic_ci, + "utf32_latvian_ci" => Collation::utf32_latvian_ci, + "utf32_romanian_ci" => Collation::utf32_romanian_ci, + "utf32_slovenian_ci" => Collation::utf32_slovenian_ci, + "utf32_polish_ci" => Collation::utf32_polish_ci, + "utf32_estonian_ci" => Collation::utf32_estonian_ci, + "utf32_spanish_ci" => Collation::utf32_spanish_ci, + "utf32_swedish_ci" => Collation::utf32_swedish_ci, + "utf32_turkish_ci" => Collation::utf32_turkish_ci, + "utf32_czech_ci" => Collation::utf32_czech_ci, + "utf32_danish_ci" => Collation::utf32_danish_ci, + "utf32_lithuanian_ci" => Collation::utf32_lithuanian_ci, + "utf32_slovak_ci" => Collation::utf32_slovak_ci, + "utf32_spanish2_ci" => Collation::utf32_spanish2_ci, + "utf32_roman_ci" => Collation::utf32_roman_ci, + "utf32_persian_ci" => Collation::utf32_persian_ci, + "utf32_esperanto_ci" => Collation::utf32_esperanto_ci, + "utf32_hungarian_ci" => Collation::utf32_hungarian_ci, + "utf32_sinhala_ci" => Collation::utf32_sinhala_ci, + "tis620_thai_ci" => Collation::tis620_thai_ci, + "utf32_german2_ci" => Collation::utf32_german2_ci, + "utf32_croatian_ci" => Collation::utf32_croatian_ci, + "utf32_unicode_520_ci" => Collation::utf32_unicode_520_ci, + "utf32_vietnamese_ci" => Collation::utf32_vietnamese_ci, + "euckr_korean_ci" => Collation::euckr_korean_ci, + "utf8_unicode_ci" => Collation::utf8_unicode_ci, + "utf8_icelandic_ci" => Collation::utf8_icelandic_ci, + "utf8_latvian_ci" => Collation::utf8_latvian_ci, + "utf8_romanian_ci" => Collation::utf8_romanian_ci, + "utf8_slovenian_ci" => Collation::utf8_slovenian_ci, + "utf8_polish_ci" => Collation::utf8_polish_ci, + "utf8_estonian_ci" => Collation::utf8_estonian_ci, + "utf8_spanish_ci" => Collation::utf8_spanish_ci, + "latin2_czech_cs" => Collation::latin2_czech_cs, + "latin7_estonian_cs" => Collation::latin7_estonian_cs, + "utf8_swedish_ci" => Collation::utf8_swedish_ci, + "utf8_turkish_ci" => Collation::utf8_turkish_ci, + "utf8_czech_ci" => Collation::utf8_czech_ci, + "utf8_danish_ci" => Collation::utf8_danish_ci, + "utf8_lithuanian_ci" => Collation::utf8_lithuanian_ci, + "utf8_slovak_ci" => Collation::utf8_slovak_ci, + "utf8_spanish2_ci" => Collation::utf8_spanish2_ci, + "utf8_roman_ci" => Collation::utf8_roman_ci, + "utf8_persian_ci" => Collation::utf8_persian_ci, + "utf8_esperanto_ci" => Collation::utf8_esperanto_ci, + "latin2_hungarian_ci" => Collation::latin2_hungarian_ci, + "utf8_hungarian_ci" => Collation::utf8_hungarian_ci, + "utf8_sinhala_ci" => Collation::utf8_sinhala_ci, + "utf8_german2_ci" => Collation::utf8_german2_ci, + "utf8_croatian_ci" => Collation::utf8_croatian_ci, + "utf8_unicode_520_ci" => Collation::utf8_unicode_520_ci, + "utf8_vietnamese_ci" => Collation::utf8_vietnamese_ci, + "koi8u_general_ci" => Collation::koi8u_general_ci, + "utf8_general_mysql500_ci" => Collation::utf8_general_mysql500_ci, + "utf8mb4_unicode_ci" => Collation::utf8mb4_unicode_ci, + "utf8mb4_icelandic_ci" => Collation::utf8mb4_icelandic_ci, + "utf8mb4_latvian_ci" => Collation::utf8mb4_latvian_ci, + "utf8mb4_romanian_ci" => Collation::utf8mb4_romanian_ci, + "utf8mb4_slovenian_ci" => Collation::utf8mb4_slovenian_ci, + "utf8mb4_polish_ci" => Collation::utf8mb4_polish_ci, + "cp1251_ukrainian_ci" => Collation::cp1251_ukrainian_ci, + "utf8mb4_estonian_ci" => Collation::utf8mb4_estonian_ci, + "utf8mb4_spanish_ci" => Collation::utf8mb4_spanish_ci, + "utf8mb4_swedish_ci" => Collation::utf8mb4_swedish_ci, + "utf8mb4_turkish_ci" => Collation::utf8mb4_turkish_ci, + "utf8mb4_czech_ci" => Collation::utf8mb4_czech_ci, + "utf8mb4_danish_ci" => Collation::utf8mb4_danish_ci, + "utf8mb4_lithuanian_ci" => Collation::utf8mb4_lithuanian_ci, + "utf8mb4_slovak_ci" => Collation::utf8mb4_slovak_ci, + "utf8mb4_spanish2_ci" => Collation::utf8mb4_spanish2_ci, + "utf8mb4_roman_ci" => Collation::utf8mb4_roman_ci, + "gb2312_chinese_ci" => Collation::gb2312_chinese_ci, + "utf8mb4_persian_ci" => Collation::utf8mb4_persian_ci, + "utf8mb4_esperanto_ci" => Collation::utf8mb4_esperanto_ci, + "utf8mb4_hungarian_ci" => Collation::utf8mb4_hungarian_ci, + "utf8mb4_sinhala_ci" => Collation::utf8mb4_sinhala_ci, + "utf8mb4_german2_ci" => Collation::utf8mb4_german2_ci, + "utf8mb4_croatian_ci" => Collation::utf8mb4_croatian_ci, + "utf8mb4_unicode_520_ci" => Collation::utf8mb4_unicode_520_ci, + "utf8mb4_vietnamese_ci" => Collation::utf8mb4_vietnamese_ci, + "gb18030_chinese_ci" => Collation::gb18030_chinese_ci, + "gb18030_bin" => Collation::gb18030_bin, + "greek_general_ci" => Collation::greek_general_ci, + "gb18030_unicode_520_ci" => Collation::gb18030_unicode_520_ci, + "utf8mb4_0900_ai_ci" => Collation::utf8mb4_0900_ai_ci, + "cp1250_general_ci" => Collation::cp1250_general_ci, + "latin2_croatian_ci" => Collation::latin2_croatian_ci, + "gbk_chinese_ci" => Collation::gbk_chinese_ci, + "cp1257_lithuanian_ci" => Collation::cp1257_lithuanian_ci, + "dec8_swedish_ci" => Collation::dec8_swedish_ci, + "latin5_turkish_ci" => Collation::latin5_turkish_ci, + "latin1_german2_ci" => Collation::latin1_german2_ci, + "armscii8_general_ci" => Collation::armscii8_general_ci, + "utf8_general_ci" => Collation::utf8_general_ci, + "cp1250_czech_cs" => Collation::cp1250_czech_cs, + "ucs2_general_ci" => Collation::ucs2_general_ci, + "cp866_general_ci" => Collation::cp866_general_ci, + "keybcs2_general_ci" => Collation::keybcs2_general_ci, + "macce_general_ci" => Collation::macce_general_ci, + "macroman_general_ci" => Collation::macroman_general_ci, + "cp850_general_ci" => Collation::cp850_general_ci, + "cp852_general_ci" => Collation::cp852_general_ci, + "latin7_general_ci" => Collation::latin7_general_ci, + "latin7_general_cs" => Collation::latin7_general_cs, + "macce_bin" => Collation::macce_bin, + "cp1250_croatian_ci" => Collation::cp1250_croatian_ci, + "utf8mb4_general_ci" => Collation::utf8mb4_general_ci, + "utf8mb4_bin" => Collation::utf8mb4_bin, + "latin1_bin" => Collation::latin1_bin, + "latin1_general_ci" => Collation::latin1_general_ci, + "latin1_general_cs" => Collation::latin1_general_cs, + "latin1_german1_ci" => Collation::latin1_german1_ci, + "cp1251_bin" => Collation::cp1251_bin, + "cp1251_general_ci" => Collation::cp1251_general_ci, + "cp1251_general_cs" => Collation::cp1251_general_cs, + "macroman_bin" => Collation::macroman_bin, + "utf16_general_ci" => Collation::utf16_general_ci, + "utf16_bin" => Collation::utf16_bin, + "utf16le_general_ci" => Collation::utf16le_general_ci, + "cp1256_general_ci" => Collation::cp1256_general_ci, + "cp1257_bin" => Collation::cp1257_bin, + "cp1257_general_ci" => Collation::cp1257_general_ci, + "hp8_english_ci" => Collation::hp8_english_ci, + "utf32_general_ci" => Collation::utf32_general_ci, + "utf32_bin" => Collation::utf32_bin, + "utf16le_bin" => Collation::utf16le_bin, + "binary" => Collation::binary, + "armscii8_bin" => Collation::armscii8_bin, + "ascii_bin" => Collation::ascii_bin, + "cp1250_bin" => Collation::cp1250_bin, + "cp1256_bin" => Collation::cp1256_bin, + "cp866_bin" => Collation::cp866_bin, + "dec8_bin" => Collation::dec8_bin, + "koi8r_general_ci" => Collation::koi8r_general_ci, + "greek_bin" => Collation::greek_bin, + "hebrew_bin" => Collation::hebrew_bin, + "hp8_bin" => Collation::hp8_bin, + "keybcs2_bin" => Collation::keybcs2_bin, + "koi8r_bin" => Collation::koi8r_bin, + "koi8u_bin" => Collation::koi8u_bin, + "utf8_tolower_ci" => Collation::utf8_tolower_ci, + "latin2_bin" => Collation::latin2_bin, + "latin5_bin" => Collation::latin5_bin, + "latin7_bin" => Collation::latin7_bin, + "latin1_swedish_ci" => Collation::latin1_swedish_ci, + "cp850_bin" => Collation::cp850_bin, + "cp852_bin" => Collation::cp852_bin, + "swe7_bin" => Collation::swe7_bin, + "utf8_bin" => Collation::utf8_bin, + "big5_bin" => Collation::big5_bin, + "euckr_bin" => Collation::euckr_bin, + "gb2312_bin" => Collation::gb2312_bin, + "gbk_bin" => Collation::gbk_bin, + "sjis_bin" => Collation::sjis_bin, + "tis620_bin" => Collation::tis620_bin, + "latin2_general_ci" => Collation::latin2_general_ci, + "ucs2_bin" => Collation::ucs2_bin, + "ujis_bin" => Collation::ujis_bin, + "geostd8_general_ci" => Collation::geostd8_general_ci, + "geostd8_bin" => Collation::geostd8_bin, + "latin1_spanish_ci" => Collation::latin1_spanish_ci, + "cp932_japanese_ci" => Collation::cp932_japanese_ci, + "cp932_bin" => Collation::cp932_bin, + "eucjpms_japanese_ci" => Collation::eucjpms_japanese_ci, + "eucjpms_bin" => Collation::eucjpms_bin, + "cp1250_polish_ci" => Collation::cp1250_polish_ci, + + _ => { + return Err(Error::Configuration( + format!("unsupported MySQL collation: {}", collation).into(), + )); + } + }) + } +} diff --git a/warpgate-database-protocols/src/mysql/io/buf.rs b/warpgate-database-protocols/src/mysql/io/buf.rs new file mode 100644 index 0000000..9ccb62e --- /dev/null +++ b/warpgate-database-protocols/src/mysql/io/buf.rs @@ -0,0 +1,40 @@ +use bytes::{Buf, Bytes}; + +use crate::error::Error; +use crate::io::BufExt; + +pub trait MySqlBufExt: Buf { + // Read a length-encoded integer. + // NOTE: 0xfb or NULL is only returned for binary value encoding to indicate NULL. + // NOTE: 0xff is only returned during a result set to indicate ERR. + // + fn get_uint_lenenc(&mut self) -> u64; + + // Read a length-encoded string. + fn get_str_lenenc(&mut self) -> Result; + + // Read a length-encoded byte sequence. + fn get_bytes_lenenc(&mut self) -> Bytes; +} + +impl MySqlBufExt for Bytes { + fn get_uint_lenenc(&mut self) -> u64 { + match self.get_u8() { + 0xfc => u64::from(self.get_u16_le()), + 0xfd => self.get_uint_le(3), + 0xfe => self.get_u64_le(), + + v => u64::from(v), + } + } + + fn get_str_lenenc(&mut self) -> Result { + let size = self.get_uint_lenenc(); + self.get_str(size as usize) + } + + fn get_bytes_lenenc(&mut self) -> Bytes { + let size = self.get_uint_lenenc(); + self.split_to(size as usize) + } +} diff --git a/warpgate-database-protocols/src/mysql/io/buf_mut.rs b/warpgate-database-protocols/src/mysql/io/buf_mut.rs new file mode 100644 index 0000000..ba2ba3e --- /dev/null +++ b/warpgate-database-protocols/src/mysql/io/buf_mut.rs @@ -0,0 +1,126 @@ +use bytes::BufMut; + +pub trait MySqlBufMutExt: BufMut { + fn put_uint_lenenc(&mut self, v: u64); + + fn put_str_lenenc(&mut self, v: &str); + + fn put_bytes_lenenc(&mut self, v: &[u8]); +} + +impl MySqlBufMutExt for Vec { + fn put_uint_lenenc(&mut self, v: u64) { + // https://dev.mysql.com/doc/internals/en/integer.html + // https://mariadb.com/kb/en/library/protocol-data-types/#length-encoded-integers + + if v < 251 { + self.push(v as u8); + } else if v < 0x1_00_00 { + self.push(0xfc); + self.extend(&(v as u16).to_le_bytes()); + } else if v < 0x1_00_00_00 { + self.push(0xfd); + self.extend(&(v as u32).to_le_bytes()[..3]); + } else { + self.push(0xfe); + self.extend(&v.to_le_bytes()); + } + } + + fn put_str_lenenc(&mut self, v: &str) { + self.put_bytes_lenenc(v.as_bytes()); + } + + fn put_bytes_lenenc(&mut self, v: &[u8]) { + self.put_uint_lenenc(v.len() as u64); + self.extend(v); + } +} + +#[test] +fn test_encodes_int_lenenc_u8() { + let mut buf = Vec::with_capacity(1024); + buf.put_uint_lenenc(0xFA_u64); + + assert_eq!(&buf[..], b"\xFA"); +} + +#[test] +fn test_encodes_int_lenenc_u16() { + let mut buf = Vec::with_capacity(1024); + buf.put_uint_lenenc(std::u16::MAX as u64); + + assert_eq!(&buf[..], b"\xFC\xFF\xFF"); +} + +#[test] +fn test_encodes_int_lenenc_u24() { + let mut buf = Vec::with_capacity(1024); + buf.put_uint_lenenc(0xFF_FF_FF_u64); + + assert_eq!(&buf[..], b"\xFD\xFF\xFF\xFF"); +} + +#[test] +fn test_encodes_int_lenenc_u64() { + let mut buf = Vec::with_capacity(1024); + buf.put_uint_lenenc(std::u64::MAX); + + assert_eq!(&buf[..], b"\xFE\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF"); +} + +#[test] +fn test_encodes_int_lenenc_fb() { + let mut buf = Vec::with_capacity(1024); + buf.put_uint_lenenc(0xFB_u64); + + assert_eq!(&buf[..], b"\xFC\xFB\x00"); +} + +#[test] +fn test_encodes_int_lenenc_fc() { + let mut buf = Vec::with_capacity(1024); + buf.put_uint_lenenc(0xFC_u64); + + assert_eq!(&buf[..], b"\xFC\xFC\x00"); +} + +#[test] +fn test_encodes_int_lenenc_fd() { + let mut buf = Vec::with_capacity(1024); + buf.put_uint_lenenc(0xFD_u64); + + assert_eq!(&buf[..], b"\xFC\xFD\x00"); +} + +#[test] +fn test_encodes_int_lenenc_fe() { + let mut buf = Vec::with_capacity(1024); + buf.put_uint_lenenc(0xFE_u64); + + assert_eq!(&buf[..], b"\xFC\xFE\x00"); +} + +#[test] +fn test_encodes_int_lenenc_ff() { + let mut buf = Vec::with_capacity(1024); + buf.put_uint_lenenc(0xFF_u64); + + assert_eq!(&buf[..], b"\xFC\xFF\x00"); +} + +#[test] +fn test_encodes_string_lenenc() { + let mut buf = Vec::with_capacity(1024); + buf.put_str_lenenc("random_string"); + + assert_eq!(&buf[..], b"\x0Drandom_string"); +} + +#[test] +fn test_encodes_byte_lenenc() { + let mut buf = Vec::with_capacity(1024); + buf.put_bytes_lenenc(b"random_string"); + + assert_eq!(&buf[..], b"\x0Drandom_string"); +} diff --git a/warpgate-database-protocols/src/mysql/io/mod.rs b/warpgate-database-protocols/src/mysql/io/mod.rs new file mode 100644 index 0000000..fafc914 --- /dev/null +++ b/warpgate-database-protocols/src/mysql/io/mod.rs @@ -0,0 +1,5 @@ +mod buf; +mod buf_mut; + +pub use buf::MySqlBufExt; +pub use buf_mut::MySqlBufMutExt; diff --git a/warpgate-database-protocols/src/mysql/mod.rs b/warpgate-database-protocols/src/mysql/mod.rs new file mode 100644 index 0000000..146321f --- /dev/null +++ b/warpgate-database-protocols/src/mysql/mod.rs @@ -0,0 +1,5 @@ +//! **MySQL** database driver. + +pub mod collation; +pub mod io; +pub mod protocol; diff --git a/warpgate-database-protocols/src/mysql/protocol/auth.rs b/warpgate-database-protocols/src/mysql/protocol/auth.rs new file mode 100644 index 0000000..aa27bf7 --- /dev/null +++ b/warpgate-database-protocols/src/mysql/protocol/auth.rs @@ -0,0 +1,38 @@ +use std::str::FromStr; + +use crate::err_protocol; +use crate::error::Error; + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum AuthPlugin { + MySqlClearPassword, + MySqlNativePassword, + CachingSha2Password, + Sha256Password, +} + +impl AuthPlugin { + pub(crate) fn name(self) -> &'static str { + match self { + AuthPlugin::MySqlClearPassword => "mysql_clear_password", + AuthPlugin::MySqlNativePassword => "mysql_native_password", + AuthPlugin::CachingSha2Password => "caching_sha2_password", + AuthPlugin::Sha256Password => "sha256_password", + } + } +} + +impl FromStr for AuthPlugin { + type Err = Error; + + fn from_str(s: &str) -> Result { + match s { + "mysql_clear_password" => Ok(AuthPlugin::MySqlClearPassword), + "mysql_native_password" => Ok(AuthPlugin::MySqlNativePassword), + "caching_sha2_password" => Ok(AuthPlugin::CachingSha2Password), + "sha256_password" => Ok(AuthPlugin::Sha256Password), + + _ => Err(err_protocol!("unknown authentication plugin: {}", s)), + } + } +} diff --git a/warpgate-database-protocols/src/mysql/protocol/capabilities.rs b/warpgate-database-protocols/src/mysql/protocol/capabilities.rs new file mode 100644 index 0000000..6d7b582 --- /dev/null +++ b/warpgate-database-protocols/src/mysql/protocol/capabilities.rs @@ -0,0 +1,86 @@ +// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/group__group__cs__capabilities__flags.html +// https://mariadb.com/kb/en/library/connection/#capabilities +bitflags::bitflags! { + pub struct Capabilities: u64 { + // [MariaDB] MySQL compatibility + const MYSQL = 1; + + // [*] Send found rows instead of affected rows in EOF_Packet. + const FOUND_ROWS = 2; + + // Get all column flags. + const LONG_FLAG = 4; + + // [*] Database (schema) name can be specified on connect in Handshake Response Packet. + const CONNECT_WITH_DB = 8; + + // Don't allow database.table.column + const NO_SCHEMA = 16; + + // [*] Compression protocol supported + const COMPRESS = 32; + + // Special handling of ODBC behavior. + const ODBC = 64; + + // Can use LOAD DATA LOCAL + const LOCAL_FILES = 128; + + // [*] Ignore spaces before '(' + const IGNORE_SPACE = 256; + + // [*] New 4.1+ protocol + const PROTOCOL_41 = 512; + + // This is an interactive client + const INTERACTIVE = 1024; + + // Use SSL encryption for this session + const SSL = 2048; + + // Client knows about transactions + const TRANSACTIONS = 8192; + + // 4.1+ authentication + const SECURE_CONNECTION = (1 << 15); + + // Enable/disable multi-statement support for COM_QUERY *and* COM_STMT_PREPARE + const MULTI_STATEMENTS = (1 << 16); + + // Enable/disable multi-results for COM_QUERY + const MULTI_RESULTS = (1 << 17); + + // Enable/disable multi-results for COM_STMT_PREPARE + const PS_MULTI_RESULTS = (1 << 18); + + // Client supports plugin authentication + const PLUGIN_AUTH = (1 << 19); + + // Client supports connection attributes + const CONNECT_ATTRS = (1 << 20); + + // Enable authentication response packet to be larger than 255 bytes. + const PLUGIN_AUTH_LENENC_DATA = (1 << 21); + + // Don't close the connection for a user account with expired password. + const CAN_HANDLE_EXPIRED_PASSWORDS = (1 << 22); + + // Capable of handling server state change information. + const SESSION_TRACK = (1 << 23); + + // Client no longer needs EOF_Packet and will use OK_Packet instead. + const DEPRECATE_EOF = (1 << 24); + + // Support ZSTD protocol compression + const ZSTD_COMPRESSION_ALGORITHM = (1 << 26); + + // Verify server certificate + const SSL_VERIFY_SERVER_CERT = (1 << 30); + + // The client can handle optional metadata information in the resultset + const OPTIONAL_RESULTSET_METADATA = (1 << 25); + + // Don't reset the options after an unsuccessful connect + const REMEMBER_OPTIONS = (1 << 31); + } +} diff --git a/warpgate-database-protocols/src/mysql/protocol/connect/auth_switch.rs b/warpgate-database-protocols/src/mysql/protocol/connect/auth_switch.rs new file mode 100644 index 0000000..bdb330f --- /dev/null +++ b/warpgate-database-protocols/src/mysql/protocol/connect/auth_switch.rs @@ -0,0 +1,58 @@ +use bytes::{Buf, BufMut, Bytes}; + +use crate::err_protocol; +use crate::error::Error; +use crate::io::{BufExt, BufMutExt, Decode, Encode}; +use crate::mysql::protocol::auth::AuthPlugin; +use crate::mysql::protocol::Capabilities; + +// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase_packets_protocol_auth_switch_request.html + +#[derive(Debug)] +pub struct AuthSwitchRequest { + pub plugin: AuthPlugin, + pub data: Bytes, +} + +impl Decode<'_> for AuthSwitchRequest { + fn decode_with(mut buf: Bytes, _: ()) -> Result { + let header = buf.get_u8(); + if header != 0xfe { + return Err(err_protocol!( + "expected 0xfe (AUTH_SWITCH) but found 0x{:x}", + header + )); + } + + let plugin = buf.get_str_nul()?.parse()?; + + // See: https://github.com/mysql/mysql-server/blob/ea7d2e2d16ac03afdd9cb72a972a95981107bf51/sql/auth/sha2_password.cc#L942 + if buf.len() != 21 { + return Err(err_protocol!( + "expected 21 bytes but found {} bytes", + buf.len() + )); + } + let data = buf.get_bytes(20); + buf.advance(1); // NUL-terminator + + Ok(Self { plugin, data }) + } +} + +impl Encode<'_, ()> for AuthSwitchRequest { + fn encode_with(&self, buf: &mut Vec, _: ()) { + buf.put_u8(0xfe); + buf.put_str_nul(self.plugin.name()); + buf.extend(&self.data); + } +} + +#[derive(Debug)] +pub struct AuthSwitchResponse(pub Vec); + +impl Encode<'_, Capabilities> for AuthSwitchResponse { + fn encode_with(&self, buf: &mut Vec, _: Capabilities) { + buf.extend_from_slice(&self.0); + } +} diff --git a/warpgate-database-protocols/src/mysql/protocol/connect/handshake.rs b/warpgate-database-protocols/src/mysql/protocol/connect/handshake.rs new file mode 100644 index 0000000..32ccdc3 --- /dev/null +++ b/warpgate-database-protocols/src/mysql/protocol/connect/handshake.rs @@ -0,0 +1,233 @@ +use bytes::buf::Chain; +use bytes::{Buf, BufMut, Bytes}; + +use crate::error::Error; +use crate::io::{BufExt, BufMutExt, Decode, Encode}; +use crate::mysql::protocol::auth::AuthPlugin; +use crate::mysql::protocol::response::Status; +use crate::mysql::protocol::Capabilities; + +// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake +// https://mariadb.com/kb/en/connection/#initial-handshake-packet + +#[derive(Debug)] +pub struct Handshake { + #[allow(unused)] + pub protocol_version: u8, + pub server_version: String, + #[allow(unused)] + pub connection_id: u32, + pub server_capabilities: Capabilities, + #[allow(unused)] + pub server_default_collation: u8, + #[allow(unused)] + pub status: Status, + pub auth_plugin: Option, + pub auth_plugin_data: Chain, +} + +impl Decode<'_> for Handshake { + fn decode_with(mut buf: Bytes, _: ()) -> Result { + let protocol_version = buf.get_u8(); // int<1> + let server_version = buf.get_str_nul()?; // string + let connection_id = buf.get_u32_le(); // int<4> + let auth_plugin_data_1 = buf.get_bytes(8); // string<8> + + buf.advance(1); // reserved: string<1> + + let capabilities_1 = buf.get_u16_le(); // int<2> + let mut capabilities = Capabilities::from_bits_truncate(capabilities_1.into()); + + let collation = buf.get_u8(); // int<1> + let status = Status::from_bits_truncate(buf.get_u16_le()); + + let capabilities_2 = buf.get_u16_le(); // int<2> + capabilities |= Capabilities::from_bits_truncate(((capabilities_2 as u32) << 16).into()); + + let auth_plugin_data_len = if capabilities.contains(Capabilities::PLUGIN_AUTH) { + buf.get_u8() + } else { + buf.advance(1); // int<1> + 0 + }; + + buf.advance(6); // reserved: string<6> + + if capabilities.contains(Capabilities::MYSQL) { + buf.advance(4); // reserved: string<4> + } else { + let capabilities_3 = buf.get_u32_le(); // int<4> + capabilities |= Capabilities::from_bits_truncate((capabilities_3 as u64) << 32); + } + + let auth_plugin_data_2 = if capabilities.contains(Capabilities::SECURE_CONNECTION) { + let len = ((auth_plugin_data_len as isize) - 9).max(12) as usize; + let v = buf.get_bytes(len); + buf.advance(1); // NUL-terminator + + v + } else { + Bytes::new() + }; + + let auth_plugin = if capabilities.contains(Capabilities::PLUGIN_AUTH) { + Some(buf.get_str_nul()?.parse()?) + } else { + None + }; + + Ok(Self { + protocol_version, + server_version, + connection_id, + server_default_collation: collation, + status, + server_capabilities: capabilities, + auth_plugin, + auth_plugin_data: auth_plugin_data_1.chain(auth_plugin_data_2), + }) + } +} + +impl Encode<'_, ()> for Handshake { + fn encode_with(&self, buf: &mut Vec, _: ()) { + buf.put_u8(self.protocol_version); + buf.put_str_nul(&self.server_version); + buf.put_u32_le(self.connection_id); + buf.put_slice(self.auth_plugin_data.first_ref()); + buf.put_u8(0x00); + buf.put_u16_le((self.server_capabilities.bits() & 0x0000_FFFF) as u16); + buf.put_u8(self.server_default_collation); + buf.put_u16_le(self.status.bits()); + buf.put_u16_le(((self.server_capabilities.bits() & 0xFFFF_0000) >> 16) as u16); + + if self.server_capabilities.contains(Capabilities::PLUGIN_AUTH) { + buf.put_u8((self.auth_plugin_data.last_ref().len() + 8 + 1) as u8); + } else { + buf.put_u8(0); + } + + buf.put_slice(&[0_u8; 10][..]); + + if self + .server_capabilities + .contains(Capabilities::SECURE_CONNECTION) + { + buf.put_slice(self.auth_plugin_data.last_ref()); + buf.put_u8(0); + } + + if self.server_capabilities.contains(Capabilities::PLUGIN_AUTH) { + if let Some(auth_plugin) = self.auth_plugin { + buf.put_str_nul(auth_plugin.name()); + } + } + } +} + +#[test] +fn test_decode_handshake_mysql_8_0_18() { + const HANDSHAKE_MYSQL_8_0_18: &[u8] = b"\n8.0.18\x00\x19\x00\x00\x00\x114aB0c\x06g\x00\xff\xff\xff\x02\x00\xff\xc7\x15\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00tL\x03s\x0f[4\rl4. \x00caching_sha2_password\x00"; + + let mut p = Handshake::decode(HANDSHAKE_MYSQL_8_0_18.into()).unwrap(); + + assert_eq!(p.protocol_version, 10); + + p.server_capabilities.toggle( + Capabilities::MYSQL + | Capabilities::FOUND_ROWS + | Capabilities::LONG_FLAG + | Capabilities::CONNECT_WITH_DB + | Capabilities::NO_SCHEMA + | Capabilities::COMPRESS + | Capabilities::ODBC + | Capabilities::LOCAL_FILES + | Capabilities::IGNORE_SPACE + | Capabilities::PROTOCOL_41 + | Capabilities::INTERACTIVE + | Capabilities::SSL + | Capabilities::TRANSACTIONS + | Capabilities::SECURE_CONNECTION + | Capabilities::MULTI_STATEMENTS + | Capabilities::MULTI_RESULTS + | Capabilities::PS_MULTI_RESULTS + | Capabilities::PLUGIN_AUTH + | Capabilities::CONNECT_ATTRS + | Capabilities::PLUGIN_AUTH_LENENC_DATA + | Capabilities::CAN_HANDLE_EXPIRED_PASSWORDS + | Capabilities::SESSION_TRACK + | Capabilities::DEPRECATE_EOF + | Capabilities::ZSTD_COMPRESSION_ALGORITHM + | Capabilities::SSL_VERIFY_SERVER_CERT + | Capabilities::OPTIONAL_RESULTSET_METADATA + | Capabilities::REMEMBER_OPTIONS, + ); + + assert!(p.server_capabilities.is_empty()); + + assert_eq!(p.server_default_collation, 255); + assert!(p.status.contains(Status::SERVER_STATUS_AUTOCOMMIT)); + + assert!(matches!( + p.auth_plugin, + Some(AuthPlugin::CachingSha2Password) + )); + + assert_eq!( + &*p.auth_plugin_data.into_iter().collect::>(), + &[17, 52, 97, 66, 48, 99, 6, 103, 116, 76, 3, 115, 15, 91, 52, 13, 108, 52, 46, 32,] + ); +} + +#[test] +fn test_decode_handshake_mariadb_10_4_7() { + const HANDSHAKE_MARIA_DB_10_4_7: &[u8] = b"\n5.5.5-10.4.7-MariaDB-1:10.4.7+maria~bionic\x00\x0b\x00\x00\x00t6L\\j\"dS\x00\xfe\xf7\x08\x02\x00\xff\x81\x15\x00\x00\x00\x00\x00\x00\x07\x00\x00\x00U14Oph9\">(), + &[116, 54, 76, 92, 106, 34, 100, 83, 85, 49, 52, 79, 112, 104, 57, 34, 60, 72, 53, 110,] + ); +} diff --git a/warpgate-database-protocols/src/mysql/protocol/connect/handshake_response.rs b/warpgate-database-protocols/src/mysql/protocol/connect/handshake_response.rs new file mode 100644 index 0000000..29862ca --- /dev/null +++ b/warpgate-database-protocols/src/mysql/protocol/connect/handshake_response.rs @@ -0,0 +1,147 @@ +use std::str::FromStr; + +use bytes::{Buf, Bytes}; + +use crate::error::Error; +use crate::io::{BufExt, BufMutExt, Decode, Encode}; +use crate::mysql::io::{MySqlBufExt, MySqlBufMutExt}; +use crate::mysql::protocol::auth::AuthPlugin; +use crate::mysql::protocol::connect::ssl_request::SslRequest; +use crate::mysql::protocol::Capabilities; + +// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse +// https://mariadb.com/kb/en/connection/#client-handshake-response + +#[derive(Debug)] +pub struct HandshakeResponse { + pub database: Option, + + /// Max size of a command packet that the client wants to send to the server + pub max_packet_size: u32, + + /// Default collation for the connection + pub collation: u8, + + /// Name of the SQL account which client wants to log in + pub username: String, + + /// Authentication method used by the client + pub auth_plugin: Option, + + /// Opaque authentication response + pub auth_response: Option, +} + +impl Encode<'_, Capabilities> for HandshakeResponse { + fn encode_with(&self, buf: &mut Vec, mut capabilities: Capabilities) { + if self.auth_plugin.is_none() { + // ensure PLUGIN_AUTH is set *only* if we have a defined plugin + capabilities.remove(Capabilities::PLUGIN_AUTH); + } + + // NOTE: Half of this packet is identical to the SSL Request packet + SslRequest { + max_packet_size: self.max_packet_size, + collation: self.collation, + } + .encode_with(buf, capabilities); + + buf.put_str_nul(&self.username); + + if capabilities.contains(Capabilities::PLUGIN_AUTH_LENENC_DATA) { + if let Some(response) = &self.auth_response { + buf.put_bytes_lenenc(response); + } else { + buf.put_bytes_lenenc(&[]); + } + } else if capabilities.contains(Capabilities::SECURE_CONNECTION) { + if let Some(response) = &self.auth_response { + buf.push(response.len() as u8); + buf.extend(response); + } else { + buf.push(0); + } + } else { + buf.push(0); + } + + if capabilities.contains(Capabilities::CONNECT_WITH_DB) { + if let Some(database) = &self.database { + buf.put_str_nul(database); + } else { + buf.push(0); + } + } + + if capabilities.contains(Capabilities::PLUGIN_AUTH) { + if let Some(plugin) = &self.auth_plugin { + buf.put_str_nul(plugin.name()); + } else { + buf.push(0); + } + } + } +} + +impl Decode<'_, &mut Capabilities> for HandshakeResponse { + fn decode_with(mut buf: Bytes, server_capabilities: &mut Capabilities) -> Result { + let mut capabilities = buf.get_u32_le() as u64; + let max_packet_size = buf.get_u32_le(); + let collation = buf.get_u8(); + buf.advance(19); + + let partial_cap = Capabilities::from_bits_truncate(capabilities); + + if partial_cap.contains(Capabilities::MYSQL) { + // reserved: string<4> + buf.advance(4); + } else { + capabilities += (buf.get_u32_le() as u64) << 32; + } + + let partial_cap = Capabilities::from_bits_truncate(capabilities); + if partial_cap.contains(Capabilities::SSL) && buf.is_empty() { + return Ok(HandshakeResponse { + collation, + max_packet_size, + username: "".to_string(), + auth_response: None, + auth_plugin: None, + database: None, + }); + } + let username = buf.get_str_nul()?; + + let auth_response = if partial_cap.contains(Capabilities::PLUGIN_AUTH_LENENC_DATA) { + Some(buf.get_bytes_lenenc()) + } else if partial_cap.contains(Capabilities::SECURE_CONNECTION) { + let len = buf.get_u8(); + Some(buf.get_bytes(len as usize)) + } else { + Some(buf.get_bytes_nul()?) + }; + + let database = if partial_cap.contains(Capabilities::CONNECT_WITH_DB) { + Some(buf.get_str_nul()?) + } else { + None + }; + + let auth_plugin: Option = if partial_cap.contains(Capabilities::PLUGIN_AUTH) { + Some(AuthPlugin::from_str(&buf.get_str_nul()?)?) + } else { + None + }; + + *server_capabilities &= Capabilities::from_bits_truncate(capabilities); + + Ok(HandshakeResponse { + collation, + max_packet_size, + username, + auth_response, + auth_plugin, + database, + }) + } +} diff --git a/warpgate-database-protocols/src/mysql/protocol/connect/mod.rs b/warpgate-database-protocols/src/mysql/protocol/connect/mod.rs new file mode 100644 index 0000000..71f9999 --- /dev/null +++ b/warpgate-database-protocols/src/mysql/protocol/connect/mod.rs @@ -0,0 +1,13 @@ +//! Connection Phase +//! +//! + +mod auth_switch; +mod handshake; +mod handshake_response; +mod ssl_request; + +pub use auth_switch::{AuthSwitchRequest, AuthSwitchResponse}; +pub use handshake::Handshake; +pub use handshake_response::HandshakeResponse; +pub use ssl_request::SslRequest; diff --git a/warpgate-database-protocols/src/mysql/protocol/connect/ssl_request.rs b/warpgate-database-protocols/src/mysql/protocol/connect/ssl_request.rs new file mode 100644 index 0000000..5f0c2d8 --- /dev/null +++ b/warpgate-database-protocols/src/mysql/protocol/connect/ssl_request.rs @@ -0,0 +1,30 @@ +use crate::io::Encode; +use crate::mysql::protocol::Capabilities; + +// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase_packets_protocol_handshake_response.html +// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest + +#[derive(Debug)] +pub struct SslRequest { + pub max_packet_size: u32, + pub collation: u8, +} + +impl Encode<'_, Capabilities> for SslRequest { + fn encode_with(&self, buf: &mut Vec, capabilities: Capabilities) { + buf.extend(&(capabilities.bits() as u32).to_le_bytes()); + buf.extend(&self.max_packet_size.to_le_bytes()); + buf.push(self.collation); + + // reserved: string<19> + buf.extend(&[0_u8; 19]); + + if capabilities.contains(Capabilities::MYSQL) { + // reserved: string<4> + buf.extend(&[0_u8; 4]); + } else { + // extended client capabilities (MariaDB-specified): int<4> + buf.extend(&((capabilities.bits() >> 32) as u32).to_le_bytes()); + } + } +} diff --git a/warpgate-database-protocols/src/mysql/protocol/mod.rs b/warpgate-database-protocols/src/mysql/protocol/mod.rs new file mode 100644 index 0000000..22b5a03 --- /dev/null +++ b/warpgate-database-protocols/src/mysql/protocol/mod.rs @@ -0,0 +1,11 @@ +pub mod auth; +pub mod capabilities; +pub mod connect; +pub mod packet; +pub mod response; +pub mod row; +pub mod text; + +pub use capabilities::Capabilities; +pub use packet::Packet; +pub use row::Row; diff --git a/warpgate-database-protocols/src/mysql/protocol/packet.rs b/warpgate-database-protocols/src/mysql/protocol/packet.rs new file mode 100644 index 0000000..add23f0 --- /dev/null +++ b/warpgate-database-protocols/src/mysql/protocol/packet.rs @@ -0,0 +1,89 @@ +use std::ops::{Deref, DerefMut}; + +use bytes::Bytes; + +use crate::error::Error; +use crate::io::{Decode, Encode}; +use crate::mysql::protocol::response::{EofPacket, OkPacket}; +use crate::mysql::protocol::Capabilities; + +#[derive(Debug)] +pub struct Packet(pub T); + +impl<'en, 'stream, T> Encode<'stream, (Capabilities, &'stream mut u8)> for Packet +where + T: Encode<'en, Capabilities>, +{ + fn encode_with( + &self, + buf: &mut Vec, + (capabilities, sequence_id): (Capabilities, &'stream mut u8), + ) { + // reserve space to write the prefixed length + let offset = buf.len(); + buf.extend(&[0_u8; 4]); + + // encode the payload + self.0.encode_with(buf, capabilities); + + // determine the length of the encoded payload + // and write to our reserved space + let len = buf.len() - offset - 4; + let header = &mut buf[offset..]; + + // FIXME: Support larger packets + assert!(len < 0xFF_FF_FF); + + header[..4].copy_from_slice(&(len as u32).to_le_bytes()); + header[3] = *sequence_id; + + *sequence_id = sequence_id.wrapping_add(1); + } +} + +impl Packet { + pub(crate) fn decode<'de, T>(self) -> Result + where + T: Decode<'de, ()>, + { + self.decode_with(()) + } + + pub(crate) fn decode_with<'de, T, C>(self, context: C) -> Result + where + T: Decode<'de, C>, + { + T::decode_with(self.0, context) + } + + pub(crate) fn ok(self) -> Result { + self.decode() + } + + pub(crate) fn eof(self, capabilities: Capabilities) -> Result { + if capabilities.contains(Capabilities::DEPRECATE_EOF) { + let ok = self.ok()?; + + Ok(EofPacket { + warnings: ok.warnings, + status: ok.status, + }) + } else { + self.decode_with(capabilities) + } + } +} + +impl Deref for Packet { + type Target = Bytes; + + fn deref(&self) -> &Bytes { + &self.0 + } +} + +impl DerefMut for Packet { + fn deref_mut(&mut self) -> &mut Bytes { + &mut self.0 + } +} diff --git a/warpgate-database-protocols/src/mysql/protocol/response/eof.rs b/warpgate-database-protocols/src/mysql/protocol/response/eof.rs new file mode 100644 index 0000000..25568b5 --- /dev/null +++ b/warpgate-database-protocols/src/mysql/protocol/response/eof.rs @@ -0,0 +1,36 @@ +use bytes::{Buf, Bytes}; + +use crate::err_protocol; +use crate::error::Error; +use crate::io::Decode; +use crate::mysql::protocol::response::Status; +use crate::mysql::protocol::Capabilities; + +/// Marks the end of a result set, returning status and warnings. +/// +/// # Note +/// +/// The EOF packet is deprecated as of MySQL 5.7.5. SQLx only uses this packet for MySQL +/// prior MySQL versions. +#[derive(Debug)] +pub struct EofPacket { + pub warnings: u16, + pub status: Status, +} + +impl Decode<'_, Capabilities> for EofPacket { + fn decode_with(mut buf: Bytes, _: Capabilities) -> Result { + let header = buf.get_u8(); + if header != 0xfe { + return Err(err_protocol!( + "expected 0xfe (EOF_Packet) but found 0x{:x}", + header + )); + } + + let warnings = buf.get_u16_le(); + let status = Status::from_bits_truncate(buf.get_u16_le()); + + Ok(Self { status, warnings }) + } +} diff --git a/warpgate-database-protocols/src/mysql/protocol/response/err.rs b/warpgate-database-protocols/src/mysql/protocol/response/err.rs new file mode 100644 index 0000000..1071de5 --- /dev/null +++ b/warpgate-database-protocols/src/mysql/protocol/response/err.rs @@ -0,0 +1,81 @@ +use bytes::{Buf, BufMut, Bytes}; + +use crate::err_protocol; +use crate::error::Error; +use crate::io::{BufExt, Decode, Encode}; +use crate::mysql::protocol::Capabilities; + +// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_err_packet.html +// https://mariadb.com/kb/en/err_packet/ + +/// Indicates that an error occurred. +#[derive(Debug)] +pub struct ErrPacket { + pub error_code: u16, + pub sql_state: Option, + pub error_message: String, +} + +impl Decode<'_, Capabilities> for ErrPacket { + fn decode_with(mut buf: Bytes, capabilities: Capabilities) -> Result { + let header = buf.get_u8(); + if header != 0xff { + return Err(err_protocol!( + "expected 0xff (ERR_Packet) but found 0x{:x}", + header + )); + } + + let error_code = buf.get_u16_le(); + let mut sql_state = None; + + if capabilities.contains(Capabilities::PROTOCOL_41) { + // If the next byte is '#' then we have a SQL STATE + if buf.get(0) == Some(&0x23) { + buf.advance(1); + sql_state = Some(buf.get_str(5)?); + } + } + + let error_message = buf.get_str(buf.len())?; + + Ok(Self { + error_code, + sql_state, + error_message, + }) + } +} + +impl Encode<'_, ()> for ErrPacket { + fn encode_with(&self, buf: &mut Vec, _: ()) { + buf.put_u8(0xff); + buf.put_u16_le(self.error_code); + buf.extend_from_slice(self.error_message.as_bytes()) + //TODO: sql_state + } +} + +#[test] +fn test_decode_err_packet_out_of_order() { + const ERR_PACKETS_OUT_OF_ORDER: &[u8] = b"\xff\x84\x04Got packets out of order"; + + let p = + ErrPacket::decode_with(ERR_PACKETS_OUT_OF_ORDER.into(), Capabilities::PROTOCOL_41).unwrap(); + + assert_eq!(&p.error_message, "Got packets out of order"); + assert_eq!(p.error_code, 1156); + assert_eq!(p.sql_state, None); +} + +#[test] +fn test_decode_err_packet_unknown_database() { + const ERR_HANDSHAKE_UNKNOWN_DB: &[u8] = b"\xff\x19\x04#42000Unknown database \'unknown\'"; + + let p = + ErrPacket::decode_with(ERR_HANDSHAKE_UNKNOWN_DB.into(), Capabilities::PROTOCOL_41).unwrap(); + + assert_eq!(p.error_code, 1049); + assert_eq!(p.sql_state.as_deref(), Some("42000")); + assert_eq!(&p.error_message, "Unknown database \'unknown\'"); +} diff --git a/warpgate-database-protocols/src/mysql/protocol/response/mod.rs b/warpgate-database-protocols/src/mysql/protocol/response/mod.rs new file mode 100644 index 0000000..79767dc --- /dev/null +++ b/warpgate-database-protocols/src/mysql/protocol/response/mod.rs @@ -0,0 +1,14 @@ +//! Generic Response Packets +//! +//! +//! + +mod eof; +mod err; +mod ok; +mod status; + +pub use eof::EofPacket; +pub use err::ErrPacket; +pub use ok::OkPacket; +pub use status::Status; diff --git a/warpgate-database-protocols/src/mysql/protocol/response/ok.rs b/warpgate-database-protocols/src/mysql/protocol/response/ok.rs new file mode 100644 index 0000000..cfd8089 --- /dev/null +++ b/warpgate-database-protocols/src/mysql/protocol/response/ok.rs @@ -0,0 +1,63 @@ +use bytes::{Buf, BufMut, Bytes}; + +use crate::err_protocol; +use crate::error::Error; +use crate::io::{Decode, Encode}; +use crate::mysql::io::{MySqlBufExt, MySqlBufMutExt}; +use crate::mysql::protocol::response::Status; + +/// Indicates successful completion of a previous command sent by the client. +#[derive(Debug)] +pub struct OkPacket { + pub affected_rows: u64, + pub last_insert_id: u64, + pub status: Status, + pub warnings: u16, +} + +impl Decode<'_> for OkPacket { + fn decode_with(mut buf: Bytes, _: ()) -> Result { + let header = buf.get_u8(); + if header != 0 && header != 0xfe { + return Err(err_protocol!( + "expected 0x00 or 0xfe (OK_Packet) but found 0x{:02x}", + header + )); + } + + let affected_rows = buf.get_uint_lenenc(); + let last_insert_id = buf.get_uint_lenenc(); + let status = Status::from_bits_truncate(buf.get_u16_le()); + let warnings = buf.get_u16_le(); + + Ok(Self { + affected_rows, + last_insert_id, + status, + warnings, + }) + } +} + +impl Encode<'_, ()> for OkPacket { + fn encode_with(&self, buf: &mut Vec, _: ()) { + buf.put_u8(0); + buf.put_uint_lenenc(self.affected_rows); + buf.put_uint_lenenc(self.last_insert_id); + buf.put_u16_le(self.status.bits()); + buf.put_u16_le(self.warnings); + } +} + +#[test] +fn test_decode_ok_packet() { + const DATA: &[u8] = b"\x00\x00\x00\x02@\x00\x00"; + + let p = OkPacket::decode(DATA.into()).unwrap(); + + assert_eq!(p.affected_rows, 0); + assert_eq!(p.last_insert_id, 0); + assert_eq!(p.warnings, 0); + assert!(p.status.contains(Status::SERVER_STATUS_AUTOCOMMIT)); + assert!(p.status.contains(Status::SERVER_SESSION_STATE_CHANGED)); +} diff --git a/warpgate-database-protocols/src/mysql/protocol/response/status.rs b/warpgate-database-protocols/src/mysql/protocol/response/status.rs new file mode 100644 index 0000000..0338c0d --- /dev/null +++ b/warpgate-database-protocols/src/mysql/protocol/response/status.rs @@ -0,0 +1,49 @@ +// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/mysql__com_8h.html#a1d854e841086925be1883e4d7b4e8cad +// https://mariadb.com/kb/en/library/mariadb-connectorc-types-and-definitions/#server-status +bitflags::bitflags! { + pub struct Status: u16 { + // Is raised when a multi-statement transaction has been started, either explicitly, + // by means of BEGIN or COMMIT AND CHAIN, or implicitly, by the first + // transactional statement, when autocommit=off. + const SERVER_STATUS_IN_TRANS = 1; + + // Autocommit mode is set + const SERVER_STATUS_AUTOCOMMIT = 2; + + // Multi query - next query exists. + const SERVER_MORE_RESULTS_EXISTS = 8; + + const SERVER_QUERY_NO_GOOD_INDEX_USED = 16; + const SERVER_QUERY_NO_INDEX_USED = 32; + + // When using COM_STMT_FETCH, indicate that current cursor still has result + const SERVER_STATUS_CURSOR_EXISTS = 64; + + // When using COM_STMT_FETCH, indicate that current cursor has finished to send results + const SERVER_STATUS_LAST_ROW_SENT = 128; + + // Database has been dropped + const SERVER_STATUS_DB_DROPPED = (1 << 8); + + // Current escape mode is "no backslash escape" + const SERVER_STATUS_NO_BACKSLASH_ESCAPES = (1 << 9); + + // A DDL change did have an impact on an existing PREPARE (an automatic + // re-prepare has been executed) + const SERVER_STATUS_METADATA_CHANGED = (1 << 10); + + // Last statement took more than the time value specified + // in server variable long_query_time. + const SERVER_QUERY_WAS_SLOW = (1 << 11); + + // This result-set contain stored procedure output parameter. + const SERVER_PS_OUT_PARAMS = (1 << 12); + + // Current transaction is a read-only transaction. + const SERVER_STATUS_IN_TRANS_READONLY = (1 << 13); + + // This status flag, when on, implies that one of the state information has changed + // on the server because of the execution of the last statement. + const SERVER_SESSION_STATE_CHANGED = (1 << 14); + } +} diff --git a/warpgate-database-protocols/src/mysql/protocol/row.rs b/warpgate-database-protocols/src/mysql/protocol/row.rs new file mode 100644 index 0000000..8e53be6 --- /dev/null +++ b/warpgate-database-protocols/src/mysql/protocol/row.rs @@ -0,0 +1,17 @@ +use std::ops::Range; + +use bytes::Bytes; + +#[derive(Debug)] +pub struct Row { + pub(crate) storage: Bytes, + pub(crate) values: Vec>>, +} + +impl Row { + pub(crate) fn get(&self, index: usize) -> Option<&[u8]> { + self.values[index] + .as_ref() + .map(|col| &self.storage[(col.start as usize)..(col.end as usize)]) + } +} diff --git a/warpgate-database-protocols/src/mysql/protocol/text/column.rs b/warpgate-database-protocols/src/mysql/protocol/text/column.rs new file mode 100644 index 0000000..c901f29 --- /dev/null +++ b/warpgate-database-protocols/src/mysql/protocol/text/column.rs @@ -0,0 +1,265 @@ +use std::str::from_utf8; + +use bitflags::bitflags; +use bytes::{Buf, Bytes}; + +use crate::err_protocol; +use crate::error::Error; +use crate::io::Decode; +use crate::mysql::io::MySqlBufExt; +use crate::mysql::protocol::Capabilities; + +// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/group__group__cs__column__definition__flags.html + +bitflags! { + #[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] + pub struct ColumnFlags: u16 { + /// Field can't be `NULL`. + const NOT_NULL = 1; + + /// Field is part of a primary key. + const PRIMARY_KEY = 2; + + /// Field is part of a unique key. + const UNIQUE_KEY = 4; + + /// Field is part of a multi-part unique or primary key. + const MULTIPLE_KEY = 8; + + /// Field is a blob. + const BLOB = 16; + + /// Field is unsigned. + const UNSIGNED = 32; + + /// Field is zero filled. + const ZEROFILL = 64; + + /// Field is binary. + const BINARY = 128; + + /// Field is an enumeration. + const ENUM = 256; + + /// Field is an auto-incement field. + const AUTO_INCREMENT = 512; + + /// Field is a timestamp. + const TIMESTAMP = 1024; + + /// Field is a set. + const SET = 2048; + + /// Field does not have a default value. + const NO_DEFAULT_VALUE = 4096; + + /// Field is set to NOW on UPDATE. + const ON_UPDATE_NOW = 8192; + + /// Field is a number. + const NUM = 32768; + } +} + +// https://dev.mysql.com/doc/internals/en/com-query-response.html#column-type + +#[derive(Debug, Copy, Clone, PartialEq)] +#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] +#[repr(u8)] +pub enum ColumnType { + Decimal = 0x00, + Tiny = 0x01, + Short = 0x02, + Long = 0x03, + Float = 0x04, + Double = 0x05, + Null = 0x06, + Timestamp = 0x07, + LongLong = 0x08, + Int24 = 0x09, + Date = 0x0a, + Time = 0x0b, + Datetime = 0x0c, + Year = 0x0d, + VarChar = 0x0f, + Bit = 0x10, + Json = 0xf5, + NewDecimal = 0xf6, + Enum = 0xf7, + Set = 0xf8, + TinyBlob = 0xf9, + MediumBlob = 0xfa, + LongBlob = 0xfb, + Blob = 0xfc, + VarString = 0xfd, + String = 0xfe, + Geometry = 0xff, +} + +// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_com_query_response_text_resultset_column_definition.html +// https://mariadb.com/kb/en/resultset/#column-definition-packet +// https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41 + +#[derive(Debug)] +pub struct ColumnDefinition { + #[allow(unused)] + catalog: Bytes, + #[allow(unused)] + schema: Bytes, + #[allow(unused)] + table_alias: Bytes, + #[allow(unused)] + table: Bytes, + alias: Bytes, + name: Bytes, + pub(crate) char_set: u16, + pub(crate) max_size: u32, + pub(crate) r#type: ColumnType, + pub(crate) flags: ColumnFlags, + #[allow(unused)] + decimals: u8, +} + +impl ColumnDefinition { + // NOTE: strings in-protocol are transmitted according to the client character set + // as this is UTF-8, all these strings should be UTF-8 + + pub(crate) fn name(&self) -> Result<&str, Error> { + from_utf8(&self.name).map_err(Error::protocol) + } + + pub(crate) fn alias(&self) -> Result<&str, Error> { + from_utf8(&self.alias).map_err(Error::protocol) + } +} + +impl Decode<'_, Capabilities> for ColumnDefinition { + fn decode_with(mut buf: Bytes, _: Capabilities) -> Result { + let catalog = buf.get_bytes_lenenc(); + let schema = buf.get_bytes_lenenc(); + let table_alias = buf.get_bytes_lenenc(); + let table = buf.get_bytes_lenenc(); + let alias = buf.get_bytes_lenenc(); + let name = buf.get_bytes_lenenc(); + let _next_len = buf.get_uint_lenenc(); // always 0x0c + let char_set = buf.get_u16_le(); + let max_size = buf.get_u32_le(); + let type_id = buf.get_u8(); + let flags = buf.get_u16_le(); + let decimals = buf.get_u8(); + + Ok(Self { + catalog, + schema, + table_alias, + table, + alias, + name, + char_set, + max_size, + r#type: ColumnType::try_from_u16(type_id)?, + flags: ColumnFlags::from_bits_truncate(flags), + decimals, + }) + } +} + +impl ColumnType { + pub(crate) fn name( + self, + char_set: u16, + flags: ColumnFlags, + max_size: Option, + ) -> &'static str { + let is_binary = char_set == 63; + let is_unsigned = flags.contains(ColumnFlags::UNSIGNED); + let is_enum = flags.contains(ColumnFlags::ENUM); + + match self { + ColumnType::Tiny if max_size == Some(1) => "BOOLEAN", + ColumnType::Tiny if is_unsigned => "TINYINT UNSIGNED", + ColumnType::Short if is_unsigned => "SMALLINT UNSIGNED", + ColumnType::Long if is_unsigned => "INT UNSIGNED", + ColumnType::Int24 if is_unsigned => "MEDIUMINT UNSIGNED", + ColumnType::LongLong if is_unsigned => "BIGINT UNSIGNED", + ColumnType::Tiny => "TINYINT", + ColumnType::Short => "SMALLINT", + ColumnType::Long => "INT", + ColumnType::Int24 => "MEDIUMINT", + ColumnType::LongLong => "BIGINT", + ColumnType::Float => "FLOAT", + ColumnType::Double => "DOUBLE", + ColumnType::Null => "NULL", + ColumnType::Timestamp => "TIMESTAMP", + ColumnType::Date => "DATE", + ColumnType::Time => "TIME", + ColumnType::Datetime => "DATETIME", + ColumnType::Year => "YEAR", + ColumnType::Bit => "BIT", + ColumnType::Enum => "ENUM", + ColumnType::Set => "SET", + ColumnType::Decimal | ColumnType::NewDecimal => "DECIMAL", + ColumnType::Geometry => "GEOMETRY", + ColumnType::Json => "JSON", + + ColumnType::String if is_binary => "BINARY", + ColumnType::String if is_enum => "ENUM", + ColumnType::VarChar | ColumnType::VarString if is_binary => "VARBINARY", + + ColumnType::String => "CHAR", + ColumnType::VarChar | ColumnType::VarString => "VARCHAR", + + ColumnType::TinyBlob if is_binary => "TINYBLOB", + ColumnType::TinyBlob => "TINYTEXT", + + ColumnType::Blob if is_binary => "BLOB", + ColumnType::Blob => "TEXT", + + ColumnType::MediumBlob if is_binary => "MEDIUMBLOB", + ColumnType::MediumBlob => "MEDIUMTEXT", + + ColumnType::LongBlob if is_binary => "LONGBLOB", + ColumnType::LongBlob => "LONGTEXT", + } + } + + pub(crate) fn try_from_u16(id: u8) -> Result { + Ok(match id { + 0x00 => ColumnType::Decimal, + 0x01 => ColumnType::Tiny, + 0x02 => ColumnType::Short, + 0x03 => ColumnType::Long, + 0x04 => ColumnType::Float, + 0x05 => ColumnType::Double, + 0x06 => ColumnType::Null, + 0x07 => ColumnType::Timestamp, + 0x08 => ColumnType::LongLong, + 0x09 => ColumnType::Int24, + 0x0a => ColumnType::Date, + 0x0b => ColumnType::Time, + 0x0c => ColumnType::Datetime, + 0x0d => ColumnType::Year, + // [internal] 0x0e => ColumnType::NewDate, + 0x0f => ColumnType::VarChar, + 0x10 => ColumnType::Bit, + // [internal] 0x11 => ColumnType::Timestamp2, + // [internal] 0x12 => ColumnType::Datetime2, + // [internal] 0x13 => ColumnType::Time2, + 0xf5 => ColumnType::Json, + 0xf6 => ColumnType::NewDecimal, + 0xf7 => ColumnType::Enum, + 0xf8 => ColumnType::Set, + 0xf9 => ColumnType::TinyBlob, + 0xfa => ColumnType::MediumBlob, + 0xfb => ColumnType::LongBlob, + 0xfc => ColumnType::Blob, + 0xfd => ColumnType::VarString, + 0xfe => ColumnType::String, + 0xff => ColumnType::Geometry, + + _ => { + return Err(err_protocol!("unknown column type 0x{:02x}", id)); + } + }) + } +} diff --git a/warpgate-database-protocols/src/mysql/protocol/text/mod.rs b/warpgate-database-protocols/src/mysql/protocol/text/mod.rs new file mode 100644 index 0000000..6c174e7 --- /dev/null +++ b/warpgate-database-protocols/src/mysql/protocol/text/mod.rs @@ -0,0 +1,9 @@ +mod column; +mod ping; +mod query; +mod quit; + +pub use column::{ColumnDefinition, ColumnFlags, ColumnType}; +pub use ping::Ping; +pub use query::Query; +pub use quit::Quit; diff --git a/warpgate-database-protocols/src/mysql/protocol/text/ping.rs b/warpgate-database-protocols/src/mysql/protocol/text/ping.rs new file mode 100644 index 0000000..97c21d1 --- /dev/null +++ b/warpgate-database-protocols/src/mysql/protocol/text/ping.rs @@ -0,0 +1,13 @@ +use crate::io::Encode; +use crate::mysql::protocol::Capabilities; + +// https://dev.mysql.com/doc/internals/en/com-ping.html + +#[derive(Debug)] +pub struct Ping; + +impl Encode<'_, Capabilities> for Ping { + fn encode_with(&self, buf: &mut Vec, _: Capabilities) { + buf.push(0x0e); // COM_PING + } +} diff --git a/warpgate-database-protocols/src/mysql/protocol/text/query.rs b/warpgate-database-protocols/src/mysql/protocol/text/query.rs new file mode 100644 index 0000000..3209d48 --- /dev/null +++ b/warpgate-database-protocols/src/mysql/protocol/text/query.rs @@ -0,0 +1,32 @@ +use bytes::{Buf, Bytes}; + +use crate::error::Error; +use crate::io::{BufExt, Decode, Encode}; +use crate::mysql::protocol::Capabilities; + +// https://dev.mysql.com/doc/internals/en/com-query.html + +#[derive(Debug)] +pub struct Query(pub String); + +impl Encode<'_, ()> for Query { + fn encode_with(&self, buf: &mut Vec, _: ()) { + buf.push(0x03); // COM_QUERY + buf.extend(self.0.as_bytes()) + } +} + +impl Encode<'_, Capabilities> for Query { + fn encode_with(&self, buf: &mut Vec, _: Capabilities) { + buf.push(0x03); // COM_QUERY + buf.extend(self.0.as_bytes()) + } +} + +impl Decode<'_> for Query { + fn decode_with(mut buf: Bytes, _: ()) -> Result { + buf.advance(1); + let q = buf.get_str(buf.len())?; + Ok(Query(q)) + } +} diff --git a/warpgate-database-protocols/src/mysql/protocol/text/quit.rs b/warpgate-database-protocols/src/mysql/protocol/text/quit.rs new file mode 100644 index 0000000..ef6676d --- /dev/null +++ b/warpgate-database-protocols/src/mysql/protocol/text/quit.rs @@ -0,0 +1,13 @@ +use crate::io::Encode; +use crate::mysql::protocol::Capabilities; + +// https://dev.mysql.com/doc/internals/en/com-quit.html + +#[derive(Debug)] +pub struct Quit; + +impl Encode<'_, Capabilities> for Quit { + fn encode_with(&self, buf: &mut Vec, _: Capabilities) { + buf.push(0x01); // COM_QUIT + } +} diff --git a/warpgate-protocol-mysql/Cargo.toml b/warpgate-protocol-mysql/Cargo.toml index a68f9f6..bbc0ba9 100644 --- a/warpgate-protocol-mysql/Cargo.toml +++ b/warpgate-protocol-mysql/Cargo.toml @@ -5,13 +5,10 @@ name = "warpgate-protocol-mysql" version = "0.3.0" [dependencies] -sqlx-core-guts = { version = "0.6", features = [ - "runtime-tokio-rustls", - "mysql", -], path = "../../sqlx-core-guts/sqlx-core-guts" } warpgate-admin = { version = "*", path = "../warpgate-admin" } warpgate-common = { version = "*", path = "../warpgate-common" } warpgate-db-entities = { version = "*", path = "../warpgate-db-entities" } +warpgate-database-protocols = { version = "*", path = "../warpgate-database-protocols" } anyhow = { version = "1.0", features = ["std"] } async-trait = "0.1" tokio = { version = "1.19", features = ["tracing", "signal"] } @@ -23,7 +20,7 @@ rand = "0.8" sha1 = "0.10.1" password-hash = { version = "0.2", features = ["std"] } delegate = "0.6" -rustls = "0.20" +rustls = { version = "0.20", features = ["dangerous_configuration"] } rustls-pemfile = "1.0" tokio-rustls = "0.23" thiserror = "1.0" diff --git a/warpgate-protocol-mysql/src/client.rs b/warpgate-protocol-mysql/src/client.rs index 3a8801a..a2ba6b0 100644 --- a/warpgate-protocol-mysql/src/client.rs +++ b/warpgate-protocol-mysql/src/client.rs @@ -1,17 +1,19 @@ use std::sync::Arc; use bytes::BytesMut; -use sqlx_core_guts::io::Decode; -use sqlx_core_guts::mysql::protocol::auth::AuthPlugin; -use sqlx_core_guts::mysql::protocol::connect::{Handshake, HandshakeResponse, SslRequest}; -use sqlx_core_guts::mysql::protocol::response::ErrPacket; -use sqlx_core_guts::mysql::protocol::Capabilities; -use sqlx_core_guts::mysql::MySqlSslMode; use tokio::net::TcpStream; use tracing::*; +use warpgate_common::{TargetMySqlOptions, TlsMode}; +use warpgate_database_protocols::io::Decode; +use warpgate_database_protocols::mysql::protocol::auth::AuthPlugin; +use warpgate_database_protocols::mysql::protocol::connect::{ + Handshake, HandshakeResponse, SslRequest, +}; +use warpgate_database_protocols::mysql::protocol::response::ErrPacket; +use warpgate_database_protocols::mysql::protocol::Capabilities; -use crate::common::{compute_auth_challenge_response, parse_mysql_uri}; -use crate::error::{InvalidMySqlTargetConfig, MySqlError}; +use crate::common::compute_auth_challenge_response; +use crate::error::MySqlError; use crate::stream::MySqlStream; use crate::tls::configure_tls_connector; @@ -28,13 +30,15 @@ pub struct ConnectionOptions { } impl MySqlClient { - pub async fn connect(uri: &str, mut options: ConnectionOptions) -> Result { - let opts = parse_mysql_uri(uri)?; + pub async fn connect( + target: &TargetMySqlOptions, + mut options: ConnectionOptions, + ) -> Result { let mut stream = - MySqlStream::new(TcpStream::connect((opts.host.clone(), opts.port)).await?); + MySqlStream::new(TcpStream::connect((target.host.clone(), target.port)).await?); options.capabilities.remove(Capabilities::SSL); - if opts.ssl_mode != MySqlSslMode::Disabled { + if target.tls.mode != TlsMode::Disabled { options.capabilities |= Capabilities::SSL; } @@ -44,20 +48,17 @@ impl MySqlClient { let handshake = Handshake::decode(payload)?; options.capabilities &= handshake.server_capabilities; - if opts.ssl_mode != MySqlSslMode::Disabled - && opts.ssl_mode != MySqlSslMode::Preferred - && !options.capabilities.contains(Capabilities::SSL) + if target.tls.mode == TlsMode::Required && !options.capabilities.contains(Capabilities::SSL) { return Err(MySqlError::TlsNotSupported); } info!(capabilities=?options.capabilities, "Target handshake"); - if options.capabilities.contains(Capabilities::SSL) - && opts.ssl_mode != MySqlSslMode::Disabled + if options.capabilities.contains(Capabilities::SSL) && target.tls.mode != TlsMode::Disabled { - let accept_invalid_certs = opts.ssl_mode == MySqlSslMode::Preferred; - let accept_invalid_hostname = opts.ssl_mode != MySqlSslMode::VerifyIdentity; + let accept_invalid_certs = !target.tls.verify; + let accept_invalid_hostname = false; // ca + hostname verification let client_config = Arc::new( configure_tls_connector(accept_invalid_certs, accept_invalid_hostname, None) .await?, @@ -70,7 +71,8 @@ impl MySqlClient { stream.flush().await?; stream = stream .upgrade(( - opts.host + target + .host .as_str() .try_into() .map_err(|_| MySqlError::InvalidDomainName)?, @@ -86,7 +88,7 @@ impl MySqlClient { collation: options.collation, database: options.database, max_packet_size: options.max_packet_size, - username: opts.username, + username: target.username.clone(), }; if handshake.auth_plugin == Some(AuthPlugin::MySqlNativePassword) { @@ -100,15 +102,15 @@ impl MySqlClient { warn!("Invalid scramble length ({})", scramble_bytes.len()); } Ok(scramble) => { - let Some(password) = opts.password else { - return Err(InvalidMySqlTargetConfig::NoPassword.into()) - }; response.auth_plugin = Some(AuthPlugin::MySqlNativePassword); response.auth_response = Some( BytesMut::from( - compute_auth_challenge_response(scramble, &password) - .map_err(MySqlError::other)? - .as_bytes(), + compute_auth_challenge_response( + scramble, + target.password.as_deref().unwrap_or(""), + ) + .map_err(MySqlError::other)? + .as_bytes(), ) .freeze(), ); diff --git a/warpgate-protocol-mysql/src/common.rs b/warpgate-protocol-mysql/src/common.rs index d07a0c9..a80b191 100644 --- a/warpgate-protocol-mysql/src/common.rs +++ b/warpgate-protocol-mysql/src/common.rs @@ -1,19 +1,8 @@ use sha1::Digest; -use sqlx_core_guts::error::Error as SqlxError; -use sqlx_core_guts::mysql::MySqlConnectOptions; use warpgate_common::ProtocolName; -use crate::error::InvalidMySqlTargetConfig; - pub const PROTOCOL_NAME: ProtocolName = "MySQL"; -pub fn parse_mysql_uri(uri: &str) -> Result { - uri.parse().map_err(|e| match e { - SqlxError::Configuration(e) => InvalidMySqlTargetConfig::UriParse(e), - _ => InvalidMySqlTargetConfig::Unknown, - }) -} - pub fn compute_auth_challenge_response( challenge: [u8; 20], password: &str, diff --git a/warpgate-protocol-mysql/src/error.rs b/warpgate-protocol-mysql/src/error.rs index f45a911..3a96f57 100644 --- a/warpgate-protocol-mysql/src/error.rs +++ b/warpgate-protocol-mysql/src/error.rs @@ -1,15 +1,13 @@ use std::error::Error; -use sqlx_core_guts::error::Error as SqlxError; use warpgate_common::WarpgateError; +use warpgate_database_protocols::error::Error as SqlxError; use crate::stream::MySqlStreamError; use crate::tls::{MaybeTlsStreamError, RustlsSetupError}; #[derive(thiserror::Error, Debug)] pub enum MySqlError { - #[error("invalid target config: {0}")] - InvalidTargetConfig(#[from] InvalidMySqlTargetConfig), #[error("protocol error: {0}")] ProtocolError(String), #[error("sudden disconnection")] @@ -38,16 +36,6 @@ pub enum MySqlError { Other(Box), } -#[derive(thiserror::Error, Debug)] -pub enum InvalidMySqlTargetConfig { - #[error("Password not set")] - NoPassword, - #[error("URI parse error: {0}")] - UriParse(Box), - #[error("Unkown")] - Unknown, -} - impl MySqlError { pub fn other(err: E) -> Self { Self::Other(Box::new(err)) diff --git a/warpgate-protocol-mysql/src/session.rs b/warpgate-protocol-mysql/src/session.rs index 3867132..b9d6299 100644 --- a/warpgate-protocol-mysql/src/session.rs +++ b/warpgate-protocol-mysql/src/session.rs @@ -3,12 +3,6 @@ use std::sync::Arc; use bytes::{Buf, Bytes, BytesMut}; use rand::Rng; use rustls::ServerConfig; -use sqlx_core_guts::io::{BufExt, Decode}; -use sqlx_core_guts::mysql::protocol::auth::AuthPlugin; -use sqlx_core_guts::mysql::protocol::connect::{AuthSwitchRequest, Handshake, HandshakeResponse}; -use sqlx_core_guts::mysql::protocol::response::{ErrPacket, OkPacket, Status}; -use sqlx_core_guts::mysql::protocol::text::Query; -use sqlx_core_guts::mysql::protocol::Capabilities; use tokio::net::TcpStream; use tokio::sync::Mutex; use tracing::*; @@ -19,6 +13,14 @@ use warpgate_common::{ authorize_ticket, AuthCredential, AuthResult, Secret, Services, TargetMySqlOptions, TargetOptions, WarpgateServerHandle, }; +use warpgate_database_protocols::io::{BufExt, Decode}; +use warpgate_database_protocols::mysql::protocol::auth::AuthPlugin; +use warpgate_database_protocols::mysql::protocol::connect::{ + AuthSwitchRequest, Handshake, HandshakeResponse, +}; +use warpgate_database_protocols::mysql::protocol::response::{ErrPacket, OkPacket, Status}; +use warpgate_database_protocols::mysql::protocol::text::Query; +use warpgate_database_protocols::mysql::protocol::Capabilities; use crate::client::{ConnectionOptions, MySqlClient}; use crate::error::MySqlError; @@ -313,7 +315,7 @@ impl MySqlSession { } let mut client = match MySqlClient::connect( - &options.uri, + &options, ConnectionOptions { collation: handshake.collation, database: handshake.database, diff --git a/warpgate-protocol-mysql/src/stream.rs b/warpgate-protocol-mysql/src/stream.rs index fe493f5..f429c71 100644 --- a/warpgate-protocol-mysql/src/stream.rs +++ b/warpgate-protocol-mysql/src/stream.rs @@ -1,10 +1,10 @@ use bytes::{Bytes, BytesMut}; use mysql_common::proto::codec::error::PacketCodecError; use mysql_common::proto::codec::PacketCodec; -use sqlx_core_guts::io::Encode; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::net::TcpStream; use tracing::*; +use warpgate_database_protocols::io::Encode; use crate::tls::{MaybeTlsStream, MaybeTlsStreamError, UpgradableStream};