This commit is contained in:
Eugene Pankov 2022-07-19 22:03:35 +02:00
parent d353e63f95
commit 8ff3bc7924
No known key found for this signature in database
GPG key ID: 5896FCBBDD1CF4F4
45 changed files with 3101 additions and 242 deletions

189
Cargo.lock generated
View file

@ -760,12 +760,6 @@ dependencies = [
"tracing-subscriber",
]
[[package]]
name = "const-oid"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e4c78c047431fee22c1a7bb92e00ad095a02a983affe4d8a72e2a2c62c1b94f3"
[[package]]
name = "constant_time_eq"
version = "0.2.1"
@ -875,16 +869,6 @@ dependencies = [
"lazy_static",
]
[[package]]
name = "crypto-bigint"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "03c6a1d5fa1de37e071642dfa44ec552ca5b299adb128fab16138e24b548fd21"
dependencies = [
"generic-array",
"subtle",
]
[[package]]
name = "crypto-common"
version = "0.1.3"
@ -976,17 +960,6 @@ dependencies = [
"syn",
]
[[package]]
name = "der"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6919815d73839e7ad218de758883aae3a257ba6759ce7a9992501efbb53d705c"
dependencies = [
"const-oid",
"crypto-bigint",
"pem-rfc7468",
]
[[package]]
name = "derive_more"
version = "0.99.17"
@ -1828,9 +1801,6 @@ name = "lazy_static"
version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
dependencies = [
"spin 0.5.2",
]
[[package]]
name = "lazycell"
@ -1927,12 +1897,6 @@ dependencies = [
"winapi",
]
[[package]]
name = "libm"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "33a33a362ce288760ec6a508b94caaec573ae7d3bbbd91b87aa0bad4456839db"
[[package]]
name = "libsodium-sys"
version = "0.2.7"
@ -2211,23 +2175,6 @@ dependencies = [
"num-traits",
]
[[package]]
name = "num-bigint-dig"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "566d173b2f9406afbc5510a90925d5a2cd80cae4605631f1212303df265de011"
dependencies = [
"byteorder",
"lazy_static",
"libm",
"num-integer",
"num-iter",
"num-traits",
"rand",
"smallvec",
"zeroize",
]
[[package]]
name = "num-integer"
version = "0.1.44"
@ -2267,7 +2214,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a64b1ec5cda2586e284722486d802acf1f7dbdc623e2bfc57e65ca1cd099290"
dependencies = [
"autocfg",
"libm",
]
[[package]]
@ -2523,15 +2469,6 @@ dependencies = [
"base64",
]
[[package]]
name = "pem-rfc7468"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "01de5d978f34aa4b2296576379fcc416034702fd94117c56ffd8a1a767cefb30"
dependencies = [
"base64ct",
]
[[package]]
name = "percent-encoding"
version = "2.1.0"
@ -2588,28 +2525,6 @@ version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
[[package]]
name = "pkcs1"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a78f66c04ccc83dd4486fd46c33896f4e17b24a7a3a6400dedc48ed0ddd72320"
dependencies = [
"der",
"pkcs8",
"zeroize",
]
[[package]]
name = "pkcs8"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7cabda3fb821068a9a4fab19a683eac3af12edf0f34b94a8be53c4972b8149d0"
dependencies = [
"der",
"spki",
"zeroize",
]
[[package]]
name = "pkg-config"
version = "0.3.25"
@ -3064,26 +2979,6 @@ dependencies = [
"winapi",
]
[[package]]
name = "rsa"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4cf22754c49613d2b3b119f0e5d46e34a2c628a937e3024b8762de4e7d8c710b"
dependencies = [
"byteorder",
"digest 0.10.3",
"num-bigint-dig",
"num-integer",
"num-iter",
"num-traits",
"pkcs1",
"pkcs8",
"rand_core",
"smallvec",
"subtle",
"zeroize",
]
[[package]]
name = "russh"
version = "0.34.0-beta.5"
@ -3672,16 +3567,6 @@ dependencies = [
"lock_api",
]
[[package]]
name = "spki"
version = "0.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "44d01ac02a6ccf3e07db148d2be087da624fea0221a16152ed01f0496a6b0a27"
dependencies = [
"base64ct",
"der",
]
[[package]]
name = "sqlformat"
version = "0.1.8"
@ -3741,7 +3626,7 @@ dependencies = [
"sha2 0.10.2",
"smallvec",
"sqlformat",
"sqlx-rt 0.6.0 (registry+https://github.com/rust-lang/crates.io-index)",
"sqlx-rt",
"stringprep",
"thiserror",
"tokio-stream",
@ -3749,52 +3634,6 @@ dependencies = [
"uuid",
]
[[package]]
name = "sqlx-core-guts"
version = "0.6.0"
dependencies = [
"ahash",
"atoi",
"bitflags",
"byteorder",
"bytes",
"crc",
"crossbeam-queue",
"digest 0.10.3",
"either",
"event-listener",
"futures-channel",
"futures-core",
"futures-intrusive",
"futures-util",
"generic-array",
"hashlink",
"hex",
"indexmap",
"itoa",
"libc",
"log",
"memchr",
"num-bigint",
"once_cell",
"paste",
"percent-encoding",
"rand",
"rsa",
"rustls",
"rustls-pemfile",
"sha-1",
"sha2 0.10.2",
"smallvec",
"sqlformat",
"sqlx-rt 0.6.0",
"stringprep",
"thiserror",
"tokio-stream",
"url",
"webpki-roots",
]
[[package]]
name = "sqlx-macros"
version = "0.6.0"
@ -3810,20 +3649,11 @@ dependencies = [
"serde_json",
"sha2 0.10.2",
"sqlx-core",
"sqlx-rt 0.6.0 (registry+https://github.com/rust-lang/crates.io-index)",
"sqlx-rt",
"syn",
"url",
]
[[package]]
name = "sqlx-rt"
version = "0.6.0"
dependencies = [
"once_cell",
"tokio",
"tokio-rustls",
]
[[package]]
name = "sqlx-rt"
version = "0.6.0"
@ -4598,6 +4428,19 @@ dependencies = [
"warpgate-db-migrations",
]
[[package]]
name = "warpgate-database-protocols"
version = "0.3.0"
dependencies = [
"bitflags",
"bytes",
"futures-core",
"futures-util",
"memchr",
"thiserror",
"tokio",
]
[[package]]
name = "warpgate-db-entities"
version = "0.3.0"
@ -4664,7 +4507,6 @@ dependencies = [
"rustls",
"rustls-pemfile",
"sha1",
"sqlx-core-guts",
"thiserror",
"tokio",
"tokio-rustls",
@ -4672,6 +4514,7 @@ dependencies = [
"uuid",
"warpgate-admin",
"warpgate-common",
"warpgate-database-protocols",
"warpgate-db-entities",
"webpki",
"webpki-roots",

View file

@ -5,6 +5,7 @@ members = [
"warpgate-common",
"warpgate-db-migrations",
"warpgate-db-entities",
"warpgate-database-protocols",
"warpgate-protocol-http",
"warpgate-protocol-mysql",
"warpgate-protocol-ssh",

View file

@ -1,4 +1,4 @@
projects := "warpgate warpgate-admin warpgate-common warpgate-db-entities warpgate-db-migrations warpgate-protocol-ssh warpgate-protocol-mysql"
projects := "warpgate warpgate-admin warpgate-common warpgate-db-entities warpgate-db-migrations warpgate-database-protocols warpgate-protocol-ssh warpgate-protocol-mysql"
run *ARGS:
RUST_BACKTRACE=1 RUST_LOG=warpgate cd warpgate && cargo run -- --config ../config.yaml {{ARGS}}

View file

@ -3,7 +3,7 @@ use std::net::ToSocketAddrs;
use std::path::PathBuf;
use std::time::Duration;
use poem_openapi::{Object, Union};
use poem_openapi::{Enum, Object, Union};
use serde::{Deserialize, Serialize};
use crate::helpers::otp::OtpSecretKey;
@ -17,10 +17,14 @@ const fn _default_false() -> bool {
false
}
const fn _default_port() -> u16 {
const fn _default_ssh_port() -> u16 {
22
}
const fn _default_mysql_port() -> u16 {
3306
}
#[inline]
fn _default_username() -> String {
"root".to_owned()
@ -64,7 +68,7 @@ fn _default_empty_vec<T>() -> Vec<T> {
#[derive(Debug, Deserialize, Serialize, Clone, Object)]
pub struct TargetSSHOptions {
pub host: String,
#[serde(default = "_default_port")]
#[serde(default = "_default_ssh_port")]
pub port: u16,
#[serde(default = "_default_username")]
pub username: String,
@ -97,10 +101,57 @@ pub struct TargetHTTPOptions {
pub headers: Option<HashMap<String, String>>,
}
#[derive(Debug, Deserialize, Serialize, Clone, Enum, PartialEq, Eq)]
pub enum TlsMode {
Disabled,
Preferred,
Required,
}
impl Default for TlsMode {
fn default() -> Self {
TlsMode::Preferred
}
}
#[derive(Debug, Deserialize, Serialize, Clone, Object)]
pub struct Tls {
#[serde(default)]
pub mode: TlsMode,
#[serde(default)]
pub verify: bool,
}
#[allow(clippy::derivable_impls)]
impl Default for Tls {
fn default() -> Self {
Self {
mode: TlsMode::default(),
verify: false,
}
}
}
#[derive(Debug, Deserialize, Serialize, Clone, Object)]
pub struct TargetMySqlOptions {
#[serde(default = "_default_empty_string")]
pub uri: String,
pub host: String,
#[serde(default = "_default_mysql_port")]
pub port: u16,
#[serde(default = "_default_username")]
pub username: String,
#[serde(default)]
pub password: Option<String>,
#[serde(default)]
pub tls: Tls,
#[serde(default)]
pub verify_tls: bool,
}
#[derive(Debug, Deserialize, Serialize, Clone, Object, Default)]

View file

@ -0,0 +1,24 @@
[package]
name = "warpgate-database-protocols"
version = "0.3.0"
description = "Core of SQLx, the rust SQL toolkit. Just the database protocol parts."
license = "MIT OR Apache-2.0"
edition = "2021"
authors = [
"Ryan Leckey <leckey.ryan@gmail.com>",
"Austin Bonander <austin.bonander@gmail.com>",
"Chloe Ross <orangesnowfox@gmail.com>",
"Daniel Akhterov <akhterovd@gmail.com>",
]
[dependencies]
tokio = { version = "1.19", features = ["io-util"] }
bitflags = { version = "1.3", default-features = false }
bytes = "1.1"
futures-core = { version = "0.3", default-features = false }
futures-util = { version = "0.3", default-features = false, features = [
"alloc",
"sink",
] }
memchr = { version = "2.4.1", default-features = false }
thiserror = "1.0"

View file

@ -0,0 +1,241 @@
//! Types for working with errors produced by SQLx.
use std::borrow::Cow;
use std::error::Error as StdError;
use std::fmt::Display;
use std::io;
use std::result::Result as StdResult;
/// A specialized `Result` type for SQLx.
pub type Result<T> = StdResult<T, Error>;
// Convenience type alias for usage within SQLx.
// Do not make this type public.
pub type BoxDynError = Box<dyn StdError + 'static + Send + Sync>;
/// An unexpected `NULL` was encountered during decoding.
///
/// Returned from [`Row::get`](crate::row::Row::get) if the value from the database is `NULL`,
/// and you are not decoding into an `Option`.
#[derive(thiserror::Error, Debug)]
#[error("unexpected null; try decoding as an `Option`")]
pub struct UnexpectedNullError;
/// Represents all the ways a method can fail within SQLx.
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum Error {
/// Error occurred while parsing a connection string.
#[error("error with configuration: {0}")]
Configuration(#[source] BoxDynError),
/// Error returned from the database.
#[error("error returned from database: {0}")]
Database(#[source] Box<dyn DatabaseError>),
/// Error communicating with the database backend.
#[error("error communicating with database: {0}")]
Io(#[from] io::Error),
/// Error occurred while attempting to establish a TLS connection.
#[error("error occurred while attempting to establish a TLS connection: {0}")]
Tls(#[source] BoxDynError),
/// Unexpected or invalid data encountered while communicating with the database.
///
/// This should indicate there is a programming error in a SQLx driver or there
/// is something corrupted with the connection to the database itself.
#[error("encountered unexpected or invalid data: {0}")]
Protocol(String),
/// No rows returned by a query that expected to return at least one row.
#[error("no rows returned by a query that expected to return at least one row")]
RowNotFound,
/// Type in query doesn't exist. Likely due to typo or missing user type.
#[error("type named {type_name} not found")]
TypeNotFound { type_name: String },
/// Column index was out of bounds.
#[error("column index out of bounds: the len is {len}, but the index is {index}")]
ColumnIndexOutOfBounds { index: usize, len: usize },
/// No column found for the given name.
#[error("no column found for name: {0}")]
ColumnNotFound(String),
/// Error occurred while decoding a value from a specific column.
#[error("error occurred while decoding column {index}: {source}")]
ColumnDecode {
index: String,
#[source]
source: BoxDynError,
},
/// Error occurred while decoding a value.
#[error("error occurred while decoding: {0}")]
Decode(#[source] BoxDynError),
/// A [`Pool::acquire`] timed out due to connections not becoming available or
/// because another task encountered too many errors while trying to open a new connection.
///
/// [`Pool::acquire`]: crate::pool::Pool::acquire
#[error("pool timed out while waiting for an open connection")]
PoolTimedOut,
/// [`Pool::close`] was called while we were waiting in [`Pool::acquire`].
///
/// [`Pool::acquire`]: crate::pool::Pool::acquire
/// [`Pool::close`]: crate::pool::Pool::close
#[error("attempted to acquire a connection on a closed pool")]
PoolClosed,
/// A background worker has crashed.
#[error("attempted to communicate with a crashed background worker")]
WorkerCrashed,
#[cfg(feature = "migrate")]
#[error("{0}")]
Migrate(#[source] Box<crate::migrate::MigrateError>),
}
impl StdError for Box<dyn DatabaseError> {}
impl Error {
#[allow(dead_code)]
#[inline]
pub(crate) fn protocol(err: impl Display) -> Self {
Error::Protocol(err.to_string())
}
#[allow(dead_code)]
#[inline]
pub(crate) fn config(err: impl StdError + Send + Sync + 'static) -> Self {
Error::Configuration(err.into())
}
}
/// An error that was returned from the database.
pub trait DatabaseError: 'static + Send + Sync + StdError {
/// The primary, human-readable error message.
fn message(&self) -> &str;
/// The (SQLSTATE) code for the error.
fn code(&self) -> Option<Cow<'_, str>> {
None
}
#[doc(hidden)]
fn as_error(&self) -> &(dyn StdError + Send + Sync + 'static);
#[doc(hidden)]
fn as_error_mut(&mut self) -> &mut (dyn StdError + Send + Sync + 'static);
#[doc(hidden)]
fn into_error(self: Box<Self>) -> Box<dyn StdError + Send + Sync + 'static>;
#[doc(hidden)]
fn is_transient_in_connect_phase(&self) -> bool {
false
}
/// Returns the name of the constraint that triggered the error, if applicable.
/// If the error was caused by a conflict of a unique index, this will be the index name.
///
/// ### Note
/// Currently only populated by the Postgres driver.
fn constraint(&self) -> Option<&str> {
None
}
}
impl dyn DatabaseError {
/// Downcast a reference to this generic database error to a specific
/// database error type.
///
/// # Panics
///
/// Panics if the database error type is not `E`. This is a deliberate contrast from
/// `Error::downcast_ref` which returns `Option<&E>`. In normal usage, you should know the
/// specific error type. In other cases, use `try_downcast_ref`.
pub fn downcast_ref<E: DatabaseError>(&self) -> &E {
self.try_downcast_ref().unwrap_or_else(|| {
panic!(
"downcast to wrong DatabaseError type; original error: {}",
self
)
})
}
/// Downcast this generic database error to a specific database error type.
///
/// # Panics
///
/// Panics if the database error type is not `E`. This is a deliberate contrast from
/// `Error::downcast` which returns `Option<E>`. In normal usage, you should know the
/// specific error type. In other cases, use `try_downcast`.
pub fn downcast<E: DatabaseError>(self: Box<Self>) -> Box<E> {
self.try_downcast().unwrap_or_else(|e| {
panic!(
"downcast to wrong DatabaseError type; original error: {}",
e
)
})
}
/// Downcast a reference to this generic database error to a specific
/// database error type.
#[inline]
pub fn try_downcast_ref<E: DatabaseError>(&self) -> Option<&E> {
self.as_error().downcast_ref()
}
/// Downcast this generic database error to a specific database error type.
#[inline]
pub fn try_downcast<E: DatabaseError>(self: Box<Self>) -> StdResult<Box<E>, Box<Self>> {
if self.as_error().is::<E>() {
Ok(self.into_error().downcast().unwrap())
} else {
Err(self)
}
}
}
impl<E> From<E> for Error
where
E: DatabaseError,
{
#[inline]
fn from(error: E) -> Self {
Error::Database(Box::new(error))
}
}
#[cfg(feature = "migrate")]
impl From<crate::migrate::MigrateError> for Error {
#[inline]
fn from(error: crate::migrate::MigrateError) -> Self {
Error::Migrate(Box::new(error))
}
}
#[cfg(feature = "_tls-native-tls")]
impl From<sqlx_rt::native_tls::Error> for Error {
#[inline]
fn from(error: sqlx_rt::native_tls::Error) -> Self {
Error::Tls(Box::new(error))
}
}
// Format an error message as a `Protocol` error
#[macro_export]
macro_rules! err_protocol {
($expr:expr) => {
$crate::error::Error::Protocol($expr.into())
};
($fmt:expr, $($arg:tt)*) => {
$crate::error::Error::Protocol(format!($fmt, $($arg)*))
};
}

View file

@ -0,0 +1,59 @@
use std::str::from_utf8;
use bytes::{Buf, Bytes};
use memchr::memchr;
use crate::err_protocol;
use crate::error::Error;
pub trait BufExt: Buf {
// Read a nul-terminated byte sequence
fn get_bytes_nul(&mut self) -> Result<Bytes, Error>;
// Read a byte sequence of the exact length
fn get_bytes(&mut self, len: usize) -> Bytes;
// Read a nul-terminated string
fn get_str_nul(&mut self) -> Result<String, Error>;
// Read a string of the exact length
fn get_str(&mut self, len: usize) -> Result<String, Error>;
}
impl BufExt for Bytes {
fn get_bytes_nul(&mut self) -> Result<Bytes, Error> {
let nul =
memchr(b'\0', self).ok_or_else(|| err_protocol!("expected NUL in byte sequence"))?;
let v = self.slice(0..nul);
self.advance(nul + 1);
Ok(v)
}
fn get_bytes(&mut self, len: usize) -> Bytes {
let v = self.slice(..len);
self.advance(len);
v
}
fn get_str_nul(&mut self) -> Result<String, Error> {
self.get_bytes_nul().and_then(|bytes| {
from_utf8(&*bytes)
.map(ToOwned::to_owned)
.map_err(|err| err_protocol!("{}", err))
})
}
fn get_str(&mut self, len: usize) -> Result<String, Error> {
let v = from_utf8(&self[..len])
.map_err(|err| err_protocol!("{}", err))
.map(ToOwned::to_owned)?;
self.advance(len);
Ok(v)
}
}

View file

@ -0,0 +1,12 @@
use bytes::BufMut;
pub trait BufMutExt: BufMut {
fn put_str_nul(&mut self, s: &str);
}
impl BufMutExt for Vec<u8> {
fn put_str_nul(&mut self, s: &str) {
self.extend(s.as_bytes());
self.push(0);
}
}

View file

@ -0,0 +1,166 @@
#![allow(dead_code)]
use std::io;
use std::io::Cursor;
use std::ops::{Deref, DerefMut};
use bytes::BytesMut;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
use crate::error::Error;
use crate::io::decode::Decode;
use crate::io::encode::Encode;
use crate::io::write_and_flush::WriteAndFlush;
pub struct BufStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
pub(crate) stream: S,
// writes with `write` to the underlying stream are buffered
// this can be flushed with `flush`
pub(crate) wbuf: Vec<u8>,
// we read into the read buffer using 100% safe code
rbuf: BytesMut,
}
impl<S> BufStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
pub fn new(stream: S) -> Self {
Self {
stream,
wbuf: Vec::with_capacity(512),
rbuf: BytesMut::with_capacity(4096),
}
}
pub fn write<'en, T>(&mut self, value: T)
where
T: Encode<'en, ()>,
{
self.write_with(value, ())
}
pub fn write_with<'en, T, C>(&mut self, value: T, context: C)
where
T: Encode<'en, C>,
{
value.encode_with(&mut self.wbuf, context);
}
pub fn flush(&mut self) -> WriteAndFlush<'_, S> {
WriteAndFlush {
stream: &mut self.stream,
buf: Cursor::new(&mut self.wbuf),
}
}
pub async fn read<'de, T>(&mut self, cnt: usize) -> Result<T, Error>
where
T: Decode<'de, ()>,
{
self.read_with(cnt, ()).await
}
pub async fn read_with<'de, T, C>(&mut self, cnt: usize, context: C) -> Result<T, Error>
where
T: Decode<'de, C>,
{
T::decode_with(self.read_raw(cnt).await?.freeze(), context)
}
pub async fn read_raw(&mut self, cnt: usize) -> Result<BytesMut, Error> {
read_raw_into(&mut self.stream, &mut self.rbuf, cnt).await?;
let buf = self.rbuf.split_to(cnt);
Ok(buf)
}
pub async fn read_raw_into(&mut self, buf: &mut BytesMut, cnt: usize) -> Result<(), Error> {
read_raw_into(&mut self.stream, buf, cnt).await
}
pub fn take(self) -> S {
self.stream
}
}
impl<S> Deref for BufStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
type Target = S;
fn deref(&self) -> &Self::Target {
&self.stream
}
}
impl<S> DerefMut for BufStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.stream
}
}
// Holds a buffer which has been temporarily extended, so that
// we can read into it. Automatically shrinks the buffer back
// down if the read is cancelled.
struct BufTruncator<'a> {
buf: &'a mut BytesMut,
filled_len: usize,
}
impl<'a> BufTruncator<'a> {
fn new(buf: &'a mut BytesMut) -> Self {
let filled_len = buf.len();
Self { buf, filled_len }
}
fn reserve(&mut self, space: usize) {
self.buf.resize(self.filled_len + space, 0);
}
async fn read<S: AsyncRead + Unpin>(&mut self, stream: &mut S) -> Result<usize, Error> {
let n = stream.read(&mut self.buf[self.filled_len..]).await?;
self.filled_len += n;
Ok(n)
}
fn is_full(&self) -> bool {
self.filled_len >= self.buf.len()
}
}
impl Drop for BufTruncator<'_> {
fn drop(&mut self) {
self.buf.truncate(self.filled_len);
}
}
async fn read_raw_into<S: AsyncRead + Unpin>(
stream: &mut S,
buf: &mut BytesMut,
cnt: usize,
) -> Result<(), Error> {
let mut buf = BufTruncator::new(buf);
buf.reserve(cnt);
while !buf.is_full() {
let n = buf.read(stream).await?;
if n == 0 {
// a zero read when we had space in the read buffer
// should be treated as an EOF
// and an unexpected EOF means the server told us to go away
return Err(io::Error::from(io::ErrorKind::ConnectionAborted).into());
}
}
Ok(())
}

View file

@ -0,0 +1,29 @@
use bytes::Bytes;
use crate::error::Error;
pub trait Decode<'de, Context = ()>
where
Self: Sized,
{
fn decode(buf: Bytes) -> Result<Self, Error>
where
Self: Decode<'de, ()>,
{
Self::decode_with(buf, ())
}
fn decode_with(buf: Bytes, context: Context) -> Result<Self, Error>;
}
impl Decode<'_> for Bytes {
fn decode_with(buf: Bytes, _: ()) -> Result<Self, Error> {
Ok(buf)
}
}
impl Decode<'_> for () {
fn decode_with(_: Bytes, _: ()) -> Result<(), Error> {
Ok(())
}
}

View file

@ -0,0 +1,16 @@
pub trait Encode<'en, Context = ()> {
fn encode(&self, buf: &mut Vec<u8>)
where
Self: Encode<'en, ()>,
{
self.encode_with(buf, ());
}
fn encode_with(&self, buf: &mut Vec<u8>, context: Context);
}
impl<'en, C> Encode<'en, C> for &'_ [u8] {
fn encode_with(&self, buf: &mut Vec<u8>, _: C) {
buf.extend_from_slice(self);
}
}

View file

@ -0,0 +1,12 @@
mod buf;
mod buf_mut;
mod buf_stream;
mod decode;
mod encode;
mod write_and_flush;
pub use buf::BufExt;
pub use buf_mut::BufMutExt;
pub use buf_stream::BufStream;
pub use decode::Decode;
pub use encode::Encode;

View file

@ -0,0 +1,47 @@
use std::io::{BufRead, Cursor};
use std::pin::Pin;
use std::task::{Context, Poll};
use futures_core::Future;
use futures_util::ready;
use tokio::io::AsyncWrite;
use crate::error::Error;
// Atomic operation that writes the full buffer to the stream, flushes the stream, and then
// clears the buffer (even if either of the two previous operations failed).
pub struct WriteAndFlush<'a, S> {
pub(super) stream: &'a mut S,
pub(super) buf: Cursor<&'a mut Vec<u8>>,
}
impl<S: AsyncWrite + Unpin> Future for WriteAndFlush<'_, S> {
type Output = Result<(), Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let Self {
ref mut stream,
ref mut buf,
} = *self;
loop {
let read = buf.fill_buf()?;
if !read.is_empty() {
let written = ready!(Pin::new(&mut *stream).poll_write(cx, read)?);
buf.consume(written);
} else {
break;
}
}
Pin::new(stream).poll_flush(cx).map_err(Error::Io)
}
}
impl<'a, S> Drop for WriteAndFlush<'a, S> {
fn drop(&mut self) {
// clear the buffer regardless of whether the flush succeeded or not
self.buf.get_mut().clear();
}
}

