From cd9e83fac0c37d2e95189ca4f3529c3680641694 Mon Sep 17 00:00:00 2001 From: lhear Date: Tue, 28 Apr 2026 23:10:14 +0800 Subject: [PATCH] feat: implement e2e encryption and split-stream transport protocol - Implement X25519 + ML-KEM-768 hybrid KEM with HKDF for AES-256-GCM authenticated encryption of cookies and data frames. - Refactor transport to a dual-stream architecture, utilizing separate POST handles for downlink and batch-based uplink. - Enable encrypted stream cookies to securely transmit target routing and ephemeral key state. - Integrate a ManualResolver in the client to allow binding to specific remote addresses. - Add a new CLI command to generate cryptographic X25519 keypairs. --- .github/workflows/build.yml | 29 +- CONFIGURATION.md | 15 +- Cargo.lock | 387 ++++++++++++++++++- Cargo.toml | 17 +- src/bin/client.rs | 89 +++++ src/bin/server.rs | 81 ++++ src/{bypass.rs => bypass/mod.rs} | 0 src/client.rs | 300 --------------- src/client/connection.rs | 311 ++++++++++++++++ src/client/constants.rs | 16 + src/client/handshake.rs | 372 +++++++++++++++++++ src/client/mod.rs | 48 +++ src/client/proxy.rs | 76 ++++ src/client/state.rs | 42 +++ src/client/tunnel.rs | 234 ++++++++++++ src/client/utils.rs | 108 ++++++ src/config/mod.rs | 108 ++++++ src/crypto/cipher.rs | 202 ++++++++++ src/crypto/handshake.rs | 90 +++++ src/crypto/keys.rs | 135 +++++++ src/crypto/mod.rs | 15 + src/{dns.rs => dns/mod.rs} | 0 src/error/mod.rs | 166 +++++++++ src/lib.rs | 9 + src/{log.rs => log/mod.rs} | 0 src/server.rs | 369 ------------------ src/server/connection.rs | 157 ++++++++ src/server/constants.rs | 29 ++ src/server/handlers.rs | 616 ++++++++++++++++++++++++++++++ src/server/janitor.rs | 45 +++ src/server/mod.rs | 179 +++++++++ src/server/state.rs | 129 +++++++ src/server/utils.rs | 78 ++++ src/shaper.rs | 246 ------------ src/shaper/mod.rs | 618 +++++++++++++++++++++++++++++++ 35 files changed, 4389 insertions(+), 927 deletions(-) create mode 100644 src/bin/client.rs create mode 100644 src/bin/server.rs rename src/{bypass.rs => bypass/mod.rs} (100%) delete mode 100644 src/client.rs create mode 100644 src/client/connection.rs create mode 100644 src/client/constants.rs create mode 100644 src/client/handshake.rs create mode 100644 src/client/mod.rs create mode 100644 src/client/proxy.rs create mode 100644 src/client/state.rs create mode 100644 src/client/tunnel.rs create mode 100644 src/client/utils.rs create mode 100644 src/config/mod.rs create mode 100644 src/crypto/cipher.rs create mode 100644 src/crypto/handshake.rs create mode 100644 src/crypto/keys.rs create mode 100644 src/crypto/mod.rs rename src/{dns.rs => dns/mod.rs} (100%) create mode 100644 src/error/mod.rs create mode 100644 src/lib.rs rename src/{log.rs => log/mod.rs} (100%) delete mode 100644 src/server.rs create mode 100644 src/server/connection.rs create mode 100644 src/server/constants.rs create mode 100644 src/server/handlers.rs create mode 100644 src/server/janitor.rs create mode 100644 src/server/mod.rs create mode 100644 src/server/state.rs create mode 100644 src/server/utils.rs delete mode 100644 src/shaper.rs create mode 100644 src/shaper/mod.rs diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 3a7c170..d7614ba 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -19,8 +19,21 @@ jobs: with: command: check licenses - linux-gnu: + test: needs: lint-audit + name: cargo test --lib + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + - name: Rust Cache + uses: swatinem/rust-cache@v2 + with: + key: test-ubuntu + - name: Run unit tests + run: cargo test --lib + + linux-gnu: + needs: [lint-audit, test] name: Build Linux (GNU) runs-on: ubuntu-latest strategy: @@ -29,7 +42,8 @@ jobs: - target: x86_64 target_triple: x86_64-unknown-linux-gnu apt_packages: "" - custom_env: {} + custom_env: + RUSTFLAGS: -C target-cpu=x86-64-v3 - target: arm64 target_triple: aarch64-unknown-linux-gnu apt_packages: crossbuild-essential-arm64 @@ -69,7 +83,7 @@ jobs: path: dist/ linux-musl: - needs: lint-audit + needs: [lint-audit, test] name: Build Linux (musl) runs-on: ubuntu-latest strategy: @@ -83,6 +97,7 @@ jobs: CXX: x86_64-linux-musl-g++ CARGO_TARGET_X86_64_UNKNOWN_LINUX_MUSL_LINKER: x86_64-linux-musl-g++ RUSTC_LINKER: x86_64-linux-musl-g++ + RUSTFLAGS: -C target-cpu=x86-64-v3 - target: arm64 target_triple: aarch64-unknown-linux-musl package: aarch64-linux-musl-cross @@ -124,7 +139,7 @@ jobs: path: dist/ windows: - needs: lint-audit + needs: [lint-audit, test] name: Build Windows runs-on: windows-latest strategy: @@ -132,6 +147,8 @@ jobs: include: - target: x86_64 target_triple: x86_64-pc-windows-msvc + custom_env: + RUSTFLAGS: -C target-cpu=x86-64-v3 - target: arm64 target_triple: aarch64-pc-windows-msvc steps: @@ -162,7 +179,7 @@ jobs: path: dist/ darwin: - needs: lint-audit + needs: [lint-audit, test] name: Build Darwin runs-on: macos-latest strategy: @@ -196,7 +213,7 @@ jobs: path: dist/ android: - needs: lint-audit + needs: [lint-audit, test] name: Build Android runs-on: ubuntu-latest strategy: diff --git a/CONFIGURATION.md b/CONFIGURATION.md index ac8cf03..b1ca09c 100644 --- a/CONFIGURATION.md +++ b/CONFIGURATION.md @@ -22,6 +22,16 @@ Generate a secure bearer token. **The secret used here must match the `secret` i ./server gen-token --secret "my_secret_key" --user "admin" --exp 1768281600 ``` +## Keypair Generation + +Generate an X25519 keypair for end-to-end encryption. **The public key will be used in the client configuration, while the private key must be kept secure on the server.** + +> **Note**: End-to-end encryption is optional. If the `public_key` is not configured in the client, encryption will be disabled. + +```bash +./server gen-key +``` + ## Client Configuration `config.toml`: @@ -30,6 +40,8 @@ Generate a secure bearer token. **The secret used here must match the `secret` i [client] listen = "127.0.0.1:8080" remote = "https://your-server-domain/YOUR_SECRET_PATH" +# address = "your-server-ip" +# public_key = "your-public-key" [auth] token = "your-token" @@ -85,6 +97,7 @@ padding_threshold = 2000 [server] listen = "/dev/shm/httproxy.sock" path = "/YOUR_SECRET_PATH" +# private_key = "your-private-key" [auth] secret = "my_secret_key" @@ -191,7 +204,7 @@ server { proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_request_buffering off; proxy_http_version 1.1; - client_max_body_size 0; + client_max_body_size 1m; proxy_buffering off; proxy_buffer_size 16k; proxy_buffers 2 16k; diff --git a/Cargo.lock b/Cargo.lock index a0ccd7d..4e84ad1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8,6 +8,41 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" +[[package]] +name = "aead" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0" +dependencies = [ + "crypto-common 0.1.7", + "generic-array", +] + +[[package]] +name = "aes" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" +dependencies = [ + "cfg-if", + "cipher", + "cpufeatures 0.2.17", +] + +[[package]] +name = "aes-gcm" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "831010a0f742e1209b3bcea8fab6a8e149051ba6099432c8cb2cc117dec3ead1" +dependencies = [ + "aead", + "aes", + "cipher", + "ctr", + "ghash", + "subtle", +] + [[package]] name = "ahash" version = "0.8.12" @@ -225,6 +260,12 @@ dependencies = [ "syn", ] +[[package]] +name = "base122-fast" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2361be42fbd11eefcef8fec542469d77cdea96a05eb92c3d7e61c20bffa89f3d" + [[package]] name = "base64" version = "0.22.1" @@ -255,6 +296,15 @@ version = "2.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c4512299f36f043ab09a583e57bceb5a5aab7a73db1805848e8fef3c9e8c78b3" +[[package]] +name = "block-buffer" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cdd35008169921d80bc60d3d0ab416eecb028c4cd653352907921d95084790be" +dependencies = [ + "hybrid-array", +] + [[package]] name = "boring-sys2" version = "5.0.0-alpha.13" @@ -347,7 +397,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6f8d983286843e49675a4b7a2d174efe136dc93a18d69130dd18198a6c167601" dependencies = [ "cfg-if", - "cpufeatures", + "cpufeatures 0.3.0", "rand_core 0.10.1", ] @@ -362,6 +412,16 @@ dependencies = [ "windows-link", ] +[[package]] +name = "cipher" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" +dependencies = [ + "crypto-common 0.1.7", + "inout", +] + [[package]] name = "clang-sys" version = "1.8.1" @@ -422,6 +482,12 @@ dependencies = [ "cc", ] +[[package]] +name = "cmov" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f88a43d011fc4a6876cb7344703e297c71dda42494fee094d5f7c76bf13f746" + [[package]] name = "colorchoice" version = "1.0.5" @@ -437,12 +503,27 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "const-oid" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6ef517f0926dd24a1582492c791b6a4818a4d94e789a334894aa15b0d12f55c" + [[package]] name = "core-foundation-sys" version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + [[package]] name = "cpufeatures" version = "0.3.0" @@ -485,6 +566,85 @@ version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" +[[package]] +name = "crypto-common" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" +dependencies = [ + "generic-array", + "rand_core 0.6.4", + "typenum", +] + +[[package]] +name = "crypto-common" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77727bb15fa921304124b128af125e7e3b968275d1b108b379190264f4423710" +dependencies = [ + "hybrid-array", + "rand_core 0.10.1", +] + +[[package]] +name = "ctr" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0369ee1ad671834580515889b80f2ea915f23b8be8d0daa4bbaf2ac5c7590835" +dependencies = [ + "cipher", +] + +[[package]] +name = "ctutils" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d5515a3834141de9eafb9717ad39eea8247b5674e6066c404e8c4b365d2a29e" +dependencies = [ + "cmov", +] + +[[package]] +name = "curve25519-dalek" +version = "4.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97fb8b7c4503de7d6ae7b42ab72a5a59857b4c937ec27a3d4539dba95b5ab2be" +dependencies = [ + "cfg-if", + "cpufeatures 0.2.17", + "curve25519-dalek-derive", + "fiat-crypto", + "rustc_version", + "subtle", + "zeroize", +] + +[[package]] +name = "curve25519-dalek-derive" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f46882e17999c6cc590af592290432be3bce0428cb0d5f8b6715e4dc7b383eb3" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "deranged" version = "0.5.8" @@ -494,6 +654,18 @@ dependencies = [ "powerfmt", ] +[[package]] +name = "digest" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4850db49bf08e663084f7fb5c87d202ef91a3907271aff24a94eb97ff039153c" +dependencies = [ + "block-buffer", + "const-oid", + "crypto-common 0.2.1", + "ctutils", +] + [[package]] name = "displaydoc" version = "0.2.5" @@ -570,6 +742,12 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "fiat-crypto" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28dea519a9695b9977216879a3ebfddf92f1c08c05d984f8996aecd6ecdc811d" + [[package]] name = "find-msvc-tools" version = "0.1.9" @@ -744,6 +922,16 @@ dependencies = [ "slab", ] +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "getrandom" version = "0.2.17" @@ -781,6 +969,16 @@ dependencies = [ "wasip3", ] +[[package]] +name = "ghash" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0d8a4362ccb29cb0b265253fb0a2728f592895ee6854fd9bc13f2ffda266ff1" +dependencies = [ + "opaque-debug", + "polyval", +] + [[package]] name = "glob" version = "0.3.3" @@ -842,6 +1040,24 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "hkdf" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4aaa26c720c68b866f2c96ef5c1264b3e6f473fe5d4ce61cd44bbe913e553018" +dependencies = [ + "hmac", +] + +[[package]] +name = "hmac" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6303bc9732ae41b04cb554b844a762b4115a61bfaa81e3e83050991eeb56863f" +dependencies = [ + "digest", +] + [[package]] name = "http" version = "1.4.0" @@ -911,21 +1127,28 @@ checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" name = "httproxy" version = "0.1.16" dependencies = [ + "aes-gcm", "anyhow", "axum", + "base122-fast", + "base64", "bytes", "clap", + "crypto-common 0.2.1", + "dashmap", "domain", "fst", "futures", "h2", + "hkdf", "http", - "http-body", "http-body-util", "httparse", "ip_network", "ip_network_table", "jsonwebtoken", + "memchr", + "ml-kem", "moka", "pin-project-lite", "rand 0.10.1", @@ -933,6 +1156,7 @@ dependencies = [ "rustls", "serde", "serde_json", + "sha2", "singleflight-async", "tokio", "tokio-rustls", @@ -945,9 +1169,23 @@ dependencies = [ "tracing-appender", "tracing-subscriber", "url", + "uuid", "webpki-roots", "wreq", "wreq-util", + "x25519-dalek", + "zeroize", +] + +[[package]] +name = "hybrid-array" +version = "0.4.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08d46837a0ed51fe95bd3b05de33cd64a1ee88fc797477ca48446872504507c5" +dependencies = [ + "ctutils", + "typenum", + "zeroize", ] [[package]] @@ -1131,6 +1369,15 @@ dependencies = [ "serde_core", ] +[[package]] +name = "inout" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "879f10e63c20629ecabbb64a8010319738c66a5cd0c29b02d63d272b03751d01" +dependencies = [ + "generic-array", +] + [[package]] name = "ip_network" version = "0.4.1" @@ -1217,6 +1464,26 @@ dependencies = [ "simple_asn1", ] +[[package]] +name = "keccak" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e24a010dd405bd7ed803e5253182815b41bf2e6a80cc3bfc066658e03a198aa" +dependencies = [ + "cfg-if", + "cpufeatures 0.3.0", +] + +[[package]] +name = "kem" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01737161ba802849cfd486b5bd209d38ba4943494c249a8126005170c7621edd" +dependencies = [ + "crypto-common 0.2.1", + "rand_core 0.10.1", +] + [[package]] name = "lazy_static" version = "1.5.0" @@ -1326,6 +1593,32 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "ml-kem" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68c77d5ff6d755d09a0ef4d4d28c2b7e83658fe83e8c736d55e93d43e380d1cd" +dependencies = [ + "hybrid-array", + "kem", + "module-lattice", + "rand_core 0.10.1", + "sha3", + "zeroize", +] + +[[package]] +name = "module-lattice" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc7c90d33a0dac244570c26461d761ffaeadb3bfc2b17cc625ae2185cafdffae" +dependencies = [ + "ctutils", + "hybrid-array", + "num-traits", + "zeroize", +] + [[package]] name = "moka" version = "0.12.15" @@ -1421,6 +1714,12 @@ version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" +[[package]] +name = "opaque-debug" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381" + [[package]] name = "openssl-macros" version = "0.1.1" @@ -1489,6 +1788,18 @@ version = "0.3.33" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "19f132c84eca552bf34cab8ec81f1c1dcc229b811638f9d283dceabe58c5569e" +[[package]] +name = "polyval" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d1fe60d06143b2430aa532c94cfe9e29783047f06c0d7fd359a9a51b729fa25" +dependencies = [ + "cfg-if", + "cpufeatures 0.2.17", + "opaque-debug", + "universal-hash", +] + [[package]] name = "portable-atomic" version = "1.13.1" @@ -1674,6 +1985,15 @@ version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94300abf3f1ae2e2b8ffb7b58043de3d399c73fa6f4b73826402a5c457614dbe" +[[package]] +name = "rustc_version" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" +dependencies = [ + "semver", +] + [[package]] name = "rustls" version = "0.23.38" @@ -1820,6 +2140,27 @@ dependencies = [ "serde", ] +[[package]] +name = "sha2" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "446ba717509524cb3f22f17ecc096f10f4822d76ab5c0b9822c5f9c284e825f4" +dependencies = [ + "cfg-if", + "cpufeatures 0.3.0", + "digest", +] + +[[package]] +name = "sha3" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be176f1a57ce4e3d31c1a166222d9768de5954f811601fb7ca06fc8203905ce1" +dependencies = [ + "digest", + "keccak", +] + [[package]] name = "sharded-slab" version = "0.1.7" @@ -2323,6 +2664,12 @@ dependencies = [ "syn", ] +[[package]] +name = "typenum" +version = "1.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40ce102ab67701b8526c123c1bab5cbe42d7040ccfd0f64af1a385808d2f43de" + [[package]] name = "unicode-ident" version = "1.0.24" @@ -2335,6 +2682,16 @@ version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" +[[package]] +name = "universal-hash" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc1de2c688dc15305988b563c3854064043356019f97a4b46276fe734c4f07ea" +dependencies = [ + "crypto-common 0.1.7", + "subtle", +] + [[package]] name = "untrusted" version = "0.7.1" @@ -2837,6 +3194,18 @@ version = "0.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1ffae5123b2d3fc086436f8834ae3ab053a283cfac8fe0a0b8eaae044768a4c4" +[[package]] +name = "x25519-dalek" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7e468321c81fb07fa7f4c636c3972b9100f0346e5b6a9f2bd0603a52f7ed277" +dependencies = [ + "curve25519-dalek", + "rand_core 0.6.4", + "serde", + "zeroize", +] + [[package]] name = "yoke" version = "0.8.2" @@ -2906,6 +3275,20 @@ name = "zeroize" version = "1.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" +dependencies = [ + "zeroize_derive", +] + +[[package]] +name = "zeroize_derive" +version = "1.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85a5b4158499876c763cb03bc4e49185d3cccbabb15b33c627f7884f43db852e" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] [[package]] name = "zerotrie" diff --git a/Cargo.toml b/Cargo.toml index bd7f226..2e60d84 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,11 +6,11 @@ version = "0.1.16" [[bin]] name = "server" -path = "src/server.rs" +path = "src/bin/server.rs" [[bin]] name = "client" -path = "src/client.rs" +path = "src/bin/client.rs" [profile.release] codegen-units = 1 @@ -19,21 +19,28 @@ panic = "abort" strip = true [dependencies] +aes-gcm = "0.10.3" anyhow = "1.0" axum = {version = "0.8.9", features = ["http2", "macros"]} +base122-fast = "0.1.3" +base64 = "0.22.1" bytes = "1.11" clap = {version = "4.6.1", features = ["derive"]} +crypto-common = "0.2.1" +dashmap = "6.1.0" domain = "0.11" fst = "0.4.7" futures = "0.3" h2 = "0.4" +hkdf = "0.13.0" http = "1.4" -http-body = "1.0" http-body-util = "0.1" httparse = "1.10" ip_network = "0.4.1" ip_network_table = "0.2.0" jsonwebtoken = {version = "10.2", features = ["aws_lc_rs"]} +memchr = "2.8.0" +ml-kem = {version = "0.3", features = ["zeroize"]} moka = {version = "0.12", features = ["future"]} pin-project-lite = "0.2" rand = "0.10" @@ -41,6 +48,7 @@ rand_distr = "0.6" rustls = "0.23.38" serde = {version = "1.0", features = ["derive"]} serde_json = "1.0.149" +sha2 = "0.11.0" singleflight-async = "0.2" tokio = {version = "1.52.1", features = ["rt-multi-thread"]} tokio-rustls = "0.26" @@ -57,6 +65,9 @@ tracing-subscriber = {version = "0.3", features = [ "json", ]} url = "2.5" +uuid = "1.23.1" webpki-roots = "1.0.7" wreq = "6.0.0-rc.28" wreq-util = "3.0.0-rc.10" +x25519-dalek = {version = "2.0", features = ["static_secrets", "getrandom"]} +zeroize = "1.8.2" diff --git a/src/bin/client.rs b/src/bin/client.rs new file mode 100644 index 0000000..7d86248 --- /dev/null +++ b/src/bin/client.rs @@ -0,0 +1,89 @@ +use anyhow::Context; +use clap::Parser; +use std::fs; +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::time::Duration; +use tokio::net::TcpListener; +use tracing::{Instrument, error_span, info, warn}; + +static NEXT_SPAN_ID: AtomicU64 = AtomicU64::new(1); + +#[derive(Parser, Debug)] +#[command(version, about, long_about = None)] +struct Cli { + #[arg(short = 'c', long, default_value = "config.toml")] + config: String, +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let cli = Cli::parse(); + let config_str = fs::read_to_string(&cli.config)?; + let config: httproxy::config::ClientTopConfig = toml::from_str(&config_str)?; + + let _guard = httproxy::log::init_tracing(&config.log.clone().unwrap_or_default()); + + let state = httproxy::client::build_state(&config)?; + + let addr: std::net::SocketAddr = config.client.listen.parse()?; + let listener = TcpListener::bind(addr).await?; + + let remote_url: url::Url = state.remote_str.parse()?; + let domain = remote_url.host_str().context("No domain in remote URL")?; + let final_addr = config + .client + .address + .clone() + .unwrap_or_else(|| domain.to_string()); + + let http_client = Arc::new( + wreq::Client::builder() + .tcp_nodelay(true) + .tcp_keepalive(Duration::from_secs(45)) + .tcp_keepalive_interval(Duration::from_secs(45)) + .pool_idle_timeout(Duration::from_secs(300)) + .pool_max_idle_per_host(6) + .emulation(wreq_util::Emulation::Chrome143) + .no_proxy() + .dns_resolver(Arc::new(httproxy::client::state::ManualResolver { + target_addr: final_addr, + })) + .build()?, + ); + + info!(listen = %addr, "proxy listening"); + + loop { + let (socket, peer) = match listener.accept().await { + Ok(conn) => conn, + Err(e) => { + warn!(reason = %e, "accept failed"); + continue; + } + }; + + let http_client = Arc::clone(&http_client); + let state = Arc::clone(&state); + + let span_id = NEXT_SPAN_ID.fetch_add(1, Ordering::Relaxed); + + tokio::spawn( + async move { + if let Err(e) = + httproxy::client::connection::handle_connection(socket, http_client, state) + .await + && !httproxy::client::utils::is_silent_error(e.root_cause()) + { + warn!(reason = %e, "connection aborted"); + } + } + .instrument(error_span!( + "session", + id = span_id, + client = %peer, + target = tracing::field::Empty, + )), + ); + } +} diff --git a/src/bin/server.rs b/src/bin/server.rs new file mode 100644 index 0000000..d9309e7 --- /dev/null +++ b/src/bin/server.rs @@ -0,0 +1,81 @@ +use clap::Parser; +use serde::{Deserialize, Serialize}; +use std::fs; + +#[derive(Parser, Debug)] +#[command(version, about, long_about = None)] +struct Cli { + #[arg(short = 'c', long, default_value = "config.toml")] + config: String, + #[command(subcommand)] + command: Option, +} + +#[derive(clap::Subcommand, Debug)] +enum Commands { + #[command(about = "Generate a JWT token")] + GenToken { + #[arg(short, long, help = "Secret key for signing")] + secret: String, + #[arg(short, long, help = "Username or Subject")] + user: String, + #[arg(short, long, help = "Expiration timestamp (Unix)")] + exp: u64, + }, + #[command(about = "Generate an x25519 keypair")] + GenKey, +} + +#[derive(Debug, Serialize, Deserialize)] +struct Claims { + sub: String, + exp: u64, +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let cli = Cli::parse(); + + if let Some(cmd) = cli.command { + match cmd { + Commands::GenToken { secret, user, exp } => { + let token = jsonwebtoken::encode( + &jsonwebtoken::Header::default(), + &Claims { sub: user, exp }, + &jsonwebtoken::EncodingKey::from_secret(secret.as_bytes()), + )?; + println!("{token}"); + return Ok(()); + } + Commands::GenKey => { + let (sk, pk) = httproxy::crypto::generate_keypair(); + println!( + "private_key = \"{}\"", + httproxy::crypto::private_key_to_b64(&sk) + ); + println!( + "public_key = \"{}\"", + httproxy::crypto::public_key_to_b64(&pk) + ); + return Ok(()); + } + } + } + + let config_str = fs::read_to_string(&cli.config)?; + let mut config: httproxy::config::ServerTopConfig = toml::from_str(&config_str)?; + + let _guard = httproxy::log::init_tracing(&config.log.as_ref().cloned().unwrap_or_default()); + + let state = httproxy::server::build_state(&mut config).await?; + + let (_stream_jh, _master_jh) = httproxy::server::spawn_janitors(&state); + + let router = httproxy::server::build_router(state, &config.server.path); + httproxy::server::run_server(router, &config.server.listen).await?; + + drop(_stream_jh); + drop(_master_jh); + + Ok(()) +} diff --git a/src/bypass.rs b/src/bypass/mod.rs similarity index 100% rename from src/bypass.rs rename to src/bypass/mod.rs diff --git a/src/client.rs b/src/client.rs deleted file mode 100644 index e07ab04..0000000 --- a/src/client.rs +++ /dev/null @@ -1,300 +0,0 @@ -mod bypass; -mod log; -mod shaper; - -use anyhow::{Context, Result, anyhow}; -use bypass::{BypassConfig, BypassRules}; -use bytes::{Buf, BytesMut}; -use clap::Parser; -use futures::StreamExt; -use http::uri::Authority; -use http_body::Frame; -use http_body_util::{BodyExt, StreamBody}; -use rand::RngExt; -use serde::Deserialize; -use std::{ - fs, - net::SocketAddr, - sync::{ - Arc, - atomic::{AtomicU64, Ordering}, - }, - time::Duration, -}; -use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, - net::{TcpListener, TcpStream}, -}; -use tracing::{Instrument, error_span, info, warn}; -use url::Url; -use wreq::{Body, Client}; -use wreq_util::Emulation; - -static NEXT_STREAM_ID: AtomicU64 = AtomicU64::new(1); - -const CONNECT_RESPONSE: &[u8] = b"HTTP/1.1 200 Connection Established\r\n\r\n"; -const MAX_HEADER_LEN: usize = 16 * 1024; -const INITIAL_BUF_CAP: usize = 16 * 1024; -const PADDING: [u8; 32] = [b'X'; 32]; -const MIN_PADDING: usize = 16; - -#[derive(Parser, Debug)] -#[command(version, about, long_about = None)] -struct Cli { - #[arg(short = 'c', long, default_value = "config.toml")] - config: String, -} - -#[derive(Deserialize, Debug)] -struct Config { - client: ClientConfig, - auth: AuthConfig, - log: Option, - traffic_shaping: shaper::TrafficConfig, - #[serde(default)] - bypass: BypassConfig, -} - -#[derive(Deserialize, Debug)] -struct ClientConfig { - listen: String, - remote: String, -} - -#[derive(Deserialize, Debug)] -struct AuthConfig { - token: String, -} - -struct SharedState { - remote: Url, - auth_header: String, - traffic_config: shaper::TrafficConfig, - bypass: Option>, -} - -#[tokio::main] -async fn main() -> Result<()> { - let cli = Cli::parse(); - let config: Config = toml::from_str(&fs::read_to_string(&cli.config)?)?; - let _guard = log::init_tracing(&config.log.clone().unwrap_or_default()); - - run_server(&config.client.listen, Arc::new(build_state(&config)?)).await -} - -fn build_state(cfg: &Config) -> Result { - let bypass = if cfg.bypass.bypass_files.is_empty() { - None - } else { - let rules = BypassRules::load(&cfg.bypass).context("failed to load bypass rules")?; - if rules.is_empty() { - None - } else { - Some(Arc::new(rules)) - } - }; - - Ok(SharedState { - remote: cfg.client.remote.parse().context("invalid server URL")?, - auth_header: format!("Bearer {}", cfg.auth.token), - traffic_config: cfg.traffic_shaping.clone(), - bypass, - }) -} - -async fn run_server(listen: &str, state: Arc) -> Result<()> { - let addr: SocketAddr = listen.parse().context("invalid bind address")?; - let listener = TcpListener::bind(addr).await?; - - let http_client = Arc::new( - Client::builder() - .tcp_nodelay(true) - .tcp_keepalive(Duration::from_secs(45)) - .tcp_keepalive_interval(Duration::from_secs(45)) - .pool_idle_timeout(Duration::from_secs(300)) - .pool_max_idle_per_host(6) - .emulation(Emulation::Chrome143) - .no_proxy() - .build()?, - ); - info!(listen = %addr, "server started"); - - loop { - let (socket, peer) = match listener.accept().await { - Ok(conn) => conn, - Err(e) => { - warn!(reason = %e, "accept failed"); - continue; - } - }; - - let http_client = Arc::clone(&http_client); - let state = Arc::clone(&state); - let id = NEXT_STREAM_ID.fetch_add(1, Ordering::Relaxed); - - tokio::spawn( - async move { - if let Err(e) = handle_connection(socket, http_client, state).await { - if !is_silent_error(e.root_cause()) { - warn!(reason = %e, "connection aborted"); - } - } - } - .instrument(error_span!( - "session", - id, - client = %peer, - target = tracing::field::Empty, - )), - ); - } -} - -fn is_silent_error(root: &(dyn std::error::Error + 'static)) -> bool { - use std::io::ErrorKind::*; - if let Some(e) = root.downcast_ref::() { - return e.is_reset() || e.is_library(); - } - if let Some(e) = root.downcast_ref::() { - return matches!(e.kind(), ConnectionReset | UnexpectedEof | NotConnected); - } - root.to_string() - .contains("connection closed during header parsing") -} - -async fn parse_proxy_request( - reader: &mut (impl AsyncReadExt + Unpin), - buffer: &mut BytesMut, -) -> Result<(String, usize, String)> { - loop { - if reader.read_buf(buffer).await? == 0 { - return Err(anyhow!("connection closed during header parsing")); - } - let mut headers = [httparse::EMPTY_HEADER; 64]; - let mut req = httparse::Request::new(&mut headers); - if let httparse::Status::Complete(amt) = req.parse(buffer)? { - return Ok(( - req.method.context("no method")?.to_owned(), - amt, - req.path.context("no path")?.to_owned(), - )); - } - if buffer.len() > MAX_HEADER_LEN { - return Err(anyhow!("header too large")); - } - } -} - -fn resolve_target_host(method: &str, url_str: &str) -> Result { - if method == "CONNECT" { - let auth: Authority = url_str - .parse() - .map_err(|_| anyhow!("invalid target: {url_str}"))?; - let port = auth - .port_u16() - .ok_or_else(|| anyhow!("port required: {url_str}"))?; - return Ok(format!("{}:{port}", auth.host())); - } - - let url = Url::parse(url_str).context("invalid proxy URL")?; - let host = url.host_str().context("URL has no host")?; - let port = url.port_or_known_default().context("port required")?; - Ok(format!("{host}:{port}")) -} - -async fn handle_connection( - socket: TcpStream, - http_client: Arc, - state: Arc, -) -> Result<()> { - socket.set_nodelay(true)?; - let (mut read_half, mut write_half) = socket.into_split(); - - let mut buffer = BytesMut::with_capacity(INITIAL_BUF_CAP); - let (method, header_len, url) = parse_proxy_request(&mut read_half, &mut buffer).await?; - - if method == "CONNECT" { - buffer.advance(header_len); - write_half.write_all(CONNECT_RESPONSE).await?; - } - - let target_host = resolve_target_host(&method, &url)?; - tracing::Span::current().record("target", target_host.as_str()); - - if let Some(bypass) = &state.bypass { - if bypass.should_bypass(&target_host) { - info!(mode = "bypass", "direct connect"); - let payload = buffer.split().freeze(); - return handle_bypass(read_half, write_half, &target_host, payload).await; - } - } - - let payload = buffer.split().freeze(); - - info!(mode = "proxy", "connecting"); - - let mut remote_url = state.remote.clone(); - remote_url - .query_pairs_mut() - .append_pair("target", &target_host); - - let reader = AsyncReadExt::chain(std::io::Cursor::new(payload), read_half); - let body_stream = shaper::TrafficShaper::new(reader, state.traffic_config.clone()) - .map(|item| item.map(Frame::data)); - - let padding_len = rand::rng().random_range(MIN_PADDING..PADDING.len()); - - let response = http_client - .post(remote_url.as_str()) - .header("Authorization", state.auth_header.as_str()) - .header("X-Padding", &PADDING[..padding_len]) - .body(Body::wrap(StreamBody::new(body_stream))) - .send() - .await - .context("http post failed")?; - - if !response.status().is_success() { - return Err(anyhow!("upstream rejected: {}", response.status())); - } - - let mut data_stream = response.into_data_stream(); - - while let Some(chunk) = data_stream.next().await { - buffer.extend_from_slice(&chunk.context("stream read error")?); - while let Some(frame) = shaper::TrafficShaper::decode_from_buffer(&mut buffer)? { - write_half.write_all(&frame).await?; - } - } - - write_half.shutdown().await?; - Ok(()) -} - -async fn handle_bypass( - mut read_half: tokio::net::tcp::OwnedReadHalf, - mut write_half: tokio::net::tcp::OwnedWriteHalf, - target: &str, - initial_payload: bytes::Bytes, -) -> Result<()> { - let mut remote = TcpStream::connect(target) - .await - .with_context(|| format!("bypass connect to {target} failed"))?; - remote.set_nodelay(true)?; - - if !initial_payload.is_empty() { - remote.write_all(&initial_payload).await?; - } - - let (mut remote_read, mut remote_write) = remote.into_split(); - - let client_to_remote = tokio::io::copy(&mut read_half, &mut remote_write); - let remote_to_client = tokio::io::copy(&mut remote_read, &mut write_half); - - tokio::select! { - result = client_to_remote => { result.context("bypass client→remote")?; } - result = remote_to_client => { result.context("bypass remote→client")?; } - } - - let _ = write_half.shutdown().await; - Ok(()) -} diff --git a/src/client/connection.rs b/src/client/connection.rs new file mode 100644 index 0000000..6a0610a --- /dev/null +++ b/src/client/connection.rs @@ -0,0 +1,311 @@ +use anyhow::{Context, Result, anyhow}; +use bytes::{Buf, Bytes, BytesMut}; +use std::sync::Arc; +use std::time::Duration; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpStream; +use tracing::{Instrument, info, warn}; + +use crate::client::{ + constants::{CONNECT_RESPONSE, DOWNLOAD_CONNECT_TIMEOUT, EARLY_READ_WINDOW}, + handshake::{self, try_pq_connect}, + proxy, + state::SharedState, + tunnel, utils, +}; + +pub async fn handle_connection( + socket: TcpStream, + http_client: Arc, + state: Arc, +) -> Result<()> { + socket.set_nodelay(true)?; + let (mut read_half, mut write_half) = socket.into_split(); + + let mut buffer = BytesMut::with_capacity(16 * 1024); + let (method, header_len, url) = proxy::parse_proxy_request(&mut read_half, &mut buffer).await?; + + if method == "CONNECT" { + buffer.advance(header_len); + write_half.write_all(CONNECT_RESPONSE).await?; + let deadline = tokio::time::Instant::now() + EARLY_READ_WINDOW; + loop { + let remaining = crate::shaper::MAX_RAW_PAYLOAD.saturating_sub(buffer.len()); + if remaining == 0 { + break; + } + match tokio::time::timeout_at(deadline, read_half.read_buf(&mut buffer)).await { + Ok(Ok(0)) => break, + Ok(Ok(_)) => {} + _ => break, + } + } + } + + let target_host = proxy::resolve_target_host(&method, &url)?; + tracing::Span::current().record("target", target_host.as_str()); + + if let Some(ref bypass) = state.bypass + && bypass.should_bypass(&target_host) + { + info!(mode = "bypass", "direct connect"); + let payload = buffer.split().freeze(); + return handle_bypass(read_half, write_half, &target_host, payload).await; + } + + let payload = buffer.split().freeze(); + info!(mode = "proxy", "connecting"); + + let server_pk_opt = state.server_public_key; + if let Some(ref server_pk) = server_pk_opt { + handle_pq_proxy( + read_half, + write_half, + http_client, + state, + payload, + &target_host, + server_pk, + ) + .await + } else { + handle_plain_proxy( + read_half, + write_half, + http_client, + state, + payload, + &target_host, + ) + .await + } +} + +async fn handle_pq_proxy( + read_half: tokio::net::tcp::OwnedReadHalf, + write_half: tokio::net::tcp::OwnedWriteHalf, + http_client: Arc, + state: Arc, + initial_payload: Bytes, + target_host: &str, + server_pk: &x25519_dalek::PublicKey, +) -> Result<()> { + let mut read_half = Some(read_half); + let mut write_half = Some(write_half); + + { + let mut master_guard = state.initial_master.lock().await; + if let Some((session_id, master, created)) = master_guard.as_ref() { + if created.elapsed() < Duration::from_secs(1200 - 30) { + let (session_id, master) = (session_id.clone(), **master); + drop(master_guard); + match try_pq_connect( + &http_client, + &state, + &master, + &session_id, + target_host, + initial_payload.clone(), + &mut read_half, + &mut write_half, + ) + .await + { + Ok(()) => return Ok(()), + Err(e) => { + warn!("session resumption failed, falling back to full handshake: {e}"); + if read_half.is_none() { + return Err(e); + } + } + } + } else { + *master_guard = None; + } + } + } + + { + let handshake_mutex = state + .handshake_lock + .get_or_init(|| async { tokio::sync::Mutex::new(()) }) + .await; + let _guard = handshake_mutex.lock().await; + + { + let master_guard = state.initial_master.lock().await; + if let Some((session_id, master, created)) = master_guard.as_ref() + && created.elapsed() < Duration::from_secs(1200 - 30) + { + let (session_id, master) = (session_id.clone(), **master); + drop(master_guard); + match try_pq_connect( + &http_client, + &state, + &master, + &session_id, + target_host, + initial_payload.clone(), + &mut read_half, + &mut write_half, + ) + .await + { + Ok(()) => return Ok(()), + Err(e) => { + warn!( + "session resumption (post-lock) failed, falling back to full handshake: {e}" + ); + if read_half.is_none() { + return Err(e); + } + let mut mg = state.initial_master.lock().await; + if let Some((ref cur_sid, _, _)) = *mg + && cur_sid == &session_id + { + *mg = None; + } + } + } + } + } + + let rh = read_half.take().context("read half already consumed")?; + let wh = write_half.take().context("write half already consumed")?; + handshake::full_handshake( + &http_client, + &state, + server_pk, + target_host, + initial_payload, + rh, + wh, + ) + .await + } +} + +async fn handle_plain_proxy( + read_half: tokio::net::tcp::OwnedReadHalf, + write_half: tokio::net::tcp::OwnedWriteHalf, + http_client: Arc, + state: Arc, + payload: Bytes, + target_host: &str, +) -> Result<()> { + let stream_id = uuid::Uuid::new_v4().to_string(); + let mut cookie = String::new(); + utils::build_tunnel_cookie(&mut cookie, &stream_id); + + let (early_data, remaining_payload, frames_sent) = utils::encode_initial_payload( + &payload, + crate::shaper::MAX_RAW_PAYLOAD, + None, + &state.traffic_config, + )?; + + info!(target = %target_host, initial_bytes = %payload.len(), + body_len = %early_data.len(), frames_sent = %frames_sent, "connection initiated"); + + let response = tokio::time::timeout( + DOWNLOAD_CONNECT_TIMEOUT, + http_client + .post(state.remote_str.as_str()) + .header("Authorization", state.auth_header.as_str()) + .header("X-Target", target_host) + .header("Cookie", cookie) + .body(wreq::Body::from(early_data)) + .send(), + ) + .await + .context("download connect timed out")? + .context("download request failed")?; + + if !response.status().is_success() { + let status = response.status(); + let _ = response.bytes().await; + return Err(anyhow!("upstream rejected download: {status}")); + } + + let encoding = state.traffic_config.encoding_type; + let upload_client = Arc::clone(&http_client); + let upload_state = Arc::clone(&state); + let stream_id_clone = stream_id.clone(); + + let mut upload_task = tokio::spawn( + async move { + tunnel::upload_loop( + upload_client, + upload_state, + remaining_payload, + read_half, + None, + stream_id_clone, + frames_sent, + ) + .await + } + .instrument(tracing::Span::current()), + ); + + let download_fut = tunnel::download_loop(response, write_half, None, encoding); + tokio::pin!(download_fut); + + let result: Result<()> = tokio::select! { + biased; + upload_res = &mut upload_task => { + let upload_outcome: Result<()> = match upload_res { + Ok(r) => r, + Err(e) if e.is_cancelled() => Ok(()), + Err(e) => Err(anyhow!("upload task panicked: {e}")), + }; + if let Err(ref e) = upload_outcome { + warn!(reason = %e, "upload failed; aborting download"); + return upload_outcome; + } + download_fut.await + } + dl_res = &mut download_fut => { + upload_task.abort(); + let _ = upload_task.await; + if let Err(ref e) = dl_res { + warn!(reason = %e, "download failed; upload task aborted"); + } + dl_res + } + }; + result +} + +async fn handle_bypass( + mut read_half: tokio::net::tcp::OwnedReadHalf, + mut write_half: tokio::net::tcp::OwnedWriteHalf, + target: &str, + initial_payload: Bytes, +) -> Result<()> { + let mut remote = TcpStream::connect(target) + .await + .with_context(|| format!("bypass connect to {target} failed"))?; + remote.set_nodelay(true)?; + + if !initial_payload.is_empty() { + remote.write_all(&initial_payload).await?; + } + + let (mut remote_read, mut remote_write) = remote.into_split(); + + let up = async { + tokio::io::copy(&mut read_half, &mut remote_write).await?; + remote_write.shutdown().await + }; + + let down = async { + tokio::io::copy(&mut remote_read, &mut write_half).await?; + write_half.shutdown().await + }; + + let (up_res, down_res) = tokio::join!(up, down); + up_res.context("bypass client->remote")?; + down_res.context("bypass remote->client")?; + Ok(()) +} diff --git a/src/client/constants.rs b/src/client/constants.rs new file mode 100644 index 0000000..7c66415 --- /dev/null +++ b/src/client/constants.rs @@ -0,0 +1,16 @@ +use std::time::Duration; + +pub const CONNECT_RESPONSE: &[u8] = b"HTTP/1.1 200 Connection Established\r\n\r\n"; +pub const EARLY_READ_WINDOW: Duration = Duration::from_millis(2); + +pub const DOWNLOAD_CONNECT_TIMEOUT: Duration = Duration::from_secs(15); +pub const UPLOAD_REQUEST_TIMEOUT: Duration = Duration::from_secs(30); + +pub const MAX_BATCH_BYTES: usize = 1024 * 1024; +pub const MAX_IN_FLIGHT_BYTES: usize = 2 * 1024 * 1024; +pub const UPLOAD_CONCURRENCY: usize = 128; + +pub const DECODE_BUF_CAPACITY: usize = 16 * 1024 + 2396; + +pub const PADDING_POOL: &[u8] = b"padding=XXXXXXXXXXXXXXXXXXXXXXXXXX"; +pub const MIN_PADDING: usize = 16; diff --git a/src/client/handshake.rs b/src/client/handshake.rs new file mode 100644 index 0000000..e4d6497 --- /dev/null +++ b/src/client/handshake.rs @@ -0,0 +1,372 @@ +use anyhow::{Context, Result, anyhow}; +use base64::Engine; +use bytes::{Bytes, BytesMut}; +use crypto_common::KeyExport; +use futures::StreamExt; +use http_body_util::BodyExt; +use rand::RngExt; +use std::sync::Arc; +use tracing::{Instrument, info, warn}; +use zeroize::Zeroizing; + +use crate::client::tunnel; +use crate::client::utils; +use crate::crypto::{self, AesFrameCipher}; +use crate::shaper::{self, FrameCipher, MAX_RAW_PAYLOAD}; + +use super::state::SharedState; +use super::tunnel::download_loop; +use crate::client::constants::{ + DECODE_BUF_CAPACITY, DOWNLOAD_CONNECT_TIMEOUT, MIN_PADDING, PADDING_POOL, +}; + +pub async fn try_pq_connect( + http_client: &Arc, + state: &Arc, + master: &[u8; 32], + session_id: &str, + target_host: &str, + initial_payload: Bytes, + read_half: &mut Option, + write_half: &mut Option, +) -> Result<()> { + info!(session_id = %session_id, target = %target_host, "session resumption: attempting to reuse session"); + + let conn_nonce: [u8; 16] = rand::rng().random(); + let (upload_key, download_key) = crypto::derive_connection_keys(master, &conn_nonce); + let upload_cipher = Arc::new(AesFrameCipher::new(upload_key)); + let download_cipher = Arc::new(AesFrameCipher::new(download_key)); + + let cookie_master_key = crypto::derive_cookie_master_key(master); + let enc_target = crypto::encrypt_bytes(&cookie_master_key, target_host.as_bytes())?; + let enc_conn_nonce = crypto::encrypt_bytes(&cookie_master_key, &conn_nonce)?; + + let cookie_val = format!( + "{}:{}:{}", + session_id, + base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(&enc_target), + base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(&enc_conn_nonce) + ); + + let (early_data, remaining_payload, frames_sent) = utils::encode_initial_payload( + &initial_payload, + MAX_RAW_PAYLOAD, + Some(upload_cipher.as_ref() as &dyn FrameCipher), + &state.traffic_config, + )?; + + let mut session_cookie = String::new(); + utils::build_tunnel_cookie(&mut session_cookie, &cookie_val); + + let response = tokio::time::timeout( + DOWNLOAD_CONNECT_TIMEOUT, + http_client + .post(state.remote_str.as_str()) + .header("Cookie", &session_cookie) + .body(wreq::Body::from(early_data)) + .send(), + ) + .await + .context("session resumption download connect timed out")? + .context("session resumption POST failed")?; + + if response.status().as_u16() == 428 { + let _ = response.bytes().await; + return Err(anyhow!("server requests re-handshake (428)")); + } + if !response.status().is_success() { + let status = response.status(); + let _ = response.bytes().await; + return Err(anyhow!("server rejected session resumption: {status}")); + } + + let read_half = read_half + .take() + .ok_or_else(|| anyhow!("read half already consumed"))?; + let write_half = write_half + .take() + .ok_or_else(|| anyhow!("write half already consumed"))?; + + let encoding = state.traffic_config.encoding_type; + + let upload_client = Arc::clone(http_client); + let upload_state = Arc::clone(state); + let upload_cipher_clone = Arc::clone(&upload_cipher); + let session_cookie_val = cookie_val.clone(); + + let mut upload_task = tokio::spawn( + async move { + tunnel::upload_loop( + upload_client, + upload_state, + remaining_payload, + read_half, + Some(upload_cipher_clone), + session_cookie_val, + frames_sent, + ) + .await + } + .instrument(tracing::Span::current()), + ); + + let download_fut = download_loop(response, write_half, Some(download_cipher), encoding); + tokio::pin!(download_fut); + + let result: Result<()> = tokio::select! { + biased; + upload_res = &mut upload_task => { + let upload_outcome: Result<()> = match upload_res { + Ok(r) => r, + Err(e) if e.is_cancelled() => Ok(()), + Err(e) => Err(anyhow!("upload task panicked: {e}")), + }; + if let Err(ref e) = upload_outcome { + warn!(reason = %e, "upload failed; aborting download"); + return upload_outcome; + } + download_fut.await + } + dl_res = &mut download_fut => { + upload_task.abort(); + let _ = upload_task.await; + if let Err(ref e) = dl_res { + warn!(reason = %e, "download failed; upload task aborted"); + } + dl_res.context("download failed") + } + }; + result +} + +pub async fn full_handshake( + http_client: &Arc, + state: &Arc, + server_pk: &x25519_dalek::PublicKey, + target_host: &str, + initial_payload: Bytes, + read_half: tokio::net::tcp::OwnedReadHalf, + write_half: tokio::net::tcp::OwnedWriteHalf, +) -> Result<()> { + info!(target = %target_host, "PQ handshake initiated"); + + let (eph_sk_a, eph_pk_a) = crypto::generate_keypair(); + let eph_sk_a = Zeroizing::new(eph_sk_a); + let x25519_shared_a = crypto::diffie_hellman(&eph_sk_a, server_pk); + let handshake_key = crypto::derive_handshake_key(&*x25519_shared_a); + let handshake_cipher = AesFrameCipher::new(crypto::AesKey::from(*handshake_key)); + + let (kem_sk, kem_pk) = crypto::generate_mlkem_keypair(); + let kem_pk_bytes = kem_pk.to_bytes(); + + let (eph_sk_b, eph_pk_b) = crypto::generate_keypair(); + let eph_sk_b = Zeroizing::new(eph_sk_b); + let eph_pk_b_bytes = eph_pk_b.to_bytes().to_vec(); + + let mut client_hello = Vec::with_capacity(2 + kem_pk_bytes.len() + 2 + eph_pk_b_bytes.len()); + client_hello.extend_from_slice(&(kem_pk_bytes.len() as u16).to_be_bytes()); + client_hello.extend_from_slice(&kem_pk_bytes); + client_hello.extend_from_slice(&(eph_pk_b_bytes.len() as u16).to_be_bytes()); + client_hello.extend_from_slice(&eph_pk_b_bytes); + let client_hello_frame = shaper::encode_frame( + &client_hello, + 0, + Some(&handshake_cipher), + &state.traffic_config, + )?; + + let eph_pk_a_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(eph_pk_a.as_bytes()); + + let mut handshake_cookie = + String::with_capacity(9 + eph_pk_a_b64.len() + 2 + MIN_PADDING + PADDING_POOL.len()); + handshake_cookie.push_str("eph_pk_a="); + handshake_cookie.push_str(&eph_pk_a_b64); + handshake_cookie.push_str("; "); + let padding_len = rand::rng().random_range(MIN_PADDING..PADDING_POOL.len()); + handshake_cookie + .push_str(std::str::from_utf8(&PADDING_POOL[..padding_len]).expect("Invalid UTF-8")); + + info!("ClientHello ready, sending to server"); + + let response = tokio::time::timeout( + DOWNLOAD_CONNECT_TIMEOUT, + http_client + .post(state.remote_str.as_str()) + .header("Authorization", state.auth_header.as_str()) + .header("Cookie", &handshake_cookie) + .body(wreq::Body::from(client_hello_frame)) + .send(), + ) + .await + .context("handshake download connect timed out")? + .context("handshake POST failed")?; + + if !response.status().is_success() { + let status = response.status(); + let _ = response.bytes().await; + return Err(anyhow!("server rejected handshake: {status}")); + } + + let handshake_cipher_ref: &dyn FrameCipher = &handshake_cipher; + let mut body_buf = BytesMut::with_capacity(DECODE_BUF_CAPACITY); + let mut stream = response.into_data_stream(); + let server_hello_data = loop { + match stream.next().await { + Some(Ok(chunk)) => { + body_buf.extend_from_slice(&chunk); + if let Some((_, data)) = shaper::decode_from_buffer( + &mut body_buf, + Some(handshake_cipher_ref), + state.traffic_config.encoding_type, + )? { + break data; + } + } + Some(Err(e)) => return Err(e.into()), + None => return Err(anyhow!("ServerHello not received")), + } + }; + + if server_hello_data.len() < 2 { + return Err(anyhow!("ServerHello too short")); + } + + info!("ServerHello received, deriving master key"); + + let sid_len = u16::from_be_bytes([server_hello_data[0], server_hello_data[1]]) as usize; + let ct_start = 2 + sid_len; + let ct_end = ct_start + 1088; + if server_hello_data.len() < ct_end + 32 { + return Err(anyhow!("ServerHello truncated")); + } + let session_id = std::str::from_utf8(&server_hello_data[2..2 + sid_len]) + .context("invalid session_id")? + .to_owned(); + let ct_bytes = &server_hello_data[ct_start..ct_end]; + let ct: ml_kem::Ciphertext = ct_bytes + .try_into() + .map_err(|_| anyhow!("invalid ct: wrong length"))?; + let server_eph_pk_bytes: [u8; 32] = server_hello_data[ct_end..ct_end + 32].try_into().unwrap(); + let server_eph_pk = x25519_dalek::PublicKey::from(server_eph_pk_bytes); + + let master = { + let ss_mlkem = crypto::mlkem_decapsulate(&kem_sk, &ct); + let ss_x25519 = crypto::diffie_hellman(&eph_sk_b, &server_eph_pk); + crypto::derive_initial_master(&*ss_mlkem, &*ss_x25519) + }; + + info!(session_id = %session_id, "handshake complete, master key derived"); + + { + let mut lock = state.initial_master.lock().await; + *lock = Some(( + session_id.clone(), + Zeroizing::new(*master), + std::time::Instant::now(), + )); + } + + let conn_nonce: [u8; 16] = rand::rng().random(); + let (upload_key, download_key) = crypto::derive_connection_keys(&*master, &conn_nonce); + let upload_cipher = Arc::new(AesFrameCipher::new(upload_key)); + let download_cipher = Arc::new(AesFrameCipher::new(download_key)); + + let cookie_master_key = crypto::derive_cookie_master_key(&*master); + let enc_target = crypto::encrypt_bytes(&cookie_master_key, target_host.as_bytes())?; + let enc_conn_nonce = crypto::encrypt_bytes(&cookie_master_key, &conn_nonce)?; + + let cookie_val = format!( + "{}:{}:{}", + session_id, + base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(&enc_target), + base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(&enc_conn_nonce) + ); + + let mut session_cookie = String::new(); + utils::build_tunnel_cookie(&mut session_cookie, &cookie_val); + + drop(stream); + + let (early_data, remaining_payload, frames_sent) = utils::encode_initial_payload( + &initial_payload, + MAX_RAW_PAYLOAD, + Some(upload_cipher.as_ref() as &dyn FrameCipher), + &state.traffic_config, + )?; + + info!(session_id = %session_id, target = %target_host, + body_len = %early_data.len(), frames_sent = %frames_sent, "PQ tunnel established"); + + let response = tokio::time::timeout( + DOWNLOAD_CONNECT_TIMEOUT, + http_client + .post(state.remote_str.as_str()) + .header("Cookie", &session_cookie) + .body(wreq::Body::from(early_data)) + .send(), + ) + .await + .context("post-handshake download connect timed out")? + .context("post-handshake POST failed")?; + + if !response.status().is_success() { + let status = response.status(); + let _ = response.bytes().await; + return Err(anyhow!("post-handshake download rejected: {status}")); + } + + let encoding = state.traffic_config.encoding_type; + let upload_client = Arc::clone(http_client); + let upload_state = Arc::clone(state); + let upload_cipher_clone = Arc::clone(&upload_cipher); + let session_cookie_val = cookie_val.clone(); + + drop(eph_sk_a); + drop(handshake_key); + drop(kem_sk); + drop(eph_sk_b); + + let mut upload_task = tokio::spawn( + async move { + tunnel::upload_loop( + upload_client, + upload_state, + remaining_payload, + read_half, + Some(upload_cipher_clone), + session_cookie_val, + frames_sent, + ) + .await + } + .instrument(tracing::Span::current()), + ); + + let download_fut = download_loop(response, write_half, Some(download_cipher), encoding); + tokio::pin!(download_fut); + + let result: Result<()> = tokio::select! { + biased; + upload_res = &mut upload_task => { + let upload_outcome: Result<()> = match upload_res { + Ok(r) => r, + Err(e) if e.is_cancelled() => Ok(()), + Err(e) => Err(anyhow!("upload task panicked: {e}")), + }; + if let Err(ref e) = upload_outcome { + warn!(reason = %e, "upload failed; aborting download"); + return upload_outcome; + } + download_fut.await + } + dl_res = &mut download_fut => { + upload_task.abort(); + let _ = upload_task.await; + if let Err(ref e) = dl_res { + warn!(reason = %e, "download failed; upload task aborted"); + } + dl_res.context("tunnel download failed") + } + }; + result +} diff --git a/src/client/mod.rs b/src/client/mod.rs new file mode 100644 index 0000000..ec20a5f --- /dev/null +++ b/src/client/mod.rs @@ -0,0 +1,48 @@ +pub mod connection; +pub mod constants; +pub mod handshake; +pub mod proxy; +pub mod state; +pub mod tunnel; +pub mod utils; + +use crate::config::ClientTopConfig; +use crate::crypto; + +use anyhow::{Context, Result}; +use std::sync::Arc; +use tokio::sync::{Mutex, OnceCell}; + +pub fn build_state(cfg: &ClientTopConfig) -> Result> { + let bypass = if cfg.bypass.bypass_files.is_empty() { + None + } else { + let rules = + crate::bypass::BypassRules::load(&cfg.bypass).context("failed to load bypass rules")?; + if rules.is_empty() { + None + } else { + Some(Arc::new(rules)) + } + }; + + let server_public_key = cfg + .client + .public_key + .as_deref() + .map(crypto::b64_to_public_key) + .transpose()?; + + let remote: url::Url = cfg.client.remote.parse().context("invalid server URL")?; + let remote_str = remote.as_str().to_owned(); + + Ok(Arc::new(state::SharedState { + remote_str, + auth_header: format!("Bearer {}", cfg.auth.token), + traffic_config: cfg.traffic_shaping.clone(), + bypass, + server_public_key, + initial_master: Mutex::new(None), + handshake_lock: OnceCell::new(), + })) +} diff --git a/src/client/proxy.rs b/src/client/proxy.rs new file mode 100644 index 0000000..0055537 --- /dev/null +++ b/src/client/proxy.rs @@ -0,0 +1,76 @@ +use anyhow::{Context, Result, anyhow}; +use bytes::BytesMut; +use http::uri::Authority; +use tokio::io::AsyncReadExt; +use url::Url; + +pub async fn parse_proxy_request( + reader: &mut (impl AsyncReadExt + Unpin), + buffer: &mut BytesMut, +) -> Result<(String, usize, String)> { + const MAX_HEADER_LEN: usize = 16 * 1024; + + loop { + if reader.read_buf(buffer).await? == 0 { + return Err(anyhow!("connection closed during header parsing")); + } + let mut headers = [httparse::EMPTY_HEADER; 64]; + let mut req = httparse::Request::new(&mut headers); + if let httparse::Status::Complete(amt) = req.parse(buffer)? { + return Ok(( + req.method.context("no method")?.to_owned(), + amt, + req.path.context("no path")?.to_owned(), + )); + } + if buffer.len() > MAX_HEADER_LEN { + return Err(anyhow!("header too large")); + } + } +} + +#[inline] +pub fn resolve_target_host(method: &str, url_str: &str) -> Result { + if method == "CONNECT" { + let auth: Authority = url_str + .parse() + .map_err(|_| anyhow!("invalid target: {url_str}"))?; + let port = auth + .port_u16() + .ok_or_else(|| anyhow!("port required: {url_str}"))?; + return Ok(format!("{}:{port}", auth.host())); + } + + let url = Url::parse(url_str).context("invalid proxy URL")?; + let host = url.host_str().context("URL has no host")?; + let port = url.port_or_known_default().context("port required")?; + Ok(format!("{host}:{port}")) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn resolve_connect_host() { + let target = resolve_target_host("CONNECT", "example.com:443").unwrap(); + assert_eq!(target, "example.com:443"); + } + + #[test] + fn resolve_http_url() { + let target = resolve_target_host("GET", "http://example.com/path").unwrap(); + assert_eq!(target, "example.com:80"); + } + + #[test] + fn resolve_https_url() { + let target = resolve_target_host("GET", "https://example.com/path").unwrap(); + assert_eq!(target, "example.com:443"); + } + + #[test] + fn resolve_connect_no_port_fails() { + assert!(resolve_target_host("CONNECT", "example.com").is_err()); + } +} diff --git a/src/client/state.rs b/src/client/state.rs new file mode 100644 index 0000000..2adcb11 --- /dev/null +++ b/src/client/state.rs @@ -0,0 +1,42 @@ +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::Instant; +use tokio::sync::{Mutex, OnceCell}; +use zeroize::Zeroizing; + +use crate::bypass::BypassRules; +use crate::shaper::TrafficConfig; + +pub struct ManualResolver { + pub target_addr: String, +} + +impl wreq::dns::Resolve for ManualResolver { + fn resolve(&self, _name: wreq::dns::Name) -> wreq::dns::Resolving { + let target = self.target_addr.clone(); + Box::pin(async move { + let mut lookup_str = String::with_capacity(target.len() + 2); + lookup_str.push_str(&target); + lookup_str.push_str(":0"); + let addrs = tokio::net::lookup_host(lookup_str) + .await? + .map(|mut s| { + s.set_port(0); + s + }) + .collect::>(); + Ok(Box::new(addrs.into_iter()) + as Box + Send + 'static>) + }) + } +} + +pub struct SharedState { + pub remote_str: String, + pub auth_header: String, + pub traffic_config: TrafficConfig, + pub bypass: Option>, + pub server_public_key: Option, + pub initial_master: Mutex, Instant)>>, + pub handshake_lock: OnceCell>, +} diff --git a/src/client/tunnel.rs b/src/client/tunnel.rs new file mode 100644 index 0000000..dbf331f --- /dev/null +++ b/src/client/tunnel.rs @@ -0,0 +1,234 @@ +use anyhow::{Context, Result, anyhow}; +use bytes::{BufMut, Bytes, BytesMut}; +use futures::{FutureExt, StreamExt}; +use http_body_util::BodyExt; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::sync::{OwnedSemaphorePermit, Semaphore}; +use tokio::task::JoinSet; +use tracing::warn; + +use crate::client::constants::{ + DECODE_BUF_CAPACITY, MAX_BATCH_BYTES, MAX_IN_FLIGHT_BYTES, UPLOAD_CONCURRENCY, + UPLOAD_REQUEST_TIMEOUT, +}; +use crate::client::utils; +use crate::crypto::AesFrameCipher; +use crate::shaper::{self, EncodingType, FrameCipher}; + +use super::state::SharedState; + +#[inline] +pub async fn send_upload_post( + http_client: &wreq::Client, + state: &SharedState, + body: Bytes, + session_cookie_val: &str, +) -> Result<()> { + debug_assert!(!body.is_empty(), "empty upload body"); + let mut cookie = String::new(); + utils::build_tunnel_cookie(&mut cookie, session_cookie_val); + let mut req = http_client + .post(state.remote_str.as_str()) + .header("Accept-Encoding", "identity") + .header("Cache-Control", "no-store, no-transform") + .header("Content-Type", "application/octet-stream") + .header("Cookie", cookie); + if state.server_public_key.is_none() { + req = req.header("Authorization", state.auth_header.as_str()); + } + let response = tokio::time::timeout( + UPLOAD_REQUEST_TIMEOUT, + req.body(wreq::Body::from(body)).send(), + ) + .await + .context("upload POST timed out")? + .context("http post failed")?; + + if !response.status().is_success() { + let status = response.status(); + let _ = response.bytes().await; + return Err(anyhow!("upstream rejected upload: {status}")); + } + response.bytes().await.context("drain upload response")?; + Ok(()) +} + +pub async fn upload_loop( + http_client: Arc, + state: Arc, + initial_payload: Bytes, + read_half: tokio::net::tcp::OwnedReadHalf, + cipher: Option>, + encrypted_session: String, + start_seq: u64, +) -> Result<()> { + let reader = AsyncReadExt::chain(std::io::Cursor::new(initial_payload), read_half); + let traffic_cipher: Option> = cipher + .as_ref() + .map(|c| Arc::clone(c) as Arc); + + let mut shaped = Box::pin(shaper::TrafficShaper::with_seq( + reader, + state.traffic_config.clone(), + traffic_cipher, + start_seq, + )); + + let request_sem = Arc::new(Semaphore::new(UPLOAD_CONCURRENCY)); + let bytes_sem = Arc::new(Semaphore::new(MAX_IN_FLIGHT_BYTES)); + + let mut tasks = JoinSet::new(); + let mut leftover: Option = None; + + loop { + let mut batch_buf = BytesMut::with_capacity(8 * 1024); + let mut stream_ended = false; + let mut bytes_permits: Vec = Vec::new(); + + if let Some(data) = leftover.take() { + let size = data.len() as u32; + let permit = bytes_sem + .clone() + .acquire_many_owned(size) + .await + .map_err(|_| anyhow!("bytes semaphore closed"))?; + batch_buf.put_slice(&data); + bytes_permits.push(permit); + } + + if batch_buf.is_empty() { + match shaped.next().await { + Some(Ok((_seq, data))) => { + let size = data.len() as u32; + let permit = bytes_sem + .clone() + .acquire_many_owned(size) + .await + .map_err(|_| anyhow!("bytes semaphore closed"))?; + batch_buf.put_slice(&data); + bytes_permits.push(permit); + } + Some(Err(e)) => return Err(e.into()), + None => break, + } + } + + loop { + match shaped.next().now_or_never() { + Some(Some(Ok((_seq, data)))) => { + let frame_size = data.len(); + if batch_buf.len() + frame_size > MAX_BATCH_BYTES { + leftover = Some(data); + break; + } + match bytes_sem.clone().try_acquire_many_owned(frame_size as u32) { + Ok(permit) => { + batch_buf.put_slice(&data); + bytes_permits.push(permit); + } + Err(_) => { + leftover = Some(data); + break; + } + } + } + Some(Some(Err(e))) => return Err(e.into()), + Some(None) => { + stream_ended = true; + break; + } + None => break, + } + } + + if batch_buf.is_empty() { + if stream_ended { + break; + } + continue; + } + + let req_permit = request_sem + .clone() + .acquire_owned() + .await + .map_err(|_| anyhow!("request semaphore closed"))?; + + let body = batch_buf.freeze(); + let http_client = Arc::clone(&http_client); + let state_ref = Arc::clone(&state); + let session_val = encrypted_session.clone(); + + tasks.spawn(async move { + let _req_guard = req_permit; + let _bytes_guards = bytes_permits; + send_upload_post(&http_client, &state_ref, body, &session_val).await + }); + + while let Some(result) = tasks.try_join_next() { + match result { + Ok(Ok(())) => {} + Ok(Err(e)) => return Err(e.context("upload POST failed")), + Err(join_err) => return Err(anyhow!("upload task panicked: {}", join_err)), + } + } + + if stream_ended { + break; + } + } + + while let Some(result) = tasks.join_next().await { + match result { + Ok(Ok(())) => {} + Ok(Err(e)) => return Err(e), + Err(join_err) => return Err(anyhow!("upload task panicked: {}", join_err)), + } + } + + Ok(()) +} + +pub async fn download_loop( + response: wreq::Response, + mut write_half: tokio::net::tcp::OwnedWriteHalf, + cipher: Option>, + encoding: EncodingType, +) -> Result<()> { + let mut buffer = BytesMut::with_capacity(DECODE_BUF_CAPACITY); + let mut data_stream = response.into_data_stream(); + let cipher_ref: Option<&dyn FrameCipher> = cipher.as_deref().map(|c| c as &dyn FrameCipher); + let mut expected_seq: u64 = 0; + + let result: Result<()> = async { + while let Some(chunk) = data_stream.next().await { + buffer.extend_from_slice(&chunk.context("response read error")?); + while let Some((seq, frame)) = + shaper::decode_from_buffer(&mut buffer, cipher_ref, encoding)? + { + if seq != expected_seq { + return Err(anyhow!( + "download frame seq {} out of order, expected {}", + seq, + expected_seq + )); + } + expected_seq += 1; + write_half.write_all(&frame).await?; + } + } + Ok(()) + } + .await; + + if !buffer.is_empty() { + warn!( + remaining = buffer.len(), + "download stream ended with undecoded data" + ); + } + + let _ = write_half.shutdown().await; + result +} diff --git a/src/client/utils.rs b/src/client/utils.rs new file mode 100644 index 0000000..89a1d7a --- /dev/null +++ b/src/client/utils.rs @@ -0,0 +1,108 @@ +use anyhow::{Context, Result}; +use bytes::Bytes; +use rand::RngExt; + +use crate::client::constants::{MIN_PADDING, PADDING_POOL}; +use crate::shaper::{self, FrameCipher}; + +#[inline] +pub fn build_tunnel_cookie(buf: &mut String, session_val: &str) { + buf.clear(); + let cap = 8 + session_val.len() + MIN_PADDING + PADDING_POOL.len(); + buf.reserve(cap); + buf.push_str("session="); + buf.push_str(session_val); + buf.push_str("; "); + let padding_len = rand::rng().random_range(MIN_PADDING..PADDING_POOL.len()); + buf.push_str(std::str::from_utf8(&PADDING_POOL[..padding_len]).expect("Invalid UTF-8")) +} + +pub fn encode_initial_payload( + initial_payload: &[u8], + max_bytes: usize, + cipher: Option<&dyn FrameCipher>, + config: &shaper::TrafficConfig, +) -> Result<(Vec, Bytes, u64)> { + let take_len = initial_payload.len().min(max_bytes); + let data_to_send = &initial_payload[..take_len]; + let remaining = if take_len < initial_payload.len() { + Bytes::copy_from_slice(&initial_payload[take_len..]) + } else { + Bytes::new() + }; + + let raw_payload_limit = shaper::MAX_RAW_PAYLOAD; + let mut body = Vec::new(); + let mut offset = 0; + let mut seq: u64 = 0; + + while offset < data_to_send.len() { + let chunk_end = (offset + raw_payload_limit).min(data_to_send.len()); + let chunk = &data_to_send[offset..chunk_end]; + let frame = shaper::encode_frame(chunk, seq, cipher, config) + .context("encode_frame failed on initial payload")?; + body.extend_from_slice(&frame); + offset = chunk_end; + seq += 1; + } + + Ok((body, remaining, seq)) +} + +#[inline] +pub fn is_silent_error(root: &(dyn std::error::Error + 'static)) -> bool { + use std::io::ErrorKind::*; + if let Some(e) = root.downcast_ref::() { + return e.is_reset() || e.is_library(); + } + if let Some(e) = root.downcast_ref::() { + return matches!( + e.kind(), + ConnectionReset | UnexpectedEof | NotConnected | BrokenPipe + ); + } + root.to_string() + .contains("connection closed during header parsing") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn build_cookie_contains_session() { + let mut buf = String::new(); + build_tunnel_cookie(&mut buf, "abc:def:ghi"); + assert!(buf.starts_with("session=abc:def:ghi; ")); + assert!(buf.len() > 25); + } + + #[test] + fn encode_zero_payload() { + let config = shaper::TrafficConfig { + global: shaper::PaddingConfig { + padding_threshold: 16384, + padding_range: [0, 0], + }, + stages: vec![], + encoding_type: Default::default(), + }; + let (body, remaining, seq) = + encode_initial_payload(b"", shaper::MAX_RAW_PAYLOAD, None, &config).unwrap(); + assert!(body.is_empty()); + assert!(remaining.is_empty()); + assert_eq!(seq, 0); + } + + #[test] + fn is_silent_reset() { + let e = std::io::Error::new(std::io::ErrorKind::ConnectionReset, "reset"); + assert!(is_silent_error(&e)); + } + + #[test] + fn is_not_silent_other() { + let e = std::io::Error::new(std::io::ErrorKind::Other, "other"); + assert!(!is_silent_error(&e)); + } +} diff --git a/src/config/mod.rs b/src/config/mod.rs new file mode 100644 index 0000000..5dd5595 --- /dev/null +++ b/src/config/mod.rs @@ -0,0 +1,108 @@ +use serde::Deserialize; + +use crate::log::LogConfig; + +pub use crate::bypass::BypassConfig; +pub use crate::shaper::TrafficConfig; + +#[derive(Deserialize, Debug)] +#[serde(deny_unknown_fields)] +pub struct ServerTopConfig { + pub server: ServerSection, + pub auth: AuthSection, + pub proxy: Option, + pub log: Option, + pub dns: Option, + pub traffic_shaping: TrafficConfig, +} + +#[derive(Deserialize, Debug)] +#[serde(deny_unknown_fields)] +pub struct ClientTopConfig { + pub client: ClientSection, + pub auth: ClientAuthSection, + pub log: Option, + pub traffic_shaping: TrafficConfig, + #[serde(default)] + pub bypass: BypassConfig, +} + +#[derive(Deserialize, Debug)] +#[serde(deny_unknown_fields)] +pub struct AuthSection { + pub secret: String, +} + +#[derive(Deserialize, Debug)] +#[serde(deny_unknown_fields)] +pub struct ProxySection { + pub socks5: Option, +} + +#[derive(Deserialize, Debug)] +#[serde(deny_unknown_fields)] +pub struct ServerSection { + pub listen: String, + pub path: String, + pub private_key: Option, +} + +#[derive(Deserialize, Debug)] +#[serde(deny_unknown_fields)] +pub struct ClientSection { + pub listen: String, + pub remote: String, + pub address: Option, + #[serde(default)] + pub public_key: Option, +} + +#[derive(Deserialize, Debug)] +#[serde(deny_unknown_fields)] +pub struct ClientAuthSection { + pub token: String, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_minimal_server_config() { + let toml_str = r#" +[server] +listen = "127.0.0.1:3000" +path = "/secret" + +[auth] +secret = "key" + +[traffic_shaping.global] +padding_range = [0, 100] +padding_threshold = 50 +"#; + let cfg: ServerTopConfig = toml::from_str(toml_str).unwrap(); + assert_eq!(cfg.server.listen, "127.0.0.1:3000"); + assert_eq!(cfg.server.path, "/secret"); + assert!(cfg.proxy.is_none()); + } + + #[test] + fn parse_minimal_client_config() { + let toml_str = r#" +[client] +listen = "127.0.0.1:8080" +remote = "https://example.com/secret" + +[auth] +token = "mytoken" + +[traffic_shaping.global] +padding_range = [0, 100] +padding_threshold = 50 +"#; + let cfg: ClientTopConfig = toml::from_str(toml_str).unwrap(); + assert_eq!(cfg.client.listen, "127.0.0.1:8080"); + assert_eq!(cfg.auth.token, "mytoken"); + } +} diff --git a/src/crypto/cipher.rs b/src/crypto/cipher.rs new file mode 100644 index 0000000..3c990c4 --- /dev/null +++ b/src/crypto/cipher.rs @@ -0,0 +1,202 @@ +use aes_gcm::{ + Aes256Gcm, Nonce, Tag, + aead::{AeadInPlace, KeyInit}, +}; +use anyhow::{Result, anyhow}; +use base64::Engine; +use base64::engine::general_purpose::URL_SAFE_NO_PAD; +use rand::Rng; +use std::io; +use zeroize::Zeroizing; + +use super::AesKey; +use crate::shaper::FrameCipher; + +const NONCE_LEN: usize = 12; +const TAG_LEN: usize = 16; +const EMPTY_AAD: &[u8] = b""; + +#[inline] +fn random_nonce() -> [u8; NONCE_LEN] { + let mut nonce = [0u8; NONCE_LEN]; + rand::rng().fill_bytes(&mut nonce); + nonce +} + +#[inline] +fn encrypt_with_cipher(cipher: &Aes256Gcm, plaintext: &[u8]) -> Result> { + let nonce_bytes = random_nonce(); + let mut out = Vec::with_capacity(NONCE_LEN + plaintext.len() + TAG_LEN); + out.extend_from_slice(&nonce_bytes); + out.extend_from_slice(plaintext); + + let tag = cipher + .encrypt_in_place_detached( + Nonce::from_slice(&nonce_bytes), + EMPTY_AAD, + &mut out[NONCE_LEN..], + ) + .map_err(|e| anyhow!("encryption error: {e}"))?; + out.extend_from_slice(tag.as_ref()); + Ok(out) +} + +#[inline] +fn decrypt_with_cipher( + cipher: &Aes256Gcm, + data: &[u8], + short_err: &'static str, +) -> Result> { + if data.len() < NONCE_LEN + TAG_LEN { + return Err(anyhow!(short_err)); + } + let ct_end = data.len() - TAG_LEN; + let mut plaintext = data[NONCE_LEN..ct_end].to_vec(); + + cipher + .decrypt_in_place_detached( + Nonce::from_slice(&data[..NONCE_LEN]), + EMPTY_AAD, + &mut plaintext, + Tag::from_slice(&data[ct_end..]), + ) + .map_err(|e| anyhow!("decryption error: {e}"))?; + Ok(plaintext) +} + +#[allow(dead_code)] +#[inline] +pub fn encrypt_cookie(key: &AesKey, plaintext: &str) -> Result { + let cipher = Aes256Gcm::new(key); + let encrypted = encrypt_with_cipher(&cipher, plaintext.as_bytes())?; + Ok(URL_SAFE_NO_PAD.encode(encrypted)) +} + +#[allow(dead_code)] +#[inline] +pub fn decrypt_cookie(key: &AesKey, ciphertext_b64: &str) -> Result { + let mut combined = URL_SAFE_NO_PAD.decode(ciphertext_b64.as_bytes())?; + if combined.len() < NONCE_LEN + TAG_LEN { + return Err(anyhow!("ciphertext too short")); + } + let cipher = Aes256Gcm::new(key); + let ct_len = combined.len() - NONCE_LEN - TAG_LEN; + let (nonce_bytes, rest) = combined.split_at_mut(NONCE_LEN); + let (ciphertext, tag_bytes) = rest.split_at_mut(ct_len); + + cipher + .decrypt_in_place_detached( + Nonce::from_slice(nonce_bytes), + EMPTY_AAD, + ciphertext, + Tag::from_slice(tag_bytes), + ) + .map_err(|e| anyhow!("decryption error: {e}"))?; + + let result = ciphertext.to_vec(); + String::from_utf8(result).map_err(|e| anyhow!("invalid utf8: {e}")) +} + +#[inline] +pub fn encrypt_bytes(key: &AesKey, data: &[u8]) -> Result> { + let cipher = Aes256Gcm::new(key); + encrypt_with_cipher(&cipher, data) +} + +#[inline] +pub fn decrypt_bytes(key: &AesKey, data: &[u8]) -> Result> { + let cipher = Aes256Gcm::new(key); + decrypt_with_cipher(&cipher, data, "encrypted data too short") +} + +pub struct AesFrameCipher { + key: Zeroizing<[u8; 32]>, + cipher: Aes256Gcm, +} + +impl Clone for AesFrameCipher { + #[inline] + fn clone(&self) -> Self { + Self::new(AesKey::from(*self.key)) + } +} + +impl AesFrameCipher { + #[inline] + pub fn new(key: AesKey) -> Self { + let mut key_bytes = [0u8; 32]; + key_bytes.copy_from_slice(key.as_ref()); + let cipher = Aes256Gcm::new(&key); + Self { + key: Zeroizing::new(key_bytes), + cipher, + } + } +} + +impl FrameCipher for AesFrameCipher { + #[inline] + fn encrypt(&self, data: &[u8]) -> io::Result> { + encrypt_with_cipher(&self.cipher, data).map_err(io::Error::other) + } + + #[inline] + fn decrypt(&self, data: &[u8]) -> io::Result> { + decrypt_with_cipher(&self.cipher, data, "encrypted frame too short") + .map_err(io::Error::other) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::Rng; + + fn random_key() -> AesKey { + let mut bytes = [0u8; 32]; + rand::rng().fill_bytes(&mut bytes); + AesKey::from(bytes) + } + + #[test] + fn encrypt_decrypt_roundtrip() { + let key = random_key(); + let plain = b"hello world test frame data"; + let ct = encrypt_bytes(&key, plain).unwrap(); + let pt = decrypt_bytes(&key, &ct).unwrap(); + assert_eq!(pt, plain); + } + + #[test] + fn cookie_encrypt_decrypt_roundtrip() { + let key = random_key(); + let msg = "session-1234"; + let ct = encrypt_cookie(&key, msg).unwrap(); + let pt = decrypt_cookie(&key, &ct).unwrap(); + assert_eq!(pt, msg); + } + + #[test] + fn frame_cipher_roundtrip() { + let key = random_key(); + let cipher = AesFrameCipher::new(key); + let data = b"frame data for cipher test"; + let ct = cipher.encrypt(data).unwrap(); + let pt = cipher.decrypt(&ct).unwrap(); + assert_eq!(pt, data); + } + + #[test] + fn decrypt_garbage_fails() { + let key = random_key(); + assert!(decrypt_bytes(&key, b"too-short").is_err()); + } + + #[test] + fn decrypt_cookie_invalid_utf8() { + let key = random_key(); + let junk = URL_SAFE_NO_PAD.encode(b"\xff\xfe\xfd"); + let padded = format!("AAAAQQAAAAAA{}{junk}", "x".repeat(12)); + assert!(decrypt_cookie(&key, &padded).is_err()); + } +} diff --git a/src/crypto/handshake.rs b/src/crypto/handshake.rs new file mode 100644 index 0000000..0e13571 --- /dev/null +++ b/src/crypto/handshake.rs @@ -0,0 +1,90 @@ +use hkdf::Hkdf; +use sha2::Sha256; +use zeroize::Zeroizing; + +use super::AesKey; + +pub fn derive_handshake_key(shared: &[u8; 32]) -> Zeroizing<[u8; 32]> { + let hkdf = Hkdf::::new(None, shared); + let mut key = Zeroizing::new([0u8; 32]); + hkdf.expand(b"mlkem_handshake_key", &mut *key) + .expect("32 bytes is valid for HKDF"); + key +} + +pub fn derive_initial_master(mlkem_ss: &[u8], x25519_ss: &[u8]) -> Zeroizing<[u8; 32]> { + let mut ikm = Vec::with_capacity(mlkem_ss.len() + x25519_ss.len()); + ikm.extend_from_slice(mlkem_ss); + ikm.extend_from_slice(x25519_ss); + let hkdf = Hkdf::::new(Some(b"initial_master_salt"), &ikm); + let mut master = Zeroizing::new([0u8; 32]); + hkdf.expand(b"", &mut *master) + .expect("32 bytes is valid for HKDF"); + master +} + +pub fn derive_cookie_master_key(master: &[u8; 32]) -> AesKey { + let hkdf = Hkdf::::new(None, master); + let mut key = [0u8; 32]; + hkdf.expand(b"cookie_master_key", &mut key) + .expect("32 bytes is valid for HKDF"); + key.into() +} + +pub fn derive_connection_keys(master: &[u8; 32], conn_nonce: &[u8; 16]) -> (AesKey, AesKey) { + let hkdf = Hkdf::::new(None, master); + let mut info = Vec::with_capacity(16 + 15); + info.extend_from_slice(conn_nonce); + info.extend_from_slice(b"connection_keys"); + let mut buf = [0u8; 64]; + hkdf.expand(&info, &mut buf) + .expect("64 bytes is valid for HKDF"); + + let mut upload_key = [0u8; 32]; + let mut download_key = [0u8; 32]; + upload_key.copy_from_slice(&buf[..32]); + download_key.copy_from_slice(&buf[32..]); + (upload_key.into(), download_key.into()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn derive_handshake_key_deterministic() { + let shared = [0xAAu8; 32]; + let k1 = derive_handshake_key(&shared); + let k2 = derive_handshake_key(&shared); + assert_eq!(*k1, *k2); + } + + #[test] + fn derive_initial_master_deterministic() { + let ml = [0x11u8; 32]; + let x2 = [0x22u8; 32]; + let m1 = derive_initial_master(&ml, &x2); + let m2 = derive_initial_master(&ml, &x2); + assert_eq!(*m1, *m2); + } + + #[test] + fn connection_keys_deterministic() { + let master = [0xBBu8; 32]; + let nonce = [0xCCu8; 16]; + let (up1, dn1) = derive_connection_keys(&master, &nonce); + let (up2, dn2) = derive_connection_keys(&master, &nonce); + assert_eq!(up1, up2); + assert_eq!(dn1, dn2); + } + + #[test] + fn connection_keys_different_for_different_nonces() { + let master = [0xBBu8; 32]; + let n1 = [0xCCu8; 16]; + let n2 = [0xDDu8; 16]; + let (up1, _) = derive_connection_keys(&master, &n1); + let (up2, _) = derive_connection_keys(&master, &n2); + assert_ne!(up1, up2); + } +} diff --git a/src/crypto/keys.rs b/src/crypto/keys.rs new file mode 100644 index 0000000..904190b --- /dev/null +++ b/src/crypto/keys.rs @@ -0,0 +1,135 @@ +use anyhow::{Result, anyhow}; +use base64::Engine; +use base64::engine::general_purpose::URL_SAFE_NO_PAD; +use crypto_common::Key; +use ml_kem::{ + Ciphertext, DecapsulationKey, EncapsulationKey, MlKem768, + kem::{Decapsulate, Encapsulate}, +}; +use x25519_dalek::{PublicKey as X25519PublicKey, StaticSecret}; +use zeroize::Zeroizing; + +const X25519_KEY_LEN: usize = 32; +const X25519_B64_LEN: usize = 43; + +#[inline] +pub fn generate_keypair() -> (StaticSecret, X25519PublicKey) { + let secret = StaticSecret::random(); + let public = X25519PublicKey::from(&secret); + (secret, public) +} + +#[inline] +fn encode_fixed_32(bytes: &[u8; X25519_KEY_LEN]) -> String { + let mut out = String::with_capacity(X25519_B64_LEN); + URL_SAFE_NO_PAD.encode_string(bytes, &mut out); + out +} + +#[inline] +fn decode_fixed_32(s: &str, kind: &'static str) -> Result<[u8; X25519_KEY_LEN]> { + let mut out = [0u8; X25519_KEY_LEN]; + let decoded = URL_SAFE_NO_PAD.decode(s.as_bytes())?; + if decoded.len() != X25519_KEY_LEN { + return Err(anyhow!("invalid {kind} length")); + } + out.copy_from_slice(&decoded); + Ok(out) +} + +#[inline] +pub fn public_key_to_b64(pk: &X25519PublicKey) -> String { + encode_fixed_32(pk.as_bytes()) +} + +#[inline] +pub fn private_key_to_b64(sk: &StaticSecret) -> String { + encode_fixed_32(&sk.to_bytes()) +} + +#[inline] +pub fn b64_to_public_key(s: &str) -> Result { + Ok(X25519PublicKey::from(decode_fixed_32(s, "public key")?)) +} + +#[inline] +pub fn b64_to_private_key(s: &str) -> Result { + Ok(StaticSecret::from(decode_fixed_32(s, "private key")?)) +} + +#[inline] +pub fn diffie_hellman(our_sk: &StaticSecret, their_pk: &X25519PublicKey) -> Zeroizing<[u8; 32]> { + Zeroizing::new(*our_sk.diffie_hellman(their_pk).as_bytes()) +} + +pub fn bytes_to_encapsulation_key(bytes: &[u8]) -> Result> { + let key: Key> = bytes + .try_into() + .map_err(|_| anyhow!("invalid encapsulation key length"))?; + EncapsulationKey::new(&key).map_err(|_| anyhow!("invalid encapsulation key")) +} + +pub fn generate_mlkem_keypair() -> (DecapsulationKey, EncapsulationKey) { + use ml_kem::kem::Kem; + ::generate_keypair_from_rng(&mut rand::rng()) +} + +pub fn mlkem_encapsulate( + pk: &EncapsulationKey, +) -> (Ciphertext, Zeroizing>) { + let (ct, ss) = pk.encapsulate_with_rng(&mut rand::rng()); + let ss_bytes: &[u8] = &ss; + (ct, Zeroizing::new(ss_bytes.to_vec())) +} + +pub fn mlkem_decapsulate( + sk: &DecapsulationKey, + ct: &Ciphertext, +) -> Zeroizing> { + let ss = sk.decapsulate(ct); + let ss_bytes: &[u8] = &ss; + Zeroizing::new(ss_bytes.to_vec()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn keypair_roundtrip() { + let (sk, pk) = generate_keypair(); + let b64_pk = public_key_to_b64(&pk); + let b64_sk = private_key_to_b64(&sk); + assert_eq!(b64_pk.len(), X25519_B64_LEN); + assert_eq!(b64_sk.len(), X25519_B64_LEN); + + let pk2 = b64_to_public_key(&b64_pk).unwrap(); + assert_eq!(pk.as_bytes(), pk2.as_bytes()); + + let sk2 = b64_to_private_key(&b64_sk).unwrap(); + assert_eq!(sk.to_bytes(), sk2.to_bytes()); + } + + #[test] + fn diffie_hellman_agreement() { + let (a_sk, a_pk) = generate_keypair(); + let (b_sk, b_pk) = generate_keypair(); + let ss_a = diffie_hellman(&a_sk, &b_pk); + let ss_b = diffie_hellman(&b_sk, &a_pk); + assert_eq!(*ss_a, *ss_b); + } + + #[test] + fn mlkem_roundtrip() { + let (sk, pk) = generate_mlkem_keypair(); + let (ct, ss_enc) = mlkem_encapsulate(&pk); + let ss_dec = mlkem_decapsulate(&sk, &ct); + assert_eq!(*ss_enc, *ss_dec); + } + + #[test] + fn b64_invalid_length_rejected() { + assert!(b64_to_public_key("abc").is_err()); + assert!(b64_to_private_key("abc").is_err()); + } +} diff --git a/src/crypto/mod.rs b/src/crypto/mod.rs new file mode 100644 index 0000000..51c0275 --- /dev/null +++ b/src/crypto/mod.rs @@ -0,0 +1,15 @@ +mod cipher; +mod handshake; +mod keys; + +pub use cipher::{AesFrameCipher, decrypt_bytes, encrypt_bytes}; +pub use handshake::{ + derive_connection_keys, derive_cookie_master_key, derive_handshake_key, derive_initial_master, +}; +pub use keys::{ + b64_to_private_key, b64_to_public_key, bytes_to_encapsulation_key, diffie_hellman, + generate_keypair, generate_mlkem_keypair, mlkem_decapsulate, mlkem_encapsulate, + private_key_to_b64, public_key_to_b64, +}; + +pub type AesKey = aes_gcm::Key; diff --git a/src/dns.rs b/src/dns/mod.rs similarity index 100% rename from src/dns.rs rename to src/dns/mod.rs diff --git a/src/error/mod.rs b/src/error/mod.rs new file mode 100644 index 0000000..7190534 --- /dev/null +++ b/src/error/mod.rs @@ -0,0 +1,166 @@ +use axum::http::StatusCode; +use axum::response::{IntoResponse, Response}; +use std::fmt; + +#[derive(Debug)] +pub enum HttpProxyError { + Anyhow(anyhow::Error), + Io(std::io::Error), + Protocol(String), + Config(String), + Auth(String), + Dns(String), + Timeout(String), + NotFound(String), + Precondition(String), + PayloadTooLarge(String), +} + +impl fmt::Display for HttpProxyError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Anyhow(e) => write!(f, "{e}"), + Self::Io(e) => write!(f, "I/O error: {e}"), + Self::Protocol(m) => write!(f, "protocol error: {m}"), + Self::Config(m) => write!(f, "config error: {m}"), + Self::Auth(m) => write!(f, "auth error: {m}"), + Self::Dns(m) => write!(f, "dns error: {m}"), + Self::Timeout(m) => write!(f, "timeout: {m}"), + Self::NotFound(m) => write!(f, "not found: {m}"), + Self::Precondition(m) => write!(f, "precondition: {m}"), + Self::PayloadTooLarge(m) => write!(f, "payload too large: {m}"), + } + } +} + +impl std::error::Error for HttpProxyError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::Anyhow(e) => Some(e.as_ref()), + Self::Io(e) => Some(e), + _ => None, + } + } +} + +impl From for HttpProxyError { + fn from(e: anyhow::Error) -> Self { + Self::Anyhow(e) + } +} + +impl From for HttpProxyError { + fn from(e: std::io::Error) -> Self { + Self::Io(e) + } +} + +#[derive(Debug)] +pub struct ServerError(pub StatusCode, pub String); + +impl ServerError { + #[inline(always)] + pub fn bad_request(msg: impl Into) -> Self { + Self(StatusCode::BAD_REQUEST, msg.into()) + } + #[inline(always)] + pub fn bad_gateway(msg: impl Into) -> Self { + Self(StatusCode::BAD_GATEWAY, msg.into()) + } + #[inline(always)] + pub fn gateway_timeout(msg: impl Into) -> Self { + Self(StatusCode::GATEWAY_TIMEOUT, msg.into()) + } + #[inline(always)] + pub fn unauthorized(msg: impl Into) -> Self { + Self(StatusCode::UNAUTHORIZED, msg.into()) + } + #[inline(always)] + pub fn not_found(msg: impl Into) -> Self { + Self(StatusCode::GONE, msg.into()) + } + #[inline(always)] + pub fn internal(msg: impl Into) -> Self { + Self(StatusCode::INTERNAL_SERVER_ERROR, msg.into()) + } + #[inline(always)] + pub fn payload_too_large(msg: impl Into) -> Self { + Self(StatusCode::PAYLOAD_TOO_LARGE, msg.into()) + } + #[inline(always)] + pub fn precondition_required(msg: impl Into) -> Self { + Self(StatusCode::PRECONDITION_REQUIRED, msg.into()) + } +} + +impl IntoResponse for ServerError { + fn into_response(self) -> Response { + (self.0, self.1).into_response() + } +} + +impl From for ServerError { + fn from(err: E) -> Self { + Self::internal(err.to_string()) + } +} + +impl From for HttpProxyError { + fn from(e: ServerError) -> Self { + Self::Protocol(e.1) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn error_display_and_source() { + let io_err = HttpProxyError::Io(std::io::Error::new(std::io::ErrorKind::Other, "boom")); + assert!(io_err.to_string().contains("boom")); + assert!(std::error::Error::source(&io_err).is_some()); + + let proto_err = HttpProxyError::Protocol("bad frame".into()); + assert_eq!(proto_err.to_string(), "protocol error: bad frame"); + assert!(std::error::Error::source(&proto_err).is_none()); + } + + #[test] + fn app_error_constructors() { + assert_eq!(ServerError::bad_request("x").0, StatusCode::BAD_REQUEST); + assert_eq!(ServerError::bad_gateway("x").0, StatusCode::BAD_GATEWAY); + assert_eq!( + ServerError::gateway_timeout("x").0, + StatusCode::GATEWAY_TIMEOUT + ); + assert_eq!(ServerError::unauthorized("x").0, StatusCode::UNAUTHORIZED); + assert_eq!(ServerError::not_found("x").0, StatusCode::GONE); + assert_eq!( + ServerError::internal("x").0, + StatusCode::INTERNAL_SERVER_ERROR + ); + assert_eq!( + ServerError::payload_too_large("x").0, + StatusCode::PAYLOAD_TOO_LARGE + ); + assert_eq!( + ServerError::precondition_required("x").0, + StatusCode::PRECONDITION_REQUIRED + ); + } + + #[test] + fn app_error_into_response() { + let err = ServerError::bad_request("test"); + let resp = err.into_response(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + } + + #[test] + fn app_error_from_std_error() { + let io = std::io::Error::new(std::io::ErrorKind::Other, "oops"); + let app: ServerError = io.into(); + assert_eq!(app.0, StatusCode::INTERNAL_SERVER_ERROR); + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..09d2230 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,9 @@ +pub mod bypass; +pub mod client; +pub mod config; +pub mod crypto; +pub mod dns; +pub mod error; +pub mod log; +pub mod server; +pub mod shaper; diff --git a/src/log.rs b/src/log/mod.rs similarity index 100% rename from src/log.rs rename to src/log/mod.rs diff --git a/src/server.rs b/src/server.rs deleted file mode 100644 index 0951b5f..0000000 --- a/src/server.rs +++ /dev/null @@ -1,369 +0,0 @@ -mod dns; -mod log; -mod shaper; - -use anyhow::Context; -use axum::{ - Router, - body::Body, - extract::{Query, State}, - http::{HeaderMap, StatusCode}, - response::{IntoResponse, Response}, - routing::post, -}; -use bytes::BytesMut; -use clap::Parser; -use jsonwebtoken::{DecodingKey, Validation}; -use rand::RngExt; -use serde::{Deserialize, Serialize}; -#[cfg(unix)] -use std::os::unix::fs::PermissionsExt; -use std::{ - fs, - net::{IpAddr, SocketAddr}, - sync::{ - Arc, - atomic::{AtomicU64, Ordering}, - }, - time::Duration, -}; -use tokio::{io::AsyncWriteExt, net::TcpStream}; -use tokio_socks::tcp::Socks5Stream; -use tokio_stream::StreamExt; -use tower::ServiceBuilder; -use tower_http::trace::TraceLayer; -use tracing::{Instrument, info, warn}; - -static NEXT_STREAM_ID: AtomicU64 = AtomicU64::new(1); - -const PADDING_POOL: [u8; 62] = [b'X'; 62]; -const DECODE_BUF_CAPACITY: usize = 16 * 1024; - -#[derive(Parser, Debug)] -#[command(version, about, long_about = None)] -struct Cli { - #[arg(short = 'c', long, default_value = "config.toml")] - config: String, - #[command(subcommand)] - command: Option, -} - -#[derive(clap::Subcommand, Debug)] -enum Commands { - #[command(about = "Generate a JWT token")] - GenToken { - #[arg(short, long, help = "Secret key for signing")] - secret: String, - #[arg(short, long, help = "Username or Subject")] - user: String, - #[arg(short, long, help = "Expiration timestamp (Unix)")] - exp: u64, - }, -} - -#[derive(Deserialize, Debug)] -struct Config { - server: ServerConfig, - auth: AuthConfig, - proxy: Option, - log: Option, - dns: Option, - traffic_shaping: shaper::TrafficConfig, -} - -#[derive(Deserialize, Debug)] -struct ServerConfig { - listen: String, - path: String, -} - -#[derive(Deserialize, Debug)] -struct AuthConfig { - secret: String, -} - -#[derive(Deserialize, Debug)] -struct ProxyConfig { - socks5: Option, -} - -#[derive(Debug, Serialize, Deserialize)] -struct Claims { - sub: String, - exp: u64, -} - -#[derive(Clone)] -struct AppState { - decoding_key: DecodingKey, - jwt_validation: Validation, - socks5_proxy: Option>, - dns_client: Option>, - client_subnet: Option, - traffic_config: shaper::TrafficConfig, -} - -#[derive(Deserialize)] -struct TunnelQuery { - target: String, -} - -struct AppError(StatusCode, String); - -impl IntoResponse for AppError { - fn into_response(self) -> Response { - (self.0, self.1).into_response() - } -} - -impl AppError { - #[inline] - fn bad_request(msg: impl Into) -> Self { - Self(StatusCode::BAD_REQUEST, msg.into()) - } - - #[inline] - fn bad_gateway(msg: impl Into) -> Self { - Self(StatusCode::BAD_GATEWAY, msg.into()) - } - - #[inline] - fn gateway_timeout(msg: impl Into) -> Self { - Self(StatusCode::GATEWAY_TIMEOUT, msg.into()) - } - - #[inline] - fn unauthorized(msg: impl Into) -> Self { - Self(StatusCode::UNAUTHORIZED, msg.into()) - } - - #[inline] - fn internal(msg: impl Into) -> Self { - Self(StatusCode::INTERNAL_SERVER_ERROR, msg.into()) - } -} - -impl From for AppError { - fn from(err: E) -> Self { - Self::internal(err.to_string()) - } -} - -#[tokio::main] -async fn main() -> anyhow::Result<()> { - let cli = Cli::parse(); - - if let Some(Commands::GenToken { secret, user, exp }) = cli.command { - return gen_token(&secret, user, exp); - } - - let mut config: Config = toml::from_str(&fs::read_to_string(&cli.config)?)?; - let _guard = log::init_tracing(&config.log.as_ref().cloned().unwrap_or_default()); - let state = build_state(&mut config).await?; - - run_server( - build_router(state, &config.server.path), - &config.server.listen, - ) - .await -} - -fn gen_token(secret: &str, user: String, exp: u64) -> anyhow::Result<()> { - let token = jsonwebtoken::encode( - &jsonwebtoken::Header::default(), - &Claims { sub: user, exp }, - &jsonwebtoken::EncodingKey::from_secret(secret.as_bytes()), - )?; - println!("{token}"); - Ok(()) -} - -async fn build_state(config: &mut Config) -> anyhow::Result> { - let (dns_client, client_subnet) = match config.dns { - Some(ref mut dc) => { - let mut dc = dc.clone(); - let client = dns::init_dns(&mut dc).await?; - (Some(client), dc.options.client_subnet) - } - None => (None, None), - }; - - Ok(Arc::new(AppState { - decoding_key: DecodingKey::from_secret(config.auth.secret.as_bytes()), - jwt_validation: Validation::default(), - socks5_proxy: config - .proxy - .as_ref() - .and_then(|p| p.socks5.as_deref()) - .map(Arc::from), - dns_client, - client_subnet, - traffic_config: config.traffic_shaping.clone(), - })) -} - -fn build_router(state: Arc, path: &str) -> Router { - use tracing::field::Empty; - - Router::new() - .route(path, post(tunnel_handler)) - .layer( - ServiceBuilder::new().layer(TraceLayer::new_for_http().make_span_with( - |req: &axum::http::Request| { - let id = NEXT_STREAM_ID.fetch_add(1, Ordering::Relaxed); - let client = req - .headers() - .get("X-Forwarded-For") - .and_then(|h| h.to_str().ok()) - .unwrap_or("-"); - tracing::error_span!("session", id, client, user = Empty, target = Empty) - }, - )), - ) - .with_state(state) -} - -async fn run_server(app: Router, listen: &str) -> anyhow::Result<()> { - #[cfg(unix)] - if listen.contains('/') || listen.ends_with(".sock") { - let path = std::path::Path::new(listen); - if path.exists() { - fs::remove_file(path)?; - } - let listener = tokio::net::UnixListener::bind(path)?; - fs::set_permissions(path, fs::Permissions::from_mode(0o666))?; - info!("listening on unix:{listen}"); - return Ok(axum::serve(listener, app.into_make_service()).await?); - } - - let addr: SocketAddr = listen.parse().context("invalid bind address")?; - info!("listening on {addr}"); - let listener = tokio::net::TcpListener::bind(addr).await?; - axum::serve(listener, app).await?; - Ok(()) -} - -#[inline] -fn validate_jwt( - headers: &HeaderMap, - key: &DecodingKey, - validation: &Validation, -) -> Result { - let token = headers - .get("Authorization") - .and_then(|h| h.to_str().ok()) - .and_then(|s| s.strip_prefix("Bearer ")) - .ok_or_else(|| { - warn!("rejected: missing or invalid authorization header"); - AppError::unauthorized("invalid header") - })?; - - jsonwebtoken::decode::(token, key, validation) - .map(|td| td.claims.sub) - .map_err(|_| { - warn!("rejected: invalid token"); - AppError::unauthorized("invalid token") - }) -} - -async fn connect_upstream( - dns_client: Option<&Arc>, - client_subnet: Option, - socks5_proxy: Option<&Arc>, - host: &str, - port: u16, -) -> Result { - if let Some(client) = dns_client { - return client - .connect( - host, - port, - client_subnet, - socks5_proxy.map(|s| s.to_string()), - ) - .await - .map_err(|e| format!("dns error: {e}")); - } - - match socks5_proxy { - Some(p) => Socks5Stream::connect(p.as_ref(), (host, port)) - .await - .map(Socks5Stream::into_inner) - .map_err(|e| e.to_string()), - None => TcpStream::connect((host, port)) - .await - .map_err(|e| e.to_string()), - } -} - -async fn tunnel_handler( - State(state): State>, - headers: HeaderMap, - Query(query): Query, - body: Body, -) -> Result { - let span = tracing::Span::current(); - span.record( - "user", - validate_jwt(&headers, &state.decoding_key, &state.jwt_validation)?, - ); - span.record("target", &query.target); - - let auth = query - .target - .parse::() - .map_err(|_| AppError::bad_request("invalid target format"))?; - - let host = auth.host(); - let port = auth - .port_u16() - .ok_or_else(|| AppError::bad_request("port required"))?; - - info!("connecting"); - - let upstream = tokio::time::timeout( - Duration::from_secs(10), - connect_upstream( - state.dns_client.as_ref(), - state.client_subnet, - state.socks5_proxy.as_ref(), - host, - port, - ), - ) - .await - .map_err(|_| AppError::gateway_timeout("connect timeout"))? - .map_err(AppError::bad_gateway)?; - - upstream.set_nodelay(true)?; - - let (upstream_read, mut upstream_write) = upstream.into_split(); - - tokio::spawn( - async move { - let mut stream = body.into_data_stream(); - let mut buf = BytesMut::with_capacity(DECODE_BUF_CAPACITY); - while let Some(chunk) = stream.next().await { - let data = chunk.context("stream error")?; - buf.extend_from_slice(&data); - while let Some(decoded) = shaper::TrafficShaper::decode_from_buffer(&mut buf)? { - upstream_write.write_all(&decoded).await?; - } - } - upstream_write.shutdown().await?; - Ok::<(), anyhow::Error>(()) - } - .instrument(tracing::Span::current()), - ); - - let shaper_stream = shaper::TrafficShaper::new(upstream_read, state.traffic_config.clone()); - let padding_len = rand::rng().random_range(30..=PADDING_POOL.len()); - - Ok(( - [ - ("Cache-Control", b"no-store" as &[u8]), - ("X-Padding", &PADDING_POOL[..padding_len]), - ], - Body::from_stream(shaper_stream), - )) -} diff --git a/src/server/connection.rs b/src/server/connection.rs new file mode 100644 index 0000000..aeee54a --- /dev/null +++ b/src/server/connection.rs @@ -0,0 +1,157 @@ +use bytes::Bytes; +use std::{cmp::Ordering, collections::BTreeMap, net::IpAddr, sync::Arc, time::Duration}; +use tokio::{ + net::TcpStream, + sync::{mpsc, oneshot}, +}; +use tokio_socks::tcp::Socks5Stream; +use tracing::{info, warn}; + +use tokio::io::AsyncWriteExt; + +use crate::dns::DnsClient; +use crate::server::constants::{ + MAX_PENDING_BYTES, MAX_PENDING_FRAMES, MAX_REORDER_SECS, WRITE_TIMEOUT, +}; +use crate::server::state::{FrameOrEos, UploadStream}; + +pub async fn connect_upstream( + dns_client: Option<&Arc>, + client_subnet: Option, + socks5_proxy: Option<&Arc>, + host: &str, + port: u16, +) -> Result { + if let Some(client) = dns_client { + return client + .connect( + host, + port, + client_subnet, + socks5_proxy.map(|s| s.to_string()), + ) + .await + .map_err(|e| format!("dns error: {e}")); + } + match socks5_proxy { + Some(p) => Socks5Stream::connect(p.as_ref(), (host, port)) + .await + .map(Socks5Stream::into_inner) + .map_err(|e| e.to_string()), + None => TcpStream::connect((host, port)) + .await + .map_err(|e| e.to_string()), + } +} + +pub async fn ordered_frame_writer( + mut rx: mpsc::Receiver, + mut upstream_write: tokio::net::tcp::OwnedWriteHalf, + stream_key: String, + stream: Arc, + initial_seq: u64, +) { + let mut next_seq: u64 = initial_seq; + let mut pending: BTreeMap = BTreeMap::new(); + let mut pending_bytes: usize = 0; + let mut eos_waiters: BTreeMap>> = BTreeMap::new(); + + 'main: loop { + tokio::select! { + cmd = rx.recv() => { + match cmd { + Some(FrameOrEos::Data { seq, data }) => { + match seq.cmp(&next_seq) { + Ordering::Less => { + warn!(stream_id = %stream_key, seq, "stale frame discarded"); + } + Ordering::Equal => { + let mut expected = next_seq + 1; + let mut run = vec![data]; + + while let Some(d) = pending.remove(&expected) { + pending_bytes -= d.len(); + run.push(d); + expected += 1; + } + next_seq = expected; + + for buf in run { + if let Err(e) = tokio::time::timeout(WRITE_TIMEOUT, upstream_write.write_all(&buf)).await { + warn!(stream_id = %stream_key, reason = %e, "upstream write failed"); + break 'main; + } + } + notify_eos(&mut eos_waiters, next_seq); + } + Ordering::Greater => { + if pending.contains_key(&seq) { + warn!(stream_id = %stream_key, seq, "duplicate pending frame discarded"); + continue; + } + let data_len = data.len(); + if pending.len() >= MAX_PENDING_FRAMES || + pending_bytes + data_len > MAX_PENDING_BYTES { + warn!(stream_id = %stream_key, "reorder buffer full"); + break; + } + pending.insert(seq, data); + pending_bytes += data_len; + } + } + } + Some(FrameOrEos::Eos { max_seq, done }) => { + if next_seq > max_seq { + let _ = done.send(()); + } else { + eos_waiters.entry(max_seq) + .or_default() + .push(done); + } + } + None => break + } + } + _ = stream.shutdown.notified() => { + info!(stream_id = %stream_key, "shutdown received, exiting"); + break; + }, + _ = tokio::time::sleep(Duration::from_secs(MAX_REORDER_SECS)), if !pending.is_empty() || !eos_waiters.is_empty() => { + warn!(stream_id = %stream_key, next_seq, + pending_frames = pending.len(), eos_waiters = eos_waiters.len(), + "reorder timeout"); + break; + } + } + } + + stream.do_shutdown(); + + while let Ok(cmd) = rx.try_recv() { + if let FrameOrEos::Eos { done, .. } = cmd { + let _ = done.send(()); + } + } + + for (_, waiters) in eos_waiters { + for sender in waiters { + let _ = sender.send(()); + } + } + + let _ = upstream_write.shutdown().await; + info!(stream_id = %stream_key, "frame writer exited"); +} + +#[inline] +fn notify_eos(eos_waiters: &mut BTreeMap>>, next_seq: u64) { + while let Some(entry) = eos_waiters.first_entry() { + if *entry.key() >= next_seq { + break; + } + let (_, senders) = entry.remove_entry(); + for sender in senders { + let _ = sender.send(()); + } + } +} diff --git a/src/server/constants.rs b/src/server/constants.rs new file mode 100644 index 0000000..059dbd9 --- /dev/null +++ b/src/server/constants.rs @@ -0,0 +1,29 @@ +use std::{ + sync::LazyLock, + time::{Duration, Instant}, +}; + +pub const MAX_UPLOAD_BODY_SIZE: usize = 1024 * 1024; + +pub const MAX_PENDING_BYTES: usize = 2 * 1024 * 1024; +pub const MAX_PENDING_FRAMES: usize = 8 * 1024; +pub const MAX_REORDER_SECS: u64 = 10; + +pub const STREAM_IDLE_TIMEOUT_SECS: u64 = 120; + +pub const JANITOR_INTERVAL: Duration = Duration::from_secs(30); +pub const NONCE_CLEANUP_INTERVAL: Duration = Duration::from_secs(60); + +pub const CONNECT_TIMEOUT: Duration = Duration::from_secs(10); +pub const WRITE_TIMEOUT: Duration = Duration::from_secs(10); + +pub const MASTER_EXPIRY: Duration = Duration::from_secs(1200); + +pub const PADDING_POOL: &[u8] = b"padding=XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"; + +pub static START: LazyLock = LazyLock::new(Instant::now); + +#[inline(always)] +pub fn now_secs() -> u64 { + START.elapsed().as_secs() +} diff --git a/src/server/handlers.rs b/src/server/handlers.rs new file mode 100644 index 0000000..db3a7bb --- /dev/null +++ b/src/server/handlers.rs @@ -0,0 +1,616 @@ +use axum::{body::Body, extract::State, http::HeaderMap, response::Response}; +use base64::Engine; +use bytes::{Bytes, BytesMut}; +use futures::StreamExt; +use jsonwebtoken::{DecodingKey, Validation}; +use std::sync::Arc; +use tokio::io::AsyncWriteExt; +use tokio::sync::{mpsc, oneshot}; +use tracing::{Instrument, info, warn}; +use uuid; +use zeroize::Zeroizing; + +use crate::crypto::{self, AesFrameCipher, AesKey}; +use crate::error::ServerError; +use crate::server::constants::{CONNECT_TIMEOUT, MASTER_EXPIRY, MAX_UPLOAD_BODY_SIZE}; +use crate::server::{ + connection::{self, connect_upstream}, + state::{DownloadStream, FrameOrEos, UploadStream}, + utils, +}; +use crate::shaper::{self, FrameCipher}; + +use super::AppState; + +pub async fn dispatch( + State(state): State>, + headers: HeaderMap, + body: Body, +) -> Result { + let span = tracing::Span::current(); + + let body_bytes = axum::body::to_bytes(body, MAX_UPLOAD_BODY_SIZE) + .await + .map_err(|e| ServerError::bad_request(format!("failed to read body: {e}")))?; + + let has_x_target = headers.get("X-Target").is_some(); + let session_cookie = utils::extract_cookie_value(&headers, "session"); + + if has_x_target { + return handle_plaintext_download(state, headers, body_bytes, span).await; + } + + if session_cookie.is_none() { + return handle_fresh_handshake(state, headers, Body::from(body_bytes), span).await; + } + + let cookie_val = session_cookie.unwrap(); + let session_id = cookie_val.split(':').next().unwrap_or(cookie_val); + let is_pq = state.master_store.get(session_id).is_some(); + + if !is_pq { + return Err(ServerError::precondition_required("session not found")); + } + + if !body_bytes.is_empty() && !state.streams.contains_key(cookie_val) { + return handle_pq_download(state, cookie_val, body_bytes, span).await; + } + + if body_bytes.is_empty() { + handle_pq_download(state, cookie_val, body_bytes, span).await + } else { + let user = &state + .master_store + .get(session_id) + .map(|e| e.value().0.clone()) + .unwrap_or_default(); + span.record("user", user); + let upload_body = Body::from(body_bytes); + handle_stream_upload(state, cookie_val.to_owned(), upload_body, span).await + } +} + +async fn handle_plaintext_download( + state: Arc, + headers: HeaderMap, + early_data: Bytes, + span: tracing::Span, +) -> Result { + let user = validate_jwt_if_needed(&headers, false, &state.decoding_key, &state.jwt_validation)?; + span.record("user", &user); + + let target = headers + .get("X-Target") + .and_then(|v| v.to_str().ok()) + .ok_or_else(|| ServerError::bad_request("missing X-Target header"))?; + + span.record("target", target); + info!(user = %user, target = %target, body_len = %early_data.len(), "connection initiated"); + + let (host, port_str) = target + .rsplit_once(':') + .ok_or_else(|| ServerError::bad_request("target must be host:port"))?; + let port: u16 = port_str + .parse() + .map_err(|_| ServerError::bad_request("invalid port"))?; + + let mut upstream = tokio::time::timeout( + CONNECT_TIMEOUT, + connect_upstream( + state.dns_client.as_ref(), + state.client_subnet, + state.socks5_proxy.as_ref(), + host, + port, + ), + ) + .await + .map_err(|_| ServerError::gateway_timeout("connect timeout"))? + .map_err(ServerError::bad_gateway)?; + + upstream.set_nodelay(true)?; + + let frames_written: u64 = if !early_data.is_empty() { + let mut buf = BytesMut::from(&early_data[..]); + let mut count: u64 = 0; + while let Some((_seq, data)) = + shaper::decode_from_buffer(&mut buf, None, state.traffic_config.encoding_type)? + { + upstream.write_all(&data).await.map_err(|e| { + ServerError::bad_gateway(format!("initial upload write error: {e}")) + })?; + count += 1; + } + if !buf.is_empty() { + return Err(ServerError::bad_request( + "trailing data in initial upload body", + )); + } + count + } else { + 0 + }; + + let (upstream_read, upstream_write) = upstream.into_split(); + let session_id = utils::extract_cookie_value(&headers, "session") + .map(|s| s.to_owned()) + .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()); + let (frame_tx, frame_rx) = mpsc::channel::(1); + + let stream = Arc::new(UploadStream::new(frame_tx, None)); + match state.streams.entry(session_id.clone()) { + dashmap::mapref::entry::Entry::Occupied(_) => { + return Err(ServerError::bad_request("stream already exists")); + } + dashmap::mapref::entry::Entry::Vacant(entry) => { + entry.insert(Arc::clone(&stream)); + } + } + + info!(stream_id = %session_id, user = %user, target = %target, + initial_frames = %frames_written, "stream established"); + + tokio::spawn( + connection::ordered_frame_writer( + frame_rx, + upstream_write, + session_id.clone(), + Arc::clone(&stream), + frames_written, + ) + .instrument(tracing::Span::current()), + ); + + let traffic_config = Arc::clone(&state.traffic_config); + let shaper = + crate::shaper::TrafficShaper::with_seq(upstream_read, (*traffic_config).clone(), None, 0); + + let shutdown_fut = { + let notify = Arc::clone(&stream.shutdown); + Box::pin(async move { notify.notified().await }) + }; + + let download_stream = DownloadStream { + shaper: Box::pin(shaper), + stream: Arc::clone(&stream), + streams: Arc::clone(&state.streams), + map_key: session_id.clone(), + log_key: session_id, + shutdown_fut: Some(shutdown_fut), + done: false, + }; + + let padding = utils::random_padding(); + let resp = Response::builder() + .header("Cache-Control", "no-store") + .header("Set-Cookie", padding) + .body(Body::from_stream(download_stream)) + .map_err(|e| ServerError::internal(e.to_string()))?; + + Ok(resp) +} + +async fn handle_fresh_handshake( + state: Arc, + headers: HeaderMap, + body: Body, + span: tracing::Span, +) -> Result { + let user = validate_jwt_if_needed(&headers, false, &state.decoding_key, &state.jwt_validation)?; + span.record("user", &user); + + info!(user = %user, "handshake: received ClientHello"); + + let eph_pk_a_b64 = utils::extract_cookie_value(&headers, "eph_pk_a") + .ok_or_else(|| ServerError::bad_request("missing eph_pk_a cookie"))?; + let eph_pk_a_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD + .decode(eph_pk_a_b64.as_bytes()) + .map_err(|_| ServerError::bad_request("invalid eph_pk_a base64"))?; + if eph_pk_a_bytes.len() != 32 { + return Err(ServerError::bad_request("eph_pk_a must be 32 bytes")); + } + let eph_pk_a_arr: [u8; 32] = eph_pk_a_bytes[..].try_into().unwrap(); + let eph_pk_a = x25519_dalek::PublicKey::from(eph_pk_a_arr); + + let private_key = state + .private_key + .as_ref() + .ok_or_else(|| ServerError::internal("server private key not configured"))?; + + let shared_a = crypto::diffie_hellman(private_key, &eph_pk_a); + let handshake_key = crypto::derive_handshake_key(&*shared_a); + let handshake_cipher = AesFrameCipher::new(AesKey::from(*handshake_key)); + + let body_bytes = axum::body::to_bytes(body, MAX_UPLOAD_BODY_SIZE) + .await + .map_err(|e| ServerError::bad_request(format!("failed to read body: {e}")))?; + + let mut buf = BytesMut::from(&body_bytes[..]); + let Some((_seq, client_hello_data)) = shaper::decode_from_buffer( + &mut buf, + Some(&handshake_cipher as &dyn FrameCipher), + state.traffic_config.encoding_type, + )? + else { + return Err(ServerError::bad_request("invalid ClientHello frame")); + }; + + if client_hello_data.len() < 2 { + return Err(ServerError::bad_request("ClientHello too short")); + } + + let len_kem = u16::from_be_bytes([client_hello_data[0], client_hello_data[1]]) as usize; + let kem_end = 2 + len_kem; + if client_hello_data.len() < kem_end { + return Err(ServerError::bad_request("ClientHello truncated (KEM part)")); + } + let kem_pk_bytes = &client_hello_data[2..kem_end]; + let kem_pk = crypto::bytes_to_encapsulation_key(kem_pk_bytes) + .map_err(|e| ServerError::bad_request(format!("invalid mlkem public key: {e}")))?; + + if client_hello_data.len() < kem_end + 2 { + return Err(ServerError::bad_request( + "ClientHello truncated (no X25519 part)", + )); + } + let len_x25519 = + u16::from_be_bytes([client_hello_data[kem_end], client_hello_data[kem_end + 1]]) as usize; + let x25519_end = kem_end + 2 + len_x25519; + if client_hello_data.len() < x25519_end { + return Err(ServerError::bad_request( + "ClientHello truncated (X25519 part)", + )); + } + let eph_pk_b_bytes = &client_hello_data[kem_end + 2..x25519_end]; + let eph_pk_b: [u8; 32] = eph_pk_b_bytes + .try_into() + .map_err(|_| ServerError::bad_request("invalid client x25519 pk length"))?; + let client_eph_pk_b = x25519_dalek::PublicKey::from(eph_pk_b); + + let (ct, ss_mlkem) = crypto::mlkem_encapsulate(&kem_pk); + let (server_eph_sk, server_eph_pk) = crypto::generate_keypair(); + + let master = { + let server_eph_sk = Zeroizing::new(server_eph_sk); + let ss_x25519 = crypto::diffie_hellman(&server_eph_sk, &client_eph_pk_b); + crypto::derive_initial_master(&*ss_mlkem, &*ss_x25519) + }; + + let session_id = uuid::Uuid::new_v4().to_string(); + state.master_store.insert( + session_id.clone(), + (user.clone(), master, std::time::Instant::now()), + ); + + info!(session_id = %session_id, user = %user, "handshake: master key derived"); + + let ct_bytes: &[u8] = &ct; + let ct_bytes = ct_bytes.to_vec(); + let sid_bytes = session_id.as_bytes(); + let mut server_hello = Vec::with_capacity(2 + sid_bytes.len() + ct_bytes.len() + 32); + server_hello.extend_from_slice(&(sid_bytes.len() as u16).to_be_bytes()); + server_hello.extend_from_slice(sid_bytes); + server_hello.extend_from_slice(&ct_bytes); + server_hello.extend_from_slice(server_eph_pk.as_bytes()); + + let server_hello_frame = shaper::encode_frame( + &server_hello, + 0, + Some(&handshake_cipher as &dyn FrameCipher), + &state.traffic_config, + ) + .map_err(|e| ServerError::internal(format!("encode ServerHello: {e}")))?; + + drop(handshake_key); + + info!(session_id = %session_id, "handshake: ServerHello sent"); + + let padding = utils::random_padding(); + let resp = Response::builder() + .header("Cache-Control", "no-store") + .header("Set-Cookie", padding) + .body(Body::from(server_hello_frame)) + .map_err(|e| ServerError::internal(e.to_string()))?; + + Ok(resp) +} + +async fn handle_pq_download( + state: Arc, + cookie_val: &str, + early_data: Bytes, + span: tracing::Span, +) -> Result { + let parts: Vec<&str> = cookie_val.splitn(3, ':').collect(); + if parts.len() != 3 { + return Err(ServerError::bad_request("invalid session cookie format")); + } + let (session_id, enc_target_b64, enc_nonce_b64) = (parts[0], parts[1], parts[2]); + info!(session_id = %session_id, body_len = %early_data.len(), + "session resumption: download request received"); + + let entry = state + .master_store + .get(session_id) + .ok_or_else(|| ServerError::precondition_required("session not found"))?; + + let value_ref = entry.value(); + let mut master = Zeroizing::new([0u8; 32]); + let (username, master_z, created) = value_ref; + master.copy_from_slice(&**master_z); + let username = username.clone(); + let created = *created; + if created.elapsed() > MASTER_EXPIRY { + drop(entry); + state.master_store.remove(session_id); + return Err(ServerError::precondition_required("master key expired")); + } + + let cookie_master_key = crypto::derive_cookie_master_key(&*master); + + let enc_target = base64::engine::general_purpose::URL_SAFE_NO_PAD + .decode(enc_target_b64) + .map_err(|_| ServerError::bad_request("invalid cookie encoding"))?; + let target_bytes = crypto::decrypt_bytes(&cookie_master_key, &enc_target) + .map_err(|_| ServerError::bad_request("failed to decrypt target"))?; + let target = String::from_utf8(target_bytes) + .map_err(|_| ServerError::bad_request("invalid target utf8"))?; + + let enc_nonce = base64::engine::general_purpose::URL_SAFE_NO_PAD + .decode(enc_nonce_b64) + .map_err(|_| ServerError::bad_request("invalid cookie encoding"))?; + let conn_nonce_bytes = crypto::decrypt_bytes(&cookie_master_key, &enc_nonce) + .map_err(|_| ServerError::bad_request("failed to decrypt conn_nonce"))?; + let conn_nonce: [u8; 16] = conn_nonce_bytes + .try_into() + .map_err(|_| ServerError::bad_request("invalid conn_nonce length"))?; + + { + let nonce_set = state.used_nonces.entry(session_id.to_string()).or_default(); + if !nonce_set.insert(conn_nonce) { + return Err(ServerError::precondition_required( + "nonce already used (replay detected)", + )); + } + } + + drop(entry); + + let (upload_key, download_key) = crypto::derive_connection_keys(&*master, &conn_nonce); + let upload_cipher = Arc::new(AesFrameCipher::new(upload_key)); + let download_cipher = Arc::new(AesFrameCipher::new(download_key)); + + let (host, port_str) = target + .rsplit_once(':') + .ok_or_else(|| ServerError::bad_request("target must be host:port"))?; + let port: u16 = port_str + .parse() + .map_err(|_| ServerError::bad_request("invalid port"))?; + + span.record("user", &username); + span.record("target", &target); + + info!(session_id = %session_id, user = %username, "session resumption: connecting to {}", target); + + let mut upstream = tokio::time::timeout( + CONNECT_TIMEOUT, + connect_upstream( + state.dns_client.as_ref(), + state.client_subnet, + state.socks5_proxy.as_ref(), + host, + port, + ), + ) + .await + .map_err(|_| ServerError::gateway_timeout("connect timeout"))? + .map_err(ServerError::bad_gateway)?; + + upstream.set_nodelay(true)?; + + let upload_cipher_ref: &dyn FrameCipher = upload_cipher.as_ref() as &dyn FrameCipher; + let frames_written: u64 = if !early_data.is_empty() { + let mut buf = BytesMut::from(&early_data[..]); + let mut count: u64 = 0; + while let Some((_seq, data)) = shaper::decode_from_buffer( + &mut buf, + Some(upload_cipher_ref), + state.traffic_config.encoding_type, + )? { + upstream.write_all(&data).await.map_err(|e| { + ServerError::bad_gateway(format!("initial upload write error: {e}")) + })?; + count += 1; + } + if !buf.is_empty() { + return Err(ServerError::bad_request( + "trailing data in initial upload body", + )); + } + count + } else { + 0 + }; + + let (upstream_read, upstream_write) = upstream.into_split(); + let (frame_tx, frame_rx) = mpsc::channel::(1); + + let stream = Arc::new(UploadStream::new(frame_tx, Some(upload_cipher))); + match state.streams.entry(cookie_val.to_owned()) { + dashmap::mapref::entry::Entry::Occupied(_) => { + return Err(ServerError::bad_request("stream already exists")); + } + dashmap::mapref::entry::Entry::Vacant(entry) => { + entry.insert(Arc::clone(&stream)); + } + } + + let display_key = session_id.to_owned(); + + tokio::spawn( + connection::ordered_frame_writer( + frame_rx, + upstream_write, + display_key, + Arc::clone(&stream), + frames_written, + ) + .instrument(tracing::Span::current()), + ); + + let download_cipher: Option> = + Some(download_cipher as Arc); + let traffic_config = Arc::clone(&state.traffic_config); + let shaper = crate::shaper::TrafficShaper::with_seq( + upstream_read, + (*traffic_config).clone(), + download_cipher, + 0, + ); + + let shutdown_fut = { + let notify = Arc::clone(&stream.shutdown); + Box::pin(async move { notify.notified().await }) + }; + + let download_stream = DownloadStream { + shaper: Box::pin(shaper), + stream: Arc::clone(&stream), + streams: Arc::clone(&state.streams), + map_key: cookie_val.to_owned(), + log_key: session_id.to_owned(), + shutdown_fut: Some(shutdown_fut), + done: false, + }; + + let padding = utils::random_padding(); + let resp = Response::builder() + .header("Cache-Control", "no-store") + .header("Set-Cookie", padding) + .body(Body::from_stream(download_stream)) + .map_err(|e| ServerError::internal(e.to_string()))?; + + Ok(resp) +} + +async fn handle_stream_upload( + state: Arc, + cookie_val: String, + body: Body, + _span: tracing::Span, +) -> Result { + let stream = state + .streams + .get(&cookie_val) + .map(|r| Arc::clone(r.value())) + .ok_or_else(|| ServerError::not_found("unknown upload stream"))?; + + let session_id = cookie_val.split(':').next().unwrap_or(&cookie_val); + if let Some(entry) = state.master_store.get(session_id) { + let user = &entry.value().0; + tracing::Span::current().record("user", user); + } + + let cipher_ref: Option<&dyn FrameCipher> = stream + .upload_cipher + .as_deref() + .map(|c| c as &dyn FrameCipher); + let encoding_type = state.traffic_config.encoding_type; + + let mut body = body.into_data_stream(); + let mut buf = BytesMut::with_capacity(8192); + let mut total_read = 0usize; + let mut max_seq = 0u64; + + while let Some(chunk) = body.next().await { + let chunk = + chunk.map_err(|e| ServerError::bad_request(format!("failed to read body: {e}")))?; + total_read += chunk.len(); + if total_read > MAX_UPLOAD_BODY_SIZE { + return Err(ServerError::payload_too_large( + "body exceeds max upload size", + )); + } + buf.extend_from_slice(&chunk); + + loop { + let Some((seq, data)) = + shaper::decode_from_buffer(&mut buf, cipher_ref, encoding_type)? + else { + break; + }; + if seq > max_seq { + max_seq = seq; + } + stream + .tx + .send(FrameOrEos::Data { seq, data }) + .await + .map_err(|_| ServerError::bad_gateway("upload channel closed"))?; + } + } + + if !buf.is_empty() { + return Err(ServerError::bad_request("incomplete frame in batch body")); + } + + let (done_tx, done_rx) = oneshot::channel(); + stream + .tx + .send(FrameOrEos::Eos { + max_seq, + done: done_tx, + }) + .await + .map_err(|_| ServerError::bad_gateway("upload channel closed"))?; + + done_rx + .await + .map_err(|_| ServerError::bad_gateway("upload stream closed"))?; + + stream.touch(); + + let padding = utils::random_padding(); + let resp = Response::builder() + .header("Cache-Control", "no-store") + .header("Set-Cookie", padding) + .status(axum::http::StatusCode::NO_CONTENT) + .body(Body::empty()) + .map_err(|e| ServerError::internal(e.to_string()))?; + + Ok(resp) +} + +#[inline] +pub fn validate_jwt_if_needed( + headers: &HeaderMap, + has_valid_session: bool, + key: &DecodingKey, + validation: &Validation, +) -> Result { + if has_valid_session { + return Ok("session-resumed".into()); + } + + let auth_header = headers + .get("Authorization") + .and_then(|v| v.to_str().ok()) + .ok_or_else(|| { + warn!("rejected: missing or invalid authorization header"); + ServerError::unauthorized("invalid header") + })?; + + if !auth_header.starts_with("Bearer ") { + warn!("rejected: invalid authorization format"); + return Err(ServerError::unauthorized("invalid header")); + } + + let token = &auth_header[7..]; + + jsonwebtoken::decode::(token, key, validation) + .map(|td| td.claims.sub) + .map_err(|e| { + warn!("rejected: invalid token - {:?}", e); + ServerError::unauthorized("invalid token") + }) +} diff --git a/src/server/janitor.rs b/src/server/janitor.rs new file mode 100644 index 0000000..3c9bfa3 --- /dev/null +++ b/src/server/janitor.rs @@ -0,0 +1,45 @@ +use dashmap::{DashMap, DashSet}; +use std::sync::Arc; +use tracing::warn; + +use crate::server::constants::{JANITOR_INTERVAL, MASTER_EXPIRY, NONCE_CLEANUP_INTERVAL}; +use crate::server::state::UploadStream; + +pub async fn stream_janitor(streams: Arc>>) { + let mut interval = tokio::time::interval(JANITOR_INTERVAL); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + loop { + interval.tick().await; + let mut expired = vec![]; + for entry in streams.iter() { + let stream = entry.value(); + if stream.is_idle() && stream.do_shutdown() { + expired.push(entry.key().clone()); + } + } + for key in expired { + streams.remove(&key); + let display_id = key.split(':').next().unwrap_or(&key); + warn!(stream_id = %display_id, reason = "idle timeout", "shutting down idle stream"); + } + } +} + +pub async fn master_and_nonce_janitor( + master_store: Arc, std::time::Instant)>>, + used_nonces: Arc>>, +) { + let mut interval = tokio::time::interval(NONCE_CLEANUP_INTERVAL); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + loop { + interval.tick().await; + master_store.retain(|session_id, (_, _master, created)| { + if created.elapsed() >= MASTER_EXPIRY { + used_nonces.remove(session_id); + false + } else { + true + } + }); + } +} diff --git a/src/server/mod.rs b/src/server/mod.rs new file mode 100644 index 0000000..6fef97d --- /dev/null +++ b/src/server/mod.rs @@ -0,0 +1,179 @@ +pub mod connection; +pub mod constants; +pub mod handlers; +pub mod janitor; +pub mod state; +pub mod utils; + +use crate::config::ServerTopConfig; +use crate::crypto; +use crate::dns::{self, DnsClient}; +use crate::shaper::TrafficConfig; + +use anyhow::Context; +use axum::{Router, body::Body, routing::post}; +use dashmap::{DashMap, DashSet}; +use jsonwebtoken::{Algorithm, DecodingKey, Validation}; +use serde::{Deserialize, Serialize}; +use std::{ + net::{IpAddr, SocketAddr}, + sync::{ + Arc, + atomic::{AtomicU64, Ordering}, + }, + time::Instant, +}; +use tower::ServiceBuilder; +use tower_http::trace::TraceLayer; +use tracing::info; +use zeroize::Zeroizing; + +pub type MasterStoreEntry = (String, Zeroizing<[u8; 32]>, Instant); + +pub static NEXT_STREAM_ID: AtomicU64 = AtomicU64::new(1); + +#[derive(Debug, Serialize, Deserialize)] +pub struct Claims { + pub sub: String, + pub exp: u64, +} + +#[derive(Clone)] +pub struct AppState { + pub decoding_key: DecodingKey, + pub jwt_validation: Validation, + pub socks5_proxy: Option>, + pub dns_client: Option>, + pub client_subnet: Option, + pub traffic_config: Arc, + pub streams: Arc>>, + pub private_key: Option, + pub master_store: Arc>, + pub used_nonces: Arc>>, +} + +pub async fn build_state(config: &mut ServerTopConfig) -> anyhow::Result> { + let (dns_client, client_subnet) = match config.dns { + Some(ref mut dc) => { + let mut dc = dc.clone(); + let client = dns::init_dns(&mut dc).await?; + (Some(client), dc.options.client_subnet) + } + None => (None, None), + }; + + let private_key = config + .server + .private_key + .as_deref() + .map(crypto::b64_to_private_key) + .transpose()?; + + Ok(Arc::new(AppState { + decoding_key: DecodingKey::from_secret(config.auth.secret.as_bytes()), + jwt_validation: { + let mut v = Validation::new(Algorithm::HS256); + v.validate_exp = true; + v + }, + socks5_proxy: config + .proxy + .as_ref() + .and_then(|p| p.socks5.as_deref()) + .map(Arc::from), + dns_client, + client_subnet, + traffic_config: Arc::new(config.traffic_shaping.clone()), + streams: Arc::new(DashMap::new()), + private_key, + master_store: Arc::new(DashMap::new()), + used_nonces: Arc::new(DashMap::new()), + })) +} + +pub fn build_router(state: Arc, path: &str) -> Router { + use tracing::field::Empty; + Router::new() + .route(path, post(handlers::dispatch)) + .layer( + ServiceBuilder::new().layer(TraceLayer::new_for_http().make_span_with( + |req: &axum::http::Request| { + let id = NEXT_STREAM_ID.fetch_add(1, Ordering::Relaxed); + let client = req + .headers() + .get("X-Forwarded-For") + .and_then(|h| h.to_str().ok()) + .unwrap_or("-"); + tracing::error_span!("session", id, client, user = Empty, target = Empty) + }, + )), + ) + .with_state(state) +} + +pub async fn run_server(app: Router, listen: &str) -> anyhow::Result<()> { + #[cfg(unix)] + if listen.contains('/') || listen.ends_with(".sock") { + let path = std::path::Path::new(listen); + if path.exists() { + std::fs::remove_file(path)?; + } + let listener = tokio::net::UnixListener::bind(path)?; + use std::os::unix::fs::PermissionsExt; + std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o666))?; + info!("listening on unix:{listen}"); + return Ok(axum::serve(listener, app.into_make_service()).await?); + } + let addr: SocketAddr = listen.parse().context("invalid bind address")?; + info!("listening on {addr}"); + let listener = tokio::net::TcpListener::bind(addr).await?; + axum::serve(listener, app).await?; + Ok(()) +} + +pub fn spawn_janitors( + state: &Arc, +) -> (tokio::task::JoinHandle<()>, tokio::task::JoinHandle<()>) { + let stream_handle = tokio::spawn(janitor::stream_janitor(Arc::clone(&state.streams))); + let master_handle = tokio::spawn(janitor::master_and_nonce_janitor( + Arc::clone(&state.master_store), + Arc::clone(&state.used_nonces), + )); + (stream_handle, master_handle) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::{AuthSection, ServerSection}; + use crate::shaper; + + #[test] + fn build_state_with_minimal_config() { + let mut cfg = ServerTopConfig { + server: ServerSection { + listen: "0.0.0.0:0".into(), + path: "/t".into(), + private_key: None, + }, + auth: AuthSection { + secret: "test-key".into(), + }, + proxy: None, + log: None, + dns: None, + traffic_shaping: TrafficConfig { + global: shaper::PaddingConfig { + padding_threshold: 100, + padding_range: [0, 50], + }, + stages: vec![], + encoding_type: Default::default(), + }, + }; + let rt = tokio::runtime::Runtime::new().unwrap(); + let state = rt.block_on(build_state(&mut cfg)).unwrap(); + assert!(state.private_key.is_none()); + assert!(state.dns_client.is_none()); + } +} diff --git a/src/server/state.rs b/src/server/state.rs new file mode 100644 index 0000000..5f91b47 --- /dev/null +++ b/src/server/state.rs @@ -0,0 +1,129 @@ +use bytes::Bytes; +use dashmap::DashMap; +use futures::{Future, Stream}; +use std::{ + pin::Pin, + sync::{ + Arc, + atomic::{AtomicBool, AtomicU64, Ordering}, + }, + task::Poll, +}; +use tokio::sync::{Notify, mpsc, oneshot}; +use tracing::{info, warn}; + +use crate::crypto::AesFrameCipher; +use crate::server::constants::{STREAM_IDLE_TIMEOUT_SECS, now_secs}; + +pub enum FrameOrEos { + Data { + seq: u64, + data: Bytes, + }, + Eos { + max_seq: u64, + done: oneshot::Sender<()>, + }, +} + +pub struct UploadStream { + pub last_activity: AtomicU64, + pub tx: mpsc::Sender, + pub upload_cipher: Option>, + pub shutdown: Arc, + shutdown_flag: AtomicBool, +} + +impl UploadStream { + #[inline] + pub fn new(tx: mpsc::Sender, upload_cipher: Option>) -> Self { + Self { + last_activity: AtomicU64::new(now_secs()), + tx, + upload_cipher, + shutdown: Arc::new(Notify::new()), + shutdown_flag: AtomicBool::new(false), + } + } + #[inline(always)] + pub fn touch(&self) { + self.last_activity.store(now_secs(), Ordering::Relaxed); + } + #[inline(always)] + pub fn is_idle(&self) -> bool { + now_secs().saturating_sub(self.last_activity.load(Ordering::Relaxed)) + > STREAM_IDLE_TIMEOUT_SECS + } + #[inline(always)] + pub fn do_shutdown(&self) -> bool { + if self.shutdown_flag.swap(true, Ordering::AcqRel) { + false + } else { + self.shutdown.notify_one(); + true + } + } +} + +type ShaperStream = Pin> + Send>>; + +pub struct DownloadStream { + pub shaper: ShaperStream, + pub stream: Arc, + pub streams: Arc>>, + pub map_key: String, + pub log_key: String, + pub shutdown_fut: Option + Send>>>, + pub done: bool, +} + +impl Stream for DownloadStream { + type Item = std::io::Result; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let this = self.get_mut(); + if this.done { + return Poll::Ready(None); + } + + if let Some(fut) = this.shutdown_fut.as_mut() + && fut.as_mut().poll(cx).is_ready() + { + info!(stream_id = %this.log_key, reason = "shutdown signal", "download stream ended"); + this.done = true; + this.shutdown_fut = None; + return Poll::Ready(None); + } + + match this.shaper.as_mut().poll_next(cx) { + Poll::Ready(Some(Ok((_seq, data)))) => { + this.stream.touch(); + Poll::Ready(Some(Ok(data))) + } + Poll::Ready(Some(Err(e))) => { + warn!(stream_id = %this.log_key, error = %e, "upstream read error"); + this.done = true; + Poll::Ready(Some(Err(e))) + } + Poll::Ready(None) => { + info!(stream_id = %this.log_key, reason = "upstream closed", "download stream ended"); + this.done = true; + Poll::Ready(None) + } + Poll::Pending => Poll::Pending, + } + } +} + +impl Drop for DownloadStream { + fn drop(&mut self) { + if !self.done { + info!(stream_id = %self.log_key, reason = "client disconnected", "download stream ended"); + } + self.stream.do_shutdown(); + self.streams.remove(&self.map_key); + } +} diff --git a/src/server/utils.rs b/src/server/utils.rs new file mode 100644 index 0000000..ef716f0 --- /dev/null +++ b/src/server/utils.rs @@ -0,0 +1,78 @@ +use rand::RngExt; + +use crate::server::constants::PADDING_POOL; + +#[inline] +pub fn extract_cookie_value<'a>(headers: &'a axum::http::HeaderMap, key: &str) -> Option<&'a str> { + let cookie_header = headers.get("Cookie")?.as_bytes(); + let cookie_str = std::str::from_utf8(cookie_header).ok()?; + let key_bytes = key.as_bytes(); + let key_len = key_bytes.len(); + + let mut pos = 0; + let haystack = cookie_str.as_bytes(); + let haystack_len = haystack.len(); + + while pos < haystack_len { + while pos < haystack_len && (haystack[pos] == b' ' || haystack[pos] == b';') { + pos += 1; + } + if pos >= haystack_len { + break; + } + + if pos + key_len < haystack_len + && &haystack[pos..pos + key_len] == key_bytes + && haystack[pos + key_len] == b'=' + { + let val_start = pos + key_len + 1; + let val_end = memchr::memchr(b';', &haystack[val_start..]) + .map(|i| val_start + i) + .unwrap_or(haystack_len); + let val = &cookie_str[val_start..val_end]; + return Some(val.trim()); + } + + match memchr::memchr(b';', &haystack[pos..]) { + Some(i) => pos += i + 1, + None => break, + } + } + None +} + +#[inline(always)] +pub fn random_padding() -> &'static [u8] { + let padding_len = rand::rng().random_range(30..=PADDING_POOL.len()); + &PADDING_POOL[..padding_len] +} + +#[cfg(test)] +mod tests { + use super::*; + use axum::http::HeaderMap; + + #[test] + fn extract_cookie_basic() { + let mut headers = HeaderMap::new(); + headers.insert("Cookie", "session=abc123; other=val".parse().unwrap()); + assert_eq!(extract_cookie_value(&headers, "session"), Some("abc123")); + assert_eq!(extract_cookie_value(&headers, "other"), Some("val")); + assert_eq!(extract_cookie_value(&headers, "missing"), None); + } + + #[test] + fn extract_cookie_no_cookie_header() { + let headers = HeaderMap::new(); + assert_eq!(extract_cookie_value(&headers, "session"), None); + } + + #[test] + fn random_padding_length() { + for _ in 0..100 { + let p = random_padding(); + assert!(p.len() >= 30); + assert!(p.len() <= PADDING_POOL.len()); + } + } +} diff --git a/src/shaper.rs b/src/shaper.rs deleted file mode 100644 index 787f4d7..0000000 --- a/src/shaper.rs +++ /dev/null @@ -1,246 +0,0 @@ -use bytes::{Buf, BufMut, Bytes, BytesMut}; -use pin_project_lite::pin_project; -use rand::{Rng, RngExt, seq::SliceRandom}; -use rand_distr::{Distribution, Normal}; -use serde::Deserialize; -use std::{ - io::{Error, ErrorKind}, - pin::Pin, - sync::OnceLock, - task::{Context, Poll}, - time::Duration, -}; -use tokio::{ - io::{AsyncRead, ReadBuf}, - time::{Instant, Sleep}, -}; - -const TABLE_SIZE: usize = 1024; -const TABLE_MASK: usize = TABLE_SIZE - 1; -const CHUNK_SIZE: usize = 16 * 1024; -const HEADER_SIZE: usize = 4; -const MAX_PAYLOAD: usize = CHUNK_SIZE - HEADER_SIZE; -const AVG_LATENCY_MICROS: f64 = 5_000.0; - -static JITTER_TABLE: OnceLock> = OnceLock::new(); - -#[derive(Debug, Deserialize, Clone)] -pub struct PaddingConfig { - pub padding_threshold: usize, - pub padding_range: [usize; 2], -} - -#[derive(Debug, Deserialize, Clone)] -pub struct StageConfig { - pub count: Option, - pub count_range: Option<[usize; 2]>, - pub padding_threshold: usize, - pub padding_range: [usize; 2], -} - -#[derive(Debug, Deserialize, Clone)] -pub struct TrafficConfig { - pub global: PaddingConfig, - pub stages: Vec, -} - -#[derive(Debug, Clone, Copy)] -struct ResolvedStage { - end_count: usize, - padding_threshold: usize, - padding_range: [usize; 2], -} - -pin_project! { - #[project = TrafficShaperProj] - pub struct TrafficShaper { - #[pin] - reader: R, - frame_buffer: BytesMut, - #[pin] - flush_timer: Sleep, - cursor: usize, - stages: Vec, - global_threshold: usize, - global_range: [usize; 2], - packet_count: usize, - stage_idx: usize, - } -} - -impl TrafficShaper<()> { - pub fn decode_from_buffer(src: &mut BytesMut) -> Result, Error> { - if src.len() < HEADER_SIZE { - return Ok(None); - } - - let header = u32::from_be_bytes([src[0], src[1], src[2], src[3]]); - let actual_len = (header >> 16) as usize; - let total_len = (header & 0xFFFF) as usize; - - if total_len > MAX_PAYLOAD || actual_len > total_len { - return Err(Error::new(ErrorKind::InvalidData, "invalid frame size")); - } - - let full_frame_len = HEADER_SIZE + total_len; - if src.len() < full_frame_len { - return Ok(None); - } - - let mut frame = src.split_to(full_frame_len); - frame.advance(HEADER_SIZE); - frame.truncate(actual_len); - Ok(Some(frame.freeze())) - } -} - -impl TrafficShaper { - pub fn new(reader: R, config: TrafficConfig) -> Self { - let cursor = (rand::rng().next_u64() as usize) & TABLE_MASK; - - let mut stages: Vec = config - .stages - .iter() - .map(|s| ResolvedStage { - end_count: s - .count - .or_else(|| s.count_range.map(|[_, hi]| hi)) - .unwrap_or(0), - padding_threshold: s.padding_threshold, - padding_range: s.padding_range, - }) - .collect(); - stages.sort_unstable_by_key(|s| s.end_count); - - let mut frame_buffer = BytesMut::with_capacity(CHUNK_SIZE); - unsafe { frame_buffer.advance_mut(HEADER_SIZE) } - - Self { - reader, - frame_buffer, - flush_timer: tokio::time::sleep_until(Instant::now()), - stages, - global_threshold: config.global.padding_threshold, - global_range: config.global.padding_range, - packet_count: 0, - cursor, - stage_idx: 0, - } - } - - #[inline(always)] - fn prepare_next_frame(buf: &mut BytesMut) { - buf.reserve(CHUNK_SIZE); - unsafe { buf.advance_mut(HEADER_SIZE) } - } - - fn seal_and_emit(this: &mut TrafficShaperProj<'_, R>, actual_len: usize) -> Bytes { - *this.packet_count += 1; - - while let Some(stage) = this.stages.get(*this.stage_idx) { - if *this.packet_count <= stage.end_count { - break; - } - *this.stage_idx += 1; - } - - let (threshold, range) = match this.stages.get(*this.stage_idx) { - Some(s) => (s.padding_threshold, s.padding_range), - None => (*this.global_threshold, *this.global_range), - }; - - let padding_len = if actual_len < threshold { - rand::rng() - .random_range(range[0]..=range[1]) - .min(MAX_PAYLOAD - actual_len) - } else { - 0 - }; - - let total_payload = actual_len + padding_len; - let header = ((actual_len as u32) << 16) | (total_payload as u32); - unsafe { - std::ptr::copy_nonoverlapping( - header.to_be_bytes().as_ptr(), - this.frame_buffer.as_mut_ptr(), - HEADER_SIZE, - ); - } - - if padding_len > 0 { - this.frame_buffer.put_bytes(0, padding_len); - } - - let frame = this.frame_buffer.split().freeze(); - Self::prepare_next_frame(this.frame_buffer); - frame - } -} - -impl tokio_stream::Stream for TrafficShaper { - type Item = Result; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut this = self.project(); - - loop { - let actual_len = this.frame_buffer.len() - HEADER_SIZE; - let remaining = MAX_PAYLOAD - actual_len; - - if remaining == 0 { - return Poll::Ready(Some(Ok(Self::seal_and_emit(&mut this, actual_len)))); - } - - if actual_len > 0 && this.flush_timer.as_mut().poll(cx).is_ready() { - return Poll::Ready(Some(Ok(Self::seal_and_emit(&mut this, actual_len)))); - } - - let spare = this.frame_buffer.spare_capacity_mut(); - let read_limit = spare.len().min(remaining); - let mut read_buf = ReadBuf::uninit(&mut spare[..read_limit]); - - match this.reader.as_mut().poll_read(cx, &mut read_buf) { - Poll::Ready(Ok(())) => { - let n = read_buf.filled().len(); - if n == 0 { - return if actual_len == 0 { - Poll::Ready(None) - } else { - Poll::Ready(Some(Ok(Self::seal_and_emit(&mut this, actual_len)))) - }; - } - - if actual_len == 0 && n < MAX_PAYLOAD { - let idx = *this.cursor; - let delay_us = jitter_table()[idx]; - *this.cursor = (idx + 1) & TABLE_MASK; - this.flush_timer - .as_mut() - .reset(Instant::now() + Duration::from_micros(delay_us)); - let _ = this.flush_timer.as_mut().poll(cx); - } - - unsafe { this.frame_buffer.advance_mut(n) } - } - Poll::Pending => return Poll::Pending, - Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e))), - } - } - } -} - -fn jitter_table() -> &'static [u64; TABLE_SIZE] { - JITTER_TABLE.get_or_init(|| { - let std_dev = AVG_LATENCY_MICROS / 3.0; - let normal = Normal::new(AVG_LATENCY_MICROS, std_dev).unwrap(); - let mut rng = rand::rng(); - let max_val = AVG_LATENCY_MICROS * 2.0; - - let mut table = Box::new([0u64; TABLE_SIZE]); - for slot in table.iter_mut() { - *slot = normal.sample(&mut rng).clamp(0.0, max_val) as u64; - } - table.shuffle(&mut rng); - table - }) -} diff --git a/src/shaper/mod.rs b/src/shaper/mod.rs new file mode 100644 index 0000000..d7374a3 --- /dev/null +++ b/src/shaper/mod.rs @@ -0,0 +1,618 @@ +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use pin_project_lite::pin_project; +use rand::rngs::SmallRng; +use rand::{Rng, RngExt, SeedableRng, seq::SliceRandom}; +use rand_distr::{Distribution, LogNormal}; +use serde::Deserialize; +use std::{ + io::{Error, ErrorKind}, + pin::Pin, + sync::{Arc, OnceLock}, + task::{Context, Poll}, + time::Duration, +}; +use tokio::{ + io::{AsyncRead, ReadBuf}, + time::{Instant, Sleep}, +}; + +pub const MAX_RAW_PAYLOAD: usize = 16 * 1024; + +const TABLE_SIZE: usize = 8192; +const TABLE_MASK: usize = TABLE_SIZE - 1; +const DELIMITER: u8 = b'\n'; +const AVG_LATENCY_MICROS: f64 = 5_000.0; +const MAX_BINARY_FRAME_LEN: usize = MAX_RAW_PAYLOAD + 38; +const MAX_JSON_LINE_LEN: usize = MAX_RAW_PAYLOAD + 2396; +const HEADER_LEN: usize = 10; + +static JITTER_TABLE: OnceLock> = OnceLock::new(); + +#[inline] +fn jitter_table() -> &'static [u64; TABLE_SIZE] { + JITTER_TABLE.get_or_init(|| { + let sigma = 0.5; + let avg = AVG_LATENCY_MICROS; + let mu = avg.ln() - (sigma * sigma) * 0.5; + let log_normal = LogNormal::new(mu, sigma).expect("Invalid parameters"); + let mut rng = rand::rng(); + let mut table = Box::new([0u64; TABLE_SIZE]); + for slot in table.iter_mut() { + *slot = log_normal.sample(&mut rng).round() as u64; + } + table.shuffle(&mut rng); + table + }) +} + +#[derive(Debug, Deserialize, Clone)] +#[serde(deny_unknown_fields)] +pub struct PaddingConfig { + pub padding_threshold: usize, + pub padding_range: [usize; 2], +} + +#[derive(Debug, Deserialize, Clone)] +#[serde(deny_unknown_fields)] +pub struct StageConfig { + pub count: Option, + pub count_range: Option<[usize; 2]>, + pub padding_threshold: usize, + pub padding_range: [usize; 2], +} + +#[derive(Debug, Deserialize, Clone, Copy)] +#[serde(rename_all = "lowercase")] +pub enum EncodingType { + Json, + Binary, +} + +impl Default for EncodingType { + #[inline] + fn default() -> Self { + EncodingType::Binary + } +} + +#[derive(Debug, Deserialize, Clone)] +#[serde(deny_unknown_fields)] +pub struct TrafficConfig { + pub global: PaddingConfig, + #[serde(default)] + pub stages: Vec, + #[serde(default)] + pub encoding_type: EncodingType, +} + +#[derive(Debug, Clone, Copy)] +struct ResolvedStage { + end_count: usize, + padding_threshold: usize, + padding_range: [usize; 2], +} + +pub trait FrameCipher: Send + Sync { + fn encrypt(&self, data: &[u8]) -> Result, Error>; + fn decrypt(&self, data: &[u8]) -> Result, Error>; +} + +#[inline] +fn read_u64_be(data: &[u8]) -> u64 { + let mut buf = [0u8; 8]; + buf.copy_from_slice(&data[..8]); + u64::from_be_bytes(buf) +} + +#[inline] +fn read_u16_be(data: &[u8]) -> u16 { + let mut buf = [0u8; 2]; + buf.copy_from_slice(&data[..2]); + u16::from_be_bytes(buf) +} + +#[inline(always)] +fn extract_frame(payload: &[u8]) -> Result<(u64, Bytes), Error> { + if payload.len() < HEADER_LEN { + return Err(Error::new(ErrorKind::InvalidData, "payload too short")); + } + let seq = read_u64_be(&payload[..8]); + let orig_len = read_u16_be(&payload[8..10]) as usize; + let total = HEADER_LEN + orig_len; + if payload.len() < total { + return Err(Error::new( + ErrorKind::InvalidData, + "payload shorter than declared original length", + )); + } + Ok((seq, Bytes::copy_from_slice(&payload[HEADER_LEN..total]))) +} + +#[inline] +fn trim_bytes(mut b: &[u8]) -> &[u8] { + while let Some((&first, rest)) = b.split_first() { + if first.is_ascii_whitespace() { + b = rest; + } else { + break; + } + } + while let Some((&last, rest)) = b.split_last() { + if last.is_ascii_whitespace() { + b = rest; + } else { + break; + } + } + b +} + +#[inline] +fn parse_json_payload(json: &[u8]) -> Result, Error> { + let json = trim_bytes(json); + let err = |msg: &str| Error::new(ErrorKind::InvalidData, msg); + + const PREFIX: &[u8] = b"\"data\":\""; + + let finder = memchr::memmem::Finder::new(PREFIX); + let start = finder + .find(json) + .ok_or_else(|| err("missing 'data' field"))?; + let data_start = start + PREFIX.len(); + + let remaining = &json[data_start..]; + let data_end = + memchr::memchr(b'"', remaining).ok_or_else(|| err("malformed JSON structure"))?; + + let enc_str_bytes = &remaining[..data_end]; + let enc_str = + std::str::from_utf8(enc_str_bytes).map_err(|_| err("payload is not valid UTF-8"))?; + + base122_fast::decode(enc_str).map_err(err) +} + +pub fn encode_frame( + data: &[u8], + seq: u64, + cipher: Option<&dyn FrameCipher>, + config: &TrafficConfig, +) -> std::io::Result> { + let raw_len = data.len(); + let encoding = config.encoding_type; + + let padding_len = if raw_len < config.global.padding_threshold { + let max_pad = MAX_RAW_PAYLOAD - raw_len; + let wanted = rand::rng() + .random_range(config.global.padding_range[0]..=config.global.padding_range[1]); + wanted.min(max_pad) + } else { + 0 + }; + + let payload_len = HEADER_LEN + raw_len + padding_len; + let mut payload = Vec::with_capacity(payload_len); + payload.put_u64(seq); + payload.put_u16(raw_len as u16); + payload.extend_from_slice(data); + if padding_len > 0 { + payload.resize(payload_len, 0u8); + } + + if let Some(cipher) = cipher { + payload = cipher.encrypt(&payload)?; + } + + let mut frame = Vec::new(); + match encoding { + EncodingType::Binary => { + frame.put_u16(payload.len() as u16); + frame.extend_from_slice(&payload); + } + EncodingType::Json => { + let enc_str = base122_fast::encode(&payload); + frame.extend_from_slice(b"{\"data\":\""); + frame.extend_from_slice(enc_str.as_bytes()); + frame.extend_from_slice(b"\"}\n"); + } + } + Ok(frame) +} + +pub fn decode_from_buffer( + src: &mut BytesMut, + cipher: Option<&dyn FrameCipher>, + encoding: EncodingType, +) -> Result, Error> { + match encoding { + EncodingType::Binary => { + if src.len() < 2 { + return Ok(None); + } + let frame_len = read_u16_be(&src[..2]) as usize; + + if frame_len > MAX_BINARY_FRAME_LEN { + return Err(Error::new( + ErrorKind::InvalidData, + "binary frame length exceeds limit", + )); + } + if src.len() < 2 + frame_len { + return Ok(None); + } + src.advance(2); + let frame_data = src.split_to(frame_len); + + if let Some(c) = cipher { + let decrypted = c.decrypt(&frame_data)?; + Ok(Some(extract_frame(&decrypted)?)) + } else { + Ok(Some(extract_frame(&frame_data)?)) + } + } + + EncodingType::Json => { + let newline_pos = memchr::memchr(DELIMITER, src); + + match newline_pos { + Some(pos) => { + if pos > MAX_JSON_LINE_LEN { + return Err(Error::new( + ErrorKind::InvalidData, + "JSON line exceeds maximum allowed length", + )); + } + + let line = src.split_to(pos); + src.advance(1); + + if line.is_empty() { + return Err(Error::new(ErrorKind::InvalidData, "empty frame line")); + } + + let encoded_payload = parse_json_payload(&line)?; + + if let Some(c) = cipher { + let decrypted = c.decrypt(&encoded_payload)?; + Ok(Some(extract_frame(&decrypted)?)) + } else { + Ok(Some(extract_frame(&encoded_payload)?)) + } + } + None => { + if src.len() > MAX_JSON_LINE_LEN { + return Err(Error::new( + ErrorKind::InvalidData, + "incomplete JSON line is too long", + )); + } + Ok(None) + } + } + } + } +} + +pin_project! { + #[project = Proj] + pub struct TrafficShaper { + #[pin] + reader: R, + + raw_buf: BytesMut, + out_buf: BytesMut, + + #[pin] + flush_timer: Sleep, + timer_armed: bool, + cursor: usize, + stages: Vec, + global_threshold: usize, + global_range: [usize; 2], + packet_count: usize, + stage_idx: usize, + rng: SmallRng, + cipher: Option>, + encoding: EncodingType, + seq: u64, + } +} + +impl TrafficShaper { + pub fn with_seq( + reader: R, + config: TrafficConfig, + cipher: Option>, + start_seq: u64, + ) -> Self { + let mut base_rng = rand::rng(); + let cursor = (base_rng.next_u64() as usize) & TABLE_MASK; + + let mut stages: Vec = config + .stages + .iter() + .map(|s| ResolvedStage { + end_count: s + .count + .or_else(|| s.count_range.map(|[_, hi]| hi)) + .unwrap_or(0), + padding_threshold: s.padding_threshold, + padding_range: s.padding_range, + }) + .collect(); + stages.sort_unstable_by_key(|s| s.end_count); + + let out_capacity = match config.encoding_type { + EncodingType::Binary => MAX_BINARY_FRAME_LEN + 2, + EncodingType::Json => MAX_JSON_LINE_LEN + 1, + }; + + Self { + reader, + raw_buf: BytesMut::with_capacity(MAX_RAW_PAYLOAD), + out_buf: BytesMut::with_capacity(out_capacity), + flush_timer: tokio::time::sleep_until(Instant::now()), + timer_armed: false, + stages, + global_threshold: config.global.padding_threshold, + global_range: config.global.padding_range, + packet_count: 0, + cursor, + stage_idx: 0, + rng: SmallRng::from_rng(&mut base_rng), + cipher, + encoding: config.encoding_type, + seq: start_seq, + } + } + + #[inline] + fn seal_and_emit(this: &mut Proj<'_, R>) -> Result<(u64, Bytes), Error> { + let raw_len = this.raw_buf.len(); + debug_assert!(raw_len > 0); + debug_assert!(raw_len <= MAX_RAW_PAYLOAD); + + *this.timer_armed = false; + + *this.packet_count += 1; + let seq = *this.seq; + *this.seq = seq + 1; + + let stages = &this.stages; + let pc = *this.packet_count; + let mut si = *this.stage_idx; + while si < stages.len() && pc > stages[si].end_count { + si += 1; + } + *this.stage_idx = si; + + let (threshold, range) = if si < stages.len() { + (stages[si].padding_threshold, stages[si].padding_range) + } else { + (*this.global_threshold, *this.global_range) + }; + + let padding_len = if raw_len < threshold { + let max_pad = MAX_RAW_PAYLOAD - raw_len; + let wanted = this.rng.random_range(range[0]..=range[1]); + wanted.min(max_pad) + } else { + 0 + }; + + let payload_len = HEADER_LEN + raw_len + padding_len; + + if let Some(cipher) = this.cipher { + this.out_buf.clear(); + this.out_buf.reserve(payload_len); + this.out_buf.put_u64(seq); + this.out_buf.put_u16(raw_len as u16); + this.out_buf.put_slice(&this.raw_buf[..raw_len]); + if padding_len > 0 { + this.out_buf.put_bytes(0u8, padding_len); + } + + let encrypted = cipher.encrypt(&this.out_buf[..payload_len])?; + this.out_buf.clear(); + + match *this.encoding { + EncodingType::Binary => { + this.out_buf.reserve(2 + encrypted.len()); + this.out_buf.put_u16(encrypted.len() as u16); + this.out_buf.put_slice(&encrypted); + } + EncodingType::Json => { + let enc_str = base122_fast::encode(&encrypted); + let enc_bytes = enc_str.as_bytes(); + + this.out_buf.reserve(9 + enc_bytes.len() + 2 + 1); + this.out_buf.put_slice(b"{\"data\":\""); + this.out_buf.put_slice(enc_bytes); + this.out_buf.put_slice(b"\"}\n"); + } + } + } else { + this.out_buf.clear(); + + match *this.encoding { + EncodingType::Binary => { + this.out_buf.reserve(2 + payload_len); + this.out_buf.put_u16(payload_len as u16); + this.out_buf.put_u64(seq); + this.out_buf.put_u16(raw_len as u16); + this.out_buf.put_slice(&this.raw_buf[..raw_len]); + if padding_len > 0 { + this.out_buf.put_bytes(0u8, padding_len); + } + } + EncodingType::Json => { + this.out_buf.reserve(payload_len); + this.out_buf.put_u64(seq); + this.out_buf.put_u16(raw_len as u16); + this.out_buf.put_slice(&this.raw_buf[..raw_len]); + if padding_len > 0 { + this.out_buf.put_bytes(0u8, padding_len); + } + let enc_str = base122_fast::encode(&this.out_buf[..payload_len]); + let enc_bytes = enc_str.as_bytes(); + + this.out_buf.clear(); + this.out_buf.reserve(9 + enc_bytes.len() + 2 + 1); + this.out_buf.put_slice(b"{\"data\":\""); + this.out_buf.put_slice(enc_bytes); + this.out_buf.put_slice(b"\"}\n"); + } + } + } + + this.raw_buf.clear(); + let result = this.out_buf.split().freeze(); + Ok((seq, result)) + } +} + +impl tokio_stream::Stream for TrafficShaper { + type Item = Result<(u64, Bytes), Error>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + + loop { + let raw_len = this.raw_buf.len(); + let remaining = MAX_RAW_PAYLOAD - raw_len; + + if remaining == 0 { + return Poll::Ready(Some(Self::seal_and_emit(&mut this))); + } + + if *this.timer_armed && raw_len > 0 { + if this.flush_timer.as_mut().poll(cx).is_ready() { + return Poll::Ready(Some(Self::seal_and_emit(&mut this))); + } else { + return Poll::Pending; + } + } + + this.raw_buf.reserve(remaining); + let spare = this.raw_buf.spare_capacity_mut(); + let read_limit = spare.len().min(remaining); + let mut rb = ReadBuf::uninit(&mut spare[..read_limit]); + + match this.reader.as_mut().poll_read(cx, &mut rb) { + Poll::Ready(Ok(())) => { + let n = rb.filled().len(); + if n == 0 { + return if raw_len == 0 { + Poll::Ready(None) + } else { + Poll::Ready(Some(Self::seal_and_emit(&mut this))) + }; + } + + unsafe { this.raw_buf.advance_mut(n) } + + if raw_len == 0 && this.raw_buf.len() < MAX_RAW_PAYLOAD { + let idx = *this.cursor; + let delay_us = jitter_table()[idx]; + *this.cursor = (idx + 1) & TABLE_MASK; + this.flush_timer + .as_mut() + .reset(Instant::now() + Duration::from_micros(delay_us)); + *this.timer_armed = true; + + let _ = this.flush_timer.as_mut().poll(cx); + } + } + Poll::Pending => { + if *this.timer_armed + && raw_len > 0 + && this.flush_timer.as_mut().poll(cx).is_ready() + { + return Poll::Ready(Some(Self::seal_and_emit(&mut this))); + } + return Poll::Pending; + } + Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e))), + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test_config() -> TrafficConfig { + TrafficConfig { + encoding_type: EncodingType::Binary, + global: PaddingConfig { + padding_threshold: 16384, + padding_range: [0, 16], + }, + stages: vec![], + } + } + + #[test] + fn encode_decode_roundtrip_binary_no_cipher() { + let data = b"hello proxy frame"; + let config = test_config(); + let frame = encode_frame(data, 7, None, &config).unwrap(); + let mut buf = BytesMut::from(&frame[..]); + let (seq, decoded) = decode_from_buffer(&mut buf, None, EncodingType::Binary) + .unwrap() + .unwrap(); + assert_eq!(seq, 7); + assert_eq!(&decoded[..], data); + assert!(buf.is_empty()); + } + + #[test] + fn decode_incomplete_returns_none() { + let mut buf = BytesMut::new(); + buf.put_u16(100u16); + buf.put_u8(0xAA); + let result = decode_from_buffer(&mut buf, None, EncodingType::Binary).unwrap(); + assert!(result.is_none()); + } + + #[test] + fn decode_too_long_rejected() { + let mut buf = BytesMut::new(); + buf.put_u16((MAX_RAW_PAYLOAD + 1000) as u16); + buf.resize(2 + MAX_RAW_PAYLOAD + 1000, 0u8); + let result = decode_from_buffer(&mut buf, None, EncodingType::Binary); + assert!(result.is_err()); + } + + #[test] + fn extract_frame_valid() { + let payload = [0u8; 8] + .iter() + .copied() + .chain(3u16.to_be_bytes()) + .chain(b"abc".iter().copied()) + .collect::>(); + let (seq, data) = extract_frame(&payload).unwrap(); + assert_eq!(seq, 0); + assert_eq!(&data[..], b"abc"); + } + + #[test] + fn extract_frame_too_short() { + assert!(extract_frame(b"short").is_err()); + } + + #[test] + fn parse_json_payload_valid() { + let enc = base122_fast::encode(b"hello"); + let json = format!("{{\"data\":\"{enc}\"}}"); + let result = parse_json_payload(json.as_bytes()).unwrap(); + assert_eq!(result, b"hello"); + } + + #[test] + fn parse_json_payload_missing_field() { + let result = parse_json_payload(b"{\"other\":\"x\"}"); + assert!(result.is_err()); + } +}