From c3b2b1b8e71024bbe8ac56787b54b9c43f484e72 Mon Sep 17 00:00:00 2001 From: Paolo Barbolini Date: Thu, 23 Jan 2025 17:39:51 +0100 Subject: [PATCH] First commit Co-authored-by: Edoardo Morandi Co-authored-by: Angel Hudgins --- .github/workflows/ci.yml | 87 + .gitignore | 1 + .gitlab-ci.yml | 46 + Cargo.lock | 1812 +++++++++++++++++ Cargo.toml | 30 + LICENSE-APACHE | 201 ++ LICENSE-MIT | 21 + README.md | 50 + deny.toml | 36 + watermelon-mini/Cargo.toml | 38 + watermelon-mini/LICENSE-APACHE | 1 + watermelon-mini/LICENSE-MIT | 1 + watermelon-mini/README.md | 1 + watermelon-mini/src/lib.rs | 79 + watermelon-mini/src/non_standard_zstd.rs | 107 + watermelon-mini/src/proto/authenticator.rs | 148 ++ .../src/proto/connection/compression.rs | 106 + watermelon-mini/src/proto/connection/mod.rs | 5 + .../src/proto/connection/security.rs | 101 + watermelon-mini/src/proto/connector.rs | 223 ++ watermelon-mini/src/proto/mod.rs | 8 + watermelon-mini/src/util.rs | 56 + watermelon-net/Cargo.toml | 43 + watermelon-net/LICENSE-APACHE | 1 + watermelon-net/LICENSE-MIT | 1 + watermelon-net/README.md | 1 + watermelon-net/src/connection/mod.rs | 264 +++ watermelon-net/src/connection/streaming.rs | 242 +++ watermelon-net/src/connection/websocket.rs | 143 ++ watermelon-net/src/happy_eyeballs.rs | 213 ++ watermelon-net/src/lib.rs | 13 + watermelon-nkeys/Cargo.toml | 26 + watermelon-nkeys/LICENSE-APACHE | 1 + watermelon-nkeys/LICENSE-MIT | 1 + watermelon-nkeys/README.md | 1 + watermelon-nkeys/src/crc.rs | 25 + watermelon-nkeys/src/lib.rs | 4 + watermelon-nkeys/src/seed.rs | 131 ++ watermelon-proto/Cargo.toml | 37 + watermelon-proto/LICENSE-APACHE | 1 + watermelon-proto/LICENSE-MIT | 1 + watermelon-proto/README.md | 1 + watermelon-proto/src/connect.rs | 63 + watermelon-proto/src/headers/map.rs | 303 +++ watermelon-proto/src/headers/mod.rs | 12 + watermelon-proto/src/headers/name.rs | 203 ++ watermelon-proto/src/headers/value.rs | 151 ++ watermelon-proto/src/lib.rs | 36 + watermelon-proto/src/message.rs | 18 + watermelon-proto/src/proto/client.rs | 28 + watermelon-proto/src/proto/decoder/framed.rs | 27 + watermelon-proto/src/proto/decoder/mod.rs | 373 ++++ watermelon-proto/src/proto/decoder/stream.rs | 125 ++ watermelon-proto/src/proto/encoder/framed.rs | 192 ++ watermelon-proto/src/proto/encoder/mod.rs | 173 ++ watermelon-proto/src/proto/encoder/stream.rs | 339 +++ watermelon-proto/src/proto/mod.rs | 13 + watermelon-proto/src/proto/server.rs | 13 + watermelon-proto/src/queue_group.rs | 253 +++ watermelon-proto/src/server_addr.rs | 367 ++++ watermelon-proto/src/server_error.rs | 110 + watermelon-proto/src/server_info.rs | 69 + watermelon-proto/src/status_code.rs | 152 ++ watermelon-proto/src/subject.rs | 259 +++ watermelon-proto/src/subscription_id.rs | 38 + watermelon-proto/src/tests.rs | 12 + watermelon-proto/src/util/buf_list.rs | 111 + watermelon-proto/src/util/lines_iter.rs | 54 + watermelon-proto/src/util/mod.rs | 10 + watermelon-proto/src/util/split_spaces.rs | 29 + watermelon-proto/src/util/uint.rs | 51 + watermelon/Cargo.toml | 54 + watermelon/LICENSE-APACHE | 1 + watermelon/LICENSE-MIT | 1 + watermelon/README.md | 1 + watermelon/src/atomic.rs | 4 + watermelon/src/client/builder.rs | 201 ++ watermelon/src/client/commands/mod.rs | 11 + watermelon/src/client/commands/publish.rs | 302 +++ watermelon/src/client/commands/request.rs | 416 ++++ watermelon/src/client/from_env.rs | 42 + .../jetstream/commands/consumer_batch.rs | 143 ++ .../jetstream/commands/consumer_list.rs | 117 ++ .../jetstream/commands/consumer_stream.rs | 127 ++ .../src/client/jetstream/commands/mod.rs | 9 + .../client/jetstream/commands/stream_list.rs | 110 + watermelon/src/client/jetstream/mod.rs | 250 +++ .../client/jetstream/resources/consumer.rs | 400 ++++ .../src/client/jetstream/resources/mod.rs | 235 +++ .../src/client/jetstream/resources/stream.rs | 101 + watermelon/src/client/mod.rs | 506 +++++ watermelon/src/client/quick_info.rs | 165 ++ watermelon/src/client/tests.rs | 14 + watermelon/src/handler.rs | 721 +++++++ watermelon/src/lib.rs | 61 + watermelon/src/multiplexed_subscription.rs | 69 + watermelon/src/subscription.rs | 375 ++++ watermelon/src/tests.rs | 26 + 98 files changed, 12354 insertions(+) create mode 100644 .github/workflows/ci.yml create mode 100644 .gitignore create mode 100644 .gitlab-ci.yml create mode 100644 Cargo.lock create mode 100644 Cargo.toml create mode 100644 LICENSE-APACHE create mode 100644 LICENSE-MIT create mode 100644 README.md create mode 100644 deny.toml create mode 100644 watermelon-mini/Cargo.toml create mode 120000 watermelon-mini/LICENSE-APACHE create mode 120000 watermelon-mini/LICENSE-MIT create mode 120000 watermelon-mini/README.md create mode 100644 watermelon-mini/src/lib.rs create mode 100644 watermelon-mini/src/non_standard_zstd.rs create mode 100644 watermelon-mini/src/proto/authenticator.rs create mode 100644 watermelon-mini/src/proto/connection/compression.rs create mode 100644 watermelon-mini/src/proto/connection/mod.rs create mode 100644 watermelon-mini/src/proto/connection/security.rs create mode 100644 watermelon-mini/src/proto/connector.rs create mode 100644 watermelon-mini/src/proto/mod.rs create mode 100644 watermelon-mini/src/util.rs create mode 100644 watermelon-net/Cargo.toml create mode 120000 watermelon-net/LICENSE-APACHE create mode 120000 watermelon-net/LICENSE-MIT create mode 120000 watermelon-net/README.md create mode 100644 watermelon-net/src/connection/mod.rs create mode 100644 watermelon-net/src/connection/streaming.rs create mode 100644 watermelon-net/src/connection/websocket.rs create mode 100644 watermelon-net/src/happy_eyeballs.rs create mode 100644 watermelon-net/src/lib.rs create mode 100644 watermelon-nkeys/Cargo.toml create mode 120000 watermelon-nkeys/LICENSE-APACHE create mode 120000 watermelon-nkeys/LICENSE-MIT create mode 120000 watermelon-nkeys/README.md create mode 100644 watermelon-nkeys/src/crc.rs create mode 100644 watermelon-nkeys/src/lib.rs create mode 100644 watermelon-nkeys/src/seed.rs create mode 100644 watermelon-proto/Cargo.toml create mode 120000 watermelon-proto/LICENSE-APACHE create mode 120000 watermelon-proto/LICENSE-MIT create mode 120000 watermelon-proto/README.md create mode 100644 watermelon-proto/src/connect.rs create mode 100644 watermelon-proto/src/headers/map.rs create mode 100644 watermelon-proto/src/headers/mod.rs create mode 100644 watermelon-proto/src/headers/name.rs create mode 100644 watermelon-proto/src/headers/value.rs create mode 100644 watermelon-proto/src/lib.rs create mode 100644 watermelon-proto/src/message.rs create mode 100644 watermelon-proto/src/proto/client.rs create mode 100644 watermelon-proto/src/proto/decoder/framed.rs create mode 100644 watermelon-proto/src/proto/decoder/mod.rs create mode 100644 watermelon-proto/src/proto/decoder/stream.rs create mode 100644 watermelon-proto/src/proto/encoder/framed.rs create mode 100644 watermelon-proto/src/proto/encoder/mod.rs create mode 100644 watermelon-proto/src/proto/encoder/stream.rs create mode 100644 watermelon-proto/src/proto/mod.rs create mode 100644 watermelon-proto/src/proto/server.rs create mode 100644 watermelon-proto/src/queue_group.rs create mode 100644 watermelon-proto/src/server_addr.rs create mode 100644 watermelon-proto/src/server_error.rs create mode 100644 watermelon-proto/src/server_info.rs create mode 100644 watermelon-proto/src/status_code.rs create mode 100644 watermelon-proto/src/subject.rs create mode 100644 watermelon-proto/src/subscription_id.rs create mode 100644 watermelon-proto/src/tests.rs create mode 100644 watermelon-proto/src/util/buf_list.rs create mode 100644 watermelon-proto/src/util/lines_iter.rs create mode 100644 watermelon-proto/src/util/mod.rs create mode 100644 watermelon-proto/src/util/split_spaces.rs create mode 100644 watermelon-proto/src/util/uint.rs create mode 100644 watermelon/Cargo.toml create mode 120000 watermelon/LICENSE-APACHE create mode 120000 watermelon/LICENSE-MIT create mode 120000 watermelon/README.md create mode 100644 watermelon/src/atomic.rs create mode 100644 watermelon/src/client/builder.rs create mode 100644 watermelon/src/client/commands/mod.rs create mode 100644 watermelon/src/client/commands/publish.rs create mode 100644 watermelon/src/client/commands/request.rs create mode 100644 watermelon/src/client/from_env.rs create mode 100644 watermelon/src/client/jetstream/commands/consumer_batch.rs create mode 100644 watermelon/src/client/jetstream/commands/consumer_list.rs create mode 100644 watermelon/src/client/jetstream/commands/consumer_stream.rs create mode 100644 watermelon/src/client/jetstream/commands/mod.rs create mode 100644 watermelon/src/client/jetstream/commands/stream_list.rs create mode 100644 watermelon/src/client/jetstream/mod.rs create mode 100644 watermelon/src/client/jetstream/resources/consumer.rs create mode 100644 watermelon/src/client/jetstream/resources/mod.rs create mode 100644 watermelon/src/client/jetstream/resources/stream.rs create mode 100644 watermelon/src/client/mod.rs create mode 100644 watermelon/src/client/quick_info.rs create mode 100644 watermelon/src/client/tests.rs create mode 100644 watermelon/src/handler.rs create mode 100644 watermelon/src/lib.rs create mode 100644 watermelon/src/multiplexed_subscription.rs create mode 100644 watermelon/src/subscription.rs create mode 100644 watermelon/src/tests.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..70743bb --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,87 @@ +name: CI + +on: [push, pull_request] + +jobs: + cargo-deny: + name: cargo-deny + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: EmbarkStudios/cargo-deny-action@v2 + + fmt: + name: rustfmt / 1.84.0 + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - uses: dtolnay/rust-toolchain@1.84.0 + with: + components: rustfmt + + - name: Rust rustfmt + run: cargo fmt --all -- --check + + clippy: + name: clippy / 1.84.0 + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - uses: dtolnay/rust-toolchain@1.84.0 + with: + components: clippy + + - name: Run clippy + run: cargo clippy --all-features -- -D warnings + + cargo-hack: + name: cargo-hack / 1.84.0 + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: dtolnay/rust-toolchain@1.84.0 + + - uses: taiki-e/install-action@v2 + with: + tool: cargo-hack@0.6.34 + + - name: Run cargo-hack + run: cargo hack check --feature-powerset --no-dev-deps --at-least-one-of aws-lc-rs,ring + + test: + name: test / ${{ matrix.name }} + runs-on: ubuntu-latest + + strategy: + matrix: + include: + - name: stable + rust: stable + - name: beta + rust: beta + - name: nightly + rust: nightly + - name: 1.81.0 + rust: 1.81.0 + + steps: + - uses: actions/checkout@v4 + + - uses: dtolnay/rust-toolchain@master + with: + toolchain: ${{ matrix.rust }} + + - name: Run tests + run: cargo test + + - name: Run tests (--features websocket,portable-atomic) + run: cargo test --features websocket,portable-atomic + + - name: Run tests (--no-default-features --features ring) + run: cargo test --no-default-features --features ring diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ea8c4bf --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +/target diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml new file mode 100644 index 0000000..0041553 --- /dev/null +++ b/.gitlab-ci.yml @@ -0,0 +1,46 @@ +stages: + - test + +rust:deny: + stage: test + image: rust:1.84-alpine3.21 + before_script: + - apk add cargo-deny + script: + - cargo deny check + +rust:fmt: + stage: test + image: rust:1.84-alpine3.21 + before_script: + - rustup component add rustfmt + script: + - cargo fmt -- --check + +rust:clippy: + stage: test + image: rust:1.84-alpine3.20 + before_script: + - apk add build-base musl-dev linux-headers cmake perl go + - rustup component add clippy + script: + - cargo clippy --all-features -- -D warnings + +rust:hack: + stage: test + image: rust:1.84-alpine3.20 + before_script: + - apk add build-base musl-dev linux-headers cmake perl go + - cargo install --locked cargo-hack@0.6.34 + script: + - cargo hack check --feature-powerset --no-dev-deps --at-least-one-of aws-lc-rs,ring + +rust:test: + stage: test + image: rust:1.84-alpine3.21 + before_script: + - apk add musl-dev cmake perl go + script: + - cargo test + - cargo test --features websocket,portable-atomic + - cargo test --no-default-features --features ring diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 0000000..9fcec13 --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,1812 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "addr2line" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfbe277e56a376000877090da837660b4427aad530e3028d44e0bffe4f89a1c1" +dependencies = [ + "gimli", +] + +[[package]] +name = "adler2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" + +[[package]] +name = "aho-corasick" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +dependencies = [ + "memchr", +] + +[[package]] +name = "android-tzdata" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + +[[package]] +name = "arc-swap" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" + +[[package]] +name = "async-compression" +version = "0.4.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df895a515f70646414f4b45c0b79082783b80552b373a68283012928df56f522" +dependencies = [ + "futures-core", + "memchr", + "pin-project-lite", + "tokio", + "zstd", + "zstd-safe", +] + +[[package]] +name = "autocfg" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" + +[[package]] +name = "aws-lc-fips-sys" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8c7557f6c81ecd3e38582996b31a0f329900586abaae5f092e756686958f22c" +dependencies = [ + "bindgen", + "cc", + "cmake", + "dunce", + "fs_extra", + "paste", + "regex", +] + +[[package]] +name = "aws-lc-rs" +version = "1.12.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c2b7ddaa2c56a367ad27a094ad8ef4faacf8a617c2575acb2ba88949df999ca" +dependencies = [ + "aws-lc-fips-sys", + "aws-lc-sys", + "paste", + "zeroize", +] + +[[package]] +name = "aws-lc-sys" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "71b2ddd3ada61a305e1d8bb6c005d1eaa7d14d903681edfc400406d523a9b491" +dependencies = [ + "bindgen", + "cc", + "cmake", + "dunce", + "fs_extra", + "paste", +] + +[[package]] +name = "backtrace" +version = "0.3.74" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d82cb332cdfaed17ae235a638438ac4d4839913cc2af585c3c6746e8f8bee1a" +dependencies = [ + "addr2line", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", + "windows-targets 0.52.6", +] + +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + +[[package]] +name = "bindgen" +version = "0.69.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088" +dependencies = [ + "bitflags", + "cexpr", + "clang-sys", + "itertools", + "lazy_static", + "lazycell", + "log", + "prettyplease", + "proc-macro2", + "quote", + "regex", + "rustc-hash", + "shlex", + "syn", + "which", +] + +[[package]] +name = "bitflags" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36" + +[[package]] +name = "bumpalo" +version = "3.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" + +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + +[[package]] +name = "bytes" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b" + +[[package]] +name = "bytestring" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e465647ae23b2823b0753f50decb2d5a86d2bb2cac04788fafd1f80e45378e5f" +dependencies = [ + "bytes", + "serde", +] + +[[package]] +name = "cc" +version = "1.2.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13208fcbb66eaeffe09b99fffbe1af420f00a7b35aa99ad683dfc1aa76145229" +dependencies = [ + "jobserver", + "libc", + "shlex", +] + +[[package]] +name = "cesu8" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d43a04d8753f35258c91f8ec639f792891f748a1edbd759cf1dcea3382ad83c" + +[[package]] +name = "cexpr" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" +dependencies = [ + "nom", +] + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "chrono" +version = "0.4.39" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e36cc9d416881d2e24f9a963be5fb1cd90966419ac844274161d10488b3e825" +dependencies = [ + "android-tzdata", + "iana-time-zone", + "num-traits", + "serde", + "windows-targets 0.52.6", +] + +[[package]] +name = "claims" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bba18ee93d577a8428902687bcc2b6b45a56b1981a1f6d779731c86cc4c5db18" + +[[package]] +name = "clang-sys" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4" +dependencies = [ + "glob", + "libc", + "libloading", +] + +[[package]] +name = "cmake" +version = "0.1.52" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c682c223677e0e5b6b7f63a64b9351844c3f1b1678a68b7ee617e30fb082620e" +dependencies = [ + "cc", +] + +[[package]] +name = "combine" +version = "4.6.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd" +dependencies = [ + "bytes", + "memchr", +] + +[[package]] +name = "core-foundation" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b55271e5c8c478ad3f38ad24ef34923091e0548492a266d19b3c0b4d82574c63" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" + +[[package]] +name = "crc" +version = "3.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69e6e4d7b33a94f0991c26729976b10ebde1d34c3ee82408fb536164fa10d636" +dependencies = [ + "crc-catalog", +] + +[[package]] +name = "crc-catalog" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5" + +[[package]] +name = "data-encoding" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e60eed09d8c01d3cee5b7d30acb059b76614c918fa0f992e0dd6eeb10daad6f" + +[[package]] +name = "displaydoc" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "dunce" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" + +[[package]] +name = "either" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" + +[[package]] +name = "envy" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f47e0157f2cb54f5ae1bd371b30a2ae4311e1c028f575cd4e81de7353215965" +dependencies = [ + "serde", +] + +[[package]] +name = "errno" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d" +dependencies = [ + "libc", + "windows-sys 0.59.0", +] + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "form_urlencoded" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" +dependencies = [ + "percent-encoding", +] + +[[package]] +name = "fs_extra" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" + +[[package]] +name = "futures-core" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" + +[[package]] +name = "futures-sink" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" + +[[package]] +name = "futures-task" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" + +[[package]] +name = "futures-util" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" +dependencies = [ + "futures-core", + "futures-task", + "pin-project-lite", + "pin-utils", +] + +[[package]] +name = "getrandom" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "gimli" +version = "0.31.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" + +[[package]] +name = "glob" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" + +[[package]] +name = "home" +version = "0.5.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf" +dependencies = [ + "windows-sys 0.59.0", +] + +[[package]] +name = "http" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f16ca2af56261c99fba8bac40a10251ce8188205a4c448fbb745a2e4daa76fea" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + +[[package]] +name = "httparse" +version = "1.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d71d3574edd2771538b901e6549113b4006ece66150fb69c0fb6d9a2adae946" + +[[package]] +name = "iana-time-zone" +version = "0.1.61" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + +[[package]] +name = "icu_collections" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db2fa452206ebee18c4b5c2274dbf1de17008e874b4dc4f0aea9d01ca79e4526" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_locid" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13acbb8371917fc971be86fc8057c41a64b521c184808a698c02acc242dbf637" +dependencies = [ + "displaydoc", + "litemap", + "tinystr", + "writeable", + "zerovec", +] + +[[package]] +name = "icu_locid_transform" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01d11ac35de8e40fdeda00d9e1e9d92525f3f9d887cdd7aa81d727596788b54e" +dependencies = [ + "displaydoc", + "icu_locid", + "icu_locid_transform_data", + "icu_provider", + "tinystr", + "zerovec", +] + +[[package]] +name = "icu_locid_transform_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdc8ff3388f852bede6b579ad4e978ab004f139284d7b28715f773507b946f6e" + +[[package]] +name = "icu_normalizer" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19ce3e0da2ec68599d193c93d088142efd7f9c5d6fc9b803774855747dc6a84f" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_normalizer_data", + "icu_properties", + "icu_provider", + "smallvec", + "utf16_iter", + "utf8_iter", + "write16", + "zerovec", +] + +[[package]] +name = "icu_normalizer_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8cafbf7aa791e9b22bec55a167906f9e1215fd475cd22adfcf660e03e989516" + +[[package]] +name = "icu_properties" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93d6020766cfc6302c15dbbc9c8778c37e62c14427cb7f6e601d849e092aeef5" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_locid_transform", + "icu_properties_data", + "icu_provider", + "tinystr", + "zerovec", +] + +[[package]] +name = "icu_properties_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67a8effbc3dd3e4ba1afa8ad918d5684b8868b3b26500753effea8d2eed19569" + +[[package]] +name = "icu_provider" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ed421c8a8ef78d3e2dbc98a973be2f3770cb42b606e3ab18d6237c4dfde68d9" +dependencies = [ + "displaydoc", + "icu_locid", + "icu_provider_macros", + "stable_deref_trait", + "tinystr", + "writeable", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_provider_macros" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "idna" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "686f825264d630750a544639377bae737628043f20d38bbc029e8f29ea968a7e" +dependencies = [ + "idna_adapter", + "smallvec", + "utf8_iter", +] + +[[package]] +name = "idna_adapter" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daca1df1c957320b2cf139ac61e7bd64fed304c5040df000a745aa1de3b4ef71" +dependencies = [ + "icu_normalizer", + "icu_properties", +] + +[[package]] +name = "itertools" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +dependencies = [ + "either", +] + +[[package]] +name = "itoa" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674" + +[[package]] +name = "jni" +version = "0.21.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a87aa2bb7d2af34197c04845522473242e1aa17c12f4935d5856491a7fb8c97" +dependencies = [ + "cesu8", + "cfg-if", + "combine", + "jni-sys", + "log", + "thiserror 1.0.69", + "walkdir", + "windows-sys 0.45.0", +] + +[[package]] +name = "jni-sys" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" + +[[package]] +name = "jobserver" +version = "0.1.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0" +dependencies = [ + "libc", +] + +[[package]] +name = "js-sys" +version = "0.3.77" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1cfaf33c695fc6e08064efbc1f72ec937429614f25eef83af942d0e227c3a28f" +dependencies = [ + "once_cell", + "wasm-bindgen", +] + +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" + +[[package]] +name = "lazycell" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" + +[[package]] +name = "libc" +version = "0.2.169" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" + +[[package]] +name = "libloading" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34" +dependencies = [ + "cfg-if", + "windows-targets 0.52.6", +] + +[[package]] +name = "linux-raw-sys" +version = "0.4.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" + +[[package]] +name = "litemap" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ee93343901ab17bd981295f2cf0026d4ad018c7c31ba84549a4ddbb47a45104" + +[[package]] +name = "log" +version = "0.4.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f" + +[[package]] +name = "memchr" +version = "2.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" + +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + +[[package]] +name = "miniz_oxide" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8402cab7aefae129c6977bb0ff1b8fd9a04eb5b51efc50a70bea51cda0c7924" +dependencies = [ + "adler2", +] + +[[package]] +name = "mio" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd" +dependencies = [ + "libc", + "wasi", + "windows-sys 0.52.0", +] + +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + +[[package]] +name = "object" +version = "0.36.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62948e14d923ea95ea2c7c86c71013138b66525b86bdc08d2dcc262bdb497b87" +dependencies = [ + "memchr", +] + +[[package]] +name = "once_cell" +version = "1.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" + +[[package]] +name = "openssl-probe" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" + +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + +[[package]] +name = "percent-encoding" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" + +[[package]] +name = "pin-project-lite" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + +[[package]] +name = "pkg-config" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2" + +[[package]] +name = "portable-atomic" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "280dc24453071f1b63954171985a0b0d30058d287960968b9b2aca264c8d4ee6" + +[[package]] +name = "ppv-lite86" +version = "0.2.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" +dependencies = [ + "zerocopy", +] + +[[package]] +name = "prettyplease" +version = "0.2.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6924ced06e1f7dfe3fa48d57b9f74f55d8915f5036121bef647ef4b204895fac" +dependencies = [ + "proc-macro2", + "syn", +] + +[[package]] +name = "proc-macro2" +version = "1.0.93" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60946a68e5f9d28b0dc1c21bb8a97ee7d018a8b322fa57838ba31cc878e22d99" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.38" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e4dccaaaf89514f546c693ddc140f729f958c247918a13380cccc6078391acc" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + +[[package]] +name = "regex" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" + +[[package]] +name = "ring" +version = "0.17.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d" +dependencies = [ + "cc", + "cfg-if", + "getrandom", + "libc", + "spin", + "untrusted", + "windows-sys 0.52.0", +] + +[[package]] +name = "rustc-demangle" +version = "0.1.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" + +[[package]] +name = "rustc-hash" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" + +[[package]] +name = "rustix" +version = "0.38.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.59.0", +] + +[[package]] +name = "rustls" +version = "0.23.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f287924602bf649d949c63dc8ac8b235fa5387d394020705b80c4eb597ce5b8" +dependencies = [ + "aws-lc-rs", + "once_cell", + "ring", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-native-certs" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fcff2dd52b58a8d98a70243663a0d234c4e2b79235637849d15913394a247d3" +dependencies = [ + "openssl-probe", + "rustls-pki-types", + "schannel", + "security-framework", +] + +[[package]] +name = "rustls-pki-types" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2bf47e6ff922db3825eb750c4e2ff784c6ff8fb9e13046ef6a1d1c5401b0b37" + +[[package]] +name = "rustls-platform-verifier" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e012c45844a1790332c9386ed4ca3a06def221092eda277e6f079728f8ea99da" +dependencies = [ + "core-foundation", + "core-foundation-sys", + "jni", + "log", + "once_cell", + "rustls", + "rustls-native-certs", + "rustls-platform-verifier-android", + "rustls-webpki", + "security-framework", + "security-framework-sys", + "webpki-root-certs", + "windows-sys 0.52.0", +] + +[[package]] +name = "rustls-platform-verifier-android" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f87165f0995f63a9fbeea62b64d10b4d9d8e78ec6d7d51fb2125fda7bb36788f" + +[[package]] +name = "rustls-webpki" +version = "0.102.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9" +dependencies = [ + "aws-lc-rs", + "ring", + "rustls-pki-types", + "untrusted", +] + +[[package]] +name = "rustversion" +version = "1.0.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7c45b9784283f1b2e7fb61b42047c2fd678ef0960d4f6f1eba131594cc369d4" + +[[package]] +name = "ryu" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" + +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "schannel" +version = "0.1.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f29ebaa345f945cec9fbbc532eb307f0fdad8161f281b6369539c8d84876b3d" +dependencies = [ + "windows-sys 0.59.0", +] + +[[package]] +name = "security-framework" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271720403f46ca04f7ba6f55d438f8bd878d6b8ca0a1046e8228c4145bcbb316" +dependencies = [ + "bitflags", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49db231d56a190491cb4aeda9527f1ad45345af50b0851622a7adb8c03b01c32" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "serde" +version = "1.0.217" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02fc4265df13d6fa1d00ecff087228cc0a2b5f3c0e87e258d8b94a156e984c70" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.217" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.137" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "930cfb6e6abf99298aaad7d29abbef7a9999a9a8806a40088f55f0dcec03146b" +dependencies = [ + "itoa", + "memchr", + "ryu", + "serde", +] + +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "smallvec" +version = "1.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" + +[[package]] +name = "socket2" +version = "0.5.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c970269d99b64e60ec3bd6ad27270092a5394c4e309314b18ae3fe575695fbe8" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" + +[[package]] +name = "stable_deref_trait" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" + +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + +[[package]] +name = "syn" +version = "2.0.96" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5d0adab1ae378d7f53bdebc67a39f1f151407ef230f0ce2883572f5d8985c80" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "synstructure" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "thiserror" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl 1.0.69", +] + +[[package]] +name = "thiserror" +version = "2.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d452f284b73e6d76dd36758a0c8684b1d5be31f92b89d07fd5822175732206fc" +dependencies = [ + "thiserror-impl 2.0.11", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26afc1baea8a989337eeb52b6e72a039780ce45c3edfcc9c5b9d112feeb173c2" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tinystr" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9117f5d4db391c1cf6927e7bea3db74b9a1c1add8f7eda9ffd5364f40f57b82f" +dependencies = [ + "displaydoc", + "zerovec", +] + +[[package]] +name = "tokio" +version = "1.43.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d61fa4ffa3de412bfea335c6ecff681de2b609ba3c77ef3e00e521813a9ed9e" +dependencies = [ + "backtrace", + "bytes", + "libc", + "mio", + "pin-project-lite", + "socket2", + "tokio-macros", + "windows-sys 0.52.0", +] + +[[package]] +name = "tokio-macros" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tokio-rustls" +version = "0.26.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f6d0975eaace0cf0fcadee4e4aaa5da15b5c079146f2cffb67c113be122bf37" +dependencies = [ + "rustls", + "tokio", +] + +[[package]] +name = "tokio-util" +version = "0.7.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7fcaa8d55a2bdd6b83ace262b016eca0d79ee02818c5c1bcdf0305114081078" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "tokio-websockets" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bad3d80d26e290444b67405d54661959107d3aadc6ecc35d717a626b5e208c51" +dependencies = [ + "aws-lc-rs", + "base64", + "bytes", + "futures-core", + "futures-sink", + "http", + "httparse", + "rand", + "ring", + "tokio", + "tokio-rustls", + "tokio-util", +] + +[[package]] +name = "unicase" +version = "2.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539" + +[[package]] +name = "unicode-ident" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11cd88e12b17c6494200a9c1b683a04fcac9573ed74cd1b62aeb2727c5592243" + +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + +[[package]] +name = "url" +version = "2.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32f8b686cadd1473f4bd0117a5d28d36b1ade384ea9b5069a1c40aefed7fda60" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", + "serde", +] + +[[package]] +name = "utf16_iter" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8232dd3cdaed5356e0f716d285e4b40b932ac434100fe9b7e0e8e935b9e6246" + +[[package]] +name = "utf8_iter" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" + +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "wasm-bindgen" +version = "0.2.100" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1edc8929d7499fc4e8f0be2262a241556cfc54a0bea223790e71446f2aab1ef5" +dependencies = [ + "cfg-if", + "once_cell", + "rustversion", + "wasm-bindgen-macro", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.100" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f0a0651a5c2bc21487bde11ee802ccaf4c51935d0d3d42a6101f98161700bc6" +dependencies = [ + "bumpalo", + "log", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.100" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fe63fc6d09ed3792bd0897b314f53de8e16568c2b3f7982f468c0bf9bd0b407" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.100" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.100" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a05d73b933a847d6cccdda8f838a22ff101ad9bf93e33684f39c1f5f0eece3d" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "watermelon" +version = "0.1.0" +dependencies = [ + "arc-swap", + "bytes", + "chrono", + "claims", + "envy", + "futures-core", + "futures-util", + "pin-project-lite", + "portable-atomic", + "rand", + "serde", + "serde_json", + "thiserror 2.0.11", + "tokio", + "watermelon-mini", + "watermelon-net", + "watermelon-nkeys", + "watermelon-proto", +] + +[[package]] +name = "watermelon-mini" +version = "0.1.0" +dependencies = [ + "async-compression", + "rustls-platform-verifier", + "thiserror 2.0.11", + "tokio", + "tokio-rustls", + "watermelon-net", + "watermelon-nkeys", + "watermelon-proto", +] + +[[package]] +name = "watermelon-net" +version = "0.1.0" +dependencies = [ + "bytes", + "claims", + "futures-sink", + "futures-util", + "http", + "pin-project-lite", + "thiserror 2.0.11", + "tokio", + "tokio-websockets", + "watermelon-proto", +] + +[[package]] +name = "watermelon-nkeys" +version = "0.1.0" +dependencies = [ + "aws-lc-rs", + "crc", + "data-encoding", + "ring", + "thiserror 2.0.11", +] + +[[package]] +name = "watermelon-proto" +version = "0.1.0" +dependencies = [ + "bytes", + "bytestring", + "claims", + "memchr", + "percent-encoding", + "serde", + "serde_json", + "thiserror 2.0.11", + "unicase", + "url", +] + +[[package]] +name = "webpki-root-certs" +version = "0.26.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9cd5da49bdf1f30054cfe0b8ce2958b8fbeb67c4d82c8967a598af481bef255c" +dependencies = [ + "rustls-pki-types", +] + +[[package]] +name = "which" +version = "4.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7" +dependencies = [ + "either", + "home", + "once_cell", + "rustix", +] + +[[package]] +name = "winapi-util" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" +dependencies = [ + "windows-sys 0.59.0", +] + +[[package]] +name = "windows-core" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-sys" +version = "0.45.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" +dependencies = [ + "windows-targets 0.42.2", +] + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-targets" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" +dependencies = [ + "windows_aarch64_gnullvm 0.42.2", + "windows_aarch64_msvc 0.42.2", + "windows_i686_gnu 0.42.2", + "windows_i686_msvc 0.42.2", + "windows_x86_64_gnu 0.42.2", + "windows_x86_64_gnullvm 0.42.2", + "windows_x86_64_msvc 0.42.2", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", + "windows_i686_gnullvm", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "write16" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1890f4022759daae28ed4fe62859b1236caebfc61ede2f63ed4e695f3f6d936" + +[[package]] +name = "writeable" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51" + +[[package]] +name = "yoke" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "120e6aef9aa629e3d4f52dc8cc43a015c7724194c97dfaf45180d2daf2b77f40" +dependencies = [ + "serde", + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + +[[package]] +name = "zerocopy" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" +dependencies = [ + "byteorder", + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "zerofrom" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cff3ee08c995dee1859d998dea82f7374f2826091dd9cd47def953cae446cd2e" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "595eed982f7d355beb85837f651fa22e90b3c044842dc7f2c2842c086f295808" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + +[[package]] +name = "zeroize" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" + +[[package]] +name = "zerovec" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa2b893d79df23bfb12d5461018d408ea19dfafe76c2c7ef6d4eba614f8ff079" +dependencies = [ + "yoke", + "zerofrom", + "zerovec-derive", +] + +[[package]] +name = "zerovec-derive" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "zstd" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcf2b778a664581e31e389454a7072dab1647606d44f7feea22cd5abb9c9f3f9" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "7.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54a3ab4db68cea366acc5c897c7b4d4d1b8994a9cd6e6f841f8964566a419059" +dependencies = [ + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.13+zstd.1.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38ff0f21cfee8f97d94cef41359e0c89aa6113028ab0291aa8ca0038995a95aa" +dependencies = [ + "cc", + "pkg-config", +] diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..8965987 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,30 @@ +[workspace] +members = [ + "watermelon", + "watermelon-mini", + "watermelon-net", + "watermelon-proto", + "watermelon-nkeys", +] +resolver = "2" + +[workspace.package] +edition = "2021" +license = "MIT OR Apache-2.0" +repository = "https://github.com/M4SS-Code/watermelon" +rust-version = "1.81" + +[workspace.lints.rust] +unsafe_code = "deny" +unreachable_pub = "deny" + +[workspace.lints.clippy] +pedantic = { level = "warn", priority = -1 } +module_name_repetitions = "allow" +await_holding_refcell_ref = "deny" +map_unwrap_or = "warn" +needless_lifetimes = "warn" +needless_raw_string_hashes = "warn" +redundant_closure_for_method_calls = "warn" +semicolon_if_nothing_returned = "warn" +str_to_string = "warn" diff --git a/LICENSE-APACHE b/LICENSE-APACHE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/LICENSE-APACHE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/LICENSE-MIT b/LICENSE-MIT new file mode 100644 index 0000000..710f700 --- /dev/null +++ b/LICENSE-MIT @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024-2025 M4SS Srl + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..eb5f33e --- /dev/null +++ b/README.md @@ -0,0 +1,50 @@ +

watermelon

+
+ + Pure Rust NATS client implementation and tokio integration + +
+ +`watermelon` is an independent implementation of the NATS protocol. +The goal of the project is to produce an opinionated, composable, +idiomatic implementation with a keen eye on security, correctness and +ease of use. + +Watermelon is divided into multiple crates, all hosted in the same monorepo. + +| Crate name | Crates.io release | Documentation | Description | +| ------------------ | --------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------ | ---------------------------------------------------------------------------------------- | +| `watermelon` | [![crates.io](https://img.shields.io/crates/v/watermelon.svg)](https://crates.io/crates/watermelon) | [![Documentation](https://docs.rs/watermelon/badge.svg)](https://docs.rs/watermelon) | High level actor based implementation NATS Core and NATS Jetstream client implementation | +| `watermelon-mini` | [![crates.io](https://img.shields.io/crates/v/watermelon-mini.svg)](https://crates.io/crates/watermelon-mini) | [![Documentation](https://docs.rs/watermelon-mini/badge.svg)](https://docs.rs/watermelon-mini) | Minimal NATS Core client implementation | +| `watermelon-net` | [![crates.io](https://img.shields.io/crates/v/watermelon-net.svg)](https://crates.io/crates/watermelon-net) | [![Documentation](https://docs.rs/watermelon-net/badge.svg)](https://docs.rs/watermelon-net) | Low-level NATS Core network implementation | +| `watermelon-proto` | [![crates.io](https://img.shields.io/crates/v/watermelon-proto.svg)](https://crates.io/crates/watermelon-proto) | [![Documentation](https://docs.rs/watermelon-proto/badge.svg)](https://docs.rs/watermelon-proto) | `#[no_std]` NATS Core Sans-IO protocol implementation | +| `watermelon-proto` | [![crates.io](https://img.shields.io/crates/v/watermelon-proto.svg)](https://crates.io/crates/watermelon-nkeys) | [![Documentation](https://docs.rs/watermelon-nkeys/badge.svg)](https://docs.rs/watermelon-nkeys) | Minimal NKeys implementation for NATS client authentication | + +# Advantages over `async-nats` + +1. **Security**: this client is protected against command injection attacks via checked APIs like `Subject`. +2. **Extendibility**: exposes the inner components via `watermelon-mini` and `watermelon-net`. +3. **Error handling**: `subscribe` errors are correctly caught - internally enables server verbose mode. +4. **Fresh start**: this client only supports nats-server >=2.10.0. We may drop support for older server versions as new ones come out. `tls://` also uses _TLS-first handshake_ mode by default. +5. **Licensing**: dual licensed under MIT and APACHE-2.0. + +# Disadvantages over `async-nats` + +1. **Completeness**: most APIs (Jetstream specifically) haven't been implemented yet. +2. **Future work**: we may never add support for functionallity that we, M4SS Srl, don't use. +3. **Difference in APIs**: official NATS clients tend to have similar APIs. `watermelon` does not follow the official guidelines. +4. **Ecosystem**: as the client does not support older server versions or ignore old configuration options, it may not work in an environment that hasn't adopted the new standards yet. +5. **Backwards compatibility**: this client is in no way compatible with the `async-nats` API and may make frequent breaking changes. + +## License + +Licensed under either of + +- Apache License, Version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or ) +- MIT license ([LICENSE-MIT](LICENSE-MIT) or ) + +at your option. + +### Contribution + +Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion in the work by you, as defined in the Apache-2.0 license, shall be dual licensed as above, without any additional terms or conditions. diff --git a/deny.toml b/deny.toml new file mode 100644 index 0000000..dd03599 --- /dev/null +++ b/deny.toml @@ -0,0 +1,36 @@ +[advisories] +ignore = [ +] + +[licenses] +allow = [ + "MIT", + "Apache-2.0", + "BSD-3-Clause", + "ISC", + "Unicode-3.0", + "MPL-2.0", + "0BSD", + "OpenSSL", +] + +[[licenses.clarify]] +name = "ring" +expression = "ISC AND MIT AND OpenSSL" +license-files = [{ path = "LICENSE", hash = 0xbd0eed23 }] + +[licenses.private] +ignore = true + +[bans] +multiple-versions = "warn" +wildcards = "deny" +deny = [ +] + +[sources] +unknown-registry = "deny" +unknown-git = "deny" + +[sources.allow-org] +#github = ["M4SS-Code"] diff --git a/watermelon-mini/Cargo.toml b/watermelon-mini/Cargo.toml new file mode 100644 index 0000000..37a253e --- /dev/null +++ b/watermelon-mini/Cargo.toml @@ -0,0 +1,38 @@ +[package] +name = "watermelon-mini" +version = "0.1.0" +description = "Minimal NATS Core client implementation" +categories = ["api-bindings", "network-programming"] +keywords = ["nats", "client"] +edition.workspace = true +license.workspace = true +repository.workspace = true +rust-version.workspace = true + +[package.metadata.docs.rs] +features = ["websocket", "non-standard-zstd"] + +[dependencies] +tokio = { version = "1", features = ["net"] } +tokio-rustls = { version = "0.26", default-features = false } +rustls-platform-verifier = "0.5" + +watermelon-net = { version = "0.1", path = "../watermelon-net" } +watermelon-proto = { version = "0.1", path = "../watermelon-proto" } +watermelon-nkeys = { version = "0.1", path = "../watermelon-nkeys", default-features = false } + +thiserror = "2" + +# non-standard-zstd +async-compression = { version = "0.4", features = ["tokio"], optional = true } + +[features] +default = ["aws-lc-rs"] +websocket = ["watermelon-net/websocket"] +aws-lc-rs = ["tokio-rustls/aws-lc-rs", "watermelon-net/aws-lc-rs", "watermelon-nkeys/aws-lc-rs"] +ring = ["tokio-rustls/ring", "watermelon-net/ring", "watermelon-nkeys/ring"] +fips = ["tokio-rustls/fips", "watermelon-net/fips", "watermelon-nkeys/fips"] +non-standard-zstd = ["watermelon-net/non-standard-zstd", "watermelon-proto/non-standard-zstd", "dep:async-compression", "async-compression/zstd"] + +[lints] +workspace = true diff --git a/watermelon-mini/LICENSE-APACHE b/watermelon-mini/LICENSE-APACHE new file mode 120000 index 0000000..965b606 --- /dev/null +++ b/watermelon-mini/LICENSE-APACHE @@ -0,0 +1 @@ +../LICENSE-APACHE \ No newline at end of file diff --git a/watermelon-mini/LICENSE-MIT b/watermelon-mini/LICENSE-MIT new file mode 120000 index 0000000..76219eb --- /dev/null +++ b/watermelon-mini/LICENSE-MIT @@ -0,0 +1 @@ +../LICENSE-MIT \ No newline at end of file diff --git a/watermelon-mini/README.md b/watermelon-mini/README.md new file mode 120000 index 0000000..32d46ee --- /dev/null +++ b/watermelon-mini/README.md @@ -0,0 +1 @@ +../README.md \ No newline at end of file diff --git a/watermelon-mini/src/lib.rs b/watermelon-mini/src/lib.rs new file mode 100644 index 0000000..49d1620 --- /dev/null +++ b/watermelon-mini/src/lib.rs @@ -0,0 +1,79 @@ +use std::sync::Arc; + +use rustls_platform_verifier::Verifier; +use tokio::net::TcpStream; +use tokio_rustls::{ + rustls::{self, crypto::CryptoProvider, version::TLS13, ClientConfig}, + TlsConnector, +}; +use watermelon_net::Connection; +use watermelon_proto::{ServerAddr, ServerInfo}; + +#[cfg(feature = "non-standard-zstd")] +pub use self::non_standard_zstd::ZstdStream; +use self::proto::connect; +pub use self::proto::{ + AuthenticationMethod, ConnectError, ConnectionCompression, ConnectionSecurity, +}; + +#[cfg(feature = "non-standard-zstd")] +pub(crate) mod non_standard_zstd; +mod proto; +mod util; + +#[derive(Debug, Clone, Default)] +#[non_exhaustive] +pub struct ConnectFlags { + pub echo: bool, + #[cfg(feature = "non-standard-zstd")] + pub zstd: bool, +} + +/// Connect to a given address with some reasonable presets. +/// +/// The function is going to establish a TLS 1.3 connection, without the support of the client +/// authorization. +/// +/// # Errors +/// +/// This returns an error in case the connection fails. +#[expect( + clippy::missing_panics_doc, + reason = "the crypto_provider function always returns a provider that supports TLS 1.3" +)] +pub async fn easy_connect( + addr: &ServerAddr, + auth: Option<&AuthenticationMethod>, + flags: ConnectFlags, +) -> Result< + ( + Connection< + ConnectionCompression>, + ConnectionSecurity, + >, + Box, + ), + ConnectError, +> { + let provider = Arc::new(crypto_provider()); + let connector = TlsConnector::from(Arc::new( + ClientConfig::builder_with_provider(Arc::clone(&provider)) + .with_protocol_versions(&[&TLS13]) + .unwrap() + .dangerous() + .with_custom_certificate_verifier(Arc::new(Verifier::new().with_provider(provider))) + .with_no_client_auth(), + )); + + let (conn, info) = connect(&connector, addr, "watermelon".to_owned(), auth, flags).await?; + Ok((conn, info)) +} + +fn crypto_provider() -> CryptoProvider { + #[cfg(feature = "aws-lc-rs")] + return rustls::crypto::aws_lc_rs::default_provider(); + #[cfg(all(not(feature = "aws-lc-rs"), feature = "ring"))] + return rustls::crypto::ring::default_provider(); + #[cfg(not(any(feature = "aws-lc-rs", feature = "ring")))] + compile_error!("Please enable the `aws-lc-rs` or the `ring` feature") +} diff --git a/watermelon-mini/src/non_standard_zstd.rs b/watermelon-mini/src/non_standard_zstd.rs new file mode 100644 index 0000000..62f8453 --- /dev/null +++ b/watermelon-mini/src/non_standard_zstd.rs @@ -0,0 +1,107 @@ +use std::{ + fmt::{self, Debug, Formatter}, + io, + pin::Pin, + task::{Context, Poll}, +}; + +use async_compression::tokio::{bufread::ZstdDecoder, write::ZstdEncoder}; +use tokio::io::{AsyncRead, AsyncWrite, BufReader, ReadBuf}; + +use crate::util::MaybeConnection; + +pub struct ZstdStream { + decoder: ZstdDecoder>>, + encoder: ZstdEncoder>, +} + +impl ZstdStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + #[must_use] + pub fn new(stream: S) -> Self { + Self { + decoder: ZstdDecoder::new(BufReader::new(MaybeConnection(Some(stream)))), + encoder: ZstdEncoder::new(MaybeConnection(None)), + } + } +} + +impl AsyncRead for ZstdStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + if let Some(stream) = self.encoder.get_mut().0.take() { + self.decoder.get_mut().get_mut().0 = Some(stream); + } + + Pin::new(&mut self.decoder).poll_read(cx, buf) + } +} + +impl AsyncWrite for ZstdStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + if let Some(stream) = self.decoder.get_mut().get_mut().0.take() { + self.encoder.get_mut().0 = Some(stream); + } + + Pin::new(&mut self.encoder).poll_write(cx, buf) + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + if let Some(stream) = self.decoder.get_mut().get_mut().0.take() { + self.encoder.get_mut().0 = Some(stream); + } + + Pin::new(&mut self.encoder).poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + if let Some(stream) = &self.encoder.get_ref().0 { + stream.is_write_vectored() + } else if let Some(stream) = &self.decoder.get_ref().get_ref().0 { + stream.is_write_vectored() + } else { + unreachable!() + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if let Some(stream) = self.decoder.get_mut().get_mut().0.take() { + self.encoder.get_mut().0 = Some(stream); + } + + Pin::new(&mut self.encoder).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if let Some(stream) = self.decoder.get_mut().get_mut().0.take() { + self.encoder.get_mut().0 = Some(stream); + } + + Pin::new(&mut self.encoder).poll_shutdown(cx) + } +} + +impl Debug for ZstdStream { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("ZstdStream").finish_non_exhaustive() + } +} diff --git a/watermelon-mini/src/proto/authenticator.rs b/watermelon-mini/src/proto/authenticator.rs new file mode 100644 index 0000000..b4c7e8a --- /dev/null +++ b/watermelon-mini/src/proto/authenticator.rs @@ -0,0 +1,148 @@ +use std::fmt::{self, Debug, Formatter}; + +use watermelon_nkeys::{KeyPair, KeyPairFromSeedError}; +use watermelon_proto::{Connect, ServerAddr, ServerInfo}; + +pub enum AuthenticationMethod { + UserAndPassword { username: String, password: String }, + Creds { jwt: String, nkey: KeyPair }, +} + +#[derive(Debug, thiserror::Error)] +pub enum AuthenticationError { + #[error("missing nonce")] + MissingNonce, +} + +#[derive(Debug, thiserror::Error)] +pub enum CredsParseError { + #[error("contents are truncated")] + Truncated, + #[error("missing closing for JWT")] + MissingJwtClosing, + #[error("missing closing for nkey")] + MissingNkeyClosing, + #[error("missing JWT")] + MissingJwt, + #[error("missing nkey")] + MissingNkey, + #[error("invalid nkey")] + InvalidKey(#[source] KeyPairFromSeedError), +} + +impl AuthenticationMethod { + pub(crate) fn try_from_addr(addr: &ServerAddr) -> Option { + if let (Some(username), Some(password)) = (addr.username(), addr.password()) { + Some(Self::UserAndPassword { + username: username.to_owned(), + password: password.to_owned(), + }) + } else { + None + } + } + + pub(crate) fn prepare_for_auth( + &self, + info: &ServerInfo, + connect: &mut Connect, + ) -> Result<(), AuthenticationError> { + match self { + Self::UserAndPassword { username, password } => { + connect.username = Some(username.clone()); + connect.password = Some(password.clone()); + } + Self::Creds { jwt, nkey } => { + let nonce = info + .nonce + .as_deref() + .ok_or(AuthenticationError::MissingNonce)?; + let signature = nkey.sign(nonce.as_bytes()).to_string(); + + connect.jwt = Some(jwt.clone()); + connect.nkey = Some(nkey.public_key().to_string()); + connect.signature = Some(signature); + } + } + + Ok(()) + } + + /// Creates an `AuthenticationMethod` from the content of a credentials file. + /// + /// # Errors + /// + /// It returns an error if the content is not valid. + pub fn from_creds(contents: &str) -> Result { + let mut jtw = None; + let mut secret = None; + + let mut lines = contents.lines(); + while let Some(line) = lines.next() { + if line == "-----BEGIN NATS USER JWT-----" { + jtw = Some(lines.next().ok_or(CredsParseError::Truncated)?); + + let line = lines.next().ok_or(CredsParseError::Truncated)?; + if line != "------END NATS USER JWT------" { + return Err(CredsParseError::MissingJwtClosing); + } + } else if line == "-----BEGIN USER NKEY SEED-----" { + secret = Some(lines.next().ok_or(CredsParseError::Truncated)?); + + let line = lines.next().ok_or(CredsParseError::Truncated)?; + if line != "------END USER NKEY SEED------" { + return Err(CredsParseError::MissingNkeyClosing); + } + } + } + + let jtw = jtw.ok_or(CredsParseError::MissingJwt)?; + let nkey = secret.ok_or(CredsParseError::MissingNkey)?; + let nkey = KeyPair::from_encoded_seed(nkey).map_err(CredsParseError::InvalidKey)?; + + Ok(Self::Creds { + jwt: jtw.to_owned(), + nkey, + }) + } +} + +impl Debug for AuthenticationMethod { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("AuthenticationMethod") + .finish_non_exhaustive() + } +} + +#[cfg(test)] +mod tests { + use super::AuthenticationMethod; + + #[test] + fn parse_creds() { + let creds = r"-----BEGIN NATS USER JWT----- +eyJ0eXAiOiJqd3QiLCJhbGciOiJlZDI1NTE5In0.eyJqdGkiOiJUVlNNTEtTWkJBN01VWDNYQUxNUVQzTjRISUw1UkZGQU9YNUtaUFhEU0oyWlAzNkVMNVJBIiwiaWF0IjoxNTU4MDQ1NTYyLCJpc3MiOiJBQlZTQk0zVTQ1REdZRVVFQ0tYUVM3QkVOSFdHN0tGUVVEUlRFSEFKQVNPUlBWV0JaNEhPSUtDSCIsIm5hbWUiOiJvbWVnYSIsInN1YiI6IlVEWEIyVk1MWFBBU0FKN1pEVEtZTlE3UU9DRldTR0I0Rk9NWVFRMjVIUVdTQUY3WlFKRUJTUVNXIiwidHlwZSI6InVzZXIiLCJuYXRzIjp7InB1YiI6e30sInN1YiI6e319fQ.6TQ2ilCDb6m2ZDiJuj_D_OePGXFyN3Ap2DEm3ipcU5AhrWrNvneJryWrpgi_yuVWKo1UoD5s8bxlmwypWVGFAA +------END NATS USER JWT------ + +************************* IMPORTANT ************************* +NKEY Seed printed below can be used to sign and prove identity. +NKEYs are sensitive and should be treated as secrets. + +-----BEGIN USER NKEY SEED----- +SUAOY5JZ2WJKVR4UO2KJ2P3SW6FZFNWEOIMAXF4WZEUNVQXXUOKGM55CYE +------END USER NKEY SEED------ + +*************************************************************"; + + let AuthenticationMethod::Creds { jwt, nkey } = + AuthenticationMethod::from_creds(creds).unwrap() + else { + panic!("invalid auth method"); + }; + assert_eq!(jwt, "eyJ0eXAiOiJqd3QiLCJhbGciOiJlZDI1NTE5In0.eyJqdGkiOiJUVlNNTEtTWkJBN01VWDNYQUxNUVQzTjRISUw1UkZGQU9YNUtaUFhEU0oyWlAzNkVMNVJBIiwiaWF0IjoxNTU4MDQ1NTYyLCJpc3MiOiJBQlZTQk0zVTQ1REdZRVVFQ0tYUVM3QkVOSFdHN0tGUVVEUlRFSEFKQVNPUlBWV0JaNEhPSUtDSCIsIm5hbWUiOiJvbWVnYSIsInN1YiI6IlVEWEIyVk1MWFBBU0FKN1pEVEtZTlE3UU9DRldTR0I0Rk9NWVFRMjVIUVdTQUY3WlFKRUJTUVNXIiwidHlwZSI6InVzZXIiLCJuYXRzIjp7InB1YiI6e30sInN1YiI6e319fQ.6TQ2ilCDb6m2ZDiJuj_D_OePGXFyN3Ap2DEm3ipcU5AhrWrNvneJryWrpgi_yuVWKo1UoD5s8bxlmwypWVGFAA"); + assert_eq!( + nkey.public_key().to_string(), + "SAAO4HKVRO54CIBH7EONLBWD6BYIW2IYHQVZTCCDLU6C2IAX7GBEQGJDYE" + ); + } +} diff --git a/watermelon-mini/src/proto/connection/compression.rs b/watermelon-mini/src/proto/connection/compression.rs new file mode 100644 index 0000000..6350f8d --- /dev/null +++ b/watermelon-mini/src/proto/connection/compression.rs @@ -0,0 +1,106 @@ +use std::{ + io, + pin::Pin, + task::{Context, Poll}, +}; + +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +#[cfg(feature = "non-standard-zstd")] +use crate::non_standard_zstd::ZstdStream; + +#[derive(Debug)] +pub enum ConnectionCompression { + Plain(S), + #[cfg(feature = "non-standard-zstd")] + Zstd(ZstdStream), +} + +impl ConnectionCompression +where + S: AsyncRead + AsyncWrite + Unpin, +{ + #[cfg(feature = "non-standard-zstd")] + pub(crate) fn upgrade_zstd(self) -> Self { + let Self::Plain(socket) = self else { + unreachable!() + }; + + Self::Zstd(ZstdStream::new(socket)) + } + + #[cfg(feature = "non-standard-zstd")] + pub fn is_zstd_compressed(&self) -> bool { + matches!(self, Self::Zstd(_)) + } +} + +impl AsyncRead for ConnectionCompression +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match self.get_mut() { + Self::Plain(conn) => Pin::new(conn).poll_read(cx, buf), + #[cfg(feature = "non-standard-zstd")] + Self::Zstd(conn) => Pin::new(conn).poll_read(cx, buf), + } + } +} + +impl AsyncWrite for ConnectionCompression +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match self.get_mut() { + Self::Plain(conn) => Pin::new(conn).poll_write(cx, buf), + #[cfg(feature = "non-standard-zstd")] + Self::Zstd(conn) => Pin::new(conn).poll_write(cx, buf), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + Self::Plain(conn) => Pin::new(conn).poll_flush(cx), + #[cfg(feature = "non-standard-zstd")] + Self::Zstd(conn) => Pin::new(conn).poll_flush(cx), + } + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + Self::Plain(conn) => Pin::new(conn).poll_shutdown(cx), + #[cfg(feature = "non-standard-zstd")] + Self::Zstd(conn) => Pin::new(conn).poll_shutdown(cx), + } + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + match self.get_mut() { + Self::Plain(conn) => Pin::new(conn).poll_write_vectored(cx, bufs), + #[cfg(feature = "non-standard-zstd")] + Self::Zstd(conn) => Pin::new(conn).poll_write_vectored(cx, bufs), + } + } + + fn is_write_vectored(&self) -> bool { + match self { + Self::Plain(conn) => conn.is_write_vectored(), + #[cfg(feature = "non-standard-zstd")] + Self::Zstd(conn) => conn.is_write_vectored(), + } + } +} diff --git a/watermelon-mini/src/proto/connection/mod.rs b/watermelon-mini/src/proto/connection/mod.rs new file mode 100644 index 0000000..60029a6 --- /dev/null +++ b/watermelon-mini/src/proto/connection/mod.rs @@ -0,0 +1,5 @@ +pub use self::compression::ConnectionCompression; +pub use self::security::ConnectionSecurity; + +mod compression; +mod security; diff --git a/watermelon-mini/src/proto/connection/security.rs b/watermelon-mini/src/proto/connection/security.rs new file mode 100644 index 0000000..7c9d0ce --- /dev/null +++ b/watermelon-mini/src/proto/connection/security.rs @@ -0,0 +1,101 @@ +use std::{ + io, + pin::Pin, + task::{Context, Poll}, +}; + +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio_rustls::{client::TlsStream, rustls::pki_types::ServerName, TlsConnector}; + +#[derive(Debug)] +#[expect( + clippy::large_enum_variant, + reason = "using TLS is the recommended thing, we do not want to affect it" +)] +pub enum ConnectionSecurity { + Plain(S), + Tls(TlsStream), +} + +impl ConnectionSecurity +where + S: AsyncRead + AsyncWrite + Unpin, +{ + pub(crate) async fn upgrade_tls( + self, + connector: &TlsConnector, + domain: ServerName<'static>, + ) -> io::Result { + let conn = match self { + Self::Plain(conn) => conn, + Self::Tls(_) => unreachable!("trying to upgrade to Tls a Tls connection"), + }; + + let conn = connector.connect(domain, conn).await?; + Ok(Self::Tls(conn)) + } +} + +impl AsyncRead for ConnectionSecurity +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match self.get_mut() { + Self::Plain(conn) => Pin::new(conn).poll_read(cx, buf), + Self::Tls(conn) => Pin::new(conn).poll_read(cx, buf), + } + } +} + +impl AsyncWrite for ConnectionSecurity +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match self.get_mut() { + Self::Plain(conn) => Pin::new(conn).poll_write(cx, buf), + Self::Tls(conn) => Pin::new(conn).poll_write(cx, buf), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + Self::Plain(conn) => Pin::new(conn).poll_flush(cx), + Self::Tls(conn) => Pin::new(conn).poll_flush(cx), + } + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + Self::Plain(conn) => Pin::new(conn).poll_shutdown(cx), + Self::Tls(conn) => Pin::new(conn).poll_shutdown(cx), + } + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + match self.get_mut() { + Self::Plain(conn) => Pin::new(conn).poll_write_vectored(cx, bufs), + Self::Tls(conn) => Pin::new(conn).poll_write_vectored(cx, bufs), + } + } + + fn is_write_vectored(&self) -> bool { + match self { + Self::Plain(conn) => conn.is_write_vectored(), + Self::Tls(conn) => conn.is_write_vectored(), + } + } +} diff --git a/watermelon-mini/src/proto/connector.rs b/watermelon-mini/src/proto/connector.rs new file mode 100644 index 0000000..c7725d3 --- /dev/null +++ b/watermelon-mini/src/proto/connector.rs @@ -0,0 +1,223 @@ +use std::io; + +use tokio::net::TcpStream; +use tokio_rustls::{ + rustls::pki_types::{InvalidDnsNameError, ServerName}, + TlsConnector, +}; +use watermelon_net::{ + connect_tcp, + error::{ConnectionReadError, StreamingReadError}, + proto_connect, Connection, StreamingConnection, +}; +#[cfg(feature = "websocket")] +use watermelon_net::{error::WebsocketReadError, WebsocketConnection}; +#[cfg(feature = "websocket")] +use watermelon_proto::proto::error::FrameDecoderError; +use watermelon_proto::{ + proto::{error::DecoderError, ServerOp}, + Connect, Host, NonStandardConnect, Protocol, ServerAddr, ServerInfo, Transport, +}; + +use crate::{util::MaybeConnection, ConnectFlags, ConnectionCompression}; + +use super::{ + authenticator::{AuthenticationError, AuthenticationMethod}, + connection::ConnectionSecurity, +}; + +#[derive(Debug, thiserror::Error)] +pub enum ConnectError { + #[error("io error")] + Io(#[source] io::Error), + #[error("invalid DNS name")] + InvalidDnsName(#[source] InvalidDnsNameError), + #[error("websocket not supported")] + WebsocketUnsupported, + #[error("unexpected ServerOp")] + UnexpectedServerOp, + #[error("decoder error")] + Decoder(#[source] DecoderError), + #[error("authentication error")] + Authentication(#[source] AuthenticationError), + #[error("connect")] + Connect(#[source] watermelon_net::error::ConnectError), +} + +#[expect(clippy::too_many_lines)] +pub(crate) async fn connect( + connector: &TlsConnector, + addr: &ServerAddr, + client_name: String, + auth_method: Option<&AuthenticationMethod>, + flags: ConnectFlags, +) -> Result< + ( + Connection< + ConnectionCompression>, + ConnectionSecurity, + >, + Box, + ), + ConnectError, +> { + let conn = connect_tcp(addr).await.map_err(ConnectError::Io)?; + conn.set_nodelay(true).map_err(ConnectError::Io)?; + let mut conn = ConnectionSecurity::Plain(conn); + + if matches!(addr.protocol(), Protocol::TLS) { + let domain = rustls_server_name_from_addr(addr).map_err(ConnectError::InvalidDnsName)?; + conn = conn + .upgrade_tls(connector, domain.to_owned()) + .await + .map_err(ConnectError::Io)?; + } + + let mut conn = match addr.transport() { + Transport::TCP => Connection::Streaming(StreamingConnection::new(conn)), + #[cfg(feature = "websocket")] + Transport::Websocket => { + let uri = addr.to_string().parse().unwrap(); + Connection::Websocket( + WebsocketConnection::new(uri, conn) + .await + .map_err(ConnectError::Io)?, + ) + } + #[cfg(not(feature = "websocket"))] + Transport::Websocket => return Err(ConnectError::WebsocketUnsupported), + }; + let info = match conn.read_next().await { + Ok(ServerOp::Info { info }) => info, + Ok(_) => return Err(ConnectError::UnexpectedServerOp), + Err(ConnectionReadError::Streaming(StreamingReadError::Io(err))) => { + return Err(ConnectError::Io(err)) + } + Err(ConnectionReadError::Streaming(StreamingReadError::Decoder(err))) => { + return Err(ConnectError::Decoder(err)) + } + #[cfg(feature = "websocket")] + Err(ConnectionReadError::Websocket(WebsocketReadError::Io(err))) => { + return Err(ConnectError::Io(err)) + } + #[cfg(feature = "websocket")] + Err(ConnectionReadError::Websocket(WebsocketReadError::Decoder( + FrameDecoderError::Decoder(err), + ))) => return Err(ConnectError::Decoder(err)), + #[cfg(feature = "websocket")] + Err(ConnectionReadError::Websocket(WebsocketReadError::Decoder( + FrameDecoderError::IncompleteFrame, + ))) => todo!(), + #[cfg(feature = "websocket")] + Err(ConnectionReadError::Websocket(WebsocketReadError::Closed)) => todo!(), + }; + + let conn = match conn { + Connection::Streaming(streaming) => Connection::Streaming( + if matches!( + (addr.protocol(), info.tls_required), + (Protocol::PossiblyPlain, true) + ) { + let domain = + rustls_server_name_from_addr(addr).map_err(ConnectError::InvalidDnsName)?; + StreamingConnection::new( + streaming + .into_inner() + .upgrade_tls(connector, domain.to_owned()) + .await + .map_err(ConnectError::Io)?, + ) + } else { + streaming + }, + ), + Connection::Websocket(websocket) => Connection::Websocket(websocket), + }; + + let auth; + let auth_method = if let Some(auth_method) = auth_method { + Some(auth_method) + } else if let Some(auth_method) = AuthenticationMethod::try_from_addr(addr) { + auth = auth_method; + Some(&auth) + } else { + None + }; + + #[allow(unused_mut)] + let mut non_standard = NonStandardConnect::default(); + #[cfg(feature = "non-standard-zstd")] + if matches!(conn, Connection::Streaming(_)) { + non_standard.zstd = flags.zstd && info.non_standard.zstd; + } + + let mut connect = Connect { + verbose: true, + pedantic: false, + require_tls: false, + auth_token: None, + username: None, + password: None, + client_name: Some(client_name), + client_lang: "rust-watermelon", + client_version: env!("CARGO_PKG_VERSION"), + protocol: 1, + echo: flags.echo, + signature: None, + jwt: None, + supports_no_responders: true, + supports_headers: true, + nkey: None, + non_standard, + }; + if let Some(auth_method) = auth_method { + auth_method + .prepare_for_auth(&info, &mut connect) + .map_err(ConnectError::Authentication)?; + } + + let mut conn = match conn { + Connection::Streaming(streaming) => { + Connection::Streaming(streaming.replace_socket(|stream| { + MaybeConnection(Some(ConnectionCompression::Plain(stream))) + })) + } + Connection::Websocket(websocket) => Connection::Websocket(websocket), + }; + + #[cfg(feature = "non-standard-zstd")] + let zstd = connect.non_standard.zstd; + + proto_connect(&mut conn, connect, |conn| { + #[cfg(feature = "non-standard-zstd")] + match conn { + Connection::Streaming(streaming) => { + if zstd { + let stream = streaming.socket_mut().0.take().unwrap(); + streaming.socket_mut().0 = Some(stream.upgrade_zstd()); + } + } + Connection::Websocket(_websocket) => {} + } + + let _ = conn; + }) + .await + .map_err(ConnectError::Connect)?; + + let conn = match conn { + Connection::Streaming(streaming) => { + Connection::Streaming(streaming.replace_socket(|stream| stream.0.unwrap())) + } + Connection::Websocket(websocket) => Connection::Websocket(websocket), + }; + + Ok((conn, info)) +} + +fn rustls_server_name_from_addr(addr: &ServerAddr) -> Result, InvalidDnsNameError> { + match addr.host() { + Host::Ip(addr) => Ok(ServerName::IpAddress((*addr).into())), + Host::Dns(name) => <_ as AsRef>::as_ref(name).try_into(), + } +} diff --git a/watermelon-mini/src/proto/mod.rs b/watermelon-mini/src/proto/mod.rs new file mode 100644 index 0000000..134ead3 --- /dev/null +++ b/watermelon-mini/src/proto/mod.rs @@ -0,0 +1,8 @@ +pub use self::authenticator::AuthenticationMethod; +pub use self::connection::{ConnectionCompression, ConnectionSecurity}; +pub(crate) use self::connector::connect; +pub use self::connector::ConnectError; + +mod authenticator; +mod connection; +mod connector; diff --git a/watermelon-mini/src/util.rs b/watermelon-mini/src/util.rs new file mode 100644 index 0000000..0fdbc4e --- /dev/null +++ b/watermelon-mini/src/util.rs @@ -0,0 +1,56 @@ +use std::{ + io, + pin::Pin, + task::{Context, Poll}, +}; + +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +#[derive(Debug)] +pub(crate) struct MaybeConnection(pub(crate) Option); + +impl AsyncRead for MaybeConnection +where + S: AsyncRead + Unpin, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(self.0.as_mut().unwrap()).poll_read(cx, buf) + } +} + +impl AsyncWrite for MaybeConnection +where + S: AsyncWrite + Unpin, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(self.0.as_mut().unwrap()).poll_write(cx, buf) + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + Pin::new(self.0.as_mut().unwrap()).poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + self.0.as_ref().unwrap().is_write_vectored() + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(self.0.as_mut().unwrap()).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(self.0.as_mut().unwrap()).poll_shutdown(cx) + } +} diff --git a/watermelon-net/Cargo.toml b/watermelon-net/Cargo.toml new file mode 100644 index 0000000..e48c688 --- /dev/null +++ b/watermelon-net/Cargo.toml @@ -0,0 +1,43 @@ +[package] +name = "watermelon-net" +version = "0.1.0" +description = "Low-level NATS Core network implementation" +categories = ["api-bindings", "network-programming"] +keywords = ["nats", "client"] +edition.workspace = true +license.workspace = true +repository.workspace = true +rust-version.workspace = true + +[package.metadata.docs.rs] +features = ["websocket", "non-standard-zstd"] + +[dependencies] +tokio = { version = "1", features = ["net", "time", "io-util"] } +futures-util = { version = "0.3.14", default-features = false, features = ["alloc"] } +bytes = "1" + +tokio-websockets = { version = "0.11", features = ["client", "rand"], optional = true } +futures-sink = { version = "0.3.14", default-features = false, optional = true } +http = { version = "1", optional = true } + +watermelon-proto = { version = "0.1", path = "../watermelon-proto" } + +thiserror = "2" +pin-project-lite = "0.2.15" + +[dev-dependencies] +tokio = { version = "1", features = ["macros", "rt"] } +futures-util = { version = "0.3.14", default-features = false } +claims = "0.8" + +[features] +default = ["aws-lc-rs"] +websocket = ["dep:tokio-websockets", "dep:futures-sink", "dep:http"] +ring = ["tokio-websockets?/ring"] +aws-lc-rs = ["tokio-websockets?/aws-lc-rs"] +fips = ["tokio-websockets?/fips"] +non-standard-zstd = ["watermelon-proto/non-standard-zstd"] + +[lints] +workspace = true diff --git a/watermelon-net/LICENSE-APACHE b/watermelon-net/LICENSE-APACHE new file mode 120000 index 0000000..965b606 --- /dev/null +++ b/watermelon-net/LICENSE-APACHE @@ -0,0 +1 @@ +../LICENSE-APACHE \ No newline at end of file diff --git a/watermelon-net/LICENSE-MIT b/watermelon-net/LICENSE-MIT new file mode 120000 index 0000000..76219eb --- /dev/null +++ b/watermelon-net/LICENSE-MIT @@ -0,0 +1 @@ +../LICENSE-MIT \ No newline at end of file diff --git a/watermelon-net/README.md b/watermelon-net/README.md new file mode 120000 index 0000000..32d46ee --- /dev/null +++ b/watermelon-net/README.md @@ -0,0 +1 @@ +../README.md \ No newline at end of file diff --git a/watermelon-net/src/connection/mod.rs b/watermelon-net/src/connection/mod.rs new file mode 100644 index 0000000..71e3ff0 --- /dev/null +++ b/watermelon-net/src/connection/mod.rs @@ -0,0 +1,264 @@ +#[cfg(not(feature = "websocket"))] +use std::{convert::Infallible, marker::PhantomData}; +use std::{ + io, + task::{Context, Poll}, +}; + +use tokio::io::{AsyncRead, AsyncWrite}; +#[cfg(feature = "websocket")] +use watermelon_proto::proto::error::FrameDecoderError; +use watermelon_proto::{ + error::ServerError, + proto::{error::DecoderError, ClientOp, ServerOp}, + Connect, +}; + +pub use self::streaming::{StreamingConnection, StreamingReadError}; +#[cfg(feature = "websocket")] +pub use self::websocket::{WebsocketConnection, WebsocketReadError}; + +mod streaming; +#[cfg(feature = "websocket")] +mod websocket; + +#[derive(Debug)] +pub enum Connection { + Streaming(StreamingConnection), + Websocket(WebsocketConnection), +} + +#[derive(Debug)] +#[cfg(not(feature = "websocket"))] +#[doc(hidden)] +pub struct WebsocketConnection { + _socket: PhantomData, + _impossible: Infallible, +} + +#[derive(Debug, thiserror::Error)] +pub enum ConnectionReadError { + #[error("streaming connection error")] + Streaming(#[source] StreamingReadError), + #[cfg(feature = "websocket")] + #[error("websocket connection error")] + Websocket(#[source] WebsocketReadError), +} + +impl Connection +where + S1: AsyncRead + AsyncWrite + Unpin, + S2: AsyncRead + AsyncWrite + Unpin, +{ + pub fn poll_read_next( + &mut self, + cx: &mut Context<'_>, + ) -> Poll> { + match self { + Self::Streaming(streaming) => streaming + .poll_read_next(cx) + .map_err(ConnectionReadError::Streaming), + #[cfg(feature = "websocket")] + Self::Websocket(websocket) => websocket + .poll_read_next(cx) + .map_err(ConnectionReadError::Websocket), + #[cfg(not(feature = "websocket"))] + Self::Websocket(_) => unreachable!(), + } + } + + /// Read the next incoming server operation. + /// + /// # Errors + /// + /// Returns an error if reading or decoding the message fails. + pub async fn read_next(&mut self) -> Result { + match self { + Self::Streaming(streaming) => streaming + .read_next() + .await + .map_err(ConnectionReadError::Streaming), + #[cfg(feature = "websocket")] + Self::Websocket(websocket) => websocket + .read_next() + .await + .map_err(ConnectionReadError::Websocket), + #[cfg(not(feature = "websocket"))] + Self::Websocket(_) => unreachable!(), + } + } + + pub fn flushes_automatically_when_full(&self) -> bool { + match self { + Self::Streaming(_streaming) => true, + #[cfg(feature = "websocket")] + Self::Websocket(_websocket) => false, + #[cfg(not(feature = "websocket"))] + Self::Websocket(_) => unreachable!(), + } + } + + pub fn should_flush(&self) -> bool { + match self { + Self::Streaming(streaming) => streaming.may_flush(), + #[cfg(feature = "websocket")] + Self::Websocket(websocket) => websocket.should_flush(), + #[cfg(not(feature = "websocket"))] + Self::Websocket(_) => unreachable!(), + } + } + + pub fn may_enqueue_more_ops(&mut self) -> bool { + match self { + Self::Streaming(streaming) => streaming.may_enqueue_more_ops(), + #[cfg(feature = "websocket")] + Self::Websocket(websocket) => websocket.may_enqueue_more_ops(), + #[cfg(not(feature = "websocket"))] + Self::Websocket(_) => unreachable!(), + } + } + + pub fn enqueue_write_op(&mut self, item: &ClientOp) { + match self { + Self::Streaming(streaming) => streaming.enqueue_write_op(item), + #[cfg(feature = "websocket")] + Self::Websocket(websocket) => websocket.enqueue_write_op(item), + #[cfg(not(feature = "websocket"))] + Self::Websocket(_) => unreachable!(), + } + } + + /// Convenience function for writing enqueued messages and flushing. + /// + /// # Errors + /// + /// Returns an error if writing or flushing fails. + pub async fn write_and_flush(&mut self) -> io::Result<()> { + if let Self::Streaming(streaming) = self { + while streaming.may_write() { + streaming.write_next().await?; + } + } + + self.flush().await + } + + pub fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll> { + match self { + Self::Streaming(streaming) => streaming.poll_flush(cx), + #[cfg(feature = "websocket")] + Self::Websocket(websocket) => websocket.poll_flush(cx), + #[cfg(not(feature = "websocket"))] + Self::Websocket(_) => unreachable!(), + } + } + + /// Flush any buffered writes to the connection + /// + /// # Errors + /// + /// Returns an error if flushing fails + pub async fn flush(&mut self) -> io::Result<()> { + match self { + Self::Streaming(streaming) => streaming.flush().await, + #[cfg(feature = "websocket")] + Self::Websocket(websocket) => websocket.flush().await, + #[cfg(not(feature = "websocket"))] + Self::Websocket(_) => unreachable!(), + } + } + + /// Shutdown the connection + /// + /// # Errors + /// + /// Returns an error if shutting down the connection fails. + /// Implementations usually ignore this error. + pub async fn shutdown(&mut self) -> io::Result<()> { + match self { + Self::Streaming(streaming) => streaming.shutdown().await, + #[cfg(feature = "websocket")] + Self::Websocket(websocket) => websocket.shutdown().await, + #[cfg(not(feature = "websocket"))] + Self::Websocket(_) => unreachable!(), + } + } +} + +#[derive(Debug, thiserror::Error)] +pub enum ConnectError { + #[error("proto")] + Proto(#[source] DecoderError), + #[error("server")] + ServerError(#[source] ServerError), + #[error("io")] + Io(#[source] io::Error), + #[error("unexpected ServerOp")] + UnexpectedOp, +} + +/// Send the `CONNECT` command to a pre-establised connection `conn`. +/// +/// # Errors +/// +/// Returns an error if connecting fails +pub async fn connect( + conn: &mut Connection, + connect: Connect, + after_connect: F, +) -> Result<(), ConnectError> +where + S1: AsyncRead + AsyncWrite + Unpin, + S2: AsyncRead + AsyncWrite + Unpin, + F: FnOnce(&mut Connection), +{ + conn.enqueue_write_op(&ClientOp::Connect { + connect: Box::new(connect), + }); + conn.write_and_flush().await.map_err(ConnectError::Io)?; + + after_connect(conn); + conn.enqueue_write_op(&ClientOp::Ping); + conn.write_and_flush().await.map_err(ConnectError::Io)?; + + loop { + match conn.read_next().await { + Ok(ServerOp::Success) => { + // Success. Repeat to receive the PONG + } + Ok(ServerOp::Pong) => { + // Success. We've received the PONG, + // possibly after having received OK. + return Ok(()); + } + Ok(ServerOp::Ping) => { + // I guess this could somehow happen. Handle it and repeat + conn.enqueue_write_op(&ClientOp::Pong); + } + Ok(ServerOp::Error { error }) => return Err(ConnectError::ServerError(error)), + Ok(ServerOp::Info { .. } | ServerOp::Message { .. }) => { + return Err(ConnectError::UnexpectedOp); + } + Err(ConnectionReadError::Streaming(StreamingReadError::Decoder(err))) => { + return Err(ConnectError::Proto(err)) + } + Err(ConnectionReadError::Streaming(StreamingReadError::Io(err))) => { + return Err(ConnectError::Io(err)) + } + #[cfg(feature = "websocket")] + Err(ConnectionReadError::Websocket(WebsocketReadError::Decoder( + FrameDecoderError::Decoder(err), + ))) => return Err(ConnectError::Proto(err)), + #[cfg(feature = "websocket")] + Err(ConnectionReadError::Websocket(WebsocketReadError::Decoder( + FrameDecoderError::IncompleteFrame, + ))) => todo!(), + #[cfg(feature = "websocket")] + Err(ConnectionReadError::Websocket(WebsocketReadError::Io(err))) => { + return Err(ConnectError::Io(err)) + } + #[cfg(feature = "websocket")] + Err(ConnectionReadError::Websocket(WebsocketReadError::Closed)) => todo!(), + } + } +} diff --git a/watermelon-net/src/connection/streaming.rs b/watermelon-net/src/connection/streaming.rs new file mode 100644 index 0000000..5bf60f2 --- /dev/null +++ b/watermelon-net/src/connection/streaming.rs @@ -0,0 +1,242 @@ +use std::{ + future::{self, Future}, + io, + pin::{pin, Pin}, + task::{Context, Poll}, +}; + +use bytes::Buf; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; +use watermelon_proto::proto::{ + error::DecoderError, ClientOp, ServerOp, StreamDecoder, StreamEncoder, +}; + +#[derive(Debug)] +pub struct StreamingConnection { + socket: S, + encoder: StreamEncoder, + decoder: StreamDecoder, + may_flush: bool, +} + +impl StreamingConnection +where + S: AsyncRead + AsyncWrite + Unpin, +{ + #[must_use] + pub fn new(socket: S) -> Self { + Self { + socket, + encoder: StreamEncoder::new(), + decoder: StreamDecoder::new(), + may_flush: false, + } + } + + pub fn poll_read_next( + &mut self, + cx: &mut Context<'_>, + ) -> Poll> { + loop { + match self.decoder.decode() { + Ok(Some(server_op)) => return Poll::Ready(Ok(server_op)), + Ok(None) => {} + Err(err) => return Poll::Ready(Err(StreamingReadError::Decoder(err))), + } + + let read_buf_fut = pin!(self.socket.read_buf(self.decoder.read_buf())); + match read_buf_fut.poll(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Ok(1..)) => {} + Poll::Ready(Ok(0)) => { + return Poll::Ready(Err(StreamingReadError::Io( + io::ErrorKind::UnexpectedEof.into(), + ))) + } + Poll::Ready(Err(err)) => return Poll::Ready(Err(StreamingReadError::Io(err))), + } + } + } + + /// Reads the next [`ServerOp`]. + /// + /// # Errors + /// + /// It returns an error if the content cannot be decoded or if an I/O error occurs. + pub async fn read_next(&mut self) -> Result { + future::poll_fn(|cx| self.poll_read_next(cx)).await + } + + pub fn may_write(&self) -> bool { + self.encoder.has_remaining() + } + + pub fn may_flush(&self) -> bool { + self.may_flush + } + + pub fn may_enqueue_more_ops(&self) -> bool { + self.encoder.remaining() < 8_290_304 + } + + pub fn enqueue_write_op(&mut self, item: &ClientOp) { + self.encoder.enqueue_write_op(item); + } + + pub fn poll_write_next(&mut self, cx: &mut Context<'_>) -> Poll> { + if !self.encoder.has_remaining() { + return Poll::Ready(Ok(0)); + } + + let write_outcome = if self.socket.is_write_vectored() { + let mut bufs = [io::IoSlice::new(&[]); 64]; + let n = self.encoder.chunks_vectored(&mut bufs); + debug_assert!(n > 0); + + Pin::new(&mut self.socket).poll_write_vectored(cx, &bufs[..n]) + } else { + Pin::new(&mut self.socket).poll_write(cx, self.encoder.chunk()) + }; + + match write_outcome { + Poll::Pending => { + self.may_flush = false; + Poll::Pending + } + Poll::Ready(Ok(n)) => { + self.encoder.advance(n); + self.may_flush = true; + Poll::Ready(Ok(n)) + } + Poll::Ready(Err(err)) => Poll::Ready(Err(err)), + } + } + + /// Writes the next chunk of data to the socket. + /// + /// It returns the number of bytes that have been written. + /// + /// # Errors + /// + /// An I/O error is returned if it is not possible to write to the socket. + pub async fn write_next(&mut self) -> io::Result { + future::poll_fn(|cx| self.poll_write_next(cx)).await + } + + pub fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll> { + match Pin::new(&mut self.socket).poll_flush(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Ok(())) => { + self.may_flush = false; + Poll::Ready(Ok(())) + } + Poll::Ready(Err(err)) => Poll::Ready(Err(err)), + } + } + + /// Flush any buffered writes to the connection + /// + /// # Errors + /// + /// Returns an error if flushing fails + pub async fn flush(&mut self) -> io::Result<()> { + future::poll_fn(|cx| self.poll_flush(cx)).await + } + + /// Shutdown the connection + /// + /// # Errors + /// + /// Returns an error if shutting down the connection fails. + /// Implementations usually ignore this error. + pub async fn shutdown(&mut self) -> io::Result<()> { + future::poll_fn(|cx| Pin::new(&mut self.socket).poll_shutdown(cx)).await + } + + pub fn socket(&self) -> &S { + &self.socket + } + + pub fn socket_mut(&mut self) -> &mut S { + &mut self.socket + } + + pub fn replace_socket(self, replacer: F) -> StreamingConnection + where + F: FnOnce(S) -> S2, + { + StreamingConnection { + socket: replacer(self.socket), + encoder: self.encoder, + decoder: self.decoder, + may_flush: self.may_flush, + } + } + + pub fn into_inner(self) -> S { + self.socket + } +} + +#[derive(Debug, thiserror::Error)] +pub enum StreamingReadError { + #[error("decoder")] + Decoder(#[source] DecoderError), + #[error("io")] + Io(#[source] io::Error), +} + +#[cfg(test)] +mod tests { + use std::{ + pin::Pin, + task::{Context, Poll}, + }; + + use claims::assert_matches; + use futures_util::task; + use tokio::io::{self, AsyncRead, AsyncWrite, ReadBuf}; + use watermelon_proto::proto::{ClientOp, ServerOp}; + + use super::StreamingConnection; + + #[test] + fn ping_pong() { + let waker = task::noop_waker(); + let mut cx = Context::from_waker(&waker); + + let (socket, mut conn) = io::duplex(1024); + + let mut client = StreamingConnection::new(socket); + + // Initial state is ok + assert!(client.poll_read_next(&mut cx).is_pending()); + assert_matches!(client.poll_write_next(&mut cx), Poll::Ready(Ok(0))); + + let mut buf = [0; 1024]; + let mut read_buf = ReadBuf::new(&mut buf); + assert!(Pin::new(&mut conn) + .poll_read(&mut cx, &mut read_buf) + .is_pending()); + + // Write PING and verify it was received + client.enqueue_write_op(&ClientOp::Ping); + assert_matches!(client.poll_write_next(&mut cx), Poll::Ready(Ok(6))); + assert_matches!( + Pin::new(&mut conn).poll_read(&mut cx, &mut read_buf), + Poll::Ready(Ok(())) + ); + assert_eq!(read_buf.filled(), b"PING\r\n"); + + // Receive PONG + assert_matches!( + Pin::new(&mut conn).poll_write(&mut cx, b"PONG\r\n"), + Poll::Ready(Ok(6)) + ); + assert_matches!( + client.poll_read_next(&mut cx), + Poll::Ready(Ok(ServerOp::Pong)) + ); + assert!(client.poll_read_next(&mut cx).is_pending()); + } +} diff --git a/watermelon-net/src/connection/websocket.rs b/watermelon-net/src/connection/websocket.rs new file mode 100644 index 0000000..481fc48 --- /dev/null +++ b/watermelon-net/src/connection/websocket.rs @@ -0,0 +1,143 @@ +use std::{ + future, io, + pin::Pin, + task::{Context, Poll}, +}; + +use bytes::Bytes; +use futures_sink::Sink; +use futures_util::{task::noop_waker_ref, Stream}; +use http::Uri; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_websockets::{ClientBuilder, Message, WebSocketStream}; +use watermelon_proto::proto::{ + decode_frame, error::FrameDecoderError, ClientOp, FramedEncoder, ServerOp, +}; + +#[derive(Debug)] +pub struct WebsocketConnection { + socket: WebSocketStream, + encoder: FramedEncoder, + residual_frame: Bytes, + should_flush: bool, +} + +impl WebsocketConnection +where + S: AsyncRead + AsyncWrite + Unpin, +{ + /// Construct a websocket stream to a pre-established connection `socket`. + /// + /// # Errors + /// + /// Returns an error if the websocket handshake fails. + pub async fn new(uri: Uri, socket: S) -> io::Result { + let (socket, _resp) = ClientBuilder::from_uri(uri) + .connect_on(socket) + .await + .map_err(websockets_error_to_io)?; + Ok(Self { + socket, + encoder: FramedEncoder::new(), + residual_frame: Bytes::new(), + should_flush: false, + }) + } + + pub fn poll_read_next( + &mut self, + cx: &mut Context<'_>, + ) -> Poll> { + loop { + if !self.residual_frame.is_empty() { + return Poll::Ready( + decode_frame(&mut self.residual_frame).map_err(WebsocketReadError::Decoder), + ); + } + + match Pin::new(&mut self.socket).poll_next(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Some(Ok(message))) if message.is_binary() => { + self.residual_frame = message.into_payload().into(); + } + Poll::Ready(Some(Ok(_message))) => {} + Poll::Ready(Some(Err(err))) => { + return Poll::Ready(Err(WebsocketReadError::Io(websockets_error_to_io(err)))) + } + Poll::Ready(None) => return Poll::Ready(Err(WebsocketReadError::Closed)), + } + } + } + + /// Reads the next [`ServerOp`]. + /// + /// # Errors + /// + /// It returns an error if the content cannot be decoded or if an I/O error occurs. + pub async fn read_next(&mut self) -> Result { + future::poll_fn(|cx| self.poll_read_next(cx)).await + } + + pub fn should_flush(&self) -> bool { + self.should_flush + } + + pub fn may_enqueue_more_ops(&mut self) -> bool { + let mut cx = Context::from_waker(noop_waker_ref()); + Pin::new(&mut self.socket).poll_ready(&mut cx).is_ready() + } + + /// Enqueue `item` to be written. + #[expect(clippy::missing_panics_doc)] + pub fn enqueue_write_op(&mut self, item: &ClientOp) { + let payload = self.encoder.encode(item); + Pin::new(&mut self.socket) + .start_send(Message::binary(payload)) + .unwrap(); + self.should_flush = true; + } + + pub fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.socket) + .poll_flush(cx) + .map_err(websockets_error_to_io) + } + + /// Flush any buffered writes to the connection + /// + /// # Errors + /// + /// Returns an error if flushing fails + pub async fn flush(&mut self) -> io::Result<()> { + future::poll_fn(|cx| self.poll_flush(cx)).await + } + + /// Shutdown the connection + /// + /// # Errors + /// + /// Returns an error if shutting down the connection fails. + /// Implementations usually ignore this error. + pub async fn shutdown(&mut self) -> io::Result<()> { + future::poll_fn(|cx| Pin::new(&mut self.socket).poll_close(cx)) + .await + .map_err(websockets_error_to_io) + } +} + +#[derive(Debug, thiserror::Error)] +pub enum WebsocketReadError { + #[error("decoder")] + Decoder(#[source] FrameDecoderError), + #[error("io")] + Io(#[source] io::Error), + #[error("closed")] + Closed, +} + +fn websockets_error_to_io(err: tokio_websockets::Error) -> io::Error { + match err { + tokio_websockets::Error::Io(err) => err, + err => io::Error::new(io::ErrorKind::Other, err), + } +} diff --git a/watermelon-net/src/happy_eyeballs.rs b/watermelon-net/src/happy_eyeballs.rs new file mode 100644 index 0000000..955b3e3 --- /dev/null +++ b/watermelon-net/src/happy_eyeballs.rs @@ -0,0 +1,213 @@ +use std::{ + future::Future, + io, + net::SocketAddr, + pin::{pin, Pin}, + task::{Context, Poll}, + time::Duration, +}; + +use futures_util::{ + stream::{self, FusedStream, FuturesUnordered}, + Stream, StreamExt, +}; +use pin_project_lite::pin_project; +use tokio::{ + net::{self, TcpStream}, + time::{self, Sleep}, +}; +use watermelon_proto::{Host, ServerAddr}; + +const CONN_ATTEMPT_DELAY: Duration = Duration::from_millis(250); + +/// Connects to an address and returns a [`TcpStream`]. +/// +/// If the given address is an ip, this just uses [`TcpStream::connect`]. Otherwise, if a host is +/// given, the [Happy Eyeballs] protocol is being used. +/// +/// [Happy Eyeballs]: https://en.wikipedia.org/wiki/Happy_Eyeballs +/// +/// # Errors +/// +/// It returns an error if it is not possible to connect to any host. +pub async fn connect(addr: &ServerAddr) -> io::Result { + match addr.host() { + Host::Ip(ip) => TcpStream::connect(SocketAddr::new(*ip, addr.port())).await, + Host::Dns(host) => { + let host = <_ as AsRef>::as_ref(host); + let addrs = net::lookup_host(format!("{}:{}", host, addr.port())).await?; + + let mut happy_eyeballs = pin!(HappyEyeballs::new(stream::iter(addrs))); + let mut last_err = None; + loop { + match happy_eyeballs.next().await { + Some(Ok(conn)) => return Ok(conn), + Some(Err(err)) => last_err = Some(err), + None => { + return Err(last_err.unwrap_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "could not resolve to any address", + ) + })); + } + } + } + } + } +} + +pin_project! { + #[project = HappyEyeballsProj] + struct HappyEyeballs { + dns: Option, + dns_received: Vec, + connecting: FuturesUnordered< + Pin> + Send + Sync + 'static>>, + >, + last_attempted: Option, + #[pin] + next_attempt_delay: Option, + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +enum LastAttempted { + Ipv4, + Ipv6, +} + +impl HappyEyeballs { + fn new(dns: D) -> Self { + Self { + dns: Some(dns), + dns_received: Vec::new(), + connecting: FuturesUnordered::new(), + last_attempted: None, + next_attempt_delay: None, + } + } +} + +impl HappyEyeballsProj<'_, D> { + fn next_dns_record(&mut self) -> Option { + if self.dns_received.is_empty() { + return None; + } + + let next_kind = self + .last_attempted + .map_or(LastAttempted::Ipv6, LastAttempted::opposite); + for i in 0..self.dns_received.len() { + if LastAttempted::from_addr(self.dns_received[i]) == next_kind { + *self.last_attempted = Some(next_kind); + return Some(self.dns_received.remove(i)); + } + } + + let record = self.dns_received.remove(0); + *self.last_attempted = Some(LastAttempted::from_addr(record)); + Some(record) + } +} + +impl Stream for HappyEyeballs +where + D: Stream + Unpin, +{ + type Item = io::Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + let mut dead_end = true; + + while let Some(dns) = &mut this.dns { + match Pin::new(&mut *dns).poll_next(cx) { + Poll::Pending => { + dead_end = false; + break; + } + Poll::Ready(Some(record)) => { + dead_end = false; + this.dns_received.push(record); + } + Poll::Ready(None) => *this.dns = None, + } + } + + loop { + match Pin::new(&mut this.connecting).poll_next(cx) { + Poll::Pending => dead_end = false, + Poll::Ready(Some(maybe_conn)) => return Poll::Ready(Some(maybe_conn)), + Poll::Ready(None) => {} + } + + let make_new_attempt = if this.connecting.is_empty() { + true + } else if let Some(next_attempt_delay) = this.next_attempt_delay.as_mut().as_pin_mut() { + match next_attempt_delay.poll(cx) { + Poll::Pending => false, + Poll::Ready(()) => { + this.next_attempt_delay.set(None); + true + } + } + } else { + true + }; + if !make_new_attempt { + break; + } + + match this.next_dns_record() { + Some(record) => { + let conn_fut = TcpStream::connect(record); + this.connecting.push(Box::pin(conn_fut)); + this.next_attempt_delay + .set(Some(time::sleep(CONN_ATTEMPT_DELAY))); + } + None => break, + } + } + + if dead_end { + Poll::Ready(None) + } else { + Poll::Pending + } + } + + fn size_hint(&self) -> (usize, Option) { + let (mut len, mut max) = self.dns.as_ref().map_or((0, Some(0)), Stream::size_hint); + len = len.saturating_add(self.dns_received.len() + self.connecting.len()); + if let Some(max) = &mut max { + *max = max.saturating_add(self.dns_received.len() + self.connecting.len()); + } + (len, max) + } +} + +impl FusedStream for HappyEyeballs +where + D: Stream + Unpin, +{ + fn is_terminated(&self) -> bool { + self.dns.is_none() && self.dns_received.is_empty() && self.connecting.is_empty() + } +} + +impl LastAttempted { + fn from_addr(addr: SocketAddr) -> Self { + match addr { + SocketAddr::V4(_) => Self::Ipv4, + SocketAddr::V6(_) => Self::Ipv6, + } + } + + fn opposite(self) -> Self { + match self { + Self::Ipv4 => Self::Ipv6, + Self::Ipv6 => Self::Ipv4, + } + } +} diff --git a/watermelon-net/src/lib.rs b/watermelon-net/src/lib.rs new file mode 100644 index 0000000..57c9243 --- /dev/null +++ b/watermelon-net/src/lib.rs @@ -0,0 +1,13 @@ +#[cfg(feature = "websocket")] +pub use self::connection::WebsocketConnection; +pub use self::connection::{connect as proto_connect, Connection, StreamingConnection}; +pub use self::happy_eyeballs::connect as connect_tcp; + +mod connection; +mod happy_eyeballs; + +pub mod error { + #[cfg(feature = "websocket")] + pub use super::connection::WebsocketReadError; + pub use super::connection::{ConnectError, ConnectionReadError, StreamingReadError}; +} diff --git a/watermelon-nkeys/Cargo.toml b/watermelon-nkeys/Cargo.toml new file mode 100644 index 0000000..ae09e0a --- /dev/null +++ b/watermelon-nkeys/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "watermelon-nkeys" +version = "0.1.0" +description = "Minimal NKeys implementation for NATS client authentication" +categories = ["parser-implementations", "cryptography"] +keywords = ["nats", "nkey"] +edition.workspace = true +license.workspace = true +repository.workspace = true +rust-version.workspace = true + +[dependencies] +aws-lc-rs = { version = "1.12.2", default-features = false, features = ["aws-lc-sys", "prebuilt-nasm"], optional = true } +ring = { version = "0.17", optional = true } +crc = "3.2.1" +thiserror = "2" +data-encoding = { version = "2.7.0", default-features = false } + +[features] +default = ["aws-lc-rs"] +aws-lc-rs = ["dep:aws-lc-rs"] +ring = ["dep:ring"] +fips = ["aws-lc-rs", "aws-lc-rs/fips"] + +[lints] +workspace = true diff --git a/watermelon-nkeys/LICENSE-APACHE b/watermelon-nkeys/LICENSE-APACHE new file mode 120000 index 0000000..965b606 --- /dev/null +++ b/watermelon-nkeys/LICENSE-APACHE @@ -0,0 +1 @@ +../LICENSE-APACHE \ No newline at end of file diff --git a/watermelon-nkeys/LICENSE-MIT b/watermelon-nkeys/LICENSE-MIT new file mode 120000 index 0000000..76219eb --- /dev/null +++ b/watermelon-nkeys/LICENSE-MIT @@ -0,0 +1 @@ +../LICENSE-MIT \ No newline at end of file diff --git a/watermelon-nkeys/README.md b/watermelon-nkeys/README.md new file mode 120000 index 0000000..32d46ee --- /dev/null +++ b/watermelon-nkeys/README.md @@ -0,0 +1 @@ +../README.md \ No newline at end of file diff --git a/watermelon-nkeys/src/crc.rs b/watermelon-nkeys/src/crc.rs new file mode 100644 index 0000000..17ca01b --- /dev/null +++ b/watermelon-nkeys/src/crc.rs @@ -0,0 +1,25 @@ +#[derive(Debug, PartialEq, Eq)] +pub(crate) struct Crc16(u16); + +impl Crc16 { + pub(crate) fn compute(buf: &[u8]) -> Self { + Self(crc::Crc::::new(&crc::CRC_16_XMODEM).checksum(buf)) + } + + pub(crate) fn from_raw_encoded(val: [u8; 2]) -> Self { + Self::from_raw(u16::from_le_bytes(val)) + } + + pub(crate) fn from_raw(val: u16) -> Self { + Self(val) + } + + #[expect(dead_code)] + pub(crate) fn to_raw(&self) -> u16 { + self.0 + } + + pub(crate) fn to_raw_encoded(&self) -> [u8; 2] { + self.0.to_le_bytes() + } +} diff --git a/watermelon-nkeys/src/lib.rs b/watermelon-nkeys/src/lib.rs new file mode 100644 index 0000000..bb49db9 --- /dev/null +++ b/watermelon-nkeys/src/lib.rs @@ -0,0 +1,4 @@ +pub use self::seed::{KeyPair, KeyPairFromSeedError, PublicKey}; + +mod crc; +mod seed; diff --git a/watermelon-nkeys/src/seed.rs b/watermelon-nkeys/src/seed.rs new file mode 100644 index 0000000..12f8a8a --- /dev/null +++ b/watermelon-nkeys/src/seed.rs @@ -0,0 +1,131 @@ +use std::fmt::{self, Display}; + +#[cfg(feature = "aws-lc-rs")] +use aws_lc_rs::{ + self as crypto_provider, + signature::{Ed25519KeyPair, KeyPair as _}, +}; +use data_encoding::{BASE32_NOPAD, BASE64URL_NOPAD}; +#[cfg(all(not(feature = "aws-lc-rs"), feature = "ring"))] +use ring::{ + self as crypto_provider, + signature::{Ed25519KeyPair, KeyPair as _}, +}; + +#[cfg(not(any(feature = "aws-lc-rs", feature = "ring")))] +compile_error!("Please enable the `aws-lc-rs` or the `ring` feature"); + +use crate::crc::Crc16; + +const SEED_PREFIX_BYTE: u8 = 18 << 3; + +/// A `NKey` private/public key pair. +#[derive(Debug)] +pub struct KeyPair { + kind: u8, + key: Ed25519KeyPair, +} + +/// The public key within an `NKey` private/public key pair. +#[derive(Debug)] +pub struct PublicKey<'a>(&'a KeyPair); + +/// An error encountered while decoding an `NKey`. +#[derive(Debug, thiserror::Error)] +pub enum KeyPairFromSeedError { + /// The string rapresentation of the seed has an invalid length. + #[error("invalid length of the seed's string the string rapresentation")] + InvalidSeedLength, + /// The string rapresentation of the seed contains characters that are not part of the base32 dictionary. + #[error("the seed contains non-base32 characters")] + InvalidBase32, + /// The decoded base32 rapresentation of the seed has an invalid length. + #[error("invalid base32 decoded seed length")] + InvalidRawSeedLength, + /// The CRC does not match the crc calculated for the seed payload. + #[error("invalid CRC")] + BadCrc, + /// The prefix for the seed is invalid + #[error("invalid seed prefix")] + InvalidPrefix, + /// the seed could not be decoded by the crypto backend + #[error("")] + DecodeError, +} + +pub struct Signature(crypto_provider::signature::Signature); + +impl KeyPair { + /// Decode a key from an `NKey` seed. + /// + /// # Errors + /// + /// Returns an error if `seed` is invalid. + #[expect( + clippy::missing_panics_doc, + reason = "the array `TryInto` calls cannot panic" + )] + pub fn from_encoded_seed(seed: &str) -> Result { + if seed.len() != 58 { + return Err(KeyPairFromSeedError::InvalidSeedLength); + } + + let mut full_raw_seed = [0; 36]; + let len = BASE32_NOPAD + .decode_mut(seed.as_bytes(), &mut full_raw_seed) + .map_err(|_| KeyPairFromSeedError::InvalidBase32)?; + if len != full_raw_seed.len() { + return Err(KeyPairFromSeedError::InvalidRawSeedLength); + } + + let (raw_seed, crc) = full_raw_seed.split_at(full_raw_seed.len() - 2); + let raw_seed_crc = Crc16::compute(raw_seed); + let expected_crc = Crc16::from_raw_encoded(crc.try_into().unwrap()); + if raw_seed_crc != expected_crc { + return Err(KeyPairFromSeedError::BadCrc); + } + + Self::from_raw_seed(raw_seed.try_into().unwrap()) + } + + fn from_raw_seed(raw_seed: [u8; 34]) -> Result { + if raw_seed[0] & 248 != SEED_PREFIX_BYTE { + println!("{:x}", raw_seed[0]); + return Err(KeyPairFromSeedError::InvalidPrefix); + } + + let kind = raw_seed[1]; + + let key = Ed25519KeyPair::from_seed_unchecked(&raw_seed[2..]) + .map_err(|_| KeyPairFromSeedError::DecodeError)?; + Ok(Self { kind, key }) + } + + #[must_use] + pub fn public_key(&self) -> PublicKey<'_> { + PublicKey(self) + } + + #[must_use] + pub fn sign(&self, buf: &[u8]) -> Signature { + Signature(self.key.sign(buf)) + } +} + +impl Display for Signature { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + Display::fmt(&BASE64URL_NOPAD.encode_display(self.0.as_ref()), f) + } +} + +impl Display for PublicKey<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut full_raw_seed = [0; 36]; + full_raw_seed[0] = SEED_PREFIX_BYTE; + full_raw_seed[1] = self.0.kind; + full_raw_seed[2..34].copy_from_slice(self.0.key.public_key().as_ref()); + let crc = Crc16::compute(&full_raw_seed[..34]); + full_raw_seed[34..36].copy_from_slice(&crc.to_raw_encoded()); + Display::fmt(&BASE32_NOPAD.encode_display(&full_raw_seed), f) + } +} diff --git a/watermelon-proto/Cargo.toml b/watermelon-proto/Cargo.toml new file mode 100644 index 0000000..173e834 --- /dev/null +++ b/watermelon-proto/Cargo.toml @@ -0,0 +1,37 @@ +[package] +name = "watermelon-proto" +version = "0.1.0" +description = "#[no_std] NATS Core Sans-IO protocol implementation" +categories = ["network-programming", "parser-implementations", "no-std"] +keywords = ["nats", "client"] +edition.workspace = true +license.workspace = true +repository.workspace = true +rust-version.workspace = true + +[package.metadata.docs.rs] +features = ["non-standard-zstd"] + +[dependencies] +bytes = { version = "1", default-features = false } +bytestring = { version = "1", default-features = false, features = ["serde"] } +url = { version = "2.5.3", default-features = false, features = ["serde"] } +percent-encoding = { version = "2", default-features = false, features = ["alloc"] } +memchr = { version = "2.4", default-features = false } +unicase = "2.7" + +serde = { version = "1.0.107", default-features = false, features = ["derive"] } +serde_json = { version = "1", default-features = false, features = ["alloc"] } + +thiserror = { version = "2", default-features = false } + +[dev-dependencies] +claims = "0.8" + +[features] +default = ["std"] +std = ["bytes/std", "url/std", "percent-encoding/std", "memchr/std", "serde/std", "serde_json/std", "thiserror/std"] +non-standard-zstd = [] + +[lints] +workspace = true diff --git a/watermelon-proto/LICENSE-APACHE b/watermelon-proto/LICENSE-APACHE new file mode 120000 index 0000000..965b606 --- /dev/null +++ b/watermelon-proto/LICENSE-APACHE @@ -0,0 +1 @@ +../LICENSE-APACHE \ No newline at end of file diff --git a/watermelon-proto/LICENSE-MIT b/watermelon-proto/LICENSE-MIT new file mode 120000 index 0000000..76219eb --- /dev/null +++ b/watermelon-proto/LICENSE-MIT @@ -0,0 +1 @@ +../LICENSE-MIT \ No newline at end of file diff --git a/watermelon-proto/README.md b/watermelon-proto/README.md new file mode 120000 index 0000000..32d46ee --- /dev/null +++ b/watermelon-proto/README.md @@ -0,0 +1 @@ +../README.md \ No newline at end of file diff --git a/watermelon-proto/src/connect.rs b/watermelon-proto/src/connect.rs new file mode 100644 index 0000000..25b91c1 --- /dev/null +++ b/watermelon-proto/src/connect.rs @@ -0,0 +1,63 @@ +use alloc::string::String; + +use serde::Serialize; + +#[derive(Debug, Serialize)] +#[allow(clippy::struct_excessive_bools)] +pub struct Connect { + pub verbose: bool, + pub pedantic: bool, + #[serde(rename = "tls_required")] + pub require_tls: bool, + pub auth_token: Option, + #[serde(rename = "user")] + pub username: Option, + #[serde(rename = "pass")] + pub password: Option, + #[serde(rename = "name")] + pub client_name: Option, + #[serde(rename = "lang")] + pub client_lang: &'static str, + #[serde(rename = "version")] + pub client_version: &'static str, + pub protocol: u8, + pub echo: bool, + #[serde(rename = "sig")] + pub signature: Option, + pub jwt: Option, + #[serde(rename = "no_responders")] + pub supports_no_responders: bool, + #[serde(rename = "headers")] + pub supports_headers: bool, + pub nkey: Option, + + #[serde(flatten)] + pub non_standard: NonStandardConnect, +} + +#[derive(Debug, Serialize)] +#[non_exhaustive] +pub struct NonStandardConnect { + #[cfg(feature = "non-standard-zstd")] + #[serde( + rename = "m4ss_zstd", + skip_serializing_if = "skip_serializing_if_false" + )] + pub zstd: bool, +} + +#[allow(clippy::derivable_impls)] +impl Default for NonStandardConnect { + fn default() -> Self { + Self { + #[cfg(feature = "non-standard-zstd")] + zstd: false, + } + } +} + +#[cfg(feature = "non-standard-zstd")] +#[allow(clippy::trivially_copy_pass_by_ref)] +fn skip_serializing_if_false(val: &bool) -> bool { + !*val +} diff --git a/watermelon-proto/src/headers/map.rs b/watermelon-proto/src/headers/map.rs new file mode 100644 index 0000000..bc301cc --- /dev/null +++ b/watermelon-proto/src/headers/map.rs @@ -0,0 +1,303 @@ +use alloc::{ + collections::{btree_map::Entry, BTreeMap}, + vec, + vec::Vec, +}; +use core::{iter, mem, slice}; + +use super::{HeaderName, HeaderValue}; + +/// A set of NATS headers +/// +/// [`HeaderMap`] is a multimap of [`HeaderName`]. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct HeaderMap { + headers: BTreeMap, + len: usize, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum OneOrMany { + One(HeaderValue), + Many(Vec), +} + +impl HeaderMap { + /// Create an empty `HeaderMap` + /// + /// The map will be created without any capacity. This function will not allocate. + /// + /// Consider using the [`FromIterator`], [`Extend`] implementations if the final + /// length is known upfront. + #[must_use] + pub const fn new() -> Self { + Self { + headers: BTreeMap::new(), + len: 0, + } + } + + pub fn insert(&mut self, name: HeaderName, value: HeaderValue) { + if let Some(prev) = self.headers.insert(name, OneOrMany::One(value)) { + self.len -= prev.len(); + } + self.len += 1; + } + + pub fn append(&mut self, name: HeaderName, value: HeaderValue) { + match self.headers.entry(name) { + Entry::Vacant(vacant) => { + vacant.insert(OneOrMany::One(value)); + } + Entry::Occupied(mut occupied) => { + occupied.get_mut().push(value); + } + } + self.len += 1; + } + + pub fn remove(&mut self, name: &HeaderName) { + if let Some(prev) = self.headers.remove(name) { + self.len -= prev.len(); + } + } + + /// Returns the number of keys stored in the map + /// + /// This number will be less than or equal to [`HeaderMap::len`]. + #[must_use] + pub fn keys_len(&self) -> usize { + self.headers.len() + } + + /// Returns the number of headers stored in the map + /// + /// This number represents the total number of **values** stored in the map. + /// This number can be greater than or equal to the number of **keys** stored. + #[must_use] + pub fn len(&self) -> usize { + self.len + } + + /// Returns true if the map contains no elements + #[must_use] + pub fn is_empty(&self) -> bool { + self.headers.is_empty() + } + + /// Clear the map, removing all key-value pairs. Keeps the allocated memory for reuse + pub fn clear(&mut self) { + self.headers.clear(); + self.len = 0; + } + + #[cfg(test)] + fn keys(&self) -> impl Iterator { + self.headers.keys() + } + + pub(crate) fn iter( + &self, + ) -> impl DoubleEndedIterator)> + { + self.headers + .iter() + .map(|(name, value)| (name, value.iter())) + } +} + +impl FromIterator<(HeaderName, HeaderValue)> for HeaderMap { + fn from_iter>(iter: I) -> Self { + let mut this = Self::new(); + this.extend(iter); + this + } +} + +impl Extend<(HeaderName, HeaderValue)> for HeaderMap { + fn extend>(&mut self, iter: T) { + iter.into_iter().for_each(|(name, value)| { + self.append(name, value); + }); + } +} + +impl Default for HeaderMap { + fn default() -> Self { + Self::new() + } +} + +impl OneOrMany { + fn len(&self) -> usize { + match self { + Self::One(_) => 1, + Self::Many(vec) => vec.len(), + } + } + + fn push(&mut self, item: HeaderValue) { + match self { + Self::One(current_item) => { + let current_item = + mem::replace(current_item, HeaderValue::from_static("replacing")); + *self = Self::Many(vec![current_item, item]); + } + Self::Many(vec) => { + debug_assert!(!vec.is_empty(), "OneOrMany can't be empty"); + vec.push(item); + } + } + } + + fn iter(&self) -> impl Iterator { + enum Either<'a> { + A(iter::Once<&'a HeaderValue>), + B(slice::Iter<'a, HeaderValue>), + } + + impl<'a> Iterator for Either<'a> { + type Item = &'a HeaderValue; + + fn next(&mut self) -> Option { + match self { + Self::A(a) => a.next(), + Self::B(b) => b.next(), + } + } + + fn size_hint(&self) -> (usize, Option) { + match self { + Self::A(a) => a.size_hint(), + Self::B(b) => b.size_hint(), + } + } + + fn last(mut self) -> Option { + self.next_back() + } + + fn fold(self, init: B, f: F) -> B + where + F: FnMut(B, Self::Item) -> B, + { + match self { + Self::A(a) => a.fold(init, f), + Self::B(b) => b.fold(init, f), + } + } + } + + impl DoubleEndedIterator for Either<'_> { + fn next_back(&mut self) -> Option { + match self { + Self::A(a) => a.next_back(), + Self::B(b) => b.next_back(), + } + } + + fn rfold(self, init: B, f: F) -> B + where + F: FnMut(B, Self::Item) -> B, + { + match self { + Self::A(a) => a.rfold(init, f), + Self::B(b) => b.rfold(init, f), + } + } + } + + match self { + Self::One(one) => Either::A(iter::once(one)), + Self::Many(many) => Either::B(many.iter()), + } + } +} + +#[cfg(test)] +mod tests { + use alloc::{vec, vec::Vec}; + + use crate::headers::{HeaderName, HeaderValue}; + + use super::HeaderMap; + + #[test] + fn manual() { + let mut headers = HeaderMap::new(); + headers.append( + HeaderName::from_static("Nats-Message-Id"), + HeaderValue::from_static("abcd"), + ); + headers.append( + HeaderName::from_static("Nats-Sequence"), + HeaderValue::from_static("1"), + ); + headers.append( + HeaderName::from_static("Nats-Message-Id"), + HeaderValue::from_static("1234"), + ); + headers.append( + HeaderName::from_static("Nats-Time-Stamp"), + HeaderValue::from_static("0"), + ); + headers.remove(&HeaderName::from_static("Nats-Time-Stamp")); + + verify_header_map(&headers); + } + + #[test] + fn collect() { + let headers = [ + ( + HeaderName::from_static("Nats-Message-Id"), + HeaderValue::from_static("abcd"), + ), + ( + HeaderName::from_static("Nats-Sequence"), + HeaderValue::from_static("1"), + ), + ( + HeaderName::from_static("Nats-Message-Id"), + HeaderValue::from_static("1234"), + ), + ] + .into_iter() + .collect::(); + + verify_header_map(&headers); + } + + fn verify_header_map(headers: &HeaderMap) { + assert_eq!( + [ + HeaderName::from_static("Nats-Message-Id"), + HeaderName::from_static("Nats-Sequence") + ] + .as_slice(), + headers.keys().cloned().collect::>().as_slice() + ); + + let raw_headers = headers + .iter() + .map(|(name, values)| (name.clone(), values.cloned().collect::>())) + .collect::>(); + assert_eq!( + [ + ( + HeaderName::from_static("Nats-Message-Id"), + vec![ + HeaderValue::from_static("abcd"), + HeaderValue::from_static("1234") + ] + ), + ( + HeaderName::from_static("Nats-Sequence"), + vec![HeaderValue::from_static("1")] + ), + ] + .as_slice(), + raw_headers.as_slice(), + ); + } +} diff --git a/watermelon-proto/src/headers/mod.rs b/watermelon-proto/src/headers/mod.rs new file mode 100644 index 0000000..5807ee0 --- /dev/null +++ b/watermelon-proto/src/headers/mod.rs @@ -0,0 +1,12 @@ +pub use self::map::HeaderMap; +pub use self::name::HeaderName; +pub use self::value::HeaderValue; + +mod map; +mod name; +mod value; + +pub mod error { + pub use super::name::HeaderNameValidateError; + pub use super::value::HeaderValueValidateError; +} diff --git a/watermelon-proto/src/headers/name.rs b/watermelon-proto/src/headers/name.rs new file mode 100644 index 0000000..b626d90 --- /dev/null +++ b/watermelon-proto/src/headers/name.rs @@ -0,0 +1,203 @@ +use alloc::string::String; +use core::{ + fmt::{self, Display}, + ops::Deref, +}; +use unicase::UniCase; + +use bytestring::ByteString; + +/// A string that can be used to represent an header name +/// +/// `HeaderName` contains a string that is guaranteed [^1] to +/// contain a valid header name that meets the following requirements: +/// +/// * The value is not empty +/// * The value has a length less than or equal to 64 [^2] +/// * The value does not contain any whitespace characters or `:` +/// +/// `HeaderName` can be constructed from [`HeaderName::from_static`] +/// or any of the `TryFrom` implementations. +/// +/// [^1]: Because [`HeaderName::from_dangerous_value`] is safe to call, +/// unsafe code must not assume any of the above invariants. +/// [^2]: Messages coming from the NATS server are allowed to violate this rule. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub struct HeaderName(UniCase); + +impl HeaderName { + /// Client-defined unique identifier for a message that will be used by the server apply de-duplication within the configured Jetstream _Duplicate Window_ + pub const MESSAGE_ID: Self = Self::new_internal("Nats-Msg-Id"); + /// Have Jetstream assert that the published message is received by the expected stream + pub const EXPECTED_STREAM: Self = Self::new_internal("Nats-Expected-Stream"); + /// Have Jetstream assert that the last expected [`HeaderName::MESSAGE_ID`] matches this ID + pub const EXPECTED_LAST_MESSAGE_ID: Self = Self::new_internal("Nats-Expected-Last-Msg-Id"); + /// Have Jetstream assert that the last sequence ID matches this ID + pub const EXPECTED_LAST_SEQUENCE: Self = Self::new_internal("Nats-Expected-Last-Sequence"); + /// Purge all prior messages in the stream (`all` value) or at the subject-level (`sub` value) + pub const ROLLUP: Self = Self::new_internal("Nats-Rollup"); + + /// Name of the stream the message was republished from + pub const STREAM: Self = Self::new_internal("Nats-Stream"); + /// Original subject to which the message was republished from + pub const SUBJECT: Self = Self::new_internal("Nats-Subject"); + /// Original sequence ID the message was republished from + pub const SEQUENCE: Self = Self::new_internal("Nats-Sequence"); + /// Last sequence ID of the message having the same subject, or zero if this is the first message for the subject + pub const LAST_SEQUENCE: Self = Self::new_internal("Nats-Last-Sequence"); + /// The original RFC3339 timestamp of the message + pub const TIMESTAMP: Self = Self::new_internal("Nats-Time-Stamp"); + + /// Origin stream name, subject, sequence number, subject filter and destination transform of the message being sourced + pub const STREAM_SOURCE: Self = Self::new_internal("Nats-Stream-Source"); + + /// Size of the message payload in bytes for an headers-only message + pub const MESSAGE_SIZE: Self = Self::new_internal("Nats-Msg-Size"); + + /// Construct `HeaderName` from a static string + /// + /// # Panics + /// + /// Will panic if `value` isn't a valid `HeaderName` + #[must_use] + pub fn from_static(value: &'static str) -> Self { + Self::try_from(ByteString::from_static(value)).expect("invalid HeaderName") + } + + /// Construct a `HeaderName` from a string, without checking invariants + /// + /// This method bypasses invariants checks implemented by [`HeaderName::from_static`] + /// and all `TryFrom` implementations. + /// + /// # Security + /// + /// While calling this method can eliminate the runtime performance cost of + /// checking the string, constructing `HeaderName` with an invalid string and + /// then calling the NATS server with it can cause serious security issues. + /// When in doubt use the [`HeaderName::from_static`] or any of the `TryFrom` + /// implementations. + #[expect( + clippy::missing_panics_doc, + reason = "The header validation is only made in debug" + )] + #[must_use] + pub fn from_dangerous_value(value: ByteString) -> Self { + if cfg!(debug_assertions) { + if let Err(err) = validate_header_name(&value) { + panic!("HeaderName {value:?} isn't valid {err:?}"); + } + } + Self(UniCase::new(value)) + } + + const fn new_internal(value: &'static str) -> Self { + if value.is_ascii() { + Self(UniCase::ascii(ByteString::from_static(value))) + } else { + Self(UniCase::unicode(ByteString::from_static(value))) + } + } + + #[must_use] + pub fn as_str(&self) -> &str { + &self.0 + } +} + +impl Display for HeaderName { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + Display::fmt(&self.0, f) + } +} + +impl TryFrom for HeaderName { + type Error = HeaderNameValidateError; + + fn try_from(value: ByteString) -> Result { + validate_header_name(&value)?; + Ok(Self::from_dangerous_value(value)) + } +} + +impl TryFrom for HeaderName { + type Error = HeaderNameValidateError; + + fn try_from(value: String) -> Result { + validate_header_name(&value)?; + Ok(Self::from_dangerous_value(value.into())) + } +} + +impl From for ByteString { + fn from(value: HeaderName) -> Self { + value.0.into_inner() + } +} + +impl AsRef<[u8]> for HeaderName { + fn as_ref(&self) -> &[u8] { + self.as_str().as_bytes() + } +} + +impl AsRef for HeaderName { + fn as_ref(&self) -> &str { + self.as_str() + } +} + +impl Deref for HeaderName { + type Target = str; + + fn deref(&self) -> &Self::Target { + self.as_str() + } +} + +/// An error encountered while validating [`HeaderName`] +#[derive(Debug, thiserror::Error)] +pub enum HeaderNameValidateError { + /// The value is empty + #[error("HeaderName is empty")] + Empty, + /// The value has a length greater than 64 + #[error("HeaderName is too long")] + TooLong, + /// The value contains an Unicode whitespace character or `:` + #[error("HeaderName contained an illegal whitespace character")] + IllegalCharacter, +} + +fn validate_header_name(header_name: &str) -> Result<(), HeaderNameValidateError> { + if header_name.is_empty() { + return Err(HeaderNameValidateError::Empty); + } + + if header_name.len() > 64 { + // This is an arbitrary limit, but I guess the server must also have one + return Err(HeaderNameValidateError::TooLong); + } + + if header_name.chars().any(|c| c.is_whitespace() || c == ':') { + // The theoretical security limit is just ` `, `\t`, `\r`, `\n` and `:`. + // Let's be more careful. + return Err(HeaderNameValidateError::IllegalCharacter); + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use core::cmp::Ordering; + + use super::HeaderName; + + #[test] + fn eq() { + let cased = HeaderName::from_static("Nats-Message-Id"); + let lowercase = HeaderName::from_static("nats-message-id"); + assert_eq!(cased, lowercase); + assert_eq!(cased.cmp(&lowercase), Ordering::Equal); + } +} diff --git a/watermelon-proto/src/headers/value.rs b/watermelon-proto/src/headers/value.rs new file mode 100644 index 0000000..f5d7046 --- /dev/null +++ b/watermelon-proto/src/headers/value.rs @@ -0,0 +1,151 @@ +use alloc::string::String; +use core::{ + fmt::{self, Display}, + ops::Deref, +}; + +use bytestring::ByteString; + +/// A string that can be used to represent an header value +/// +/// `HeaderValue` contains a string that is guaranteed [^1] to +/// contain a valid header value that meets the following requirements: +/// +/// * The value is not empty +/// * The value has a length less than or equal to 1024 [^2] +/// * The value does not contain any whitespace characters +/// +/// `HeaderValue` can be constructed from [`HeaderValue::from_static`] +/// or any of the `TryFrom` implementations. +/// +/// [^1]: Because [`HeaderValue::from_dangerous_value`] is safe to call, +/// unsafe code must not assume any of the above invariants. +/// [^2]: Messages coming from the NATS server are allowed to violate this rule. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub struct HeaderValue(ByteString); + +impl HeaderValue { + /// Construct `HeaderValue` from a static string + /// + /// # Panics + /// + /// Will panic if `value` isn't a valid `HeaderValue` + #[must_use] + pub fn from_static(value: &'static str) -> Self { + Self::try_from(ByteString::from_static(value)).expect("invalid HeaderValue") + } + + /// Construct a `HeaderValue` from a string, without checking invariants + /// + /// This method bypasses invariants checks implemented by [`HeaderValue::from_static`] + /// and all `TryFrom` implementations. + /// + /// # Security + /// + /// While calling this method can eliminate the runtime performance cost of + /// checking the string, constructing `HeaderValue` with an invalid string and + /// then calling the NATS server with it can cause serious security issues. + /// When in doubt use the [`HeaderValue::from_static`] or any of the `TryFrom` + /// implementations. + #[must_use] + #[expect( + clippy::missing_panics_doc, + reason = "The header validation is only made in debug" + )] + pub fn from_dangerous_value(value: ByteString) -> Self { + if cfg!(debug_assertions) { + if let Err(err) = validate_header_value(&value) { + panic!("HeaderValue {value:?} isn't valid {err:?}"); + } + } + Self(value) + } + + #[must_use] + pub fn as_str(&self) -> &str { + &self.0 + } +} + +impl Display for HeaderValue { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + Display::fmt(&self.0, f) + } +} + +impl TryFrom for HeaderValue { + type Error = HeaderValueValidateError; + + fn try_from(value: ByteString) -> Result { + validate_header_value(&value)?; + Ok(Self::from_dangerous_value(value)) + } +} + +impl TryFrom for HeaderValue { + type Error = HeaderValueValidateError; + + fn try_from(value: String) -> Result { + validate_header_value(&value)?; + Ok(Self::from_dangerous_value(value.into())) + } +} + +impl From for ByteString { + fn from(value: HeaderValue) -> Self { + value.0 + } +} + +impl AsRef<[u8]> for HeaderValue { + fn as_ref(&self) -> &[u8] { + self.as_str().as_bytes() + } +} + +impl AsRef for HeaderValue { + fn as_ref(&self) -> &str { + self.as_str() + } +} + +impl Deref for HeaderValue { + type Target = str; + + fn deref(&self) -> &Self::Target { + self.as_str() + } +} + +/// An error encountered while validating [`HeaderValue`] +#[derive(Debug, thiserror::Error)] +pub enum HeaderValueValidateError { + /// The value is empty + #[error("HeaderValue is empty")] + Empty, + /// The value has a length greater than 64 + #[error("HeaderValue is too long")] + TooLong, + /// The value contains an Unicode whitespace character + #[error("HeaderValue contained an illegal whitespace character")] + IllegalCharacter, +} + +fn validate_header_value(header_value: &str) -> Result<(), HeaderValueValidateError> { + if header_value.is_empty() { + return Err(HeaderValueValidateError::Empty); + } + + if header_value.len() > 1024 { + // This is an arbitrary limit, but I guess the server must also have one + return Err(HeaderValueValidateError::TooLong); + } + + if header_value.chars().any(char::is_whitespace) { + // The theoretical security limit is just ` `, `\t`, `\r` and `\n`. + // Let's be more careful. + return Err(HeaderValueValidateError::IllegalCharacter); + } + + Ok(()) +} diff --git a/watermelon-proto/src/lib.rs b/watermelon-proto/src/lib.rs new file mode 100644 index 0000000..6788c7f --- /dev/null +++ b/watermelon-proto/src/lib.rs @@ -0,0 +1,36 @@ +#![cfg_attr(not(feature = "std"), no_std)] + +extern crate alloc; + +pub use self::connect::{Connect, NonStandardConnect}; +pub use self::message::{MessageBase, ServerMessage}; +pub use self::queue_group::QueueGroup; +pub use self::server_addr::{Host, Protocol, ServerAddr, Transport}; +pub use self::server_info::{NonStandardServerInfo, ServerInfo}; +pub use self::status_code::StatusCode; +pub use self::subject::Subject; +pub use self::subscription_id::SubscriptionId; + +mod connect; +pub mod headers; +mod message; +pub mod proto; +mod queue_group; +mod server_addr; +mod server_error; +mod server_info; +mod status_code; +mod subject; +mod subscription_id; +#[cfg(test)] +mod tests; +mod util; + +pub mod error { + pub use super::queue_group::QueueGroupValidateError; + pub use super::server_addr::ServerAddrError; + pub use super::server_error::ServerError; + pub use super::status_code::StatusCodeError; + pub use super::subject::SubjectValidateError; + pub use super::util::ParseUintError; +} diff --git a/watermelon-proto/src/message.rs b/watermelon-proto/src/message.rs new file mode 100644 index 0000000..353fe36 --- /dev/null +++ b/watermelon-proto/src/message.rs @@ -0,0 +1,18 @@ +use bytes::Bytes; + +use crate::{headers::HeaderMap, subscription_id::SubscriptionId, StatusCode, Subject}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct MessageBase { + pub subject: Subject, + pub reply_subject: Option, + pub headers: HeaderMap, + pub payload: Bytes, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ServerMessage { + pub status_code: Option, + pub subscription_id: SubscriptionId, + pub base: MessageBase, +} diff --git a/watermelon-proto/src/proto/client.rs b/watermelon-proto/src/proto/client.rs new file mode 100644 index 0000000..549ca93 --- /dev/null +++ b/watermelon-proto/src/proto/client.rs @@ -0,0 +1,28 @@ +use alloc::boxed::Box; +use core::num::NonZeroU64; + +use crate::{ + connect::Connect, message::MessageBase, queue_group::QueueGroup, + subscription_id::SubscriptionId, Subject, +}; + +#[derive(Debug)] +pub enum ClientOp { + Connect { + connect: Box, + }, + Publish { + message: MessageBase, + }, + Subscribe { + id: SubscriptionId, + subject: Subject, + queue_group: Option, + }, + Unsubscribe { + id: SubscriptionId, + max_messages: Option, + }, + Ping, + Pong, +} diff --git a/watermelon-proto/src/proto/decoder/framed.rs b/watermelon-proto/src/proto/decoder/framed.rs new file mode 100644 index 0000000..f6a2145 --- /dev/null +++ b/watermelon-proto/src/proto/decoder/framed.rs @@ -0,0 +1,27 @@ +use bytes::Bytes; + +use crate::proto::ServerOp; + +use super::{DecoderError, DecoderStatus}; + +/// Decodes a frame of bytes into a [`ServerOp`]. +/// +/// # Errors +/// +/// It returns an error in case the frame is incomplete or if a decoding error occurs. +pub fn decode_frame(frame: &mut Bytes) -> Result { + let mut status = DecoderStatus::ControlLine { last_bytes_read: 0 }; + match super::decode(&mut status, frame) { + Ok(Some(server_op)) => Ok(server_op), + Ok(None) => Err(FrameDecoderError::IncompleteFrame), + Err(err) => Err(FrameDecoderError::Decoder(err)), + } +} + +#[derive(Debug, thiserror::Error)] +pub enum FrameDecoderError { + #[error("incomplete frame")] + IncompleteFrame, + #[error("decoder error")] + Decoder(#[source] DecoderError), +} diff --git a/watermelon-proto/src/proto/decoder/mod.rs b/watermelon-proto/src/proto/decoder/mod.rs new file mode 100644 index 0000000..932b45e --- /dev/null +++ b/watermelon-proto/src/proto/decoder/mod.rs @@ -0,0 +1,373 @@ +use core::{mem, ops::Deref}; + +use bytes::{Buf, Bytes, BytesMut}; +use bytestring::ByteString; + +use crate::{ + error::ServerError, + headers::{ + error::{HeaderNameValidateError, HeaderValueValidateError}, + HeaderMap, HeaderName, HeaderValue, + }, + status_code::StatusCodeError, + util::{self, ParseUintError}, + MessageBase, ServerMessage, StatusCode, Subject, SubscriptionId, +}; + +pub use self::framed::{decode_frame, FrameDecoderError}; +pub use self::stream::StreamDecoder; + +use super::ServerOp; + +mod framed; +mod stream; + +const MAX_HEAD_LEN: usize = 16 * 1024; + +#[derive(Debug)] +pub(super) enum DecoderStatus { + ControlLine { + last_bytes_read: usize, + }, + Headers { + subscription_id: SubscriptionId, + subject: Subject, + reply_subject: Option, + header_len: usize, + payload_len: usize, + }, + Payload { + subscription_id: SubscriptionId, + subject: Subject, + reply_subject: Option, + status_code: Option, + headers: HeaderMap, + payload_len: usize, + }, + Poisoned, +} + +pub(super) trait BytesLike: Buf + Deref { + fn len(&self) -> usize { + Buf::remaining(self) + } + + fn split_to(&mut self, at: usize) -> Bytes { + self.copy_to_bytes(at) + } +} + +impl BytesLike for Bytes {} +impl BytesLike for BytesMut {} + +pub(super) fn decode( + status: &mut DecoderStatus, + read_buf: &mut impl BytesLike, +) -> Result, DecoderError> { + loop { + match status { + DecoderStatus::ControlLine { last_bytes_read } => { + if *last_bytes_read == read_buf.len() { + // No progress has been made + return Ok(None); + } + + let Some(control_line_len) = memchr::memmem::find(read_buf, b"\r\n") else { + *last_bytes_read = read_buf.len(); + return Ok(None); + }; + + let mut control_line = read_buf.split_to(control_line_len + "\r\n".len()); + control_line.truncate(control_line.len() - 2); + + return if control_line.starts_with(b"+OK") { + Ok(Some(ServerOp::Success)) + } else if control_line.starts_with(b"MSG ") { + *status = decode_msg(control_line)?; + continue; + } else if control_line.starts_with(b"HMSG ") { + *status = decode_hmsg(control_line)?; + continue; + } else if control_line.starts_with(b"PING") { + Ok(Some(ServerOp::Ping)) + } else if control_line.starts_with(b"PONG") { + Ok(Some(ServerOp::Pong)) + } else if control_line.starts_with(b"-ERR ") { + control_line.advance("-ERR ".len()); + if !control_line.starts_with(b"'") || !control_line.ends_with(b"'") { + return Err(DecoderError::InvalidErrorMessage); + } + + control_line.advance(1); + control_line.truncate(control_line.len() - 1); + let raw_message = ByteString::try_from(control_line) + .map_err(|_| DecoderError::InvalidErrorMessage)?; + let error = ServerError::parse(raw_message); + Ok(Some(ServerOp::Error { error })) + } else if let Some(info) = control_line.strip_prefix(b"INFO ") { + let info = serde_json::from_slice(info).map_err(DecoderError::InvalidInfo)?; + Ok(Some(ServerOp::Info { info })) + } else if read_buf.len() > MAX_HEAD_LEN { + Err(DecoderError::HeadTooLong { + len: read_buf.len(), + }) + } else { + Err(DecoderError::InvalidCommand) + }; + } + DecoderStatus::Headers { header_len, .. } => { + if read_buf.len() < *header_len { + return Ok(None); + } + + decode_headers(read_buf, status)?; + } + DecoderStatus::Payload { payload_len, .. } => { + if read_buf.len() < *payload_len + "\r\n".len() { + return Ok(None); + } + + let DecoderStatus::Payload { + subscription_id, + subject, + reply_subject, + status_code, + headers, + payload_len, + } = mem::replace(status, DecoderStatus::ControlLine { last_bytes_read: 0 }) + else { + unreachable!() + }; + + let payload = read_buf.split_to(payload_len); + read_buf.advance("\r\n".len()); + let message = ServerMessage { + status_code, + subscription_id, + base: MessageBase { + subject, + reply_subject, + headers, + payload, + }, + }; + return Ok(Some(ServerOp::Message { message })); + } + DecoderStatus::Poisoned => return Err(DecoderError::Poisoned), + } + } +} + +fn decode_msg(mut control_line: Bytes) -> Result { + control_line.advance("MSG ".len()); + + let mut chunks = util::split_spaces(control_line); + let (subject, subscription_id, reply_subject, payload_len) = match ( + chunks.next(), + chunks.next(), + chunks.next(), + chunks.next(), + chunks.next(), + ) { + (Some(subject), Some(subscription_id), Some(reply_subject), Some(payload_len), None) => { + (subject, subscription_id, Some(reply_subject), payload_len) + } + (Some(subject), Some(subscription_id), Some(payload_len), None, None) => { + (subject, subscription_id, None, payload_len) + } + _ => return Err(DecoderError::InvalidMsgArgsCount), + }; + let subject = Subject::from_dangerous_value( + subject + .try_into() + .map_err(|_| DecoderError::SubjectInvalidUtf8)?, + ); + let subscription_id = + SubscriptionId::from_ascii_bytes(&subscription_id).map_err(DecoderError::SubscriptionId)?; + let reply_subject = reply_subject + .map(|reply_subject| { + ByteString::try_from(reply_subject).map_err(|_| DecoderError::SubjectInvalidUtf8) + }) + .transpose()? + .map(Subject::from_dangerous_value); + let payload_len = + util::parse_usize(&payload_len).map_err(DecoderError::InvalidPayloadLength)?; + Ok(DecoderStatus::Payload { + subscription_id, + subject, + reply_subject, + status_code: None, + headers: HeaderMap::new(), + payload_len, + }) +} + +fn decode_hmsg(mut control_line: Bytes) -> Result { + control_line.advance("HMSG ".len()); + let mut chunks = util::split_spaces(control_line); + + let (subject, subscription_id, reply_subject, header_len, total_len) = match ( + chunks.next(), + chunks.next(), + chunks.next(), + chunks.next(), + chunks.next(), + chunks.next(), + ) { + ( + Some(subject), + Some(subscription_id), + Some(reply_to), + Some(header_len), + Some(total_len), + None, + ) => ( + subject, + subscription_id, + Some(reply_to), + header_len, + total_len, + ), + (Some(subject), Some(subscription_id), Some(header_len), Some(total_len), None, None) => { + (subject, subscription_id, None, header_len, total_len) + } + _ => return Err(DecoderError::InvalidHmsgArgsCount), + }; + + let subject = Subject::from_dangerous_value( + subject + .try_into() + .map_err(|_| DecoderError::SubjectInvalidUtf8)?, + ); + let subscription_id = + SubscriptionId::from_ascii_bytes(&subscription_id).map_err(DecoderError::SubscriptionId)?; + let reply_subject = reply_subject + .map(|reply_subject| { + ByteString::try_from(reply_subject).map_err(|_| DecoderError::SubjectInvalidUtf8) + }) + .transpose()? + .map(Subject::from_dangerous_value); + let header_len = util::parse_usize(&header_len).map_err(DecoderError::InvalidHeaderLength)?; + let total_len = util::parse_usize(&total_len).map_err(DecoderError::InvalidPayloadLength)?; + + let payload_len = total_len + .checked_sub(header_len) + .ok_or(DecoderError::InvalidTotalLength)?; + + Ok(DecoderStatus::Headers { + subscription_id, + subject, + reply_subject, + header_len, + payload_len, + }) +} + +fn decode_headers( + read_buf: &mut impl BytesLike, + status: &mut DecoderStatus, +) -> Result<(), DecoderError> { + let DecoderStatus::Headers { + subscription_id, + subject, + reply_subject, + header_len, + payload_len, + } = mem::replace(status, DecoderStatus::Poisoned) + else { + unreachable!() + }; + + let header = read_buf.split_to(header_len); + let mut lines = util::lines_iter(header); + let head = lines.next().ok_or(DecoderError::MissingHead)?; + let head = head + .strip_prefix(b"NATS/1.0") + .ok_or(DecoderError::InvalidHead)?; + let status_code = if head.len() >= 4 { + Some(StatusCode::from_ascii_bytes(&head[1..4]).map_err(DecoderError::StatusCode)?) + } else { + None + }; + + let headers = lines + .filter(|line| !line.is_empty()) + .map(|mut line| { + let i = memchr::memchr(b':', &line).ok_or(DecoderError::InvalidHeaderLine)?; + + let name = line.split_to(i); + line.advance(":".len()); + if line[0].is_ascii_whitespace() { + // The fact that this is allowed sounds like BS to me + line.advance(1); + } + let value = line; + + let name = HeaderName::try_from( + ByteString::try_from(name).map_err(|_| DecoderError::HeaderNameInvalidUtf8)?, + ) + .map_err(DecoderError::HeaderName)?; + let value = HeaderValue::try_from( + ByteString::try_from(value).map_err(|_| DecoderError::HeaderValueInvalidUtf8)?, + ) + .map_err(DecoderError::HeaderValue)?; + Ok((name, value)) + }) + .collect::>()?; + + *status = DecoderStatus::Payload { + subscription_id, + subject, + reply_subject, + status_code, + headers, + payload_len, + }; + Ok(()) +} + +#[derive(Debug, thiserror::Error)] +pub enum DecoderError { + #[error("The head exceeded the maximum head length (len {len} maximum {MAX_HEAD_LEN}")] + HeadTooLong { len: usize }, + #[error("Invalid command")] + InvalidCommand, + #[error("MSG command has an unexpected number of arguments")] + InvalidMsgArgsCount, + #[error("HMSG command has an unexpected number of arguments")] + InvalidHmsgArgsCount, + #[error("The subject isn't valid utf-8")] + SubjectInvalidUtf8, + #[error("The reply subject isn't valid utf-8")] + ReplySubjectInvalidUtf8, + #[error("Couldn't parse the Subscription ID")] + SubscriptionId(#[source] ParseUintError), + #[error("Couldn't parse the length of the header")] + InvalidHeaderLength(#[source] ParseUintError), + #[error("Couldn't parse the length of the payload")] + InvalidPayloadLength(#[source] ParseUintError), + #[error("The total length is greater than the header length")] + InvalidTotalLength, + #[error("HMSG is missing head")] + MissingHead, + #[error("HMSG has an invalid head")] + InvalidHead, + #[error("HMSG header line is missing ': '")] + InvalidHeaderLine, + #[error("Couldn't parse the status code")] + StatusCode(#[source] StatusCodeError), + #[error("The header name isn't valid utf-8")] + HeaderNameInvalidUtf8, + #[error("The header name coouldn't be parsed")] + HeaderName(#[source] HeaderNameValidateError), + #[error("The header value isn't valid utf-8")] + HeaderValueInvalidUtf8, + #[error("The header value coouldn't be parsed")] + HeaderValue(#[source] HeaderValueValidateError), + #[error("INFO command JSON payload couldn't be deserialized")] + InvalidInfo(#[source] serde_json::Error), + #[error("-ERR command message couldn't be deserialized")] + InvalidErrorMessage, + #[error("The decoder was poisoned")] + Poisoned, +} diff --git a/watermelon-proto/src/proto/decoder/stream.rs b/watermelon-proto/src/proto/decoder/stream.rs new file mode 100644 index 0000000..727d326 --- /dev/null +++ b/watermelon-proto/src/proto/decoder/stream.rs @@ -0,0 +1,125 @@ +use bytes::{BufMut, BytesMut}; + +use crate::proto::{error::DecoderError, ServerOp}; + +use super::DecoderStatus; + +const INITIAL_READ_BUF_CAPACITY: usize = 64 * 1024; + +#[derive(Debug)] +pub struct StreamDecoder { + read_buf: BytesMut, + status: DecoderStatus, +} + +impl StreamDecoder { + #[must_use] + pub fn new() -> Self { + Self { + read_buf: BytesMut::with_capacity(INITIAL_READ_BUF_CAPACITY), + status: DecoderStatus::ControlLine { last_bytes_read: 0 }, + } + } + + #[must_use] + pub fn read_buf(&mut self) -> &mut impl BufMut { + &mut self.read_buf + } + + /// Decodes the next frame of bytes into a [`ServerOp`]. + /// + /// A `None` variant is returned in case no progress is made, + /// + /// # Errors + /// + /// It returns an error if a decoding error occurs. + pub fn decode(&mut self) -> Result, DecoderError> { + super::decode(&mut self.status, &mut self.read_buf) + } +} + +impl Default for StreamDecoder { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use bytes::{BufMut as _, Bytes}; + use claims::assert_ok_eq; + + use crate::{ + error::ServerError, + headers::HeaderMap, + message::{MessageBase, ServerMessage}, + proto::server::ServerOp, + Subject, + }; + + use super::StreamDecoder; + + #[test] + fn decode_ping() { + let mut decoder = StreamDecoder::new(); + decoder.read_buf().put(Bytes::from_static(b"PING\r\n")); + assert_ok_eq!(decoder.decode(), Some(ServerOp::Ping)); + assert_ok_eq!(decoder.decode(), None); + } + + #[test] + fn decode_pong() { + let mut decoder = StreamDecoder::new(); + decoder.read_buf().put(Bytes::from_static(b"PONG\r\n")); + assert_ok_eq!(decoder.decode(), Some(ServerOp::Pong)); + assert_ok_eq!(decoder.decode(), None); + } + + #[test] + fn decode_ok() { + let mut decoder = StreamDecoder::new(); + decoder.read_buf().put(Bytes::from_static(b"+OK\r\n")); + assert_ok_eq!(decoder.decode(), Some(ServerOp::Success)); + assert_ok_eq!(decoder.decode(), None); + } + + #[test] + fn decode_error() { + let mut decoder = StreamDecoder::new(); + decoder + .read_buf() + .put(Bytes::from_static(b"-ERR 'Authorization Violation'\r\n")); + assert_ok_eq!( + decoder.decode(), + Some(ServerOp::Error { + error: ServerError::AuthorizationViolation + }) + ); + assert_ok_eq!(decoder.decode(), None); + } + + #[test] + + fn decode_msg() { + let mut decoder = StreamDecoder::new(); + decoder.read_buf().put(Bytes::from_static( + b"MSG hello.world 1 12\r\nHello World!\r\n", + )); + assert_ok_eq!( + decoder.decode(), + Some(ServerOp::Message { + message: ServerMessage { + status_code: None, + subscription_id: 1.into(), + base: MessageBase { + subject: Subject::from_static("hello.world"), + reply_subject: None, + headers: HeaderMap::new(), + payload: Bytes::from_static(b"Hello World!") + } + } + }) + ); + assert_ok_eq!(decoder.decode(), None); + } +} diff --git a/watermelon-proto/src/proto/encoder/framed.rs b/watermelon-proto/src/proto/encoder/framed.rs new file mode 100644 index 0000000..c5a390e --- /dev/null +++ b/watermelon-proto/src/proto/encoder/framed.rs @@ -0,0 +1,192 @@ +use bytes::BytesMut; + +use crate::proto::ClientOp; + +use super::FrameEncoder; + +#[derive(Debug)] +pub struct FramedEncoder { + buf: BytesMut, +} + +impl FramedEncoder { + #[must_use] + pub fn new() -> Self { + Self { + buf: BytesMut::new(), + } + } + + pub fn encode(&mut self, item: &ClientOp) -> BytesMut { + struct Encoder<'a>(&'a mut FramedEncoder); + + impl FrameEncoder for Encoder<'_> { + fn small_write(&mut self, buf: &[u8]) { + self.0.buf.extend_from_slice(buf); + } + } + + super::encode(&mut Encoder(self), item); + self.buf.split() + } +} + +impl Default for FramedEncoder { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use core::num::NonZeroU64; + + use bytes::Bytes; + + use super::FramedEncoder; + use crate::{ + headers::{HeaderMap, HeaderName, HeaderValue}, + proto::ClientOp, + tests::ToBytes as _, + MessageBase, QueueGroup, Subject, + }; + + #[test] + fn encode_ping() { + let mut encoder = FramedEncoder::new(); + assert_eq!( + encoder.encode(&ClientOp::Ping).to_bytes(), + "PING\r\n".as_bytes() + ); + } + + #[test] + fn encode_pong() { + let mut encoder = FramedEncoder::new(); + assert_eq!( + encoder.encode(&ClientOp::Pong).to_bytes(), + "PONG\r\n".as_bytes() + ); + } + + #[test] + fn encode_subscribe() { + let mut encoder = FramedEncoder::new(); + assert_eq!( + encoder + .encode(&ClientOp::Subscribe { + id: 1.into(), + subject: Subject::from_static("hello.world"), + queue_group: None, + }) + .to_bytes(), + "SUB hello.world 1\r\n".as_bytes() + ); + } + + #[test] + fn encode_subscribe_with_queue_group() { + let mut encoder = FramedEncoder::new(); + assert_eq!( + encoder + .encode(&ClientOp::Subscribe { + id: 1.into(), + subject: Subject::from_static("hello.world"), + queue_group: Some(QueueGroup::from_static("stuff")), + }) + .to_bytes(), + "SUB hello.world stuff 1\r\n".as_bytes() + ); + } + + #[test] + fn encode_unsubscribe() { + let mut encoder = FramedEncoder::new(); + assert_eq!( + encoder + .encode(&ClientOp::Unsubscribe { + id: 1.into(), + max_messages: None, + }) + .to_bytes(), + "UNSUB 1\r\n".as_bytes() + ); + } + + #[test] + fn encode_unsubscribe_with_max_messages() { + let mut encoder = FramedEncoder::new(); + assert_eq!( + encoder + .encode(&ClientOp::Unsubscribe { + id: 1.into(), + max_messages: Some(NonZeroU64::new(5).unwrap()), + }) + .to_bytes(), + "UNSUB 1 5\r\n".as_bytes() + ); + } + + #[test] + fn encode_publish() { + let mut encoder = FramedEncoder::new(); + assert_eq!( + encoder + .encode(&ClientOp::Publish { + message: MessageBase { + subject: Subject::from_static("hello.world"), + reply_subject: None, + headers: HeaderMap::new(), + payload: Bytes::from_static(b"Hello World!"), + }, + }) + .to_bytes(), + "PUB hello.world 12\r\nHello World!\r\n".as_bytes() + ); + } + + #[test] + fn encode_publish_with_reply_subject() { + let mut encoder = FramedEncoder::new(); + assert_eq!( + encoder + .encode(&ClientOp::Publish { + message: MessageBase { + subject: Subject::from_static("hello.world"), + reply_subject: Some(Subject::from_static("_INBOX.1234")), + headers: HeaderMap::new(), + payload: Bytes::from_static(b"Hello World!"), + }, + }) + .to_bytes(), + "PUB hello.world _INBOX.1234 12\r\nHello World!\r\n".as_bytes() + ); + } + + #[test] + fn encode_publish_with_headers() { + let mut encoder = FramedEncoder::new(); + assert_eq!( + encoder.encode(&ClientOp::Publish { + message: MessageBase { + subject: Subject::from_static("hello.world"), + reply_subject: None, + headers: [ + ( + HeaderName::from_static("Nats-Message-Id"), + HeaderValue::from_static("abcd"), + ), + ( + HeaderName::from_static("Nats-Sequence"), + HeaderValue::from_static("1"), + ), + ] + .into_iter() + .collect(), + payload: Bytes::from_static(b"Hello World!"), + }, + }).to_bytes(), + "HPUB hello.world 53 65\r\nNATS/1.0\r\nNats-Message-Id: abcd\r\nNats-Sequence: 1\r\n\r\nHello World!\r\n".as_bytes() + ); + } +} diff --git a/watermelon-proto/src/proto/encoder/mod.rs b/watermelon-proto/src/proto/encoder/mod.rs new file mode 100644 index 0000000..6addaa9 --- /dev/null +++ b/watermelon-proto/src/proto/encoder/mod.rs @@ -0,0 +1,173 @@ +use core::fmt::{self, Write as _}; +#[cfg(feature = "std")] +use std::io; + +use bytes::Bytes; + +use crate::headers::HeaderMap; +use crate::MessageBase; + +pub use self::framed::FramedEncoder; +pub use self::stream::StreamEncoder; + +use super::ClientOp; + +mod framed; +mod stream; + +pub(super) trait FrameEncoder { + fn small_write(&mut self, buf: &[u8]); + + fn write(&mut self, buf: B) + where + B: Into + AsRef<[u8]>, + { + self.small_write(buf.as_ref()); + } + + fn small_fmt_writer(&mut self) -> SmallFmtWriter<'_, Self> { + SmallFmtWriter(self) + } + + #[cfg(feature = "std")] + fn small_io_writer(&mut self) -> SmallIoWriter<'_, Self> { + SmallIoWriter(self) + } +} + +pub(super) struct SmallFmtWriter<'a, E: ?Sized>(&'a mut E); + +impl fmt::Write for SmallFmtWriter<'_, E> +where + E: FrameEncoder, +{ + fn write_str(&mut self, s: &str) -> fmt::Result { + self.0.small_write(s.as_bytes()); + Ok(()) + } +} + +#[cfg(feature = "std")] +pub(super) struct SmallIoWriter<'a, E: ?Sized>(&'a mut E); + +#[cfg(feature = "std")] +impl io::Write for SmallIoWriter<'_, E> +where + E: FrameEncoder, +{ + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0.small_write(buf); + Ok(buf.len()) + } + + fn write_all(&mut self, buf: &[u8]) -> io::Result<()> { + self.0.small_write(buf); + Ok(()) + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +pub(super) fn encode(encoder: &mut E, item: &ClientOp) { + macro_rules! small_write { + ($dst:expr) => { + write!(encoder.small_fmt_writer(), $dst).expect("do small write to Connection"); + }; + } + + match item { + ClientOp::Publish { message } => { + let MessageBase { + subject, + reply_subject, + headers, + payload, + } = &message; + let verb = if headers.is_empty() { "PUB" } else { "HPUB" }; + + small_write!("{verb} {subject} "); + + if let Some(reply_subject) = reply_subject { + small_write!("{reply_subject} "); + } + + if headers.is_empty() { + let payload_len = payload.len(); + small_write!("{payload_len}\r\n"); + } else { + let headers_len = encode_headers(headers).fold(0, |len, s| len + s.len()); + + let total_len = headers_len + payload.len(); + small_write!("{headers_len} {total_len}\r\n"); + + encode_headers(headers).for_each(|s| { + encoder.small_write(s.as_bytes()); + }); + } + + encoder.write(IntoBytes(payload)); + encoder.small_write(b"\r\n"); + } + ClientOp::Subscribe { + id, + subject, + queue_group, + } => match queue_group { + Some(queue_group) => { + small_write!("SUB {subject} {queue_group} {id}\r\n"); + } + None => { + small_write!("SUB {subject} {id}\r\n"); + } + }, + ClientOp::Unsubscribe { id, max_messages } => match max_messages { + Some(max_messages) => { + small_write!("UNSUB {id} {max_messages}\r\n"); + } + None => { + small_write!("UNSUB {id}\r\n"); + } + }, + ClientOp::Connect { connect } => { + encoder.small_write(b"CONNECT "); + #[cfg(feature = "std")] + serde_json::to_writer(encoder.small_io_writer(), &connect) + .expect("serialize `Connect`"); + #[cfg(not(feature = "std"))] + encoder.write(serde_json::to_vec(&connect).expect("serialize `Connect`")); + encoder.small_write(b"\r\n"); + } + ClientOp::Ping => { + encoder.small_write(b"PING\r\n"); + } + ClientOp::Pong => { + encoder.small_write(b"PONG\r\n"); + } + } +} + +struct IntoBytes<'a>(&'a Bytes); + +impl<'a> From> for Bytes { + fn from(value: IntoBytes<'a>) -> Self { + Bytes::clone(value.0) + } +} + +impl AsRef<[u8]> for IntoBytes<'_> { + fn as_ref(&self) -> &[u8] { + self.0 + } +} + +fn encode_headers(headers: &HeaderMap) -> impl Iterator { + let head = ["NATS/1.0\r\n"]; + let headers = headers.iter().flat_map(|(name, values)| { + values.flat_map(|value| [name.as_str(), ": ", value.as_str(), "\r\n"]) + }); + let footer = ["\r\n"]; + + head.into_iter().chain(headers).chain(footer) +} diff --git a/watermelon-proto/src/proto/encoder/stream.rs b/watermelon-proto/src/proto/encoder/stream.rs new file mode 100644 index 0000000..c0a9c86 --- /dev/null +++ b/watermelon-proto/src/proto/encoder/stream.rs @@ -0,0 +1,339 @@ +#[cfg(feature = "std")] +use std::io; + +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +use crate::util::BufList; + +use super::{ClientOp, FrameEncoder}; + +const WRITE_FLATTEN_THRESHOLD: usize = 4096; + +#[derive(Debug)] +pub struct StreamEncoder { + write_buf: BufList, + flattened_writes: BytesMut, +} + +impl StreamEncoder { + #[must_use] + pub fn new() -> Self { + Self { + write_buf: BufList::new(), + flattened_writes: BytesMut::new(), + } + } + + pub fn enqueue_write_op(&mut self, item: &ClientOp) { + super::encode(self, item); + } + + #[cfg(test)] + fn all_bytes(&mut self) -> alloc::vec::Vec { + self.copy_to_bytes(self.remaining()).to_vec() + } +} + +impl Buf for StreamEncoder { + fn remaining(&self) -> usize { + self.write_buf.remaining() + self.flattened_writes.remaining() + } + + fn has_remaining(&self) -> bool { + self.write_buf.has_remaining() || self.flattened_writes.has_remaining() + } + + fn chunk(&self) -> &[u8] { + let chunk = self.write_buf.chunk(); + if chunk.is_empty() { + &self.flattened_writes + } else { + chunk + } + } + + #[cfg(feature = "std")] + fn chunks_vectored<'a>(&'a self, dst: &mut [io::IoSlice<'a>]) -> usize { + let mut n = self.write_buf.chunks_vectored(dst); + n += self.flattened_writes.chunks_vectored(&mut dst[n..]); + n + } + + fn advance(&mut self, cnt: usize) { + assert!(cnt <= self.remaining()); + + let mid = self.write_buf.remaining().min(cnt); + self.write_buf.advance(mid); + + let rem = cnt - mid; + if rem == self.flattened_writes.len() { + // https://github.com/tokio-rs/bytes/pull/698 + self.flattened_writes.clear(); + } else { + self.flattened_writes.advance(rem); + } + } + + fn copy_to_bytes(&mut self, len: usize) -> Bytes { + assert!( + len <= self.remaining(), + "copy_to_bytes out of range ({} <= {})", + len, + self.remaining() + ); + + if self.write_buf.remaining() >= len { + self.write_buf.copy_to_bytes(len) + } else if !self.write_buf.has_remaining() { + self.flattened_writes.copy_to_bytes(len) + } else { + let rem = len - self.write_buf.remaining(); + + let mut bufs = BytesMut::with_capacity(len); + bufs.put(&mut self.write_buf); + bufs.put_slice(&self.flattened_writes[..rem]); + + if self.flattened_writes.remaining() == rem { + // https://github.com/tokio-rs/bytes/pull/698 + self.flattened_writes.clear(); + } else { + self.flattened_writes.advance(rem); + } + + bufs.freeze() + } + } +} + +impl FrameEncoder for StreamEncoder { + fn small_write(&mut self, buf: &[u8]) { + self.flattened_writes.extend_from_slice(buf); + } + + fn write(&mut self, buf: B) + where + B: Into + AsRef<[u8]>, + { + let b = buf.as_ref(); + + let len = b.len(); + if len == 0 { + return; + } + + if len < WRITE_FLATTEN_THRESHOLD { + self.flattened_writes.extend_from_slice(b); + } else { + if !self.flattened_writes.is_empty() { + let buf = self.flattened_writes.split().freeze(); + self.write_buf.push(buf); + } + + self.write_buf.push(buf.into()); + } + } +} + +impl Default for StreamEncoder { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use core::num::NonZeroU64; + #[cfg(feature = "std")] + use std::io::IoSlice; + + #[cfg(feature = "std")] + use bytes::Buf; + use bytes::Bytes; + + use super::StreamEncoder; + #[cfg(feature = "std")] + use crate::proto::encoder::FrameEncoder; + use crate::{ + headers::{HeaderMap, HeaderName, HeaderValue}, + proto::ClientOp, + MessageBase, QueueGroup, Subject, + }; + + #[test] + fn base() { + let encoder = StreamEncoder::new(); + assert_eq!(0, encoder.remaining()); + assert!(!encoder.has_remaining()); + } + + #[cfg(feature = "std")] + #[test] + fn vectored() { + let mut encoder = StreamEncoder::new(); + let mut bufs = [IoSlice::new(&[]); 64]; + assert_eq!(0, encoder.chunks_vectored(&mut bufs)); + + encoder.small_write(b"1234"); + let mut bufs = [IoSlice::new(&[]); 64]; + assert_eq!(1, encoder.chunks_vectored(&mut bufs)); + + encoder.small_write(b"5678"); + let mut bufs = [IoSlice::new(&[]); 64]; + assert_eq!(1, encoder.chunks_vectored(&mut bufs)); + + encoder.write("9"); + let mut bufs = [IoSlice::new(&[]); 64]; + assert_eq!(1, encoder.chunks_vectored(&mut bufs)); + + encoder.write(vec![b'_'; 8196]); + let mut bufs = [IoSlice::new(&[]); 64]; + assert_eq!(2, encoder.chunks_vectored(&mut bufs)); + + encoder.write(vec![b':'; 8196]); + let mut bufs = [IoSlice::new(&[]); 64]; + assert_eq!(3, encoder.chunks_vectored(&mut bufs)); + + encoder.write("0"); + let mut bufs = [IoSlice::new(&[]); 64]; + assert_eq!(4, encoder.chunks_vectored(&mut bufs)); + } + + #[test] + fn encode_ping() { + let mut encoder = StreamEncoder::new(); + encoder.enqueue_write_op(&ClientOp::Ping); + assert_eq!(6, encoder.remaining()); + assert!(encoder.has_remaining()); + assert_eq!("PING\r\n".as_bytes(), encoder.all_bytes()); + } + + #[test] + fn encode_pong() { + let mut encoder = StreamEncoder::new(); + encoder.enqueue_write_op(&ClientOp::Pong); + assert_eq!(6, encoder.remaining()); + assert!(encoder.has_remaining()); + assert_eq!("PONG\r\n".as_bytes(), encoder.all_bytes()); + } + + #[test] + fn encode_subscribe() { + let mut encoder = StreamEncoder::new(); + encoder.enqueue_write_op(&ClientOp::Subscribe { + id: 1.into(), + subject: Subject::from_static("hello.world"), + queue_group: None, + }); + assert_eq!(19, encoder.remaining()); + assert!(encoder.has_remaining()); + assert_eq!("SUB hello.world 1\r\n".as_bytes(), encoder.all_bytes()); + } + + #[test] + fn encode_subscribe_with_queue_group() { + let mut encoder = StreamEncoder::new(); + encoder.enqueue_write_op(&ClientOp::Subscribe { + id: 1.into(), + subject: Subject::from_static("hello.world"), + queue_group: Some(QueueGroup::from_static("stuff")), + }); + assert_eq!(25, encoder.remaining()); + assert!(encoder.has_remaining()); + assert_eq!( + "SUB hello.world stuff 1\r\n".as_bytes(), + encoder.all_bytes() + ); + } + + #[test] + fn encode_unsubscribe() { + let mut encoder = StreamEncoder::new(); + encoder.enqueue_write_op(&ClientOp::Unsubscribe { + id: 1.into(), + max_messages: None, + }); + assert_eq!(9, encoder.remaining()); + assert!(encoder.has_remaining()); + assert_eq!("UNSUB 1\r\n".as_bytes(), encoder.all_bytes()); + } + + #[test] + fn encode_unsubscribe_with_max_messages() { + let mut encoder = StreamEncoder::new(); + encoder.enqueue_write_op(&ClientOp::Unsubscribe { + id: 1.into(), + max_messages: Some(NonZeroU64::new(5).unwrap()), + }); + assert_eq!(11, encoder.remaining()); + assert!(encoder.has_remaining()); + assert_eq!("UNSUB 1 5\r\n".as_bytes(), encoder.all_bytes()); + } + + #[test] + fn encode_publish() { + let mut encoder = StreamEncoder::new(); + encoder.enqueue_write_op(&ClientOp::Publish { + message: MessageBase { + subject: Subject::from_static("hello.world"), + reply_subject: None, + headers: HeaderMap::new(), + payload: Bytes::from_static(b"Hello World!"), + }, + }); + assert_eq!(34, encoder.remaining()); + assert!(encoder.has_remaining()); + assert_eq!( + "PUB hello.world 12\r\nHello World!\r\n".as_bytes(), + encoder.all_bytes() + ); + } + + #[test] + fn encode_publish_with_reply_subject() { + let mut encoder = StreamEncoder::new(); + encoder.enqueue_write_op(&ClientOp::Publish { + message: MessageBase { + subject: Subject::from_static("hello.world"), + reply_subject: Some(Subject::from_static("_INBOX.1234")), + headers: HeaderMap::new(), + payload: Bytes::from_static(b"Hello World!"), + }, + }); + assert_eq!(46, encoder.remaining()); + assert!(encoder.has_remaining()); + assert_eq!( + "PUB hello.world _INBOX.1234 12\r\nHello World!\r\n".as_bytes(), + encoder.all_bytes() + ); + } + + #[test] + fn encode_publish_with_headers() { + let mut encoder = StreamEncoder::new(); + encoder.enqueue_write_op(&ClientOp::Publish { + message: MessageBase { + subject: Subject::from_static("hello.world"), + reply_subject: None, + headers: [ + ( + HeaderName::from_static("Nats-Message-Id"), + HeaderValue::from_static("abcd"), + ), + ( + HeaderName::from_static("Nats-Sequence"), + HeaderValue::from_static("1"), + ), + ] + .into_iter() + .collect(), + payload: Bytes::from_static(b"Hello World!"), + }, + }); + assert_eq!(91, encoder.remaining()); + assert!(encoder.has_remaining()); + assert_eq!( + "HPUB hello.world 53 65\r\nNATS/1.0\r\nNats-Message-Id: abcd\r\nNats-Sequence: 1\r\n\r\nHello World!\r\n".as_bytes(), + encoder.all_bytes() + ); + } +} diff --git a/watermelon-proto/src/proto/mod.rs b/watermelon-proto/src/proto/mod.rs new file mode 100644 index 0000000..7cadc17 --- /dev/null +++ b/watermelon-proto/src/proto/mod.rs @@ -0,0 +1,13 @@ +pub use self::client::ClientOp; +pub use self::decoder::{decode_frame, StreamDecoder}; +pub use self::encoder::{FramedEncoder, StreamEncoder}; +pub use self::server::ServerOp; + +mod client; +mod decoder; +mod encoder; +mod server; + +pub mod error { + pub use super::decoder::{DecoderError, FrameDecoderError}; +} diff --git a/watermelon-proto/src/proto/server.rs b/watermelon-proto/src/proto/server.rs new file mode 100644 index 0000000..1ff9573 --- /dev/null +++ b/watermelon-proto/src/proto/server.rs @@ -0,0 +1,13 @@ +use alloc::boxed::Box; + +use crate::{error::ServerError, message::ServerMessage, ServerInfo}; + +#[derive(Debug, PartialEq, Eq)] +pub enum ServerOp { + Info { info: Box }, + Message { message: ServerMessage }, + Success, + Error { error: ServerError }, + Ping, + Pong, +} diff --git a/watermelon-proto/src/queue_group.rs b/watermelon-proto/src/queue_group.rs new file mode 100644 index 0000000..743adcd --- /dev/null +++ b/watermelon-proto/src/queue_group.rs @@ -0,0 +1,253 @@ +use alloc::string::String; +use core::{ + fmt::{self, Display}, + ops::Deref, +}; +use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; + +use bytestring::ByteString; + +/// A string that can be used to represent an queue group +/// +/// `QueueGroup` contains a string that is guaranteed [^1] to +/// contain a valid header name that meets the following requirements: +/// +/// * The value is not empty +/// * The value has a length less than or equal to 64 [^2] +/// * The value does not contain any whitespace characters or `:` +/// +/// `QueueGroup` can be constructed from [`QueueGroup::from_static`] +/// or any of the `TryFrom` implementations. +/// +/// [^1]: Because [`QueueGroup::from_dangerous_value`] is safe to call, +/// unsafe code must not assume any of the above invariants. +/// [^2]: Messages coming from the NATS server are allowed to violate this rule. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub struct QueueGroup(ByteString); + +impl QueueGroup { + /// Construct `QueueGroup` from a static string + /// + /// # Panics + /// + /// Will panic if `value` isn't a valid `QueueGroup` + #[must_use] + pub fn from_static(value: &'static str) -> Self { + Self::try_from(ByteString::from_static(value)).expect("invalid QueueGroup") + } + + /// Construct a `QueueGroup` from a string, without checking invariants + /// + /// This method bypasses invariants checks implemented by [`QueueGroup::from_static`] + /// and all `TryFrom` implementations. + /// + /// # Security + /// + /// While calling this method can eliminate the runtime performance cost of + /// checking the string, constructing `QueueGroup` with an invalid string and + /// then calling the NATS server with it can cause serious security issues. + /// When in doubt use the [`QueueGroup::from_static`] or any of the `TryFrom` + /// implementations. + #[must_use] + #[expect( + clippy::missing_panics_doc, + reason = "The queue group validation is only made in debug" + )] + pub fn from_dangerous_value(value: ByteString) -> Self { + if cfg!(debug_assertions) { + if let Err(err) = validate_queue_group(&value) { + panic!("QueueGroup {value:?} isn't valid {err:?}"); + } + } + Self(value) + } + + #[must_use] + pub fn as_str(&self) -> &str { + &self.0 + } +} + +impl Display for QueueGroup { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + Display::fmt(&self.0, f) + } +} + +impl TryFrom for QueueGroup { + type Error = QueueGroupValidateError; + + fn try_from(value: ByteString) -> Result { + validate_queue_group(&value)?; + Ok(Self::from_dangerous_value(value)) + } +} + +impl TryFrom for QueueGroup { + type Error = QueueGroupValidateError; + + fn try_from(value: String) -> Result { + validate_queue_group(&value)?; + Ok(Self::from_dangerous_value(value.into())) + } +} + +impl From for ByteString { + fn from(value: QueueGroup) -> Self { + value.0 + } +} + +impl AsRef<[u8]> for QueueGroup { + fn as_ref(&self) -> &[u8] { + self.as_str().as_bytes() + } +} + +impl AsRef for QueueGroup { + fn as_ref(&self) -> &str { + self.as_str() + } +} + +impl Deref for QueueGroup { + type Target = str; + + fn deref(&self) -> &Self::Target { + self.as_str() + } +} + +impl Serialize for QueueGroup { + fn serialize(&self, serializer: S) -> Result { + self.as_str().serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for QueueGroup { + fn deserialize>(deserializer: D) -> Result { + let s = ByteString::deserialize(deserializer)?; + s.try_into().map_err(de::Error::custom) + } +} + +/// An error encountered while validating [`QueueGroup`] +#[derive(Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq, Eq))] +pub enum QueueGroupValidateError { + /// The value is empty + #[error("QueueGroup is empty")] + Empty, + /// The value has a length greater than 64 + #[error("QueueGroup is too long")] + TooLong, + /// The value contains an Unicode whitespace character + #[error("QueueGroup contained an illegal whitespace character")] + IllegalCharacter, +} + +fn validate_queue_group(queue_group: &str) -> Result<(), QueueGroupValidateError> { + if queue_group.is_empty() { + return Err(QueueGroupValidateError::Empty); + } + + if queue_group.len() > 64 { + // This is an arbitrary limit, but I guess the server must also have one + return Err(QueueGroupValidateError::TooLong); + } + + if queue_group.chars().any(char::is_whitespace) { + // The theoretical security limit is just ` `, `\t`, `\r` and `\n`. + // Let's be more careful. + return Err(QueueGroupValidateError::IllegalCharacter); + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use bytestring::ByteString; + + use super::{QueueGroup, QueueGroupValidateError}; + + #[test] + fn valid_queue_groups() { + let queue_groups = ["importer", "importer.thing", "blablabla:itworks"]; + for queue_group in queue_groups { + let q = QueueGroup::try_from(ByteString::from_static(queue_group)).unwrap(); + assert_eq!(queue_group, q.as_str()); + } + } + + #[test] + fn invalid_queue_groups() { + let queue_groups = [ + ("", QueueGroupValidateError::Empty), + ( + "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + QueueGroupValidateError::TooLong, + ), + ("importer ", QueueGroupValidateError::IllegalCharacter), + ("importer .thing", QueueGroupValidateError::IllegalCharacter), + (" importer", QueueGroupValidateError::IllegalCharacter), + ("importer.thing ", QueueGroupValidateError::IllegalCharacter), + ( + "importer.thing.works ", + QueueGroupValidateError::IllegalCharacter, + ), + ( + "importer.thing.works\r", + QueueGroupValidateError::IllegalCharacter, + ), + ( + "importer.thing.works\n", + QueueGroupValidateError::IllegalCharacter, + ), + ( + "importer.thing.works\t", + QueueGroupValidateError::IllegalCharacter, + ), + ( + "importer.thi ng.works", + QueueGroupValidateError::IllegalCharacter, + ), + ( + "importer.thi\rng.works", + QueueGroupValidateError::IllegalCharacter, + ), + ( + "importer.thi\nng.works", + QueueGroupValidateError::IllegalCharacter, + ), + ( + "importer.thi\tng.works", + QueueGroupValidateError::IllegalCharacter, + ), + ( + "importer.thing .works", + QueueGroupValidateError::IllegalCharacter, + ), + ( + "importer.thing\r.works", + QueueGroupValidateError::IllegalCharacter, + ), + ( + "importer.thing\n.works", + QueueGroupValidateError::IllegalCharacter, + ), + ( + "importer.thing\t.works", + QueueGroupValidateError::IllegalCharacter, + ), + (" ", QueueGroupValidateError::IllegalCharacter), + ("\r", QueueGroupValidateError::IllegalCharacter), + ("\n", QueueGroupValidateError::IllegalCharacter), + ("\t", QueueGroupValidateError::IllegalCharacter), + ]; + for (queue_group, expected_err) in queue_groups { + let err = QueueGroup::try_from(ByteString::from_static(queue_group)).unwrap_err(); + assert_eq!(expected_err, err); + } + } +} diff --git a/watermelon-proto/src/server_addr.rs b/watermelon-proto/src/server_addr.rs new file mode 100644 index 0000000..ecefcb6 --- /dev/null +++ b/watermelon-proto/src/server_addr.rs @@ -0,0 +1,367 @@ +use alloc::{str::FromStr, string::String}; +use core::{ + fmt::{self, Debug, Display, Write}, + net::IpAddr, + ops::Deref, +}; + +use bytestring::ByteString; +use percent_encoding::{percent_decode_str, percent_encode, NON_ALPHANUMERIC}; +use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; +use url::Url; + +/// Address of a NATS server +#[derive(Clone, PartialEq, Eq)] +pub struct ServerAddr { + protocol: Protocol, + transport: Transport, + host: Host, + port: u16, + username: ByteString, + password: ByteString, +} + +/// The protocol of the NATS server +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub enum Protocol { + /// Plaintext with the option to later upgrade to TLS + /// + /// This option should only be used when esplicity wanting to + /// connect using a plaintext connection. Using this option + /// over the public internet or other untrusted networks + /// leaves the client open to MITM attacks. + /// + /// Corresponds to the `nats` scheme. + PossiblyPlain, + /// TLS connection + /// + /// Requires the TCP connection to successfully upgrade to TLS. + /// + /// Corresponds to the `tls` scheme. + TLS, +} + +/// The transport protocol of the NATS server +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub enum Transport { + /// Transmit data over a TCP stream + TCP, + /// Transmit data over WebSocket frames + Websocket, +} + +/// The hostname of the NATS server +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Host { + /// An IPv4 or IPv6 address + Ip(IpAddr), + /// A DNS hostname + Dns(ByteString), +} + +impl ServerAddr { + /// Get the connection protocol + pub fn protocol(&self) -> Protocol { + self.protocol + } + + /// Get the transport protocol + pub fn transport(&self) -> Transport { + self.transport + } + + /// Get the hostname + pub fn host(&self) -> &Host { + &self.host + } + + /// Get the port + pub fn port(&self) -> u16 { + self.port + } + + fn is_default_port(&self) -> bool { + self.port == protocol_transport_to_port(self.protocol, self.transport) + } + + /// Get the username + pub fn username(&self) -> Option<&str> { + if self.username.is_empty() { + None + } else { + Some(&self.username) + } + } + + /// Get the password + pub fn password(&self) -> Option<&str> { + if self.password.is_empty() { + None + } else { + Some(&self.password) + } + } +} + +impl FromStr for ServerAddr { + type Err = ServerAddrError; + + fn from_str(value: &str) -> Result { + let url = value.parse::().map_err(ServerAddrError::InvalidUrl)?; + + let (protocol, transport) = match url.scheme() { + "nats" => (Protocol::PossiblyPlain, Transport::TCP), + "tls" => (Protocol::TLS, Transport::TCP), + "ws" => (Protocol::PossiblyPlain, Transport::Websocket), + "wss" => (Protocol::TLS, Transport::Websocket), + _ => return Err(ServerAddrError::InvalidScheme), + }; + + let host = match url.host() { + Some(url::Host::Ipv4(addr)) => Host::Ip(IpAddr::V4(addr)), + Some(url::Host::Ipv6(addr)) => Host::Ip(IpAddr::V6(addr)), + Some(url::Host::Domain(host)) => { + // TODO: this shouldn't be necessary + let host = host + .strip_prefix('[') + .and_then(|host| host.strip_suffix(']')) + .unwrap_or(host); + match host.parse::() { + Ok(ip) => Host::Ip(ip), + Err(_) => Host::Dns(host.into()), + } + } + None => return Err(ServerAddrError::MissingHost), + }; + + let port = if let Some(port) = url.port() { + port + } else { + protocol_transport_to_port(protocol, transport) + }; + + let username = percent_decode_str(url.username()) + .decode_utf8() + .map_err(|_| ServerAddrError::UsernameInvalidUtf8)? + .deref() + .into(); + let password = percent_decode_str(url.password().unwrap_or_default()) + .decode_utf8() + .map_err(|_| ServerAddrError::PasswordInvalidUtf8)? + .deref() + .into(); + + Ok(Self { + protocol, + transport, + host, + port, + username, + password, + }) + } +} + +impl Debug for ServerAddr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let username = if self.username.is_empty() { + "" + } else { + "" + }; + let password = if self.password.is_empty() { + "" + } else { + "" + }; + f.debug_struct("ServerAddr") + .field("protocol", &self.protocol) + .field("transport", &self.transport) + .field("host", &self.host) + .field("port", &self.port) + .field("username", &username) + .field("password", &password) + .finish() + } +} + +impl Display for ServerAddr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match (self.protocol, self.transport) { + (Protocol::PossiblyPlain, Transport::TCP) => "nats", + (Protocol::TLS, Transport::TCP) => "tls", + (Protocol::PossiblyPlain, Transport::Websocket) => "ws", + (Protocol::TLS, Transport::Websocket) => "wss", + })?; + f.write_str("://")?; + + if let Some(username) = self.username() { + Display::fmt(&percent_encode(username.as_bytes(), NON_ALPHANUMERIC), f)?; + + if let Some(password) = self.password() { + write!( + f, + ":{}", + percent_encode(password.as_bytes(), NON_ALPHANUMERIC) + )?; + } + f.write_char('@')?; + } + + match &self.host { + Host::Ip(IpAddr::V4(addr)) => Display::fmt(addr, f)?, + Host::Ip(IpAddr::V6(addr)) => write!(f, "[{addr}]")?, + Host::Dns(record) => Display::fmt(record, f)?, + } + if !self.is_default_port() { + write!(f, ":{}", self.port)?; + } + + Ok(()) + } +} + +impl<'de> Deserialize<'de> for ServerAddr { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let val = String::deserialize(deserializer)?; + val.parse().map_err(de::Error::custom) + } +} + +impl Serialize for ServerAddr { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.collect_str(self) + } +} + +/// An error encountered while parsing [`ServerAddr`] +#[derive(Debug, thiserror::Error)] +pub enum ServerAddrError { + /// The Url could not be parsed + #[error("invalid Url")] + InvalidUrl(#[source] url::ParseError), + /// The Url has a bad scheme + #[error("invalid Url scheme")] + InvalidScheme, + /// The Url is missing the hostname + #[error("missing host")] + MissingHost, + /// The Url contains a non-utf8 username + #[error("username is not utf-8")] + UsernameInvalidUtf8, + /// The Url contains a non-utf8 password + #[error("password is not utf-8")] + PasswordInvalidUtf8, +} + +fn protocol_transport_to_port(protocol: Protocol, transport: Transport) -> u16 { + match (protocol, transport) { + (Protocol::PossiblyPlain | Protocol::TLS, Transport::TCP) => 4222, + (Protocol::PossiblyPlain, Transport::Websocket) => 80, + (Protocol::TLS, Transport::Websocket) => 443, + } +} + +#[cfg(test)] +mod tests { + use alloc::string::ToString; + use core::net::{IpAddr, Ipv4Addr, Ipv6Addr}; + + use super::{Host, Protocol, ServerAddr, Transport}; + + #[test] + fn nats() { + let server_addr = "nats://127.0.0.1".parse::().unwrap(); + assert_eq!(server_addr.transport(), Transport::TCP); + assert_eq!(server_addr.protocol(), Protocol::PossiblyPlain); + assert_eq!( + server_addr.host(), + &Host::Ip(IpAddr::V4(Ipv4Addr::LOCALHOST)) + ); + assert_eq!(server_addr.port(), 4222); + assert_eq!(server_addr.username(), None); + assert_eq!(server_addr.password(), None); + assert_eq!(server_addr.to_string(), "nats://127.0.0.1"); + } + + #[test] + fn nats_non_default_port() { + let server_addr = "nats://127.0.0.1:4321".parse::().unwrap(); + assert_eq!(server_addr.transport(), Transport::TCP); + assert_eq!(server_addr.protocol(), Protocol::PossiblyPlain); + assert_eq!( + server_addr.host(), + &Host::Ip(IpAddr::V4(Ipv4Addr::LOCALHOST)) + ); + assert_eq!(server_addr.port(), 4321); + assert_eq!(server_addr.username(), None); + assert_eq!(server_addr.password(), None); + assert_eq!(server_addr.to_string(), "nats://127.0.0.1:4321"); + } + + #[test] + fn nats_ipv6() { + let server_addr = "nats://[::1]".parse::().unwrap(); + assert_eq!(server_addr.transport(), Transport::TCP); + assert_eq!(server_addr.protocol(), Protocol::PossiblyPlain); + assert_eq!( + server_addr.host(), + &Host::Ip(IpAddr::V6(Ipv6Addr::LOCALHOST)) + ); + assert_eq!(server_addr.port(), 4222); + assert_eq!(server_addr.username(), None); + assert_eq!(server_addr.password(), None); + assert_eq!(server_addr.to_string(), "nats://[::1]"); + } + + #[test] + fn tls() { + let server_addr = "tls://127.0.0.1".parse::().unwrap(); + assert_eq!(server_addr.transport(), Transport::TCP); + assert_eq!(server_addr.protocol(), Protocol::TLS); + assert_eq!( + server_addr.host(), + &Host::Ip(IpAddr::V4(Ipv4Addr::LOCALHOST)) + ); + assert_eq!(server_addr.port(), 4222); + assert_eq!(server_addr.username(), None); + assert_eq!(server_addr.password(), None); + assert_eq!(server_addr.to_string(), "tls://127.0.0.1"); + } + + #[test] + fn ws() { + let server_addr = "ws://127.0.0.1".parse::().unwrap(); + assert_eq!(server_addr.transport(), Transport::Websocket); + assert_eq!(server_addr.protocol(), Protocol::PossiblyPlain); + assert_eq!( + server_addr.host(), + &Host::Ip(IpAddr::V4(Ipv4Addr::LOCALHOST)) + ); + assert_eq!(server_addr.port(), 80); + assert_eq!(server_addr.username(), None); + assert_eq!(server_addr.password(), None); + assert_eq!(server_addr.to_string(), "ws://127.0.0.1"); + } + + #[test] + fn wss() { + let server_addr = "wss://127.0.0.1".parse::().unwrap(); + assert_eq!(server_addr.transport(), Transport::Websocket); + assert_eq!(server_addr.protocol(), Protocol::TLS); + assert_eq!( + server_addr.host(), + &Host::Ip(IpAddr::V4(Ipv4Addr::LOCALHOST)) + ); + assert_eq!(server_addr.port(), 443); + assert_eq!(server_addr.username(), None); + assert_eq!(server_addr.password(), None); + assert_eq!(server_addr.to_string(), "wss://127.0.0.1"); + } +} diff --git a/watermelon-proto/src/server_error.rs b/watermelon-proto/src/server_error.rs new file mode 100644 index 0000000..6ab403a --- /dev/null +++ b/watermelon-proto/src/server_error.rs @@ -0,0 +1,110 @@ +use bytestring::ByteString; + +#[derive(Debug, PartialEq, Eq, thiserror::Error)] +pub enum ServerError { + #[error("subject is invalid")] + InvalidSubject, + #[error("permissions violation for publish")] + PublishPermissionViolation, + #[error("permissions violation for subscription")] + SubscribePermissionViolation, + + #[error("unknown protocol operation")] + UnknownProtocolOperation, + + #[error("attempted to connect to route port")] + ConnectionAttemptedToWrongPort, + + #[error("authorization violation")] + AuthorizationViolation, + #[error("authorization timeout")] + AuthorizationTimeout, + #[error("invalid client protocol")] + InvalidClientProtocol, + #[error("maximum control line exceeded")] + MaximumControlLineExceeded, + #[error("parser error")] + ParseError, + #[error("secure connection, tls required")] + TlsRequired, + #[error("stale connection")] + StaleConnection, + #[error("maximum connections exceeded")] + MaximumConnectionsExceeded, + #[error("slow consumer")] + SlowConsumer, + #[error("maximum payload violation")] + MaximumPayloadViolation, + + #[error("unknown error: {raw_message}")] + Other { raw_message: ByteString }, +} + +impl ServerError { + pub fn is_fatal(&self) -> Option { + match self { + Self::InvalidSubject + | Self::PublishPermissionViolation + | Self::SubscribePermissionViolation => Some(false), + + Self::UnknownProtocolOperation + | Self::ConnectionAttemptedToWrongPort + | Self::AuthorizationViolation + | Self::AuthorizationTimeout + | Self::InvalidClientProtocol + | Self::MaximumControlLineExceeded + | Self::ParseError + | Self::TlsRequired + | Self::StaleConnection + | Self::MaximumConnectionsExceeded + | Self::SlowConsumer + | Self::MaximumPayloadViolation => Some(true), + + Self::Other { .. } => None, + } + } + + pub(crate) fn parse(raw_message: ByteString) -> Self { + const PUBLISH_PERMISSIONS: &str = "Permissions Violation for Publish"; + const SUBSCRIPTION_PERMISSIONS: &str = "Permissions Violation for Subscription"; + + let m = raw_message.trim(); + if m.eq_ignore_ascii_case("Invalid Subject") { + Self::InvalidSubject + } else if m.len() > PUBLISH_PERMISSIONS.len() + && m[..PUBLISH_PERMISSIONS.len()].eq_ignore_ascii_case(PUBLISH_PERMISSIONS) + { + Self::PublishPermissionViolation + } else if m.len() > SUBSCRIPTION_PERMISSIONS.len() + && m[..SUBSCRIPTION_PERMISSIONS.len()].eq_ignore_ascii_case(SUBSCRIPTION_PERMISSIONS) + { + Self::SubscribePermissionViolation + } else if m.eq_ignore_ascii_case("Unknown Protocol Operation") { + Self::UnknownProtocolOperation + } else if m.eq_ignore_ascii_case("Attempted To Connect To Route Port") { + Self::ConnectionAttemptedToWrongPort + } else if m.eq_ignore_ascii_case("Authorization Violation") { + Self::AuthorizationViolation + } else if m.eq_ignore_ascii_case("Authorization Timeout") { + Self::AuthorizationTimeout + } else if m.eq_ignore_ascii_case("Invalid Client Protocol") { + Self::InvalidClientProtocol + } else if m.eq_ignore_ascii_case("Maximum Control Line Exceeded") { + Self::MaximumControlLineExceeded + } else if m.eq_ignore_ascii_case("Parser Error") { + Self::ParseError + } else if m.eq_ignore_ascii_case("Secure Connection - TLS Required") { + Self::TlsRequired + } else if m.eq_ignore_ascii_case("Stale Connection") { + Self::StaleConnection + } else if m.eq_ignore_ascii_case("Maximum Connections Exceeded") { + Self::MaximumConnectionsExceeded + } else if m.eq_ignore_ascii_case("Slow Consumer") { + Self::SlowConsumer + } else if m.eq_ignore_ascii_case("Maximum Payload Violation") { + Self::MaximumPayloadViolation + } else { + Self::Other { raw_message } + } + } +} diff --git a/watermelon-proto/src/server_info.rs b/watermelon-proto/src/server_info.rs new file mode 100644 index 0000000..492bb6b --- /dev/null +++ b/watermelon-proto/src/server_info.rs @@ -0,0 +1,69 @@ +use alloc::{string::String, vec::Vec}; +use core::{ + net::IpAddr, + num::{NonZeroU16, NonZeroU32}, +}; + +use serde::Deserialize; + +use crate::ServerAddr; + +#[derive(Debug, PartialEq, Eq, Deserialize)] +#[allow(clippy::struct_excessive_bools)] +pub struct ServerInfo { + #[serde(rename = "server_id")] + pub id: String, + #[serde(rename = "server_name")] + pub name: String, + pub version: String, + #[serde(rename = "go")] + pub go_version: String, + pub host: IpAddr, + pub port: NonZeroU16, + #[serde(rename = "headers")] + pub supports_headers: bool, + pub max_payload: NonZeroU32, + #[serde(rename = "proto")] + pub protocol_version: u32, + #[serde(default)] + pub client_id: Option, + #[serde(default)] + pub auth_required: bool, + #[serde(default)] + pub tls_required: bool, + #[serde(default)] + pub tls_verify: bool, + #[serde(default)] + pub tls_available: bool, + #[serde(default)] + pub connect_urls: Vec, + #[serde(default, rename = "ws_connect_urls")] + pub websocket_connect_urls: Vec, + #[serde(default, rename = "ldm")] + pub lame_duck_mode: bool, + #[serde(default)] + pub git_commit: Option, + #[serde(default, rename = "jetstream")] + pub supports_jetstream: bool, + #[serde(default)] + pub ip: Option, + #[serde(default)] + pub client_ip: Option, + #[serde(default)] + pub nonce: Option, + #[serde(default, rename = "cluster")] + pub cluster_name: Option, + #[serde(default)] + pub domain: Option, + + #[serde(flatten)] + pub non_standard: NonStandardServerInfo, +} + +#[derive(Debug, PartialEq, Eq, Deserialize, Default)] +#[non_exhaustive] +pub struct NonStandardServerInfo { + #[cfg(feature = "non-standard-zstd")] + #[serde(default, rename = "m4ss_zstd")] + pub zstd: bool, +} diff --git a/watermelon-proto/src/status_code.rs b/watermelon-proto/src/status_code.rs new file mode 100644 index 0000000..83fbfd7 --- /dev/null +++ b/watermelon-proto/src/status_code.rs @@ -0,0 +1,152 @@ +use core::{ + fmt::{self, Display, Formatter}, + num::NonZeroU16, + str::FromStr, +}; + +use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; + +use crate::util; + +/// A NATS status code +/// +/// Constants are provided for known and accurately status codes +/// within the NATS Server. +/// +/// Values are guaranteed to be in range `100..1000`. +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub struct StatusCode(NonZeroU16); + +impl StatusCode { + /// The Jetstream consumer hearthbeat timeout has been reached with no new messages to deliver + /// + /// See [ADR-9]. + /// + /// [ADR-9]: https://github.com/nats-io/nats-architecture-and-design/blob/main/adr/ADR-9.md + pub const IDLE_HEARTBEAT: StatusCode = Self::new_internal(100); + /// The request has successfully been sent + pub const OK: StatusCode = Self::new_internal(200); + /// The requested Jetstream resource doesn't exist + pub const NOT_FOUND: StatusCode = Self::new_internal(404); + /// The pull consumer batch reached the timeout + pub const TIMEOUT: StatusCode = Self::new_internal(408); + /// The request was sent to a subject that does not appear to have any subscribers listening + pub const NO_RESPONDERS: StatusCode = Self::new_internal(503); + + /// Decodes a status code from a slice of ASCII characters. + /// + /// The ASCII representation is expected to be in the form of `"NNN"`, where `N` is a numeric + /// digit. + /// + /// # Errors + /// + /// It returns an error if the slice of bytes does not contain a valid status code. + pub fn from_ascii_bytes(buf: &[u8]) -> Result { + if buf.len() != 3 { + return Err(StatusCodeError); + } + + util::parse_u16(buf) + .map_err(|_| StatusCodeError)? + .try_into() + .map(Self) + .map_err(|_| StatusCodeError) + } + + const fn new_internal(val: u16) -> Self { + match NonZeroU16::new(val) { + Some(val) => Self(val), + None => unreachable!(), + } + } +} + +impl FromStr for StatusCode { + type Err = StatusCodeError; + + fn from_str(s: &str) -> Result { + Self::from_ascii_bytes(s.as_bytes()) + } +} + +impl Display for StatusCode { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl TryFrom for StatusCode { + type Error = StatusCodeError; + + fn try_from(value: u16) -> Result { + if (100..1000).contains(&value) { + Ok(Self(NonZeroU16::new(value).unwrap())) + } else { + Err(StatusCodeError) + } + } +} + +impl From for u16 { + fn from(value: StatusCode) -> Self { + value.0.get() + } +} + +impl Serialize for StatusCode { + fn serialize(&self, serializer: S) -> Result { + u16::from(*self).serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for StatusCode { + fn deserialize>(deserializer: D) -> Result { + let n = u16::deserialize(deserializer)?; + n.try_into().map_err(de::Error::custom) + } +} + +/// An error encountered while parsing [`StatusCode`] +#[derive(Debug, thiserror::Error)] +#[non_exhaustive] +#[error("invalid status code")] +pub struct StatusCodeError; + +#[cfg(test)] +mod tests { + use alloc::string::ToString; + + use claims::assert_err; + + use super::StatusCode; + + #[test] + fn valid_status_codes() { + let status_codes = [100, 200, 404, 408, 409, 503]; + + for status_code in status_codes { + assert_eq!( + status_code, + u16::from(StatusCode::try_from(status_code).unwrap()) + ); + + let s = status_code.to_string(); + assert_eq!( + status_code, + u16::from(StatusCode::from_ascii_bytes(s.as_bytes()).unwrap()) + ); + } + } + + #[test] + fn invalid_status_codes() { + let status_codes = [0, 5, 55, 9999]; + + for status_code in status_codes { + assert_err!(StatusCode::try_from(status_code)); + + let s = status_code.to_string(); + assert_err!(StatusCode::from_ascii_bytes(s.as_bytes())); + } + } +} diff --git a/watermelon-proto/src/subject.rs b/watermelon-proto/src/subject.rs new file mode 100644 index 0000000..68e4008 --- /dev/null +++ b/watermelon-proto/src/subject.rs @@ -0,0 +1,259 @@ +use alloc::string::String; +use core::{ + fmt::{self, Display}, + ops::Deref, +}; +use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; + +use bytestring::ByteString; + +/// A string that can be used to represent a subject +/// +/// `Subject` contains a string that is guaranteed [^1] to +/// contain a valid subject that meets the following requirements: +/// +/// * The value is not empty +/// * The value has a length less than or equal to 256 [^2] +/// * The value does not contain any whitespace characters or `:` +/// * The value does not contain wrongly placed `*` or `>` characters +/// +/// `Subject` can be constructed from [`Subject::from_static`] +/// or any of the `TryFrom` implementations. +/// +/// [^1]: Because [`Subject::from_dangerous_value`] is safe to call, +/// unsafe code must not assume any of the above invariants. +/// [^2]: Messages coming from the NATS server are allowed to violate this rule. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub struct Subject(ByteString); + +impl Subject { + /// Construct `Subject` from a static string + /// + /// # Panics + /// + /// Will panic if `value` isn't a valid `Subject` + #[must_use] + pub fn from_static(value: &'static str) -> Self { + Self::try_from(ByteString::from_static(value)).expect("invalid Subject") + } + + /// Construct a `Subject` from a string, without checking invariants + /// + /// This method bypasses invariants checks implemented by [`Subject::from_static`] + /// and all `TryFrom` implementations. + /// + /// # Security + /// + /// While calling this method can eliminate the runtime performance cost of + /// checking the string, constructing `Subject` with an invalid string and + /// then calling the NATS server with it can cause serious security issues. + /// When in doubt use the [`Subject::from_static`] or any of the `TryFrom` + /// implementations. + #[expect( + clippy::missing_panics_doc, + reason = "The subject validation is only made in debug" + )] + #[must_use] + pub fn from_dangerous_value(value: ByteString) -> Self { + if cfg!(debug_assertions) { + if let Err(err) = validate_subject(&value) { + panic!("Subject {value:?} isn't valid {err:?}"); + } + } + Self(value) + } + + #[must_use] + pub fn as_str(&self) -> &str { + &self.0 + } +} + +impl Display for Subject { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + Display::fmt(&self.0, f) + } +} + +impl TryFrom for Subject { + type Error = SubjectValidateError; + + fn try_from(value: ByteString) -> Result { + validate_subject(&value)?; + Ok(Self::from_dangerous_value(value)) + } +} + +impl TryFrom for Subject { + type Error = SubjectValidateError; + + fn try_from(value: String) -> Result { + validate_subject(&value)?; + Ok(Self::from_dangerous_value(value.into())) + } +} + +impl From for ByteString { + fn from(value: Subject) -> Self { + value.0 + } +} + +impl AsRef<[u8]> for Subject { + fn as_ref(&self) -> &[u8] { + self.as_str().as_bytes() + } +} + +impl AsRef for Subject { + fn as_ref(&self) -> &str { + self.as_str() + } +} + +impl Deref for Subject { + type Target = str; + + fn deref(&self) -> &Self::Target { + self.as_str() + } +} + +impl Serialize for Subject { + fn serialize(&self, serializer: S) -> Result { + self.as_str().serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for Subject { + fn deserialize>(deserializer: D) -> Result { + let s = ByteString::deserialize(deserializer)?; + s.try_into().map_err(de::Error::custom) + } +} + +/// An error encountered while validating [`Subject`] +#[derive(Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq, Eq))] +pub enum SubjectValidateError { + /// The value is empty + #[error("Subject is empty")] + Empty, + /// The value has a length greater than 256 + #[error("Subject is too long")] + TooLong, + /// The value contains an Unicode whitespace character + #[error("Subject contained an illegal whitespace character")] + IllegalCharacter, + /// The value contains consecutive `.` characters + #[error("Subject contained a broken token")] + BrokenToken, + /// The value contains `.` or `>` together with other characters + /// in the same token, or the `>` is in the non-last token + #[error("Subject contained a broken wildcard")] + BrokenWildcard, +} + +fn validate_subject(subject: &str) -> Result<(), SubjectValidateError> { + if subject.is_empty() { + return Err(SubjectValidateError::Empty); + } + + if subject.len() > 256 { + // This is an arbitrary limit, but I guess the server must also have one + return Err(SubjectValidateError::TooLong); + } + + if subject.chars().any(char::is_whitespace) { + // The theoretical security limit is just ` `, `\t`, `\r` and `\n`. + // Let's be more careful. + return Err(SubjectValidateError::IllegalCharacter); + } + + let mut tokens = subject.split('.').peekable(); + while let Some(token) = tokens.next() { + if token.is_empty() || token.contains("..") { + return Err(SubjectValidateError::BrokenToken); + } + + if token.len() > 1 && (token.contains(['*', '>'])) { + return Err(SubjectValidateError::BrokenWildcard); + } + + if token == ">" && tokens.peek().is_some() { + return Err(SubjectValidateError::BrokenWildcard); + } + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use bytestring::ByteString; + + use super::{Subject, SubjectValidateError}; + + #[test] + fn valid_subjects() { + let subjects = [ + "cmd", + "cmd.endpoint", + "cmd.endpoint.detail", + "cmd.*.detail", + "cmd.*.*", + "cmd.endpoint.>", + ]; + for subject in subjects { + let s = Subject::try_from(ByteString::from_static(subject)).unwrap(); + assert_eq!(subject, s.as_str()); + } + } + + #[test] + fn invalid_subjects() { + let subjects = [ + ("", SubjectValidateError::Empty), + + ("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", SubjectValidateError::TooLong), + + ("cmd ", SubjectValidateError::IllegalCharacter), + ("cmd .endpoint", SubjectValidateError::IllegalCharacter), + (" cmd", SubjectValidateError::IllegalCharacter), + ("cmd.endpoint ", SubjectValidateError::IllegalCharacter), + ("cmd.endpoint.detail ", SubjectValidateError::IllegalCharacter), + ("cmd.endpoint.detail\r", SubjectValidateError::IllegalCharacter), + ("cmd.endpoint.detail\n", SubjectValidateError::IllegalCharacter), + ("cmd.endpoint.detail\t", SubjectValidateError::IllegalCharacter), + ("cmd.endp oint.detail", SubjectValidateError::IllegalCharacter), + ("cmd.endp\roint.detail", SubjectValidateError::IllegalCharacter), + ("cmd.endp\noint.detail", SubjectValidateError::IllegalCharacter), + ("cmd.endp\toint.detail", SubjectValidateError::IllegalCharacter), + ("cmd.endpoint .detail", SubjectValidateError::IllegalCharacter), + ("cmd.endpoint\r.detail", SubjectValidateError::IllegalCharacter), + ("cmd.endpoint\n.detail", SubjectValidateError::IllegalCharacter), + ("cmd.endpoint\t.detail", SubjectValidateError::IllegalCharacter), + (" ", SubjectValidateError::IllegalCharacter), + ("\r", SubjectValidateError::IllegalCharacter), + ("\n", SubjectValidateError::IllegalCharacter), + ("\t", SubjectValidateError::IllegalCharacter), + + ("cmd..endpoint", SubjectValidateError::BrokenToken), + (".cmd.endpoint", SubjectValidateError::BrokenToken), + ("cmd.endpoint.", SubjectValidateError::BrokenToken), + + ("cmd.**", SubjectValidateError::BrokenWildcard), + ("cmd.**.endpoint", SubjectValidateError::BrokenWildcard), + ("cmd.a*.endpoint", SubjectValidateError::BrokenWildcard), + ("cmd.*a.endpoint", SubjectValidateError::BrokenWildcard), + ("cmd.>.endpoint", SubjectValidateError::BrokenWildcard), + ("cmd.a>.endpoint", SubjectValidateError::BrokenWildcard), + ("cmd.endpoint.a>", SubjectValidateError::BrokenWildcard), + ("cmd.endpoint.>a", SubjectValidateError::BrokenWildcard), + ]; + for (subject, expected_err) in subjects { + let err = Subject::try_from(ByteString::from_static(subject)).unwrap_err(); + assert_eq!(expected_err, err); + } + } +} diff --git a/watermelon-proto/src/subscription_id.rs b/watermelon-proto/src/subscription_id.rs new file mode 100644 index 0000000..1a5b506 --- /dev/null +++ b/watermelon-proto/src/subscription_id.rs @@ -0,0 +1,38 @@ +use core::fmt::{self, Display}; + +use crate::util::{self, ParseUintError}; + +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub struct SubscriptionId(u64); + +impl SubscriptionId { + pub const MIN: Self = SubscriptionId(1); + pub const MAX: Self = SubscriptionId(u64::MAX); + + /// Converts a slice of ASCII bytes to a `SubscriptionId`. + /// + /// # Errors + /// + /// It returns an error if the bytes do not contain a valid numeric value. + pub fn from_ascii_bytes(buf: &[u8]) -> Result { + util::parse_u64(buf).map(Self) + } +} + +impl From for SubscriptionId { + fn from(value: u64) -> Self { + Self(value) + } +} + +impl From for u64 { + fn from(value: SubscriptionId) -> Self { + value.0 + } +} + +impl Display for SubscriptionId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + Display::fmt(&self.0, f) + } +} diff --git a/watermelon-proto/src/tests.rs b/watermelon-proto/src/tests.rs new file mode 100644 index 0000000..14e1ce5 --- /dev/null +++ b/watermelon-proto/src/tests.rs @@ -0,0 +1,12 @@ +use bytes::{Buf, Bytes}; + +pub(crate) trait ToBytes: Buf { + fn to_bytes(mut self) -> Bytes + where + Self: Sized, + { + self.copy_to_bytes(self.remaining()) + } +} + +impl ToBytes for T {} diff --git a/watermelon-proto/src/util/buf_list.rs b/watermelon-proto/src/util/buf_list.rs new file mode 100644 index 0000000..028ebbc --- /dev/null +++ b/watermelon-proto/src/util/buf_list.rs @@ -0,0 +1,111 @@ +use alloc::collections::VecDeque; +use core::cmp::Ordering; +#[cfg(feature = "std")] +use std::io; + +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +#[derive(Debug)] +pub(crate) struct BufList { + bufs: VecDeque, + len: usize, +} + +impl BufList { + pub(crate) const fn new() -> Self { + Self { + bufs: VecDeque::new(), + len: 0, + } + } + + pub(crate) fn push(&mut self, buf: B) { + debug_assert!(buf.has_remaining()); + let rem = buf.remaining(); + self.bufs.push_back(buf); + self.len += rem; + } +} + +impl Buf for BufList { + fn remaining(&self) -> usize { + self.len + } + + fn has_remaining(&self) -> bool { + !self.bufs.is_empty() + } + + fn chunk(&self) -> &[u8] { + self.bufs.front().map(Buf::chunk).unwrap_or_default() + } + + fn advance(&mut self, mut cnt: usize) { + assert!( + cnt <= self.remaining(), + "advance out of range ({} <= {})", + cnt, + self.remaining() + ); + + while cnt > 0 { + let entry = self.bufs.front_mut().unwrap(); + let remaining = entry.remaining(); + if remaining < cnt { + entry.advance(cnt); + self.len -= cnt; + cnt -= cnt; + } else { + let _ = self.bufs.remove(0); + self.len -= remaining; + cnt -= remaining; + } + } + } + + #[cfg(feature = "std")] + fn chunks_vectored<'a>(&'a self, mut dst: &mut [io::IoSlice<'a>]) -> usize { + let mut filled = 0; + for buf in &self.bufs { + let n = buf.chunks_vectored(dst); + filled += n; + + dst = &mut dst[n..]; + if dst.is_empty() { + break; + } + } + + filled + } + + fn copy_to_bytes(&mut self, len: usize) -> Bytes { + assert!( + len <= self.remaining(), + "copy_to_bytes out of range ({} <= {})", + len, + self.remaining() + ); + + if let Some(first) = self.bufs.front_mut() { + match first.remaining().cmp(&len) { + Ordering::Greater => { + self.len -= len; + return first.copy_to_bytes(len); + } + Ordering::Equal => { + self.len -= len; + return self.bufs.remove(0).unwrap().copy_to_bytes(len); + } + Ordering::Less => {} + } + } + + let mut bufs = BytesMut::with_capacity(len); + bufs.put(self.take(len)); + let bufs = bufs.freeze(); + + self.len -= len; + bufs + } +} diff --git a/watermelon-proto/src/util/lines_iter.rs b/watermelon-proto/src/util/lines_iter.rs new file mode 100644 index 0000000..498e849 --- /dev/null +++ b/watermelon-proto/src/util/lines_iter.rs @@ -0,0 +1,54 @@ +use core::mem; + +use bytes::{Buf, Bytes}; + +pub(crate) fn lines_iter(bytes: Bytes) -> impl Iterator { + struct LinesIterator(Bytes); + + impl Iterator for LinesIterator { + type Item = Bytes; + + fn next(&mut self) -> Option { + if self.0.is_empty() { + return None; + } + + Some(match memchr::memmem::find(&self.0, b"\r\n") { + Some(i) => { + let chunk = self.0.split_to(i); + self.0.advance("\r\n".len()); + chunk + } + None => mem::take(&mut self.0), + }) + } + } + + LinesIterator(bytes) +} + +#[cfg(test)] +mod tests { + use bytes::{Bytes, BytesMut}; + + use super::lines_iter; + + #[test] + fn iterate_lines() { + let expected_chunks = ["", "abcd", "12334534", "alkfdasfsd", "", "-"]; + let mut combined_chunk = expected_chunks + .iter() + .fold(BytesMut::new(), |mut buf, chunk| { + buf.extend_from_slice(chunk.as_bytes()); + buf.extend_from_slice(b"\r\n"); + buf + }); + combined_chunk.truncate(combined_chunk.len() - "\r\n".len()); + let combined_chunk = combined_chunk.freeze(); + + let expected_chunks = expected_chunks + .iter() + .map(|c| Bytes::from_static(c.as_bytes())); + assert!(expected_chunks.eq(lines_iter(combined_chunk))); + } +} diff --git a/watermelon-proto/src/util/mod.rs b/watermelon-proto/src/util/mod.rs new file mode 100644 index 0000000..0ba393b --- /dev/null +++ b/watermelon-proto/src/util/mod.rs @@ -0,0 +1,10 @@ +pub(crate) use self::buf_list::BufList; +pub(crate) use self::lines_iter::lines_iter; +pub(crate) use self::split_spaces::split_spaces; +pub use self::uint::ParseUintError; +pub(crate) use self::uint::{parse_u16, parse_u64, parse_usize}; + +mod buf_list; +mod lines_iter; +mod split_spaces; +mod uint; diff --git a/watermelon-proto/src/util/split_spaces.rs b/watermelon-proto/src/util/split_spaces.rs new file mode 100644 index 0000000..57b44df --- /dev/null +++ b/watermelon-proto/src/util/split_spaces.rs @@ -0,0 +1,29 @@ +use core::array; + +use bytes::{Buf, Bytes}; + +pub(crate) fn split_spaces(mut bytes: Bytes) -> impl Iterator { + let mut chunks = array::from_fn::<_, 6, _>(|_| Bytes::new()); + let mut found = 0; + + for chunk in &mut chunks { + let Some(i) = memchr::memchr2(b' ', b'\t', &bytes) else { + if !bytes.is_empty() { + *chunk = bytes; + found += 1; + } + break; + }; + + *chunk = bytes.split_to(i); + found += 1; + + let spaces = bytes + .iter() + .take_while(|b| matches!(b, b' ' | b'\t')) + .count(); + bytes.advance(spaces); + } + + chunks.into_iter().take(found) +} diff --git a/watermelon-proto/src/util/uint.rs b/watermelon-proto/src/util/uint.rs new file mode 100644 index 0000000..5fff08e --- /dev/null +++ b/watermelon-proto/src/util/uint.rs @@ -0,0 +1,51 @@ +macro_rules! parse_unsigned { + ($name:ident, $num:ty) => { + pub(crate) fn $name(buf: &[u8]) -> Result<$num, ParseUintError> { + let mut val: $num = 0; + + for &b in buf { + if !b.is_ascii_digit() { + return Err(ParseUintError::InvalidByte(b)); + } + + val = val.checked_mul(10).ok_or(ParseUintError::Overflow)?; + val = val + .checked_add(<$num>::from(b - b'0')) + .ok_or(ParseUintError::Overflow)?; + } + + Ok(val) + } + }; +} + +parse_unsigned!(parse_u16, u16); +parse_unsigned!(parse_u64, u64); +parse_unsigned!(parse_usize, usize); + +#[derive(Debug, thiserror::Error)] +pub enum ParseUintError { + #[error("invalid byte {0:?}")] + InvalidByte(u8), + #[error("overflow")] + Overflow, +} + +#[cfg(test)] +mod tests { + use alloc::string::ToString; + + use claims::assert_ok_eq; + + use super::{parse_u16, parse_u64, parse_usize}; + + #[test] + fn parse_u16_range() { + for n in 0..=u16::MAX { + let s = n.to_string(); + assert_ok_eq!(parse_u16(s.as_bytes()), n); + assert_ok_eq!(parse_usize(s.as_bytes()), usize::from(n)); + assert_ok_eq!(parse_u64(s.as_bytes()), u64::from(n)); + } + } +} diff --git a/watermelon/Cargo.toml b/watermelon/Cargo.toml new file mode 100644 index 0000000..222a6cb --- /dev/null +++ b/watermelon/Cargo.toml @@ -0,0 +1,54 @@ +[package] +name = "watermelon" +version = "0.1.0" +description = "High level actor based implementation NATS Core and NATS Jetstream client implementation" +categories = ["api-bindings", "network-programming"] +keywords = ["nats", "client", "jetstream"] +edition.workspace = true +license.workspace = true +repository.workspace = true +rust-version.workspace = true + +[package.metadata.docs.rs] +features = ["websocket", "non-standard-zstd"] + +[dependencies] +tokio = { version = "1.36", features = ["rt", "sync", "time"] } +arc-swap = "1" +futures-core = "0.3" +futures-util = { version = "0.3", default-features = false } +bytes = "1" +serde = { version = "1", features = ["derive"] } +serde_json = "1" +pin-project-lite = "0.2" +rand = "0.8" +chrono = { version = "0.4", default-features = false, features = ["std", "clock", "serde"] } + +# from-env +envy = { version = "0.4", optional = true } + +# portable-atomic +portable-atomic = { version = "1", optional = true } + +watermelon-mini = { version = "0.1", path = "../watermelon-mini", default-features = false } +watermelon-net = { version = "0.1", path = "../watermelon-net" } +watermelon-proto = { version = "0.1", path = "../watermelon-proto" } +watermelon-nkeys = { version = "0.1", path = "../watermelon-nkeys", default-features = false } + +thiserror = "2" + +[dev-dependencies] +claims = "0.8" + +[features] +default = ["aws-lc-rs", "from-env"] +websocket = ["watermelon-mini/websocket"] +aws-lc-rs = ["watermelon-mini/aws-lc-rs", "watermelon-nkeys/aws-lc-rs"] +ring = ["watermelon-mini/ring", "watermelon-nkeys/ring"] +fips = ["watermelon-mini/fips", "watermelon-nkeys/fips"] +from-env = ["dep:envy"] +portable-atomic = ["dep:portable-atomic"] +non-standard-zstd = ["watermelon-mini/non-standard-zstd", "watermelon-net/non-standard-zstd", "watermelon-proto/non-standard-zstd"] + +[lints] +workspace = true diff --git a/watermelon/LICENSE-APACHE b/watermelon/LICENSE-APACHE new file mode 120000 index 0000000..965b606 --- /dev/null +++ b/watermelon/LICENSE-APACHE @@ -0,0 +1 @@ +../LICENSE-APACHE \ No newline at end of file diff --git a/watermelon/LICENSE-MIT b/watermelon/LICENSE-MIT new file mode 120000 index 0000000..76219eb --- /dev/null +++ b/watermelon/LICENSE-MIT @@ -0,0 +1 @@ +../LICENSE-MIT \ No newline at end of file diff --git a/watermelon/README.md b/watermelon/README.md new file mode 120000 index 0000000..32d46ee --- /dev/null +++ b/watermelon/README.md @@ -0,0 +1 @@ +../README.md \ No newline at end of file diff --git a/watermelon/src/atomic.rs b/watermelon/src/atomic.rs new file mode 100644 index 0000000..66fe5ce --- /dev/null +++ b/watermelon/src/atomic.rs @@ -0,0 +1,4 @@ +#[cfg(feature = "portable-atomic")] +pub(crate) use portable_atomic::*; +#[cfg(not(feature = "portable-atomic"))] +pub(crate) use std::sync::atomic::*; diff --git a/watermelon/src/client/builder.rs b/watermelon/src/client/builder.rs new file mode 100644 index 0000000..9c87b90 --- /dev/null +++ b/watermelon/src/client/builder.rs @@ -0,0 +1,201 @@ +use std::time::Duration; + +use watermelon_mini::{AuthenticationMethod, ConnectError}; +use watermelon_proto::{ServerAddr, Subject}; + +#[cfg(feature = "from-env")] +use super::from_env::FromEnv; +use crate::core::Client; + +/// A builder for [`Client`] +/// +/// Obtained from [`Client::builder`]. +#[derive(Debug)] +pub struct ClientBuilder { + pub(crate) auth_method: Option, + pub(crate) flush_interval: Duration, + pub(crate) inbox_prefix: Subject, + pub(crate) echo: Echo, + pub(crate) default_response_timeout: Duration, + #[cfg(feature = "non-standard-zstd")] + pub(crate) non_standard_zstd: bool, +} + +/// Whether or not to allow messages published by this client to be echoed back to it's own subscriptions +#[derive(Debug, Copy, Clone, Default)] +pub enum Echo { + /// Do not allow messages published by this client to be echoed back to it's own [`Subscription`]s + /// + /// [`Subscription`]: crate::core::Subscription + #[default] + Prevent, + /// Allow messages published by this client to be echoed back to it's own [`Subscription`]s + /// + /// [`Subscription`]: crate::core::Subscription + Allow, +} + +impl ClientBuilder { + pub(super) fn new() -> Self { + Self { + auth_method: None, + flush_interval: Duration::ZERO, + inbox_prefix: Subject::from_static("_INBOX"), + echo: Echo::Prevent, + default_response_timeout: Duration::from_secs(5), + #[cfg(feature = "non-standard-zstd")] + non_standard_zstd: true, + } + } + + /// Construct [`ClientBuilder`] from environment variables + /// + /// Reads the following environment variables into [`ClientBuilder`]: + /// + /// Authentication: + /// + /// * `NATS_JWT` and `NATS_NKEY`: use nkey authentication + /// * `NATS_CREDS_FILE`: read JWT and NKEY from the provided `.creds` file + /// * `NATS_USERNAME` and `NATS_PASSWORD`: use username and password authentication + /// + /// # Panics + /// + /// It panics if: + /// + /// - it is not possible to get the environment variables; + /// - an error occurs when trying to read the credentials file; + /// - the credentials file is invalid. + #[cfg(feature = "from-env")] + #[must_use] + pub fn from_env() -> Self { + use super::from_env; + + let env = envy::from_env::().expect("FromEnv deserialization error"); + + let mut this = Self::new(); + + match env.auth { + from_env::AuthenticationMethod::Creds { jwt, nkey } => { + this = this.authentication_method(Some(AuthenticationMethod::Creds { jwt, nkey })); + } + from_env::AuthenticationMethod::CredsFile { creds_file } => { + let contents = std::fs::read_to_string(creds_file).expect("read credentials file"); + let auth = + AuthenticationMethod::from_creds(&contents).expect("parse credentials file"); + this = this.authentication_method(Some(auth)); + } + from_env::AuthenticationMethod::UserAndPassword { username, password } => { + this = this.authentication_method(Some(AuthenticationMethod::UserAndPassword { + username, + password, + })); + } + from_env::AuthenticationMethod::None => { + this = this.authentication_method(None); + } + } + + if let Some(inbox_prefix) = env.inbox_prefix { + this = this.inbox_prefix(inbox_prefix); + } + + this + } + + /// Define an authentication method + #[must_use] + pub fn authentication_method(mut self, auth_method: Option) -> Self { + self.auth_method = auth_method; + self + } + + /// Define a flush interval + /// + /// Setting a non-zero flush interval allows the client to generate + /// larger TLS and TCP packets at the cost of increased latency. Using + /// a value greater than a few seconds may break the client in + /// unexpected ways. + /// + /// Setting this to [`Duration::ZERO`] causes the client to send messages + /// as fast as the network will allow, trading off smaller packets for + /// lower latency. + /// + /// Default: 0 + #[must_use] + pub fn flush_interval(mut self, flush_interval: Duration) -> Self { + self.flush_interval = flush_interval; + self + } + + /// Configure the inbox prefix to which replies from the NATS server will be received + /// + /// Default: `_INBOX` + #[must_use] + pub fn inbox_prefix(mut self, inbox_prefix: Subject) -> Self { + self.inbox_prefix = inbox_prefix; + self + } + + /// Whether or not to allow messages published by this client to be echoed back to it's own [`Subscription`]s + /// + /// Setting this option to [`Echo::Allow`] will allow [`Subscription`]s created by + /// this client to receive messages by itself published. + /// + /// Default: [`Echo::Prevent`]. + /// + /// [`Subscription`]: crate::core::Subscription + #[must_use] + pub fn echo(mut self, echo: Echo) -> Self { + self.echo = echo; + self + } + + /// The default timeout for [`ResponseFut`] + /// + /// Defines how long we should wait for a response in [`Client::request`]. + /// + /// Default: 5 seconds. + /// + /// [`ResponseFut`]: crate::core::request::ResponseFut + #[must_use] + pub fn default_response_timeout(mut self, timeout: Duration) -> Self { + self.default_response_timeout = timeout; + self + } + + /// Have the client compress the connection using zstd when talking to a NATS server + /// behind a custom zstd proxy + /// + /// The NATS protocol and applications developed on top of it can make inefficient + /// use of the network, making applications running on extremely slow or expensive internet + /// connections infeasible. This option adds a non-standard zstd compression + /// feature on top of the client which, when used in conjunction with a custom zstd reverse proxy + /// put in from of the NATS server allows for large bandwidth savings. + /// + /// This option is particularly powerful when combined with [`ClientBuilder::flush_interval`]. + /// + /// This option is automatically disabled when connecting to an unsupported server. + /// + /// Default: `true` when compiled with the `non-standard-zstd` option. + #[cfg(feature = "non-standard-zstd")] + #[must_use] + pub fn non_standard_zstd(mut self, non_standard_zstd: bool) -> Self { + self.non_standard_zstd = non_standard_zstd; + self + } + + /// Creates a new [`Client`], connecting to the given address. + /// + /// # Errors + /// + /// It returns an error if the connection fails. + pub async fn connect(self, addr: ServerAddr) -> Result { + Client::connect(addr, self).await + } +} + +impl Default for ClientBuilder { + fn default() -> Self { + Self::new() + } +} diff --git a/watermelon/src/client/commands/mod.rs b/watermelon/src/client/commands/mod.rs new file mode 100644 index 0000000..094fd8b --- /dev/null +++ b/watermelon/src/client/commands/mod.rs @@ -0,0 +1,11 @@ +pub use self::publish::{ + ClientPublish, DoClientPublish, DoOwnedClientPublish, OwnedClientPublish, Publish, + PublishBuilder, +}; +pub use self::request::{ + ClientRequest, DoClientRequest, DoOwnedClientRequest, OwnedClientRequest, Request, + RequestBuilder, ResponseError, ResponseFut, +}; + +mod publish; +mod request; diff --git a/watermelon/src/client/commands/publish.rs b/watermelon/src/client/commands/publish.rs new file mode 100644 index 0000000..9fe4fa4 --- /dev/null +++ b/watermelon/src/client/commands/publish.rs @@ -0,0 +1,302 @@ +use std::{ + fmt::{self, Debug}, + future::IntoFuture, +}; + +use bytes::Bytes; +use futures_core::future::BoxFuture; +use watermelon_proto::{ + headers::{HeaderMap, HeaderName, HeaderValue}, + MessageBase, Subject, +}; + +use crate::{ + client::{Client, ClientClosedError, TryCommandError}, + handler::HandlerCommand, +}; + +use super::Request; + +/// A publishable message +#[derive(Debug, Clone)] +pub struct Publish { + pub(super) subject: Subject, + pub(super) reply_subject: Option, + pub(super) headers: HeaderMap, + pub(super) payload: Bytes, +} + +/// A constructor for a publishable message +/// +/// Obtained from [`Publish::builder`]. +#[derive(Debug)] +pub struct PublishBuilder { + publish: Publish, +} + +/// A constructor for a publishable message to be sent using the given client +/// +/// Obtained from [`Client::publish`]. +pub struct ClientPublish<'a> { + client: &'a Client, + publish: Publish, +} + +/// A publisheable message ready to be published to the given client +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct DoClientPublish<'a> { + client: &'a Client, + publish: Publish, +} + +/// A constructor for a publishable message to be sent using the given owned client +/// +/// Obtained from [`Client::publish_owned`]. +pub struct OwnedClientPublish { + client: Client, + publish: Publish, +} + +/// A publisheable message ready to be published to the given owned client +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct DoOwnedClientPublish { + client: Client, + publish: Publish, +} + +macro_rules! publish { + () => { + #[must_use] + pub fn reply_subject(mut self, reply_subject: Option) -> Self { + self.publish_mut().reply_subject = reply_subject; + self + } + + #[must_use] + pub fn header(mut self, name: HeaderName, value: HeaderValue) -> Self { + self.publish_mut().headers.insert(name, value); + self + } + + #[must_use] + pub fn headers(mut self, headers: HeaderMap) -> Self { + self.publish_mut().headers = headers; + self + } + }; +} + +impl Publish { + /// Build a new [`Publish`] + #[must_use] + pub fn builder(subject: Subject) -> PublishBuilder { + PublishBuilder::subject(subject) + } + + /// Publish this message to `client` + pub fn client(self, client: &Client) -> DoClientPublish<'_> { + DoClientPublish { + client, + publish: self, + } + } + + /// Publish this message to `client`, taking ownership of it + pub fn client_owned(self, client: Client) -> DoOwnedClientPublish { + DoOwnedClientPublish { + client, + publish: self, + } + } + + pub fn into_request(self) -> Request { + Request { + publish: self, + response_timeout: None, + } + } + + fn into_message_base(self) -> MessageBase { + let Self { + subject, + reply_subject, + headers, + payload, + } = self; + MessageBase { + subject, + reply_subject, + headers, + payload, + } + } +} + +impl PublishBuilder { + #[must_use] + pub fn subject(subject: Subject) -> Self { + Self { + publish: Publish { + subject, + reply_subject: None, + headers: HeaderMap::new(), + payload: Bytes::new(), + }, + } + } + + publish!(); + + #[must_use] + pub fn payload(mut self, payload: Bytes) -> Publish { + self.publish.payload = payload; + self.publish + } + + fn publish_mut(&mut self) -> &mut Publish { + &mut self.publish + } +} + +impl<'a> ClientPublish<'a> { + pub(crate) fn build(client: &'a Client, subject: Subject) -> Self { + Self { + client, + publish: PublishBuilder::subject(subject).publish, + } + } + + publish!(); + + pub fn payload(mut self, payload: Bytes) -> DoClientPublish<'a> { + self.publish.payload = payload; + self.publish.client(self.client) + } + + /// Convert this into [`OwnedClientPublish`] + #[must_use] + pub fn to_owned(self) -> OwnedClientPublish { + OwnedClientPublish { + client: self.client.clone(), + publish: self.publish, + } + } + + fn publish_mut(&mut self) -> &mut Publish { + &mut self.publish + } +} + +impl OwnedClientPublish { + pub(crate) fn build(client: Client, subject: Subject) -> Self { + Self { + client, + publish: PublishBuilder::subject(subject).publish, + } + } + + publish!(); + + pub fn payload(mut self, payload: Bytes) -> DoOwnedClientPublish { + self.publish.payload = payload; + self.publish.client_owned(self.client) + } + + fn publish_mut(&mut self) -> &mut Publish { + &mut self.publish + } +} + +impl DoClientPublish<'_> { + /// Publish this message if there's enough immediately available space in the internal buffers + /// + /// This method will publish the given message only if there's enough + /// immediately available space to enqueue it in the client's + /// networking stack. + /// + /// # Errors + /// + /// It returns an error if the client's buffer is full or if the client has been closed. + pub fn try_publish(self) -> Result<(), TryCommandError> { + try_publish(self.client, self.publish) + } +} + +impl<'a> IntoFuture for DoClientPublish<'a> { + type Output = Result<(), ClientClosedError>; + type IntoFuture = BoxFuture<'a, Self::Output>; + + fn into_future(self) -> Self::IntoFuture { + Box::pin(async move { publish(self.client, self.publish).await }) + } +} + +impl DoOwnedClientPublish { + /// Publish this message if there's enough immediately available space in the internal buffers + /// + /// This method will publish the given message only if there's enough + /// immediately available space to enqueue it in the client's + /// networking stack. + /// + /// # Errors + /// + /// It returns an error if the client's buffer is full or if the client has been closed. + pub fn try_publish(self) -> Result<(), TryCommandError> { + try_publish(&self.client, self.publish) + } +} + +impl IntoFuture for DoOwnedClientPublish { + type Output = Result<(), ClientClosedError>; + type IntoFuture = BoxFuture<'static, Self::Output>; + + fn into_future(self) -> Self::IntoFuture { + Box::pin(async move { publish(&self.client, self.publish).await }) + } +} + +fn try_publish(client: &Client, publish: Publish) -> Result<(), TryCommandError> { + client.try_enqueue_command(HandlerCommand::Publish { + message: publish.into_message_base(), + }) +} + +async fn publish(client: &Client, publish: Publish) -> Result<(), ClientClosedError> { + client + .enqueue_command(HandlerCommand::Publish { + message: publish.into_message_base(), + }) + .await +} + +impl Debug for ClientPublish<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ClientPublish") + .field("publish", &self.publish) + .finish_non_exhaustive() + } +} + +impl Debug for DoClientPublish<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("DoClientPublish") + .field("publish", &self.publish) + .finish_non_exhaustive() + } +} + +impl Debug for OwnedClientPublish { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("OwnedClientPublish") + .field("publish", &self.publish) + .finish_non_exhaustive() + } +} + +impl Debug for DoOwnedClientPublish { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("DoOwnedClientPublish") + .field("publish", &self.publish) + .finish_non_exhaustive() + } +} diff --git a/watermelon/src/client/commands/request.rs b/watermelon/src/client/commands/request.rs new file mode 100644 index 0000000..c735319 --- /dev/null +++ b/watermelon/src/client/commands/request.rs @@ -0,0 +1,416 @@ +use std::{ + fmt::{self, Debug}, + future::{Future, IntoFuture}, + num::NonZeroU64, + pin::Pin, + task::{Context, Poll}, + time::Duration, +}; + +use bytes::Bytes; +use futures_core::{future::BoxFuture, Stream}; +use pin_project_lite::pin_project; +use tokio::time::{sleep, Sleep}; +use watermelon_proto::{ + error::ServerError, + headers::{HeaderMap, HeaderName, HeaderValue}, + ServerMessage, StatusCode, Subject, +}; + +use crate::{ + client::{Client, ClientClosedError, TryCommandError}, + core::MultiplexedSubscription, + subscription::Subscription, +}; + +use super::Publish; + +/// A publishable request +#[derive(Debug, Clone)] +pub struct Request { + pub(super) publish: Publish, + pub(super) response_timeout: Option, +} + +/// A constructor for a publishable request +/// +/// Obtained from [`Request::builder`]. +#[derive(Debug)] +pub struct RequestBuilder { + request: Request, +} + +/// A constructor for a publishable request to be sent using the given client +/// +/// Obtained from [`Client::request`]. +pub struct ClientRequest<'a> { + client: &'a Client, + request: Request, +} + +/// A publisheable request ready to be published to the given client +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct DoClientRequest<'a> { + client: &'a Client, + request: Request, +} + +/// A constructor for a publishable request to be sent using the given owned client +/// +/// Obtained from [`Client::request_owned`]. +pub struct OwnedClientRequest { + client: Client, + request: Request, +} + +/// A publisheable request ready to be published to the given owned client +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct DoOwnedClientRequest { + client: Client, + request: Request, +} + +pin_project! { + /// A [`Future`] for receiving a response + #[derive(Debug)] + #[must_use = "consider using a `Publish` instead of `Request` if uninterested in the response"] + pub struct ResponseFut { + subscription: ResponseSubscription, + #[pin] + timeout: Sleep, + } +} + +#[derive(Debug)] +enum ResponseSubscription { + Multiplexed(MultiplexedSubscription), + Subscription(Subscription), +} + +/// An error encountered while waiting for a response +#[derive(Debug, thiserror::Error)] +pub enum ResponseError { + /// The [`Subscription`] encountered a server error + #[error("server error")] + ServerError(#[source] ServerError), + /// The NATS server told us that no subscriptions are present for the requested subject + #[error("no responders")] + NoResponders, + /// A response hasn't been received within the timeout + #[error("received no response within the timeout window")] + TimedOut, + /// The [`Subscription`] was closed without yielding any message + /// + /// On a multiplexed subscription this may mean that the client + /// reconnected to the server + #[error("subscription closed")] + SubscriptionClosed, +} + +macro_rules! request { + () => { + #[must_use] + pub fn reply_subject(mut self, reply_subject: Option) -> Self { + self.request_mut().publish.reply_subject = reply_subject; + self + } + + #[must_use] + pub fn header(mut self, name: HeaderName, value: HeaderValue) -> Self { + self.request_mut().publish.headers.insert(name, value); + self + } + + #[must_use] + pub fn headers(mut self, headers: HeaderMap) -> Self { + self.request_mut().publish.headers = headers; + self + } + + #[must_use] + pub fn response_timeout(mut self, timeout: Duration) -> Self { + self.request_mut().response_timeout = Some(timeout); + self + } + }; +} + +impl Request { + /// Build a new [`Request`] + #[must_use] + pub fn builder(subject: Subject) -> RequestBuilder { + RequestBuilder::subject(subject) + } + + /// Publish this request to `client` + pub fn client(self, client: &Client) -> DoClientRequest<'_> { + DoClientRequest { + client, + request: self, + } + } + + /// Publish this request to `client`, taking ownership of it + pub fn client_owned(self, client: Client) -> DoOwnedClientRequest { + DoOwnedClientRequest { + client, + request: self, + } + } +} + +impl RequestBuilder { + #[must_use] + pub fn subject(subject: Subject) -> Self { + Self { + request: Request { + publish: Publish { + subject, + reply_subject: None, + headers: HeaderMap::new(), + payload: Bytes::new(), + }, + response_timeout: None, + }, + } + } + + request!(); + + #[must_use] + pub fn payload(mut self, payload: Bytes) -> Request { + self.request.publish.payload = payload; + self.request + } + + fn request_mut(&mut self) -> &mut Request { + &mut self.request + } +} + +impl<'a> ClientRequest<'a> { + pub(crate) fn build(client: &'a Client, subject: Subject) -> Self { + Self { + client, + request: RequestBuilder::subject(subject).request, + } + } + + request!(); + + pub fn payload(mut self, payload: Bytes) -> DoClientRequest<'a> { + self.request.publish.payload = payload; + self.request.client(self.client) + } + + /// Convert this into [`OwnedClientRequest`] + #[must_use] + pub fn to_owned(self) -> OwnedClientRequest { + OwnedClientRequest { + client: self.client.clone(), + request: self.request, + } + } + + fn request_mut(&mut self) -> &mut Request { + &mut self.request + } +} + +impl OwnedClientRequest { + pub(crate) fn build(client: Client, subject: Subject) -> Self { + Self { + client, + request: RequestBuilder::subject(subject).request, + } + } + + request!(); + + pub fn payload(mut self, payload: Bytes) -> DoOwnedClientRequest { + self.request.publish.payload = payload; + self.request.client_owned(self.client) + } + + fn request_mut(&mut self) -> &mut Request { + &mut self.request + } +} + +impl DoClientRequest<'_> { + /// Publish this request if there's enough immediately available space in the internal buffers + /// + /// This method will publish the given request only if there's enough + /// immediately available space to enqueue it in the client's + /// networking stack. + /// + /// # Errors + /// + /// It returns an error if the client's buffer is full or if the client has been closed. + pub fn try_request(self) -> Result { + try_request(self.client, self.request) + } +} + +impl<'a> IntoFuture for DoClientRequest<'a> { + type Output = Result; + type IntoFuture = BoxFuture<'a, Self::Output>; + + fn into_future(self) -> Self::IntoFuture { + Box::pin(async move { request(self.client, self.request).await }) + } +} + +impl DoOwnedClientRequest { + /// Request this message if there's enough immediately available space in the internal buffers + /// + /// This method will publish the given request only if there's enough + /// immediately available space to enqueue it in the client's + /// networking stack. + /// + /// # Errors + /// + /// It returns an error if the client's buffer is full or if the client has been closed. + pub fn try_request(self) -> Result { + try_request(&self.client, self.request) + } +} + +impl IntoFuture for DoOwnedClientRequest { + type Output = Result; + type IntoFuture = BoxFuture<'static, Self::Output>; + + fn into_future(self) -> Self::IntoFuture { + Box::pin(async move { request(&self.client, self.request).await }) + } +} + +impl Future for ResponseFut { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + match this.subscription { + ResponseSubscription::Multiplexed(receiver) => match Pin::new(receiver).poll(cx) { + Poll::Pending => match this.timeout.poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(()) => Poll::Ready(Err(ResponseError::TimedOut)), + }, + Poll::Ready(Ok(message)) + if message.status_code == Some(StatusCode::NO_RESPONDERS) => + { + Poll::Ready(Err(ResponseError::NoResponders)) + } + Poll::Ready(Ok(message)) => Poll::Ready(Ok(message)), + Poll::Ready(Err(_err)) => Poll::Ready(Err(ResponseError::SubscriptionClosed)), + }, + ResponseSubscription::Subscription(subscription) => { + match Pin::new(subscription).poll_next(cx) { + Poll::Pending => match this.timeout.poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(()) => Poll::Ready(Err(ResponseError::TimedOut)), + }, + Poll::Ready(Some(Ok(message))) + if message.status_code == Some(StatusCode::NO_RESPONDERS) => + { + Poll::Ready(Err(ResponseError::NoResponders)) + } + Poll::Ready(Some(Ok(message))) => Poll::Ready(Ok(message)), + Poll::Ready(Some(Err(server_error))) => { + Poll::Ready(Err(ResponseError::ServerError(server_error))) + } + Poll::Ready(None) => Poll::Ready(Err(ResponseError::SubscriptionClosed)), + } + } + } + } +} + +fn try_request(client: &Client, request: Request) -> Result { + let subscription = if let Some(reply_subject) = &request.publish.reply_subject { + let subscription = client.try_subscribe(reply_subject.clone(), None)?; + client.lazy_unsubscribe(subscription.id, Some(NonZeroU64::new(1).unwrap())); + + request.publish.client(client).try_publish()?; + ResponseSubscription::Subscription(subscription) + } else { + let receiver = client.try_multiplexed_request( + request.publish.subject, + request.publish.headers, + request.publish.payload, + )?; + ResponseSubscription::Multiplexed(receiver) + }; + + let timeout = sleep( + request + .response_timeout + .unwrap_or(client.default_response_timeout()), + ); + Ok(ResponseFut { + subscription, + timeout, + }) +} + +async fn request(client: &Client, request: Request) -> Result { + let subscription = if let Some(reply_subject) = &request.publish.reply_subject { + let subscription = client.subscribe(reply_subject.clone(), None).await?; + client.lazy_unsubscribe(subscription.id, Some(NonZeroU64::new(1).unwrap())); + + request.publish.client(client).await?; + ResponseSubscription::Subscription(subscription) + } else { + let receiver = client + .multiplexed_request( + request.publish.subject, + request.publish.headers, + request.publish.payload, + ) + .await?; + ResponseSubscription::Multiplexed(receiver) + }; + + let timeout = sleep( + request + .response_timeout + .unwrap_or(client.default_response_timeout()), + ); + Ok(ResponseFut { + subscription, + timeout, + }) +} + +impl Debug for ClientRequest<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ClientRequest") + .field("request", &self.request) + .finish_non_exhaustive() + } +} + +impl Debug for DoClientRequest<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("DoClientRequest") + .field("request", &self.request) + .finish_non_exhaustive() + } +} + +impl Debug for OwnedClientRequest { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("OwnedClientRequest") + .field("request", &self.request) + .finish_non_exhaustive() + } +} + +impl Debug for DoOwnedClientRequest { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("DoOwnedClientRequest") + .field("request", &self.request) + .finish_non_exhaustive() + } +} diff --git a/watermelon/src/client/from_env.rs b/watermelon/src/client/from_env.rs new file mode 100644 index 0000000..20dac9b --- /dev/null +++ b/watermelon/src/client/from_env.rs @@ -0,0 +1,42 @@ +use std::path::PathBuf; + +use serde::{de, Deserialize, Deserializer}; +use watermelon_nkeys::KeyPair; +use watermelon_proto::Subject; + +#[derive(Debug, Deserialize)] +pub(super) struct FromEnv { + #[serde(flatten)] + pub(super) auth: AuthenticationMethod, + pub(super) inbox_prefix: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(untagged)] +pub(super) enum AuthenticationMethod { + Creds { + #[serde(rename = "nats_jwt")] + jwt: String, + #[serde(rename = "nats_nkey", deserialize_with = "deserialize_nkey")] + nkey: KeyPair, + }, + CredsFile { + #[serde(rename = "nats_creds_file")] + creds_file: PathBuf, + }, + UserAndPassword { + #[serde(rename = "nats_username")] + username: String, + #[serde(rename = "nats_password")] + password: String, + }, + None, +} + +fn deserialize_nkey<'de, D>(deserializer: D) -> Result +where + D: Deserializer<'de>, +{ + let secret = String::deserialize(deserializer)?; + KeyPair::from_encoded_seed(&secret).map_err(de::Error::custom) +} diff --git a/watermelon/src/client/jetstream/commands/consumer_batch.rs b/watermelon/src/client/jetstream/commands/consumer_batch.rs new file mode 100644 index 0000000..b3051b5 --- /dev/null +++ b/watermelon/src/client/jetstream/commands/consumer_batch.rs @@ -0,0 +1,143 @@ +use std::{ + pin::Pin, + task::{Context, Poll}, + time::Duration, +}; + +use futures_core::{FusedStream, Future, Stream}; +use pin_project_lite::pin_project; +use serde_json::json; +use tokio::time::{sleep, Sleep}; +use watermelon_proto::{error::ServerError, ServerMessage, StatusCode}; + +use crate::{ + client::{Consumer, JetstreamClient, JetstreamError2}, + subscription::Subscription, +}; + +pin_project! { + /// A consumer batch request + /// + /// Obtained from [`JetstreamClient::consumer_batch`]. + #[derive(Debug)] + #[must_use = "streams do nothing unless polled"] + pub struct ConsumerBatch { + subscription: Subscription, + #[pin] + timeout: Sleep, + pending_msgs: usize, + } +} + +#[derive(Debug, thiserror::Error)] +pub enum ConsumerBatchError { + #[error("an error returned by the server")] + ServerError(#[source] ServerError), + #[error("unexpected status code")] + UnexpectedStatus(ServerMessage), +} + +impl ConsumerBatch { + pub(crate) fn new( + consumer: &Consumer, + client: JetstreamClient, + expires: Duration, + max_msgs: usize, + ) -> impl Future> { + let subject = format!( + "{}.CONSUMER.MSG.NEXT.{}.{}", + client.prefix, consumer.stream_name, consumer.config.name + ) + .try_into(); + + async move { + let subject = subject.map_err(JetstreamError2::Subject)?; + let incoming_subject = client.client.create_inbox_subject(); + let payload = serde_json::to_vec(&if expires.is_zero() { + json!({ + "batch": max_msgs, + "no_wait": true, + }) + } else { + json!({ + "batch": max_msgs, + "expires": expires.as_nanos(), + "no_wait": true + }) + }) + .map_err(JetstreamError2::Json)?; + + let subscription = client + .client + .subscribe(incoming_subject.clone(), None) + .await + .map_err(JetstreamError2::ClientClosed)?; + client + .client + .publish(subject) + .reply_subject(Some(incoming_subject.clone())) + .payload(payload.into()) + .await + .map_err(JetstreamError2::ClientClosed)?; + + let timeout = sleep(expires.saturating_add(client.request_timeout)); + Ok(Self { + subscription, + timeout, + pending_msgs: max_msgs, + }) + } + } +} + +impl Stream for ConsumerBatch { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + + if *this.pending_msgs == 0 { + return Poll::Ready(None); + } + + match Pin::new(this.subscription).poll_next(cx) { + Poll::Pending => match this.timeout.poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(()) => { + *this.pending_msgs = 0; + Poll::Ready(None) + } + }, + Poll::Ready(Some(Ok(msg))) => match msg.status_code { + None | Some(StatusCode::OK) => { + *this.pending_msgs -= 1; + + Poll::Ready(Some(Ok(msg))) + } + Some(StatusCode::IDLE_HEARTBEAT) => { + cx.waker().wake_by_ref(); + Poll::Pending + } + Some(StatusCode::TIMEOUT | StatusCode::NOT_FOUND) => { + *this.pending_msgs = 0; + Poll::Ready(None) + } + _ => Poll::Ready(Some(Err(ConsumerBatchError::UnexpectedStatus(msg)))), + }, + Poll::Ready(Some(Err(err))) => { + *this.pending_msgs = 0; + Poll::Ready(Some(Err(ConsumerBatchError::ServerError(err)))) + } + Poll::Ready(None) => { + *this.pending_msgs = 0; + Poll::Ready(None) + } + } + } +} + +impl FusedStream for ConsumerBatch { + fn is_terminated(&self) -> bool { + self.pending_msgs == 0 + } +} diff --git a/watermelon/src/client/jetstream/commands/consumer_list.rs b/watermelon/src/client/jetstream/commands/consumer_list.rs new file mode 100644 index 0000000..4f77d5d --- /dev/null +++ b/watermelon/src/client/jetstream/commands/consumer_list.rs @@ -0,0 +1,117 @@ +use std::{ + collections::VecDeque, + fmt::Display, + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use futures_core::{future::BoxFuture, FusedStream, Stream}; +use serde::Deserialize; +use serde_json::json; +use watermelon_proto::Subject; + +use crate::client::{self, jetstream::JetstreamError2, JetstreamClient}; + +/// A request to list consumers of a stream +/// +/// Obtained from [`JetstreamClient::consumers`]. +#[must_use = "streams do nothing unless polled"] +pub struct Consumers { + client: JetstreamClient, + offset: u32, + partial_subject: Subject, + fetch: Option>>, + buffer: VecDeque, + exhausted: bool, +} + +#[derive(Debug, Deserialize)] +struct ConsumersResponse { + limit: u32, + consumers: VecDeque, +} + +impl Consumers { + pub(crate) fn new(client: JetstreamClient, stream_name: impl Display) -> Self { + let partial_subject = format!("CONSUMER.LIST.{stream_name}") + .try_into() + .expect("stream name is valid"); + Self { + client, + offset: 0, + partial_subject, + fetch: None, + buffer: VecDeque::new(), + exhausted: false, + } + } +} + +impl Stream for Consumers { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + + if let Some(consumer) = this.buffer.pop_front() { + return Poll::Ready(Some(Ok(consumer))); + } + + if this.exhausted { + return Poll::Ready(None); + } + + let fetch = this.fetch.get_or_insert_with(|| { + let client = this.client.clone(); + let partial_subject = this.partial_subject.clone(); + let offset = this.offset; + + Box::pin(async move { + let response_fut = client + .client() + .request(client.subject_for_request(&partial_subject)) + .response_timeout(client.request_timeout) + .payload( + serde_json::to_vec(&json!({ + "offset": offset, + })) + .unwrap() + .into(), + ) + .await + .map_err(JetstreamError2::ClientClosed)?; + let response = response_fut.await.map_err(JetstreamError2::ResponseError)?; + let payload = serde_json::from_slice(&response.base.payload) + .map_err(JetstreamError2::Json)?; + Ok(payload) + }) + }); + + match Pin::new(fetch).poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Ok(response)) => { + this.fetch = None; + this.buffer = response.consumers; + if this.buffer.len() < response.limit as usize { + this.exhausted = true; + } else if !this.buffer.is_empty() { + this.offset += 1; + } + + cx.waker().wake_by_ref(); + Poll::Pending + } + Poll::Ready(Err(err)) => { + this.fetch = None; + Poll::Ready(Some(Err(err))) + } + } + } +} + +impl FusedStream for Consumers { + fn is_terminated(&self) -> bool { + self.buffer.is_empty() && self.exhausted + } +} diff --git a/watermelon/src/client/jetstream/commands/consumer_stream.rs b/watermelon/src/client/jetstream/commands/consumer_stream.rs new file mode 100644 index 0000000..22c3a51 --- /dev/null +++ b/watermelon/src/client/jetstream/commands/consumer_stream.rs @@ -0,0 +1,127 @@ +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, + time::Duration, +}; + +use futures_core::{future::BoxFuture, FusedStream, Stream}; +use pin_project_lite::pin_project; +use watermelon_proto::ServerMessage; + +use crate::client::{Consumer, JetstreamClient, JetstreamError2}; + +use super::{consumer_batch::ConsumerBatchError, ConsumerBatch}; + +pin_project! { + /// A consumer stream of batch requests + /// + /// Obtained from [`JetstreamClient::consumer_stream`]. + #[must_use = "streams do nothing unless polled"] + pub struct ConsumerStream { + #[pin] + status: ConsumerStreamStatus, + consumer: Consumer, + client: JetstreamClient, + + expires: Duration, + max_msgs: usize, + } +} + +pin_project! { + #[project = ConsumerStreamStatusProj] + enum ConsumerStreamStatus { + Polling { + future: BoxFuture<'static, Result>, + }, + RunningBatch { + #[pin] + batch: ConsumerBatch, + }, + Broken, + } +} + +#[derive(Debug, thiserror::Error)] +pub enum ConsumerStreamError { + #[error("consumer batch error")] + BatchError(#[source] ConsumerBatchError), + #[error("jetstream error")] + Jetstream(#[source] JetstreamError2), +} + +impl ConsumerStream { + pub(crate) fn new( + consumer: Consumer, + client: JetstreamClient, + expires: Duration, + max_msgs: usize, + ) -> Self { + let poll_fut = { + let client = client.clone(); + Box::pin(ConsumerBatch::new(&consumer, client, expires, max_msgs)) + }; + + Self { + status: ConsumerStreamStatus::Polling { future: poll_fut }, + consumer, + client, + + expires, + max_msgs, + } + } +} + +impl Stream for ConsumerStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + match this.status.as_mut().project() { + ConsumerStreamStatusProj::RunningBatch { batch } => match batch.poll_next(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Some(Ok(msg))) => Poll::Ready(Some(Ok(msg))), + Poll::Ready(Some(Err(err))) => { + this.status.set(ConsumerStreamStatus::Broken); + Poll::Ready(Some(Err(ConsumerStreamError::BatchError(err)))) + } + Poll::Ready(None) => { + this.status.set(ConsumerStreamStatus::Polling { + future: Box::pin(ConsumerBatch::new( + this.consumer, + this.client.clone(), + *this.expires, + *this.max_msgs, + )), + }); + + cx.waker().wake_by_ref(); + Poll::Pending + } + }, + ConsumerStreamStatusProj::Polling { future: fut } => match Pin::new(fut).poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Ok(batch)) => { + this.status + .set(ConsumerStreamStatus::RunningBatch { batch }); + + cx.waker().wake_by_ref(); + Poll::Pending + } + Poll::Ready(Err(err)) => { + this.status.set(ConsumerStreamStatus::Broken); + Poll::Ready(Some(Err(ConsumerStreamError::Jetstream(err)))) + } + }, + ConsumerStreamStatusProj::Broken => Poll::Ready(None), + } + } +} + +impl FusedStream for ConsumerStream { + fn is_terminated(&self) -> bool { + matches!(self.status, ConsumerStreamStatus::Broken) + } +} diff --git a/watermelon/src/client/jetstream/commands/mod.rs b/watermelon/src/client/jetstream/commands/mod.rs new file mode 100644 index 0000000..cecf3bc --- /dev/null +++ b/watermelon/src/client/jetstream/commands/mod.rs @@ -0,0 +1,9 @@ +pub use self::consumer_batch::ConsumerBatch; +pub use self::consumer_list::Consumers; +pub use self::consumer_stream::{ConsumerStream, ConsumerStreamError}; +pub use self::stream_list::Streams; + +mod consumer_batch; +mod consumer_list; +mod consumer_stream; +mod stream_list; diff --git a/watermelon/src/client/jetstream/commands/stream_list.rs b/watermelon/src/client/jetstream/commands/stream_list.rs new file mode 100644 index 0000000..7f911d2 --- /dev/null +++ b/watermelon/src/client/jetstream/commands/stream_list.rs @@ -0,0 +1,110 @@ +use std::{ + collections::VecDeque, + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use futures_core::{future::BoxFuture, FusedStream, Stream}; +use serde::Deserialize; +use serde_json::json; +use watermelon_proto::Subject; + +use crate::client::{self, jetstream::JetstreamError2, JetstreamClient}; + +/// A request to list streams +/// +/// Obtained from [`JetstreamClient::streams`]. +#[must_use = "streams do nothing unless polled"] +pub struct Streams { + client: JetstreamClient, + offset: u32, + fetch: Option>>, + buffer: VecDeque, + exhausted: bool, +} + +#[derive(Debug, Deserialize)] +struct StreamsResponse { + limit: u32, + streams: VecDeque, +} + +impl Streams { + pub(crate) fn new(client: JetstreamClient) -> Self { + Self { + client, + offset: 0, + fetch: None, + buffer: VecDeque::new(), + exhausted: false, + } + } +} + +impl Stream for Streams { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + + if let Some(stream) = this.buffer.pop_front() { + return Poll::Ready(Some(Ok(stream))); + } + + if this.exhausted { + return Poll::Ready(None); + } + + let fetch = this.fetch.get_or_insert_with(|| { + let client = this.client.clone(); + let offset = this.offset; + + Box::pin(async move { + let response_fut = client + .client() + .request(client.subject_for_request(&Subject::from_static("STREAM.LIST"))) + .response_timeout(client.request_timeout) + .payload( + serde_json::to_vec(&json!({ + "offset": offset, + })) + .unwrap() + .into(), + ) + .await + .map_err(JetstreamError2::ClientClosed)?; + let response = response_fut.await.map_err(JetstreamError2::ResponseError)?; + let payload = serde_json::from_slice(&response.base.payload) + .map_err(JetstreamError2::Json)?; + Ok(payload) + }) + }); + + match Pin::new(fetch).poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Ok(response)) => { + this.fetch = None; + this.buffer = response.streams; + if this.buffer.len() < response.limit as usize { + this.exhausted = true; + } else if !this.buffer.is_empty() { + this.offset += 1; + } + + cx.waker().wake_by_ref(); + Poll::Pending + } + Poll::Ready(Err(err)) => { + this.fetch = None; + Poll::Ready(Some(Err(err))) + } + } + } +} + +impl FusedStream for Streams { + fn is_terminated(&self) -> bool { + self.buffer.is_empty() && self.exhausted + } +} diff --git a/watermelon/src/client/jetstream/mod.rs b/watermelon/src/client/jetstream/mod.rs new file mode 100644 index 0000000..eda6033 --- /dev/null +++ b/watermelon/src/client/jetstream/mod.rs @@ -0,0 +1,250 @@ +use std::{fmt::Display, time::Duration}; + +use bytes::Bytes; +use resources::Response; +use serde::{Deserialize, Serialize}; +use watermelon_proto::StatusCode; +use watermelon_proto::{error::SubjectValidateError, Subject}; + +pub use self::commands::{ConsumerBatch, ConsumerStream, ConsumerStreamError, Consumers, Streams}; +pub use self::resources::{ + AckPolicy, Compression, Consumer, ConsumerConfig, ConsumerDurability, ConsumerSpecificConfig, + ConsumerStorage, DeliverPolicy, DiscardPolicy, ReplayPolicy, RetentionPolicy, Storage, Stream, + StreamConfig, StreamState, +}; +use crate::core::Client; + +use super::{ClientClosedError, ResponseError}; + +const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(2); + +mod commands; +mod resources; + +/// A NATS Jetstream client +/// +/// `JetstreamClient` is a `Clone`able handle to a NATS [`Client`], +/// with Jetstream specific configurations. +#[derive(Debug, Clone)] +pub struct JetstreamClient { + client: Client, + prefix: Subject, + request_timeout: Duration, +} + +/// A Jetstream API error +#[derive(Debug, Deserialize, thiserror::Error)] +#[error("jetstream error status={status}")] +pub struct JetstreamError { + #[serde(rename = "code")] + status: StatusCode, + #[serde(rename = "err_code")] + code: JetstreamErrorCode, + description: String, +} + +/// The type of error encountered while processing a Jetstream request +#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(transparent)] +pub struct JetstreamErrorCode(u16); + +/// An error encountered while making a Jetstream request +#[derive(Debug, thiserror::Error)] +pub enum JetstreamError2 { + #[error("invalid subject")] + Subject(#[source] SubjectValidateError), + #[error("client closed")] + ClientClosed(#[source] ClientClosedError), + #[error("client request failure")] + ResponseError(#[source] ResponseError), + #[error("JSON deserialization")] + Json(#[source] serde_json::Error), + #[error("bad response code")] + Status(#[source] JetstreamError), +} + +impl JetstreamClient { + /// Create a Jetstream client using the default configuration + #[must_use] + pub fn new(client: Client) -> Self { + Self::new_with_prefix(client, Subject::from_static("$JS.API")) + } + + /// Create a Jetstream client using the provided `domain` + /// + /// # Errors + /// + /// It returns an error if the subject derived by the `domain` is not valid. + pub fn new_with_domain( + client: Client, + domain: impl Display, + ) -> Result { + let prefix = format!("$JS.{domain}.API").try_into()?; + Ok(Self::new_with_prefix(client, prefix)) + } + + /// Create a Jetstream client using the provided API `prefix` + #[must_use] + pub fn new_with_prefix(client: Client, prefix: Subject) -> Self { + Self { + client, + prefix, + request_timeout: DEFAULT_REQUEST_TIMEOUT, + } + } + + /// List streams present within this client's Jetstream context + pub fn streams(&self) -> Streams { + Streams::new(self.clone()) + } + + /// Obtain a stream present within this client's Jetstream context + /// + /// # Errors + /// + /// It returns an error if the given `name` produces an invalid subject or if an error occurs + /// while creating the stream. + pub async fn stream(&self, name: impl Display) -> Result, JetstreamError2> { + let subject = format!("{}.STREAM.INFO.{}", self.prefix, name) + .try_into() + .map_err(JetstreamError2::Subject)?; + let resp = self + .client + .request(subject) + .response_timeout(self.request_timeout) + .payload(Bytes::new()) + .await + .map_err(JetstreamError2::ClientClosed)?; + let resp = resp.await.map_err(JetstreamError2::ResponseError)?; + + if resp.status_code == Some(StatusCode::NO_RESPONDERS) { + return Err(JetstreamError2::ResponseError(ResponseError::NoResponders)); + } + + let json = serde_json::from_slice::>(&resp.base.payload) + .map_err(JetstreamError2::Json)?; + match json { + Response::Response(stream) => Ok(Some(stream)), + Response::Error { error } if error.code == JetstreamErrorCode::STREAM_NOT_FOUND => { + Ok(None) + } + Response::Error { error } => Err(JetstreamError2::Status(error)), + } + } + + /// List consumers present within this client's Jetstream context + pub fn consumers(&self, stream_name: impl Display) -> Consumers { + Consumers::new(self.clone(), stream_name) + } + + /// Obtain a consumer present within this client's Jetstream context + /// + /// # Errors + /// + /// It returns an error if the given `stream_name` and `consumer_name` produce an invalid + /// subject or if an error occurs while creating the consumer. + pub async fn consumer( + &self, + stream_name: impl Display, + consumer_name: impl Display, + ) -> Result, JetstreamError2> { + let subject = format!( + "{}.CONSUMER.INFO.{}.{}", + self.prefix, stream_name, consumer_name + ) + .try_into() + .map_err(JetstreamError2::Subject)?; + let resp = self + .client + .request(subject) + .response_timeout(self.request_timeout) + .payload(Bytes::new()) + .await + .map_err(JetstreamError2::ClientClosed)?; + let resp = resp.await.map_err(JetstreamError2::ResponseError)?; + + if resp.status_code == Some(StatusCode::NO_RESPONDERS) { + return Err(JetstreamError2::ResponseError(ResponseError::NoResponders)); + } + + let json = serde_json::from_slice::>(&resp.base.payload) + .map_err(JetstreamError2::Json)?; + match json { + Response::Response(stream) => Ok(Some(stream)), + Response::Error { error } if error.code == JetstreamErrorCode::CONSUMER_NOT_FOUND => { + Ok(None) + } + Response::Error { error } => Err(JetstreamError2::Status(error)), + } + } + + /// Run a batch request over the provided `consumer` + /// + /// # Errors + /// + /// An error is returned if the subject is not valid or if the client has been closed. + pub async fn consumer_batch( + &self, + consumer: &Consumer, + expires: Duration, + max_msgs: usize, + ) -> Result { + ConsumerBatch::new(consumer, self.clone(), expires, max_msgs).await + } + + /// Run a stream request over the provided `consumer` + pub fn consumer_stream( + &self, + consumer: Consumer, + expires: Duration, + max_msgs: usize, + ) -> ConsumerStream { + ConsumerStream::new(consumer, self.clone(), expires, max_msgs) + } + + pub(crate) fn subject_for_request(&self, endpoint: &Subject) -> Subject { + Subject::from_dangerous_value(format!("{}.{}", self.prefix, endpoint).into()) + } + + /// Get a reference to the inner NATS Core client + #[must_use] + pub fn client(&self) -> &Client { + &self.client + } + + #[must_use] + pub fn prefix(&self) -> &Subject { + &self.prefix + } +} + +impl JetstreamErrorCode { + pub const NOT_ENABLED: Self = Self(10076); + pub const NOT_ENABLED_FOR_ACCOUNT: Self = Self(10039); + pub const BAD_REQUEST: Self = Self(10003); + + pub const STREAM_NOT_FOUND: Self = Self(10059); + pub const STREAM_NAME_IN_USE: Self = Self(10058); + pub const STREAM_MESSAGE_NOT_FOUND: Self = Self(10037); + pub const STREAM_WRONG_LAST_SEQUENCE: Self = Self(10071); + + pub const COULD_NOT_CREATE_CONSUMER: Self = Self(10012); + pub const CONSUMER_NOT_FOUND: Self = Self(10014); + pub const CONSUMER_NAME_IN_USE: Self = Self(10148); + + pub const CONSUMER_DUPLICATE_FILTER_SUBJECTS: Self = Self(10136); + pub const CONSUMER_OVERLAPPING_FILTER_SUBJECTS: Self = Self(10138); + pub const CONSUMER_FILTER_SUBJECTS_IS_EMPTY: Self = Self(10139); +} + +impl From for JetstreamErrorCode { + fn from(value: u16) -> Self { + Self(value) + } +} + +impl From for u16 { + fn from(value: JetstreamErrorCode) -> Self { + value.0 + } +} diff --git a/watermelon/src/client/jetstream/resources/consumer.rs b/watermelon/src/client/jetstream/resources/consumer.rs new file mode 100644 index 0000000..efb1ae2 --- /dev/null +++ b/watermelon/src/client/jetstream/resources/consumer.rs @@ -0,0 +1,400 @@ +use std::{ + collections::BTreeMap, + num::{NonZeroU32, NonZeroU64}, + time::Duration, +}; + +use chrono::{DateTime, Utc}; +use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; +use watermelon_proto::{QueueGroup, Subject}; + +use super::{duration, duration_vec, nullable_number, option_nonzero}; + +/// A Jetstream consumer +#[derive(Debug, Serialize, Deserialize)] +pub struct Consumer { + pub stream_name: String, + pub config: ConsumerConfig, + #[serde(rename = "created")] + pub created_at: DateTime, +} + +/// A Jetstream consumer configuration +#[derive(Debug)] +pub struct ConsumerConfig { + pub durability: ConsumerDurability, + pub name: String, + pub description: String, + pub deliver_policy: DeliverPolicy, + pub ack_policy: AckPolicy, + pub max_deliver: Option, + pub backoff: Vec, + pub filter_subjects: Vec, + pub replay_policy: ReplayPolicy, + pub rate_limit: Option, + pub flow_control: Option, + pub idle_heartbeat: Duration, + pub headers_only: bool, + + pub specs: ConsumerSpecificConfig, + + // Inactivity threshold. + pub inactive_threshold: Duration, + pub replicas: Option, + pub storage: ConsumerStorage, + pub metadata: BTreeMap, +} + +/// Pull or Push configuration parameters for a consumer +#[derive(Debug)] +pub enum ConsumerSpecificConfig { + Pull { + max_waiting: Option, + max_request_batch: Option, + max_request_expires: Duration, + max_request_max_bytes: Option, + }, + Push { + deliver_subject: Subject, + deliver_group: Option, + }, +} + +/// The durability of the consumer +#[derive(Debug)] +pub enum ConsumerDurability { + Ephemeral, + Durable, +} + +/// The delivery policy of the consumer +#[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)] +#[serde(tag = "deliver_policy")] +pub enum DeliverPolicy { + #[default] + #[serde(rename = "all")] + All, + #[serde(rename = "last")] + Last, + #[serde(rename = "last_per_subject")] + LastPerSubject, + #[serde(rename = "new")] + New, + #[serde(rename = "by_start_sequence")] + StartSequence { + #[serde(rename = "opt_start_seq")] + sequence: u64, + }, + #[serde(rename = "by_start_time")] + StartTime { + #[serde(rename = "opt_start_time")] + from: DateTime, + }, +} + +/// The acknowledgment policy of the consumer +#[derive(Debug, Copy, Clone, Serialize, Deserialize)] +#[serde(tag = "ack_policy", rename_all = "lowercase")] +pub enum AckPolicy { + Explicit { + #[serde(rename = "ack_wait", with = "duration")] + wait: Duration, + #[serde( + rename = "max_ack_pending", + with = "nullable_number", + skip_serializing_if = "Option::is_none" + )] + max_pending: Option, + }, + All { + #[serde(rename = "ack_wait", with = "duration")] + wait: Duration, + #[serde( + rename = "max_ack_pending", + with = "nullable_number", + skip_serializing_if = "Option::is_none" + )] + max_pending: Option, + }, + None, +} + +/// The replay policy of the consumer +#[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum ReplayPolicy { + #[default] + Instant, + Original, +} + +/// Whether the consumer is kept on disk or in memory +#[derive(Debug, Copy, Clone, Default)] +pub enum ConsumerStorage { + #[default] + Disk, + Memory, +} + +#[derive(Debug, Serialize, Deserialize)] +struct RawConsumerConfig { + #[serde(default)] + name: String, + #[serde(default)] + durable_name: String, + #[serde(default)] + description: String, + + #[serde(flatten)] + deliver_policy: DeliverPolicy, + + #[serde(flatten)] + ack_policy: AckPolicy, + + #[serde(skip_serializing_if = "Option::is_none", with = "nullable_number")] + max_deliver: Option, + + #[serde(default, skip_serializing_if = "Vec::is_empty", with = "duration_vec")] + backoff: Vec, + + #[serde(skip_serializing_if = "Option::is_none")] + filter_subject: Option, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + filter_subjects: Vec, + replay_policy: ReplayPolicy, + + #[serde(rename = "rate_limit_bps", skip_serializing_if = "Option::is_none")] + rate_limit: Option, + #[serde(skip_serializing_if = "Option::is_none")] + flow_control: Option, + #[serde(default, skip_serializing_if = "Duration::is_zero", with = "duration")] + idle_heartbeat: Duration, + #[serde(default)] + headers_only: bool, + + // Pull based options. + #[serde(skip_serializing_if = "Option::is_none")] + max_waiting: Option, + #[serde(rename = "max_batch", skip_serializing_if = "Option::is_none")] + max_request_batch: Option, + #[serde( + default, + rename = "max_expires", + skip_serializing_if = "Duration::is_zero", + with = "duration" + )] + max_request_expires: Duration, + #[serde(rename = "max_bytes", skip_serializing_if = "Option::is_none")] + max_request_max_bytes: Option, + + // Push based consumers. + #[serde(rename = "deliver_subject", skip_serializing_if = "Option::is_none")] + deliver_subject: Option, + #[serde(rename = "deliver_group", skip_serializing_if = "Option::is_none")] + deliver_group: Option, + + #[serde(default, with = "duration")] + inactive_threshold: Duration, + #[serde(rename = "num_replicas", with = "option_nonzero")] + replicas: Option, + #[serde(default, rename = "mem_storage")] + storage: ConsumerStorage, + #[serde(default)] + metadata: BTreeMap, +} + +impl Serialize for ConsumerConfig { + fn serialize(&self, serializer: S) -> Result { + let (name, durable_name) = match self.durability { + ConsumerDurability::Ephemeral => (self.name.clone(), String::new()), + ConsumerDurability::Durable => (self.name.clone(), self.name.clone()), + }; + + let (filter_subject, filter_subjects) = match self.filter_subjects.len() { + 1 => (Some(self.filter_subjects[0].clone()), Vec::new()), + _ => (None, self.filter_subjects.clone()), + }; + + let ( + max_waiting, + max_request_batch, + max_request_expires, + max_request_max_bytes, + deliver_subject, + deliver_group, + ) = match &self.specs { + ConsumerSpecificConfig::Pull { + max_waiting, + max_request_batch, + max_request_expires, + max_request_max_bytes, + } => ( + *max_waiting, + *max_request_batch, + *max_request_expires, + *max_request_max_bytes, + None, + None, + ), + ConsumerSpecificConfig::Push { + deliver_subject, + deliver_group, + } => ( + None, + None, + Duration::ZERO, + None, + Some(deliver_subject.clone()), + deliver_group.clone(), + ), + }; + + RawConsumerConfig { + name, + durable_name, + description: self.description.clone(), + + deliver_policy: self.deliver_policy, + ack_policy: self.ack_policy, + max_deliver: self.max_deliver, + backoff: self.backoff.clone(), + filter_subject, + filter_subjects, + replay_policy: self.replay_policy, + rate_limit: self.rate_limit, + flow_control: self.flow_control, + idle_heartbeat: self.idle_heartbeat, + headers_only: self.headers_only, + + // Pull based options. + max_waiting, + max_request_batch, + max_request_expires, + max_request_max_bytes, + + // Push based consumers. + deliver_subject, + deliver_group, + + inactive_threshold: self.inactive_threshold, + replicas: self.replicas, + storage: self.storage, + metadata: self.metadata.clone(), + } + .serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for ConsumerConfig { + fn deserialize>(deserializer: D) -> Result { + let RawConsumerConfig { + name, + durable_name, + description, + deliver_policy, + ack_policy, + max_deliver, + backoff, + filter_subject, + filter_subjects, + replay_policy, + rate_limit, + flow_control, + idle_heartbeat, + headers_only, + max_waiting, + max_request_batch, + max_request_expires, + max_request_max_bytes, + deliver_subject, + deliver_group, + inactive_threshold, + replicas, + storage, + metadata, + } = RawConsumerConfig::deserialize(deserializer)?; + let (durability, name) = if !durable_name.is_empty() { + (ConsumerDurability::Durable, durable_name) + } else if !name.is_empty() { + (ConsumerDurability::Ephemeral, name) + } else { + return Err(de::Error::custom( + "consumer neither has a name or a durable name", + )); + }; + + let filter_subjects = if let Some(filter_subject) = filter_subject { + vec![filter_subject] + } else { + filter_subjects + }; + + let specs = match deliver_subject { + Some(deliver_subject) => ConsumerSpecificConfig::Push { + deliver_subject, + deliver_group, + }, + None => ConsumerSpecificConfig::Pull { + max_waiting, + max_request_batch, + max_request_expires, + max_request_max_bytes, + }, + }; + + Ok(Self { + durability, + name, + description, + deliver_policy, + ack_policy, + max_deliver, + backoff, + filter_subjects, + replay_policy, + rate_limit, + flow_control, + idle_heartbeat, + headers_only, + + specs, + + inactive_threshold, + replicas, + storage, + metadata, + }) + } +} + +impl Serialize for ConsumerStorage { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + matches!(self, Self::Memory).serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for ConsumerStorage { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let b = bool::deserialize(deserializer)?; + Ok(if b { + ConsumerStorage::Memory + } else { + ConsumerStorage::Disk + }) + } +} + +impl Default for AckPolicy { + fn default() -> Self { + Self::All { + wait: Duration::ZERO, + max_pending: None, + } + } +} diff --git a/watermelon/src/client/jetstream/resources/mod.rs b/watermelon/src/client/jetstream/resources/mod.rs new file mode 100644 index 0000000..fc2386d --- /dev/null +++ b/watermelon/src/client/jetstream/resources/mod.rs @@ -0,0 +1,235 @@ +use serde::Deserialize; + +pub use self::consumer::{ + AckPolicy, Consumer, ConsumerConfig, ConsumerDurability, ConsumerSpecificConfig, + ConsumerStorage, DeliverPolicy, ReplayPolicy, +}; +pub use self::stream::{ + Compression, DiscardPolicy, RetentionPolicy, Storage, Stream, StreamConfig, StreamState, +}; + +use super::JetstreamError; + +mod consumer; +mod stream; + +#[derive(Debug, Deserialize)] +#[serde(untagged)] +pub(crate) enum Response { + Response(T), + Error { error: JetstreamError }, +} + +mod nullable_number { + use std::{any::type_name, fmt::Display}; + + use serde::{ + de::{self, DeserializeOwned}, + ser, Deserialize, Deserializer, Serialize, Serializer, + }; + + pub(crate) trait NullableNumber: Copy + Display { + const NULL_VALUE: Self::SignedValue; + type SignedValue: Copy + + TryFrom + + TryInto + + Display + + Eq + + Serialize + + DeserializeOwned; + } + + impl NullableNumber for u32 { + const NULL_VALUE: Self::SignedValue = -1; + type SignedValue = i32; + } + + impl NullableNumber for u64 { + const NULL_VALUE: Self::SignedValue = -1; + type SignedValue = i64; + } + + #[expect(clippy::ref_option)] + pub(crate) fn serialize(num: &Option, serializer: S) -> Result + where + S: Serializer, + N: NullableNumber, + { + match *num { + Some(num) => num.try_into().map_err(|_| { + ser::Error::custom(format!( + "{num} can't be converted to {}", + type_name::() + )) + })?, + None => N::NULL_VALUE, + } + .serialize(serializer) + } + + pub(crate) fn deserialize<'de, D: Deserializer<'de>, N: NullableNumber>( + deserializer: D, + ) -> Result, D::Error> { + let num = N::SignedValue::deserialize(deserializer)?; + Ok(if num == N::NULL_VALUE { + None + } else { + Some(num.try_into().map_err(|_| { + de::Error::custom(format!("{num} can't be converted to {}", type_name::())) + })?) + }) + } +} + +mod option_nonzero { + use std::num::NonZeroU32; + + use serde::{de::DeserializeOwned, Deserialize, Deserializer, Serialize, Serializer}; + + pub(crate) trait NonZeroNumber: Copy { + type Inner: Copy + Default + From + TryInto + Serialize + DeserializeOwned; + } + + impl NonZeroNumber for NonZeroU32 { + type Inner = u32; + } + + #[expect(clippy::ref_option)] + pub(crate) fn serialize(num: &Option, serializer: S) -> Result + where + S: Serializer, + N: NonZeroNumber, + { + match *num { + Some(num) => >::from(num), + None => Default::default(), + } + .serialize(serializer) + } + + pub(crate) fn deserialize<'de, D: Deserializer<'de>, N: NonZeroNumber>( + deserializer: D, + ) -> Result, D::Error> { + let num = ::deserialize(deserializer)?; + Ok(num.try_into().ok()) + } +} + +mod nullable_datetime { + use chrono::{DateTime, Datelike, Utc}; + use serde::{Deserialize, Deserializer}; + + pub(crate) fn deserialize<'de, D: Deserializer<'de>>( + deserializer: D, + ) -> Result>, D::Error> { + let datetime = >::deserialize(deserializer)?; + Ok(if datetime.year() == 1 { + None + } else { + Some(datetime) + }) + } +} + +mod duration { + use std::time::Duration; + + use serde::{Deserialize, Deserializer, Serialize, Serializer}; + + pub(crate) fn serialize(duration: &Duration, serializer: S) -> Result + where + S: Serializer, + { + duration.as_nanos().serialize(serializer) + } + + pub(crate) fn deserialize<'de, D: Deserializer<'de>>( + deserializer: D, + ) -> Result { + Ok(Duration::from_nanos(u64::deserialize(deserializer)?)) + } +} + +mod duration_vec { + use std::time::Duration; + + use serde::{Deserialize, Deserializer, Serializer}; + + #[expect( + clippy::ptr_arg, + reason = "this must follow the signature expected by serde" + )] + pub(crate) fn serialize(durations: &Vec, serializer: S) -> Result + where + S: Serializer, + { + serializer.collect_seq(durations.iter().map(std::time::Duration::as_nanos)) + } + + pub(crate) fn deserialize<'de, D: Deserializer<'de>>( + deserializer: D, + ) -> Result, D::Error> { + let durations = as Deserialize>::deserialize(deserializer)?; + Ok(durations.into_iter().map(Duration::from_nanos).collect()) + } +} + +mod compression { + #[derive(Debug, Serialize, Deserialize)] + #[serde(rename_all = "snake_case")] + enum CompressionInner { + None, + S2, + } + + use serde::{Deserialize, Deserializer, Serialize, Serializer}; + + use super::Compression; + + #[expect(clippy::ref_option)] + pub(crate) fn serialize( + compression: &Option, + serializer: S, + ) -> Result + where + S: Serializer, + { + match compression { + None => CompressionInner::None, + Some(Compression::S2) => CompressionInner::S2, + } + .serialize(serializer) + } + + pub(crate) fn deserialize<'de, D: Deserializer<'de>>( + deserializer: D, + ) -> Result, D::Error> { + Ok(match CompressionInner::deserialize(deserializer)? { + CompressionInner::None => None, + CompressionInner::S2 => Some(Compression::S2), + }) + } +} + +mod opposite_bool { + use std::ops::Not; + + use serde::{Deserialize, Deserializer, Serialize, Serializer}; + + #[expect( + clippy::trivially_copy_pass_by_ref, + reason = "this must follow the signature expected by serde" + )] + pub(crate) fn serialize(val: &bool, serializer: S) -> Result + where + S: Serializer, + { + val.not().serialize(serializer) + } + + pub(crate) fn deserialize<'de, D: Deserializer<'de>>( + deserializer: D, + ) -> Result { + bool::deserialize(deserializer).map(Not::not) + } +} diff --git a/watermelon/src/client/jetstream/resources/stream.rs b/watermelon/src/client/jetstream/resources/stream.rs new file mode 100644 index 0000000..8137776 --- /dev/null +++ b/watermelon/src/client/jetstream/resources/stream.rs @@ -0,0 +1,101 @@ +use std::{num::NonZeroU32, time::Duration}; + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; + +use super::{compression, duration, nullable_datetime, nullable_number, opposite_bool}; + +/// A Jetstream stream +#[derive(Debug, Deserialize)] +pub struct Stream { + pub config: StreamConfig, + #[serde(rename = "created")] + pub created_at: DateTime, + // TODO: `cluster` +} + +/// The state of the stream +#[derive(Debug, Deserialize)] +pub struct StreamState { + pub messages: u64, + pub bytes: u64, + pub first_sequence: u64, + #[serde(with = "nullable_datetime", rename = "first_ts")] + pub first_sequence_timestamp: Option>, + pub last_sequence: u64, + #[serde(with = "nullable_datetime", rename = "last_ts")] + pub last_sequence_timestamp: Option>, + pub consumer_count: u32, +} + +/// A Jetstream stream configuration +#[derive(Debug, Serialize, Deserialize)] +#[expect( + clippy::struct_excessive_bools, + reason = "it is the actual config of a Jetstream" +)] +pub struct StreamConfig { + pub name: String, + pub subjects: Vec, + #[serde(with = "nullable_number")] + pub max_consumers: Option, + #[serde(with = "nullable_number", rename = "max_msgs")] + pub max_messages: Option, + #[serde(with = "nullable_number")] + pub max_bytes: Option, + #[serde(with = "duration")] + pub max_age: Duration, + #[serde(with = "nullable_number", rename = "max_msgs_per_subject")] + pub max_messages_per_subject: Option, + #[serde(with = "nullable_number", rename = "max_msg_size")] + pub max_message_size: Option, + #[serde(rename = "discard")] + pub discard_policy: DiscardPolicy, + pub storage: Storage, + #[serde(rename = "num_replicas")] + pub replicas: NonZeroU32, + #[serde(with = "duration")] + pub duplicate_window: Duration, + #[serde(with = "compression")] + pub compression: Option, + pub allow_direct: bool, + pub mirror_direct: bool, + pub sealed: bool, + #[serde(with = "opposite_bool", rename = "deny_delete")] + pub allow_delete: bool, + #[serde(with = "opposite_bool", rename = "deny_purge")] + pub allow_purge: bool, + pub allow_rollup_hdrs: bool, + // TODO: `consumer_limits` https://github.com/nats-io/nats-server/blob/e25d973a8f389ce3aa415e4bcdfba1f7d0834f7f/server/stream.go#L99 +} + +/// A streams retention policy +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum RetentionPolicy { + Limits, + Interest, + WorkQueue, +} + +/// A streams discard policy +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum DiscardPolicy { + Old, + New, +} + +/// Whether the disk is stored on disk or in memory +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum Storage { + File, + Memory, +} + +/// The compression algorithm used by a stream +#[derive(Debug)] +pub enum Compression { + S2, +} diff --git a/watermelon/src/client/mod.rs b/watermelon/src/client/mod.rs new file mode 100644 index 0000000..d38fd56 --- /dev/null +++ b/watermelon/src/client/mod.rs @@ -0,0 +1,506 @@ +use std::{fmt::Write, num::NonZeroU64, process::abort, sync::Arc, time::Duration}; +#[cfg(test)] +use std::{ + net::{IpAddr, Ipv4Addr}, + num::{NonZeroU16, NonZeroU32}, +}; + +use arc_swap::ArcSwap; +use bytes::Bytes; +use rand::RngCore; +use tokio::{ + sync::{ + mpsc::{self, error::TrySendError, Permit}, + oneshot, + }, + task::JoinHandle, + time::{interval, MissedTickBehavior}, +}; +use watermelon_mini::ConnectError; +#[cfg(test)] +use watermelon_proto::NonStandardServerInfo; +use watermelon_proto::{ + headers::HeaderMap, QueueGroup, ServerAddr, ServerInfo, Subject, SubscriptionId, +}; + +pub use self::builder::{ClientBuilder, Echo}; +pub use self::commands::{ + ClientPublish, ClientRequest, DoClientPublish, DoClientRequest, DoOwnedClientPublish, + DoOwnedClientRequest, OwnedClientPublish, OwnedClientRequest, Publish, PublishBuilder, Request, + RequestBuilder, ResponseError, ResponseFut, +}; +pub use self::jetstream::{ + AckPolicy, Compression, Consumer, ConsumerBatch, ConsumerConfig, ConsumerDurability, + ConsumerSpecificConfig, ConsumerStorage, ConsumerStream, ConsumerStreamError, Consumers, + DeliverPolicy, DiscardPolicy, JetstreamClient, JetstreamError, JetstreamError2, + JetstreamErrorCode, ReplayPolicy, RetentionPolicy, Storage, Stream, StreamConfig, StreamState, + Streams, +}; +pub use self::quick_info::QuickInfo; +pub(crate) use self::quick_info::RawQuickInfo; +#[cfg(test)] +use self::tests::TestHandler; +use crate::{ + atomic::{AtomicU64, Ordering}, + core::{MultiplexedSubscription, Subscription}, + handler::{ + Handler, HandlerCommand, HandlerOutput, RecycledHandler, MULTIPLEXED_SUBSCRIPTION_ID, + }, +}; + +mod builder; +mod commands; +mod jetstream; +mod quick_info; +#[cfg(test)] +pub(crate) mod tests; + +#[cfg(feature = "from-env")] +pub(super) mod from_env; + +const CLIENT_OP_CHANNEL_SIZE: usize = 512; +const SUBSCRIPTION_CHANNEL_SIZE: usize = 256; +const RECONNECT_DELAY: Duration = Duration::from_secs(10); + +/// A NATS client +/// +/// `Client` is a `Clone`able handle to a NATS connection. +/// If the connection is lost, the client will automatically reconnect and +/// resume any currently open subscriptions. +#[derive(Debug, Clone)] +pub struct Client { + inner: Arc, +} + +#[derive(Debug)] +struct ClientInner { + sender: mpsc::Sender, + info: Arc>, + quick_info: Arc, + multiplexed_subscription_prefix: Subject, + next_subscription_id: AtomicU64, + inbox_prefix: Subject, + default_response_timeout: Duration, + handler: JoinHandle<()>, +} + +/// An error encountered while trying to publish a command to a closed [`Client`] +#[derive(Debug, thiserror::Error)] +#[non_exhaustive] +#[error("client closed")] +pub struct ClientClosedError; + +#[derive(Debug, thiserror::Error)] +#[error("try command error")] +pub enum TryCommandError { + /// The client's internal buffer is currently full + #[error("buffer full")] + BufferFull, + /// The client has been closed via [`Client::close`] + #[error("client closed")] + Closed(#[source] ClientClosedError), +} + +impl Client { + /// Construct a new client + #[must_use] + pub fn builder() -> ClientBuilder { + ClientBuilder::new() + } + + pub(super) async fn connect( + addr: ServerAddr, + builder: ClientBuilder, + ) -> Result { + let (sender, receiver) = mpsc::channel(CLIENT_OP_CHANNEL_SIZE); + + let quick_info = Arc::new(RawQuickInfo::new()); + let handle = RecycledHandler::new(receiver, Arc::clone(&quick_info), &builder); + let handle = Handler::connect(&addr, &builder, handle) + .await + .map_err(|(err, _recycle)| err)?; + let info = handle.info().clone(); + let multiplexed_subscription_prefix = handle.multiplexed_subscription_prefix().clone(); + let inbox_prefix = builder.inbox_prefix.clone(); + let default_response_timeout = builder.default_response_timeout; + + let handler = tokio::spawn(async move { + let mut handle = handle; + + loop { + match (&mut handle).await { + HandlerOutput::ServerError | HandlerOutput::Disconnected => { + let mut recycle = handle.recycle().await; + + let mut interval = interval(RECONNECT_DELAY); + interval.set_missed_tick_behavior(MissedTickBehavior::Delay); + + loop { + interval.tick().await; + + match Handler::connect(&addr, &builder, recycle).await { + Ok(new_handle) => { + handle = new_handle; + break; + } + Err((_err, prev_recycle)) => recycle = prev_recycle, + } + } + } + HandlerOutput::UnexpectedState => { + // Retry and hope for the best + } + HandlerOutput::Closed => break, + } + } + }); + + Ok(Self { + inner: Arc::new(ClientInner { + info, + sender, + quick_info, + multiplexed_subscription_prefix, + next_subscription_id: AtomicU64::new(u64::from(MULTIPLEXED_SUBSCRIPTION_ID) + 1), + inbox_prefix, + default_response_timeout, + handler, + }), + }) + } + + #[cfg(test)] + pub(crate) fn test(client_to_handler_chan_size: usize) -> (Self, TestHandler) { + let builder = Self::builder(); + let (sender, receiver) = mpsc::channel(client_to_handler_chan_size); + let info = Arc::new(ArcSwap::new(Arc::from(ServerInfo { + id: "1234".to_owned(), + name: "watermelon-test".to_owned(), + version: "2.10.17".to_owned(), + go_version: "1.22.5".to_owned(), + host: IpAddr::V4(Ipv4Addr::LOCALHOST), + port: NonZeroU16::new(4222).unwrap(), + supports_headers: true, + max_payload: NonZeroU32::new(1024 * 1024).unwrap(), + protocol_version: 2, + client_id: Some(1), + auth_required: false, + tls_required: false, + tls_verify: false, + tls_available: false, + connect_urls: Vec::new(), + websocket_connect_urls: Vec::new(), + lame_duck_mode: false, + git_commit: None, + supports_jetstream: true, + ip: None, + client_ip: None, + nonce: None, + cluster_name: None, + domain: None, + + non_standard: NonStandardServerInfo::default(), + }))); + let quick_info = Arc::new(RawQuickInfo::new()); + let multiplexed_subscription_prefix = create_inbox_subject(&builder.inbox_prefix); + + let this = Self { + inner: Arc::new(ClientInner { + sender, + info: Arc::clone(&info), + quick_info: Arc::clone(&quick_info), + multiplexed_subscription_prefix, + next_subscription_id: AtomicU64::new(1), + inbox_prefix: builder.inbox_prefix, + default_response_timeout: builder.default_response_timeout, + handler: tokio::spawn(async move {}), + }), + }; + let handler = TestHandler { + receiver, + _info: info, + quick_info, + }; + (this, handler) + } + + /// Publish a new message to the NATS server + /// + /// Consider calling [`Publish::client`] instead if you already have + /// a [`Publish`] instance. + #[must_use] + pub fn publish(&self, subject: Subject) -> ClientPublish { + ClientPublish::build(self, subject) + } + + /// Publish a new message to the NATS server + /// + /// Consider calling [`Request::client`] instead if you already have + /// a [`Request`] instance. + #[must_use] + pub fn request(&self, subject: Subject) -> ClientRequest { + ClientRequest::build(self, subject) + } + + /// Publish a new message to the NATS server, taking ownership of this client + /// + /// When possible consider using [`Client::publish`] instead. + /// + /// Consider calling [`Publish::client_owned`] instead if you already have + /// a [`Publish`] instance. + #[must_use] + pub fn publish_owned(self, subject: Subject) -> OwnedClientPublish { + OwnedClientPublish::build(self, subject) + } + + /// Publish a new message to the NATS server, taking ownership of this client + /// + /// When possible consider using [`Client::request`] instead. + /// + /// Consider calling [`Request::client_owned`] instead if you already have + /// a [`Request`] instance. + #[must_use] + pub fn request_owned(self, subject: Subject) -> OwnedClientRequest { + OwnedClientRequest::build(self, subject) + } + + /// Subscribe to the given filter subject + /// + /// Create a new subscription with the NATS server and ask for all + /// messages matching the given `filter_subject` to be delivered + /// to the client. + /// + /// If `queue_group` is provided and multiple clients subscribe with + /// the same [`QueueGroup`] value, the NATS server will try to deliver + /// these messages to only one of the clients. + /// + /// If the client was built with [`Echo::Allow`], then messages + /// published by this same client may be received by this subscription. + /// + /// # Errors + /// + /// This returns an error if the connection with the client is closed. + pub async fn subscribe( + &self, + filter_subject: Subject, + queue_group: Option, + ) -> Result { + let permit = self + .inner + .sender + .reserve() + .await + .map_err(|_| ClientClosedError)?; + + Ok(self.do_subscribe(permit, filter_subject, queue_group)) + } + + pub(crate) fn try_subscribe( + &self, + filter_subject: Subject, + queue_group: Option, + ) -> Result { + let permit = self + .inner + .sender + .try_reserve() + .map_err(|_| TryCommandError::BufferFull)?; + + Ok(self.do_subscribe(permit, filter_subject, queue_group)) + } + + fn do_subscribe( + &self, + permit: Permit<'_, HandlerCommand>, + filter_subject: Subject, + queue_group: Option, + ) -> Subscription { + let id = self + .inner + .next_subscription_id + .fetch_add(1, Ordering::AcqRel) + .into(); + if id == SubscriptionId::MAX { + abort(); + } + let (sender, receiver) = mpsc::channel(SUBSCRIPTION_CHANNEL_SIZE); + + permit.send(HandlerCommand::Subscribe { + id, + subject: filter_subject, + queue_group, + messages: sender, + }); + Subscription::new(id, self.clone(), receiver) + } + + pub(super) async fn multiplexed_request( + &self, + subject: Subject, + headers: HeaderMap, + payload: Bytes, + ) -> Result { + let permit = self + .inner + .sender + .reserve() + .await + .map_err(|_| ClientClosedError)?; + + Ok(self.do_multiplexed_request(permit, subject, headers, payload)) + } + + pub(super) fn try_multiplexed_request( + &self, + subject: Subject, + headers: HeaderMap, + payload: Bytes, + ) -> Result { + let permit = self + .inner + .sender + .try_reserve() + .map_err(|_| TryCommandError::BufferFull)?; + + Ok(self.do_multiplexed_request(permit, subject, headers, payload)) + } + + fn do_multiplexed_request( + &self, + permit: Permit<'_, HandlerCommand>, + subject: Subject, + headers: HeaderMap, + payload: Bytes, + ) -> MultiplexedSubscription { + let (sender, receiver) = oneshot::channel(); + + let reply_subject = create_inbox_subject(&self.inner.multiplexed_subscription_prefix); + + permit.send(HandlerCommand::RequestMultiplexed { + subject, + reply_subject: reply_subject.clone(), + headers, + payload, + reply: sender, + }); + MultiplexedSubscription::new(reply_subject, receiver, self.clone()) + } + + /// Get the last [`ServerInfo`] sent by the server + /// + /// Consider calling [`Client::quick_info`] if you only need + /// information about Lame Duck Mode. + #[must_use] + pub fn server_info(&self) -> Arc { + self.inner.info.load_full() + } + + /// Get information about the client + #[must_use] + pub fn quick_info(&self) -> QuickInfo { + self.inner.quick_info.get() + } + + pub(crate) fn create_inbox_subject(&self) -> Subject { + create_inbox_subject(&self.inner.inbox_prefix) + } + + pub(crate) fn default_response_timeout(&self) -> Duration { + self.inner.default_response_timeout + } + + pub(crate) fn lazy_unsubscribe_multiplexed(&self, reply_subject: Subject) { + if self + .try_enqueue_command(HandlerCommand::UnsubscribeMultiplexed { reply_subject }) + .is_ok() + { + return; + } + + self.inner.quick_info.store_is_failed_unsubscribe(true); + } + + pub(crate) async fn unsubscribe( + &self, + id: SubscriptionId, + max_messages: Option, + ) -> Result<(), ClientClosedError> { + self.enqueue_command(HandlerCommand::Unsubscribe { id, max_messages }) + .await + } + + pub(crate) fn lazy_unsubscribe(&self, id: SubscriptionId, max_messages: Option) { + if self + .try_enqueue_command(HandlerCommand::Unsubscribe { id, max_messages }) + .is_ok() + { + return; + } + + self.inner.quick_info.store_is_failed_unsubscribe(true); + } + + pub(super) async fn enqueue_command( + &self, + cmd: HandlerCommand, + ) -> Result<(), ClientClosedError> { + self.inner + .sender + .send(cmd) + .await + .map_err(|_| ClientClosedError) + } + + pub(super) fn try_enqueue_command(&self, cmd: HandlerCommand) -> Result<(), TryCommandError> { + self.inner + .sender + .try_send(cmd) + .map_err(TryCommandError::from_try_send_error) + } + + /// Close this client, waiting for any remaining buffered messages to be processed first + /// + /// Attempts to send commands to the NATS server after this method has been called will + /// result into a [`ClientClosedError`] error. + pub async fn close(&self) { + let (sender, receiver) = oneshot::channel(); + if self + .enqueue_command(HandlerCommand::Close(sender)) + .await + .is_err() + { + return; + } + + let _ = receiver.await; + } +} + +impl Drop for ClientInner { + fn drop(&mut self) { + self.handler.abort(); + } +} + +impl TryCommandError { + #[expect( + clippy::needless_pass_by_value, + reason = "this is an auxiliary conversion function" + )] + pub(crate) fn from_try_send_error(err: TrySendError) -> Self { + match err { + TrySendError::Full(_) => Self::BufferFull, + TrySendError::Closed(_) => Self::Closed(ClientClosedError), + } + } +} + +pub(crate) fn create_inbox_subject(prefix: &Subject) -> Subject { + let mut suffix = [0u8; 16]; + rand::thread_rng().fill_bytes(&mut suffix); + + let mut subject = String::with_capacity(prefix.len() + ".".len() + (suffix.len() * 2)); + write!(&mut subject, "{}.{:x}", prefix, u128::from_ne_bytes(suffix)).unwrap(); + + Subject::from_dangerous_value(subject.into()) +} diff --git a/watermelon/src/client/quick_info.rs b/watermelon/src/client/quick_info.rs new file mode 100644 index 0000000..c6073da --- /dev/null +++ b/watermelon/src/client/quick_info.rs @@ -0,0 +1,165 @@ +use crate::atomic::{AtomicU32, Ordering}; + +const IS_CONNECTED: u32 = 1 << 0; +#[cfg(feature = "non-standard-zstd")] +const IS_ZSTD_COMPRESSED: u32 = 1 << 1; +const IS_LAMEDUCK: u32 = 1 << 2; +const IS_FAILED_UNSUBSCRIBE: u32 = 1 << 31; + +#[derive(Debug)] +pub(crate) struct RawQuickInfo(AtomicU32); + +/// Client information +/// +/// Obtained from [`Client::quick_info`]. +/// +/// [`Client::quick_info`]: crate::core::Client::quick_info +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[expect(clippy::struct_excessive_bools)] +pub struct QuickInfo { + pub(crate) is_connected: bool, + #[cfg(feature = "non-standard-zstd")] + pub(crate) is_zstd_compressed: bool, + pub(crate) is_lameduck: bool, + pub(crate) is_failed_unsubscribe: bool, +} + +impl RawQuickInfo { + pub(crate) fn new() -> Self { + Self(AtomicU32::new( + QuickInfo { + is_connected: false, + #[cfg(feature = "non-standard-zstd")] + is_zstd_compressed: false, + is_lameduck: false, + is_failed_unsubscribe: false, + } + .encode(), + )) + } + + pub(crate) fn get(&self) -> QuickInfo { + QuickInfo::decode(self.0.load(Ordering::Acquire)) + } + + pub(crate) fn store(&self, mut f: F) + where + F: FnMut(QuickInfo) -> QuickInfo, + { + let prev_params = self.get(); + self.0.store(f(prev_params).encode(), Ordering::Release); + } + + pub(crate) fn store_is_connected(&self, val: bool) { + self.store_bit(IS_CONNECTED, val); + } + pub(crate) fn store_is_lameduck(&self, val: bool) { + self.store_bit(IS_LAMEDUCK, val); + } + pub(crate) fn store_is_failed_unsubscribe(&self, val: bool) { + self.store_bit(IS_FAILED_UNSUBSCRIBE, val); + } + + #[expect( + clippy::inline_always, + reason = "we want this to be inlined inside the store_* functions" + )] + #[inline(always)] + fn store_bit(&self, mask: u32, val: bool) { + debug_assert_eq!(mask.count_ones(), 1); + + if val { + self.0.fetch_or(mask, Ordering::AcqRel); + } else { + self.0.fetch_and(!mask, Ordering::AcqRel); + } + } +} + +impl QuickInfo { + /// Returns `true` if the client is currently connected to the NATS server + #[must_use] + pub fn is_connected(&self) -> bool { + self.is_connected + } + + /// Returns `true` if the client connection is zstd compressed + #[cfg(feature = "non-standard-zstd")] + #[must_use] + pub fn is_zstd_compressed(&self) -> bool { + self.is_zstd_compressed + } + + /// Returns `true` if the client is currently in Lame Duck Mode + #[must_use] + pub fn is_lameduck(&self) -> bool { + self.is_lameduck + } + + fn encode(self) -> u32 { + let mut val = 0; + + if self.is_connected { + val |= IS_CONNECTED; + } + + #[cfg(feature = "non-standard-zstd")] + if self.is_zstd_compressed { + val |= IS_ZSTD_COMPRESSED; + } + + if self.is_lameduck { + val |= IS_LAMEDUCK; + } + + if self.is_failed_unsubscribe { + val |= IS_FAILED_UNSUBSCRIBE; + } + + val + } + + fn decode(val: u32) -> Self { + Self { + is_connected: (val & IS_CONNECTED) != 0, + #[cfg(feature = "non-standard-zstd")] + is_zstd_compressed: (val & IS_ZSTD_COMPRESSED) != 0, + is_lameduck: (val & IS_LAMEDUCK) != 0, + is_failed_unsubscribe: (val & IS_FAILED_UNSUBSCRIBE) != 0, + } + } +} + +#[cfg(test)] +mod tests { + use super::{QuickInfo, RawQuickInfo}; + + #[test] + fn set_get() { + let quick_info = RawQuickInfo::new(); + let mut expected = QuickInfo { + is_connected: false, + #[cfg(feature = "non-standard-zstd")] + is_zstd_compressed: false, + is_lameduck: false, + is_failed_unsubscribe: false, + }; + + for is_connected in [false, true] { + quick_info.store_is_connected(is_connected); + expected.is_connected = is_connected; + + for is_lameduck in [false, true] { + quick_info.store_is_lameduck(is_lameduck); + expected.is_lameduck = is_lameduck; + + for is_failed_unsubscribe in [false, true] { + quick_info.store_is_failed_unsubscribe(is_failed_unsubscribe); + expected.is_failed_unsubscribe = is_failed_unsubscribe; + + assert_eq!(expected, quick_info.get()); + } + } + } + } +} diff --git a/watermelon/src/client/tests.rs b/watermelon/src/client/tests.rs new file mode 100644 index 0000000..b4f5b1a --- /dev/null +++ b/watermelon/src/client/tests.rs @@ -0,0 +1,14 @@ +use std::sync::Arc; + +use arc_swap::ArcSwap; +use tokio::sync::mpsc; +use watermelon_proto::ServerInfo; + +use crate::{client::RawQuickInfo, handler::HandlerCommand}; + +#[derive(Debug)] +pub(crate) struct TestHandler { + pub(crate) receiver: mpsc::Receiver, + pub(crate) _info: Arc>, + pub(crate) quick_info: Arc, +} diff --git a/watermelon/src/handler.rs b/watermelon/src/handler.rs new file mode 100644 index 0000000..4177c56 --- /dev/null +++ b/watermelon/src/handler.rs @@ -0,0 +1,721 @@ +use std::{ + collections::{BTreeMap, VecDeque}, + future::Future, + num::NonZeroU64, + ops::ControlFlow, + pin::Pin, + sync::Arc, + task::{Context, Poll}, + time::Duration, +}; + +use arc_swap::ArcSwap; +use bytes::Bytes; +use tokio::{ + net::TcpStream, + sync::{ + mpsc::{self, error::TrySendError}, + oneshot, + }, + time::{self, Instant, Sleep}, +}; +use watermelon_mini::{ + easy_connect, ConnectError, ConnectFlags, ConnectionCompression, ConnectionSecurity, +}; +use watermelon_net::Connection; +use watermelon_proto::{ + error::ServerError, + headers::HeaderMap, + proto::{ClientOp, ServerOp}, + MessageBase, QueueGroup, ServerAddr, ServerInfo, ServerMessage, Subject, SubscriptionId, +}; + +use crate::client::{create_inbox_subject, QuickInfo, RawQuickInfo}; +use crate::core::{ClientBuilder, Echo}; + +pub(crate) const MULTIPLEXED_SUBSCRIPTION_ID: SubscriptionId = SubscriptionId::MIN; +const PING_INTERVAL: Duration = Duration::from_secs(10); +const RECV_BUF: usize = 16; + +#[derive(Debug)] +pub(crate) struct Handler { + conn: Connection< + ConnectionCompression>, + ConnectionSecurity, + >, + info: Arc>, + quick_info: Arc, + delayed_flusher: Option, + flushing: bool, + shutting_down: bool, + + ping_interval: Pin>, + pending_pings: u8, + + commands: mpsc::Receiver, + recv_buf: Vec, + in_flight_commands: VecDeque, + + multiplexed_subscription_prefix: Subject, + multiplexed_subscriptions: Option>>, + subscriptions: BTreeMap, + + awaiting_close: Vec>, +} + +#[derive(Debug)] +struct DelayedFlusher { + // INVARIANT: `interval != Duration::ZERO` + interval: Duration, + delay: Pin>>, +} + +#[derive(Debug)] +pub(crate) struct RecycledHandler { + commands: mpsc::Receiver, + quick_info: Arc, + + multiplexed_subscription_prefix: Subject, + subscriptions: BTreeMap, + + awaiting_close: Vec>, +} + +#[derive(Debug)] +struct Subscription { + subject: Subject, + queue_group: Option, + messages: mpsc::Sender>, + remaining: Option, + failed_subscribe: bool, +} + +#[derive(Debug)] +pub(crate) enum HandlerCommand { + Publish { + message: MessageBase, + }, + RequestMultiplexed { + subject: Subject, + reply_subject: Subject, + headers: HeaderMap, + payload: Bytes, + reply: oneshot::Sender, + }, + UnsubscribeMultiplexed { + reply_subject: Subject, + }, + Subscribe { + id: SubscriptionId, + subject: Subject, + queue_group: Option, + messages: mpsc::Sender>, + }, + Unsubscribe { + id: SubscriptionId, + max_messages: Option, + }, + Close(oneshot::Sender<()>), +} + +#[derive(Debug)] +pub(crate) enum InFlightCommand { + Unimportant, + Subscribe { id: SubscriptionId }, +} + +#[derive(Debug)] +pub(crate) enum HandlerOutput { + ServerError, + UnexpectedState, + Disconnected, + Closed, +} + +impl Handler { + pub(crate) async fn connect( + addr: &ServerAddr, + builder: &ClientBuilder, + recycle: RecycledHandler, + ) -> Result { + let mut flags = ConnectFlags::default(); + flags.echo = matches!(builder.echo, Echo::Allow); + #[cfg(feature = "non-standard-zstd")] + { + flags.zstd = builder.non_standard_zstd; + } + + let (mut conn, info) = match easy_connect(addr, builder.auth_method.as_ref(), flags).await { + Ok(items) => items, + Err(err) => return Err((err, recycle)), + }; + + #[cfg(feature = "non-standard-zstd")] + let is_zstd_compressed = if let Connection::Streaming(streaming) = &conn { + streaming.socket().is_zstd_compressed() + } else { + false + }; + recycle.quick_info.store(|quick_info| QuickInfo { + is_connected: true, + #[cfg(feature = "non-standard-zstd")] + is_zstd_compressed, + is_lameduck: false, + ..quick_info + }); + + let mut in_flight_commands = VecDeque::new(); + for (&id, subscription) in &recycle.subscriptions { + in_flight_commands.push_back(InFlightCommand::Subscribe { id }); + conn.enqueue_write_op(&ClientOp::Subscribe { + id, + subject: subscription.subject.clone(), + queue_group: subscription.queue_group.clone(), + }); + + if let Some(remaining) = subscription.remaining { + conn.enqueue_write_op(&ClientOp::Unsubscribe { + id, + max_messages: Some(remaining), + }); + } + } + + let delayed_flusher = if builder.flush_interval.is_zero() { + None + } else { + Some(DelayedFlusher { + interval: builder.flush_interval, + delay: Box::pin(None), + }) + }; + + Ok(Self { + conn, + info: Arc::new(ArcSwap::new(Arc::from(info))), + quick_info: recycle.quick_info, + delayed_flusher, + flushing: false, + shutting_down: false, + ping_interval: Box::pin(time::sleep(PING_INTERVAL)), + pending_pings: 0, + commands: recycle.commands, + recv_buf: Vec::with_capacity(RECV_BUF), + in_flight_commands, + subscriptions: recycle.subscriptions, + multiplexed_subscription_prefix: recycle.multiplexed_subscription_prefix, + multiplexed_subscriptions: None, + awaiting_close: recycle.awaiting_close, + }) + } + + pub(crate) async fn recycle(mut self) -> RecycledHandler { + self.quick_info.store_is_connected(false); + let _ = self.conn.shutdown().await; + + RecycledHandler { + commands: self.commands, + quick_info: self.quick_info, + subscriptions: self.subscriptions, + multiplexed_subscription_prefix: self.multiplexed_subscription_prefix, + awaiting_close: self.awaiting_close, + } + } + + pub(crate) fn info(&self) -> &Arc> { + &self.info + } + + pub(crate) fn multiplexed_subscription_prefix(&self) -> &Subject { + &self.multiplexed_subscription_prefix + } + + fn handle_server_op(&mut self, server_op: ServerOp) -> ControlFlow { + match server_op { + ServerOp::Message { message } + if message.subscription_id == MULTIPLEXED_SUBSCRIPTION_ID => + { + let Some(multiplexed_subscriptions) = &mut self.multiplexed_subscriptions else { + return ControlFlow::Continue(()); + }; + + if let Some(sender) = multiplexed_subscriptions.remove(&message.base.subject) { + let _ = sender.send(message); + } else { + // 🤷 + } + } + ServerOp::Message { message } => { + let subscription_id = message.subscription_id; + + if let Some(subscription) = self.subscriptions.get_mut(&subscription_id) { + match subscription.messages.try_send(Ok(message)) { + Ok(()) => {} + #[expect( + clippy::match_same_arms, + reason = "the case still needs to be implemented" + )] + Err(TrySendError::Full(_)) => { + // TODO + } + Err(TrySendError::Closed(_)) => { + self.in_flight_commands + .push_back(InFlightCommand::Unimportant); + self.conn.enqueue_write_op(&ClientOp::Unsubscribe { + id: subscription_id, + max_messages: None, + }); + return ControlFlow::Continue(()); + } + } + + if let Some(remaining) = &mut subscription.remaining { + match NonZeroU64::new(remaining.get() - 1) { + Some(new_remaining) => *remaining = new_remaining, + None => { + self.subscriptions.remove(&subscription_id); + } + } + } + } else { + // 🤷 + } + } + ServerOp::Success => { + let Some(in_flight_command) = self.in_flight_commands.pop_front() else { + return ControlFlow::Break(HandlerOutput::UnexpectedState); + }; + + match in_flight_command { + InFlightCommand::Unimportant | InFlightCommand::Subscribe { .. } => { + // Nothing to do + } + } + } + ServerOp::Error { error } if error.is_fatal() == Some(false) => { + let Some(in_flight_command) = self.in_flight_commands.pop_front() else { + return ControlFlow::Break(HandlerOutput::UnexpectedState); + }; + + match in_flight_command { + InFlightCommand::Unimportant => { + // Nothing to do + } + InFlightCommand::Subscribe { id } => { + if let Some(mut subscription) = self.subscriptions.remove(&id) { + match subscription.messages.try_send(Err(error)) { + Ok(()) | Err(TrySendError::Closed(_)) => { + // Nothing to do + } + Err(TrySendError::Full(_)) => { + // The error is going to be lost + + // We have to put the subscription back in order for the unsubscribe to be handled correctly + subscription.failed_subscribe = true; + self.subscriptions.insert(id, subscription); + self.quick_info.store_is_failed_unsubscribe(true); + } + } + } + } + } + } + ServerOp::Error { error: _ } => return ControlFlow::Break(HandlerOutput::ServerError), + ServerOp::Ping => { + self.conn.enqueue_write_op(&ClientOp::Pong); + } + ServerOp::Pong => { + self.pending_pings = self.pending_pings.saturating_sub(1); + } + ServerOp::Info { info } => { + self.quick_info.store_is_lameduck(info.lame_duck_mode); + self.info.store(Arc::from(info)); + } + } + + ControlFlow::Continue(()) + } + + #[cold] + fn ping(&mut self, cx: &mut Context<'_>) -> Result<(), HandlerOutput> { + if self.pending_pings < 2 { + loop { + self.reset_ping_interval(); + if Pin::new(&mut self.ping_interval).poll(cx).is_pending() { + break; + } + } + + self.conn.enqueue_write_op(&ClientOp::Ping); + self.pending_pings += 1; + Ok(()) + } else { + Err(HandlerOutput::Disconnected) + } + } + + #[cold] + fn failed_unsubscribe(&mut self) { + self.quick_info.store_is_failed_unsubscribe(false); + + if let Some(multiplexed_subscriptions) = &mut self.multiplexed_subscriptions { + multiplexed_subscriptions.retain(|_subject, sender| !sender.is_closed()); + } + + let closed_subscription_ids = self + .subscriptions + .iter() + .filter(|(_id, subscription)| { + subscription.messages.is_closed() || subscription.failed_subscribe + }) + .map(|(&id, _subscription)| id) + .collect::>(); + + for closed_subscription_id in closed_subscription_ids { + self.in_flight_commands + .push_back(InFlightCommand::Unimportant); + self.conn.enqueue_write_op(&ClientOp::Unsubscribe { + id: closed_subscription_id, + max_messages: None, + }); + self.subscriptions.remove(&closed_subscription_id); + } + } + + fn reset_ping_interval(&mut self) { + Sleep::reset(self.ping_interval.as_mut(), Instant::now() + PING_INTERVAL); + } +} + +impl Future for Handler { + type Output = HandlerOutput; + + #[expect(clippy::too_many_lines)] + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + #[derive(Debug, Copy, Clone)] + enum FlushAction { + Start, + Stop, + } + + let this = self.get_mut(); + if Pin::new(&mut this.ping_interval).poll(cx).is_ready() { + if let Err(output) = this.ping(cx) { + return Poll::Ready(output); + } + } + + if this.quick_info.get().is_failed_unsubscribe { + this.failed_unsubscribe(); + } + + let mut handled_server_op = false; + loop { + match this.conn.poll_read_next(cx) { + Poll::Pending => break, + Poll::Ready(Ok(server_op)) => { + this.handle_server_op(server_op); + handled_server_op = true; + } + Poll::Ready(Err(_err)) => return Poll::Ready(HandlerOutput::Disconnected), + } + } + if handled_server_op { + this.reset_ping_interval(); + } + + loop { + let receive_outcome = this.receive_command(cx); + let write_waker_registered = match &mut this.conn { + Connection::Streaming(streaming) => { + if streaming.may_write() { + match streaming.poll_write_next(cx) { + Poll::Pending => true, + Poll::Ready(Ok(_n)) => false, + Poll::Ready(Err(_err)) => { + return Poll::Ready(HandlerOutput::Disconnected); + } + } + } else { + true + } + } + Connection::Websocket(_) => true, + }; + + let flushes_automatically_when_full = this.conn.flushes_automatically_when_full(); + let should_flush = this.conn.should_flush(); + + let flush_action = match ( + receive_outcome, + flushes_automatically_when_full, + should_flush, + ) { + (ReceiveOutcome::NoMoreCommands, _, true) => { + // We have written everything there was to write, + // and some data is buffered + FlushAction::Start + } + (ReceiveOutcome::NoMoreSpace, true, should_flush) => { + debug_assert!(should_flush, "the connection is out space for writing but doesn't report the need to flush"); + + // There's no more space to write, but the implementation automatically + // flushes so we're good + FlushAction::Stop + } + (ReceiveOutcome::NoMoreSpace, false, true) => { + // There's no more space to write, and the implementation doesn't + // flush automatically + FlushAction::Start + } + (_, _, false) => { + // There's nothing to flush + FlushAction::Stop + } + }; + + match flush_action { + FlushAction::Start => { + this.flushing = true; + if let Some(delayed_flusher) = &mut this.delayed_flusher { + if delayed_flusher.delay.is_none() { + delayed_flusher + .delay + .set(Some(time::sleep(delayed_flusher.interval))); + } + } + } + FlushAction::Stop => { + this.flushing = false; + } + } + + match (receive_outcome, write_waker_registered) { + (ReceiveOutcome::NoMoreCommands, true) => { + // There are no more commands to receive and writing is blocked. + // There's no progress to be made + break; + } + (ReceiveOutcome::NoMoreSpace, true) => { + // There's no more space to write and writing is blocked. + // There's no progress to be made + break; + } + (_, false) => { + // At least the write waker must be registered + continue; + } + } + } + + if this.flushing { + let mut can_flush = true; + if let Some(delay_flusher) = &mut this.delayed_flusher { + if let Some(delay) = delay_flusher.delay.as_mut().as_pin_mut() { + if delay.poll(cx).is_ready() { + delay_flusher.delay.set(None); + } else { + can_flush = false; + } + } + } + + if can_flush { + match this.conn.poll_flush(cx) { + Poll::Pending => {} + Poll::Ready(Ok(())) => this.flushing = false, + Poll::Ready(Err(_err)) => return Poll::Ready(HandlerOutput::Disconnected), + } + } + } + + if this.shutting_down { + Poll::Ready(HandlerOutput::Closed) + } else { + Poll::Pending + } + } +} + +#[derive(Debug, Copy, Clone)] +enum ReceiveOutcome { + NoMoreCommands, + NoMoreSpace, +} + +impl Handler { + // TODO: refactor this, a view into Handler is needed in order to split `recv_buf` from the + // rest. + #[expect( + clippy::too_many_lines, + reason = "not good, but a non trivial refactor is needed" + )] + fn receive_command(&mut self, cx: &mut Context<'_>) -> ReceiveOutcome { + while self.conn.may_enqueue_more_ops() { + debug_assert!(self.recv_buf.is_empty()); + + match self + .commands + .poll_recv_many(cx, &mut self.recv_buf, RECV_BUF) + { + Poll::Pending => return ReceiveOutcome::NoMoreCommands, + Poll::Ready(1..) => { + for cmd in self.recv_buf.drain(..) { + match cmd { + HandlerCommand::Publish { message } => { + self.in_flight_commands + .push_back(InFlightCommand::Unimportant); + self.conn.enqueue_write_op(&ClientOp::Publish { message }); + } + HandlerCommand::RequestMultiplexed { + subject, + reply_subject, + headers, + payload, + reply, + } => { + debug_assert!(reply_subject + .starts_with(&*self.multiplexed_subscription_prefix)); + + let multiplexed_subscriptions = + if let Some(multiplexed_subscriptions) = + &mut self.multiplexed_subscriptions + { + multiplexed_subscriptions + } else { + init_multiplexed_subscriptions( + &mut self.in_flight_commands, + &mut self.conn, + &self.multiplexed_subscription_prefix, + &mut self.multiplexed_subscriptions, + ) + }; + + self.in_flight_commands + .push_back(InFlightCommand::Unimportant); + multiplexed_subscriptions.insert(reply_subject.clone(), reply); + + let message = MessageBase { + subject, + reply_subject: Some(reply_subject), + headers, + payload, + }; + self.conn.enqueue_write_op(&ClientOp::Publish { message }); + } + HandlerCommand::UnsubscribeMultiplexed { reply_subject } => { + debug_assert!(reply_subject + .starts_with(&*self.multiplexed_subscription_prefix)); + + if let Some(multiplexed_subscriptions) = + &mut self.multiplexed_subscriptions + { + let _ = multiplexed_subscriptions.remove(&reply_subject); + } + } + HandlerCommand::Subscribe { + id, + subject, + queue_group, + messages, + } => { + self.subscriptions.insert( + id, + Subscription { + subject: subject.clone(), + queue_group: queue_group.clone(), + messages, + remaining: None, + failed_subscribe: false, + }, + ); + self.in_flight_commands + .push_back(InFlightCommand::Subscribe { id }); + self.conn.enqueue_write_op(&ClientOp::Subscribe { + id, + subject, + queue_group, + }); + } + HandlerCommand::Unsubscribe { + id, + max_messages: Some(max_messages), + } => { + if let Some(subscription) = self.subscriptions.get_mut(&id) { + subscription.remaining = Some(max_messages); + self.in_flight_commands + .push_back(InFlightCommand::Unimportant); + self.conn.enqueue_write_op(&ClientOp::Unsubscribe { + id, + max_messages: Some(max_messages), + }); + } + } + HandlerCommand::Unsubscribe { + id, + max_messages: None, + } => { + if self.subscriptions.remove(&id).is_some() { + self.in_flight_commands + .push_back(InFlightCommand::Unimportant); + self.conn.enqueue_write_op(&ClientOp::Unsubscribe { + id, + max_messages: None, + }); + } + } + HandlerCommand::Close(sender) => { + self.shutting_down = true; + self.awaiting_close.push(sender); + self.commands.close(); + } + } + } + } + Poll::Ready(0) => self.shutting_down = true, + } + } + + ReceiveOutcome::NoMoreSpace + } +} + +impl RecycledHandler { + pub(crate) fn new( + commands: mpsc::Receiver, + quick_info: Arc, + builder: &ClientBuilder, + ) -> Self { + Self { + commands, + quick_info, + subscriptions: BTreeMap::new(), + multiplexed_subscription_prefix: create_inbox_subject(&builder.inbox_prefix), + awaiting_close: Vec::new(), + } + } +} + +#[cold] +fn init_multiplexed_subscriptions<'a>( + in_flight_commands: &mut VecDeque, + conn: &mut Connection< + ConnectionCompression>, + ConnectionSecurity, + >, + multiplexed_subscription_prefix: &Subject, + multiplexed_subscriptions: &'a mut Option>>, +) -> &'a mut BTreeMap> { + in_flight_commands.push_back(InFlightCommand::Subscribe { + id: MULTIPLEXED_SUBSCRIPTION_ID, + }); + conn.enqueue_write_op(&ClientOp::Subscribe { + id: MULTIPLEXED_SUBSCRIPTION_ID, + subject: Subject::from_dangerous_value( + format!("{multiplexed_subscription_prefix}.*").into(), + ), + queue_group: None, + }); + + multiplexed_subscriptions.insert(BTreeMap::new()) +} diff --git a/watermelon/src/lib.rs b/watermelon/src/lib.rs new file mode 100644 index 0000000..1cbc970 --- /dev/null +++ b/watermelon/src/lib.rs @@ -0,0 +1,61 @@ +pub use watermelon_proto as proto; + +mod atomic; +mod client; +mod handler; +mod multiplexed_subscription; +mod subscription; +#[cfg(test)] +pub(crate) mod tests; + +pub mod core { + //! NATS Core functionality implementation + + pub use crate::client::{Client, ClientBuilder, Echo, QuickInfo}; + pub(crate) use crate::multiplexed_subscription::MultiplexedSubscription; + pub use crate::subscription::Subscription; + pub use watermelon_mini::AuthenticationMethod; + + pub mod publish { + //! Utilities for publishing messages + + pub use crate::client::{ + ClientPublish, DoClientPublish, DoOwnedClientPublish, OwnedClientPublish, Publish, + PublishBuilder, + }; + } + + pub mod request { + //! Utilities for publishing messages and awaiting for a response + + pub use crate::client::{ + ClientRequest, DoClientRequest, DoOwnedClientRequest, OwnedClientRequest, Request, + RequestBuilder, ResponseFut, + }; + } + + pub mod error { + //! NATS Core specific errors + + pub use crate::client::{ClientClosedError, ResponseError, TryCommandError}; + } +} + +pub mod jetstream { + //! NATS Jetstream functionality implementation + //! + //! Relies on NATS Core to communicate with the NATS server + + pub use crate::client::{ + AckPolicy, Compression, Consumer, ConsumerBatch, ConsumerConfig, ConsumerDurability, + ConsumerSpecificConfig, ConsumerStorage, ConsumerStream, ConsumerStreamError, Consumers, + DeliverPolicy, DiscardPolicy, JetstreamClient, ReplayPolicy, RetentionPolicy, Storage, + Stream, StreamConfig, StreamState, Streams, + }; + + pub mod error { + //! NATS Jetstream specific errors + + pub use crate::client::{JetstreamError, JetstreamError2, JetstreamErrorCode}; + } +} diff --git a/watermelon/src/multiplexed_subscription.rs b/watermelon/src/multiplexed_subscription.rs new file mode 100644 index 0000000..b2940ed --- /dev/null +++ b/watermelon/src/multiplexed_subscription.rs @@ -0,0 +1,69 @@ +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use tokio::sync::oneshot; +use watermelon_proto::{ServerMessage, Subject}; + +use crate::{client::ClientClosedError, core::Client}; + +#[derive(Debug)] +pub(crate) struct MultiplexedSubscription { + subscription: Option, +} + +#[derive(Debug)] +struct Inner { + reply_subject: Subject, + receiver: oneshot::Receiver, + client: Client, +} + +impl MultiplexedSubscription { + pub(crate) fn new( + reply_subject: Subject, + receiver: oneshot::Receiver, + client: Client, + ) -> Self { + Self { + subscription: Some(Inner { + reply_subject, + receiver, + client, + }), + } + } +} + +impl Future for MultiplexedSubscription { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let subscription = self + .subscription + .as_mut() + .expect("MultiplexedSubscription polled after completing"); + + match Pin::new(&mut subscription.receiver).poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(result) => { + self.subscription = None; + Poll::Ready(result.map_err(|_| ClientClosedError)) + } + } + } +} + +impl Drop for MultiplexedSubscription { + fn drop(&mut self) { + let Some(subscription) = self.subscription.take() else { + return; + }; + + subscription + .client + .lazy_unsubscribe_multiplexed(subscription.reply_subject); + } +} diff --git a/watermelon/src/subscription.rs b/watermelon/src/subscription.rs new file mode 100644 index 0000000..99da902 --- /dev/null +++ b/watermelon/src/subscription.rs @@ -0,0 +1,375 @@ +use std::{ + num::NonZeroU64, + pin::Pin, + task::{Context, Poll}, +}; + +use futures_core::{FusedStream, Stream}; +use tokio::sync::mpsc; +use watermelon_proto::{error::ServerError, ServerMessage, SubscriptionId}; + +use crate::core::{error::ClientClosedError, Client}; + +const BATCH_RECEIVE_SIZE: usize = 16; + +/// A NATS subscription +/// +/// Receives messages coming from the NATS server with At Most Once Delivery. +/// +/// Messages are yielded via the [`Stream`] implementation as they are received by client. +/// Errors can only occur immediately after subscribing or after the client reconnects. +/// +/// The subscription MUST be polled continuously. If the subscription is not polled +/// for a relatively long period of time the internal buffers will fill up and any +/// further messages will be dropped. +/// +/// Obtained from [`Client::subscribe`]. +#[derive(Debug)] +pub struct Subscription { + pub(crate) id: SubscriptionId, + client: Client, + receiver: mpsc::Receiver>, + receiver_queue: Vec>, + status: SubscriptionStatus, +} + +#[derive(Debug, Copy, Clone)] +enum SubscriptionStatus { + Subscribed, + Unsubscribed, +} + +impl Subscription { + pub(crate) fn new( + id: SubscriptionId, + client: Client, + receiver: mpsc::Receiver>, + ) -> Self { + Self { + id, + client, + receiver, + receiver_queue: Vec::with_capacity(BATCH_RECEIVE_SIZE), + status: SubscriptionStatus::Subscribed, + } + } + + /// Immediately close the subscription + /// + /// The `Stream` implementation will continue to yield any remaining + /// in-flight or otherwise buffered messages. + /// + /// Calling this method multiple times is a NOOP. + /// + /// # Errors + /// + /// It returns an error if the client is closed. + pub async fn close(&mut self) -> Result<(), ClientClosedError> { + match (self.status, self.receiver.is_closed()) { + (SubscriptionStatus::Subscribed, true) => { + self.status = SubscriptionStatus::Unsubscribed; + } + (SubscriptionStatus::Subscribed, false) => { + self.client.unsubscribe(self.id, None).await?; + self.status = SubscriptionStatus::Unsubscribed; + } + (SubscriptionStatus::Unsubscribed, _) => {} + } + + Ok(()) + } + + /// Close the subscription after `max_messages` have been delivered + /// + /// Ask the NATS Server to automatically close the subscription after + /// `max_messages` have been sent to the client. + /// + ///
+ /// Calling this method does not guarantee that the Stream will + /// deliver the exact number of messages specified in max_messages. + ///
+ /// + /// More or less messages may be delivered to the client due to race conditions + /// or data loss between it and the server. + /// + /// More messages could be delivered if the server receives the `close_after` command + /// after it has already started buffering messages to send to the client. + /// The same race condition could occur after a reconnect. + /// + /// Fewer messages could be delivered if the connection handler is faster + /// at producing messages than the [`Stream`] is at reading them, causing the + /// channel to fill up and starting to drop subsequent messages. + /// + /// # Errors + /// + /// It returns an error if the client is closed. + pub async fn close_after(&mut self, max_messages: NonZeroU64) -> Result<(), ClientClosedError> { + match (self.status, self.receiver.is_closed()) { + (SubscriptionStatus::Subscribed, true) => { + self.status = SubscriptionStatus::Unsubscribed; + } + (SubscriptionStatus::Subscribed, false) => { + self.client.unsubscribe(self.id, Some(max_messages)).await?; + } + (SubscriptionStatus::Unsubscribed, _) => {} + } + + Ok(()) + } +} + +impl Stream for Subscription { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + + if let Some(msg) = this.receiver_queue.pop() { + return Poll::Ready(Some(msg)); + } + + match Pin::new(&mut this.receiver).poll_recv_many( + cx, + &mut this.receiver_queue, + BATCH_RECEIVE_SIZE, + ) { + Poll::Pending => Poll::Pending, + Poll::Ready(n @ 1..) => { + debug_assert_eq!(n, this.receiver_queue.len()); + this.receiver_queue.reverse(); + Poll::Ready(this.receiver_queue.pop()) + } + Poll::Ready(0) => { + this.status = SubscriptionStatus::Unsubscribed; + Poll::Ready(None) + } + } + } + + fn size_hint(&self) -> (usize, Option) { + (self.receiver_queue.len(), None) + } +} + +impl FusedStream for Subscription { + fn is_terminated(&self) -> bool { + self.receiver.is_closed() && self.receiver_queue.is_empty() + } +} + +impl Drop for Subscription { + fn drop(&mut self) { + if matches!(self.status, SubscriptionStatus::Unsubscribed) || self.receiver.is_closed() { + return; + } + + self.client.lazy_unsubscribe(self.id, None); + } +} + +#[cfg(test)] +mod tests { + use std::{ + future::Future, + pin::pin, + task::{Context, Poll}, + }; + + use bytes::Bytes; + use claims::assert_matches; + use futures_util::{task::noop_waker_ref, StreamExt}; + use tokio::sync::mpsc::error::TryRecvError; + use watermelon_proto::{ + headers::HeaderMap, MessageBase, ServerMessage, StatusCode, Subject, SubscriptionId, + }; + + use crate::{core::Client, handler::HandlerCommand}; + + #[tokio::test] + async fn subscribe() { + let (client, mut handler) = Client::test(1); + + let mut subscription = client + .subscribe(Subject::from_static("abcd.>"), None) + .await + .unwrap(); + + let subscribe_command = handler.receiver.try_recv().unwrap(); + let HandlerCommand::Subscribe { + id, + subject, + queue_group, + messages, + } = subscribe_command + else { + unreachable!() + }; + assert_eq!(SubscriptionId::from(1), id); + assert_eq!(Subject::from_static("abcd.>"), subject); + assert_eq!(None, queue_group); + + // Messages are delivered as expected + + let (flag, waker) = crate::tests::FlagWaker::new(); + let mut cx = Context::from_waker(&waker); + + let mut expected_wakes = 0; + for num_messages in 0..32 { + assert!(subscription.poll_next_unpin(&mut cx).is_pending()); + assert_eq!(expected_wakes, flag.wakes()); + + let msgs = (0..num_messages) + .map(|num| ServerMessage { + status_code: Some(StatusCode::OK), + subscription_id: SubscriptionId::from(1), + base: MessageBase { + subject: format!("abcd.{num}").try_into().unwrap(), + reply_subject: None, + headers: HeaderMap::new(), + payload: Bytes::from_static(b"test"), + }, + }) + .collect::>(); + for msg in &msgs { + messages.try_send(Ok(msg.clone())).unwrap(); + } + if num_messages > 0 { + expected_wakes += 1; + } + + assert_eq!(expected_wakes, flag.wakes()); + for msg in msgs { + assert_eq!( + Poll::Ready(Some(Ok(msg))), + subscription.poll_next_unpin(&mut cx) + ); + } + assert!(subscription.poll_next_unpin(&mut cx).is_pending()); + } + + drop(messages); + expected_wakes += 1; + + assert_eq!(expected_wakes, flag.wakes()); + assert_eq!(Poll::Ready(None), subscription.poll_next_unpin(&mut cx)); + } + + #[tokio::test] + async fn unsubscribe() { + let (client, mut handler) = Client::test(1); + + let mut subscription = client + .subscribe(Subject::from_static("abcd.>"), None) + .await + .unwrap(); + + let subscribe_command = handler.receiver.try_recv().unwrap(); + assert_matches!(subscribe_command, HandlerCommand::Subscribe { .. }); + + // Closing the subscription sends `Unsubscribe` + + subscription.close().await.unwrap(); + let HandlerCommand::Unsubscribe { + id, + max_messages: None, + } = handler.receiver.try_recv().unwrap() + else { + unreachable!() + }; + assert_eq!(SubscriptionId::from(1), id); + + // Unsubscribing again is a NOOP + + subscription.close().await.unwrap(); + assert_eq!( + TryRecvError::Empty, + handler.receiver.try_recv().unwrap_err() + ); + + // Same when dropping the subscription + + drop(subscription); + assert_eq!( + TryRecvError::Empty, + handler.receiver.try_recv().unwrap_err() + ); + } + + #[tokio::test] + async fn drop_unsubscribe() { + let (client, mut handler) = Client::test(1); + + let subscription = client + .subscribe(Subject::from_static("abcd.>"), None) + .await + .unwrap(); + + let subscribe_command = handler.receiver.try_recv().unwrap(); + let HandlerCommand::Subscribe { + id, + subject, + queue_group, + messages: _, + } = subscribe_command + else { + unreachable!() + }; + assert_eq!(SubscriptionId::from(1), id); + assert_eq!(Subject::from_static("abcd.>"), subject); + assert_eq!(None, queue_group); + + // Dropping `Subscription` sends `Unsubscribe` + + drop(subscription); + let HandlerCommand::Unsubscribe { + id, + max_messages: None, + } = handler.receiver.try_recv().unwrap() + else { + unreachable!() + }; + assert_eq!(SubscriptionId::from(1), id); + } + + #[tokio::test] + async fn subscribe_is_cancel_safe() { + let (client, mut handler) = Client::test(1); + + let subscription = client + .subscribe(Subject::from_static("abcd.>"), None) + .await + .unwrap(); + + { + let subscribe_future = pin!(client.subscribe(Subject::from_static("dcba.>"), None)); + + let mut cx = Context::from_waker(noop_waker_ref()); + assert!(subscribe_future.poll(&mut cx).is_pending()); + } + + let subscribe_command = handler.receiver.try_recv().unwrap(); + let HandlerCommand::Subscribe { id, .. } = subscribe_command else { + unreachable!() + }; + assert_eq!(SubscriptionId::from(1), id); + + let subscription2 = client + .subscribe(Subject::from_static("abcd.>"), None) + .await + .unwrap(); + + let subscribe_command = handler.receiver.try_recv().unwrap(); + let HandlerCommand::Subscribe { id, .. } = subscribe_command else { + unreachable!() + }; + assert_eq!(SubscriptionId::from(2), id); + + // Failing to unsubscribe triggers `is_failed_unsubscribe = true` + + assert!(!handler.quick_info.get().is_failed_unsubscribe); + drop(subscription); + assert!(!handler.quick_info.get().is_failed_unsubscribe); + drop(subscription2); + assert!(handler.quick_info.get().is_failed_unsubscribe); + } +} diff --git a/watermelon/src/tests.rs b/watermelon/src/tests.rs new file mode 100644 index 0000000..ee3aea0 --- /dev/null +++ b/watermelon/src/tests.rs @@ -0,0 +1,26 @@ +use std::{sync::Arc, task::Waker}; + +use futures_util::task::ArcWake; + +use crate::atomic::{AtomicUsize, Ordering}; + +#[derive(Debug)] +pub(crate) struct FlagWaker(AtomicUsize); + +impl FlagWaker { + pub(crate) fn new() -> (Arc, Waker) { + let this = Arc::new(Self(AtomicUsize::new(0))); + let waker = futures_util::task::waker(Arc::clone(&this)); + (this, waker) + } + + pub(crate) fn wakes(&self) -> usize { + self.0.load(Ordering::Acquire) + } +} + +impl ArcWake for FlagWaker { + fn wake_by_ref(arc_self: &Arc) { + arc_self.0.fetch_add(1, Ordering::AcqRel); + } +}