View file

@ -0,0 +1,6 @@
#![allow(dead_code)]
pub mod io;
pub mod mysql;
#[macro_use]
pub mod error;

View file

@ -0,0 +1,901 @@
use std::str::FromStr;
use crate::error::Error;
#[allow(non_camel_case_types)]
#[derive(Copy, Clone)]
pub(crate) enum CharSet {
armscii8,
ascii,
big5,
binary,
cp1250,
cp1251,
cp1256,
cp1257,
cp850,
cp852,
cp866,
cp932,
dec8,
eucjpms,
euckr,
gb18030,
gb2312,
gbk,
geostd8,
greek,
hebrew,
hp8,
keybcs2,
koi8r,
koi8u,
latin1,
latin2,
latin5,
latin7,
macce,
macroman,
sjis,
swe7,
tis620,
ucs2,
ujis,
utf16,
utf16le,
utf32,
utf8,
utf8mb4,
}
impl CharSet {
pub(crate) fn as_str(&self) -> &'static str {
match self {
CharSet::armscii8 => "armscii8",
CharSet::ascii => "ascii",
CharSet::big5 => "big5",
CharSet::binary => "binary",
CharSet::cp1250 => "cp1250",
CharSet::cp1251 => "cp1251",
CharSet::cp1256 => "cp1256",
CharSet::cp1257 => "cp1257",
CharSet::cp850 => "cp850",
CharSet::cp852 => "cp852",
CharSet::cp866 => "cp866",
CharSet::cp932 => "cp932",
CharSet::dec8 => "dec8",
CharSet::eucjpms => "eucjpms",
CharSet::euckr => "euckr",
CharSet::gb18030 => "gb18030",
CharSet::gb2312 => "gb2312",
CharSet::gbk => "gbk",
CharSet::geostd8 => "geostd8",
CharSet::greek => "greek",
CharSet::hebrew => "hebrew",
CharSet::hp8 => "hp8",
CharSet::keybcs2 => "keybcs2",
CharSet::koi8r => "koi8r",
CharSet::koi8u => "koi8u",
CharSet::latin1 => "latin1",
CharSet::latin2 => "latin2",
CharSet::latin5 => "latin5",
CharSet::latin7 => "latin7",
CharSet::macce => "macce",
CharSet::macroman => "macroman",
CharSet::sjis => "sjis",
CharSet::swe7 => "swe7",
CharSet::tis620 => "tis620",
CharSet::ucs2 => "ucs2",
CharSet::ujis => "ujis",
CharSet::utf16 => "utf16",
CharSet::utf16le => "utf16le",
CharSet::utf32 => "utf32",
CharSet::utf8 => "utf8",
CharSet::utf8mb4 => "utf8mb4",
}
}
pub(crate) fn default_collation(&self) -> Collation {
match self {
CharSet::armscii8 => Collation::armscii8_general_ci,
CharSet::ascii => Collation::ascii_general_ci,
CharSet::big5 => Collation::big5_chinese_ci,
CharSet::binary => Collation::binary,
CharSet::cp1250 => Collation::cp1250_general_ci,
CharSet::cp1251 => Collation::cp1251_general_ci,
CharSet::cp1256 => Collation::cp1256_general_ci,
CharSet::cp1257 => Collation::cp1257_general_ci,
CharSet::cp850 => Collation::cp850_general_ci,
CharSet::cp852 => Collation::cp852_general_ci,
CharSet::cp866 => Collation::cp866_general_ci,
CharSet::cp932 => Collation::cp932_japanese_ci,
CharSet::dec8 => Collation::dec8_swedish_ci,
CharSet::eucjpms => Collation::eucjpms_japanese_ci,
CharSet::euckr => Collation::euckr_korean_ci,
CharSet::gb18030 => Collation::gb18030_chinese_ci,
CharSet::gb2312 => Collation::gb2312_chinese_ci,
CharSet::gbk => Collation::gbk_chinese_ci,
CharSet::geostd8 => Collation::geostd8_general_ci,
CharSet::greek => Collation::greek_general_ci,
CharSet::hebrew => Collation::hebrew_general_ci,
CharSet::hp8 => Collation::hp8_english_ci,
CharSet::keybcs2 => Collation::keybcs2_general_ci,
CharSet::koi8r => Collation::koi8r_general_ci,
CharSet::koi8u => Collation::koi8u_general_ci,
CharSet::latin1 => Collation::latin1_swedish_ci,
CharSet::latin2 => Collation::latin2_general_ci,
CharSet::latin5 => Collation::latin5_turkish_ci,
CharSet::latin7 => Collation::latin7_general_ci,
CharSet::macce => Collation::macce_general_ci,
CharSet::macroman => Collation::macroman_general_ci,
CharSet::sjis => Collation::sjis_japanese_ci,
CharSet::swe7 => Collation::swe7_swedish_ci,
CharSet::tis620 => Collation::tis620_thai_ci,
CharSet::ucs2 => Collation::ucs2_general_ci,
CharSet::ujis => Collation::ujis_japanese_ci,
CharSet::utf16 => Collation::utf16_general_ci,
CharSet::utf16le => Collation::utf16le_general_ci,
CharSet::utf32 => Collation::utf32_general_ci,
CharSet::utf8 => Collation::utf8_unicode_ci,
CharSet::utf8mb4 => Collation::utf8mb4_unicode_ci,
}
}
}
impl FromStr for CharSet {
type Err = Error;
fn from_str(char_set: &str) -> Result<Self, Self::Err> {
Ok(match char_set {
"armscii8" => CharSet::armscii8,
"ascii" => CharSet::ascii,
"big5" => CharSet::big5,
"binary" => CharSet::binary,
"cp1250" => CharSet::cp1250,
"cp1251" => CharSet::cp1251,
"cp1256" => CharSet::cp1256,
"cp1257" => CharSet::cp1257,
"cp850" => CharSet::cp850,
"cp852" => CharSet::cp852,
"cp866" => CharSet::cp866,
"cp932" => CharSet::cp932,
"dec8" => CharSet::dec8,
"eucjpms" => CharSet::eucjpms,
"euckr" => CharSet::euckr,
"gb18030" => CharSet::gb18030,
"gb2312" => CharSet::gb2312,
"gbk" => CharSet::gbk,
"geostd8" => CharSet::geostd8,
"greek" => CharSet::greek,
"hebrew" => CharSet::hebrew,
"hp8" => CharSet::hp8,
"keybcs2" => CharSet::keybcs2,
"koi8r" => CharSet::koi8r,
"koi8u" => CharSet::koi8u,
"latin1" => CharSet::latin1,
"latin2" => CharSet::latin2,
"latin5" => CharSet::latin5,
"latin7" => CharSet::latin7,
"macce" => CharSet::macce,
"macroman" => CharSet::macroman,
"sjis" => CharSet::sjis,
"swe7" => CharSet::swe7,
"tis620" => CharSet::tis620,
"ucs2" => CharSet::ucs2,
"ujis" => CharSet::ujis,
"utf16" => CharSet::utf16,
"utf16le" => CharSet::utf16le,
"utf32" => CharSet::utf32,
"utf8" => CharSet::utf8,
"utf8mb4" => CharSet::utf8mb4,
_ => {
return Err(Error::Configuration(
format!("unsupported MySQL charset: {}", char_set).into(),
));
}
})
}
}
#[derive(Copy, Clone)]
#[allow(non_camel_case_types)]
#[repr(u8)]
pub(crate) enum Collation {
armscii8_bin = 64,
armscii8_general_ci = 32,
ascii_bin = 65,
ascii_general_ci = 11,
big5_bin = 84,
big5_chinese_ci = 1,
binary = 63,
cp1250_bin = 66,
cp1250_croatian_ci = 44,
cp1250_czech_cs = 34,
cp1250_general_ci = 26,
cp1250_polish_ci = 99,
cp1251_bin = 50,
cp1251_bulgarian_ci = 14,
cp1251_general_ci = 51,
cp1251_general_cs = 52,
cp1251_ukrainian_ci = 23,
cp1256_bin = 67,
cp1256_general_ci = 57,
cp1257_bin = 58,
cp1257_general_ci = 59,
cp1257_lithuanian_ci = 29,
cp850_bin = 80,
cp850_general_ci = 4,
cp852_bin = 81,
cp852_general_ci = 40,
cp866_bin = 68,
cp866_general_ci = 36,
cp932_bin = 96,
cp932_japanese_ci = 95,
dec8_bin = 69,
dec8_swedish_ci = 3,
eucjpms_bin = 98,
eucjpms_japanese_ci = 97,
euckr_bin = 85,
euckr_korean_ci = 19,
gb18030_bin = 249,
gb18030_chinese_ci = 248,
gb18030_unicode_520_ci = 250,
gb2312_bin = 86,
gb2312_chinese_ci = 24,
gbk_bin = 87,
gbk_chinese_ci = 28,
geostd8_bin = 93,
geostd8_general_ci = 92,
greek_bin = 70,
greek_general_ci = 25,
hebrew_bin = 71,
hebrew_general_ci = 16,
hp8_bin = 72,
hp8_english_ci = 6,
keybcs2_bin = 73,
keybcs2_general_ci = 37,
koi8r_bin = 74,
koi8r_general_ci = 7,
koi8u_bin = 75,
koi8u_general_ci = 22,
latin1_bin = 47,
latin1_danish_ci = 15,
latin1_general_ci = 48,
latin1_general_cs = 49,
latin1_german1_ci = 5,
latin1_german2_ci = 31,
latin1_spanish_ci = 94,
latin1_swedish_ci = 8,
latin2_bin = 77,
latin2_croatian_ci = 27,
latin2_czech_cs = 2,
latin2_general_ci = 9,
latin2_hungarian_ci = 21,
latin5_bin = 78,
latin5_turkish_ci = 30,
latin7_bin = 79,
latin7_estonian_cs = 20,
latin7_general_ci = 41,
latin7_general_cs = 42,
macce_bin = 43,
macce_general_ci = 38,
macroman_bin = 53,
macroman_general_ci = 39,
sjis_bin = 88,
sjis_japanese_ci = 13,
swe7_bin = 82,
swe7_swedish_ci = 10,
tis620_bin = 89,
tis620_thai_ci = 18,
ucs2_bin = 90,
ucs2_croatian_ci = 149,
ucs2_czech_ci = 138,
ucs2_danish_ci = 139,
ucs2_esperanto_ci = 145,
ucs2_estonian_ci = 134,
ucs2_general_ci = 35,
ucs2_general_mysql500_ci = 159,
ucs2_german2_ci = 148,
ucs2_hungarian_ci = 146,
ucs2_icelandic_ci = 129,
ucs2_latvian_ci = 130,
ucs2_lithuanian_ci = 140,
ucs2_persian_ci = 144,
ucs2_polish_ci = 133,
ucs2_roman_ci = 143,
ucs2_romanian_ci = 131,
ucs2_sinhala_ci = 147,
ucs2_slovak_ci = 141,
ucs2_slovenian_ci = 132,
ucs2_spanish_ci = 135,
ucs2_spanish2_ci = 142,
ucs2_swedish_ci = 136,
ucs2_turkish_ci = 137,
ucs2_unicode_520_ci = 150,
ucs2_unicode_ci = 128,
ucs2_vietnamese_ci = 151,
ujis_bin = 91,
ujis_japanese_ci = 12,
utf16_bin = 55,
utf16_croatian_ci = 122,
utf16_czech_ci = 111,
utf16_danish_ci = 112,
utf16_esperanto_ci = 118,
utf16_estonian_ci = 107,
utf16_general_ci = 54,
utf16_german2_ci = 121,
utf16_hungarian_ci = 119,
utf16_icelandic_ci = 102,
utf16_latvian_ci = 103,
utf16_lithuanian_ci = 113,
utf16_persian_ci = 117,
utf16_polish_ci = 106,
utf16_roman_ci = 116,
utf16_romanian_ci = 104,
utf16_sinhala_ci = 120,
utf16_slovak_ci = 114,
utf16_slovenian_ci = 105,
utf16_spanish_ci = 108,
utf16_spanish2_ci = 115,
utf16_swedish_ci = 109,
utf16_turkish_ci = 110,
utf16_unicode_520_ci = 123,
utf16_unicode_ci = 101,
utf16_vietnamese_ci = 124,
utf16le_bin = 62,
utf16le_general_ci = 56,
utf32_bin = 61,
utf32_croatian_ci = 181,
utf32_czech_ci = 170,
utf32_danish_ci = 171,
utf32_esperanto_ci = 177,
utf32_estonian_ci = 166,
utf32_general_ci = 60,
utf32_german2_ci = 180,
utf32_hungarian_ci = 178,
utf32_icelandic_ci = 161,
utf32_latvian_ci = 162,
utf32_lithuanian_ci = 172,
utf32_persian_ci = 176,
utf32_polish_ci = 165,
utf32_roman_ci = 175,
utf32_romanian_ci = 163,
utf32_sinhala_ci = 179,
utf32_slovak_ci = 173,
utf32_slovenian_ci = 164,
utf32_spanish_ci = 167,
utf32_spanish2_ci = 174,
utf32_swedish_ci = 168,
utf32_turkish_ci = 169,
utf32_unicode_520_ci = 182,
utf32_unicode_ci = 160,
utf32_vietnamese_ci = 183,
utf8_bin = 83,
utf8_croatian_ci = 213,
utf8_czech_ci = 202,
utf8_danish_ci = 203,
utf8_esperanto_ci = 209,
utf8_estonian_ci = 198,
utf8_general_ci = 33,
utf8_general_mysql500_ci = 223,
utf8_german2_ci = 212,
utf8_hungarian_ci = 210,
utf8_icelandic_ci = 193,
utf8_latvian_ci = 194,
utf8_lithuanian_ci = 204,
utf8_persian_ci = 208,
utf8_polish_ci = 197,
utf8_roman_ci = 207,
utf8_romanian_ci = 195,
utf8_sinhala_ci = 211,
utf8_slovak_ci = 205,
utf8_slovenian_ci = 196,
utf8_spanish_ci = 199,
utf8_spanish2_ci = 206,
utf8_swedish_ci = 200,
utf8_tolower_ci = 76,
utf8_turkish_ci = 201,
utf8_unicode_520_ci = 214,
utf8_unicode_ci = 192,
utf8_vietnamese_ci = 215,
utf8mb4_0900_ai_ci = 255,
utf8mb4_bin = 46,
utf8mb4_croatian_ci = 245,
utf8mb4_czech_ci = 234,
utf8mb4_danish_ci = 235,
utf8mb4_esperanto_ci = 241,
utf8mb4_estonian_ci = 230,
utf8mb4_general_ci = 45,
utf8mb4_german2_ci = 244,
utf8mb4_hungarian_ci = 242,
utf8mb4_icelandic_ci = 225,
utf8mb4_latvian_ci = 226,
utf8mb4_lithuanian_ci = 236,
utf8mb4_persian_ci = 240,
utf8mb4_polish_ci = 229,
utf8mb4_roman_ci = 239,
utf8mb4_romanian_ci = 227,
utf8mb4_sinhala_ci = 243,
utf8mb4_slovak_ci = 237,
utf8mb4_slovenian_ci = 228,
utf8mb4_spanish_ci = 231,
utf8mb4_spanish2_ci = 238,
utf8mb4_swedish_ci = 232,
utf8mb4_turkish_ci = 233,
utf8mb4_unicode_520_ci = 246,
utf8mb4_unicode_ci = 224,
utf8mb4_vietnamese_ci = 247,
}
impl Collation {
pub(crate) fn as_str(&self) -> &'static str {
match self {
Collation::armscii8_bin => "armscii8_bin",
Collation::armscii8_general_ci => "armscii8_general_ci",
Collation::ascii_bin => "ascii_bin",
Collation::ascii_general_ci => "ascii_general_ci",
Collation::big5_bin => "big5_bin",
Collation::big5_chinese_ci => "big5_chinese_ci",
Collation::binary => "binary",
Collation::cp1250_bin => "cp1250_bin",
Collation::cp1250_croatian_ci => "cp1250_croatian_ci",
Collation::cp1250_czech_cs => "cp1250_czech_cs",
Collation::cp1250_general_ci => "cp1250_general_ci",
Collation::cp1250_polish_ci => "cp1250_polish_ci",
Collation::cp1251_bin => "cp1251_bin",
Collation::cp1251_bulgarian_ci => "cp1251_bulgarian_ci",
Collation::cp1251_general_ci => "cp1251_general_ci",
Collation::cp1251_general_cs => "cp1251_general_cs",
Collation::cp1251_ukrainian_ci => "cp1251_ukrainian_ci",
Collation::cp1256_bin => "cp1256_bin",
Collation::cp1256_general_ci => "cp1256_general_ci",
Collation::cp1257_bin => "cp1257_bin",
Collation::cp1257_general_ci => "cp1257_general_ci",
Collation::cp1257_lithuanian_ci => "cp1257_lithuanian_ci",
Collation::cp850_bin => "cp850_bin",
Collation::cp850_general_ci => "cp850_general_ci",
Collation::cp852_bin => "cp852_bin",
Collation::cp852_general_ci => "cp852_general_ci",
Collation::cp866_bin => "cp866_bin",
Collation::cp866_general_ci => "cp866_general_ci",
Collation::cp932_bin => "cp932_bin",
Collation::cp932_japanese_ci => "cp932_japanese_ci",
Collation::dec8_bin => "dec8_bin",
Collation::dec8_swedish_ci => "dec8_swedish_ci",
Collation::eucjpms_bin => "eucjpms_bin",
Collation::eucjpms_japanese_ci => "eucjpms_japanese_ci",
Collation::euckr_bin => "euckr_bin",
Collation::euckr_korean_ci => "euckr_korean_ci",
Collation::gb18030_bin => "gb18030_bin",
Collation::gb18030_chinese_ci => "gb18030_chinese_ci",
Collation::gb18030_unicode_520_ci => "gb18030_unicode_520_ci",
Collation::gb2312_bin => "gb2312_bin",
Collation::gb2312_chinese_ci => "gb2312_chinese_ci",
Collation::gbk_bin => "gbk_bin",
Collation::gbk_chinese_ci => "gbk_chinese_ci",
Collation::geostd8_bin => "geostd8_bin",
Collation::geostd8_general_ci => "geostd8_general_ci",
Collation::greek_bin => "greek_bin",
Collation::greek_general_ci => "greek_general_ci",
Collation::hebrew_bin => "hebrew_bin",
Collation::hebrew_general_ci => "hebrew_general_ci",
Collation::hp8_bin => "hp8_bin",
Collation::hp8_english_ci => "hp8_english_ci",
Collation::keybcs2_bin => "keybcs2_bin",
Collation::keybcs2_general_ci => "keybcs2_general_ci",
Collation::koi8r_bin => "koi8r_bin",
Collation::koi8r_general_ci => "koi8r_general_ci",
Collation::koi8u_bin => "koi8u_bin",
Collation::koi8u_general_ci => "koi8u_general_ci",
Collation::latin1_bin => "latin1_bin",
Collation::latin1_danish_ci => "latin1_danish_ci",
Collation::latin1_general_ci => "latin1_general_ci",
Collation::latin1_general_cs => "latin1_general_cs",
Collation::latin1_german1_ci => "latin1_german1_ci",
Collation::latin1_german2_ci => "latin1_german2_ci",
Collation::latin1_spanish_ci => "latin1_spanish_ci",
Collation::latin1_swedish_ci => "latin1_swedish_ci",
Collation::latin2_bin => "latin2_bin",
Collation::latin2_croatian_ci => "latin2_croatian_ci",
Collation::latin2_czech_cs => "latin2_czech_cs",
Collation::latin2_general_ci => "latin2_general_ci",
Collation::latin2_hungarian_ci => "latin2_hungarian_ci",
Collation::latin5_bin => "latin5_bin",
Collation::latin5_turkish_ci => "latin5_turkish_ci",
Collation::latin7_bin => "latin7_bin",
Collation::latin7_estonian_cs => "latin7_estonian_cs",
Collation::latin7_general_ci => "latin7_general_ci",
Collation::latin7_general_cs => "latin7_general_cs",
Collation::macce_bin => "macce_bin",
Collation::macce_general_ci => "macce_general_ci",
Collation::macroman_bin => "macroman_bin",
Collation::macroman_general_ci => "macroman_general_ci",
Collation::sjis_bin => "sjis_bin",
Collation::sjis_japanese_ci => "sjis_japanese_ci",
Collation::swe7_bin => "swe7_bin",
Collation::swe7_swedish_ci => "swe7_swedish_ci",
Collation::tis620_bin => "tis620_bin",
Collation::tis620_thai_ci => "tis620_thai_ci",
Collation::ucs2_bin => "ucs2_bin",
Collation::ucs2_croatian_ci => "ucs2_croatian_ci",
Collation::ucs2_czech_ci => "ucs2_czech_ci",
Collation::ucs2_danish_ci => "ucs2_danish_ci",
Collation::ucs2_esperanto_ci => "ucs2_esperanto_ci",
Collation::ucs2_estonian_ci => "ucs2_estonian_ci",
Collation::ucs2_general_ci => "ucs2_general_ci",
Collation::ucs2_general_mysql500_ci => "ucs2_general_mysql500_ci",
Collation::ucs2_german2_ci => "ucs2_german2_ci",
Collation::ucs2_hungarian_ci => "ucs2_hungarian_ci",
Collation::ucs2_icelandic_ci => "ucs2_icelandic_ci",
Collation::ucs2_latvian_ci => "ucs2_latvian_ci",
Collation::ucs2_lithuanian_ci => "ucs2_lithuanian_ci",
Collation::ucs2_persian_ci => "ucs2_persian_ci",
Collation::ucs2_polish_ci => "ucs2_polish_ci",
Collation::ucs2_roman_ci => "ucs2_roman_ci",
Collation::ucs2_romanian_ci => "ucs2_romanian_ci",
Collation::ucs2_sinhala_ci => "ucs2_sinhala_ci",
Collation::ucs2_slovak_ci => "ucs2_slovak_ci",
Collation::ucs2_slovenian_ci => "ucs2_slovenian_ci",
Collation::ucs2_spanish_ci => "ucs2_spanish_ci",
Collation::ucs2_spanish2_ci => "ucs2_spanish2_ci",
Collation::ucs2_swedish_ci => "ucs2_swedish_ci",
Collation::ucs2_turkish_ci => "ucs2_turkish_ci",
Collation::ucs2_unicode_520_ci => "ucs2_unicode_520_ci",
Collation::ucs2_unicode_ci => "ucs2_unicode_ci",
Collation::ucs2_vietnamese_ci => "ucs2_vietnamese_ci",
Collation::ujis_bin => "ujis_bin",
Collation::ujis_japanese_ci => "ujis_japanese_ci",
Collation::utf16_bin => "utf16_bin",
Collation::utf16_croatian_ci => "utf16_croatian_ci",
Collation::utf16_czech_ci => "utf16_czech_ci",
Collation::utf16_danish_ci => "utf16_danish_ci",
Collation::utf16_esperanto_ci => "utf16_esperanto_ci",
Collation::utf16_estonian_ci => "utf16_estonian_ci",
Collation::utf16_general_ci => "utf16_general_ci",
Collation::utf16_german2_ci => "utf16_german2_ci",
Collation::utf16_hungarian_ci => "utf16_hungarian_ci",
Collation::utf16_icelandic_ci => "utf16_icelandic_ci",
Collation::utf16_latvian_ci => "utf16_latvian_ci",
Collation::utf16_lithuanian_ci => "utf16_lithuanian_ci",
Collation::utf16_persian_ci => "utf16_persian_ci",
Collation::utf16_polish_ci => "utf16_polish_ci",
Collation::utf16_roman_ci => "utf16_roman_ci",
Collation::utf16_romanian_ci => "utf16_romanian_ci",
Collation::utf16_sinhala_ci => "utf16_sinhala_ci",
Collation::utf16_slovak_ci => "utf16_slovak_ci",
Collation::utf16_slovenian_ci => "utf16_slovenian_ci",
Collation::utf16_spanish_ci => "utf16_spanish_ci",
Collation::utf16_spanish2_ci => "utf16_spanish2_ci",
Collation::utf16_swedish_ci => "utf16_swedish_ci",
Collation::utf16_turkish_ci => "utf16_turkish_ci",
Collation::utf16_unicode_520_ci => "utf16_unicode_520_ci",
Collation::utf16_unicode_ci => "utf16_unicode_ci",
Collation::utf16_vietnamese_ci => "utf16_vietnamese_ci",
Collation::utf16le_bin => "utf16le_bin",
Collation::utf16le_general_ci => "utf16le_general_ci",
Collation::utf32_bin => "utf32_bin",
Collation::utf32_croatian_ci => "utf32_croatian_ci",
Collation::utf32_czech_ci => "utf32_czech_ci",
Collation::utf32_danish_ci => "utf32_danish_ci",
Collation::utf32_esperanto_ci => "utf32_esperanto_ci",
Collation::utf32_estonian_ci => "utf32_estonian_ci",
Collation::utf32_general_ci => "utf32_general_ci",
Collation::utf32_german2_ci => "utf32_german2_ci",
Collation::utf32_hungarian_ci => "utf32_hungarian_ci",
Collation::utf32_icelandic_ci => "utf32_icelandic_ci",
Collation::utf32_latvian_ci => "utf32_latvian_ci",
Collation::utf32_lithuanian_ci => "utf32_lithuanian_ci",
Collation::utf32_persian_ci => "utf32_persian_ci",
Collation::utf32_polish_ci => "utf32_polish_ci",
Collation::utf32_roman_ci => "utf32_roman_ci",
Collation::utf32_romanian_ci => "utf32_romanian_ci",
Collation::utf32_sinhala_ci => "utf32_sinhala_ci",
Collation::utf32_slovak_ci => "utf32_slovak_ci",
Collation::utf32_slovenian_ci => "utf32_slovenian_ci",
Collation::utf32_spanish_ci => "utf32_spanish_ci",
Collation::utf32_spanish2_ci => "utf32_spanish2_ci",
Collation::utf32_swedish_ci => "utf32_swedish_ci",
Collation::utf32_turkish_ci => "utf32_turkish_ci",
Collation::utf32_unicode_520_ci => "utf32_unicode_520_ci",
Collation::utf32_unicode_ci => "utf32_unicode_ci",
Collation::utf32_vietnamese_ci => "utf32_vietnamese_ci",
Collation::utf8_bin => "utf8_bin",
Collation::utf8_croatian_ci => "utf8_croatian_ci",
Collation::utf8_czech_ci => "utf8_czech_ci",
Collation::utf8_danish_ci => "utf8_danish_ci",
Collation::utf8_esperanto_ci => "utf8_esperanto_ci",
Collation::utf8_estonian_ci => "utf8_estonian_ci",
Collation::utf8_general_ci => "utf8_general_ci",
Collation::utf8_general_mysql500_ci => "utf8_general_mysql500_ci",
Collation::utf8_german2_ci => "utf8_german2_ci",
Collation::utf8_hungarian_ci => "utf8_hungarian_ci",
Collation::utf8_icelandic_ci => "utf8_icelandic_ci",
Collation::utf8_latvian_ci => "utf8_latvian_ci",
Collation::utf8_lithuanian_ci => "utf8_lithuanian_ci",
Collation::utf8_persian_ci => "utf8_persian_ci",
Collation::utf8_polish_ci => "utf8_polish_ci",
Collation::utf8_roman_ci => "utf8_roman_ci",
Collation::utf8_romanian_ci => "utf8_romanian_ci",
Collation::utf8_sinhala_ci => "utf8_sinhala_ci",
Collation::utf8_slovak_ci => "utf8_slovak_ci",
Collation::utf8_slovenian_ci => "utf8_slovenian_ci",
Collation::utf8_spanish_ci => "utf8_spanish_ci",
Collation::utf8_spanish2_ci => "utf8_spanish2_ci",
Collation::utf8_swedish_ci => "utf8_swedish_ci",
Collation::utf8_tolower_ci => "utf8_tolower_ci",
Collation::utf8_turkish_ci => "utf8_turkish_ci",
Collation::utf8_unicode_520_ci => "utf8_unicode_520_ci",
Collation::utf8_unicode_ci => "utf8_unicode_ci",
Collation::utf8_vietnamese_ci => "utf8_vietnamese_ci",
Collation::utf8mb4_0900_ai_ci => "utf8mb4_0900_ai_ci",
Collation::utf8mb4_bin => "utf8mb4_bin",
Collation::utf8mb4_croatian_ci => "utf8mb4_croatian_ci",
Collation::utf8mb4_czech_ci => "utf8mb4_czech_ci",
Collation::utf8mb4_danish_ci => "utf8mb4_danish_ci",
Collation::utf8mb4_esperanto_ci => "utf8mb4_esperanto_ci",
Collation::utf8mb4_estonian_ci => "utf8mb4_estonian_ci",
Collation::utf8mb4_general_ci => "utf8mb4_general_ci",
Collation::utf8mb4_german2_ci => "utf8mb4_german2_ci",
Collation::utf8mb4_hungarian_ci => "utf8mb4_hungarian_ci",
Collation::utf8mb4_icelandic_ci => "utf8mb4_icelandic_ci",
Collation::utf8mb4_latvian_ci => "utf8mb4_latvian_ci",
Collation::utf8mb4_lithuanian_ci => "utf8mb4_lithuanian_ci",
Collation::utf8mb4_persian_ci => "utf8mb4_persian_ci",
Collation::utf8mb4_polish_ci => "utf8mb4_polish_ci",
Collation::utf8mb4_roman_ci => "utf8mb4_roman_ci",
Collation::utf8mb4_romanian_ci => "utf8mb4_romanian_ci",
Collation::utf8mb4_sinhala_ci => "utf8mb4_sinhala_ci",
Collation::utf8mb4_slovak_ci => "utf8mb4_slovak_ci",
Collation::utf8mb4_slovenian_ci => "utf8mb4_slovenian_ci",
Collation::utf8mb4_spanish_ci => "utf8mb4_spanish_ci",
Collation::utf8mb4_spanish2_ci => "utf8mb4_spanish2_ci",
Collation::utf8mb4_swedish_ci => "utf8mb4_swedish_ci",
Collation::utf8mb4_turkish_ci => "utf8mb4_turkish_ci",
Collation::utf8mb4_unicode_520_ci => "utf8mb4_unicode_520_ci",
Collation::utf8mb4_unicode_ci => "utf8mb4_unicode_ci",
Collation::utf8mb4_vietnamese_ci => "utf8mb4_vietnamese_ci",
}
}
}
// Handshake packet have only 1 byte for collation_id.
// So we can't use collations with ID > 255.
impl FromStr for Collation {
type Err = Error;
fn from_str(collation: &str) -> Result<Self, Self::Err> {
Ok(match collation {
"big5_chinese_ci" => Collation::big5_chinese_ci,
"swe7_swedish_ci" => Collation::swe7_swedish_ci,
"utf16_unicode_ci" => Collation::utf16_unicode_ci,
"utf16_icelandic_ci" => Collation::utf16_icelandic_ci,
"utf16_latvian_ci" => Collation::utf16_latvian_ci,
"utf16_romanian_ci" => Collation::utf16_romanian_ci,
"utf16_slovenian_ci" => Collation::utf16_slovenian_ci,
"utf16_polish_ci" => Collation::utf16_polish_ci,
"utf16_estonian_ci" => Collation::utf16_estonian_ci,
"utf16_spanish_ci" => Collation::utf16_spanish_ci,
"utf16_swedish_ci" => Collation::utf16_swedish_ci,
"ascii_general_ci" => Collation::ascii_general_ci,
"utf16_turkish_ci" => Collation::utf16_turkish_ci,
"utf16_czech_ci" => Collation::utf16_czech_ci,
"utf16_danish_ci" => Collation::utf16_danish_ci,
"utf16_lithuanian_ci" => Collation::utf16_lithuanian_ci,
"utf16_slovak_ci" => Collation::utf16_slovak_ci,
"utf16_spanish2_ci" => Collation::utf16_spanish2_ci,
"utf16_roman_ci" => Collation::utf16_roman_ci,
"utf16_persian_ci" => Collation::utf16_persian_ci,
"utf16_esperanto_ci" => Collation::utf16_esperanto_ci,
"utf16_hungarian_ci" => Collation::utf16_hungarian_ci,
"ujis_japanese_ci" => Collation::ujis_japanese_ci,
"utf16_sinhala_ci" => Collation::utf16_sinhala_ci,
"utf16_german2_ci" => Collation::utf16_german2_ci,
"utf16_croatian_ci" => Collation::utf16_croatian_ci,
"utf16_unicode_520_ci" => Collation::utf16_unicode_520_ci,
"utf16_vietnamese_ci" => Collation::utf16_vietnamese_ci,
"ucs2_unicode_ci" => Collation::ucs2_unicode_ci,
"ucs2_icelandic_ci" => Collation::ucs2_icelandic_ci,
"sjis_japanese_ci" => Collation::sjis_japanese_ci,
"ucs2_latvian_ci" => Collation::ucs2_latvian_ci,
"ucs2_romanian_ci" => Collation::ucs2_romanian_ci,
"ucs2_slovenian_ci" => Collation::ucs2_slovenian_ci,
"ucs2_polish_ci" => Collation::ucs2_polish_ci,
"ucs2_estonian_ci" => Collation::ucs2_estonian_ci,
"ucs2_spanish_ci" => Collation::ucs2_spanish_ci,
"ucs2_swedish_ci" => Collation::ucs2_swedish_ci,
"ucs2_turkish_ci" => Collation::ucs2_turkish_ci,
"ucs2_czech_ci" => Collation::ucs2_czech_ci,
"ucs2_danish_ci" => Collation::ucs2_danish_ci,
"cp1251_bulgarian_ci" => Collation::cp1251_bulgarian_ci,
"ucs2_lithuanian_ci" => Collation::ucs2_lithuanian_ci,
"ucs2_slovak_ci" => Collation::ucs2_slovak_ci,
"ucs2_spanish2_ci" => Collation::ucs2_spanish2_ci,
"ucs2_roman_ci" => Collation::ucs2_roman_ci,
"ucs2_persian_ci" => Collation::ucs2_persian_ci,
"ucs2_esperanto_ci" => Collation::ucs2_esperanto_ci,
"ucs2_hungarian_ci" => Collation::ucs2_hungarian_ci,
"ucs2_sinhala_ci" => Collation::ucs2_sinhala_ci,
"ucs2_german2_ci" => Collation::ucs2_german2_ci,
"ucs2_croatian_ci" => Collation::ucs2_croatian_ci,
"latin1_danish_ci" => Collation::latin1_danish_ci,
"ucs2_unicode_520_ci" => Collation::ucs2_unicode_520_ci,
"ucs2_vietnamese_ci" => Collation::ucs2_vietnamese_ci,
"ucs2_general_mysql500_ci" => Collation::ucs2_general_mysql500_ci,
"hebrew_general_ci" => Collation::hebrew_general_ci,
"utf32_unicode_ci" => Collation::utf32_unicode_ci,
"utf32_icelandic_ci" => Collation::utf32_icelandic_ci,
"utf32_latvian_ci" => Collation::utf32_latvian_ci,
"utf32_romanian_ci" => Collation::utf32_romanian_ci,
"utf32_slovenian_ci" => Collation::utf32_slovenian_ci,
"utf32_polish_ci" => Collation::utf32_polish_ci,
"utf32_estonian_ci" => Collation::utf32_estonian_ci,
"utf32_spanish_ci" => Collation::utf32_spanish_ci,
"utf32_swedish_ci" => Collation::utf32_swedish_ci,
"utf32_turkish_ci" => Collation::utf32_turkish_ci,
"utf32_czech_ci" => Collation::utf32_czech_ci,
"utf32_danish_ci" => Collation::utf32_danish_ci,
"utf32_lithuanian_ci" => Collation::utf32_lithuanian_ci,
"utf32_slovak_ci" => Collation::utf32_slovak_ci,
"utf32_spanish2_ci" => Collation::utf32_spanish2_ci,
"utf32_roman_ci" => Collation::utf32_roman_ci,
"utf32_persian_ci" => Collation::utf32_persian_ci,
"utf32_esperanto_ci" => Collation::utf32_esperanto_ci,
"utf32_hungarian_ci" => Collation::utf32_hungarian_ci,
"utf32_sinhala_ci" => Collation::utf32_sinhala_ci,
"tis620_thai_ci" => Collation::tis620_thai_ci,
"utf32_german2_ci" => Collation::utf32_german2_ci,
"utf32_croatian_ci" => Collation::utf32_croatian_ci,
"utf32_unicode_520_ci" => Collation::utf32_unicode_520_ci,
"utf32_vietnamese_ci" => Collation::utf32_vietnamese_ci,
"euckr_korean_ci" => Collation::euckr_korean_ci,
"utf8_unicode_ci" => Collation::utf8_unicode_ci,
"utf8_icelandic_ci" => Collation::utf8_icelandic_ci,
"utf8_latvian_ci" => Collation::utf8_latvian_ci,
"utf8_romanian_ci" => Collation::utf8_romanian_ci,
"utf8_slovenian_ci" => Collation::utf8_slovenian_ci,
"utf8_polish_ci" => Collation::utf8_polish_ci,
"utf8_estonian_ci" => Collation::utf8_estonian_ci,
"utf8_spanish_ci" => Collation::utf8_spanish_ci,
"latin2_czech_cs" => Collation::latin2_czech_cs,
"latin7_estonian_cs" => Collation::latin7_estonian_cs,
"utf8_swedish_ci" => Collation::utf8_swedish_ci,
"utf8_turkish_ci" => Collation::utf8_turkish_ci,
"utf8_czech_ci" => Collation::utf8_czech_ci,
"utf8_danish_ci" => Collation::utf8_danish_ci,
"utf8_lithuanian_ci" => Collation::utf8_lithuanian_ci,
"utf8_slovak_ci" => Collation::utf8_slovak_ci,
"utf8_spanish2_ci" => Collation::utf8_spanish2_ci,
"utf8_roman_ci" => Collation::utf8_roman_ci,
"utf8_persian_ci" => Collation::utf8_persian_ci,
"utf8_esperanto_ci" => Collation::utf8_esperanto_ci,
"latin2_hungarian_ci" => Collation::latin2_hungarian_ci,
"utf8_hungarian_ci" => Collation::utf8_hungarian_ci,
"utf8_sinhala_ci" => Collation::utf8_sinhala_ci,
"utf8_german2_ci" => Collation::utf8_german2_ci,
"utf8_croatian_ci" => Collation::utf8_croatian_ci,
"utf8_unicode_520_ci" => Collation::utf8_unicode_520_ci,
"utf8_vietnamese_ci" => Collation::utf8_vietnamese_ci,
"koi8u_general_ci" => Collation::koi8u_general_ci,
"utf8_general_mysql500_ci" => Collation::utf8_general_mysql500_ci,
"utf8mb4_unicode_ci" => Collation::utf8mb4_unicode_ci,
"utf8mb4_icelandic_ci" => Collation::utf8mb4_icelandic_ci,
"utf8mb4_latvian_ci" => Collation::utf8mb4_latvian_ci,
"utf8mb4_romanian_ci" => Collation::utf8mb4_romanian_ci,
"utf8mb4_slovenian_ci" => Collation::utf8mb4_slovenian_ci,
"utf8mb4_polish_ci" => Collation::utf8mb4_polish_ci,
"cp1251_ukrainian_ci" => Collation::cp1251_ukrainian_ci,
"utf8mb4_estonian_ci" => Collation::utf8mb4_estonian_ci,
"utf8mb4_spanish_ci" => Collation::utf8mb4_spanish_ci,
"utf8mb4_swedish_ci" => Collation::utf8mb4_swedish_ci,
"utf8mb4_turkish_ci" => Collation::utf8mb4_turkish_ci,
"utf8mb4_czech_ci" => Collation::utf8mb4_czech_ci,
"utf8mb4_danish_ci" => Collation::utf8mb4_danish_ci,
"utf8mb4_lithuanian_ci" => Collation::utf8mb4_lithuanian_ci,
"utf8mb4_slovak_ci" => Collation::utf8mb4_slovak_ci,
"utf8mb4_spanish2_ci" => Collation::utf8mb4_spanish2_ci,
"utf8mb4_roman_ci" => Collation::utf8mb4_roman_ci,
"gb2312_chinese_ci" => Collation::gb2312_chinese_ci,
"utf8mb4_persian_ci" => Collation::utf8mb4_persian_ci,
"utf8mb4_esperanto_ci" => Collation::utf8mb4_esperanto_ci,
"utf8mb4_hungarian_ci" => Collation::utf8mb4_hungarian_ci,
"utf8mb4_sinhala_ci" => Collation::utf8mb4_sinhala_ci,
"utf8mb4_german2_ci" => Collation::utf8mb4_german2_ci,
"utf8mb4_croatian_ci" => Collation::utf8mb4_croatian_ci,
"utf8mb4_unicode_520_ci" => Collation::utf8mb4_unicode_520_ci,
"utf8mb4_vietnamese_ci" => Collation::utf8mb4_vietnamese_ci,
"gb18030_chinese_ci" => Collation::gb18030_chinese_ci,
"gb18030_bin" => Collation::gb18030_bin,
"greek_general_ci" => Collation::greek_general_ci,
"gb18030_unicode_520_ci" => Collation::gb18030_unicode_520_ci,
"utf8mb4_0900_ai_ci" => Collation::utf8mb4_0900_ai_ci,
"cp1250_general_ci" => Collation::cp1250_general_ci,
"latin2_croatian_ci" => Collation::latin2_croatian_ci,
"gbk_chinese_ci" => Collation::gbk_chinese_ci,
"cp1257_lithuanian_ci" => Collation::cp1257_lithuanian_ci,
"dec8_swedish_ci" => Collation::dec8_swedish_ci,
"latin5_turkish_ci" => Collation::latin5_turkish_ci,
"latin1_german2_ci" => Collation::latin1_german2_ci,
"armscii8_general_ci" => Collation::armscii8_general_ci,
"utf8_general_ci" => Collation::utf8_general_ci,
"cp1250_czech_cs" => Collation::cp1250_czech_cs,
"ucs2_general_ci" => Collation::ucs2_general_ci,
"cp866_general_ci" => Collation::cp866_general_ci,
"keybcs2_general_ci" => Collation::keybcs2_general_ci,
"macce_general_ci" => Collation::macce_general_ci,
"macroman_general_ci" => Collation::macroman_general_ci,
"cp850_general_ci" => Collation::cp850_general_ci,
"cp852_general_ci" => Collation::cp852_general_ci,
"latin7_general_ci" => Collation::latin7_general_ci,
"latin7_general_cs" => Collation::latin7_general_cs,
"macce_bin" => Collation::macce_bin,
"cp1250_croatian_ci" => Collation::cp1250_croatian_ci,
"utf8mb4_general_ci" => Collation::utf8mb4_general_ci,
"utf8mb4_bin" => Collation::utf8mb4_bin,
"latin1_bin" => Collation::latin1_bin,
"latin1_general_ci" => Collation::latin1_general_ci,
"latin1_general_cs" => Collation::latin1_general_cs,
"latin1_german1_ci" => Collation::latin1_german1_ci,
"cp1251_bin" => Collation::cp1251_bin,
"cp1251_general_ci" => Collation::cp1251_general_ci,
"cp1251_general_cs" => Collation::cp1251_general_cs,
"macroman_bin" => Collation::macroman_bin,
"utf16_general_ci" => Collation::utf16_general_ci,
"utf16_bin" => Collation::utf16_bin,
"utf16le_general_ci" => Collation::utf16le_general_ci,
"cp1256_general_ci" => Collation::cp1256_general_ci,
"cp1257_bin" => Collation::cp1257_bin,
"cp1257_general_ci" => Collation::cp1257_general_ci,
"hp8_english_ci" => Collation::hp8_english_ci,
"utf32_general_ci" => Collation::utf32_general_ci,
"utf32_bin" => Collation::utf32_bin,
"utf16le_bin" => Collation::utf16le_bin,
"binary" => Collation::binary,
"armscii8_bin" => Collation::armscii8_bin,
"ascii_bin" => Collation::ascii_bin,
"cp1250_bin" => Collation::cp1250_bin,
"cp1256_bin" => Collation::cp1256_bin,
"cp866_bin" => Collation::cp866_bin,
"dec8_bin" => Collation::dec8_bin,
"koi8r_general_ci" => Collation::koi8r_general_ci,
"greek_bin" => Collation::greek_bin,
"hebrew_bin" => Collation::hebrew_bin,
"hp8_bin" => Collation::hp8_bin,
"keybcs2_bin" => Collation::keybcs2_bin,
"koi8r_bin" => Collation::koi8r_bin,
"koi8u_bin" => Collation::koi8u_bin,
"utf8_tolower_ci" => Collation::utf8_tolower_ci,
"latin2_bin" => Collation::latin2_bin,
"latin5_bin" => Collation::latin5_bin,
"latin7_bin" => Collation::latin7_bin,
"latin1_swedish_ci" => Collation::latin1_swedish_ci,
"cp850_bin" => Collation::cp850_bin,
"cp852_bin" => Collation::cp852_bin,
"swe7_bin" => Collation::swe7_bin,
"utf8_bin" => Collation::utf8_bin,
"big5_bin" => Collation::big5_bin,
"euckr_bin" => Collation::euckr_bin,
"gb2312_bin" => Collation::gb2312_bin,
"gbk_bin" => Collation::gbk_bin,
"sjis_bin" => Collation::sjis_bin,
"tis620_bin" => Collation::tis620_bin,
"latin2_general_ci" => Collation::latin2_general_ci,
"ucs2_bin" => Collation::ucs2_bin,
"ujis_bin" => Collation::ujis_bin,
"geostd8_general_ci" => Collation::geostd8_general_ci,
"geostd8_bin" => Collation::geostd8_bin,
"latin1_spanish_ci" => Collation::latin1_spanish_ci,
"cp932_japanese_ci" => Collation::cp932_japanese_ci,
"cp932_bin" => Collation::cp932_bin,
"eucjpms_japanese_ci" => Collation::eucjpms_japanese_ci,
"eucjpms_bin" => Collation::eucjpms_bin,
"cp1250_polish_ci" => Collation::cp1250_polish_ci,
_ => {
return Err(Error::Configuration(
format!("unsupported MySQL collation: {}", collation).into(),
));
}
})
}
}

