diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 89acee5..5e9a1db 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -27,6 +27,10 @@ replace = version = "{new_version}" search = version = "{current_version}" replace = version = "{new_version}" +[bumpversion:file:warpgate-protocol-mysql/Cargo.toml] +search = version = "{current_version}" +replace = version = "{new_version}" + [bumpversion:file:warpgate-protocol-ssh/Cargo.toml] search = version = "{current_version}" replace = version = "{new_version}" diff --git a/Cargo.lock b/Cargo.lock index 4706fb9..5df024c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -118,6 +118,12 @@ dependencies = [ "password-hash 0.4.1", ] +[[package]] +name = "arrayvec" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8da52d66c7071e2e3fa2a1e5c6d088fec47b593032b254f5e980de8ea54454d6" + [[package]] name = "async-attributes" version = "1.1.2" @@ -396,9 +402,9 @@ checksum = "8a32fd6af2b5827bce66c29053ba0e7c42b9dcab01835835058558c10851a46b" [[package]] name = "bcrypt-pbkdf" -version = "0.6.2" +version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c38c03b9506bd92bf1ef50665a81eda156f615438f7654bffba58907e6149d7" +checksum = "12621b8e87feb183a6e5dbb315e49026b2229c4398797ee0ae2d1bc00aef41b9" dependencies = [ "blowfish", "crypto-mac", @@ -407,12 +413,42 @@ dependencies = [ "zeroize", ] +[[package]] +name = "bigdecimal" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6aaf33151a6429fe9211d1b276eafdf70cdff28b071e76c0b0e1503221ea3744" +dependencies = [ + "num-bigint", + "num-integer", + "num-traits", +] + [[package]] name = "bimap" version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bc0455254eb5c6964c4545d8bac815e1a1be4f3afe0ae695ea539c12d728d44b" +[[package]] +name = "bindgen" +version = "0.59.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bd2a9a458e8f4304c52c43ebb0cfbd520289f8379a52e329a38afda99bf8eb8" +dependencies = [ + "bitflags", + "cexpr", + "clang-sys", + "lazy_static", + "lazycell", + "peeking_take_while", + "proc-macro2", + "quote", + "regex", + "rustc-hash", + "shlex", +] + [[package]] name = "bit-vec" version = "0.6.3" @@ -425,6 +461,18 @@ version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" +[[package]] +name = "bitvec" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c" +dependencies = [ + "funty", + "radium", + "tap", + "wyz", +] + [[package]] name = "blake2" version = "0.10.4" @@ -529,6 +577,15 @@ version = "1.0.73" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2fff2a6927b3bb87f9595d67196a70493f627687a71d87a0d692242c33f58c11" +[[package]] +name = "cexpr" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" +dependencies = [ + "nom", +] + [[package]] name = "cfg-if" version = "1.0.0" @@ -564,6 +621,17 @@ dependencies = [ "generic-array", ] +[[package]] +name = "clang-sys" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a050e2153c5be08febd6734e29298e844fdb0fa21aeddd63b4eb7baa106c69b" +dependencies = [ + "glob", + "libc", + "libloading", +] + [[package]] name = "clap" version = "3.2.2" @@ -603,6 +671,15 @@ dependencies = [ "os_str_bytes", ] +[[package]] +name = "cmake" +version = "0.1.48" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8ad8cef104ac57b68b89df3208164d228503abbdce70f6880ffa3d970e7443a" +dependencies = [ + "cc", +] + [[package]] name = "color_quant" version = "1.1.0" @@ -1031,6 +1108,7 @@ dependencies = [ "cfg-if", "crc32fast", "libc", + "libz-sys", "miniz_oxide", ] @@ -1077,6 +1155,70 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "frunk" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cd67cf7d54b7e72d0ea76f3985c3747d74aee43e0218ad993b7903ba7a5395e" +dependencies = [ + "frunk_core", + "frunk_derives", + "frunk_proc_macros", +] + +[[package]] +name = "frunk_core" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1246cf43ec80bf8b2505b5c360b8fb999c97dabd17dbb604d85558d5cbc25482" + +[[package]] +name = "frunk_derives" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3dbc4f084ec5a3f031d24ccedeb87ab2c3189a2f33b8d070889073837d5ea09e" +dependencies = [ + "frunk_proc_macro_helpers", + "quote", + "syn", +] + +[[package]] +name = "frunk_proc_macro_helpers" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99f11257f106c6753f5ffcb8e601fb39c390a088017aaa55b70c526bff15f63e" +dependencies = [ + "frunk_core", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "frunk_proc_macros" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a078bd8459eccbb85e0b007b8f756585762a72a9efc53f359b371c3b6351dbcc" +dependencies = [ + "frunk_core", + "frunk_proc_macros_impl", + "proc-macro-hack", +] + +[[package]] +name = "frunk_proc_macros_impl" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ffba99f0fa4f57e42f57388fbb9a0ca863bc2b4261f3c5570fed579d5df6c32" +dependencies = [ + "frunk_core", + "frunk_proc_macro_helpers", + "proc-macro-hack", + "quote", + "syn", +] + [[package]] name = "fsevent-sys" version = "4.1.0" @@ -1086,6 +1228,12 @@ dependencies = [ "libc", ] +[[package]] +name = "funty" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" + [[package]] name = "futures" version = "0.3.21" @@ -1238,6 +1386,12 @@ version = "0.26.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78cc372d058dcf6d5ecd98510e7fbc9e5aec4d21de70f65fea8fecebcd881bd4" +[[package]] +name = "glob" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b919933a397b79c37e33b77bb2aa3dc8eb6e165ad809e58ff75bc7db2e34574" + [[package]] name = "gloo-timers" version = "0.2.4" @@ -1648,12 +1802,101 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +[[package]] +name = "lazycell" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" + +[[package]] +name = "lexical" +version = "6.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7aefb36fd43fef7003334742cbf77b243fcd36418a1d1bdd480d613a67968f6" +dependencies = [ + "lexical-core", +] + +[[package]] +name = "lexical-core" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2cde5de06e8d4c2faabc400238f9ae1c74d5412d03a7bd067645ccbc47070e46" +dependencies = [ + "lexical-parse-float", + "lexical-parse-integer", + "lexical-util", + "lexical-write-float", + "lexical-write-integer", +] + +[[package]] +name = "lexical-parse-float" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "683b3a5ebd0130b8fb52ba0bdc718cc56815b6a097e28ae5a6997d0ad17dc05f" +dependencies = [ + "lexical-parse-integer", + "lexical-util", + "static_assertions", +] + +[[package]] +name = "lexical-parse-integer" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d0994485ed0c312f6d965766754ea177d07f9c00c9b82a5ee62ed5b47945ee9" +dependencies = [ + "lexical-util", + "static_assertions", +] + +[[package]] +name = "lexical-util" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5255b9ff16ff898710eb9eb63cb39248ea8a5bb036bea8085b1a767ff6c4e3fc" +dependencies = [ + "static_assertions", +] + +[[package]] +name = "lexical-write-float" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accabaa1c4581f05a3923d1b4cfd124c329352288b7b9da09e766b0668116862" +dependencies = [ + "lexical-util", + "lexical-write-integer", + "static_assertions", +] + +[[package]] +name = "lexical-write-integer" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1b6f3d1f4422866b68192d62f77bc5c700bee84f3069f2469d7bc8c77852446" +dependencies = [ + "lexical-util", + "static_assertions", +] + [[package]] name = "libc" version = "0.2.124" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "21a41fed9d98f27ab1c6d161da622a4fa35e8a54a8adc24bbf3ddd0ef70b0e50" +[[package]] +name = "libloading" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "efbc0f03f9a775e9f6aed295c6a1ba2253c5757a9e03d55c6caa46a681abcddd" +dependencies = [ + "cfg-if", + "winapi", +] + [[package]] name = "libsodium-sys" version = "0.2.7" @@ -1677,6 +1920,17 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "libz-sys" +version = "1.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9702761c3935f8cc2f101793272e202c72b99da8f4224a19ddcf1279a6450bbf" +dependencies = [ + "cc", + "pkg-config", + "vcpkg", +] + [[package]] name = "linked-hash-map" version = "0.5.4" @@ -1818,6 +2072,43 @@ dependencies = [ "version_check", ] +[[package]] +name = "mysql_common" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20ce6fdcef94a8e87fea3f9402de4d7e403f10e8c20b6e2bd207e6b6f8a4eab0" +dependencies = [ + "base64", + "bigdecimal", + "bindgen", + "bitflags", + "bitvec", + "byteorder", + "bytes", + "cc", + "cmake", + "crc32fast", + "flate2", + "frunk", + "lazy_static", + "lexical", + "num-bigint", + "num-traits", + "rand", + "regex", + "rust_decimal", + "saturating", + "serde", + "serde_json", + "sha-1", + "sha2 0.10.2", + "smallvec", + "subprocess", + "thiserror", + "time 0.3.11", + "uuid", +] + [[package]] name = "native-tls" version = "0.2.10" @@ -2163,6 +2454,12 @@ dependencies = [ "sha2 0.9.9", ] +[[package]] +name = "peeking_take_while" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099" + [[package]] name = "pem" version = "1.0.2" @@ -2444,6 +2741,12 @@ dependencies = [ "version_check", ] +[[package]] +name = "proc-macro-hack" +version = "0.5.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbf0c48bc1d91375ae5c3cd81e3722dff1abcf81a30960240640d223f59fe0e5" + [[package]] name = "proc-macro2" version = "1.0.39" @@ -2505,6 +2808,12 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "radium" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" + [[package]] name = "rand" version = "0.8.5" @@ -2783,6 +3092,17 @@ dependencies = [ "walkdir", ] +[[package]] +name = "rust_decimal" +version = "1.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34a3bb58e85333f1ab191bf979104b586ebd77475bc6681882825f4532dfe87c" +dependencies = [ + "arrayvec", + "num-traits", + "serde", +] + [[package]] name = "rustc-demangle" version = "0.1.21" @@ -2858,6 +3178,12 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "saturating" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ece8e78b2f38ec51c51f5d475df0a7187ba5111b2a28bdc761ee05b075d40a71" + [[package]] name = "schannel" version = "0.1.19" @@ -3139,6 +3465,17 @@ dependencies = [ "digest 0.10.3", ] +[[package]] +name = "sha1" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c77f4e7f65455545c2153c1253d25056825e77ee2533f0e41deb65a93a34852f" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest 0.10.3", +] + [[package]] name = "sha2" version = "0.9.9" @@ -3172,6 +3509,12 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "shlex" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43b2853a4d09f215c24cc5489c992ce46052d359b5109343cbafbf26bc62f8a3" + [[package]] name = "signal-hook-registry" version = "1.4.0" @@ -3329,6 +3672,12 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + [[package]] name = "stringprep" version = "0.1.2" @@ -3345,6 +3694,16 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" +[[package]] +name = "subprocess" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c2e86926081dda636c546d8c5e641661049d7562a68f5488be4a1f7f66f6086" +dependencies = [ + "libc", + "winapi", +] + [[package]] name = "subtle" version = "2.4.1" @@ -3368,6 +3727,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "20518fe4a4c9acf048008599e464deb21beeae3d3578418951a189c235a7a9a8" +[[package]] +name = "tap" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" + [[package]] name = "tempfile" version = "3.3.0" @@ -3488,10 +3853,11 @@ checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c" [[package]] name = "tokio" -version = "1.19.2" +version = "1.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c51a52ed6686dd62c320f9b89299e9dfb46f730c7a48e635c19f21d116cb1439" +checksum = "57aec3cfa4c296db7255446efb4928a6be304b431a806216105542a67b6ca82e" dependencies = [ + "autocfg", "bytes", "libc", "memchr", @@ -3561,9 +3927,9 @@ dependencies = [ [[package]] name = "tokio-tungstenite" -version = "0.17.1" +version = "0.17.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06cda1232a49558c46f8a504d5b93101d42c0bf7f911f12a105ba48168f821ae" +checksum = "f714dd15bead90401d77e04243611caec13726c2408afd5b31901dfcdcb3b181" dependencies = [ "futures-util", "log", @@ -3778,9 +4144,9 @@ checksum = "59547bce71d9c38b83d9c0e92b6066c4253371f15005def0c30d9657f50c7642" [[package]] name = "tungstenite" -version = "0.17.2" +version = "0.17.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d96a2dea40e7570482f28eb57afbe42d97551905da6a9400acc5c328d24004f5" +checksum = "e27992fd6a8c29ee7eef28fc78349aa244134e10ad447ce3b9f0ac0ed0fa4ce0" dependencies = [ "base64", "byteorder", @@ -3995,6 +4361,7 @@ dependencies = [ "warpgate-admin", "warpgate-common", "warpgate-protocol-http", + "warpgate-protocol-mysql", "warpgate-protocol-ssh", ] @@ -4011,6 +4378,7 @@ dependencies = [ "mime_guess", "poem", "poem-openapi", + "regex", "russh-keys", "rust-embed", "sea-orm", @@ -4041,6 +4409,7 @@ dependencies = [ "once_cell", "packet", "password-hash 0.4.1", + "poem", "poem-openapi", "rand", "rand_chacha", @@ -4060,6 +4429,19 @@ dependencies = [ "warpgate-db-migrations", ] +[[package]] +name = "warpgate-database-protocols" +version = "0.3.0" +dependencies = [ + "bitflags", + "bytes", + "futures-core", + "futures-util", + "memchr", + "thiserror", + "tokio", +] + [[package]] name = "warpgate-db-entities" version = "0.3.0" @@ -4098,6 +4480,7 @@ dependencies = [ "percent-encoding", "poem", "poem-openapi", + "regex", "reqwest", "serde", "serde_json", @@ -4111,6 +4494,33 @@ dependencies = [ "warpgate-web", ] +[[package]] +name = "warpgate-protocol-mysql" +version = "0.3.0" +dependencies = [ + "anyhow", + "async-trait", + "bytes", + "delegate", + "mysql_common", + "password-hash 0.2.3", + "rand", + "rustls", + "rustls-pemfile", + "sha1", + "thiserror", + "tokio", + "tokio-rustls", + "tracing", + "uuid", + "warpgate-admin", + "warpgate-common", + "warpgate-database-protocols", + "warpgate-db-entities", + "webpki", + "webpki-roots", +] + [[package]] name = "warpgate-protocol-ssh" version = "0.3.0" @@ -4132,6 +4542,7 @@ dependencies = [ "uuid", "warpgate-common", "warpgate-db-entities", + "zeroize", ] [[package]] @@ -4242,6 +4653,15 @@ dependencies = [ "untrusted", ] +[[package]] +name = "webpki-roots" +version = "0.22.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1c760f0d366a6c24a02ed7816e23e691f5d92291f94d15e836006fd11b04daf" +dependencies = [ + "webpki", +] + [[package]] name = "wepoll-ffi" version = "0.1.2" @@ -4334,6 +4754,15 @@ dependencies = [ "winapi", ] +[[package]] +name = "wyz" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30b31594f29d27036c383b53b59ed3476874d518f0efb151b27a4c275141390e" +dependencies = [ + "tap", +] + [[package]] name = "yaml-rust" version = "0.4.5" @@ -4364,6 +4793,6 @@ dependencies = [ [[package]] name = "zeroize" -version = "1.3.0" +version = "1.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4756f7db3f7b5574938c3eb1c117038b8e07f95ee6718c0efad4ac21508f1efd" +checksum = "20b578acffd8516a6c3f2a1bdefc1ec37e547bb4e0fb8b6b01a4cafc886b4442" diff --git a/Cargo.toml b/Cargo.toml index a7fc1ad..0732e18 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,9 @@ members = [ "warpgate-common", "warpgate-db-migrations", "warpgate-db-entities", + "warpgate-database-protocols", "warpgate-protocol-http", + "warpgate-protocol-mysql", "warpgate-protocol-ssh", "warpgate-web", ] diff --git a/justfile b/justfile index 84e8730..5522ad6 100644 --- a/justfile +++ b/justfile @@ -1,4 +1,4 @@ -projects := "warpgate warpgate-admin warpgate-common warpgate-db-entities warpgate-db-migrations warpgate-protocol-ssh" +projects := "warpgate warpgate-admin warpgate-common warpgate-db-entities warpgate-db-migrations warpgate-database-protocols warpgate-protocol-ssh warpgate-protocol-mysql" run *ARGS: RUST_BACKTRACE=1 RUST_LOG=warpgate cd warpgate && cargo run -- --config ../config.yaml {{ARGS}} diff --git a/warpgate-admin/Cargo.toml b/warpgate-admin/Cargo.toml index d252efc..6ea044e 100644 --- a/warpgate-admin/Cargo.toml +++ b/warpgate-admin/Cargo.toml @@ -5,25 +5,40 @@ name = "warpgate-admin" version = "0.3.0" [dependencies] -anyhow = {version = "1.0", features = ["std"]} +anyhow = { version = "1.0", features = ["std"] } async-trait = "0.1" bytes = "1.1" chrono = "0.4" futures = "0.3" hex = "0.4" -mime_guess = {version = "2.0", default_features = false} -poem = {version = "^1.3.30", features = ["cookie", "session", "anyhow", "websocket"]} -poem-openapi = {version = "^2.0.4", features = ["swagger-ui", "chrono", "uuid", "static-files"]} -russh-keys = {version = "0.22.0-beta.3", features = ["openssl"]} +mime_guess = { version = "2.0", default_features = false } +poem = { version = "^1.3.30", features = [ + "cookie", + "session", + "anyhow", + "websocket", +] } +poem-openapi = { version = "^2.0.4", features = [ + "swagger-ui", + "chrono", + "uuid", + "static-files", +] } +russh-keys = { version = "0.22.0-beta.3", features = ["openssl"] } rust-embed = "6.3" -sea-orm = {version = "^0.9", features = ["sqlx-sqlite", "runtime-tokio-native-tls", "macros"], default-features = false} +sea-orm = { version = "^0.9", features = [ + "sqlx-sqlite", + "runtime-tokio-native-tls", + "macros", +], default-features = false } serde = "1.0" serde_json = "1.0" thiserror = "1.0" -tokio = {version = "1.19", features = ["tracing"]} +tokio = {version = "1.20", features = ["tracing"]} tracing = "0.1" -uuid = {version = "1.0", features = ["v4", "serde"]} -warpgate-common = {version = "*", path = "../warpgate-common"} -warpgate-db-entities = {version = "*", path = "../warpgate-db-entities"} -warpgate-protocol-ssh = {version = "*", path = "../warpgate-protocol-ssh"} -warpgate-web = {version = "*", path = "../warpgate-web"} +uuid = { version = "1.0", features = ["v4", "serde"] } +warpgate-common = { version = "*", path = "../warpgate-common" } +warpgate-db-entities = { version = "*", path = "../warpgate-db-entities" } +warpgate-protocol-ssh = { version = "*", path = "../warpgate-protocol-ssh" } +warpgate-web = { version = "*", path = "../warpgate-web" } +regex = "1.5" diff --git a/warpgate-admin/src/main.rs b/warpgate-admin/src/main.rs index 565eb72..31cb32d 100644 --- a/warpgate-admin/src/main.rs +++ b/warpgate-admin/src/main.rs @@ -1,10 +1,18 @@ #![feature(type_alias_impl_trait, let_else, try_blocks)] mod api; use poem_openapi::OpenApiService; +use regex::Regex; pub fn main() { let api_service = OpenApiService::new(api::get(), "Warpgate Web Admin", env!("CARGO_PKG_VERSION")) .server("/@warpgate/admin/api"); - println!("{}", api_service.spec()); + + let spec = api_service.spec(); + let re = Regex::new(r"TargetOptions\[(?P\w+)\]").unwrap(); + let spec = re.replace_all(&spec, "TargetOptions$name"); + let re = Regex::new(r"PaginatedResponse<(?P\w+)>").unwrap(); + let spec = re.replace_all(&spec, "Paginated$name"); + + println!("{}", spec); } diff --git a/warpgate-common/Cargo.toml b/warpgate-common/Cargo.toml index a8b76b3..636fd1f 100644 --- a/warpgate-common/Cargo.toml +++ b/warpgate-common/Cargo.toml @@ -16,7 +16,8 @@ lazy_static = "1.4" once_cell = "1.10" packet = "0.1" password-hash = "0.4" -poem-openapi = {version = "^2.0.4", features = ["swagger-ui", "chrono", "uuid", "static-files"]} +poem = "^1.3.30" +poem-openapi = {version = "2.0.4", features = ["swagger-ui", "chrono", "uuid", "static-files"]} rand = "0.8" rand_chacha = "0.3" rand_core = {version = "0.6", features = ["std"]} @@ -24,7 +25,7 @@ sea-orm = {version = "^0.9", features = ["sqlx-sqlite", "runtime-tokio-native-tl serde = "1.0" serde_json = "1.0" thiserror = "1.0" -tokio = {version = "1.19", features = ["tracing"]} +tokio = {version = "1.20", features = ["tracing"]} totp-rs = {version = "2.0", features = ["otpauth"]} tracing = "0.1" tracing-core = "0.1" diff --git a/warpgate-common/src/auth.rs b/warpgate-common/src/auth.rs index d13e188..270eb8f 100644 --- a/warpgate-common/src/auth.rs +++ b/warpgate-common/src/auth.rs @@ -20,7 +20,9 @@ impl From<&String> for AuthSelector { return AuthSelector::Ticket { secret }; } - let mut parts = selector.splitn(2, ':'); + let separator = if selector.contains('#') { '#' } else { ':' }; + + let mut parts = selector.splitn(2, separator); let username = parts.next().unwrap_or("").to_string(); let target_name = parts.next().unwrap_or("").to_string(); AuthSelector::User { diff --git a/warpgate-common/src/config.rs b/warpgate-common/src/config.rs index 19a15b5..bc524d7 100644 --- a/warpgate-common/src/config.rs +++ b/warpgate-common/src/config.rs @@ -1,12 +1,13 @@ use std::collections::HashMap; +use std::net::ToSocketAddrs; use std::path::PathBuf; use std::time::Duration; -use poem_openapi::{Object, Union}; +use poem_openapi::{Enum, Object, Union}; use serde::{Deserialize, Serialize}; use crate::helpers::otp::OtpSecretKey; -use crate::Secret; +use crate::{ListenEndpoint, Secret}; const fn _default_true() -> bool { true @@ -16,10 +17,14 @@ const fn _default_false() -> bool { false } -const fn _default_port() -> u16 { +const fn _default_ssh_port() -> u16 { 22 } +const fn _default_mysql_port() -> u16 { + 3306 +} + #[inline] fn _default_username() -> String { "root".to_owned() @@ -41,8 +46,13 @@ fn _default_database_url() -> Secret { } #[inline] -fn _default_http_listen() -> String { - "0.0.0.0:8888".to_owned() +fn _default_http_listen() -> ListenEndpoint { + ListenEndpoint("0.0.0.0:8888".to_socket_addrs().unwrap().next().unwrap()) +} + +#[inline] +fn _default_mysql_listen() -> ListenEndpoint { + ListenEndpoint("0.0.0.0:33306".to_socket_addrs().unwrap().next().unwrap()) } #[inline] @@ -58,7 +68,7 @@ fn _default_empty_vec() -> Vec { #[derive(Debug, Deserialize, Serialize, Clone, Object)] pub struct TargetSSHOptions { pub host: String, - #[serde(default = "_default_port")] + #[serde(default = "_default_ssh_port")] pub port: u16, #[serde(default = "_default_username")] pub username: String, @@ -91,6 +101,59 @@ pub struct TargetHTTPOptions { pub headers: Option>, } +#[derive(Debug, Deserialize, Serialize, Clone, Enum, PartialEq, Eq)] +pub enum TlsMode { + Disabled, + Preferred, + Required, +} + +impl Default for TlsMode { + fn default() -> Self { + TlsMode::Preferred + } +} + +#[derive(Debug, Deserialize, Serialize, Clone, Object)] +pub struct Tls { + #[serde(default)] + pub mode: TlsMode, + + #[serde(default)] + pub verify: bool, +} + +#[allow(clippy::derivable_impls)] +impl Default for Tls { + fn default() -> Self { + Self { + mode: TlsMode::default(), + verify: false, + } + } +} + +#[derive(Debug, Deserialize, Serialize, Clone, Object)] +pub struct TargetMySqlOptions { + #[serde(default = "_default_empty_string")] + pub host: String, + + #[serde(default = "_default_mysql_port")] + pub port: u16, + + #[serde(default = "_default_username")] + pub username: String, + + #[serde(default)] + pub password: Option, + + #[serde(default)] + pub tls: Tls, + + #[serde(default)] + pub verify_tls: bool, +} + #[derive(Debug, Deserialize, Serialize, Clone, Object, Default)] pub struct TargetWebAdminOptions {} @@ -104,12 +167,14 @@ pub struct Target { } #[derive(Debug, Deserialize, Serialize, Clone, Union)] -#[oai(discriminator_name = "kind")] +#[oai(discriminator_name = "kind", one_of)] pub enum TargetOptions { #[serde(rename = "ssh")] Ssh(TargetSSHOptions), #[serde(rename = "http")] Http(TargetHTTPOptions), + #[serde(rename = "mysql")] + MySql(TargetMySqlOptions), #[serde(rename = "web_admin")] WebAdmin(TargetWebAdminOptions), } @@ -134,6 +199,8 @@ pub struct UserRequireCredentialsPolicy { pub http: Option>, #[serde(skip_serializing_if = "Option::is_none")] pub ssh: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub mysql: Option>, } #[derive(Debug, Deserialize, Serialize, Clone)] @@ -150,8 +217,8 @@ pub struct Role { pub name: String, } -fn _default_ssh_listen() -> String { - "0.0.0.0:2222".to_owned() +fn _default_ssh_listen() -> ListenEndpoint { + ListenEndpoint("0.0.0.0:2222".to_socket_addrs().unwrap().next().unwrap()) } fn _default_ssh_client_key() -> String { @@ -164,8 +231,11 @@ fn _default_ssh_keys_path() -> String { #[derive(Debug, Deserialize, Serialize, Clone)] pub struct SSHConfig { + #[serde(default = "_default_false")] + pub enable: bool, + #[serde(default = "_default_ssh_listen")] - pub listen: String, + pub listen: ListenEndpoint, #[serde(default = "_default_ssh_keys_path")] pub keys: String, @@ -177,6 +247,7 @@ pub struct SSHConfig { impl Default for SSHConfig { fn default() -> Self { SSHConfig { + enable: true, listen: _default_ssh_listen(), keys: _default_ssh_keys_path(), client_key: _default_ssh_client_key(), @@ -186,11 +257,11 @@ impl Default for SSHConfig { #[derive(Debug, Deserialize, Serialize, Clone)] pub struct HTTPConfig { - #[serde(default = "_default_false")] + #[serde(default = "_default_true")] pub enable: bool, #[serde(default = "_default_http_listen")] - pub listen: String, + pub listen: ListenEndpoint, #[serde(default)] pub certificate: String, @@ -210,6 +281,32 @@ impl Default for HTTPConfig { } } +#[derive(Debug, Deserialize, Serialize, Clone)] +pub struct MySQLConfig { + #[serde(default = "_default_false")] + pub enable: bool, + + #[serde(default = "_default_mysql_listen")] + pub listen: ListenEndpoint, + + #[serde(default)] + pub certificate: String, + + #[serde(default)] + pub key: String, +} + +impl Default for MySQLConfig { + fn default() -> Self { + MySQLConfig { + enable: true, + listen: _default_http_listen(), + certificate: "".to_owned(), + key: "".to_owned(), + } + } +} + #[derive(Debug, Deserialize, Serialize, Clone)] pub struct RecordingsConfig { #[serde(default = "_default_false")] @@ -267,6 +364,9 @@ pub struct WarpgateConfigStore { #[serde(default)] pub http: HTTPConfig, + #[serde(default)] + pub mysql: MySQLConfig, + #[serde(default)] pub log: LogConfig, } @@ -282,6 +382,7 @@ impl Default for WarpgateConfigStore { database_url: _default_database_url(), ssh: SSHConfig::default(), http: HTTPConfig::default(), + mysql: MySQLConfig::default(), log: LogConfig::default(), } } diff --git a/warpgate-common/src/config_providers/file.rs b/warpgate-common/src/config_providers/file.rs index a29e9db..5531840 100644 --- a/warpgate-common/src/config_providers/file.rs +++ b/warpgate-common/src/config_providers/file.rs @@ -1,7 +1,6 @@ use std::collections::HashSet; use std::sync::Arc; -use anyhow::Result; use async_trait::async_trait; use data_encoding::BASE64_MIME; use sea_orm::ActiveValue::Set; @@ -16,7 +15,7 @@ use crate::helpers::hash::verify_password_hash; use crate::helpers::otp::verify_totp; use crate::{ AuthCredential, AuthResult, ProtocolName, Target, User, UserAuthCredential, UserSnapshot, - WarpgateConfig, + WarpgateConfig, WarpgateError, }; pub struct FileConfigProvider { @@ -46,7 +45,7 @@ fn credential_is_type(c: &UserAuthCredential, k: &str) -> bool { #[async_trait] impl ConfigProvider for FileConfigProvider { - async fn list_users(&mut self) -> Result> { + async fn list_users(&mut self) -> Result, WarpgateError> { Ok(self .config .lock() @@ -58,7 +57,7 @@ impl ConfigProvider for FileConfigProvider { .collect::>()) } - async fn list_targets(&mut self) -> Result> { + async fn list_targets(&mut self) -> Result, WarpgateError> { Ok(self .config .lock() @@ -75,7 +74,7 @@ impl ConfigProvider for FileConfigProvider { username: &str, credentials: &[AuthCredential], protocol: ProtocolName, - ) -> Result { + ) -> Result { if credentials.is_empty() { return Ok(AuthResult::Rejected); } @@ -167,6 +166,7 @@ impl ConfigProvider for FileConfigProvider { let required_kinds = match protocol { "SSH" => &policy.ssh, "HTTP" => &policy.http, + "MySQL" => &policy.mysql, _ => { error!(%protocol, "Unkown protocol"); return Ok(AuthResult::Rejected); @@ -204,7 +204,11 @@ impl ConfigProvider for FileConfigProvider { }) } - async fn authorize_target(&mut self, username: &str, target_name: &str) -> Result { + async fn authorize_target( + &mut self, + username: &str, + target_name: &str, + ) -> Result { let config = self.config.lock().await; let user = config .store @@ -220,7 +224,7 @@ impl ConfigProvider for FileConfigProvider { }; let Some(target) = target else { - error!("Selected target not found: {}", target_name); + warn!("Selected target not found: {}", target_name); return Ok(false); }; @@ -244,11 +248,11 @@ impl ConfigProvider for FileConfigProvider { Ok(intersect) } - async fn consume_ticket(&mut self, ticket_id: &Uuid) -> Result<()> { + async fn consume_ticket(&mut self, ticket_id: &Uuid) -> Result<(), WarpgateError> { let db = self.db.lock().await; let ticket = Ticket::Entity::find_by_id(*ticket_id).one(&*db).await?; let Some(ticket) = ticket else { - anyhow::bail!("Ticket not found: {}", ticket_id); + return Err(WarpgateError::InvalidTicket(*ticket_id)); }; if let Some(uses_left) = ticket.uses_left { diff --git a/warpgate-common/src/config_providers/mod.rs b/warpgate-common/src/config_providers/mod.rs index 0d078f6..6ae1212 100644 --- a/warpgate-common/src/config_providers/mod.rs +++ b/warpgate-common/src/config_providers/mod.rs @@ -1,7 +1,6 @@ mod file; use std::sync::Arc; -use anyhow::Result; use async_trait::async_trait; use bytes::Bytes; pub use file::FileConfigProvider; @@ -11,7 +10,7 @@ use tracing::*; use uuid::Uuid; use warpgate_db_entities::Ticket; -use crate::{ProtocolName, Secret, Target, UserSnapshot}; +use crate::{ProtocolName, Secret, Target, UserSnapshot, WarpgateError}; pub enum AuthResult { Accepted { username: String }, @@ -30,27 +29,31 @@ pub enum AuthCredential { #[async_trait] pub trait ConfigProvider { - async fn list_users(&mut self) -> Result>; + async fn list_users(&mut self) -> Result, WarpgateError>; - async fn list_targets(&mut self) -> Result>; + async fn list_targets(&mut self) -> Result, WarpgateError>; async fn authorize( &mut self, username: &str, credentials: &[AuthCredential], protocol: ProtocolName, - ) -> Result; + ) -> Result; - async fn authorize_target(&mut self, username: &str, target: &str) -> Result; + async fn authorize_target( + &mut self, + username: &str, + target: &str, + ) -> Result; - async fn consume_ticket(&mut self, ticket_id: &Uuid) -> Result<()>; + async fn consume_ticket(&mut self, ticket_id: &Uuid) -> Result<(), WarpgateError>; } //TODO: move this somewhere pub async fn authorize_ticket( db: &Arc>, secret: &Secret, -) -> Result> { +) -> Result, WarpgateError> { let ticket = { let db = db.lock().await; Ticket::Entity::find() diff --git a/warpgate-common/src/error.rs b/warpgate-common/src/error.rs new file mode 100644 index 0000000..f37cc95 --- /dev/null +++ b/warpgate-common/src/error.rs @@ -0,0 +1,26 @@ +use std::error::Error; + +use poem::error::ResponseError; +use uuid::Uuid; + +#[derive(thiserror::Error, Debug)] +pub enum WarpgateError { + #[error("database error: {0}")] + DatabaseError(#[from] sea_orm::DbErr), + #[error("ticket not found: {0}")] + InvalidTicket(Uuid), + #[error(transparent)] + Other(Box), +} + +impl ResponseError for WarpgateError { + fn status(&self) -> poem::http::StatusCode { + poem::http::StatusCode::INTERNAL_SERVER_ERROR + } +} + +impl WarpgateError { + pub fn other(err: E) -> Self { + Self::Other(Box::new(err)) + } +} diff --git a/warpgate-common/src/lib.rs b/warpgate-common/src/lib.rs index 457a418..a6a76e2 100644 --- a/warpgate-common/src/lib.rs +++ b/warpgate-common/src/lib.rs @@ -5,6 +5,7 @@ mod config_providers; pub mod consts; mod data; pub mod db; +mod error; pub mod eventhub; pub mod helpers; pub mod logging; @@ -18,6 +19,7 @@ mod types; pub use config::*; pub use config_providers::*; pub use data::*; +pub use error::WarpgateError; pub use protocols::*; pub use services::*; pub use state::{SessionState, SessionStateInit, State}; diff --git a/warpgate-common/src/protocols/handle.rs b/warpgate-common/src/protocols/handle.rs index 955a2b6..1a3fa93 100644 --- a/warpgate-common/src/protocols/handle.rs +++ b/warpgate-common/src/protocols/handle.rs @@ -1,11 +1,10 @@ use std::sync::Arc; -use anyhow::{Context, Result}; use sea_orm::{ColumnTrait, DatabaseConnection, EntityTrait, QueryFilter}; use tokio::sync::Mutex; use warpgate_db_entities::Session; -use crate::{SessionId, SessionState, State, Target}; +use crate::{SessionId, SessionState, State, Target, WarpgateError}; pub trait SessionHandle { fn close(&mut self); @@ -41,7 +40,7 @@ impl WarpgateServerHandle { &self.session_state } - pub async fn set_username(&mut self, username: String) -> Result<()> { + pub async fn set_username(&self, username: String) -> Result<(), WarpgateError> { use sea_orm::ActiveValue::Set; { @@ -64,7 +63,7 @@ impl WarpgateServerHandle { Ok(()) } - pub async fn set_target(&self, target: &Target) -> Result<()> { + pub async fn set_target(&self, target: &Target) -> Result<(), WarpgateError> { use sea_orm::ActiveValue::Set; { let mut state = self.session_state.lock().await; @@ -77,7 +76,7 @@ impl WarpgateServerHandle { Session::Entity::update_many() .set(Session::ActiveModel { target_snapshot: Set(Some( - serde_json::to_string(&target).context("Error serializing target")?, + serde_json::to_string(&target).map_err(WarpgateError::other)?, )), ..Default::default() }) diff --git a/warpgate-common/src/types.rs b/warpgate-common/src/types.rs index 11f6463..5164248 100644 --- a/warpgate-common/src/types.rs +++ b/warpgate-common/src/types.rs @@ -1,4 +1,6 @@ use std::fmt::Debug; +use std::net::{SocketAddr, ToSocketAddrs}; +use std::ops::Deref; use bytes::Bytes; use data_encoding::HEXLOWER; @@ -66,3 +68,51 @@ impl Debug for Secret { write!(f, "") } } + +#[derive(Clone)] +pub struct ListenEndpoint(pub SocketAddr); + +impl Deref for ListenEndpoint { + type Target = SocketAddr; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl<'de> Deserialize<'de> for ListenEndpoint { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let v: String = Deserialize::deserialize::(deserializer)?; + let v = v + .to_socket_addrs() + .map_err(|e| { + serde::de::Error::custom(format!( + "failed to resolve {v} into a TCP endpoint: {e:?}" + )) + })? + .next() + .ok_or_else(|| { + return serde::de::Error::custom(format!( + "failed to resolve {v} into a TCP endpoint" + )); + })?; + Ok(Self(v)) + } +} + +impl Serialize for ListenEndpoint { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + self.0.serialize(serializer) + } +} + +impl Debug for ListenEndpoint { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} diff --git a/warpgate-database-protocols/Cargo.toml b/warpgate-database-protocols/Cargo.toml new file mode 100644 index 0000000..6e34dbc --- /dev/null +++ b/warpgate-database-protocols/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "warpgate-database-protocols" +version = "0.3.0" +description = "Core of SQLx, the rust SQL toolkit. Just the database protocol parts." +license = "MIT OR Apache-2.0" +edition = "2021" +authors = [ + "Ryan Leckey ", + "Austin Bonander ", + "Chloe Ross ", + "Daniel Akhterov ", +] + +[dependencies] +tokio = { version = "1.20", features = ["io-util"] } +bitflags = { version = "1.3", default-features = false } +bytes = "1.1" +futures-core = { version = "0.3", default-features = false } +futures-util = { version = "0.3", default-features = false, features = [ + "alloc", + "sink", +] } +memchr = { version = "2.4.1", default-features = false } +thiserror = "1.0" diff --git a/warpgate-database-protocols/README.md b/warpgate-database-protocols/README.md new file mode 100644 index 0000000..735ab70 --- /dev/null +++ b/warpgate-database-protocols/README.md @@ -0,0 +1 @@ +This is an extract from sqlx-core with Encode/Decode impls added for server-side packet flow diff --git a/warpgate-database-protocols/src/error.rs b/warpgate-database-protocols/src/error.rs new file mode 100644 index 0000000..9a47e4c --- /dev/null +++ b/warpgate-database-protocols/src/error.rs @@ -0,0 +1,241 @@ +//! Types for working with errors produced by SQLx. + +use std::borrow::Cow; +use std::error::Error as StdError; +use std::fmt::Display; +use std::io; +use std::result::Result as StdResult; + +/// A specialized `Result` type for SQLx. +pub type Result = StdResult; + +// Convenience type alias for usage within SQLx. +// Do not make this type public. +pub type BoxDynError = Box; + +/// An unexpected `NULL` was encountered during decoding. +/// +/// Returned from [`Row::get`](crate::row::Row::get) if the value from the database is `NULL`, +/// and you are not decoding into an `Option`. +#[derive(thiserror::Error, Debug)] +#[error("unexpected null; try decoding as an `Option`")] +pub struct UnexpectedNullError; + +/// Represents all the ways a method can fail within SQLx. +#[derive(Debug, thiserror::Error)] +#[non_exhaustive] +pub enum Error { + /// Error occurred while parsing a connection string. + #[error("error with configuration: {0}")] + Configuration(#[source] BoxDynError), + + /// Error returned from the database. + #[error("error returned from database: {0}")] + Database(#[source] Box), + + /// Error communicating with the database backend. + #[error("error communicating with database: {0}")] + Io(#[from] io::Error), + + /// Error occurred while attempting to establish a TLS connection. + #[error("error occurred while attempting to establish a TLS connection: {0}")] + Tls(#[source] BoxDynError), + + /// Unexpected or invalid data encountered while communicating with the database. + /// + /// This should indicate there is a programming error in a SQLx driver or there + /// is something corrupted with the connection to the database itself. + #[error("encountered unexpected or invalid data: {0}")] + Protocol(String), + + /// No rows returned by a query that expected to return at least one row. + #[error("no rows returned by a query that expected to return at least one row")] + RowNotFound, + + /// Type in query doesn't exist. Likely due to typo or missing user type. + #[error("type named {type_name} not found")] + TypeNotFound { type_name: String }, + + /// Column index was out of bounds. + #[error("column index out of bounds: the len is {len}, but the index is {index}")] + ColumnIndexOutOfBounds { index: usize, len: usize }, + + /// No column found for the given name. + #[error("no column found for name: {0}")] + ColumnNotFound(String), + + /// Error occurred while decoding a value from a specific column. + #[error("error occurred while decoding column {index}: {source}")] + ColumnDecode { + index: String, + + #[source] + source: BoxDynError, + }, + + /// Error occurred while decoding a value. + #[error("error occurred while decoding: {0}")] + Decode(#[source] BoxDynError), + + /// A [`Pool::acquire`] timed out due to connections not becoming available or + /// because another task encountered too many errors while trying to open a new connection. + /// + /// [`Pool::acquire`]: crate::pool::Pool::acquire + #[error("pool timed out while waiting for an open connection")] + PoolTimedOut, + + /// [`Pool::close`] was called while we were waiting in [`Pool::acquire`]. + /// + /// [`Pool::acquire`]: crate::pool::Pool::acquire + /// [`Pool::close`]: crate::pool::Pool::close + #[error("attempted to acquire a connection on a closed pool")] + PoolClosed, + + /// A background worker has crashed. + #[error("attempted to communicate with a crashed background worker")] + WorkerCrashed, + + #[cfg(feature = "migrate")] + #[error("{0}")] + Migrate(#[source] Box), +} + +impl StdError for Box {} + +impl Error { + #[allow(dead_code)] + #[inline] + pub(crate) fn protocol(err: impl Display) -> Self { + Error::Protocol(err.to_string()) + } + + #[allow(dead_code)] + #[inline] + pub(crate) fn config(err: impl StdError + Send + Sync + 'static) -> Self { + Error::Configuration(err.into()) + } +} + +/// An error that was returned from the database. +pub trait DatabaseError: 'static + Send + Sync + StdError { + /// The primary, human-readable error message. + fn message(&self) -> &str; + + /// The (SQLSTATE) code for the error. + fn code(&self) -> Option> { + None + } + + #[doc(hidden)] + fn as_error(&self) -> &(dyn StdError + Send + Sync + 'static); + + #[doc(hidden)] + fn as_error_mut(&mut self) -> &mut (dyn StdError + Send + Sync + 'static); + + #[doc(hidden)] + fn into_error(self: Box) -> Box; + + #[doc(hidden)] + fn is_transient_in_connect_phase(&self) -> bool { + false + } + + /// Returns the name of the constraint that triggered the error, if applicable. + /// If the error was caused by a conflict of a unique index, this will be the index name. + /// + /// ### Note + /// Currently only populated by the Postgres driver. + fn constraint(&self) -> Option<&str> { + None + } +} + +impl dyn DatabaseError { + /// Downcast a reference to this generic database error to a specific + /// database error type. + /// + /// # Panics + /// + /// Panics if the database error type is not `E`. This is a deliberate contrast from + /// `Error::downcast_ref` which returns `Option<&E>`. In normal usage, you should know the + /// specific error type. In other cases, use `try_downcast_ref`. + pub fn downcast_ref(&self) -> &E { + self.try_downcast_ref().unwrap_or_else(|| { + panic!( + "downcast to wrong DatabaseError type; original error: {}", + self + ) + }) + } + + /// Downcast this generic database error to a specific database error type. + /// + /// # Panics + /// + /// Panics if the database error type is not `E`. This is a deliberate contrast from + /// `Error::downcast` which returns `Option`. In normal usage, you should know the + /// specific error type. In other cases, use `try_downcast`. + pub fn downcast(self: Box) -> Box { + self.try_downcast().unwrap_or_else(|e| { + panic!( + "downcast to wrong DatabaseError type; original error: {}", + e + ) + }) + } + + /// Downcast a reference to this generic database error to a specific + /// database error type. + #[inline] + pub fn try_downcast_ref(&self) -> Option<&E> { + self.as_error().downcast_ref() + } + + /// Downcast this generic database error to a specific database error type. + #[inline] + pub fn try_downcast(self: Box) -> StdResult, Box> { + if self.as_error().is::() { + Ok(self.into_error().downcast().unwrap()) + } else { + Err(self) + } + } +} + +impl From for Error +where + E: DatabaseError, +{ + #[inline] + fn from(error: E) -> Self { + Error::Database(Box::new(error)) + } +} + +#[cfg(feature = "migrate")] +impl From for Error { + #[inline] + fn from(error: crate::migrate::MigrateError) -> Self { + Error::Migrate(Box::new(error)) + } +} + +#[cfg(feature = "_tls-native-tls")] +impl From for Error { + #[inline] + fn from(error: sqlx_rt::native_tls::Error) -> Self { + Error::Tls(Box::new(error)) + } +} + +// Format an error message as a `Protocol` error +#[macro_export] +macro_rules! err_protocol { + ($expr:expr) => { + $crate::error::Error::Protocol($expr.into()) + }; + + ($fmt:expr, $($arg:tt)*) => { + $crate::error::Error::Protocol(format!($fmt, $($arg)*)) + }; +} diff --git a/warpgate-database-protocols/src/io/buf.rs b/warpgate-database-protocols/src/io/buf.rs new file mode 100644 index 0000000..f73764c --- /dev/null +++ b/warpgate-database-protocols/src/io/buf.rs @@ -0,0 +1,59 @@ +use std::str::from_utf8; + +use bytes::{Buf, Bytes}; +use memchr::memchr; + +use crate::err_protocol; +use crate::error::Error; + +pub trait BufExt: Buf { + // Read a nul-terminated byte sequence + fn get_bytes_nul(&mut self) -> Result; + + // Read a byte sequence of the exact length + fn get_bytes(&mut self, len: usize) -> Bytes; + + // Read a nul-terminated string + fn get_str_nul(&mut self) -> Result; + + // Read a string of the exact length + fn get_str(&mut self, len: usize) -> Result; +} + +impl BufExt for Bytes { + fn get_bytes_nul(&mut self) -> Result { + let nul = + memchr(b'\0', self).ok_or_else(|| err_protocol!("expected NUL in byte sequence"))?; + + let v = self.slice(0..nul); + + self.advance(nul + 1); + + Ok(v) + } + + fn get_bytes(&mut self, len: usize) -> Bytes { + let v = self.slice(..len); + self.advance(len); + + v + } + + fn get_str_nul(&mut self) -> Result { + self.get_bytes_nul().and_then(|bytes| { + from_utf8(&*bytes) + .map(ToOwned::to_owned) + .map_err(|err| err_protocol!("{}", err)) + }) + } + + fn get_str(&mut self, len: usize) -> Result { + let v = from_utf8(&self[..len]) + .map_err(|err| err_protocol!("{}", err)) + .map(ToOwned::to_owned)?; + + self.advance(len); + + Ok(v) + } +} diff --git a/warpgate-database-protocols/src/io/buf_mut.rs b/warpgate-database-protocols/src/io/buf_mut.rs new file mode 100644 index 0000000..565d850 --- /dev/null +++ b/warpgate-database-protocols/src/io/buf_mut.rs @@ -0,0 +1,12 @@ +use bytes::BufMut; + +pub trait BufMutExt: BufMut { + fn put_str_nul(&mut self, s: &str); +} + +impl BufMutExt for Vec { + fn put_str_nul(&mut self, s: &str) { + self.extend(s.as_bytes()); + self.push(0); + } +} diff --git a/warpgate-database-protocols/src/io/buf_stream.rs b/warpgate-database-protocols/src/io/buf_stream.rs new file mode 100644 index 0000000..6711e3b --- /dev/null +++ b/warpgate-database-protocols/src/io/buf_stream.rs @@ -0,0 +1,166 @@ +#![allow(dead_code)] + +use std::io; +use std::io::Cursor; +use std::ops::{Deref, DerefMut}; + +use bytes::BytesMut; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; + +use crate::error::Error; +use crate::io::decode::Decode; +use crate::io::encode::Encode; +use crate::io::write_and_flush::WriteAndFlush; + +pub struct BufStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + pub(crate) stream: S, + + // writes with `write` to the underlying stream are buffered + // this can be flushed with `flush` + pub(crate) wbuf: Vec, + + // we read into the read buffer using 100% safe code + rbuf: BytesMut, +} + +impl BufStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + pub fn new(stream: S) -> Self { + Self { + stream, + wbuf: Vec::with_capacity(512), + rbuf: BytesMut::with_capacity(4096), + } + } + + pub fn write<'en, T>(&mut self, value: T) + where + T: Encode<'en, ()>, + { + self.write_with(value, ()) + } + + pub fn write_with<'en, T, C>(&mut self, value: T, context: C) + where + T: Encode<'en, C>, + { + value.encode_with(&mut self.wbuf, context); + } + + pub fn flush(&mut self) -> WriteAndFlush<'_, S> { + WriteAndFlush { + stream: &mut self.stream, + buf: Cursor::new(&mut self.wbuf), + } + } + + pub async fn read<'de, T>(&mut self, cnt: usize) -> Result + where + T: Decode<'de, ()>, + { + self.read_with(cnt, ()).await + } + + pub async fn read_with<'de, T, C>(&mut self, cnt: usize, context: C) -> Result + where + T: Decode<'de, C>, + { + T::decode_with(self.read_raw(cnt).await?.freeze(), context) + } + + pub async fn read_raw(&mut self, cnt: usize) -> Result { + read_raw_into(&mut self.stream, &mut self.rbuf, cnt).await?; + let buf = self.rbuf.split_to(cnt); + + Ok(buf) + } + + pub async fn read_raw_into(&mut self, buf: &mut BytesMut, cnt: usize) -> Result<(), Error> { + read_raw_into(&mut self.stream, buf, cnt).await + } + + pub fn take(self) -> S { + self.stream + } +} + +impl Deref for BufStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + type Target = S; + + fn deref(&self) -> &Self::Target { + &self.stream + } +} + +impl DerefMut for BufStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.stream + } +} + +// Holds a buffer which has been temporarily extended, so that +// we can read into it. Automatically shrinks the buffer back +// down if the read is cancelled. +struct BufTruncator<'a> { + buf: &'a mut BytesMut, + filled_len: usize, +} + +impl<'a> BufTruncator<'a> { + fn new(buf: &'a mut BytesMut) -> Self { + let filled_len = buf.len(); + Self { buf, filled_len } + } + fn reserve(&mut self, space: usize) { + self.buf.resize(self.filled_len + space, 0); + } + async fn read(&mut self, stream: &mut S) -> Result { + let n = stream.read(&mut self.buf[self.filled_len..]).await?; + self.filled_len += n; + Ok(n) + } + fn is_full(&self) -> bool { + self.filled_len >= self.buf.len() + } +} + +impl Drop for BufTruncator<'_> { + fn drop(&mut self) { + self.buf.truncate(self.filled_len); + } +} + +async fn read_raw_into( + stream: &mut S, + buf: &mut BytesMut, + cnt: usize, +) -> Result<(), Error> { + let mut buf = BufTruncator::new(buf); + buf.reserve(cnt); + + while !buf.is_full() { + let n = buf.read(stream).await?; + + if n == 0 { + // a zero read when we had space in the read buffer + // should be treated as an EOF + + // and an unexpected EOF means the server told us to go away + + return Err(io::Error::from(io::ErrorKind::ConnectionAborted).into()); + } + } + + Ok(()) +} diff --git a/warpgate-database-protocols/src/io/decode.rs b/warpgate-database-protocols/src/io/decode.rs new file mode 100644 index 0000000..2f39712 --- /dev/null +++ b/warpgate-database-protocols/src/io/decode.rs @@ -0,0 +1,29 @@ +use bytes::Bytes; + +use crate::error::Error; + +pub trait Decode<'de, Context = ()> +where + Self: Sized, +{ + fn decode(buf: Bytes) -> Result + where + Self: Decode<'de, ()>, + { + Self::decode_with(buf, ()) + } + + fn decode_with(buf: Bytes, context: Context) -> Result; +} + +impl Decode<'_> for Bytes { + fn decode_with(buf: Bytes, _: ()) -> Result { + Ok(buf) + } +} + +impl Decode<'_> for () { + fn decode_with(_: Bytes, _: ()) -> Result<(), Error> { + Ok(()) + } +} diff --git a/warpgate-database-protocols/src/io/encode.rs b/warpgate-database-protocols/src/io/encode.rs new file mode 100644 index 0000000..a417ef9 --- /dev/null +++ b/warpgate-database-protocols/src/io/encode.rs @@ -0,0 +1,16 @@ +pub trait Encode<'en, Context = ()> { + fn encode(&self, buf: &mut Vec) + where + Self: Encode<'en, ()>, + { + self.encode_with(buf, ()); + } + + fn encode_with(&self, buf: &mut Vec, context: Context); +} + +impl<'en, C> Encode<'en, C> for &'_ [u8] { + fn encode_with(&self, buf: &mut Vec, _: C) { + buf.extend_from_slice(self); + } +} diff --git a/warpgate-database-protocols/src/io/mod.rs b/warpgate-database-protocols/src/io/mod.rs new file mode 100644 index 0000000..f994965 --- /dev/null +++ b/warpgate-database-protocols/src/io/mod.rs @@ -0,0 +1,12 @@ +mod buf; +mod buf_mut; +mod buf_stream; +mod decode; +mod encode; +mod write_and_flush; + +pub use buf::BufExt; +pub use buf_mut::BufMutExt; +pub use buf_stream::BufStream; +pub use decode::Decode; +pub use encode::Encode; diff --git a/warpgate-database-protocols/src/io/write_and_flush.rs b/warpgate-database-protocols/src/io/write_and_flush.rs new file mode 100644 index 0000000..8d37d34 --- /dev/null +++ b/warpgate-database-protocols/src/io/write_and_flush.rs @@ -0,0 +1,47 @@ +use std::io::{BufRead, Cursor}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use futures_core::Future; +use futures_util::ready; +use tokio::io::AsyncWrite; + +use crate::error::Error; + +// Atomic operation that writes the full buffer to the stream, flushes the stream, and then +// clears the buffer (even if either of the two previous operations failed). +pub struct WriteAndFlush<'a, S> { + pub(super) stream: &'a mut S, + pub(super) buf: Cursor<&'a mut Vec>, +} + +impl Future for WriteAndFlush<'_, S> { + type Output = Result<(), Error>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let Self { + ref mut stream, + ref mut buf, + } = *self; + + loop { + let read = buf.fill_buf()?; + + if !read.is_empty() { + let written = ready!(Pin::new(&mut *stream).poll_write(cx, read)?); + buf.consume(written); + } else { + break; + } + } + + Pin::new(stream).poll_flush(cx).map_err(Error::Io) + } +} + +impl<'a, S> Drop for WriteAndFlush<'a, S> { + fn drop(&mut self) { + // clear the buffer regardless of whether the flush succeeded or not + self.buf.get_mut().clear(); + } +} diff --git a/warpgate-database-protocols/src/lib.rs b/warpgate-database-protocols/src/lib.rs new file mode 100644 index 0000000..d75cfe5 --- /dev/null +++ b/warpgate-database-protocols/src/lib.rs @@ -0,0 +1,6 @@ +#![allow(dead_code)] +pub mod io; +pub mod mysql; + +#[macro_use] +pub mod error; diff --git a/warpgate-database-protocols/src/mysql/collation.rs b/warpgate-database-protocols/src/mysql/collation.rs new file mode 100644 index 0000000..0247679 --- /dev/null +++ b/warpgate-database-protocols/src/mysql/collation.rs @@ -0,0 +1,901 @@ +use std::str::FromStr; + +use crate::error::Error; + +#[allow(non_camel_case_types)] +#[derive(Copy, Clone)] +pub(crate) enum CharSet { + armscii8, + ascii, + big5, + binary, + cp1250, + cp1251, + cp1256, + cp1257, + cp850, + cp852, + cp866, + cp932, + dec8, + eucjpms, + euckr, + gb18030, + gb2312, + gbk, + geostd8, + greek, + hebrew, + hp8, + keybcs2, + koi8r, + koi8u, + latin1, + latin2, + latin5, + latin7, + macce, + macroman, + sjis, + swe7, + tis620, + ucs2, + ujis, + utf16, + utf16le, + utf32, + utf8, + utf8mb4, +} + +impl CharSet { + pub(crate) fn as_str(&self) -> &'static str { + match self { + CharSet::armscii8 => "armscii8", + CharSet::ascii => "ascii", + CharSet::big5 => "big5", + CharSet::binary => "binary", + CharSet::cp1250 => "cp1250", + CharSet::cp1251 => "cp1251", + CharSet::cp1256 => "cp1256", + CharSet::cp1257 => "cp1257", + CharSet::cp850 => "cp850", + CharSet::cp852 => "cp852", + CharSet::cp866 => "cp866", + CharSet::cp932 => "cp932", + CharSet::dec8 => "dec8", + CharSet::eucjpms => "eucjpms", + CharSet::euckr => "euckr", + CharSet::gb18030 => "gb18030", + CharSet::gb2312 => "gb2312", + CharSet::gbk => "gbk", + CharSet::geostd8 => "geostd8", + CharSet::greek => "greek", + CharSet::hebrew => "hebrew", + CharSet::hp8 => "hp8", + CharSet::keybcs2 => "keybcs2", + CharSet::koi8r => "koi8r", + CharSet::koi8u => "koi8u", + CharSet::latin1 => "latin1", + CharSet::latin2 => "latin2", + CharSet::latin5 => "latin5", + CharSet::latin7 => "latin7", + CharSet::macce => "macce", + CharSet::macroman => "macroman", + CharSet::sjis => "sjis", + CharSet::swe7 => "swe7", + CharSet::tis620 => "tis620", + CharSet::ucs2 => "ucs2", + CharSet::ujis => "ujis", + CharSet::utf16 => "utf16", + CharSet::utf16le => "utf16le", + CharSet::utf32 => "utf32", + CharSet::utf8 => "utf8", + CharSet::utf8mb4 => "utf8mb4", + } + } + + pub(crate) fn default_collation(&self) -> Collation { + match self { + CharSet::armscii8 => Collation::armscii8_general_ci, + CharSet::ascii => Collation::ascii_general_ci, + CharSet::big5 => Collation::big5_chinese_ci, + CharSet::binary => Collation::binary, + CharSet::cp1250 => Collation::cp1250_general_ci, + CharSet::cp1251 => Collation::cp1251_general_ci, + CharSet::cp1256 => Collation::cp1256_general_ci, + CharSet::cp1257 => Collation::cp1257_general_ci, + CharSet::cp850 => Collation::cp850_general_ci, + CharSet::cp852 => Collation::cp852_general_ci, + CharSet::cp866 => Collation::cp866_general_ci, + CharSet::cp932 => Collation::cp932_japanese_ci, + CharSet::dec8 => Collation::dec8_swedish_ci, + CharSet::eucjpms => Collation::eucjpms_japanese_ci, + CharSet::euckr => Collation::euckr_korean_ci, + CharSet::gb18030 => Collation::gb18030_chinese_ci, + CharSet::gb2312 => Collation::gb2312_chinese_ci, + CharSet::gbk => Collation::gbk_chinese_ci, + CharSet::geostd8 => Collation::geostd8_general_ci, + CharSet::greek => Collation::greek_general_ci, + CharSet::hebrew => Collation::hebrew_general_ci, + CharSet::hp8 => Collation::hp8_english_ci, + CharSet::keybcs2 => Collation::keybcs2_general_ci, + CharSet::koi8r => Collation::koi8r_general_ci, + CharSet::koi8u => Collation::koi8u_general_ci, + CharSet::latin1 => Collation::latin1_swedish_ci, + CharSet::latin2 => Collation::latin2_general_ci, + CharSet::latin5 => Collation::latin5_turkish_ci, + CharSet::latin7 => Collation::latin7_general_ci, + CharSet::macce => Collation::macce_general_ci, + CharSet::macroman => Collation::macroman_general_ci, + CharSet::sjis => Collation::sjis_japanese_ci, + CharSet::swe7 => Collation::swe7_swedish_ci, + CharSet::tis620 => Collation::tis620_thai_ci, + CharSet::ucs2 => Collation::ucs2_general_ci, + CharSet::ujis => Collation::ujis_japanese_ci, + CharSet::utf16 => Collation::utf16_general_ci, + CharSet::utf16le => Collation::utf16le_general_ci, + CharSet::utf32 => Collation::utf32_general_ci, + CharSet::utf8 => Collation::utf8_unicode_ci, + CharSet::utf8mb4 => Collation::utf8mb4_unicode_ci, + } + } +} + +impl FromStr for CharSet { + type Err = Error; + + fn from_str(char_set: &str) -> Result { + Ok(match char_set { + "armscii8" => CharSet::armscii8, + "ascii" => CharSet::ascii, + "big5" => CharSet::big5, + "binary" => CharSet::binary, + "cp1250" => CharSet::cp1250, + "cp1251" => CharSet::cp1251, + "cp1256" => CharSet::cp1256, + "cp1257" => CharSet::cp1257, + "cp850" => CharSet::cp850, + "cp852" => CharSet::cp852, + "cp866" => CharSet::cp866, + "cp932" => CharSet::cp932, + "dec8" => CharSet::dec8, + "eucjpms" => CharSet::eucjpms, + "euckr" => CharSet::euckr, + "gb18030" => CharSet::gb18030, + "gb2312" => CharSet::gb2312, + "gbk" => CharSet::gbk, + "geostd8" => CharSet::geostd8, + "greek" => CharSet::greek, + "hebrew" => CharSet::hebrew, + "hp8" => CharSet::hp8, + "keybcs2" => CharSet::keybcs2, + "koi8r" => CharSet::koi8r, + "koi8u" => CharSet::koi8u, + "latin1" => CharSet::latin1, + "latin2" => CharSet::latin2, + "latin5" => CharSet::latin5, + "latin7" => CharSet::latin7, + "macce" => CharSet::macce, + "macroman" => CharSet::macroman, + "sjis" => CharSet::sjis, + "swe7" => CharSet::swe7, + "tis620" => CharSet::tis620, + "ucs2" => CharSet::ucs2, + "ujis" => CharSet::ujis, + "utf16" => CharSet::utf16, + "utf16le" => CharSet::utf16le, + "utf32" => CharSet::utf32, + "utf8" => CharSet::utf8, + "utf8mb4" => CharSet::utf8mb4, + + _ => { + return Err(Error::Configuration( + format!("unsupported MySQL charset: {}", char_set).into(), + )); + } + }) + } +} + +#[derive(Copy, Clone)] +#[allow(non_camel_case_types)] +#[repr(u8)] +pub(crate) enum Collation { + armscii8_bin = 64, + armscii8_general_ci = 32, + ascii_bin = 65, + ascii_general_ci = 11, + big5_bin = 84, + big5_chinese_ci = 1, + binary = 63, + cp1250_bin = 66, + cp1250_croatian_ci = 44, + cp1250_czech_cs = 34, + cp1250_general_ci = 26, + cp1250_polish_ci = 99, + cp1251_bin = 50, + cp1251_bulgarian_ci = 14, + cp1251_general_ci = 51, + cp1251_general_cs = 52, + cp1251_ukrainian_ci = 23, + cp1256_bin = 67, + cp1256_general_ci = 57, + cp1257_bin = 58, + cp1257_general_ci = 59, + cp1257_lithuanian_ci = 29, + cp850_bin = 80, + cp850_general_ci = 4, + cp852_bin = 81, + cp852_general_ci = 40, + cp866_bin = 68, + cp866_general_ci = 36, + cp932_bin = 96, + cp932_japanese_ci = 95, + dec8_bin = 69, + dec8_swedish_ci = 3, + eucjpms_bin = 98, + eucjpms_japanese_ci = 97, + euckr_bin = 85, + euckr_korean_ci = 19, + gb18030_bin = 249, + gb18030_chinese_ci = 248, + gb18030_unicode_520_ci = 250, + gb2312_bin = 86, + gb2312_chinese_ci = 24, + gbk_bin = 87, + gbk_chinese_ci = 28, + geostd8_bin = 93, + geostd8_general_ci = 92, + greek_bin = 70, + greek_general_ci = 25, + hebrew_bin = 71, + hebrew_general_ci = 16, + hp8_bin = 72, + hp8_english_ci = 6, + keybcs2_bin = 73, + keybcs2_general_ci = 37, + koi8r_bin = 74, + koi8r_general_ci = 7, + koi8u_bin = 75, + koi8u_general_ci = 22, + latin1_bin = 47, + latin1_danish_ci = 15, + latin1_general_ci = 48, + latin1_general_cs = 49, + latin1_german1_ci = 5, + latin1_german2_ci = 31, + latin1_spanish_ci = 94, + latin1_swedish_ci = 8, + latin2_bin = 77, + latin2_croatian_ci = 27, + latin2_czech_cs = 2, + latin2_general_ci = 9, + latin2_hungarian_ci = 21, + latin5_bin = 78, + latin5_turkish_ci = 30, + latin7_bin = 79, + latin7_estonian_cs = 20, + latin7_general_ci = 41, + latin7_general_cs = 42, + macce_bin = 43, + macce_general_ci = 38, + macroman_bin = 53, + macroman_general_ci = 39, + sjis_bin = 88, + sjis_japanese_ci = 13, + swe7_bin = 82, + swe7_swedish_ci = 10, + tis620_bin = 89, + tis620_thai_ci = 18, + ucs2_bin = 90, + ucs2_croatian_ci = 149, + ucs2_czech_ci = 138, + ucs2_danish_ci = 139, + ucs2_esperanto_ci = 145, + ucs2_estonian_ci = 134, + ucs2_general_ci = 35, + ucs2_general_mysql500_ci = 159, + ucs2_german2_ci = 148, + ucs2_hungarian_ci = 146, + ucs2_icelandic_ci = 129, + ucs2_latvian_ci = 130, + ucs2_lithuanian_ci = 140, + ucs2_persian_ci = 144, + ucs2_polish_ci = 133, + ucs2_roman_ci = 143, + ucs2_romanian_ci = 131, + ucs2_sinhala_ci = 147, + ucs2_slovak_ci = 141, + ucs2_slovenian_ci = 132, + ucs2_spanish_ci = 135, + ucs2_spanish2_ci = 142, + ucs2_swedish_ci = 136, + ucs2_turkish_ci = 137, + ucs2_unicode_520_ci = 150, + ucs2_unicode_ci = 128, + ucs2_vietnamese_ci = 151, + ujis_bin = 91, + ujis_japanese_ci = 12, + utf16_bin = 55, + utf16_croatian_ci = 122, + utf16_czech_ci = 111, + utf16_danish_ci = 112, + utf16_esperanto_ci = 118, + utf16_estonian_ci = 107, + utf16_general_ci = 54, + utf16_german2_ci = 121, + utf16_hungarian_ci = 119, + utf16_icelandic_ci = 102, + utf16_latvian_ci = 103, + utf16_lithuanian_ci = 113, + utf16_persian_ci = 117, + utf16_polish_ci = 106, + utf16_roman_ci = 116, + utf16_romanian_ci = 104, + utf16_sinhala_ci = 120, + utf16_slovak_ci = 114, + utf16_slovenian_ci = 105, + utf16_spanish_ci = 108, + utf16_spanish2_ci = 115, + utf16_swedish_ci = 109, + utf16_turkish_ci = 110, + utf16_unicode_520_ci = 123, + utf16_unicode_ci = 101, + utf16_vietnamese_ci = 124, + utf16le_bin = 62, + utf16le_general_ci = 56, + utf32_bin = 61, + utf32_croatian_ci = 181, + utf32_czech_ci = 170, + utf32_danish_ci = 171, + utf32_esperanto_ci = 177, + utf32_estonian_ci = 166, + utf32_general_ci = 60, + utf32_german2_ci = 180, + utf32_hungarian_ci = 178, + utf32_icelandic_ci = 161, + utf32_latvian_ci = 162, + utf32_lithuanian_ci = 172, + utf32_persian_ci = 176, + utf32_polish_ci = 165, + utf32_roman_ci = 175, + utf32_romanian_ci = 163, + utf32_sinhala_ci = 179, + utf32_slovak_ci = 173, + utf32_slovenian_ci = 164, + utf32_spanish_ci = 167, + utf32_spanish2_ci = 174, + utf32_swedish_ci = 168, + utf32_turkish_ci = 169, + utf32_unicode_520_ci = 182, + utf32_unicode_ci = 160, + utf32_vietnamese_ci = 183, + utf8_bin = 83, + utf8_croatian_ci = 213, + utf8_czech_ci = 202, + utf8_danish_ci = 203, + utf8_esperanto_ci = 209, + utf8_estonian_ci = 198, + utf8_general_ci = 33, + utf8_general_mysql500_ci = 223, + utf8_german2_ci = 212, + utf8_hungarian_ci = 210, + utf8_icelandic_ci = 193, + utf8_latvian_ci = 194, + utf8_lithuanian_ci = 204, + utf8_persian_ci = 208, + utf8_polish_ci = 197, + utf8_roman_ci = 207, + utf8_romanian_ci = 195, + utf8_sinhala_ci = 211, + utf8_slovak_ci = 205, + utf8_slovenian_ci = 196, + utf8_spanish_ci = 199, + utf8_spanish2_ci = 206, + utf8_swedish_ci = 200, + utf8_tolower_ci = 76, + utf8_turkish_ci = 201, + utf8_unicode_520_ci = 214, + utf8_unicode_ci = 192, + utf8_vietnamese_ci = 215, + utf8mb4_0900_ai_ci = 255, + utf8mb4_bin = 46, + utf8mb4_croatian_ci = 245, + utf8mb4_czech_ci = 234, + utf8mb4_danish_ci = 235, + utf8mb4_esperanto_ci = 241, + utf8mb4_estonian_ci = 230, + utf8mb4_general_ci = 45, + utf8mb4_german2_ci = 244, + utf8mb4_hungarian_ci = 242, + utf8mb4_icelandic_ci = 225, + utf8mb4_latvian_ci = 226, + utf8mb4_lithuanian_ci = 236, + utf8mb4_persian_ci = 240, + utf8mb4_polish_ci = 229, + utf8mb4_roman_ci = 239, + utf8mb4_romanian_ci = 227, + utf8mb4_sinhala_ci = 243, + utf8mb4_slovak_ci = 237, + utf8mb4_slovenian_ci = 228, + utf8mb4_spanish_ci = 231, + utf8mb4_spanish2_ci = 238, + utf8mb4_swedish_ci = 232, + utf8mb4_turkish_ci = 233, + utf8mb4_unicode_520_ci = 246, + utf8mb4_unicode_ci = 224, + utf8mb4_vietnamese_ci = 247, +} + +impl Collation { + pub(crate) fn as_str(&self) -> &'static str { + match self { + Collation::armscii8_bin => "armscii8_bin", + Collation::armscii8_general_ci => "armscii8_general_ci", + Collation::ascii_bin => "ascii_bin", + Collation::ascii_general_ci => "ascii_general_ci", + Collation::big5_bin => "big5_bin", + Collation::big5_chinese_ci => "big5_chinese_ci", + Collation::binary => "binary", + Collation::cp1250_bin => "cp1250_bin", + Collation::cp1250_croatian_ci => "cp1250_croatian_ci", + Collation::cp1250_czech_cs => "cp1250_czech_cs", + Collation::cp1250_general_ci => "cp1250_general_ci", + Collation::cp1250_polish_ci => "cp1250_polish_ci", + Collation::cp1251_bin => "cp1251_bin", + Collation::cp1251_bulgarian_ci => "cp1251_bulgarian_ci", + Collation::cp1251_general_ci => "cp1251_general_ci", + Collation::cp1251_general_cs => "cp1251_general_cs", + Collation::cp1251_ukrainian_ci => "cp1251_ukrainian_ci", + Collation::cp1256_bin => "cp1256_bin", + Collation::cp1256_general_ci => "cp1256_general_ci", + Collation::cp1257_bin => "cp1257_bin", + Collation::cp1257_general_ci => "cp1257_general_ci", + Collation::cp1257_lithuanian_ci => "cp1257_lithuanian_ci", + Collation::cp850_bin => "cp850_bin", + Collation::cp850_general_ci => "cp850_general_ci", + Collation::cp852_bin => "cp852_bin", + Collation::cp852_general_ci => "cp852_general_ci", + Collation::cp866_bin => "cp866_bin", + Collation::cp866_general_ci => "cp866_general_ci", + Collation::cp932_bin => "cp932_bin", + Collation::cp932_japanese_ci => "cp932_japanese_ci", + Collation::dec8_bin => "dec8_bin", + Collation::dec8_swedish_ci => "dec8_swedish_ci", + Collation::eucjpms_bin => "eucjpms_bin", + Collation::eucjpms_japanese_ci => "eucjpms_japanese_ci", + Collation::euckr_bin => "euckr_bin", + Collation::euckr_korean_ci => "euckr_korean_ci", + Collation::gb18030_bin => "gb18030_bin", + Collation::gb18030_chinese_ci => "gb18030_chinese_ci", + Collation::gb18030_unicode_520_ci => "gb18030_unicode_520_ci", + Collation::gb2312_bin => "gb2312_bin", + Collation::gb2312_chinese_ci => "gb2312_chinese_ci", + Collation::gbk_bin => "gbk_bin", + Collation::gbk_chinese_ci => "gbk_chinese_ci", + Collation::geostd8_bin => "geostd8_bin", + Collation::geostd8_general_ci => "geostd8_general_ci", + Collation::greek_bin => "greek_bin", + Collation::greek_general_ci => "greek_general_ci", + Collation::hebrew_bin => "hebrew_bin", + Collation::hebrew_general_ci => "hebrew_general_ci", + Collation::hp8_bin => "hp8_bin", + Collation::hp8_english_ci => "hp8_english_ci", + Collation::keybcs2_bin => "keybcs2_bin", + Collation::keybcs2_general_ci => "keybcs2_general_ci", + Collation::koi8r_bin => "koi8r_bin", + Collation::koi8r_general_ci => "koi8r_general_ci", + Collation::koi8u_bin => "koi8u_bin", + Collation::koi8u_general_ci => "koi8u_general_ci", + Collation::latin1_bin => "latin1_bin", + Collation::latin1_danish_ci => "latin1_danish_ci", + Collation::latin1_general_ci => "latin1_general_ci", + Collation::latin1_general_cs => "latin1_general_cs", + Collation::latin1_german1_ci => "latin1_german1_ci", + Collation::latin1_german2_ci => "latin1_german2_ci", + Collation::latin1_spanish_ci => "latin1_spanish_ci", + Collation::latin1_swedish_ci => "latin1_swedish_ci", + Collation::latin2_bin => "latin2_bin", + Collation::latin2_croatian_ci => "latin2_croatian_ci", + Collation::latin2_czech_cs => "latin2_czech_cs", + Collation::latin2_general_ci => "latin2_general_ci", + Collation::latin2_hungarian_ci => "latin2_hungarian_ci", + Collation::latin5_bin => "latin5_bin", + Collation::latin5_turkish_ci => "latin5_turkish_ci", + Collation::latin7_bin => "latin7_bin", + Collation::latin7_estonian_cs => "latin7_estonian_cs", + Collation::latin7_general_ci => "latin7_general_ci", + Collation::latin7_general_cs => "latin7_general_cs", + Collation::macce_bin => "macce_bin", + Collation::macce_general_ci => "macce_general_ci", + Collation::macroman_bin => "macroman_bin", + Collation::macroman_general_ci => "macroman_general_ci", + Collation::sjis_bin => "sjis_bin", + Collation::sjis_japanese_ci => "sjis_japanese_ci", + Collation::swe7_bin => "swe7_bin", + Collation::swe7_swedish_ci => "swe7_swedish_ci", + Collation::tis620_bin => "tis620_bin", + Collation::tis620_thai_ci => "tis620_thai_ci", + Collation::ucs2_bin => "ucs2_bin", + Collation::ucs2_croatian_ci => "ucs2_croatian_ci", + Collation::ucs2_czech_ci => "ucs2_czech_ci", + Collation::ucs2_danish_ci => "ucs2_danish_ci", + Collation::ucs2_esperanto_ci => "ucs2_esperanto_ci", + Collation::ucs2_estonian_ci => "ucs2_estonian_ci", + Collation::ucs2_general_ci => "ucs2_general_ci", + Collation::ucs2_general_mysql500_ci => "ucs2_general_mysql500_ci", + Collation::ucs2_german2_ci => "ucs2_german2_ci", + Collation::ucs2_hungarian_ci => "ucs2_hungarian_ci", + Collation::ucs2_icelandic_ci => "ucs2_icelandic_ci", + Collation::ucs2_latvian_ci => "ucs2_latvian_ci", + Collation::ucs2_lithuanian_ci => "ucs2_lithuanian_ci", + Collation::ucs2_persian_ci => "ucs2_persian_ci", + Collation::ucs2_polish_ci => "ucs2_polish_ci", + Collation::ucs2_roman_ci => "ucs2_roman_ci", + Collation::ucs2_romanian_ci => "ucs2_romanian_ci", + Collation::ucs2_sinhala_ci => "ucs2_sinhala_ci", + Collation::ucs2_slovak_ci => "ucs2_slovak_ci", + Collation::ucs2_slovenian_ci => "ucs2_slovenian_ci", + Collation::ucs2_spanish_ci => "ucs2_spanish_ci", + Collation::ucs2_spanish2_ci => "ucs2_spanish2_ci", + Collation::ucs2_swedish_ci => "ucs2_swedish_ci", + Collation::ucs2_turkish_ci => "ucs2_turkish_ci", + Collation::ucs2_unicode_520_ci => "ucs2_unicode_520_ci", + Collation::ucs2_unicode_ci => "ucs2_unicode_ci", + Collation::ucs2_vietnamese_ci => "ucs2_vietnamese_ci", + Collation::ujis_bin => "ujis_bin", + Collation::ujis_japanese_ci => "ujis_japanese_ci", + Collation::utf16_bin => "utf16_bin", + Collation::utf16_croatian_ci => "utf16_croatian_ci", + Collation::utf16_czech_ci => "utf16_czech_ci", + Collation::utf16_danish_ci => "utf16_danish_ci", + Collation::utf16_esperanto_ci => "utf16_esperanto_ci", + Collation::utf16_estonian_ci => "utf16_estonian_ci", + Collation::utf16_general_ci => "utf16_general_ci", + Collation::utf16_german2_ci => "utf16_german2_ci", + Collation::utf16_hungarian_ci => "utf16_hungarian_ci", + Collation::utf16_icelandic_ci => "utf16_icelandic_ci", + Collation::utf16_latvian_ci => "utf16_latvian_ci", + Collation::utf16_lithuanian_ci => "utf16_lithuanian_ci", + Collation::utf16_persian_ci => "utf16_persian_ci", + Collation::utf16_polish_ci => "utf16_polish_ci", + Collation::utf16_roman_ci => "utf16_roman_ci", + Collation::utf16_romanian_ci => "utf16_romanian_ci", + Collation::utf16_sinhala_ci => "utf16_sinhala_ci", + Collation::utf16_slovak_ci => "utf16_slovak_ci", + Collation::utf16_slovenian_ci => "utf16_slovenian_ci", + Collation::utf16_spanish_ci => "utf16_spanish_ci", + Collation::utf16_spanish2_ci => "utf16_spanish2_ci", + Collation::utf16_swedish_ci => "utf16_swedish_ci", + Collation::utf16_turkish_ci => "utf16_turkish_ci", + Collation::utf16_unicode_520_ci => "utf16_unicode_520_ci", + Collation::utf16_unicode_ci => "utf16_unicode_ci", + Collation::utf16_vietnamese_ci => "utf16_vietnamese_ci", + Collation::utf16le_bin => "utf16le_bin", + Collation::utf16le_general_ci => "utf16le_general_ci", + Collation::utf32_bin => "utf32_bin", + Collation::utf32_croatian_ci => "utf32_croatian_ci", + Collation::utf32_czech_ci => "utf32_czech_ci", + Collation::utf32_danish_ci => "utf32_danish_ci", + Collation::utf32_esperanto_ci => "utf32_esperanto_ci", + Collation::utf32_estonian_ci => "utf32_estonian_ci", + Collation::utf32_general_ci => "utf32_general_ci", + Collation::utf32_german2_ci => "utf32_german2_ci", + Collation::utf32_hungarian_ci => "utf32_hungarian_ci", + Collation::utf32_icelandic_ci => "utf32_icelandic_ci", + Collation::utf32_latvian_ci => "utf32_latvian_ci", + Collation::utf32_lithuanian_ci => "utf32_lithuanian_ci", + Collation::utf32_persian_ci => "utf32_persian_ci", + Collation::utf32_polish_ci => "utf32_polish_ci", + Collation::utf32_roman_ci => "utf32_roman_ci", + Collation::utf32_romanian_ci => "utf32_romanian_ci", + Collation::utf32_sinhala_ci => "utf32_sinhala_ci", + Collation::utf32_slovak_ci => "utf32_slovak_ci", + Collation::utf32_slovenian_ci => "utf32_slovenian_ci", + Collation::utf32_spanish_ci => "utf32_spanish_ci", + Collation::utf32_spanish2_ci => "utf32_spanish2_ci", + Collation::utf32_swedish_ci => "utf32_swedish_ci", + Collation::utf32_turkish_ci => "utf32_turkish_ci", + Collation::utf32_unicode_520_ci => "utf32_unicode_520_ci", + Collation::utf32_unicode_ci => "utf32_unicode_ci", + Collation::utf32_vietnamese_ci => "utf32_vietnamese_ci", + Collation::utf8_bin => "utf8_bin", + Collation::utf8_croatian_ci => "utf8_croatian_ci", + Collation::utf8_czech_ci => "utf8_czech_ci", + Collation::utf8_danish_ci => "utf8_danish_ci", + Collation::utf8_esperanto_ci => "utf8_esperanto_ci", + Collation::utf8_estonian_ci => "utf8_estonian_ci", + Collation::utf8_general_ci => "utf8_general_ci", + Collation::utf8_general_mysql500_ci => "utf8_general_mysql500_ci", + Collation::utf8_german2_ci => "utf8_german2_ci", + Collation::utf8_hungarian_ci => "utf8_hungarian_ci", + Collation::utf8_icelandic_ci => "utf8_icelandic_ci", + Collation::utf8_latvian_ci => "utf8_latvian_ci", + Collation::utf8_lithuanian_ci => "utf8_lithuanian_ci", + Collation::utf8_persian_ci => "utf8_persian_ci", + Collation::utf8_polish_ci => "utf8_polish_ci", + Collation::utf8_roman_ci => "utf8_roman_ci", + Collation::utf8_romanian_ci => "utf8_romanian_ci", + Collation::utf8_sinhala_ci => "utf8_sinhala_ci", + Collation::utf8_slovak_ci => "utf8_slovak_ci", + Collation::utf8_slovenian_ci => "utf8_slovenian_ci", + Collation::utf8_spanish_ci => "utf8_spanish_ci", + Collation::utf8_spanish2_ci => "utf8_spanish2_ci", + Collation::utf8_swedish_ci => "utf8_swedish_ci", + Collation::utf8_tolower_ci => "utf8_tolower_ci", + Collation::utf8_turkish_ci => "utf8_turkish_ci", + Collation::utf8_unicode_520_ci => "utf8_unicode_520_ci", + Collation::utf8_unicode_ci => "utf8_unicode_ci", + Collation::utf8_vietnamese_ci => "utf8_vietnamese_ci", + Collation::utf8mb4_0900_ai_ci => "utf8mb4_0900_ai_ci", + Collation::utf8mb4_bin => "utf8mb4_bin", + Collation::utf8mb4_croatian_ci => "utf8mb4_croatian_ci", + Collation::utf8mb4_czech_ci => "utf8mb4_czech_ci", + Collation::utf8mb4_danish_ci => "utf8mb4_danish_ci", + Collation::utf8mb4_esperanto_ci => "utf8mb4_esperanto_ci", + Collation::utf8mb4_estonian_ci => "utf8mb4_estonian_ci", + Collation::utf8mb4_general_ci => "utf8mb4_general_ci", + Collation::utf8mb4_german2_ci => "utf8mb4_german2_ci", + Collation::utf8mb4_hungarian_ci => "utf8mb4_hungarian_ci", + Collation::utf8mb4_icelandic_ci => "utf8mb4_icelandic_ci", + Collation::utf8mb4_latvian_ci => "utf8mb4_latvian_ci", + Collation::utf8mb4_lithuanian_ci => "utf8mb4_lithuanian_ci", + Collation::utf8mb4_persian_ci => "utf8mb4_persian_ci", + Collation::utf8mb4_polish_ci => "utf8mb4_polish_ci", + Collation::utf8mb4_roman_ci => "utf8mb4_roman_ci", + Collation::utf8mb4_romanian_ci => "utf8mb4_romanian_ci", + Collation::utf8mb4_sinhala_ci => "utf8mb4_sinhala_ci", + Collation::utf8mb4_slovak_ci => "utf8mb4_slovak_ci", + Collation::utf8mb4_slovenian_ci => "utf8mb4_slovenian_ci", + Collation::utf8mb4_spanish_ci => "utf8mb4_spanish_ci", + Collation::utf8mb4_spanish2_ci => "utf8mb4_spanish2_ci", + Collation::utf8mb4_swedish_ci => "utf8mb4_swedish_ci", + Collation::utf8mb4_turkish_ci => "utf8mb4_turkish_ci", + Collation::utf8mb4_unicode_520_ci => "utf8mb4_unicode_520_ci", + Collation::utf8mb4_unicode_ci => "utf8mb4_unicode_ci", + Collation::utf8mb4_vietnamese_ci => "utf8mb4_vietnamese_ci", + } + } +} + +// Handshake packet have only 1 byte for collation_id. +// So we can't use collations with ID > 255. +impl FromStr for Collation { + type Err = Error; + + fn from_str(collation: &str) -> Result { + Ok(match collation { + "big5_chinese_ci" => Collation::big5_chinese_ci, + "swe7_swedish_ci" => Collation::swe7_swedish_ci, + "utf16_unicode_ci" => Collation::utf16_unicode_ci, + "utf16_icelandic_ci" => Collation::utf16_icelandic_ci, + "utf16_latvian_ci" => Collation::utf16_latvian_ci, + "utf16_romanian_ci" => Collation::utf16_romanian_ci, + "utf16_slovenian_ci" => Collation::utf16_slovenian_ci, + "utf16_polish_ci" => Collation::utf16_polish_ci, + "utf16_estonian_ci" => Collation::utf16_estonian_ci, + "utf16_spanish_ci" => Collation::utf16_spanish_ci, + "utf16_swedish_ci" => Collation::utf16_swedish_ci, + "ascii_general_ci" => Collation::ascii_general_ci, + "utf16_turkish_ci" => Collation::utf16_turkish_ci, + "utf16_czech_ci" => Collation::utf16_czech_ci, + "utf16_danish_ci" => Collation::utf16_danish_ci, + "utf16_lithuanian_ci" => Collation::utf16_lithuanian_ci, + "utf16_slovak_ci" => Collation::utf16_slovak_ci, + "utf16_spanish2_ci" => Collation::utf16_spanish2_ci, + "utf16_roman_ci" => Collation::utf16_roman_ci, + "utf16_persian_ci" => Collation::utf16_persian_ci, + "utf16_esperanto_ci" => Collation::utf16_esperanto_ci, + "utf16_hungarian_ci" => Collation::utf16_hungarian_ci, + "ujis_japanese_ci" => Collation::ujis_japanese_ci, + "utf16_sinhala_ci" => Collation::utf16_sinhala_ci, + "utf16_german2_ci" => Collation::utf16_german2_ci, + "utf16_croatian_ci" => Collation::utf16_croatian_ci, + "utf16_unicode_520_ci" => Collation::utf16_unicode_520_ci, + "utf16_vietnamese_ci" => Collation::utf16_vietnamese_ci, + "ucs2_unicode_ci" => Collation::ucs2_unicode_ci, + "ucs2_icelandic_ci" => Collation::ucs2_icelandic_ci, + "sjis_japanese_ci" => Collation::sjis_japanese_ci, + "ucs2_latvian_ci" => Collation::ucs2_latvian_ci, + "ucs2_romanian_ci" => Collation::ucs2_romanian_ci, + "ucs2_slovenian_ci" => Collation::ucs2_slovenian_ci, + "ucs2_polish_ci" => Collation::ucs2_polish_ci, + "ucs2_estonian_ci" => Collation::ucs2_estonian_ci, + "ucs2_spanish_ci" => Collation::ucs2_spanish_ci, + "ucs2_swedish_ci" => Collation::ucs2_swedish_ci, + "ucs2_turkish_ci" => Collation::ucs2_turkish_ci, + "ucs2_czech_ci" => Collation::ucs2_czech_ci, + "ucs2_danish_ci" => Collation::ucs2_danish_ci, + "cp1251_bulgarian_ci" => Collation::cp1251_bulgarian_ci, + "ucs2_lithuanian_ci" => Collation::ucs2_lithuanian_ci, + "ucs2_slovak_ci" => Collation::ucs2_slovak_ci, + "ucs2_spanish2_ci" => Collation::ucs2_spanish2_ci, + "ucs2_roman_ci" => Collation::ucs2_roman_ci, + "ucs2_persian_ci" => Collation::ucs2_persian_ci, + "ucs2_esperanto_ci" => Collation::ucs2_esperanto_ci, + "ucs2_hungarian_ci" => Collation::ucs2_hungarian_ci, + "ucs2_sinhala_ci" => Collation::ucs2_sinhala_ci, + "ucs2_german2_ci" => Collation::ucs2_german2_ci, + "ucs2_croatian_ci" => Collation::ucs2_croatian_ci, + "latin1_danish_ci" => Collation::latin1_danish_ci, + "ucs2_unicode_520_ci" => Collation::ucs2_unicode_520_ci, + "ucs2_vietnamese_ci" => Collation::ucs2_vietnamese_ci, + "ucs2_general_mysql500_ci" => Collation::ucs2_general_mysql500_ci, + "hebrew_general_ci" => Collation::hebrew_general_ci, + "utf32_unicode_ci" => Collation::utf32_unicode_ci, + "utf32_icelandic_ci" => Collation::utf32_icelandic_ci, + "utf32_latvian_ci" => Collation::utf32_latvian_ci, + "utf32_romanian_ci" => Collation::utf32_romanian_ci, + "utf32_slovenian_ci" => Collation::utf32_slovenian_ci, + "utf32_polish_ci" => Collation::utf32_polish_ci, + "utf32_estonian_ci" => Collation::utf32_estonian_ci, + "utf32_spanish_ci" => Collation::utf32_spanish_ci, + "utf32_swedish_ci" => Collation::utf32_swedish_ci, + "utf32_turkish_ci" => Collation::utf32_turkish_ci, + "utf32_czech_ci" => Collation::utf32_czech_ci, + "utf32_danish_ci" => Collation::utf32_danish_ci, + "utf32_lithuanian_ci" => Collation::utf32_lithuanian_ci, + "utf32_slovak_ci" => Collation::utf32_slovak_ci, + "utf32_spanish2_ci" => Collation::utf32_spanish2_ci, + "utf32_roman_ci" => Collation::utf32_roman_ci, + "utf32_persian_ci" => Collation::utf32_persian_ci, + "utf32_esperanto_ci" => Collation::utf32_esperanto_ci, + "utf32_hungarian_ci" => Collation::utf32_hungarian_ci, + "utf32_sinhala_ci" => Collation::utf32_sinhala_ci, + "tis620_thai_ci" => Collation::tis620_thai_ci, + "utf32_german2_ci" => Collation::utf32_german2_ci, + "utf32_croatian_ci" => Collation::utf32_croatian_ci, + "utf32_unicode_520_ci" => Collation::utf32_unicode_520_ci, + "utf32_vietnamese_ci" => Collation::utf32_vietnamese_ci, + "euckr_korean_ci" => Collation::euckr_korean_ci, + "utf8_unicode_ci" => Collation::utf8_unicode_ci, + "utf8_icelandic_ci" => Collation::utf8_icelandic_ci, + "utf8_latvian_ci" => Collation::utf8_latvian_ci, + "utf8_romanian_ci" => Collation::utf8_romanian_ci, + "utf8_slovenian_ci" => Collation::utf8_slovenian_ci, + "utf8_polish_ci" => Collation::utf8_polish_ci, + "utf8_estonian_ci" => Collation::utf8_estonian_ci, + "utf8_spanish_ci" => Collation::utf8_spanish_ci, + "latin2_czech_cs" => Collation::latin2_czech_cs, + "latin7_estonian_cs" => Collation::latin7_estonian_cs, + "utf8_swedish_ci" => Collation::utf8_swedish_ci, + "utf8_turkish_ci" => Collation::utf8_turkish_ci, + "utf8_czech_ci" => Collation::utf8_czech_ci, + "utf8_danish_ci" => Collation::utf8_danish_ci, + "utf8_lithuanian_ci" => Collation::utf8_lithuanian_ci, + "utf8_slovak_ci" => Collation::utf8_slovak_ci, + "utf8_spanish2_ci" => Collation::utf8_spanish2_ci, + "utf8_roman_ci" => Collation::utf8_roman_ci, + "utf8_persian_ci" => Collation::utf8_persian_ci, + "utf8_esperanto_ci" => Collation::utf8_esperanto_ci, + "latin2_hungarian_ci" => Collation::latin2_hungarian_ci, + "utf8_hungarian_ci" => Collation::utf8_hungarian_ci, + "utf8_sinhala_ci" => Collation::utf8_sinhala_ci, + "utf8_german2_ci" => Collation::utf8_german2_ci, + "utf8_croatian_ci" => Collation::utf8_croatian_ci, + "utf8_unicode_520_ci" => Collation::utf8_unicode_520_ci, + "utf8_vietnamese_ci" => Collation::utf8_vietnamese_ci, + "koi8u_general_ci" => Collation::koi8u_general_ci, + "utf8_general_mysql500_ci" => Collation::utf8_general_mysql500_ci, + "utf8mb4_unicode_ci" => Collation::utf8mb4_unicode_ci, + "utf8mb4_icelandic_ci" => Collation::utf8mb4_icelandic_ci, + "utf8mb4_latvian_ci" => Collation::utf8mb4_latvian_ci, + "utf8mb4_romanian_ci" => Collation::utf8mb4_romanian_ci, + "utf8mb4_slovenian_ci" => Collation::utf8mb4_slovenian_ci, + "utf8mb4_polish_ci" => Collation::utf8mb4_polish_ci, + "cp1251_ukrainian_ci" => Collation::cp1251_ukrainian_ci, + "utf8mb4_estonian_ci" => Collation::utf8mb4_estonian_ci, + "utf8mb4_spanish_ci" => Collation::utf8mb4_spanish_ci, + "utf8mb4_swedish_ci" => Collation::utf8mb4_swedish_ci, + "utf8mb4_turkish_ci" => Collation::utf8mb4_turkish_ci, + "utf8mb4_czech_ci" => Collation::utf8mb4_czech_ci, + "utf8mb4_danish_ci" => Collation::utf8mb4_danish_ci, + "utf8mb4_lithuanian_ci" => Collation::utf8mb4_lithuanian_ci, + "utf8mb4_slovak_ci" => Collation::utf8mb4_slovak_ci, + "utf8mb4_spanish2_ci" => Collation::utf8mb4_spanish2_ci, + "utf8mb4_roman_ci" => Collation::utf8mb4_roman_ci, + "gb2312_chinese_ci" => Collation::gb2312_chinese_ci, + "utf8mb4_persian_ci" => Collation::utf8mb4_persian_ci, + "utf8mb4_esperanto_ci" => Collation::utf8mb4_esperanto_ci, + "utf8mb4_hungarian_ci" => Collation::utf8mb4_hungarian_ci, + "utf8mb4_sinhala_ci" => Collation::utf8mb4_sinhala_ci, + "utf8mb4_german2_ci" => Collation::utf8mb4_german2_ci, + "utf8mb4_croatian_ci" => Collation::utf8mb4_croatian_ci, + "utf8mb4_unicode_520_ci" => Collation::utf8mb4_unicode_520_ci, + "utf8mb4_vietnamese_ci" => Collation::utf8mb4_vietnamese_ci, + "gb18030_chinese_ci" => Collation::gb18030_chinese_ci, + "gb18030_bin" => Collation::gb18030_bin, + "greek_general_ci" => Collation::greek_general_ci, + "gb18030_unicode_520_ci" => Collation::gb18030_unicode_520_ci, + "utf8mb4_0900_ai_ci" => Collation::utf8mb4_0900_ai_ci, + "cp1250_general_ci" => Collation::cp1250_general_ci, + "latin2_croatian_ci" => Collation::latin2_croatian_ci, + "gbk_chinese_ci" => Collation::gbk_chinese_ci, + "cp1257_lithuanian_ci" => Collation::cp1257_lithuanian_ci, + "dec8_swedish_ci" => Collation::dec8_swedish_ci, + "latin5_turkish_ci" => Collation::latin5_turkish_ci, + "latin1_german2_ci" => Collation::latin1_german2_ci, + "armscii8_general_ci" => Collation::armscii8_general_ci, + "utf8_general_ci" => Collation::utf8_general_ci, + "cp1250_czech_cs" => Collation::cp1250_czech_cs, + "ucs2_general_ci" => Collation::ucs2_general_ci, + "cp866_general_ci" => Collation::cp866_general_ci, + "keybcs2_general_ci" => Collation::keybcs2_general_ci, + "macce_general_ci" => Collation::macce_general_ci, + "macroman_general_ci" => Collation::macroman_general_ci, + "cp850_general_ci" => Collation::cp850_general_ci, + "cp852_general_ci" => Collation::cp852_general_ci, + "latin7_general_ci" => Collation::latin7_general_ci, + "latin7_general_cs" => Collation::latin7_general_cs, + "macce_bin" => Collation::macce_bin, + "cp1250_croatian_ci" => Collation::cp1250_croatian_ci, + "utf8mb4_general_ci" => Collation::utf8mb4_general_ci, + "utf8mb4_bin" => Collation::utf8mb4_bin, + "latin1_bin" => Collation::latin1_bin, + "latin1_general_ci" => Collation::latin1_general_ci, + "latin1_general_cs" => Collation::latin1_general_cs, + "latin1_german1_ci" => Collation::latin1_german1_ci, + "cp1251_bin" => Collation::cp1251_bin, + "cp1251_general_ci" => Collation::cp1251_general_ci, + "cp1251_general_cs" => Collation::cp1251_general_cs, + "macroman_bin" => Collation::macroman_bin, + "utf16_general_ci" => Collation::utf16_general_ci, + "utf16_bin" => Collation::utf16_bin, + "utf16le_general_ci" => Collation::utf16le_general_ci, + "cp1256_general_ci" => Collation::cp1256_general_ci, + "cp1257_bin" => Collation::cp1257_bin, + "cp1257_general_ci" => Collation::cp1257_general_ci, + "hp8_english_ci" => Collation::hp8_english_ci, + "utf32_general_ci" => Collation::utf32_general_ci, + "utf32_bin" => Collation::utf32_bin, + "utf16le_bin" => Collation::utf16le_bin, + "binary" => Collation::binary, + "armscii8_bin" => Collation::armscii8_bin, + "ascii_bin" => Collation::ascii_bin, + "cp1250_bin" => Collation::cp1250_bin, + "cp1256_bin" => Collation::cp1256_bin, + "cp866_bin" => Collation::cp866_bin, + "dec8_bin" => Collation::dec8_bin, + "koi8r_general_ci" => Collation::koi8r_general_ci, + "greek_bin" => Collation::greek_bin, + "hebrew_bin" => Collation::hebrew_bin, + "hp8_bin" => Collation::hp8_bin, + "keybcs2_bin" => Collation::keybcs2_bin, + "koi8r_bin" => Collation::koi8r_bin, + "koi8u_bin" => Collation::koi8u_bin, + "utf8_tolower_ci" => Collation::utf8_tolower_ci, + "latin2_bin" => Collation::latin2_bin, + "latin5_bin" => Collation::latin5_bin, + "latin7_bin" => Collation::latin7_bin, + "latin1_swedish_ci" => Collation::latin1_swedish_ci, + "cp850_bin" => Collation::cp850_bin, + "cp852_bin" => Collation::cp852_bin, + "swe7_bin" => Collation::swe7_bin, + "utf8_bin" => Collation::utf8_bin, + "big5_bin" => Collation::big5_bin, + "euckr_bin" => Collation::euckr_bin, + "gb2312_bin" => Collation::gb2312_bin, + "gbk_bin" => Collation::gbk_bin, + "sjis_bin" => Collation::sjis_bin, + "tis620_bin" => Collation::tis620_bin, + "latin2_general_ci" => Collation::latin2_general_ci, + "ucs2_bin" => Collation::ucs2_bin, + "ujis_bin" => Collation::ujis_bin, + "geostd8_general_ci" => Collation::geostd8_general_ci, + "geostd8_bin" => Collation::geostd8_bin, + "latin1_spanish_ci" => Collation::latin1_spanish_ci, + "cp932_japanese_ci" => Collation::cp932_japanese_ci, + "cp932_bin" => Collation::cp932_bin, + "eucjpms_japanese_ci" => Collation::eucjpms_japanese_ci, + "eucjpms_bin" => Collation::eucjpms_bin, + "cp1250_polish_ci" => Collation::cp1250_polish_ci, + + _ => { + return Err(Error::Configuration( + format!("unsupported MySQL collation: {}", collation).into(), + )); + } + }) + } +} diff --git a/warpgate-database-protocols/src/mysql/io/buf.rs b/warpgate-database-protocols/src/mysql/io/buf.rs new file mode 100644 index 0000000..9ccb62e --- /dev/null +++ b/warpgate-database-protocols/src/mysql/io/buf.rs @@ -0,0 +1,40 @@ +use bytes::{Buf, Bytes}; + +use crate::error::Error; +use crate::io::BufExt; + +pub trait MySqlBufExt: Buf { + // Read a length-encoded integer. + // NOTE: 0xfb or NULL is only returned for binary value encoding to indicate NULL. + // NOTE: 0xff is only returned during a result set to indicate ERR. + // + fn get_uint_lenenc(&mut self) -> u64; + + // Read a length-encoded string. + fn get_str_lenenc(&mut self) -> Result; + + // Read a length-encoded byte sequence. + fn get_bytes_lenenc(&mut self) -> Bytes; +} + +impl MySqlBufExt for Bytes { + fn get_uint_lenenc(&mut self) -> u64 { + match self.get_u8() { + 0xfc => u64::from(self.get_u16_le()), + 0xfd => self.get_uint_le(3), + 0xfe => self.get_u64_le(), + + v => u64::from(v), + } + } + + fn get_str_lenenc(&mut self) -> Result { + let size = self.get_uint_lenenc(); + self.get_str(size as usize) + } + + fn get_bytes_lenenc(&mut self) -> Bytes { + let size = self.get_uint_lenenc(); + self.split_to(size as usize) + } +} diff --git a/warpgate-database-protocols/src/mysql/io/buf_mut.rs b/warpgate-database-protocols/src/mysql/io/buf_mut.rs new file mode 100644 index 0000000..ba2ba3e --- /dev/null +++ b/warpgate-database-protocols/src/mysql/io/buf_mut.rs @@ -0,0 +1,126 @@ +use bytes::BufMut; + +pub trait MySqlBufMutExt: BufMut { + fn put_uint_lenenc(&mut self, v: u64); + + fn put_str_lenenc(&mut self, v: &str); + + fn put_bytes_lenenc(&mut self, v: &[u8]); +} + +impl MySqlBufMutExt for Vec { + fn put_uint_lenenc(&mut self, v: u64) { + // https://dev.mysql.com/doc/internals/en/integer.html + // https://mariadb.com/kb/en/library/protocol-data-types/#length-encoded-integers + + if v < 251 { + self.push(v as u8); + } else if v < 0x1_00_00 { + self.push(0xfc); + self.extend(&(v as u16).to_le_bytes()); + } else if v < 0x1_00_00_00 { + self.push(0xfd); + self.extend(&(v as u32).to_le_bytes()[..3]); + } else { + self.push(0xfe); + self.extend(&v.to_le_bytes()); + } + } + + fn put_str_lenenc(&mut self, v: &str) { + self.put_bytes_lenenc(v.as_bytes()); + } + + fn put_bytes_lenenc(&mut self, v: &[u8]) { + self.put_uint_lenenc(v.len() as u64); + self.extend(v); + } +} + +#[test] +fn test_encodes_int_lenenc_u8() { + let mut buf = Vec::with_capacity(1024); + buf.put_uint_lenenc(0xFA_u64); + + assert_eq!(&buf[..], b"\xFA"); +} + +#[test] +fn test_encodes_int_lenenc_u16() { + let mut buf = Vec::with_capacity(1024); + buf.put_uint_lenenc(std::u16::MAX as u64); + + assert_eq!(&buf[..], b"\xFC\xFF\xFF"); +} + +#[test] +fn test_encodes_int_lenenc_u24() { + let mut buf = Vec::with_capacity(1024); + buf.put_uint_lenenc(0xFF_FF_FF_u64); + + assert_eq!(&buf[..], b"\xFD\xFF\xFF\xFF"); +} + +#[test] +fn test_encodes_int_lenenc_u64() { + let mut buf = Vec::with_capacity(1024); + buf.put_uint_lenenc(std::u64::MAX); + + assert_eq!(&buf[..], b"\xFE\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF"); +} + +#[test] +fn test_encodes_int_lenenc_fb() { + let mut buf = Vec::with_capacity(1024); + buf.put_uint_lenenc(0xFB_u64); + + assert_eq!(&buf[..], b"\xFC\xFB\x00"); +} + +#[test] +fn test_encodes_int_lenenc_fc() { + let mut buf = Vec::with_capacity(1024); + buf.put_uint_lenenc(0xFC_u64); + + assert_eq!(&buf[..], b"\xFC\xFC\x00"); +} + +#[test] +fn test_encodes_int_lenenc_fd() { + let mut buf = Vec::with_capacity(1024); + buf.put_uint_lenenc(0xFD_u64); + + assert_eq!(&buf[..], b"\xFC\xFD\x00"); +} + +#[test] +fn test_encodes_int_lenenc_fe() { + let mut buf = Vec::with_capacity(1024); + buf.put_uint_lenenc(0xFE_u64); + + assert_eq!(&buf[..], b"\xFC\xFE\x00"); +} + +#[test] +fn test_encodes_int_lenenc_ff() { + let mut buf = Vec::with_capacity(1024); + buf.put_uint_lenenc(0xFF_u64); + + assert_eq!(&buf[..], b"\xFC\xFF\x00"); +} + +#[test] +fn test_encodes_string_lenenc() { + let mut buf = Vec::with_capacity(1024); + buf.put_str_lenenc("random_string"); + + assert_eq!(&buf[..], b"\x0Drandom_string"); +} + +#[test] +fn test_encodes_byte_lenenc() { + let mut buf = Vec::with_capacity(1024); + buf.put_bytes_lenenc(b"random_string"); + + assert_eq!(&buf[..], b"\x0Drandom_string"); +} diff --git a/warpgate-database-protocols/src/mysql/io/mod.rs b/warpgate-database-protocols/src/mysql/io/mod.rs new file mode 100644 index 0000000..fafc914 --- /dev/null +++ b/warpgate-database-protocols/src/mysql/io/mod.rs @@ -0,0 +1,5 @@ +mod buf; +mod buf_mut; + +pub use buf::MySqlBufExt; +pub use buf_mut::MySqlBufMutExt; diff --git a/warpgate-database-protocols/src/mysql/mod.rs b/warpgate-database-protocols/src/mysql/mod.rs new file mode 100644 index 0000000..146321f --- /dev/null +++ b/warpgate-database-protocols/src/mysql/mod.rs @@ -0,0 +1,5 @@ +//! **MySQL** database driver. + +pub mod collation; +pub mod io; +pub mod protocol; diff --git a/warpgate-database-protocols/src/mysql/protocol/auth.rs b/warpgate-database-protocols/src/mysql/protocol/auth.rs new file mode 100644 index 0000000..aa27bf7 --- /dev/null +++ b/warpgate-database-protocols/src/mysql/protocol/auth.rs @@ -0,0 +1,38 @@ +use std::str::FromStr; + +use crate::err_protocol; +use crate::error::Error; + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum AuthPlugin { + MySqlClearPassword, + MySqlNativePassword, + CachingSha2Password, + Sha256Password, +} + +impl AuthPlugin { + pub(crate) fn name(self) -> &'static str { + match self { + AuthPlugin::MySqlClearPassword => "mysql_clear_password", + AuthPlugin::MySqlNativePassword => "mysql_native_password", + AuthPlugin::CachingSha2Password => "caching_sha2_password", + AuthPlugin::Sha256Password => "sha256_password", + } + } +} + +impl FromStr for AuthPlugin { + type Err = Error; + + fn from_str(s: &str) -> Result { + match s { + "mysql_clear_password" => Ok(AuthPlugin::MySqlClearPassword), + "mysql_native_password" => Ok(AuthPlugin::MySqlNativePassword), + "caching_sha2_password" => Ok(AuthPlugin::CachingSha2Password), + "sha256_password" => Ok(AuthPlugin::Sha256Password), + + _ => Err(err_protocol!("unknown authentication plugin: {}", s)), + } + } +} diff --git a/warpgate-database-protocols/src/mysql/protocol/capabilities.rs b/warpgate-database-protocols/src/mysql/protocol/capabilities.rs new file mode 100644 index 0000000..6d7b582 --- /dev/null +++ b/warpgate-database-protocols/src/mysql/protocol/capabilities.rs @@ -0,0 +1,86 @@ +// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/group__group__cs__capabilities__flags.html +// https://mariadb.com/kb/en/library/connection/#capabilities +bitflags::bitflags! { + pub struct Capabilities: u64 { + // [MariaDB] MySQL compatibility + const MYSQL = 1; + + // [*] Send found rows instead of affected rows in EOF_Packet. + const FOUND_ROWS = 2; + + // Get all column flags. + const LONG_FLAG = 4; + + // [*] Database (schema) name can be specified on connect in Handshake Response Packet. + const CONNECT_WITH_DB = 8; + + // Don't allow database.table.column + const NO_SCHEMA = 16; + + // [*] Compression protocol supported + const COMPRESS = 32; + + // Special handling of ODBC behavior. + const ODBC = 64; + + // Can use LOAD DATA LOCAL + const LOCAL_FILES = 128; + + // [*] Ignore spaces before '(' + const IGNORE_SPACE = 256; + + // [*] New 4.1+ protocol + const PROTOCOL_41 = 512; + + // This is an interactive client + const INTERACTIVE = 1024; + + // Use SSL encryption for this session + const SSL = 2048; + + // Client knows about transactions + const TRANSACTIONS = 8192; + + // 4.1+ authentication + const SECURE_CONNECTION = (1 << 15); + + // Enable/disable multi-statement support for COM_QUERY *and* COM_STMT_PREPARE + const MULTI_STATEMENTS = (1 << 16); + + // Enable/disable multi-results for COM_QUERY + const MULTI_RESULTS = (1 << 17); + + // Enable/disable multi-results for COM_STMT_PREPARE + const PS_MULTI_RESULTS = (1 << 18); + + // Client supports plugin authentication + const PLUGIN_AUTH = (1 << 19); + + // Client supports connection attributes + const CONNECT_ATTRS = (1 << 20); + + // Enable authentication response packet to be larger than 255 bytes. + const PLUGIN_AUTH_LENENC_DATA = (1 << 21); + + // Don't close the connection for a user account with expired password. + const CAN_HANDLE_EXPIRED_PASSWORDS = (1 << 22); + + // Capable of handling server state change information. + const SESSION_TRACK = (1 << 23); + + // Client no longer needs EOF_Packet and will use OK_Packet instead. + const DEPRECATE_EOF = (1 << 24); + + // Support ZSTD protocol compression + const ZSTD_COMPRESSION_ALGORITHM = (1 << 26); + + // Verify server certificate + const SSL_VERIFY_SERVER_CERT = (1 << 30); + + // The client can handle optional metadata information in the resultset + const OPTIONAL_RESULTSET_METADATA = (1 << 25); + + // Don't reset the options after an unsuccessful connect + const REMEMBER_OPTIONS = (1 << 31); + } +} diff --git a/warpgate-database-protocols/src/mysql/protocol/connect/auth_switch.rs b/warpgate-database-protocols/src/mysql/protocol/connect/auth_switch.rs new file mode 100644 index 0000000..bdb330f --- /dev/null +++ b/warpgate-database-protocols/src/mysql/protocol/connect/auth_switch.rs @@ -0,0 +1,58 @@ +use bytes::{Buf, BufMut, Bytes}; + +use crate::err_protocol; +use crate::error::Error; +use crate::io::{BufExt, BufMutExt, Decode, Encode}; +use crate::mysql::protocol::auth::AuthPlugin; +use crate::mysql::protocol::Capabilities; + +// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase_packets_protocol_auth_switch_request.html + +#[derive(Debug)] +pub struct AuthSwitchRequest { + pub plugin: AuthPlugin, + pub data: Bytes, +} + +impl Decode<'_> for AuthSwitchRequest { + fn decode_with(mut buf: Bytes, _: ()) -> Result { + let header = buf.get_u8(); + if header != 0xfe { + return Err(err_protocol!( + "expected 0xfe (AUTH_SWITCH) but found 0x{:x}", + header + )); + } + + let plugin = buf.get_str_nul()?.parse()?; + + // See: https://github.com/mysql/mysql-server/blob/ea7d2e2d16ac03afdd9cb72a972a95981107bf51/sql/auth/sha2_password.cc#L942 + if buf.len() != 21 { + return Err(err_protocol!( + "expected 21 bytes but found {} bytes", + buf.len() + )); + } + let data = buf.get_bytes(20); + buf.advance(1); // NUL-terminator + + Ok(Self { plugin, data }) + } +} + +impl Encode<'_, ()> for AuthSwitchRequest { + fn encode_with(&self, buf: &mut Vec, _: ()) { + buf.put_u8(0xfe); + buf.put_str_nul(self.plugin.name()); + buf.extend(&self.data); + } +} + +#[derive(Debug)] +pub struct AuthSwitchResponse(pub Vec); + +impl Encode<'_, Capabilities> for AuthSwitchResponse { + fn encode_with(&self, buf: &mut Vec, _: Capabilities) { + buf.extend_from_slice(&self.0); + } +} diff --git a/warpgate-database-protocols/src/mysql/protocol/connect/handshake.rs b/warpgate-database-protocols/src/mysql/protocol/connect/handshake.rs new file mode 100644 index 0000000..32ccdc3 --- /dev/null +++ b/warpgate-database-protocols/src/mysql/protocol/connect/handshake.rs @@ -0,0 +1,233 @@ +use bytes::buf::Chain; +use bytes::{Buf, BufMut, Bytes}; + +use crate::error::Error; +use crate::io::{BufExt, BufMutExt, Decode, Encode}; +use crate::mysql::protocol::auth::AuthPlugin; +use crate::mysql::protocol::response::Status; +use crate::mysql::protocol::Capabilities; + +// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake +// https://mariadb.com/kb/en/connection/#initial-handshake-packet + +#[derive(Debug)] +pub struct Handshake { + #[allow(unused)] + pub protocol_version: u8, + pub server_version: String, + #[allow(unused)] + pub connection_id: u32, + pub server_capabilities: Capabilities, + #[allow(unused)] + pub server_default_collation: u8, + #[allow(unused)] + pub status: Status, + pub auth_plugin: Option, + pub auth_plugin_data: Chain, +} + +impl Decode<'_> for Handshake { + fn decode_with(mut buf: Bytes, _: ()) -> Result { + let protocol_version = buf.get_u8(); // int<1> + let server_version = buf.get_str_nul()?; // string + let connection_id = buf.get_u32_le(); // int<4> + let auth_plugin_data_1 = buf.get_bytes(8); // string<8> + + buf.advance(1); // reserved: string<1> + + let capabilities_1 = buf.get_u16_le(); // int<2> + let mut capabilities = Capabilities::from_bits_truncate(capabilities_1.into()); + + let collation = buf.get_u8(); // int<1> + let status = Status::from_bits_truncate(buf.get_u16_le()); + + let capabilities_2 = buf.get_u16_le(); // int<2> + capabilities |= Capabilities::from_bits_truncate(((capabilities_2 as u32) << 16).into()); + + let auth_plugin_data_len = if capabilities.contains(Capabilities::PLUGIN_AUTH) { + buf.get_u8() + } else { + buf.advance(1); // int<1> + 0 + }; + + buf.advance(6); // reserved: string<6> + + if capabilities.contains(Capabilities::MYSQL) { + buf.advance(4); // reserved: string<4> + } else { + let capabilities_3 = buf.get_u32_le(); // int<4> + capabilities |= Capabilities::from_bits_truncate((capabilities_3 as u64) << 32); + } + + let auth_plugin_data_2 = if capabilities.contains(Capabilities::SECURE_CONNECTION) { + let len = ((auth_plugin_data_len as isize) - 9).max(12) as usize; + let v = buf.get_bytes(len); + buf.advance(1); // NUL-terminator + + v + } else { + Bytes::new() + }; + + let auth_plugin = if capabilities.contains(Capabilities::PLUGIN_AUTH) { + Some(buf.get_str_nul()?.parse()?) + } else { + None + }; + + Ok(Self { + protocol_version, + server_version, + connection_id, + server_default_collation: collation, + status, + server_capabilities: capabilities, + auth_plugin, + auth_plugin_data: auth_plugin_data_1.chain(auth_plugin_data_2), + }) + } +} + +impl Encode<'_, ()> for Handshake { + fn encode_with(&self, buf: &mut Vec, _: ()) { + buf.put_u8(self.protocol_version); + buf.put_str_nul(&self.server_version); + buf.put_u32_le(self.connection_id); + buf.put_slice(self.auth_plugin_data.first_ref()); + buf.put_u8(0x00); + buf.put_u16_le((self.server_capabilities.bits() & 0x0000_FFFF) as u16); + buf.put_u8(self.server_default_collation); + buf.put_u16_le(self.status.bits()); + buf.put_u16_le(((self.server_capabilities.bits() & 0xFFFF_0000) >> 16) as u16); + + if self.server_capabilities.contains(Capabilities::PLUGIN_AUTH) { + buf.put_u8((self.auth_plugin_data.last_ref().len() + 8 + 1) as u8); + } else { + buf.put_u8(0); + } + + buf.put_slice(&[0_u8; 10][..]); + + if self + .server_capabilities + .contains(Capabilities::SECURE_CONNECTION) + { + buf.put_slice(self.auth_plugin_data.last_ref()); + buf.put_u8(0); + } + + if self.server_capabilities.contains(Capabilities::PLUGIN_AUTH) { + if let Some(auth_plugin) = self.auth_plugin { + buf.put_str_nul(auth_plugin.name()); + } + } + } +} + +#[test] +fn test_decode_handshake_mysql_8_0_18() { + const HANDSHAKE_MYSQL_8_0_18: &[u8] = b"\n8.0.18\x00\x19\x00\x00\x00\x114aB0c\x06g\x00\xff\xff\xff\x02\x00\xff\xc7\x15\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00tL\x03s\x0f[4\rl4. \x00caching_sha2_password\x00"; + + let mut p = Handshake::decode(HANDSHAKE_MYSQL_8_0_18.into()).unwrap(); + + assert_eq!(p.protocol_version, 10); + + p.server_capabilities.toggle( + Capabilities::MYSQL + | Capabilities::FOUND_ROWS + | Capabilities::LONG_FLAG + | Capabilities::CONNECT_WITH_DB + | Capabilities::NO_SCHEMA + | Capabilities::COMPRESS + | Capabilities::ODBC + | Capabilities::LOCAL_FILES + | Capabilities::IGNORE_SPACE + | Capabilities::PROTOCOL_41 + | Capabilities::INTERACTIVE + | Capabilities::SSL + | Capabilities::TRANSACTIONS + | Capabilities::SECURE_CONNECTION + | Capabilities::MULTI_STATEMENTS + | Capabilities::MULTI_RESULTS + | Capabilities::PS_MULTI_RESULTS + | Capabilities::PLUGIN_AUTH + | Capabilities::CONNECT_ATTRS + | Capabilities::PLUGIN_AUTH_LENENC_DATA + | Capabilities::CAN_HANDLE_EXPIRED_PASSWORDS + | Capabilities::SESSION_TRACK + | Capabilities::DEPRECATE_EOF + | Capabilities::ZSTD_COMPRESSION_ALGORITHM + | Capabilities::SSL_VERIFY_SERVER_CERT + | Capabilities::OPTIONAL_RESULTSET_METADATA + | Capabilities::REMEMBER_OPTIONS, + ); + + assert!(p.server_capabilities.is_empty()); + + assert_eq!(p.server_default_collation, 255); + assert!(p.status.contains(Status::SERVER_STATUS_AUTOCOMMIT)); + + assert!(matches!( + p.auth_plugin, + Some(AuthPlugin::CachingSha2Password) + )); + + assert_eq!( + &*p.auth_plugin_data.into_iter().collect::>(), + &[17, 52, 97, 66, 48, 99, 6, 103, 116, 76, 3, 115, 15, 91, 52, 13, 108, 52, 46, 32,] + ); +} + +#[test] +fn test_decode_handshake_mariadb_10_4_7() { + const HANDSHAKE_MARIA_DB_10_4_7: &[u8] = b"\n5.5.5-10.4.7-MariaDB-1:10.4.7+maria~bionic\x00\x0b\x00\x00\x00t6L\\j\"dS\x00\xfe\xf7\x08\x02\x00\xff\x81\x15\x00\x00\x00\x00\x00\x00\x07\x00\x00\x00U14Oph9\">(), + &[116, 54, 76, 92, 106, 34, 100, 83, 85, 49, 52, 79, 112, 104, 57, 34, 60, 72, 53, 110,] + ); +} diff --git a/warpgate-database-protocols/src/mysql/protocol/connect/handshake_response.rs b/warpgate-database-protocols/src/mysql/protocol/connect/handshake_response.rs new file mode 100644 index 0000000..29862ca --- /dev/null +++ b/warpgate-database-protocols/src/mysql/protocol/connect/handshake_response.rs @@ -0,0 +1,147 @@ +use std::str::FromStr; + +use bytes::{Buf, Bytes}; + +use crate::error::Error; +use crate::io::{BufExt, BufMutExt, Decode, Encode}; +use crate::mysql::io::{MySqlBufExt, MySqlBufMutExt}; +use crate::mysql::protocol::auth::AuthPlugin; +use crate::mysql::protocol::connect::ssl_request::SslRequest; +use crate::mysql::protocol::Capabilities; + +// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse +// https://mariadb.com/kb/en/connection/#client-handshake-response + +#[derive(Debug)] +pub struct HandshakeResponse { + pub database: Option, + + /// Max size of a command packet that the client wants to send to the server + pub max_packet_size: u32, + + /// Default collation for the connection + pub collation: u8, + + /// Name of the SQL account which client wants to log in + pub username: String, + + /// Authentication method used by the client + pub auth_plugin: Option, + + /// Opaque authentication response + pub auth_response: Option, +} + +impl Encode<'_, Capabilities> for HandshakeResponse { + fn encode_with(&self, buf: &mut Vec, mut capabilities: Capabilities) { + if self.auth_plugin.is_none() { + // ensure PLUGIN_AUTH is set *only* if we have a defined plugin + capabilities.remove(Capabilities::PLUGIN_AUTH); + } + + // NOTE: Half of this packet is identical to the SSL Request packet + SslRequest { + max_packet_size: self.max_packet_size, + collation: self.collation, + } + .encode_with(buf, capabilities); + + buf.put_str_nul(&self.username); + + if capabilities.contains(Capabilities::PLUGIN_AUTH_LENENC_DATA) { + if let Some(response) = &self.auth_response { + buf.put_bytes_lenenc(response); + } else { + buf.put_bytes_lenenc(&[]); + } + } else if capabilities.contains(Capabilities::SECURE_CONNECTION) { + if let Some(response) = &self.auth_response { + buf.push(response.len() as u8); + buf.extend(response); + } else { + buf.push(0); + } + } else { + buf.push(0); + } + + if capabilities.contains(Capabilities::CONNECT_WITH_DB) { + if let Some(database) = &self.database { + buf.put_str_nul(database); + } else { + buf.push(0); + } + } + + if capabilities.contains(Capabilities::PLUGIN_AUTH) { + if let Some(plugin) = &self.auth_plugin { + buf.put_str_nul(plugin.name()); + } else { + buf.push(0); + } + } + } +} + +impl Decode<'_, &mut Capabilities> for HandshakeResponse { + fn decode_with(mut buf: Bytes, server_capabilities: &mut Capabilities) -> Result { + let mut capabilities = buf.get_u32_le() as u64; + let max_packet_size = buf.get_u32_le(); + let collation = buf.get_u8(); + buf.advance(19); + + let partial_cap = Capabilities::from_bits_truncate(capabilities); + + if partial_cap.contains(Capabilities::MYSQL) { + // reserved: string<4> + buf.advance(4); + } else { + capabilities += (buf.get_u32_le() as u64) << 32; + } + + let partial_cap = Capabilities::from_bits_truncate(capabilities); + if partial_cap.contains(Capabilities::SSL) && buf.is_empty() { + return Ok(HandshakeResponse { + collation, + max_packet_size, + username: "".to_string(), + auth_response: None, + auth_plugin: None, + database: None, + }); + } + let username = buf.get_str_nul()?; + + let auth_response = if partial_cap.contains(Capabilities::PLUGIN_AUTH_LENENC_DATA) { + Some(buf.get_bytes_lenenc()) + } else if partial_cap.contains(Capabilities::SECURE_CONNECTION) { + let len = buf.get_u8(); + Some(buf.get_bytes(len as usize)) + } else { + Some(buf.get_bytes_nul()?) + }; + + let database = if partial_cap.contains(Capabilities::CONNECT_WITH_DB) { + Some(buf.get_str_nul()?) + } else { + None + }; + + let auth_plugin: Option = if partial_cap.contains(Capabilities::PLUGIN_AUTH) { + Some(AuthPlugin::from_str(&buf.get_str_nul()?)?) + } else { + None + }; + + *server_capabilities &= Capabilities::from_bits_truncate(capabilities); + + Ok(HandshakeResponse { + collation, + max_packet_size, + username, + auth_response, + auth_plugin, + database, + }) + } +} diff --git a/warpgate-database-protocols/src/mysql/protocol/connect/mod.rs b/warpgate-database-protocols/src/mysql/protocol/connect/mod.rs new file mode 100644 index 0000000..71f9999 --- /dev/null +++ b/warpgate-database-protocols/src/mysql/protocol/connect/mod.rs @@ -0,0 +1,13 @@ +//! Connection Phase +//! +//! + +mod auth_switch; +mod handshake; +mod handshake_response; +mod ssl_request; + +pub use auth_switch::{AuthSwitchRequest, AuthSwitchResponse}; +pub use handshake::Handshake; +pub use handshake_response::HandshakeResponse; +pub use ssl_request::SslRequest; diff --git a/warpgate-database-protocols/src/mysql/protocol/connect/ssl_request.rs b/warpgate-database-protocols/src/mysql/protocol/connect/ssl_request.rs new file mode 100644 index 0000000..5f0c2d8 --- /dev/null +++ b/warpgate-database-protocols/src/mysql/protocol/connect/ssl_request.rs @@ -0,0 +1,30 @@ +use crate::io::Encode; +use crate::mysql::protocol::Capabilities; + +// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase_packets_protocol_handshake_response.html +// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest + +#[derive(Debug)] +pub struct SslRequest { + pub max_packet_size: u32, + pub collation: u8, +} + +impl Encode<'_, Capabilities> for SslRequest { + fn encode_with(&self, buf: &mut Vec, capabilities: Capabilities) { + buf.extend(&(capabilities.bits() as u32).to_le_bytes()); + buf.extend(&self.max_packet_size.to_le_bytes()); + buf.push(self.collation); + + // reserved: string<19> + buf.extend(&[0_u8; 19]); + + if capabilities.contains(Capabilities::MYSQL) { + // reserved: string<4> + buf.extend(&[0_u8; 4]); + } else { + // extended client capabilities (MariaDB-specified): int<4> + buf.extend(&((capabilities.bits() >> 32) as u32).to_le_bytes()); + } + } +} diff --git a/warpgate-database-protocols/src/mysql/protocol/mod.rs b/warpgate-database-protocols/src/mysql/protocol/mod.rs new file mode 100644 index 0000000..22b5a03 --- /dev/null +++ b/warpgate-database-protocols/src/mysql/protocol/mod.rs @@ -0,0 +1,11 @@ +pub mod auth; +pub mod capabilities; +pub mod connect; +pub mod packet; +pub mod response; +pub mod row; +pub mod text; + +pub use capabilities::Capabilities; +pub use packet::Packet; +pub use row::Row; diff --git a/warpgate-database-protocols/src/mysql/protocol/packet.rs b/warpgate-database-protocols/src/mysql/protocol/packet.rs new file mode 100644 index 0000000..add23f0 --- /dev/null +++ b/warpgate-database-protocols/src/mysql/protocol/packet.rs @@ -0,0 +1,89 @@ +use std::ops::{Deref, DerefMut}; + +use bytes::Bytes; + +use crate::error::Error; +use crate::io::{Decode, Encode}; +use crate::mysql::protocol::response::{EofPacket, OkPacket}; +use crate::mysql::protocol::Capabilities; + +#[derive(Debug)] +pub struct Packet(pub T); + +impl<'en, 'stream, T> Encode<'stream, (Capabilities, &'stream mut u8)> for Packet +where + T: Encode<'en, Capabilities>, +{ + fn encode_with( + &self, + buf: &mut Vec, + (capabilities, sequence_id): (Capabilities, &'stream mut u8), + ) { + // reserve space to write the prefixed length + let offset = buf.len(); + buf.extend(&[0_u8; 4]); + + // encode the payload + self.0.encode_with(buf, capabilities); + + // determine the length of the encoded payload + // and write to our reserved space + let len = buf.len() - offset - 4; + let header = &mut buf[offset..]; + + // FIXME: Support larger packets + assert!(len < 0xFF_FF_FF); + + header[..4].copy_from_slice(&(len as u32).to_le_bytes()); + header[3] = *sequence_id; + + *sequence_id = sequence_id.wrapping_add(1); + } +} + +impl Packet { + pub(crate) fn decode<'de, T>(self) -> Result + where + T: Decode<'de, ()>, + { + self.decode_with(()) + } + + pub(crate) fn decode_with<'de, T, C>(self, context: C) -> Result + where + T: Decode<'de, C>, + { + T::decode_with(self.0, context) + } + + pub(crate) fn ok(self) -> Result { + self.decode() + } + + pub(crate) fn eof(self, capabilities: Capabilities) -> Result { + if capabilities.contains(Capabilities::DEPRECATE_EOF) { + let ok = self.ok()?; + + Ok(EofPacket { + warnings: ok.warnings, + status: ok.status, + }) + } else { + self.decode_with(capabilities) + } + } +} + +impl Deref for Packet { + type Target = Bytes; + + fn deref(&self) -> &Bytes { + &self.0 + } +} + +impl DerefMut for Packet { + fn deref_mut(&mut self) -> &mut Bytes { + &mut self.0 + } +} diff --git a/warpgate-database-protocols/src/mysql/protocol/response/eof.rs b/warpgate-database-protocols/src/mysql/protocol/response/eof.rs new file mode 100644 index 0000000..25568b5 --- /dev/null +++ b/warpgate-database-protocols/src/mysql/protocol/response/eof.rs @@ -0,0 +1,36 @@ +use bytes::{Buf, Bytes}; + +use crate::err_protocol; +use crate::error::Error; +use crate::io::Decode; +use crate::mysql::protocol::response::Status; +use crate::mysql::protocol::Capabilities; + +/// Marks the end of a result set, returning status and warnings. +/// +/// # Note +/// +/// The EOF packet is deprecated as of MySQL 5.7.5. SQLx only uses this packet for MySQL +/// prior MySQL versions. +#[derive(Debug)] +pub struct EofPacket { + pub warnings: u16, + pub status: Status, +} + +impl Decode<'_, Capabilities> for EofPacket { + fn decode_with(mut buf: Bytes, _: Capabilities) -> Result { + let header = buf.get_u8(); + if header != 0xfe { + return Err(err_protocol!( + "expected 0xfe (EOF_Packet) but found 0x{:x}", + header + )); + } + + let warnings = buf.get_u16_le(); + let status = Status::from_bits_truncate(buf.get_u16_le()); + + Ok(Self { status, warnings }) + } +} diff --git a/warpgate-database-protocols/src/mysql/protocol/response/err.rs b/warpgate-database-protocols/src/mysql/protocol/response/err.rs new file mode 100644 index 0000000..1071de5 --- /dev/null +++ b/warpgate-database-protocols/src/mysql/protocol/response/err.rs @@ -0,0 +1,81 @@ +use bytes::{Buf, BufMut, Bytes}; + +use crate::err_protocol; +use crate::error::Error; +use crate::io::{BufExt, Decode, Encode}; +use crate::mysql::protocol::Capabilities; + +// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_err_packet.html +// https://mariadb.com/kb/en/err_packet/ + +/// Indicates that an error occurred. +#[derive(Debug)] +pub struct ErrPacket { + pub error_code: u16, + pub sql_state: Option, + pub error_message: String, +} + +impl Decode<'_, Capabilities> for ErrPacket { + fn decode_with(mut buf: Bytes, capabilities: Capabilities) -> Result { + let header = buf.get_u8(); + if header != 0xff { + return Err(err_protocol!( + "expected 0xff (ERR_Packet) but found 0x{:x}", + header + )); + } + + let error_code = buf.get_u16_le(); + let mut sql_state = None; + + if capabilities.contains(Capabilities::PROTOCOL_41) { + // If the next byte is '#' then we have a SQL STATE + if buf.get(0) == Some(&0x23) { + buf.advance(1); + sql_state = Some(buf.get_str(5)?); + } + } + + let error_message = buf.get_str(buf.len())?; + + Ok(Self { + error_code, + sql_state, + error_message, + }) + } +} + +impl Encode<'_, ()> for ErrPacket { + fn encode_with(&self, buf: &mut Vec, _: ()) { + buf.put_u8(0xff); + buf.put_u16_le(self.error_code); + buf.extend_from_slice(self.error_message.as_bytes()) + //TODO: sql_state + } +} + +#[test] +fn test_decode_err_packet_out_of_order() { + const ERR_PACKETS_OUT_OF_ORDER: &[u8] = b"\xff\x84\x04Got packets out of order"; + + let p = + ErrPacket::decode_with(ERR_PACKETS_OUT_OF_ORDER.into(), Capabilities::PROTOCOL_41).unwrap(); + + assert_eq!(&p.error_message, "Got packets out of order"); + assert_eq!(p.error_code, 1156); + assert_eq!(p.sql_state, None); +} + +#[test] +fn test_decode_err_packet_unknown_database() { + const ERR_HANDSHAKE_UNKNOWN_DB: &[u8] = b"\xff\x19\x04#42000Unknown database \'unknown\'"; + + let p = + ErrPacket::decode_with(ERR_HANDSHAKE_UNKNOWN_DB.into(), Capabilities::PROTOCOL_41).unwrap(); + + assert_eq!(p.error_code, 1049); + assert_eq!(p.sql_state.as_deref(), Some("42000")); + assert_eq!(&p.error_message, "Unknown database \'unknown\'"); +} diff --git a/warpgate-database-protocols/src/mysql/protocol/response/mod.rs b/warpgate-database-protocols/src/mysql/protocol/response/mod.rs new file mode 100644 index 0000000..79767dc --- /dev/null +++ b/warpgate-database-protocols/src/mysql/protocol/response/mod.rs @@ -0,0 +1,14 @@ +//! Generic Response Packets +//! +//! +//! + +mod eof; +mod err; +mod ok; +mod status; + +pub use eof::EofPacket; +pub use err::ErrPacket; +pub use ok::OkPacket; +pub use status::Status; diff --git a/warpgate-database-protocols/src/mysql/protocol/response/ok.rs b/warpgate-database-protocols/src/mysql/protocol/response/ok.rs new file mode 100644 index 0000000..cfd8089 --- /dev/null +++ b/warpgate-database-protocols/src/mysql/protocol/response/ok.rs @@ -0,0 +1,63 @@ +use bytes::{Buf, BufMut, Bytes}; + +use crate::err_protocol; +use crate::error::Error; +use crate::io::{Decode, Encode}; +use crate::mysql::io::{MySqlBufExt, MySqlBufMutExt}; +use crate::mysql::protocol::response::Status; + +/// Indicates successful completion of a previous command sent by the client. +#[derive(Debug)] +pub struct OkPacket { + pub affected_rows: u64, + pub last_insert_id: u64, + pub status: Status, + pub warnings: u16, +} + +impl Decode<'_> for OkPacket { + fn decode_with(mut buf: Bytes, _: ()) -> Result { + let header = buf.get_u8(); + if header != 0 && header != 0xfe { + return Err(err_protocol!( + "expected 0x00 or 0xfe (OK_Packet) but found 0x{:02x}", + header + )); + } + + let affected_rows = buf.get_uint_lenenc(); + let last_insert_id = buf.get_uint_lenenc(); + let status = Status::from_bits_truncate(buf.get_u16_le()); + let warnings = buf.get_u16_le(); + + Ok(Self { + affected_rows, + last_insert_id, + status, + warnings, + }) + } +} + +impl Encode<'_, ()> for OkPacket { + fn encode_with(&self, buf: &mut Vec, _: ()) { + buf.put_u8(0); + buf.put_uint_lenenc(self.affected_rows); + buf.put_uint_lenenc(self.last_insert_id); + buf.put_u16_le(self.status.bits()); + buf.put_u16_le(self.warnings); + } +} + +#[test] +fn test_decode_ok_packet() { + const DATA: &[u8] = b"\x00\x00\x00\x02@\x00\x00"; + + let p = OkPacket::decode(DATA.into()).unwrap(); + + assert_eq!(p.affected_rows, 0); + assert_eq!(p.last_insert_id, 0); + assert_eq!(p.warnings, 0); + assert!(p.status.contains(Status::SERVER_STATUS_AUTOCOMMIT)); + assert!(p.status.contains(Status::SERVER_SESSION_STATE_CHANGED)); +} diff --git a/warpgate-database-protocols/src/mysql/protocol/response/status.rs b/warpgate-database-protocols/src/mysql/protocol/response/status.rs new file mode 100644 index 0000000..0338c0d --- /dev/null +++ b/warpgate-database-protocols/src/mysql/protocol/response/status.rs @@ -0,0 +1,49 @@ +// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/mysql__com_8h.html#a1d854e841086925be1883e4d7b4e8cad +// https://mariadb.com/kb/en/library/mariadb-connectorc-types-and-definitions/#server-status +bitflags::bitflags! { + pub struct Status: u16 { + // Is raised when a multi-statement transaction has been started, either explicitly, + // by means of BEGIN or COMMIT AND CHAIN, or implicitly, by the first + // transactional statement, when autocommit=off. + const SERVER_STATUS_IN_TRANS = 1; + + // Autocommit mode is set + const SERVER_STATUS_AUTOCOMMIT = 2; + + // Multi query - next query exists. + const SERVER_MORE_RESULTS_EXISTS = 8; + + const SERVER_QUERY_NO_GOOD_INDEX_USED = 16; + const SERVER_QUERY_NO_INDEX_USED = 32; + + // When using COM_STMT_FETCH, indicate that current cursor still has result + const SERVER_STATUS_CURSOR_EXISTS = 64; + + // When using COM_STMT_FETCH, indicate that current cursor has finished to send results + const SERVER_STATUS_LAST_ROW_SENT = 128; + + // Database has been dropped + const SERVER_STATUS_DB_DROPPED = (1 << 8); + + // Current escape mode is "no backslash escape" + const SERVER_STATUS_NO_BACKSLASH_ESCAPES = (1 << 9); + + // A DDL change did have an impact on an existing PREPARE (an automatic + // re-prepare has been executed) + const SERVER_STATUS_METADATA_CHANGED = (1 << 10); + + // Last statement took more than the time value specified + // in server variable long_query_time. + const SERVER_QUERY_WAS_SLOW = (1 << 11); + + // This result-set contain stored procedure output parameter. + const SERVER_PS_OUT_PARAMS = (1 << 12); + + // Current transaction is a read-only transaction. + const SERVER_STATUS_IN_TRANS_READONLY = (1 << 13); + + // This status flag, when on, implies that one of the state information has changed + // on the server because of the execution of the last statement. + const SERVER_SESSION_STATE_CHANGED = (1 << 14); + } +} diff --git a/warpgate-database-protocols/src/mysql/protocol/row.rs b/warpgate-database-protocols/src/mysql/protocol/row.rs new file mode 100644 index 0000000..8e53be6 --- /dev/null +++ b/warpgate-database-protocols/src/mysql/protocol/row.rs @@ -0,0 +1,17 @@ +use std::ops::Range; + +use bytes::Bytes; + +#[derive(Debug)] +pub struct Row { + pub(crate) storage: Bytes, + pub(crate) values: Vec>>, +} + +impl Row { + pub(crate) fn get(&self, index: usize) -> Option<&[u8]> { + self.values[index] + .as_ref() + .map(|col| &self.storage[(col.start as usize)..(col.end as usize)]) + } +} diff --git a/warpgate-database-protocols/src/mysql/protocol/text/column.rs b/warpgate-database-protocols/src/mysql/protocol/text/column.rs new file mode 100644 index 0000000..c901f29 --- /dev/null +++ b/warpgate-database-protocols/src/mysql/protocol/text/column.rs @@ -0,0 +1,265 @@ +use std::str::from_utf8; + +use bitflags::bitflags; +use bytes::{Buf, Bytes}; + +use crate::err_protocol; +use crate::error::Error; +use crate::io::Decode; +use crate::mysql::io::MySqlBufExt; +use crate::mysql::protocol::Capabilities; + +// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/group__group__cs__column__definition__flags.html + +bitflags! { + #[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] + pub struct ColumnFlags: u16 { + /// Field can't be `NULL`. + const NOT_NULL = 1; + + /// Field is part of a primary key. + const PRIMARY_KEY = 2; + + /// Field is part of a unique key. + const UNIQUE_KEY = 4; + + /// Field is part of a multi-part unique or primary key. + const MULTIPLE_KEY = 8; + + /// Field is a blob. + const BLOB = 16; + + /// Field is unsigned. + const UNSIGNED = 32; + + /// Field is zero filled. + const ZEROFILL = 64; + + /// Field is binary. + const BINARY = 128; + + /// Field is an enumeration. + const ENUM = 256; + + /// Field is an auto-incement field. + const AUTO_INCREMENT = 512; + + /// Field is a timestamp. + const TIMESTAMP = 1024; + + /// Field is a set. + const SET = 2048; + + /// Field does not have a default value. + const NO_DEFAULT_VALUE = 4096; + + /// Field is set to NOW on UPDATE. + const ON_UPDATE_NOW = 8192; + + /// Field is a number. + const NUM = 32768; + } +} + +// https://dev.mysql.com/doc/internals/en/com-query-response.html#column-type + +#[derive(Debug, Copy, Clone, PartialEq)] +#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] +#[repr(u8)] +pub enum ColumnType { + Decimal = 0x00, + Tiny = 0x01, + Short = 0x02, + Long = 0x03, + Float = 0x04, + Double = 0x05, + Null = 0x06, + Timestamp = 0x07, + LongLong = 0x08, + Int24 = 0x09, + Date = 0x0a, + Time = 0x0b, + Datetime = 0x0c, + Year = 0x0d, + VarChar = 0x0f, + Bit = 0x10, + Json = 0xf5, + NewDecimal = 0xf6, + Enum = 0xf7, + Set = 0xf8, + TinyBlob = 0xf9, + MediumBlob = 0xfa, + LongBlob = 0xfb, + Blob = 0xfc, + VarString = 0xfd, + String = 0xfe, + Geometry = 0xff, +} + +// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_com_query_response_text_resultset_column_definition.html +// https://mariadb.com/kb/en/resultset/#column-definition-packet +// https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41 + +#[derive(Debug)] +pub struct ColumnDefinition { + #[allow(unused)] + catalog: Bytes, + #[allow(unused)] + schema: Bytes, + #[allow(unused)] + table_alias: Bytes, + #[allow(unused)] + table: Bytes, + alias: Bytes, + name: Bytes, + pub(crate) char_set: u16, + pub(crate) max_size: u32, + pub(crate) r#type: ColumnType, + pub(crate) flags: ColumnFlags, + #[allow(unused)] + decimals: u8, +} + +impl ColumnDefinition { + // NOTE: strings in-protocol are transmitted according to the client character set + // as this is UTF-8, all these strings should be UTF-8 + + pub(crate) fn name(&self) -> Result<&str, Error> { + from_utf8(&self.name).map_err(Error::protocol) + } + + pub(crate) fn alias(&self) -> Result<&str, Error> { + from_utf8(&self.alias).map_err(Error::protocol) + } +} + +impl Decode<'_, Capabilities> for ColumnDefinition { + fn decode_with(mut buf: Bytes, _: Capabilities) -> Result { + let catalog = buf.get_bytes_lenenc(); + let schema = buf.get_bytes_lenenc(); + let table_alias = buf.get_bytes_lenenc(); + let table = buf.get_bytes_lenenc(); + let alias = buf.get_bytes_lenenc(); + let name = buf.get_bytes_lenenc(); + let _next_len = buf.get_uint_lenenc(); // always 0x0c + let char_set = buf.get_u16_le(); + let max_size = buf.get_u32_le(); + let type_id = buf.get_u8(); + let flags = buf.get_u16_le(); + let decimals = buf.get_u8(); + + Ok(Self { + catalog, + schema, + table_alias, + table, + alias, + name, + char_set, + max_size, + r#type: ColumnType::try_from_u16(type_id)?, + flags: ColumnFlags::from_bits_truncate(flags), + decimals, + }) + } +} + +impl ColumnType { + pub(crate) fn name( + self, + char_set: u16, + flags: ColumnFlags, + max_size: Option, + ) -> &'static str { + let is_binary = char_set == 63; + let is_unsigned = flags.contains(ColumnFlags::UNSIGNED); + let is_enum = flags.contains(ColumnFlags::ENUM); + + match self { + ColumnType::Tiny if max_size == Some(1) => "BOOLEAN", + ColumnType::Tiny if is_unsigned => "TINYINT UNSIGNED", + ColumnType::Short if is_unsigned => "SMALLINT UNSIGNED", + ColumnType::Long if is_unsigned => "INT UNSIGNED", + ColumnType::Int24 if is_unsigned => "MEDIUMINT UNSIGNED", + ColumnType::LongLong if is_unsigned => "BIGINT UNSIGNED", + ColumnType::Tiny => "TINYINT", + ColumnType::Short => "SMALLINT", + ColumnType::Long => "INT", + ColumnType::Int24 => "MEDIUMINT", + ColumnType::LongLong => "BIGINT", + ColumnType::Float => "FLOAT", + ColumnType::Double => "DOUBLE", + ColumnType::Null => "NULL", + ColumnType::Timestamp => "TIMESTAMP", + ColumnType::Date => "DATE", + ColumnType::Time => "TIME", + ColumnType::Datetime => "DATETIME", + ColumnType::Year => "YEAR", + ColumnType::Bit => "BIT", + ColumnType::Enum => "ENUM", + ColumnType::Set => "SET", + ColumnType::Decimal | ColumnType::NewDecimal => "DECIMAL", + ColumnType::Geometry => "GEOMETRY", + ColumnType::Json => "JSON", + + ColumnType::String if is_binary => "BINARY", + ColumnType::String if is_enum => "ENUM", + ColumnType::VarChar | ColumnType::VarString if is_binary => "VARBINARY", + + ColumnType::String => "CHAR", + ColumnType::VarChar | ColumnType::VarString => "VARCHAR", + + ColumnType::TinyBlob if is_binary => "TINYBLOB", + ColumnType::TinyBlob => "TINYTEXT", + + ColumnType::Blob if is_binary => "BLOB", + ColumnType::Blob => "TEXT", + + ColumnType::MediumBlob if is_binary => "MEDIUMBLOB", + ColumnType::MediumBlob => "MEDIUMTEXT", + + ColumnType::LongBlob if is_binary => "LONGBLOB", + ColumnType::LongBlob => "LONGTEXT", + } + } + + pub(crate) fn try_from_u16(id: u8) -> Result { + Ok(match id { + 0x00 => ColumnType::Decimal, + 0x01 => ColumnType::Tiny, + 0x02 => ColumnType::Short, + 0x03 => ColumnType::Long, + 0x04 => ColumnType::Float, + 0x05 => ColumnType::Double, + 0x06 => ColumnType::Null, + 0x07 => ColumnType::Timestamp, + 0x08 => ColumnType::LongLong, + 0x09 => ColumnType::Int24, + 0x0a => ColumnType::Date, + 0x0b => ColumnType::Time, + 0x0c => ColumnType::Datetime, + 0x0d => ColumnType::Year, + // [internal] 0x0e => ColumnType::NewDate, + 0x0f => ColumnType::VarChar, + 0x10 => ColumnType::Bit, + // [internal] 0x11 => ColumnType::Timestamp2, + // [internal] 0x12 => ColumnType::Datetime2, + // [internal] 0x13 => ColumnType::Time2, + 0xf5 => ColumnType::Json, + 0xf6 => ColumnType::NewDecimal, + 0xf7 => ColumnType::Enum, + 0xf8 => ColumnType::Set, + 0xf9 => ColumnType::TinyBlob, + 0xfa => ColumnType::MediumBlob, + 0xfb => ColumnType::LongBlob, + 0xfc => ColumnType::Blob, + 0xfd => ColumnType::VarString, + 0xfe => ColumnType::String, + 0xff => ColumnType::Geometry, + + _ => { + return Err(err_protocol!("unknown column type 0x{:02x}", id)); + } + }) + } +} diff --git a/warpgate-database-protocols/src/mysql/protocol/text/mod.rs b/warpgate-database-protocols/src/mysql/protocol/text/mod.rs new file mode 100644 index 0000000..6c174e7 --- /dev/null +++ b/warpgate-database-protocols/src/mysql/protocol/text/mod.rs @@ -0,0 +1,9 @@ +mod column; +mod ping; +mod query; +mod quit; + +pub use column::{ColumnDefinition, ColumnFlags, ColumnType}; +pub use ping::Ping; +pub use query::Query; +pub use quit::Quit; diff --git a/warpgate-database-protocols/src/mysql/protocol/text/ping.rs b/warpgate-database-protocols/src/mysql/protocol/text/ping.rs new file mode 100644 index 0000000..97c21d1 --- /dev/null +++ b/warpgate-database-protocols/src/mysql/protocol/text/ping.rs @@ -0,0 +1,13 @@ +use crate::io::Encode; +use crate::mysql::protocol::Capabilities; + +// https://dev.mysql.com/doc/internals/en/com-ping.html + +#[derive(Debug)] +pub struct Ping; + +impl Encode<'_, Capabilities> for Ping { + fn encode_with(&self, buf: &mut Vec, _: Capabilities) { + buf.push(0x0e); // COM_PING + } +} diff --git a/warpgate-database-protocols/src/mysql/protocol/text/query.rs b/warpgate-database-protocols/src/mysql/protocol/text/query.rs new file mode 100644 index 0000000..3209d48 --- /dev/null +++ b/warpgate-database-protocols/src/mysql/protocol/text/query.rs @@ -0,0 +1,32 @@ +use bytes::{Buf, Bytes}; + +use crate::error::Error; +use crate::io::{BufExt, Decode, Encode}; +use crate::mysql::protocol::Capabilities; + +// https://dev.mysql.com/doc/internals/en/com-query.html + +#[derive(Debug)] +pub struct Query(pub String); + +impl Encode<'_, ()> for Query { + fn encode_with(&self, buf: &mut Vec, _: ()) { + buf.push(0x03); // COM_QUERY + buf.extend(self.0.as_bytes()) + } +} + +impl Encode<'_, Capabilities> for Query { + fn encode_with(&self, buf: &mut Vec, _: Capabilities) { + buf.push(0x03); // COM_QUERY + buf.extend(self.0.as_bytes()) + } +} + +impl Decode<'_> for Query { + fn decode_with(mut buf: Bytes, _: ()) -> Result { + buf.advance(1); + let q = buf.get_str(buf.len())?; + Ok(Query(q)) + } +} diff --git a/warpgate-database-protocols/src/mysql/protocol/text/quit.rs b/warpgate-database-protocols/src/mysql/protocol/text/quit.rs new file mode 100644 index 0000000..ef6676d --- /dev/null +++ b/warpgate-database-protocols/src/mysql/protocol/text/quit.rs @@ -0,0 +1,13 @@ +use crate::io::Encode; +use crate::mysql::protocol::Capabilities; + +// https://dev.mysql.com/doc/internals/en/com-quit.html + +#[derive(Debug)] +pub struct Quit; + +impl Encode<'_, Capabilities> for Quit { + fn encode_with(&self, buf: &mut Vec, _: Capabilities) { + buf.push(0x01); // COM_QUIT + } +} diff --git a/warpgate-protocol-http/Cargo.toml b/warpgate-protocol-http/Cargo.toml index 0e34020..f902cc0 100644 --- a/warpgate-protocol-http/Cargo.toml +++ b/warpgate-protocol-http/Cargo.toml @@ -18,7 +18,7 @@ poem-openapi = {version = "^2.0.4", features = ["swagger-ui"]} reqwest = {version = "0.11", features = ["rustls-tls-native-roots", "stream"]} serde = "1.0" serde_json = "1.0" -tokio = {version = "1.19", features = ["tracing", "signal"]} +tokio = {version = "1.20", features = ["tracing", "signal"]} tokio-tungstenite = {version = "0.17", features = ["rustls-tls-native-roots"]} tracing = "0.1" warpgate-admin = {version = "*", path = "../warpgate-admin"} @@ -27,3 +27,4 @@ warpgate-db-entities = {version = "*", path = "../warpgate-db-entities"} warpgate-web = {version = "*", path = "../warpgate-web"} percent-encoding = "2.1" uuid = {version = "1.0", features = ["v4"]} +regex = "1.5" diff --git a/warpgate-protocol-http/src/api/auth.rs b/warpgate-protocol-http/src/api/auth.rs index b7dc14c..6477df2 100644 --- a/warpgate-protocol-http/src/api/auth.rs +++ b/warpgate-protocol-http/src/api/auth.rs @@ -1,5 +1,6 @@ use crate::common::SessionExt; use crate::session::SessionMiddleware; +use anyhow::Context; use poem::session::Session; use poem::web::Data; use poem::Request; @@ -66,7 +67,7 @@ impl Api { config_provider .authorize(&body.username, &credentials, crate::common::PROTOCOL_NAME) .await - .map_err(|e| e.context("Failed to authorize user"))? + .context("Failed to authorize user")? }; match result { diff --git a/warpgate-protocol-http/src/api/info.rs b/warpgate-protocol-http/src/api/info.rs index 19eb1d1..fc733c1 100644 --- a/warpgate-protocol-http/src/api/info.rs +++ b/warpgate-protocol-http/src/api/info.rs @@ -1,6 +1,3 @@ -use std::net::ToSocketAddrs; - -use crate::common::SessionExt; use poem::session::Session; use poem::web::Data; use poem::Request; @@ -9,11 +6,14 @@ use poem_openapi::{ApiResponse, Object, OpenApi}; use serde::Serialize; use warpgate_common::Services; +use crate::common::SessionExt; + pub struct Api; #[derive(Serialize, Object)] pub struct PortsInfo { - ssh: u16, + ssh: Option, + mysql: Option, } #[derive(Serialize, Object)] @@ -54,15 +54,22 @@ impl Api { external_host: external_host.map(&str::to_string), ports: if session.is_authenticated() { PortsInfo { - ssh: config - .store - .ssh - .listen - .to_socket_addrs() - .map_or(0, |mut x| x.next().map(|x| x.port()).unwrap_or(0)), + ssh: if config.store.ssh.enable { + Some(config.store.ssh.listen.port()) + } else { + None + }, + mysql: if config.store.mysql.enable { + Some(config.store.mysql.listen.port()) + } else { + None + }, } } else { - PortsInfo { ssh: 0 } + PortsInfo { + ssh: None, + mysql: None, + } }, }))) } diff --git a/warpgate-protocol-http/src/api/targets_list.rs b/warpgate-protocol-http/src/api/targets_list.rs index 9e69d25..430f121 100644 --- a/warpgate-protocol-http/src/api/targets_list.rs +++ b/warpgate-protocol-http/src/api/targets_list.rs @@ -12,6 +12,7 @@ pub struct Api; #[derive(Debug, Serialize, Clone, Enum)] pub enum TargetKind { Http, + MySql, Ssh, WebAdmin, } @@ -68,6 +69,7 @@ impl Api { kind: match t.options { TargetOptions::Ssh(_) => TargetKind::Ssh, TargetOptions::Http(_) => TargetKind::Http, + TargetOptions::MySql(_) => TargetKind::MySql, TargetOptions::WebAdmin(_) => TargetKind::WebAdmin, }, }) diff --git a/warpgate-protocol-http/src/main.rs b/warpgate-protocol-http/src/main.rs index 1d14b7b..9fb25b4 100644 --- a/warpgate-protocol-http/src/main.rs +++ b/warpgate-protocol-http/src/main.rs @@ -1,5 +1,6 @@ #![feature(type_alias_impl_trait, let_else, try_blocks)] mod api; +use regex::Regex; mod common; mod session; mod session_handle; @@ -12,5 +13,10 @@ pub fn main() { env!("CARGO_PKG_VERSION"), ) .server("/@warpgate/api"); - println!("{}", api_service.spec()); + + let spec = api_service.spec(); + let re = Regex::new(r"PaginatedResponse<(?P\w+)>").unwrap(); + let spec = re.replace_all(&spec, "Paginated$name"); + + println!("{}", spec); } diff --git a/warpgate-protocol-mysql/Cargo.toml b/warpgate-protocol-mysql/Cargo.toml new file mode 100644 index 0000000..01c5837 --- /dev/null +++ b/warpgate-protocol-mysql/Cargo.toml @@ -0,0 +1,28 @@ +[package] +edition = "2021" +license = "Apache-2.0" +name = "warpgate-protocol-mysql" +version = "0.3.0" + +[dependencies] +warpgate-admin = { version = "*", path = "../warpgate-admin" } +warpgate-common = { version = "*", path = "../warpgate-common" } +warpgate-db-entities = { version = "*", path = "../warpgate-db-entities" } +warpgate-database-protocols = { version = "*", path = "../warpgate-database-protocols" } +anyhow = { version = "1.0", features = ["std"] } +async-trait = "0.1" +tokio = { version = "1.20", features = ["tracing", "signal"] } +tracing = "0.1" +uuid = { version = "1.0", features = ["v4"] } +bytes = "1.1" +mysql_common = "0.29" +rand = "0.8" +sha1 = "0.10.1" +password-hash = { version = "0.2", features = ["std"] } +delegate = "0.6" +rustls = { version = "0.20", features = ["dangerous_configuration"] } +rustls-pemfile = "1.0" +tokio-rustls = "0.23" +thiserror = "1.0" +webpki = "0.22" +webpki-roots = "0.22" diff --git a/warpgate-protocol-mysql/src/client.rs b/warpgate-protocol-mysql/src/client.rs new file mode 100644 index 0000000..a2ba6b0 --- /dev/null +++ b/warpgate-protocol-mysql/src/client.rs @@ -0,0 +1,150 @@ +use std::sync::Arc; + +use bytes::BytesMut; +use tokio::net::TcpStream; +use tracing::*; +use warpgate_common::{TargetMySqlOptions, TlsMode}; +use warpgate_database_protocols::io::Decode; +use warpgate_database_protocols::mysql::protocol::auth::AuthPlugin; +use warpgate_database_protocols::mysql::protocol::connect::{ + Handshake, HandshakeResponse, SslRequest, +}; +use warpgate_database_protocols::mysql::protocol::response::ErrPacket; +use warpgate_database_protocols::mysql::protocol::Capabilities; + +use crate::common::compute_auth_challenge_response; +use crate::error::MySqlError; +use crate::stream::MySqlStream; +use crate::tls::configure_tls_connector; + +pub struct MySqlClient { + pub stream: MySqlStream>, + pub capabilities: Capabilities, +} + +pub struct ConnectionOptions { + pub collation: u8, + pub database: Option, + pub max_packet_size: u32, + pub capabilities: Capabilities, +} + +impl MySqlClient { + pub async fn connect( + target: &TargetMySqlOptions, + mut options: ConnectionOptions, + ) -> Result { + let mut stream = + MySqlStream::new(TcpStream::connect((target.host.clone(), target.port)).await?); + + options.capabilities.remove(Capabilities::SSL); + if target.tls.mode != TlsMode::Disabled { + options.capabilities |= Capabilities::SSL; + } + + let Some(payload) = stream.recv().await? else { + return Err(MySqlError::Eof) + }; + let handshake = Handshake::decode(payload)?; + + options.capabilities &= handshake.server_capabilities; + if target.tls.mode == TlsMode::Required && !options.capabilities.contains(Capabilities::SSL) + { + return Err(MySqlError::TlsNotSupported); + } + + info!(capabilities=?options.capabilities, "Target handshake"); + + if options.capabilities.contains(Capabilities::SSL) && target.tls.mode != TlsMode::Disabled + { + let accept_invalid_certs = !target.tls.verify; + let accept_invalid_hostname = false; // ca + hostname verification + let client_config = Arc::new( + configure_tls_connector(accept_invalid_certs, accept_invalid_hostname, None) + .await?, + ); + let req = SslRequest { + collation: options.collation, + max_packet_size: options.max_packet_size, + }; + stream.push(&req, options.capabilities)?; + stream.flush().await?; + stream = stream + .upgrade(( + target + .host + .as_str() + .try_into() + .map_err(|_| MySqlError::InvalidDomainName)?, + client_config, + )) + .await?; + info!("Target connection upgraded to TLS"); + } + + let mut response = HandshakeResponse { + auth_plugin: None, + auth_response: None, + collation: options.collation, + database: options.database, + max_packet_size: options.max_packet_size, + username: target.username.clone(), + }; + + if handshake.auth_plugin == Some(AuthPlugin::MySqlNativePassword) { + let scramble_bytes = [ + &handshake.auth_plugin_data.first_ref()[..], + &handshake.auth_plugin_data.last_ref()[..], + ] + .concat(); + match scramble_bytes.try_into() as Result<[u8; 20], Vec> { + Err(scramble_bytes) => { + warn!("Invalid scramble length ({})", scramble_bytes.len()); + } + Ok(scramble) => { + response.auth_plugin = Some(AuthPlugin::MySqlNativePassword); + response.auth_response = Some( + BytesMut::from( + compute_auth_challenge_response( + scramble, + target.password.as_deref().unwrap_or(""), + ) + .map_err(MySqlError::other)? + .as_bytes(), + ) + .freeze(), + ); + trace!(response=?response.auth_response, ?scramble, "auth"); + } + } + } + + stream.push(&response, options.capabilities)?; + stream.flush().await?; + + let Some(response) = stream.recv().await? else { + return Err(MySqlError::Eof) + }; + if response.get(0) == Some(&0) || response.get(0) == Some(&0xfe) { + debug!("Authorized"); + } else if response.get(0) == Some(&0xff) { + let error = ErrPacket::decode_with(response, options.capabilities)?; + return Err(MySqlError::ProtocolError(format!( + "handshake failed: {:?}", + error + ))); + } else { + return Err(MySqlError::ProtocolError(format!( + "unknown response type {:?}", + response.get(0) + ))); + } + + stream.reset_sequence_id(); + + Ok(Self { + stream, + capabilities: options.capabilities, + }) + } +} diff --git a/warpgate-protocol-mysql/src/common.rs b/warpgate-protocol-mysql/src/common.rs new file mode 100644 index 0000000..a80b191 --- /dev/null +++ b/warpgate-protocol-mysql/src/common.rs @@ -0,0 +1,25 @@ +use sha1::Digest; +use warpgate_common::ProtocolName; + +pub const PROTOCOL_NAME: ProtocolName = "MySQL"; + +pub fn compute_auth_challenge_response( + challenge: [u8; 20], + password: &str, +) -> Result { + password_hash::Output::new( + &{ + let password_sha: [u8; 20] = sha1::Sha1::digest(password).into(); + let password_sha_sha: [u8; 20] = sha1::Sha1::digest(password_sha).into(); + let password_seed_2sha_sha: [u8; 20] = + sha1::Sha1::digest([challenge, password_sha_sha].concat()).into(); + + let mut result = password_sha; + result + .iter_mut() + .zip(password_seed_2sha_sha.iter()) + .for_each(|(x1, x2)| *x1 ^= *x2); + result + }[..], + ) +} diff --git a/warpgate-protocol-mysql/src/error.rs b/warpgate-protocol-mysql/src/error.rs new file mode 100644 index 0000000..3a96f57 --- /dev/null +++ b/warpgate-protocol-mysql/src/error.rs @@ -0,0 +1,50 @@ +use std::error::Error; + +use warpgate_common::WarpgateError; +use warpgate_database_protocols::error::Error as SqlxError; + +use crate::stream::MySqlStreamError; +use crate::tls::{MaybeTlsStreamError, RustlsSetupError}; + +#[derive(thiserror::Error, Debug)] +pub enum MySqlError { + #[error("protocol error: {0}")] + ProtocolError(String), + #[error("sudden disconnection")] + Eof, + #[error("server doesn't offer TLS")] + TlsNotSupported, + #[error("client doesn't support TLS")] + TlsNotSupportedByClient, + #[error("TLS setup failed: {0}")] + TlsSetup(#[from] RustlsSetupError), + #[error("TLS stream error: {0}")] + Tls(#[from] MaybeTlsStreamError), + #[error("Invalid domain name")] + InvalidDomainName, + #[error("sqlx error: {0}")] + Sqlx(#[from] SqlxError), + #[error("MySQL stream error: {0}")] + MySqlStream(#[from] MySqlStreamError), + #[error("I/O: {0}")] + Io(#[from] std::io::Error), + #[error("packet decode error: {0}")] + Decode(Box), + #[error(transparent)] + Warpgate(#[from] WarpgateError), + #[error(transparent)] + Other(Box), +} + +impl MySqlError { + pub fn other(err: E) -> Self { + Self::Other(Box::new(err)) + } + + pub fn decode(err: SqlxError) -> Self { + match err { + SqlxError::Decode(err) => Self::Decode(err), + _ => Self::Sqlx(err), + } + } +} diff --git a/warpgate-protocol-mysql/src/lib.rs b/warpgate-protocol-mysql/src/lib.rs new file mode 100644 index 0000000..d41a97f --- /dev/null +++ b/warpgate-protocol-mysql/src/lib.rs @@ -0,0 +1,108 @@ +#![feature(type_alias_impl_trait, let_else, try_blocks)] +mod client; +mod common; +mod error; +mod session; +mod session_handle; +mod stream; +mod tls; +use std::fmt::Debug; +use std::net::SocketAddr; + +use anyhow::{Context, Result}; +use async_trait::async_trait; +use rustls::ServerConfig; +use tokio::net::TcpListener; +use tracing::*; +use warpgate_common::{ProtocolServer, Services, SessionStateInit, Target, TargetTestError}; + +use crate::session::MySqlSession; +use crate::session_handle::MySqlSessionHandle; +use crate::tls::FromCertificateAndKey; + +pub struct MySQLProtocolServer { + services: Services, +} + +impl MySQLProtocolServer { + pub async fn new(services: &Services) -> Result { + Ok(MySQLProtocolServer { + services: services.clone(), + }) + } +} + +#[async_trait] +impl ProtocolServer for MySQLProtocolServer { + async fn run(self, address: SocketAddr) -> Result<()> { + let (certificate, key) = { + let config = self.services.config.lock().await; + let certificate_path = config + .paths_relative_to + .join(&config.store.mysql.certificate); + let key_path = config.paths_relative_to.join(&config.store.mysql.key); + + ( + std::fs::read(&certificate_path).with_context(|| { + format!( + "reading SSL certificate from '{}'", + certificate_path.display() + ) + })?, + std::fs::read(&key_path).with_context(|| { + format!("reading SSL private key from '{}'", key_path.display()) + })?, + ) + }; + + let tls_config = ServerConfig::try_from_certificate_and_key(certificate, key)?; + + info!(?address, "Listening"); + let listener = TcpListener::bind(address).await?; + loop { + let (stream, remote_address) = listener.accept().await?; + let tls_config = tls_config.clone(); + let services = self.services.clone(); + tokio::spawn(async move { + let (session_handle, mut abort_rx) = MySqlSessionHandle::new(); + + let server_handle = services + .state + .lock() + .await + .register_session( + &crate::common::PROTOCOL_NAME, + SessionStateInit { + remote_address: Some(remote_address), + handle: Box::new(session_handle), + }, + ) + .await?; + + let session = MySqlSession::new(server_handle, services, stream, tls_config).await; + let span = session.make_logging_span(); + tokio::select! { + result = session.run().instrument(span) => match result { + Ok(_) => info!("Session ended"), + Err(e) => error!(error=%e, "Session failed"), + }, + _ = abort_rx.recv() => { + warn!("Session aborted by admin"); + }, + } + + Ok::<(), anyhow::Error>(()) + }); + } + } + + async fn test_target(self, _target: Target) -> Result<(), TargetTestError> { + Ok(()) + } +} + +impl Debug for MySQLProtocolServer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "MySQLProtocolServer") + } +} diff --git a/warpgate-protocol-mysql/src/session.rs b/warpgate-protocol-mysql/src/session.rs new file mode 100644 index 0000000..b9d6299 --- /dev/null +++ b/warpgate-protocol-mysql/src/session.rs @@ -0,0 +1,426 @@ +use std::sync::Arc; + +use bytes::{Buf, Bytes, BytesMut}; +use rand::Rng; +use rustls::ServerConfig; +use tokio::net::TcpStream; +use tokio::sync::Mutex; +use tracing::*; +use uuid::Uuid; +use warpgate_common::auth::AuthSelector; +use warpgate_common::helpers::rng::get_crypto_rng; +use warpgate_common::{ + authorize_ticket, AuthCredential, AuthResult, Secret, Services, TargetMySqlOptions, + TargetOptions, WarpgateServerHandle, +}; +use warpgate_database_protocols::io::{BufExt, Decode}; +use warpgate_database_protocols::mysql::protocol::auth::AuthPlugin; +use warpgate_database_protocols::mysql::protocol::connect::{ + AuthSwitchRequest, Handshake, HandshakeResponse, +}; +use warpgate_database_protocols::mysql::protocol::response::{ErrPacket, OkPacket, Status}; +use warpgate_database_protocols::mysql::protocol::text::Query; +use warpgate_database_protocols::mysql::protocol::Capabilities; + +use crate::client::{ConnectionOptions, MySqlClient}; +use crate::error::MySqlError; +use crate::stream::MySqlStream; + +pub struct MySqlSession { + stream: MySqlStream>, + capabilities: Capabilities, + challenge: [u8; 20], + username: Option, + database: Option, + tls_config: Arc, + server_handle: Arc>, + id: Uuid, + services: Services, +} + +impl MySqlSession { + pub async fn new( + server_handle: Arc>, + services: Services, + stream: TcpStream, + tls_config: ServerConfig, + ) -> Self { + let id = server_handle.lock().await.id(); + Self { + services, + stream: MySqlStream::new(stream), + capabilities: Capabilities::PROTOCOL_41 + | Capabilities::PLUGIN_AUTH + | Capabilities::FOUND_ROWS + | Capabilities::LONG_FLAG + | Capabilities::NO_SCHEMA + | Capabilities::PLUGIN_AUTH_LENENC_DATA + | Capabilities::CONNECT_WITH_DB + | Capabilities::SESSION_TRACK + | Capabilities::IGNORE_SPACE + | Capabilities::INTERACTIVE + | Capabilities::TRANSACTIONS + | Capabilities::DEPRECATE_EOF + | Capabilities::SECURE_CONNECTION + | Capabilities::SSL, + challenge: get_crypto_rng().gen(), + tls_config: Arc::new(tls_config), + username: None, + database: None, + server_handle, + id, + } + } + + pub fn make_logging_span(&self) -> tracing::Span { + match self.username { + Some(ref username) => info_span!("MySQL", session=%self.id, session_username=%username), + None => info_span!("MySQL", session=%self.id), + } + } + + pub async fn run(mut self) -> Result<(), MySqlError> { + let mut challenge_1 = BytesMut::from(&self.challenge[..]); + let challenge_2 = challenge_1.split_off(8); + let challenge_chain = challenge_1.freeze().chain(challenge_2.freeze()); + + let handshake = Handshake { + protocol_version: 10, + server_version: "8.0.0-Warpgate".to_owned(), + connection_id: 1, + auth_plugin_data: challenge_chain, + server_capabilities: self.capabilities, + server_default_collation: 45, + status: Status::empty(), + auth_plugin: Some(AuthPlugin::MySqlNativePassword), + }; + self.stream.push(&handshake, ())?; + self.stream.flush().await?; + + let resp = loop { + let Some(payload) = self.stream.recv().await? else { + return Err(MySqlError::Eof); + }; + let resp = HandshakeResponse::decode_with(payload, &mut self.capabilities) + .map_err(MySqlError::decode)?; + + trace!(?resp, "Handshake response"); + info!(capabilities=?self.capabilities, username=%resp.username, "User handshake"); + + if self.capabilities.contains(Capabilities::SSL) { + if self.stream.is_tls() { + break resp; + } + self.stream = self.stream.upgrade(self.tls_config.clone()).await?; + continue; + } else { + self.send_error(1002, "Warpgate requires TLS - please enable it in your client: add `--ssl` on the CLI or add `?sslMode=PREFERRED` to your database URI").await?; + return Err(MySqlError::TlsNotSupportedByClient); + } + }; + + if resp.auth_plugin == Some(AuthPlugin::MySqlClearPassword) { + if let Some(mut response) = resp.auth_response.clone() { + let password = Secret::new(response.get_str_nul()?); + return self.run_authorization(resp, password).await; + } + } + + let req = AuthSwitchRequest { + plugin: AuthPlugin::MySqlClearPassword, + data: Bytes::new(), + }; + self.stream.push(&req, ())?; + + // self.push(&RawBytes::< + self.stream.flush().await?; + + let Some(response) = &self.stream.recv().await? else { + return Err(MySqlError::Eof); + }; + let password = Secret::new(response.clone().get_str_nul()?); + return self.run_authorization(resp, password).await; + } + + async fn send_error(&mut self, code: u16, message: &str) -> Result<(), MySqlError> { + self.stream.push( + &ErrPacket { + error_code: code, + error_message: message.to_owned(), + sql_state: None, + }, + (), + )?; + self.stream.flush().await?; + Ok(()) + } + + pub async fn run_authorization( + mut self, + handshake: HandshakeResponse, + password: Secret, + ) -> Result<(), MySqlError> { + let selector: AuthSelector = (&handshake.username).into(); + + async fn fail(this: &mut MySqlSession) -> Result<(), MySqlError> { + this.stream.push( + &ErrPacket { + error_code: 1, + error_message: "Warpgate access denied".to_owned(), + sql_state: None, + }, + (), + )?; + this.stream.flush().await?; + Ok(()) + } + + let credentials = vec![AuthCredential::Password(password)]; + match selector { + AuthSelector::User { + username, + target_name, + } => { + let user_auth_result: AuthResult = { + self.services + .config_provider + .lock() + .await + .authorize(&username, &credentials, crate::common::PROTOCOL_NAME) + .await + .map_err(MySqlError::other)? + }; + + match user_auth_result { + AuthResult::Accepted { username } => { + let target_auth_result = { + self.services + .config_provider + .lock() + .await + .authorize_target(&username, &target_name) + .await + .map_err(MySqlError::other)? + }; + if !target_auth_result { + warn!( + "Target {} not authorized for user {}", + target_name, username + ); + return fail(&mut self).await; + } + return self.run_authorized(handshake, username, target_name).await; + } + AuthResult::Rejected | AuthResult::OtpNeeded => { + return fail(&mut self).await; + } + } + } + AuthSelector::Ticket { secret } => { + match authorize_ticket(&self.services.db, &secret) + .await + .map_err(MySqlError::other)? + { + Some(ticket) => { + info!("Authorized for {} with a ticket", ticket.target); + self.services + .config_provider + .lock() + .await + .consume_ticket(&ticket.id) + .await + .map_err(MySqlError::other)?; + + return self + .run_authorized(handshake, ticket.username, ticket.target) + .await; + } + _ => return fail(&mut self).await, + } + } + } + } + + async fn run_authorized( + mut self, + handshake: HandshakeResponse, + username: String, + target_name: String, + ) -> Result<(), MySqlError> { + self.stream.push( + &OkPacket { + affected_rows: 0, + last_insert_id: 0, + status: Status::empty(), + warnings: 0, + }, + (), + )?; + self.stream.flush().await?; + + info!(%username, "Authenticated"); + + let target = { + self.services + .config + .lock() + .await + .store + .targets + .iter() + .filter_map(|t| match t.options { + TargetOptions::MySql(ref options) => Some((t, options)), + _ => None, + }) + .find(|(t, _)| t.name == target_name) + .map(|(t, opt)| (t.clone(), opt.clone())) + }; + + let Some((target, mysql_options)) = target else { + warn!("Selected target not found"); + self.stream.push( + &ErrPacket { + error_code: 1, + error_message: "Warpgate access denied".to_owned(), + sql_state: None, + }, + (), + )?; + self.stream.flush().await?; + return Ok(()); + }; + + { + let handle = self.server_handle.lock().await; + handle.set_username(username).await?; + handle.set_target(&target).await?; + } + + let span = self.make_logging_span(); + return self + .run_authorized_inner(handshake, mysql_options) + .instrument(span) + .await; + } + + async fn run_authorized_inner( + mut self, + handshake: HandshakeResponse, + options: TargetMySqlOptions, + ) -> Result<(), MySqlError> { + self.database = handshake.database.clone(); + self.username = Some(handshake.username); + if let Some(ref database) = handshake.database { + info!("Selected database: {database}"); + } + + let mut client = match MySqlClient::connect( + &options, + ConnectionOptions { + collation: handshake.collation, + database: handshake.database, + max_packet_size: handshake.max_packet_size, + capabilities: self.capabilities, + }, + ) + .await + { + Err(error) => { + error!(%error, "Target connection failed"); + self.send_error(1045, "Access denied").await?; + Err(error) + } + x => x, + }?; + + loop { + self.stream.reset_sequence_id(); + client.stream.reset_sequence_id(); + let Some(payload) = self.stream.recv().await? else { + break; + }; + trace!(?payload, "server got packet"); + + let com = payload.get(0); + + // COM_QUERY + if com == Some(&0x03) { + let query = Query::decode(payload)?; + info!(query=%query.0, "SQL"); + + client.stream.push(&query, ())?; + client.stream.flush().await?; + + let mut eof_ctr = 0; + loop { + let Some(response) = client.stream.recv().await? else { + return Err(MySqlError::Eof); + }; + trace!(?response, "client got packet"); + self.stream.push(&&response[..], ())?; + self.stream.flush().await?; + if let Some(com) = response.get(0) { + if com == &0xfe { + if self.capabilities.contains(Capabilities::DEPRECATE_EOF) { + break; + } + eof_ctr += 1; + if eof_ctr == 2 { + // todo check multiple results + break; + } + } + if com == &0 || com == &0xff { + break; + } + } + } + // COM_QUIT + } else if com == Some(&0x01) { + break; + // COM_INIT_DB + } else if com == Some(&0x02) { + let mut buf = payload.clone(); + buf.advance(1); + let db = buf.get_str(buf.len())?; + self.database = Some(db.clone()); + info!("Selected database: {db}"); + client.stream.push(&&payload[..], ())?; + client.stream.flush().await?; + self.passthrough_until_result(&mut client).await?; + // COM_FIELD_LIST, COM_PING, COM_RESET_CONNECTION + } else if com == Some(&0x04) || com == Some(&0x0e) || com == Some(&0x1f) { + client.stream.push(&&payload[..], ())?; + client.stream.flush().await?; + self.passthrough_until_result(&mut client).await?; + } else if let Some(com) = com { + warn!("Unknown packet type {com}"); + self.send_error(1047, "Not implemented").await?; + } else { + break; + } + } + + Ok(()) + } + + async fn passthrough_until_result( + &mut self, + client: &mut MySqlClient, + ) -> Result<(), MySqlError> { + loop { + let Some(response) = client.stream.recv().await? else{ + return Err(MySqlError::Eof); + }; + trace!(?response, "client got packet"); + self.stream.push(&&response[..], ())?; + self.stream.flush().await?; + if let Some(com) = response.get(0) { + if com == &0 || com == &0xff || com == &0xfe { + break; + } + } + } + Ok(()) + } +} diff --git a/warpgate-protocol-mysql/src/session_handle.rs b/warpgate-protocol-mysql/src/session_handle.rs new file mode 100644 index 0000000..0cf6544 --- /dev/null +++ b/warpgate-protocol-mysql/src/session_handle.rs @@ -0,0 +1,19 @@ +use tokio::sync::mpsc; +use warpgate_common::SessionHandle; + +pub struct MySqlSessionHandle { + abort_tx: mpsc::UnboundedSender<()>, +} + +impl MySqlSessionHandle { + pub fn new() -> (Self, mpsc::UnboundedReceiver<()>) { + let (abort_tx, abort_rx) = mpsc::unbounded_channel(); + (MySqlSessionHandle { abort_tx }, abort_rx) + } +} + +impl SessionHandle for MySqlSessionHandle { + fn close(&mut self) { + let _ = self.abort_tx.send(()); + } +} diff --git a/warpgate-protocol-mysql/src/stream.rs b/warpgate-protocol-mysql/src/stream.rs new file mode 100644 index 0000000..f429c71 --- /dev/null +++ b/warpgate-protocol-mysql/src/stream.rs @@ -0,0 +1,100 @@ +use bytes::{Bytes, BytesMut}; +use mysql_common::proto::codec::error::PacketCodecError; +use mysql_common::proto::codec::PacketCodec; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio::net::TcpStream; +use tracing::*; +use warpgate_database_protocols::io::Encode; + +use crate::tls::{MaybeTlsStream, MaybeTlsStreamError, UpgradableStream}; + +#[derive(thiserror::Error, Debug)] +pub enum MySqlStreamError { + #[error("packet codec error: {0}")] + Codec(#[from] PacketCodecError), + #[error("I/O: {0}")] + Io(#[from] std::io::Error), +} + +pub struct MySqlStream +where + TcpStream: UpgradableStream, + TS: AsyncRead + AsyncWrite + Unpin, +{ + stream: MaybeTlsStream, + codec: PacketCodec, + inbound_buffer: BytesMut, + outbound_buffer: BytesMut, +} + +impl MySqlStream +where + TcpStream: UpgradableStream, + TS: AsyncRead + AsyncWrite + Unpin, +{ + pub fn new(stream: TcpStream) -> Self { + Self { + stream: MaybeTlsStream::new(stream), + codec: PacketCodec::default(), + inbound_buffer: BytesMut::new(), + outbound_buffer: BytesMut::new(), + } + } + + pub fn push<'a, C, P: Encode<'a, C>>( + &mut self, + packet: &'a P, + context: C, + ) -> Result<(), MySqlStreamError> { + let mut buf = vec![]; + packet.encode_with(&mut buf, context); + self.codec.encode(&mut &*buf, &mut self.outbound_buffer)?; + Ok(()) + } + + pub async fn flush(&mut self) -> std::io::Result<()> { + trace!(outbound_buffer=?self.outbound_buffer, "sending"); + self.stream.write_all(&self.outbound_buffer[..]).await?; + self.outbound_buffer = BytesMut::new(); + self.stream.flush().await?; + Ok(()) + } + + pub async fn recv(&mut self) -> Result, MySqlStreamError> { + let mut payload = BytesMut::new(); + loop { + { + let got_full_packet = self.codec.decode(&mut self.inbound_buffer, &mut payload)?; + if got_full_packet { + trace!(?payload, "received"); + return Ok(Some(payload.freeze())); + } + } + let read_bytes = self.stream.read_buf(&mut self.inbound_buffer).await?; + if read_bytes == 0 { + return Ok(None); + } + trace!(inbound_buffer=?self.inbound_buffer, "received chunk"); + } + } + + pub fn reset_sequence_id(&mut self) { + self.codec.reset_seq_id(); + } + + pub async fn upgrade( + mut self, + config: >::UpgradeConfig, + ) -> Result { + self.stream = self.stream.upgrade(config).await?; + Ok(self) + } + + pub fn is_tls(&self) -> bool { + match self.stream { + MaybeTlsStream::Raw(_) => false, + MaybeTlsStream::Tls(_) => true, + MaybeTlsStream::Upgrading => false, + } + } +} diff --git a/warpgate-protocol-mysql/src/tls/maybe_tls_stream.rs b/warpgate-protocol-mysql/src/tls/maybe_tls_stream.rs new file mode 100644 index 0000000..c6672e5 --- /dev/null +++ b/warpgate-protocol-mysql/src/tls/maybe_tls_stream.rs @@ -0,0 +1,155 @@ +use std::pin::Pin; +use std::sync::Arc; +use std::task::Poll; + +use async_trait::async_trait; +use rustls::{ClientConfig, ServerConfig, ServerName}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tracing::*; + +#[derive(thiserror::Error, Debug)] +pub enum MaybeTlsStreamError { + #[error("stream is already upgraded")] + AlreadyUpgraded, + #[error("I/O: {0}")] + Io(#[from] std::io::Error), +} + +#[async_trait] +pub trait UpgradableStream +where + Self: Sized, + T: AsyncRead + AsyncWrite + Unpin, +{ + type UpgradeConfig; + async fn upgrade(self, config: Self::UpgradeConfig) -> Result; +} + +pub enum MaybeTlsStream +where + S: AsyncRead + AsyncWrite + Unpin + UpgradableStream, + TS: AsyncRead + AsyncWrite + Unpin, +{ + Tls(TS), + Raw(S), + Upgrading, +} + +impl MaybeTlsStream +where + S: AsyncRead + AsyncWrite + Unpin + UpgradableStream, + TS: AsyncRead + AsyncWrite + Unpin, +{ + pub fn new(stream: S) -> Self { + Self::Raw(stream) + } +} + +impl MaybeTlsStream +where + S: AsyncRead + AsyncWrite + Unpin + UpgradableStream, + TS: AsyncRead + AsyncWrite + Unpin, +{ + pub async fn upgrade( + mut self, + tls_config: S::UpgradeConfig, + ) -> Result { + if let Self::Raw(stream) = std::mem::replace(&mut self, Self::Upgrading) { + let stream = stream.upgrade(tls_config).await?; + Ok(MaybeTlsStream::Tls(stream)) + } else { + Err(MaybeTlsStreamError::AlreadyUpgraded) + } + } +} + +impl AsyncRead for MaybeTlsStream +where + S: AsyncRead + AsyncWrite + Unpin + UpgradableStream, + TS: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match self.get_mut() { + MaybeTlsStream::Tls(tls) => Pin::new(tls).poll_read(cx, buf), + MaybeTlsStream::Raw(stream) => Pin::new(stream).poll_read(cx, buf), + _ => unreachable!(), + } + } +} + +impl AsyncWrite for MaybeTlsStream +where + S: AsyncRead + AsyncWrite + Unpin + UpgradableStream, + TS: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + match self.get_mut() { + MaybeTlsStream::Tls(tls) => Pin::new(tls).poll_write(cx, buf), + MaybeTlsStream::Raw(stream) => Pin::new(stream).poll_write(cx, buf), + _ => unreachable!(), + } + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match self.get_mut() { + MaybeTlsStream::Tls(tls) => Pin::new(tls).poll_flush(cx), + MaybeTlsStream::Raw(stream) => Pin::new(stream).poll_flush(cx), + _ => unreachable!(), + } + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match self.get_mut() { + MaybeTlsStream::Tls(tls) => Pin::new(tls).poll_shutdown(cx), + MaybeTlsStream::Raw(stream) => Pin::new(stream).poll_shutdown(cx), + _ => unreachable!(), + } + } +} + +#[async_trait] +impl UpgradableStream> for S +where + S: AsyncRead + AsyncWrite + Unpin + Send, +{ + type UpgradeConfig = (ServerName, Arc); + + async fn upgrade( + mut self, + config: Self::UpgradeConfig, + ) -> Result, MaybeTlsStreamError> { + let (domain, tls_config) = config; + let connector = tokio_rustls::TlsConnector::from(tls_config); + Ok(connector.connect(domain, self).await?) + } +} + +#[async_trait] +impl UpgradableStream> for S +where + S: AsyncRead + AsyncWrite + Unpin + Send, +{ + type UpgradeConfig = Arc; + + async fn upgrade( + mut self, + tls_config: Self::UpgradeConfig, + ) -> Result, MaybeTlsStreamError> { + let acceptor = tokio_rustls::TlsAcceptor::from(tls_config); + Ok(acceptor.accept(self).await?) + } +} diff --git a/warpgate-protocol-mysql/src/tls/mod.rs b/warpgate-protocol-mysql/src/tls/mod.rs new file mode 100644 index 0000000..2f857d3 --- /dev/null +++ b/warpgate-protocol-mysql/src/tls/mod.rs @@ -0,0 +1,5 @@ +mod maybe_tls_stream; +mod rustls_helpers; + +pub use maybe_tls_stream::{MaybeTlsStream, MaybeTlsStreamError, UpgradableStream}; +pub use rustls_helpers::{configure_tls_connector, FromCertificateAndKey, RustlsSetupError}; diff --git a/warpgate-protocol-mysql/src/tls/rustls_helpers.rs b/warpgate-protocol-mysql/src/tls/rustls_helpers.rs new file mode 100644 index 0000000..cb1bb51 --- /dev/null +++ b/warpgate-protocol-mysql/src/tls/rustls_helpers.rs @@ -0,0 +1,174 @@ +use std::io::Cursor; +use std::sync::Arc; +use std::time::SystemTime; + +use rustls::client::{ServerCertVerified, ServerCertVerifier, WebPkiVerifier}; +use rustls::server::{ClientHello, NoClientAuth, ResolvesServerCert}; +use rustls::sign::CertifiedKey; +use rustls::{ + Certificate, ClientConfig, Error as TlsError, OwnedTrustAnchor, PrivateKey, RootCertStore, + ServerConfig, ServerName, +}; + +#[derive(thiserror::Error, Debug)] +pub enum RustlsSetupError { + #[error("rustls")] + Rustls(#[from] rustls::Error), + #[error("sign")] + Sign(#[from] rustls::sign::SignError), + #[error("no private keys in key file")] + NoKeys, + #[error("I/O")] + Io(#[from] std::io::Error), + #[error("PKI")] + Pki(#[from] webpki::Error), +} + +pub trait FromCertificateAndKey +where + Self: Sized, +{ + fn try_from_certificate_and_key(cert: Vec, key: Vec) -> Result; +} + +impl FromCertificateAndKey for rustls::ServerConfig { + fn try_from_certificate_and_key( + cert: Vec, + key_bytes: Vec, + ) -> Result { + let certificates = rustls_pemfile::certs(&mut &cert[..]).map(|mut certs| { + certs + .drain(..) + .map(Certificate) + .collect::>() + })?; + + let mut key = rustls_pemfile::pkcs8_private_keys(&mut key_bytes.as_slice())? + .drain(..) + .next() + .map(PrivateKey); + + if key.is_none() { + key = rustls_pemfile::rsa_private_keys(&mut key_bytes.as_slice())? + .drain(..) + .next() + .map(PrivateKey); + } + + let key = key.ok_or(RustlsSetupError::NoKeys)?; + let key = rustls::sign::any_supported_type(&key)?; + + let cert_key = Arc::new(CertifiedKey { + cert: certificates, + key, + ocsp: None, + sct_list: None, + }); + + Ok(ServerConfig::builder() + .with_safe_defaults() + .with_client_cert_verifier(NoClientAuth::new()) + .with_cert_resolver(Arc::new(ResolveServerCert(cert_key)))) + } +} + +struct ResolveServerCert(Arc); + +impl ResolvesServerCert for ResolveServerCert { + fn resolve(&self, _: ClientHello) -> Option> { + Some(self.0.clone()) + } +} + +pub async fn configure_tls_connector( + accept_invalid_certs: bool, + accept_invalid_hostnames: bool, + root_cert: Option<&[u8]>, +) -> Result { + let config = ClientConfig::builder().with_safe_defaults(); + + let config = if accept_invalid_certs { + config + .with_custom_certificate_verifier(Arc::new(DummyTlsVerifier)) + .with_no_client_auth() + } else { + let mut cert_store = RootCertStore::empty(); + cert_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| { + OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + ) + })); + + if let Some(data) = root_cert { + let mut cursor = Cursor::new(data); + + for cert in rustls_pemfile::certs(&mut cursor)? { + cert_store.add(&rustls::Certificate(cert))?; + } + } + + if accept_invalid_hostnames { + let verifier = WebPkiVerifier::new(cert_store, None); + + config + .with_custom_certificate_verifier(Arc::new(NoHostnameTlsVerifier { verifier })) + .with_no_client_auth() + } else { + config + .with_root_certificates(cert_store) + .with_no_client_auth() + } + }; + + Ok(config) +} + +struct DummyTlsVerifier; + +impl ServerCertVerifier for DummyTlsVerifier { + fn verify_server_cert( + &self, + _end_entity: &rustls::Certificate, + _intermediates: &[rustls::Certificate], + _server_name: &ServerName, + _scts: &mut dyn Iterator, + _ocsp_response: &[u8], + _now: SystemTime, + ) -> Result { + Ok(ServerCertVerified::assertion()) + } +} + +pub struct NoHostnameTlsVerifier { + verifier: WebPkiVerifier, +} + +impl ServerCertVerifier for NoHostnameTlsVerifier { + fn verify_server_cert( + &self, + end_entity: &rustls::Certificate, + intermediates: &[rustls::Certificate], + server_name: &ServerName, + scts: &mut dyn Iterator, + ocsp_response: &[u8], + now: SystemTime, + ) -> Result { + match self.verifier.verify_server_cert( + end_entity, + intermediates, + server_name, + scts, + ocsp_response, + now, + ) { + Err(TlsError::InvalidCertificateData(reason)) + if reason.contains("CertNotValidForName") => + { + Ok(ServerCertVerified::assertion()) + } + res => res, + } + } +} diff --git a/warpgate-protocol-ssh/Cargo.toml b/warpgate-protocol-ssh/Cargo.toml index 0e06a14..1ca3e2b 100644 --- a/warpgate-protocol-ssh/Cargo.toml +++ b/warpgate-protocol-ssh/Cargo.toml @@ -17,8 +17,9 @@ russh-keys = {version = "0.22.0-beta.3", features = ["openssl"]} sea-orm = {version = "^0.9", features = ["runtime-tokio-native-tls"], default-features = false} thiserror = "1.0" time = "0.3" -tokio = {version = "1.19", features = ["tracing", "signal"]} +tokio = {version = "1.20", features = ["tracing", "signal"]} tracing = "0.1" uuid = {version = "1.0", features = ["v4"]} warpgate-common = {version = "*", path = "../warpgate-common"} warpgate-db-entities = {version = "*", path = "../warpgate-db-entities"} +zeroize="^1.5" diff --git a/warpgate-protocol-ssh/src/server/session.rs b/warpgate-protocol-ssh/src/server/session.rs index 6aff545..3c71b1b 100644 --- a/warpgate-protocol-ssh/src/server/session.rs +++ b/warpgate-protocol-ssh/src/server/session.rs @@ -1003,7 +1003,7 @@ impl ServerSession { } async fn _auth_accept(&mut self, username: &str, target_name: &str) { - info!(username = username, "Authenticated"); + info!(%username, "Authenticated"); let _ = self .server_handle @@ -1031,7 +1031,7 @@ impl ServerSession { let Some((target, ssh_options)) = target else { self.target = TargetSelection::NotFound(target_name.to_string()); - info!("Selected target not found"); + warn!("Selected target not found"); return; }; diff --git a/warpgate-web/package.json b/warpgate-web/package.json index 0ba9724..0fd763c 100644 --- a/warpgate-web/package.json +++ b/warpgate-web/package.json @@ -13,8 +13,8 @@ "postinstall": "yarn run openapi:client:gateway && yarn run openapi:client:admin", "openapi:schema:gateway": "cargo run -p warpgate-protocol-http > src/gateway/lib/openapi-schema.json", "openapi:schema:admin": "cargo run -p warpgate-admin > src/admin/lib/openapi-schema.json", - "openapi:client:gateway": "openapi-generator-cli generate -g typescript-fetch -i src/gateway/lib/openapi-schema.json -o src/gateway/lib/api-client -p npmName=warpgate-gateway-api-client -p useSingleRequestParameter=true && cd src/gateway/lib/api-client && npm i && yarn tsc --target esnext --module esnext && rm -rf src", - "openapi:client:admin": "openapi-generator-cli generate -g typescript-fetch -i src/admin/lib/openapi-schema.json -o src/admin/lib/api-client -p npmName=warpgate-admin-api-client -p useSingleRequestParameter=true && cd src/admin/lib/api-client && npm i && yarn tsc --target esnext --module esnext && rm -rf src", + "openapi:client:gateway": "openapi-generator-cli generate -g typescript-fetch -i src/gateway/lib/openapi-schema.json -o src/gateway/lib/api-client -p npmName=warpgate-gateway-api-client -p useSingleRequestParameter=true && cd src/gateway/lib/api-client && npm i typescript@3.5 && npm i && yarn tsc --target esnext --module esnext && rm -rf src", + "openapi:client:admin": "openapi-generator-cli generate -g typescript-fetch -i src/admin/lib/openapi-schema.json -o src/admin/lib/api-client -p npmName=warpgate-admin-api-client -p useSingleRequestParameter=true && cd src/admin/lib/api-client && npm i typescript@3.5 && npm i && yarn tsc --target esnext --module esnext && rm -rf src", "openapi": "yarn run openapi:schema:admin && yarn run openapi:schema:gateway && yarn run openapi:client:admin && yarn run openapi:client:gateway" }, "devDependencies": { diff --git a/warpgate-web/src/admin/CreateTicket.svelte b/warpgate-web/src/admin/CreateTicket.svelte index 16c46a0..e4a5b37 100644 --- a/warpgate-web/src/admin/CreateTicket.svelte +++ b/warpgate-web/src/admin/CreateTicket.svelte @@ -1,6 +1,7 @@ {#if error} @@ -42,7 +34,7 @@ $: targetURL = selectedTarget ? makeTargetURL(selectedTarget.name) : '' {#each targets as target} { - if (target.options.kind !== 'TargetWebAdminOptions') { + if (target.options.kind !== 'WebAdmin') { selectedTarget = target } }}> @@ -50,10 +42,16 @@ $: targetURL = selectedTarget ? makeTargetURL(selectedTarget.name) : '' {target.name} - {#if target.options.kind === 'TargetSSHOptions'} + {#if target.options.kind === 'Http'} + HTTP + {/if} + {#if target.options.kind === 'MySql'} + MySQL + {/if} + {#if target.options.kind === 'Ssh'} SSH {/if} - {#if target.options.kind === 'TargetWebAdminOptions'} + {#if target.options.kind === 'WebAdmin'} This web admin interface {/if} @@ -67,17 +65,21 @@ $: targetURL = selectedTarget ? makeTargetURL(selectedTarget.name) : '' {selectedTarget?.name}
- {#if selectedTarget?.options.kind === 'TargetSSHOptions'} + {#if selectedTarget?.options.kind === 'MySql'} + MySQL target + {/if} + {#if selectedTarget?.options.kind === 'Ssh'} SSH target {/if} - {#if selectedTarget?.options.kind === 'TargetWebAdminOptions'} + {#if selectedTarget?.options.kind === 'WebAdmin'} This web admin interface {/if}

Access instructions

- {#if selectedTarget?.options.kind === 'TargetSSHOptions'} + + {#if selectedTarget?.options.kind === 'Ssh' || selectedTarget?.options.kind === 'MySql'} {#if users} {/if} - - - - - - - - - - {/if} - {#if selectedTarget?.options.kind === 'TargetHTTPOptions'} - - - - - {/if} +
{/if} diff --git a/warpgate-web/src/admin/lib/openapi-schema.json b/warpgate-web/src/admin/lib/openapi-schema.json index 49bb4c0..45ce395 100644 --- a/warpgate-web/src/admin/lib/openapi-schema.json +++ b/warpgate-web/src/admin/lib/openapi-schema.json @@ -2,7 +2,7 @@ "openapi": "3.0.0", "info": { "title": "Warpgate Web Admin", - "version": "0.2.5" + "version": "0.3.0" }, "servers": [ { @@ -59,28 +59,7 @@ "content": { "application/json": { "schema": { - "type": "object", - "required": [ - "items", - "offset", - "total" - ], - "properties": { - "items": { - "type": "array", - "items": { - "$ref": "#/components/schemas/SessionSnapshot" - } - }, - "offset": { - "type": "integer", - "format": "uint64" - }, - "total": { - "type": "integer", - "format": "uint64" - } - } + "$ref": "#/components/schemas/PaginatedSessionSnapshot" } } } @@ -509,6 +488,30 @@ } } }, + "PaginatedSessionSnapshot": { + "type": "object", + "required": [ + "items", + "offset", + "total" + ], + "properties": { + "items": { + "type": "array", + "items": { + "$ref": "#/components/schemas/SessionSnapshot" + } + }, + "offset": { + "type": "integer", + "format": "uint64" + }, + "total": { + "type": "integer", + "format": "uint64" + } + } + }, "Recording": { "type": "object", "required": [ @@ -668,74 +671,139 @@ } } }, + "TargetMySqlOptions": { + "type": "object", + "required": [ + "host", + "port", + "username", + "tls", + "verify_tls" + ], + "properties": { + "host": { + "type": "string" + }, + "port": { + "type": "integer", + "format": "uint16" + }, + "username": { + "type": "string" + }, + "password": { + "type": "string" + }, + "tls": { + "$ref": "#/components/schemas/Tls" + }, + "verify_tls": { + "type": "boolean" + } + } + }, "TargetOptions": { "type": "object", - "anyOf": [ + "oneOf": [ { - "required": [ - "kind" - ], - "allOf": [ - { - "$ref": "#/components/schemas/TargetSSHOptions" - }, - { - "type": "object", - "title": "TargetSSHOptions", - "properties": { - "kind": { - "type": "string", - "example": "TargetSSHOptions" - } - } - } - ] + "$ref": "#/components/schemas/TargetOptionsTargetSSHOptions" }, { - "required": [ - "kind" - ], - "allOf": [ - { - "$ref": "#/components/schemas/TargetHTTPOptions" - }, - { - "type": "object", - "title": "TargetHTTPOptions", - "properties": { - "kind": { - "type": "string", - "example": "TargetHTTPOptions" - } - } - } - ] + "$ref": "#/components/schemas/TargetOptionsTargetHTTPOptions" }, { - "required": [ - "kind" - ], - "allOf": [ - { - "$ref": "#/components/schemas/TargetWebAdminOptions" - }, - { - "type": "object", - "title": "TargetWebAdminOptions", - "properties": { - "kind": { - "type": "string", - "example": "TargetWebAdminOptions" - } - } - } - ] + "$ref": "#/components/schemas/TargetOptionsTargetMySqlOptions" + }, + { + "$ref": "#/components/schemas/TargetOptionsTargetWebAdminOptions" } ], "discriminator": { - "propertyName": "kind" + "propertyName": "kind", + "mapping": { + "Ssh": "#/components/schemas/TargetOptionsTargetSSHOptions", + "Http": "#/components/schemas/TargetOptionsTargetHTTPOptions", + "MySql": "#/components/schemas/TargetOptionsTargetMySqlOptions", + "WebAdmin": "#/components/schemas/TargetOptionsTargetWebAdminOptions" + } } }, + "TargetOptionsTargetHTTPOptions": { + "allOf": [ + { + "type": "object", + "required": [ + "kind" + ], + "properties": { + "kind": { + "type": "string", + "example": "Http" + } + } + }, + { + "$ref": "#/components/schemas/TargetHTTPOptions" + } + ] + }, + "TargetOptionsTargetMySqlOptions": { + "allOf": [ + { + "type": "object", + "required": [ + "kind" + ], + "properties": { + "kind": { + "type": "string", + "example": "MySql" + } + } + }, + { + "$ref": "#/components/schemas/TargetMySqlOptions" + } + ] + }, + "TargetOptionsTargetSSHOptions": { + "allOf": [ + { + "type": "object", + "required": [ + "kind" + ], + "properties": { + "kind": { + "type": "string", + "example": "Ssh" + } + } + }, + { + "$ref": "#/components/schemas/TargetSSHOptions" + } + ] + }, + "TargetOptionsTargetWebAdminOptions": { + "allOf": [ + { + "type": "object", + "required": [ + "kind" + ], + "properties": { + "kind": { + "type": "string", + "example": "WebAdmin" + } + } + }, + { + "$ref": "#/components/schemas/TargetWebAdminOptions" + } + ] + }, "TargetSSHOptions": { "type": "object", "required": [ @@ -807,6 +875,29 @@ } } }, + "Tls": { + "type": "object", + "required": [ + "mode", + "verify" + ], + "properties": { + "mode": { + "$ref": "#/components/schemas/TlsMode" + }, + "verify": { + "type": "boolean" + } + } + }, + "TlsMode": { + "type": "string", + "enum": [ + "Disabled", + "Preferred", + "Required" + ] + }, "UserSnapshot": { "type": "object", "required": [ diff --git a/warpgate-web/src/common/ConnectionInstructions.svelte b/warpgate-web/src/common/ConnectionInstructions.svelte new file mode 100644 index 0000000..b33b353 --- /dev/null +++ b/warpgate-web/src/common/ConnectionInstructions.svelte @@ -0,0 +1,60 @@ + + +{#if targetKind === TargetKind.Ssh} + + + + + + + + + +{/if} + +{#if targetKind === TargetKind.Http} + + + + +{/if} + +{#if targetKind === TargetKind.MySql} + + + + + + + + + + + + + + + + + Make sure you've set your client to require TLS and allowed cleartext password authentication. + +{/if} diff --git a/warpgate-web/src/common/mysql.ts b/warpgate-web/src/common/mysql.ts new file mode 100644 index 0000000..f346c71 --- /dev/null +++ b/warpgate-web/src/common/mysql.ts @@ -0,0 +1,13 @@ +import type { Info } from 'gateway/lib/api' + +export function makeMySQLUsername (targetName?: string, username?: string): string { + return `${username ?? 'username'}#${targetName ?? 'target'}` +} + +export function makeExampleMySQLCommand (targetName?: string, username?: string, serverInfo?: Info): string { + return `mysql -u ${makeMySQLUsername(targetName, username)} --host ${serverInfo?.externalHost ?? 'warpgate-host'} --port ${serverInfo?.ports.mysql ?? 'warpgate-mysql-port'} -p --ssl` +} + +export function makeExampleMySQLURI (targetName?: string, username?: string, serverInfo?: Info): string { + return `mysql://${makeMySQLUsername(targetName, username)}:@${serverInfo?.externalHost ?? 'warpgate-host'}:${serverInfo?.ports.mysql ?? 'warpgate-mysql-port'}?sslMode=required` +} diff --git a/warpgate-web/src/common/ssh.ts b/warpgate-web/src/common/ssh.ts index 16547ac..d33f6b7 100644 --- a/warpgate-web/src/common/ssh.ts +++ b/warpgate-web/src/common/ssh.ts @@ -5,5 +5,5 @@ export function makeSSHUsername (targetName?: string, username?: string): string } export function makeExampleSSHCommand (targetName?: string, username?: string, serverInfo?: Info): string { - return `ssh ${makeSSHUsername(targetName, username)}@${serverInfo?.externalHost ?? 'warpgate-host'} -p ${serverInfo?.ports.ssh ?? 'warpgate-port'}` + return `ssh ${makeSSHUsername(targetName, username)}@${serverInfo?.externalHost ?? 'warpgate-host'} -p ${serverInfo?.ports.ssh ?? 'warpgate-ssh-port'}` } diff --git a/warpgate-web/src/gateway/TargetList.svelte b/warpgate-web/src/gateway/TargetList.svelte index 540afe9..b830729 100644 --- a/warpgate-web/src/gateway/TargetList.svelte +++ b/warpgate-web/src/gateway/TargetList.svelte @@ -1,21 +1,16 @@