View file

@ -0,0 +1,40 @@
use bytes::{Buf, Bytes};
use crate::error::Error;
use crate::io::BufExt;
pub trait MySqlBufExt: Buf {
// Read a length-encoded integer.
// NOTE: 0xfb or NULL is only returned for binary value encoding to indicate NULL.
// NOTE: 0xff is only returned during a result set to indicate ERR.
// <https://dev.mysql.com/doc/internals/en/integer.html#packet-Protocol::LengthEncodedInteger>
fn get_uint_lenenc(&mut self) -> u64;
// Read a length-encoded string.
fn get_str_lenenc(&mut self) -> Result<String, Error>;
// Read a length-encoded byte sequence.
fn get_bytes_lenenc(&mut self) -> Bytes;
}
impl MySqlBufExt for Bytes {
fn get_uint_lenenc(&mut self) -> u64 {
match self.get_u8() {
0xfc => u64::from(self.get_u16_le()),
0xfd => self.get_uint_le(3),
0xfe => self.get_u64_le(),
v => u64::from(v),
}
}
fn get_str_lenenc(&mut self) -> Result<String, Error> {
let size = self.get_uint_lenenc();
self.get_str(size as usize)
}
fn get_bytes_lenenc(&mut self) -> Bytes {
let size = self.get_uint_lenenc();
self.split_to(size as usize)
}
}

View file

@ -0,0 +1,126 @@
use bytes::BufMut;
pub trait MySqlBufMutExt: BufMut {
fn put_uint_lenenc(&mut self, v: u64);
fn put_str_lenenc(&mut self, v: &str);
fn put_bytes_lenenc(&mut self, v: &[u8]);
}
impl MySqlBufMutExt for Vec<u8> {
fn put_uint_lenenc(&mut self, v: u64) {
// https://dev.mysql.com/doc/internals/en/integer.html
// https://mariadb.com/kb/en/library/protocol-data-types/#length-encoded-integers
if v < 251 {
self.push(v as u8);
} else if v < 0x1_00_00 {
self.push(0xfc);
self.extend(&(v as u16).to_le_bytes());
} else if v < 0x1_00_00_00 {
self.push(0xfd);
self.extend(&(v as u32).to_le_bytes()[..3]);
} else {
self.push(0xfe);
self.extend(&v.to_le_bytes());
}
}
fn put_str_lenenc(&mut self, v: &str) {
self.put_bytes_lenenc(v.as_bytes());
}
fn put_bytes_lenenc(&mut self, v: &[u8]) {
self.put_uint_lenenc(v.len() as u64);
self.extend(v);
}
}
#[test]
fn test_encodes_int_lenenc_u8() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc(0xFA_u64);
assert_eq!(&buf[..], b"\xFA");
}
#[test]
fn test_encodes_int_lenenc_u16() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc(std::u16::MAX as u64);
assert_eq!(&buf[..], b"\xFC\xFF\xFF");
}
#[test]
fn test_encodes_int_lenenc_u24() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc(0xFF_FF_FF_u64);
assert_eq!(&buf[..], b"\xFD\xFF\xFF\xFF");
}
#[test]
fn test_encodes_int_lenenc_u64() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc(std::u64::MAX);
assert_eq!(&buf[..], b"\xFE\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF");
}
#[test]
fn test_encodes_int_lenenc_fb() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc(0xFB_u64);
assert_eq!(&buf[..], b"\xFC\xFB\x00");
}
#[test]
fn test_encodes_int_lenenc_fc() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc(0xFC_u64);
assert_eq!(&buf[..], b"\xFC\xFC\x00");
}
#[test]
fn test_encodes_int_lenenc_fd() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc(0xFD_u64);
assert_eq!(&buf[..], b"\xFC\xFD\x00");
}
#[test]
fn test_encodes_int_lenenc_fe() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc(0xFE_u64);
assert_eq!(&buf[..], b"\xFC\xFE\x00");
}
#[test]
fn test_encodes_int_lenenc_ff() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc(0xFF_u64);
assert_eq!(&buf[..], b"\xFC\xFF\x00");
}
#[test]
fn test_encodes_string_lenenc() {
let mut buf = Vec::with_capacity(1024);
buf.put_str_lenenc("random_string");
assert_eq!(&buf[..], b"\x0Drandom_string");
}
#[test]
fn test_encodes_byte_lenenc() {
let mut buf = Vec::with_capacity(1024);
buf.put_bytes_lenenc(b"random_string");
assert_eq!(&buf[..], b"\x0Drandom_string");
}

View file

@ -0,0 +1,5 @@
mod buf;
mod buf_mut;
pub use buf::MySqlBufExt;
pub use buf_mut::MySqlBufMutExt;

View file

@ -0,0 +1,5 @@
//! **MySQL** database driver.
pub mod collation;
pub mod io;
pub mod protocol;

View file

@ -0,0 +1,38 @@
use std::str::FromStr;
use crate::err_protocol;
use crate::error::Error;
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum AuthPlugin {
MySqlClearPassword,
MySqlNativePassword,
CachingSha2Password,
Sha256Password,
}
impl AuthPlugin {
pub(crate) fn name(self) -> &'static str {
match self {
AuthPlugin::MySqlClearPassword => "mysql_clear_password",
AuthPlugin::MySqlNativePassword => "mysql_native_password",
AuthPlugin::CachingSha2Password => "caching_sha2_password",
AuthPlugin::Sha256Password => "sha256_password",
}
}
}
impl FromStr for AuthPlugin {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"mysql_clear_password" => Ok(AuthPlugin::MySqlClearPassword),
"mysql_native_password" => Ok(AuthPlugin::MySqlNativePassword),
"caching_sha2_password" => Ok(AuthPlugin::CachingSha2Password),
"sha256_password" => Ok(AuthPlugin::Sha256Password),
_ => Err(err_protocol!("unknown authentication plugin: {}", s)),
}
}
}

View file

@ -0,0 +1,86 @@
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/group__group__cs__capabilities__flags.html
// https://mariadb.com/kb/en/library/connection/#capabilities
bitflags::bitflags! {
pub struct Capabilities: u64 {
// [MariaDB] MySQL compatibility
const MYSQL = 1;
// [*] Send found rows instead of affected rows in EOF_Packet.
const FOUND_ROWS = 2;
// Get all column flags.
const LONG_FLAG = 4;
// [*] Database (schema) name can be specified on connect in Handshake Response Packet.
const CONNECT_WITH_DB = 8;
// Don't allow database.table.column
const NO_SCHEMA = 16;
// [*] Compression protocol supported
const COMPRESS = 32;
// Special handling of ODBC behavior.
const ODBC = 64;
// Can use LOAD DATA LOCAL
const LOCAL_FILES = 128;
// [*] Ignore spaces before '('
const IGNORE_SPACE = 256;
// [*] New 4.1+ protocol
const PROTOCOL_41 = 512;
// This is an interactive client
const INTERACTIVE = 1024;
// Use SSL encryption for this session
const SSL = 2048;
// Client knows about transactions
const TRANSACTIONS = 8192;
// 4.1+ authentication
const SECURE_CONNECTION = (1 << 15);
// Enable/disable multi-statement support for COM_QUERY *and* COM_STMT_PREPARE
const MULTI_STATEMENTS = (1 << 16);
// Enable/disable multi-results for COM_QUERY
const MULTI_RESULTS = (1 << 17);
// Enable/disable multi-results for COM_STMT_PREPARE
const PS_MULTI_RESULTS = (1 << 18);
// Client supports plugin authentication
const PLUGIN_AUTH = (1 << 19);
// Client supports connection attributes
const CONNECT_ATTRS = (1 << 20);
// Enable authentication response packet to be larger than 255 bytes.
const PLUGIN_AUTH_LENENC_DATA = (1 << 21);
// Don't close the connection for a user account with expired password.
const CAN_HANDLE_EXPIRED_PASSWORDS = (1 << 22);
// Capable of handling server state change information.
const SESSION_TRACK = (1 << 23);
// Client no longer needs EOF_Packet and will use OK_Packet instead.
const DEPRECATE_EOF = (1 << 24);
// Support ZSTD protocol compression
const ZSTD_COMPRESSION_ALGORITHM = (1 << 26);
// Verify server certificate
const SSL_VERIFY_SERVER_CERT = (1 << 30);
// The client can handle optional metadata information in the resultset
const OPTIONAL_RESULTSET_METADATA = (1 << 25);
// Don't reset the options after an unsuccessful connect
const REMEMBER_OPTIONS = (1 << 31);
}
}

View file

@ -0,0 +1,58 @@
use bytes::{Buf, BufMut, Bytes};
use crate::err_protocol;
use crate::error::Error;
use crate::io::{BufExt, BufMutExt, Decode, Encode};
use crate::mysql::protocol::auth::AuthPlugin;
use crate::mysql::protocol::Capabilities;
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase_packets_protocol_auth_switch_request.html
#[derive(Debug)]
pub struct AuthSwitchRequest {
pub plugin: AuthPlugin,
pub data: Bytes,
}
impl Decode<'_> for AuthSwitchRequest {
fn decode_with(mut buf: Bytes, _: ()) -> Result<Self, Error> {
let header = buf.get_u8();
if header != 0xfe {
return Err(err_protocol!(
"expected 0xfe (AUTH_SWITCH) but found 0x{:x}",
header
));
}
let plugin = buf.get_str_nul()?.parse()?;
// See: https://github.com/mysql/mysql-server/blob/ea7d2e2d16ac03afdd9cb72a972a95981107bf51/sql/auth/sha2_password.cc#L942
if buf.len() != 21 {
return Err(err_protocol!(
"expected 21 bytes but found {} bytes",
buf.len()
));
}
let data = buf.get_bytes(20);
buf.advance(1); // NUL-terminator
Ok(Self { plugin, data })
}
}
impl Encode<'_, ()> for AuthSwitchRequest {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
buf.put_u8(0xfe);
buf.put_str_nul(self.plugin.name());
buf.extend(&self.data);
}
}
#[derive(Debug)]
pub struct AuthSwitchResponse(pub Vec<u8>);
impl Encode<'_, Capabilities> for AuthSwitchResponse {
fn encode_with(&self, buf: &mut Vec<u8>, _: Capabilities) {
buf.extend_from_slice(&self.0);
}
}

View file

@ -0,0 +1,233 @@
use bytes::buf::Chain;
use bytes::{Buf, BufMut, Bytes};
use crate::error::Error;
use crate::io::{BufExt, BufMutExt, Decode, Encode};
use crate::mysql::protocol::auth::AuthPlugin;
use crate::mysql::protocol::response::Status;
use crate::mysql::protocol::Capabilities;
// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
// https://mariadb.com/kb/en/connection/#initial-handshake-packet
#[derive(Debug)]
pub struct Handshake {
#[allow(unused)]
pub protocol_version: u8,
pub server_version: String,
#[allow(unused)]
pub connection_id: u32,
pub server_capabilities: Capabilities,
#[allow(unused)]
pub server_default_collation: u8,
#[allow(unused)]
pub status: Status,
pub auth_plugin: Option<AuthPlugin>,
pub auth_plugin_data: Chain<Bytes, Bytes>,
}
impl Decode<'_> for Handshake {
fn decode_with(mut buf: Bytes, _: ()) -> Result<Self, Error> {
let protocol_version = buf.get_u8(); // int<1>
let server_version = buf.get_str_nul()?; // string<NUL>
let connection_id = buf.get_u32_le(); // int<4>
let auth_plugin_data_1 = buf.get_bytes(8); // string<8>
buf.advance(1); // reserved: string<1>
let capabilities_1 = buf.get_u16_le(); // int<2>
let mut capabilities = Capabilities::from_bits_truncate(capabilities_1.into());
let collation = buf.get_u8(); // int<1>
let status = Status::from_bits_truncate(buf.get_u16_le());
let capabilities_2 = buf.get_u16_le(); // int<2>
capabilities |= Capabilities::from_bits_truncate(((capabilities_2 as u32) << 16).into());
let auth_plugin_data_len = if capabilities.contains(Capabilities::PLUGIN_AUTH) {
buf.get_u8()
} else {
buf.advance(1); // int<1>
0
};
buf.advance(6); // reserved: string<6>
if capabilities.contains(Capabilities::MYSQL) {
buf.advance(4); // reserved: string<4>
} else {
let capabilities_3 = buf.get_u32_le(); // int<4>
capabilities |= Capabilities::from_bits_truncate((capabilities_3 as u64) << 32);
}
let auth_plugin_data_2 = if capabilities.contains(Capabilities::SECURE_CONNECTION) {
let len = ((auth_plugin_data_len as isize) - 9).max(12) as usize;
let v = buf.get_bytes(len);
buf.advance(1); // NUL-terminator
v
} else {
Bytes::new()
};
let auth_plugin = if capabilities.contains(Capabilities::PLUGIN_AUTH) {
Some(buf.get_str_nul()?.parse()?)
} else {
None
};
Ok(Self {
protocol_version,
server_version,
connection_id,
server_default_collation: collation,
status,
server_capabilities: capabilities,
auth_plugin,
auth_plugin_data: auth_plugin_data_1.chain(auth_plugin_data_2),
})
}
}
impl Encode<'_, ()> for Handshake {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
buf.put_u8(self.protocol_version);
buf.put_str_nul(&self.server_version);
buf.put_u32_le(self.connection_id);
buf.put_slice(self.auth_plugin_data.first_ref());
buf.put_u8(0x00);
buf.put_u16_le((self.server_capabilities.bits() & 0x0000_FFFF) as u16);
buf.put_u8(self.server_default_collation);
buf.put_u16_le(self.status.bits());
buf.put_u16_le(((self.server_capabilities.bits() & 0xFFFF_0000) >> 16) as u16);
if self.server_capabilities.contains(Capabilities::PLUGIN_AUTH) {
buf.put_u8((self.auth_plugin_data.last_ref().len() + 8 + 1) as u8);
} else {
buf.put_u8(0);
}
buf.put_slice(&[0_u8; 10][..]);
if self
.server_capabilities
.contains(Capabilities::SECURE_CONNECTION)
{
buf.put_slice(self.auth_plugin_data.last_ref());
buf.put_u8(0);
}
if self.server_capabilities.contains(Capabilities::PLUGIN_AUTH) {
if let Some(auth_plugin) = self.auth_plugin {
buf.put_str_nul(auth_plugin.name());
}
}
}
}
#[test]
fn test_decode_handshake_mysql_8_0_18() {
const HANDSHAKE_MYSQL_8_0_18: &[u8] = b"\n8.0.18\x00\x19\x00\x00\x00\x114aB0c\x06g\x00\xff\xff\xff\x02\x00\xff\xc7\x15\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00tL\x03s\x0f[4\rl4. \x00caching_sha2_password\x00";
let mut p = Handshake::decode(HANDSHAKE_MYSQL_8_0_18.into()).unwrap();
assert_eq!(p.protocol_version, 10);
p.server_capabilities.toggle(
Capabilities::MYSQL
| Capabilities::FOUND_ROWS
| Capabilities::LONG_FLAG
| Capabilities::CONNECT_WITH_DB
| Capabilities::NO_SCHEMA
| Capabilities::COMPRESS
| Capabilities::ODBC
| Capabilities::LOCAL_FILES
| Capabilities::IGNORE_SPACE
| Capabilities::PROTOCOL_41
| Capabilities::INTERACTIVE
| Capabilities::SSL
| Capabilities::TRANSACTIONS
| Capabilities::SECURE_CONNECTION
| Capabilities::MULTI_STATEMENTS
| Capabilities::MULTI_RESULTS
| Capabilities::PS_MULTI_RESULTS
| Capabilities::PLUGIN_AUTH
| Capabilities::CONNECT_ATTRS
| Capabilities::PLUGIN_AUTH_LENENC_DATA
| Capabilities::CAN_HANDLE_EXPIRED_PASSWORDS
| Capabilities::SESSION_TRACK
| Capabilities::DEPRECATE_EOF
| Capabilities::ZSTD_COMPRESSION_ALGORITHM
| Capabilities::SSL_VERIFY_SERVER_CERT
| Capabilities::OPTIONAL_RESULTSET_METADATA
| Capabilities::REMEMBER_OPTIONS,
);
assert!(p.server_capabilities.is_empty());
assert_eq!(p.server_default_collation, 255);
assert!(p.status.contains(Status::SERVER_STATUS_AUTOCOMMIT));
assert!(matches!(
p.auth_plugin,
Some(AuthPlugin::CachingSha2Password)
));
assert_eq!(
&*p.auth_plugin_data.into_iter().collect::<Vec<_>>(),
&[17, 52, 97, 66, 48, 99, 6, 103, 116, 76, 3, 115, 15, 91, 52, 13, 108, 52, 46, 32,]
);
}
#[test]
fn test_decode_handshake_mariadb_10_4_7() {
const HANDSHAKE_MARIA_DB_10_4_7: &[u8] = b"\n5.5.5-10.4.7-MariaDB-1:10.4.7+maria~bionic\x00\x0b\x00\x00\x00t6L\\j\"dS\x00\xfe\xf7\x08\x02\x00\xff\x81\x15\x00\x00\x00\x00\x00\x00\x07\x00\x00\x00U14Oph9\"<H5n\x00mysql_native_password\x00";
let mut p = Handshake::decode(HANDSHAKE_MARIA_DB_10_4_7.into()).unwrap();
assert_eq!(p.protocol_version, 10);
assert_eq!(
&*p.server_version,
"5.5.5-10.4.7-MariaDB-1:10.4.7+maria~bionic"
);
p.server_capabilities.toggle(
Capabilities::FOUND_ROWS
| Capabilities::LONG_FLAG
| Capabilities::CONNECT_WITH_DB
| Capabilities::NO_SCHEMA
| Capabilities::COMPRESS
| Capabilities::ODBC
| Capabilities::LOCAL_FILES
| Capabilities::IGNORE_SPACE
| Capabilities::PROTOCOL_41
| Capabilities::INTERACTIVE
| Capabilities::TRANSACTIONS
| Capabilities::SECURE_CONNECTION
| Capabilities::MULTI_STATEMENTS
| Capabilities::MULTI_RESULTS
| Capabilities::PS_MULTI_RESULTS
| Capabilities::PLUGIN_AUTH
| Capabilities::CONNECT_ATTRS
| Capabilities::PLUGIN_AUTH_LENENC_DATA
| Capabilities::CAN_HANDLE_EXPIRED_PASSWORDS
| Capabilities::SESSION_TRACK
| Capabilities::DEPRECATE_EOF
| Capabilities::REMEMBER_OPTIONS,
);
assert!(p.server_capabilities.is_empty());
assert_eq!(p.server_default_collation, 8);
assert!(p.status.contains(Status::SERVER_STATUS_AUTOCOMMIT));
assert!(matches!(
p.auth_plugin,
Some(AuthPlugin::MySqlNativePassword)
));
assert_eq!(
&*p.auth_plugin_data.into_iter().collect::<Vec<_>>(),
&[116, 54, 76, 92, 106, 34, 100, 83, 85, 49, 52, 79, 112, 104, 57, 34, 60, 72, 53, 110,]
);
}

View file

@ -0,0 +1,147 @@
use std::str::FromStr;
use bytes::{Buf, Bytes};
use crate::error::Error;
use crate::io::{BufExt, BufMutExt, Decode, Encode};
use crate::mysql::io::{MySqlBufExt, MySqlBufMutExt};
use crate::mysql::protocol::auth::AuthPlugin;
use crate::mysql::protocol::connect::ssl_request::SslRequest;
use crate::mysql::protocol::Capabilities;
// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
// https://mariadb.com/kb/en/connection/#client-handshake-response
#[derive(Debug)]
pub struct HandshakeResponse {
pub database: Option<String>,
/// Max size of a command packet that the client wants to send to the server
pub max_packet_size: u32,
/// Default collation for the connection
pub collation: u8,
/// Name of the SQL account which client wants to log in
pub username: String,
/// Authentication method used by the client
pub auth_plugin: Option<AuthPlugin>,
/// Opaque authentication response
pub auth_response: Option<Bytes>,
}
impl Encode<'_, Capabilities> for HandshakeResponse {
fn encode_with(&self, buf: &mut Vec<u8>, mut capabilities: Capabilities) {
if self.auth_plugin.is_none() {
// ensure PLUGIN_AUTH is set *only* if we have a defined plugin
capabilities.remove(Capabilities::PLUGIN_AUTH);
}
// NOTE: Half of this packet is identical to the SSL Request packet
SslRequest {
max_packet_size: self.max_packet_size,
collation: self.collation,
}
.encode_with(buf, capabilities);
buf.put_str_nul(&self.username);
if capabilities.contains(Capabilities::PLUGIN_AUTH_LENENC_DATA) {
if let Some(response) = &self.auth_response {
buf.put_bytes_lenenc(response);
} else {
buf.put_bytes_lenenc(&[]);
}
} else if capabilities.contains(Capabilities::SECURE_CONNECTION) {
if let Some(response) = &self.auth_response {
buf.push(response.len() as u8);
buf.extend(response);
} else {
buf.push(0);
}
} else {
buf.push(0);
}
if capabilities.contains(Capabilities::CONNECT_WITH_DB) {
if let Some(database) = &self.database {
buf.put_str_nul(database);
} else {
buf.push(0);
}
}
if capabilities.contains(Capabilities::PLUGIN_AUTH) {
if let Some(plugin) = &self.auth_plugin {
buf.put_str_nul(plugin.name());
} else {
buf.push(0);
}
}
}
}
impl Decode<'_, &mut Capabilities> for HandshakeResponse {
fn decode_with(mut buf: Bytes, server_capabilities: &mut Capabilities) -> Result<Self, Error> {
let mut capabilities = buf.get_u32_le() as u64;
let max_packet_size = buf.get_u32_le();
let collation = buf.get_u8();
buf.advance(19);
let partial_cap = Capabilities::from_bits_truncate(capabilities);
if partial_cap.contains(Capabilities::MYSQL) {
// reserved: string<4>
buf.advance(4);
} else {
capabilities += (buf.get_u32_le() as u64) << 32;
}
let partial_cap = Capabilities::from_bits_truncate(capabilities);
if partial_cap.contains(Capabilities::SSL) && buf.is_empty() {
return Ok(HandshakeResponse {
collation,
max_packet_size,
username: "".to_string(),
auth_response: None,
auth_plugin: None,
database: None,
});
}
let username = buf.get_str_nul()?;
let auth_response = if partial_cap.contains(Capabilities::PLUGIN_AUTH_LENENC_DATA) {
Some(buf.get_bytes_lenenc())
} else if partial_cap.contains(Capabilities::SECURE_CONNECTION) {
let len = buf.get_u8();
Some(buf.get_bytes(len as usize))
} else {
Some(buf.get_bytes_nul()?)
};
let database = if partial_cap.contains(Capabilities::CONNECT_WITH_DB) {
Some(buf.get_str_nul()?)
} else {
None
};
let auth_plugin: Option<AuthPlugin> = if partial_cap.contains(Capabilities::PLUGIN_AUTH) {
Some(AuthPlugin::from_str(&buf.get_str_nul()?)?)
} else {
None
};
*server_capabilities &= Capabilities::from_bits_truncate(capabilities);
Ok(HandshakeResponse {
collation,
max_packet_size,
username,
auth_response,
auth_plugin,
database,
})
}
}

View file

@ -0,0 +1,13 @@
//! Connection Phase
//!
//! <https://dev.mysql.com/doc/internals/en/connection-phase.html>
mod auth_switch;
mod handshake;
mod handshake_response;
mod ssl_request;
pub use auth_switch::{AuthSwitchRequest, AuthSwitchResponse};
pub use handshake::Handshake;
pub use handshake_response::HandshakeResponse;
pub use ssl_request::SslRequest;

View file

@ -0,0 +1,30 @@
use crate::io::Encode;
use crate::mysql::protocol::Capabilities;
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase_packets_protocol_handshake_response.html
// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
#[derive(Debug)]
pub struct SslRequest {
pub max_packet_size: u32,
pub collation: u8,
}
impl Encode<'_, Capabilities> for SslRequest {
fn encode_with(&self, buf: &mut Vec<u8>, capabilities: Capabilities) {
buf.extend(&(capabilities.bits() as u32).to_le_bytes());
buf.extend(&self.max_packet_size.to_le_bytes());
buf.push(self.collation);
// reserved: string<19>
buf.extend(&[0_u8; 19]);
if capabilities.contains(Capabilities::MYSQL) {
// reserved: string<4>
buf.extend(&[0_u8; 4]);
} else {
// extended client capabilities (MariaDB-specified): int<4>
buf.extend(&((capabilities.bits() >> 32) as u32).to_le_bytes());
}
}
}

View file

@ -0,0 +1,11 @@
pub mod auth;
pub mod capabilities;
pub mod connect;
pub mod packet;
pub mod response;
pub mod row;
pub mod text;
pub use capabilities::Capabilities;
pub use packet::Packet;
pub use row::Row;

View file

@ -0,0 +1,89 @@
use std::ops::{Deref, DerefMut};
use bytes::Bytes;
use crate::error::Error;
use crate::io::{Decode, Encode};
use crate::mysql::protocol::response::{EofPacket, OkPacket};
use crate::mysql::protocol::Capabilities;
#[derive(Debug)]
pub struct Packet<T>(pub T);
impl<'en, 'stream, T> Encode<'stream, (Capabilities, &'stream mut u8)> for Packet<T>
where
T: Encode<'en, Capabilities>,
{
fn encode_with(
&self,
buf: &mut Vec<u8>,
(capabilities, sequence_id): (Capabilities, &'stream mut u8),
) {
// reserve space to write the prefixed length
let offset = buf.len();
buf.extend(&[0_u8; 4]);
// encode the payload
self.0.encode_with(buf, capabilities);
// determine the length of the encoded payload
// and write to our reserved space
let len = buf.len() - offset - 4;
let header = &mut buf[offset..];
// FIXME: Support larger packets
assert!(len < 0xFF_FF_FF);
header[..4].copy_from_slice(&(len as u32).to_le_bytes());
header[3] = *sequence_id;
*sequence_id = sequence_id.wrapping_add(1);
}
}
impl Packet<Bytes> {
pub(crate) fn decode<'de, T>(self) -> Result<T, Error>
where
T: Decode<'de, ()>,
{
self.decode_with(())
}
pub(crate) fn decode_with<'de, T, C>(self, context: C) -> Result<T, Error>
where
T: Decode<'de, C>,
{
T::decode_with(self.0, context)
}
pub(crate) fn ok(self) -> Result<OkPacket, Error> {
self.decode()
}
pub(crate) fn eof(self, capabilities: Capabilities) -> Result<EofPacket, Error> {
if capabilities.contains(Capabilities::DEPRECATE_EOF) {
let ok = self.ok()?;
Ok(EofPacket {
warnings: ok.warnings,
status: ok.status,
})
} else {
self.decode_with(capabilities)
}
}
}
impl Deref for Packet<Bytes> {
type Target = Bytes;
fn deref(&self) -> &Bytes {
&self.0
}
}
impl DerefMut for Packet<Bytes> {
fn deref_mut(&mut self) -> &mut Bytes {
&mut self.0
}
}

View file

@ -0,0 +1,36 @@
use bytes::{Buf, Bytes};
use crate::err_protocol;
use crate::error::Error;
use crate::io::Decode;
use crate::mysql::protocol::response::Status;
use crate::mysql::protocol::Capabilities;
/// Marks the end of a result set, returning status and warnings.
///
/// # Note
///
/// The EOF packet is deprecated as of MySQL 5.7.5. SQLx only uses this packet for MySQL
/// prior MySQL versions.
#[derive(Debug)]
pub struct EofPacket {
pub warnings: u16,
pub status: Status,
}
impl Decode<'_, Capabilities> for EofPacket {
fn decode_with(mut buf: Bytes, _: Capabilities) -> Result<Self, Error> {
let header = buf.get_u8();
if header != 0xfe {
return Err(err_protocol!(
"expected 0xfe (EOF_Packet) but found 0x{:x}",
header
));
}
let warnings = buf.get_u16_le();
let status = Status::from_bits_truncate(buf.get_u16_le());
Ok(Self { status, warnings })
}
}

View file

@ -0,0 +1,81 @@
use bytes::{Buf, BufMut, Bytes};
use crate::err_protocol;
use crate::error::Error;
use crate::io::{BufExt, Decode, Encode};
use crate::mysql::protocol::Capabilities;
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_err_packet.html
// https://mariadb.com/kb/en/err_packet/
/// Indicates that an error occurred.
#[derive(Debug)]
pub struct ErrPacket {
pub error_code: u16,
pub sql_state: Option<String>,
pub error_message: String,
}
impl Decode<'_, Capabilities> for ErrPacket {
fn decode_with(mut buf: Bytes, capabilities: Capabilities) -> Result<Self, Error> {
let header = buf.get_u8();
if header != 0xff {
return Err(err_protocol!(
"expected 0xff (ERR_Packet) but found 0x{:x}",
header
));
}
let error_code = buf.get_u16_le();
let mut sql_state = None;
if capabilities.contains(Capabilities::PROTOCOL_41) {
// If the next byte is '#' then we have a SQL STATE
if buf.get(0) == Some(&0x23) {
buf.advance(1);
sql_state = Some(buf.get_str(5)?);
}
}
let error_message = buf.get_str(buf.len())?;
Ok(Self {
error_code,
sql_state,
error_message,
})
}
}
impl Encode<'_, ()> for ErrPacket {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
buf.put_u8(0xff);
buf.put_u16_le(self.error_code);
buf.extend_from_slice(self.error_message.as_bytes())
//TODO: sql_state
}
}
#[test]
fn test_decode_err_packet_out_of_order() {
const ERR_PACKETS_OUT_OF_ORDER: &[u8] = b"\xff\x84\x04Got packets out of order";
let p =
ErrPacket::decode_with(ERR_PACKETS_OUT_OF_ORDER.into(), Capabilities::PROTOCOL_41).unwrap();
assert_eq!(&p.error_message, "Got packets out of order");
assert_eq!(p.error_code, 1156);
assert_eq!(p.sql_state, None);
}
#[test]
fn test_decode_err_packet_unknown_database() {
const ERR_HANDSHAKE_UNKNOWN_DB: &[u8] = b"\xff\x19\x04#42000Unknown database \'unknown\'";
let p =
ErrPacket::decode_with(ERR_HANDSHAKE_UNKNOWN_DB.into(), Capabilities::PROTOCOL_41).unwrap();
assert_eq!(p.error_code, 1049);
assert_eq!(p.sql_state.as_deref(), Some("42000"));
assert_eq!(&p.error_message, "Unknown database \'unknown\'");
}

View file

@ -0,0 +1,14 @@
//! Generic Response Packets
//!
//! <https://dev.mysql.com/doc/internals/en/generic-response-packets.html>
//! <https://mariadb.com/kb/en/4-server-response-packets/>
mod eof;
mod err;
mod ok;
mod status;
pub use eof::EofPacket;
pub use err::ErrPacket;
pub use ok::OkPacket;
pub use status::Status;

View file

@ -0,0 +1,63 @@
use bytes::{Buf, BufMut, Bytes};
use crate::err_protocol;
use crate::error::Error;
use crate::io::{Decode, Encode};
use crate::mysql::io::{MySqlBufExt, MySqlBufMutExt};
use crate::mysql::protocol::response::Status;
/// Indicates successful completion of a previous command sent by the client.
#[derive(Debug)]
pub struct OkPacket {
pub affected_rows: u64,
pub last_insert_id: u64,
pub status: Status,
pub warnings: u16,
}
impl Decode<'_> for OkPacket {
fn decode_with(mut buf: Bytes, _: ()) -> Result<Self, Error> {
let header = buf.get_u8();
if header != 0 && header != 0xfe {
return Err(err_protocol!(
"expected 0x00 or 0xfe (OK_Packet) but found 0x{:02x}",
header
));
}
let affected_rows = buf.get_uint_lenenc();
let last_insert_id = buf.get_uint_lenenc();
let status = Status::from_bits_truncate(buf.get_u16_le());
let warnings = buf.get_u16_le();
Ok(Self {
affected_rows,
last_insert_id,
status,
warnings,
})
}
}
impl Encode<'_, ()> for OkPacket {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
buf.put_u8(0);
buf.put_uint_lenenc(self.affected_rows);
buf.put_uint_lenenc(self.last_insert_id);
buf.put_u16_le(self.status.bits());
buf.put_u16_le(self.warnings);
}
}
#[test]
fn test_decode_ok_packet() {
const DATA: &[u8] = b"\x00\x00\x00\x02@\x00\x00";
let p = OkPacket::decode(DATA.into()).unwrap();
assert_eq!(p.affected_rows, 0);
assert_eq!(p.last_insert_id, 0);
assert_eq!(p.warnings, 0);
assert!(p.status.contains(Status::SERVER_STATUS_AUTOCOMMIT));
assert!(p.status.contains(Status::SERVER_SESSION_STATE_CHANGED));
}

View file

@ -0,0 +1,49 @@
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/mysql__com_8h.html#a1d854e841086925be1883e4d7b4e8cad
// https://mariadb.com/kb/en/library/mariadb-connectorc-types-and-definitions/#server-status
bitflags::bitflags! {
pub struct Status: u16 {
// Is raised when a multi-statement transaction has been started, either explicitly,
// by means of BEGIN or COMMIT AND CHAIN, or implicitly, by the first
// transactional statement, when autocommit=off.
const SERVER_STATUS_IN_TRANS = 1;
// Autocommit mode is set
const SERVER_STATUS_AUTOCOMMIT = 2;
// Multi query - next query exists.
const SERVER_MORE_RESULTS_EXISTS = 8;
const SERVER_QUERY_NO_GOOD_INDEX_USED = 16;
const SERVER_QUERY_NO_INDEX_USED = 32;
// When using COM_STMT_FETCH, indicate that current cursor still has result
const SERVER_STATUS_CURSOR_EXISTS = 64;
// When using COM_STMT_FETCH, indicate that current cursor has finished to send results
const SERVER_STATUS_LAST_ROW_SENT = 128;
// Database has been dropped
const SERVER_STATUS_DB_DROPPED = (1 << 8);
// Current escape mode is "no backslash escape"
const SERVER_STATUS_NO_BACKSLASH_ESCAPES = (1 << 9);
// A DDL change did have an impact on an existing PREPARE (an automatic
// re-prepare has been executed)
const SERVER_STATUS_METADATA_CHANGED = (1 << 10);
// Last statement took more than the time value specified
// in server variable long_query_time.
const SERVER_QUERY_WAS_SLOW = (1 << 11);
// This result-set contain stored procedure output parameter.
const SERVER_PS_OUT_PARAMS = (1 << 12);
// Current transaction is a read-only transaction.
const SERVER_STATUS_IN_TRANS_READONLY = (1 << 13);
// This status flag, when on, implies that one of the state information has changed
// on the server because of the execution of the last statement.
const SERVER_SESSION_STATE_CHANGED = (1 << 14);
}
}

View file

@ -0,0 +1,17 @@
use std::ops::Range;
use bytes::Bytes;
#[derive(Debug)]
pub struct Row {
pub(crate) storage: Bytes,
pub(crate) values: Vec<Option<Range<usize>>>,
}
impl Row {
pub(crate) fn get(&self, index: usize) -> Option<&[u8]> {
self.values[index]
.as_ref()
.map(|col| &self.storage[(col.start as usize)..(col.end as usize)])
}
}

View file

@ -0,0 +1,265 @@
use std::str::from_utf8;
use bitflags::bitflags;
use bytes::{Buf, Bytes};
use crate::err_protocol;
use crate::error::Error;
use crate::io::Decode;
use crate::mysql::io::MySqlBufExt;
use crate::mysql::protocol::Capabilities;
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/group__group__cs__column__definition__flags.html
bitflags! {
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
pub struct ColumnFlags: u16 {
/// Field can't be `NULL`.
const NOT_NULL = 1;
/// Field is part of a primary key.
const PRIMARY_KEY = 2;
/// Field is part of a unique key.
const UNIQUE_KEY = 4;
/// Field is part of a multi-part unique or primary key.
const MULTIPLE_KEY = 8;
/// Field is a blob.
const BLOB = 16;
/// Field is unsigned.
const UNSIGNED = 32;
/// Field is zero filled.
const ZEROFILL = 64;
/// Field is binary.
const BINARY = 128;
/// Field is an enumeration.
const ENUM = 256;
/// Field is an auto-incement field.
const AUTO_INCREMENT = 512;
/// Field is a timestamp.
const TIMESTAMP = 1024;
/// Field is a set.
const SET = 2048;
/// Field does not have a default value.
const NO_DEFAULT_VALUE = 4096;
/// Field is set to NOW on UPDATE.
const ON_UPDATE_NOW = 8192;
/// Field is a number.
const NUM = 32768;
}
}
// https://dev.mysql.com/doc/internals/en/com-query-response.html#column-type
#[derive(Debug, Copy, Clone, PartialEq)]
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
#[repr(u8)]
pub enum ColumnType {
Decimal = 0x00,
Tiny = 0x01,
Short = 0x02,
Long = 0x03,
Float = 0x04,
Double = 0x05,
Null = 0x06,
Timestamp = 0x07,
LongLong = 0x08,
Int24 = 0x09,
Date = 0x0a,
Time = 0x0b,
Datetime = 0x0c,
Year = 0x0d,
VarChar = 0x0f,
Bit = 0x10,
Json = 0xf5,
NewDecimal = 0xf6,
Enum = 0xf7,
Set = 0xf8,
TinyBlob = 0xf9,
MediumBlob = 0xfa,
LongBlob = 0xfb,
Blob = 0xfc,
VarString = 0xfd,
String = 0xfe,
Geometry = 0xff,
}
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_com_query_response_text_resultset_column_definition.html
// https://mariadb.com/kb/en/resultset/#column-definition-packet
// https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41
#[derive(Debug)]
pub struct ColumnDefinition {
#[allow(unused)]
catalog: Bytes,
#[allow(unused)]
schema: Bytes,
#[allow(unused)]
table_alias: Bytes,
#[allow(unused)]
table: Bytes,
alias: Bytes,
name: Bytes,
pub(crate) char_set: u16,
pub(crate) max_size: u32,
pub(crate) r#type: ColumnType,
pub(crate) flags: ColumnFlags,
#[allow(unused)]
decimals: u8,
}
impl ColumnDefinition {
// NOTE: strings in-protocol are transmitted according to the client character set
// as this is UTF-8, all these strings should be UTF-8
pub(crate) fn name(&self) -> Result<&str, Error> {
from_utf8(&self.name).map_err(Error::protocol)
}
pub(crate) fn alias(&self) -> Result<&str, Error> {
from_utf8(&self.alias).map_err(Error::protocol)
}
}
impl Decode<'_, Capabilities> for ColumnDefinition {
fn decode_with(mut buf: Bytes, _: Capabilities) -> Result<Self, Error> {
let catalog = buf.get_bytes_lenenc();
let schema = buf.get_bytes_lenenc();
let table_alias = buf.get_bytes_lenenc();
let table = buf.get_bytes_lenenc();
let alias = buf.get_bytes_lenenc();
let name = buf.get_bytes_lenenc();
let _next_len = buf.get_uint_lenenc(); // always 0x0c
let char_set = buf.get_u16_le();
let max_size = buf.get_u32_le();
let type_id = buf.get_u8();
let flags = buf.get_u16_le();
let decimals = buf.get_u8();
Ok(Self {
catalog,
schema,
table_alias,
table,
alias,
name,
char_set,
max_size,
r#type: ColumnType::try_from_u16(type_id)?,
flags: ColumnFlags::from_bits_truncate(flags),
decimals,
})
}
}
impl ColumnType {
pub(crate) fn name(
self,
char_set: u16,
flags: ColumnFlags,
max_size: Option<u32>,
) -> &'static str {
let is_binary = char_set == 63;
let is_unsigned = flags.contains(ColumnFlags::UNSIGNED);
let is_enum = flags.contains(ColumnFlags::ENUM);
match self {
ColumnType::Tiny if max_size == Some(1) => "BOOLEAN",
ColumnType::Tiny if is_unsigned => "TINYINT UNSIGNED",
ColumnType::Short if is_unsigned => "SMALLINT UNSIGNED",
ColumnType::Long if is_unsigned => "INT UNSIGNED",
ColumnType::Int24 if is_unsigned => "MEDIUMINT UNSIGNED",
ColumnType::LongLong if is_unsigned => "BIGINT UNSIGNED",
ColumnType::Tiny => "TINYINT",
ColumnType::Short => "SMALLINT",
ColumnType::Long => "INT",
ColumnType::Int24 => "MEDIUMINT",
ColumnType::LongLong => "BIGINT",
ColumnType::Float => "FLOAT",
ColumnType::Double => "DOUBLE",
ColumnType::Null => "NULL",
ColumnType::Timestamp => "TIMESTAMP",
ColumnType::Date => "DATE",
ColumnType::Time => "TIME",
ColumnType::Datetime => "DATETIME",
ColumnType::Year => "YEAR",
ColumnType::Bit => "BIT",
ColumnType::Enum => "ENUM",
ColumnType::Set => "SET",
ColumnType::Decimal | ColumnType::NewDecimal => "DECIMAL",
ColumnType::Geometry => "GEOMETRY",
ColumnType::Json => "JSON",
ColumnType::String if is_binary => "BINARY",
ColumnType::String if is_enum => "ENUM",
ColumnType::VarChar | ColumnType::VarString if is_binary => "VARBINARY",
ColumnType::String => "CHAR",
ColumnType::VarChar | ColumnType::VarString => "VARCHAR",
ColumnType::TinyBlob if is_binary => "TINYBLOB",
ColumnType::TinyBlob => "TINYTEXT",
ColumnType::Blob if is_binary => "BLOB",
ColumnType::Blob => "TEXT",
ColumnType::MediumBlob if is_binary => "MEDIUMBLOB",
ColumnType::MediumBlob => "MEDIUMTEXT",
ColumnType::LongBlob if is_binary => "LONGBLOB",
ColumnType::LongBlob => "LONGTEXT",
}
}
pub(crate) fn try_from_u16(id: u8) -> Result<Self, Error> {
Ok(match id {
0x00 => ColumnType::Decimal,
0x01 => ColumnType::Tiny,
0x02 => ColumnType::Short,
0x03 => ColumnType::Long,
0x04 => ColumnType::Float,
0x05 => ColumnType::Double,
0x06 => ColumnType::Null,
0x07 => ColumnType::Timestamp,
0x08 => ColumnType::LongLong,
0x09 => ColumnType::Int24,
0x0a => ColumnType::Date,
0x0b => ColumnType::Time,
0x0c => ColumnType::Datetime,
0x0d => ColumnType::Year,
// [internal] 0x0e => ColumnType::NewDate,
0x0f => ColumnType::VarChar,
0x10 => ColumnType::Bit,
// [internal] 0x11 => ColumnType::Timestamp2,
// [internal] 0x12 => ColumnType::Datetime2,
// [internal] 0x13 => ColumnType::Time2,
0xf5 => ColumnType::Json,
0xf6 => ColumnType::NewDecimal,
0xf7 => ColumnType::Enum,
0xf8 => ColumnType::Set,
0xf9 => ColumnType::TinyBlob,
0xfa => ColumnType::MediumBlob,
0xfb => ColumnType::LongBlob,
0xfc => ColumnType::Blob,
0xfd => ColumnType::VarString,
0xfe => ColumnType::String,
0xff => ColumnType::Geometry,
_ => {
return Err(err_protocol!("unknown column type 0x{:02x}", id));
}
})
}
}

View file

@ -0,0 +1,9 @@
mod column;
mod ping;
mod query;
mod quit;
pub use column::{ColumnDefinition, ColumnFlags, ColumnType};
pub use ping::Ping;
pub use query::Query;
pub use quit::Quit;

View file

@ -0,0 +1,13 @@
use crate::io::Encode;
use crate::mysql::protocol::Capabilities;
// https://dev.mysql.com/doc/internals/en/com-ping.html
#[derive(Debug)]
pub struct Ping;
impl Encode<'_, Capabilities> for Ping {
fn encode_with(&self, buf: &mut Vec<u8>, _: Capabilities) {
buf.push(0x0e); // COM_PING
}
}

View file

@ -0,0 +1,32 @@
use bytes::{Buf, Bytes};
use crate::error::Error;
use crate::io::{BufExt, Decode, Encode};
use crate::mysql::protocol::Capabilities;
// https://dev.mysql.com/doc/internals/en/com-query.html
#[derive(Debug)]
pub struct Query(pub String);
impl Encode<'_, ()> for Query {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
buf.push(0x03); // COM_QUERY
buf.extend(self.0.as_bytes())
}
}
impl Encode<'_, Capabilities> for Query {
fn encode_with(&self, buf: &mut Vec<u8>, _: Capabilities) {
buf.push(0x03); // COM_QUERY
buf.extend(self.0.as_bytes())
}
}
impl Decode<'_> for Query {
fn decode_with(mut buf: Bytes, _: ()) -> Result<Self, Error> {
buf.advance(1);
let q = buf.get_str(buf.len())?;
Ok(Query(q))
}
}

View file

@ -0,0 +1,13 @@
use crate::io::Encode;
use crate::mysql::protocol::Capabilities;
// https://dev.mysql.com/doc/internals/en/com-quit.html
#[derive(Debug)]
pub struct Quit;
impl Encode<'_, Capabilities> for Quit {
fn encode_with(&self, buf: &mut Vec<u8>, _: Capabilities) {
buf.push(0x01); // COM_QUIT
}
}

View file

@ -5,13 +5,10 @@ name = "warpgate-protocol-mysql"
version = "0.3.0"
[dependencies]
sqlx-core-guts = { version = "0.6", features = [
"runtime-tokio-rustls",
"mysql",
], path = "../../sqlx-core-guts/sqlx-core-guts" }
warpgate-admin = { version = "*", path = "../warpgate-admin" }
warpgate-common = { version = "*", path = "../warpgate-common" }
warpgate-db-entities = { version = "*", path = "../warpgate-db-entities" }
warpgate-database-protocols = { version = "*", path = "../warpgate-database-protocols" }
anyhow = { version = "1.0", features = ["std"] }
async-trait = "0.1"
tokio = { version = "1.19", features = ["tracing", "signal"] }
@ -23,7 +20,7 @@ rand = "0.8"
sha1 = "0.10.1"
password-hash = { version = "0.2", features = ["std"] }
delegate = "0.6"
rustls = "0.20"
rustls = { version = "0.20", features = ["dangerous_configuration"] }
rustls-pemfile = "1.0"
tokio-rustls = "0.23"
thiserror = "1.0"

View file

@ -1,17 +1,19 @@
use std::sync::Arc;
use bytes::BytesMut;
use sqlx_core_guts::io::Decode;
use sqlx_core_guts::mysql::protocol::auth::AuthPlugin;
use sqlx_core_guts::mysql::protocol::connect::{Handshake, HandshakeResponse, SslRequest};
use sqlx_core_guts::mysql::protocol::response::ErrPacket;
use sqlx_core_guts::mysql::protocol::Capabilities;
use sqlx_core_guts::mysql::MySqlSslMode;
use tokio::net::TcpStream;
use tracing::*;
use warpgate_common::{TargetMySqlOptions, TlsMode};
use warpgate_database_protocols::io::Decode;
use warpgate_database_protocols::mysql::protocol::auth::AuthPlugin;
use warpgate_database_protocols::mysql::protocol::connect::{
Handshake, HandshakeResponse, SslRequest,
};
use warpgate_database_protocols::mysql::protocol::response::ErrPacket;
use warpgate_database_protocols::mysql::protocol::Capabilities;
use crate::common::{compute_auth_challenge_response, parse_mysql_uri};
use crate::error::{InvalidMySqlTargetConfig, MySqlError};
use crate::common::compute_auth_challenge_response;
use crate::error::MySqlError;
use crate::stream::MySqlStream;
use crate::tls::configure_tls_connector;
@ -28,13 +30,15 @@ pub struct ConnectionOptions {
}
impl MySqlClient {
pub async fn connect(uri: &str, mut options: ConnectionOptions) -> Result<Self, MySqlError> {
let opts = parse_mysql_uri(uri)?;
pub async fn connect(
target: &TargetMySqlOptions,
mut options: ConnectionOptions,
) -> Result<Self, MySqlError> {
let mut stream =
MySqlStream::new(TcpStream::connect((opts.host.clone(), opts.port)).await?);
MySqlStream::new(TcpStream::connect((target.host.clone(), target.port)).await?);
options.capabilities.remove(Capabilities::SSL);
if opts.ssl_mode != MySqlSslMode::Disabled {
if target.tls.mode != TlsMode::Disabled {
options.capabilities |= Capabilities::SSL;
}
@ -44,20 +48,17 @@ impl MySqlClient {
let handshake = Handshake::decode(payload)?;
options.capabilities &= handshake.server_capabilities;
if opts.ssl_mode != MySqlSslMode::Disabled
&& opts.ssl_mode != MySqlSslMode::Preferred
&& !options.capabilities.contains(Capabilities::SSL)
if target.tls.mode == TlsMode::Required && !options.capabilities.contains(Capabilities::SSL)
{
return Err(MySqlError::TlsNotSupported);
}
info!(capabilities=?options.capabilities, "Target handshake");
if options.capabilities.contains(Capabilities::SSL)
&& opts.ssl_mode != MySqlSslMode::Disabled
if options.capabilities.contains(Capabilities::SSL) && target.tls.mode != TlsMode::Disabled
{
let accept_invalid_certs = opts.ssl_mode == MySqlSslMode::Preferred;
let accept_invalid_hostname = opts.ssl_mode != MySqlSslMode::VerifyIdentity;
let accept_invalid_certs = !target.tls.verify;
let accept_invalid_hostname = false; // ca + hostname verification
let client_config = Arc::new(
configure_tls_connector(accept_invalid_certs, accept_invalid_hostname, None)
.await?,
@ -70,7 +71,8 @@ impl MySqlClient {
stream.flush().await?;
stream = stream
.upgrade((
opts.host
target
.host
.as_str()
.try_into()
.map_err(|_| MySqlError::InvalidDomainName)?,
@ -86,7 +88,7 @@ impl MySqlClient {
collation: options.collation,
database: options.database,
max_packet_size: options.max_packet_size,
username: opts.username,
username: target.username.clone(),
};
if handshake.auth_plugin == Some(AuthPlugin::MySqlNativePassword) {
@ -100,13 +102,13 @@ impl MySqlClient {
warn!("Invalid scramble length ({})", scramble_bytes.len());
}
Ok(scramble) => {
let Some(password) = opts.password else {
return Err(InvalidMySqlTargetConfig::NoPassword.into())
};
response.auth_plugin = Some(AuthPlugin::MySqlNativePassword);
response.auth_response = Some(
BytesMut::from(
compute_auth_challenge_response(scramble, &password)
compute_auth_challenge_response(
scramble,
target.password.as_deref().unwrap_or(""),
)
.map_err(MySqlError::other)?
.as_bytes(),
)

View file

@ -1,19 +1,8 @@
use sha1::Digest;
use sqlx_core_guts::error::Error as SqlxError;
use sqlx_core_guts::mysql::MySqlConnectOptions;
use warpgate_common::ProtocolName;
use crate::error::InvalidMySqlTargetConfig;
pub const PROTOCOL_NAME: ProtocolName = "MySQL";
pub fn parse_mysql_uri(uri: &str) -> Result<MySqlConnectOptions, InvalidMySqlTargetConfig> {
uri.parse().map_err(|e| match e {
SqlxError::Configuration(e) => InvalidMySqlTargetConfig::UriParse(e),
_ => InvalidMySqlTargetConfig::Unknown,
})
}
pub fn compute_auth_challenge_response(
challenge: [u8; 20],
password: &str,

View file

@ -1,15 +1,13 @@
use std::error::Error;
use sqlx_core_guts::error::Error as SqlxError;
use warpgate_common::WarpgateError;
use warpgate_database_protocols::error::Error as SqlxError;
use crate::stream::MySqlStreamError;
use crate::tls::{MaybeTlsStreamError, RustlsSetupError};
#[derive(thiserror::Error, Debug)]
pub enum MySqlError {
#[error("invalid target config: {0}")]
InvalidTargetConfig(#[from] InvalidMySqlTargetConfig),
#[error("protocol error: {0}")]
ProtocolError(String),
#[error("sudden disconnection")]
@ -38,16 +36,6 @@ pub enum MySqlError {
Other(Box<dyn Error + Send + Sync>),
}
#[derive(thiserror::Error, Debug)]
pub enum InvalidMySqlTargetConfig {
#[error("Password not set")]
NoPassword,
#[error("URI parse error: {0}")]
UriParse(Box<dyn Error + Send + Sync>),
#[error("Unkown")]
Unknown,
}
impl MySqlError {
pub fn other<E: Error + Send + Sync + 'static>(err: E) -> Self {
Self::Other(Box::new(err))

View file

@ -3,12 +3,6 @@ use std::sync::Arc;
use bytes::{Buf, Bytes, BytesMut};
use rand::Rng;
use rustls::ServerConfig;
use sqlx_core_guts::io::{BufExt, Decode};
use sqlx_core_guts::mysql::protocol::auth::AuthPlugin;
use sqlx_core_guts::mysql::protocol::connect::{AuthSwitchRequest, Handshake, HandshakeResponse};
use sqlx_core_guts::mysql::protocol::response::{ErrPacket, OkPacket, Status};
use sqlx_core_guts::mysql::protocol::text::Query;
use sqlx_core_guts::mysql::protocol::Capabilities;
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tracing::*;
@ -19,6 +13,14 @@ use warpgate_common::{
authorize_ticket, AuthCredential, AuthResult, Secret, Services, TargetMySqlOptions,
TargetOptions, WarpgateServerHandle,
};
use warpgate_database_protocols::io::{BufExt, Decode};
use warpgate_database_protocols::mysql::protocol::auth::AuthPlugin;
use warpgate_database_protocols::mysql::protocol::connect::{
AuthSwitchRequest, Handshake, HandshakeResponse,
};
use warpgate_database_protocols::mysql::protocol::response::{ErrPacket, OkPacket, Status};
use warpgate_database_protocols::mysql::protocol::text::Query;
use warpgate_database_protocols::mysql::protocol::Capabilities;
use crate::client::{ConnectionOptions, MySqlClient};
use crate::error::MySqlError;
@ -313,7 +315,7 @@ impl MySqlSession {
}
let mut client = match MySqlClient::connect(
&options.uri,
&options,
ConnectionOptions {
collation: handshake.collation,
database: handshake.database,

View file

@ -1,10 +1,10 @@
use bytes::{Bytes, BytesMut};
use mysql_common::proto::codec::error::PacketCodecError;
use mysql_common::proto::codec::PacketCodec;
use sqlx_core_guts::io::Encode;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
use tracing::*;
use warpgate_database_protocols::io::Encode;
use crate::tls::{MaybeTlsStream, MaybeTlsStreamError, UpgradableStream};