From f5581d496bd51bcd11a997d788aa22e30c651156 Mon Sep 17 00:00:00 2001 From: Jaydip Gabani Date: Mon, 29 Jul 2024 01:50:39 +0000 Subject: [PATCH] adding eventhub driver Signed-off-by: Jaydip Gabani --- Makefile | 3 + go.mod | 4 + go.sum | 8 + .../pubsub/pubsub_config_controller.go | 8 +- pkg/pubsub/dapr/dapr.go | 38 +- pkg/pubsub/dapr/dapr_test.go | 2 +- pkg/pubsub/dapr/fake_dapr_client.go | 34 +- pkg/pubsub/eventhub/eventhub.go | 91 + pkg/pubsub/provider/provider.go | 2 + pkg/pubsub/system.go | 6 +- test/pubsub/publish-components.yaml | 36 +- .../azure-sdk-for-go/sdk/azcore/CHANGELOG.md | 786 ++++++ .../azure-sdk-for-go/sdk/azcore/LICENSE.txt | 21 + .../azure-sdk-for-go/sdk/azcore/README.md | 39 + .../Azure/azure-sdk-for-go/sdk/azcore/ci.yml | 29 + .../sdk/azcore/cloud/cloud.go | 44 + .../azure-sdk-for-go/sdk/azcore/cloud/doc.go | 53 + .../Azure/azure-sdk-for-go/sdk/azcore/core.go | 173 ++ .../Azure/azure-sdk-for-go/sdk/azcore/doc.go | 264 ++ .../azure-sdk-for-go/sdk/azcore/errors.go | 14 + .../Azure/azure-sdk-for-go/sdk/azcore/etag.go | 57 + .../sdk/azcore/internal/exported/exported.go | 175 ++ .../sdk/azcore/internal/exported/pipeline.go | 77 + .../sdk/azcore/internal/exported/request.go | 223 ++ .../internal/exported/response_error.go | 167 ++ .../sdk/azcore/internal/log/log.go | 50 + .../azcore/internal/pollers/async/async.go | 159 ++ .../sdk/azcore/internal/pollers/body/body.go | 135 ++ .../sdk/azcore/internal/pollers/fake/fake.go | 133 + .../sdk/azcore/internal/pollers/loc/loc.go | 123 + .../sdk/azcore/internal/pollers/op/op.go | 145 ++ .../sdk/azcore/internal/pollers/poller.go | 24 + .../sdk/azcore/internal/pollers/util.go | 200 ++ .../sdk/azcore/internal/shared/constants.go | 44 + .../sdk/azcore/internal/shared/shared.go | 149 ++ .../azure-sdk-for-go/sdk/azcore/log/doc.go | 10 + .../azure-sdk-for-go/sdk/azcore/log/log.go | 55 + .../azure-sdk-for-go/sdk/azcore/policy/doc.go | 10 + .../sdk/azcore/policy/policy.go | 197 ++ .../sdk/azcore/runtime/doc.go | 10 + .../sdk/azcore/runtime/errors.go | 27 + .../sdk/azcore/runtime/pager.go | 128 + .../sdk/azcore/runtime/pipeline.go | 94 + .../sdk/azcore/runtime/policy_api_version.go | 75 + .../sdk/azcore/runtime/policy_bearer_token.go | 123 + .../azcore/runtime/policy_body_download.go | 72 + .../sdk/azcore/runtime/policy_http_header.go | 40 + .../sdk/azcore/runtime/policy_http_trace.go | 143 ++ .../azcore/runtime/policy_include_response.go | 35 + .../azcore/runtime/policy_key_credential.go | 64 + .../sdk/azcore/runtime/policy_logging.go | 264 ++ .../sdk/azcore/runtime/policy_request_id.go | 34 + .../sdk/azcore/runtime/policy_retry.go | 255 ++ .../azcore/runtime/policy_sas_credential.go | 55 + .../sdk/azcore/runtime/policy_telemetry.go | 83 + .../sdk/azcore/runtime/poller.go | 389 +++ .../sdk/azcore/runtime/request.go | 265 ++ .../sdk/azcore/runtime/response.go | 109 + .../runtime/transport_default_dialer_other.go | 15 + .../runtime/transport_default_dialer_wasm.go | 15 + .../runtime/transport_default_http_client.go | 48 + .../sdk/azcore/streaming/doc.go | 9 + .../sdk/azcore/streaming/progress.go | 89 + .../azure-sdk-for-go/sdk/azcore/to/doc.go | 9 + .../azure-sdk-for-go/sdk/azcore/to/to.go | 21 + .../sdk/azcore/tracing/constants.go | 41 + .../sdk/azcore/tracing/tracing.go | 191 ++ .../azure-sdk-for-go/sdk/internal/LICENSE.txt | 21 + .../sdk/internal/diag/diag.go | 51 + .../azure-sdk-for-go/sdk/internal/diag/doc.go | 7 + .../sdk/internal/errorinfo/doc.go | 7 + .../sdk/internal/errorinfo/errorinfo.go | 46 + .../sdk/internal/exported/exported.go | 129 + .../azure-sdk-for-go/sdk/internal/log/doc.go | 7 + .../azure-sdk-for-go/sdk/internal/log/log.go | 104 + .../sdk/internal/poller/util.go | 155 ++ .../sdk/internal/telemetry/telemetry.go | 33 + .../sdk/internal/temporal/resource.go | 123 + .../azure-sdk-for-go/sdk/internal/uuid/doc.go | 7 + .../sdk/internal/uuid/uuid.go | 76 + .../sdk/messaging/azeventhubs/CHANGELOG.md | 177 ++ .../sdk/messaging/azeventhubs/LICENSE.txt | 21 + .../sdk/messaging/azeventhubs/README.md | 133 + .../sdk/messaging/azeventhubs/amqp_message.go | 271 +++ .../messaging/azeventhubs/checkpoint_store.go | 70 + .../sdk/messaging/azeventhubs/ci.yml | 35 + .../connection_string_properties.go | 21 + .../messaging/azeventhubs/consumer_client.go | 262 ++ .../sdk/messaging/azeventhubs/doc.go | 15 + .../sdk/messaging/azeventhubs/error.go | 31 + .../sdk/messaging/azeventhubs/event_data.go | 195 ++ .../messaging/azeventhubs/event_data_batch.go | 236 ++ .../azeventhubs/internal/amqpInterfaces.go | 21 + .../azeventhubs/internal/amqp_fakes.go | 149 ++ .../azeventhubs/internal/amqpwrap/amqpwrap.go | 307 +++ .../azeventhubs/internal/amqpwrap/error.go | 42 + .../azeventhubs/internal/amqpwrap/rpc.go | 27 + .../azeventhubs/internal/auth/token.go | 39 + .../sdk/messaging/azeventhubs/internal/cbs.go | 78 + .../azeventhubs/internal/constants.go | 7 + .../azeventhubs/internal/eh/eh_internal.go | 21 + .../messaging/azeventhubs/internal/errors.go | 265 ++ .../exported/connection_string_properties.go | 129 + .../azeventhubs/internal/exported/error.go | 58 + .../internal/exported/log_events.go | 23 + .../internal/exported/retry_options.go | 26 + .../exported/websocket_conn_params.go | 13 + .../messaging/azeventhubs/internal/links.go | 395 +++ .../azeventhubs/internal/links_recover.go | 155 ++ .../azeventhubs/internal/namespace.go | 512 ++++ .../azeventhubs/internal/namespace_eh.go | 48 + .../sdk/messaging/azeventhubs/internal/rpc.go | 444 ++++ .../messaging/azeventhubs/internal/sas/sas.go | 179 ++ .../internal/sbauth/token_provider.go | 138 ++ .../azeventhubs/internal/utils/retrier.go | 138 ++ .../sdk/messaging/azeventhubs/log.go | 23 + .../sdk/messaging/azeventhubs/mgmt.go | 253 ++ .../messaging/azeventhubs/migrationguide.md | 106 + .../messaging/azeventhubs/partition_client.go | 380 +++ .../sdk/messaging/azeventhubs/processor.go | 515 ++++ .../azeventhubs/processor_load_balancer.go | 302 +++ .../azeventhubs/processor_partition_client.go | 73 + .../messaging/azeventhubs/producer_client.go | 312 +++ .../sdk/messaging/azeventhubs/sample.env | 20 + .../azeventhubs/test-resources.bicep | 225 ++ .../github.com/Azure/go-amqp/.gitattributes | 3 + vendor/github.com/Azure/go-amqp/.gitignore | 12 + vendor/github.com/Azure/go-amqp/CHANGELOG.md | 174 ++ .../Azure/go-amqp/CODE_OF_CONDUCT.md | 9 + .../github.com/Azure/go-amqp/CONTRIBUTING.md | 76 + vendor/github.com/Azure/go-amqp/LICENSE | 22 + vendor/github.com/Azure/go-amqp/Makefile | 31 + vendor/github.com/Azure/go-amqp/NOTICE.txt | 29 + vendor/github.com/Azure/go-amqp/README.md | 194 ++ vendor/github.com/Azure/go-amqp/SECURITY.md | 41 + .../Azure/go-amqp/azure-pipelines.yml | 105 + vendor/github.com/Azure/go-amqp/conn.go | 1147 +++++++++ vendor/github.com/Azure/go-amqp/const.go | 93 + vendor/github.com/Azure/go-amqp/creditor.go | 117 + vendor/github.com/Azure/go-amqp/doc.go | 10 + vendor/github.com/Azure/go-amqp/errors.go | 104 + .../Azure/go-amqp/internal/bitmap/bitmap.go | 96 + .../Azure/go-amqp/internal/buffer/buffer.go | 177 ++ .../Azure/go-amqp/internal/debug/debug.go | 17 + .../go-amqp/internal/debug/debug_debug.go | 48 + .../Azure/go-amqp/internal/encoding/decode.go | 1149 +++++++++ .../Azure/go-amqp/internal/encoding/encode.go | 570 +++++ .../Azure/go-amqp/internal/encoding/types.go | 2152 +++++++++++++++++ .../Azure/go-amqp/internal/frames/frames.go | 1540 ++++++++++++ .../Azure/go-amqp/internal/frames/parsing.go | 159 ++ .../Azure/go-amqp/internal/queue/queue.go | 162 ++ .../Azure/go-amqp/internal/shared/shared.go | 34 + vendor/github.com/Azure/go-amqp/link.go | 393 +++ .../github.com/Azure/go-amqp/link_options.go | 238 ++ vendor/github.com/Azure/go-amqp/message.go | 500 ++++ vendor/github.com/Azure/go-amqp/receiver.go | 909 +++++++ vendor/github.com/Azure/go-amqp/sasl.go | 259 ++ vendor/github.com/Azure/go-amqp/sender.go | 505 ++++ vendor/github.com/Azure/go-amqp/session.go | 822 +++++++ vendor/modules.txt | 50 + 160 files changed, 26023 insertions(+), 71 deletions(-) create mode 100644 pkg/pubsub/eventhub/eventhub.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/CHANGELOG.md create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/LICENSE.txt create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/README.md create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/ci.yml create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud/cloud.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud/doc.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/core.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/doc.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/errors.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/etag.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported/exported.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported/pipeline.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported/request.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported/response_error.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/log/log.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/async/async.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/body/body.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/fake/fake.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/loc/loc.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/op/op.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/poller.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/util.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared/constants.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared/shared.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/log/doc.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/log/log.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/policy/doc.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/policy/policy.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/doc.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/errors.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/pager.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/pipeline.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_api_version.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_bearer_token.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_body_download.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_http_header.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_http_trace.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_include_response.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_key_credential.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_logging.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_request_id.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_retry.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_sas_credential.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_telemetry.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/poller.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/request.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/response.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/transport_default_dialer_other.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/transport_default_dialer_wasm.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/transport_default_http_client.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming/doc.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming/progress.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/to/doc.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/to/to.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/tracing/constants.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/tracing/tracing.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/LICENSE.txt create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/diag/diag.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/diag/doc.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/errorinfo/doc.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/errorinfo/errorinfo.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/exported/exported.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/log/doc.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/log/log.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/poller/util.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/telemetry/telemetry.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/temporal/resource.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/uuid/doc.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/uuid/uuid.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/CHANGELOG.md create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/LICENSE.txt create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/README.md create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/amqp_message.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/checkpoint_store.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/ci.yml create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/connection_string_properties.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/consumer_client.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/doc.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/error.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/event_data.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/event_data_batch.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/amqpInterfaces.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/amqp_fakes.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/amqpwrap/amqpwrap.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/amqpwrap/error.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/amqpwrap/rpc.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/auth/token.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/cbs.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/constants.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/eh/eh_internal.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/errors.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/exported/connection_string_properties.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/exported/error.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/exported/log_events.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/exported/retry_options.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/exported/websocket_conn_params.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/links.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/links_recover.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/namespace.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/namespace_eh.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/rpc.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/sas/sas.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/sbauth/token_provider.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/utils/retrier.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/log.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/mgmt.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/migrationguide.md create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/partition_client.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/processor.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/processor_load_balancer.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/processor_partition_client.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/producer_client.go create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/sample.env create mode 100644 vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/test-resources.bicep create mode 100644 vendor/github.com/Azure/go-amqp/.gitattributes create mode 100644 vendor/github.com/Azure/go-amqp/.gitignore create mode 100644 vendor/github.com/Azure/go-amqp/CHANGELOG.md create mode 100644 vendor/github.com/Azure/go-amqp/CODE_OF_CONDUCT.md create mode 100644 vendor/github.com/Azure/go-amqp/CONTRIBUTING.md create mode 100644 vendor/github.com/Azure/go-amqp/LICENSE create mode 100644 vendor/github.com/Azure/go-amqp/Makefile create mode 100644 vendor/github.com/Azure/go-amqp/NOTICE.txt create mode 100644 vendor/github.com/Azure/go-amqp/README.md create mode 100644 vendor/github.com/Azure/go-amqp/SECURITY.md create mode 100644 vendor/github.com/Azure/go-amqp/azure-pipelines.yml create mode 100644 vendor/github.com/Azure/go-amqp/conn.go create mode 100644 vendor/github.com/Azure/go-amqp/const.go create mode 100644 vendor/github.com/Azure/go-amqp/creditor.go create mode 100644 vendor/github.com/Azure/go-amqp/doc.go create mode 100644 vendor/github.com/Azure/go-amqp/errors.go create mode 100644 vendor/github.com/Azure/go-amqp/internal/bitmap/bitmap.go create mode 100644 vendor/github.com/Azure/go-amqp/internal/buffer/buffer.go create mode 100644 vendor/github.com/Azure/go-amqp/internal/debug/debug.go create mode 100644 vendor/github.com/Azure/go-amqp/internal/debug/debug_debug.go create mode 100644 vendor/github.com/Azure/go-amqp/internal/encoding/decode.go create mode 100644 vendor/github.com/Azure/go-amqp/internal/encoding/encode.go create mode 100644 vendor/github.com/Azure/go-amqp/internal/encoding/types.go create mode 100644 vendor/github.com/Azure/go-amqp/internal/frames/frames.go create mode 100644 vendor/github.com/Azure/go-amqp/internal/frames/parsing.go create mode 100644 vendor/github.com/Azure/go-amqp/internal/queue/queue.go create mode 100644 vendor/github.com/Azure/go-amqp/internal/shared/shared.go create mode 100644 vendor/github.com/Azure/go-amqp/link.go create mode 100644 vendor/github.com/Azure/go-amqp/link_options.go create mode 100644 vendor/github.com/Azure/go-amqp/message.go create mode 100644 vendor/github.com/Azure/go-amqp/receiver.go create mode 100644 vendor/github.com/Azure/go-amqp/sasl.go create mode 100644 vendor/github.com/Azure/go-amqp/sender.go create mode 100644 vendor/github.com/Azure/go-amqp/session.go diff --git a/Makefile b/Makefile index 1675eb6272d..03505c07d2a 100644 --- a/Makefile +++ b/Makefile @@ -99,6 +99,9 @@ MANAGER_IMAGE_PATCH := "apiVersion: apps/v1\ \n - --default-create-vap-for-templates=${GENERATE_VAP}\ \n - --default-create-vap-binding-for-constraints=${GENERATE_VAPBINDING}\ \n - --experimental-enable-k8s-native-validation\ +\n - --enable-pub-sub=${ENABLE_PUBSUB}\ +\n - --audit-connection=${AUDIT_CONNECTION}\ +\n - --audit-channel=${AUDIT_CHANNEL}\ \n" # Get the currently used golang install path (in GOPATH/bin, unless GOBIN is set) diff --git a/go.mod b/go.mod index 96bce7b53f8..1e4728d2026 100644 --- a/go.mod +++ b/go.mod @@ -56,6 +56,10 @@ require ( cloud.google.com/go/compute/metadata v0.5.0 // indirect cloud.google.com/go/monitoring v1.20.1 // indirect github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24 // indirect + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1 // indirect + github.com/Azure/azure-sdk-for-go/sdk/internal v1.7.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs v1.2.1 // indirect + github.com/Azure/go-amqp v1.0.5 // indirect github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.20.0 // indirect github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.44.0 // indirect github.com/Microsoft/hcsshim v0.11.4 // indirect diff --git a/go.sum b/go.sum index a8fa188ea07..b4be5344c31 100644 --- a/go.sum +++ b/go.sum @@ -16,6 +16,14 @@ cloud.google.com/go/trace v1.10.11 h1:+Y1emOgcyGy6OdJ2KQbT4t2oecPp49GtJn8j3GM1pW cloud.google.com/go/trace v1.10.11/go.mod h1:fUr5L3wSXerNfT0f1bBg08W4axS2VbHGgYcfH4KuTXU= github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24 h1:bvDV9vkmnHYOMsOr4WLk+Vo07yKIzd94sVoIqshQ4bU= github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1 h1:E+OJmp2tPvt1W+amx48v1eqbjDYsgN+RzP4q16yV5eM= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1/go.mod h1:a6xsAQUZg+VsS3TJ05SRp524Hs4pZ/AeFSr5ENf0Yjo= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.7.0 h1:rTfKOCZGy5ViVrlA74ZPE99a+SgoEE2K/yg3RyW9dFA= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.7.0/go.mod h1:4OG6tQ9EOP/MT0NMjDlRzWoVFxfu9rN9B2X+tlSVktg= +github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs v1.2.1 h1:0f6XnzroY1yCQQwxGf/n/2xlaBF02Qhof2as99dGNsY= +github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs v1.2.1/go.mod h1:vMGz6NOUGJ9h5ONl2kkyaqq5E0g7s4CHNSrXN5fl8UY= +github.com/Azure/go-amqp v1.0.5 h1:po5+ljlcNSU8xtapHTe8gIc8yHxCzC03E8afH2g1ftU= +github.com/Azure/go-amqp v1.0.5/go.mod h1:vZAogwdrkbyK3Mla8m/CxSc/aKdnTZ4IbPxl51Y5WZE= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.20.0 h1:tk85AYGwOf6VNtoOQi8w/kVDi2vmPxp3/OU2FsUpdcA= github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.20.0/go.mod h1:Xx0VKh7GJ4si3rmElbh19Mejxz68ibWg/J30ZOMrqzU= diff --git a/pkg/controller/pubsub/pubsub_config_controller.go b/pkg/controller/pubsub/pubsub_config_controller.go index e70ac499587..f52d8e2c10c 100644 --- a/pkg/controller/pubsub/pubsub_config_controller.go +++ b/pkg/controller/pubsub/pubsub_config_controller.go @@ -2,7 +2,6 @@ package pubsub import ( "context" - "encoding/json" "flag" "fmt" @@ -118,13 +117,8 @@ func (r *Reconciler) Reconcile(ctx context.Context, request reconcile.Request) ( if _, ok := cfg.Data["provider"]; !ok { return reconcile.Result{}, fmt.Errorf(fmt.Sprintf("missing provider field in configmap %s, unable to configure respective pubsub", request.NamespacedName)) } - var config interface{} - err = json.Unmarshal([]byte(cfg.Data["config"]), &config) - if err != nil { - return reconcile.Result{}, err - } - err = r.system.UpsertConnection(ctx, config, request.Name, cfg.Data["provider"]) + err = r.system.UpsertConnection(ctx, cfg.Data["config"], request.Name, cfg.Data["provider"]) if err != nil { return reconcile.Result{}, err } diff --git a/pkg/pubsub/dapr/dapr.go b/pkg/pubsub/dapr/dapr.go index 0db60445494..fe0d04d3187 100644 --- a/pkg/pubsub/dapr/dapr.go +++ b/pkg/pubsub/dapr/dapr.go @@ -9,18 +9,13 @@ import ( "github.com/open-policy-agent/gatekeeper/v3/pkg/pubsub/connection" ) -type ClientConfig struct { - // Name of the component to be used for pub sub messaging - Component string `json:"component"` -} - // Dapr represents driver for interacting with pub sub using dapr. type Dapr struct { // Array of clients to talk to different endpoints client daprClient.Client // Name of the pubsub component - pubSubComponent string + Component string `json:"component"` } const ( @@ -33,7 +28,7 @@ func (r *Dapr) Publish(_ context.Context, data interface{}, topic string) error return fmt.Errorf("error marshaling data: %w", err) } - err = r.client.PublishEvent(context.Background(), r.pubSubComponent, topic, jsonData) + err = r.client.PublishEvent(context.Background(), r.Component, topic, jsonData) if err != nil { return fmt.Errorf("error publishing message to dapr: %w", err) } @@ -46,38 +41,35 @@ func (r *Dapr) CloseConnection() error { } func (r *Dapr) UpdateConnection(_ context.Context, config interface{}) error { - var cfg ClientConfig - m, ok := config.(map[string]interface{}) + dClient := &Dapr{} + cfg, ok := config.(string) if !ok { return fmt.Errorf("invalid type assertion, config is not in expected format") } - cfg.Component, ok = m["component"].(string) - if !ok { - return fmt.Errorf("failed to get value of component") + err := json.Unmarshal([]byte(cfg), &dClient) + if err != nil { + return err } - r.pubSubComponent = cfg.Component + r.Component = dClient.Component return nil } // Returns a new client for dapr. func NewConnection(_ context.Context, config interface{}) (connection.Connection, error) { - var cfg ClientConfig - m, ok := config.(map[string]interface{}) + dClient := &Dapr{} + cfg, ok := config.(string) if !ok { return nil, fmt.Errorf("invalid type assertion, config is not in expected format") } - cfg.Component, ok = m["component"].(string) - if !ok { - return nil, fmt.Errorf("failed to get value of component") + err := json.Unmarshal([]byte(cfg), &dClient) + if err != nil { + return nil, err } - tmp, err := daprClient.NewClient() + dClient.client, err = daprClient.NewClient() if err != nil { return nil, err } - return &Dapr{ - client: tmp, - pubSubComponent: cfg.Component, - }, nil + return dClient, nil } diff --git a/pkg/pubsub/dapr/dapr_test.go b/pkg/pubsub/dapr/dapr_test.go index 5a2e72615b1..94b6313f6e9 100644 --- a/pkg/pubsub/dapr/dapr_test.go +++ b/pkg/pubsub/dapr/dapr_test.go @@ -144,7 +144,7 @@ func TestDapr_UpdateConnection(t *testing.T) { assert.True(t, ok) tmp, ok := r.(*Dapr) assert.True(t, ok) - assert.Equal(t, cmp, tmp.pubSubComponent) + assert.Equal(t, cmp, tmp.Component) } }) } diff --git a/pkg/pubsub/dapr/fake_dapr_client.go b/pkg/pubsub/dapr/fake_dapr_client.go index 4bd36da5ecd..0939faceff2 100644 --- a/pkg/pubsub/dapr/fake_dapr_client.go +++ b/pkg/pubsub/dapr/fake_dapr_client.go @@ -331,7 +331,7 @@ func FakeConnection() (connection.Connection, func()) { c, f := getTestClient(ctx) return &Dapr{ client: c, - pubSubComponent: "test", + Component: "test", }, f } @@ -340,7 +340,7 @@ type FakeDapr struct { client daprClient.Client // Name of the pubsub component - pubSubComponent string + Component string `json:"component"` // closing function f func() @@ -356,36 +356,32 @@ func (r *FakeDapr) CloseConnection() error { } func (r *FakeDapr) UpdateConnection(_ context.Context, config interface{}) error { - var cfg ClientConfig - m, ok := config.(map[string]interface{}) + fClient := &FakeDapr{} + cfg, ok := config.(string) if !ok { return fmt.Errorf("invalid type assertion, config is not in expected format") } - cfg.Component, ok = m["component"].(string) - if !ok { - return fmt.Errorf("failed to get value of component") + err := json.Unmarshal([]byte(cfg), &fClient) + if err != nil { + return err } - r.pubSubComponent = cfg.Component + r.Component = fClient.Component return nil } // Returns a fake client for dapr. func FakeNewConnection(ctx context.Context, config interface{}) (connection.Connection, error) { - var cfg ClientConfig - m, ok := config.(map[string]interface{}) + fClient := &FakeDapr{} + cfg, ok := config.(string) if !ok { return nil, fmt.Errorf("invalid type assertion, config is not in expected format") } - cfg.Component, ok = m["component"].(string) - if !ok { - return nil, fmt.Errorf("failed to get value of component") + err := json.Unmarshal([]byte(cfg), &fClient) + if err != nil { + return nil, err } - c, f := getTestClient(ctx) + fClient.client, fClient.f = getTestClient(ctx) - return &FakeDapr{ - client: c, - pubSubComponent: cfg.Component, - f: f, - }, nil + return fClient, nil } diff --git a/pkg/pubsub/eventhub/eventhub.go b/pkg/pubsub/eventhub/eventhub.go new file mode 100644 index 00000000000..701fad26186 --- /dev/null +++ b/pkg/pubsub/eventhub/eventhub.go @@ -0,0 +1,91 @@ +package eventhub + +import ( + "context" + "fmt" + "encoding/json" + + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs" + "github.com/open-policy-agent/gatekeeper/v3/pkg/pubsub/connection" +) + + +// Dapr represents driver for interacting with pub sub using dapr. +type EventHub struct { + // Array of clients to talk to different endpoints + producerClient *azeventhubs.ProducerClient + + // Name of the pubsub component + ConnectionString string `json:"connectionString"` + EventHubName string `json:"eventHubName"` +} + +const ( + Name = "eventhub" +) + +func (r *EventHub) Publish(ctx context.Context, data interface{}, topic string) error { + jsonData, err := json.Marshal(data) + if err != nil { + return fmt.Errorf("error marshaling data: %w", err) + } + + newBatchOptions := &azeventhubs.EventDataBatchOptions{} + + batch, err := r.producerClient.NewEventDataBatch(context.TODO(), newBatchOptions) + err = batch.AddEventData(&azeventhubs.EventData{ + Body: jsonData, + }, nil) + if err != nil { + return fmt.Errorf("error adding event data to batch: %w", err) + } + + err = r.producerClient.SendEventDataBatch(ctx, batch, nil) + if err != nil { + return fmt.Errorf("error publishing message to dapr: %w", err) + } + + return nil +} + +func (r *EventHub) CloseConnection() error { + return nil +} + +func (r *EventHub) UpdateConnection(_ context.Context, config interface{}) error { + cfg, ok := config.(string) + if !ok { + return fmt.Errorf("invalid type assertion, config is not in expected format") + } + + err := json.Unmarshal([]byte(cfg), &r) + if err != nil { + return err + } + + r.producerClient, err = azeventhubs.NewProducerClientFromConnectionString(r.ConnectionString, r.EventHubName, nil) + if err != nil { + return err + } + return nil +} + +// Returns a new client for dapr. +func NewConnection(_ context.Context, config interface{}) (connection.Connection, error) { + cfg, ok := config.(string) + if !ok { + return nil, fmt.Errorf("invalid type assertion, config is not in expected format") + } + client := &EventHub{} + err := json.Unmarshal([]byte(cfg), &client) + if err != nil { + return nil, err + } + + client.producerClient, err = azeventhubs.NewProducerClientFromConnectionString(client.ConnectionString, client.EventHubName, nil) + if err != nil { + return nil, err + } + + return client, nil +} diff --git a/pkg/pubsub/provider/provider.go b/pkg/pubsub/provider/provider.go index 5e1d0601014..e22cc0e49cd 100644 --- a/pkg/pubsub/provider/provider.go +++ b/pkg/pubsub/provider/provider.go @@ -5,10 +5,12 @@ import ( "github.com/open-policy-agent/gatekeeper/v3/pkg/pubsub/connection" "github.com/open-policy-agent/gatekeeper/v3/pkg/pubsub/dapr" + "github.com/open-policy-agent/gatekeeper/v3/pkg/pubsub/eventhub" ) var pubSubs = newPubSubSet(map[string]InitiateConnection{ dapr.Name: dapr.NewConnection, + eventhub.Name: eventhub.NewConnection, }, ) diff --git a/pkg/pubsub/system.go b/pkg/pubsub/system.go index da60a1be8e6..d891c2adaff 100644 --- a/pkg/pubsub/system.go +++ b/pkg/pubsub/system.go @@ -19,19 +19,19 @@ func NewSystem() *System { return &System{} } -func (s *System) Publish(_ context.Context, connection string, topic string, msg interface{}) error { +func (s *System) Publish(ctx context.Context, connection string, topic string, msg interface{}) error { s.mux.RLock() defer s.mux.RUnlock() if len(s.connections) > 0 { if c, ok := s.connections[connection]; ok { - return c.Publish(context.Background(), msg, topic) + return c.Publish(ctx, msg, topic) } return fmt.Errorf("connection is not initialized, name: %s ", connection) } return fmt.Errorf("No connections are established") } -func (s *System) UpsertConnection(ctx context.Context, config interface{}, name string, provider string) error { +func (s *System) UpsertConnection(ctx context.Context, config string, name string, provider string) error { s.mux.Lock() defer s.mux.Unlock() // Check if the connection already exists. diff --git a/test/pubsub/publish-components.yaml b/test/pubsub/publish-components.yaml index 9686935dd01..6e8e054b1c0 100644 --- a/test/pubsub/publish-components.yaml +++ b/test/pubsub/publish-components.yaml @@ -1,27 +1,27 @@ ---- -apiVersion: dapr.io/v1alpha1 -kind: Component -metadata: - name: pubsub - namespace: gatekeeper-system -spec: - type: pubsub.redis - version: v1 - metadata: - - name: redisHost - value: redis-master.default.svc.cluster.local:6379 - - name: redisPassword - secretKeyRef: - name: redis - key: redis-password +# --- +# apiVersion: dapr.io/v1alpha1 +# kind: Component +# metadata: +# name: pubsub +# namespace: gatekeeper-system +# spec: +# type: pubsub.redis +# version: v1 +# metadata: +# - name: redisHost +# value: redis-master.default.svc.cluster.local:6379 +# - name: redisPassword +# secretKeyRef: +# name: redis +# key: redis-password --- apiVersion: v1 kind: ConfigMap metadata: - name: audit + name: audit-connection namespace: gatekeeper-system data: - provider: "dapr" + provider: "eventhub" config: | { "component": "pubsub" diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/CHANGELOG.md b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/CHANGELOG.md new file mode 100644 index 00000000000..a6675492b1a --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/CHANGELOG.md @@ -0,0 +1,786 @@ +# Release History + +## 1.11.1 (2024-04-02) + +### Bugs Fixed + +* Pollers that use the `Location` header won't consider `http.StatusRequestTimeout` a terminal failure. +* `runtime.Poller[T].Result` won't consider non-terminal error responses as terminal. + +## 1.11.0 (2024-04-01) + +### Features Added + +* Added `StatusCodes` to `arm/policy.RegistrationOptions` to allow supporting non-standard HTTP status codes during registration. +* Added field `InsecureAllowCredentialWithHTTP` to `azcore.ClientOptions` and dependent authentication pipeline policies. +* Added type `MultipartContent` to the `streaming` package to support multipart/form payloads with custom Content-Type and file name. + +### Bugs Fixed + +* `runtime.SetMultipartFormData` won't try to stringify `[]byte` values. +* Pollers that use the `Location` header won't consider `http.StatusTooManyRequests` a terminal failure. + +### Other Changes + +* Update dependencies. + +## 1.10.0 (2024-02-29) + +### Features Added + +* Added logging event `log.EventResponseError` that will contain the contents of `ResponseError.Error()` whenever an `azcore.ResponseError` is created. +* Added `runtime.NewResponseErrorWithErrorCode` for creating an `azcore.ResponseError` with a caller-supplied error code. +* Added type `MatchConditions` for use in conditional requests. + +### Bugs Fixed + +* Fixed a potential race condition between `NullValue` and `IsNullValue`. +* `runtime.EncodeQueryParams` will escape semicolons before calling `url.ParseQuery`. + +### Other Changes + +* Update dependencies. + +## 1.9.2 (2024-02-06) + +### Bugs Fixed + +* `runtime.MarshalAsByteArray` and `runtime.MarshalAsJSON` will preserve the preexisting value of the `Content-Type` header. + +### Other Changes + +* Update to latest version of `internal`. + +## 1.9.1 (2023-12-11) + +### Bugs Fixed + +* The `retry-after-ms` and `x-ms-retry-after-ms` headers weren't being checked during retries. + +### Other Changes + +* Update dependencies. + +## 1.9.0 (2023-11-06) + +### Breaking Changes +> These changes affect only code written against previous beta versions of `v1.7.0` and `v1.8.0` +* The function `NewTokenCredential` has been removed from the `fake` package. Use a literal `&fake.TokenCredential{}` instead. +* The field `TracingNamespace` in `runtime.PipelineOptions` has been replaced by `TracingOptions`. + +### Bugs Fixed + +* Fixed an issue that could cause some allowed HTTP header values to not show up in logs. +* Include error text instead of error type in traces when the transport returns an error. +* Fixed an issue that could cause an HTTP/2 request to hang when the TCP connection becomes unresponsive. +* Block key and SAS authentication for non TLS protected endpoints. +* Passing a `nil` credential value will no longer cause a panic. Instead, the authentication is skipped. +* Calling `Error` on a zero-value `azcore.ResponseError` will no longer panic. +* Fixed an issue in `fake.PagerResponder[T]` that would cause a trailing error to be omitted when iterating over pages. +* Context values created by `azcore` will no longer flow across disjoint HTTP requests. + +### Other Changes + +* Skip generating trace info for no-op tracers. +* The `clientName` paramater in client constructors has been renamed to `moduleName`. + +## 1.9.0-beta.1 (2023-10-05) + +### Other Changes + +* The beta features for tracing and fakes have been reinstated. + +## 1.8.0 (2023-10-05) + +### Features Added + +* This includes the following features from `v1.8.0-beta.N` releases. + * Claims and CAE for authentication. + * New `messaging` package. + * Various helpers in the `runtime` package. + * Deprecation of `runtime.With*` funcs and their replacements in the `policy` package. +* Added types `KeyCredential` and `SASCredential` to the `azcore` package. + * Includes their respective constructor functions. +* Added types `KeyCredentialPolicy` and `SASCredentialPolicy` to the `azcore/runtime` package. + * Includes their respective constructor functions and options types. + +### Breaking Changes +> These changes affect only code written against beta versions of `v1.8.0` +* The beta features for tracing and fakes have been omitted for this release. + +### Bugs Fixed + +* Fixed an issue that could cause some ARM RPs to not be automatically registered. +* Block bearer token authentication for non TLS protected endpoints. + +### Other Changes + +* Updated dependencies. + +## 1.8.0-beta.3 (2023-09-07) + +### Features Added + +* Added function `FetcherForNextLink` and `FetcherForNextLinkOptions` to the `runtime` package to centralize creation of `Pager[T].Fetcher` from a next link URL. + +### Bugs Fixed + +* Suppress creating spans for nested SDK API calls. The HTTP span will be a child of the outer API span. + +### Other Changes + +* The following functions in the `runtime` package are now exposed from the `policy` package, and the `runtime` versions have been deprecated. + * `WithCaptureResponse` + * `WithHTTPHeader` + * `WithRetryOptions` + +## 1.7.2 (2023-09-06) + +### Bugs Fixed + +* Fix default HTTP transport to work in WASM modules. + +## 1.8.0-beta.2 (2023-08-14) + +### Features Added + +* Added function `SanitizePagerPollerPath` to the `server` package to centralize sanitization and formalize the contract. +* Added `TokenRequestOptions.EnableCAE` to indicate whether to request a CAE token. + +### Breaking Changes + +> This change affects only code written against beta version `v1.8.0-beta.1`. +* `messaging.CloudEvent` deserializes JSON objects as `[]byte`, instead of `json.RawMessage`. See the documentation for CloudEvent.Data for more information. + +> This change affects only code written against beta versions `v1.7.0-beta.2` and `v1.8.0-beta.1`. +* Removed parameter from method `Span.End()` and its type `tracing.SpanEndOptions`. This API GA'ed in `v1.2.0` so we cannot change it. + +### Bugs Fixed + +* Propagate any query parameters when constructing a fake poller and/or injecting next links. + +## 1.7.1 (2023-08-14) + +## Bugs Fixed + +* Enable TLS renegotiation in the default transport policy. + +## 1.8.0-beta.1 (2023-07-12) + +### Features Added + +- `messaging/CloudEvent` allows you to serialize/deserialize CloudEvents, as described in the CloudEvents 1.0 specification: [link](https://github.com/cloudevents/spec) + +### Other Changes + +* The beta features for CAE, tracing, and fakes have been reinstated. + +## 1.7.0 (2023-07-12) + +### Features Added +* Added method `WithClientName()` to type `azcore.Client` to support shallow cloning of a client with a new name used for tracing. + +### Breaking Changes +> These changes affect only code written against beta versions v1.7.0-beta.1 or v1.7.0-beta.2 +* The beta features for CAE, tracing, and fakes have been omitted for this release. + +## 1.7.0-beta.2 (2023-06-06) + +### Breaking Changes +> These changes affect only code written against beta version v1.7.0-beta.1 +* Method `SpanFromContext()` on type `tracing.Tracer` had the `bool` return value removed. + * This includes the field `SpanFromContext` in supporting type `tracing.TracerOptions`. +* Method `AddError()` has been removed from type `tracing.Span`. +* Method `Span.End()` now requires an argument of type `*tracing.SpanEndOptions`. + +## 1.6.1 (2023-06-06) + +### Bugs Fixed +* Fixed an issue in `azcore.NewClient()` and `arm.NewClient()` that could cause an incorrect module name to be used in telemetry. + +### Other Changes +* This version contains all bug fixes from `v1.7.0-beta.1` + +## 1.7.0-beta.1 (2023-05-24) + +### Features Added +* Restored CAE support for ARM clients. +* Added supporting features to enable distributed tracing. + * Added func `runtime.StartSpan()` for use by SDKs to start spans. + * Added method `WithContext()` to `runtime.Request` to support shallow cloning with a new context. + * Added field `TracingNamespace` to `runtime.PipelineOptions`. + * Added field `Tracer` to `runtime.NewPollerOptions` and `runtime.NewPollerFromResumeTokenOptions` types. + * Added field `SpanFromContext` to `tracing.TracerOptions`. + * Added methods `Enabled()`, `SetAttributes()`, and `SpanFromContext()` to `tracing.Tracer`. + * Added supporting pipeline policies to include HTTP spans when creating clients. +* Added package `fake` to support generated fakes packages in SDKs. + * The package contains public surface area exposed by fake servers and supporting APIs intended only for use by the fake server implementations. + * Added an internal fake poller implementation. + +### Bugs Fixed +* Retry policy always clones the underlying `*http.Request` before invoking the next policy. +* Added some non-standard error codes to the list of error codes for unregistered resource providers. + +## 1.6.0 (2023-05-04) + +### Features Added +* Added support for ARM cross-tenant authentication. Set the `AuxiliaryTenants` field of `arm.ClientOptions` to enable. +* Added `TenantID` field to `policy.TokenRequestOptions`. + +## 1.5.0 (2023-04-06) + +### Features Added +* Added `ShouldRetry` to `policy.RetryOptions` for finer-grained control over when to retry. + +### Breaking Changes +> These changes affect only code written against a beta version such as v1.5.0-beta.1 +> These features will return in v1.6.0-beta.1. +* Removed `TokenRequestOptions.Claims` and `.TenantID` +* Removed ARM client support for CAE and cross-tenant auth. + +### Bugs Fixed +* Added non-conformant LRO terminal states `Cancelled` and `Completed`. + +### Other Changes +* Updated to latest `internal` module. + +## 1.5.0-beta.1 (2023-03-02) + +### Features Added +* This release includes the features added in v1.4.0-beta.1 + +## 1.4.0 (2023-03-02) +> This release doesn't include features added in v1.4.0-beta.1. They will return in v1.5.0-beta.1. + +### Features Added +* Add `Clone()` method for `arm/policy.ClientOptions`. + +### Bugs Fixed +* ARM's RP registration policy will no longer swallow unrecognized errors. +* Fixed an issue in `runtime.NewPollerFromResumeToken()` when resuming a `Poller` with a custom `PollingHandler`. +* Fixed wrong policy copy in `arm/runtime.NewPipeline()`. + +## 1.4.0-beta.1 (2023-02-02) + +### Features Added +* Added support for ARM cross-tenant authentication. Set the `AuxiliaryTenants` field of `arm.ClientOptions` to enable. +* Added `Claims` and `TenantID` fields to `policy.TokenRequestOptions`. +* ARM bearer token policy handles CAE challenges. + +## 1.3.1 (2023-02-02) + +### Other Changes +* Update dependencies to latest versions. + +## 1.3.0 (2023-01-06) + +### Features Added +* Added `BearerTokenOptions.AuthorizationHandler` to enable extending `runtime.BearerTokenPolicy` + with custom authorization logic +* Added `Client` types and matching constructors to the `azcore` and `arm` packages. These represent a basic client for HTTP and ARM respectively. + +### Other Changes +* Updated `internal` module to latest version. +* `policy/Request.SetBody()` allows replacing a request's body with an empty one + +## 1.2.0 (2022-11-04) + +### Features Added +* Added `ClientOptions.APIVersion` field, which overrides the default version a client + requests of the service, if the client supports this (all ARM clients do). +* Added package `tracing` that contains the building blocks for distributed tracing. +* Added field `TracingProvider` to type `policy.ClientOptions` that will be used to set the per-client tracing implementation. + +### Bugs Fixed +* Fixed an issue in `runtime.SetMultipartFormData` to properly handle slices of `io.ReadSeekCloser`. +* Fixed the MaxRetryDelay default to be 60s. +* Failure to poll the state of an LRO will now return an `*azcore.ResponseError` for poller types that require this behavior. +* Fixed a bug in `runtime.NewPipeline` that would cause pipeline-specified allowed headers and query parameters to be lost. + +### Other Changes +* Retain contents of read-only fields when sending requests. + +## 1.1.4 (2022-10-06) + +### Bugs Fixed +* Don't retry a request if the `Retry-After` delay is greater than the configured `RetryOptions.MaxRetryDelay`. +* `runtime.JoinPaths`: do not unconditionally add a forward slash before the query string + +### Other Changes +* Removed logging URL from retry policy as it's redundant. +* Retry policy logs when it exits due to a non-retriable status code. + +## 1.1.3 (2022-09-01) + +### Bugs Fixed +* Adjusted the initial retry delay to 800ms per the Azure SDK guidelines. + +## 1.1.2 (2022-08-09) + +### Other Changes +* Fixed various doc bugs. + +## 1.1.1 (2022-06-30) + +### Bugs Fixed +* Avoid polling when a RELO LRO synchronously terminates. + +## 1.1.0 (2022-06-03) + +### Other Changes +* The one-second floor for `Frequency` when calling `PollUntilDone()` has been removed when running tests. + +## 1.0.0 (2022-05-12) + +### Features Added +* Added interface `runtime.PollingHandler` to support custom poller implementations. + * Added field `PollingHandler` of this type to `runtime.NewPollerOptions[T]` and `runtime.NewPollerFromResumeTokenOptions[T]`. + +### Breaking Changes +* Renamed `cloud.Configuration.LoginEndpoint` to `.ActiveDirectoryAuthorityHost` +* Renamed `cloud.AzurePublicCloud` to `cloud.AzurePublic` +* Removed `AuxiliaryTenants` field from `arm/ClientOptions` and `arm/policy/BearerTokenOptions` +* Removed `TokenRequestOptions.TenantID` +* `Poller[T].PollUntilDone()` now takes an `options *PollUntilDoneOptions` param instead of `freq time.Duration` +* Removed `arm/runtime.Poller[T]`, `arm/runtime.NewPoller[T]()` and `arm/runtime.NewPollerFromResumeToken[T]()` +* Removed `arm/runtime.FinalStateVia` and related `const` values +* Renamed `runtime.PageProcessor` to `runtime.PagingHandler` +* The `arm/runtime.ProviderRepsonse` and `arm/runtime.Provider` types are no longer exported. +* Renamed `NewRequestIdPolicy()` to `NewRequestIDPolicy()` +* `TokenCredential.GetToken` now returns `AccessToken` by value. + +### Bugs Fixed +* When per-try timeouts are enabled, only cancel the context after the body has been read and closed. +* The `Operation-Location` poller now properly handles `final-state-via` values. +* Improvements in `runtime.Poller[T]` + * `Poll()` shouldn't cache errors, allowing for additional retries when in a non-terminal state. + * `Result()` will cache the terminal result or error but not transient errors, allowing for additional retries. + +### Other Changes +* Updated to latest `internal` module and absorbed breaking changes. + * Use `temporal.Resource` and deleted copy. +* The internal poller implementation has been refactored. + * The implementation in `internal/pollers/poller.go` has been merged into `runtime/poller.go` with some slight modification. + * The internal poller types had their methods updated to conform to the `runtime.PollingHandler` interface. + * The creation of resume tokens has been refactored so that implementers of `runtime.PollingHandler` don't need to know about it. +* `NewPipeline()` places policies from `ClientOptions` after policies from `PipelineOptions` +* Default User-Agent headers no longer include `azcore` version information + +## 0.23.1 (2022-04-14) + +### Bugs Fixed +* Include XML header when marshalling XML content. +* Handle XML namespaces when searching for error code. +* Handle `odata.error` when searching for error code. + +## 0.23.0 (2022-04-04) + +### Features Added +* Added `runtime.Pager[T any]` and `runtime.Poller[T any]` supporting types for central, generic, implementations. +* Added `cloud` package with a new API for cloud configuration +* Added `FinalStateVia` field to `runtime.NewPollerOptions[T any]` type. + +### Breaking Changes +* Removed the `Poller` type-alias to the internal poller implementation. +* Added `Ptr[T any]` and `SliceOfPtrs[T any]` in the `to` package and removed all non-generic implementations. +* `NullValue` and `IsNullValue` now take a generic type parameter instead of an interface func parameter. +* Replaced `arm.Endpoint` with `cloud` API + * Removed the `endpoint` parameter from `NewRPRegistrationPolicy()` + * `arm/runtime.NewPipeline()` and `.NewRPRegistrationPolicy()` now return an `error` +* Refactored `NewPoller` and `NewPollerFromResumeToken` funcs in `arm/runtime` and `runtime` packages. + * Removed the `pollerID` parameter as it's no longer required. + * Created optional parameter structs and moved optional parameters into them. +* Changed `FinalStateVia` field to a `const` type. + +### Other Changes +* Converted expiring resource and dependent types to use generics. + +## 0.22.0 (2022-03-03) + +### Features Added +* Added header `WWW-Authenticate` to the default allow-list of headers for logging. +* Added a pipeline policy that enables the retrieval of HTTP responses from API calls. + * Added `runtime.WithCaptureResponse` to enable the policy at the API level (off by default). + +### Breaking Changes +* Moved `WithHTTPHeader` and `WithRetryOptions` from the `policy` package to the `runtime` package. + +## 0.21.1 (2022-02-04) + +### Bugs Fixed +* Restore response body after reading in `Poller.FinalResponse()`. (#16911) +* Fixed bug in `NullValue` that could lead to incorrect comparisons for empty maps/slices (#16969) + +### Other Changes +* `BearerTokenPolicy` is more resilient to transient authentication failures. (#16789) + +## 0.21.0 (2022-01-11) + +### Features Added +* Added `AllowedHeaders` and `AllowedQueryParams` to `policy.LogOptions` to control which headers and query parameters are written to the logger. +* Added `azcore.ResponseError` type which is returned from APIs when a non-success HTTP status code is received. + +### Breaking Changes +* Moved `[]policy.Policy` parameters of `arm/runtime.NewPipeline` and `runtime.NewPipeline` into a new struct, `runtime.PipelineOptions` +* Renamed `arm/ClientOptions.Host` to `.Endpoint` +* Moved `Request.SkipBodyDownload` method to function `runtime.SkipBodyDownload` +* Removed `azcore.HTTPResponse` interface type +* `arm.NewPoller()` and `runtime.NewPoller()` no longer require an `eu` parameter +* `runtime.NewResponseError()` no longer requires an `error` parameter + +## 0.20.0 (2021-10-22) + +### Breaking Changes +* Removed `arm.Connection` +* Removed `azcore.Credential` and `.NewAnonymousCredential()` + * `NewRPRegistrationPolicy` now requires an `azcore.TokenCredential` +* `runtime.NewPipeline` has a new signature that simplifies implementing custom authentication +* `arm/runtime.RegistrationOptions` embeds `policy.ClientOptions` +* Contents in the `log` package have been slightly renamed. +* Removed `AuthenticationOptions` in favor of `policy.BearerTokenOptions` +* Changed parameters for `NewBearerTokenPolicy()` +* Moved policy config options out of `arm/runtime` and into `arm/policy` + +### Features Added +* Updating Documentation +* Added string typdef `arm.Endpoint` to provide a hint toward expected ARM client endpoints +* `azcore.ClientOptions` contains common pipeline configuration settings +* Added support for multi-tenant authorization in `arm/runtime` +* Require one second minimum when calling `PollUntilDone()` + +### Bug Fixes +* Fixed a potential panic when creating the default Transporter. +* Close LRO initial response body when creating a poller. +* Fixed a panic when recursively cloning structs that contain time.Time. + +## 0.19.0 (2021-08-25) + +### Breaking Changes +* Split content out of `azcore` into various packages. The intent is to separate content based on its usage (common, uncommon, SDK authors). + * `azcore` has all core functionality. + * `log` contains facilities for configuring in-box logging. + * `policy` is used for configuring pipeline options and creating custom pipeline policies. + * `runtime` contains various helpers used by SDK authors and generated content. + * `streaming` has helpers for streaming IO operations. +* `NewTelemetryPolicy()` now requires module and version parameters and the `Value` option has been removed. + * As a result, the `Request.Telemetry()` method has been removed. +* The telemetry policy now includes the SDK prefix `azsdk-go-` so callers no longer need to provide it. +* The `*http.Request` in `runtime.Request` is no longer anonymously embedded. Use the `Raw()` method to access it. +* The `UserAgent` and `Version` constants have been made internal, `Module` and `Version` respectively. + +### Bug Fixes +* Fixed an issue in the retry policy where the request body could be overwritten after a rewind. + +### Other Changes +* Moved modules `armcore` and `to` content into `arm` and `to` packages respectively. + * The `Pipeline()` method on `armcore.Connection` has been replaced by `NewPipeline()` in `arm.Connection`. It takes module and version parameters used by the telemetry policy. +* Poller logic has been consolidated across ARM and core implementations. + * This required some changes to the internal interfaces for core pollers. +* The core poller types have been improved, including more logging and test coverage. + +## 0.18.1 (2021-08-20) + +### Features Added +* Adds an `ETag` type for comparing etags and handling etags on requests +* Simplifies the `requestBodyProgess` and `responseBodyProgress` into a single `progress` object + +### Bugs Fixed +* `JoinPaths` will preserve query parameters encoded in the `root` url. + +### Other Changes +* Bumps dependency on `internal` module to the latest version (v0.7.0) + +## 0.18.0 (2021-07-29) +### Features Added +* Replaces methods from Logger type with two package methods for interacting with the logging functionality. +* `azcore.SetClassifications` replaces `azcore.Logger().SetClassifications` +* `azcore.SetListener` replaces `azcore.Logger().SetListener` + +### Breaking Changes +* Removes `Logger` type from `azcore` + + +## 0.17.0 (2021-07-27) +### Features Added +* Adding TenantID to TokenRequestOptions (https://github.com/Azure/azure-sdk-for-go/pull/14879) +* Adding AuxiliaryTenants to AuthenticationOptions (https://github.com/Azure/azure-sdk-for-go/pull/15123) + +### Breaking Changes +* Rename `AnonymousCredential` to `NewAnonymousCredential` (https://github.com/Azure/azure-sdk-for-go/pull/15104) +* rename `AuthenticationPolicyOptions` to `AuthenticationOptions` (https://github.com/Azure/azure-sdk-for-go/pull/15103) +* Make Header constants private (https://github.com/Azure/azure-sdk-for-go/pull/15038) + + +## 0.16.2 (2021-05-26) +### Features Added +* Improved support for byte arrays [#14715](https://github.com/Azure/azure-sdk-for-go/pull/14715) + + +## 0.16.1 (2021-05-19) +### Features Added +* Add license.txt to azcore module [#14682](https://github.com/Azure/azure-sdk-for-go/pull/14682) + + +## 0.16.0 (2021-05-07) +### Features Added +* Remove extra `*` in UnmarshalAsByteArray() [#14642](https://github.com/Azure/azure-sdk-for-go/pull/14642) + + +## 0.15.1 (2021-05-06) +### Features Added +* Cache the original request body on Request [#14634](https://github.com/Azure/azure-sdk-for-go/pull/14634) + + +## 0.15.0 (2021-05-05) +### Features Added +* Add support for null map and slice +* Export `Response.Payload` method + +### Breaking Changes +* remove `Response.UnmarshalError` as it's no longer required + + +## 0.14.5 (2021-04-23) +### Features Added +* Add `UnmarshalError()` on `azcore.Response` + + +## 0.14.4 (2021-04-22) +### Features Added +* Support for basic LRO polling +* Added type `LROPoller` and supporting types for basic polling on long running operations. +* rename poller param and added doc comment + +### Bugs Fixed +* Fixed content type detection bug in logging. + + +## 0.14.3 (2021-03-29) +### Features Added +* Add support for multi-part form data +* Added method `WriteMultipartFormData()` to Request. + + +## 0.14.2 (2021-03-17) +### Features Added +* Add support for encoding JSON null values +* Adds `NullValue()` and `IsNullValue()` functions for setting and detecting sentinel values used for encoding a JSON null. +* Documentation fixes + +### Bugs Fixed +* Fixed improper error wrapping + + +## 0.14.1 (2021-02-08) +### Features Added +* Add `Pager` and `Poller` interfaces to azcore + + +## 0.14.0 (2021-01-12) +### Features Added +* Accept zero-value options for default values +* Specify zero-value options structs to accept default values. +* Remove `DefaultXxxOptions()` methods. +* Do not silently change TryTimeout on negative values +* make per-try timeout opt-in + + +## 0.13.4 (2020-11-20) +### Features Added +* Include telemetry string in User Agent + + +## 0.13.3 (2020-11-20) +### Features Added +* Updating response body handling on `azcore.Response` + + +## 0.13.2 (2020-11-13) +### Features Added +* Remove implementation of stateless policies as first-class functions. + + +## 0.13.1 (2020-11-05) +### Features Added +* Add `Telemetry()` method to `azcore.Request()` + + +## 0.13.0 (2020-10-14) +### Features Added +* Rename `log` to `logger` to avoid name collision with the log package. +* Documentation improvements +* Simplified `DefaultHTTPClientTransport()` implementation + + +## 0.12.1 (2020-10-13) +### Features Added +* Update `internal` module dependence to `v0.5.0` + + +## 0.12.0 (2020-10-08) +### Features Added +* Removed storage specific content +* Removed internal content to prevent API clutter +* Refactored various policy options to conform with our options pattern + + +## 0.11.0 (2020-09-22) +### Features Added + +* Removed `LogError` and `LogSlowResponse`. +* Renamed `options` in `RequestLogOptions`. +* Updated `NewRequestLogPolicy()` to follow standard pattern for options. +* Refactored `requestLogPolicy.Do()` per above changes. +* Cleaned up/added logging in retry policy. +* Export `NewResponseError()` +* Fix `RequestLogOptions` comment + + +## 0.10.1 (2020-09-17) +### Features Added +* Add default console logger +* Default console logger writes to stderr. To enable it, set env var `AZURE_SDK_GO_LOGGING` to the value 'all'. +* Added `Logger.Writef()` to reduce the need for `ShouldLog()` checks. +* Add `LogLongRunningOperation` + + +## 0.10.0 (2020-09-10) +### Features Added +* The `request` and `transport` interfaces have been refactored to align with the patterns in the standard library. +* `NewRequest()` now uses `http.NewRequestWithContext()` and performs additional validation, it also requires a context parameter. +* The `Policy` and `Transport` interfaces have had their context parameter removed as the context is associated with the underlying `http.Request`. +* `Pipeline.Do()` will validate the HTTP request before sending it through the pipeline, avoiding retries on a malformed request. +* The `Retrier` interface has been replaced with the `NonRetriableError` interface, and the retry policy updated to test for this. +* `Request.SetBody()` now requires a content type parameter for setting the request's MIME type. +* moved path concatenation into `JoinPaths()` func + + +## 0.9.6 (2020-08-18) +### Features Added +* Improvements to body download policy +* Always download the response body for error responses, i.e. HTTP status codes >= 400. +* Simplify variable declarations + + +## 0.9.5 (2020-08-11) +### Features Added +* Set the Content-Length header in `Request.SetBody` + + +## 0.9.4 (2020-08-03) +### Features Added +* Fix cancellation of per try timeout +* Per try timeout is used to ensure that an HTTP operation doesn't take too long, e.g. that a GET on some URL doesn't take an inordinant amount of time. +* Once the HTTP request returns, the per try timeout should be cancelled, not when the response has been read to completion. +* Do not drain response body if there are no more retries +* Do not retry non-idempotent operations when body download fails + + +## 0.9.3 (2020-07-28) +### Features Added +* Add support for custom HTTP request headers +* Inserts an internal policy into the pipeline that can extract HTTP header values from the caller's context, adding them to the request. +* Use `azcore.WithHTTPHeader` to add HTTP headers to a context. +* Remove method specific to Go 1.14 + + +## 0.9.2 (2020-07-28) +### Features Added +* Omit read-only content from request payloads +* If any field in a payload's object graph contains `azure:"ro"`, make a clone of the object graph, omitting all fields with this annotation. +* Verify no fields were dropped +* Handle embedded struct types +* Added test for cloning by value +* Add messages to failures + + +## 0.9.1 (2020-07-22) +### Features Added +* Updated dependency on internal module to fix race condition. + + +## 0.9.0 (2020-07-09) +### Features Added +* Add `HTTPResponse` interface to be used by callers to access the raw HTTP response from an error in the event of an API call failure. +* Updated `sdk/internal` dependency to latest version. +* Rename package alias + + +## 0.8.2 (2020-06-29) +### Features Added +* Added missing documentation comments + +### Bugs Fixed +* Fixed a bug in body download policy. + + +## 0.8.1 (2020-06-26) +### Features Added +* Miscellaneous clean-up reported by linters + + +## 0.8.0 (2020-06-01) +### Features Added +* Differentiate between standard and URL encoding. + + +## 0.7.1 (2020-05-27) +### Features Added +* Add support for for base64 encoding and decoding of payloads. + + +## 0.7.0 (2020-05-12) +### Features Added +* Change `RetryAfter()` to a function. + + +## 0.6.0 (2020-04-29) +### Features Added +* Updating `RetryAfter` to only return the detaion in the RetryAfter header + + +## 0.5.0 (2020-03-23) +### Features Added +* Export `TransportFunc` + +### Breaking Changes +* Removed `IterationDone` + + +## 0.4.1 (2020-02-25) +### Features Added +* Ensure per-try timeout is properly cancelled +* Explicitly call cancel the per-try timeout when the response body has been read/closed by the body download policy. +* When the response body is returned to the caller for reading/closing, wrap it in a `responseBodyReader` that will cancel the timeout when the body is closed. +* `Logger.Should()` will return false if no listener is set. + + +## 0.4.0 (2020-02-18) +### Features Added +* Enable custom `RetryOptions` to be specified per API call +* Added `WithRetryOptions()` that adds a custom `RetryOptions` to the provided context, allowing custom settings per API call. +* Remove 429 from the list of default HTTP status codes for retry. +* Change StatusCodesForRetry to a slice so consumers can append to it. +* Added support for retry-after in HTTP-date format. +* Cleaned up some comments specific to storage. +* Remove `Request.SetQueryParam()` +* Renamed `MaxTries` to `MaxRetries` + +## 0.3.0 (2020-01-16) +### Features Added +* Added `DefaultRetryOptions` to create initialized default options. + +### Breaking Changes +* Removed `Response.CheckStatusCode()` + + +## 0.2.0 (2020-01-15) +### Features Added +* Add support for marshalling and unmarshalling JSON +* Removed `Response.Payload` field +* Exit early when unmarsahlling if there is no payload + + +## 0.1.0 (2020-01-10) +### Features Added +* Initial release diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/LICENSE.txt b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/LICENSE.txt new file mode 100644 index 00000000000..48ea6616b5b --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/LICENSE.txt @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) Microsoft Corporation. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/README.md b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/README.md new file mode 100644 index 00000000000..35a74e18d09 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/README.md @@ -0,0 +1,39 @@ +# Azure Core Client Module for Go + +[![PkgGoDev](https://pkg.go.dev/badge/github.com/Azure/azure-sdk-for-go/sdk/azcore)](https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/azcore) +[![Build Status](https://dev.azure.com/azure-sdk/public/_apis/build/status/go/go%20-%20azcore%20-%20ci?branchName=main)](https://dev.azure.com/azure-sdk/public/_build/latest?definitionId=1843&branchName=main) +[![Code Coverage](https://img.shields.io/azure-devops/coverage/azure-sdk/public/1843/main)](https://img.shields.io/azure-devops/coverage/azure-sdk/public/1843/main) + +The `azcore` module provides a set of common interfaces and types for Go SDK client modules. +These modules follow the [Azure SDK Design Guidelines for Go](https://azure.github.io/azure-sdk/golang_introduction.html). + +## Getting started + +This project uses [Go modules](https://github.com/golang/go/wiki/Modules) for versioning and dependency management. + +Typically, you will not need to explicitly install `azcore` as it will be installed as a client module dependency. +To add the latest version to your `go.mod` file, execute the following command. + +```bash +go get github.com/Azure/azure-sdk-for-go/sdk/azcore +``` + +General documentation and examples can be found on [pkg.go.dev](https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/azcore). + +## Contributing +This project welcomes contributions and suggestions. Most contributions require +you to agree to a Contributor License Agreement (CLA) declaring that you have +the right to, and actually do, grant us the rights to use your contribution. +For details, visit [https://cla.microsoft.com](https://cla.microsoft.com). + +When you submit a pull request, a CLA-bot will automatically determine whether +you need to provide a CLA and decorate the PR appropriately (e.g., label, +comment). Simply follow the instructions provided by the bot. You will only +need to do this once across all repos using our CLA. + +This project has adopted the +[Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). +For more information, see the +[Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) +or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any +additional questions or comments. diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/ci.yml b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/ci.yml new file mode 100644 index 00000000000..99348527b54 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/ci.yml @@ -0,0 +1,29 @@ +# NOTE: Please refer to https://aka.ms/azsdk/engsys/ci-yaml before editing this file. +trigger: + branches: + include: + - main + - feature/* + - hotfix/* + - release/* + paths: + include: + - sdk/azcore/ + - eng/ + +pr: + branches: + include: + - main + - feature/* + - hotfix/* + - release/* + paths: + include: + - sdk/azcore/ + - eng/ + +extends: + template: /eng/pipelines/templates/jobs/archetype-sdk-client.yml + parameters: + ServiceDirectory: azcore diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud/cloud.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud/cloud.go new file mode 100644 index 00000000000..9d077a3e126 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud/cloud.go @@ -0,0 +1,44 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cloud + +var ( + // AzureChina contains configuration for Azure China. + AzureChina = Configuration{ + ActiveDirectoryAuthorityHost: "https://login.chinacloudapi.cn/", Services: map[ServiceName]ServiceConfiguration{}, + } + // AzureGovernment contains configuration for Azure Government. + AzureGovernment = Configuration{ + ActiveDirectoryAuthorityHost: "https://login.microsoftonline.us/", Services: map[ServiceName]ServiceConfiguration{}, + } + // AzurePublic contains configuration for Azure Public Cloud. + AzurePublic = Configuration{ + ActiveDirectoryAuthorityHost: "https://login.microsoftonline.com/", Services: map[ServiceName]ServiceConfiguration{}, + } +) + +// ServiceName identifies a cloud service. +type ServiceName string + +// ResourceManager is a global constant identifying Azure Resource Manager. +const ResourceManager ServiceName = "resourceManager" + +// ServiceConfiguration configures a specific cloud service such as Azure Resource Manager. +type ServiceConfiguration struct { + // Audience is the audience the client will request for its access tokens. + Audience string + // Endpoint is the service's base URL. + Endpoint string +} + +// Configuration configures a cloud. +type Configuration struct { + // ActiveDirectoryAuthorityHost is the base URL of the cloud's Azure Active Directory. + ActiveDirectoryAuthorityHost string + // Services contains configuration for the cloud's services. + Services map[ServiceName]ServiceConfiguration +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud/doc.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud/doc.go new file mode 100644 index 00000000000..985b1bde2f2 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud/doc.go @@ -0,0 +1,53 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +/* +Package cloud implements a configuration API for applications deployed to sovereign or private Azure clouds. + +Azure SDK client configuration defaults are appropriate for Azure Public Cloud (sometimes referred to as +"Azure Commercial" or simply "Microsoft Azure"). This package enables applications deployed to other +Azure Clouds to configure clients appropriately. + +This package contains predefined configuration for well-known sovereign clouds such as Azure Government and +Azure China. Azure SDK clients accept this configuration via the Cloud field of azcore.ClientOptions. For +example, configuring a credential and ARM client for Azure Government: + + opts := azcore.ClientOptions{Cloud: cloud.AzureGovernment} + cred, err := azidentity.NewDefaultAzureCredential( + &azidentity.DefaultAzureCredentialOptions{ClientOptions: opts}, + ) + handle(err) + + client, err := armsubscription.NewClient( + cred, &arm.ClientOptions{ClientOptions: opts}, + ) + handle(err) + +Applications deployed to a private cloud such as Azure Stack create a Configuration object with +appropriate values: + + c := cloud.Configuration{ + ActiveDirectoryAuthorityHost: "https://...", + Services: map[cloud.ServiceName]cloud.ServiceConfiguration{ + cloud.ResourceManager: { + Audience: "...", + Endpoint: "https://...", + }, + }, + } + opts := azcore.ClientOptions{Cloud: c} + + cred, err := azidentity.NewDefaultAzureCredential( + &azidentity.DefaultAzureCredentialOptions{ClientOptions: opts}, + ) + handle(err) + + client, err := armsubscription.NewClient( + cred, &arm.ClientOptions{ClientOptions: opts}, + ) + handle(err) +*/ +package cloud diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/core.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/core.go new file mode 100644 index 00000000000..9d1c2f0c053 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/core.go @@ -0,0 +1,173 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azcore + +import ( + "reflect" + "sync" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/tracing" +) + +// AccessToken represents an Azure service bearer access token with expiry information. +type AccessToken = exported.AccessToken + +// TokenCredential represents a credential capable of providing an OAuth token. +type TokenCredential = exported.TokenCredential + +// KeyCredential contains an authentication key used to authenticate to an Azure service. +type KeyCredential = exported.KeyCredential + +// NewKeyCredential creates a new instance of [KeyCredential] with the specified values. +// - key is the authentication key +func NewKeyCredential(key string) *KeyCredential { + return exported.NewKeyCredential(key) +} + +// SASCredential contains a shared access signature used to authenticate to an Azure service. +type SASCredential = exported.SASCredential + +// NewSASCredential creates a new instance of [SASCredential] with the specified values. +// - sas is the shared access signature +func NewSASCredential(sas string) *SASCredential { + return exported.NewSASCredential(sas) +} + +// holds sentinel values used to send nulls +var nullables map[reflect.Type]any = map[reflect.Type]any{} +var nullablesMu sync.RWMutex + +// NullValue is used to send an explicit 'null' within a request. +// This is typically used in JSON-MERGE-PATCH operations to delete a value. +func NullValue[T any]() T { + t := shared.TypeOfT[T]() + + nullablesMu.RLock() + v, found := nullables[t] + nullablesMu.RUnlock() + + if found { + // return the sentinel object + return v.(T) + } + + // promote to exclusive lock and check again (double-checked locking pattern) + nullablesMu.Lock() + defer nullablesMu.Unlock() + v, found = nullables[t] + + if !found { + var o reflect.Value + if k := t.Kind(); k == reflect.Map { + o = reflect.MakeMap(t) + } else if k == reflect.Slice { + // empty slices appear to all point to the same data block + // which causes comparisons to become ambiguous. so we create + // a slice with len/cap of one which ensures a unique address. + o = reflect.MakeSlice(t, 1, 1) + } else { + o = reflect.New(t.Elem()) + } + v = o.Interface() + nullables[t] = v + } + // return the sentinel object + return v.(T) +} + +// IsNullValue returns true if the field contains a null sentinel value. +// This is used by custom marshallers to properly encode a null value. +func IsNullValue[T any](v T) bool { + // see if our map has a sentinel object for this *T + t := reflect.TypeOf(v) + nullablesMu.RLock() + defer nullablesMu.RUnlock() + + if o, found := nullables[t]; found { + o1 := reflect.ValueOf(o) + v1 := reflect.ValueOf(v) + // we found it; return true if v points to the sentinel object. + // NOTE: maps and slices can only be compared to nil, else you get + // a runtime panic. so we compare addresses instead. + return o1.Pointer() == v1.Pointer() + } + // no sentinel object for this *t + return false +} + +// ClientOptions contains optional settings for a client's pipeline. +// Instances can be shared across calls to SDK client constructors when uniform configuration is desired. +// Zero-value fields will have their specified default values applied during use. +type ClientOptions = policy.ClientOptions + +// Client is a basic HTTP client. It consists of a pipeline and tracing provider. +type Client struct { + pl runtime.Pipeline + tr tracing.Tracer + + // cached on the client to support shallow copying with new values + tp tracing.Provider + modVer string + namespace string +} + +// NewClient creates a new Client instance with the provided values. +// - moduleName - the fully qualified name of the module where the client is defined; used by the telemetry policy and tracing provider. +// - moduleVersion - the semantic version of the module; used by the telemetry policy and tracing provider. +// - plOpts - pipeline configuration options; can be the zero-value +// - options - optional client configurations; pass nil to accept the default values +func NewClient(moduleName, moduleVersion string, plOpts runtime.PipelineOptions, options *ClientOptions) (*Client, error) { + if options == nil { + options = &ClientOptions{} + } + + if !options.Telemetry.Disabled { + if err := shared.ValidateModVer(moduleVersion); err != nil { + return nil, err + } + } + + pl := runtime.NewPipeline(moduleName, moduleVersion, plOpts, options) + + tr := options.TracingProvider.NewTracer(moduleName, moduleVersion) + if tr.Enabled() && plOpts.Tracing.Namespace != "" { + tr.SetAttributes(tracing.Attribute{Key: shared.TracingNamespaceAttrName, Value: plOpts.Tracing.Namespace}) + } + + return &Client{ + pl: pl, + tr: tr, + tp: options.TracingProvider, + modVer: moduleVersion, + namespace: plOpts.Tracing.Namespace, + }, nil +} + +// Pipeline returns the pipeline for this client. +func (c *Client) Pipeline() runtime.Pipeline { + return c.pl +} + +// Tracer returns the tracer for this client. +func (c *Client) Tracer() tracing.Tracer { + return c.tr +} + +// WithClientName returns a shallow copy of the Client with its tracing client name changed to clientName. +// Note that the values for module name and version will be preserved from the source Client. +// - clientName - the fully qualified name of the client ("package.Client"); this is used by the tracing provider when creating spans +func (c *Client) WithClientName(clientName string) *Client { + tr := c.tp.NewTracer(clientName, c.modVer) + if tr.Enabled() && c.namespace != "" { + tr.SetAttributes(tracing.Attribute{Key: shared.TracingNamespaceAttrName, Value: c.namespace}) + } + return &Client{pl: c.pl, tr: tr, tp: c.tp, modVer: c.modVer, namespace: c.namespace} +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/doc.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/doc.go new file mode 100644 index 00000000000..654a5f40431 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/doc.go @@ -0,0 +1,264 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright 2017 Microsoft Corporation. All rights reserved. +// Use of this source code is governed by an MIT +// license that can be found in the LICENSE file. + +/* +Package azcore implements an HTTP request/response middleware pipeline used by Azure SDK clients. + +The middleware consists of three components. + + - One or more Policy instances. + - A Transporter instance. + - A Pipeline instance that combines the Policy and Transporter instances. + +# Implementing the Policy Interface + +A Policy can be implemented in two ways; as a first-class function for a stateless Policy, or as +a method on a type for a stateful Policy. Note that HTTP requests made via the same pipeline share +the same Policy instances, so if a Policy mutates its state it MUST be properly synchronized to +avoid race conditions. + +A Policy's Do method is called when an HTTP request wants to be sent over the network. The Do method can +perform any operation(s) it desires. For example, it can log the outgoing request, mutate the URL, headers, +and/or query parameters, inject a failure, etc. Once the Policy has successfully completed its request +work, it must call the Next() method on the *policy.Request instance in order to pass the request to the +next Policy in the chain. + +When an HTTP response comes back, the Policy then gets a chance to process the response/error. The Policy instance +can log the response, retry the operation if it failed due to a transient error or timeout, unmarshal the response +body, etc. Once the Policy has successfully completed its response work, it must return the *http.Response +and error instances to its caller. + +Template for implementing a stateless Policy: + + type policyFunc func(*policy.Request) (*http.Response, error) + + // Do implements the Policy interface on policyFunc. + func (pf policyFunc) Do(req *policy.Request) (*http.Response, error) { + return pf(req) + } + + func NewMyStatelessPolicy() policy.Policy { + return policyFunc(func(req *policy.Request) (*http.Response, error) { + // TODO: mutate/process Request here + + // forward Request to next Policy & get Response/error + resp, err := req.Next() + + // TODO: mutate/process Response/error here + + // return Response/error to previous Policy + return resp, err + }) + } + +Template for implementing a stateful Policy: + + type MyStatefulPolicy struct { + // TODO: add configuration/setting fields here + } + + // TODO: add initialization args to NewMyStatefulPolicy() + func NewMyStatefulPolicy() policy.Policy { + return &MyStatefulPolicy{ + // TODO: initialize configuration/setting fields here + } + } + + func (p *MyStatefulPolicy) Do(req *policy.Request) (resp *http.Response, err error) { + // TODO: mutate/process Request here + + // forward Request to next Policy & get Response/error + resp, err := req.Next() + + // TODO: mutate/process Response/error here + + // return Response/error to previous Policy + return resp, err + } + +# Implementing the Transporter Interface + +The Transporter interface is responsible for sending the HTTP request and returning the corresponding +HTTP response or error. The Transporter is invoked by the last Policy in the chain. The default Transporter +implementation uses a shared http.Client from the standard library. + +The same stateful/stateless rules for Policy implementations apply to Transporter implementations. + +# Using Policy and Transporter Instances Via a Pipeline + +To use the Policy and Transporter instances, an application passes them to the runtime.NewPipeline function. + + func NewPipeline(transport Transporter, policies ...Policy) Pipeline + +The specified Policy instances form a chain and are invoked in the order provided to NewPipeline +followed by the Transporter. + +Once the Pipeline has been created, create a runtime.Request instance and pass it to Pipeline's Do method. + + func NewRequest(ctx context.Context, httpMethod string, endpoint string) (*Request, error) + + func (p Pipeline) Do(req *Request) (*http.Request, error) + +The Pipeline.Do method sends the specified Request through the chain of Policy and Transporter +instances. The response/error is then sent through the same chain of Policy instances in reverse +order. For example, assuming there are Policy types PolicyA, PolicyB, and PolicyC along with +TransportA. + + pipeline := NewPipeline(TransportA, PolicyA, PolicyB, PolicyC) + +The flow of Request and Response looks like the following: + + policy.Request -> PolicyA -> PolicyB -> PolicyC -> TransportA -----+ + | + HTTP(S) endpoint + | + caller <--------- PolicyA <- PolicyB <- PolicyC <- http.Response-+ + +# Creating a Request Instance + +The Request instance passed to Pipeline's Do method is a wrapper around an *http.Request. It also +contains some internal state and provides various convenience methods. You create a Request instance +by calling the runtime.NewRequest function: + + func NewRequest(ctx context.Context, httpMethod string, endpoint string) (*Request, error) + +If the Request should contain a body, call the SetBody method. + + func (req *Request) SetBody(body ReadSeekCloser, contentType string) error + +A seekable stream is required so that upon retry, the retry Policy instance can seek the stream +back to the beginning before retrying the network request and re-uploading the body. + +# Sending an Explicit Null + +Operations like JSON-MERGE-PATCH send a JSON null to indicate a value should be deleted. + + { + "delete-me": null + } + +This requirement conflicts with the SDK's default marshalling that specifies "omitempty" as +a means to resolve the ambiguity between a field to be excluded and its zero-value. + + type Widget struct { + Name *string `json:",omitempty"` + Count *int `json:",omitempty"` + } + +In the above example, Name and Count are defined as pointer-to-type to disambiguate between +a missing value (nil) and a zero-value (0) which might have semantic differences. + +In a PATCH operation, any fields left as nil are to have their values preserved. When updating +a Widget's count, one simply specifies the new value for Count, leaving Name nil. + +To fulfill the requirement for sending a JSON null, the NullValue() function can be used. + + w := Widget{ + Count: azcore.NullValue[*int](), + } + +This sends an explict "null" for Count, indicating that any current value for Count should be deleted. + +# Processing the Response + +When the HTTP response is received, the *http.Response is returned directly. Each Policy instance +can inspect/mutate the *http.Response. + +# Built-in Logging + +To enable logging, set environment variable AZURE_SDK_GO_LOGGING to "all" before executing your program. + +By default the logger writes to stderr. This can be customized by calling log.SetListener, providing +a callback that writes to the desired location. Any custom logging implementation MUST provide its +own synchronization to handle concurrent invocations. + +See the docs for the log package for further details. + +# Pageable Operations + +Pageable operations return potentially large data sets spread over multiple GET requests. The result of +each GET is a "page" of data consisting of a slice of items. + +Pageable operations can be identified by their New*Pager naming convention and return type of *runtime.Pager[T]. + + func (c *WidgetClient) NewListWidgetsPager(o *Options) *runtime.Pager[PageResponse] + +The call to WidgetClient.NewListWidgetsPager() returns an instance of *runtime.Pager[T] for fetching pages +and determining if there are more pages to fetch. No IO calls are made until the NextPage() method is invoked. + + pager := widgetClient.NewListWidgetsPager(nil) + for pager.More() { + page, err := pager.NextPage(context.TODO()) + // handle err + for _, widget := range page.Values { + // process widget + } + } + +# Long-Running Operations + +Long-running operations (LROs) are operations consisting of an initial request to start the operation followed +by polling to determine when the operation has reached a terminal state. An LRO's terminal state is one +of the following values. + + - Succeeded - the LRO completed successfully + - Failed - the LRO failed to complete + - Canceled - the LRO was canceled + +LROs can be identified by their Begin* prefix and their return type of *runtime.Poller[T]. + + func (c *WidgetClient) BeginCreateOrUpdate(ctx context.Context, w Widget, o *Options) (*runtime.Poller[Response], error) + +When a call to WidgetClient.BeginCreateOrUpdate() returns a nil error, it means that the LRO has started. +It does _not_ mean that the widget has been created or updated (or failed to be created/updated). + +The *runtime.Poller[T] provides APIs for determining the state of the LRO. To wait for the LRO to complete, +call the PollUntilDone() method. + + poller, err := widgetClient.BeginCreateOrUpdate(context.TODO(), Widget{}, nil) + // handle err + result, err := poller.PollUntilDone(context.TODO(), nil) + // handle err + // use result + +The call to PollUntilDone() will block the current goroutine until the LRO has reached a terminal state or the +context is canceled/timed out. + +Note that LROs can take anywhere from several seconds to several minutes. The duration is operation-dependent. Due to +this variant behavior, pollers do _not_ have a preconfigured time-out. Use a context with the appropriate cancellation +mechanism as required. + +# Resume Tokens + +Pollers provide the ability to serialize their state into a "resume token" which can be used by another process to +recreate the poller. This is achieved via the runtime.Poller[T].ResumeToken() method. + + token, err := poller.ResumeToken() + // handle error + +Note that a token can only be obtained for a poller that's in a non-terminal state. Also note that any subsequent calls +to poller.Poll() might change the poller's state. In this case, a new token should be created. + +After the token has been obtained, it can be used to recreate an instance of the originating poller. + + poller, err := widgetClient.BeginCreateOrUpdate(nil, Widget{}, &Options{ + ResumeToken: token, + }) + +When resuming a poller, no IO is performed, and zero-value arguments can be used for everything but the Options.ResumeToken. + +Resume tokens are unique per service client and operation. Attempting to resume a poller for LRO BeginB() with a token from LRO +BeginA() will result in an error. + +# Fakes + +The fake package contains types used for constructing in-memory fake servers used in unit tests. +This allows writing tests to cover various success/error conditions without the need for connecting to a live service. + +Please see https://github.com/Azure/azure-sdk-for-go/tree/main/sdk/samples/fakes for details and examples on how to use fakes. +*/ +package azcore diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/errors.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/errors.go new file mode 100644 index 00000000000..17bd50c6732 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/errors.go @@ -0,0 +1,14 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azcore + +import "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" + +// ResponseError is returned when a request is made to a service and +// the service returns a non-success HTTP status code. +// Use errors.As() to access this type in the error chain. +type ResponseError = exported.ResponseError diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/etag.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/etag.go new file mode 100644 index 00000000000..2b19d01f76e --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/etag.go @@ -0,0 +1,57 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azcore + +import ( + "strings" +) + +// ETag is a property used for optimistic concurrency during updates +// ETag is a validator based on https://tools.ietf.org/html/rfc7232#section-2.3.2 +// An ETag can be empty (""). +type ETag string + +// ETagAny is an ETag that represents everything, the value is "*" +const ETagAny ETag = "*" + +// Equals does a strong comparison of two ETags. Equals returns true when both +// ETags are not weak and the values of the underlying strings are equal. +func (e ETag) Equals(other ETag) bool { + return !e.IsWeak() && !other.IsWeak() && e == other +} + +// WeakEquals does a weak comparison of two ETags. Two ETags are equivalent if their opaque-tags match +// character-by-character, regardless of either or both being tagged as "weak". +func (e ETag) WeakEquals(other ETag) bool { + getStart := func(e1 ETag) int { + if e1.IsWeak() { + return 2 + } + return 0 + } + aStart := getStart(e) + bStart := getStart(other) + + aVal := e[aStart:] + bVal := other[bStart:] + + return aVal == bVal +} + +// IsWeak specifies whether the ETag is strong or weak. +func (e ETag) IsWeak() bool { + return len(e) >= 4 && strings.HasPrefix(string(e), "W/\"") && strings.HasSuffix(string(e), "\"") +} + +// MatchConditions specifies HTTP options for conditional requests. +type MatchConditions struct { + // Optionally limit requests to resources that have a matching ETag. + IfMatch *ETag + + // Optionally limit requests to resources that do not match the ETag. + IfNoneMatch *ETag +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported/exported.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported/exported.go new file mode 100644 index 00000000000..f2b296b6dc7 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported/exported.go @@ -0,0 +1,175 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package exported + +import ( + "context" + "encoding/base64" + "fmt" + "io" + "net/http" + "sync/atomic" + "time" +) + +type nopCloser struct { + io.ReadSeeker +} + +func (n nopCloser) Close() error { + return nil +} + +// NopCloser returns a ReadSeekCloser with a no-op close method wrapping the provided io.ReadSeeker. +// Exported as streaming.NopCloser(). +func NopCloser(rs io.ReadSeeker) io.ReadSeekCloser { + return nopCloser{rs} +} + +// HasStatusCode returns true if the Response's status code is one of the specified values. +// Exported as runtime.HasStatusCode(). +func HasStatusCode(resp *http.Response, statusCodes ...int) bool { + if resp == nil { + return false + } + for _, sc := range statusCodes { + if resp.StatusCode == sc { + return true + } + } + return false +} + +// AccessToken represents an Azure service bearer access token with expiry information. +// Exported as azcore.AccessToken. +type AccessToken struct { + Token string + ExpiresOn time.Time +} + +// TokenRequestOptions contain specific parameter that may be used by credentials types when attempting to get a token. +// Exported as policy.TokenRequestOptions. +type TokenRequestOptions struct { + // Claims are any additional claims required for the token to satisfy a conditional access policy, such as a + // service may return in a claims challenge following an authorization failure. If a service returned the + // claims value base64 encoded, it must be decoded before setting this field. + Claims string + + // EnableCAE indicates whether to enable Continuous Access Evaluation (CAE) for the requested token. When true, + // azidentity credentials request CAE tokens for resource APIs supporting CAE. Clients are responsible for + // handling CAE challenges. If a client that doesn't handle CAE challenges receives a CAE token, it may end up + // in a loop retrying an API call with a token that has been revoked due to CAE. + EnableCAE bool + + // Scopes contains the list of permission scopes required for the token. + Scopes []string + + // TenantID identifies the tenant from which to request the token. azidentity credentials authenticate in + // their configured default tenants when this field isn't set. + TenantID string +} + +// TokenCredential represents a credential capable of providing an OAuth token. +// Exported as azcore.TokenCredential. +type TokenCredential interface { + // GetToken requests an access token for the specified set of scopes. + GetToken(ctx context.Context, options TokenRequestOptions) (AccessToken, error) +} + +// DecodeByteArray will base-64 decode the provided string into v. +// Exported as runtime.DecodeByteArray() +func DecodeByteArray(s string, v *[]byte, format Base64Encoding) error { + if len(s) == 0 { + return nil + } + payload := string(s) + if payload[0] == '"' { + // remove surrounding quotes + payload = payload[1 : len(payload)-1] + } + switch format { + case Base64StdFormat: + decoded, err := base64.StdEncoding.DecodeString(payload) + if err == nil { + *v = decoded + return nil + } + return err + case Base64URLFormat: + // use raw encoding as URL format should not contain any '=' characters + decoded, err := base64.RawURLEncoding.DecodeString(payload) + if err == nil { + *v = decoded + return nil + } + return err + default: + return fmt.Errorf("unrecognized byte array format: %d", format) + } +} + +// KeyCredential contains an authentication key used to authenticate to an Azure service. +// Exported as azcore.KeyCredential. +type KeyCredential struct { + cred *keyCredential +} + +// NewKeyCredential creates a new instance of [KeyCredential] with the specified values. +// - key is the authentication key +func NewKeyCredential(key string) *KeyCredential { + return &KeyCredential{cred: newKeyCredential(key)} +} + +// Update replaces the existing key with the specified value. +func (k *KeyCredential) Update(key string) { + k.cred.Update(key) +} + +// SASCredential contains a shared access signature used to authenticate to an Azure service. +// Exported as azcore.SASCredential. +type SASCredential struct { + cred *keyCredential +} + +// NewSASCredential creates a new instance of [SASCredential] with the specified values. +// - sas is the shared access signature +func NewSASCredential(sas string) *SASCredential { + return &SASCredential{cred: newKeyCredential(sas)} +} + +// Update replaces the existing shared access signature with the specified value. +func (k *SASCredential) Update(sas string) { + k.cred.Update(sas) +} + +// KeyCredentialGet returns the key for cred. +func KeyCredentialGet(cred *KeyCredential) string { + return cred.cred.Get() +} + +// SASCredentialGet returns the shared access sig for cred. +func SASCredentialGet(cred *SASCredential) string { + return cred.cred.Get() +} + +type keyCredential struct { + key atomic.Value // string +} + +func newKeyCredential(key string) *keyCredential { + keyCred := keyCredential{} + keyCred.key.Store(key) + return &keyCred +} + +func (k *keyCredential) Get() string { + return k.key.Load().(string) +} + +func (k *keyCredential) Update(key string) { + k.key.Store(key) +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported/pipeline.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported/pipeline.go new file mode 100644 index 00000000000..e45f831ed2a --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported/pipeline.go @@ -0,0 +1,77 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package exported + +import ( + "errors" + "net/http" +) + +// Policy represents an extensibility point for the Pipeline that can mutate the specified +// Request and react to the received Response. +// Exported as policy.Policy. +type Policy interface { + // Do applies the policy to the specified Request. When implementing a Policy, mutate the + // request before calling req.Next() to move on to the next policy, and respond to the result + // before returning to the caller. + Do(req *Request) (*http.Response, error) +} + +// Pipeline represents a primitive for sending HTTP requests and receiving responses. +// Its behavior can be extended by specifying policies during construction. +// Exported as runtime.Pipeline. +type Pipeline struct { + policies []Policy +} + +// Transporter represents an HTTP pipeline transport used to send HTTP requests and receive responses. +// Exported as policy.Transporter. +type Transporter interface { + // Do sends the HTTP request and returns the HTTP response or error. + Do(req *http.Request) (*http.Response, error) +} + +// used to adapt a TransportPolicy to a Policy +type transportPolicy struct { + trans Transporter +} + +func (tp transportPolicy) Do(req *Request) (*http.Response, error) { + if tp.trans == nil { + return nil, errors.New("missing transporter") + } + resp, err := tp.trans.Do(req.Raw()) + if err != nil { + return nil, err + } else if resp == nil { + // there was no response and no error (rare but can happen) + // this ensures the retry policy will retry the request + return nil, errors.New("received nil response") + } + return resp, nil +} + +// NewPipeline creates a new Pipeline object from the specified Policies. +// Not directly exported, but used as part of runtime.NewPipeline(). +func NewPipeline(transport Transporter, policies ...Policy) Pipeline { + // transport policy must always be the last in the slice + policies = append(policies, transportPolicy{trans: transport}) + return Pipeline{ + policies: policies, + } +} + +// Do is called for each and every HTTP request. It passes the request through all +// the Policy objects (which can transform the Request's URL/query parameters/headers) +// and ultimately sends the transformed HTTP request over the network. +func (p Pipeline) Do(req *Request) (*http.Response, error) { + if req == nil { + return nil, errors.New("request cannot be nil") + } + req.policies = p.policies + return req.Next() +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported/request.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported/request.go new file mode 100644 index 00000000000..3041984d9b1 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported/request.go @@ -0,0 +1,223 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package exported + +import ( + "context" + "encoding/base64" + "errors" + "fmt" + "io" + "net/http" + "reflect" + "strconv" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" +) + +// Base64Encoding is usesd to specify which base-64 encoder/decoder to use when +// encoding/decoding a slice of bytes to/from a string. +// Exported as runtime.Base64Encoding +type Base64Encoding int + +const ( + // Base64StdFormat uses base64.StdEncoding for encoding and decoding payloads. + Base64StdFormat Base64Encoding = 0 + + // Base64URLFormat uses base64.RawURLEncoding for encoding and decoding payloads. + Base64URLFormat Base64Encoding = 1 +) + +// EncodeByteArray will base-64 encode the byte slice v. +// Exported as runtime.EncodeByteArray() +func EncodeByteArray(v []byte, format Base64Encoding) string { + if format == Base64URLFormat { + return base64.RawURLEncoding.EncodeToString(v) + } + return base64.StdEncoding.EncodeToString(v) +} + +// Request is an abstraction over the creation of an HTTP request as it passes through the pipeline. +// Don't use this type directly, use NewRequest() instead. +// Exported as policy.Request. +type Request struct { + req *http.Request + body io.ReadSeekCloser + policies []Policy + values opValues +} + +type opValues map[reflect.Type]any + +// Set adds/changes a value +func (ov opValues) set(value any) { + ov[reflect.TypeOf(value)] = value +} + +// Get looks for a value set by SetValue first +func (ov opValues) get(value any) bool { + v, ok := ov[reflect.ValueOf(value).Elem().Type()] + if ok { + reflect.ValueOf(value).Elem().Set(reflect.ValueOf(v)) + } + return ok +} + +// NewRequest creates a new Request with the specified input. +// Exported as runtime.NewRequest(). +func NewRequest(ctx context.Context, httpMethod string, endpoint string) (*Request, error) { + req, err := http.NewRequestWithContext(ctx, httpMethod, endpoint, nil) + if err != nil { + return nil, err + } + if req.URL.Host == "" { + return nil, errors.New("no Host in request URL") + } + if !(req.URL.Scheme == "http" || req.URL.Scheme == "https") { + return nil, fmt.Errorf("unsupported protocol scheme %s", req.URL.Scheme) + } + return &Request{req: req}, nil +} + +// Body returns the original body specified when the Request was created. +func (req *Request) Body() io.ReadSeekCloser { + return req.body +} + +// Raw returns the underlying HTTP request. +func (req *Request) Raw() *http.Request { + return req.req +} + +// Next calls the next policy in the pipeline. +// If there are no more policies, nil and an error are returned. +// This method is intended to be called from pipeline policies. +// To send a request through a pipeline call Pipeline.Do(). +func (req *Request) Next() (*http.Response, error) { + if len(req.policies) == 0 { + return nil, errors.New("no more policies") + } + nextPolicy := req.policies[0] + nextReq := *req + nextReq.policies = nextReq.policies[1:] + return nextPolicy.Do(&nextReq) +} + +// SetOperationValue adds/changes a mutable key/value associated with a single operation. +func (req *Request) SetOperationValue(value any) { + if req.values == nil { + req.values = opValues{} + } + req.values.set(value) +} + +// OperationValue looks for a value set by SetOperationValue(). +func (req *Request) OperationValue(value any) bool { + if req.values == nil { + return false + } + return req.values.get(value) +} + +// SetBody sets the specified ReadSeekCloser as the HTTP request body, and sets Content-Type and Content-Length +// accordingly. If the ReadSeekCloser is nil or empty, Content-Length won't be set. If contentType is "", +// Content-Type won't be set, and if it was set, will be deleted. +// Use streaming.NopCloser to turn an io.ReadSeeker into an io.ReadSeekCloser. +func (req *Request) SetBody(body io.ReadSeekCloser, contentType string) error { + // clobber the existing Content-Type to preserve behavior + return SetBody(req, body, contentType, true) +} + +// RewindBody seeks the request's Body stream back to the beginning so it can be resent when retrying an operation. +func (req *Request) RewindBody() error { + if req.body != nil { + // Reset the stream back to the beginning and restore the body + _, err := req.body.Seek(0, io.SeekStart) + req.req.Body = req.body + return err + } + return nil +} + +// Close closes the request body. +func (req *Request) Close() error { + if req.body == nil { + return nil + } + return req.body.Close() +} + +// Clone returns a deep copy of the request with its context changed to ctx. +func (req *Request) Clone(ctx context.Context) *Request { + r2 := *req + r2.req = req.req.Clone(ctx) + return &r2 +} + +// WithContext returns a shallow copy of the request with its context changed to ctx. +func (req *Request) WithContext(ctx context.Context) *Request { + r2 := new(Request) + *r2 = *req + r2.req = r2.req.WithContext(ctx) + return r2 +} + +// not exported but dependent on Request + +// PolicyFunc is a type that implements the Policy interface. +// Use this type when implementing a stateless policy as a first-class function. +type PolicyFunc func(*Request) (*http.Response, error) + +// Do implements the Policy interface on policyFunc. +func (pf PolicyFunc) Do(req *Request) (*http.Response, error) { + return pf(req) +} + +// SetBody sets the specified ReadSeekCloser as the HTTP request body, and sets Content-Type and Content-Length accordingly. +// - req is the request to modify +// - body is the request body; if nil or empty, Content-Length won't be set +// - contentType is the value for the Content-Type header; if empty, Content-Type will be deleted +// - clobberContentType when true, will overwrite the existing value of Content-Type with contentType +func SetBody(req *Request, body io.ReadSeekCloser, contentType string, clobberContentType bool) error { + var err error + var size int64 + if body != nil { + size, err = body.Seek(0, io.SeekEnd) // Seek to the end to get the stream's size + if err != nil { + return err + } + } + if size == 0 { + // treat an empty stream the same as a nil one: assign req a nil body + body = nil + // RFC 9110 specifies a client shouldn't set Content-Length on a request containing no content + // (Del is a no-op when the header has no value) + req.req.Header.Del(shared.HeaderContentLength) + } else { + _, err = body.Seek(0, io.SeekStart) + if err != nil { + return err + } + req.req.Header.Set(shared.HeaderContentLength, strconv.FormatInt(size, 10)) + req.Raw().GetBody = func() (io.ReadCloser, error) { + _, err := body.Seek(0, io.SeekStart) // Seek back to the beginning of the stream + return body, err + } + } + // keep a copy of the body argument. this is to handle cases + // where req.Body is replaced, e.g. httputil.DumpRequest and friends. + req.body = body + req.req.Body = body + req.req.ContentLength = size + if contentType == "" { + // Del is a no-op when the header has no value + req.req.Header.Del(shared.HeaderContentType) + } else if req.req.Header.Get(shared.HeaderContentType) == "" || clobberContentType { + req.req.Header.Set(shared.HeaderContentType, contentType) + } + return nil +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported/response_error.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported/response_error.go new file mode 100644 index 00000000000..08a95458730 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported/response_error.go @@ -0,0 +1,167 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package exported + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "regexp" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/log" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/internal/exported" +) + +// NewResponseError creates a new *ResponseError from the provided HTTP response. +// Exported as runtime.NewResponseError(). +func NewResponseError(resp *http.Response) error { + // prefer the error code in the response header + if ec := resp.Header.Get(shared.HeaderXMSErrorCode); ec != "" { + return NewResponseErrorWithErrorCode(resp, ec) + } + + // if we didn't get x-ms-error-code, check in the response body + body, err := exported.Payload(resp, nil) + if err != nil { + // since we're not returning the ResponseError in this + // case we also don't want to write it to the log. + return err + } + + var errorCode string + if len(body) > 0 { + if fromJSON := extractErrorCodeJSON(body); fromJSON != "" { + errorCode = fromJSON + } else if fromXML := extractErrorCodeXML(body); fromXML != "" { + errorCode = fromXML + } + } + + return NewResponseErrorWithErrorCode(resp, errorCode) +} + +// NewResponseErrorWithErrorCode creates an *azcore.ResponseError from the provided HTTP response and errorCode. +// Exported as runtime.NewResponseErrorWithErrorCode(). +func NewResponseErrorWithErrorCode(resp *http.Response, errorCode string) error { + respErr := &ResponseError{ + ErrorCode: errorCode, + StatusCode: resp.StatusCode, + RawResponse: resp, + } + log.Write(log.EventResponseError, respErr.Error()) + return respErr +} + +func extractErrorCodeJSON(body []byte) string { + var rawObj map[string]any + if err := json.Unmarshal(body, &rawObj); err != nil { + // not a JSON object + return "" + } + + // check if this is a wrapped error, i.e. { "error": { ... } } + // if so then unwrap it + if wrapped, ok := rawObj["error"]; ok { + unwrapped, ok := wrapped.(map[string]any) + if !ok { + return "" + } + rawObj = unwrapped + } else if wrapped, ok := rawObj["odata.error"]; ok { + // check if this a wrapped odata error, i.e. { "odata.error": { ... } } + unwrapped, ok := wrapped.(map[string]any) + if !ok { + return "" + } + rawObj = unwrapped + } + + // now check for the error code + code, ok := rawObj["code"] + if !ok { + return "" + } + codeStr, ok := code.(string) + if !ok { + return "" + } + return codeStr +} + +func extractErrorCodeXML(body []byte) string { + // regular expression is much easier than dealing with the XML parser + rx := regexp.MustCompile(`<(?:\w+:)?[c|C]ode>\s*(\w+)\s*<\/(?:\w+:)?[c|C]ode>`) + res := rx.FindStringSubmatch(string(body)) + if len(res) != 2 { + return "" + } + // first submatch is the entire thing, second one is the captured error code + return res[1] +} + +// ResponseError is returned when a request is made to a service and +// the service returns a non-success HTTP status code. +// Use errors.As() to access this type in the error chain. +// Exported as azcore.ResponseError. +type ResponseError struct { + // ErrorCode is the error code returned by the resource provider if available. + ErrorCode string + + // StatusCode is the HTTP status code as defined in https://pkg.go.dev/net/http#pkg-constants. + StatusCode int + + // RawResponse is the underlying HTTP response. + RawResponse *http.Response +} + +// Error implements the error interface for type ResponseError. +// Note that the message contents are not contractual and can change over time. +func (e *ResponseError) Error() string { + const separator = "--------------------------------------------------------------------------------" + // write the request method and URL with response status code + msg := &bytes.Buffer{} + if e.RawResponse != nil { + if e.RawResponse.Request != nil { + fmt.Fprintf(msg, "%s %s://%s%s\n", e.RawResponse.Request.Method, e.RawResponse.Request.URL.Scheme, e.RawResponse.Request.URL.Host, e.RawResponse.Request.URL.Path) + } else { + fmt.Fprintln(msg, "Request information not available") + } + fmt.Fprintln(msg, separator) + fmt.Fprintf(msg, "RESPONSE %d: %s\n", e.RawResponse.StatusCode, e.RawResponse.Status) + } else { + fmt.Fprintln(msg, "Missing RawResponse") + fmt.Fprintln(msg, separator) + } + if e.ErrorCode != "" { + fmt.Fprintf(msg, "ERROR CODE: %s\n", e.ErrorCode) + } else { + fmt.Fprintln(msg, "ERROR CODE UNAVAILABLE") + } + if e.RawResponse != nil { + fmt.Fprintln(msg, separator) + body, err := exported.Payload(e.RawResponse, nil) + if err != nil { + // this really shouldn't fail at this point as the response + // body is already cached (it was read in NewResponseError) + fmt.Fprintf(msg, "Error reading response body: %v", err) + } else if len(body) > 0 { + if err := json.Indent(msg, body, "", " "); err != nil { + // failed to pretty-print so just dump it verbatim + fmt.Fprint(msg, string(body)) + } + // the standard library doesn't have a pretty-printer for XML + fmt.Fprintln(msg) + } else { + fmt.Fprintln(msg, "Response contained no body") + } + } + fmt.Fprintln(msg, separator) + + return msg.String() +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/log/log.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/log/log.go new file mode 100644 index 00000000000..6fc6d1400e7 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/log/log.go @@ -0,0 +1,50 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This is an internal helper package to combine the complete logging APIs. +package log + +import ( + azlog "github.com/Azure/azure-sdk-for-go/sdk/azcore/log" + "github.com/Azure/azure-sdk-for-go/sdk/internal/log" +) + +type Event = log.Event + +const ( + EventRequest = azlog.EventRequest + EventResponse = azlog.EventResponse + EventResponseError = azlog.EventResponseError + EventRetryPolicy = azlog.EventRetryPolicy + EventLRO = azlog.EventLRO +) + +// Write invokes the underlying listener with the specified event and message. +// If the event shouldn't be logged or there is no listener then Write does nothing. +func Write(cls log.Event, msg string) { + log.Write(cls, msg) +} + +// Writef invokes the underlying listener with the specified event and formatted message. +// If the event shouldn't be logged or there is no listener then Writef does nothing. +func Writef(cls log.Event, format string, a ...any) { + log.Writef(cls, format, a...) +} + +// SetListener will set the Logger to write to the specified listener. +func SetListener(lst func(Event, string)) { + log.SetListener(lst) +} + +// Should returns true if the specified log event should be written to the log. +// By default all log events will be logged. Call SetEvents() to limit +// the log events for logging. +// If no listener has been set this will return false. +// Calling this method is useful when the message to log is computationally expensive +// and you want to avoid the overhead if its log event is not enabled. +func Should(cls log.Event) bool { + return log.Should(cls) +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/async/async.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/async/async.go new file mode 100644 index 00000000000..ccd4794e9e9 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/async/async.go @@ -0,0 +1,159 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package async + +import ( + "context" + "errors" + "fmt" + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/log" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/internal/poller" +) + +// see https://github.com/Azure/azure-resource-manager-rpc/blob/master/v1.0/async-api-reference.md + +// Applicable returns true if the LRO is using Azure-AsyncOperation. +func Applicable(resp *http.Response) bool { + return resp.Header.Get(shared.HeaderAzureAsync) != "" +} + +// CanResume returns true if the token can rehydrate this poller type. +func CanResume(token map[string]any) bool { + _, ok := token["asyncURL"] + return ok +} + +// Poller is an LRO poller that uses the Azure-AsyncOperation pattern. +type Poller[T any] struct { + pl exported.Pipeline + + resp *http.Response + + // The URL from Azure-AsyncOperation header. + AsyncURL string `json:"asyncURL"` + + // The URL from Location header. + LocURL string `json:"locURL"` + + // The URL from the initial LRO request. + OrigURL string `json:"origURL"` + + // The HTTP method from the initial LRO request. + Method string `json:"method"` + + // The value of final-state-via from swagger, can be the empty string. + FinalState pollers.FinalStateVia `json:"finalState"` + + // The LRO's current state. + CurState string `json:"state"` +} + +// New creates a new Poller from the provided initial response and final-state type. +// Pass nil for response to create an empty Poller for rehydration. +func New[T any](pl exported.Pipeline, resp *http.Response, finalState pollers.FinalStateVia) (*Poller[T], error) { + if resp == nil { + log.Write(log.EventLRO, "Resuming Azure-AsyncOperation poller.") + return &Poller[T]{pl: pl}, nil + } + log.Write(log.EventLRO, "Using Azure-AsyncOperation poller.") + asyncURL := resp.Header.Get(shared.HeaderAzureAsync) + if asyncURL == "" { + return nil, errors.New("response is missing Azure-AsyncOperation header") + } + if !poller.IsValidURL(asyncURL) { + return nil, fmt.Errorf("invalid polling URL %s", asyncURL) + } + // check for provisioning state. if the operation is a RELO + // and terminates synchronously this will prevent extra polling. + // it's ok if there's no provisioning state. + state, _ := poller.GetProvisioningState(resp) + if state == "" { + state = poller.StatusInProgress + } + p := &Poller[T]{ + pl: pl, + resp: resp, + AsyncURL: asyncURL, + LocURL: resp.Header.Get(shared.HeaderLocation), + OrigURL: resp.Request.URL.String(), + Method: resp.Request.Method, + FinalState: finalState, + CurState: state, + } + return p, nil +} + +// Done returns true if the LRO is in a terminal state. +func (p *Poller[T]) Done() bool { + return poller.IsTerminalState(p.CurState) +} + +// Poll retrieves the current state of the LRO. +func (p *Poller[T]) Poll(ctx context.Context) (*http.Response, error) { + err := pollers.PollHelper(ctx, p.AsyncURL, p.pl, func(resp *http.Response) (string, error) { + if !poller.StatusCodeValid(resp) { + p.resp = resp + return "", exported.NewResponseError(resp) + } + state, err := poller.GetStatus(resp) + if err != nil { + return "", err + } else if state == "" { + return "", errors.New("the response did not contain a status") + } + p.resp = resp + p.CurState = state + return p.CurState, nil + }) + if err != nil { + return nil, err + } + return p.resp, nil +} + +func (p *Poller[T]) Result(ctx context.Context, out *T) error { + if p.resp.StatusCode == http.StatusNoContent { + return nil + } else if poller.Failed(p.CurState) { + return exported.NewResponseError(p.resp) + } + var req *exported.Request + var err error + if p.Method == http.MethodPatch || p.Method == http.MethodPut { + // for PATCH and PUT, the final GET is on the original resource URL + req, err = exported.NewRequest(ctx, http.MethodGet, p.OrigURL) + } else if p.Method == http.MethodPost { + if p.FinalState == pollers.FinalStateViaAzureAsyncOp { + // no final GET required + } else if p.FinalState == pollers.FinalStateViaOriginalURI { + req, err = exported.NewRequest(ctx, http.MethodGet, p.OrigURL) + } else if p.LocURL != "" { + // ideally FinalState would be set to "location" but it isn't always. + // must check last due to more permissive condition. + req, err = exported.NewRequest(ctx, http.MethodGet, p.LocURL) + } + } + if err != nil { + return err + } + + // if a final GET request has been created, execute it + if req != nil { + resp, err := p.pl.Do(req) + if err != nil { + return err + } + p.resp = resp + } + + return pollers.ResultHelper(p.resp, poller.Failed(p.CurState), out) +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/body/body.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/body/body.go new file mode 100644 index 00000000000..0d781b31d0c --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/body/body.go @@ -0,0 +1,135 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package body + +import ( + "context" + "errors" + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/log" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/internal/poller" +) + +// Kind is the identifier of this type in a resume token. +const kind = "body" + +// Applicable returns true if the LRO is using no headers, just provisioning state. +// This is only applicable to PATCH and PUT methods and assumes no polling headers. +func Applicable(resp *http.Response) bool { + // we can't check for absense of headers due to some misbehaving services + // like redis that return a Location header but don't actually use that protocol + return resp.Request.Method == http.MethodPatch || resp.Request.Method == http.MethodPut +} + +// CanResume returns true if the token can rehydrate this poller type. +func CanResume(token map[string]any) bool { + t, ok := token["type"] + if !ok { + return false + } + tt, ok := t.(string) + if !ok { + return false + } + return tt == kind +} + +// Poller is an LRO poller that uses the Body pattern. +type Poller[T any] struct { + pl exported.Pipeline + + resp *http.Response + + // The poller's type, used for resume token processing. + Type string `json:"type"` + + // The URL for polling. + PollURL string `json:"pollURL"` + + // The LRO's current state. + CurState string `json:"state"` +} + +// New creates a new Poller from the provided initial response. +// Pass nil for response to create an empty Poller for rehydration. +func New[T any](pl exported.Pipeline, resp *http.Response) (*Poller[T], error) { + if resp == nil { + log.Write(log.EventLRO, "Resuming Body poller.") + return &Poller[T]{pl: pl}, nil + } + log.Write(log.EventLRO, "Using Body poller.") + p := &Poller[T]{ + pl: pl, + resp: resp, + Type: kind, + PollURL: resp.Request.URL.String(), + } + // default initial state to InProgress. depending on the HTTP + // status code and provisioning state, we might change the value. + curState := poller.StatusInProgress + provState, err := poller.GetProvisioningState(resp) + if err != nil && !errors.Is(err, poller.ErrNoBody) { + return nil, err + } + if resp.StatusCode == http.StatusCreated && provState != "" { + // absense of provisioning state is ok for a 201, means the operation is in progress + curState = provState + } else if resp.StatusCode == http.StatusOK { + if provState != "" { + curState = provState + } else if provState == "" { + // for a 200, absense of provisioning state indicates success + curState = poller.StatusSucceeded + } + } else if resp.StatusCode == http.StatusNoContent { + curState = poller.StatusSucceeded + } + p.CurState = curState + return p, nil +} + +func (p *Poller[T]) Done() bool { + return poller.IsTerminalState(p.CurState) +} + +func (p *Poller[T]) Poll(ctx context.Context) (*http.Response, error) { + err := pollers.PollHelper(ctx, p.PollURL, p.pl, func(resp *http.Response) (string, error) { + if !poller.StatusCodeValid(resp) { + p.resp = resp + return "", exported.NewResponseError(resp) + } + if resp.StatusCode == http.StatusNoContent { + p.resp = resp + p.CurState = poller.StatusSucceeded + return p.CurState, nil + } + state, err := poller.GetProvisioningState(resp) + if errors.Is(err, poller.ErrNoBody) { + // a missing response body in non-204 case is an error + return "", err + } else if state == "" { + // a response body without provisioning state is considered terminal success + state = poller.StatusSucceeded + } else if err != nil { + return "", err + } + p.resp = resp + p.CurState = state + return p.CurState, nil + }) + if err != nil { + return nil, err + } + return p.resp, nil +} + +func (p *Poller[T]) Result(ctx context.Context, out *T) error { + return pollers.ResultHelper(p.resp, poller.Failed(p.CurState), out) +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/fake/fake.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/fake/fake.go new file mode 100644 index 00000000000..51aede8a2b8 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/fake/fake.go @@ -0,0 +1,133 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package fake + +import ( + "context" + "errors" + "fmt" + "net/http" + "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/log" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/internal/poller" +) + +// Applicable returns true if the LRO is a fake. +func Applicable(resp *http.Response) bool { + return resp.Header.Get(shared.HeaderFakePollerStatus) != "" +} + +// CanResume returns true if the token can rehydrate this poller type. +func CanResume(token map[string]any) bool { + _, ok := token["fakeURL"] + return ok +} + +// Poller is an LRO poller that uses the Core-Fake-Poller pattern. +type Poller[T any] struct { + pl exported.Pipeline + + resp *http.Response + + // The API name from CtxAPINameKey + APIName string `json:"apiName"` + + // The URL from Core-Fake-Poller header. + FakeURL string `json:"fakeURL"` + + // The LRO's current state. + FakeStatus string `json:"status"` +} + +// lroStatusURLSuffix is the URL path suffix for a faked LRO. +const lroStatusURLSuffix = "/get/fake/status" + +// New creates a new Poller from the provided initial response. +// Pass nil for response to create an empty Poller for rehydration. +func New[T any](pl exported.Pipeline, resp *http.Response) (*Poller[T], error) { + if resp == nil { + log.Write(log.EventLRO, "Resuming Core-Fake-Poller poller.") + return &Poller[T]{pl: pl}, nil + } + + log.Write(log.EventLRO, "Using Core-Fake-Poller poller.") + fakeStatus := resp.Header.Get(shared.HeaderFakePollerStatus) + if fakeStatus == "" { + return nil, errors.New("response is missing Fake-Poller-Status header") + } + + ctxVal := resp.Request.Context().Value(shared.CtxAPINameKey{}) + if ctxVal == nil { + return nil, errors.New("missing value for CtxAPINameKey") + } + + apiName, ok := ctxVal.(string) + if !ok { + return nil, fmt.Errorf("expected string for CtxAPINameKey, the type was %T", ctxVal) + } + + qp := "" + if resp.Request.URL.RawQuery != "" { + qp = "?" + resp.Request.URL.RawQuery + } + + p := &Poller[T]{ + pl: pl, + resp: resp, + APIName: apiName, + // NOTE: any changes to this path format MUST be reflected in SanitizePollerPath() + FakeURL: fmt.Sprintf("%s://%s%s%s%s", resp.Request.URL.Scheme, resp.Request.URL.Host, resp.Request.URL.Path, lroStatusURLSuffix, qp), + FakeStatus: fakeStatus, + } + return p, nil +} + +// Done returns true if the LRO is in a terminal state. +func (p *Poller[T]) Done() bool { + return poller.IsTerminalState(p.FakeStatus) +} + +// Poll retrieves the current state of the LRO. +func (p *Poller[T]) Poll(ctx context.Context) (*http.Response, error) { + ctx = context.WithValue(ctx, shared.CtxAPINameKey{}, p.APIName) + err := pollers.PollHelper(ctx, p.FakeURL, p.pl, func(resp *http.Response) (string, error) { + if !poller.StatusCodeValid(resp) { + p.resp = resp + return "", exported.NewResponseError(resp) + } + fakeStatus := resp.Header.Get(shared.HeaderFakePollerStatus) + if fakeStatus == "" { + return "", errors.New("response is missing Fake-Poller-Status header") + } + p.resp = resp + p.FakeStatus = fakeStatus + return p.FakeStatus, nil + }) + if err != nil { + return nil, err + } + return p.resp, nil +} + +func (p *Poller[T]) Result(ctx context.Context, out *T) error { + if p.resp.StatusCode == http.StatusNoContent { + return nil + } else if poller.Failed(p.FakeStatus) { + return exported.NewResponseError(p.resp) + } + + return pollers.ResultHelper(p.resp, poller.Failed(p.FakeStatus), out) +} + +// SanitizePollerPath removes any fake-appended suffix from a URL's path. +func SanitizePollerPath(path string) string { + return strings.TrimSuffix(path, lroStatusURLSuffix) +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/loc/loc.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/loc/loc.go new file mode 100644 index 00000000000..7a56c5211b7 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/loc/loc.go @@ -0,0 +1,123 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package loc + +import ( + "context" + "errors" + "fmt" + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/log" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/internal/poller" +) + +// Kind is the identifier of this type in a resume token. +const kind = "loc" + +// Applicable returns true if the LRO is using Location. +func Applicable(resp *http.Response) bool { + return resp.Header.Get(shared.HeaderLocation) != "" +} + +// CanResume returns true if the token can rehydrate this poller type. +func CanResume(token map[string]any) bool { + t, ok := token["type"] + if !ok { + return false + } + tt, ok := t.(string) + if !ok { + return false + } + return tt == kind +} + +// Poller is an LRO poller that uses the Location pattern. +type Poller[T any] struct { + pl exported.Pipeline + resp *http.Response + + Type string `json:"type"` + PollURL string `json:"pollURL"` + CurState string `json:"state"` +} + +// New creates a new Poller from the provided initial response. +// Pass nil for response to create an empty Poller for rehydration. +func New[T any](pl exported.Pipeline, resp *http.Response) (*Poller[T], error) { + if resp == nil { + log.Write(log.EventLRO, "Resuming Location poller.") + return &Poller[T]{pl: pl}, nil + } + log.Write(log.EventLRO, "Using Location poller.") + locURL := resp.Header.Get(shared.HeaderLocation) + if locURL == "" { + return nil, errors.New("response is missing Location header") + } + if !poller.IsValidURL(locURL) { + return nil, fmt.Errorf("invalid polling URL %s", locURL) + } + // check for provisioning state. if the operation is a RELO + // and terminates synchronously this will prevent extra polling. + // it's ok if there's no provisioning state. + state, _ := poller.GetProvisioningState(resp) + if state == "" { + state = poller.StatusInProgress + } + return &Poller[T]{ + pl: pl, + resp: resp, + Type: kind, + PollURL: locURL, + CurState: state, + }, nil +} + +func (p *Poller[T]) Done() bool { + return poller.IsTerminalState(p.CurState) +} + +func (p *Poller[T]) Poll(ctx context.Context) (*http.Response, error) { + err := pollers.PollHelper(ctx, p.PollURL, p.pl, func(resp *http.Response) (string, error) { + // location polling can return an updated polling URL + if h := resp.Header.Get(shared.HeaderLocation); h != "" { + p.PollURL = h + } + // if provisioning state is available, use that. this is only + // for some ARM LRO scenarios (e.g. DELETE with a Location header) + // so if it's missing then use HTTP status code. + provState, _ := poller.GetProvisioningState(resp) + p.resp = resp + if provState != "" { + p.CurState = provState + } else if resp.StatusCode == http.StatusAccepted { + p.CurState = poller.StatusInProgress + } else if resp.StatusCode > 199 && resp.StatusCode < 300 { + // any 2xx other than a 202 indicates success + p.CurState = poller.StatusSucceeded + } else if pollers.IsNonTerminalHTTPStatusCode(resp) { + // the request timed out or is being throttled. + // DO NOT include this as a terminal failure. preserve + // the existing state and return the response. + } else { + p.CurState = poller.StatusFailed + } + return p.CurState, nil + }) + if err != nil { + return nil, err + } + return p.resp, nil +} + +func (p *Poller[T]) Result(ctx context.Context, out *T) error { + return pollers.ResultHelper(p.resp, poller.Failed(p.CurState), out) +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/op/op.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/op/op.go new file mode 100644 index 00000000000..ac1c0efb5ac --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/op/op.go @@ -0,0 +1,145 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package op + +import ( + "context" + "errors" + "fmt" + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/log" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/internal/poller" +) + +// Applicable returns true if the LRO is using Operation-Location. +func Applicable(resp *http.Response) bool { + return resp.Header.Get(shared.HeaderOperationLocation) != "" +} + +// CanResume returns true if the token can rehydrate this poller type. +func CanResume(token map[string]any) bool { + _, ok := token["oplocURL"] + return ok +} + +// Poller is an LRO poller that uses the Operation-Location pattern. +type Poller[T any] struct { + pl exported.Pipeline + resp *http.Response + + OpLocURL string `json:"oplocURL"` + LocURL string `json:"locURL"` + OrigURL string `json:"origURL"` + Method string `json:"method"` + FinalState pollers.FinalStateVia `json:"finalState"` + CurState string `json:"state"` +} + +// New creates a new Poller from the provided initial response. +// Pass nil for response to create an empty Poller for rehydration. +func New[T any](pl exported.Pipeline, resp *http.Response, finalState pollers.FinalStateVia) (*Poller[T], error) { + if resp == nil { + log.Write(log.EventLRO, "Resuming Operation-Location poller.") + return &Poller[T]{pl: pl}, nil + } + log.Write(log.EventLRO, "Using Operation-Location poller.") + opURL := resp.Header.Get(shared.HeaderOperationLocation) + if opURL == "" { + return nil, errors.New("response is missing Operation-Location header") + } + if !poller.IsValidURL(opURL) { + return nil, fmt.Errorf("invalid Operation-Location URL %s", opURL) + } + locURL := resp.Header.Get(shared.HeaderLocation) + // Location header is optional + if locURL != "" && !poller.IsValidURL(locURL) { + return nil, fmt.Errorf("invalid Location URL %s", locURL) + } + // default initial state to InProgress. if the + // service sent us a status then use that instead. + curState := poller.StatusInProgress + status, err := poller.GetStatus(resp) + if err != nil && !errors.Is(err, poller.ErrNoBody) { + return nil, err + } + if status != "" { + curState = status + } + + return &Poller[T]{ + pl: pl, + resp: resp, + OpLocURL: opURL, + LocURL: locURL, + OrigURL: resp.Request.URL.String(), + Method: resp.Request.Method, + FinalState: finalState, + CurState: curState, + }, nil +} + +func (p *Poller[T]) Done() bool { + return poller.IsTerminalState(p.CurState) +} + +func (p *Poller[T]) Poll(ctx context.Context) (*http.Response, error) { + err := pollers.PollHelper(ctx, p.OpLocURL, p.pl, func(resp *http.Response) (string, error) { + if !poller.StatusCodeValid(resp) { + p.resp = resp + return "", exported.NewResponseError(resp) + } + state, err := poller.GetStatus(resp) + if err != nil { + return "", err + } else if state == "" { + return "", errors.New("the response did not contain a status") + } + p.resp = resp + p.CurState = state + return p.CurState, nil + }) + if err != nil { + return nil, err + } + return p.resp, nil +} + +func (p *Poller[T]) Result(ctx context.Context, out *T) error { + var req *exported.Request + var err error + if p.FinalState == pollers.FinalStateViaLocation && p.LocURL != "" { + req, err = exported.NewRequest(ctx, http.MethodGet, p.LocURL) + } else if p.FinalState == pollers.FinalStateViaOpLocation && p.Method == http.MethodPost { + // no final GET required, terminal response should have it + } else if rl, rlErr := poller.GetResourceLocation(p.resp); rlErr != nil && !errors.Is(rlErr, poller.ErrNoBody) { + return rlErr + } else if rl != "" { + req, err = exported.NewRequest(ctx, http.MethodGet, rl) + } else if p.Method == http.MethodPatch || p.Method == http.MethodPut { + req, err = exported.NewRequest(ctx, http.MethodGet, p.OrigURL) + } else if p.Method == http.MethodPost && p.LocURL != "" { + req, err = exported.NewRequest(ctx, http.MethodGet, p.LocURL) + } + if err != nil { + return err + } + + // if a final GET request has been created, execute it + if req != nil { + resp, err := p.pl.Do(req) + if err != nil { + return err + } + p.resp = resp + } + + return pollers.ResultHelper(p.resp, poller.Failed(p.CurState), out) +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/poller.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/poller.go new file mode 100644 index 00000000000..37ed647f4e0 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/poller.go @@ -0,0 +1,24 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package pollers + +// FinalStateVia is the enumerated type for the possible final-state-via values. +type FinalStateVia string + +const ( + // FinalStateViaAzureAsyncOp indicates the final payload comes from the Azure-AsyncOperation URL. + FinalStateViaAzureAsyncOp FinalStateVia = "azure-async-operation" + + // FinalStateViaLocation indicates the final payload comes from the Location URL. + FinalStateViaLocation FinalStateVia = "location" + + // FinalStateViaOriginalURI indicates the final payload comes from the original URL. + FinalStateViaOriginalURI FinalStateVia = "original-uri" + + // FinalStateViaOpLocation indicates the final payload comes from the Operation-Location URL. + FinalStateViaOpLocation FinalStateVia = "operation-location" +) diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/util.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/util.go new file mode 100644 index 00000000000..eb3cf651db0 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/util.go @@ -0,0 +1,200 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package pollers + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "reflect" + + azexported "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/log" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/internal/exported" + "github.com/Azure/azure-sdk-for-go/sdk/internal/poller" +) + +// getTokenTypeName creates a type name from the type parameter T. +func getTokenTypeName[T any]() (string, error) { + tt := shared.TypeOfT[T]() + var n string + if tt.Kind() == reflect.Pointer { + n = "*" + tt = tt.Elem() + } + n += tt.Name() + if n == "" { + return "", errors.New("nameless types are not allowed") + } + return n, nil +} + +type resumeTokenWrapper[T any] struct { + Type string `json:"type"` + Token T `json:"token"` +} + +// NewResumeToken creates a resume token from the specified type. +// An error is returned if the generic type has no name (e.g. struct{}). +func NewResumeToken[TResult, TSource any](from TSource) (string, error) { + n, err := getTokenTypeName[TResult]() + if err != nil { + return "", err + } + b, err := json.Marshal(resumeTokenWrapper[TSource]{ + Type: n, + Token: from, + }) + if err != nil { + return "", err + } + return string(b), nil +} + +// ExtractToken returns the poller-specific token information from the provided token value. +func ExtractToken(token string) ([]byte, error) { + raw := map[string]json.RawMessage{} + if err := json.Unmarshal([]byte(token), &raw); err != nil { + return nil, err + } + // this is dependent on the type resumeTokenWrapper[T] + tk, ok := raw["token"] + if !ok { + return nil, errors.New("missing token value") + } + return tk, nil +} + +// IsTokenValid returns an error if the specified token isn't applicable for generic type T. +func IsTokenValid[T any](token string) error { + raw := map[string]any{} + if err := json.Unmarshal([]byte(token), &raw); err != nil { + return err + } + t, ok := raw["type"] + if !ok { + return errors.New("missing type value") + } + tt, ok := t.(string) + if !ok { + return fmt.Errorf("invalid type format %T", t) + } + n, err := getTokenTypeName[T]() + if err != nil { + return err + } + if tt != n { + return fmt.Errorf("cannot resume from this poller token. token is for type %s, not %s", tt, n) + } + return nil +} + +// used if the operation synchronously completed +type NopPoller[T any] struct { + resp *http.Response + result T +} + +// NewNopPoller creates a NopPoller from the provided response. +// It unmarshals the response body into an instance of T. +func NewNopPoller[T any](resp *http.Response) (*NopPoller[T], error) { + np := &NopPoller[T]{resp: resp} + if resp.StatusCode == http.StatusNoContent { + return np, nil + } + payload, err := exported.Payload(resp, nil) + if err != nil { + return nil, err + } + if len(payload) == 0 { + return np, nil + } + if err = json.Unmarshal(payload, &np.result); err != nil { + return nil, err + } + return np, nil +} + +func (*NopPoller[T]) Done() bool { + return true +} + +func (p *NopPoller[T]) Poll(context.Context) (*http.Response, error) { + return p.resp, nil +} + +func (p *NopPoller[T]) Result(ctx context.Context, out *T) error { + *out = p.result + return nil +} + +// PollHelper creates and executes the request, calling update() with the response. +// If the request fails, the update func is not called. +// The update func returns the state of the operation for logging purposes or an error +// if it fails to extract the required state from the response. +func PollHelper(ctx context.Context, endpoint string, pl azexported.Pipeline, update func(resp *http.Response) (string, error)) error { + req, err := azexported.NewRequest(ctx, http.MethodGet, endpoint) + if err != nil { + return err + } + resp, err := pl.Do(req) + if err != nil { + return err + } + state, err := update(resp) + if err != nil { + return err + } + log.Writef(log.EventLRO, "State %s", state) + return nil +} + +// ResultHelper processes the response as success or failure. +// In the success case, it unmarshals the payload into either a new instance of T or out. +// In the failure case, it creates an *azcore.Response error from the response. +func ResultHelper[T any](resp *http.Response, failed bool, out *T) error { + // short-circuit the simple success case with no response body to unmarshal + if resp.StatusCode == http.StatusNoContent { + return nil + } + + defer resp.Body.Close() + if !poller.StatusCodeValid(resp) || failed { + // the LRO failed. unmarshall the error and update state + return azexported.NewResponseError(resp) + } + + // success case + payload, err := exported.Payload(resp, nil) + if err != nil { + return err + } + if len(payload) == 0 { + return nil + } + + if err = json.Unmarshal(payload, out); err != nil { + return err + } + return nil +} + +// IsNonTerminalHTTPStatusCode returns true if the HTTP status code should be +// considered non-terminal thus eligible for retry. +func IsNonTerminalHTTPStatusCode(resp *http.Response) bool { + return exported.HasStatusCode(resp, + http.StatusRequestTimeout, // 408 + http.StatusTooManyRequests, // 429 + http.StatusInternalServerError, // 500 + http.StatusBadGateway, // 502 + http.StatusServiceUnavailable, // 503 + http.StatusGatewayTimeout, // 504 + ) +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared/constants.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared/constants.go new file mode 100644 index 00000000000..03691cbf024 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared/constants.go @@ -0,0 +1,44 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package shared + +const ( + ContentTypeAppJSON = "application/json" + ContentTypeAppXML = "application/xml" + ContentTypeTextPlain = "text/plain" +) + +const ( + HeaderAuthorization = "Authorization" + HeaderAuxiliaryAuthorization = "x-ms-authorization-auxiliary" + HeaderAzureAsync = "Azure-AsyncOperation" + HeaderContentLength = "Content-Length" + HeaderContentType = "Content-Type" + HeaderFakePollerStatus = "Fake-Poller-Status" + HeaderLocation = "Location" + HeaderOperationLocation = "Operation-Location" + HeaderRetryAfter = "Retry-After" + HeaderRetryAfterMS = "Retry-After-Ms" + HeaderUserAgent = "User-Agent" + HeaderWWWAuthenticate = "WWW-Authenticate" + HeaderXMSClientRequestID = "x-ms-client-request-id" + HeaderXMSRequestID = "x-ms-request-id" + HeaderXMSErrorCode = "x-ms-error-code" + HeaderXMSRetryAfterMS = "x-ms-retry-after-ms" +) + +const BearerTokenPrefix = "Bearer " + +const TracingNamespaceAttrName = "az.namespace" + +const ( + // Module is the name of the calling module used in telemetry data. + Module = "azcore" + + // Version is the semantic version (see http://semver.org) of this module. + Version = "v1.11.1" +) diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared/shared.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared/shared.go new file mode 100644 index 00000000000..d3da2c5fdfa --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared/shared.go @@ -0,0 +1,149 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package shared + +import ( + "context" + "fmt" + "net/http" + "reflect" + "regexp" + "strconv" + "time" +) + +// NOTE: when adding a new context key type, it likely needs to be +// added to the deny-list of key types in ContextWithDeniedValues + +// CtxWithHTTPHeaderKey is used as a context key for adding/retrieving http.Header. +type CtxWithHTTPHeaderKey struct{} + +// CtxWithRetryOptionsKey is used as a context key for adding/retrieving RetryOptions. +type CtxWithRetryOptionsKey struct{} + +// CtxWithCaptureResponse is used as a context key for retrieving the raw response. +type CtxWithCaptureResponse struct{} + +// CtxWithTracingTracer is used as a context key for adding/retrieving tracing.Tracer. +type CtxWithTracingTracer struct{} + +// CtxAPINameKey is used as a context key for adding/retrieving the API name. +type CtxAPINameKey struct{} + +// Delay waits for the duration to elapse or the context to be cancelled. +func Delay(ctx context.Context, delay time.Duration) error { + select { + case <-time.After(delay): + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +// RetryAfter returns non-zero if the response contains one of the headers with a "retry after" value. +// Headers are checked in the following order: retry-after-ms, x-ms-retry-after-ms, retry-after +func RetryAfter(resp *http.Response) time.Duration { + if resp == nil { + return 0 + } + + type retryData struct { + header string + units time.Duration + + // custom is used when the regular algorithm failed and is optional. + // the returned duration is used verbatim (units is not applied). + custom func(string) time.Duration + } + + nop := func(string) time.Duration { return 0 } + + // the headers are listed in order of preference + retries := []retryData{ + { + header: HeaderRetryAfterMS, + units: time.Millisecond, + custom: nop, + }, + { + header: HeaderXMSRetryAfterMS, + units: time.Millisecond, + custom: nop, + }, + { + header: HeaderRetryAfter, + units: time.Second, + + // retry-after values are expressed in either number of + // seconds or an HTTP-date indicating when to try again + custom: func(ra string) time.Duration { + t, err := time.Parse(time.RFC1123, ra) + if err != nil { + return 0 + } + return time.Until(t) + }, + }, + } + + for _, retry := range retries { + v := resp.Header.Get(retry.header) + if v == "" { + continue + } + if retryAfter, _ := strconv.Atoi(v); retryAfter > 0 { + return time.Duration(retryAfter) * retry.units + } else if d := retry.custom(v); d > 0 { + return d + } + } + + return 0 +} + +// TypeOfT returns the type of the generic type param. +func TypeOfT[T any]() reflect.Type { + // you can't, at present, obtain the type of + // a type parameter, so this is the trick + return reflect.TypeOf((*T)(nil)).Elem() +} + +// TransportFunc is a helper to use a first-class func to satisfy the Transporter interface. +type TransportFunc func(*http.Request) (*http.Response, error) + +// Do implements the Transporter interface for the TransportFunc type. +func (pf TransportFunc) Do(req *http.Request) (*http.Response, error) { + return pf(req) +} + +// ValidateModVer verifies that moduleVersion is a valid semver 2.0 string. +func ValidateModVer(moduleVersion string) error { + modVerRegx := regexp.MustCompile(`^v\d+\.\d+\.\d+(?:-[a-zA-Z0-9_.-]+)?$`) + if !modVerRegx.MatchString(moduleVersion) { + return fmt.Errorf("malformed moduleVersion param value %s", moduleVersion) + } + return nil +} + +// ContextWithDeniedValues wraps an existing [context.Context], denying access to certain context values. +// Pipeline policies that create new requests to be sent down their own pipeline MUST wrap the caller's +// context with an instance of this type. This is to prevent context values from flowing across disjoint +// requests which can have unintended side-effects. +type ContextWithDeniedValues struct { + context.Context +} + +// Value implements part of the [context.Context] interface. +// It acts as a deny-list for certain context keys. +func (c *ContextWithDeniedValues) Value(key any) any { + switch key.(type) { + case CtxAPINameKey, CtxWithCaptureResponse, CtxWithHTTPHeaderKey, CtxWithRetryOptionsKey, CtxWithTracingTracer: + return nil + default: + return c.Context.Value(key) + } +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/log/doc.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/log/doc.go new file mode 100644 index 00000000000..2f3901bff3c --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/log/doc.go @@ -0,0 +1,10 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright 2017 Microsoft Corporation. All rights reserved. +// Use of this source code is governed by an MIT +// license that can be found in the LICENSE file. + +// Package log contains functionality for configuring logging behavior. +// Default logging to stderr can be enabled by setting environment variable AZURE_SDK_GO_LOGGING to "all". +package log diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/log/log.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/log/log.go new file mode 100644 index 00000000000..f260dac3637 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/log/log.go @@ -0,0 +1,55 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Package log provides functionality for configuring logging facilities. +package log + +import ( + "github.com/Azure/azure-sdk-for-go/sdk/internal/log" +) + +// Event is used to group entries. Each group can be toggled on or off. +type Event = log.Event + +const ( + // EventRequest entries contain information about HTTP requests. + // This includes information like the URL, query parameters, and headers. + EventRequest Event = "Request" + + // EventResponse entries contain information about HTTP responses. + // This includes information like the HTTP status code, headers, and request URL. + EventResponse Event = "Response" + + // EventResponseError entries contain information about HTTP responses that returned + // an *azcore.ResponseError (i.e. responses with a non 2xx HTTP status code). + // This includes the contents of ResponseError.Error(). + EventResponseError Event = "ResponseError" + + // EventRetryPolicy entries contain information specific to the retry policy in use. + EventRetryPolicy Event = "Retry" + + // EventLRO entries contain information specific to long-running operations. + // This includes information like polling location, operation state, and sleep intervals. + EventLRO Event = "LongRunningOperation" +) + +// SetEvents is used to control which events are written to +// the log. By default all log events are writen. +// NOTE: this is not goroutine safe and should be called before using SDK clients. +func SetEvents(cls ...Event) { + log.SetEvents(cls...) +} + +// SetListener will set the Logger to write to the specified Listener. +// NOTE: this is not goroutine safe and should be called before using SDK clients. +func SetListener(lst func(Event, string)) { + log.SetListener(lst) +} + +// for testing purposes +func resetEvents() { + log.TestResetEvents() +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/policy/doc.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/policy/doc.go new file mode 100644 index 00000000000..fad2579ed6c --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/policy/doc.go @@ -0,0 +1,10 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright 2017 Microsoft Corporation. All rights reserved. +// Use of this source code is governed by an MIT +// license that can be found in the LICENSE file. + +// Package policy contains the definitions needed for configuring in-box pipeline policies +// and creating custom policies. +package policy diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/policy/policy.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/policy/policy.go new file mode 100644 index 00000000000..8d984535887 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/policy/policy.go @@ -0,0 +1,197 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package policy + +import ( + "context" + "net/http" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/tracing" +) + +// Policy represents an extensibility point for the Pipeline that can mutate the specified +// Request and react to the received Response. +type Policy = exported.Policy + +// Transporter represents an HTTP pipeline transport used to send HTTP requests and receive responses. +type Transporter = exported.Transporter + +// Request is an abstraction over the creation of an HTTP request as it passes through the pipeline. +// Don't use this type directly, use runtime.NewRequest() instead. +type Request = exported.Request + +// ClientOptions contains optional settings for a client's pipeline. +// Instances can be shared across calls to SDK client constructors when uniform configuration is desired. +// Zero-value fields will have their specified default values applied during use. +type ClientOptions struct { + // APIVersion overrides the default version requested of the service. + // Set with caution as this package version has not been tested with arbitrary service versions. + APIVersion string + + // Cloud specifies a cloud for the client. The default is Azure Public Cloud. + Cloud cloud.Configuration + + // InsecureAllowCredentialWithHTTP enables authenticated requests over HTTP. + // By default, authenticated requests to an HTTP endpoint are rejected by the client. + // WARNING: setting this to true will allow sending the credential in clear text. Use with caution. + InsecureAllowCredentialWithHTTP bool + + // Logging configures the built-in logging policy. + Logging LogOptions + + // Retry configures the built-in retry policy. + Retry RetryOptions + + // Telemetry configures the built-in telemetry policy. + Telemetry TelemetryOptions + + // TracingProvider configures the tracing provider. + // It defaults to a no-op tracer. + TracingProvider tracing.Provider + + // Transport sets the transport for HTTP requests. + Transport Transporter + + // PerCallPolicies contains custom policies to inject into the pipeline. + // Each policy is executed once per request. + PerCallPolicies []Policy + + // PerRetryPolicies contains custom policies to inject into the pipeline. + // Each policy is executed once per request, and for each retry of that request. + PerRetryPolicies []Policy +} + +// LogOptions configures the logging policy's behavior. +type LogOptions struct { + // IncludeBody indicates if request and response bodies should be included in logging. + // The default value is false. + // NOTE: enabling this can lead to disclosure of sensitive information, use with care. + IncludeBody bool + + // AllowedHeaders is the slice of headers to log with their values intact. + // All headers not in the slice will have their values REDACTED. + // Applies to request and response headers. + AllowedHeaders []string + + // AllowedQueryParams is the slice of query parameters to log with their values intact. + // All query parameters not in the slice will have their values REDACTED. + AllowedQueryParams []string +} + +// RetryOptions configures the retry policy's behavior. +// Zero-value fields will have their specified default values applied during use. +// This allows for modification of a subset of fields. +type RetryOptions struct { + // MaxRetries specifies the maximum number of attempts a failed operation will be retried + // before producing an error. + // The default value is three. A value less than zero means one try and no retries. + MaxRetries int32 + + // TryTimeout indicates the maximum time allowed for any single try of an HTTP request. + // This is disabled by default. Specify a value greater than zero to enable. + // NOTE: Setting this to a small value might cause premature HTTP request time-outs. + TryTimeout time.Duration + + // RetryDelay specifies the initial amount of delay to use before retrying an operation. + // The value is used only if the HTTP response does not contain a Retry-After header. + // The delay increases exponentially with each retry up to the maximum specified by MaxRetryDelay. + // The default value is four seconds. A value less than zero means no delay between retries. + RetryDelay time.Duration + + // MaxRetryDelay specifies the maximum delay allowed before retrying an operation. + // Typically the value is greater than or equal to the value specified in RetryDelay. + // The default Value is 60 seconds. A value less than zero means there is no cap. + MaxRetryDelay time.Duration + + // StatusCodes specifies the HTTP status codes that indicate the operation should be retried. + // A nil slice will use the following values. + // http.StatusRequestTimeout 408 + // http.StatusTooManyRequests 429 + // http.StatusInternalServerError 500 + // http.StatusBadGateway 502 + // http.StatusServiceUnavailable 503 + // http.StatusGatewayTimeout 504 + // Specifying values will replace the default values. + // Specifying an empty slice will disable retries for HTTP status codes. + StatusCodes []int + + // ShouldRetry evaluates if the retry policy should retry the request. + // When specified, the function overrides comparison against the list of + // HTTP status codes and error checking within the retry policy. Context + // and NonRetriable errors remain evaluated before calling ShouldRetry. + // The *http.Response and error parameters are mutually exclusive, i.e. + // if one is nil, the other is not nil. + // A return value of true means the retry policy should retry. + ShouldRetry func(*http.Response, error) bool +} + +// TelemetryOptions configures the telemetry policy's behavior. +type TelemetryOptions struct { + // ApplicationID is an application-specific identification string to add to the User-Agent. + // It has a maximum length of 24 characters and must not contain any spaces. + ApplicationID string + + // Disabled will prevent the addition of any telemetry data to the User-Agent. + Disabled bool +} + +// TokenRequestOptions contain specific parameter that may be used by credentials types when attempting to get a token. +type TokenRequestOptions = exported.TokenRequestOptions + +// BearerTokenOptions configures the bearer token policy's behavior. +type BearerTokenOptions struct { + // AuthorizationHandler allows SDK developers to run client-specific logic when BearerTokenPolicy must authorize a request. + // When this field isn't set, the policy follows its default behavior of authorizing every request with a bearer token from + // its given credential. + AuthorizationHandler AuthorizationHandler + + // InsecureAllowCredentialWithHTTP enables authenticated requests over HTTP. + // By default, authenticated requests to an HTTP endpoint are rejected by the client. + // WARNING: setting this to true will allow sending the bearer token in clear text. Use with caution. + InsecureAllowCredentialWithHTTP bool +} + +// AuthorizationHandler allows SDK developers to insert custom logic that runs when BearerTokenPolicy must authorize a request. +type AuthorizationHandler struct { + // OnRequest is called each time the policy receives a request. Its func parameter authorizes the request with a token + // from the policy's given credential. Implementations that need to perform I/O should use the Request's context, + // available from Request.Raw().Context(). When OnRequest returns an error, the policy propagates that error and doesn't + // send the request. When OnRequest is nil, the policy follows its default behavior, authorizing the request with a + // token from its credential according to its configuration. + OnRequest func(*Request, func(TokenRequestOptions) error) error + + // OnChallenge is called when the policy receives a 401 response, allowing the AuthorizationHandler to re-authorize the + // request according to an authentication challenge (the Response's WWW-Authenticate header). OnChallenge is responsible + // for parsing parameters from the challenge. Its func parameter will authorize the request with a token from the policy's + // given credential. Implementations that need to perform I/O should use the Request's context, available from + // Request.Raw().Context(). When OnChallenge returns nil, the policy will send the request again. When OnChallenge is nil, + // the policy will return any 401 response to the client. + OnChallenge func(*Request, *http.Response, func(TokenRequestOptions) error) error +} + +// WithCaptureResponse applies the HTTP response retrieval annotation to the parent context. +// The resp parameter will contain the HTTP response after the request has completed. +func WithCaptureResponse(parent context.Context, resp **http.Response) context.Context { + return context.WithValue(parent, shared.CtxWithCaptureResponse{}, resp) +} + +// WithHTTPHeader adds the specified http.Header to the parent context. +// Use this to specify custom HTTP headers at the API-call level. +// Any overlapping headers will have their values replaced with the values specified here. +func WithHTTPHeader(parent context.Context, header http.Header) context.Context { + return context.WithValue(parent, shared.CtxWithHTTPHeaderKey{}, header) +} + +// WithRetryOptions adds the specified RetryOptions to the parent context. +// Use this to specify custom RetryOptions at the API-call level. +func WithRetryOptions(parent context.Context, options RetryOptions) context.Context { + return context.WithValue(parent, shared.CtxWithRetryOptionsKey{}, options) +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/doc.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/doc.go new file mode 100644 index 00000000000..c9cfa438cb3 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/doc.go @@ -0,0 +1,10 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright 2017 Microsoft Corporation. All rights reserved. +// Use of this source code is governed by an MIT +// license that can be found in the LICENSE file. + +// Package runtime contains various facilities for creating requests and handling responses. +// The content is intended for SDK authors. +package runtime diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/errors.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/errors.go new file mode 100644 index 00000000000..c0d56158e22 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/errors.go @@ -0,0 +1,27 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" +) + +// NewResponseError creates an *azcore.ResponseError from the provided HTTP response. +// Call this when a service request returns a non-successful status code. +// The error code will be extracted from the *http.Response, either from the x-ms-error-code +// header (preferred) or attempted to be parsed from the response body. +func NewResponseError(resp *http.Response) error { + return exported.NewResponseError(resp) +} + +// NewResponseErrorWithErrorCode creates an *azcore.ResponseError from the provided HTTP response and errorCode. +// Use this variant when the error code is in a non-standard location. +func NewResponseErrorWithErrorCode(resp *http.Response, errorCode string) error { + return exported.NewResponseErrorWithErrorCode(resp, errorCode) +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/pager.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/pager.go new file mode 100644 index 00000000000..cffe692d7e3 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/pager.go @@ -0,0 +1,128 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "reflect" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/tracing" +) + +// PagingHandler contains the required data for constructing a Pager. +type PagingHandler[T any] struct { + // More returns a boolean indicating if there are more pages to fetch. + // It uses the provided page to make the determination. + More func(T) bool + + // Fetcher fetches the first and subsequent pages. + Fetcher func(context.Context, *T) (T, error) + + // Tracer contains the Tracer from the client that's creating the Pager. + Tracer tracing.Tracer +} + +// Pager provides operations for iterating over paged responses. +type Pager[T any] struct { + current *T + handler PagingHandler[T] + tracer tracing.Tracer + firstPage bool +} + +// NewPager creates an instance of Pager using the specified PagingHandler. +// Pass a non-nil T for firstPage if the first page has already been retrieved. +func NewPager[T any](handler PagingHandler[T]) *Pager[T] { + return &Pager[T]{ + handler: handler, + tracer: handler.Tracer, + firstPage: true, + } +} + +// More returns true if there are more pages to retrieve. +func (p *Pager[T]) More() bool { + if p.current != nil { + return p.handler.More(*p.current) + } + return true +} + +// NextPage advances the pager to the next page. +func (p *Pager[T]) NextPage(ctx context.Context) (T, error) { + if p.current != nil { + if p.firstPage { + // we get here if it's an LRO-pager, we already have the first page + p.firstPage = false + return *p.current, nil + } else if !p.handler.More(*p.current) { + return *new(T), errors.New("no more pages") + } + } else { + // non-LRO case, first page + p.firstPage = false + } + + var err error + ctx, endSpan := StartSpan(ctx, fmt.Sprintf("%s.NextPage", shortenTypeName(reflect.TypeOf(*p).Name())), p.tracer, nil) + defer func() { endSpan(err) }() + + resp, err := p.handler.Fetcher(ctx, p.current) + if err != nil { + return *new(T), err + } + p.current = &resp + return *p.current, nil +} + +// UnmarshalJSON implements the json.Unmarshaler interface for Pager[T]. +func (p *Pager[T]) UnmarshalJSON(data []byte) error { + return json.Unmarshal(data, &p.current) +} + +// FetcherForNextLinkOptions contains the optional values for [FetcherForNextLink]. +type FetcherForNextLinkOptions struct { + // NextReq is the func to be called when requesting subsequent pages. + // Used for paged operations that have a custom next link operation. + NextReq func(context.Context, string) (*policy.Request, error) +} + +// FetcherForNextLink is a helper containing boilerplate code to simplify creating a PagingHandler[T].Fetcher from a next link URL. +// - ctx is the [context.Context] controlling the lifetime of the HTTP operation +// - pl is the [Pipeline] used to dispatch the HTTP request +// - nextLink is the URL used to fetch the next page. the empty string indicates the first page is to be requested +// - firstReq is the func to be called when creating the request for the first page +// - options contains any optional parameters, pass nil to accept the default values +func FetcherForNextLink(ctx context.Context, pl Pipeline, nextLink string, firstReq func(context.Context) (*policy.Request, error), options *FetcherForNextLinkOptions) (*http.Response, error) { + var req *policy.Request + var err error + if nextLink == "" { + req, err = firstReq(ctx) + } else if nextLink, err = EncodeQueryParams(nextLink); err == nil { + if options != nil && options.NextReq != nil { + req, err = options.NextReq(ctx, nextLink) + } else { + req, err = NewRequest(ctx, http.MethodGet, nextLink) + } + } + if err != nil { + return nil, err + } + resp, err := pl.Do(req) + if err != nil { + return nil, err + } + if !HasStatusCode(resp, http.StatusOK) { + return nil, NewResponseError(resp) + } + return resp, nil +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/pipeline.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/pipeline.go new file mode 100644 index 00000000000..6b1f5c083eb --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/pipeline.go @@ -0,0 +1,94 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" +) + +// PipelineOptions contains Pipeline options for SDK developers +type PipelineOptions struct { + // AllowedHeaders is the slice of headers to log with their values intact. + // All headers not in the slice will have their values REDACTED. + // Applies to request and response headers. + AllowedHeaders []string + + // AllowedQueryParameters is the slice of query parameters to log with their values intact. + // All query parameters not in the slice will have their values REDACTED. + AllowedQueryParameters []string + + // APIVersion overrides the default version requested of the service. + // Set with caution as this package version has not been tested with arbitrary service versions. + APIVersion APIVersionOptions + + // PerCall contains custom policies to inject into the pipeline. + // Each policy is executed once per request. + PerCall []policy.Policy + + // PerRetry contains custom policies to inject into the pipeline. + // Each policy is executed once per request, and for each retry of that request. + PerRetry []policy.Policy + + // Tracing contains options used to configure distributed tracing. + Tracing TracingOptions +} + +// TracingOptions contains tracing options for SDK developers. +type TracingOptions struct { + // Namespace contains the value to use for the az.namespace span attribute. + Namespace string +} + +// Pipeline represents a primitive for sending HTTP requests and receiving responses. +// Its behavior can be extended by specifying policies during construction. +type Pipeline = exported.Pipeline + +// NewPipeline creates a pipeline from connection options, with any additional policies as specified. +// Policies from ClientOptions are placed after policies from PipelineOptions. +// The module and version parameters are used by the telemetry policy, when enabled. +func NewPipeline(module, version string, plOpts PipelineOptions, options *policy.ClientOptions) Pipeline { + cp := policy.ClientOptions{} + if options != nil { + cp = *options + } + if len(plOpts.AllowedHeaders) > 0 { + headers := make([]string, len(plOpts.AllowedHeaders)+len(cp.Logging.AllowedHeaders)) + copy(headers, plOpts.AllowedHeaders) + headers = append(headers, cp.Logging.AllowedHeaders...) + cp.Logging.AllowedHeaders = headers + } + if len(plOpts.AllowedQueryParameters) > 0 { + qp := make([]string, len(plOpts.AllowedQueryParameters)+len(cp.Logging.AllowedQueryParams)) + copy(qp, plOpts.AllowedQueryParameters) + qp = append(qp, cp.Logging.AllowedQueryParams...) + cp.Logging.AllowedQueryParams = qp + } + // we put the includeResponsePolicy at the very beginning so that the raw response + // is populated with the final response (some policies might mutate the response) + policies := []policy.Policy{exported.PolicyFunc(includeResponsePolicy)} + if cp.APIVersion != "" { + policies = append(policies, newAPIVersionPolicy(cp.APIVersion, &plOpts.APIVersion)) + } + if !cp.Telemetry.Disabled { + policies = append(policies, NewTelemetryPolicy(module, version, &cp.Telemetry)) + } + policies = append(policies, plOpts.PerCall...) + policies = append(policies, cp.PerCallPolicies...) + policies = append(policies, NewRetryPolicy(&cp.Retry)) + policies = append(policies, plOpts.PerRetry...) + policies = append(policies, cp.PerRetryPolicies...) + policies = append(policies, exported.PolicyFunc(httpHeaderPolicy)) + policies = append(policies, newHTTPTracePolicy(cp.Logging.AllowedQueryParams)) + policies = append(policies, NewLogPolicy(&cp.Logging)) + policies = append(policies, exported.PolicyFunc(bodyDownloadPolicy)) + transport := cp.Transport + if transport == nil { + transport = defaultHTTPClient + } + return exported.NewPipeline(transport, policies...) +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_api_version.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_api_version.go new file mode 100644 index 00000000000..e5309aa6c15 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_api_version.go @@ -0,0 +1,75 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "errors" + "fmt" + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" +) + +// APIVersionOptions contains options for API versions +type APIVersionOptions struct { + // Location indicates where to set the version on a request, for example in a header or query param + Location APIVersionLocation + // Name is the name of the header or query parameter, for example "api-version" + Name string +} + +// APIVersionLocation indicates which part of a request identifies the service version +type APIVersionLocation int + +const ( + // APIVersionLocationQueryParam indicates a query parameter + APIVersionLocationQueryParam = 0 + // APIVersionLocationHeader indicates a header + APIVersionLocationHeader = 1 +) + +// newAPIVersionPolicy constructs an APIVersionPolicy. If version is "", Do will be a no-op. If version +// isn't empty and opts.Name is empty, Do will return an error. +func newAPIVersionPolicy(version string, opts *APIVersionOptions) *apiVersionPolicy { + if opts == nil { + opts = &APIVersionOptions{} + } + return &apiVersionPolicy{location: opts.Location, name: opts.Name, version: version} +} + +// apiVersionPolicy enables users to set the API version of every request a client sends. +type apiVersionPolicy struct { + // location indicates whether "name" refers to a query parameter or header. + location APIVersionLocation + + // name of the query param or header whose value should be overridden; provided by the client. + name string + + // version is the value (provided by the user) that replaces the default version value. + version string +} + +// Do sets the request's API version, if the policy is configured to do so, replacing any prior value. +func (a *apiVersionPolicy) Do(req *policy.Request) (*http.Response, error) { + if a.version != "" { + if a.name == "" { + // user set ClientOptions.APIVersion but the client ctor didn't set PipelineOptions.APIVersionOptions + return nil, errors.New("this client doesn't support overriding its API version") + } + switch a.location { + case APIVersionLocationHeader: + req.Raw().Header.Set(a.name, a.version) + case APIVersionLocationQueryParam: + q := req.Raw().URL.Query() + q.Set(a.name, a.version) + req.Raw().URL.RawQuery = q.Encode() + default: + return nil, fmt.Errorf("unknown APIVersionLocation %d", a.location) + } + } + return req.Next() +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_bearer_token.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_bearer_token.go new file mode 100644 index 00000000000..cb2a6952805 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_bearer_token.go @@ -0,0 +1,123 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "errors" + "net/http" + "strings" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/internal/errorinfo" + "github.com/Azure/azure-sdk-for-go/sdk/internal/temporal" +) + +// BearerTokenPolicy authorizes requests with bearer tokens acquired from a TokenCredential. +type BearerTokenPolicy struct { + // mainResource is the resource to be retreived using the tenant specified in the credential + mainResource *temporal.Resource[exported.AccessToken, acquiringResourceState] + // the following fields are read-only + authzHandler policy.AuthorizationHandler + cred exported.TokenCredential + scopes []string + allowHTTP bool +} + +type acquiringResourceState struct { + req *policy.Request + p *BearerTokenPolicy + tro policy.TokenRequestOptions +} + +// acquire acquires or updates the resource; only one +// thread/goroutine at a time ever calls this function +func acquire(state acquiringResourceState) (newResource exported.AccessToken, newExpiration time.Time, err error) { + tk, err := state.p.cred.GetToken(&shared.ContextWithDeniedValues{Context: state.req.Raw().Context()}, state.tro) + if err != nil { + return exported.AccessToken{}, time.Time{}, err + } + return tk, tk.ExpiresOn, nil +} + +// NewBearerTokenPolicy creates a policy object that authorizes requests with bearer tokens. +// cred: an azcore.TokenCredential implementation such as a credential object from azidentity +// scopes: the list of permission scopes required for the token. +// opts: optional settings. Pass nil to accept default values; this is the same as passing a zero-value options. +func NewBearerTokenPolicy(cred exported.TokenCredential, scopes []string, opts *policy.BearerTokenOptions) *BearerTokenPolicy { + if opts == nil { + opts = &policy.BearerTokenOptions{} + } + return &BearerTokenPolicy{ + authzHandler: opts.AuthorizationHandler, + cred: cred, + scopes: scopes, + mainResource: temporal.NewResource(acquire), + allowHTTP: opts.InsecureAllowCredentialWithHTTP, + } +} + +// authenticateAndAuthorize returns a function which authorizes req with a token from the policy's credential +func (b *BearerTokenPolicy) authenticateAndAuthorize(req *policy.Request) func(policy.TokenRequestOptions) error { + return func(tro policy.TokenRequestOptions) error { + as := acquiringResourceState{p: b, req: req, tro: tro} + tk, err := b.mainResource.Get(as) + if err != nil { + return err + } + req.Raw().Header.Set(shared.HeaderAuthorization, shared.BearerTokenPrefix+tk.Token) + return nil + } +} + +// Do authorizes a request with a bearer token +func (b *BearerTokenPolicy) Do(req *policy.Request) (*http.Response, error) { + // skip adding the authorization header if no TokenCredential was provided. + // this prevents a panic that might be hard to diagnose and allows testing + // against http endpoints that don't require authentication. + if b.cred == nil { + return req.Next() + } + + if err := checkHTTPSForAuth(req, b.allowHTTP); err != nil { + return nil, err + } + + var err error + if b.authzHandler.OnRequest != nil { + err = b.authzHandler.OnRequest(req, b.authenticateAndAuthorize(req)) + } else { + err = b.authenticateAndAuthorize(req)(policy.TokenRequestOptions{Scopes: b.scopes}) + } + if err != nil { + return nil, errorinfo.NonRetriableError(err) + } + + res, err := req.Next() + if err != nil { + return nil, err + } + + if res.StatusCode == http.StatusUnauthorized { + b.mainResource.Expire() + if res.Header.Get("WWW-Authenticate") != "" && b.authzHandler.OnChallenge != nil { + if err = b.authzHandler.OnChallenge(req, res, b.authenticateAndAuthorize(req)); err == nil { + res, err = req.Next() + } + } + } + if err != nil { + err = errorinfo.NonRetriableError(err) + } + return res, err +} + +func checkHTTPSForAuth(req *policy.Request, allowHTTP bool) error { + if strings.ToLower(req.Raw().URL.Scheme) != "https" && !allowHTTP { + return errorinfo.NonRetriableError(errors.New("authenticated requests are not permitted for non TLS protected (https) endpoints")) + } + return nil +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_body_download.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_body_download.go new file mode 100644 index 00000000000..99dc029f0c1 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_body_download.go @@ -0,0 +1,72 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "fmt" + "net/http" + "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/internal/errorinfo" +) + +// bodyDownloadPolicy creates a policy object that downloads the response's body to a []byte. +func bodyDownloadPolicy(req *policy.Request) (*http.Response, error) { + resp, err := req.Next() + if err != nil { + return resp, err + } + var opValues bodyDownloadPolicyOpValues + // don't skip downloading error response bodies + if req.OperationValue(&opValues); opValues.Skip && resp.StatusCode < 400 { + return resp, err + } + // Either bodyDownloadPolicyOpValues was not specified (so skip is false) + // or it was specified and skip is false: don't skip downloading the body + _, err = Payload(resp) + if err != nil { + return resp, newBodyDownloadError(err, req) + } + return resp, err +} + +// bodyDownloadPolicyOpValues is the struct containing the per-operation values +type bodyDownloadPolicyOpValues struct { + Skip bool +} + +type bodyDownloadError struct { + err error +} + +func newBodyDownloadError(err error, req *policy.Request) error { + // on failure, only retry the request for idempotent operations. + // we currently identify them as DELETE, GET, and PUT requests. + if m := strings.ToUpper(req.Raw().Method); m == http.MethodDelete || m == http.MethodGet || m == http.MethodPut { + // error is safe for retry + return err + } + // wrap error to avoid retries + return &bodyDownloadError{ + err: err, + } +} + +func (b *bodyDownloadError) Error() string { + return fmt.Sprintf("body download policy: %s", b.err.Error()) +} + +func (b *bodyDownloadError) NonRetriable() { + // marker method +} + +func (b *bodyDownloadError) Unwrap() error { + return b.err +} + +var _ errorinfo.NonRetriable = (*bodyDownloadError)(nil) diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_http_header.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_http_header.go new file mode 100644 index 00000000000..c230af0afa8 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_http_header.go @@ -0,0 +1,40 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "context" + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" +) + +// newHTTPHeaderPolicy creates a policy object that adds custom HTTP headers to a request +func httpHeaderPolicy(req *policy.Request) (*http.Response, error) { + // check if any custom HTTP headers have been specified + if header := req.Raw().Context().Value(shared.CtxWithHTTPHeaderKey{}); header != nil { + for k, v := range header.(http.Header) { + // use Set to replace any existing value + // it also canonicalizes the header key + req.Raw().Header.Set(k, v[0]) + // add any remaining values + for i := 1; i < len(v); i++ { + req.Raw().Header.Add(k, v[i]) + } + } + } + return req.Next() +} + +// WithHTTPHeader adds the specified http.Header to the parent context. +// Use this to specify custom HTTP headers at the API-call level. +// Any overlapping headers will have their values replaced with the values specified here. +// Deprecated: use [policy.WithHTTPHeader] instead. +func WithHTTPHeader(parent context.Context, header http.Header) context.Context { + return policy.WithHTTPHeader(parent, header) +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_http_trace.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_http_trace.go new file mode 100644 index 00000000000..3df1c121890 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_http_trace.go @@ -0,0 +1,143 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/url" + "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/tracing" +) + +const ( + attrHTTPMethod = "http.method" + attrHTTPURL = "http.url" + attrHTTPUserAgent = "http.user_agent" + attrHTTPStatusCode = "http.status_code" + + attrAZClientReqID = "az.client_request_id" + attrAZServiceReqID = "az.service_request_id" + + attrNetPeerName = "net.peer.name" +) + +// newHTTPTracePolicy creates a new instance of the httpTracePolicy. +// - allowedQueryParams contains the user-specified query parameters that don't need to be redacted from the trace +func newHTTPTracePolicy(allowedQueryParams []string) exported.Policy { + return &httpTracePolicy{allowedQP: getAllowedQueryParams(allowedQueryParams)} +} + +// httpTracePolicy is a policy that creates a trace for the HTTP request and its response +type httpTracePolicy struct { + allowedQP map[string]struct{} +} + +// Do implements the pipeline.Policy interfaces for the httpTracePolicy type. +func (h *httpTracePolicy) Do(req *policy.Request) (resp *http.Response, err error) { + rawTracer := req.Raw().Context().Value(shared.CtxWithTracingTracer{}) + if tracer, ok := rawTracer.(tracing.Tracer); ok && tracer.Enabled() { + attributes := []tracing.Attribute{ + {Key: attrHTTPMethod, Value: req.Raw().Method}, + {Key: attrHTTPURL, Value: getSanitizedURL(*req.Raw().URL, h.allowedQP)}, + {Key: attrNetPeerName, Value: req.Raw().URL.Host}, + } + + if ua := req.Raw().Header.Get(shared.HeaderUserAgent); ua != "" { + attributes = append(attributes, tracing.Attribute{Key: attrHTTPUserAgent, Value: ua}) + } + if reqID := req.Raw().Header.Get(shared.HeaderXMSClientRequestID); reqID != "" { + attributes = append(attributes, tracing.Attribute{Key: attrAZClientReqID, Value: reqID}) + } + + ctx := req.Raw().Context() + ctx, span := tracer.Start(ctx, "HTTP "+req.Raw().Method, &tracing.SpanOptions{ + Kind: tracing.SpanKindClient, + Attributes: attributes, + }) + + defer func() { + if resp != nil { + span.SetAttributes(tracing.Attribute{Key: attrHTTPStatusCode, Value: resp.StatusCode}) + if resp.StatusCode > 399 { + span.SetStatus(tracing.SpanStatusError, resp.Status) + } + if reqID := resp.Header.Get(shared.HeaderXMSRequestID); reqID != "" { + span.SetAttributes(tracing.Attribute{Key: attrAZServiceReqID, Value: reqID}) + } + } else if err != nil { + var urlErr *url.Error + if errors.As(err, &urlErr) { + // calling *url.Error.Error() will include the unsanitized URL + // which we don't want. in addition, we already have the HTTP verb + // and sanitized URL in the trace so we aren't losing any info + err = urlErr.Err + } + span.SetStatus(tracing.SpanStatusError, err.Error()) + } + span.End() + }() + + req = req.WithContext(ctx) + } + resp, err = req.Next() + return +} + +// StartSpanOptions contains the optional values for StartSpan. +type StartSpanOptions struct { + // for future expansion +} + +// StartSpan starts a new tracing span. +// You must call the returned func to terminate the span. Pass the applicable error +// if the span will exit with an error condition. +// - ctx is the parent context of the newly created context +// - name is the name of the span. this is typically the fully qualified name of an API ("Client.Method") +// - tracer is the client's Tracer for creating spans +// - options contains optional values. pass nil to accept any default values +func StartSpan(ctx context.Context, name string, tracer tracing.Tracer, options *StartSpanOptions) (context.Context, func(error)) { + if !tracer.Enabled() { + return ctx, func(err error) {} + } + + // we MUST propagate the active tracer before returning so that the trace policy can access it + ctx = context.WithValue(ctx, shared.CtxWithTracingTracer{}, tracer) + + const newSpanKind = tracing.SpanKindInternal + if activeSpan := ctx.Value(ctxActiveSpan{}); activeSpan != nil { + // per the design guidelines, if a SDK method Foo() calls SDK method Bar(), + // then the span for Bar() must be suppressed. however, if Bar() makes a REST + // call, then Bar's HTTP span must be a child of Foo's span. + // however, there is an exception to this rule. if the SDK method Foo() is a + // messaging producer/consumer, and it takes a callback that's a SDK method + // Bar(), then the span for Bar() must _not_ be suppressed. + if kind := activeSpan.(tracing.SpanKind); kind == tracing.SpanKindClient || kind == tracing.SpanKindInternal { + return ctx, func(err error) {} + } + } + ctx, span := tracer.Start(ctx, name, &tracing.SpanOptions{ + Kind: newSpanKind, + }) + ctx = context.WithValue(ctx, ctxActiveSpan{}, newSpanKind) + return ctx, func(err error) { + if err != nil { + errType := strings.Replace(fmt.Sprintf("%T", err), "*exported.", "*azcore.", 1) + span.SetStatus(tracing.SpanStatusError, fmt.Sprintf("%s:\n%s", errType, err.Error())) + } + span.End() + } +} + +// ctxActiveSpan is used as a context key for indicating a SDK client span is in progress. +type ctxActiveSpan struct{} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_include_response.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_include_response.go new file mode 100644 index 00000000000..bb00f6c2fdb --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_include_response.go @@ -0,0 +1,35 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "context" + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" +) + +// includeResponsePolicy creates a policy that retrieves the raw HTTP response upon request +func includeResponsePolicy(req *policy.Request) (*http.Response, error) { + resp, err := req.Next() + if resp == nil { + return resp, err + } + if httpOutRaw := req.Raw().Context().Value(shared.CtxWithCaptureResponse{}); httpOutRaw != nil { + httpOut := httpOutRaw.(**http.Response) + *httpOut = resp + } + return resp, err +} + +// WithCaptureResponse applies the HTTP response retrieval annotation to the parent context. +// The resp parameter will contain the HTTP response after the request has completed. +// Deprecated: use [policy.WithCaptureResponse] instead. +func WithCaptureResponse(parent context.Context, resp **http.Response) context.Context { + return policy.WithCaptureResponse(parent, resp) +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_key_credential.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_key_credential.go new file mode 100644 index 00000000000..eeb1c09cc12 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_key_credential.go @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" +) + +// KeyCredentialPolicy authorizes requests with a [azcore.KeyCredential]. +type KeyCredentialPolicy struct { + cred *exported.KeyCredential + header string + prefix string + allowHTTP bool +} + +// KeyCredentialPolicyOptions contains the optional values configuring [KeyCredentialPolicy]. +type KeyCredentialPolicyOptions struct { + // InsecureAllowCredentialWithHTTP enables authenticated requests over HTTP. + // By default, authenticated requests to an HTTP endpoint are rejected by the client. + // WARNING: setting this to true will allow sending the authentication key in clear text. Use with caution. + InsecureAllowCredentialWithHTTP bool + + // Prefix is used if the key requires a prefix before it's inserted into the HTTP request. + Prefix string +} + +// NewKeyCredentialPolicy creates a new instance of [KeyCredentialPolicy]. +// - cred is the [azcore.KeyCredential] used to authenticate with the service +// - header is the name of the HTTP request header in which the key is placed +// - options contains optional configuration, pass nil to accept the default values +func NewKeyCredentialPolicy(cred *exported.KeyCredential, header string, options *KeyCredentialPolicyOptions) *KeyCredentialPolicy { + if options == nil { + options = &KeyCredentialPolicyOptions{} + } + return &KeyCredentialPolicy{ + cred: cred, + header: header, + prefix: options.Prefix, + allowHTTP: options.InsecureAllowCredentialWithHTTP, + } +} + +// Do implementes the Do method on the [policy.Polilcy] interface. +func (k *KeyCredentialPolicy) Do(req *policy.Request) (*http.Response, error) { + // skip adding the authorization header if no KeyCredential was provided. + // this prevents a panic that might be hard to diagnose and allows testing + // against http endpoints that don't require authentication. + if k.cred != nil { + if err := checkHTTPSForAuth(req, k.allowHTTP); err != nil { + return nil, err + } + val := exported.KeyCredentialGet(k.cred) + if k.prefix != "" { + val = k.prefix + val + } + req.Raw().Header.Add(k.header, val) + } + return req.Next() +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_logging.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_logging.go new file mode 100644 index 00000000000..f048d7fb53f --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_logging.go @@ -0,0 +1,264 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "bytes" + "fmt" + "io" + "net/http" + "net/url" + "sort" + "strings" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/log" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/internal/diag" +) + +type logPolicy struct { + includeBody bool + allowedHeaders map[string]struct{} + allowedQP map[string]struct{} +} + +// NewLogPolicy creates a request/response logging policy object configured using the specified options. +// Pass nil to accept the default values; this is the same as passing a zero-value options. +func NewLogPolicy(o *policy.LogOptions) policy.Policy { + if o == nil { + o = &policy.LogOptions{} + } + // construct default hash set of allowed headers + allowedHeaders := map[string]struct{}{ + "accept": {}, + "cache-control": {}, + "connection": {}, + "content-length": {}, + "content-type": {}, + "date": {}, + "etag": {}, + "expires": {}, + "if-match": {}, + "if-modified-since": {}, + "if-none-match": {}, + "if-unmodified-since": {}, + "last-modified": {}, + "ms-cv": {}, + "pragma": {}, + "request-id": {}, + "retry-after": {}, + "server": {}, + "traceparent": {}, + "transfer-encoding": {}, + "user-agent": {}, + "www-authenticate": {}, + "x-ms-request-id": {}, + "x-ms-client-request-id": {}, + "x-ms-return-client-request-id": {}, + } + // add any caller-specified allowed headers to the set + for _, ah := range o.AllowedHeaders { + allowedHeaders[strings.ToLower(ah)] = struct{}{} + } + // now do the same thing for query params + allowedQP := getAllowedQueryParams(o.AllowedQueryParams) + return &logPolicy{ + includeBody: o.IncludeBody, + allowedHeaders: allowedHeaders, + allowedQP: allowedQP, + } +} + +// getAllowedQueryParams merges the default set of allowed query parameters +// with a custom set (usually comes from client options). +func getAllowedQueryParams(customAllowedQP []string) map[string]struct{} { + allowedQP := map[string]struct{}{ + "api-version": {}, + } + for _, qp := range customAllowedQP { + allowedQP[strings.ToLower(qp)] = struct{}{} + } + return allowedQP +} + +// logPolicyOpValues is the struct containing the per-operation values +type logPolicyOpValues struct { + try int32 + start time.Time +} + +func (p *logPolicy) Do(req *policy.Request) (*http.Response, error) { + // Get the per-operation values. These are saved in the Message's map so that they persist across each retry calling into this policy object. + var opValues logPolicyOpValues + if req.OperationValue(&opValues); opValues.start.IsZero() { + opValues.start = time.Now() // If this is the 1st try, record this operation's start time + } + opValues.try++ // The first try is #1 (not #0) + req.SetOperationValue(opValues) + + // Log the outgoing request as informational + if log.Should(log.EventRequest) { + b := &bytes.Buffer{} + fmt.Fprintf(b, "==> OUTGOING REQUEST (Try=%d)\n", opValues.try) + p.writeRequestWithResponse(b, req, nil, nil) + var err error + if p.includeBody { + err = writeReqBody(req, b) + } + log.Write(log.EventRequest, b.String()) + if err != nil { + return nil, err + } + } + + // Set the time for this particular retry operation and then Do the operation. + tryStart := time.Now() + response, err := req.Next() // Make the request + tryEnd := time.Now() + tryDuration := tryEnd.Sub(tryStart) + opDuration := tryEnd.Sub(opValues.start) + + if log.Should(log.EventResponse) { + // We're going to log this; build the string to log + b := &bytes.Buffer{} + fmt.Fprintf(b, "==> REQUEST/RESPONSE (Try=%d/%v, OpTime=%v) -- ", opValues.try, tryDuration, opDuration) + if err != nil { // This HTTP request did not get a response from the service + fmt.Fprint(b, "REQUEST ERROR\n") + } else { + fmt.Fprint(b, "RESPONSE RECEIVED\n") + } + + p.writeRequestWithResponse(b, req, response, err) + if err != nil { + // skip frames runtime.Callers() and runtime.StackTrace() + b.WriteString(diag.StackTrace(2, 32)) + } else if p.includeBody { + err = writeRespBody(response, b) + } + log.Write(log.EventResponse, b.String()) + } + return response, err +} + +const redactedValue = "REDACTED" + +// getSanitizedURL returns a sanitized string for the provided url.URL +func getSanitizedURL(u url.URL, allowedQueryParams map[string]struct{}) string { + // redact applicable query params + qp := u.Query() + for k := range qp { + if _, ok := allowedQueryParams[strings.ToLower(k)]; !ok { + qp.Set(k, redactedValue) + } + } + u.RawQuery = qp.Encode() + return u.String() +} + +// writeRequestWithResponse appends a formatted HTTP request into a Buffer. If request and/or err are +// not nil, then these are also written into the Buffer. +func (p *logPolicy) writeRequestWithResponse(b *bytes.Buffer, req *policy.Request, resp *http.Response, err error) { + // Write the request into the buffer. + fmt.Fprint(b, " "+req.Raw().Method+" "+getSanitizedURL(*req.Raw().URL, p.allowedQP)+"\n") + p.writeHeader(b, req.Raw().Header) + if resp != nil { + fmt.Fprintln(b, " --------------------------------------------------------------------------------") + fmt.Fprint(b, " RESPONSE Status: "+resp.Status+"\n") + p.writeHeader(b, resp.Header) + } + if err != nil { + fmt.Fprintln(b, " --------------------------------------------------------------------------------") + fmt.Fprint(b, " ERROR:\n"+err.Error()+"\n") + } +} + +// formatHeaders appends an HTTP request's or response's header into a Buffer. +func (p *logPolicy) writeHeader(b *bytes.Buffer, header http.Header) { + if len(header) == 0 { + b.WriteString(" (no headers)\n") + return + } + keys := make([]string, 0, len(header)) + // Alphabetize the headers + for k := range header { + keys = append(keys, k) + } + sort.Strings(keys) + for _, k := range keys { + // don't use Get() as it will canonicalize k which might cause a mismatch + value := header[k][0] + // redact all header values not in the allow-list + if _, ok := p.allowedHeaders[strings.ToLower(k)]; !ok { + value = redactedValue + } + fmt.Fprintf(b, " %s: %+v\n", k, value) + } +} + +// returns true if the request/response body should be logged. +// this is determined by looking at the content-type header value. +func shouldLogBody(b *bytes.Buffer, contentType string) bool { + contentType = strings.ToLower(contentType) + if strings.HasPrefix(contentType, "text") || + strings.Contains(contentType, "json") || + strings.Contains(contentType, "xml") { + return true + } + fmt.Fprintf(b, " Skip logging body for %s\n", contentType) + return false +} + +// writes to a buffer, used for logging purposes +func writeReqBody(req *policy.Request, b *bytes.Buffer) error { + if req.Raw().Body == nil { + fmt.Fprint(b, " Request contained no body\n") + return nil + } + if ct := req.Raw().Header.Get(shared.HeaderContentType); !shouldLogBody(b, ct) { + return nil + } + body, err := io.ReadAll(req.Raw().Body) + if err != nil { + fmt.Fprintf(b, " Failed to read request body: %s\n", err.Error()) + return err + } + if err := req.RewindBody(); err != nil { + return err + } + logBody(b, body) + return nil +} + +// writes to a buffer, used for logging purposes +func writeRespBody(resp *http.Response, b *bytes.Buffer) error { + ct := resp.Header.Get(shared.HeaderContentType) + if ct == "" { + fmt.Fprint(b, " Response contained no body\n") + return nil + } else if !shouldLogBody(b, ct) { + return nil + } + body, err := Payload(resp) + if err != nil { + fmt.Fprintf(b, " Failed to read response body: %s\n", err.Error()) + return err + } + if len(body) > 0 { + logBody(b, body) + } else { + fmt.Fprint(b, " Response contained no body\n") + } + return nil +} + +func logBody(b *bytes.Buffer, body []byte) { + fmt.Fprintln(b, " --------------------------------------------------------------------------------") + fmt.Fprintln(b, string(body)) + fmt.Fprintln(b, " --------------------------------------------------------------------------------") +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_request_id.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_request_id.go new file mode 100644 index 00000000000..360a7f2118a --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_request_id.go @@ -0,0 +1,34 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/internal/uuid" +) + +type requestIDPolicy struct{} + +// NewRequestIDPolicy returns a policy that add the x-ms-client-request-id header +func NewRequestIDPolicy() policy.Policy { + return &requestIDPolicy{} +} + +func (r *requestIDPolicy) Do(req *policy.Request) (*http.Response, error) { + if req.Raw().Header.Get(shared.HeaderXMSClientRequestID) == "" { + id, err := uuid.New() + if err != nil { + return nil, err + } + req.Raw().Header.Set(shared.HeaderXMSClientRequestID, id.String()) + } + + return req.Next() +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_retry.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_retry.go new file mode 100644 index 00000000000..04d7bb4ecbc --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_retry.go @@ -0,0 +1,255 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "context" + "errors" + "io" + "math" + "math/rand" + "net/http" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/log" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/internal/errorinfo" + "github.com/Azure/azure-sdk-for-go/sdk/internal/exported" +) + +const ( + defaultMaxRetries = 3 +) + +func setDefaults(o *policy.RetryOptions) { + if o.MaxRetries == 0 { + o.MaxRetries = defaultMaxRetries + } else if o.MaxRetries < 0 { + o.MaxRetries = 0 + } + + // SDK guidelines specify the default MaxRetryDelay is 60 seconds + if o.MaxRetryDelay == 0 { + o.MaxRetryDelay = 60 * time.Second + } else if o.MaxRetryDelay < 0 { + // not really an unlimited cap, but sufficiently large enough to be considered as such + o.MaxRetryDelay = math.MaxInt64 + } + if o.RetryDelay == 0 { + o.RetryDelay = 800 * time.Millisecond + } else if o.RetryDelay < 0 { + o.RetryDelay = 0 + } + if o.StatusCodes == nil { + // NOTE: if you change this list, you MUST update the docs in policy/policy.go + o.StatusCodes = []int{ + http.StatusRequestTimeout, // 408 + http.StatusTooManyRequests, // 429 + http.StatusInternalServerError, // 500 + http.StatusBadGateway, // 502 + http.StatusServiceUnavailable, // 503 + http.StatusGatewayTimeout, // 504 + } + } +} + +func calcDelay(o policy.RetryOptions, try int32) time.Duration { // try is >=1; never 0 + delay := time.Duration((1< o.MaxRetryDelay { + delay = o.MaxRetryDelay + } + return delay +} + +// NewRetryPolicy creates a policy object configured using the specified options. +// Pass nil to accept the default values; this is the same as passing a zero-value options. +func NewRetryPolicy(o *policy.RetryOptions) policy.Policy { + if o == nil { + o = &policy.RetryOptions{} + } + p := &retryPolicy{options: *o} + return p +} + +type retryPolicy struct { + options policy.RetryOptions +} + +func (p *retryPolicy) Do(req *policy.Request) (resp *http.Response, err error) { + options := p.options + // check if the retry options have been overridden for this call + if override := req.Raw().Context().Value(shared.CtxWithRetryOptionsKey{}); override != nil { + options = override.(policy.RetryOptions) + } + setDefaults(&options) + // Exponential retry algorithm: ((2 ^ attempt) - 1) * delay * random(0.8, 1.2) + // When to retry: connection failure or temporary/timeout. + var rwbody *retryableRequestBody + if req.Body() != nil { + // wrap the body so we control when it's actually closed. + // do this outside the for loop so defers don't accumulate. + rwbody = &retryableRequestBody{body: req.Body()} + defer rwbody.realClose() + } + try := int32(1) + for { + resp = nil // reset + log.Writef(log.EventRetryPolicy, "=====> Try=%d", try) + + // For each try, seek to the beginning of the Body stream. We do this even for the 1st try because + // the stream may not be at offset 0 when we first get it and we want the same behavior for the + // 1st try as for additional tries. + err = req.RewindBody() + if err != nil { + return + } + // RewindBody() restores Raw().Body to its original state, so set our rewindable after + if rwbody != nil { + req.Raw().Body = rwbody + } + + if options.TryTimeout == 0 { + clone := req.Clone(req.Raw().Context()) + resp, err = clone.Next() + } else { + // Set the per-try time for this particular retry operation and then Do the operation. + tryCtx, tryCancel := context.WithTimeout(req.Raw().Context(), options.TryTimeout) + clone := req.Clone(tryCtx) + resp, err = clone.Next() // Make the request + // if the body was already downloaded or there was an error it's safe to cancel the context now + if err != nil { + tryCancel() + } else if exported.PayloadDownloaded(resp) { + tryCancel() + } else { + // must cancel the context after the body has been read and closed + resp.Body = &contextCancelReadCloser{cf: tryCancel, body: resp.Body} + } + } + if err == nil { + log.Writef(log.EventRetryPolicy, "response %d", resp.StatusCode) + } else { + log.Writef(log.EventRetryPolicy, "error %v", err) + } + + if ctxErr := req.Raw().Context().Err(); ctxErr != nil { + // don't retry if the parent context has been cancelled or its deadline exceeded + err = ctxErr + log.Writef(log.EventRetryPolicy, "abort due to %v", err) + return + } + + // check if the error is not retriable + var nre errorinfo.NonRetriable + if errors.As(err, &nre) { + // the error says it's not retriable so don't retry + log.Writef(log.EventRetryPolicy, "non-retriable error %T", nre) + return + } + + if options.ShouldRetry != nil { + // a non-nil ShouldRetry overrides our HTTP status code check + if !options.ShouldRetry(resp, err) { + // predicate says we shouldn't retry + log.Write(log.EventRetryPolicy, "exit due to ShouldRetry") + return + } + } else if err == nil && !HasStatusCode(resp, options.StatusCodes...) { + // if there is no error and the response code isn't in the list of retry codes then we're done. + log.Write(log.EventRetryPolicy, "exit due to non-retriable status code") + return + } + + if try == options.MaxRetries+1 { + // max number of tries has been reached, don't sleep again + log.Writef(log.EventRetryPolicy, "MaxRetries %d exceeded", options.MaxRetries) + return + } + + // use the delay from retry-after if available + delay := shared.RetryAfter(resp) + if delay <= 0 { + delay = calcDelay(options, try) + } else if delay > options.MaxRetryDelay { + // the retry-after delay exceeds the the cap so don't retry + log.Writef(log.EventRetryPolicy, "Retry-After delay %s exceeds MaxRetryDelay of %s", delay, options.MaxRetryDelay) + return + } + + // drain before retrying so nothing is leaked + Drain(resp) + + log.Writef(log.EventRetryPolicy, "End Try #%d, Delay=%v", try, delay) + select { + case <-time.After(delay): + try++ + case <-req.Raw().Context().Done(): + err = req.Raw().Context().Err() + log.Writef(log.EventRetryPolicy, "abort due to %v", err) + return + } + } +} + +// WithRetryOptions adds the specified RetryOptions to the parent context. +// Use this to specify custom RetryOptions at the API-call level. +// Deprecated: use [policy.WithRetryOptions] instead. +func WithRetryOptions(parent context.Context, options policy.RetryOptions) context.Context { + return policy.WithRetryOptions(parent, options) +} + +// ********** The following type/methods implement the retryableRequestBody (a ReadSeekCloser) + +// This struct is used when sending a body to the network +type retryableRequestBody struct { + body io.ReadSeeker // Seeking is required to support retries +} + +// Read reads a block of data from an inner stream and reports progress +func (b *retryableRequestBody) Read(p []byte) (n int, err error) { + return b.body.Read(p) +} + +func (b *retryableRequestBody) Seek(offset int64, whence int) (offsetFromStart int64, err error) { + return b.body.Seek(offset, whence) +} + +func (b *retryableRequestBody) Close() error { + // We don't want the underlying transport to close the request body on transient failures so this is a nop. + // The retry policy closes the request body upon success. + return nil +} + +func (b *retryableRequestBody) realClose() error { + if c, ok := b.body.(io.Closer); ok { + return c.Close() + } + return nil +} + +// ********** The following type/methods implement the contextCancelReadCloser + +// contextCancelReadCloser combines an io.ReadCloser with a cancel func. +// it ensures the cancel func is invoked once the body has been read and closed. +type contextCancelReadCloser struct { + cf context.CancelFunc + body io.ReadCloser +} + +func (rc *contextCancelReadCloser) Read(p []byte) (n int, err error) { + return rc.body.Read(p) +} + +func (rc *contextCancelReadCloser) Close() error { + err := rc.body.Close() + rc.cf() + return err +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_sas_credential.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_sas_credential.go new file mode 100644 index 00000000000..3964beea862 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_sas_credential.go @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" +) + +// SASCredentialPolicy authorizes requests with a [azcore.SASCredential]. +type SASCredentialPolicy struct { + cred *exported.SASCredential + header string + allowHTTP bool +} + +// SASCredentialPolicyOptions contains the optional values configuring [SASCredentialPolicy]. +type SASCredentialPolicyOptions struct { + // InsecureAllowCredentialWithHTTP enables authenticated requests over HTTP. + // By default, authenticated requests to an HTTP endpoint are rejected by the client. + // WARNING: setting this to true will allow sending the authentication key in clear text. Use with caution. + InsecureAllowCredentialWithHTTP bool +} + +// NewSASCredentialPolicy creates a new instance of [SASCredentialPolicy]. +// - cred is the [azcore.SASCredential] used to authenticate with the service +// - header is the name of the HTTP request header in which the shared access signature is placed +// - options contains optional configuration, pass nil to accept the default values +func NewSASCredentialPolicy(cred *exported.SASCredential, header string, options *SASCredentialPolicyOptions) *SASCredentialPolicy { + if options == nil { + options = &SASCredentialPolicyOptions{} + } + return &SASCredentialPolicy{ + cred: cred, + header: header, + allowHTTP: options.InsecureAllowCredentialWithHTTP, + } +} + +// Do implementes the Do method on the [policy.Polilcy] interface. +func (k *SASCredentialPolicy) Do(req *policy.Request) (*http.Response, error) { + // skip adding the authorization header if no SASCredential was provided. + // this prevents a panic that might be hard to diagnose and allows testing + // against http endpoints that don't require authentication. + if k.cred != nil { + if err := checkHTTPSForAuth(req, k.allowHTTP); err != nil { + return nil, err + } + req.Raw().Header.Add(k.header, exported.SASCredentialGet(k.cred)) + } + return req.Next() +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_telemetry.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_telemetry.go new file mode 100644 index 00000000000..80a90354619 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/policy_telemetry.go @@ -0,0 +1,83 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "bytes" + "fmt" + "net/http" + "os" + "runtime" + "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" +) + +type telemetryPolicy struct { + telemetryValue string +} + +// NewTelemetryPolicy creates a telemetry policy object that adds telemetry information to outgoing HTTP requests. +// The format is [ ]azsdk-go-/ . +// Pass nil to accept the default values; this is the same as passing a zero-value options. +func NewTelemetryPolicy(mod, ver string, o *policy.TelemetryOptions) policy.Policy { + if o == nil { + o = &policy.TelemetryOptions{} + } + tp := telemetryPolicy{} + if o.Disabled { + return &tp + } + b := &bytes.Buffer{} + // normalize ApplicationID + if o.ApplicationID != "" { + o.ApplicationID = strings.ReplaceAll(o.ApplicationID, " ", "/") + if len(o.ApplicationID) > 24 { + o.ApplicationID = o.ApplicationID[:24] + } + b.WriteString(o.ApplicationID) + b.WriteRune(' ') + } + // mod might be the fully qualified name. in that case, we just want the package name + if i := strings.LastIndex(mod, "/"); i > -1 { + mod = mod[i+1:] + } + b.WriteString(formatTelemetry(mod, ver)) + b.WriteRune(' ') + b.WriteString(platformInfo) + tp.telemetryValue = b.String() + return &tp +} + +func formatTelemetry(comp, ver string) string { + return fmt.Sprintf("azsdk-go-%s/%s", comp, ver) +} + +func (p telemetryPolicy) Do(req *policy.Request) (*http.Response, error) { + if p.telemetryValue == "" { + return req.Next() + } + // preserve the existing User-Agent string + if ua := req.Raw().Header.Get(shared.HeaderUserAgent); ua != "" { + p.telemetryValue = fmt.Sprintf("%s %s", p.telemetryValue, ua) + } + req.Raw().Header.Set(shared.HeaderUserAgent, p.telemetryValue) + return req.Next() +} + +// NOTE: the ONLY function that should write to this variable is this func +var platformInfo = func() string { + operatingSystem := runtime.GOOS // Default OS string + switch operatingSystem { + case "windows": + operatingSystem = os.Getenv("OS") // Get more specific OS information + case "linux": // accept default OS info + case "freebsd": // accept default OS info + } + return fmt.Sprintf("(%s; %s)", runtime.Version(), operatingSystem) +}() diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/poller.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/poller.go new file mode 100644 index 00000000000..03f76c9aa8e --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/poller.go @@ -0,0 +1,389 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "context" + "encoding/json" + "errors" + "flag" + "fmt" + "net/http" + "reflect" + "strings" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/log" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/async" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/body" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/fake" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/loc" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/op" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/tracing" + "github.com/Azure/azure-sdk-for-go/sdk/internal/poller" +) + +// FinalStateVia is the enumerated type for the possible final-state-via values. +type FinalStateVia = pollers.FinalStateVia + +const ( + // FinalStateViaAzureAsyncOp indicates the final payload comes from the Azure-AsyncOperation URL. + FinalStateViaAzureAsyncOp = pollers.FinalStateViaAzureAsyncOp + + // FinalStateViaLocation indicates the final payload comes from the Location URL. + FinalStateViaLocation = pollers.FinalStateViaLocation + + // FinalStateViaOriginalURI indicates the final payload comes from the original URL. + FinalStateViaOriginalURI = pollers.FinalStateViaOriginalURI + + // FinalStateViaOpLocation indicates the final payload comes from the Operation-Location URL. + FinalStateViaOpLocation = pollers.FinalStateViaOpLocation +) + +// NewPollerOptions contains the optional parameters for NewPoller. +type NewPollerOptions[T any] struct { + // FinalStateVia contains the final-state-via value for the LRO. + FinalStateVia FinalStateVia + + // Response contains a preconstructed response type. + // The final payload will be unmarshaled into it and returned. + Response *T + + // Handler[T] contains a custom polling implementation. + Handler PollingHandler[T] + + // Tracer contains the Tracer from the client that's creating the Poller. + Tracer tracing.Tracer +} + +// NewPoller creates a Poller based on the provided initial response. +func NewPoller[T any](resp *http.Response, pl exported.Pipeline, options *NewPollerOptions[T]) (*Poller[T], error) { + if options == nil { + options = &NewPollerOptions[T]{} + } + result := options.Response + if result == nil { + result = new(T) + } + if options.Handler != nil { + return &Poller[T]{ + op: options.Handler, + resp: resp, + result: result, + tracer: options.Tracer, + }, nil + } + + defer resp.Body.Close() + // this is a back-stop in case the swagger is incorrect (i.e. missing one or more status codes for success). + // ideally the codegen should return an error if the initial response failed and not even create a poller. + if !poller.StatusCodeValid(resp) { + return nil, errors.New("the operation failed or was cancelled") + } + + // determine the polling method + var opr PollingHandler[T] + var err error + if fake.Applicable(resp) { + opr, err = fake.New[T](pl, resp) + } else if async.Applicable(resp) { + // async poller must be checked first as it can also have a location header + opr, err = async.New[T](pl, resp, options.FinalStateVia) + } else if op.Applicable(resp) { + // op poller must be checked before loc as it can also have a location header + opr, err = op.New[T](pl, resp, options.FinalStateVia) + } else if loc.Applicable(resp) { + opr, err = loc.New[T](pl, resp) + } else if body.Applicable(resp) { + // must test body poller last as it's a subset of the other pollers. + // TODO: this is ambiguous for PATCH/PUT if it returns a 200 with no polling headers (sync completion) + opr, err = body.New[T](pl, resp) + } else if m := resp.Request.Method; resp.StatusCode == http.StatusAccepted && (m == http.MethodDelete || m == http.MethodPost) { + // if we get here it means we have a 202 with no polling headers. + // for DELETE and POST this is a hard error per ARM RPC spec. + return nil, errors.New("response is missing polling URL") + } else { + opr, err = pollers.NewNopPoller[T](resp) + } + + if err != nil { + return nil, err + } + return &Poller[T]{ + op: opr, + resp: resp, + result: result, + tracer: options.Tracer, + }, nil +} + +// NewPollerFromResumeTokenOptions contains the optional parameters for NewPollerFromResumeToken. +type NewPollerFromResumeTokenOptions[T any] struct { + // Response contains a preconstructed response type. + // The final payload will be unmarshaled into it and returned. + Response *T + + // Handler[T] contains a custom polling implementation. + Handler PollingHandler[T] + + // Tracer contains the Tracer from the client that's creating the Poller. + Tracer tracing.Tracer +} + +// NewPollerFromResumeToken creates a Poller from a resume token string. +func NewPollerFromResumeToken[T any](token string, pl exported.Pipeline, options *NewPollerFromResumeTokenOptions[T]) (*Poller[T], error) { + if options == nil { + options = &NewPollerFromResumeTokenOptions[T]{} + } + result := options.Response + if result == nil { + result = new(T) + } + + if err := pollers.IsTokenValid[T](token); err != nil { + return nil, err + } + raw, err := pollers.ExtractToken(token) + if err != nil { + return nil, err + } + var asJSON map[string]any + if err := json.Unmarshal(raw, &asJSON); err != nil { + return nil, err + } + + opr := options.Handler + // now rehydrate the poller based on the encoded poller type + if fake.CanResume(asJSON) { + opr, _ = fake.New[T](pl, nil) + } else if opr != nil { + log.Writef(log.EventLRO, "Resuming custom poller %T.", opr) + } else if async.CanResume(asJSON) { + opr, _ = async.New[T](pl, nil, "") + } else if body.CanResume(asJSON) { + opr, _ = body.New[T](pl, nil) + } else if loc.CanResume(asJSON) { + opr, _ = loc.New[T](pl, nil) + } else if op.CanResume(asJSON) { + opr, _ = op.New[T](pl, nil, "") + } else { + return nil, fmt.Errorf("unhandled poller token %s", string(raw)) + } + if err := json.Unmarshal(raw, &opr); err != nil { + return nil, err + } + return &Poller[T]{ + op: opr, + result: result, + tracer: options.Tracer, + }, nil +} + +// PollingHandler[T] abstracts the differences among poller implementations. +type PollingHandler[T any] interface { + // Done returns true if the LRO has reached a terminal state. + Done() bool + + // Poll fetches the latest state of the LRO. + Poll(context.Context) (*http.Response, error) + + // Result is called once the LRO has reached a terminal state. It populates the out parameter + // with the result of the operation. + Result(ctx context.Context, out *T) error +} + +// Poller encapsulates a long-running operation, providing polling facilities until the operation reaches a terminal state. +type Poller[T any] struct { + op PollingHandler[T] + resp *http.Response + err error + result *T + tracer tracing.Tracer + done bool +} + +// PollUntilDoneOptions contains the optional values for the Poller[T].PollUntilDone() method. +type PollUntilDoneOptions struct { + // Frequency is the time to wait between polling intervals in absence of a Retry-After header. Allowed minimum is one second. + // Pass zero to accept the default value (30s). + Frequency time.Duration +} + +// PollUntilDone will poll the service endpoint until a terminal state is reached, an error is received, or the context expires. +// It internally uses Poll(), Done(), and Result() in its polling loop, sleeping for the specified duration between intervals. +// options: pass nil to accept the default values. +// NOTE: the default polling frequency is 30 seconds which works well for most operations. However, some operations might +// benefit from a shorter or longer duration. +func (p *Poller[T]) PollUntilDone(ctx context.Context, options *PollUntilDoneOptions) (res T, err error) { + if options == nil { + options = &PollUntilDoneOptions{} + } + cp := *options + if cp.Frequency == 0 { + cp.Frequency = 30 * time.Second + } + + ctx, endSpan := StartSpan(ctx, fmt.Sprintf("%s.PollUntilDone", shortenTypeName(reflect.TypeOf(*p).Name())), p.tracer, nil) + defer func() { endSpan(err) }() + + // skip the floor check when executing tests so they don't take so long + if isTest := flag.Lookup("test.v"); isTest == nil && cp.Frequency < time.Second { + err = errors.New("polling frequency minimum is one second") + return + } + + start := time.Now() + logPollUntilDoneExit := func(v any) { + log.Writef(log.EventLRO, "END PollUntilDone() for %T: %v, total time: %s", p.op, v, time.Since(start)) + } + log.Writef(log.EventLRO, "BEGIN PollUntilDone() for %T", p.op) + if p.resp != nil { + // initial check for a retry-after header existing on the initial response + if retryAfter := shared.RetryAfter(p.resp); retryAfter > 0 { + log.Writef(log.EventLRO, "initial Retry-After delay for %s", retryAfter.String()) + if err = shared.Delay(ctx, retryAfter); err != nil { + logPollUntilDoneExit(err) + return + } + } + } + // begin polling the endpoint until a terminal state is reached + for { + var resp *http.Response + resp, err = p.Poll(ctx) + if err != nil { + logPollUntilDoneExit(err) + return + } + if p.Done() { + logPollUntilDoneExit("succeeded") + res, err = p.Result(ctx) + return + } + d := cp.Frequency + if retryAfter := shared.RetryAfter(resp); retryAfter > 0 { + log.Writef(log.EventLRO, "Retry-After delay for %s", retryAfter.String()) + d = retryAfter + } else { + log.Writef(log.EventLRO, "delay for %s", d.String()) + } + if err = shared.Delay(ctx, d); err != nil { + logPollUntilDoneExit(err) + return + } + } +} + +// Poll fetches the latest state of the LRO. It returns an HTTP response or error. +// If Poll succeeds, the poller's state is updated and the HTTP response is returned. +// If Poll fails, the poller's state is unmodified and the error is returned. +// Calling Poll on an LRO that has reached a terminal state will return the last HTTP response. +func (p *Poller[T]) Poll(ctx context.Context) (resp *http.Response, err error) { + if p.Done() { + // the LRO has reached a terminal state, don't poll again + resp = p.resp + return + } + + ctx, endSpan := StartSpan(ctx, fmt.Sprintf("%s.Poll", shortenTypeName(reflect.TypeOf(*p).Name())), p.tracer, nil) + defer func() { endSpan(err) }() + + resp, err = p.op.Poll(ctx) + if err != nil { + return + } + p.resp = resp + return +} + +// Done returns true if the LRO has reached a terminal state. +// Once a terminal state is reached, call Result(). +func (p *Poller[T]) Done() bool { + return p.op.Done() +} + +// Result returns the result of the LRO and is meant to be used in conjunction with Poll and Done. +// If the LRO completed successfully, a populated instance of T is returned. +// If the LRO failed or was canceled, an *azcore.ResponseError error is returned. +// Calling this on an LRO in a non-terminal state will return an error. +func (p *Poller[T]) Result(ctx context.Context) (res T, err error) { + if !p.Done() { + err = errors.New("poller is in a non-terminal state") + return + } + if p.done { + // the result has already been retrieved, return the cached value + if p.err != nil { + err = p.err + return + } + res = *p.result + return + } + + ctx, endSpan := StartSpan(ctx, fmt.Sprintf("%s.Result", shortenTypeName(reflect.TypeOf(*p).Name())), p.tracer, nil) + defer func() { endSpan(err) }() + + err = p.op.Result(ctx, p.result) + var respErr *exported.ResponseError + if errors.As(err, &respErr) { + if pollers.IsNonTerminalHTTPStatusCode(respErr.RawResponse) { + // the request failed in a non-terminal way. + // don't cache the error or mark the Poller as done + return + } + // the LRO failed. record the error + p.err = err + } else if err != nil { + // the call to Result failed, don't cache anything in this case + return + } + p.done = true + if p.err != nil { + err = p.err + return + } + res = *p.result + return +} + +// ResumeToken returns a value representing the poller that can be used to resume +// the LRO at a later time. ResumeTokens are unique per service operation. +// The token's format should be considered opaque and is subject to change. +// Calling this on an LRO in a terminal state will return an error. +func (p *Poller[T]) ResumeToken() (string, error) { + if p.Done() { + return "", errors.New("poller is in a terminal state") + } + tk, err := pollers.NewResumeToken[T](p.op) + if err != nil { + return "", err + } + return tk, err +} + +// extracts the type name from the string returned from reflect.Value.Name() +func shortenTypeName(s string) string { + // the value is formatted as follows + // Poller[module/Package.Type].Method + // we want to shorten the generic type parameter string to Type + // anything we don't recognize will be left as-is + begin := strings.Index(s, "[") + end := strings.Index(s, "]") + if begin == -1 || end == -1 { + return s + } + + typeName := s[begin+1 : end] + if i := strings.LastIndex(typeName, "."); i > -1 { + typeName = typeName[i+1:] + } + return s[:begin+1] + typeName + s[end:] +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/request.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/request.go new file mode 100644 index 00000000000..06ac95b1b71 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/request.go @@ -0,0 +1,265 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "bytes" + "context" + "encoding/json" + "encoding/xml" + "errors" + "fmt" + "io" + "mime/multipart" + "net/textproto" + "net/url" + "path" + "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming" +) + +// Base64Encoding is usesd to specify which base-64 encoder/decoder to use when +// encoding/decoding a slice of bytes to/from a string. +type Base64Encoding = exported.Base64Encoding + +const ( + // Base64StdFormat uses base64.StdEncoding for encoding and decoding payloads. + Base64StdFormat Base64Encoding = exported.Base64StdFormat + + // Base64URLFormat uses base64.RawURLEncoding for encoding and decoding payloads. + Base64URLFormat Base64Encoding = exported.Base64URLFormat +) + +// NewRequest creates a new policy.Request with the specified input. +// The endpoint MUST be properly encoded before calling this function. +func NewRequest(ctx context.Context, httpMethod string, endpoint string) (*policy.Request, error) { + return exported.NewRequest(ctx, httpMethod, endpoint) +} + +// EncodeQueryParams will parse and encode any query parameters in the specified URL. +// Any semicolons will automatically be escaped. +func EncodeQueryParams(u string) (string, error) { + before, after, found := strings.Cut(u, "?") + if !found { + return u, nil + } + // starting in Go 1.17, url.ParseQuery will reject semicolons in query params. + // so, we must escape them first. note that this assumes that semicolons aren't + // being used as query param separators which is per the current RFC. + // for more info: + // https://github.com/golang/go/issues/25192 + // https://github.com/golang/go/issues/50034 + qp, err := url.ParseQuery(strings.ReplaceAll(after, ";", "%3B")) + if err != nil { + return "", err + } + return before + "?" + qp.Encode(), nil +} + +// JoinPaths concatenates multiple URL path segments into one path, +// inserting path separation characters as required. JoinPaths will preserve +// query parameters in the root path +func JoinPaths(root string, paths ...string) string { + if len(paths) == 0 { + return root + } + + qps := "" + if strings.Contains(root, "?") { + splitPath := strings.Split(root, "?") + root, qps = splitPath[0], splitPath[1] + } + + p := path.Join(paths...) + // path.Join will remove any trailing slashes. + // if one was provided, preserve it. + if strings.HasSuffix(paths[len(paths)-1], "/") && !strings.HasSuffix(p, "/") { + p += "/" + } + + if qps != "" { + p = p + "?" + qps + } + + if strings.HasSuffix(root, "/") && strings.HasPrefix(p, "/") { + root = root[:len(root)-1] + } else if !strings.HasSuffix(root, "/") && !strings.HasPrefix(p, "/") { + p = "/" + p + } + return root + p +} + +// EncodeByteArray will base-64 encode the byte slice v. +func EncodeByteArray(v []byte, format Base64Encoding) string { + return exported.EncodeByteArray(v, format) +} + +// MarshalAsByteArray will base-64 encode the byte slice v, then calls SetBody. +// The encoded value is treated as a JSON string. +func MarshalAsByteArray(req *policy.Request, v []byte, format Base64Encoding) error { + // send as a JSON string + encode := fmt.Sprintf("\"%s\"", EncodeByteArray(v, format)) + // tsp generated code can set Content-Type so we must prefer that + return exported.SetBody(req, exported.NopCloser(strings.NewReader(encode)), shared.ContentTypeAppJSON, false) +} + +// MarshalAsJSON calls json.Marshal() to get the JSON encoding of v then calls SetBody. +func MarshalAsJSON(req *policy.Request, v any) error { + b, err := json.Marshal(v) + if err != nil { + return fmt.Errorf("error marshalling type %T: %s", v, err) + } + // tsp generated code can set Content-Type so we must prefer that + return exported.SetBody(req, exported.NopCloser(bytes.NewReader(b)), shared.ContentTypeAppJSON, false) +} + +// MarshalAsXML calls xml.Marshal() to get the XML encoding of v then calls SetBody. +func MarshalAsXML(req *policy.Request, v any) error { + b, err := xml.Marshal(v) + if err != nil { + return fmt.Errorf("error marshalling type %T: %s", v, err) + } + // inclue the XML header as some services require it + b = []byte(xml.Header + string(b)) + return req.SetBody(exported.NopCloser(bytes.NewReader(b)), shared.ContentTypeAppXML) +} + +// SetMultipartFormData writes the specified keys/values as multi-part form fields with the specified value. +// File content must be specified as an [io.ReadSeekCloser] or [streaming.MultipartContent]. +// Byte slices will be treated as JSON. All other values are treated as string values. +func SetMultipartFormData(req *policy.Request, formData map[string]any) error { + body := bytes.Buffer{} + writer := multipart.NewWriter(&body) + + writeContent := func(fieldname, filename string, src io.Reader) error { + fd, err := writer.CreateFormFile(fieldname, filename) + if err != nil { + return err + } + // copy the data to the form file + if _, err = io.Copy(fd, src); err != nil { + return err + } + return nil + } + + quoteEscaper := strings.NewReplacer("\\", "\\\\", `"`, "\\\"") + + writeMultipartContent := func(fieldname string, mpc streaming.MultipartContent) error { + if mpc.Body == nil { + return errors.New("streaming.MultipartContent.Body cannot be nil") + } + + // use fieldname for the file name when unspecified + filename := fieldname + + if mpc.ContentType == "" && mpc.Filename == "" { + return writeContent(fieldname, filename, mpc.Body) + } + if mpc.Filename != "" { + filename = mpc.Filename + } + // this is pretty much copied from multipart.Writer.CreateFormFile + // but lets us set the caller provided Content-Type and filename + h := make(textproto.MIMEHeader) + h.Set("Content-Disposition", + fmt.Sprintf(`form-data; name="%s"; filename="%s"`, + quoteEscaper.Replace(fieldname), quoteEscaper.Replace(filename))) + contentType := "application/octet-stream" + if mpc.ContentType != "" { + contentType = mpc.ContentType + } + h.Set("Content-Type", contentType) + fd, err := writer.CreatePart(h) + if err != nil { + return err + } + // copy the data to the form file + if _, err = io.Copy(fd, mpc.Body); err != nil { + return err + } + return nil + } + + // the same as multipart.Writer.WriteField but lets us specify the Content-Type + writeField := func(fieldname, contentType string, value string) error { + h := make(textproto.MIMEHeader) + h.Set("Content-Disposition", + fmt.Sprintf(`form-data; name="%s"`, quoteEscaper.Replace(fieldname))) + h.Set("Content-Type", contentType) + fd, err := writer.CreatePart(h) + if err != nil { + return err + } + if _, err = fd.Write([]byte(value)); err != nil { + return err + } + return nil + } + + for k, v := range formData { + if rsc, ok := v.(io.ReadSeekCloser); ok { + if err := writeContent(k, k, rsc); err != nil { + return err + } + continue + } else if rscs, ok := v.([]io.ReadSeekCloser); ok { + for _, rsc := range rscs { + if err := writeContent(k, k, rsc); err != nil { + return err + } + } + continue + } else if mpc, ok := v.(streaming.MultipartContent); ok { + if err := writeMultipartContent(k, mpc); err != nil { + return err + } + continue + } else if mpcs, ok := v.([]streaming.MultipartContent); ok { + for _, mpc := range mpcs { + if err := writeMultipartContent(k, mpc); err != nil { + return err + } + } + continue + } + + var content string + contentType := shared.ContentTypeTextPlain + switch tt := v.(type) { + case []byte: + // JSON, don't quote it + content = string(tt) + contentType = shared.ContentTypeAppJSON + case string: + content = tt + default: + // ensure the value is in string format + content = fmt.Sprintf("%v", v) + } + + if err := writeField(k, contentType, content); err != nil { + return err + } + } + if err := writer.Close(); err != nil { + return err + } + return req.SetBody(exported.NopCloser(bytes.NewReader(body.Bytes())), writer.FormDataContentType()) +} + +// SkipBodyDownload will disable automatic downloading of the response body. +func SkipBodyDownload(req *policy.Request) { + req.SetOperationValue(bodyDownloadPolicyOpValues{Skip: true}) +} + +// CtxAPINameKey is used as a context key for adding/retrieving the API name. +type CtxAPINameKey = shared.CtxAPINameKey diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/response.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/response.go new file mode 100644 index 00000000000..048566e02c0 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/response.go @@ -0,0 +1,109 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "bytes" + "encoding/json" + "encoding/xml" + "fmt" + "io" + "net/http" + + azexported "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" + "github.com/Azure/azure-sdk-for-go/sdk/internal/exported" +) + +// Payload reads and returns the response body or an error. +// On a successful read, the response body is cached. +// Subsequent reads will access the cached value. +func Payload(resp *http.Response) ([]byte, error) { + return exported.Payload(resp, nil) +} + +// HasStatusCode returns true if the Response's status code is one of the specified values. +func HasStatusCode(resp *http.Response, statusCodes ...int) bool { + return exported.HasStatusCode(resp, statusCodes...) +} + +// UnmarshalAsByteArray will base-64 decode the received payload and place the result into the value pointed to by v. +func UnmarshalAsByteArray(resp *http.Response, v *[]byte, format Base64Encoding) error { + p, err := Payload(resp) + if err != nil { + return err + } + return DecodeByteArray(string(p), v, format) +} + +// UnmarshalAsJSON calls json.Unmarshal() to unmarshal the received payload into the value pointed to by v. +func UnmarshalAsJSON(resp *http.Response, v any) error { + payload, err := Payload(resp) + if err != nil { + return err + } + // TODO: verify early exit is correct + if len(payload) == 0 { + return nil + } + err = removeBOM(resp) + if err != nil { + return err + } + err = json.Unmarshal(payload, v) + if err != nil { + err = fmt.Errorf("unmarshalling type %T: %s", v, err) + } + return err +} + +// UnmarshalAsXML calls xml.Unmarshal() to unmarshal the received payload into the value pointed to by v. +func UnmarshalAsXML(resp *http.Response, v any) error { + payload, err := Payload(resp) + if err != nil { + return err + } + // TODO: verify early exit is correct + if len(payload) == 0 { + return nil + } + err = removeBOM(resp) + if err != nil { + return err + } + err = xml.Unmarshal(payload, v) + if err != nil { + err = fmt.Errorf("unmarshalling type %T: %s", v, err) + } + return err +} + +// Drain reads the response body to completion then closes it. The bytes read are discarded. +func Drain(resp *http.Response) { + if resp != nil && resp.Body != nil { + _, _ = io.Copy(io.Discard, resp.Body) + resp.Body.Close() + } +} + +// removeBOM removes any byte-order mark prefix from the payload if present. +func removeBOM(resp *http.Response) error { + _, err := exported.Payload(resp, &exported.PayloadOptions{ + BytesModifier: func(b []byte) []byte { + // UTF8 + return bytes.TrimPrefix(b, []byte("\xef\xbb\xbf")) + }, + }) + if err != nil { + return err + } + return nil +} + +// DecodeByteArray will base-64 decode the provided string into v. +func DecodeByteArray(s string, v *[]byte, format Base64Encoding) error { + return azexported.DecodeByteArray(s, v, format) +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/transport_default_dialer_other.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/transport_default_dialer_other.go new file mode 100644 index 00000000000..1c75d771f2e --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/transport_default_dialer_other.go @@ -0,0 +1,15 @@ +//go:build !wasm + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "context" + "net" +) + +func defaultTransportDialContext(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) { + return dialer.DialContext +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/transport_default_dialer_wasm.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/transport_default_dialer_wasm.go new file mode 100644 index 00000000000..3dc9eeecddf --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/transport_default_dialer_wasm.go @@ -0,0 +1,15 @@ +//go:build (js && wasm) || wasip1 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "context" + "net" +) + +func defaultTransportDialContext(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) { + return nil +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/transport_default_http_client.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/transport_default_http_client.go new file mode 100644 index 00000000000..2124c1d48b9 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime/transport_default_http_client.go @@ -0,0 +1,48 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "crypto/tls" + "net" + "net/http" + "time" + + "golang.org/x/net/http2" +) + +var defaultHTTPClient *http.Client + +func init() { + defaultTransport := &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: defaultTransportDialContext(&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }), + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + TLSClientConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + Renegotiation: tls.RenegotiateFreelyAsClient, + }, + } + // TODO: evaluate removing this once https://github.com/golang/go/issues/59690 has been fixed + if http2Transport, err := http2.ConfigureTransports(defaultTransport); err == nil { + // if the connection has been idle for 10 seconds, send a ping frame for a health check + http2Transport.ReadIdleTimeout = 10 * time.Second + // if there's no response to the ping within the timeout, the connection will be closed + http2Transport.PingTimeout = 5 * time.Second + } + defaultHTTPClient = &http.Client{ + Transport: defaultTransport, + } +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming/doc.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming/doc.go new file mode 100644 index 00000000000..cadaef3d584 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming/doc.go @@ -0,0 +1,9 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright 2017 Microsoft Corporation. All rights reserved. +// Use of this source code is governed by an MIT +// license that can be found in the LICENSE file. + +// Package streaming contains helpers for streaming IO operations and progress reporting. +package streaming diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming/progress.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming/progress.go new file mode 100644 index 00000000000..2468540bd75 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming/progress.go @@ -0,0 +1,89 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package streaming + +import ( + "io" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" +) + +type progress struct { + rc io.ReadCloser + rsc io.ReadSeekCloser + pr func(bytesTransferred int64) + offset int64 +} + +// NopCloser returns a ReadSeekCloser with a no-op close method wrapping the provided io.ReadSeeker. +// In addition to adding a Close method to an io.ReadSeeker, this can also be used to wrap an +// io.ReadSeekCloser with a no-op Close method to allow explicit control of when the io.ReedSeekCloser +// has its underlying stream closed. +func NopCloser(rs io.ReadSeeker) io.ReadSeekCloser { + return exported.NopCloser(rs) +} + +// NewRequestProgress adds progress reporting to an HTTP request's body stream. +func NewRequestProgress(body io.ReadSeekCloser, pr func(bytesTransferred int64)) io.ReadSeekCloser { + return &progress{ + rc: body, + rsc: body, + pr: pr, + offset: 0, + } +} + +// NewResponseProgress adds progress reporting to an HTTP response's body stream. +func NewResponseProgress(body io.ReadCloser, pr func(bytesTransferred int64)) io.ReadCloser { + return &progress{ + rc: body, + rsc: nil, + pr: pr, + offset: 0, + } +} + +// Read reads a block of data from an inner stream and reports progress +func (p *progress) Read(b []byte) (n int, err error) { + n, err = p.rc.Read(b) + if err != nil && err != io.EOF { + return + } + p.offset += int64(n) + // Invokes the user's callback method to report progress + p.pr(p.offset) + return +} + +// Seek only expects a zero or from beginning. +func (p *progress) Seek(offset int64, whence int) (int64, error) { + // This should only ever be called with offset = 0 and whence = io.SeekStart + n, err := p.rsc.Seek(offset, whence) + if err == nil { + p.offset = int64(n) + } + return n, err +} + +// requestBodyProgress supports Close but the underlying stream may not; if it does, Close will close it. +func (p *progress) Close() error { + return p.rc.Close() +} + +// MultipartContent contains streaming content used in multipart/form payloads. +type MultipartContent struct { + // Body contains the required content body. + Body io.ReadSeekCloser + + // ContentType optionally specifies the HTTP Content-Type for this Body. + // The default value is application/octet-stream. + ContentType string + + // Filename optionally specifies the filename for this Body. + // The default value is the field name for the multipart/form section. + Filename string +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/to/doc.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/to/doc.go new file mode 100644 index 00000000000..faa98c9dc51 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/to/doc.go @@ -0,0 +1,9 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright 2017 Microsoft Corporation. All rights reserved. +// Use of this source code is governed by an MIT +// license that can be found in the LICENSE file. + +// Package to contains various type-conversion helper functions. +package to diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/to/to.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/to/to.go new file mode 100644 index 00000000000..e0e4817b90d --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/to/to.go @@ -0,0 +1,21 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package to + +// Ptr returns a pointer to the provided value. +func Ptr[T any](v T) *T { + return &v +} + +// SliceOfPtrs returns a slice of *T from the specified values. +func SliceOfPtrs[T any](vv ...T) []*T { + slc := make([]*T, len(vv)) + for i := range vv { + slc[i] = Ptr(vv[i]) + } + return slc +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/tracing/constants.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/tracing/constants.go new file mode 100644 index 00000000000..80282d4ab0a --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/tracing/constants.go @@ -0,0 +1,41 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package tracing + +// SpanKind represents the role of a Span inside a Trace. Often, this defines how a Span will be processed and visualized by various backends. +type SpanKind int + +const ( + // SpanKindInternal indicates the span represents an internal operation within an application. + SpanKindInternal SpanKind = 1 + + // SpanKindServer indicates the span covers server-side handling of a request. + SpanKindServer SpanKind = 2 + + // SpanKindClient indicates the span describes a request to a remote service. + SpanKindClient SpanKind = 3 + + // SpanKindProducer indicates the span was created by a messaging producer. + SpanKindProducer SpanKind = 4 + + // SpanKindConsumer indicates the span was created by a messaging consumer. + SpanKindConsumer SpanKind = 5 +) + +// SpanStatus represents the status of a span. +type SpanStatus int + +const ( + // SpanStatusUnset is the default status code. + SpanStatusUnset SpanStatus = 0 + + // SpanStatusError indicates the operation contains an error. + SpanStatusError SpanStatus = 1 + + // SpanStatusOK indicates the operation completed successfully. + SpanStatusOK SpanStatus = 2 +) diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/tracing/tracing.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/tracing/tracing.go new file mode 100644 index 00000000000..1ade7c560ff --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/azcore/tracing/tracing.go @@ -0,0 +1,191 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Package tracing contains the definitions needed to support distributed tracing. +package tracing + +import ( + "context" +) + +// ProviderOptions contains the optional values when creating a Provider. +type ProviderOptions struct { + // for future expansion +} + +// NewProvider creates a new Provider with the specified values. +// - newTracerFn is the underlying implementation for creating Tracer instances +// - options contains optional values; pass nil to accept the default value +func NewProvider(newTracerFn func(name, version string) Tracer, options *ProviderOptions) Provider { + return Provider{ + newTracerFn: newTracerFn, + } +} + +// Provider is the factory that creates Tracer instances. +// It defaults to a no-op provider. +type Provider struct { + newTracerFn func(name, version string) Tracer +} + +// NewTracer creates a new Tracer for the specified module name and version. +// - module - the fully qualified name of the module +// - version - the version of the module +func (p Provider) NewTracer(module, version string) (tracer Tracer) { + if p.newTracerFn != nil { + tracer = p.newTracerFn(module, version) + } + return +} + +///////////////////////////////////////////////////////////////////////////////////////////////////////////// + +// TracerOptions contains the optional values when creating a Tracer. +type TracerOptions struct { + // SpanFromContext contains the implementation for the Tracer.SpanFromContext method. + SpanFromContext func(context.Context) Span +} + +// NewTracer creates a Tracer with the specified values. +// - newSpanFn is the underlying implementation for creating Span instances +// - options contains optional values; pass nil to accept the default value +func NewTracer(newSpanFn func(ctx context.Context, spanName string, options *SpanOptions) (context.Context, Span), options *TracerOptions) Tracer { + if options == nil { + options = &TracerOptions{} + } + return Tracer{ + newSpanFn: newSpanFn, + spanFromContextFn: options.SpanFromContext, + } +} + +// Tracer is the factory that creates Span instances. +type Tracer struct { + attrs []Attribute + newSpanFn func(ctx context.Context, spanName string, options *SpanOptions) (context.Context, Span) + spanFromContextFn func(ctx context.Context) Span +} + +// Start creates a new span and a context.Context that contains it. +// - ctx is the parent context for this span. If it contains a Span, the newly created span will be a child of that span, else it will be a root span +// - spanName identifies the span within a trace, it's typically the fully qualified API name +// - options contains optional values for the span, pass nil to accept any defaults +func (t Tracer) Start(ctx context.Context, spanName string, options *SpanOptions) (context.Context, Span) { + if t.newSpanFn != nil { + opts := SpanOptions{} + if options != nil { + opts = *options + } + opts.Attributes = append(opts.Attributes, t.attrs...) + return t.newSpanFn(ctx, spanName, &opts) + } + return ctx, Span{} +} + +// SetAttributes sets attrs to be applied to each Span. If a key from attrs +// already exists for an attribute of the Span it will be overwritten with +// the value contained in attrs. +func (t *Tracer) SetAttributes(attrs ...Attribute) { + t.attrs = append(t.attrs, attrs...) +} + +// Enabled returns true if this Tracer is capable of creating Spans. +func (t Tracer) Enabled() bool { + return t.newSpanFn != nil +} + +// SpanFromContext returns the Span associated with the current context. +// If the provided context has no Span, false is returned. +func (t Tracer) SpanFromContext(ctx context.Context) Span { + if t.spanFromContextFn != nil { + return t.spanFromContextFn(ctx) + } + return Span{} +} + +// SpanOptions contains optional settings for creating a span. +type SpanOptions struct { + // Kind indicates the kind of Span. + Kind SpanKind + + // Attributes contains key-value pairs of attributes for the span. + Attributes []Attribute +} + +///////////////////////////////////////////////////////////////////////////////////////////////////////////// + +// SpanImpl abstracts the underlying implementation for Span, +// allowing it to work with various tracing implementations. +// Any zero-values will have their default, no-op behavior. +type SpanImpl struct { + // End contains the implementation for the Span.End method. + End func() + + // SetAttributes contains the implementation for the Span.SetAttributes method. + SetAttributes func(...Attribute) + + // AddEvent contains the implementation for the Span.AddEvent method. + AddEvent func(string, ...Attribute) + + // SetStatus contains the implementation for the Span.SetStatus method. + SetStatus func(SpanStatus, string) +} + +// NewSpan creates a Span with the specified implementation. +func NewSpan(impl SpanImpl) Span { + return Span{ + impl: impl, + } +} + +// Span is a single unit of a trace. A trace can contain multiple spans. +// A zero-value Span provides a no-op implementation. +type Span struct { + impl SpanImpl +} + +// End terminates the span and MUST be called before the span leaves scope. +// Any further updates to the span will be ignored after End is called. +func (s Span) End() { + if s.impl.End != nil { + s.impl.End() + } +} + +// SetAttributes sets the specified attributes on the Span. +// Any existing attributes with the same keys will have their values overwritten. +func (s Span) SetAttributes(attrs ...Attribute) { + if s.impl.SetAttributes != nil { + s.impl.SetAttributes(attrs...) + } +} + +// AddEvent adds a named event with an optional set of attributes to the span. +func (s Span) AddEvent(name string, attrs ...Attribute) { + if s.impl.AddEvent != nil { + s.impl.AddEvent(name, attrs...) + } +} + +// SetStatus sets the status on the span along with a description. +func (s Span) SetStatus(code SpanStatus, desc string) { + if s.impl.SetStatus != nil { + s.impl.SetStatus(code, desc) + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////////////////// + +// Attribute is a key-value pair. +type Attribute struct { + // Key is the name of the attribute. + Key string + + // Value is the attribute's value. + // Types that are natively supported include int64, float64, int, bool, string. + // Any other type will be formatted per rules of fmt.Sprintf("%v"). + Value any +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/LICENSE.txt b/vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/LICENSE.txt new file mode 100644 index 00000000000..48ea6616b5b --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/LICENSE.txt @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) Microsoft Corporation. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/diag/diag.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/diag/diag.go new file mode 100644 index 00000000000..245af7d2bec --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/diag/diag.go @@ -0,0 +1,51 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package diag + +import ( + "fmt" + "runtime" + "strings" +) + +// Caller returns the file and line number of a frame on the caller's stack. +// If the funtion fails an empty string is returned. +// skipFrames - the number of frames to skip when determining the caller. +// Passing a value of 0 will return the immediate caller of this function. +func Caller(skipFrames int) string { + if pc, file, line, ok := runtime.Caller(skipFrames + 1); ok { + // the skipFrames + 1 is to skip ourselves + frame := runtime.FuncForPC(pc) + return fmt.Sprintf("%s()\n\t%s:%d", frame.Name(), file, line) + } + return "" +} + +// StackTrace returns a formatted stack trace string. +// If the funtion fails an empty string is returned. +// skipFrames - the number of stack frames to skip before composing the trace string. +// totalFrames - the maximum number of stack frames to include in the trace string. +func StackTrace(skipFrames, totalFrames int) string { + pcCallers := make([]uintptr, totalFrames) + if frames := runtime.Callers(skipFrames, pcCallers); frames == 0 { + return "" + } + frames := runtime.CallersFrames(pcCallers) + sb := strings.Builder{} + for { + frame, more := frames.Next() + sb.WriteString(frame.Function) + sb.WriteString("()\n\t") + sb.WriteString(frame.File) + sb.WriteRune(':') + sb.WriteString(fmt.Sprintf("%d\n", frame.Line)) + if !more { + break + } + } + return sb.String() +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/diag/doc.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/diag/doc.go new file mode 100644 index 00000000000..66bf13e5f04 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/diag/doc.go @@ -0,0 +1,7 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package diag diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/errorinfo/doc.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/errorinfo/doc.go new file mode 100644 index 00000000000..8c6eacb618a --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/errorinfo/doc.go @@ -0,0 +1,7 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package errorinfo diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/errorinfo/errorinfo.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/errorinfo/errorinfo.go new file mode 100644 index 00000000000..8ee66b52676 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/errorinfo/errorinfo.go @@ -0,0 +1,46 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package errorinfo + +// NonRetriable represents a non-transient error. This works in +// conjunction with the retry policy, indicating that the error condition +// is idempotent, so no retries will be attempted. +// Use errors.As() to access this interface in the error chain. +type NonRetriable interface { + error + NonRetriable() +} + +// NonRetriableError marks the specified error as non-retriable. +// This function takes an error as input and returns a new error that is marked as non-retriable. +func NonRetriableError(err error) error { + return &nonRetriableError{err} +} + +// nonRetriableError is a struct that embeds the error interface. +// It is used to represent errors that should not be retried. +type nonRetriableError struct { + error +} + +// Error method for nonRetriableError struct. +// It returns the error message of the embedded error. +func (p *nonRetriableError) Error() string { + return p.error.Error() +} + +// NonRetriable is a marker method for nonRetriableError struct. +// Non-functional and indicates that the error is non-retriable. +func (*nonRetriableError) NonRetriable() { + // marker method +} + +// Unwrap method for nonRetriableError struct. +// It returns the original error that was marked as non-retriable. +func (p *nonRetriableError) Unwrap() error { + return p.error +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/exported/exported.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/exported/exported.go new file mode 100644 index 00000000000..9948f604b30 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/exported/exported.go @@ -0,0 +1,129 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package exported + +import ( + "errors" + "io" + "net/http" +) + +// HasStatusCode returns true if the Response's status code is one of the specified values. +// Exported as runtime.HasStatusCode(). +func HasStatusCode(resp *http.Response, statusCodes ...int) bool { + if resp == nil { + return false + } + for _, sc := range statusCodes { + if resp.StatusCode == sc { + return true + } + } + return false +} + +// PayloadOptions contains the optional values for the Payload func. +// NOT exported but used by azcore. +type PayloadOptions struct { + // BytesModifier receives the downloaded byte slice and returns an updated byte slice. + // Use this to modify the downloaded bytes in a payload (e.g. removing a BOM). + BytesModifier func([]byte) []byte +} + +// Payload reads and returns the response body or an error. +// On a successful read, the response body is cached. +// Subsequent reads will access the cached value. +// Exported as runtime.Payload() WITHOUT the opts parameter. +func Payload(resp *http.Response, opts *PayloadOptions) ([]byte, error) { + if resp.Body == nil { + // this shouldn't happen in real-world scenarios as a + // response with no body should set it to http.NoBody + return nil, nil + } + modifyBytes := func(b []byte) []byte { return b } + if opts != nil && opts.BytesModifier != nil { + modifyBytes = opts.BytesModifier + } + + // r.Body won't be a nopClosingBytesReader if downloading was skipped + if buf, ok := resp.Body.(*nopClosingBytesReader); ok { + bytesBody := modifyBytes(buf.Bytes()) + buf.Set(bytesBody) + return bytesBody, nil + } + + bytesBody, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + return nil, err + } + + bytesBody = modifyBytes(bytesBody) + resp.Body = &nopClosingBytesReader{s: bytesBody} + return bytesBody, nil +} + +// PayloadDownloaded returns true if the response body has already been downloaded. +// This implies that the Payload() func above has been previously called. +// NOT exported but used by azcore. +func PayloadDownloaded(resp *http.Response) bool { + _, ok := resp.Body.(*nopClosingBytesReader) + return ok +} + +// nopClosingBytesReader is an io.ReadSeekCloser around a byte slice. +// It also provides direct access to the byte slice to avoid rereading. +type nopClosingBytesReader struct { + s []byte + i int64 +} + +// Bytes returns the underlying byte slice. +func (r *nopClosingBytesReader) Bytes() []byte { + return r.s +} + +// Close implements the io.Closer interface. +func (*nopClosingBytesReader) Close() error { + return nil +} + +// Read implements the io.Reader interface. +func (r *nopClosingBytesReader) Read(b []byte) (n int, err error) { + if r.i >= int64(len(r.s)) { + return 0, io.EOF + } + n = copy(b, r.s[r.i:]) + r.i += int64(n) + return +} + +// Set replaces the existing byte slice with the specified byte slice and resets the reader. +func (r *nopClosingBytesReader) Set(b []byte) { + r.s = b + r.i = 0 +} + +// Seek implements the io.Seeker interface. +func (r *nopClosingBytesReader) Seek(offset int64, whence int) (int64, error) { + var i int64 + switch whence { + case io.SeekStart: + i = offset + case io.SeekCurrent: + i = r.i + offset + case io.SeekEnd: + i = int64(len(r.s)) + offset + default: + return 0, errors.New("nopClosingBytesReader: invalid whence") + } + if i < 0 { + return 0, errors.New("nopClosingBytesReader: negative position") + } + r.i = i + return i, nil +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/log/doc.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/log/doc.go new file mode 100644 index 00000000000..d7876d297ae --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/log/doc.go @@ -0,0 +1,7 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package log diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/log/log.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/log/log.go new file mode 100644 index 00000000000..4f1dcf1b78a --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/log/log.go @@ -0,0 +1,104 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package log + +import ( + "fmt" + "os" + "time" +) + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// NOTE: The following are exported as public surface area from azcore. DO NOT MODIFY +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Event is used to group entries. Each group can be toggled on or off. +type Event string + +// SetEvents is used to control which events are written to +// the log. By default all log events are writen. +func SetEvents(cls ...Event) { + log.cls = cls +} + +// SetListener will set the Logger to write to the specified listener. +func SetListener(lst func(Event, string)) { + log.lst = lst +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// END PUBLIC SURFACE AREA +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Should returns true if the specified log event should be written to the log. +// By default all log events will be logged. Call SetEvents() to limit +// the log events for logging. +// If no listener has been set this will return false. +// Calling this method is useful when the message to log is computationally expensive +// and you want to avoid the overhead if its log event is not enabled. +func Should(cls Event) bool { + if log.lst == nil { + return false + } + if log.cls == nil || len(log.cls) == 0 { + return true + } + for _, c := range log.cls { + if c == cls { + return true + } + } + return false +} + +// Write invokes the underlying listener with the specified event and message. +// If the event shouldn't be logged or there is no listener then Write does nothing. +func Write(cls Event, message string) { + if !Should(cls) { + return + } + log.lst(cls, message) +} + +// Writef invokes the underlying listener with the specified event and formatted message. +// If the event shouldn't be logged or there is no listener then Writef does nothing. +func Writef(cls Event, format string, a ...interface{}) { + if !Should(cls) { + return + } + log.lst(cls, fmt.Sprintf(format, a...)) +} + +// TestResetEvents is used for TESTING PURPOSES ONLY. +func TestResetEvents() { + log.cls = nil +} + +// logger controls which events to log and writing to the underlying log. +type logger struct { + cls []Event + lst func(Event, string) +} + +// the process-wide logger +var log logger + +func init() { + initLogging() +} + +// split out for testing purposes +func initLogging() { + if cls := os.Getenv("AZURE_SDK_GO_LOGGING"); cls == "all" { + // cls could be enhanced to support a comma-delimited list of log events + log.lst = func(cls Event, msg string) { + // simple console logger, it writes to stderr in the following format: + // [time-stamp] Event: message + fmt.Fprintf(os.Stderr, "[%s] %s: %s\n", time.Now().Format(time.StampMicro), cls, msg) + } + } +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/poller/util.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/poller/util.go new file mode 100644 index 00000000000..db8269627d3 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/poller/util.go @@ -0,0 +1,155 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package poller + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/internal/exported" +) + +// the well-known set of LRO status/provisioning state values. +const ( + StatusSucceeded = "Succeeded" + StatusCanceled = "Canceled" + StatusFailed = "Failed" + StatusInProgress = "InProgress" +) + +// these are non-conformant states that we've seen in the wild. +// we support them for back-compat. +const ( + StatusCancelled = "Cancelled" + StatusCompleted = "Completed" +) + +// IsTerminalState returns true if the LRO's state is terminal. +func IsTerminalState(s string) bool { + return Failed(s) || Succeeded(s) +} + +// Failed returns true if the LRO's state is terminal failure. +func Failed(s string) bool { + return strings.EqualFold(s, StatusFailed) || strings.EqualFold(s, StatusCanceled) || strings.EqualFold(s, StatusCancelled) +} + +// Succeeded returns true if the LRO's state is terminal success. +func Succeeded(s string) bool { + return strings.EqualFold(s, StatusSucceeded) || strings.EqualFold(s, StatusCompleted) +} + +// returns true if the LRO response contains a valid HTTP status code +func StatusCodeValid(resp *http.Response) bool { + return exported.HasStatusCode(resp, http.StatusOK, http.StatusAccepted, http.StatusCreated, http.StatusNoContent) +} + +// IsValidURL verifies that the URL is valid and absolute. +func IsValidURL(s string) bool { + u, err := url.Parse(s) + return err == nil && u.IsAbs() +} + +// ErrNoBody is returned if the response didn't contain a body. +var ErrNoBody = errors.New("the response did not contain a body") + +// GetJSON reads the response body into a raw JSON object. +// It returns ErrNoBody if there was no content. +func GetJSON(resp *http.Response) (map[string]any, error) { + body, err := exported.Payload(resp, nil) + if err != nil { + return nil, err + } + if len(body) == 0 { + return nil, ErrNoBody + } + // unmarshall the body to get the value + var jsonBody map[string]any + if err = json.Unmarshal(body, &jsonBody); err != nil { + return nil, err + } + return jsonBody, nil +} + +// provisioningState returns the provisioning state from the response or the empty string. +func provisioningState(jsonBody map[string]any) string { + jsonProps, ok := jsonBody["properties"] + if !ok { + return "" + } + props, ok := jsonProps.(map[string]any) + if !ok { + return "" + } + rawPs, ok := props["provisioningState"] + if !ok { + return "" + } + ps, ok := rawPs.(string) + if !ok { + return "" + } + return ps +} + +// status returns the status from the response or the empty string. +func status(jsonBody map[string]any) string { + rawStatus, ok := jsonBody["status"] + if !ok { + return "" + } + status, ok := rawStatus.(string) + if !ok { + return "" + } + return status +} + +// GetStatus returns the LRO's status from the response body. +// Typically used for Azure-AsyncOperation flows. +// If there is no status in the response body the empty string is returned. +func GetStatus(resp *http.Response) (string, error) { + jsonBody, err := GetJSON(resp) + if err != nil { + return "", err + } + return status(jsonBody), nil +} + +// GetProvisioningState returns the LRO's state from the response body. +// If there is no state in the response body the empty string is returned. +func GetProvisioningState(resp *http.Response) (string, error) { + jsonBody, err := GetJSON(resp) + if err != nil { + return "", err + } + return provisioningState(jsonBody), nil +} + +// GetResourceLocation returns the LRO's resourceLocation value from the response body. +// Typically used for Operation-Location flows. +// If there is no resourceLocation in the response body the empty string is returned. +func GetResourceLocation(resp *http.Response) (string, error) { + jsonBody, err := GetJSON(resp) + if err != nil { + return "", err + } + v, ok := jsonBody["resourceLocation"] + if !ok { + // it might be ok if the field doesn't exist, the caller must make that determination + return "", nil + } + vv, ok := v.(string) + if !ok { + return "", fmt.Errorf("the resourceLocation value %v was not in string format", v) + } + return vv, nil +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/telemetry/telemetry.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/telemetry/telemetry.go new file mode 100644 index 00000000000..8ac589ade56 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/telemetry/telemetry.go @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package telemetry + +import ( + "fmt" + "os" + "runtime" +) + +// Format creates the properly formatted SDK component for a User-Agent string. +// Ex: azsdk-go-azservicebus/v1.0.0 (go1.19.3; linux) +// comp - the package name for a component (ex: azservicebus) +// ver - the version of the component (ex: v1.0.0) +func Format(comp, ver string) string { + // ex: azsdk-go-azservicebus/v1.0.0 (go1.19.3; windows) + return fmt.Sprintf("azsdk-go-%s/%s %s", comp, ver, platformInfo) +} + +// platformInfo is the Go version and OS, formatted properly for insertion +// into a User-Agent string. (ex: '(go1.19.3; windows') +// NOTE: the ONLY function that should write to this variable is this func +var platformInfo = func() string { + operatingSystem := runtime.GOOS // Default OS string + switch operatingSystem { + case "windows": + operatingSystem = os.Getenv("OS") // Get more specific OS information + case "linux": // accept default OS info + case "freebsd": // accept default OS info + } + return fmt.Sprintf("(%s; %s)", runtime.Version(), operatingSystem) +}() diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/temporal/resource.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/temporal/resource.go new file mode 100644 index 00000000000..238ef42ed03 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/temporal/resource.go @@ -0,0 +1,123 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package temporal + +import ( + "sync" + "time" +) + +// AcquireResource abstracts a method for refreshing a temporal resource. +type AcquireResource[TResource, TState any] func(state TState) (newResource TResource, newExpiration time.Time, err error) + +// Resource is a temporal resource (usually a credential) that requires periodic refreshing. +type Resource[TResource, TState any] struct { + // cond is used to synchronize access to the shared resource embodied by the remaining fields + cond *sync.Cond + + // acquiring indicates that some thread/goroutine is in the process of acquiring/updating the resource + acquiring bool + + // resource contains the value of the shared resource + resource TResource + + // expiration indicates when the shared resource expires; it is 0 if the resource was never acquired + expiration time.Time + + // lastAttempt indicates when a thread/goroutine last attempted to acquire/update the resource + lastAttempt time.Time + + // acquireResource is the callback function that actually acquires the resource + acquireResource AcquireResource[TResource, TState] +} + +// NewResource creates a new Resource that uses the specified AcquireResource for refreshing. +func NewResource[TResource, TState any](ar AcquireResource[TResource, TState]) *Resource[TResource, TState] { + return &Resource[TResource, TState]{cond: sync.NewCond(&sync.Mutex{}), acquireResource: ar} +} + +// Get returns the underlying resource. +// If the resource is fresh, no refresh is performed. +func (er *Resource[TResource, TState]) Get(state TState) (TResource, error) { + // If the resource is expiring within this time window, update it eagerly. + // This allows other threads/goroutines to keep running by using the not-yet-expired + // resource value while one thread/goroutine updates the resource. + const window = 5 * time.Minute // This example updates the resource 5 minutes prior to expiration + const backoff = 30 * time.Second // Minimum wait time between eager update attempts + + now, acquire, expired := time.Now(), false, false + + // acquire exclusive lock + er.cond.L.Lock() + resource := er.resource + + for { + expired = er.expiration.IsZero() || er.expiration.Before(now) + if expired { + // The resource was never acquired or has expired + if !er.acquiring { + // If another thread/goroutine is not acquiring/updating the resource, this thread/goroutine will do it + er.acquiring, acquire = true, true + break + } + // Getting here means that this thread/goroutine will wait for the updated resource + } else if er.expiration.Add(-window).Before(now) { + // The resource is valid but is expiring within the time window + if !er.acquiring && er.lastAttempt.Add(backoff).Before(now) { + // If another thread/goroutine is not acquiring/renewing the resource, and none has attempted + // to do so within the last 30 seconds, this thread/goroutine will do it + er.acquiring, acquire = true, true + break + } + // This thread/goroutine will use the existing resource value while another updates it + resource = er.resource + break + } else { + // The resource is not close to expiring, this thread/goroutine should use its current value + resource = er.resource + break + } + // If we get here, wait for the new resource value to be acquired/updated + er.cond.Wait() + } + er.cond.L.Unlock() // Release the lock so no threads/goroutines are blocked + + var err error + if acquire { + // This thread/goroutine has been selected to acquire/update the resource + var expiration time.Time + var newValue TResource + er.lastAttempt = now + newValue, expiration, err = er.acquireResource(state) + + // Atomically, update the shared resource's new value & expiration. + er.cond.L.Lock() + if err == nil { + // Update resource & expiration, return the new value + resource = newValue + er.resource, er.expiration = resource, expiration + } else if !expired { + // An eager update failed. Discard the error and return the current--still valid--resource value + err = nil + } + er.acquiring = false // Indicate that no thread/goroutine is currently acquiring the resource + + // Wake up any waiting threads/goroutines since there is a resource they can ALL use + er.cond.L.Unlock() + er.cond.Broadcast() + } + return resource, err // Return the resource this thread/goroutine can use +} + +// Expire marks the resource as expired, ensuring it's refreshed on the next call to Get(). +func (er *Resource[TResource, TState]) Expire() { + er.cond.L.Lock() + defer er.cond.L.Unlock() + + // Reset the expiration as if we never got this resource to begin with + er.expiration = time.Time{} +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/uuid/doc.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/uuid/doc.go new file mode 100644 index 00000000000..a3824bee8b5 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/uuid/doc.go @@ -0,0 +1,7 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package uuid diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/uuid/uuid.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/uuid/uuid.go new file mode 100644 index 00000000000..278ac9cd1c2 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/internal/uuid/uuid.go @@ -0,0 +1,76 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package uuid + +import ( + "crypto/rand" + "errors" + "fmt" + "strconv" +) + +// The UUID reserved variants. +const ( + reservedRFC4122 byte = 0x40 +) + +// A UUID representation compliant with specification in RFC4122 document. +type UUID [16]byte + +// New returns a new UUID using the RFC4122 algorithm. +func New() (UUID, error) { + u := UUID{} + // Set all bits to pseudo-random values. + // NOTE: this takes a process-wide lock + _, err := rand.Read(u[:]) + if err != nil { + return u, err + } + u[8] = (u[8] | reservedRFC4122) & 0x7F // u.setVariant(ReservedRFC4122) + + var version byte = 4 + u[6] = (u[6] & 0xF) | (version << 4) // u.setVersion(4) + return u, nil +} + +// String returns the UUID in "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" format. +func (u UUID) String() string { + return fmt.Sprintf("%x-%x-%x-%x-%x", u[0:4], u[4:6], u[6:8], u[8:10], u[10:]) +} + +// Parse parses a string formatted as "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" +// or "{xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx}" into a UUID. +func Parse(s string) (UUID, error) { + var uuid UUID + // ensure format + switch len(s) { + case 36: + // xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx + case 38: + // {xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx} + s = s[1:37] + default: + return uuid, errors.New("invalid UUID format") + } + if s[8] != '-' || s[13] != '-' || s[18] != '-' || s[23] != '-' { + return uuid, errors.New("invalid UUID format") + } + // parse chunks + for i, x := range [16]int{ + 0, 2, 4, 6, + 9, 11, + 14, 16, + 19, 21, + 24, 26, 28, 30, 32, 34} { + b, err := strconv.ParseUint(s[x:x+2], 16, 8) + if err != nil { + return uuid, fmt.Errorf("invalid UUID format: %s", err) + } + uuid[i] = byte(b) + } + return uuid, nil +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/CHANGELOG.md b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/CHANGELOG.md new file mode 100644 index 00000000000..35ed4b321ca --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/CHANGELOG.md @@ -0,0 +1,177 @@ +# Release History + +## 1.2.1 (2024-05-20) + +### Bugs Fixed + +- Emulator strings should allow for hosts other than localhost (PR#22898) + +## 1.2.0 (2024-05-07) + +### Bugs Fixed + +Processor.Run had unclear behavior for some cases: +- Run() now returns an explicit error when called more than once on a single + Processor instance or if multiple Run calls are made concurrently. (PR#22833) +- NextProcessorClient now properly terminates (and returns nil) if called on a + stopped Processor. (PR#22833) + +## 1.1.0 (2024-04-02) + +### Features Added + +- Add in ability to handle emulator connection strings. (PR#22663) + +### Bugs Fixed + +- Fixed a race condition between Processor.Run() and Processor.NextPartitionClient() where cancelling Run() quickly could lead to NextPartitionClient hanging indefinitely. (PR#22541) + +## 1.0.4 (2024-03-05) + +### Bugs Fixed + +- Fixed case where closing a Receiver/Sender after an idle period would take > 20 seconds. (PR#22509) + +## 1.0.3 (2024-01-16) + +### Bugs Fixed + +- Processor distributes partitions optimally, which would result in idle or over-assigned processors. (PR#22153) + +## 1.0.2 (2023-11-07) + +### Bugs Fixed + +- Processor now relinquishes ownership of partitions when it shuts down, making them immediately available to other active Processor instances. (PR#21899) + +## 1.0.1 (2023-06-06) + +### Bugs Fixed + +- GetPartitionProperties and GetEventHubProperties now retry properly on failures. (PR#20893) +- Connection recovery could artifically fail, prolonging recovery. (PR#20883) + +## 1.0.0 (2023-05-09) + +### Features Added + +- First stable release of the azeventhubs package. +- Authentication errors are indicated with an `azeventhubs.Error`, with a `Code` of `azeventhubs.ErrorCodeUnauthorizedAccess`. (PR#20450) + +### Bugs Fixed + +- Authentication errors could cause unnecessary retries, making calls taking longer to fail. (PR#20450) +- Recovery now includes internal timeouts and also handles restarting a connection if AMQP primitives aren't closed cleanly. +- Potential leaks for $cbs and $management when there was a partial failure. (PR#20564) +- Latest go-amqp changes have been merged in with fixes for robustness. +- Sending a message to an entity that is full will no longer retry. (PR#20722) +- Checkpoint store handles multiple initial owners properly, allowing only one through. (PR#20727) + +## 0.6.0 (2023-03-07) + +### Features Added + +- Added the `ConsumerClientOptions.InstanceID` field. This optional field can enhance error messages from + Event Hubs. For example, error messages related to ownership changes for a partition will contain the + name of the link that has taken ownership, which can help with traceability. + +### Breaking Changes + +- `ConsumerClient.ID()` renamed to `ConsumerClient.InstanceID()`. + +### Bugs Fixed + +- Recover the connection when the $cbs Receiver/Sender is not closed properly. This would cause + clients to return an error saying "$cbs node has already been opened." (PR#20334) + +## 0.5.0 (2023-02-07) + +### Features Added + +- Adds ProcessorOptions.Prefetch field, allowing configuration of Prefetch values for PartitionClients created using the Processor. (PR#19786) +- Added new function to parse connection string into values using `ParseConnectionString` and `ConnectionStringProperties`. (PR#19855) + +### Breaking Changes + +- ProcessorOptions.OwnerLevel has been removed. The Processor uses 0 as the owner level. +- Uses the public release of `github.com/Azure/azure-sdk-for-go/sdk/storage/azblob` package rather than using an internal copy. + For an example, see [example_consuming_with_checkpoints_test.go](https://github.com/Azure/azure-sdk-for-go/blob/main/sdk/messaging/azeventhubs/example_consuming_with_checkpoints_test.go). + +## 0.4.0 (2023-01-10) + +### Bugs Fixed + +- User-Agent was incorrectly formatted in our AMQP-based clients. (PR#19712) +- Connection recovery has been improved, removing some unnecessasry retries as well as adding a bound around + some operations (Close) that could potentially block recovery for a long time. (PR#19683) + +## 0.3.0 (2022-11-10) + +### Bugs Fixed + +- $cbs link is properly closed, even on cancellation (#19492) + +### Breaking Changes + +- ProducerClient.SendEventBatch renamed to ProducerClient.SendEventDataBatch, to align with + the name of the type. + +## 0.2.0 (2022-10-17) + +### Features Added + +- Raw AMQP message support, including full support for encoding Body (Value, Sequence and also multiple byte slices for Data). See ExampleEventDataBatch_AddEventData_rawAMQPMessages for some concrete examples. (PR#19156) +- Prefetch is now enabled by default. Prefetch allows the Event Hubs client to maintain a continuously full cache of events, controlled by PartitionClientOptions.Prefetch. (PR#19281) +- ConsumerClient.ID() returns a unique ID representing each instance of ConsumerClient. + +### Breaking Changes + +- EventDataBatch.NumMessages() renamed to EventDataBatch.NumEvents() +- Prefetch is now enabled by default. To disable it set PartitionClientOptions.Prefetch to -1. +- NewWebSocketConnArgs renamed to WebSocketConnParams +- Code renamed to ErrorCode, including associated constants like `ErrorCodeOwnershipLost`. +- OwnershipData, CheckpointData, and CheckpointStoreAddress have been folded into their individual structs: Ownership and Checkpoint. +- StartPosition and OwnerLevel were erroneously included in the ConsumerClientOptions struct - they've been removed. These can be + configured in the PartitionClientOptions. + +### Bugs Fixed + +- Retries now respect cancellation when they're in the "delay before next try" phase. (PR#19295) +- Fixed a potential leak which could cause us to open and leak a $cbs link connection, resulting in errors. (PR#19326) + +## 0.1.1 (2022-09-08) + +### Features Added + +- Adding in the new Processor type, which can be used to do distributed (and load balanced) consumption of events, using a + CheckpointStore. The built-in checkpoints.BlobStore uses Azure Blob Storage for persistence. A full example is + in [example_consuming_with_checkpoints_test.go](https://github.com/Azure/azure-sdk-for-go/blob/main/sdk/messaging/azeventhubs/example_consuming_with_checkpoints_test.go). + +### Breaking Changes + +- In the first beta, ConsumerClient took constructor parameter that required a partition ID, which meant you had to create + multiple ConsumerClients if you wanted to consume multiple partitions. ConsumerClient can now create multiple PartitionClient + instances (using ConsumerClient.NewPartitionClient), which allows you to share the same AMQP connection and receive from multiple + partitions simultaneously. +- Changes to EventData/ReceivedEventData: + + - ReceivedEventData now embeds EventData for fields common between the two, making it easier to change and resend. + - `ApplicationProperties` renamed to `Properties`. + - `PartitionKey` removed from `EventData`. To send events using a PartitionKey you must set it in the options + when creating the EventDataBatch: + + ```go + batch, err := producerClient.NewEventDataBatch(context.TODO(), &azeventhubs.NewEventDataBatchOptions{ + PartitionKey: to.Ptr("partition key"), + }) + ``` + +### Bugs Fixed + +- ReceivedEventData.Offset was incorrectly parsed, resulting in it always being 0. +- Added missing fields to ReceivedEventData and EventData (CorrelationID) +- PartitionKey property was not populated for messages sent via batch. + +## 0.1.0 (2022-08-11) + +- Initial preview for the new version of the Azure Event Hubs Go SDK. diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/LICENSE.txt b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/LICENSE.txt new file mode 100644 index 00000000000..b2f52a2bad4 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/LICENSE.txt @@ -0,0 +1,21 @@ +Copyright (c) Microsoft Corporation. + +MIT License + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/README.md b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/README.md new file mode 100644 index 00000000000..bd724a1810b --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/README.md @@ -0,0 +1,133 @@ +# Azure Event Hubs Client Module for Go + +[Azure Event Hubs](https://azure.microsoft.com/services/event-hubs/) is a big data streaming platform and event ingestion service from Microsoft. For more information about Event Hubs see: [link](https://docs.microsoft.com/azure/event-hubs/event-hubs-about). + +Use the client library `github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs` in your application to: + +- Send events to an event hub. +- Consume events from an event hub. + +Key links: +- [Source code][source] +- [API Reference Documentation][godoc] +- [Product documentation](https://azure.microsoft.com/services/event-hubs/) +- [Samples][godoc_examples] + +## Getting started + +### Install the package + +Install the Azure Event Hubs client module for Go with `go get`: + +```bash +go get github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs +``` + +### Prerequisites + +- Go, version 1.18 or higher +- An [Azure subscription](https://azure.microsoft.com/free/) +- An [Event Hub namespace](https://docs.microsoft.com/azure/event-hubs/). +- An Event Hub. You can create an event hub in your Event Hubs Namespace using the [Azure Portal](https://docs.microsoft.com/azure/event-hubs/event-hubs-create), or the [Azure CLI](https://docs.microsoft.com/azure/event-hubs/event-hubs-quickstart-cli). + +### Authenticate the client + +Event Hub clients are created using a TokenCredential from the [Azure Identity package][azure_identity_pkg], like [DefaultAzureCredential][default_azure_credential]. +You can also create a client using a connection string. + +#### Using a service principal + - ConsumerClient: [link](https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs#example-NewConsumerClient) + - ProducerClient: [link](https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs#example-NewProducerClient) + +#### Using a connection string + - ConsumerClient: [link](https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs#example-NewConsumerClientFromConnectionString) + - ProducerClient: [link](https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs#example-NewProducerClientFromConnectionString) + +# Key concepts + +An Event Hub [**namespace**](https://docs.microsoft.com/azure/event-hubs/event-hubs-features#namespace) can have multiple event hubs. Each event hub, in turn, contains [**partitions**](https://docs.microsoft.com/azure/event-hubs/event-hubs-features#partitions) which store events. + +Events are published to an event hub using an [event publisher](https://docs.microsoft.com/azure/event-hubs/event-hubs-features#event-publishers). In this package, the event publisher is the [ProducerClient](https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs#ProducerClient) + +Events can be consumed from an event hub using an [event consumer](https://docs.microsoft.com/azure/event-hubs/event-hubs-features#event-consumers). In this package there are two types for consuming events: +- The basic event consumer is the, in the [ConsumerClient](https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs#ConsumerClient). This consumer is useful if you already known which partitions you want to receive from. +- A distributed event consumer, which uses Azure Blobs for checkpointing and coordination. This is implemented in the [Processor](https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs#Processor). This is useful when you want to have the partition assignment be dynamically chosen, and balanced with other Processor instances. + +For more information about Event Hubs features and terminology can be found here: [link](https://docs.microsoft.com/azure/event-hubs/event-hubs-features) + +# Examples + +Examples for various scenarios can be found on [pkg.go.dev](https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs#pkg-examples) or in the example*_test.go files in our GitHub repo for [azeventhubs](https://github.com/Azure/azure-sdk-for-go/blob/main/sdk/messaging/azeventhubs). + +# Troubleshooting + +### Logging + +This module uses the classification-based logging implementation in `azcore`. To enable console logging for all SDK modules, set the environment variable `AZURE_SDK_GO_LOGGING` to `all`. + +Use the `azcore/log` package to control log event output or to enable logs for `azeventhubs` only. For example: + +```go +import ( + "fmt" + azlog "github.com/Azure/azure-sdk-for-go/sdk/azcore/log" +) + +// print log output to stdout +azlog.SetListener(func(event azlog.Event, s string) { + fmt.Printf("[%s] %s\n", event, s) +}) + +// pick the set of events to log +azlog.SetEvents( + azeventhubs.EventConn, + azeventhubs.EventAuth, + azeventhubs.EventProducer, + azeventhubs.EventConsumer, +) +``` + +## Contributing +For details on contributing to this repository, see the [contributing guide][azure_sdk_for_go_contributing]. + +This project welcomes contributions and suggestions. Most contributions require you to agree to a +Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us +the rights to use your contribution. For details, visit https://cla.microsoft.com. + +When you submit a pull request, a CLA-bot will automatically determine whether you need to provide +a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the instructions +provided by the bot. You will only need to do this once across all repos using our CLA. + +This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). +For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or +contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. + +### Additional Helpful Links for Contributors +Many people all over the world have helped make this project better. You'll want to check out: + +* [What are some good first issues for new contributors to the repo?](https://github.com/azure/azure-sdk-for-go/issues?q=is%3Aopen+is%3Aissue+label%3A%22up+for+grabs%22) +* [How to build and test your change][azure_sdk_for_go_contributing_developer_guide] +* [How you can make a change happen!][azure_sdk_for_go_contributing_pull_requests] +* Frequently Asked Questions (FAQ) and Conceptual Topics in the detailed [Azure SDK for Go wiki](https://github.com/azure/azure-sdk-for-go/wiki). + + +### Reporting security issues and security bugs + +Security issues and bugs should be reported privately, via email, to the Microsoft Security Response Center (MSRC) . You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Further information, including the MSRC PGP key, can be found in the [Security TechCenter](https://www.microsoft.com/msrc/faqs-report-an-issue). + +### License + +Azure SDK for Go is licensed under the [MIT](https://github.com/Azure/azure-sdk-for-go/blob/main/sdk/messaging/azeventhubs/LICENSE.txt) license. + + +[azure_sdk_for_go_contributing]: https://github.com/Azure/azure-sdk-for-go/blob/main/CONTRIBUTING.md +[azure_sdk_for_go_contributing_developer_guide]: https://github.com/Azure/azure-sdk-for-go/blob/main/CONTRIBUTING.md#developer-guide +[azure_sdk_for_go_contributing_pull_requests]: https://github.com/Azure/azure-sdk-for-go/blob/main/CONTRIBUTING.md#pull-requests + +[azure_identity_pkg]: https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/azidentity +[default_azure_credential]: https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/azidentity#NewDefaultAzureCredential +[source]: https://github.com/Azure/azure-sdk-for-go/tree/main/sdk/messaging/azeventhubs +[godoc]: https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs +[godoc_examples]: https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs#pkg-examples + +![Impressions](https://azure-sdk-impressions.azurewebsites.net/api/impressions/azure-sdk-for-go%2Fsdk%2Fmessaging%2Fazeventhubs%2FREADME.png) diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/amqp_message.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/amqp_message.go new file mode 100644 index 00000000000..2e0bc54045f --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/amqp_message.go @@ -0,0 +1,271 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azeventhubs + +import ( + "time" + + "github.com/Azure/go-amqp" +) + +// AMQPAnnotatedMessage represents the AMQP message, as received from Event Hubs. +// For details about these properties, refer to the AMQP specification: +// +// https://docs.oasis-open.org/amqp/core/v1.0/os/amqp-core-messaging-v1.0-os.html#section-message-format +// +// Some fields in this struct are typed 'any', which means they will accept AMQP primitives, or in some +// cases slices and maps. +// +// AMQP simple types include: +// - int (any size), uint (any size) +// - float (any size) +// - string +// - bool +// - time.Time +type AMQPAnnotatedMessage struct { + // ApplicationProperties corresponds to the "application-properties" section of an AMQP message. + // + // The values of the map are restricted to AMQP simple types, as listed in the comment for AMQPAnnotatedMessage. + ApplicationProperties map[string]any + + // Body represents the body of an AMQP message. + Body AMQPAnnotatedMessageBody + + // DeliveryAnnotations corresponds to the "delivery-annotations" section in an AMQP message. + // + // The values of the map are restricted to AMQP simple types, as listed in the comment for AMQPAnnotatedMessage. + DeliveryAnnotations map[any]any + + // DeliveryTag corresponds to the delivery-tag property of the TRANSFER frame + // for this message. + DeliveryTag []byte + + // Footer is the transport footers for this AMQP message. + // + // The values of the map are restricted to AMQP simple types, as listed in the comment for AMQPAnnotatedMessage. + Footer map[any]any + + // Header is the transport headers for this AMQP message. + Header *AMQPAnnotatedMessageHeader + + // MessageAnnotations corresponds to the message-annotations section of an AMQP message. + // + // The values of the map are restricted to AMQP simple types, as listed in the comment for AMQPAnnotatedMessage. + MessageAnnotations map[any]any + + // Properties corresponds to the properties section of an AMQP message. + Properties *AMQPAnnotatedMessageProperties +} + +// AMQPAnnotatedMessageProperties represents the properties of an AMQP message. +// See here for more details: +// http://docs.oasis-open.org/amqp/core/v1.0/os/amqp-core-messaging-v1.0-os.html#type-properties +type AMQPAnnotatedMessageProperties struct { + // AbsoluteExpiryTime corresponds to the 'absolute-expiry-time' property. + AbsoluteExpiryTime *time.Time + + // ContentEncoding corresponds to the 'content-encoding' property. + ContentEncoding *string + + // ContentType corresponds to the 'content-type' property + ContentType *string + + // CorrelationID corresponds to the 'correlation-id' property. + // The type of CorrelationID can be a uint64, UUID, []byte, or a string + CorrelationID any + + // CreationTime corresponds to the 'creation-time' property. + CreationTime *time.Time + + // GroupID corresponds to the 'group-id' property. + GroupID *string + + // GroupSequence corresponds to the 'group-sequence' property. + GroupSequence *uint32 + + // MessageID corresponds to the 'message-id' property. + // The type of MessageID can be a uint64, UUID, []byte, or string + MessageID any + + // ReplyTo corresponds to the 'reply-to' property. + ReplyTo *string + + // ReplyToGroupID corresponds to the 'reply-to-group-id' property. + ReplyToGroupID *string + + // Subject corresponds to the 'subject' property. + Subject *string + + // To corresponds to the 'to' property. + To *string + + // UserID corresponds to the 'user-id' property. + UserID []byte +} + +// AMQPAnnotatedMessageBody represents the body of an AMQP message. +// Only one of these fields can be used a a time. They are mutually exclusive. +type AMQPAnnotatedMessageBody struct { + // Data is encoded/decoded as multiple data sections in the body. + Data [][]byte + + // Sequence is encoded/decoded as one or more amqp-sequence sections in the body. + // + // The values of the slices are are restricted to AMQP simple types, as listed in the comment for AMQPAnnotatedMessage. + Sequence [][]any + + // Value is encoded/decoded as the amqp-value section in the body. + // + // The type of Value can be any of the AMQP simple types, as listed in the comment for AMQPAnnotatedMessage, + // as well as slices or maps of AMQP simple types. + Value any +} + +// AMQPAnnotatedMessageHeader carries standard delivery details about the transfer +// of a message. +// See https://docs.oasis-open.org/amqp/core/v1.0/os/amqp-core-messaging-v1.0-os.html#type-header +// for more details. +type AMQPAnnotatedMessageHeader struct { + // DeliveryCount is the number of unsuccessful previous attempts to deliver this message. + // It corresponds to the 'delivery-count' property. + DeliveryCount uint32 + + // Durable corresponds to the 'durable' property. + Durable bool + + // FirstAcquirer corresponds to the 'first-acquirer' property. + FirstAcquirer bool + + // Priority corresponds to the 'priority' property. + Priority uint8 + + // TTL corresponds to the 'ttl' property. + TTL time.Duration +} + +// toAMQPMessage converts between our (azeventhubs) AMQP message +// to the underlying message used by go-amqp. +func (am *AMQPAnnotatedMessage) toAMQPMessage() *amqp.Message { + var header *amqp.MessageHeader + + if am.Header != nil { + header = &amqp.MessageHeader{ + DeliveryCount: am.Header.DeliveryCount, + Durable: am.Header.Durable, + FirstAcquirer: am.Header.FirstAcquirer, + Priority: am.Header.Priority, + TTL: am.Header.TTL, + } + } + + var properties *amqp.MessageProperties + + if am.Properties != nil { + properties = &amqp.MessageProperties{ + AbsoluteExpiryTime: am.Properties.AbsoluteExpiryTime, + ContentEncoding: am.Properties.ContentEncoding, + ContentType: am.Properties.ContentType, + CorrelationID: am.Properties.CorrelationID, + CreationTime: am.Properties.CreationTime, + GroupID: am.Properties.GroupID, + GroupSequence: am.Properties.GroupSequence, + MessageID: am.Properties.MessageID, + ReplyTo: am.Properties.ReplyTo, + ReplyToGroupID: am.Properties.ReplyToGroupID, + Subject: am.Properties.Subject, + To: am.Properties.To, + UserID: am.Properties.UserID, + } + } else { + properties = &amqp.MessageProperties{} + } + + var footer amqp.Annotations + + if am.Footer != nil { + footer = (amqp.Annotations)(am.Footer) + } + + return &amqp.Message{ + Annotations: copyAnnotations(am.MessageAnnotations), + ApplicationProperties: am.ApplicationProperties, + Data: am.Body.Data, + DeliveryAnnotations: amqp.Annotations(am.DeliveryAnnotations), + DeliveryTag: am.DeliveryTag, + Footer: footer, + Header: header, + Properties: properties, + Sequence: am.Body.Sequence, + Value: am.Body.Value, + } +} + +func copyAnnotations(src map[any]any) amqp.Annotations { + if src == nil { + return amqp.Annotations{} + } + + dest := amqp.Annotations{} + + for k, v := range src { + dest[k] = v + } + + return dest +} + +func newAMQPAnnotatedMessage(goAMQPMessage *amqp.Message) *AMQPAnnotatedMessage { + var header *AMQPAnnotatedMessageHeader + + if goAMQPMessage.Header != nil { + header = &AMQPAnnotatedMessageHeader{ + DeliveryCount: goAMQPMessage.Header.DeliveryCount, + Durable: goAMQPMessage.Header.Durable, + FirstAcquirer: goAMQPMessage.Header.FirstAcquirer, + Priority: goAMQPMessage.Header.Priority, + TTL: goAMQPMessage.Header.TTL, + } + } + + var properties *AMQPAnnotatedMessageProperties + + if goAMQPMessage.Properties != nil { + properties = &AMQPAnnotatedMessageProperties{ + AbsoluteExpiryTime: goAMQPMessage.Properties.AbsoluteExpiryTime, + ContentEncoding: goAMQPMessage.Properties.ContentEncoding, + ContentType: goAMQPMessage.Properties.ContentType, + CorrelationID: goAMQPMessage.Properties.CorrelationID, + CreationTime: goAMQPMessage.Properties.CreationTime, + GroupID: goAMQPMessage.Properties.GroupID, + GroupSequence: goAMQPMessage.Properties.GroupSequence, + MessageID: goAMQPMessage.Properties.MessageID, + ReplyTo: goAMQPMessage.Properties.ReplyTo, + ReplyToGroupID: goAMQPMessage.Properties.ReplyToGroupID, + Subject: goAMQPMessage.Properties.Subject, + To: goAMQPMessage.Properties.To, + UserID: goAMQPMessage.Properties.UserID, + } + } + + var footer map[any]any + + if goAMQPMessage.Footer != nil { + footer = (map[any]any)(goAMQPMessage.Footer) + } + + return &AMQPAnnotatedMessage{ + MessageAnnotations: map[any]any(goAMQPMessage.Annotations), + ApplicationProperties: goAMQPMessage.ApplicationProperties, + Body: AMQPAnnotatedMessageBody{ + Data: goAMQPMessage.Data, + Sequence: goAMQPMessage.Sequence, + Value: goAMQPMessage.Value, + }, + DeliveryAnnotations: map[any]any(goAMQPMessage.DeliveryAnnotations), + DeliveryTag: goAMQPMessage.DeliveryTag, + Footer: footer, + Header: header, + Properties: properties, + } +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/checkpoint_store.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/checkpoint_store.go new file mode 100644 index 00000000000..83c1c3e54fa --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/checkpoint_store.go @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azeventhubs + +import ( + "context" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" +) + +// CheckpointStore is used by multiple consumers to coordinate progress and ownership for partitions. +type CheckpointStore interface { + // ClaimOwnership attempts to claim ownership of the partitions in partitionOwnership and returns + // the actual partitions that were claimed. + ClaimOwnership(ctx context.Context, partitionOwnership []Ownership, options *ClaimOwnershipOptions) ([]Ownership, error) + + // ListCheckpoints lists all the available checkpoints. + ListCheckpoints(ctx context.Context, fullyQualifiedNamespace string, eventHubName string, consumerGroup string, options *ListCheckpointsOptions) ([]Checkpoint, error) + + // ListOwnership lists all ownerships. + ListOwnership(ctx context.Context, fullyQualifiedNamespace string, eventHubName string, consumerGroup string, options *ListOwnershipOptions) ([]Ownership, error) + + // SetCheckpoint updates a specific checkpoint with a sequence and offset. + SetCheckpoint(ctx context.Context, checkpoint Checkpoint, options *SetCheckpointOptions) error +} + +// Ownership tracks which consumer owns a particular partition. +type Ownership struct { + ConsumerGroup string + EventHubName string + FullyQualifiedNamespace string + PartitionID string + + OwnerID string // the owner ID of the Processor + LastModifiedTime time.Time // used when calculating if ownership has expired + ETag *azcore.ETag // the ETag, used when attempting to claim or update ownership of a partition. +} + +// Checkpoint tracks the last succesfully processed event in a partition. +type Checkpoint struct { + ConsumerGroup string + EventHubName string + FullyQualifiedNamespace string + PartitionID string + + Offset *int64 // the last succesfully processed Offset. + SequenceNumber *int64 // the last succesfully processed SequenceNumber. +} + +// ListCheckpointsOptions contains optional parameters for the ListCheckpoints function +type ListCheckpointsOptions struct { + // For future expansion +} + +// ListOwnershipOptions contains optional parameters for the ListOwnership function +type ListOwnershipOptions struct { + // For future expansion +} + +// SetCheckpointOptions contains optional parameters for the UpdateCheckpoint function +type SetCheckpointOptions struct { + // For future expansion +} + +// ClaimOwnershipOptions contains optional parameters for the ClaimOwnership function +type ClaimOwnershipOptions struct { + // For future expansion +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/ci.yml b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/ci.yml new file mode 100644 index 00000000000..ab79ba00916 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/ci.yml @@ -0,0 +1,35 @@ +# NOTE: Please refer to https://aka.ms/azsdk/engsys/ci-yaml before editing this file. +trigger: + branches: + include: + - main + - feature/* + - hotfix/* + - release/* + paths: + include: + - sdk/messaging/azeventhubs + +pr: + branches: + include: + - main + - feature/* + - hotfix/* + - release/* + paths: + include: + - sdk/messaging/azeventhubs + +extends: + template: /eng/pipelines/templates/jobs/archetype-sdk-client.yml + parameters: + ServiceDirectory: 'messaging/azeventhubs' + # (live tests not yet ready to run) + RunLiveTests: true + SupportedClouds: 'Public,UsGov,China' + EnvVars: + AZURE_CLIENT_ID: $(AZEVENTHUBS_CLIENT_ID) + AZURE_TENANT_ID: $(AZEVENTHUBS_TENANT_ID) + AZURE_CLIENT_SECRET: $(AZEVENTHUBS_CLIENT_SECRET) + AZURE_SUBSCRIPTION_ID: $(AZEVENTHUBS_SUBSCRIPTION_ID) diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/connection_string_properties.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/connection_string_properties.go new file mode 100644 index 00000000000..d28e837e237 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/connection_string_properties.go @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azeventhubs + +import "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/exported" + +// ConnectionStringProperties are the properties of a connection string +// as returned by [ParseConnectionString]. +type ConnectionStringProperties = exported.ConnectionStringProperties + +// ParseConnectionString takes a connection string from the Azure portal and returns the +// parsed representation. +// +// There are two supported formats: +// 1. Connection strings generated from the portal (or elsewhere) that contain an embedded key and keyname. +// 2. A connection string with an embedded SharedAccessSignature: +// Endpoint=sb://.servicebus.windows.net;SharedAccessSignature=SharedAccessSignature sr=.servicebus.windows.net&sig=&se=&skn=" +func ParseConnectionString(connStr string) (ConnectionStringProperties, error) { + return exported.ParseConnectionString(connStr) +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/consumer_client.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/consumer_client.go new file mode 100644 index 00000000000..84716b4d012 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/consumer_client.go @@ -0,0 +1,262 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azeventhubs + +import ( + "context" + "crypto/tls" + "fmt" + "net" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/internal/uuid" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/amqpwrap" +) + +// ConsumerClientOptions configures optional parameters for a ConsumerClient. +type ConsumerClientOptions struct { + // ApplicationID is used as the identifier when setting the User-Agent property. + ApplicationID string + + // InstanceID is a unique name used to identify the consumer. This can help with + // diagnostics as this name will be returned in error messages. By default, + // an identifier will be automatically generated. + InstanceID string + + // NewWebSocketConn is a function that can create a net.Conn for use with websockets. + // For an example, see ExampleNewClient_usingWebsockets() function in example_client_test.go. + NewWebSocketConn func(ctx context.Context, args WebSocketConnParams) (net.Conn, error) + + // RetryOptions controls how often operations are retried from this client and any + // Receivers and Senders created from this client. + RetryOptions RetryOptions + + // TLSConfig configures a client with a custom *tls.Config. + TLSConfig *tls.Config +} + +// ConsumerClient can create PartitionClient instances, which can read events from +// a partition. +type ConsumerClient struct { + consumerGroup string + eventHub string + + // instanceID is a customer supplied instanceID that can be passed to Event Hubs. + // It'll be returned in error messages and can be useful for customers when + // troubleshooting. + instanceID string + + links *internal.Links[amqpwrap.RPCLink] + namespace *internal.Namespace + retryOptions RetryOptions +} + +// NewConsumerClient creates a ConsumerClient which uses an azcore.TokenCredential for authentication. You +// MUST call [ConsumerClient.Close] on this client to avoid leaking resources. +// +// The fullyQualifiedNamespace is the Event Hubs namespace name (ex: myeventhub.servicebus.windows.net) +// The credential is one of the credentials in the [azidentity] package. +// +// [azidentity]: https://github.com/Azure/azure-sdk-for-go/blob/main/sdk/azidentity +func NewConsumerClient(fullyQualifiedNamespace string, eventHub string, consumerGroup string, credential azcore.TokenCredential, options *ConsumerClientOptions) (*ConsumerClient, error) { + return newConsumerClient(consumerClientArgs{ + consumerGroup: consumerGroup, + fullyQualifiedNamespace: fullyQualifiedNamespace, + eventHub: eventHub, + credential: credential, + }, options) +} + +// NewConsumerClientFromConnectionString creates a ConsumerClient from a connection string. You +// MUST call [ConsumerClient.Close] on this client to avoid leaking resources. +// +// connectionString can be one of two formats - with or without an EntityPath key. +// +// When the connection string does not have an entity path, as shown below, the eventHub parameter cannot +// be empty and should contain the name of your event hub. +// +// Endpoint=sb://.servicebus.windows.net/;SharedAccessKeyName=;SharedAccessKey= +// +// When the connection string DOES have an entity path, as shown below, the eventHub parameter must be empty. +// +// Endpoint=sb://.servicebus.windows.net/;SharedAccessKeyName=;SharedAccessKey=;EntityPath=; +func NewConsumerClientFromConnectionString(connectionString string, eventHub string, consumerGroup string, options *ConsumerClientOptions) (*ConsumerClient, error) { + props, err := parseConn(connectionString, eventHub) + + if err != nil { + return nil, err + } + + return newConsumerClient(consumerClientArgs{ + consumerGroup: consumerGroup, + connectionString: connectionString, + eventHub: *props.EntityPath, + }, options) +} + +// PartitionClientOptions provides options for the NewPartitionClient function. +type PartitionClientOptions struct { + // StartPosition is the position we will start receiving events from, + // either an offset (inclusive) with Offset, or receiving events received + // after a specific time using EnqueuedTime. + // + // NOTE: you can also use the [Processor], which will automatically manage the start + // value using a [CheckpointStore]. See [example_consuming_with_checkpoints_test.go] for an + // example. + // + // [example_consuming_with_checkpoints_test.go]: https://github.com/Azure/azure-sdk-for-go/blob/main/sdk/messaging/azeventhubs/example_consuming_with_checkpoints_test.go + StartPosition StartPosition + + // OwnerLevel is the priority for this partition client, also known as the 'epoch' level. + // When used, a partition client with a higher OwnerLevel will take ownership of a partition + // from partition clients with a lower OwnerLevel. + // Default is off. + OwnerLevel *int64 + + // Prefetch represents the size of the internal prefetch buffer. When set, + // this client will attempt to always maintain an internal cache of events of + // this size, asynchronously, increasing the odds that ReceiveEvents() will use + // a locally stored cache of events, rather than having to wait for events to + // arrive from the network. + // + // Defaults to 300 events if Prefetch == 0. + // Disabled if Prefetch < 0. + Prefetch int32 +} + +// NewPartitionClient creates a client that can receive events from a partition. By default it starts +// at the latest point in the partition. This can be changed using the options parameter. +// You MUST call [azeventhubs.PartitionClient.Close] on the returned client to avoid leaking resources. +func (cc *ConsumerClient) NewPartitionClient(partitionID string, options *PartitionClientOptions) (*PartitionClient, error) { + return newPartitionClient(partitionClientArgs{ + namespace: cc.namespace, + eventHub: cc.eventHub, + partitionID: partitionID, + instanceID: cc.instanceID, + consumerGroup: cc.consumerGroup, + retryOptions: cc.retryOptions, + }, options) +} + +// GetEventHubProperties gets event hub properties, like the available partition IDs and when the Event Hub was created. +func (cc *ConsumerClient) GetEventHubProperties(ctx context.Context, options *GetEventHubPropertiesOptions) (EventHubProperties, error) { + return getEventHubProperties(ctx, EventConsumer, cc.namespace, cc.links, cc.eventHub, cc.retryOptions, options) +} + +// GetPartitionProperties gets properties for a specific partition. This includes data like the +// last enqueued sequence number, the first sequence number and when an event was last enqueued +// to the partition. +func (cc *ConsumerClient) GetPartitionProperties(ctx context.Context, partitionID string, options *GetPartitionPropertiesOptions) (PartitionProperties, error) { + return getPartitionProperties(ctx, EventConsumer, cc.namespace, cc.links, cc.eventHub, partitionID, cc.retryOptions, options) +} + +// InstanceID is the identifier for this ConsumerClient. +func (cc *ConsumerClient) InstanceID() string { + return cc.instanceID +} + +type consumerClientDetails struct { + FullyQualifiedNamespace string + ConsumerGroup string + EventHubName string + ClientID string +} + +func (cc *ConsumerClient) getDetails() consumerClientDetails { + return consumerClientDetails{ + FullyQualifiedNamespace: cc.namespace.FQDN, + ConsumerGroup: cc.consumerGroup, + EventHubName: cc.eventHub, + ClientID: cc.InstanceID(), + } +} + +// Close releases resources for this client. +func (cc *ConsumerClient) Close(ctx context.Context) error { + return cc.namespace.Close(ctx, true) +} + +type consumerClientArgs struct { + connectionString string + + // the Event Hubs namespace name (ex: myservicebus.servicebus.windows.net) + fullyQualifiedNamespace string + credential azcore.TokenCredential + + consumerGroup string + eventHub string +} + +func newConsumerClient(args consumerClientArgs, options *ConsumerClientOptions) (*ConsumerClient, error) { + if options == nil { + options = &ConsumerClientOptions{} + } + + instanceID, err := getInstanceID(options.InstanceID) + + if err != nil { + return nil, err + } + + client := &ConsumerClient{ + consumerGroup: args.consumerGroup, + eventHub: args.eventHub, + instanceID: instanceID, + } + + var nsOptions []internal.NamespaceOption + + if args.connectionString != "" { + nsOptions = append(nsOptions, internal.NamespaceWithConnectionString(args.connectionString)) + } else if args.credential != nil { + option := internal.NamespaceWithTokenCredential( + args.fullyQualifiedNamespace, + args.credential) + + nsOptions = append(nsOptions, option) + } + + client.retryOptions = options.RetryOptions + + if options.TLSConfig != nil { + nsOptions = append(nsOptions, internal.NamespaceWithTLSConfig(options.TLSConfig)) + } + + if options.NewWebSocketConn != nil { + nsOptions = append(nsOptions, internal.NamespaceWithWebSocket(options.NewWebSocketConn)) + } + + if options.ApplicationID != "" { + nsOptions = append(nsOptions, internal.NamespaceWithUserAgent(options.ApplicationID)) + } + + nsOptions = append(nsOptions, internal.NamespaceWithRetryOptions(options.RetryOptions)) + + tempNS, err := internal.NewNamespace(nsOptions...) + + if err != nil { + return nil, err + } + + client.namespace = tempNS + client.links = internal.NewLinks[amqpwrap.RPCLink](tempNS, fmt.Sprintf("%s/$management", client.eventHub), nil, nil) + + return client, nil +} + +func getInstanceID(optionalID string) (string, error) { + if optionalID != "" { + return optionalID, nil + } + + // generate a new one + id, err := uuid.New() + + if err != nil { + return "", err + } + + return id.String(), nil +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/doc.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/doc.go new file mode 100644 index 00000000000..25375f6dc9c --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/doc.go @@ -0,0 +1,15 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Package azeventhubs provides clients for sending events and consuming events. +// +// For sending events, use the [ProducerClient]. +// +// There are two clients for consuming events: +// - [Processor], which handles checkpointing and load balancing using durable storage. +// - [ConsumerClient], which is fully manual, but provides full control. + +package azeventhubs diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/error.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/error.go new file mode 100644 index 00000000000..39a7eaa016d --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/error.go @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azeventhubs + +import "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/exported" + +// Error represents an Event Hub specific error. +// NOTE: the Code is considered part of the published API but the message that +// comes back from Error(), as well as the underlying wrapped error, are NOT and +// are subject to change. +type Error = exported.Error + +// ErrorCode is an error code, usable by consuming code to work with +// programatically. +type ErrorCode = exported.ErrorCode + +const ( + // ErrorCodeUnauthorizedAccess means the credentials provided are not valid for use with + // a particular entity, or have expired. + ErrorCodeUnauthorizedAccess ErrorCode = exported.ErrorCodeUnauthorizedAccess + + // ErrorCodeConnectionLost means our connection was lost and all retry attempts failed. + // This typically reflects an extended outage or connection disruption and may + // require manual intervention. + ErrorCodeConnectionLost ErrorCode = exported.ErrorCodeConnectionLost + + // ErrorCodeOwnershipLost means that a partition that you were reading from was opened + // by another link with a higher epoch/owner level. + ErrorCodeOwnershipLost ErrorCode = exported.ErrorCodeOwnershipLost +) diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/event_data.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/event_data.go new file mode 100644 index 00000000000..00b89a3ca0e --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/event_data.go @@ -0,0 +1,195 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azeventhubs + +import ( + "errors" + "strconv" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/eh" + "github.com/Azure/go-amqp" +) + +// EventData is an event that can be sent, using the ProducerClient, to an Event Hub. +type EventData struct { + // Properties can be used to store custom metadata for a message. + Properties map[string]any + + // Body is the payload for a message. + Body []byte + + // ContentType describes the payload of the message, with a descriptor following + // the format of Content-Type, specified by RFC2045 (ex: "application/json"). + ContentType *string + + // CorrelationID is a client-specific id that can be used to mark or identify messages + // between clients. + // CorrelationID can be a uint64, UUID, []byte, or string + CorrelationID any + + // MessageID is an application-defined value that uniquely identifies + // the message and its payload. The identifier is a free-form string. + // + // If enabled, the duplicate detection feature identifies and removes further submissions + // of messages with the same MessageId. + MessageID *string +} + +// ReceivedEventData is an event that has been received using the ConsumerClient. +type ReceivedEventData struct { + EventData + + // EnqueuedTime is the UTC time when the message was accepted and stored by Event Hubs. + EnqueuedTime *time.Time + + // PartitionKey is used with a partitioned entity and enables assigning related messages + // to the same internal partition. This ensures that the submission sequence order is correctly + // recorded. The partition is chosen by a hash function in Event Hubs and cannot be chosen + // directly. + PartitionKey *string + + // Offset is the offset of the event. + Offset int64 + + // RawAMQPMessage is the AMQP message, as received by the client. This can be useful to get access + // to properties that are not exposed by ReceivedEventData such as payloads encoded into the + // Value or Sequence section, payloads sent as multiple Data sections, as well as Footer + // and Header fields. + RawAMQPMessage *AMQPAnnotatedMessage + + // SequenceNumber is a unique number assigned to a message by Event Hubs. + SequenceNumber int64 + + // Properties set by the Event Hubs service. + SystemProperties map[string]any +} + +// Event Hubs custom properties +const ( + // Annotation properties + partitionKeyAnnotation = "x-opt-partition-key" + sequenceNumberAnnotation = "x-opt-sequence-number" + offsetNumberAnnotation = "x-opt-offset" + enqueuedTimeAnnotation = "x-opt-enqueued-time" +) + +func (e *EventData) toAMQPMessage() *amqp.Message { + amqpMsg := amqp.NewMessage(e.Body) + + var messageID any + + if e.MessageID != nil { + messageID = *e.MessageID + } + + amqpMsg.Properties = &amqp.MessageProperties{ + MessageID: messageID, + } + + amqpMsg.Properties.ContentType = e.ContentType + amqpMsg.Properties.CorrelationID = e.CorrelationID + + if len(e.Properties) > 0 { + amqpMsg.ApplicationProperties = make(map[string]any) + for key, value := range e.Properties { + amqpMsg.ApplicationProperties[key] = value + } + } + + return amqpMsg +} + +// newReceivedEventData creates a received message from an AMQP message. +// NOTE: this converter assumes that the Body of this message will be the first +// serialized byte array in the Data section of the messsage. +func newReceivedEventData(amqpMsg *amqp.Message) (*ReceivedEventData, error) { + re := &ReceivedEventData{ + RawAMQPMessage: newAMQPAnnotatedMessage(amqpMsg), + } + + if len(amqpMsg.Data) == 1 { + re.Body = amqpMsg.Data[0] + } + + if amqpMsg.Properties != nil { + if id, ok := amqpMsg.Properties.MessageID.(string); ok { + re.MessageID = &id + } + + re.ContentType = amqpMsg.Properties.ContentType + re.CorrelationID = amqpMsg.Properties.CorrelationID + } + + if amqpMsg.ApplicationProperties != nil { + re.Properties = make(map[string]any, len(amqpMsg.ApplicationProperties)) + for key, value := range amqpMsg.ApplicationProperties { + re.Properties[key] = value + } + } + + if err := updateFromAMQPAnnotations(amqpMsg, re); err != nil { + return nil, err + } + + return re, nil +} + +// the "SystemProperties" in an EventData are any annotations that are +// NOT available at the top level as normal fields. So excluding sequence +// number, offset, enqueued time, and partition key. +func updateFromAMQPAnnotations(src *amqp.Message, dest *ReceivedEventData) error { + if src.Annotations == nil { + return nil + } + + for kAny, v := range src.Annotations { + keyStr, keyIsString := kAny.(string) + + if !keyIsString { + continue + } + + switch keyStr { + case sequenceNumberAnnotation: + if asInt64, ok := eh.ConvertToInt64(v); ok { + dest.SequenceNumber = asInt64 + continue + } + + return errors.New("sequence number cannot be converted to an int64") + case partitionKeyAnnotation: + if asString, ok := v.(string); ok { + dest.PartitionKey = to.Ptr(asString) + continue + } + + return errors.New("partition key cannot be converted to a string") + case enqueuedTimeAnnotation: + if asTime, ok := v.(time.Time); ok { + dest.EnqueuedTime = &asTime + continue + } + + return errors.New("enqueued time cannot be converted to a time.Time") + case offsetNumberAnnotation: + if offsetStr, ok := v.(string); ok { + if offset, err := strconv.ParseInt(offsetStr, 10, 64); err == nil { + dest.Offset = offset + continue + } + } + return errors.New("offset cannot be converted to an int64") + default: + if dest.SystemProperties == nil { + dest.SystemProperties = map[string]any{} + } + + dest.SystemProperties[keyStr] = v + } + } + + return nil +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/event_data_batch.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/event_data_batch.go new file mode 100644 index 00000000000..edc6517b90b --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/event_data_batch.go @@ -0,0 +1,236 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azeventhubs + +import ( + "errors" + "fmt" + "sync" + + "github.com/Azure/azure-sdk-for-go/sdk/internal/uuid" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/amqpwrap" + "github.com/Azure/go-amqp" +) + +// ErrEventDataTooLarge is returned when a message cannot fit into a batch when using the [azeventhubs.EventDataBatch.AddEventData] function. +var ErrEventDataTooLarge = errors.New("the EventData could not be added because it is too large for the batch") + +type ( + // EventDataBatch is used to efficiently pack up EventData before sending it to Event Hubs. + // + // EventDataBatch's are not meant to be created directly. Use [ProducerClient.NewEventDataBatch], + // which will create them with the proper size limit for your Event Hub. + EventDataBatch struct { + mu sync.RWMutex + + marshaledMessages [][]byte + batchEnvelope *amqp.Message + + maxBytes uint64 + currentSize uint64 + + partitionID *string + partitionKey *string + } +) + +const ( + batchMessageFormat uint32 = 0x80013700 +) + +// AddEventDataOptions contains optional parameters for the AddEventData function. +type AddEventDataOptions struct { + // For future expansion +} + +// AddEventData adds an EventData to the batch, failing if the EventData would +// cause the EventDataBatch to be too large to send. +// +// This size limit was set when the EventDataBatch was created, in options to +// [ProducerClient.NewEventDataBatch], or (by default) from Event +// Hubs itself. +// +// Returns ErrMessageTooLarge if the event cannot fit, or a non-nil error for +// other failures. +func (b *EventDataBatch) AddEventData(ed *EventData, options *AddEventDataOptions) error { + return b.addAMQPMessage(ed.toAMQPMessage()) +} + +// AddAMQPAnnotatedMessage adds an AMQPAnnotatedMessage to the batch, failing +// if the AMQPAnnotatedMessage would cause the EventDataBatch to be too large to send. +// +// This size limit was set when the EventDataBatch was created, in options to +// [ProducerClient.NewEventDataBatch], or (by default) from Event +// Hubs itself. +// +// Returns ErrMessageTooLarge if the message cannot fit, or a non-nil error for +// other failures. +func (b *EventDataBatch) AddAMQPAnnotatedMessage(annotatedMessage *AMQPAnnotatedMessage, options *AddEventDataOptions) error { + return b.addAMQPMessage(annotatedMessage.toAMQPMessage()) +} + +// NumBytes is the number of bytes in the batch. +func (b *EventDataBatch) NumBytes() uint64 { + b.mu.RLock() + defer b.mu.RUnlock() + + return b.currentSize +} + +// NumEvents returns the number of events in the batch. +func (b *EventDataBatch) NumEvents() int32 { + b.mu.RLock() + defer b.mu.RUnlock() + + return int32(len(b.marshaledMessages)) +} + +// toAMQPMessage converts this batch into a sendable *amqp.Message +// NOTE: not idempotent! +func (b *EventDataBatch) toAMQPMessage() (*amqp.Message, error) { + b.mu.Lock() + defer b.mu.Unlock() + + if len(b.marshaledMessages) == 0 { + return nil, internal.NewErrNonRetriable("batch is nil or empty") + } + + b.batchEnvelope.Data = make([][]byte, len(b.marshaledMessages)) + b.batchEnvelope.Format = batchMessageFormat + + if b.partitionKey != nil { + if b.batchEnvelope.Annotations == nil { + b.batchEnvelope.Annotations = make(amqp.Annotations) + } + + b.batchEnvelope.Annotations[partitionKeyAnnotation] = *b.partitionKey + } + + copy(b.batchEnvelope.Data, b.marshaledMessages) + return b.batchEnvelope, nil +} + +func (b *EventDataBatch) addAMQPMessage(msg *amqp.Message) error { + if msg.Properties.MessageID == nil || msg.Properties.MessageID == "" { + uid, err := uuid.New() + if err != nil { + return err + } + msg.Properties.MessageID = uid.String() + } + + if b.partitionKey != nil { + if msg.Annotations == nil { + msg.Annotations = make(amqp.Annotations) + } + + msg.Annotations[partitionKeyAnnotation] = *b.partitionKey + } + + bin, err := msg.MarshalBinary() + if err != nil { + return err + } + + b.mu.Lock() + defer b.mu.Unlock() + + if len(b.marshaledMessages) == 0 { + // the first message is special - we use its properties and annotations as the + // actual envelope for the batch message. + batchEnv, batchEnvLen, err := createBatchEnvelope(msg) + + if err != nil { + return err + } + + // (we'll undo this if it turns out the message was too big) + b.currentSize = uint64(batchEnvLen) + b.batchEnvelope = batchEnv + } + + actualPayloadSize := calcActualSizeForPayload(bin) + + if b.currentSize+actualPayloadSize > b.maxBytes { + if len(b.marshaledMessages) == 0 { + // reset our our properties, this didn't end up being our first message. + b.currentSize = 0 + b.batchEnvelope = nil + } + + return ErrEventDataTooLarge + } + + b.currentSize += actualPayloadSize + b.marshaledMessages = append(b.marshaledMessages, bin) + + return nil +} + +// createBatchEnvelope makes a copy of the properties of the message, minus any +// payload fields (like Data, Value or Sequence). The data field will be +// filled in with all the messages when the batch is completed. +func createBatchEnvelope(am *amqp.Message) (*amqp.Message, int, error) { + batchEnvelope := *am + + batchEnvelope.Data = nil + batchEnvelope.Value = nil + batchEnvelope.Sequence = nil + + bytes, err := batchEnvelope.MarshalBinary() + + if err != nil { + return nil, 0, err + } + + return &batchEnvelope, len(bytes), nil +} + +// calcActualSizeForPayload calculates the payload size based +// on overhead from AMQP encoding. +func calcActualSizeForPayload(payload []byte) uint64 { + const vbin8Overhead = 5 + const vbin32Overhead = 8 + + if len(payload) < 256 { + return uint64(vbin8Overhead + len(payload)) + } + + return uint64(vbin32Overhead + len(payload)) +} + +func newEventDataBatch(sender amqpwrap.AMQPSenderCloser, options *EventDataBatchOptions) (*EventDataBatch, error) { + if options == nil { + options = &EventDataBatchOptions{} + } + + if options.PartitionID != nil && options.PartitionKey != nil { + return nil, errors.New("either PartitionID or PartitionKey can be set, but not both") + } + + var batch EventDataBatch + + if options.PartitionID != nil { + // they want to send to a particular partition. The batch size should be the same for any + // link but we might as well use the one they're going to send to. + pid := *options.PartitionID + batch.partitionID = &pid + } else if options.PartitionKey != nil { + partKey := *options.PartitionKey + batch.partitionKey = &partKey + } + + if options.MaxBytes == 0 { + batch.maxBytes = sender.MaxMessageSize() + return &batch, nil + } + + if options.MaxBytes > sender.MaxMessageSize() { + return nil, internal.NewErrNonRetriable(fmt.Sprintf("maximum message size for batch was set to %d bytes, which is larger than the maximum size allowed by link (%d)", options.MaxBytes, sender.MaxMessageSize())) + } + + batch.maxBytes = options.MaxBytes + return &batch, nil +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/amqpInterfaces.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/amqpInterfaces.go new file mode 100644 index 00000000000..f6ea7f0cc37 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/amqpInterfaces.go @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package internal + +import ( + "context" + + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/amqpwrap" +) + +type AMQPReceiver = amqpwrap.AMQPReceiver +type AMQPReceiverCloser = amqpwrap.AMQPReceiverCloser +type AMQPSender = amqpwrap.AMQPSender +type AMQPSenderCloser = amqpwrap.AMQPSenderCloser + +// Closeable is implemented by pretty much any AMQP link/client +// including our own higher level Receiver/Sender. +type Closeable interface { + Close(ctx context.Context) error +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/amqp_fakes.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/amqp_fakes.go new file mode 100644 index 00000000000..a08b084f767 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/amqp_fakes.go @@ -0,0 +1,149 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package internal + +import ( + "context" + + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/amqpwrap" + "github.com/Azure/go-amqp" +) + +type FakeNSForPartClient struct { + NamespaceForAMQPLinks + + Receiver *FakeAMQPReceiver + NewReceiverErr error + NewReceiverCalled int + + Sender *FakeAMQPSender + NewSenderErr error + NewSenderCalled int + + RecoverFn func(ctx context.Context, clientRevision uint64) error +} + +type FakeAMQPSession struct { + amqpwrap.AMQPSession + NS *FakeNSForPartClient + CloseCalled int +} + +type FakeAMQPReceiver struct { + amqpwrap.AMQPReceiverCloser + + // ActiveCredits are incremented and decremented by IssueCredit and Receive. + ActiveCredits int32 + + // IssuedCredit just accumulates, so we can get an idea of how many credits we issued overall. + IssuedCredit []uint32 + + // CreditsSetFromOptions is similar to issuedCredit, but only tracks credits added in via the LinkOptions.Credit + // field (ie, enabling prefetch). + CreditsSetFromOptions int32 + + // ManualCreditsSetFromOptions is the value of the LinkOptions.ManualCredits value. + ManualCreditsSetFromOptions bool + + Messages []*amqp.Message + + NameForLink string + + CloseCalled int + CloseError error +} + +func (ns *FakeNSForPartClient) Recover(ctx context.Context, clientRevision uint64) error { + return ns.RecoverFn(ctx, clientRevision) +} + +func (ns *FakeNSForPartClient) NegotiateClaim(ctx context.Context, entityPath string) (context.CancelFunc, <-chan struct{}, error) { + ctx, cancel := context.WithCancel(ctx) + return cancel, ctx.Done(), nil +} + +func (ns *FakeNSForPartClient) NewAMQPSession(ctx context.Context) (amqpwrap.AMQPSession, uint64, error) { + return &FakeAMQPSession{ + NS: ns, + }, 1, nil +} + +func (sess *FakeAMQPSession) NewReceiver(ctx context.Context, source string, partitionID string, opts *amqp.ReceiverOptions) (amqpwrap.AMQPReceiverCloser, error) { + sess.NS.NewReceiverCalled++ + sess.NS.Receiver.ManualCreditsSetFromOptions = opts.Credit == -1 + sess.NS.Receiver.CreditsSetFromOptions = opts.Credit + + if opts.Credit > 0 { + sess.NS.Receiver.ActiveCredits = opts.Credit + } + + return sess.NS.Receiver, sess.NS.NewReceiverErr +} + +func (sess *FakeAMQPSession) NewSender(ctx context.Context, target string, partitionID string, opts *amqp.SenderOptions) (AMQPSenderCloser, error) { + sess.NS.NewSenderCalled++ + return sess.NS.Sender, sess.NS.NewSenderErr +} + +func (sess *FakeAMQPSession) Close(ctx context.Context) error { + sess.CloseCalled++ + return nil +} + +func (r *FakeAMQPReceiver) Credits() uint32 { + return uint32(r.ActiveCredits) +} + +func (r *FakeAMQPReceiver) IssueCredit(credit uint32) error { + r.ActiveCredits += int32(credit) + r.IssuedCredit = append(r.IssuedCredit, credit) + return nil +} + +func (r *FakeAMQPReceiver) LinkName() string { + return r.NameForLink +} + +func (r *FakeAMQPReceiver) Receive(ctx context.Context, o *amqp.ReceiveOptions) (*amqp.Message, error) { + if len(r.Messages) > 0 { + r.ActiveCredits-- + m := r.Messages[0] + r.Messages = r.Messages[1:] + return m, nil + } else { + <-ctx.Done() + return nil, ctx.Err() + } +} + +func (r *FakeAMQPReceiver) Close(ctx context.Context) error { + r.CloseCalled++ + return r.CloseError +} + +type FakeAMQPSender struct { + amqpwrap.AMQPSenderCloser + CloseCalled int + CloseError error +} + +func (s *FakeAMQPSender) Close(ctx context.Context) error { + s.CloseCalled++ + return s.CloseError +} + +type fakeAMQPClient struct { + amqpwrap.AMQPClient + closeCalled int + session *FakeAMQPSession +} + +func (f *fakeAMQPClient) NewSession(ctx context.Context, opts *amqp.SessionOptions) (amqpwrap.AMQPSession, error) { + return f.session, nil +} + +func (f *fakeAMQPClient) Close() error { + f.closeCalled++ + return nil +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/amqpwrap/amqpwrap.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/amqpwrap/amqpwrap.go new file mode 100644 index 00000000000..750b80c55ea --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/amqpwrap/amqpwrap.go @@ -0,0 +1,307 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Package amqpwrap has some simple wrappers to make it easier to +// abstract the go-amqp types. +package amqpwrap + +import ( + "context" + "errors" + "time" + + "github.com/Azure/go-amqp" +) + +// AMQPReceiver is implemented by *amqp.Receiver +type AMQPReceiver interface { + IssueCredit(credit uint32) error + Receive(ctx context.Context, o *amqp.ReceiveOptions) (*amqp.Message, error) + Prefetched() *amqp.Message + + // settlement functions + AcceptMessage(ctx context.Context, msg *amqp.Message) error + RejectMessage(ctx context.Context, msg *amqp.Message, e *amqp.Error) error + ReleaseMessage(ctx context.Context, msg *amqp.Message) error + ModifyMessage(ctx context.Context, msg *amqp.Message, options *amqp.ModifyMessageOptions) error + + LinkName() string + LinkSourceFilterValue(name string) any + + // wrapper only functions + + // Credits returns the # of credits still active on this link. + Credits() uint32 + + ConnID() uint64 +} + +// AMQPReceiverCloser is implemented by *amqp.Receiver +type AMQPReceiverCloser interface { + AMQPReceiver + Close(ctx context.Context) error +} + +// AMQPSender is implemented by *amqp.Sender +type AMQPSender interface { + Send(ctx context.Context, msg *amqp.Message, o *amqp.SendOptions) error + MaxMessageSize() uint64 + LinkName() string + ConnID() uint64 +} + +// AMQPSenderCloser is implemented by *amqp.Sender +type AMQPSenderCloser interface { + AMQPSender + Close(ctx context.Context) error +} + +// AMQPSession is a simple interface, implemented by *AMQPSessionWrapper. +// It exists only so we can return AMQPReceiver/AMQPSender interfaces. +type AMQPSession interface { + Close(ctx context.Context) error + ConnID() uint64 + NewReceiver(ctx context.Context, source string, partitionID string, opts *amqp.ReceiverOptions) (AMQPReceiverCloser, error) + NewSender(ctx context.Context, target string, partitionID string, opts *amqp.SenderOptions) (AMQPSenderCloser, error) +} + +type AMQPClient interface { + Close() error + NewSession(ctx context.Context, opts *amqp.SessionOptions) (AMQPSession, error) + ID() uint64 +} + +type goamqpConn interface { + NewSession(ctx context.Context, opts *amqp.SessionOptions) (*amqp.Session, error) + Close() error +} + +type goamqpSession interface { + Close(ctx context.Context) error + NewReceiver(ctx context.Context, source string, opts *amqp.ReceiverOptions) (*amqp.Receiver, error) + NewSender(ctx context.Context, target string, opts *amqp.SenderOptions) (*amqp.Sender, error) +} + +type goamqpReceiver interface { + IssueCredit(credit uint32) error + Receive(ctx context.Context, o *amqp.ReceiveOptions) (*amqp.Message, error) + Prefetched() *amqp.Message + + // settlement functions + AcceptMessage(ctx context.Context, msg *amqp.Message) error + RejectMessage(ctx context.Context, msg *amqp.Message, e *amqp.Error) error + ReleaseMessage(ctx context.Context, msg *amqp.Message) error + ModifyMessage(ctx context.Context, msg *amqp.Message, options *amqp.ModifyMessageOptions) error + + LinkName() string + LinkSourceFilterValue(name string) any + Close(ctx context.Context) error +} + +type goamqpSender interface { + Send(ctx context.Context, msg *amqp.Message, o *amqp.SendOptions) error + MaxMessageSize() uint64 + LinkName() string + Close(ctx context.Context) error +} + +// AMQPClientWrapper is a simple interface, implemented by *AMQPClientWrapper +// It exists only so we can return AMQPSession, which itself only exists so we can +// return interfaces for AMQPSender and AMQPReceiver from AMQPSession. +type AMQPClientWrapper struct { + ConnID uint64 + Inner goamqpConn +} + +func (w *AMQPClientWrapper) ID() uint64 { + return w.ConnID +} + +func (w *AMQPClientWrapper) Close() error { + err := w.Inner.Close() + return WrapError(err, w.ConnID, "", "") +} + +func (w *AMQPClientWrapper) NewSession(ctx context.Context, opts *amqp.SessionOptions) (AMQPSession, error) { + sess, err := w.Inner.NewSession(ctx, opts) + + if err != nil { + return nil, WrapError(err, w.ConnID, "", "") + } + + return &AMQPSessionWrapper{ + connID: w.ConnID, + Inner: sess, + ContextWithTimeoutFn: context.WithTimeout, + }, nil +} + +type AMQPSessionWrapper struct { + connID uint64 + Inner goamqpSession + ContextWithTimeoutFn ContextWithTimeoutFn +} + +func (w *AMQPSessionWrapper) ConnID() uint64 { + return w.connID +} + +func (w *AMQPSessionWrapper) Close(ctx context.Context) error { + ctx, cancel := w.ContextWithTimeoutFn(ctx, defaultCloseTimeout) + defer cancel() + err := w.Inner.Close(ctx) + return WrapError(err, w.connID, "", "") +} + +func (w *AMQPSessionWrapper) NewReceiver(ctx context.Context, source string, partitionID string, opts *amqp.ReceiverOptions) (AMQPReceiverCloser, error) { + receiver, err := w.Inner.NewReceiver(ctx, source, opts) + + if err != nil { + return nil, WrapError(err, w.connID, "", partitionID) + } + + return &AMQPReceiverWrapper{ + connID: w.connID, + partitionID: partitionID, + Inner: receiver, + ContextWithTimeoutFn: context.WithTimeout}, nil +} + +func (w *AMQPSessionWrapper) NewSender(ctx context.Context, target string, partitionID string, opts *amqp.SenderOptions) (AMQPSenderCloser, error) { + sender, err := w.Inner.NewSender(ctx, target, opts) + + if err != nil { + return nil, WrapError(err, w.connID, "", partitionID) + } + + return &AMQPSenderWrapper{ + connID: w.connID, + partitionID: partitionID, + Inner: sender, + ContextWithTimeoutFn: context.WithTimeout}, nil +} + +type AMQPReceiverWrapper struct { + connID uint64 + partitionID string + Inner goamqpReceiver + credits uint32 + ContextWithTimeoutFn ContextWithTimeoutFn +} + +func (rw *AMQPReceiverWrapper) ConnID() uint64 { + return rw.connID +} + +func (rw *AMQPReceiverWrapper) Credits() uint32 { + return rw.credits +} + +func (rw *AMQPReceiverWrapper) IssueCredit(credit uint32) error { + err := rw.Inner.IssueCredit(credit) + + if err == nil { + rw.credits += credit + } + + return WrapError(err, rw.connID, rw.LinkName(), rw.partitionID) +} + +func (rw *AMQPReceiverWrapper) Receive(ctx context.Context, o *amqp.ReceiveOptions) (*amqp.Message, error) { + message, err := rw.Inner.Receive(ctx, o) + + if err != nil { + return nil, WrapError(err, rw.connID, rw.LinkName(), rw.partitionID) + } + + rw.credits-- + return message, nil +} + +func (rw *AMQPReceiverWrapper) Prefetched() *amqp.Message { + msg := rw.Inner.Prefetched() + + if msg == nil { + return nil + } + + rw.credits-- + return msg +} + +// settlement functions +func (rw *AMQPReceiverWrapper) AcceptMessage(ctx context.Context, msg *amqp.Message) error { + err := rw.Inner.AcceptMessage(ctx, msg) + return WrapError(err, rw.connID, rw.LinkName(), rw.partitionID) +} + +func (rw *AMQPReceiverWrapper) RejectMessage(ctx context.Context, msg *amqp.Message, e *amqp.Error) error { + err := rw.Inner.RejectMessage(ctx, msg, e) + return WrapError(err, rw.connID, rw.LinkName(), rw.partitionID) +} + +func (rw *AMQPReceiverWrapper) ReleaseMessage(ctx context.Context, msg *amqp.Message) error { + err := rw.Inner.ReleaseMessage(ctx, msg) + return WrapError(err, rw.connID, rw.LinkName(), rw.partitionID) +} + +func (rw *AMQPReceiverWrapper) ModifyMessage(ctx context.Context, msg *amqp.Message, options *amqp.ModifyMessageOptions) error { + err := rw.Inner.ModifyMessage(ctx, msg, options) + return WrapError(err, rw.connID, rw.LinkName(), rw.partitionID) +} + +func (rw *AMQPReceiverWrapper) LinkName() string { + return rw.Inner.LinkName() +} + +func (rw *AMQPReceiverWrapper) LinkSourceFilterValue(name string) any { + return rw.Inner.LinkSourceFilterValue(name) +} + +func (rw *AMQPReceiverWrapper) Close(ctx context.Context) error { + ctx, cancel := rw.ContextWithTimeoutFn(ctx, defaultCloseTimeout) + defer cancel() + err := rw.Inner.Close(ctx) + + return WrapError(err, rw.connID, rw.LinkName(), rw.partitionID) +} + +type AMQPSenderWrapper struct { + connID uint64 + partitionID string + Inner goamqpSender + ContextWithTimeoutFn ContextWithTimeoutFn +} + +func (sw *AMQPSenderWrapper) ConnID() uint64 { + return sw.connID +} + +func (sw *AMQPSenderWrapper) Send(ctx context.Context, msg *amqp.Message, o *amqp.SendOptions) error { + err := sw.Inner.Send(ctx, msg, o) + return WrapError(err, sw.connID, sw.LinkName(), sw.partitionID) +} + +func (sw *AMQPSenderWrapper) MaxMessageSize() uint64 { + return sw.Inner.MaxMessageSize() +} + +func (sw *AMQPSenderWrapper) LinkName() string { + return sw.Inner.LinkName() +} + +func (sw *AMQPSenderWrapper) Close(ctx context.Context) error { + ctx, cancel := sw.ContextWithTimeoutFn(ctx, defaultCloseTimeout) + defer cancel() + err := sw.Inner.Close(ctx) + + return WrapError(err, sw.connID, sw.LinkName(), sw.partitionID) +} + +var ErrConnResetNeeded = errors.New("connection must be reset, link/connection state may be inconsistent") + +const defaultCloseTimeout = time.Minute + +// ContextWithTimeoutFn matches the signature for `context.WithTimeout` and is used when we want to +// stub things out for tests. +type ContextWithTimeoutFn func(parent context.Context, timeout time.Duration) (context.Context, context.CancelFunc) diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/amqpwrap/error.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/amqpwrap/error.go new file mode 100644 index 00000000000..5953fd18c37 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/amqpwrap/error.go @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package amqpwrap + +import ( + "errors" +) + +// Error is a wrapper that has the context of which connection and +// link the error happened with. +type Error struct { + ConnID uint64 + LinkName string + PartitionID string + Err error +} + +func (e Error) Error() string { + return e.Err.Error() +} + +func (e Error) As(target any) bool { + return errors.As(e.Err, target) +} + +func (e Error) Is(target error) bool { + return errors.Is(e.Err, target) +} + +func WrapError(err error, connID uint64, linkName string, partitionID string) error { + if err == nil { + return nil + } + + return Error{ + ConnID: connID, + LinkName: linkName, + PartitionID: partitionID, + Err: err, + } +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/amqpwrap/rpc.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/amqpwrap/rpc.go new file mode 100644 index 00000000000..ced17fbc493 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/amqpwrap/rpc.go @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package amqpwrap + +import ( + "context" + + "github.com/Azure/go-amqp" +) + +// RPCResponse is the simplified response structure from an RPC like call +type RPCResponse struct { + // Code is the response code - these originate from Service Bus. Some + // common values are called out below, with the RPCResponseCode* constants. + Code int + Description string + Message *amqp.Message +} + +// RPCLink is implemented by *rpc.Link +type RPCLink interface { + Close(ctx context.Context) error + ConnID() uint64 + RPC(ctx context.Context, msg *amqp.Message) (*RPCResponse, error) + LinkName() string +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/auth/token.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/auth/token.go new file mode 100644 index 00000000000..9aed3b521d5 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/auth/token.go @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Package auth provides an abstraction over claims-based security for Azure Event Hub and Service Bus. +package auth + +const ( + // CBSTokenTypeJWT is the type of token to be used for JWTs. For example Azure Active Directory tokens. + CBSTokenTypeJWT TokenType = "jwt" + // CBSTokenTypeSAS is the type of token to be used for SAS tokens. + CBSTokenTypeSAS TokenType = "servicebus.windows.net:sastoken" +) + +type ( + // TokenType represents types of tokens known for claims-based auth + TokenType string + + // Token contains all of the information to negotiate authentication + Token struct { + // TokenType is the type of CBS token + TokenType TokenType + Token string + Expiry string + } + + // TokenProvider abstracts the fetching of authentication tokens + TokenProvider interface { + GetToken(uri string) (*Token, error) + } +) + +// NewToken constructs a new auth token +func NewToken(tokenType TokenType, token, expiry string) *Token { + return &Token{ + TokenType: tokenType, + Token: token, + Expiry: expiry, + } +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/cbs.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/cbs.go new file mode 100644 index 00000000000..103f71a9a92 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/cbs.go @@ -0,0 +1,78 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package internal + +import ( + "context" + + azlog "github.com/Azure/azure-sdk-for-go/sdk/internal/log" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/amqpwrap" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/auth" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/exported" + "github.com/Azure/go-amqp" +) + +const ( + cbsAddress = "$cbs" + cbsOperationKey = "operation" + cbsOperationPutToken = "put-token" + cbsTokenTypeKey = "type" + cbsAudienceKey = "name" + cbsExpirationKey = "expiration" +) + +// NegotiateClaim attempts to put a token to the $cbs management endpoint to negotiate auth for the given audience +func NegotiateClaim(ctx context.Context, audience string, conn amqpwrap.AMQPClient, provider auth.TokenProvider) error { + link, err := NewRPCLink(ctx, RPCLinkArgs{ + Client: conn, + Address: cbsAddress, + LogEvent: exported.EventAuth, + }) + + if err != nil { + // In some circumstances we can end up in a situation where the link closing was cancelled + // or interrupted, leaving $cbs still open by some dangling receiver or sender. The only way + // to fix this is to restart the connection. + if IsNotAllowedError(err) { + azlog.Writef(exported.EventAuth, "Not allowed to open, connection will be reset: %s", err) + return amqpwrap.ErrConnResetNeeded + } + + return err + } + + closeLink := func(ctx context.Context, origErr error) error { + if err := link.Close(ctx); err != nil { + azlog.Writef(exported.EventAuth, "Failed closing claim link: %s", err.Error()) + return err + } + + return origErr + } + + token, err := provider.GetToken(audience) + if err != nil { + azlog.Writef(exported.EventAuth, "Failed to get token from provider: %s", err) + return closeLink(ctx, err) + } + + azlog.Writef(exported.EventAuth, "negotiating claim for audience %s with token type %s and expiry of %s", audience, token.TokenType, token.Expiry) + + msg := &amqp.Message{ + Value: token.Token, + ApplicationProperties: map[string]any{ + cbsOperationKey: cbsOperationPutToken, + cbsTokenTypeKey: string(token.TokenType), + cbsAudienceKey: audience, + cbsExpirationKey: token.Expiry, + }, + } + + if _, err := link.RPC(ctx, msg); err != nil { + azlog.Writef(exported.EventAuth, "Failed to send/receive RPC message: %s", err) + return closeLink(ctx, err) + } + + return closeLink(ctx, nil) +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/constants.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/constants.go new file mode 100644 index 00000000000..154eda8786c --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/constants.go @@ -0,0 +1,7 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package internal + +// Version is the semantic version number +const Version = "v1.2.1" diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/eh/eh_internal.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/eh/eh_internal.go new file mode 100644 index 00000000000..17e0c7f138b --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/eh/eh_internal.go @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package eh + +// ConvertToInt64 converts any int-like value to be an int64. +func ConvertToInt64(intValue any) (int64, bool) { + switch v := intValue.(type) { + case int: + return int64(v), true + case int8: + return int64(v), true + case int16: + return int64(v), true + case int32: + return int64(v), true + case int64: + return int64(v), true + } + + return 0, false +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/errors.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/errors.go new file mode 100644 index 00000000000..d86c09f8bf6 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/errors.go @@ -0,0 +1,265 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package internal + +import ( + "context" + "errors" + "io" + "net" + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/amqpwrap" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/exported" + "github.com/Azure/go-amqp" +) + +type errNonRetriable struct { + Message string +} + +func NewErrNonRetriable(message string) error { + return errNonRetriable{Message: message} +} + +func (e errNonRetriable) Error() string { return e.Message } + +// RecoveryKind dictates what kind of recovery is possible. Used with +// GetRecoveryKind(). +type RecoveryKind string + +const ( + RecoveryKindNone RecoveryKind = "" + RecoveryKindFatal RecoveryKind = "fatal" + RecoveryKindLink RecoveryKind = "link" + RecoveryKindConn RecoveryKind = "connection" +) + +func IsFatalEHError(err error) bool { + return GetRecoveryKind(err) == RecoveryKindFatal +} + +// TransformError will create a proper error type that users +// can potentially inspect. +// If the error is actionable then it'll be of type exported.Error which +// has a 'Code' field that can be used programatically. +// If it's not actionable or if it's nil it'll just be returned. +func TransformError(err error) error { + if err == nil { + return nil + } + + _, ok := err.(*exported.Error) + + if ok { + // it's already been wrapped. + return err + } + + if IsOwnershipLostError(err) { + return exported.NewError(exported.ErrorCodeOwnershipLost, err) + } + + // there are a few errors that all boil down to "bad creds or unauthorized" + var amqpErr *amqp.Error + + if errors.As(err, &amqpErr) && amqpErr.Condition == amqp.ErrCondUnauthorizedAccess { + return exported.NewError(exported.ErrorCodeUnauthorizedAccess, err) + } + + var rpcErr RPCError + if errors.As(err, &rpcErr) && rpcErr.Resp.Code == http.StatusUnauthorized { + return exported.NewError(exported.ErrorCodeUnauthorizedAccess, err) + } + + rk := GetRecoveryKind(err) + + switch rk { + case RecoveryKindLink: + // note that we could give back a more differentiated error code + // here but it's probably best to just give the customer the simplest + // recovery mechanism possible. + return exported.NewError(exported.ErrorCodeConnectionLost, err) + case RecoveryKindConn: + return exported.NewError(exported.ErrorCodeConnectionLost, err) + default: + // isn't one of our specifically called out cases so we'll just return it. + return err + } +} + +func IsQuickRecoveryError(err error) bool { + if IsOwnershipLostError(err) { + return false + } + + var de *amqp.LinkError + return errors.As(err, &de) +} + +func IsCancelError(err error) bool { + if err == nil { + return false + } + + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return true + } + + if err.Error() == "context canceled" { // go-amqp is returning this when I cancel + return true + } + + return false +} + +const errorConditionLockLost = amqp.ErrCond("com.microsoft:message-lock-lost") + +var amqpConditionsToRecoveryKind = map[amqp.ErrCond]RecoveryKind{ + // no recovery needed, these are temporary errors. + amqp.ErrCond("com.microsoft:server-busy"): RecoveryKindNone, + amqp.ErrCond("com.microsoft:timeout"): RecoveryKindNone, + amqp.ErrCond("com.microsoft:operation-cancelled"): RecoveryKindNone, + + // Link recovery needed + amqp.ErrCondDetachForced: RecoveryKindLink, // "amqp:link:detach-forced" + amqp.ErrCondTransferLimitExceeded: RecoveryKindLink, // "amqp:link:transfer-limit-exceeded" + + // Connection recovery needed + amqp.ErrCondConnectionForced: RecoveryKindConn, // "amqp:connection:forced" + amqp.ErrCondInternalError: RecoveryKindConn, // "amqp:internal-error" + + // No recovery possible - this operation is non retriable. + + // ErrCondResourceLimitExceeded comes back if the entity is actually full. + amqp.ErrCondResourceLimitExceeded: RecoveryKindFatal, // "amqp:resource-limit-exceeded" + amqp.ErrCondMessageSizeExceeded: RecoveryKindFatal, // "amqp:link:message-size-exceeded" + amqp.ErrCondUnauthorizedAccess: RecoveryKindFatal, // creds are bad + amqp.ErrCondNotFound: RecoveryKindFatal, // "amqp:not-found" + amqp.ErrCondNotAllowed: RecoveryKindFatal, // "amqp:not-allowed" + amqp.ErrCond("com.microsoft:entity-disabled"): RecoveryKindFatal, // entity is disabled in the portal + amqp.ErrCond("com.microsoft:session-cannot-be-locked"): RecoveryKindFatal, + amqp.ErrCond("com.microsoft:argument-out-of-range"): RecoveryKindFatal, // asked for a partition ID that doesn't exist + errorConditionLockLost: RecoveryKindFatal, +} + +// GetRecoveryKind determines the recovery type for non-session based links. +func GetRecoveryKind(err error) RecoveryKind { + if err == nil { + return RecoveryKindNone + } + + if errors.Is(err, RPCLinkClosedErr) { + return RecoveryKindFatal + } + + if IsCancelError(err) { + return RecoveryKindFatal + } + + if errors.Is(err, amqpwrap.ErrConnResetNeeded) { + return RecoveryKindConn + } + + var netErr net.Error + + // these are errors that can flow from the go-amqp connection to + // us. There's work underway to improve this but for now we can handle + // these as "catastrophic" errors and reset everything. + if errors.Is(err, io.EOF) || errors.As(err, &netErr) { + return RecoveryKindConn + } + + var errNonRetriable errNonRetriable + + if errors.As(err, &errNonRetriable) { + return RecoveryKindFatal + } + + // azidentity returns errors that match this for auth failures. + var errNonRetriableMarker interface { + NonRetriable() + error + } + + if errors.As(err, &errNonRetriableMarker) { + return RecoveryKindFatal + } + + if IsOwnershipLostError(err) { + return RecoveryKindFatal + } + + // check the "special" AMQP errors that aren't condition-based. + if IsQuickRecoveryError(err) { + return RecoveryKindLink + } + + var connErr *amqp.ConnError + var sessionErr *amqp.SessionError + + if errors.As(err, &connErr) || + // session closures appear to leak through when the connection itself is going down. + errors.As(err, &sessionErr) { + return RecoveryKindConn + } + + // then it's _probably_ an actual *amqp.Error, in which case we bucket it by + // the 'condition'. + var amqpError *amqp.Error + + if errors.As(err, &amqpError) { + recoveryKind, ok := amqpConditionsToRecoveryKind[amqpError.Condition] + + if ok { + return recoveryKind + } + } + + var rpcErr RPCError + + if errors.As(err, &rpcErr) { + // Described more here: + // https://www.oasis-open.org/committees/download.php/54441/AMQP%20Management%20v1.0%20WD09 + // > Unsuccessful operations MUST NOT result in a statusCode in the 2xx range as defined in Section 10.2 of [RFC2616] + // RFC2616 is the specification for HTTP. + code := rpcErr.RPCCode() + + if code == http.StatusNotFound || + code == http.StatusUnauthorized { + return RecoveryKindFatal + } + + // simple timeouts + if rpcErr.Resp.Code == http.StatusRequestTimeout || rpcErr.Resp.Code == http.StatusServiceUnavailable || + // internal server errors are worth retrying (they will typically lead + // to a more actionable error). A simple example of this is when you're + // in the middle of an operation and the link is detached. Sometimes you'll get + // the detached event immediately, but sometimes you'll get an intermediate 500 + // indicating your original operation was cancelled. + rpcErr.Resp.Code == http.StatusInternalServerError { + return RecoveryKindNone + } + } + + // this is some error type we've never seen - recover the entire connection. + return RecoveryKindConn +} + +func IsNotAllowedError(err error) bool { + var e *amqp.Error + + return errors.As(err, &e) && + e.Condition == amqp.ErrCondNotAllowed +} + +func IsOwnershipLostError(err error) bool { + var de *amqp.LinkError + + if errors.As(err, &de) { + return de.RemoteErr != nil && de.RemoteErr.Condition == "amqp:link:stolen" + } + + return false +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/exported/connection_string_properties.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/exported/connection_string_properties.go new file mode 100644 index 00000000000..b77d22305c1 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/exported/connection_string_properties.go @@ -0,0 +1,129 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package exported + +import ( + "errors" + "fmt" + "net/url" + "strconv" + "strings" +) + +// ConnectionStringProperties are the properties of a connection string +// as returned by [ParseConnectionString]. +type ConnectionStringProperties struct { + // Endpoint is the Endpoint value in the connection string. + // Ex: sb://example.servicebus.windows.net + Endpoint string + + // EntityPath is EntityPath value in the connection string. + EntityPath *string + + // FullyQualifiedNamespace is the Endpoint value without the protocol scheme. + // Ex: example.servicebus.windows.net + FullyQualifiedNamespace string + + // SharedAccessKey is the SharedAccessKey value in the connection string. + SharedAccessKey *string + + // SharedAccessKeyName is the SharedAccessKeyName value in the connection string. + SharedAccessKeyName *string + + // SharedAccessSignature is the SharedAccessSignature value in the connection string. + SharedAccessSignature *string + + // Emulator indicates that the connection string is for an emulator: + // ex: Endpoint=localhost:6765;SharedAccessKeyName=<< REDACTED >>;SharedAccessKey=<< REDACTED >>;UseDevelopmentEmulator=true + Emulator bool +} + +// ParseConnectionString takes a connection string from the Azure portal and returns the +// parsed representation. +// +// There are two supported formats: +// +// 1. Connection strings generated from the portal (or elsewhere) that contain an embedded key and keyname. +// +// 2. A connection string with an embedded SharedAccessSignature: +// Endpoint=sb://.servicebus.windows.net;SharedAccessSignature=SharedAccessSignature sr=.servicebus.windows.net&sig=&se=&skn=" +func ParseConnectionString(connStr string) (ConnectionStringProperties, error) { + const ( + endpointKey = "Endpoint" + sharedAccessKeyNameKey = "SharedAccessKeyName" + sharedAccessKeyKey = "SharedAccessKey" + entityPathKey = "EntityPath" + sharedAccessSignatureKey = "SharedAccessSignature" + useEmulator = "UseDevelopmentEmulator" + ) + + csp := ConnectionStringProperties{} + + splits := strings.Split(connStr, ";") + + for _, split := range splits { + if split == "" { + continue + } + + keyAndValue := strings.SplitN(split, "=", 2) + if len(keyAndValue) < 2 { + return ConnectionStringProperties{}, errors.New("failed parsing connection string due to unmatched key value separated by '='") + } + + // if a key value pair has `=` in the value, recombine them + key := keyAndValue[0] + value := strings.Join(keyAndValue[1:], "=") + switch { + case strings.EqualFold(endpointKey, key): + u, err := url.Parse(value) + if err != nil { + return ConnectionStringProperties{}, errors.New("failed parsing connection string due to an incorrectly formatted Endpoint value") + } + csp.Endpoint = value + csp.FullyQualifiedNamespace = u.Host + case strings.EqualFold(sharedAccessKeyNameKey, key): + csp.SharedAccessKeyName = &value + case strings.EqualFold(sharedAccessKeyKey, key): + csp.SharedAccessKey = &value + case strings.EqualFold(entityPathKey, key): + csp.EntityPath = &value + case strings.EqualFold(sharedAccessSignatureKey, key): + csp.SharedAccessSignature = &value + case strings.EqualFold(useEmulator, key): + v, err := strconv.ParseBool(value) + + if err != nil { + return ConnectionStringProperties{}, err + } + + csp.Emulator = v + } + } + + if csp.Emulator { + endpointParts := strings.SplitN(csp.Endpoint, ":", 3) // allow for a port, if it exists. + + if len(endpointParts) < 2 || endpointParts[0] != "sb" { + // there should always be at least two parts "sb:" and "//" + // with an optional 3rd piece that's the port "1111". + // (we don't need to validate it's a valid host since it's been through url.Parse() above) + return ConnectionStringProperties{}, fmt.Errorf("UseDevelopmentEmulator=true can only be used with sb:// or sb://:, not %s", csp.Endpoint) + } + } + + if csp.FullyQualifiedNamespace == "" { + return ConnectionStringProperties{}, fmt.Errorf("key %q must not be empty", endpointKey) + } + + if csp.SharedAccessSignature == nil && csp.SharedAccessKeyName == nil { + return ConnectionStringProperties{}, fmt.Errorf("key %q must not be empty", sharedAccessKeyNameKey) + } + + if csp.SharedAccessKey == nil && csp.SharedAccessSignature == nil { + return ConnectionStringProperties{}, fmt.Errorf("key %q or %q cannot both be empty", sharedAccessKeyKey, sharedAccessSignatureKey) + } + + return csp, nil +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/exported/error.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/exported/error.go new file mode 100644 index 00000000000..23a920a61c1 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/exported/error.go @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package exported + +import "fmt" + +// ErrorCode is an error code, usable by consuming code to work with +// programatically. +type ErrorCode string + +const ( + // ErrorCodeUnauthorizedAccess means the credentials provided are not valid for use with + // a particular entity, or have expired. + ErrorCodeUnauthorizedAccess ErrorCode = "unauthorized" + + // ErrorCodeConnectionLost means our connection was lost and all retry attempts failed. + // This typically reflects an extended outage or connection disruption and may + // require manual intervention. + ErrorCodeConnectionLost ErrorCode = "connlost" + + // ErrorCodeOwnershipLost means that a partition that you were reading from was opened + // by another link with an epoch/owner level greater or equal to your [PartitionClient]. + // + // When using types like the [Processor], partition ownership will change as instances + // rebalance. + ErrorCodeOwnershipLost ErrorCode = "ownershiplost" +) + +// Error represents an Event Hub specific error. +// NOTE: the Code is considered part of the published API but the message that +// comes back from Error(), as well as the underlying wrapped error, are NOT and +// are subject to change. +type Error struct { + // Code is a stable error code which can be used as part of programatic error handling. + // The codes can expand in the future, but the values (and their meaning) will remain the same. + Code ErrorCode + innerErr error +} + +// Error is an error message containing the code and a user friendly message, if any. +func (e *Error) Error() string { + msg := "unknown error" + if e.innerErr != nil { + msg = e.innerErr.Error() + } + return fmt.Sprintf("(%s): %s", e.Code, msg) +} + +// NewError creates a new `Error` instance. +// NOTE: this function is only exported so it can be used by the `internal` +// package. It is not available for customers. +func NewError(code ErrorCode, innerErr error) error { + return &Error{ + Code: code, + innerErr: innerErr, + } +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/exported/log_events.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/exported/log_events.go new file mode 100644 index 00000000000..2c4a36f403b --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/exported/log_events.go @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package exported + +import ( + "github.com/Azure/azure-sdk-for-go/sdk/internal/log" +) + +// NOTE: these are publicly exported via type-aliasing in azeventhubs/log.go +const ( + // EventConn is used whenever we create a connection or any links (ie: receivers, senders). + EventConn log.Event = "azeh.Conn" + + // EventAuth is used when we're doing authentication/claims negotiation. + EventAuth log.Event = "azeh.Auth" + + // EventProducer represents operations that happen on Producers. + EventProducer log.Event = "azeh.Producer" + + // EventConsumer represents operations that happen on Consumers. + EventConsumer log.Event = "azeh.Consumer" +) diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/exported/retry_options.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/exported/retry_options.go new file mode 100644 index 00000000000..6bed306ad5c --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/exported/retry_options.go @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package exported + +import "time" + +// NOTE: this is exposed via type-aliasing in azeventhubs/client.go + +// RetryOptions represent the options for retries. +type RetryOptions struct { + // MaxRetries specifies the maximum number of attempts a failed operation will be retried + // before producing an error. + // The default value is three. A value less than zero means one try and no retries. + MaxRetries int32 + + // RetryDelay specifies the initial amount of delay to use before retrying an operation. + // The delay increases exponentially with each retry up to the maximum specified by MaxRetryDelay. + // The default value is four seconds. A value less than zero means no delay between retries. + RetryDelay time.Duration + + // MaxRetryDelay specifies the maximum delay allowed before retrying an operation. + // Typically the value is greater than or equal to the value specified in RetryDelay. + // The default Value is 120 seconds. A value less than zero means there is no cap. + MaxRetryDelay time.Duration +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/exported/websocket_conn_params.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/exported/websocket_conn_params.go new file mode 100644 index 00000000000..5bc28602450 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/exported/websocket_conn_params.go @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package exported + +// NOTE: this struct is exported via client.go:WebSocketConnParams + +// WebSocketConnParams are the arguments to the NewWebSocketConn function you pass if you want +// to enable websockets. +type WebSocketConnParams struct { + // Host is the the `wss://` to connect to + Host string +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/links.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/links.go new file mode 100644 index 00000000000..b20fa6f62fd --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/links.go @@ -0,0 +1,395 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package internal + +import ( + "context" + "fmt" + "sync" + + azlog "github.com/Azure/azure-sdk-for-go/sdk/internal/log" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/amqpwrap" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/exported" +) + +type AMQPLink interface { + Close(ctx context.Context) error + LinkName() string +} + +// LinksForPartitionClient are the functions that the PartitionClient uses within Links[T] +// (for unit testing only) +type LinksForPartitionClient[LinkT AMQPLink] interface { + Retry(ctx context.Context, eventName azlog.Event, operation string, partitionID string, retryOptions exported.RetryOptions, fn func(ctx context.Context, lwid LinkWithID[LinkT]) error) error + Close(ctx context.Context) error +} + +type Links[LinkT AMQPLink] struct { + ns NamespaceForAMQPLinks + + linksMu *sync.RWMutex + links map[string]*linkState[LinkT] + + managementLinkMu *sync.RWMutex + managementLink *linkState[amqpwrap.RPCLink] + + managementPath string + newLinkFn NewLinksFn[LinkT] + entityPathFn func(partitionID string) string + + lr LinkRetrier[LinkT] + mr LinkRetrier[amqpwrap.RPCLink] +} + +type NewLinksFn[LinkT AMQPLink] func(ctx context.Context, session amqpwrap.AMQPSession, entityPath string, partitionID string) (LinkT, error) + +func NewLinks[LinkT AMQPLink](ns NamespaceForAMQPLinks, managementPath string, entityPathFn func(partitionID string) string, newLinkFn NewLinksFn[LinkT]) *Links[LinkT] { + l := &Links[LinkT]{ + ns: ns, + linksMu: &sync.RWMutex{}, + links: map[string]*linkState[LinkT]{}, + managementLinkMu: &sync.RWMutex{}, + managementPath: managementPath, + + newLinkFn: newLinkFn, + entityPathFn: entityPathFn, + } + + l.lr = LinkRetrier[LinkT]{ + GetLink: l.GetLink, + CloseLink: l.closePartitionLinkIfMatch, + NSRecover: l.ns.Recover, + } + + l.mr = LinkRetrier[amqpwrap.RPCLink]{ + GetLink: func(ctx context.Context, partitionID string) (LinkWithID[amqpwrap.RPCLink], error) { + return l.GetManagementLink(ctx) + }, + CloseLink: func(ctx context.Context, _, linkName string) error { + return l.closeManagementLinkIfMatch(ctx, linkName) + }, + NSRecover: l.ns.Recover, + } + + return l +} + +func (l *Links[LinkT]) RetryManagement(ctx context.Context, eventName azlog.Event, operation string, retryOptions exported.RetryOptions, fn func(ctx context.Context, lwid LinkWithID[amqpwrap.RPCLink]) error) error { + return l.mr.Retry(ctx, eventName, operation, "", retryOptions, fn) +} + +func (l *Links[LinkT]) Retry(ctx context.Context, eventName azlog.Event, operation string, partitionID string, retryOptions exported.RetryOptions, fn func(ctx context.Context, lwid LinkWithID[LinkT]) error) error { + return l.lr.Retry(ctx, eventName, operation, partitionID, retryOptions, fn) +} + +func (l *Links[LinkT]) GetLink(ctx context.Context, partitionID string) (LinkWithID[LinkT], error) { + if err := l.checkOpen(); err != nil { + return nil, err + } + + l.linksMu.RLock() + current := l.links[partitionID] + l.linksMu.RUnlock() + + if current != nil { + return current, nil + } + + // no existing link, let's create a new one within the write lock. + l.linksMu.Lock() + defer l.linksMu.Unlock() + + // check again now that we have the write lock + current = l.links[partitionID] + + if current == nil { + ls, err := l.newLinkState(ctx, partitionID) + + if err != nil { + return nil, err + } + + l.links[partitionID] = ls + current = ls + } + + return current, nil +} + +func (l *Links[LinkT]) GetManagementLink(ctx context.Context) (LinkWithID[amqpwrap.RPCLink], error) { + if err := l.checkOpen(); err != nil { + return nil, err + } + + l.managementLinkMu.Lock() + defer l.managementLinkMu.Unlock() + + if l.managementLink == nil { + ls, err := l.newManagementLinkState(ctx) + + if err != nil { + return nil, err + } + + l.managementLink = ls + } + + return l.managementLink, nil +} + +func (l *Links[LinkT]) newLinkState(ctx context.Context, partitionID string) (*linkState[LinkT], error) { + azlog.Writef(exported.EventConn, "Creating link for partition ID '%s'", partitionID) + + // check again now that we have the write lock + ls := &linkState[LinkT]{ + partitionID: partitionID, + } + + cancelAuth, _, err := l.ns.NegotiateClaim(ctx, l.entityPathFn(partitionID)) + + if err != nil { + azlog.Writef(exported.EventConn, "(%s): Failed to negotiate claim for partition ID '%s': %s", ls.String(), partitionID, err) + return nil, err + } + + ls.cancelAuth = cancelAuth + + session, connID, err := l.ns.NewAMQPSession(ctx) + + if err != nil { + azlog.Writef(exported.EventConn, "(%s): Failed to create AMQP session for partition ID '%s': %s", ls.String(), partitionID, err) + _ = ls.Close(ctx) + return nil, err + } + + ls.session = session + ls.connID = connID + + tmpLink, err := l.newLinkFn(ctx, session, l.entityPathFn(partitionID), partitionID) + + if err != nil { + azlog.Writef(exported.EventConn, "(%s): Failed to create link for partition ID '%s': %s", ls.String(), partitionID, err) + _ = ls.Close(ctx) + return nil, err + } + + ls.link = &tmpLink + + azlog.Writef(exported.EventConn, "(%s): Succesfully created link for partition ID '%s'", ls.String(), partitionID) + return ls, nil +} + +func (l *Links[LinkT]) newManagementLinkState(ctx context.Context) (*linkState[amqpwrap.RPCLink], error) { + ls := &linkState[amqpwrap.RPCLink]{} + + cancelAuth, _, err := l.ns.NegotiateClaim(ctx, l.managementPath) + + if err != nil { + return nil, err + } + + ls.cancelAuth = cancelAuth + + tmpRPCLink, connID, err := l.ns.NewRPCLink(ctx, "$management") + + if err != nil { + _ = ls.Close(ctx) + return nil, err + } + + ls.connID = connID + ls.link = &tmpRPCLink + + return ls, nil +} + +func (l *Links[LinkT]) Close(ctx context.Context) error { + return l.closeLinks(ctx, true) +} + +func (l *Links[LinkT]) closeLinks(ctx context.Context, permanent bool) error { + cancelled := false + + // clear out the management link + func() { + l.managementLinkMu.Lock() + defer l.managementLinkMu.Unlock() + + if l.managementLink == nil { + return + } + + mgmtLink := l.managementLink + l.managementLink = nil + + if err := mgmtLink.Close(ctx); err != nil { + azlog.Writef(exported.EventConn, "Error while cleaning up management link while doing connection recovery: %s", err.Error()) + + if IsCancelError(err) { + cancelled = true + } + } + }() + + l.linksMu.Lock() + defer l.linksMu.Unlock() + + tmpLinks := l.links + l.links = nil + + for partitionID, link := range tmpLinks { + if err := link.Close(ctx); err != nil { + azlog.Writef(exported.EventConn, "Error while cleaning up link for partition ID '%s' while doing connection recovery: %s", partitionID, err.Error()) + + if IsCancelError(err) { + cancelled = true + } + } + } + + if !permanent { + l.links = map[string]*linkState[LinkT]{} + } + + if cancelled { + // this is the only kind of error I'd consider usable from Close() - it'll indicate + // that some of the links haven't been cleanly closed. + return ctx.Err() + } + + return nil +} + +func (l *Links[LinkT]) checkOpen() error { + l.linksMu.RLock() + defer l.linksMu.RUnlock() + + if l.links == nil { + return NewErrNonRetriable("client has been closed by user") + } + + return nil +} + +// closePartitionLinkIfMatch will close the link in the cache if it matches the passed in linkName. +// This is similar to how an etag works - we'll only close it if you are working with the latest link - +// if not, it's a no-op since somebody else has already 'saved' (recovered) before you. +// +// Note that the only error that can be returned here will come from go-amqp. Cleanup of _our_ internal state +// will always happen, if needed. +func (l *Links[LinkT]) closePartitionLinkIfMatch(ctx context.Context, partitionID string, linkName string) error { + l.linksMu.RLock() + current, exists := l.links[partitionID] + l.linksMu.RUnlock() + + if !exists || + current.Link().LinkName() != linkName { // we've already created a new link, their link was stale. + return nil + } + + l.linksMu.Lock() + defer l.linksMu.Unlock() + + current, exists = l.links[partitionID] + + if !exists || + current.Link().LinkName() != linkName { // we've already created a new link, their link was stale. + return nil + } + + delete(l.links, partitionID) + return current.Close(ctx) +} + +func (l *Links[LinkT]) closeManagementLinkIfMatch(ctx context.Context, linkName string) error { + l.managementLinkMu.Lock() + defer l.managementLinkMu.Unlock() + + if l.managementLink != nil && l.managementLink.Link().LinkName() == linkName { + err := l.managementLink.Close(ctx) + l.managementLink = nil + return err + } + + return nil +} + +type linkState[LinkT AMQPLink] struct { + // connID is an arbitrary (but unique) integer that represents the + // current connection. This comes back from the Namespace, anytime + // it hands back a connection. + connID uint64 + + // link will be either an [amqpwrap.AMQPSenderCloser], [amqpwrap.AMQPReceiverCloser] or [amqpwrap.RPCLink] + link *LinkT + + // partitionID, if available. + partitionID string + + // cancelAuth cancels the backround claim negotation for this link. + cancelAuth func() + + // optional session, if we created one for this + // link. + session amqpwrap.AMQPSession +} + +// String returns a string that can be used for logging, of the format: +// (c:,l:<5 characters of link id>) +// +// It can also handle nil and partial initialization. +func (ls *linkState[LinkT]) String() string { + if ls == nil { + return "none" + } + + linkName := "" + + if ls.link != nil { + linkName = ls.Link().LinkName() + } + + return formatLogPrefix(ls.connID, linkName, ls.partitionID) +} + +// Close cancels the background authentication loop for this link and +// then closes the AMQP links. +// NOTE: this avoids any issues where closing fails on the broker-side or +// locally and we leak a goroutine. +func (ls *linkState[LinkT]) Close(ctx context.Context) error { + if ls.cancelAuth != nil { + ls.cancelAuth() + } + + if ls.link != nil { + return ls.Link().Close(ctx) + } + + return nil +} + +func (ls *linkState[LinkT]) PartitionID() string { + return ls.partitionID +} + +func (ls *linkState[LinkT]) ConnID() uint64 { + return ls.connID +} + +func (ls *linkState[LinkT]) Link() LinkT { + return *ls.link +} + +// LinkWithID is a readonly interface over the top of a linkState. +type LinkWithID[LinkT AMQPLink] interface { + ConnID() uint64 + Link() LinkT + PartitionID() string + Close(ctx context.Context) error + String() string +} + +func formatLogPrefix(connID uint64, linkName, partitionID string) string { + return fmt.Sprintf("c:%d,l:%.5s,p:%s", connID, linkName, partitionID) +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/links_recover.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/links_recover.go new file mode 100644 index 00000000000..b1da12ccd8f --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/links_recover.go @@ -0,0 +1,155 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package internal + +import ( + "context" + "errors" + + azlog "github.com/Azure/azure-sdk-for-go/sdk/internal/log" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/amqpwrap" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/exported" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/utils" +) + +type LinkRetrier[LinkT AMQPLink] struct { + GetLink func(ctx context.Context, partitionID string) (LinkWithID[LinkT], error) + CloseLink func(ctx context.Context, partitionID string, linkName string) error + NSRecover func(ctx context.Context, connID uint64) error +} + +type RetryCallback[LinkT AMQPLink] func(ctx context.Context, lwid LinkWithID[LinkT]) error + +// Retry runs the fn argument in a loop, respecting retry counts. +// If connection/link failures occur it also takes care of running recovery logic +// to bring them back, or return an appropriate error if retries are exhausted. +func (l LinkRetrier[LinkT]) Retry(ctx context.Context, + eventName azlog.Event, + operation string, + partitionID string, + retryOptions exported.RetryOptions, + fn RetryCallback[LinkT]) error { + didQuickRetry := false + + isFatalErrorFunc := func(err error) bool { + return GetRecoveryKind(err) == RecoveryKindFatal + } + + currentPrefix := "" + + prefix := func() string { + return currentPrefix + } + + return utils.Retry(ctx, eventName, prefix, retryOptions, func(ctx context.Context, args *utils.RetryFnArgs) error { + if err := l.RecoverIfNeeded(ctx, args.LastErr); err != nil { + return err + } + + linkWithID, err := l.GetLink(ctx, partitionID) + + if err != nil { + return err + } + + currentPrefix = linkWithID.String() + + if err := fn(ctx, linkWithID); err != nil { + if args.I == 0 && !didQuickRetry && IsQuickRecoveryError(err) { + // go-amqp will asynchronously handle detaches. This means errors that you get + // back from Send(), for instance, can actually be from much earlier in time + // depending on the last time you called into Send(). + // + // This means we'll sometimes do an unneeded sleep after a failed retry when + // it would have just immediately worked. To counteract that we'll do a one-time + // quick attempt to recreate link immediately if we see a detach error. This might + // waste a bit of time attempting to do the creation, but since it's just link creation + // it should be fairly fast. + // + // So when we've received a detach is: + // 0th attempt + // extra immediate 0th attempt (if last error was detach) + // (actual retries) + // + // Whereas normally you'd do (for non-detach errors): + // 0th attempt + // (actual retries) + azlog.Writef(exported.EventConn, "(%s, %s) Link was previously detached. Attempting quick reconnect to recover from error: %s", linkWithID.String(), operation, err.Error()) + didQuickRetry = true + args.ResetAttempts() + } + + return err + } + + return nil + }, isFatalErrorFunc) +} + +func (l LinkRetrier[LinkT]) RecoverIfNeeded(ctx context.Context, err error) error { + rk := GetRecoveryKind(err) + + switch rk { + case RecoveryKindNone: + return nil + case RecoveryKindLink: + var awErr amqpwrap.Error + + if !errors.As(err, &awErr) { + azlog.Writef(exported.EventConn, "RecoveryKindLink, but not an amqpwrap.Error: %T,%v", err, err) + return nil + } + + if err := l.CloseLink(ctx, awErr.PartitionID, awErr.LinkName); err != nil { + azlog.Writef(exported.EventConn, "(%s) Error when cleaning up old link for link recovery: %s", formatLogPrefix(awErr.ConnID, awErr.LinkName, awErr.PartitionID), err) + return err + } + + return nil + case RecoveryKindConn: + var awErr amqpwrap.Error + + if !errors.As(err, &awErr) { + azlog.Writef(exported.EventConn, "RecoveryKindConn, but not an amqpwrap.Error: %T,%v", err, err) + return nil + } + + // We only close _this_ partition's link. Other partitions will also get an error, and will recover. + // We used to close _all_ the links, but no longer do that since it's possible (when we do receiver + // redirect) to have more than one active connection at a time which means not all links would be + // affected when a single connection goes down. + if err := l.CloseLink(ctx, awErr.PartitionID, awErr.LinkName); err != nil { + azlog.Writef(exported.EventConn, "(%s) Error when cleaning up old link: %s", formatLogPrefix(awErr.ConnID, awErr.LinkName, awErr.PartitionID), err) + + // NOTE: this is best effort - it's probable the connection is dead anyways so we'll log + // but ignore the error for recovery purposes. + } + + // There are two possibilities here: + // + // 1. (stale) The caller got this error but the `lwid` they're passing us is 'stale' - ie, ' + // the connection the error happened on doesn't exist anymore (we recovered already) or + // the link itself is no longer active in our cache. + // + // 2. (current) The caller got this error and is the current link and/or connection, so we're going to + // need to recycle the connection (possibly) and links. + // + // For #1, we basically don't need to do anything. Recover(old-connection-id) will be a no-op + // and the closePartitionLinkIfMatch() will no-op as well since the link they passed us will + // not match the current link. + // + // For #2, we may recreate the connection. It's possible we won't if the connection itself + // has already been recovered by another goroutine. + err := l.NSRecover(ctx, awErr.ConnID) + + if err != nil { + azlog.Writef(exported.EventConn, "(%s) Failure recovering connection for link: %s", formatLogPrefix(awErr.ConnID, awErr.LinkName, awErr.PartitionID), err) + return err + } + + return nil + default: + return err + } +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/namespace.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/namespace.go new file mode 100644 index 00000000000..dd19b713a7a --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/namespace.go @@ -0,0 +1,512 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package internal + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "runtime" + "sync" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/internal/log" + "github.com/Azure/azure-sdk-for-go/sdk/internal/telemetry" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/amqpwrap" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/auth" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/exported" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/sbauth" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/utils" + "github.com/Azure/go-amqp" +) + +var rootUserAgent = telemetry.Format("azeventhubs", Version) + +type ( + // Namespace is an abstraction over an amqp.Client, allowing us to hold onto a single + // instance of a connection per client.. + Namespace struct { + // NOTE: values need to be 64-bit aligned. Simplest way to make sure this happens + // is just to make it the first value in the struct + // See: + // Godoc: https://pkg.go.dev/sync/atomic#pkg-note-BUG + // PR: https://github.com/Azure/azure-sdk-for-go/pull/16847 + connID uint64 + + FQDN string + TokenProvider *sbauth.TokenProvider + tlsConfig *tls.Config + userAgent string + + newWebSocketConn func(ctx context.Context, args exported.WebSocketConnParams) (net.Conn, error) + + // NOTE: exported only so it can be checked in a test + RetryOptions exported.RetryOptions + + clientMu sync.RWMutex + client amqpwrap.AMQPClient + negotiateClaimMu sync.Mutex + // indicates that the client was closed permanently, and not just + // for recovery. + closedPermanently bool + + // newClientFn exists so we can stub out newClient for unit tests. + newClientFn func(ctx context.Context, connID uint64) (amqpwrap.AMQPClient, error) + } + + // NamespaceOption provides structure for configuring a new Event Hub namespace + NamespaceOption func(h *Namespace) error +) + +// NamespaceWithNewAMQPLinks is the Namespace surface for consumers of AMQPLinks. +type NamespaceWithNewAMQPLinks interface { + Check() error +} + +// NamespaceForAMQPLinks is the Namespace surface needed for the internals of AMQPLinks. +type NamespaceForAMQPLinks interface { + NegotiateClaim(ctx context.Context, entityPath string) (context.CancelFunc, <-chan struct{}, error) + NewAMQPSession(ctx context.Context) (amqpwrap.AMQPSession, uint64, error) + NewRPCLink(ctx context.Context, managementPath string) (amqpwrap.RPCLink, uint64, error) + GetEntityAudience(entityPath string) string + + // Recover destroys the currently held AMQP connection and recreates it, if needed. + // + // NOTE: cancelling the context only cancels the initialization of a new AMQP + // connection - the previous connection is always closed. + Recover(ctx context.Context, clientRevision uint64) error + + Close(ctx context.Context, permanently bool) error +} + +// NamespaceWithConnectionString configures a namespace with the information provided in a Event Hub connection string +func NamespaceWithConnectionString(connStr string) NamespaceOption { + return func(ns *Namespace) error { + props, err := exported.ParseConnectionString(connStr) + if err != nil { + return err + } + + ns.FQDN = props.FullyQualifiedNamespace + + provider, err := sbauth.NewTokenProviderWithConnectionString(props) + if err != nil { + return err + } + + ns.TokenProvider = provider + return nil + } +} + +// NamespaceWithTLSConfig appends to the TLS config. +func NamespaceWithTLSConfig(tlsConfig *tls.Config) NamespaceOption { + return func(ns *Namespace) error { + ns.tlsConfig = tlsConfig + return nil + } +} + +// NamespaceWithUserAgent appends to the root user-agent value. +func NamespaceWithUserAgent(userAgent string) NamespaceOption { + return func(ns *Namespace) error { + ns.userAgent = userAgent + return nil + } +} + +// NamespaceWithWebSocket configures the namespace and all entities to use wss:// rather than amqps:// +func NamespaceWithWebSocket(newWebSocketConn func(ctx context.Context, args exported.WebSocketConnParams) (net.Conn, error)) NamespaceOption { + return func(ns *Namespace) error { + ns.newWebSocketConn = newWebSocketConn + return nil + } +} + +// NamespaceWithTokenCredential sets the token provider on the namespace +// fullyQualifiedNamespace is the Event Hub namespace name (ex: myservicebus.servicebus.windows.net) +func NamespaceWithTokenCredential(fullyQualifiedNamespace string, tokenCredential azcore.TokenCredential) NamespaceOption { + return func(ns *Namespace) error { + ns.TokenProvider = sbauth.NewTokenProvider(tokenCredential) + ns.FQDN = fullyQualifiedNamespace + return nil + } +} + +func NamespaceWithRetryOptions(retryOptions exported.RetryOptions) NamespaceOption { + return func(ns *Namespace) error { + ns.RetryOptions = retryOptions + return nil + } +} + +// NewNamespace creates a new namespace configured through NamespaceOption(s) +func NewNamespace(opts ...NamespaceOption) (*Namespace, error) { + ns := &Namespace{} + + ns.newClientFn = ns.newClientImpl + + for _, opt := range opts { + err := opt(ns) + if err != nil { + return nil, err + } + } + + return ns, nil +} + +func (ns *Namespace) newClientImpl(ctx context.Context, connID uint64) (amqpwrap.AMQPClient, error) { + connOptions := amqp.ConnOptions{ + SASLType: amqp.SASLTypeAnonymous(), + MaxSessions: 65535, + Properties: map[string]any{ + "product": "MSGolangClient", + "version": Version, + "platform": runtime.GOOS, + "framework": runtime.Version(), + "user-agent": ns.getUserAgent(), + }, + } + + if ns.tlsConfig != nil { + connOptions.TLSConfig = ns.tlsConfig + } + + if ns.newWebSocketConn != nil { + nConn, err := ns.newWebSocketConn(ctx, exported.WebSocketConnParams{ + Host: ns.getWSSHostURI() + "$servicebus/websocket", + }) + + if err != nil { + return nil, err + } + + connOptions.HostName = ns.FQDN + client, err := amqp.NewConn(ctx, nConn, &connOptions) + return &amqpwrap.AMQPClientWrapper{Inner: client, ConnID: connID}, err + } + + client, err := amqp.Dial(ctx, ns.getAMQPHostURI(), &connOptions) + return &amqpwrap.AMQPClientWrapper{Inner: client, ConnID: connID}, err +} + +// NewAMQPSession creates a new AMQP session with the internally cached *amqp.Client. +// Returns a closeable AMQP session and the current client revision. +func (ns *Namespace) NewAMQPSession(ctx context.Context) (amqpwrap.AMQPSession, uint64, error) { + client, clientRevision, err := ns.GetAMQPClientImpl(ctx) + + if err != nil { + return nil, 0, err + } + + session, err := client.NewSession(ctx, nil) + + if err != nil { + return nil, 0, err + } + + return session, clientRevision, err +} + +// Close closes the current cached client. +func (ns *Namespace) Close(ctx context.Context, permanently bool) error { + ns.clientMu.Lock() + defer ns.clientMu.Unlock() + + if permanently { + ns.closedPermanently = true + } + + if ns.client != nil { + err := ns.client.Close() + ns.client = nil + + if err != nil { + log.Writef(exported.EventConn, "Failed when closing AMQP connection: %s", err) + } + } + + return nil +} + +// Check returns an error if the namespace cannot be used (ie, closed permanently), or nil otherwise. +func (ns *Namespace) Check() error { + ns.clientMu.RLock() + defer ns.clientMu.RUnlock() + + if ns.closedPermanently { + return ErrClientClosed + } + + return nil +} + +var ErrClientClosed = NewErrNonRetriable("client has been closed by user") + +// Recover destroys the currently held AMQP connection and recreates it, if needed. +// +// NOTE: cancelling the context only cancels the initialization of a new AMQP +// connection - the previous connection is always closed. +func (ns *Namespace) Recover(ctx context.Context, theirConnID uint64) error { + if err := ns.Check(); err != nil { + return err + } + + ns.clientMu.Lock() + defer ns.clientMu.Unlock() + + if ns.closedPermanently { + return ErrClientClosed + } + + if ns.connID != theirConnID { + log.Writef(exported.EventConn, "Skipping connection recovery, already recovered: %d vs %d. Links will still be recovered.", ns.connID, theirConnID) + return nil + } + + if ns.client != nil { + oldClient := ns.client + ns.client = nil + + if err := oldClient.Close(); err != nil { + // the error on close isn't critical, we don't need to exit or + // return it. + log.Writef(exported.EventConn, "Error closing old client: %s", err.Error()) + } + } + + log.Writef(exported.EventConn, "Creating a new client (rev:%d)", ns.connID) + + if _, _, err := ns.updateClientWithoutLock(ctx); err != nil { + return err + } + + return nil +} + +// negotiateClaimFn matches the signature for NegotiateClaim, and is used when we want to stub things out for tests. +type negotiateClaimFn func( + ctx context.Context, audience string, conn amqpwrap.AMQPClient, provider auth.TokenProvider) error + +// negotiateClaim performs initial authentication and starts periodic refresh of credentials. +// the returned func is to cancel() the refresh goroutine. +func (ns *Namespace) NegotiateClaim(ctx context.Context, entityPath string) (context.CancelFunc, <-chan struct{}, error) { + return ns.startNegotiateClaimRenewer(ctx, + entityPath, + NegotiateClaim, + nextClaimRefreshDuration) +} + +// startNegotiateClaimRenewer does an initial claim request and then starts a goroutine that +// continues to automatically refresh in the background. +// Returns a func() that can be used to cancel the background renewal, a channel that will be closed +// when the background renewal stops or an error. +func (ns *Namespace) startNegotiateClaimRenewer(ctx context.Context, + entityPath string, + cbsNegotiateClaim negotiateClaimFn, + nextClaimRefreshDurationFn func(expirationTime time.Time, currentTime time.Time) time.Duration) (func(), <-chan struct{}, error) { + audience := ns.GetEntityAudience(entityPath) + + refreshClaim := func(ctx context.Context) (time.Time, error) { + log.Writef(exported.EventAuth, "(%s) refreshing claim", entityPath) + + amqpClient, clientRevision, err := ns.GetAMQPClientImpl(ctx) + + if err != nil { + return time.Time{}, err + } + + token, expiration, err := ns.TokenProvider.GetTokenAsTokenProvider(audience) + + if err != nil { + log.Writef(exported.EventAuth, "(%s) negotiate claim, failed getting token: %s", entityPath, err.Error()) + return time.Time{}, err + } + + log.Writef(exported.EventAuth, "(%s) negotiate claim, token expires on %s", entityPath, expiration.Format(time.RFC3339)) + + // You're not allowed to have multiple $cbs links open in a single connection. + // The current cbs.NegotiateClaim implementation automatically creates and shuts + // down it's own link so we have to guard against that here. + ns.negotiateClaimMu.Lock() + err = cbsNegotiateClaim(ctx, audience, amqpClient, token) + ns.negotiateClaimMu.Unlock() + + if err != nil { + // Note we only handle connection recovery here since (currently) + // the negotiateClaim code creates it's own link each time. + if GetRecoveryKind(err) == RecoveryKindConn { + if err := ns.Recover(ctx, clientRevision); err != nil { + log.Writef(exported.EventAuth, "(%s) negotiate claim, failed in connection recovery: %s", entityPath, err) + } + } + + log.Writef(exported.EventAuth, "(%s) negotiate claim, failed: %s", entityPath, err.Error()) + return time.Time{}, err + } + + return expiration, nil + } + + expiresOn, err := refreshClaim(ctx) + + if err != nil { + return nil, nil, err + } + + // start the periodic refresh of credentials + refreshCtx, cancelRefreshCtx := context.WithCancel(context.Background()) + refreshStoppedCh := make(chan struct{}) + + // connection strings with embedded SAS tokens will return a zero expiration time since they can't be renewed. + if expiresOn.IsZero() { + log.Writef(exported.EventAuth, "Token does not have an expiration date, no background renewal needed.") + + // cancel everything related to the claims refresh loop. + cancelRefreshCtx() + close(refreshStoppedCh) + + return func() {}, refreshStoppedCh, nil + } + + go func() { + defer cancelRefreshCtx() + defer close(refreshStoppedCh) + + TokenRefreshLoop: + for { + nextClaimAt := nextClaimRefreshDurationFn(expiresOn, time.Now()) + + log.Writef(exported.EventAuth, "(%s) next refresh in %s", entityPath, nextClaimAt) + + select { + case <-refreshCtx.Done(): + return + case <-time.After(nextClaimAt): + for { + err := utils.Retry(refreshCtx, exported.EventAuth, func() string { return "NegotiateClaimRefresh" }, ns.RetryOptions, func(ctx context.Context, args *utils.RetryFnArgs) error { + tmpExpiresOn, err := refreshClaim(ctx) + + if err != nil { + return err + } + + expiresOn = tmpExpiresOn + return nil + }, IsFatalEHError) + + if err == nil { + break + } + + if GetRecoveryKind(err) == RecoveryKindFatal { + log.Writef(exported.EventAuth, "[%s] fatal error, stopping token refresh loop: %s", entityPath, err.Error()) + break TokenRefreshLoop + } + } + } + } + }() + + return func() { + cancelRefreshCtx() + <-refreshStoppedCh + }, refreshStoppedCh, nil +} + +func (ns *Namespace) GetAMQPClientImpl(ctx context.Context) (amqpwrap.AMQPClient, uint64, error) { + if err := ns.Check(); err != nil { + return nil, 0, err + } + + ns.clientMu.Lock() + defer ns.clientMu.Unlock() + + if ns.closedPermanently { + return nil, 0, ErrClientClosed + } + + return ns.updateClientWithoutLock(ctx) +} + +// updateClientWithoutLock takes care of initializing a client (if needed) +// and returns the initialized client and it's connection ID, or an error. +func (ns *Namespace) updateClientWithoutLock(ctx context.Context) (amqpwrap.AMQPClient, uint64, error) { + if ns.client != nil { + return ns.client, ns.connID, nil + } + + connStart := time.Now() + log.Writef(exported.EventConn, "Creating new client, current rev: %d", ns.connID) + + newConnID := ns.connID + 1 + tempClient, err := ns.newClientFn(ctx, newConnID) + + if err != nil { + return nil, 0, err + } + + ns.connID = newConnID + ns.client = tempClient + log.Writef(exported.EventConn, "Client created, new rev: %d, took %dms", ns.connID, time.Since(connStart)/time.Millisecond) + + return ns.client, ns.connID, err +} + +func (ns *Namespace) getWSSHostURI() string { + return fmt.Sprintf("wss://%s/", ns.FQDN) +} + +func (ns *Namespace) getAMQPHostURI() string { + if ns.TokenProvider.InsecureDisableTLS { + return fmt.Sprintf("amqp://%s/", ns.FQDN) + } else { + return fmt.Sprintf("amqps://%s/", ns.FQDN) + } +} + +func (ns *Namespace) GetHTTPSHostURI() string { + return fmt.Sprintf("https://%s/", ns.FQDN) +} + +func (ns *Namespace) GetEntityAudience(entityPath string) string { + return ns.getAMQPHostURI() + entityPath +} + +func (ns *Namespace) getUserAgent() string { + userAgent := rootUserAgent + if ns.userAgent != "" { + userAgent = fmt.Sprintf("%s %s", ns.userAgent, userAgent) + } + return userAgent +} + +// nextClaimRefreshDuration figures out the proper interval for the next authorization +// refresh. +// +// It applies a few real world adjustments: +// - We assume the expiration time is 10 minutes ahead of when it actually is, to adjust for clock drift. +// - We don't let the refresh interval fall below 2 minutes +// - We don't let the refresh interval go above 49 days +// +// This logic is from here: +// https://github.com/Azure/azure-sdk-for-net/blob/bfd3109d0f9afa763131731d78a31e39c81101b3/sdk/servicebus/Azure.Messaging.ServiceBus/src/Amqp/AmqpConnectionScope.cs#L998 +func nextClaimRefreshDuration(expirationTime time.Time, currentTime time.Time) time.Duration { + const min = 2 * time.Minute + const max = 49 * 24 * time.Hour + const clockDrift = 10 * time.Minute + + var refreshDuration = expirationTime.Sub(currentTime) - clockDrift + + if refreshDuration < min { + return min + } else if refreshDuration > max { + return max + } + + return refreshDuration +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/namespace_eh.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/namespace_eh.go new file mode 100644 index 00000000000..3d827c40671 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/namespace_eh.go @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package internal + +import ( + "context" + + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/amqpwrap" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/auth" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/exported" +) + +func (l *rpcLink) LinkName() string { + return l.sender.LinkName() +} + +func (ns *Namespace) NewRPCLink(ctx context.Context, managementPath string) (amqpwrap.RPCLink, uint64, error) { + client, connID, err := ns.GetAMQPClientImpl(ctx) + + if err != nil { + return nil, 0, err + } + + rpcLink, err := NewRPCLink(ctx, RPCLinkArgs{ + Client: client, + Address: managementPath, + LogEvent: exported.EventProducer, + }) + + if err != nil { + return nil, 0, err + } + + return rpcLink, connID, nil +} + +func (ns *Namespace) GetTokenForEntity(eventHub string) (*auth.Token, error) { + audience := ns.GetEntityAudience(eventHub) + return ns.TokenProvider.GetToken(audience) +} + +type NamespaceForManagementOps interface { + NamespaceForAMQPLinks + GetTokenForEntity(eventHub string) (*auth.Token, error) +} + +// TODO: might just consolidate. +type NamespaceForProducerOrConsumer = NamespaceForManagementOps diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/rpc.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/rpc.go new file mode 100644 index 00000000000..056e55251a1 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/rpc.go @@ -0,0 +1,444 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package internal + +import ( + "context" + "errors" + "fmt" + "strings" + "sync" + "time" + + azlog "github.com/Azure/azure-sdk-for-go/sdk/internal/log" + "github.com/Azure/azure-sdk-for-go/sdk/internal/uuid" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/amqpwrap" + "github.com/Azure/go-amqp" +) + +const ( + replyPostfix = "-reply-to-" + statusCodeKey = "status-code" + descriptionKey = "status-description" + defaultReceiverCredits = 1000 +) + +var RPCLinkClosedErr = errors.New("rpc link closed") + +type ( + // rpcLink is the bidirectional communication structure used for CBS negotiation + rpcLink struct { + session amqpwrap.AMQPSession + receiver amqpwrap.AMQPReceiverCloser // *amqp.Receiver + sender amqpwrap.AMQPSenderCloser // *amqp.Sender + + clientAddress string + sessionID *string + id string + + responseMu sync.Mutex + responseRouterClosed chan struct{} + + responseMap map[string]chan rpcResponse + rpcLinkCtx context.Context + rpcLinkCtxCancel context.CancelFunc + broadcastErr error // the error that caused the responseMap to be nil'd + + logEvent azlog.Event + + // for unit tests + uuidNewV4 func() (uuid.UUID, error) + } + + // RPCLinkOption provides a way to customize the construction of a Link + RPCLinkOption func(link *rpcLink) error + + rpcResponse struct { + message *amqp.Message + err error + } +) + +// RPCError is an error from an RPCLink. +// RPCLinks are used for communication with the $management and $cbs links. +type RPCError struct { + Resp *amqpwrap.RPCResponse + Message string +} + +// Error is a string representation of the error. +func (e RPCError) Error() string { + return e.Message +} + +// RPCCode is the code that comes back in the rpc response. This code is intended +// for programs toreact to programatically. +func (e RPCError) RPCCode() int { + return e.Resp.Code +} + +type RPCLinkArgs struct { + Client amqpwrap.AMQPClient + Address string + LogEvent azlog.Event +} + +// NewRPCLink will build a new request response link +func NewRPCLink(ctx context.Context, args RPCLinkArgs) (amqpwrap.RPCLink, error) { + session, err := args.Client.NewSession(ctx, nil) + + if err != nil { + return nil, err + } + + linkID, err := uuid.New() + if err != nil { + _ = session.Close(ctx) + return nil, err + } + + id := linkID.String() + + link := &rpcLink{ + session: session, + clientAddress: strings.Replace("$", "", args.Address, -1) + replyPostfix + id, + id: id, + + uuidNewV4: uuid.New, + responseMap: map[string]chan rpcResponse{}, + responseRouterClosed: make(chan struct{}), + logEvent: args.LogEvent, + } + + sender, err := session.NewSender( + ctx, + args.Address, + "", + nil, + ) + if err != nil { + _ = session.Close(ctx) + return nil, err + } + + receiverOpts := &amqp.ReceiverOptions{ + TargetAddress: link.clientAddress, + Credit: defaultReceiverCredits, + } + + if link.sessionID != nil { + const name = "com.microsoft:session-filter" + const code = uint64(0x00000137000000C) + if link.sessionID == nil { + receiverOpts.Filters = append(receiverOpts.Filters, amqp.NewLinkFilter(name, code, nil)) + } else { + receiverOpts.Filters = append(receiverOpts.Filters, amqp.NewLinkFilter(name, code, link.sessionID)) + } + } + + receiver, err := session.NewReceiver(ctx, args.Address, "", receiverOpts) + if err != nil { + _ = session.Close(ctx) + return nil, err + } + + link.sender = sender + link.receiver = receiver + link.rpcLinkCtx, link.rpcLinkCtxCancel = context.WithCancel(context.Background()) + + go link.responseRouter() + + return link, nil +} + +const responseRouterShutdownMessage = "Response router has shut down" + +// responseRouter is responsible for taking any messages received on the 'response' +// link and forwarding it to the proper channel. The channel is being select'd by the +// original `RPC` call. +func (l *rpcLink) responseRouter() { + defer azlog.Writef(l.logEvent, responseRouterShutdownMessage) + defer close(l.responseRouterClosed) + + for { + res, err := l.receiver.Receive(l.rpcLinkCtx, nil) + + if err != nil { + // if the link or connection has a malfunction that would require it to restart then + // we need to bail out, broadcasting to all affected callers/consumers. + if GetRecoveryKind(err) != RecoveryKindNone { + if IsCancelError(err) { + err = RPCLinkClosedErr + } else { + azlog.Writef(l.logEvent, "Error in RPCLink, stopping response router: %s", err.Error()) + } + + l.broadcastError(err) + break + } + + azlog.Writef(l.logEvent, "Non-fatal error in RPCLink, starting to receive again: %s", err.Error()) + continue + } + + // I don't believe this should happen. The JS version of this same code + // ignores errors as well since responses should always be correlated + // to actual send requests. So this is just here for completeness. + if res == nil { + azlog.Writef(l.logEvent, "RPCLink received no error, but also got no response") + continue + } + + autogenMessageId, ok := res.Properties.CorrelationID.(string) + + if !ok { + azlog.Writef(l.logEvent, "RPCLink message received without a CorrelationID %v", res) + continue + } + + ch := l.deleteChannelFromMap(autogenMessageId) + + if ch == nil { + azlog.Writef(l.logEvent, "RPCLink had no response channel for correlation ID %v", autogenMessageId) + continue + } + + ch <- rpcResponse{message: res, err: err} + } +} + +func (l *rpcLink) RPC(ctx context.Context, msg *amqp.Message) (*amqpwrap.RPCResponse, error) { + resp, err := l.internalRPC(ctx, msg) + + if err != nil { + return nil, amqpwrap.WrapError(err, l.ConnID(), l.LinkName(), "") + } + + return resp, nil +} + +// RPC sends a request and waits on a response for that request +func (l *rpcLink) internalRPC(ctx context.Context, msg *amqp.Message) (*amqpwrap.RPCResponse, error) { + copiedMessage, messageID, err := addMessageID(msg, l.uuidNewV4) + + if err != nil { + return nil, err + } + + // use the copiedMessage from this point + msg = copiedMessage + + const altStatusCodeKey, altDescriptionKey = "statusCode", "statusDescription" + + msg.Properties.ReplyTo = &l.clientAddress + + if msg.ApplicationProperties == nil { + msg.ApplicationProperties = make(map[string]any) + } + + if _, ok := msg.ApplicationProperties["server-timeout"]; !ok { + if deadline, ok := ctx.Deadline(); ok { + msg.ApplicationProperties["server-timeout"] = uint(time.Until(deadline) / time.Millisecond) + } + } + + responseCh := l.addChannelToMap(messageID) + + if responseCh == nil { + return nil, l.broadcastErr + } + + err = l.sender.Send(ctx, msg, nil) + + if err != nil { + l.deleteChannelFromMap(messageID) + return nil, fmt.Errorf("failed to send message with ID %s: %w", messageID, err) + } + + var res *amqp.Message + + select { + case <-ctx.Done(): + l.deleteChannelFromMap(messageID) + res, err = nil, ctx.Err() + case resp := <-responseCh: + // this will get triggered by the loop in 'startReceiverRouter' when it receives + // a message with our autoGenMessageID set in the correlation_id property. + res, err = resp.message, resp.err + } + + if err != nil { + return nil, err + } + + var statusCode int + statusCodeCandidates := []string{statusCodeKey, altStatusCodeKey} + for i := range statusCodeCandidates { + if rawStatusCode, ok := res.ApplicationProperties[statusCodeCandidates[i]]; ok { + if cast, ok := rawStatusCode.(int32); ok { + statusCode = int(cast) + break + } + + return nil, errors.New("status code was not of expected type int32") + } + } + if statusCode == 0 { + return nil, errors.New("status codes was not found on rpc message") + } + + var description string + descriptionCandidates := []string{descriptionKey, altDescriptionKey} + for i := range descriptionCandidates { + if rawDescription, ok := res.ApplicationProperties[descriptionCandidates[i]]; ok { + if description, ok = rawDescription.(string); ok || rawDescription == nil { + break + } else { + return nil, errors.New("status description was not of expected type string") + } + } + } + + response := &amqpwrap.RPCResponse{ + Code: int(statusCode), + Description: description, + Message: res, + } + + if err := l.receiver.AcceptMessage(ctx, res); err != nil { + return response, fmt.Errorf("failed accepting message on rpc link: %w", err) + } + + var rpcErr RPCError + + if asRPCError(response, &rpcErr) { + return nil, rpcErr + } + + return response, err +} + +func (l *rpcLink) ConnID() uint64 { + return l.session.ConnID() +} + +// Close the link receiver, sender and session +func (l *rpcLink) Close(ctx context.Context) error { + l.rpcLinkCtxCancel() + + select { + case <-l.responseRouterClosed: + case <-ctx.Done(): + } + + if l.session != nil { + return l.session.Close(ctx) + } + + return nil +} + +// addChannelToMap adds a channel which will be used by the response router to +// notify when there is a response to the request. +// If l.responseMap is nil (for instance, via broadcastError) this function will +// return nil. +func (l *rpcLink) addChannelToMap(messageID string) chan rpcResponse { + l.responseMu.Lock() + defer l.responseMu.Unlock() + + if l.responseMap == nil { + return nil + } + + responseCh := make(chan rpcResponse, 1) + l.responseMap[messageID] = responseCh + + return responseCh +} + +// deleteChannelFromMap removes the message from our internal map and returns +// a channel that the corresponding RPC() call is waiting on. +// If l.responseMap is nil (for instance, via broadcastError) this function will +// return nil. +func (l *rpcLink) deleteChannelFromMap(messageID string) chan rpcResponse { + l.responseMu.Lock() + defer l.responseMu.Unlock() + + if l.responseMap == nil { + return nil + } + + ch := l.responseMap[messageID] + delete(l.responseMap, messageID) + + return ch +} + +// broadcastError notifies the anyone waiting for a response that the link/session/connection +// has closed. +func (l *rpcLink) broadcastError(err error) { + l.responseMu.Lock() + defer l.responseMu.Unlock() + + for _, ch := range l.responseMap { + ch <- rpcResponse{err: err} + } + + l.broadcastErr = err + l.responseMap = nil +} + +// addMessageID generates a unique UUID for the message. When the service +// responds it will fill out the correlation ID property of the response +// with this ID, allowing us to link the request and response together. +// +// NOTE: this function copies 'message', adding in a 'Properties' object +// if it does not already exist. +func addMessageID(message *amqp.Message, uuidNewV4 func() (uuid.UUID, error)) (*amqp.Message, string, error) { + uuid, err := uuidNewV4() + + if err != nil { + return nil, "", err + } + + autoGenMessageID := uuid.String() + + // we need to modify the message so we'll make a copy + copiedMessage := *message + + if message.Properties == nil { + copiedMessage.Properties = &amqp.MessageProperties{ + MessageID: autoGenMessageID, + } + } else { + // properties already exist, make a copy and then update + // the message ID + copiedProperties := *message.Properties + copiedProperties.MessageID = autoGenMessageID + + copiedMessage.Properties = &copiedProperties + } + + return &copiedMessage, autoGenMessageID, nil +} + +// asRPCError checks to see if the res is actually a failed request +// (where failed means the status code was non-2xx). If so, +// it returns true and updates the struct pointed to by err. +func asRPCError(res *amqpwrap.RPCResponse, err *RPCError) bool { + if res == nil { + return false + } + + if res.Code >= 200 && res.Code < 300 { + return false + } + + *err = RPCError{ + Message: fmt.Sprintf("rpc: failed, status code %d and description: %s", res.Code, res.Description), + Resp: res, + } + + return true +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/sas/sas.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/sas/sas.go new file mode 100644 index 00000000000..0b5854ea277 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/sas/sas.go @@ -0,0 +1,179 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Package sas provides SAS token functionality which implements TokenProvider from package auth for use with Azure +// Event Hubs and Service Bus. + +package sas + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "fmt" + "net/url" + "strconv" + "strings" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/auth" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/exported" +) + +type ( + // Signer provides SAS token generation for use in Service Bus and Event Hub + Signer struct { + KeyName string + Key string + + // getNow is stubabble for unit tests and is just an alias for time.Now() + getNow func() time.Time + } + + // TokenProvider is a SAS claims-based security token provider + TokenProvider struct { + // expiryDuration is only used when we're generating SAS tokens. It gets used + // to calculate the expiration timestamp for a token. Pre-computed SAS tokens + // passed in TokenProviderWithSAS() are not affected. + expiryDuration time.Duration + + signer *Signer + + // sas is a precomputed SAS token. This implies that the caller has some other + // method for generating tokens. + sas string + } + + // TokenProviderOption provides configuration options for SAS Token Providers + TokenProviderOption func(*TokenProvider) error +) + +// TokenProviderWithKey configures a SAS TokenProvider to use the given key name and key (secret) for signing +func TokenProviderWithKey(keyName, key string, expiryDuration time.Duration) TokenProviderOption { + return func(provider *TokenProvider) error { + + if expiryDuration == 0 { + expiryDuration = 2 * time.Hour + } + + provider.expiryDuration = expiryDuration + provider.signer = NewSigner(keyName, key) + return nil + } +} + +// TokenProviderWithSAS configures the token provider with a pre-created SharedAccessSignature. +// auth.Token's coming back from this TokenProvider instance will always have '0' as the expiration +// date. +func TokenProviderWithSAS(sas string) TokenProviderOption { + return func(provider *TokenProvider) error { + provider.sas = sas + return nil + } +} + +// NewTokenProvider builds a SAS claims-based security token provider +func NewTokenProvider(opts ...TokenProviderOption) (*TokenProvider, error) { + provider := new(TokenProvider) + + for _, opt := range opts { + err := opt(provider) + if err != nil { + return nil, err + } + } + return provider, nil +} + +// GetToken gets a CBS SAS token +func (t *TokenProvider) GetToken(audience string) (*auth.Token, error) { + if t.sas != "" { + // the expiration date doesn't matter here so we'll just set it 0. + return auth.NewToken(auth.CBSTokenTypeSAS, t.sas, "0"), nil + } + + signature, expiry, err := t.signer.SignWithDuration(audience, t.expiryDuration) + + if err != nil { + return nil, err + } + + return auth.NewToken(auth.CBSTokenTypeSAS, signature, expiry), nil +} + +// NewSigner builds a new SAS signer for use in generation Service Bus and Event Hub SAS tokens +func NewSigner(keyName, key string) *Signer { + return &Signer{ + KeyName: keyName, + Key: key, + + getNow: time.Now, + } +} + +// SignWithDuration signs a given for a period of time from now +func (s *Signer) SignWithDuration(uri string, interval time.Duration) (signature, expiry string, err error) { + expiry = signatureExpiry(s.getNow().UTC(), interval) + sig, err := s.SignWithExpiry(uri, expiry) + + if err != nil { + return "", "", err + } + + return sig, expiry, nil +} + +// SignWithExpiry signs a given uri with a given expiry string +func (s *Signer) SignWithExpiry(uri, expiry string) (string, error) { + audience := strings.ToLower(url.QueryEscape(uri)) + sts := stringToSign(audience, expiry) + sig, err := s.signString(sts) + + if err != nil { + return "", err + } + + return fmt.Sprintf("SharedAccessSignature sr=%s&sig=%s&se=%s&skn=%s", audience, sig, expiry, s.KeyName), nil +} + +// CreateConnectionStringWithSharedAccessSignature generates a new connection string with +// an embedded SharedAccessSignature and expiration. +// Ex: Endpoint=sb://.servicebus.windows.net;SharedAccessSignature=SharedAccessSignature sr=.servicebus.windows.net&sig=&se=&skn=" +func CreateConnectionStringWithSASUsingExpiry(connectionString string, expiry time.Time) (string, error) { + parsed, err := exported.ParseConnectionString(connectionString) + + if err != nil { + return "", err + } + + signer := NewSigner(*parsed.SharedAccessKeyName, *parsed.SharedAccessKey) + + sig, err := signer.SignWithExpiry(parsed.FullyQualifiedNamespace, fmt.Sprintf("%d", expiry.Unix())) + + if err != nil { + return "", err + } + + return fmt.Sprintf("Endpoint=sb://%s;SharedAccessSignature=%s", parsed.FullyQualifiedNamespace, sig), nil +} + +func signatureExpiry(from time.Time, interval time.Duration) string { + t := from.Add(interval).Round(time.Second).Unix() + return strconv.FormatInt(t, 10) +} + +func stringToSign(uri, expiry string) string { + return uri + "\n" + expiry +} + +func (s *Signer) signString(str string) (string, error) { + h := hmac.New(sha256.New, []byte(s.Key)) + _, err := h.Write([]byte(str)) + + if err != nil { + return "", err + } + + encodedSig := base64.StdEncoding.EncodeToString(h.Sum(nil)) + return url.QueryEscape(encodedSig), nil +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/sbauth/token_provider.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/sbauth/token_provider.go new file mode 100644 index 00000000000..f44dc22aad0 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/sbauth/token_provider.go @@ -0,0 +1,138 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package sbauth + +import ( + "context" + "strconv" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/auth" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/exported" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/sas" +) + +// TokenProvider handles access tokens and expiration calculation for SAS +// keys (via connection strings) or TokenCredentials from Azure Identity. +type TokenProvider struct { + tokenCred azcore.TokenCredential + sasTokenProvider *sas.TokenProvider + + // InsecureDisableTLS disables TLS. This is only used if the user is connecting to localhost + // and is using an emulator connection string. See [ConnectionStringProperties.Emulator] for + // details. + InsecureDisableTLS bool +} + +// NewTokenProvider creates a tokenProvider from azcore.TokenCredential. +func NewTokenProvider(tokenCredential azcore.TokenCredential) *TokenProvider { + return &TokenProvider{tokenCred: tokenCredential} +} + +// NewTokenProviderWithConnectionString creates a tokenProvider from a connection string. +func NewTokenProviderWithConnectionString(props exported.ConnectionStringProperties) (*TokenProvider, error) { + // NOTE: this is the value we've been using since forever. AFAIK, it's arbitrary. + const defaultTokenExpiry = 2 * time.Hour + + var authOption sas.TokenProviderOption + + if props.SharedAccessSignature == nil { + authOption = sas.TokenProviderWithKey(*props.SharedAccessKeyName, *props.SharedAccessKey, defaultTokenExpiry) + } else { + authOption = sas.TokenProviderWithSAS(*props.SharedAccessSignature) + } + + provider, err := sas.NewTokenProvider(authOption) + + if err != nil { + return nil, err + } + + return &TokenProvider{sasTokenProvider: provider, InsecureDisableTLS: props.Emulator}, nil +} + +// singleUseTokenProvider allows you to wrap an *auth.Token so it can be used +// with functions that require a TokenProvider, but only actually should get +// a single token (like cbs.NegotiateClaim) +type singleUseTokenProvider auth.Token + +// GetToken will return this token. +// This function makes us compatible with auth.TokenProvider. +func (tp *singleUseTokenProvider) GetToken(uri string) (*auth.Token, error) { + return (*auth.Token)(tp), nil +} + +// GetToken will retrieve a new token. +// This function makes us compatible with auth.TokenProvider. +func (tp *TokenProvider) GetToken(uri string) (*auth.Token, error) { + token, _, err := tp.getTokenImpl(uri) + return token, err +} + +// GetToken returns a token (that is compatible as an auth.TokenProvider) and +// the calculated time when you should renew your token. +func (tp *TokenProvider) GetTokenAsTokenProvider(uri string) (*singleUseTokenProvider, time.Time, error) { + token, renewAt, err := tp.getTokenImpl(uri) + + if err != nil { + return nil, time.Time{}, err + } + + return (*singleUseTokenProvider)(token), renewAt, nil +} + +func (tp *TokenProvider) getTokenImpl(uri string) (*auth.Token, time.Time, error) { + if tp.sasTokenProvider != nil { + return tp.getSASToken(uri) + } else { + return tp.getAZCoreToken() + } +} + +func (tpa *TokenProvider) getAZCoreToken() (*auth.Token, time.Time, error) { + // not sure if URI plays in here. + accessToken, err := tpa.tokenCred.GetToken(context.TODO(), policy.TokenRequestOptions{ + Scopes: []string{ + "https://eventhubs.azure.net//.default", + }, + }) + + if err != nil { + return nil, time.Time{}, err + } + + authToken := &auth.Token{ + TokenType: auth.CBSTokenTypeJWT, + Token: accessToken.Token, + Expiry: strconv.FormatInt(accessToken.ExpiresOn.Unix(), 10), + } + + return authToken, + accessToken.ExpiresOn, + nil +} + +func (tpa *TokenProvider) getSASToken(uri string) (*auth.Token, time.Time, error) { + authToken, err := tpa.sasTokenProvider.GetToken(uri) + + if err != nil { + return nil, time.Time{}, err + } + + // we can ignore the error here since we did the string-izing of the time + // in the first place. + var expiryTime time.Time + + if authToken.Expiry != "0" { + // TODO: I'd like to just use the actual Expiry time we generated + // Filed here https://github.com/Azure/azure-sdk-for-go/issues/20468 + expiryTime = time.Now().Add(time.Minute * 15) + } + + return authToken, + expiryTime, + nil +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/utils/retrier.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/utils/retrier.go new file mode 100644 index 00000000000..a61eb134934 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/utils/retrier.go @@ -0,0 +1,138 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package utils + +import ( + "context" + "errors" + "math" + "math/rand" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/internal/log" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/exported" +) + +type RetryFnArgs struct { + // I is the iteration of the retry "loop" and starts at 0. + // The 0th iteration is the first call, and doesn't count as a retry. + // The last try will equal RetryOptions.MaxRetries + I int32 + // LastErr is the returned error from the previous loop. + // If you have potentially expensive + LastErr error + + resetAttempts bool +} + +// ResetAttempts resets all Retry() attempts, starting back +// at iteration 0. +func (rf *RetryFnArgs) ResetAttempts() { + rf.resetAttempts = true +} + +// Retry runs a standard retry loop. It executes your passed in fn as the body of the loop. +// It returns if it exceeds the number of configured retry options or if 'isFatal' returns true. +func Retry(ctx context.Context, eventName log.Event, prefix func() string, o exported.RetryOptions, fn func(ctx context.Context, callbackArgs *RetryFnArgs) error, isFatalFn func(err error) bool) error { + if isFatalFn == nil { + panic("isFatalFn is nil, errors would panic") + } + + var ro exported.RetryOptions = o + setDefaults(&ro) + + var err error + + for i := int32(0); i <= ro.MaxRetries; i++ { + if i > 0 { + sleep := calcDelay(ro, i) + log.Writef(eventName, "(%s) Retry attempt %d sleeping for %s", prefix(), i, sleep) + + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(sleep): + } + } + + args := RetryFnArgs{ + I: i, + LastErr: err, + } + err = fn(ctx, &args) + + if args.resetAttempts { + log.Writef(eventName, "(%s) Resetting retry attempts", prefix()) + + // it looks weird, but we're doing -1 here because the post-increment + // will set it back to 0, which is what we want - go back to the 0th + // iteration so we don't sleep before the attempt. + // + // You'll use this when you want to get another "fast" retry attempt. + i = int32(-1) + } + + if err != nil { + if isFatalFn(err) { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + log.Writef(eventName, "(%s) Retry attempt %d was cancelled, stopping: %s", prefix(), i, err.Error()) + } else { + log.Writef(eventName, "(%s) Retry attempt %d returned non-retryable error: %s", prefix(), i, err.Error()) + } + return err + } else { + log.Writef(eventName, "(%s) Retry attempt %d returned retryable error: %s", prefix(), i, err.Error()) + } + + continue + } + + return nil + } + + return err +} + +func setDefaults(o *exported.RetryOptions) { + if o.MaxRetries == 0 { + o.MaxRetries = 3 + } else if o.MaxRetries < 0 { + o.MaxRetries = 0 + } + if o.MaxRetryDelay == 0 { + o.MaxRetryDelay = 120 * time.Second + } else if o.MaxRetryDelay < 0 { + // not really an unlimited cap, but sufficiently large enough to be considered as such + o.MaxRetryDelay = math.MaxInt64 + } + if o.RetryDelay == 0 { + o.RetryDelay = 4 * time.Second + } else if o.RetryDelay < 0 { + o.RetryDelay = 0 + } +} + +// (adapted from from azcore/policy_retry) +func calcDelay(o exported.RetryOptions, try int32) time.Duration { + if try == 0 { + return 0 + } + + pow := func(number int64, exponent int32) int64 { // pow is nested helper function + var result int64 = 1 + for n := int32(0); n < exponent; n++ { + result *= number + } + return result + } + + delay := time.Duration(pow(2, try)-1) * o.RetryDelay + + // Introduce some jitter: [0.0, 1.0) / 2 = [0.0, 0.5) + 0.8 = [0.8, 1.3) + delay = time.Duration(delay.Seconds() * (rand.Float64()/2 + 0.8) * float64(time.Second)) // NOTE: We want math/rand; not crypto/rand + if delay > o.MaxRetryDelay { + delay = o.MaxRetryDelay + } + return delay +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/log.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/log.go new file mode 100644 index 00000000000..9e487007163 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/log.go @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azeventhubs + +import ( + "github.com/Azure/azure-sdk-for-go/sdk/azcore/log" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/exported" +) + +const ( + // EventConn is used whenever we create a connection or any links (ie: producers, consumers). + EventConn log.Event = exported.EventConn + + // EventAuth is used when we're doing authentication/claims negotiation. + EventAuth log.Event = exported.EventAuth + + // EventProducer represents operations that happen on Producers. + EventProducer log.Event = exported.EventProducer + + // EventConsumer represents operations that happen on Consumers. + EventConsumer log.Event = exported.EventConsumer +) diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/mgmt.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/mgmt.go new file mode 100644 index 00000000000..3ac3dba46bd --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/mgmt.go @@ -0,0 +1,253 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azeventhubs + +import ( + "context" + "errors" + "fmt" + "strconv" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/log" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/amqpwrap" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/eh" + "github.com/Azure/go-amqp" +) + +// EventHubProperties represents properties of the Event Hub, like the number of partitions. +type EventHubProperties struct { + CreatedOn time.Time + Name string + PartitionIDs []string +} + +// GetEventHubPropertiesOptions contains optional parameters for the GetEventHubProperties function +type GetEventHubPropertiesOptions struct { + // For future expansion +} + +// getEventHubProperties gets event hub properties, like the available partition IDs and when the Event Hub was created. +func getEventHubProperties[LinkT internal.AMQPLink](ctx context.Context, eventName log.Event, ns internal.NamespaceForManagementOps, links *internal.Links[LinkT], eventHub string, retryOptions RetryOptions, options *GetEventHubPropertiesOptions) (EventHubProperties, error) { + var props EventHubProperties + + err := links.RetryManagement(ctx, eventName, "getEventHubProperties", retryOptions, func(ctx context.Context, lwid internal.LinkWithID[amqpwrap.RPCLink]) error { + tmpProps, err := getEventHubPropertiesInternal(ctx, ns, lwid.Link(), eventHub, options) + + if err != nil { + return err + } + + props = tmpProps + return nil + }) + + return props, err + +} + +func getEventHubPropertiesInternal(ctx context.Context, ns internal.NamespaceForManagementOps, rpcLink amqpwrap.RPCLink, eventHub string, options *GetEventHubPropertiesOptions) (EventHubProperties, error) { + token, err := ns.GetTokenForEntity(eventHub) + + if err != nil { + return EventHubProperties{}, internal.TransformError(err) + } + + amqpMsg := &amqp.Message{ + ApplicationProperties: map[string]any{ + "operation": "READ", + "name": eventHub, + "type": "com.microsoft:eventhub", + "security_token": token.Token, + }, + } + + resp, err := rpcLink.RPC(context.Background(), amqpMsg) + + if err != nil { + return EventHubProperties{}, err + } + + if resp.Code >= 300 { + return EventHubProperties{}, fmt.Errorf("failed getting partition properties: %v", resp.Description) + } + + return newEventHubProperties(resp.Message.Value) +} + +// PartitionProperties are the properties for a single partition. +type PartitionProperties struct { + // BeginningSequenceNumber is the first sequence number for a partition. + BeginningSequenceNumber int64 + // EventHubName is the name of the Event Hub for this partition. + EventHubName string + + // IsEmpty is true if the partition is empty, false otherwise. + IsEmpty bool + + // LastEnqueuedOffset is the offset of latest enqueued event. + LastEnqueuedOffset int64 + + // LastEnqueuedOn is the date of latest enqueued event. + LastEnqueuedOn time.Time + + // LastEnqueuedSequenceNumber is the sequence number of the latest enqueued event. + LastEnqueuedSequenceNumber int64 + + // PartitionID is the partition ID of this partition. + PartitionID string +} + +// GetPartitionPropertiesOptions are the options for the GetPartitionProperties function. +type GetPartitionPropertiesOptions struct { + // For future expansion +} + +// getPartitionProperties gets properties for a specific partition. This includes data like the last enqueued sequence number, the first sequence +// number and when an event was last enqueued to the partition. +func getPartitionProperties[LinkT internal.AMQPLink](ctx context.Context, eventName log.Event, ns internal.NamespaceForManagementOps, links *internal.Links[LinkT], eventHub string, partitionID string, retryOptions RetryOptions, options *GetPartitionPropertiesOptions) (PartitionProperties, error) { + var props PartitionProperties + + err := links.RetryManagement(ctx, eventName, "getPartitionProperties", retryOptions, func(ctx context.Context, lwid internal.LinkWithID[amqpwrap.RPCLink]) error { + tmpProps, err := getPartitionPropertiesInternal(ctx, ns, lwid.Link(), eventHub, partitionID, options) + + if err != nil { + return err + } + + props = tmpProps + return nil + }) + + return props, err +} + +func getPartitionPropertiesInternal(ctx context.Context, ns internal.NamespaceForManagementOps, rpcLink amqpwrap.RPCLink, eventHub string, partitionID string, options *GetPartitionPropertiesOptions) (PartitionProperties, error) { + token, err := ns.GetTokenForEntity(eventHub) + + if err != nil { + return PartitionProperties{}, err + } + + amqpMsg := &amqp.Message{ + ApplicationProperties: map[string]any{ + "operation": "READ", + "name": eventHub, + "type": "com.microsoft:partition", + "partition": partitionID, + "security_token": token.Token, + }, + } + + resp, err := rpcLink.RPC(context.Background(), amqpMsg) + + if err != nil { + return PartitionProperties{}, internal.TransformError(err) + } + + if resp.Code >= 300 { + return PartitionProperties{}, fmt.Errorf("failed getting partition properties: %v", resp.Description) + } + + return newPartitionProperties(resp.Message.Value) +} + +func newEventHubProperties(amqpValue any) (EventHubProperties, error) { + m, ok := amqpValue.(map[string]any) + + if !ok { + return EventHubProperties{}, nil + } + + partitionIDs, ok := m["partition_ids"].([]string) + + if !ok { + return EventHubProperties{}, fmt.Errorf("invalid message format") + } + + name, ok := m["name"].(string) + + if !ok { + return EventHubProperties{}, fmt.Errorf("invalid message format") + } + + createdOn, ok := m["created_at"].(time.Time) + + if !ok { + return EventHubProperties{}, fmt.Errorf("invalid message format") + } + + return EventHubProperties{ + Name: name, + CreatedOn: createdOn, + PartitionIDs: partitionIDs, + }, nil +} + +func newPartitionProperties(amqpValue any) (PartitionProperties, error) { + m, ok := amqpValue.(map[string]any) + + if !ok { + return PartitionProperties{}, errors.New("invalid message format") + } + + eventHubName, ok := m["name"].(string) + + if !ok { + return PartitionProperties{}, errors.New("invalid message format") + } + + partition, ok := m["partition"].(string) + + if !ok { + return PartitionProperties{}, errors.New("invalid message format") + } + + beginningSequenceNumber, ok := eh.ConvertToInt64(m["begin_sequence_number"]) + + if !ok { + return PartitionProperties{}, errors.New("invalid message format") + } + + lastEnqueuedSequenceNumber, ok := eh.ConvertToInt64(m["last_enqueued_sequence_number"]) + + if !ok { + return PartitionProperties{}, errors.New("invalid message format") + } + + lastEnqueuedOffsetStr, ok := m["last_enqueued_offset"].(string) + + if !ok { + return PartitionProperties{}, errors.New("invalid message format") + } + + lastEnqueuedOffset, err := strconv.ParseInt(lastEnqueuedOffsetStr, 10, 64) + + if err != nil { + return PartitionProperties{}, fmt.Errorf("invalid message format: %w", err) + } + + lastEnqueuedTime, ok := m["last_enqueued_time_utc"].(time.Time) + + if !ok { + return PartitionProperties{}, errors.New("invalid message format") + } + + isEmpty, ok := m["is_partition_empty"].(bool) + + if !ok { + return PartitionProperties{}, errors.New("invalid message format") + } + + return PartitionProperties{ + BeginningSequenceNumber: beginningSequenceNumber, + LastEnqueuedSequenceNumber: lastEnqueuedSequenceNumber, + LastEnqueuedOffset: lastEnqueuedOffset, + LastEnqueuedOn: lastEnqueuedTime, + IsEmpty: isEmpty, + PartitionID: partition, + EventHubName: eventHubName, + }, nil +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/migrationguide.md b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/migrationguide.md new file mode 100644 index 00000000000..4388773abdd --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/migrationguide.md @@ -0,0 +1,106 @@ +# Guide to migrate from `azure-event-hubs-go` to `azeventhubs` + +This guide is intended to assist in the migration from the `azure-event-hubs-go` package to the latest beta releases (and eventual GA) of the `github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs`. + +Our goal with this newest package was to export components that can be easily integrated into multiple styles of application, while still mapping close to the underlying resources for AMQP. This includes making TCP connection sharing simple (a must when multiplexing across multiple partitions), making batching boundaries more explicit and also integrating with the `azidentity` package, opening up a large number of authentication methods. + +These changes are described in more detail, below. + +### TCP connection sharing + +In AMQP there are is a concept of a connection and links. AMQP Connections are TCP connections. Links are a logical conduit within an AMQP connection and there are typically many of them but they use the same connection and do not require their own socket. + +The prior version of this package did not allow you to share an AMQP connection when sending events, which meant sending to multiple partitions would require a TCP connection per partition. If your application used more than a few partitions this could use up a scarce resource. + +In the newer version of the library each top-level client (ProducerClient or ConsumerClient) owns their own TCP connection. For instance, in ProducerClient, sending to separate partitions creates multiple links internally, but not multiple TCP connections. ConsumerClient works similarly - it has a single TCP connection and calling ConsumerClient.NewPartitionClient creates new links, but not new TCP connections. + +If you want to split activity across multiple TCP connections you can still do so by creating multiple instances of ProducerClient or ConsumerClient. + +Some examples: + +```go +// consumerClient will own a TCP connection. +consumerClient, err := azeventhubs.NewConsumerClient(/* arguments elided for example */) + +// Close the TCP connection (and any child links) +defer consumerClient.Close(context.TODO()) + +// this call will lazily create a set of AMQP links using the consumerClient's TCP connection. +partClient0, err := consumerClient.NewPartitionClient("0", nil) +defer partClient0.Close(context.TODO()) // will close the AMQP link, not the connection + +// this call will also lazily create a set of AMQP links using the consumerClient's TCP connection. +partClient1, err := consumerClient.NewPartitionClient("1", nil) +defer partClient1.Close(context.TODO()) // will close the AMQP link, not the connection +``` + +```go +// will lazily create an AMQP connection +producerClient, err := azeventhubs.NewProducerClient(/* arguments elided for example */) + +// close the TCP connection (and any child links created for sending events) +defer producerClient.Close(context.TODO()) + +// these calls will lazily create a set of AMQP links using the producerClient's TCP connection. +producerClient.SendEventDataBatch(context.TODO(), eventDataBatchForPartition0, nil) +producerClient.SendEventDataBatch(context.TODO(), eventDataBatchForPartition1, nil) +``` + +## Clients + +The `Hub` type has been replaced by two types: + +* Consuming events, using the `azeventhubs.ConsumerClient`: [docs](https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs#ConsumerClient) | [example](https://github.com/Azure/azure-sdk-for-go/blob/main/sdk/messaging/azeventhubs/example_consuming_events_test.go) +* Sending events, use the `azeventhubs.ProducerClient`: [docs](https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs#ProducerClient) | [example](https://github.com/Azure/azure-sdk-for-go/blob/main/sdk/messaging/azeventhubs/example_producing_events_test.go) + +`EventProcessorHost` has been replaced by the `azeventhubs.Processor` type: [docs](https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs#Processor) | [example](https://github.com/Azure/azure-sdk-for-go/blob/main/sdk/messaging/azeventhubs/example_consuming_with_checkpoints_test.go) + +## Authentication + +The older Event Hubs package provided some authentication methods like hub.NewHubFromEnvironment. These have been replaced by by using Azure Identity credentials from [azidentity](https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/azidentity#section-readme). + +You can also still authenticate using connection strings. + +* `azeventhubs.ConsumerClient`: [using azidentity](https://github.com/Azure/azure-sdk-for-go/blob/a46bd74e113d6a045541b82a0f3f6497011d8417/sdk/messaging/azeventhubs/example_consumerclient_test.go#L16) | [using a connection string](https://github.com/Azure/azure-sdk-for-go/blob/a46bd74e113d6a045541b82a0f3f6497011d8417/sdk/messaging/azeventhubs/example_consumerclient_test.go#L30) + +* `azeventhubs.ProducerClient`: [using azidentity](https://github.com/Azure/azure-sdk-for-go/blob/a46bd74e113d6a045541b82a0f3f6497011d8417/sdk/messaging/azeventhubs/example_producerclient_test.go#L16) | [using a connection string](https://github.com/Azure/azure-sdk-for-go/blob/a46bd74e113d6a045541b82a0f3f6497011d8417/sdk/messaging/azeventhubs/example_producerclient_test.go#L30) + +## EventBatchIterator + +Sending events has changed to be more explicit about when batches are formed and sent. + +The older module had a type (EventBatchIterator). This type has been removed and replaced +with explicit batching, using `azeventhubs.EventDataBatch`. See here for an example: [link](https://github.com/Azure/azure-sdk-for-go/blob/main/sdk/messaging/azeventhubs/example_producing_events_test.go). + +## Getting hub/partition information + +In the older module functions to get the partition IDs, as well as runtime properties +like the last enqueued sequence number were on the `Hub` type. These are now on both +of the client types instead (`ProducerClient`, `ConsumerClient`). + +```go +// old +hub.GetPartitionInformation(context.TODO(), "0") +hub.GetRuntimeInformation(context.TODO()) +``` + +```go +// new + +// equivalent to: hub.GetRuntimeInformation(context.TODO()) +consumerClient.GetEventHubProperties(context.TODO(), nil) + +// equivalent to: hub.GetPartitionInformation +consumerClient.GetPartitionProperties(context.TODO(), "partition-id", nil) + +// +// or, using the ProducerClient +// + +producerClient.GetEventHubProperties(context.TODO(), nil) +producerClient.GetPartitionProperties(context.TODO(), "partition-id", nil) +``` + +## Migrating from a previous checkpoint store + +See here for an example: [link](https://github.com/Azure/azure-sdk-for-go/blob/main/sdk/messaging/azeventhubs/example_checkpoint_migration_test.go) diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/partition_client.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/partition_client.go new file mode 100644 index 00000000000..8eb01be00d7 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/partition_client.go @@ -0,0 +1,380 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azeventhubs + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/internal/log" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/amqpwrap" + + "github.com/Azure/go-amqp" +) + +// DefaultConsumerGroup is the name of the default consumer group in the Event Hubs service. +const DefaultConsumerGroup = "$Default" + +const defaultPrefetchSize = int32(300) + +// defaultLinkRxBuffer is the maximum number of transfer frames we can handle +// on the Receiver. This matches the current default window size that go-amqp +// uses for sessions. +const defaultMaxCreditSize = uint32(5000) + +// StartPosition indicates the position to start receiving events within a partition. +// The default position is Latest. +// +// You can set this in the options for [ConsumerClient.NewPartitionClient]. +type StartPosition struct { + // Offset will start the consumer after the specified offset. Can be exclusive + // or inclusive, based on the Inclusive property. + // NOTE: offsets are not stable values, and might refer to different events over time + // as the Event Hub events reach their age limit and are discarded. + Offset *int64 + + // SequenceNumber will start the consumer after the specified sequence number. Can be exclusive + // or inclusive, based on the Inclusive property. + SequenceNumber *int64 + + // EnqueuedTime will start the consumer before events that were enqueued on or after EnqueuedTime. + // Can be exclusive or inclusive, based on the Inclusive property. + EnqueuedTime *time.Time + + // Inclusive configures whether the events directly at Offset, SequenceNumber or EnqueuedTime will be included (true) + // or excluded (false). + Inclusive bool + + // Earliest will start the consumer at the earliest event. + Earliest *bool + + // Latest will start the consumer after the last event. + Latest *bool +} + +// PartitionClient is used to receive events from an Event Hub partition. +// +// This type is instantiated from the [ConsumerClient] type, using [ConsumerClient.NewPartitionClient]. +type PartitionClient struct { + consumerGroup string + eventHub string + instanceID string + links internal.LinksForPartitionClient[amqpwrap.AMQPReceiverCloser] + offsetExpression string + ownerLevel *int64 + partitionID string + prefetch int32 + retryOptions RetryOptions +} + +// ReceiveEventsOptions contains optional parameters for the ReceiveEvents function +type ReceiveEventsOptions struct { + // For future expansion +} + +// ReceiveEvents receives events until 'count' events have been received or the context has +// expired or been cancelled. +// +// If your ReceiveEvents call appears to be stuck there are some common causes: +// +// 1. The PartitionClientOptions.StartPosition defaults to "Latest" when the client is created. The connection +// is lazily initialized, so it's possible the link was initialized to a position after events you've sent. +// To make this deterministic, you can choose an explicit start point using sequence number, offset or a +// timestamp. See the [PartitionClientOptions.StartPosition] field for more details. +// +// 2. You might have sent the events to a different partition than intended. By default, batches that are +// created using [ProducerClient.NewEventDataBatch] do not target a specific partition. When a partition +// is not specified, Azure Event Hubs service will choose the partition the events will be sent to. +// +// To fix this, you can specify a PartitionID as part of your [EventDataBatchOptions.PartitionID] options or +// open multiple [PartitionClient] instances, one for each partition. You can get the full list of partitions +// at runtime using [ConsumerClient.GetEventHubProperties]. See the "example_consuming_events_test.go" for +// an example of this pattern. +// +// 3. Network issues can cause internal retries. To see log messages related to this use the instructions in +// the example function "Example_enableLogging". +func (pc *PartitionClient) ReceiveEvents(ctx context.Context, count int, options *ReceiveEventsOptions) ([]*ReceivedEventData, error) { + var events []*ReceivedEventData + + prefetchDisabled := pc.prefetch < 0 + + if count <= 0 { + return nil, internal.NewErrNonRetriable("count should be greater than 0") + } + + if prefetchDisabled && count > int(defaultMaxCreditSize) { + return nil, internal.NewErrNonRetriable(fmt.Sprintf("count cannot exceed %d", defaultMaxCreditSize)) + } + + err := pc.links.Retry(ctx, EventConsumer, "ReceiveEvents", pc.partitionID, pc.retryOptions, func(ctx context.Context, lwid internal.LinkWithID[amqpwrap.AMQPReceiverCloser]) error { + events = nil + + if prefetchDisabled { + remainingCredits := lwid.Link().Credits() + + if count > int(remainingCredits) { + newCredits := uint32(count) - remainingCredits + + log.Writef(EventConsumer, "(%s) Have %d outstanding credit, only issuing %d credits", lwid.String(), remainingCredits, newCredits) + + if err := lwid.Link().IssueCredit(newCredits); err != nil { + log.Writef(EventConsumer, "(%s) Error when issuing credits: %s", lwid.String(), err) + return err + } + } + } + + for { + amqpMessage, err := lwid.Link().Receive(ctx, nil) + + if internal.IsOwnershipLostError(err) { + log.Writef(EventConsumer, "(%s) Error, link ownership lost: %s", lwid.String(), err) + events = nil + return err + } + + if err != nil { + prefetched := getAllPrefetched(lwid.Link(), count-len(events)) + + for _, amqpMsg := range prefetched { + re, err := newReceivedEventData(amqpMsg) + + if err != nil { + log.Writef(EventConsumer, "(%s) Failed converting AMQP message to EventData: %s", lwid.String(), err) + return err + } + + events = append(events, re) + + if len(events) == count { + return nil + } + } + + // this lets cancel errors just return + return err + } + + receivedEvent, err := newReceivedEventData(amqpMessage) + + if err != nil { + log.Writef(EventConsumer, "(%s) Failed converting AMQP message to EventData: %s", lwid.String(), err) + return err + } + + events = append(events, receivedEvent) + + if len(events) == count { + return nil + } + } + }) + + if err != nil && len(events) == 0 { + transformedErr := internal.TransformError(err) + log.Writef(EventConsumer, "No events received, returning error %s", transformedErr.Error()) + return nil, transformedErr + } + + numEvents := len(events) + lastSequenceNumber := events[numEvents-1].SequenceNumber + + pc.offsetExpression = formatStartExpressionForSequence(">", lastSequenceNumber) + log.Writef(EventConsumer, "%d Events received, moving sequence to %d", numEvents, lastSequenceNumber) + return events, nil +} + +// Close releases resources for this client. +func (pc *PartitionClient) Close(ctx context.Context) error { + if pc.links != nil { + return pc.links.Close(ctx) + } + + return nil +} + +func (pc *PartitionClient) getEntityPath(partitionID string) string { + return fmt.Sprintf("%s/ConsumerGroups/%s/Partitions/%s", pc.eventHub, pc.consumerGroup, partitionID) +} + +func (pc *PartitionClient) newEventHubConsumerLink(ctx context.Context, session amqpwrap.AMQPSession, entityPath string, partitionID string) (internal.AMQPReceiverCloser, error) { + props := map[string]any{ + // this lets Event Hubs return error messages that identify which Receiver stole ownership (and other things) within + // error messages. + // Ex: (ownershiplost): link detached, reason: *Error{Condition: amqp:link:stolen, Description: New receiver 'EventHubConsumerClientTestID-Interloper' with higher epoch of '1' is created hence current receiver 'EventHubConsumerClientTestID' with epoch '0' is getting disconnected. If you are recreating the receiver, make sure a higher epoch is used. TrackingId:8031553f0000a5060009a59b63f517a0_G4_B22, SystemTracker:riparkdev:eventhub:tests~10922|$default, Timestamp:2023-02-21T19:12:41, Info: map[]} + "com.microsoft:receiver-name": pc.instanceID, + } + + if pc.ownerLevel != nil { + props["com.microsoft:epoch"] = *pc.ownerLevel + } + + receiverOptions := &amqp.ReceiverOptions{ + SettlementMode: to.Ptr(amqp.ReceiverSettleModeFirst), + Filters: []amqp.LinkFilter{ + amqp.NewSelectorFilter(pc.offsetExpression), + }, + Properties: props, + TargetAddress: pc.instanceID, + } + + if pc.prefetch > 0 { + log.Writef(EventConsumer, "Enabling prefetch with %d credits", pc.prefetch) + receiverOptions.Credit = pc.prefetch + } else if pc.prefetch == 0 { + log.Writef(EventConsumer, "Enabling prefetch with %d credits", defaultPrefetchSize) + receiverOptions.Credit = defaultPrefetchSize + } else { + // prefetch is disabled, enable manual credits and enable + // a reasonable default max for the buffer. + log.Writef(EventConsumer, "Disabling prefetch") + receiverOptions.Credit = -1 + } + + log.Writef(EventConsumer, "Creating receiver:\n source:%s\n instanceID: %s\n owner level: %d\n offset: %s\n manual: %v\n prefetch: %d", + entityPath, + pc.instanceID, + pc.ownerLevel, + pc.offsetExpression, + receiverOptions.Credit == -1, + pc.prefetch) + + receiver, err := session.NewReceiver(ctx, entityPath, partitionID, receiverOptions) + + if err != nil { + return nil, err + } + + return receiver, nil +} + +func (pc *PartitionClient) init(ctx context.Context) error { + return pc.links.Retry(ctx, EventConsumer, "Init", pc.partitionID, pc.retryOptions, func(ctx context.Context, lwid internal.LinkWithID[amqpwrap.AMQPReceiverCloser]) error { + return nil + }) +} + +type partitionClientArgs struct { + namespace internal.NamespaceForAMQPLinks + + consumerGroup string + eventHub string + instanceID string + partitionID string + retryOptions RetryOptions +} + +func newPartitionClient(args partitionClientArgs, options *PartitionClientOptions) (*PartitionClient, error) { + if options == nil { + options = &PartitionClientOptions{} + } + + offsetExpr, err := getStartExpression(options.StartPosition) + + if err != nil { + return nil, err + } + + if options.Prefetch > int32(defaultMaxCreditSize) { + // don't allow them to set the prefetch above the session window size. + return nil, internal.NewErrNonRetriable(fmt.Sprintf("options.Prefetch cannot exceed %d", defaultMaxCreditSize)) + } + + client := &PartitionClient{ + consumerGroup: args.consumerGroup, + eventHub: args.eventHub, + offsetExpression: offsetExpr, + ownerLevel: options.OwnerLevel, + partitionID: args.partitionID, + prefetch: options.Prefetch, + retryOptions: args.retryOptions, + instanceID: args.instanceID, + } + + client.links = internal.NewLinks(args.namespace, fmt.Sprintf("%s/$management", client.eventHub), client.getEntityPath, client.newEventHubConsumerLink) + + return client, nil +} + +func getAllPrefetched(receiver amqpwrap.AMQPReceiver, max int) []*amqp.Message { + var messages []*amqp.Message + + for i := 0; i < max; i++ { + msg := receiver.Prefetched() + + if msg == nil { + break + } + + messages = append(messages, msg) + } + + return messages +} + +func getStartExpression(startPosition StartPosition) (string, error) { + gt := ">" + + if startPosition.Inclusive { + gt = ">=" + } + + var errMultipleFieldsSet = errors.New("only a single start point can be set: Earliest, EnqueuedTime, Latest, Offset, or SequenceNumber") + + offsetExpr := "" + + if startPosition.EnqueuedTime != nil { + // time-based, non-inclusive + offsetExpr = fmt.Sprintf("amqp.annotation.x-opt-enqueued-time %s '%d'", gt, startPosition.EnqueuedTime.UnixMilli()) + } + + if startPosition.Offset != nil { + // offset-based, non-inclusive + // ex: amqp.annotation.x-opt-enqueued-time %s '165805323000' + if offsetExpr != "" { + return "", errMultipleFieldsSet + } + + offsetExpr = fmt.Sprintf("amqp.annotation.x-opt-offset %s '%d'", gt, *startPosition.Offset) + } + + if startPosition.Latest != nil && *startPosition.Latest { + if offsetExpr != "" { + return "", errMultipleFieldsSet + } + + offsetExpr = fmt.Sprintf("amqp.annotation.x-opt-offset %s '@latest'", gt) + } + + if startPosition.SequenceNumber != nil { + if offsetExpr != "" { + return "", errMultipleFieldsSet + } + + offsetExpr = formatStartExpressionForSequence(gt, *startPosition.SequenceNumber) + } + + if startPosition.Earliest != nil && *startPosition.Earliest { + if offsetExpr != "" { + return "", errMultipleFieldsSet + } + + return "amqp.annotation.x-opt-offset > '-1'", nil + } + + if offsetExpr != "" { + return offsetExpr, nil + } + + // default to the start + return "amqp.annotation.x-opt-offset > '@latest'", nil +} + +func formatStartExpressionForSequence(op string, sequenceNumber int64) string { + return fmt.Sprintf("amqp.annotation.x-opt-sequence-number %s '%d'", op, sequenceNumber) +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/processor.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/processor.go new file mode 100644 index 00000000000..e7bc3f6039e --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/processor.go @@ -0,0 +1,515 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azeventhubs + +import ( + "context" + "errors" + "fmt" + "math/rand" + "sync" + "sync/atomic" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + azlog "github.com/Azure/azure-sdk-for-go/sdk/internal/log" +) + +// processorOwnerLevel is the owner level we assign to every ProcessorPartitionClient +// created by this Processor. +var processorOwnerLevel = to.Ptr[int64](0) + +// ProcessorStrategy specifies the load balancing strategy used by the Processor. +type ProcessorStrategy string + +const ( + // ProcessorStrategyBalanced will attempt to claim a single partition at a time, until each active + // owner has an equal share of partitions. + // This is the default strategy. + ProcessorStrategyBalanced ProcessorStrategy = "balanced" + + // ProcessorStrategyGreedy will attempt to claim as many partitions at a time as it can, ignoring + // balance. + ProcessorStrategyGreedy ProcessorStrategy = "greedy" +) + +// ProcessorOptions are the options for the NewProcessor +// function. +type ProcessorOptions struct { + // LoadBalancingStrategy dictates how concurrent Processor instances distribute + // ownership of partitions between them. + // The default strategy is ProcessorStrategyBalanced. + LoadBalancingStrategy ProcessorStrategy + + // UpdateInterval controls how often attempt to claim partitions. + // The default value is 10 seconds. + UpdateInterval time.Duration + + // PartitionExpirationDuration is the amount of time before a partition is considered + // unowned. + // The default value is 60 seconds. + PartitionExpirationDuration time.Duration + + // StartPositions are the default start positions (configurable per partition, or with an overall + // default value) if a checkpoint is not found in the CheckpointStore. + // The default position is Latest. + StartPositions StartPositions + + // Prefetch represents the size of the internal prefetch buffer for each ProcessorPartitionClient + // created by this Processor. When set, this client will attempt to always maintain + // an internal cache of events of this size, asynchronously, increasing the odds that + // ReceiveEvents() will use a locally stored cache of events, rather than having to + // wait for events to arrive from the network. + // + // Defaults to 300 events if Prefetch == 0. + // Disabled if Prefetch < 0. + Prefetch int32 +} + +// StartPositions are used if there is no checkpoint for a partition in +// the checkpoint store. +type StartPositions struct { + // PerPartition controls the start position for a specific partition, + // by partition ID. If a partition is not configured here it will default + // to Default start position. + PerPartition map[string]StartPosition + + // Default is used if the partition is not found in the PerPartition map. + Default StartPosition +} + +type state int32 + +const ( + stateNone state = 0 + stateStopped state = 1 + stateRunning state = 2 +) + +// Processor uses a [ConsumerClient] and [CheckpointStore] to provide automatic +// load balancing between multiple Processor instances, even in separate +// processes or on separate machines. +// +// See [example_consuming_with_checkpoints_test.go] for an example, and the function documentation +// for [Run] for a more detailed description of how load balancing works. +// +// [example_consuming_with_checkpoints_test.go]: https://github.com/Azure/azure-sdk-for-go/blob/main/sdk/messaging/azeventhubs/example_consuming_with_checkpoints_test.go +type Processor struct { + stateMu sync.Mutex + state state + + ownershipUpdateInterval time.Duration + defaultStartPositions StartPositions + checkpointStore CheckpointStore + prefetch int32 + + // consumerClient is actually a *azeventhubs.ConsumerClient + // it's an interface here to make testing easier. + consumerClient consumerClientForProcessor + + nextClients chan *ProcessorPartitionClient + nextClientsReady chan struct{} + consumerClientDetails consumerClientDetails + + lb *processorLoadBalancer + + // claimedOwnerships is set to whatever our current ownerships are. The underlying + // value is a []Ownership. + currentOwnerships *atomic.Value +} + +type consumerClientForProcessor interface { + GetEventHubProperties(ctx context.Context, options *GetEventHubPropertiesOptions) (EventHubProperties, error) + NewPartitionClient(partitionID string, options *PartitionClientOptions) (*PartitionClient, error) + getDetails() consumerClientDetails +} + +// NewProcessor creates a Processor. +// +// More information can be found in the documentation for the [Processor] +// type or the [example_consuming_with_checkpoints_test.go] for an example. +// +// [example_consuming_with_checkpoints_test.go]: https://github.com/Azure/azure-sdk-for-go/blob/main/sdk/messaging/azeventhubs/example_consuming_with_checkpoints_test.go +func NewProcessor(consumerClient *ConsumerClient, checkpointStore CheckpointStore, options *ProcessorOptions) (*Processor, error) { + return newProcessorImpl(consumerClient, checkpointStore, options) +} + +func newProcessorImpl(consumerClient consumerClientForProcessor, checkpointStore CheckpointStore, options *ProcessorOptions) (*Processor, error) { + if options == nil { + options = &ProcessorOptions{} + } + + updateInterval := 10 * time.Second + + if options.UpdateInterval != 0 { + updateInterval = options.UpdateInterval + } + + partitionDurationExpiration := time.Minute + + if options.PartitionExpirationDuration != 0 { + partitionDurationExpiration = options.PartitionExpirationDuration + } + + startPosPerPartition := map[string]StartPosition{} + + if options.StartPositions.PerPartition != nil { + for k, v := range options.StartPositions.PerPartition { + startPosPerPartition[k] = v + } + } + + strategy := options.LoadBalancingStrategy + + switch strategy { + case ProcessorStrategyBalanced: + case ProcessorStrategyGreedy: + case "": + strategy = ProcessorStrategyBalanced + default: + return nil, fmt.Errorf("invalid load balancing strategy '%s'", strategy) + } + + currentOwnerships := &atomic.Value{} + currentOwnerships.Store([]Ownership{}) + + return &Processor{ + ownershipUpdateInterval: updateInterval, + consumerClient: consumerClient, + checkpointStore: checkpointStore, + + defaultStartPositions: StartPositions{ + PerPartition: startPosPerPartition, + Default: options.StartPositions.Default, + }, + prefetch: options.Prefetch, + consumerClientDetails: consumerClient.getDetails(), + nextClientsReady: make(chan struct{}), + lb: newProcessorLoadBalancer(checkpointStore, consumerClient.getDetails(), strategy, partitionDurationExpiration), + currentOwnerships: currentOwnerships, + + // `nextClients` will be properly initialized when the user calls + // Run() since it needs to query the # of partitions on the Event Hub. + nextClients: make(chan *ProcessorPartitionClient), + }, nil +} + +// NextPartitionClient will get the next owned [ProcessorPartitionClient] if one is acquired +// or will block until a new one arrives or [Processor.Run] is cancelled. When the Processor +// stops running this function will return nil. +// +// NOTE: You MUST call [ProcessorPartitionClient.Close] on the returned client to avoid +// leaking resources. +// +// See [example_consuming_with_checkpoints_test.go] for an example of typical usage. +// +// [example_consuming_with_checkpoints_test.go]: https://github.com/Azure/azure-sdk-for-go/blob/main/sdk/messaging/azeventhubs/example_consuming_with_checkpoints_test.go +func (p *Processor) NextPartitionClient(ctx context.Context) *ProcessorPartitionClient { + select { + case <-ctx.Done(): + return nil + case <-p.nextClientsReady: + } + + select { + case nextClient := <-p.nextClients: + return nextClient + case <-ctx.Done(): + return nil + } +} + +func (p *Processor) checkState() error { + switch p.state { + case stateNone: + // not running so we can start. And lock out any other users. + p.state = stateRunning + return nil + case stateRunning: + return errors.New("the Processor is currently running. Concurrent calls to Run() are not allowed.") + case stateStopped: + return errors.New("the Processor has been stopped. Create a new instance to start processing again") + default: + return fmt.Errorf("unhandled state value %v", p.state) + } +} + +// Run handles the load balancing loop, blocking until the passed in context is cancelled +// or it encounters an unrecoverable error. On cancellation, it will return a nil error. +// +// This function should run for the lifetime of your application, or for as long as you want +// to continue to claim and process partitions. +// +// Once a Processor has been stopped it cannot be restarted and a new instance must +// be created. +// +// As partitions are claimed new [ProcessorPartitionClient] instances will be returned from +// [Processor.NextPartitionClient]. This can happen at any time, based on new Processor instances +// coming online, as well as other Processors exiting. +// +// [ProcessorPartitionClient] are used like a [PartitionClient] but provide an [ProcessorPartitionClient.UpdateCheckpoint] +// function that will store a checkpoint into the [CheckpointStore]. If the client were to crash, or be restarted +// it will pick up from the last checkpoint. +// +// See [example_consuming_with_checkpoints_test.go] for an example of typical usage. +// +// [example_consuming_with_checkpoints_test.go]: https://github.com/Azure/azure-sdk-for-go/blob/main/sdk/messaging/azeventhubs/example_consuming_with_checkpoints_test.go +func (p *Processor) Run(ctx context.Context) error { + p.stateMu.Lock() + err := p.checkState() + p.stateMu.Unlock() + + if err != nil { + return err + } + + err = p.runImpl(ctx) + + // the context is the proper way to close down the Run() loop, so it's not + // an error and doesn't need to be returned. + if ctx.Err() != nil { + return nil + } + + return err +} + +func (p *Processor) runImpl(ctx context.Context) error { + consumers := &sync.Map{} + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + p.close(ctx, consumers) + }() + + // size the channel to the # of partitions. We can never exceed this size since + // we'll never reclaim a partition that we already have ownership of. + eventHubProperties, err := p.initNextClientsCh(ctx) + + if err != nil { + return err + } + + // do one dispatch immediately + if err := p.dispatch(ctx, eventHubProperties, consumers); err != nil { + return err + } + + // note randSource is not thread-safe but it's not currently used in a way that requires + // it to be. + rnd := rand.New(rand.NewSource(time.Now().UnixNano())) + + for { + select { + case <-ctx.Done(): + return nil + case <-time.After(calculateUpdateInterval(rnd, p.ownershipUpdateInterval)): + if err := p.dispatch(ctx, eventHubProperties, consumers); err != nil { + return err + } + } + } +} + +func calculateUpdateInterval(rnd *rand.Rand, updateInterval time.Duration) time.Duration { + // Introduce some jitter: [0.0, 1.0) / 2 = [0.0, 0.5) + 0.8 = [0.8, 1.3) + // (copied from the retry code for calculating jitter) + return time.Duration(updateInterval.Seconds() * (rnd.Float64()/2 + 0.8) * float64(time.Second)) +} + +func (p *Processor) initNextClientsCh(ctx context.Context) (EventHubProperties, error) { + eventHubProperties, err := p.consumerClient.GetEventHubProperties(ctx, nil) + + if err != nil { + return EventHubProperties{}, err + } + + p.nextClients = make(chan *ProcessorPartitionClient, len(eventHubProperties.PartitionIDs)) + close(p.nextClientsReady) + + return eventHubProperties, nil +} + +// dispatch uses the checkpoint store to figure out which partitions should be processed by this +// instance and starts a PartitionClient, if there isn't one. +// NOTE: due to random number usage in the load balancer, this function is not thread safe. +func (p *Processor) dispatch(ctx context.Context, eventHubProperties EventHubProperties, consumers *sync.Map) error { + ownerships, err := p.lb.LoadBalance(ctx, eventHubProperties.PartitionIDs) + + if err != nil { + return err + } + + checkpoints, err := p.getCheckpointsMap(ctx) + + if err != nil { + return err + } + + wg := sync.WaitGroup{} + + // store off the set of ownerships we claimed this round - when the processor + // shuts down we'll clear them (if we still own them). + tmpOwnerships := make([]Ownership, len(ownerships)) + copy(tmpOwnerships, ownerships) + p.currentOwnerships.Store(tmpOwnerships) + + for _, ownership := range ownerships { + wg.Add(1) + + go func(o Ownership) { + defer wg.Done() + + err := p.addPartitionClient(ctx, o, checkpoints, consumers) + + if err != nil { + azlog.Writef(EventConsumer, "failed to create partition client for partition '%s': %s", o.PartitionID, err.Error()) + } + }(ownership) + } + + wg.Wait() + + return nil +} + +// addPartitionClient creates a ProcessorPartitionClient +func (p *Processor) addPartitionClient(ctx context.Context, ownership Ownership, checkpoints map[string]Checkpoint, consumers *sync.Map) error { + processorPartClient := &ProcessorPartitionClient{ + consumerClientDetails: p.consumerClientDetails, + checkpointStore: p.checkpointStore, + innerClient: nil, + partitionID: ownership.PartitionID, + cleanupFn: func() { + consumers.Delete(ownership.PartitionID) + }, + } + + // RP: I don't want to accidentally end up doing this logic because the user was closing it as we + // were doing our next load balance. + if _, alreadyExists := consumers.LoadOrStore(ownership.PartitionID, processorPartClient); alreadyExists { + return nil + } + + sp, err := p.getStartPosition(checkpoints, ownership) + + if err != nil { + return err + } + + partClient, err := p.consumerClient.NewPartitionClient(ownership.PartitionID, &PartitionClientOptions{ + StartPosition: sp, + OwnerLevel: processorOwnerLevel, + Prefetch: p.prefetch, + }) + + if err != nil { + consumers.Delete(ownership.PartitionID) + return err + } + + // make sure we create the link _now_ - if we're stealing we want to stake a claim _now_, rather than + // later when the user actually calls ReceiveEvents(), since the acquisition of the link is lazy. + if err := partClient.init(ctx); err != nil { + consumers.Delete(ownership.PartitionID) + _ = partClient.Close(ctx) + return err + } + + processorPartClient.innerClient = partClient + + select { + case p.nextClients <- processorPartClient: + return nil + default: + processorPartClient.Close(ctx) + return fmt.Errorf("partitions channel full, consumer for partition %s could not be returned", ownership.PartitionID) + } +} + +func (p *Processor) getStartPosition(checkpoints map[string]Checkpoint, ownership Ownership) (StartPosition, error) { + startPosition := p.defaultStartPositions.Default + cp, hasCheckpoint := checkpoints[ownership.PartitionID] + + if hasCheckpoint { + if cp.Offset != nil { + startPosition = StartPosition{ + Offset: cp.Offset, + } + } else if cp.SequenceNumber != nil { + startPosition = StartPosition{ + SequenceNumber: cp.SequenceNumber, + } + } else { + return StartPosition{}, fmt.Errorf("invalid checkpoint for %s, no offset or sequence number", ownership.PartitionID) + } + } else if p.defaultStartPositions.PerPartition != nil { + defaultStartPosition, exists := p.defaultStartPositions.PerPartition[ownership.PartitionID] + + if exists { + startPosition = defaultStartPosition + } + } + + return startPosition, nil +} + +func (p *Processor) getCheckpointsMap(ctx context.Context) (map[string]Checkpoint, error) { + details := p.consumerClient.getDetails() + checkpoints, err := p.checkpointStore.ListCheckpoints(ctx, details.FullyQualifiedNamespace, details.EventHubName, details.ConsumerGroup, nil) + + if err != nil { + return nil, err + } + + m := map[string]Checkpoint{} + + for _, cp := range checkpoints { + m[cp.PartitionID] = cp + } + + return m, nil +} + +func (p *Processor) close(ctx context.Context, consumersMap *sync.Map) { + consumersMap.Range(func(key, value any) bool { + client := value.(*ProcessorPartitionClient) + + if client != nil { + client.Close(ctx) + } + + return true + }) + + currentOwnerships := p.currentOwnerships.Load().([]Ownership) + + for i := 0; i < len(currentOwnerships); i++ { + currentOwnerships[i].OwnerID = relinquishedOwnershipID + } + + _, err := p.checkpointStore.ClaimOwnership(ctx, currentOwnerships, nil) + + if err != nil { + azlog.Writef(EventConsumer, "Failed to relinquish ownerships. New processors will have to wait for ownerships to expire: %s", err.Error()) + } + + p.stateMu.Lock() + p.state = stateStopped + p.stateMu.Unlock() + + // NextPartitionClient() will quit out now that p.nextClients is closed. + close(p.nextClients) + + select { + case <-p.nextClientsReady: + // already closed + default: + close(p.nextClientsReady) + } +} + +// relinquishedOwnershipID indicates that a partition is immediately available, similar to +// how we treat an ownership that is expired as available. +const relinquishedOwnershipID = "" diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/processor_load_balancer.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/processor_load_balancer.go new file mode 100644 index 00000000000..62ec59a88e4 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/processor_load_balancer.go @@ -0,0 +1,302 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azeventhubs + +import ( + "context" + "fmt" + "math" + "math/rand" + "strings" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/internal/log" +) + +type processorLoadBalancer struct { + checkpointStore CheckpointStore + details consumerClientDetails + strategy ProcessorStrategy + partitionExpirationDuration time.Duration + + // NOTE: when you create your own *rand.Rand it is not thread safe. + rnd *rand.Rand +} + +func newProcessorLoadBalancer(checkpointStore CheckpointStore, details consumerClientDetails, strategy ProcessorStrategy, partitionExpiration time.Duration) *processorLoadBalancer { + return &processorLoadBalancer{ + checkpointStore: checkpointStore, + details: details, + strategy: strategy, + partitionExpirationDuration: partitionExpiration, + rnd: rand.New(rand.NewSource(time.Now().UnixNano())), + } +} + +type loadBalancerInfo struct { + // current are the partitions that _we_ own + current []Ownership + + // unownedOrExpired partitions either had no claim _ever_ or were once + // owned but the ownership claim has expired. + unownedOrExpired []Ownership + + // aboveMax are ownerships where the specific owner has too many partitions + // it contains _all_ the partitions for that particular consumer. + aboveMax []Ownership + + // claimMorePartitions is true when we should try to claim more partitions + // because we're under the limit, or we're in a situation where we could claim + // one extra partition. + claimMorePartitions bool + + // maxAllowed is the maximum number of partitions that other processors are allowed + // to own during this round. It can change based on how many partitions we own and whether + // an 'extra' partition is allowed (ie, partitions %owners is not 0). Look at + // [processorLoadBalancer.getAvailablePartitions] for more details. + maxAllowed int + + raw []Ownership +} + +// loadBalance calls through to the user's configured load balancing algorithm. +// NOTE: this function is NOT thread safe! +func (lb *processorLoadBalancer) LoadBalance(ctx context.Context, partitionIDs []string) ([]Ownership, error) { + lbinfo, err := lb.getAvailablePartitions(ctx, partitionIDs) + + if err != nil { + return nil, err + } + + ownerships := lbinfo.current + + if lbinfo.claimMorePartitions { + switch lb.strategy { + case ProcessorStrategyGreedy: + log.Writef(EventConsumer, "[%s] Using greedy strategy to claim partitions", lb.details.ClientID) + ownerships = lb.greedyLoadBalancer(ctx, lbinfo) + case ProcessorStrategyBalanced: + log.Writef(EventConsumer, "[%s] Using balanced strategy to claim partitions", lb.details.ClientID) + o := lb.balancedLoadBalancer(ctx, lbinfo) + + if o != nil { + ownerships = append(lbinfo.current, *o) + } + default: + return nil, fmt.Errorf("[%s] invalid load balancing strategy '%s'", lb.details.ClientID, lb.strategy) + } + } + + actual, err := lb.checkpointStore.ClaimOwnership(ctx, ownerships, nil) + + if err != nil { + return nil, err + } + + if log.Should(EventConsumer) { + log.Writef(EventConsumer, "[%0.5s] Asked for %s, got %s", lb.details.ClientID, partitionsForOwnerships(ownerships), partitionsForOwnerships(actual)) + } + + return actual, nil +} + +func partitionsForOwnerships(all []Ownership) string { + var parts []string + + for _, o := range all { + parts = append(parts, o.PartitionID) + } + + return strings.Join(parts, ",") +} + +// getAvailablePartitions looks through the ownership list (using the checkpointstore.ListOwnership) and evaluates: +// - Whether we should claim more partitions +// - Which partitions are available - unowned/relinquished, expired or processors that own more than the maximum allowed. +// +// Load balancing happens in individual functions +func (lb *processorLoadBalancer) getAvailablePartitions(ctx context.Context, partitionIDs []string) (loadBalancerInfo, error) { + log.Writef(EventConsumer, "[%s] Listing ownership for %s/%s/%s", lb.details.ClientID, lb.details.FullyQualifiedNamespace, lb.details.EventHubName, lb.details.ConsumerGroup) + + ownerships, err := lb.checkpointStore.ListOwnership(ctx, lb.details.FullyQualifiedNamespace, lb.details.EventHubName, lb.details.ConsumerGroup, nil) + + if err != nil { + return loadBalancerInfo{}, err + } + + alreadyAdded := map[string]bool{} + groupedByOwner := map[string][]Ownership{ + lb.details.ClientID: nil, + } + + var unownedOrExpired []Ownership + + // split out partitions by whether they're currently owned + // and if they're expired/relinquished. + for _, o := range ownerships { + alreadyAdded[o.PartitionID] = true + + if time.Since(o.LastModifiedTime.UTC()) > lb.partitionExpirationDuration { + unownedOrExpired = append(unownedOrExpired, o) + continue + } + + if o.OwnerID == relinquishedOwnershipID { + unownedOrExpired = append(unownedOrExpired, o) + continue + } + + groupedByOwner[o.OwnerID] = append(groupedByOwner[o.OwnerID], o) + } + + numExpired := len(unownedOrExpired) + + // add in all the unowned partitions + for _, partID := range partitionIDs { + if alreadyAdded[partID] { + continue + } + + unownedOrExpired = append(unownedOrExpired, Ownership{ + FullyQualifiedNamespace: lb.details.FullyQualifiedNamespace, + ConsumerGroup: lb.details.ConsumerGroup, + EventHubName: lb.details.EventHubName, + PartitionID: partID, + OwnerID: lb.details.ClientID, + // note that we don't have etag info here since nobody has + // ever owned this partition. + }) + } + + minRequired := len(partitionIDs) / len(groupedByOwner) + maxAllowed := minRequired + allowExtraPartition := len(partitionIDs)%len(groupedByOwner) > 0 + + // only allow owners to keep extra partitions if we've already met our minimum bar. Otherwise + // above the minimum is fair game. + if allowExtraPartition && len(groupedByOwner[lb.details.ClientID]) >= minRequired { + maxAllowed += 1 + } + + var aboveMax []Ownership + + for id, ownerships := range groupedByOwner { + if id == lb.details.ClientID { + continue + } + + if len(ownerships) > maxAllowed { + aboveMax = append(aboveMax, ownerships...) + } + } + + claimMorePartitions := true + current := groupedByOwner[lb.details.ClientID] + + if len(current) >= maxAllowed { + // - I have _exactly_ the right amount + // or + // - I have too many. We expect to have some stolen from us, but we'll maintain + // ownership for now. + claimMorePartitions = false + } else if allowExtraPartition && len(current) == maxAllowed-1 { + // In the 'allowExtraPartition' scenario, some consumers will have an extra partition + // since things don't divide up evenly. We're one under the max, which means we _might_ + // be able to claim another one. + // + // We will attempt to grab _one_ more but only if there are free partitions available + // or if one of the consumers has more than the max allowed. + claimMorePartitions = len(unownedOrExpired) > 0 || len(aboveMax) > 0 + } + + log.Writef(EventConsumer, "[%s] claimMorePartitions: %t, owners: %d, current: %d, unowned: %d, expired: %d, above: %d", + lb.details.ClientID, + claimMorePartitions, + len(groupedByOwner), + len(current), + len(unownedOrExpired)-numExpired, + numExpired, + len(aboveMax)) + + return loadBalancerInfo{ + current: current, + unownedOrExpired: unownedOrExpired, + aboveMax: aboveMax, + claimMorePartitions: claimMorePartitions, + raw: ownerships, + maxAllowed: maxAllowed, + }, nil +} + +// greedyLoadBalancer will attempt to grab as many free partitions as it needs to balance +// in each round. +func (lb *processorLoadBalancer) greedyLoadBalancer(ctx context.Context, lbinfo loadBalancerInfo) []Ownership { + ours := lbinfo.current + + // try claiming from the completely unowned or expires ownerships _first_ + randomOwnerships := getRandomOwnerships(lb.rnd, lbinfo.unownedOrExpired, lbinfo.maxAllowed-len(ours)) + ours = append(ours, randomOwnerships...) + + if len(ours) < lbinfo.maxAllowed { + log.Writef(EventConsumer, "Not enough expired or unowned partitions, will need to steal from other processors") + + // if that's not enough then we'll randomly steal from any owners that had partitions + // above the maximum. + randomOwnerships := getRandomOwnerships(lb.rnd, lbinfo.aboveMax, lbinfo.maxAllowed-len(ours)) + ours = append(ours, randomOwnerships...) + } + + for i := 0; i < len(ours); i++ { + ours[i] = lb.resetOwnership(ours[i]) + } + + return ours +} + +// balancedLoadBalancer attempts to split the partition load out between the available +// consumers so each one has an even amount (or even + 1, if the # of consumers and # +// of partitions doesn't divide evenly). +// +// NOTE: the checkpoint store itself does not have a concept of 'presence' that doesn't +// ALSO involve owning a partition. It's possible for a consumer to get boxed out for a +// bit until it manages to steal at least one partition since the other consumers don't +// know it exists until then. +func (lb *processorLoadBalancer) balancedLoadBalancer(ctx context.Context, lbinfo loadBalancerInfo) *Ownership { + if len(lbinfo.unownedOrExpired) > 0 { + idx := lb.rnd.Intn(len(lbinfo.unownedOrExpired)) + o := lb.resetOwnership(lbinfo.unownedOrExpired[idx]) + return &o + } + + if len(lbinfo.aboveMax) > 0 { + idx := lb.rnd.Intn(len(lbinfo.aboveMax)) + o := lb.resetOwnership(lbinfo.aboveMax[idx]) + return &o + } + + return nil +} + +func (lb *processorLoadBalancer) resetOwnership(o Ownership) Ownership { + o.OwnerID = lb.details.ClientID + return o +} + +func getRandomOwnerships(rnd *rand.Rand, ownerships []Ownership, count int) []Ownership { + limit := int(math.Min(float64(count), float64(len(ownerships)))) + + if limit == 0 { + return nil + } + + choices := rnd.Perm(limit) + + var newOwnerships []Ownership + + for i := 0; i < len(choices); i++ { + newOwnerships = append(newOwnerships, ownerships[choices[i]]) + } + + return newOwnerships +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/processor_partition_client.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/processor_partition_client.go new file mode 100644 index 00000000000..cc52c533da5 --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/processor_partition_client.go @@ -0,0 +1,73 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azeventhubs + +import "context" + +// ProcessorPartitionClient allows you to receive events, similar to a [PartitionClient], with a +// checkpoint store for tracking progress. +// +// This type is instantiated from [Processor.NextPartitionClient], which handles load balancing +// of partition ownership between multiple [Processor] instances. +// +// See [example_consuming_with_checkpoints_test.go] for an example. +// +// NOTE: If you do NOT want to use dynamic load balancing, and would prefer to track state and ownership +// manually, use the [ConsumerClient] instead. +// +// [example_consuming_with_checkpoints_test.go]: https://github.com/Azure/azure-sdk-for-go/blob/main/sdk/messaging/azeventhubs/example_consuming_with_checkpoints_test.go +type ProcessorPartitionClient struct { + partitionID string + innerClient *PartitionClient + checkpointStore CheckpointStore + cleanupFn func() + consumerClientDetails consumerClientDetails +} + +// ReceiveEvents receives events until 'count' events have been received or the context +// has been cancelled. +// +// See [PartitionClient.ReceiveEvents] for more information, including troubleshooting. +func (c *ProcessorPartitionClient) ReceiveEvents(ctx context.Context, count int, options *ReceiveEventsOptions) ([]*ReceivedEventData, error) { + return c.innerClient.ReceiveEvents(ctx, count, options) +} + +// UpdateCheckpoint updates the checkpoint in the CheckpointStore. New Processors will resume after +// this checkpoint for this partition. +func (p *ProcessorPartitionClient) UpdateCheckpoint(ctx context.Context, latestEvent *ReceivedEventData, options *UpdateCheckpointOptions) error { + seq := latestEvent.SequenceNumber + offset := latestEvent.Offset + + return p.checkpointStore.SetCheckpoint(ctx, Checkpoint{ + ConsumerGroup: p.consumerClientDetails.ConsumerGroup, + EventHubName: p.consumerClientDetails.EventHubName, + FullyQualifiedNamespace: p.consumerClientDetails.FullyQualifiedNamespace, + PartitionID: p.partitionID, + SequenceNumber: &seq, + Offset: &offset, + }, nil) +} + +// PartitionID is the partition ID of the partition we're receiving from. +// This will not change during the lifetime of this ProcessorPartitionClient. +func (p *ProcessorPartitionClient) PartitionID() string { + return p.partitionID +} + +// Close releases resources for the partition client. +// This does not close the ConsumerClient that the Processor was started with. +func (c *ProcessorPartitionClient) Close(ctx context.Context) error { + c.cleanupFn() + + if c.innerClient != nil { + return c.innerClient.Close(ctx) + } + + return nil +} + +// UpdateCheckpointOptions contains optional parameters for the [ProcessorPartitionClient.UpdateCheckpoint] function. +type UpdateCheckpointOptions struct { + // For future expansion +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/producer_client.go b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/producer_client.go new file mode 100644 index 00000000000..56e5c9d953a --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/producer_client.go @@ -0,0 +1,312 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azeventhubs + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "net" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + azlog "github.com/Azure/azure-sdk-for-go/sdk/internal/log" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/amqpwrap" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/exported" + "github.com/Azure/go-amqp" +) + +// WebSocketConnParams are passed to your web socket creation function (ClientOptions.NewWebSocketConn) +type WebSocketConnParams = exported.WebSocketConnParams + +// RetryOptions represent the options for retries. +type RetryOptions = exported.RetryOptions + +// ProducerClientOptions contains options for the `NewProducerClient` and `NewProducerClientFromConnectionString` +// functions. +type ProducerClientOptions struct { + // Application ID that will be passed to the namespace. + ApplicationID string + + // NewWebSocketConn is a function that can create a net.Conn for use with websockets. + // For an example, see ExampleNewClient_usingWebsockets() function in example_client_test.go. + NewWebSocketConn func(ctx context.Context, params WebSocketConnParams) (net.Conn, error) + + // RetryOptions controls how often operations are retried from this client and any + // Receivers and Senders created from this client. + RetryOptions RetryOptions + + // TLSConfig configures a client with a custom *tls.Config. + TLSConfig *tls.Config +} + +// ProducerClient can be used to send events to an Event Hub. +type ProducerClient struct { + eventHub string + links *internal.Links[amqpwrap.AMQPSenderCloser] + namespace internal.NamespaceForProducerOrConsumer + retryOptions RetryOptions +} + +// anyPartitionID is what we target if we want to send a message and let Event Hubs pick a partition +// or if we're doing an operation that isn't partition specific, such as querying the management link +// to get event hub properties or partition properties. +const anyPartitionID = "" + +// NewProducerClient creates a ProducerClient which uses an azcore.TokenCredential for authentication. You +// MUST call [ProducerClient.Close] on this client to avoid leaking resources. +// +// The fullyQualifiedNamespace is the Event Hubs namespace name (ex: myeventhub.servicebus.windows.net) +// The credential is one of the credentials in the [azidentity] package. +// +// [azidentity]: https://github.com/Azure/azure-sdk-for-go/blob/main/sdk/azidentity +func NewProducerClient(fullyQualifiedNamespace string, eventHub string, credential azcore.TokenCredential, options *ProducerClientOptions) (*ProducerClient, error) { + return newProducerClientImpl(producerClientCreds{ + fullyQualifiedNamespace: fullyQualifiedNamespace, + credential: credential, + eventHub: eventHub, + }, options) +} + +// NewProducerClientFromConnectionString creates a ProducerClient from a connection string. You +// MUST call [ProducerClient.Close] on this client to avoid leaking resources. +// +// connectionString can be one of two formats - with or without an EntityPath key. +// +// When the connection string does not have an entity path, as shown below, the eventHub parameter cannot +// be empty and should contain the name of your event hub. +// +// Endpoint=sb://.servicebus.windows.net/;SharedAccessKeyName=;SharedAccessKey= +// +// When the connection string DOES have an entity path, as shown below, the eventHub parameter must be empty. +// +// Endpoint=sb://.servicebus.windows.net/;SharedAccessKeyName=;SharedAccessKey=;EntityPath=; +func NewProducerClientFromConnectionString(connectionString string, eventHub string, options *ProducerClientOptions) (*ProducerClient, error) { + props, err := parseConn(connectionString, eventHub) + + if err != nil { + return nil, err + } + + return newProducerClientImpl(producerClientCreds{ + connectionString: connectionString, + eventHub: *props.EntityPath, + }, options) +} + +// EventDataBatchOptions contains optional parameters for the [ProducerClient.NewEventDataBatch] function. +// +// If both PartitionKey and PartitionID are nil, Event Hubs will choose an arbitrary partition +// for any events in this [EventDataBatch]. +type EventDataBatchOptions struct { + // MaxBytes overrides the max size (in bytes) for a batch. + // By default NewEventDataBatch will use the max message size provided by the service. + MaxBytes uint64 + + // PartitionKey is hashed to calculate the partition assignment. Messages and message + // batches with the same PartitionKey are guaranteed to end up in the same partition. + // Note that if you use this option then PartitionID cannot be set. + PartitionKey *string + + // PartitionID is the ID of the partition to send these messages to. + // Note that if you use this option then PartitionKey cannot be set. + PartitionID *string +} + +// NewEventDataBatch can be used to create an EventDataBatch, which can contain multiple +// events. +// +// EventDataBatch contains logic to make sure that the it doesn't exceed the maximum size +// for the Event Hubs link, using it's [azeventhubs.EventDataBatch.AddEventData] function. +// A lower size limit can also be configured through the options. +// +// NOTE: if options is nil or empty, Event Hubs will choose an arbitrary partition for any +// events in this [EventDataBatch]. +// +// If the operation fails it can return an azeventhubs.Error type if the failure is actionable. +func (pc *ProducerClient) NewEventDataBatch(ctx context.Context, options *EventDataBatchOptions) (*EventDataBatch, error) { + var batch *EventDataBatch + + partitionID := anyPartitionID + + if options != nil && options.PartitionID != nil { + partitionID = *options.PartitionID + } + + err := pc.links.Retry(ctx, exported.EventProducer, "NewEventDataBatch", partitionID, pc.retryOptions, func(ctx context.Context, lwid internal.LinkWithID[amqpwrap.AMQPSenderCloser]) error { + tmpBatch, err := newEventDataBatch(lwid.Link(), options) + + if err != nil { + return err + } + + batch = tmpBatch + return nil + }) + + if err != nil { + return nil, internal.TransformError(err) + } + + return batch, nil +} + +// SendEventDataBatchOptions contains optional parameters for the SendEventDataBatch function +type SendEventDataBatchOptions struct { + // For future expansion +} + +// SendEventDataBatch sends an event data batch to Event Hubs. +func (pc *ProducerClient) SendEventDataBatch(ctx context.Context, batch *EventDataBatch, options *SendEventDataBatchOptions) error { + amqpMessage, err := batch.toAMQPMessage() + + if err != nil { + return err + } + + partID := getPartitionID(batch.partitionID) + + err = pc.links.Retry(ctx, exported.EventProducer, "SendEventDataBatch", partID, pc.retryOptions, func(ctx context.Context, lwid internal.LinkWithID[amqpwrap.AMQPSenderCloser]) error { + azlog.Writef(EventProducer, "[%s] Sending message with ID %v to partition %q", lwid.String(), amqpMessage.Properties.MessageID, partID) + return lwid.Link().Send(ctx, amqpMessage, nil) + }) + return internal.TransformError(err) +} + +// GetPartitionProperties gets properties for a specific partition. This includes data like the last enqueued sequence number, the first sequence +// number and when an event was last enqueued to the partition. +func (pc *ProducerClient) GetPartitionProperties(ctx context.Context, partitionID string, options *GetPartitionPropertiesOptions) (PartitionProperties, error) { + return getPartitionProperties(ctx, EventProducer, pc.namespace, pc.links, pc.eventHub, partitionID, pc.retryOptions, options) +} + +// GetEventHubProperties gets event hub properties, like the available partition IDs and when the Event Hub was created. +func (pc *ProducerClient) GetEventHubProperties(ctx context.Context, options *GetEventHubPropertiesOptions) (EventHubProperties, error) { + return getEventHubProperties(ctx, EventProducer, pc.namespace, pc.links, pc.eventHub, pc.retryOptions, options) +} + +// Close releases resources for this client. +func (pc *ProducerClient) Close(ctx context.Context) error { + if err := pc.links.Close(ctx); err != nil { + azlog.Writef(EventProducer, "Failed when closing links while shutting down producer client: %s", err.Error()) + } + return pc.namespace.Close(ctx, true) +} + +func (pc *ProducerClient) getEntityPath(partitionID string) string { + if partitionID != anyPartitionID { + return fmt.Sprintf("%s/Partitions/%s", pc.eventHub, partitionID) + } else { + // this is the "let Event Hubs" decide link - any sends that occur here will + // end up getting distributed to different partitions on the service side, rather + // then being specified in the client. + return pc.eventHub + } +} + +func (pc *ProducerClient) newEventHubProducerLink(ctx context.Context, session amqpwrap.AMQPSession, entityPath string, partitionID string) (amqpwrap.AMQPSenderCloser, error) { + sender, err := session.NewSender(ctx, entityPath, partitionID, &amqp.SenderOptions{ + SettlementMode: to.Ptr(amqp.SenderSettleModeMixed), + RequestedReceiverSettleMode: to.Ptr(amqp.ReceiverSettleModeFirst), + }) + + if err != nil { + return nil, err + } + + return sender, nil +} + +type producerClientCreds struct { + connectionString string + + // the Event Hubs namespace name (ex: myservicebus.servicebus.windows.net) + fullyQualifiedNamespace string + credential azcore.TokenCredential + + eventHub string +} + +func newProducerClientImpl(creds producerClientCreds, options *ProducerClientOptions) (*ProducerClient, error) { + client := &ProducerClient{ + eventHub: creds.eventHub, + } + + var nsOptions []internal.NamespaceOption + + if creds.connectionString != "" { + nsOptions = append(nsOptions, internal.NamespaceWithConnectionString(creds.connectionString)) + } else if creds.credential != nil { + option := internal.NamespaceWithTokenCredential( + creds.fullyQualifiedNamespace, + creds.credential) + + nsOptions = append(nsOptions, option) + } + + if options != nil { + client.retryOptions = options.RetryOptions + + if options.TLSConfig != nil { + nsOptions = append(nsOptions, internal.NamespaceWithTLSConfig(options.TLSConfig)) + } + + if options.NewWebSocketConn != nil { + nsOptions = append(nsOptions, internal.NamespaceWithWebSocket(options.NewWebSocketConn)) + } + + if options.ApplicationID != "" { + nsOptions = append(nsOptions, internal.NamespaceWithUserAgent(options.ApplicationID)) + } + + nsOptions = append(nsOptions, internal.NamespaceWithRetryOptions(options.RetryOptions)) + } + + tmpNS, err := internal.NewNamespace(nsOptions...) + + if err != nil { + return nil, err + } + + client.namespace = tmpNS + + client.links = internal.NewLinks(tmpNS, fmt.Sprintf("%s/$management", client.eventHub), client.getEntityPath, client.newEventHubProducerLink) + + return client, err +} + +// parseConn parses the connection string and ensures that the returned [exported.ConnectionStringProperties] +// has an EntityPath set, either from the connection string or using the eventHub parameter. +// +// If the connection string has an EntityPath then eventHub must be empty. +// If the connection string does not have an entity path then the eventHub must contain a value. +func parseConn(connectionString string, eventHub string) (exported.ConnectionStringProperties, error) { + props, err := exported.ParseConnectionString(connectionString) + + if err != nil { + return exported.ConnectionStringProperties{}, err + } + + if props.EntityPath == nil { + if eventHub == "" { + return exported.ConnectionStringProperties{}, errors.New("connection string does not contain an EntityPath. eventHub cannot be an empty string") + } + props.EntityPath = &eventHub + } else { + if eventHub != "" { + return exported.ConnectionStringProperties{}, errors.New("connection string contains an EntityPath. eventHub must be an empty string") + } + } + + return props, nil +} + +func getPartitionID(partitionID *string) string { + if partitionID != nil { + return *partitionID + } + + return anyPartitionID +} diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/sample.env b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/sample.env new file mode 100644 index 00000000000..f8687bcf89a --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/sample.env @@ -0,0 +1,20 @@ +# These are environment variables you'll need to run tests and +# samples in this package. +# NOTE: Rename this file to .env before running any tests. + +# The connection string for your event hub: +# Endpoint=sb://.servicebus.windows.net/;SharedAccessKeyName=;SharedAccessKey=;EntityPath= +EVENTHUB_CONNECTION_STRING=event-hub-connection-string + +# Your Event Hub namespace: +# .servicebus.windows.net> +EVENTHUB_NAMESPACE=event-hub-namespace + +# The name of the event hub, within your Event Hub namespace +EVENTHUB_NAME=event-hub-name + +# Checkpoint store information + +# Azure storage account connection string +# DefaultEndpointsProtocol=https;AccountName=;AccountKey=;EndpointSuffix=core.windows.net +CHECKPOINTSTORE_STORAGE_CONNECTION_STRING=storage-connection-string diff --git a/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/test-resources.bicep b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/test-resources.bicep new file mode 100644 index 00000000000..1310fee901a --- /dev/null +++ b/vendor/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/test-resources.bicep @@ -0,0 +1,225 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +@description('The base resource name.') +param baseName string = resourceGroup().name + +#disable-next-line no-hardcoded-env-urls // it's flagging the help string. +@description('Storage endpoint suffix. The default value uses Azure Public Cloud (ie: core.windows.net)') +param storageEndpointSuffix string = environment().suffixes.storage + +@description('The resource location') +param location string = resourceGroup().location + +var apiVersion = '2017-04-01' +var storageApiVersion = '2019-04-01' +var namespaceName = baseName +var storageAccountName = 'storage${baseName}' +var containerName = 'container' +var iotName = 'iot${baseName}' +var authorizationName = '${baseName}/RootManageSharedAccessKey' + +resource namespace 'Microsoft.EventHub/namespaces@2017-04-01' = { + name: namespaceName + location: location + sku: { + name: 'Standard' + tier: 'Standard' + capacity: 5 + } + properties: { + isAutoInflateEnabled: false + maximumThroughputUnits: 0 + } +} + +resource authorization 'Microsoft.EventHub/namespaces/AuthorizationRules@2017-04-01' = { + name: authorizationName + properties: { + rights: [ + 'Listen' + 'Manage' + 'Send' + ] + } + dependsOn: [ + namespace + ] +} + +resource authorizedListenOnly 'Microsoft.EventHub/namespaces/AuthorizationRules@2017-04-01' = { + name: 'ListenOnly' + parent: namespace + properties: { + rights: [ + 'Listen' + ] + } +} + +resource authorizedSendOnly 'Microsoft.EventHub/namespaces/AuthorizationRules@2017-04-01' = { + name: 'SendOnly' + parent: namespace + properties: { + rights: [ + 'Send' + ] + } +} + +resource eventHub 'Microsoft.EventHub/namespaces/eventhubs@2017-04-01' = { + name: 'eventhub' + properties: { + messageRetentionInDays: 1 + partitionCount: 4 + } + parent: namespace +} + +resource linksonly 'Microsoft.EventHub/namespaces/eventhubs@2017-04-01' = { + name: 'linksonly' + properties: { + messageRetentionInDays: 1 + partitionCount: 1 + } + parent: namespace +} + +resource namespaceName_default 'Microsoft.EventHub/namespaces/networkRuleSets@2017-04-01' = { + name: 'default' + parent: namespace + properties: { + defaultAction: 'Deny' + virtualNetworkRules: [] + ipRules: [] + } +} + +resource eventHubNameFull_Default 'Microsoft.EventHub/namespaces/eventhubs/consumergroups@2017-04-01' = { + name: '$Default' + properties: {} + parent: eventHub +} + +resource storageAccount 'Microsoft.Storage/storageAccounts@2019-04-01' = { + name: storageAccountName + location: location + sku: { + name: 'Standard_RAGRS' + } + kind: 'StorageV2' + properties: { + networkAcls: { + bypass: 'AzureServices' + virtualNetworkRules: [] + ipRules: [] + defaultAction: 'Allow' + } + supportsHttpsTrafficOnly: true + encryption: { + services: { + file: { + enabled: true + } + blob: { + enabled: true + } + } + keySource: 'Microsoft.Storage' + } + accessTier: 'Hot' + } +} + +resource storageAccountName_default_container 'Microsoft.Storage/storageAccounts/blobServices/containers@2019-04-01' = { + name: '${storageAccountName}/default/${containerName}' + dependsOn: [ + storageAccount + ] +} + +resource iot 'Microsoft.Devices/IotHubs@2018-04-01' = { + name: iotName + location: location + sku: { + name: 'S1' + capacity: 1 + } + properties: { + ipFilterRules: [] + eventHubEndpoints: { + events: { + retentionTimeInDays: 1 + partitionCount: 4 + } + } + routing: { + endpoints: { + serviceBusQueues: [] + serviceBusTopics: [] + eventHubs: [] + storageContainers: [] + } + routes: [] + fallbackRoute: { + name: '$fallback' + source: 'DeviceMessages' + condition: 'true' + endpointNames: [ + 'events' + ] + isEnabled: true + } + } + storageEndpoints: { + '$default': { + sasTtlAsIso8601: 'PT1H' + connectionString: 'DefaultEndpointsProtocol=https;AccountName=${storageAccountName};AccountKey=${listKeys(storageAccount.id, storageApiVersion).keys[0].value};EndpointSuffix=${storageEndpointSuffix}' + containerName: containerName + } + } + messagingEndpoints: { + fileNotifications: { + lockDurationAsIso8601: 'PT1M' + ttlAsIso8601: 'PT1H' + maxDeliveryCount: 10 + } + } + enableFileUploadNotifications: false + cloudToDevice: { + maxDeliveryCount: 10 + defaultTtlAsIso8601: 'PT1H' + feedback: { + lockDurationAsIso8601: 'PT1M' + ttlAsIso8601: 'PT1H' + maxDeliveryCount: 10 + } + } + features: 'None' + } +} +output IOTHUB_CONNECTION_STRING string = 'HostName=${reference(iot.id, providers('Microsoft.Devices', 'IoTHubs').apiVersions[0]).hostName};SharedAccessKeyName=iothubowner;SharedAccessKey=${listKeys(iot.id, providers('Microsoft.Devices', 'IoTHubs').apiVersions[0]).value[0].primaryKey}' + +// used for TokenCredential tests +output EVENTHUB_NAMESPACE string = '${namespace.name}.servicebus.windows.net' +output CHECKPOINTSTORE_STORAGE_ENDPOINT string = storageAccount.properties.primaryEndpoints.blob +output EVENTHUB_NAME string = eventHub.name +output EVENTHUB_LINKSONLY_NAME string = linksonly.name + +// connection strings +output EVENTHUB_CONNECTION_STRING string = listKeys( + resourceId('Microsoft.EventHub/namespaces/authorizationRules', namespaceName, 'RootManageSharedAccessKey'), + apiVersion +).primaryConnectionString +output EVENTHUB_CONNECTION_STRING_LISTEN_ONLY string = listKeys( + resourceId('Microsoft.EventHub/namespaces/authorizationRules', namespaceName, authorizedListenOnly.name), + apiVersion +).primaryConnectionString +output EVENTHUB_CONNECTION_STRING_SEND_ONLY string = listKeys( + resourceId('Microsoft.EventHub/namespaces/authorizationRules', namespaceName, authorizedSendOnly.name), + apiVersion +).primaryConnectionString +output CHECKPOINTSTORE_STORAGE_CONNECTION_STRING string = 'DefaultEndpointsProtocol=https;AccountName=${storageAccountName};AccountKey=${listKeys(storageAccount.id, storageApiVersion).keys[0].value};EndpointSuffix=${storageEndpointSuffix}' + +output RESOURCE_GROUP string = resourceGroup().name +output AZURE_SUBSCRIPTION_ID string = subscription().subscriptionId diff --git a/vendor/github.com/Azure/go-amqp/.gitattributes b/vendor/github.com/Azure/go-amqp/.gitattributes new file mode 100644 index 00000000000..854069f01fe --- /dev/null +++ b/vendor/github.com/Azure/go-amqp/.gitattributes @@ -0,0 +1,3 @@ +# Binary files (no line-ending conversions), diff using hexdump +*.bin binary diff=hex + diff --git a/vendor/github.com/Azure/go-amqp/.gitignore b/vendor/github.com/Azure/go-amqp/.gitignore new file mode 100644 index 00000000000..7241bede312 --- /dev/null +++ b/vendor/github.com/Azure/go-amqp/.gitignore @@ -0,0 +1,12 @@ +amqp.test +/fuzz/*/* +!/fuzz/*/corpus +/fuzz/*.zip +*.log +/cmd +cover.out +.envrc +recordings +.vscode +.idea +*.env \ No newline at end of file diff --git a/vendor/github.com/Azure/go-amqp/CHANGELOG.md b/vendor/github.com/Azure/go-amqp/CHANGELOG.md new file mode 100644 index 00000000000..51bc9ab57ae --- /dev/null +++ b/vendor/github.com/Azure/go-amqp/CHANGELOG.md @@ -0,0 +1,174 @@ +# Release History + +## 1.0.5 (2024-03-04) + +### Bugs Fixed + +* Fixed an issue that could cause delays when parsing small frames. + +## 1.0.4 (2024-01-16) + +### Other Changes + +* A `Receiver`'s unsettled messages are tracked as a count (currently used for diagnostic purposes only). + +## 1.0.3 (2024-01-09) + +### Bugs Fixed + +* Fixed an issue that could cause a memory leak when settling messages across `Receiver` instances. + +## 1.0.2 (2023-09-05) + +### Bugs Fixed + +* Fixed an issue that could cause frames to be sent even when the provided `context.Context` was cancelled. +* Fixed a potential hang in `Sender.Send()` that could happen in rare circumstances. +* Ensure that `Sender`'s delivery count and link credit are updated when a transfer fails to send due to context cancellation/timeout. + +## 1.0.1 (2023-06-08) + +### Bugs Fixed + +* Fixed an issue that could cause links to terminate with error "received disposition frame with unknown link handle X". + +## 1.0.0 (2023-05-04) + +### Features Added + +* Added `ConnOptions.WriteTimeout` to control the write deadline when writing to `net.Conn`. + +### Bugs Fixed + +* Calling `Dial()` with a cancelled context doesn't create a connection. +* Context cancellation is properly honored in calls to `Dial()` and `NewConn()`. +* Fixed potential race during `Conn.Close()`. +* Disable sending frames when closing `Session`, `Sender`, and `Receiver`. +* Don't leak in-flight messages when a message settlement API is cancelled or times out waiting for acknowledgement. +* `Sender.Send()` will return an `*amqp.Error` with condition `amqp.ErrCondTransferLimitExceeded` when attempting to send a transfer on a link with no credit. +* `Sender.Send()` will return an `*amqp.Error` with condition `amqp.ErrCondMessageSizeExceeded` if the message or delivery tag size exceeds the maximum allowed size for the link. + +### Other Changes + +* Debug logging includes the address of the object that's writing a log entry. +* Context expiration or cancellation when creating instances of `Session`, `Receiver`, and `Sender` no longer result in the potential for `Conn` to unexpectedly terminate. +* Session channel and link handle exhaustion will now return `*ConnError` and `*SessionError` respectively, closing the respective `Conn` or `Session`. +* If a `context.Context` contains a deadline/timeout, that value will be used as the write deadline when writing to `net.Conn`. + +## 0.19.1 (2023-03-31) + +### Bugs Fixed + +* Fixed a race closing a `Session`, `Receiver`, or `Sender` in succession when the first attempt times out. +* Check the `LinkError.RemoteErr` field when determining if a link was cleanly closed. + +## 0.19.0 (2023-03-30) + +### Breaking Changes + +* `Dial()` and `NewConn()` now require a `context.Context` as their first parameter. + * As a result, the `ConnOptions.Timeout` field has been removed. +* Methods `Sender.Send()` and `Receiver.Receive()` now take their respective options-type as the final argument. +* The `ManualCredits` field in `ReceiverOptions` has been consolidated into field `Credit`. +* Renamed fields in the `ReceiverOptions` for configuring options on the source. +* Renamed `DetachError` to `LinkError` as "detach" has a specific meaning which doesn't equate to the returned link errors. +* The `Receiver.DrainCredit()` API has been removed. +* Removed fields `Batching` and `BatchMaxAge` in `ReceiverOptions`. +* The `IncomingWindow` and `OutgoingWindow` fields in `SessionOptions` have been removed. +* The field `SenderOptions.IgnoreDispositionErrors` has been removed. + * By default, messages that are rejected by the peer no longer close the `Sender`. +* The field `SendSettled` in type `Message` has been moved to type `SendOptions` and renamed as `Settled`. +* The following type aliases have been removed. + * `Address`, `Binary`, `MessageID`, `SequenceNumber`, `Symbol` +* Method `Message.LinkName()` has been removed. + +### Bugs Fixed + +* Don't discard incoming frames while closing a Session. +* Client-side termination of a Session due to invalid state will wait for the peer to acknowledge the Session's end. +* Fixed an issue that could cause `creditor.Drain()` to return the wrong error when a link is terminated. +* Ensure that `Receiver.Receive()` drains prefetched messages when the link closed. +* Fixed an issue that could cause closing a `Receiver` to hang under certain circumstances. +* In `Receiver.Drain()`, wake up `Receiver.mux()` after the drain bit has been set. + +### Other Changes + +* Debug logging has been cleaned up to reduce the number of redundant entries and consolidate the entry format. + * DEBUG_LEVEL 1 now captures all sent/received frames along with basic flow control information. + * Higher debug levels add entries when a frame transitions across mux boundaries and other diagnostics info. +* Document default values for incoming and outgoing windows. +* Refactored handling of incoming frames to eliminate potential deadlocks due to "mux pumping". +* Disallow sending of frames once the end performative has been sent. +* Clean up client-side state when a `context.Context` expires or is cancelled and document the potential side-effects. +* Unexpected frames will now terminate a `Session`, `Receiver`, or `Sender` as required. +* Cleaned up tests that triggered the race detector. + +## 0.18.1 (2023-01-17) + +### Bugs Fixed + +* Fixed an issue that could cause `Conn.connReader()` to become blocked in rare circumstances. +* Fixed an issue that could cause outgoing transfers to be rejected by some brokers due to out-of-sequence delivery IDs. +* Fixed an issue that could cause senders and receivers within the same session to deadlock if the receiver was configured with `ReceiverSettleModeFirst`. +* Enabled support for senders in an at-most-once configuration. + +### Other Changes + +* The connection mux goroutine has been removed, eliminating a potential source of deadlocks. +* Automatic link flow control is built on the manual creditor. +* Clarified docs that messages received from a sender configured in a mode other than `SenderSettleModeSettled` must be acknowledged. +* Clarified default value for `Conn.IdleTimeout` and removed unit prefix. + +## 0.18.0 (2022-12-06) + +### Features Added +* Added `ConnError` type that's returned when a connection is no longer functional. +* Added `SessionError` type that's returned when a session has been closed. +* Added `SASLType` used when configuring the SASL authentication mechanism. +* Added `Ptr()` method to `SenderSettleMode` and `ReceiverSettleMode` types. + +### Breaking Changes +* The minimum version of Go required to build this module is now 1.18. +* The type `Client` has been renamed to `Conn`, and its constructor `New()` renamed to `NewConn()`. +* Removed `ErrConnClosed`, `ErrSessionClosed`, `ErrLinkClosed`, and `ErrTimeout` sentinel error types. +* The following methods now require a `context.Context` as their first parameter. + * `Conn.NewSession()`, `Session.NewReceiver()`, `Session.NewSender()` +* Removed `context.Context` parameter and `error` return from method `Receiver.Prefetched()`. +* The following type names had the prefix `AMQP` removed to prevent stuttering. + * `AMQPAddress`, `AMQPMessageID`, `AMQPSymbol`, `AMQPSequenceNumber`, `AMQPBinary` +* Various `Default*` constants are no longer exported. +* The args to `Receiver.ModifyMessage()` have changed. +* The "variadic config" pattern for `Conn`, `Session`, `Sender`, and `Receiver` constructors has been replaced with a struct-based config. + * This removes the `ConnOption`, `SessionOption`, and `LinkOption` types and all of the associated configuration funcs. + * The sender and receiver specific link options have been moved into their respective options types. + * The `ConnTLS()` option was removed as part of this change. +* The `Dial()` and `New()` constructors now require an `*ConnOptions` parameter. +* `Conn.NewSession()` now requires a `*SessionOptions` parameter. +* `Session.NewSender()` now requires `target` address and `*SenderOptions` parameters. +* `Session.NewReceiver()` now requires `source` address and `*ReceiverOptions` parameters. +* The various SASL configuration funcs have been slightly renamed. +* The following constant types had their values renamed in accordance with the SDK design guidelines. + * `SenderSettleMode`, `ReceiverSettleMode`, `ExpiryPolicy` +* Constant type `ErrorCondition` has been renamed to `ErrCond`. + * The `ErrCond` values have had their names updated to include the `ErrCond` prefix. +* `LinkFilterSource` and `LinkFilterSelector` have been renamed to `NewLinkFilter` and `NewSelectorFilter` respectively. +* The `RemoteError` field in `DetachError` has been renamed. + +### Bugs Fixed +* Fixed potential panic in `muxHandleFrame()` when checking for manual creditor. +* Fixed potential panic in `attachLink()` when copying source filters. +* `NewConn()` will no longer return a broken `*Conn` in some instances. +* Incoming transfer frames received during initial link detach are no longer discarded. +* Session will no longer flood peer with flow frames when half its incoming window is consumed. +* Newly created `Session` won't leak if the context passed to `Conn.NewSession()` expires before exit. +* Newly created `link` won't leak if the context passed to `link.attach()` expires before exit. +* Fixed an issue causing dispositions to hang indefinitely with batching enabled when the receiver link is detached. + +### Other Changes +* Errors when reading/writing to the underlying `net.Conn` are now wrapped in a `ConnError` type. +* Disambiguate error message for distinct cases where a session wasn't found for the specified remote channel. +* Removed `link.Paused` as it didn't add much value and was broken in some cases. +* Only send one flow frame when a drain has been requested. +* Session window size increased to 5000. +* Creation and deletion of `Session` instances have been made deterministic. +* Allocation and deallocation of link handles has been made deterministic. diff --git a/vendor/github.com/Azure/go-amqp/CODE_OF_CONDUCT.md b/vendor/github.com/Azure/go-amqp/CODE_OF_CONDUCT.md new file mode 100644 index 00000000000..c72a5749c52 --- /dev/null +++ b/vendor/github.com/Azure/go-amqp/CODE_OF_CONDUCT.md @@ -0,0 +1,9 @@ +# Microsoft Open Source Code of Conduct + +This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). + +Resources: + +- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) +- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) +- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns diff --git a/vendor/github.com/Azure/go-amqp/CONTRIBUTING.md b/vendor/github.com/Azure/go-amqp/CONTRIBUTING.md new file mode 100644 index 00000000000..8275c5bca55 --- /dev/null +++ b/vendor/github.com/Azure/go-amqp/CONTRIBUTING.md @@ -0,0 +1,76 @@ +# Azure/go-amqp Contributing Guide + +Thank you for your interest in contributing to go-amqp. + +- For reporting bugs, requesting features, or asking for support, please file an issue in the [issues](https://github.com/Azure/go-amqp/issues) section of the project. + +- If you would like to become an active contributor to this project please follow the instructions provided in [Microsoft Azure Projects Contribution Guidelines](https://azure.github.io/azure-sdk/policies_opensource.html). + +- To make code changes, or contribute something new, please follow the [GitHub Forks / Pull requests model](https://help.github.com/articles/fork-a-repo/): Fork the repo, make the change and propose it back by submitting a pull request. + +## Pull Requests + +- **DO** follow the API design and implementation [Go Guidelines](https://azure.github.io/azure-sdk/golang_introduction.html). + - When submitting large changes or features, **DO** have an issue or spec doc that describes the design, usage, and motivating scenario. +- **DO** submit all code changes via pull requests (PRs) rather than through a direct commit. PRs will be reviewed and potentially merged by the repo maintainers after a peer review that includes at least one maintainer. +- **DO** review your own PR to make sure there are no unintended changes or commits before submitting it. +- **DO NOT** submit "work in progress" PRs. A PR should only be submitted when it is considered ready for review and subsequent merging by the contributor. + - If the change is work-in-progress or an experiment, **DO** start off as a temporary draft PR. +- **DO** give PRs short-but-descriptive names (e.g. "Improve code coverage for sender by 10%", not "Fix #1234") and add a description which explains why the change is being made. +- **DO** refer to any relevant issues, and include [keywords](https://help.github.com/articles/closing-issues-via-commit-messages/) that automatically close issues when the PR is merged. +- **DO** tag any users that should know about and/or review the change. +- **DO** ensure each commit successfully builds. The entire PR must pass all tests in the Continuous Integration (CI) system before it'll be merged. +- **DO** address PR feedback in an additional commit(s) rather than amending the existing commits, and only rebase/squash them when necessary. This makes it easier for reviewers to track changes. +- **DO** assume that ["Squash and Merge"](https://github.com/blog/2141-squash-your-commits) will be used to merge your commit unless you request otherwise in the PR. +- **DO NOT** mix independent, unrelated changes in one PR. Separate real product/test code changes from larger code formatting/dead code removal changes. Separate unrelated fixes into separate PRs, especially if they are in different modules or files that otherwise wouldn't be changed. +- **DO** comment your code focusing on "why", where necessary. Otherwise, aim to keep it self-documenting with appropriate names and style. +- **DO** add [GoDoc style comments](https://azure.github.io/azure-sdk/golang_introduction.html#documentation-style) when adding new APIs or modifying header files. +- **DO** make sure there are no typos or spelling errors, especially in user-facing documentation. +- **DO** verify if your changes have impact elsewhere. For instance, do you need to update other docs or exiting markdown files that might be impacted? +- **DO** add relevant unit tests to ensure CI will catch future regressions. + +## Merging Pull Requests (for project contributors with write access) + +- **DO** use ["Squash and Merge"](https://github.com/blog/2141-squash-your-commits) by default for individual contributions unless requested by the PR author. + Do so, even if the PR contains only one commit. It creates a simpler history than "Create a Merge Commit". + Reasons that PR authors may request "Merge and Commit" may include (but are not limited to): + + - The change is easier to understand as a series of focused commits. Each commit in the series must be buildable so as not to break `git bisect`. + - Contributor is using an e-mail address other than the primary GitHub address and wants that preserved in the history. Contributor must be willing to squash + the commits manually before acceptance. + +## Developer Guide + +### Logging + +To enable debug logging, build with `-tags debug`. This enables debug level 1 by default. You can increase the level by setting the `DEBUG_LEVEL` environment variable to 2 or higher. (Debug logging is disabled entirely without `-tags debug`, regardless of `DEBUG_LEVEL` setting.) + +To add additional logging, use the `debug.Log(level int, format string, v ...any)` function, which is similar to `fmt.Printf` but takes a level as its first argument. + +### Packet Capture + +Wireshark can be very helpful in diagnosing interactions between client and server. If the connection is not encrypted Wireshark can natively decode AMQP 1.0. If the connection is encrypted with TLS you'll need to log out the keys. + +Example of logging the TLS keys: + +```go +// Create the file +f, err := os.OpenFile("key.log", os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666) + +// Configure TLS +tlsConfig := &tls.Config{ + KeyLogWriter: f, +} + +// Dial the host +const host = "my.amqp.server" +conn, err := tls.Dial("tcp", host+":5671", tlsConfig) + +// Create the connections +client, err := amqp.New(conn, + amqp.ConnSASLPlain("username", "password"), + amqp.ConnServerHostname(host), +) +``` + +You'll need to configure Wireshark to read the key.log file in Preferences > Protocols > SSL > (Pre)-Master-Secret log filename. diff --git a/vendor/github.com/Azure/go-amqp/LICENSE b/vendor/github.com/Azure/go-amqp/LICENSE new file mode 100644 index 00000000000..930bd6bdd7b --- /dev/null +++ b/vendor/github.com/Azure/go-amqp/LICENSE @@ -0,0 +1,22 @@ + MIT License + + Copyright (C) 2017 Kale Blankenship + Portions Copyright (C) Microsoft Corporation + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE diff --git a/vendor/github.com/Azure/go-amqp/Makefile b/vendor/github.com/Azure/go-amqp/Makefile new file mode 100644 index 00000000000..f6ee05dfae0 --- /dev/null +++ b/vendor/github.com/Azure/go-amqp/Makefile @@ -0,0 +1,31 @@ +PACKAGE := github.com/Azure/go-amqp +FUZZ_DIR := ./fuzz + +all: test + +fuzzconn: + go-fuzz-build -o $(FUZZ_DIR)/conn.zip -func FuzzConn $(PACKAGE) + go-fuzz -bin $(FUZZ_DIR)/conn.zip -workdir $(FUZZ_DIR)/conn + +fuzzmarshal: + go-fuzz-build -o $(FUZZ_DIR)/marshal.zip -func FuzzUnmarshal $(PACKAGE) + go-fuzz -bin $(FUZZ_DIR)/marshal.zip -workdir $(FUZZ_DIR)/marshal + +fuzzclean: + rm -f $(FUZZ_DIR)/**/{crashers,suppressions}/* + rm -f $(FUZZ_DIR)/*.zip + +test: + TEST_CORPUS=1 go test -race -run=Corpus + go test -v -race ./... + +#integration: + #go test -tags "integration" -count=1 -v -race . + +test386: + TEST_CORPUS=1 go test -count=1 -v . + +ci: test386 coverage + +coverage: + TEST_CORPUS=1 go test -cover -coverprofile=cover.out -v diff --git a/vendor/github.com/Azure/go-amqp/NOTICE.txt b/vendor/github.com/Azure/go-amqp/NOTICE.txt new file mode 100644 index 00000000000..8ccf678159c --- /dev/null +++ b/vendor/github.com/Azure/go-amqp/NOTICE.txt @@ -0,0 +1,29 @@ +NOTICES AND INFORMATION +Do Not Translate or Localize + +This software incorporates material from third parties. Microsoft makes certain +open source code available at https://3rdpartysource.microsoft.com, or you may +send a check or money order for US $5.00, including the product name, the open +source component name, and version number, to: + +Source Code Compliance Team +Microsoft Corporation +One Microsoft Way +Redmond, WA 98052 +USA + +Notwithstanding any other terms, you may reverse engineer this software to the +extent required to debug changes to any libraries licensed under the GNU Lesser +General Public License. + +------------------------------------------------------------------------------ + +go-amqp uses third-party libraries or other resources that may be +distributed under licenses different than the go-amqp software. + +In the event that we accidentally failed to list a required notice, please +bring it to our attention. Post an issue or email us: + + azgosdkhelp@microsoft.com + +The attached notices are provided for information only. diff --git a/vendor/github.com/Azure/go-amqp/README.md b/vendor/github.com/Azure/go-amqp/README.md new file mode 100644 index 00000000000..764505214db --- /dev/null +++ b/vendor/github.com/Azure/go-amqp/README.md @@ -0,0 +1,194 @@ +# AMQP 1.0 Client Module for Go + +[![PkgGoDev](https://pkg.go.dev/badge/github.com/Azure/go-amqp)](https://pkg.go.dev/github.com/Azure/go-amqp) +[![Build Status](https://dev.azure.com/azure-sdk/public/_apis/build/status/go/Azure.go-amqp?branchName=main)](https://dev.azure.com/azure-sdk/public/_build/latest?definitionId=1292&branchName=main) +[![Go Report Card](https://goreportcard.com/badge/github.com/Azure/go-amqp)](https://goreportcard.com/report/github.com/Azure/go-amqp) +[![MIT licensed](https://img.shields.io/badge/license-MIT-blue.svg)](https://raw.githubusercontent.com/Azure/go-amqp/main/LICENSE) + +The [amqp][godoc_amqp] module is an AMQP 1.0 client implementation for Go. + +[AMQP 1.0][amqp_spec] is not compatible with AMQP 0-9-1 or 0-10. + +## Getting Started + +### Prerequisites + +- Go 1.18 or later +- An AMQP 1.0 compliant [broker][broker_listing] + +### Install the module + +```sh +go get github.com/Azure/go-amqp +``` + +### Connect to a broker + +Call [amqp.Dial()][godoc_dial] to connect to an AMQP broker. This creates an [*amqp.Conn][godoc_conn]. + +```go +conn, err := amqp.Dial(context.TODO(), "amqp[s]://", nil) +if err != nil { + // handle error +} +``` + +### Sending and receiving messages + +In order to send or receive messages, first create an [*amqp.Session][godoc_session] from the [*amqp.Conn][godoc_conn] by calling [Conn.NewSession()][godoc_conn_session]. + +```go +session, err := conn.NewSession(context.TODO(), nil) +if err != nil { + // handle error +} +``` + +Once the session has been created, create an [*amqp.Sender][godoc_sender] to send messages and/or an [*amqp.Receiver][godoc_receiver] to receive messages by calling [Session.NewSender()][godoc_session_sender] and/or [Session.NewReceiver()][godoc_session_receiver] respectively. + +```go +// create a new sender +sender, err := session.NewSender(context.TODO(), "", nil) +if err != nil { + // handle error +} + +// send a message +err = sender.Send(context.TODO(), amqp.NewMessage([]byte("Hello!")), nil) +if err != nil { + // handle error +} + +// create a new receiver +receiver, err := session.NewReceiver(context.TODO(), "", nil) +if err != nil { + // handle error +} + +// receive the next message +msg, err := receiver.Receive(context.TODO(), nil) +if err != nil { + // handle error +} +``` + +## Key concepts + +- An [*amqp.Conn][godoc_conn] connects a client to a broker (e.g. Azure Service Bus). +- Once a connection has been established, create one or more [*amqp.Session][godoc_session] instances. +- From an [*amqp.Session][godoc_session] instance, create one or more senders and/or receivers. + - An [*amqp.Sender][godoc_sender] is used to send messages from the client to a broker. + - An [*amqp.Receiver][godoc_receiver] is used to receive messages from a broker to the client. + +For a complete overview of AMQP's conceptual model, please consult section [2.1 Transport][section_2_1] of the AMQP 1.0 specification. + +## Examples + +The following examples cover common scenarios for sending and receiving messages: + +- [Create a message](#create-a-message) +- [Send message](#send-message) +- [Receive messages](#receive-messages) + +### Create a message + +A message can be created in two different ways. The first is to simply instantiate a new instance of the [*amqp.Message][godoc_message] type, populating the required fields. + +```go +msg := &amqp.Message{ + // populate fields (Data is the most common) +} +``` + +The second is the [amqp.NewMessage][godoc_message_ctor] constructor. It passes the provided `[]byte` to the first entry in the `*amqp.Message.Data` slice. + +```go +msg := amqp.NewMessage(/* some []byte */) +``` + +This is purely a convenience constructor as many AMQP brokers expect a message's data in the `Data` field. + +### Send message + +Once an [*amqp.Session][godoc_session] has been created, create an [*amqp.Sender][godoc_sender] in order to send messages. + +```go +sender, err := session.NewSender(context.TODO(), "", nil) +``` + +Once the [*amqp.Sender][godoc_sender] has been created, call [Sender.Send()][godoc_sender_send] to send an [*amqp.Message][godoc_message]. + +```go +err := sender.Send(context.TODO(), msg, nil) +``` + +Depending on the sender's configuration, the call to [Sender.Send()][godoc_sender_send] will block until the peer has acknowledged the message was received. +The amount of time the call will block is dependent upon network latency and the peer's load, but is usually in a few dozen milliseconds. + +### Receive messages + +Once an [*amqp.Session][godoc_session] has been created, create an [*amqp.Receiver][godoc_receiver] in order to receive messages. + +```go +receiver, err := session.NewReceiver(context.TODO(), "", nil) +``` + +Once the [*amqp.Receiver][godoc_receiver] has been created, call [Receiver.Receive()][godoc_receiver_receive] to wait for an incoming message. + +```go +msg, err := receiver.Receive(context.TODO(), nil) +``` + +Note that calls to [Receiver.Receive()][godoc_receiver_receive] will block until either a message has been received or, if applicable, the provided [context.Context][godoc_context] has been cancelled and/or its deadline exceeded. + +After an [*amqp.Message][godoc_message] message has been received and processed, as the final step it's **imperative** that the [*amqp.Message][godoc_message] is passed to one of the acknowledgement methods on the [*amqp.Receiver][godoc_receiver]. + +- [Receiver.AcceptMessage][godoc_receiver_accept] - the client has accepted the message and no redelivery is required (most common) +- [Receiver.ModifyMessage][godoc_receiver_modify] - the client has modified the message and released it for redelivery with the specified modifications +- [Receiver.RejectMessage][godoc_receiver_reject] - the message is invalid and therefore cannot be processed +- [Receiver.ReleaseMessage][godoc_receiver_release] - the client has released the message for redelivery without any modifications + +```go +err := receiver.AcceptMessage(context.TODO(), msg) +``` + +## Next steps + +See the [examples][godoc_examples] for complete end-to-end examples on how to use this module. + +## Contributing + +This project welcomes contributions and suggestions. Most contributions require you to agree to a +Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us +the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. + +When you submit a pull request, a CLA bot will automatically determine whether you need to provide +a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions +provided by the bot. You will only need to do this once across all repos using our CLA. + +This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). +For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or +contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. + +[amqp_spec]: http://docs.oasis-open.org/amqp/core/v1.0/os/amqp-core-overview-v1.0-os.html +[broker_listing]: https://github.com/xinchen10/awesome-amqp +[section_2_1]: http://docs.oasis-open.org/amqp/core/v1.0/os/amqp-core-transport-v1.0-os.html#section-transport +[godoc_amqp]: https://pkg.go.dev/github.com/Azure/go-amqp +[godoc_examples]: https://pkg.go.dev/github.com/Azure/go-amqp#pkg-examples +[godoc_conn]: https://pkg.go.dev/github.com/Azure/go-amqp#Conn +[godoc_conn_session]: https://pkg.go.dev/github.com/Azure/go-amqp#Conn.NewSession +[godoc_dial]: https://pkg.go.dev/github.com/Azure/go-amqp#Dial +[godoc_context]: https://pkg.go.dev/context#Context +[godoc_message]: https://pkg.go.dev/github.com/Azure/go-amqp#Message +[godoc_message_ctor]: https://pkg.go.dev/github.com/Azure/go-amqp#NewMessage +[godoc_session]: https://pkg.go.dev/github.com/Azure/go-amqp#Session +[godoc_session_sender]: https://pkg.go.dev/github.com/Azure/go-amqp#Session.NewSender +[godoc_session_receiver]: https://pkg.go.dev/github.com/Azure/go-amqp#Session.NewReceiver +[godoc_sender]: https://pkg.go.dev/github.com/Azure/go-amqp#Sender +[godoc_sender_send]: https://pkg.go.dev/github.com/Azure/go-amqp#Sender.Send +[godoc_receiver]: https://pkg.go.dev/github.com/Azure/go-amqp#Receiver +[godoc_receiver_accept]: https://pkg.go.dev/github.com/Azure/go-amqp#Receiver.AcceptMessage +[godoc_receiver_modify]: https://pkg.go.dev/github.com/Azure/go-amqp#Receiver.ModifyMessage +[godoc_receiver_reject]: https://pkg.go.dev/github.com/Azure/go-amqp#Receiver.RejectMessage +[godoc_receiver_release]: https://pkg.go.dev/github.com/Azure/go-amqp#Receiver.ReleaseMessage +[godoc_receiver_receive]: https://pkg.go.dev/github.com/Azure/go-amqp#Receiver.Receive diff --git a/vendor/github.com/Azure/go-amqp/SECURITY.md b/vendor/github.com/Azure/go-amqp/SECURITY.md new file mode 100644 index 00000000000..7ab49eb8296 --- /dev/null +++ b/vendor/github.com/Azure/go-amqp/SECURITY.md @@ -0,0 +1,41 @@ + + +## Security + +Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). + +If you believe you have found a security vulnerability in any Microsoft-owned repository that meets Microsoft's [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)) of a security vulnerability, please report it to us as described below. + +## Reporting Security Issues + +**Please do not report security vulnerabilities through public GitHub issues.** + +Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). + +If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). + +You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). + +Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: + + * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) + * Full paths of source file(s) related to the manifestation of the issue + * The location of the affected source code (tag/branch/commit or direct URL) + * Any special configuration required to reproduce the issue + * Step-by-step instructions to reproduce the issue + * Proof-of-concept or exploit code (if possible) + * Impact of the issue, including how an attacker might exploit the issue + +This information will help us triage your report more quickly. + +If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. + +## Preferred Languages + +We prefer all communications to be in English. + +## Policy + +Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). + + diff --git a/vendor/github.com/Azure/go-amqp/azure-pipelines.yml b/vendor/github.com/Azure/go-amqp/azure-pipelines.yml new file mode 100644 index 00000000000..73168ab31b2 --- /dev/null +++ b/vendor/github.com/Azure/go-amqp/azure-pipelines.yml @@ -0,0 +1,105 @@ +variables: + GO111MODULE: 'on' + AMQP_BROKER_ADDR: 'amqp://127.0.0.1:25672' + +jobs: + - job: 'goamqp' + displayName: 'Run go-amqp CI Checks' + + strategy: + matrix: + Linux_Go118: + pool.name: 'azsdk-pool-mms-ubuntu-2004-general' + vm.image: 'ubuntu-20.04' + go.version: '1.18.10' + Linux_Go121: + pool.name: 'azsdk-pool-mms-ubuntu-2004-general' + vm.image: 'ubuntu-20.04' + go.version: '1.21.7' + Linux_Go122: + pool.name: 'azsdk-pool-mms-ubuntu-2004-general' + vm.image: 'ubuntu-20.04' + go.version: '1.22.0' + + pool: + name: '$(pool.name)' + vmImage: '$(vm.image)' + + steps: + - task: GoTool@0 + inputs: + version: '$(go.version)' + displayName: "Select Go Version" + + - script: | + set -e + export gopathbin=$(go env GOPATH)/bin + echo "##vso[task.prependpath]$gopathbin" + go install github.com/jstemmer/go-junit-report/v2@v2.1.0 + go install github.com/axw/gocov/gocov@v1.1.0 + go install github.com/AlekSi/gocov-xml@v1.1.0 + go install github.com/matm/gocov-html/cmd/gocov-html@v1.4.0 + displayName: 'Install Dependencies' + + - script: | + curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(go env GOPATH)/bin v1.56.2 + golangci-lint --version + golangci-lint run + displayName: 'Install and Run GoLintCLI.' + + - script: | + go build -v ./... + displayName: 'Build' + + - script: | + go vet ./... + displayName: 'Vet' + + - task: UseDotNet@2 + displayName: 'Use .NET sdk' + inputs: + packageType: sdk + version: 6.0.x + installationPath: $(Agent.ToolsDirectory)/dotnet + + - script: | + git clone https://github.com/Azure/azure-amqp $(Pipeline.Workspace)/azure-amqp + git checkout v2.6.5 + pushd $(Pipeline.Workspace)/azure-amqp/test/TestAmqpBroker + dotnet restore + dotnet build + chmod +x $(Pipeline.Workspace)/azure-amqp/bin/Debug/TestAmqpBroker/net462/TestAmqpBroker.exe + displayName: 'Clone and Build Broker' + + - script: | + set -e + export TEST_CORPUS=1 + echo '##[command]Starting broker at $(AMQP_BROKER_ADDR)' + $(Pipeline.Workspace)/azure-amqp/bin/Debug/TestAmqpBroker/net462/TestAmqpBroker.exe $AMQP_BROKER_ADDR /headless & + brokerPID=$! + echo '##[section]Starting tests' + go test -race -v -coverprofile=coverage.txt -covermode atomic ./... 2>&1 | tee gotestoutput.log + go-junit-report < gotestoutput.log > report.xml + kill $brokerPID + gocov convert coverage.txt > coverage.json + gocov-xml < coverage.json > coverage.xml + gocov-html < coverage.json > coverage.html + displayName: 'Run Tests' + + - script: | + gofmt -s -l -w . >&2 + displayName: 'Format Check' + failOnStderr: true + condition: succeededOrFailed() + + - task: PublishTestResults@2 + inputs: + testRunner: JUnit + testResultsFiles: report.xml + failTaskOnFailedTests: true + + - task: PublishCodeCoverageResults@1 + inputs: + codeCoverageTool: Cobertura + summaryFileLocation: coverage.xml + additionalCodeCoverageFiles: coverage.html diff --git a/vendor/github.com/Azure/go-amqp/conn.go b/vendor/github.com/Azure/go-amqp/conn.go new file mode 100644 index 00000000000..098a8bf8655 --- /dev/null +++ b/vendor/github.com/Azure/go-amqp/conn.go @@ -0,0 +1,1147 @@ +package amqp + +import ( + "bytes" + "context" + "crypto/tls" + "errors" + "fmt" + "math" + "net" + "net/url" + "sync" + "time" + + "github.com/Azure/go-amqp/internal/bitmap" + "github.com/Azure/go-amqp/internal/buffer" + "github.com/Azure/go-amqp/internal/debug" + "github.com/Azure/go-amqp/internal/encoding" + "github.com/Azure/go-amqp/internal/frames" + "github.com/Azure/go-amqp/internal/shared" +) + +// Default connection options +const ( + defaultIdleTimeout = 1 * time.Minute + defaultMaxFrameSize = 65536 + defaultMaxSessions = 65536 + defaultWriteTimeout = 30 * time.Second +) + +// ConnOptions contains the optional settings for configuring an AMQP connection. +type ConnOptions struct { + // ContainerID sets the container-id to use when opening the connection. + // + // A container ID will be randomly generated if this option is not used. + ContainerID string + + // HostName sets the hostname sent in the AMQP + // Open frame and TLS ServerName (if not otherwise set). + HostName string + + // IdleTimeout specifies the maximum period between + // receiving frames from the peer. + // + // Specify a value less than zero to disable idle timeout. + // + // Default: 1 minute (60000000000). + IdleTimeout time.Duration + + // MaxFrameSize sets the maximum frame size that + // the connection will accept. + // + // Must be 512 or greater. + // + // Default: 65536. + MaxFrameSize uint32 + + // MaxSessions sets the maximum number of channels. + // The value must be greater than zero. + // + // Default: 65536. + MaxSessions uint16 + + // Properties sets an entry in the connection properties map sent to the server. + Properties map[string]any + + // SASLType contains the specified SASL authentication mechanism. + SASLType SASLType + + // TLSConfig sets the tls.Config to be used during + // TLS negotiation. + // + // This option is for advanced usage, in most scenarios + // providing a URL scheme of "amqps://" is sufficient. + TLSConfig *tls.Config + + // WriteTimeout controls the write deadline when writing AMQP frames to the + // underlying net.Conn and no caller provided context.Context is available or + // the context contains no deadline (e.g. context.Background()). + // The timeout is set per write. + // + // Setting to a value less than zero means no timeout is set, so writes + // defer to the underlying behavior of net.Conn with no write deadline. + // + // Default: 30s + WriteTimeout time.Duration + + // test hook + dialer dialer +} + +// Dial connects to an AMQP broker. +// +// If the addr includes a scheme, it must be "amqp", "amqps", or "amqp+ssl". +// If no port is provided, 5672 will be used for "amqp" and 5671 for "amqps" or "amqp+ssl". +// +// If username and password information is not empty it's used as SASL PLAIN +// credentials, equal to passing ConnSASLPlain option. +// +// opts: pass nil to accept the default values. +func Dial(ctx context.Context, addr string, opts *ConnOptions) (*Conn, error) { + c, err := dialConn(ctx, addr, opts) + if err != nil { + return nil, err + } + err = c.start(ctx) + if err != nil { + return nil, err + } + return c, nil +} + +// NewConn establishes a new AMQP client connection over conn. +// NOTE: [Conn] takes ownership of the provided [net.Conn] and will close it as required. +// opts: pass nil to accept the default values. +func NewConn(ctx context.Context, conn net.Conn, opts *ConnOptions) (*Conn, error) { + c, err := newConn(conn, opts) + if err != nil { + return nil, err + } + err = c.start(ctx) + if err != nil { + return nil, err + } + return c, nil +} + +// Conn is an AMQP connection. +type Conn struct { + net net.Conn // underlying connection + dialer dialer // used for testing purposes, it allows faking dialing TCP/TLS endpoints + writeTimeout time.Duration // controls write deadline in absense of a context + + // TLS + tlsNegotiation bool // negotiate TLS + tlsComplete bool // TLS negotiation complete + tlsConfig *tls.Config // TLS config, default used if nil (ServerName set to Client.hostname) + + // SASL + saslHandlers map[encoding.Symbol]stateFunc // map of supported handlers keyed by SASL mechanism, SASL not negotiated if nil + saslComplete bool // SASL negotiation complete; internal *except* for SASL auth methods + + // local settings + maxFrameSize uint32 // max frame size to accept + channelMax uint16 // maximum number of channels to allow + hostname string // hostname of remote server (set explicitly or parsed from URL) + idleTimeout time.Duration // maximum period between receiving frames + properties map[encoding.Symbol]any // additional properties sent upon connection open + containerID string // set explicitly or randomly generated + + // peer settings + peerIdleTimeout time.Duration // maximum period between sending frames + peerMaxFrameSize uint32 // maximum frame size peer will accept + + // conn state + done chan struct{} // indicates the connection has terminated + doneErr error // contains the error state returned from Close(); DO NOT TOUCH outside of conn.go until done has been closed! + + // connReader and connWriter management + rxtxExit chan struct{} // signals connReader and connWriter to exit + closeOnce sync.Once // ensures that close() is only called once + + // session tracking + channels *bitmap.Bitmap + sessionsByChannel map[uint16]*Session + sessionsByChannelMu sync.RWMutex + + abandonedSessionsMu sync.Mutex + abandonedSessions []*Session + + // connReader + rxBuf buffer.Buffer // incoming bytes buffer + rxDone chan struct{} // closed when connReader exits + rxErr error // contains last error reading from c.net; DO NOT TOUCH outside of connReader until rxDone has been closed! + + // connWriter + txFrame chan frameEnvelope // AMQP frames to be sent by connWriter + txBuf buffer.Buffer // buffer for marshaling frames before transmitting + txDone chan struct{} // closed when connWriter exits + txErr error // contains last error writing to c.net; DO NOT TOUCH outside of connWriter until txDone has been closed! +} + +// used to abstract the underlying dialer for testing purposes +type dialer interface { + NetDialerDial(ctx context.Context, c *Conn, host, port string) error + TLSDialWithDialer(ctx context.Context, c *Conn, host, port string) error +} + +// implements the dialer interface +type defaultDialer struct{} + +func (defaultDialer) NetDialerDial(ctx context.Context, c *Conn, host, port string) (err error) { + dialer := &net.Dialer{} + c.net, err = dialer.DialContext(ctx, "tcp", net.JoinHostPort(host, port)) + return +} + +func (defaultDialer) TLSDialWithDialer(ctx context.Context, c *Conn, host, port string) (err error) { + dialer := &tls.Dialer{Config: c.tlsConfig} + c.net, err = dialer.DialContext(ctx, "tcp", net.JoinHostPort(host, port)) + return +} + +func dialConn(ctx context.Context, addr string, opts *ConnOptions) (*Conn, error) { + u, err := url.Parse(addr) + if err != nil { + return nil, err + } + host, port := u.Hostname(), u.Port() + if port == "" { + port = "5672" + if u.Scheme == "amqps" || u.Scheme == "amqp+ssl" { + port = "5671" + } + } + + var cp ConnOptions + if opts != nil { + cp = *opts + } + + // prepend SASL credentials when the user/pass segment is not empty + if u.User != nil { + pass, _ := u.User.Password() + cp.SASLType = SASLTypePlain(u.User.Username(), pass) + } + + if cp.HostName == "" { + cp.HostName = host + } + + c, err := newConn(nil, &cp) + if err != nil { + return nil, err + } + + switch u.Scheme { + case "amqp", "": + err = c.dialer.NetDialerDial(ctx, c, host, port) + case "amqps", "amqp+ssl": + c.initTLSConfig() + c.tlsNegotiation = false + err = c.dialer.TLSDialWithDialer(ctx, c, host, port) + default: + err = fmt.Errorf("unsupported scheme %q", u.Scheme) + } + + if err != nil { + return nil, err + } + return c, nil +} + +func newConn(netConn net.Conn, opts *ConnOptions) (*Conn, error) { + c := &Conn{ + dialer: defaultDialer{}, + net: netConn, + maxFrameSize: defaultMaxFrameSize, + peerMaxFrameSize: defaultMaxFrameSize, + channelMax: defaultMaxSessions - 1, // -1 because channel-max starts at zero + idleTimeout: defaultIdleTimeout, + containerID: shared.RandString(40), + done: make(chan struct{}), + rxtxExit: make(chan struct{}), + rxDone: make(chan struct{}), + txFrame: make(chan frameEnvelope), + txDone: make(chan struct{}), + sessionsByChannel: map[uint16]*Session{}, + writeTimeout: defaultWriteTimeout, + } + + // apply options + if opts == nil { + opts = &ConnOptions{} + } + + if opts.WriteTimeout > 0 { + c.writeTimeout = opts.WriteTimeout + } else if opts.WriteTimeout < 0 { + c.writeTimeout = 0 + } + if opts.ContainerID != "" { + c.containerID = opts.ContainerID + } + if opts.HostName != "" { + c.hostname = opts.HostName + } + if opts.IdleTimeout > 0 { + c.idleTimeout = opts.IdleTimeout + } else if opts.IdleTimeout < 0 { + c.idleTimeout = 0 + } + if opts.MaxFrameSize > 0 && opts.MaxFrameSize < 512 { + return nil, fmt.Errorf("invalid MaxFrameSize value %d", opts.MaxFrameSize) + } else if opts.MaxFrameSize > 512 { + c.maxFrameSize = opts.MaxFrameSize + } + if opts.MaxSessions > 0 { + c.channelMax = opts.MaxSessions + } + if opts.SASLType != nil { + if err := opts.SASLType(c); err != nil { + return nil, err + } + } + if opts.Properties != nil { + c.properties = make(map[encoding.Symbol]any) + for key, val := range opts.Properties { + c.properties[encoding.Symbol(key)] = val + } + } + if opts.TLSConfig != nil { + c.tlsConfig = opts.TLSConfig.Clone() + } + if opts.dialer != nil { + c.dialer = opts.dialer + } + return c, nil +} + +func (c *Conn) initTLSConfig() { + // create a new config if not already set + if c.tlsConfig == nil { + c.tlsConfig = new(tls.Config) + } + + // TLS config must have ServerName or InsecureSkipVerify set + if c.tlsConfig.ServerName == "" && !c.tlsConfig.InsecureSkipVerify { + c.tlsConfig.ServerName = c.hostname + } +} + +// start establishes the connection and begins multiplexing network IO. +// It is an error to call Start() on a connection that's been closed. +func (c *Conn) start(ctx context.Context) (err error) { + // if the context has a deadline or is cancellable, start the interruptor goroutine. + // this will close the underlying net.Conn in response to the context. + + if ctx.Done() != nil { + done := make(chan struct{}) + interruptRes := make(chan error, 1) + + defer func() { + close(done) + if ctxErr := <-interruptRes; ctxErr != nil { + // return context error to caller + err = ctxErr + } + }() + + go func() { + select { + case <-ctx.Done(): + c.closeDuringStart() + interruptRes <- ctx.Err() + case <-done: + interruptRes <- nil + } + }() + } + + if err = c.startImpl(ctx); err != nil { + return err + } + + // we can't create the channel bitmap until the connection has been established. + // this is because our peer can tell us the max channels they support. + c.channels = bitmap.New(uint32(c.channelMax)) + + go c.connWriter() + go c.connReader() + + return +} + +func (c *Conn) startImpl(ctx context.Context) error { + // set connection establishment deadline as required + if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() { + _ = c.net.SetDeadline(deadline) + + // remove connection establishment deadline + defer func() { + _ = c.net.SetDeadline(time.Time{}) + }() + } + + // run connection establishment state machine + for state := c.negotiateProto; state != nil; { + var err error + state, err = state(ctx) + // check if err occurred + if err != nil { + c.closeDuringStart() + return err + } + } + + return nil +} + +// Close closes the connection. +func (c *Conn) Close() error { + c.close() + + // wait until the reader/writer goroutines have exited before proceeding. + // this is to prevent a race between calling Close() and a reader/writer + // goroutine calling close() due to a terminal error. + <-c.txDone + <-c.rxDone + + var connErr *ConnError + if errors.As(c.doneErr, &connErr) && connErr.RemoteErr == nil && connErr.inner == nil { + // an empty ConnectionError means the connection was closed by the caller + return nil + } + + // there was an error during shut-down or connReader/connWriter + // experienced a terminal error + return c.doneErr +} + +// close is called once, either from Close() or when connReader/connWriter exits +func (c *Conn) close() { + c.closeOnce.Do(func() { + defer close(c.done) + + close(c.rxtxExit) + + // wait for writing to stop, allows it to send the final close frame + <-c.txDone + + closeErr := c.net.Close() + + // check rxDone after closing net, otherwise may block + // for up to c.idleTimeout + <-c.rxDone + + if errors.Is(c.rxErr, net.ErrClosed) { + // this is the expected error when the connection is closed, swallow it + c.rxErr = nil + } + + if c.txErr == nil && c.rxErr == nil && closeErr == nil { + // if there are no errors, it means user initiated close() and we shut down cleanly + c.doneErr = &ConnError{} + } else if amqpErr, ok := c.rxErr.(*Error); ok { + // we experienced a peer-initiated close that contained an Error. return it + c.doneErr = &ConnError{RemoteErr: amqpErr} + } else if c.txErr != nil { + // c.txErr is already wrapped in a ConnError + c.doneErr = c.txErr + } else if c.rxErr != nil { + c.doneErr = &ConnError{inner: c.rxErr} + } else { + c.doneErr = &ConnError{inner: closeErr} + } + }) +} + +// closeDuringStart is a special close to be used only during startup (i.e. c.start() and any of its children) +func (c *Conn) closeDuringStart() { + c.closeOnce.Do(func() { + c.net.Close() + }) +} + +// NewSession starts a new session on the connection. +// - ctx controls waiting for the peer to acknowledge the session +// - opts contains optional values, pass nil to accept the defaults +// +// If the context's deadline expires or is cancelled before the operation +// completes, an error is returned. If the Session was successfully +// created, it will be cleaned up in future calls to NewSession. +func (c *Conn) NewSession(ctx context.Context, opts *SessionOptions) (*Session, error) { + // clean up any abandoned sessions first + if err := c.freeAbandonedSessions(ctx); err != nil { + return nil, err + } + + session, err := c.newSession(opts) + if err != nil { + return nil, err + } + + if err := session.begin(ctx); err != nil { + c.abandonSession(session) + return nil, err + } + + return session, nil +} + +func (c *Conn) freeAbandonedSessions(ctx context.Context) error { + c.abandonedSessionsMu.Lock() + defer c.abandonedSessionsMu.Unlock() + + debug.Log(3, "TX (Conn %p): cleaning up %d abandoned sessions", c, len(c.abandonedSessions)) + + for _, s := range c.abandonedSessions { + fr := frames.PerformEnd{} + if err := s.txFrameAndWait(ctx, &fr); err != nil { + return err + } + } + + c.abandonedSessions = nil + return nil +} + +func (c *Conn) newSession(opts *SessionOptions) (*Session, error) { + c.sessionsByChannelMu.Lock() + defer c.sessionsByChannelMu.Unlock() + + // create the next session to allocate + // note that channel always start at 0 + channel, ok := c.channels.Next() + if !ok { + if err := c.Close(); err != nil { + return nil, err + } + return nil, &ConnError{inner: fmt.Errorf("reached connection channel max (%d)", c.channelMax)} + } + session := newSession(c, uint16(channel), opts) + c.sessionsByChannel[session.channel] = session + + return session, nil +} + +func (c *Conn) deleteSession(s *Session) { + c.sessionsByChannelMu.Lock() + defer c.sessionsByChannelMu.Unlock() + + delete(c.sessionsByChannel, s.channel) + c.channels.Remove(uint32(s.channel)) +} + +func (c *Conn) abandonSession(s *Session) { + c.abandonedSessionsMu.Lock() + defer c.abandonedSessionsMu.Unlock() + c.abandonedSessions = append(c.abandonedSessions, s) +} + +// connReader reads from the net.Conn, decodes frames, and either handles +// them here as appropriate or sends them to the session.rx channel. +func (c *Conn) connReader() { + defer func() { + close(c.rxDone) + c.close() + }() + + var sessionsByRemoteChannel = make(map[uint16]*Session) + var err error + for { + if err != nil { + debug.Log(0, "RX (connReader %p): terminal error: %v", c, err) + c.rxErr = err + return + } + + var fr frames.Frame + fr, err = c.readFrame() + if err != nil { + continue + } + + debug.Log(0, "RX (connReader %p): %s", c, fr) + + var ( + session *Session + ok bool + ) + + switch body := fr.Body.(type) { + // Server initiated close. + case *frames.PerformClose: + // connWriter will send the close performative ack on its way out. + // it's a SHOULD though, not a MUST. + if body.Error == nil { + return + } + err = body.Error + continue + + // RemoteChannel should be used when frame is Begin + case *frames.PerformBegin: + if body.RemoteChannel == nil { + // since we only support remotely-initiated sessions, this is an error + // TODO: it would be ideal to not have this kill the connection + err = fmt.Errorf("%T: nil RemoteChannel", fr.Body) + continue + } + c.sessionsByChannelMu.RLock() + session, ok = c.sessionsByChannel[*body.RemoteChannel] + c.sessionsByChannelMu.RUnlock() + if !ok { + // this can happen if NewSession() exits due to the context expiring/cancelled + // before the begin ack is received. + err = fmt.Errorf("unexpected remote channel number %d", *body.RemoteChannel) + continue + } + + session.remoteChannel = fr.Channel + sessionsByRemoteChannel[fr.Channel] = session + + case *frames.PerformEnd: + session, ok = sessionsByRemoteChannel[fr.Channel] + if !ok { + err = fmt.Errorf("%T: didn't find channel %d in sessionsByRemoteChannel (PerformEnd)", fr.Body, fr.Channel) + continue + } + // we MUST remove the remote channel from our map as soon as we receive + // the ack (i.e. before passing it on to the session mux) on the session + // ending since the numbers are recycled. + delete(sessionsByRemoteChannel, fr.Channel) + c.deleteSession(session) + + default: + // pass on performative to the correct session + session, ok = sessionsByRemoteChannel[fr.Channel] + if !ok { + err = fmt.Errorf("%T: didn't find channel %d in sessionsByRemoteChannel", fr.Body, fr.Channel) + continue + } + } + + q := session.rxQ.Acquire() + q.Enqueue(fr.Body) + session.rxQ.Release(q) + debug.Log(2, "RX (connReader %p): mux frame to Session (%p): %s", c, session, fr) + } +} + +// readFrame reads a complete frame from c.net. +// it assumes that any read deadline has already been applied. +// used externally by SASL only. +func (c *Conn) readFrame() (frames.Frame, error) { + switch { + // Cheaply reuse free buffer space when fully read. + case c.rxBuf.Len() == 0: + c.rxBuf.Reset() + + // Prevent excessive/unbounded growth by shifting data to beginning of buffer. + case int64(c.rxBuf.Size()) > int64(c.maxFrameSize): + c.rxBuf.Reclaim() + } + + var ( + currentHeader frames.Header // keep track of the current header, for frames split across multiple TCP packets + frameInProgress bool // true if in the middle of receiving data for currentHeader + ) + + for { + // need to read more if buf doesn't contain the complete frame + // or there's not enough in buf to parse the header + if frameInProgress || c.rxBuf.Len() < frames.HeaderSize { + // we MUST reset the idle timeout before each read from net.Conn + if c.idleTimeout > 0 { + _ = c.net.SetReadDeadline(time.Now().Add(c.idleTimeout)) + } + err := c.rxBuf.ReadFromOnce(c.net) + if err != nil { + return frames.Frame{}, err + } + } + + // parse the header if a frame isn't in progress + if !frameInProgress { + // read more if buf doesn't contain enough to parse the header + // NOTE: we MUST do this ONLY if a frame isn't in progress else we can + // end up stalling when reading frames with bodies smaller than HeaderSize + if c.rxBuf.Len() < frames.HeaderSize { + continue + } + + var err error + currentHeader, err = frames.ParseHeader(&c.rxBuf) + if err != nil { + return frames.Frame{}, err + } + frameInProgress = true + } + + // check size is reasonable + if currentHeader.Size > math.MaxInt32 { // make max size configurable + return frames.Frame{}, errors.New("payload too large") + } + + bodySize := int64(currentHeader.Size - frames.HeaderSize) + + // the full frame hasn't been received, keep reading + if int64(c.rxBuf.Len()) < bodySize { + continue + } + frameInProgress = false + + // check if body is empty (keepalive) + if bodySize == 0 { + debug.Log(3, "RX (connReader %p): received keep-alive frame", c) + continue + } + + // parse the frame + b, ok := c.rxBuf.Next(bodySize) + if !ok { + return frames.Frame{}, fmt.Errorf("buffer EOF; requested bytes: %d, actual size: %d", bodySize, c.rxBuf.Len()) + } + + parsedBody, err := frames.ParseBody(buffer.New(b)) + if err != nil { + return frames.Frame{}, err + } + + return frames.Frame{Channel: currentHeader.Channel, Body: parsedBody}, nil + } +} + +// frameContext is an extended context.Context used to track writes to the network. +// this is required in order to remove ambiguities that can arise when simply waiting +// on context.Context.Done() to be signaled. +type frameContext struct { + // Ctx contains the caller's context and is used to set the write deadline. + Ctx context.Context + + // Done is closed when the frame was successfully written to net.Conn or Ctx was cancelled/timed out. + // Can be nil, but shouldn't be for callers that care about confirmation of sending. + Done chan struct{} + + // Err contains the context error. MUST be set before closing Done and ONLY read if Done is closed. + // ONLY Conn.connWriter may write to this field. + Err error +} + +// frameEnvelope is used when sending a frame to connWriter to be written to net.Conn +type frameEnvelope struct { + FrameCtx *frameContext + Frame frames.Frame +} + +func (c *Conn) connWriter() { + defer func() { + close(c.txDone) + c.close() + }() + + var ( + // keepalives are sent at a rate of 1/2 idle timeout + keepaliveInterval = c.peerIdleTimeout / 2 + // 0 disables keepalives + keepalivesEnabled = keepaliveInterval > 0 + // set if enable, nil if not; nil channels block forever + keepalive <-chan time.Time + ) + + if keepalivesEnabled { + ticker := time.NewTicker(keepaliveInterval) + defer ticker.Stop() + keepalive = ticker.C + } + + var err error + for { + if err != nil { + debug.Log(0, "TX (connWriter %p): terminal error: %v", c, err) + c.txErr = err + return + } + + select { + // frame write request + case env := <-c.txFrame: + timeout, ctxErr := c.getWriteTimeout(env.FrameCtx.Ctx) + if ctxErr != nil { + debug.Log(1, "TX (connWriter %p) getWriteTimeout: %s: %s", c, ctxErr.Error(), env.Frame) + if env.FrameCtx.Done != nil { + // the error MUST be set before closing the channel + env.FrameCtx.Err = ctxErr + close(env.FrameCtx.Done) + } + continue + } + + debug.Log(0, "TX (connWriter %p) timeout %s: %s", c, timeout, env.Frame) + err = c.writeFrame(timeout, env.Frame) + if err == nil && env.FrameCtx.Done != nil { + close(env.FrameCtx.Done) + } + // in the event of write failure, Conn will close and a + // *ConnError will be propagated to all of the sessions/link. + + // keepalive timer + case <-keepalive: + debug.Log(3, "TX (connWriter %p): sending keep-alive frame", c) + _ = c.net.SetWriteDeadline(time.Now().Add(c.writeTimeout)) + if _, err = c.net.Write(keepaliveFrame); err != nil { + err = &ConnError{inner: err} + } + // It would be slightly more efficient in terms of network + // resources to reset the timer each time a frame is sent. + // However, keepalives are small (8 bytes) and the interval + // is usually on the order of minutes. It does not seem + // worth it to add extra operations in the write path to + // avoid. (To properly reset a timer it needs to be stopped, + // possibly drained, then reset.) + + // connection complete + case <-c.rxtxExit: + // send close performative. note that the spec says we + // SHOULD wait for the ack but we don't HAVE to, in order + // to be resilient to bad actors etc. so we just send + // the close performative and exit. + fr := frames.Frame{ + Type: frames.TypeAMQP, + Body: &frames.PerformClose{}, + } + debug.Log(1, "TX (connWriter %p): %s", c, fr) + c.txErr = c.writeFrame(c.writeTimeout, fr) + return + } + } +} + +// writeFrame writes a frame to the network. +// used externally by SASL only. +// - timeout - the write deadline to set. zero means no deadline +// +// errors are wrapped in a ConnError as they can be returned to outside callers. +func (c *Conn) writeFrame(timeout time.Duration, fr frames.Frame) error { + // writeFrame into txBuf + c.txBuf.Reset() + err := frames.Write(&c.txBuf, fr) + if err != nil { + return &ConnError{inner: err} + } + + // validate the frame isn't exceeding peer's max frame size + requiredFrameSize := c.txBuf.Len() + if uint64(requiredFrameSize) > uint64(c.peerMaxFrameSize) { + return &ConnError{inner: fmt.Errorf("%T frame size %d larger than peer's max frame size %d", fr, requiredFrameSize, c.peerMaxFrameSize)} + } + + if timeout == 0 { + _ = c.net.SetWriteDeadline(time.Time{}) + } else if timeout > 0 { + _ = c.net.SetWriteDeadline(time.Now().Add(timeout)) + } + + // write to network + n, err := c.net.Write(c.txBuf.Bytes()) + if l := c.txBuf.Len(); n > 0 && n < l && err != nil { + debug.Log(1, "TX (writeFrame %p): wrote %d bytes less than len %d: %v", c, n, l, err) + } + if err != nil { + err = &ConnError{inner: err} + } + return err +} + +// writeProtoHeader writes an AMQP protocol header to the +// network +func (c *Conn) writeProtoHeader(pID protoID) error { + _, err := c.net.Write([]byte{'A', 'M', 'Q', 'P', byte(pID), 1, 0, 0}) + return err +} + +// keepaliveFrame is an AMQP frame with no body, used for keepalives +var keepaliveFrame = []byte{0x00, 0x00, 0x00, 0x08, 0x02, 0x00, 0x00, 0x00} + +// SendFrame is used by sessions and links to send frames across the network. +func (c *Conn) sendFrame(frameEnv frameEnvelope) { + select { + case c.txFrame <- frameEnv: + debug.Log(2, "TX (Conn %p): mux frame to connWriter: %s", c, frameEnv.Frame) + case <-c.done: + // Conn has closed + } +} + +// stateFunc is a state in a state machine. +// +// The state is advanced by returning the next state. +// The state machine concludes when nil is returned. +type stateFunc func(context.Context) (stateFunc, error) + +// negotiateProto determines which proto to negotiate next. +// used externally by SASL only. +func (c *Conn) negotiateProto(ctx context.Context) (stateFunc, error) { + // in the order each must be negotiated + switch { + case c.tlsNegotiation && !c.tlsComplete: + return c.exchangeProtoHeader(protoTLS) + case c.saslHandlers != nil && !c.saslComplete: + return c.exchangeProtoHeader(protoSASL) + default: + return c.exchangeProtoHeader(protoAMQP) + } +} + +type protoID uint8 + +// protocol IDs received in protoHeaders +const ( + protoAMQP protoID = 0x0 + protoTLS protoID = 0x2 + protoSASL protoID = 0x3 +) + +// exchangeProtoHeader performs the round trip exchange of protocol +// headers, validation, and returns the protoID specific next state. +func (c *Conn) exchangeProtoHeader(pID protoID) (stateFunc, error) { + // write the proto header + if err := c.writeProtoHeader(pID); err != nil { + return nil, err + } + + // read response header + p, err := c.readProtoHeader() + if err != nil { + return nil, err + } + + if pID != p.ProtoID { + return nil, fmt.Errorf("unexpected protocol header %#00x, expected %#00x", p.ProtoID, pID) + } + + // go to the proto specific state + switch pID { + case protoAMQP: + return c.openAMQP, nil + case protoTLS: + return c.startTLS, nil + case protoSASL: + return c.negotiateSASL, nil + default: + return nil, fmt.Errorf("unknown protocol ID %#02x", p.ProtoID) + } +} + +// readProtoHeader reads a protocol header packet from c.rxProto. +func (c *Conn) readProtoHeader() (protoHeader, error) { + const protoHeaderSize = 8 + + // only read from the network once our buffer has been exhausted. + // TODO: this preserves existing behavior as some tests rely on this + // implementation detail (it lets you replay a stream of bytes). we + // might want to consider removing this and fixing the tests as the + // protocol doesn't actually work this way. + if c.rxBuf.Len() == 0 { + for { + err := c.rxBuf.ReadFromOnce(c.net) + if err != nil { + return protoHeader{}, err + } + + // read more if buf doesn't contain enough to parse the header + if c.rxBuf.Len() >= protoHeaderSize { + break + } + } + } + + buf, ok := c.rxBuf.Next(protoHeaderSize) + if !ok { + return protoHeader{}, errors.New("invalid protoHeader") + } + // bounds check hint to compiler; see golang.org/issue/14808 + _ = buf[protoHeaderSize-1] + + if !bytes.Equal(buf[:4], []byte{'A', 'M', 'Q', 'P'}) { + return protoHeader{}, fmt.Errorf("unexpected protocol %q", buf[:4]) + } + + p := protoHeader{ + ProtoID: protoID(buf[4]), + Major: buf[5], + Minor: buf[6], + Revision: buf[7], + } + + if p.Major != 1 || p.Minor != 0 || p.Revision != 0 { + return protoHeader{}, fmt.Errorf("unexpected protocol version %d.%d.%d", p.Major, p.Minor, p.Revision) + } + + return p, nil +} + +// startTLS wraps the conn with TLS and returns to Client.negotiateProto +func (c *Conn) startTLS(ctx context.Context) (stateFunc, error) { + c.initTLSConfig() + + _ = c.net.SetReadDeadline(time.Time{}) // clear timeout + + // wrap existing net.Conn and perform TLS handshake + tlsConn := tls.Client(c.net, c.tlsConfig) + if err := tlsConn.HandshakeContext(ctx); err != nil { + return nil, err + } + + // swap net.Conn + c.net = tlsConn + c.tlsComplete = true + + // go to next protocol + return c.negotiateProto, nil +} + +// openAMQP round trips the AMQP open performative +func (c *Conn) openAMQP(ctx context.Context) (stateFunc, error) { + // send open frame + open := &frames.PerformOpen{ + ContainerID: c.containerID, + Hostname: c.hostname, + MaxFrameSize: c.maxFrameSize, + ChannelMax: c.channelMax, + IdleTimeout: c.idleTimeout / 2, // per spec, advertise half our idle timeout + Properties: c.properties, + } + fr := frames.Frame{ + Type: frames.TypeAMQP, + Body: open, + Channel: 0, + } + debug.Log(1, "TX (openAMQP %p): %s", c, fr) + timeout, err := c.getWriteTimeout(ctx) + if err != nil { + return nil, err + } + if err = c.writeFrame(timeout, fr); err != nil { + return nil, err + } + + // get the response + fr, err = c.readSingleFrame() + if err != nil { + return nil, err + } + debug.Log(1, "RX (openAMQP %p): %s", c, fr) + o, ok := fr.Body.(*frames.PerformOpen) + if !ok { + return nil, fmt.Errorf("openAMQP: unexpected frame type %T", fr.Body) + } + + // update peer settings + if o.MaxFrameSize > 0 { + c.peerMaxFrameSize = o.MaxFrameSize + } + if o.IdleTimeout > 0 { + // TODO: reject very small idle timeouts + c.peerIdleTimeout = o.IdleTimeout + } + if o.ChannelMax < c.channelMax { + c.channelMax = o.ChannelMax + } + + // connection established, exit state machine + return nil, nil +} + +// negotiateSASL returns the SASL handler for the first matched +// mechanism specified by the server +func (c *Conn) negotiateSASL(context.Context) (stateFunc, error) { + // read mechanisms frame + fr, err := c.readSingleFrame() + if err != nil { + return nil, err + } + debug.Log(1, "RX (negotiateSASL %p): %s", c, fr) + sm, ok := fr.Body.(*frames.SASLMechanisms) + if !ok { + return nil, fmt.Errorf("negotiateSASL: unexpected frame type %T", fr.Body) + } + + // return first match in c.saslHandlers based on order received + for _, mech := range sm.Mechanisms { + if state, ok := c.saslHandlers[mech]; ok { + return state, nil + } + } + + // no match + return nil, fmt.Errorf("no supported auth mechanism (%v)", sm.Mechanisms) // TODO: send "auth not supported" frame? +} + +// saslOutcome processes the SASL outcome frame and return Client.negotiateProto +// on success. +// +// SASL handlers return this stateFunc when the mechanism specific negotiation +// has completed. +// used externally by SASL only. +func (c *Conn) saslOutcome(context.Context) (stateFunc, error) { + // read outcome frame + fr, err := c.readSingleFrame() + if err != nil { + return nil, err + } + debug.Log(1, "RX (saslOutcome %p): %s", c, fr) + so, ok := fr.Body.(*frames.SASLOutcome) + if !ok { + return nil, fmt.Errorf("saslOutcome: unexpected frame type %T", fr.Body) + } + + // check if auth succeeded + if so.Code != encoding.CodeSASLOK { + return nil, fmt.Errorf("SASL PLAIN auth failed with code %#00x: %s", so.Code, so.AdditionalData) // implement Stringer for so.Code + } + + // return to c.negotiateProto + c.saslComplete = true + return c.negotiateProto, nil +} + +// readSingleFrame is used during connection establishment to read a single frame. +// +// After setup, conn.connReader handles incoming frames. +func (c *Conn) readSingleFrame() (frames.Frame, error) { + fr, err := c.readFrame() + if err != nil { + return frames.Frame{}, err + } + + return fr, nil +} + +// getWriteTimeout returns the timeout as calculated from the context's deadline +// or the default write timeout if the context has no deadline. +// if the context has timed out or was cancelled, an error is returned. +func (c *Conn) getWriteTimeout(ctx context.Context) (time.Duration, error) { + if ctx.Err() != nil { + // if the context is already cancelled we can just bail. + return 0, ctx.Err() + } + + if deadline, ok := ctx.Deadline(); ok { + until := time.Until(deadline) + if until <= 0 { + return 0, context.DeadlineExceeded + } + return until, nil + } + return c.writeTimeout, nil +} + +type protoHeader struct { + ProtoID protoID + Major uint8 + Minor uint8 + Revision uint8 +} diff --git a/vendor/github.com/Azure/go-amqp/const.go b/vendor/github.com/Azure/go-amqp/const.go new file mode 100644 index 00000000000..1fb1214bafe --- /dev/null +++ b/vendor/github.com/Azure/go-amqp/const.go @@ -0,0 +1,93 @@ +package amqp + +import "github.com/Azure/go-amqp/internal/encoding" + +// Sender Settlement Modes +const ( + // Sender will send all deliveries initially unsettled to the receiver. + SenderSettleModeUnsettled SenderSettleMode = encoding.SenderSettleModeUnsettled + + // Sender will send all deliveries settled to the receiver. + SenderSettleModeSettled SenderSettleMode = encoding.SenderSettleModeSettled + + // Sender MAY send a mixture of settled and unsettled deliveries to the receiver. + SenderSettleModeMixed SenderSettleMode = encoding.SenderSettleModeMixed +) + +// SenderSettleMode specifies how the sender will settle messages. +type SenderSettleMode = encoding.SenderSettleMode + +func senderSettleModeValue(m *SenderSettleMode) SenderSettleMode { + if m == nil { + return SenderSettleModeMixed + } + return *m +} + +// Receiver Settlement Modes +const ( + // Receiver is the first to consider the message as settled. + // Once the corresponding disposition frame is sent, the message + // is considered to be settled. + ReceiverSettleModeFirst ReceiverSettleMode = encoding.ReceiverSettleModeFirst + + // Receiver is the second to consider the message as settled. + // Once the corresponding disposition frame is sent, the settlement + // is considered in-flight and the message will not be considered as + // settled until the sender replies acknowledging the settlement. + ReceiverSettleModeSecond ReceiverSettleMode = encoding.ReceiverSettleModeSecond +) + +// ReceiverSettleMode specifies how the receiver will settle messages. +type ReceiverSettleMode = encoding.ReceiverSettleMode + +func receiverSettleModeValue(m *ReceiverSettleMode) ReceiverSettleMode { + if m == nil { + return ReceiverSettleModeFirst + } + return *m +} + +// Durability Policies +const ( + // No terminus state is retained durably. + DurabilityNone Durability = encoding.DurabilityNone + + // Only the existence and configuration of the terminus is + // retained durably. + DurabilityConfiguration Durability = encoding.DurabilityConfiguration + + // In addition to the existence and configuration of the + // terminus, the unsettled state for durable messages is + // retained durably. + DurabilityUnsettledState Durability = encoding.DurabilityUnsettledState +) + +// Durability specifies the durability of a link. +type Durability = encoding.Durability + +// Expiry Policies +const ( + // The expiry timer starts when terminus is detached. + ExpiryPolicyLinkDetach ExpiryPolicy = encoding.ExpiryLinkDetach + + // The expiry timer starts when the most recently + // associated session is ended. + ExpiryPolicySessionEnd ExpiryPolicy = encoding.ExpirySessionEnd + + // The expiry timer starts when most recently associated + // connection is closed. + ExpiryPolicyConnectionClose ExpiryPolicy = encoding.ExpiryConnectionClose + + // The terminus never expires. + ExpiryPolicyNever ExpiryPolicy = encoding.ExpiryNever +) + +// ExpiryPolicy specifies when the expiry timer of a terminus +// starts counting down from the timeout value. +// +// If the link is subsequently re-attached before the terminus is expired, +// then the count down is aborted. If the conditions for the +// terminus-expiry-policy are subsequently re-met, the expiry timer restarts +// from its originally configured timeout value. +type ExpiryPolicy = encoding.ExpiryPolicy diff --git a/vendor/github.com/Azure/go-amqp/creditor.go b/vendor/github.com/Azure/go-amqp/creditor.go new file mode 100644 index 00000000000..f4b6a1718af --- /dev/null +++ b/vendor/github.com/Azure/go-amqp/creditor.go @@ -0,0 +1,117 @@ +package amqp + +import ( + "context" + "errors" + "sync" +) + +type creditor struct { + mu sync.Mutex + + // future values for the next flow frame. + pendingDrain bool + creditsToAdd uint32 + + // drained is set when a drain is active and we're waiting + // for the corresponding flow from the remote. + drained chan struct{} +} + +var ( + errLinkDraining = errors.New("link is currently draining, no credits can be added") + errAlreadyDraining = errors.New("drain already in process") +) + +// EndDrain ends the current drain, unblocking any active Drain calls. +func (mc *creditor) EndDrain() { + mc.mu.Lock() + defer mc.mu.Unlock() + + if mc.drained != nil { + close(mc.drained) + mc.drained = nil + } +} + +// FlowBits gets gets the proper values for the next flow frame +// and resets the internal state. +// Returns: +// +// (drain: true, credits: 0) if a flow is needed (drain) +// (drain: false, credits > 0) if a flow is needed (issue credit) +// (drain: false, credits == 0) if no flow needed. +func (mc *creditor) FlowBits(currentCredits uint32) (bool, uint32) { + mc.mu.Lock() + defer mc.mu.Unlock() + + drain := mc.pendingDrain + var credits uint32 + + if mc.pendingDrain { + // only send one drain request + mc.pendingDrain = false + } + + // either: + // drain is true (ie, we're going to send a drain frame, and the credits for it should be 0) + // mc.creditsToAdd == 0 (no flow frame needed, no new credits are being issued) + if drain || mc.creditsToAdd == 0 { + credits = 0 + } else { + credits = mc.creditsToAdd + currentCredits + } + + mc.creditsToAdd = 0 + + return drain, credits +} + +// Drain initiates a drain and blocks until EndDrain is called. +// If the context's deadline expires or is cancelled before the operation +// completes, the drain might not have happened. +func (mc *creditor) Drain(ctx context.Context, r *Receiver) error { + mc.mu.Lock() + + if mc.drained != nil { + mc.mu.Unlock() + return errAlreadyDraining + } + + mc.drained = make(chan struct{}) + // use a local copy to avoid racing with EndDrain() + drained := mc.drained + mc.pendingDrain = true + + mc.mu.Unlock() + + // cause mux() to check our flow conditions. + select { + case r.receiverReady <- struct{}{}: + default: + } + + // send drain, wait for responding flow frame + select { + case <-drained: + return nil + case <-r.l.done: + return r.l.doneErr + case <-ctx.Done(): + return ctx.Err() + } +} + +// IssueCredit queues up additional credits to be requested at the next +// call of FlowBits() +func (mc *creditor) IssueCredit(credits uint32) error { + mc.mu.Lock() + defer mc.mu.Unlock() + + if mc.drained != nil { + return errLinkDraining + } + + mc.creditsToAdd += credits + return nil +} diff --git a/vendor/github.com/Azure/go-amqp/doc.go b/vendor/github.com/Azure/go-amqp/doc.go new file mode 100644 index 00000000000..ba5158300c2 --- /dev/null +++ b/vendor/github.com/Azure/go-amqp/doc.go @@ -0,0 +1,10 @@ +/* +Package amqp provides an AMQP 1.0 client implementation. + +AMQP 1.0 is not compatible with AMQP 0-9-1 or 0-10, which are +the most common AMQP protocols in use today. + +The example below shows how to use this package to connect +to a Microsoft Azure Service Bus queue. +*/ +package amqp // import "github.com/Azure/go-amqp" diff --git a/vendor/github.com/Azure/go-amqp/errors.go b/vendor/github.com/Azure/go-amqp/errors.go new file mode 100644 index 00000000000..c2e3b68a090 --- /dev/null +++ b/vendor/github.com/Azure/go-amqp/errors.go @@ -0,0 +1,104 @@ +package amqp + +import ( + "github.com/Azure/go-amqp/internal/encoding" +) + +// ErrCond is an AMQP defined error condition. +// See http://docs.oasis-open.org/amqp/core/v1.0/os/amqp-core-transport-v1.0-os.html#type-amqp-error for info on their meaning. +type ErrCond = encoding.ErrCond + +// Error Conditions +const ( + // AMQP Errors + ErrCondDecodeError ErrCond = "amqp:decode-error" + ErrCondFrameSizeTooSmall ErrCond = "amqp:frame-size-too-small" + ErrCondIllegalState ErrCond = "amqp:illegal-state" + ErrCondInternalError ErrCond = "amqp:internal-error" + ErrCondInvalidField ErrCond = "amqp:invalid-field" + ErrCondNotAllowed ErrCond = "amqp:not-allowed" + ErrCondNotFound ErrCond = "amqp:not-found" + ErrCondNotImplemented ErrCond = "amqp:not-implemented" + ErrCondPreconditionFailed ErrCond = "amqp:precondition-failed" + ErrCondResourceDeleted ErrCond = "amqp:resource-deleted" + ErrCondResourceLimitExceeded ErrCond = "amqp:resource-limit-exceeded" + ErrCondResourceLocked ErrCond = "amqp:resource-locked" + ErrCondUnauthorizedAccess ErrCond = "amqp:unauthorized-access" + + // Connection Errors + ErrCondConnectionForced ErrCond = "amqp:connection:forced" + ErrCondConnectionRedirect ErrCond = "amqp:connection:redirect" + ErrCondFramingError ErrCond = "amqp:connection:framing-error" + + // Session Errors + ErrCondErrantLink ErrCond = "amqp:session:errant-link" + ErrCondHandleInUse ErrCond = "amqp:session:handle-in-use" + ErrCondUnattachedHandle ErrCond = "amqp:session:unattached-handle" + ErrCondWindowViolation ErrCond = "amqp:session:window-violation" + + // Link Errors + ErrCondDetachForced ErrCond = "amqp:link:detach-forced" + ErrCondLinkRedirect ErrCond = "amqp:link:redirect" + ErrCondMessageSizeExceeded ErrCond = "amqp:link:message-size-exceeded" + ErrCondStolen ErrCond = "amqp:link:stolen" + ErrCondTransferLimitExceeded ErrCond = "amqp:link:transfer-limit-exceeded" +) + +// Error is an AMQP error. +type Error = encoding.Error + +// LinkError is returned by methods on Sender/Receiver when the link has closed. +type LinkError struct { + // RemoteErr contains any error information provided by the peer if the peer detached the link. + RemoteErr *Error + + inner error +} + +// Error implements the error interface for LinkError. +func (e *LinkError) Error() string { + if e.RemoteErr == nil && e.inner == nil { + return "amqp: link closed" + } else if e.RemoteErr != nil { + return e.RemoteErr.Error() + } + return e.inner.Error() +} + +// ConnError is returned by methods on Conn and propagated to Session and Senders/Receivers +// when the connection has been closed. +type ConnError struct { + // RemoteErr contains any error information provided by the peer if the peer closed the AMQP connection. + RemoteErr *Error + + inner error +} + +// Error implements the error interface for ConnectionError. +func (e *ConnError) Error() string { + if e.RemoteErr == nil && e.inner == nil { + return "amqp: connection closed" + } else if e.RemoteErr != nil { + return e.RemoteErr.Error() + } + return e.inner.Error() +} + +// SessionError is returned by methods on Session and propagated to Senders/Receivers +// when the session has been closed. +type SessionError struct { + // RemoteErr contains any error information provided by the peer if the peer closed the session. + RemoteErr *Error + + inner error +} + +// Error implements the error interface for SessionError. +func (e *SessionError) Error() string { + if e.RemoteErr == nil && e.inner == nil { + return "amqp: session closed" + } else if e.RemoteErr != nil { + return e.RemoteErr.Error() + } + return e.inner.Error() +} diff --git a/vendor/github.com/Azure/go-amqp/internal/bitmap/bitmap.go b/vendor/github.com/Azure/go-amqp/internal/bitmap/bitmap.go new file mode 100644 index 00000000000..ba04b008de8 --- /dev/null +++ b/vendor/github.com/Azure/go-amqp/internal/bitmap/bitmap.go @@ -0,0 +1,96 @@ +package bitmap + +import ( + "math/bits" +) + +// bitmap is a lazily initialized bitmap +type Bitmap struct { + max uint32 + bits []uint64 +} + +func New(max uint32) *Bitmap { + return &Bitmap{max: max} +} + +// add sets n in the bitmap. +// +// bits will be expanded as needed. +// +// If n is greater than max, the call has no effect. +func (b *Bitmap) Add(n uint32) { + if n > b.max { + return + } + + var ( + idx = n / 64 + offset = n % 64 + ) + + if l := len(b.bits); int(idx) >= l { + b.bits = append(b.bits, make([]uint64, int(idx)-l+1)...) + } + + b.bits[idx] |= 1 << offset +} + +// remove clears n from the bitmap. +// +// If n is not set or greater than max the call has not effect. +func (b *Bitmap) Remove(n uint32) { + var ( + idx = n / 64 + offset = n % 64 + ) + + if int(idx) >= len(b.bits) { + return + } + + b.bits[idx] &= ^uint64(1 << offset) +} + +// next sets and returns the lowest unset bit in the bitmap. +// +// bits will be expanded if necessary. +// +// If there are no unset bits below max, the second return +// value will be false. +func (b *Bitmap) Next() (uint32, bool) { + // find the first unset bit + for i, v := range b.bits { + // skip if all bits are set + if v == ^uint64(0) { + continue + } + + var ( + offset = bits.TrailingZeros64(^v) // invert and count zeroes + next = uint32(i*64 + offset) + ) + + // check if in bounds + if next > b.max { + return next, false + } + + // set bit + b.bits[i] |= 1 << uint32(offset) + return next, true + } + + // no unset bits in the current slice, + // check if the full range has been allocated + if uint64(len(b.bits)*64) > uint64(b.max) { + return 0, false + } + + // full range not allocated, append entry with first + // bit set + b.bits = append(b.bits, 1) + + // return the value of the first bit + return uint32(len(b.bits)-1) * 64, true +} diff --git a/vendor/github.com/Azure/go-amqp/internal/buffer/buffer.go b/vendor/github.com/Azure/go-amqp/internal/buffer/buffer.go new file mode 100644 index 00000000000..b8009376f82 --- /dev/null +++ b/vendor/github.com/Azure/go-amqp/internal/buffer/buffer.go @@ -0,0 +1,177 @@ +package buffer + +import ( + "encoding/binary" + "io" +) + +// buffer is similar to bytes.Buffer but specialized for this package +type Buffer struct { + b []byte + i int +} + +func New(b []byte) *Buffer { + return &Buffer{b: b} +} + +func (b *Buffer) Next(n int64) ([]byte, bool) { + if b.readCheck(n) { + buf := b.b[b.i:len(b.b)] + b.i = len(b.b) + return buf, false + } + + buf := b.b[b.i : b.i+int(n)] + b.i += int(n) + return buf, true +} + +func (b *Buffer) Skip(n int) { + b.i += n +} + +func (b *Buffer) Reset() { + b.b = b.b[:0] + b.i = 0 +} + +// reclaim shifts used buffer space to the beginning of the +// underlying slice. +func (b *Buffer) Reclaim() { + l := b.Len() + copy(b.b[:l], b.b[b.i:]) + b.b = b.b[:l] + b.i = 0 +} + +func (b *Buffer) readCheck(n int64) bool { + return int64(b.i)+n > int64(len(b.b)) +} + +func (b *Buffer) ReadByte() (byte, error) { + if b.readCheck(1) { + return 0, io.EOF + } + + byte_ := b.b[b.i] + b.i++ + return byte_, nil +} + +func (b *Buffer) PeekByte() (byte, error) { + if b.readCheck(1) { + return 0, io.EOF + } + + return b.b[b.i], nil +} + +func (b *Buffer) ReadUint16() (uint16, error) { + if b.readCheck(2) { + return 0, io.EOF + } + + n := binary.BigEndian.Uint16(b.b[b.i:]) + b.i += 2 + return n, nil +} + +func (b *Buffer) ReadUint32() (uint32, error) { + if b.readCheck(4) { + return 0, io.EOF + } + + n := binary.BigEndian.Uint32(b.b[b.i:]) + b.i += 4 + return n, nil +} + +func (b *Buffer) ReadUint64() (uint64, error) { + if b.readCheck(8) { + return 0, io.EOF + } + + n := binary.BigEndian.Uint64(b.b[b.i : b.i+8]) + b.i += 8 + return n, nil +} + +func (b *Buffer) ReadFromOnce(r io.Reader) error { + const minRead = 512 + + l := len(b.b) + if cap(b.b)-l < minRead { + total := l * 2 + if total == 0 { + total = minRead + } + new := make([]byte, l, total) + copy(new, b.b) + b.b = new + } + + n, err := r.Read(b.b[l:cap(b.b)]) + b.b = b.b[:l+n] + return err +} + +func (b *Buffer) Append(p []byte) { + b.b = append(b.b, p...) +} + +func (b *Buffer) AppendByte(bb byte) { + b.b = append(b.b, bb) +} + +func (b *Buffer) AppendString(s string) { + b.b = append(b.b, s...) +} + +func (b *Buffer) Len() int { + return len(b.b) - b.i +} + +func (b *Buffer) Size() int { + return b.i +} + +func (b *Buffer) Bytes() []byte { + return b.b[b.i:] +} + +func (b *Buffer) Detach() []byte { + temp := b.b + b.b = nil + b.i = 0 + return temp +} + +func (b *Buffer) AppendUint16(n uint16) { + b.b = append(b.b, + byte(n>>8), + byte(n), + ) +} + +func (b *Buffer) AppendUint32(n uint32) { + b.b = append(b.b, + byte(n>>24), + byte(n>>16), + byte(n>>8), + byte(n), + ) +} + +func (b *Buffer) AppendUint64(n uint64) { + b.b = append(b.b, + byte(n>>56), + byte(n>>48), + byte(n>>40), + byte(n>>32), + byte(n>>24), + byte(n>>16), + byte(n>>8), + byte(n), + ) +} diff --git a/vendor/github.com/Azure/go-amqp/internal/debug/debug.go b/vendor/github.com/Azure/go-amqp/internal/debug/debug.go new file mode 100644 index 00000000000..c25a0f0b061 --- /dev/null +++ b/vendor/github.com/Azure/go-amqp/internal/debug/debug.go @@ -0,0 +1,17 @@ +//go:build !debug +// +build !debug + +package debug + +// dummy functions used when debugging is not enabled + +// Log writes the formatted string to stderr. +// Level indicates the verbosity of the messages to log. +// The greater the value, the more verbose messages will be logged. +func Log(_ int, _ string, _ ...any) {} + +// Assert panics if the specified condition is false. +func Assert(bool) {} + +// Assert panics with the provided message if the specified condition is false. +func Assertf(bool, string, ...any) {} diff --git a/vendor/github.com/Azure/go-amqp/internal/debug/debug_debug.go b/vendor/github.com/Azure/go-amqp/internal/debug/debug_debug.go new file mode 100644 index 00000000000..fb20f4121a5 --- /dev/null +++ b/vendor/github.com/Azure/go-amqp/internal/debug/debug_debug.go @@ -0,0 +1,48 @@ +//go:build debug +// +build debug + +package debug + +import ( + "fmt" + "log" + "os" + "strconv" +) + +var ( + debugLevel = 1 + logger = log.New(os.Stderr, "", log.Lmicroseconds) +) + +func init() { + level, err := strconv.Atoi(os.Getenv("DEBUG_LEVEL")) + if err != nil { + return + } + + debugLevel = level +} + +// Log writes the formatted string to stderr. +// Level indicates the verbosity of the messages to log. +// The greater the value, the more verbose messages will be logged. +func Log(level int, format string, v ...any) { + if level <= debugLevel { + logger.Printf(format, v...) + } +} + +// Assert panics if the specified condition is false. +func Assert(condition bool) { + if !condition { + panic("assertion failed!") + } +} + +// Assert panics with the provided message if the specified condition is false. +func Assertf(condition bool, msg string, v ...any) { + if !condition { + panic(fmt.Sprintf(msg, v...)) + } +} diff --git a/vendor/github.com/Azure/go-amqp/internal/encoding/decode.go b/vendor/github.com/Azure/go-amqp/internal/encoding/decode.go new file mode 100644 index 00000000000..cdcae726a06 --- /dev/null +++ b/vendor/github.com/Azure/go-amqp/internal/encoding/decode.go @@ -0,0 +1,1149 @@ +// Copyright (C) 2017 Kale Blankenship +// Portions Copyright (c) Microsoft Corporation +package encoding + +import ( + "encoding/binary" + "errors" + "fmt" + "math" + "reflect" + "time" + + "github.com/Azure/go-amqp/internal/buffer" +) + +// unmarshaler is fulfilled by types that can unmarshal +// themselves from AMQP data. +type unmarshaler interface { + Unmarshal(r *buffer.Buffer) error +} + +// unmarshal decodes AMQP encoded data into i. +// +// The decoding method is based on the type of i. +// +// If i implements unmarshaler, i.Unmarshal() will be called. +// +// Pointers to primitive types will be decoded via the appropriate read[Type] function. +// +// If i is a pointer to a pointer (**Type), it will be dereferenced and a new instance +// of (*Type) is allocated via reflection. +// +// Common map types (map[string]string, map[Symbol]any, and +// map[any]any), will be decoded via conversion to the mapStringAny, +// mapSymbolAny, and mapAnyAny types. +func Unmarshal(r *buffer.Buffer, i any) error { + if tryReadNull(r) { + return nil + } + + switch t := i.(type) { + case *int: + val, err := readInt(r) + if err != nil { + return err + } + *t = val + case *int8: + val, err := readSbyte(r) + if err != nil { + return err + } + *t = val + case *int16: + val, err := readShort(r) + if err != nil { + return err + } + *t = val + case *int32: + val, err := readInt32(r) + if err != nil { + return err + } + *t = val + case *int64: + val, err := readLong(r) + if err != nil { + return err + } + *t = val + case *uint64: + val, err := readUlong(r) + if err != nil { + return err + } + *t = val + case *uint32: + val, err := readUint32(r) + if err != nil { + return err + } + *t = val + case **uint32: // fastpath for uint32 pointer fields + val, err := readUint32(r) + if err != nil { + return err + } + *t = &val + case *uint16: + val, err := readUshort(r) + if err != nil { + return err + } + *t = val + case *uint8: + val, err := ReadUbyte(r) + if err != nil { + return err + } + *t = val + case *float32: + val, err := readFloat(r) + if err != nil { + return err + } + *t = val + case *float64: + val, err := readDouble(r) + if err != nil { + return err + } + *t = val + case *string: + val, err := ReadString(r) + if err != nil { + return err + } + *t = val + case *Symbol: + s, err := ReadString(r) + if err != nil { + return err + } + *t = Symbol(s) + case *[]byte: + val, err := readBinary(r) + if err != nil { + return err + } + *t = val + case *bool: + b, err := readBool(r) + if err != nil { + return err + } + *t = b + case *time.Time: + ts, err := readTimestamp(r) + if err != nil { + return err + } + *t = ts + case *[]int8: + return (*arrayInt8)(t).Unmarshal(r) + case *[]uint16: + return (*arrayUint16)(t).Unmarshal(r) + case *[]int16: + return (*arrayInt16)(t).Unmarshal(r) + case *[]uint32: + return (*arrayUint32)(t).Unmarshal(r) + case *[]int32: + return (*arrayInt32)(t).Unmarshal(r) + case *[]uint64: + return (*arrayUint64)(t).Unmarshal(r) + case *[]int64: + return (*arrayInt64)(t).Unmarshal(r) + case *[]float32: + return (*arrayFloat)(t).Unmarshal(r) + case *[]float64: + return (*arrayDouble)(t).Unmarshal(r) + case *[]bool: + return (*arrayBool)(t).Unmarshal(r) + case *[]string: + return (*arrayString)(t).Unmarshal(r) + case *[]Symbol: + return (*arraySymbol)(t).Unmarshal(r) + case *[][]byte: + return (*arrayBinary)(t).Unmarshal(r) + case *[]time.Time: + return (*arrayTimestamp)(t).Unmarshal(r) + case *[]UUID: + return (*arrayUUID)(t).Unmarshal(r) + case *[]any: + return (*list)(t).Unmarshal(r) + case *map[any]any: + return (*mapAnyAny)(t).Unmarshal(r) + case *map[string]any: + return (*mapStringAny)(t).Unmarshal(r) + case *map[Symbol]any: + return (*mapSymbolAny)(t).Unmarshal(r) + case *DeliveryState: + type_, _, err := PeekMessageType(r.Bytes()) + if err != nil { + return err + } + + switch AMQPType(type_) { + case TypeCodeStateAccepted: + *t = new(StateAccepted) + case TypeCodeStateModified: + *t = new(StateModified) + case TypeCodeStateReceived: + *t = new(StateReceived) + case TypeCodeStateRejected: + *t = new(StateRejected) + case TypeCodeStateReleased: + *t = new(StateReleased) + default: + return fmt.Errorf("unexpected type %d for deliveryState", type_) + } + return Unmarshal(r, *t) + + case *any: + v, err := ReadAny(r) + if err != nil { + return err + } + *t = v + + case unmarshaler: + return t.Unmarshal(r) + default: + // handle **T + v := reflect.Indirect(reflect.ValueOf(i)) + + // can't unmarshal into a non-pointer + if v.Kind() != reflect.Ptr { + return fmt.Errorf("unable to unmarshal %T", i) + } + + // if nil pointer, allocate a new value to + // unmarshal into + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + + return Unmarshal(r, v.Interface()) + } + return nil +} + +// unmarshalComposite is a helper for use in a composite's unmarshal() function. +// +// The composite from r will be unmarshaled into zero or more fields. An error +// will be returned if typ does not match the decoded type. +func UnmarshalComposite(r *buffer.Buffer, type_ AMQPType, fields ...UnmarshalField) error { + cType, numFields, err := readCompositeHeader(r) + if err != nil { + return err + } + + // check type matches expectation + if cType != type_ { + return fmt.Errorf("invalid header %#0x for %#0x", cType, type_) + } + + // Validate the field count is less than or equal to the number of fields + // provided. Fields may be omitted by the sender if they are not set. + if numFields > int64(len(fields)) { + return fmt.Errorf("invalid field count %d for %#0x", numFields, type_) + } + + for i, field := range fields[:numFields] { + // If the field is null and handleNull is set, call it. + if tryReadNull(r) { + if field.HandleNull != nil { + err = field.HandleNull() + if err != nil { + return err + } + } + continue + } + + // Unmarshal each of the received fields. + err = Unmarshal(r, field.Field) + if err != nil { + return fmt.Errorf("unmarshaling field %d: %v", i, err) + } + } + + // check and call handleNull for the remaining fields + for _, field := range fields[numFields:] { + if field.HandleNull != nil { + err = field.HandleNull() + if err != nil { + return err + } + } + } + + return nil +} + +// unmarshalField is a struct that contains a field to be unmarshaled into. +// +// An optional nullHandler can be set. If the composite field being unmarshaled +// is null and handleNull is not nil, nullHandler will be called. +type UnmarshalField struct { + Field any + HandleNull NullHandler +} + +// nullHandler is a function to be called when a composite's field +// is null. +type NullHandler func() error + +func readType(r *buffer.Buffer) (AMQPType, error) { + n, err := r.ReadByte() + return AMQPType(n), err +} + +func peekType(r *buffer.Buffer) (AMQPType, error) { + n, err := r.PeekByte() + return AMQPType(n), err +} + +// readCompositeHeader reads and consumes the composite header from r. +func readCompositeHeader(r *buffer.Buffer) (_ AMQPType, fields int64, _ error) { + type_, err := readType(r) + if err != nil { + return 0, 0, err + } + + // compsites always start with 0x0 + if type_ != 0 { + return 0, 0, fmt.Errorf("invalid composite header %#02x", type_) + } + + // next, the composite type is encoded as an AMQP uint8 + v, err := readUlong(r) + if err != nil { + return 0, 0, err + } + + // fields are represented as a list + fields, err = readListHeader(r) + + return AMQPType(v), fields, err +} + +func readListHeader(r *buffer.Buffer) (length int64, _ error) { + type_, err := readType(r) + if err != nil { + return 0, err + } + + listLength := r.Len() + + switch type_ { + case TypeCodeList0: + return 0, nil + case TypeCodeList8: + buf, ok := r.Next(2) + if !ok { + return 0, errors.New("invalid length") + } + _ = buf[1] + + size := int(buf[0]) + if size > listLength-1 { + return 0, errors.New("invalid length") + } + length = int64(buf[1]) + case TypeCodeList32: + buf, ok := r.Next(8) + if !ok { + return 0, errors.New("invalid length") + } + _ = buf[7] + + size := int(binary.BigEndian.Uint32(buf[:4])) + if size > listLength-4 { + return 0, errors.New("invalid length") + } + length = int64(binary.BigEndian.Uint32(buf[4:8])) + default: + return 0, fmt.Errorf("type code %#02x is not a recognized list type", type_) + } + + return length, nil +} + +func readArrayHeader(r *buffer.Buffer) (length int64, _ error) { + type_, err := readType(r) + if err != nil { + return 0, err + } + + arrayLength := r.Len() + + switch type_ { + case TypeCodeArray8: + buf, ok := r.Next(2) + if !ok { + return 0, errors.New("invalid length") + } + _ = buf[1] + + size := int(buf[0]) + if size > arrayLength-1 { + return 0, errors.New("invalid length") + } + length = int64(buf[1]) + case TypeCodeArray32: + buf, ok := r.Next(8) + if !ok { + return 0, errors.New("invalid length") + } + _ = buf[7] + + size := binary.BigEndian.Uint32(buf[:4]) + if int(size) > arrayLength-4 { + return 0, fmt.Errorf("invalid length for type %02x", type_) + } + length = int64(binary.BigEndian.Uint32(buf[4:8])) + default: + return 0, fmt.Errorf("type code %#02x is not a recognized array type", type_) + } + return length, nil +} + +func ReadString(r *buffer.Buffer) (string, error) { + type_, err := readType(r) + if err != nil { + return "", err + } + + var length int64 + switch type_ { + case TypeCodeStr8, TypeCodeSym8: + n, err := r.ReadByte() + if err != nil { + return "", err + } + length = int64(n) + case TypeCodeStr32, TypeCodeSym32: + buf, ok := r.Next(4) + if !ok { + return "", fmt.Errorf("invalid length for type %#02x", type_) + } + length = int64(binary.BigEndian.Uint32(buf)) + default: + return "", fmt.Errorf("type code %#02x is not a recognized string type", type_) + } + + buf, ok := r.Next(length) + if !ok { + return "", errors.New("invalid length") + } + return string(buf), nil +} + +func readBinary(r *buffer.Buffer) ([]byte, error) { + type_, err := readType(r) + if err != nil { + return nil, err + } + + var length int64 + switch type_ { + case TypeCodeVbin8: + n, err := r.ReadByte() + if err != nil { + return nil, err + } + length = int64(n) + case TypeCodeVbin32: + buf, ok := r.Next(4) + if !ok { + return nil, fmt.Errorf("invalid length for type %#02x", type_) + } + length = int64(binary.BigEndian.Uint32(buf)) + default: + return nil, fmt.Errorf("type code %#02x is not a recognized binary type", type_) + } + + if length == 0 { + // An empty value and a nil value are distinct, + // ensure that the returned value is not nil in this case. + return make([]byte, 0), nil + } + + buf, ok := r.Next(length) + if !ok { + return nil, errors.New("invalid length") + } + return append([]byte(nil), buf...), nil +} + +func ReadAny(r *buffer.Buffer) (any, error) { + if tryReadNull(r) { + return nil, nil + } + + type_, err := peekType(r) + if err != nil { + return nil, errors.New("invalid length") + } + + switch type_ { + // composite + case 0x0: + return readComposite(r) + + // bool + case TypeCodeBool, TypeCodeBoolTrue, TypeCodeBoolFalse: + return readBool(r) + + // uint + case TypeCodeUbyte: + return ReadUbyte(r) + case TypeCodeUshort: + return readUshort(r) + case TypeCodeUint, + TypeCodeSmallUint, + TypeCodeUint0: + return readUint32(r) + case TypeCodeUlong, + TypeCodeSmallUlong, + TypeCodeUlong0: + return readUlong(r) + + // int + case TypeCodeByte: + return readSbyte(r) + case TypeCodeShort: + return readShort(r) + case TypeCodeInt, + TypeCodeSmallint: + return readInt32(r) + case TypeCodeLong, + TypeCodeSmalllong: + return readLong(r) + + // floating point + case TypeCodeFloat: + return readFloat(r) + case TypeCodeDouble: + return readDouble(r) + + // binary + case TypeCodeVbin8, TypeCodeVbin32: + return readBinary(r) + + // strings + case TypeCodeStr8, TypeCodeStr32: + return ReadString(r) + case TypeCodeSym8, TypeCodeSym32: + // symbols currently decoded as string to avoid + // exposing symbol type in message, this may need + // to change if users need to distinguish strings + // from symbols + return ReadString(r) + + // timestamp + case TypeCodeTimestamp: + return readTimestamp(r) + + // UUID + case TypeCodeUUID: + return readUUID(r) + + // arrays + case TypeCodeArray8, TypeCodeArray32: + return readAnyArray(r) + + // lists + case TypeCodeList0, TypeCodeList8, TypeCodeList32: + return readAnyList(r) + + // maps + case TypeCodeMap8: + return readAnyMap(r) + case TypeCodeMap32: + return readAnyMap(r) + + // TODO: implement + case TypeCodeDecimal32: + return nil, errors.New("decimal32 not implemented") + case TypeCodeDecimal64: + return nil, errors.New("decimal64 not implemented") + case TypeCodeDecimal128: + return nil, errors.New("decimal128 not implemented") + case TypeCodeChar: + return nil, errors.New("char not implemented") + default: + return nil, fmt.Errorf("unknown type %#02x", type_) + } +} + +func readAnyMap(r *buffer.Buffer) (any, error) { + var m map[any]any + err := (*mapAnyAny)(&m).Unmarshal(r) + if err != nil { + return nil, err + } + + if len(m) == 0 { + return m, nil + } + + stringKeys := true +Loop: + for key := range m { + switch key.(type) { + case string: + case Symbol: + default: + stringKeys = false + break Loop + } + } + + if stringKeys { + mm := make(map[string]any, len(m)) + for key, value := range m { + switch key := key.(type) { + case string: + mm[key] = value + case Symbol: + mm[string(key)] = value + } + } + return mm, nil + } + + return m, nil +} + +func readAnyList(r *buffer.Buffer) (any, error) { + var a []any + err := (*list)(&a).Unmarshal(r) + return a, err +} + +func readAnyArray(r *buffer.Buffer) (any, error) { + // get the array type + buf := r.Bytes() + if len(buf) < 1 { + return nil, errors.New("invalid length") + } + + var typeIdx int + switch AMQPType(buf[0]) { + case TypeCodeArray8: + typeIdx = 3 + case TypeCodeArray32: + typeIdx = 9 + default: + return nil, fmt.Errorf("invalid array type %02x", buf[0]) + } + if len(buf) < typeIdx+1 { + return nil, errors.New("invalid length") + } + + switch AMQPType(buf[typeIdx]) { + case TypeCodeByte: + var a []int8 + err := (*arrayInt8)(&a).Unmarshal(r) + return a, err + case TypeCodeUbyte: + var a ArrayUByte + err := a.Unmarshal(r) + return a, err + case TypeCodeUshort: + var a []uint16 + err := (*arrayUint16)(&a).Unmarshal(r) + return a, err + case TypeCodeShort: + var a []int16 + err := (*arrayInt16)(&a).Unmarshal(r) + return a, err + case TypeCodeUint0, TypeCodeSmallUint, TypeCodeUint: + var a []uint32 + err := (*arrayUint32)(&a).Unmarshal(r) + return a, err + case TypeCodeSmallint, TypeCodeInt: + var a []int32 + err := (*arrayInt32)(&a).Unmarshal(r) + return a, err + case TypeCodeUlong0, TypeCodeSmallUlong, TypeCodeUlong: + var a []uint64 + err := (*arrayUint64)(&a).Unmarshal(r) + return a, err + case TypeCodeSmalllong, TypeCodeLong: + var a []int64 + err := (*arrayInt64)(&a).Unmarshal(r) + return a, err + case TypeCodeFloat: + var a []float32 + err := (*arrayFloat)(&a).Unmarshal(r) + return a, err + case TypeCodeDouble: + var a []float64 + err := (*arrayDouble)(&a).Unmarshal(r) + return a, err + case TypeCodeBool, TypeCodeBoolTrue, TypeCodeBoolFalse: + var a []bool + err := (*arrayBool)(&a).Unmarshal(r) + return a, err + case TypeCodeStr8, TypeCodeStr32: + var a []string + err := (*arrayString)(&a).Unmarshal(r) + return a, err + case TypeCodeSym8, TypeCodeSym32: + var a []Symbol + err := (*arraySymbol)(&a).Unmarshal(r) + return a, err + case TypeCodeVbin8, TypeCodeVbin32: + var a [][]byte + err := (*arrayBinary)(&a).Unmarshal(r) + return a, err + case TypeCodeTimestamp: + var a []time.Time + err := (*arrayTimestamp)(&a).Unmarshal(r) + return a, err + case TypeCodeUUID: + var a []UUID + err := (*arrayUUID)(&a).Unmarshal(r) + return a, err + default: + return nil, fmt.Errorf("array decoding not implemented for %#02x", buf[typeIdx]) + } +} + +func readComposite(r *buffer.Buffer) (any, error) { + buf := r.Bytes() + + if len(buf) < 2 { + return nil, errors.New("invalid length for composite") + } + + // compsites start with 0x0 + if AMQPType(buf[0]) != 0x0 { + return nil, fmt.Errorf("invalid composite header %#02x", buf[0]) + } + + var compositeType uint64 + switch AMQPType(buf[1]) { + case TypeCodeSmallUlong: + if len(buf) < 3 { + return nil, errors.New("invalid length for smallulong") + } + compositeType = uint64(buf[2]) + case TypeCodeUlong: + if len(buf) < 10 { + return nil, errors.New("invalid length for ulong") + } + compositeType = binary.BigEndian.Uint64(buf[2:]) + } + + if compositeType > math.MaxUint8 { + // try as described type + var dt DescribedType + err := dt.Unmarshal(r) + return dt, err + } + + switch AMQPType(compositeType) { + // Error + case TypeCodeError: + t := new(Error) + err := t.Unmarshal(r) + return t, err + + // Lifetime Policies + case TypeCodeDeleteOnClose: + t := DeleteOnClose + err := t.Unmarshal(r) + return t, err + case TypeCodeDeleteOnNoMessages: + t := DeleteOnNoMessages + err := t.Unmarshal(r) + return t, err + case TypeCodeDeleteOnNoLinks: + t := DeleteOnNoLinks + err := t.Unmarshal(r) + return t, err + case TypeCodeDeleteOnNoLinksOrMessages: + t := DeleteOnNoLinksOrMessages + err := t.Unmarshal(r) + return t, err + + // Delivery States + case TypeCodeStateAccepted: + t := new(StateAccepted) + err := t.Unmarshal(r) + return t, err + case TypeCodeStateModified: + t := new(StateModified) + err := t.Unmarshal(r) + return t, err + case TypeCodeStateReceived: + t := new(StateReceived) + err := t.Unmarshal(r) + return t, err + case TypeCodeStateRejected: + t := new(StateRejected) + err := t.Unmarshal(r) + return t, err + case TypeCodeStateReleased: + t := new(StateReleased) + err := t.Unmarshal(r) + return t, err + + case TypeCodeOpen, + TypeCodeBegin, + TypeCodeAttach, + TypeCodeFlow, + TypeCodeTransfer, + TypeCodeDisposition, + TypeCodeDetach, + TypeCodeEnd, + TypeCodeClose, + TypeCodeSource, + TypeCodeTarget, + TypeCodeMessageHeader, + TypeCodeDeliveryAnnotations, + TypeCodeMessageAnnotations, + TypeCodeMessageProperties, + TypeCodeApplicationProperties, + TypeCodeApplicationData, + TypeCodeAMQPSequence, + TypeCodeAMQPValue, + TypeCodeFooter, + TypeCodeSASLMechanism, + TypeCodeSASLInit, + TypeCodeSASLChallenge, + TypeCodeSASLResponse, + TypeCodeSASLOutcome: + return nil, fmt.Errorf("readComposite unmarshal not implemented for %#02x", compositeType) + + default: + // try as described type + var dt DescribedType + err := dt.Unmarshal(r) + return dt, err + } +} + +func readTimestamp(r *buffer.Buffer) (time.Time, error) { + type_, err := readType(r) + if err != nil { + return time.Time{}, err + } + + if type_ != TypeCodeTimestamp { + return time.Time{}, fmt.Errorf("invalid type for timestamp %02x", type_) + } + + n, err := r.ReadUint64() + ms := int64(n) + return time.Unix(ms/1000, (ms%1000)*1000000).UTC(), err +} + +func readInt(r *buffer.Buffer) (int, error) { + type_, err := peekType(r) + if err != nil { + return 0, err + } + + switch type_ { + // Unsigned + case TypeCodeUbyte: + n, err := ReadUbyte(r) + return int(n), err + case TypeCodeUshort: + n, err := readUshort(r) + return int(n), err + case TypeCodeUint0, TypeCodeSmallUint, TypeCodeUint: + n, err := readUint32(r) + return int(n), err + case TypeCodeUlong0, TypeCodeSmallUlong, TypeCodeUlong: + n, err := readUlong(r) + return int(n), err + + // Signed + case TypeCodeByte: + n, err := readSbyte(r) + return int(n), err + case TypeCodeShort: + n, err := readShort(r) + return int(n), err + case TypeCodeSmallint, TypeCodeInt: + n, err := readInt32(r) + return int(n), err + case TypeCodeSmalllong, TypeCodeLong: + n, err := readLong(r) + return int(n), err + default: + return 0, fmt.Errorf("type code %#02x is not a recognized number type", type_) + } +} + +func readLong(r *buffer.Buffer) (int64, error) { + type_, err := readType(r) + if err != nil { + return 0, err + } + + switch type_ { + case TypeCodeSmalllong: + n, err := r.ReadByte() + return int64(int8(n)), err + case TypeCodeLong: + n, err := r.ReadUint64() + return int64(n), err + default: + return 0, fmt.Errorf("invalid type for uint32 %02x", type_) + } +} + +func readInt32(r *buffer.Buffer) (int32, error) { + type_, err := readType(r) + if err != nil { + return 0, err + } + + switch type_ { + case TypeCodeSmallint: + n, err := r.ReadByte() + return int32(int8(n)), err + case TypeCodeInt: + n, err := r.ReadUint32() + return int32(n), err + default: + return 0, fmt.Errorf("invalid type for int32 %02x", type_) + } +} + +func readShort(r *buffer.Buffer) (int16, error) { + type_, err := readType(r) + if err != nil { + return 0, err + } + + if type_ != TypeCodeShort { + return 0, fmt.Errorf("invalid type for short %02x", type_) + } + + n, err := r.ReadUint16() + return int16(n), err +} + +func readSbyte(r *buffer.Buffer) (int8, error) { + type_, err := readType(r) + if err != nil { + return 0, err + } + + if type_ != TypeCodeByte { + return 0, fmt.Errorf("invalid type for int8 %02x", type_) + } + + n, err := r.ReadByte() + return int8(n), err +} + +func ReadUbyte(r *buffer.Buffer) (uint8, error) { + type_, err := readType(r) + if err != nil { + return 0, err + } + + if type_ != TypeCodeUbyte { + return 0, fmt.Errorf("invalid type for ubyte %02x", type_) + } + + return r.ReadByte() +} + +func readUshort(r *buffer.Buffer) (uint16, error) { + type_, err := readType(r) + if err != nil { + return 0, err + } + + if type_ != TypeCodeUshort { + return 0, fmt.Errorf("invalid type for ushort %02x", type_) + } + + return r.ReadUint16() +} + +func readUint32(r *buffer.Buffer) (uint32, error) { + type_, err := readType(r) + if err != nil { + return 0, err + } + + switch type_ { + case TypeCodeUint0: + return 0, nil + case TypeCodeSmallUint: + n, err := r.ReadByte() + return uint32(n), err + case TypeCodeUint: + return r.ReadUint32() + default: + return 0, fmt.Errorf("invalid type for uint32 %02x", type_) + } +} + +func readUlong(r *buffer.Buffer) (uint64, error) { + type_, err := readType(r) + if err != nil { + return 0, err + } + + switch type_ { + case TypeCodeUlong0: + return 0, nil + case TypeCodeSmallUlong: + n, err := r.ReadByte() + return uint64(n), err + case TypeCodeUlong: + return r.ReadUint64() + default: + return 0, fmt.Errorf("invalid type for uint32 %02x", type_) + } +} + +func readFloat(r *buffer.Buffer) (float32, error) { + type_, err := readType(r) + if err != nil { + return 0, err + } + + if type_ != TypeCodeFloat { + return 0, fmt.Errorf("invalid type for float32 %02x", type_) + } + + bits, err := r.ReadUint32() + return math.Float32frombits(bits), err +} + +func readDouble(r *buffer.Buffer) (float64, error) { + type_, err := readType(r) + if err != nil { + return 0, err + } + + if type_ != TypeCodeDouble { + return 0, fmt.Errorf("invalid type for float64 %02x", type_) + } + + bits, err := r.ReadUint64() + return math.Float64frombits(bits), err +} + +func readBool(r *buffer.Buffer) (bool, error) { + type_, err := readType(r) + if err != nil { + return false, err + } + + switch type_ { + case TypeCodeBool: + b, err := r.ReadByte() + return b != 0, err + case TypeCodeBoolTrue: + return true, nil + case TypeCodeBoolFalse: + return false, nil + default: + return false, fmt.Errorf("type code %#02x is not a recognized bool type", type_) + } +} + +func readUint(r *buffer.Buffer) (value uint64, _ error) { + type_, err := readType(r) + if err != nil { + return 0, err + } + + switch type_ { + case TypeCodeUint0, TypeCodeUlong0: + return 0, nil + case TypeCodeUbyte, TypeCodeSmallUint, TypeCodeSmallUlong: + n, err := r.ReadByte() + return uint64(n), err + case TypeCodeUshort: + n, err := r.ReadUint16() + return uint64(n), err + case TypeCodeUint: + n, err := r.ReadUint32() + return uint64(n), err + case TypeCodeUlong: + return r.ReadUint64() + default: + return 0, fmt.Errorf("type code %#02x is not a recognized number type", type_) + } +} + +func readUUID(r *buffer.Buffer) (UUID, error) { + var uuid UUID + + type_, err := readType(r) + if err != nil { + return uuid, err + } + + if type_ != TypeCodeUUID { + return uuid, fmt.Errorf("type code %#00x is not a UUID", type_) + } + + buf, ok := r.Next(16) + if !ok { + return uuid, errors.New("invalid length") + } + copy(uuid[:], buf) + + return uuid, nil +} + +func readMapHeader(r *buffer.Buffer) (count uint32, _ error) { + type_, err := readType(r) + if err != nil { + return 0, err + } + + length := r.Len() + + switch type_ { + case TypeCodeMap8: + buf, ok := r.Next(2) + if !ok { + return 0, errors.New("invalid length") + } + _ = buf[1] + + size := int(buf[0]) + if size > length-1 { + return 0, errors.New("invalid length") + } + count = uint32(buf[1]) + case TypeCodeMap32: + buf, ok := r.Next(8) + if !ok { + return 0, errors.New("invalid length") + } + _ = buf[7] + + size := int(binary.BigEndian.Uint32(buf[:4])) + if size > length-4 { + return 0, errors.New("invalid length") + } + count = binary.BigEndian.Uint32(buf[4:8]) + default: + return 0, fmt.Errorf("invalid map type %#02x", type_) + } + + if int(count) > r.Len() { + return 0, errors.New("invalid length") + } + return count, nil +} diff --git a/vendor/github.com/Azure/go-amqp/internal/encoding/encode.go b/vendor/github.com/Azure/go-amqp/internal/encoding/encode.go new file mode 100644 index 00000000000..898efe6a61a --- /dev/null +++ b/vendor/github.com/Azure/go-amqp/internal/encoding/encode.go @@ -0,0 +1,570 @@ +package encoding + +import ( + "encoding/binary" + "errors" + "fmt" + "math" + "time" + "unicode/utf8" + + "github.com/Azure/go-amqp/internal/buffer" +) + +type marshaler interface { + Marshal(*buffer.Buffer) error +} + +func Marshal(wr *buffer.Buffer, i any) error { + switch t := i.(type) { + case nil: + wr.AppendByte(byte(TypeCodeNull)) + case bool: + if t { + wr.AppendByte(byte(TypeCodeBoolTrue)) + } else { + wr.AppendByte(byte(TypeCodeBoolFalse)) + } + case *bool: + if *t { + wr.AppendByte(byte(TypeCodeBoolTrue)) + } else { + wr.AppendByte(byte(TypeCodeBoolFalse)) + } + case uint: + writeUint64(wr, uint64(t)) + case *uint: + writeUint64(wr, uint64(*t)) + case uint64: + writeUint64(wr, t) + case *uint64: + writeUint64(wr, *t) + case uint32: + writeUint32(wr, t) + case *uint32: + writeUint32(wr, *t) + case uint16: + wr.AppendByte(byte(TypeCodeUshort)) + wr.AppendUint16(t) + case *uint16: + wr.AppendByte(byte(TypeCodeUshort)) + wr.AppendUint16(*t) + case uint8: + wr.Append([]byte{ + byte(TypeCodeUbyte), + t, + }) + case *uint8: + wr.Append([]byte{ + byte(TypeCodeUbyte), + *t, + }) + case int: + writeInt64(wr, int64(t)) + case *int: + writeInt64(wr, int64(*t)) + case int8: + wr.Append([]byte{ + byte(TypeCodeByte), + uint8(t), + }) + case *int8: + wr.Append([]byte{ + byte(TypeCodeByte), + uint8(*t), + }) + case int16: + wr.AppendByte(byte(TypeCodeShort)) + wr.AppendUint16(uint16(t)) + case *int16: + wr.AppendByte(byte(TypeCodeShort)) + wr.AppendUint16(uint16(*t)) + case int32: + writeInt32(wr, t) + case *int32: + writeInt32(wr, *t) + case int64: + writeInt64(wr, t) + case *int64: + writeInt64(wr, *t) + case float32: + writeFloat(wr, t) + case *float32: + writeFloat(wr, *t) + case float64: + writeDouble(wr, t) + case *float64: + writeDouble(wr, *t) + case string: + return writeString(wr, t) + case *string: + return writeString(wr, *t) + case []byte: + return WriteBinary(wr, t) + case *[]byte: + return WriteBinary(wr, *t) + case map[any]any: + return writeMap(wr, t) + case *map[any]any: + return writeMap(wr, *t) + case map[string]any: + return writeMap(wr, t) + case *map[string]any: + return writeMap(wr, *t) + case map[Symbol]any: + return writeMap(wr, t) + case *map[Symbol]any: + return writeMap(wr, *t) + case Unsettled: + return writeMap(wr, t) + case *Unsettled: + return writeMap(wr, *t) + case time.Time: + writeTimestamp(wr, t) + case *time.Time: + writeTimestamp(wr, *t) + case []int8: + return arrayInt8(t).Marshal(wr) + case *[]int8: + return arrayInt8(*t).Marshal(wr) + case []uint16: + return arrayUint16(t).Marshal(wr) + case *[]uint16: + return arrayUint16(*t).Marshal(wr) + case []int16: + return arrayInt16(t).Marshal(wr) + case *[]int16: + return arrayInt16(*t).Marshal(wr) + case []uint32: + return arrayUint32(t).Marshal(wr) + case *[]uint32: + return arrayUint32(*t).Marshal(wr) + case []int32: + return arrayInt32(t).Marshal(wr) + case *[]int32: + return arrayInt32(*t).Marshal(wr) + case []uint64: + return arrayUint64(t).Marshal(wr) + case *[]uint64: + return arrayUint64(*t).Marshal(wr) + case []int64: + return arrayInt64(t).Marshal(wr) + case *[]int64: + return arrayInt64(*t).Marshal(wr) + case []float32: + return arrayFloat(t).Marshal(wr) + case *[]float32: + return arrayFloat(*t).Marshal(wr) + case []float64: + return arrayDouble(t).Marshal(wr) + case *[]float64: + return arrayDouble(*t).Marshal(wr) + case []bool: + return arrayBool(t).Marshal(wr) + case *[]bool: + return arrayBool(*t).Marshal(wr) + case []string: + return arrayString(t).Marshal(wr) + case *[]string: + return arrayString(*t).Marshal(wr) + case []Symbol: + return arraySymbol(t).Marshal(wr) + case *[]Symbol: + return arraySymbol(*t).Marshal(wr) + case [][]byte: + return arrayBinary(t).Marshal(wr) + case *[][]byte: + return arrayBinary(*t).Marshal(wr) + case []time.Time: + return arrayTimestamp(t).Marshal(wr) + case *[]time.Time: + return arrayTimestamp(*t).Marshal(wr) + case []UUID: + return arrayUUID(t).Marshal(wr) + case *[]UUID: + return arrayUUID(*t).Marshal(wr) + case []any: + return list(t).Marshal(wr) + case *[]any: + return list(*t).Marshal(wr) + case marshaler: + return t.Marshal(wr) + default: + return fmt.Errorf("marshal not implemented for %T", i) + } + return nil +} + +func writeInt32(wr *buffer.Buffer, n int32) { + if n < 128 && n >= -128 { + wr.Append([]byte{ + byte(TypeCodeSmallint), + byte(n), + }) + return + } + + wr.AppendByte(byte(TypeCodeInt)) + wr.AppendUint32(uint32(n)) +} + +func writeInt64(wr *buffer.Buffer, n int64) { + if n < 128 && n >= -128 { + wr.Append([]byte{ + byte(TypeCodeSmalllong), + byte(n), + }) + return + } + + wr.AppendByte(byte(TypeCodeLong)) + wr.AppendUint64(uint64(n)) +} + +func writeUint32(wr *buffer.Buffer, n uint32) { + if n == 0 { + wr.AppendByte(byte(TypeCodeUint0)) + return + } + + if n < 256 { + wr.Append([]byte{ + byte(TypeCodeSmallUint), + byte(n), + }) + return + } + + wr.AppendByte(byte(TypeCodeUint)) + wr.AppendUint32(n) +} + +func writeUint64(wr *buffer.Buffer, n uint64) { + if n == 0 { + wr.AppendByte(byte(TypeCodeUlong0)) + return + } + + if n < 256 { + wr.Append([]byte{ + byte(TypeCodeSmallUlong), + byte(n), + }) + return + } + + wr.AppendByte(byte(TypeCodeUlong)) + wr.AppendUint64(n) +} + +func writeFloat(wr *buffer.Buffer, f float32) { + wr.AppendByte(byte(TypeCodeFloat)) + wr.AppendUint32(math.Float32bits(f)) +} + +func writeDouble(wr *buffer.Buffer, f float64) { + wr.AppendByte(byte(TypeCodeDouble)) + wr.AppendUint64(math.Float64bits(f)) +} + +func writeTimestamp(wr *buffer.Buffer, t time.Time) { + wr.AppendByte(byte(TypeCodeTimestamp)) + ms := t.UnixNano() / int64(time.Millisecond) + wr.AppendUint64(uint64(ms)) +} + +// marshalField is a field to be marshaled +type MarshalField struct { + Value any // value to be marshaled, use pointers to avoid interface conversion overhead + Omit bool // indicates that this field should be omitted (set to null) +} + +// marshalComposite is a helper for us in a composite's marshal() function. +// +// The returned bytes include the composite header and fields. Fields with +// omit set to true will be encoded as null or omitted altogether if there are +// no non-null fields after them. +func MarshalComposite(wr *buffer.Buffer, code AMQPType, fields []MarshalField) error { + // lastSetIdx is the last index to have a non-omitted field. + // start at -1 as it's possible to have no fields in a composite + lastSetIdx := -1 + + // marshal each field into it's index in rawFields, + // null fields are skipped, leaving the index nil. + for i, f := range fields { + if f.Omit { + continue + } + lastSetIdx = i + } + + // write header only + if lastSetIdx == -1 { + wr.Append([]byte{ + 0x0, + byte(TypeCodeSmallUlong), + byte(code), + byte(TypeCodeList0), + }) + return nil + } + + // write header + WriteDescriptor(wr, code) + + // write fields + wr.AppendByte(byte(TypeCodeList32)) + + // write temp size, replace later + sizeIdx := wr.Len() + wr.Append([]byte{0, 0, 0, 0}) + preFieldLen := wr.Len() + + // field count + wr.AppendUint32(uint32(lastSetIdx + 1)) + + // write null to each index up to lastSetIdx + for _, f := range fields[:lastSetIdx+1] { + if f.Omit { + wr.AppendByte(byte(TypeCodeNull)) + continue + } + err := Marshal(wr, f.Value) + if err != nil { + return err + } + } + + // fix size + size := uint32(wr.Len() - preFieldLen) + buf := wr.Bytes() + binary.BigEndian.PutUint32(buf[sizeIdx:], size) + + return nil +} + +func WriteDescriptor(wr *buffer.Buffer, code AMQPType) { + wr.Append([]byte{ + 0x0, + byte(TypeCodeSmallUlong), + byte(code), + }) +} + +func writeString(wr *buffer.Buffer, str string) error { + if !utf8.ValidString(str) { + return errors.New("not a valid UTF-8 string") + } + l := len(str) + + switch { + // Str8 + case l < 256: + wr.Append([]byte{ + byte(TypeCodeStr8), + byte(l), + }) + wr.AppendString(str) + return nil + + // Str32 + case uint(l) < math.MaxUint32: + wr.AppendByte(byte(TypeCodeStr32)) + wr.AppendUint32(uint32(l)) + wr.AppendString(str) + return nil + + default: + return errors.New("too long") + } +} + +func WriteBinary(wr *buffer.Buffer, bin []byte) error { + l := len(bin) + + switch { + // List8 + case l < 256: + wr.Append([]byte{ + byte(TypeCodeVbin8), + byte(l), + }) + wr.Append(bin) + return nil + + // List32 + case uint(l) < math.MaxUint32: + wr.AppendByte(byte(TypeCodeVbin32)) + wr.AppendUint32(uint32(l)) + wr.Append(bin) + return nil + + default: + return errors.New("too long") + } +} + +func writeMap(wr *buffer.Buffer, m any) error { + startIdx := wr.Len() + wr.Append([]byte{ + byte(TypeCodeMap32), // type + 0, 0, 0, 0, // size placeholder + 0, 0, 0, 0, // length placeholder + }) + + var pairs int + switch m := m.(type) { + case map[any]any: + pairs = len(m) * 2 + for key, val := range m { + err := Marshal(wr, key) + if err != nil { + return err + } + err = Marshal(wr, val) + if err != nil { + return err + } + } + case map[string]any: + pairs = len(m) * 2 + for key, val := range m { + err := writeString(wr, key) + if err != nil { + return err + } + err = Marshal(wr, val) + if err != nil { + return err + } + } + case map[Symbol]any: + pairs = len(m) * 2 + for key, val := range m { + err := key.Marshal(wr) + if err != nil { + return err + } + err = Marshal(wr, val) + if err != nil { + return err + } + } + case Unsettled: + pairs = len(m) * 2 + for key, val := range m { + err := writeString(wr, key) + if err != nil { + return err + } + err = Marshal(wr, val) + if err != nil { + return err + } + } + case Filter: + pairs = len(m) * 2 + for key, val := range m { + err := key.Marshal(wr) + if err != nil { + return err + } + err = val.Marshal(wr) + if err != nil { + return err + } + } + case Annotations: + pairs = len(m) * 2 + for key, val := range m { + switch key := key.(type) { + case string: + err := Symbol(key).Marshal(wr) + if err != nil { + return err + } + case Symbol: + err := key.Marshal(wr) + if err != nil { + return err + } + case int64: + writeInt64(wr, key) + case int: + writeInt64(wr, int64(key)) + default: + return fmt.Errorf("unsupported Annotations key type %T", key) + } + + err := Marshal(wr, val) + if err != nil { + return err + } + } + default: + return fmt.Errorf("unsupported map type %T", m) + } + + if uint(pairs) > math.MaxUint32-4 { + return errors.New("map contains too many elements") + } + + // overwrite placeholder size and length + bytes := wr.Bytes()[startIdx+1 : startIdx+9] + _ = bytes[7] // bounds check hint + + length := wr.Len() - startIdx - 1 - 4 // -1 for type, -4 for length + binary.BigEndian.PutUint32(bytes[:4], uint32(length)) + binary.BigEndian.PutUint32(bytes[4:8], uint32(pairs)) + + return nil +} + +// type length sizes +const ( + array8TLSize = 2 + array32TLSize = 5 +) + +func writeArrayHeader(wr *buffer.Buffer, length, typeSize int, type_ AMQPType) { + size := length * typeSize + + // array type + if size+array8TLSize <= math.MaxUint8 { + wr.Append([]byte{ + byte(TypeCodeArray8), // type + byte(size + array8TLSize), // size + byte(length), // length + byte(type_), // element type + }) + } else { + wr.AppendByte(byte(TypeCodeArray32)) //type + wr.AppendUint32(uint32(size + array32TLSize)) // size + wr.AppendUint32(uint32(length)) // length + wr.AppendByte(byte(type_)) // element type + } +} + +func writeVariableArrayHeader(wr *buffer.Buffer, length, elementsSizeTotal int, type_ AMQPType) { + // 0xA_ == 1, 0xB_ == 4 + // http://docs.oasis-open.org/amqp/core/v1.0/os/amqp-core-types-v1.0-os.html#doc-idp82960 + elementTypeSize := 1 + if type_&0xf0 == 0xb0 { + elementTypeSize = 4 + } + + size := elementsSizeTotal + (length * elementTypeSize) // size excluding array length + if size+array8TLSize <= math.MaxUint8 { + wr.Append([]byte{ + byte(TypeCodeArray8), // type + byte(size + array8TLSize), // size + byte(length), // length + byte(type_), // element type + }) + } else { + wr.AppendByte(byte(TypeCodeArray32)) // type + wr.AppendUint32(uint32(size + array32TLSize)) // size + wr.AppendUint32(uint32(length)) // length + wr.AppendByte(byte(type_)) // element type + } +} diff --git a/vendor/github.com/Azure/go-amqp/internal/encoding/types.go b/vendor/github.com/Azure/go-amqp/internal/encoding/types.go new file mode 100644 index 00000000000..ffc93652ee6 --- /dev/null +++ b/vendor/github.com/Azure/go-amqp/internal/encoding/types.go @@ -0,0 +1,2152 @@ +package encoding + +import ( + "encoding/binary" + "encoding/hex" + "errors" + "fmt" + "math" + "reflect" + "time" + "unicode/utf8" + + "github.com/Azure/go-amqp/internal/buffer" +) + +type AMQPType uint8 + +// Type codes +const ( + TypeCodeNull AMQPType = 0x40 + + // Bool + TypeCodeBool AMQPType = 0x56 // boolean with the octet 0x00 being false and octet 0x01 being true + TypeCodeBoolTrue AMQPType = 0x41 + TypeCodeBoolFalse AMQPType = 0x42 + + // Unsigned + TypeCodeUbyte AMQPType = 0x50 // 8-bit unsigned integer (1) + TypeCodeUshort AMQPType = 0x60 // 16-bit unsigned integer in network byte order (2) + TypeCodeUint AMQPType = 0x70 // 32-bit unsigned integer in network byte order (4) + TypeCodeSmallUint AMQPType = 0x52 // unsigned integer value in the range 0 to 255 inclusive (1) + TypeCodeUint0 AMQPType = 0x43 // the uint value 0 (0) + TypeCodeUlong AMQPType = 0x80 // 64-bit unsigned integer in network byte order (8) + TypeCodeSmallUlong AMQPType = 0x53 // unsigned long value in the range 0 to 255 inclusive (1) + TypeCodeUlong0 AMQPType = 0x44 // the ulong value 0 (0) + + // Signed + TypeCodeByte AMQPType = 0x51 // 8-bit two's-complement integer (1) + TypeCodeShort AMQPType = 0x61 // 16-bit two's-complement integer in network byte order (2) + TypeCodeInt AMQPType = 0x71 // 32-bit two's-complement integer in network byte order (4) + TypeCodeSmallint AMQPType = 0x54 // 8-bit two's-complement integer (1) + TypeCodeLong AMQPType = 0x81 // 64-bit two's-complement integer in network byte order (8) + TypeCodeSmalllong AMQPType = 0x55 // 8-bit two's-complement integer + + // Decimal + TypeCodeFloat AMQPType = 0x72 // IEEE 754-2008 binary32 (4) + TypeCodeDouble AMQPType = 0x82 // IEEE 754-2008 binary64 (8) + TypeCodeDecimal32 AMQPType = 0x74 // IEEE 754-2008 decimal32 using the Binary Integer Decimal encoding (4) + TypeCodeDecimal64 AMQPType = 0x84 // IEEE 754-2008 decimal64 using the Binary Integer Decimal encoding (8) + TypeCodeDecimal128 AMQPType = 0x94 // IEEE 754-2008 decimal128 using the Binary Integer Decimal encoding (16) + + // Other + TypeCodeChar AMQPType = 0x73 // a UTF-32BE encoded Unicode character (4) + TypeCodeTimestamp AMQPType = 0x83 // 64-bit two's-complement integer representing milliseconds since the unix epoch + TypeCodeUUID AMQPType = 0x98 // UUID as defined in section 4.1.2 of RFC-4122 + + // Variable Length + TypeCodeVbin8 AMQPType = 0xa0 // up to 2^8 - 1 octets of binary data (1 + variable) + TypeCodeVbin32 AMQPType = 0xb0 // up to 2^32 - 1 octets of binary data (4 + variable) + TypeCodeStr8 AMQPType = 0xa1 // up to 2^8 - 1 octets worth of UTF-8 Unicode (with no byte order mark) (1 + variable) + TypeCodeStr32 AMQPType = 0xb1 // up to 2^32 - 1 octets worth of UTF-8 Unicode (with no byte order mark) (4 +variable) + TypeCodeSym8 AMQPType = 0xa3 // up to 2^8 - 1 seven bit ASCII characters representing a symbolic value (1 + variable) + TypeCodeSym32 AMQPType = 0xb3 // up to 2^32 - 1 seven bit ASCII characters representing a symbolic value (4 + variable) + + // Compound + TypeCodeList0 AMQPType = 0x45 // the empty list (i.e. the list with no elements) (0) + TypeCodeList8 AMQPType = 0xc0 // up to 2^8 - 1 list elements with total size less than 2^8 octets (1 + compound) + TypeCodeList32 AMQPType = 0xd0 // up to 2^32 - 1 list elements with total size less than 2^32 octets (4 + compound) + TypeCodeMap8 AMQPType = 0xc1 // up to 2^8 - 1 octets of encoded map data (1 + compound) + TypeCodeMap32 AMQPType = 0xd1 // up to 2^32 - 1 octets of encoded map data (4 + compound) + TypeCodeArray8 AMQPType = 0xe0 // up to 2^8 - 1 array elements with total size less than 2^8 octets (1 + array) + TypeCodeArray32 AMQPType = 0xf0 // up to 2^32 - 1 array elements with total size less than 2^32 octets (4 + array) + + // Composites + TypeCodeOpen AMQPType = 0x10 + TypeCodeBegin AMQPType = 0x11 + TypeCodeAttach AMQPType = 0x12 + TypeCodeFlow AMQPType = 0x13 + TypeCodeTransfer AMQPType = 0x14 + TypeCodeDisposition AMQPType = 0x15 + TypeCodeDetach AMQPType = 0x16 + TypeCodeEnd AMQPType = 0x17 + TypeCodeClose AMQPType = 0x18 + + TypeCodeSource AMQPType = 0x28 + TypeCodeTarget AMQPType = 0x29 + TypeCodeError AMQPType = 0x1d + + TypeCodeMessageHeader AMQPType = 0x70 + TypeCodeDeliveryAnnotations AMQPType = 0x71 + TypeCodeMessageAnnotations AMQPType = 0x72 + TypeCodeMessageProperties AMQPType = 0x73 + TypeCodeApplicationProperties AMQPType = 0x74 + TypeCodeApplicationData AMQPType = 0x75 + TypeCodeAMQPSequence AMQPType = 0x76 + TypeCodeAMQPValue AMQPType = 0x77 + TypeCodeFooter AMQPType = 0x78 + + TypeCodeStateReceived AMQPType = 0x23 + TypeCodeStateAccepted AMQPType = 0x24 + TypeCodeStateRejected AMQPType = 0x25 + TypeCodeStateReleased AMQPType = 0x26 + TypeCodeStateModified AMQPType = 0x27 + + TypeCodeSASLMechanism AMQPType = 0x40 + TypeCodeSASLInit AMQPType = 0x41 + TypeCodeSASLChallenge AMQPType = 0x42 + TypeCodeSASLResponse AMQPType = 0x43 + TypeCodeSASLOutcome AMQPType = 0x44 + + TypeCodeDeleteOnClose AMQPType = 0x2b + TypeCodeDeleteOnNoLinks AMQPType = 0x2c + TypeCodeDeleteOnNoMessages AMQPType = 0x2d + TypeCodeDeleteOnNoLinksOrMessages AMQPType = 0x2e +) + +// Durability Policies +const ( + // No terminus state is retained durably. + DurabilityNone Durability = 0 + + // Only the existence and configuration of the terminus is + // retained durably. + DurabilityConfiguration Durability = 1 + + // In addition to the existence and configuration of the + // terminus, the unsettled state for durable messages is + // retained durably. + DurabilityUnsettledState Durability = 2 +) + +// Durability specifies the durability of a link. +type Durability uint32 + +func (d *Durability) String() string { + if d == nil { + return "" + } + + switch *d { + case DurabilityNone: + return "none" + case DurabilityConfiguration: + return "configuration" + case DurabilityUnsettledState: + return "unsettled-state" + default: + return fmt.Sprintf("unknown durability %d", *d) + } +} + +func (d Durability) Marshal(wr *buffer.Buffer) error { + return Marshal(wr, uint32(d)) +} + +func (d *Durability) Unmarshal(r *buffer.Buffer) error { + return Unmarshal(r, (*uint32)(d)) +} + +// Expiry Policies +const ( + // The expiry timer starts when terminus is detached. + ExpiryLinkDetach ExpiryPolicy = "link-detach" + + // The expiry timer starts when the most recently + // associated session is ended. + ExpirySessionEnd ExpiryPolicy = "session-end" + + // The expiry timer starts when most recently associated + // connection is closed. + ExpiryConnectionClose ExpiryPolicy = "connection-close" + + // The terminus never expires. + ExpiryNever ExpiryPolicy = "never" +) + +// ExpiryPolicy specifies when the expiry timer of a terminus +// starts counting down from the timeout value. +// +// If the link is subsequently re-attached before the terminus is expired, +// then the count down is aborted. If the conditions for the +// terminus-expiry-policy are subsequently re-met, the expiry timer restarts +// from its originally configured timeout value. +type ExpiryPolicy Symbol + +func ValidateExpiryPolicy(e ExpiryPolicy) error { + switch e { + case ExpiryLinkDetach, + ExpirySessionEnd, + ExpiryConnectionClose, + ExpiryNever: + return nil + default: + return fmt.Errorf("unknown expiry-policy %q", e) + } +} + +func (e ExpiryPolicy) Marshal(wr *buffer.Buffer) error { + return Symbol(e).Marshal(wr) +} + +func (e *ExpiryPolicy) Unmarshal(r *buffer.Buffer) error { + err := Unmarshal(r, (*Symbol)(e)) + if err != nil { + return err + } + return ValidateExpiryPolicy(*e) +} + +func (e *ExpiryPolicy) String() string { + if e == nil { + return "" + } + return string(*e) +} + +// Sender Settlement Modes +const ( + // Sender will send all deliveries initially unsettled to the receiver. + SenderSettleModeUnsettled SenderSettleMode = 0 + + // Sender will send all deliveries settled to the receiver. + SenderSettleModeSettled SenderSettleMode = 1 + + // Sender MAY send a mixture of settled and unsettled deliveries to the receiver. + SenderSettleModeMixed SenderSettleMode = 2 +) + +// SenderSettleMode specifies how the sender will settle messages. +type SenderSettleMode uint8 + +func (m SenderSettleMode) Ptr() *SenderSettleMode { + return &m +} + +func (m *SenderSettleMode) String() string { + if m == nil { + return "" + } + + switch *m { + case SenderSettleModeUnsettled: + return "unsettled" + + case SenderSettleModeSettled: + return "settled" + + case SenderSettleModeMixed: + return "mixed" + + default: + return fmt.Sprintf("unknown sender mode %d", uint8(*m)) + } +} + +func (m SenderSettleMode) Marshal(wr *buffer.Buffer) error { + return Marshal(wr, uint8(m)) +} + +func (m *SenderSettleMode) Unmarshal(r *buffer.Buffer) error { + n, err := ReadUbyte(r) + *m = SenderSettleMode(n) + return err +} + +// Receiver Settlement Modes +const ( + // Receiver will spontaneously settle all incoming transfers. + ReceiverSettleModeFirst ReceiverSettleMode = 0 + + // Receiver will only settle after sending the disposition to the + // sender and receiving a disposition indicating settlement of + // the delivery from the sender. + ReceiverSettleModeSecond ReceiverSettleMode = 1 +) + +// ReceiverSettleMode specifies how the receiver will settle messages. +type ReceiverSettleMode uint8 + +func (m ReceiverSettleMode) Ptr() *ReceiverSettleMode { + return &m +} + +func (m *ReceiverSettleMode) String() string { + if m == nil { + return "" + } + + switch *m { + case ReceiverSettleModeFirst: + return "first" + + case ReceiverSettleModeSecond: + return "second" + + default: + return fmt.Sprintf("unknown receiver mode %d", uint8(*m)) + } +} + +func (m ReceiverSettleMode) Marshal(wr *buffer.Buffer) error { + return Marshal(wr, uint8(m)) +} + +func (m *ReceiverSettleMode) Unmarshal(r *buffer.Buffer) error { + n, err := ReadUbyte(r) + *m = ReceiverSettleMode(n) + return err +} + +type Role bool + +const ( + RoleSender Role = false + RoleReceiver Role = true +) + +func (rl Role) String() string { + if rl { + return "Receiver" + } + return "Sender" +} + +func (rl *Role) Unmarshal(r *buffer.Buffer) error { + b, err := readBool(r) + *rl = Role(b) + return err +} + +func (rl Role) Marshal(wr *buffer.Buffer) error { + return Marshal(wr, (bool)(rl)) +} + +type SASLCode uint8 + +// SASL Codes +const ( + CodeSASLOK SASLCode = iota // Connection authentication succeeded. + CodeSASLAuth // Connection authentication failed due to an unspecified problem with the supplied credentials. + CodeSASLSysPerm // Connection authentication failed due to a system error that is unlikely to be corrected without intervention. +) + +func (s SASLCode) Marshal(wr *buffer.Buffer) error { + return Marshal(wr, uint8(s)) +} + +func (s *SASLCode) Unmarshal(r *buffer.Buffer) error { + n, err := ReadUbyte(r) + *s = SASLCode(n) + return err +} + +// DeliveryState encapsulates the various concrete delivery states. +// http://docs.oasis-open.org/amqp/core/v1.0/os/amqp-core-messaging-v1.0-os.html#section-delivery-state +// TODO: http://docs.oasis-open.org/amqp/core/v1.0/os/amqp-core-transactions-v1.0-os.html#type-declared +type DeliveryState interface { + deliveryState() // marker method +} + +type Unsettled map[string]DeliveryState + +func (u Unsettled) Marshal(wr *buffer.Buffer) error { + return writeMap(wr, u) +} + +func (u *Unsettled) Unmarshal(r *buffer.Buffer) error { + count, err := readMapHeader(r) + if err != nil { + return err + } + + m := make(Unsettled, count/2) + for i := uint32(0); i < count; i += 2 { + key, err := ReadString(r) + if err != nil { + return err + } + var value DeliveryState + err = Unmarshal(r, &value) + if err != nil { + return err + } + m[key] = value + } + *u = m + return nil +} + +type Filter map[Symbol]*DescribedType + +func (f Filter) Marshal(wr *buffer.Buffer) error { + return writeMap(wr, f) +} + +func (f *Filter) Unmarshal(r *buffer.Buffer) error { + count, err := readMapHeader(r) + if err != nil { + return err + } + + m := make(Filter, count/2) + for i := uint32(0); i < count; i += 2 { + key, err := ReadString(r) + if err != nil { + return err + } + var value DescribedType + err = Unmarshal(r, &value) + if err != nil { + return err + } + m[Symbol(key)] = &value + } + *f = m + return nil +} + +// peekMessageType reads the message type without +// modifying any data. +func PeekMessageType(buf []byte) (uint8, uint8, error) { + if len(buf) < 3 { + return 0, 0, errors.New("invalid message") + } + + if buf[0] != 0 { + return 0, 0, fmt.Errorf("invalid composite header %02x", buf[0]) + } + + // copied from readUlong to avoid allocations + t := AMQPType(buf[1]) + if t == TypeCodeUlong0 { + return 0, 2, nil + } + + if t == TypeCodeSmallUlong { + if len(buf[2:]) == 0 { + return 0, 0, errors.New("invalid ulong") + } + return buf[2], 3, nil + } + + if t != TypeCodeUlong { + return 0, 0, fmt.Errorf("invalid type for uint32 %02x", t) + } + + if len(buf[2:]) < 8 { + return 0, 0, errors.New("invalid ulong") + } + v := binary.BigEndian.Uint64(buf[2:10]) + + return uint8(v), 10, nil +} + +func tryReadNull(r *buffer.Buffer) bool { + if r.Len() > 0 && AMQPType(r.Bytes()[0]) == TypeCodeNull { + r.Skip(1) + return true + } + return false +} + +// Annotations keys must be of type string, int, or int64. +// +// String keys are encoded as AMQP Symbols. +type Annotations map[any]any + +func (a Annotations) Marshal(wr *buffer.Buffer) error { + return writeMap(wr, a) +} + +func (a *Annotations) Unmarshal(r *buffer.Buffer) error { + count, err := readMapHeader(r) + if err != nil { + return err + } + + m := make(Annotations, count/2) + for i := uint32(0); i < count; i += 2 { + key, err := ReadAny(r) + if err != nil { + return err + } + value, err := ReadAny(r) + if err != nil { + return err + } + m[key] = value + } + *a = m + return nil +} + +// ErrCond is one of the error conditions defined in the AMQP spec. +type ErrCond string + +func (ec ErrCond) Marshal(wr *buffer.Buffer) error { + return (Symbol)(ec).Marshal(wr) +} + +func (ec *ErrCond) Unmarshal(r *buffer.Buffer) error { + s, err := ReadString(r) + *ec = ErrCond(s) + return err +} + +/* + + + + + + +*/ + +// Error is an AMQP error. +type Error struct { + // A symbolic value indicating the error condition. + Condition ErrCond + + // descriptive text about the error condition + // + // This text supplies any supplementary details not indicated by the condition field. + // This text can be logged as an aid to resolving issues. + Description string + + // map carrying information about the error condition + Info map[string]any +} + +func (e *Error) Marshal(wr *buffer.Buffer) error { + return MarshalComposite(wr, TypeCodeError, []MarshalField{ + {Value: &e.Condition, Omit: false}, + {Value: &e.Description, Omit: e.Description == ""}, + {Value: e.Info, Omit: len(e.Info) == 0}, + }) +} + +func (e *Error) Unmarshal(r *buffer.Buffer) error { + return UnmarshalComposite(r, TypeCodeError, []UnmarshalField{ + {Field: &e.Condition, HandleNull: func() error { return errors.New("Error.Condition is required") }}, + {Field: &e.Description}, + {Field: &e.Info}, + }...) +} + +func (e *Error) String() string { + if e == nil { + return "*Error(nil)" + } + return fmt.Sprintf("*Error{Condition: %s, Description: %s, Info: %v}", + e.Condition, + e.Description, + e.Info, + ) +} + +func (e *Error) Error() string { + return e.String() +} + +/* + + + + + +*/ + +type StateReceived struct { + // When sent by the sender this indicates the first section of the message + // (with section-number 0 being the first section) for which data can be resent. + // Data from sections prior to the given section cannot be retransmitted for + // this delivery. + // + // When sent by the receiver this indicates the first section of the message + // for which all data might not yet have been received. + SectionNumber uint32 + + // When sent by the sender this indicates the first byte of the encoded section + // data of the section given by section-number for which data can be resent + // (with section-offset 0 being the first byte). Bytes from the same section + // prior to the given offset section cannot be retransmitted for this delivery. + // + // When sent by the receiver this indicates the first byte of the given section + // which has not yet been received. Note that if a receiver has received all of + // section number X (which contains N bytes of data), but none of section number + // X + 1, then it can indicate this by sending either Received(section-number=X, + // section-offset=N) or Received(section-number=X+1, section-offset=0). The state + // Received(section-number=0, section-offset=0) indicates that no message data + // at all has been transferred. + SectionOffset uint64 +} + +func (sr *StateReceived) deliveryState() {} + +func (sr *StateReceived) Marshal(wr *buffer.Buffer) error { + return MarshalComposite(wr, TypeCodeStateReceived, []MarshalField{ + {Value: &sr.SectionNumber, Omit: false}, + {Value: &sr.SectionOffset, Omit: false}, + }) +} + +func (sr *StateReceived) Unmarshal(r *buffer.Buffer) error { + return UnmarshalComposite(r, TypeCodeStateReceived, []UnmarshalField{ + {Field: &sr.SectionNumber, HandleNull: func() error { return errors.New("StateReceiver.SectionNumber is required") }}, + {Field: &sr.SectionOffset, HandleNull: func() error { return errors.New("StateReceiver.SectionOffset is required") }}, + }...) +} + +/* + + + +*/ + +type StateAccepted struct{} + +func (sr *StateAccepted) deliveryState() {} + +func (sa *StateAccepted) Marshal(wr *buffer.Buffer) error { + return MarshalComposite(wr, TypeCodeStateAccepted, nil) +} + +func (sa *StateAccepted) Unmarshal(r *buffer.Buffer) error { + return UnmarshalComposite(r, TypeCodeStateAccepted) +} + +func (sa *StateAccepted) String() string { + return "Accepted" +} + +/* + + + + +*/ + +type StateRejected struct { + Error *Error +} + +func (sr *StateRejected) deliveryState() {} + +func (sr *StateRejected) Marshal(wr *buffer.Buffer) error { + return MarshalComposite(wr, TypeCodeStateRejected, []MarshalField{ + {Value: sr.Error, Omit: sr.Error == nil}, + }) +} + +func (sr *StateRejected) Unmarshal(r *buffer.Buffer) error { + return UnmarshalComposite(r, TypeCodeStateRejected, + UnmarshalField{Field: &sr.Error}, + ) +} + +func (sr *StateRejected) String() string { + return fmt.Sprintf("Rejected{Error: %v}", sr.Error) +} + +/* + + + +*/ + +type StateReleased struct{} + +func (sr *StateReleased) deliveryState() {} + +func (sr *StateReleased) Marshal(wr *buffer.Buffer) error { + return MarshalComposite(wr, TypeCodeStateReleased, nil) +} + +func (sr *StateReleased) Unmarshal(r *buffer.Buffer) error { + return UnmarshalComposite(r, TypeCodeStateReleased) +} + +func (sr *StateReleased) String() string { + return "Released" +} + +/* + + + + + + +*/ + +type StateModified struct { + // count the transfer as an unsuccessful delivery attempt + // + // If the delivery-failed flag is set, any messages modified + // MUST have their delivery-count incremented. + DeliveryFailed bool + + // prevent redelivery + // + // If the undeliverable-here is set, then any messages released MUST NOT + // be redelivered to the modifying link endpoint. + UndeliverableHere bool + + // message attributes + // Map containing attributes to combine with the existing message-annotations + // held in the message's header section. Where the existing message-annotations + // of the message contain an entry with the same key as an entry in this field, + // the value in this field associated with that key replaces the one in the + // existing headers; where the existing message-annotations has no such value, + // the value in this map is added. + MessageAnnotations Annotations +} + +func (sr *StateModified) deliveryState() {} + +func (sm *StateModified) Marshal(wr *buffer.Buffer) error { + return MarshalComposite(wr, TypeCodeStateModified, []MarshalField{ + {Value: &sm.DeliveryFailed, Omit: !sm.DeliveryFailed}, + {Value: &sm.UndeliverableHere, Omit: !sm.UndeliverableHere}, + {Value: sm.MessageAnnotations, Omit: sm.MessageAnnotations == nil}, + }) +} + +func (sm *StateModified) Unmarshal(r *buffer.Buffer) error { + return UnmarshalComposite(r, TypeCodeStateModified, []UnmarshalField{ + {Field: &sm.DeliveryFailed}, + {Field: &sm.UndeliverableHere}, + {Field: &sm.MessageAnnotations}, + }...) +} + +func (sm *StateModified) String() string { + return fmt.Sprintf("Modified{DeliveryFailed: %t, UndeliverableHere: %t, MessageAnnotations: %v}", sm.DeliveryFailed, sm.UndeliverableHere, sm.MessageAnnotations) +} + +// symbol is an AMQP symbolic string. +type Symbol string + +func (s Symbol) Marshal(wr *buffer.Buffer) error { + l := len(s) + switch { + // Sym8 + case l < 256: + wr.Append([]byte{ + byte(TypeCodeSym8), + byte(l), + }) + wr.AppendString(string(s)) + + // Sym32 + case uint(l) < math.MaxUint32: + wr.AppendByte(uint8(TypeCodeSym32)) + wr.AppendUint32(uint32(l)) + wr.AppendString(string(s)) + default: + return errors.New("too long") + } + return nil +} + +type Milliseconds time.Duration + +func (m Milliseconds) Marshal(wr *buffer.Buffer) error { + writeUint32(wr, uint32(m/Milliseconds(time.Millisecond))) + return nil +} + +func (m *Milliseconds) Unmarshal(r *buffer.Buffer) error { + n, err := readUint(r) + *m = Milliseconds(time.Duration(n) * time.Millisecond) + return err +} + +// mapAnyAny is used to decode AMQP maps who's keys are undefined or +// inconsistently typed. +type mapAnyAny map[any]any + +func (m mapAnyAny) Marshal(wr *buffer.Buffer) error { + return writeMap(wr, map[any]any(m)) +} + +func (m *mapAnyAny) Unmarshal(r *buffer.Buffer) error { + count, err := readMapHeader(r) + if err != nil { + return err + } + + mm := make(mapAnyAny, count/2) + for i := uint32(0); i < count; i += 2 { + key, err := ReadAny(r) + if err != nil { + return err + } + value, err := ReadAny(r) + if err != nil { + return err + } + + // https://golang.org/ref/spec#Map_types: + // The comparison operators == and != must be fully defined + // for operands of the key type; thus the key type must not + // be a function, map, or slice. + switch reflect.ValueOf(key).Kind() { + case reflect.Slice, reflect.Func, reflect.Map: + return errors.New("invalid map key") + } + + mm[key] = value + } + *m = mm + return nil +} + +// mapStringAny is used to decode AMQP maps that have string keys +type mapStringAny map[string]any + +func (m mapStringAny) Marshal(wr *buffer.Buffer) error { + return writeMap(wr, map[string]any(m)) +} + +func (m *mapStringAny) Unmarshal(r *buffer.Buffer) error { + count, err := readMapHeader(r) + if err != nil { + return err + } + + mm := make(mapStringAny, count/2) + for i := uint32(0); i < count; i += 2 { + key, err := ReadString(r) + if err != nil { + return err + } + value, err := ReadAny(r) + if err != nil { + return err + } + mm[key] = value + } + *m = mm + + return nil +} + +// mapStringAny is used to decode AMQP maps that have Symbol keys +type mapSymbolAny map[Symbol]any + +func (m mapSymbolAny) Marshal(wr *buffer.Buffer) error { + return writeMap(wr, map[Symbol]any(m)) +} + +func (m *mapSymbolAny) Unmarshal(r *buffer.Buffer) error { + count, err := readMapHeader(r) + if err != nil { + return err + } + + mm := make(mapSymbolAny, count/2) + for i := uint32(0); i < count; i += 2 { + key, err := ReadString(r) + if err != nil { + return err + } + value, err := ReadAny(r) + if err != nil { + return err + } + mm[Symbol(key)] = value + } + *m = mm + return nil +} + +// UUID is a 128 bit identifier as defined in RFC 4122. +type UUID [16]byte + +// String returns the hex encoded representation described in RFC 4122, Section 3. +func (u UUID) String() string { + var buf [36]byte + hex.Encode(buf[:8], u[:4]) + buf[8] = '-' + hex.Encode(buf[9:13], u[4:6]) + buf[13] = '-' + hex.Encode(buf[14:18], u[6:8]) + buf[18] = '-' + hex.Encode(buf[19:23], u[8:10]) + buf[23] = '-' + hex.Encode(buf[24:], u[10:]) + return string(buf[:]) +} + +func (u UUID) Marshal(wr *buffer.Buffer) error { + wr.AppendByte(byte(TypeCodeUUID)) + wr.Append(u[:]) + return nil +} + +func (u *UUID) Unmarshal(r *buffer.Buffer) error { + un, err := readUUID(r) + *u = un + return err +} + +type LifetimePolicy uint8 + +const ( + DeleteOnClose = LifetimePolicy(TypeCodeDeleteOnClose) + DeleteOnNoLinks = LifetimePolicy(TypeCodeDeleteOnNoLinks) + DeleteOnNoMessages = LifetimePolicy(TypeCodeDeleteOnNoMessages) + DeleteOnNoLinksOrMessages = LifetimePolicy(TypeCodeDeleteOnNoLinksOrMessages) +) + +func (p LifetimePolicy) Marshal(wr *buffer.Buffer) error { + wr.Append([]byte{ + 0x0, + byte(TypeCodeSmallUlong), + byte(p), + byte(TypeCodeList0), + }) + return nil +} + +func (p *LifetimePolicy) Unmarshal(r *buffer.Buffer) error { + typ, fields, err := readCompositeHeader(r) + if err != nil { + return err + } + if fields != 0 { + return fmt.Errorf("invalid size %d for lifetime-policy", fields) + } + *p = LifetimePolicy(typ) + return nil +} + +type DescribedType struct { + Descriptor any + Value any +} + +func (t DescribedType) Marshal(wr *buffer.Buffer) error { + wr.AppendByte(0x0) // descriptor constructor + err := Marshal(wr, t.Descriptor) + if err != nil { + return err + } + return Marshal(wr, t.Value) +} + +func (t *DescribedType) Unmarshal(r *buffer.Buffer) error { + b, err := r.ReadByte() + if err != nil { + return err + } + + if b != 0x0 { + return fmt.Errorf("invalid described type header %02x", b) + } + + err = Unmarshal(r, &t.Descriptor) + if err != nil { + return err + } + return Unmarshal(r, &t.Value) +} + +func (t DescribedType) String() string { + return fmt.Sprintf("DescribedType{descriptor: %v, value: %v}", + t.Descriptor, + t.Value, + ) +} + +// SLICES + +// ArrayUByte allows encoding []uint8/[]byte as an array +// rather than binary data. +type ArrayUByte []uint8 + +func (a ArrayUByte) Marshal(wr *buffer.Buffer) error { + const typeSize = 1 + + writeArrayHeader(wr, len(a), typeSize, TypeCodeUbyte) + wr.Append(a) + + return nil +} + +func (a *ArrayUByte) Unmarshal(r *buffer.Buffer) error { + length, err := readArrayHeader(r) + if err != nil { + return err + } + + type_, err := readType(r) + if err != nil { + return err + } + if type_ != TypeCodeUbyte { + return fmt.Errorf("invalid type for []uint16 %02x", type_) + } + + buf, ok := r.Next(length) + if !ok { + return fmt.Errorf("invalid length %d", length) + } + *a = append([]byte(nil), buf...) + + return nil +} + +type arrayInt8 []int8 + +func (a arrayInt8) Marshal(wr *buffer.Buffer) error { + const typeSize = 1 + + writeArrayHeader(wr, len(a), typeSize, TypeCodeByte) + + for _, value := range a { + wr.AppendByte(uint8(value)) + } + + return nil +} + +func (a *arrayInt8) Unmarshal(r *buffer.Buffer) error { + length, err := readArrayHeader(r) + if err != nil { + return err + } + + type_, err := readType(r) + if err != nil { + return err + } + if type_ != TypeCodeByte { + return fmt.Errorf("invalid type for []uint16 %02x", type_) + } + + buf, ok := r.Next(length) + if !ok { + return fmt.Errorf("invalid length %d", length) + } + + aa := (*a)[:0] + if int64(cap(aa)) < length { + aa = make([]int8, length) + } else { + aa = aa[:length] + } + + for i, value := range buf { + aa[i] = int8(value) + } + + *a = aa + return nil +} + +type arrayUint16 []uint16 + +func (a arrayUint16) Marshal(wr *buffer.Buffer) error { + const typeSize = 2 + + writeArrayHeader(wr, len(a), typeSize, TypeCodeUshort) + + for _, element := range a { + wr.AppendUint16(element) + } + + return nil +} + +func (a *arrayUint16) Unmarshal(r *buffer.Buffer) error { + length, err := readArrayHeader(r) + if err != nil { + return err + } + + type_, err := readType(r) + if err != nil { + return err + } + if type_ != TypeCodeUshort { + return fmt.Errorf("invalid type for []uint16 %02x", type_) + } + + const typeSize = 2 + buf, ok := r.Next(length * typeSize) + if !ok { + return fmt.Errorf("invalid length %d", length) + } + + aa := (*a)[:0] + if int64(cap(aa)) < length { + aa = make([]uint16, length) + } else { + aa = aa[:length] + } + + var bufIdx int + for i := range aa { + aa[i] = binary.BigEndian.Uint16(buf[bufIdx:]) + bufIdx += 2 + } + + *a = aa + return nil +} + +type arrayInt16 []int16 + +func (a arrayInt16) Marshal(wr *buffer.Buffer) error { + const typeSize = 2 + + writeArrayHeader(wr, len(a), typeSize, TypeCodeShort) + + for _, element := range a { + wr.AppendUint16(uint16(element)) + } + + return nil +} + +func (a *arrayInt16) Unmarshal(r *buffer.Buffer) error { + length, err := readArrayHeader(r) + if err != nil { + return err + } + + type_, err := readType(r) + if err != nil { + return err + } + if type_ != TypeCodeShort { + return fmt.Errorf("invalid type for []uint16 %02x", type_) + } + + const typeSize = 2 + buf, ok := r.Next(length * typeSize) + if !ok { + return fmt.Errorf("invalid length %d", length) + } + + aa := (*a)[:0] + if int64(cap(aa)) < length { + aa = make([]int16, length) + } else { + aa = aa[:length] + } + + var bufIdx int + for i := range aa { + aa[i] = int16(binary.BigEndian.Uint16(buf[bufIdx : bufIdx+2])) + bufIdx += 2 + } + + *a = aa + return nil +} + +type arrayUint32 []uint32 + +func (a arrayUint32) Marshal(wr *buffer.Buffer) error { + var ( + typeSize = 1 + TypeCode = TypeCodeSmallUint + ) + for _, n := range a { + if n > math.MaxUint8 { + typeSize = 4 + TypeCode = TypeCodeUint + break + } + } + + writeArrayHeader(wr, len(a), typeSize, TypeCode) + + if TypeCode == TypeCodeUint { + for _, element := range a { + wr.AppendUint32(element) + } + } else { + for _, element := range a { + wr.AppendByte(byte(element)) + } + } + + return nil +} + +func (a *arrayUint32) Unmarshal(r *buffer.Buffer) error { + length, err := readArrayHeader(r) + if err != nil { + return err + } + + aa := (*a)[:0] + + type_, err := readType(r) + if err != nil { + return err + } + switch type_ { + case TypeCodeUint0: + if int64(cap(aa)) < length { + aa = make([]uint32, length) + } else { + aa = aa[:length] + for i := range aa { + aa[i] = 0 + } + } + case TypeCodeSmallUint: + buf, ok := r.Next(length) + if !ok { + return errors.New("invalid length") + } + + if int64(cap(aa)) < length { + aa = make([]uint32, length) + } else { + aa = aa[:length] + } + + for i, n := range buf { + aa[i] = uint32(n) + } + case TypeCodeUint: + const typeSize = 4 + buf, ok := r.Next(length * typeSize) + if !ok { + return fmt.Errorf("invalid length %d", length) + } + + if int64(cap(aa)) < length { + aa = make([]uint32, length) + } else { + aa = aa[:length] + } + + var bufIdx int + for i := range aa { + aa[i] = binary.BigEndian.Uint32(buf[bufIdx : bufIdx+4]) + bufIdx += 4 + } + default: + return fmt.Errorf("invalid type for []uint32 %02x", type_) + } + + *a = aa + return nil +} + +type arrayInt32 []int32 + +func (a arrayInt32) Marshal(wr *buffer.Buffer) error { + var ( + typeSize = 1 + TypeCode = TypeCodeSmallint + ) + for _, n := range a { + if n > math.MaxInt8 { + typeSize = 4 + TypeCode = TypeCodeInt + break + } + } + + writeArrayHeader(wr, len(a), typeSize, TypeCode) + + if TypeCode == TypeCodeInt { + for _, element := range a { + wr.AppendUint32(uint32(element)) + } + } else { + for _, element := range a { + wr.AppendByte(byte(element)) + } + } + + return nil +} + +func (a *arrayInt32) Unmarshal(r *buffer.Buffer) error { + length, err := readArrayHeader(r) + if err != nil { + return err + } + + aa := (*a)[:0] + + type_, err := readType(r) + if err != nil { + return err + } + switch type_ { + case TypeCodeSmallint: + buf, ok := r.Next(length) + if !ok { + return errors.New("invalid length") + } + + if int64(cap(aa)) < length { + aa = make([]int32, length) + } else { + aa = aa[:length] + } + + for i, n := range buf { + aa[i] = int32(int8(n)) + } + case TypeCodeInt: + const typeSize = 4 + buf, ok := r.Next(length * typeSize) + if !ok { + return fmt.Errorf("invalid length %d", length) + } + + if int64(cap(aa)) < length { + aa = make([]int32, length) + } else { + aa = aa[:length] + } + + var bufIdx int + for i := range aa { + aa[i] = int32(binary.BigEndian.Uint32(buf[bufIdx:])) + bufIdx += 4 + } + default: + return fmt.Errorf("invalid type for []int32 %02x", type_) + } + + *a = aa + return nil +} + +type arrayUint64 []uint64 + +func (a arrayUint64) Marshal(wr *buffer.Buffer) error { + var ( + typeSize = 1 + TypeCode = TypeCodeSmallUlong + ) + for _, n := range a { + if n > math.MaxUint8 { + typeSize = 8 + TypeCode = TypeCodeUlong + break + } + } + + writeArrayHeader(wr, len(a), typeSize, TypeCode) + + if TypeCode == TypeCodeUlong { + for _, element := range a { + wr.AppendUint64(element) + } + } else { + for _, element := range a { + wr.AppendByte(byte(element)) + } + } + + return nil +} + +func (a *arrayUint64) Unmarshal(r *buffer.Buffer) error { + length, err := readArrayHeader(r) + if err != nil { + return err + } + + aa := (*a)[:0] + + type_, err := readType(r) + if err != nil { + return err + } + switch type_ { + case TypeCodeUlong0: + if int64(cap(aa)) < length { + aa = make([]uint64, length) + } else { + aa = aa[:length] + for i := range aa { + aa[i] = 0 + } + } + case TypeCodeSmallUlong: + buf, ok := r.Next(length) + if !ok { + return errors.New("invalid length") + } + + if int64(cap(aa)) < length { + aa = make([]uint64, length) + } else { + aa = aa[:length] + } + + for i, n := range buf { + aa[i] = uint64(n) + } + case TypeCodeUlong: + const typeSize = 8 + buf, ok := r.Next(length * typeSize) + if !ok { + return errors.New("invalid length") + } + + if int64(cap(aa)) < length { + aa = make([]uint64, length) + } else { + aa = aa[:length] + } + + var bufIdx int + for i := range aa { + aa[i] = binary.BigEndian.Uint64(buf[bufIdx : bufIdx+8]) + bufIdx += 8 + } + default: + return fmt.Errorf("invalid type for []uint64 %02x", type_) + } + + *a = aa + return nil +} + +type arrayInt64 []int64 + +func (a arrayInt64) Marshal(wr *buffer.Buffer) error { + var ( + typeSize = 1 + TypeCode = TypeCodeSmalllong + ) + for _, n := range a { + if n > math.MaxInt8 { + typeSize = 8 + TypeCode = TypeCodeLong + break + } + } + + writeArrayHeader(wr, len(a), typeSize, TypeCode) + + if TypeCode == TypeCodeLong { + for _, element := range a { + wr.AppendUint64(uint64(element)) + } + } else { + for _, element := range a { + wr.AppendByte(byte(element)) + } + } + + return nil +} + +func (a *arrayInt64) Unmarshal(r *buffer.Buffer) error { + length, err := readArrayHeader(r) + if err != nil { + return err + } + + aa := (*a)[:0] + + type_, err := readType(r) + if err != nil { + return err + } + switch type_ { + case TypeCodeSmalllong: + buf, ok := r.Next(length) + if !ok { + return errors.New("invalid length") + } + + if int64(cap(aa)) < length { + aa = make([]int64, length) + } else { + aa = aa[:length] + } + + for i, n := range buf { + aa[i] = int64(int8(n)) + } + case TypeCodeLong: + const typeSize = 8 + buf, ok := r.Next(length * typeSize) + if !ok { + return errors.New("invalid length") + } + + if int64(cap(aa)) < length { + aa = make([]int64, length) + } else { + aa = aa[:length] + } + + var bufIdx int + for i := range aa { + aa[i] = int64(binary.BigEndian.Uint64(buf[bufIdx:])) + bufIdx += 8 + } + default: + return fmt.Errorf("invalid type for []uint64 %02x", type_) + } + + *a = aa + return nil +} + +type arrayFloat []float32 + +func (a arrayFloat) Marshal(wr *buffer.Buffer) error { + const typeSize = 4 + + writeArrayHeader(wr, len(a), typeSize, TypeCodeFloat) + + for _, element := range a { + wr.AppendUint32(math.Float32bits(element)) + } + + return nil +} + +func (a *arrayFloat) Unmarshal(r *buffer.Buffer) error { + length, err := readArrayHeader(r) + if err != nil { + return err + } + + type_, err := readType(r) + if err != nil { + return err + } + if type_ != TypeCodeFloat { + return fmt.Errorf("invalid type for []float32 %02x", type_) + } + + const typeSize = 4 + buf, ok := r.Next(length * typeSize) + if !ok { + return fmt.Errorf("invalid length %d", length) + } + + aa := (*a)[:0] + if int64(cap(aa)) < length { + aa = make([]float32, length) + } else { + aa = aa[:length] + } + + var bufIdx int + for i := range aa { + bits := binary.BigEndian.Uint32(buf[bufIdx:]) + aa[i] = math.Float32frombits(bits) + bufIdx += typeSize + } + + *a = aa + return nil +} + +type arrayDouble []float64 + +func (a arrayDouble) Marshal(wr *buffer.Buffer) error { + const typeSize = 8 + + writeArrayHeader(wr, len(a), typeSize, TypeCodeDouble) + + for _, element := range a { + wr.AppendUint64(math.Float64bits(element)) + } + + return nil +} + +func (a *arrayDouble) Unmarshal(r *buffer.Buffer) error { + length, err := readArrayHeader(r) + if err != nil { + return err + } + + type_, err := readType(r) + if err != nil { + return err + } + if type_ != TypeCodeDouble { + return fmt.Errorf("invalid type for []float64 %02x", type_) + } + + const typeSize = 8 + buf, ok := r.Next(length * typeSize) + if !ok { + return fmt.Errorf("invalid length %d", length) + } + + aa := (*a)[:0] + if int64(cap(aa)) < length { + aa = make([]float64, length) + } else { + aa = aa[:length] + } + + var bufIdx int + for i := range aa { + bits := binary.BigEndian.Uint64(buf[bufIdx:]) + aa[i] = math.Float64frombits(bits) + bufIdx += typeSize + } + + *a = aa + return nil +} + +type arrayBool []bool + +func (a arrayBool) Marshal(wr *buffer.Buffer) error { + const typeSize = 1 + + writeArrayHeader(wr, len(a), typeSize, TypeCodeBool) + + for _, element := range a { + value := byte(0) + if element { + value = 1 + } + wr.AppendByte(value) + } + + return nil +} + +func (a *arrayBool) Unmarshal(r *buffer.Buffer) error { + length, err := readArrayHeader(r) + if err != nil { + return err + } + + aa := (*a)[:0] + if int64(cap(aa)) < length { + aa = make([]bool, length) + } else { + aa = aa[:length] + } + + type_, err := readType(r) + if err != nil { + return err + } + switch type_ { + case TypeCodeBool: + buf, ok := r.Next(length) + if !ok { + return errors.New("invalid length") + } + + for i, value := range buf { + if value == 0 { + aa[i] = false + } else { + aa[i] = true + } + } + + case TypeCodeBoolTrue: + for i := range aa { + aa[i] = true + } + case TypeCodeBoolFalse: + for i := range aa { + aa[i] = false + } + default: + return fmt.Errorf("invalid type for []bool %02x", type_) + } + + *a = aa + return nil +} + +type arrayString []string + +func (a arrayString) Marshal(wr *buffer.Buffer) error { + var ( + elementType = TypeCodeStr8 + elementsSizeTotal int + ) + for _, element := range a { + if !utf8.ValidString(element) { + return errors.New("not a valid UTF-8 string") + } + + elementsSizeTotal += len(element) + + if len(element) > math.MaxUint8 { + elementType = TypeCodeStr32 + } + } + + writeVariableArrayHeader(wr, len(a), elementsSizeTotal, elementType) + + if elementType == TypeCodeStr32 { + for _, element := range a { + wr.AppendUint32(uint32(len(element))) + wr.AppendString(element) + } + } else { + for _, element := range a { + wr.AppendByte(byte(len(element))) + wr.AppendString(element) + } + } + + return nil +} + +func (a *arrayString) Unmarshal(r *buffer.Buffer) error { + length, err := readArrayHeader(r) + if err != nil { + return err + } + + const typeSize = 2 // assume all strings are at least 2 bytes + if length*typeSize > int64(r.Len()) { + return fmt.Errorf("invalid length %d", length) + } + + aa := (*a)[:0] + if int64(cap(aa)) < length { + aa = make([]string, length) + } else { + aa = aa[:length] + } + + type_, err := readType(r) + if err != nil { + return err + } + switch type_ { + case TypeCodeStr8: + for i := range aa { + size, err := r.ReadByte() + if err != nil { + return err + } + + buf, ok := r.Next(int64(size)) + if !ok { + return errors.New("invalid length") + } + + aa[i] = string(buf) + } + case TypeCodeStr32: + for i := range aa { + buf, ok := r.Next(4) + if !ok { + return errors.New("invalid length") + } + size := int64(binary.BigEndian.Uint32(buf)) + + buf, ok = r.Next(size) + if !ok { + return errors.New("invalid length") + } + aa[i] = string(buf) + } + default: + return fmt.Errorf("invalid type for []string %02x", type_) + } + + *a = aa + return nil +} + +type arraySymbol []Symbol + +func (a arraySymbol) Marshal(wr *buffer.Buffer) error { + var ( + elementType = TypeCodeSym8 + elementsSizeTotal int + ) + for _, element := range a { + elementsSizeTotal += len(element) + + if len(element) > math.MaxUint8 { + elementType = TypeCodeSym32 + } + } + + writeVariableArrayHeader(wr, len(a), elementsSizeTotal, elementType) + + if elementType == TypeCodeSym32 { + for _, element := range a { + wr.AppendUint32(uint32(len(element))) + wr.AppendString(string(element)) + } + } else { + for _, element := range a { + wr.AppendByte(byte(len(element))) + wr.AppendString(string(element)) + } + } + + return nil +} + +func (a *arraySymbol) Unmarshal(r *buffer.Buffer) error { + length, err := readArrayHeader(r) + if err != nil { + return err + } + + const typeSize = 2 // assume all symbols are at least 2 bytes + if length*typeSize > int64(r.Len()) { + return fmt.Errorf("invalid length %d", length) + } + + aa := (*a)[:0] + if int64(cap(aa)) < length { + aa = make([]Symbol, length) + } else { + aa = aa[:length] + } + + type_, err := readType(r) + if err != nil { + return err + } + switch type_ { + case TypeCodeSym8: + for i := range aa { + size, err := r.ReadByte() + if err != nil { + return err + } + + buf, ok := r.Next(int64(size)) + if !ok { + return errors.New("invalid length") + } + aa[i] = Symbol(buf) + } + case TypeCodeSym32: + for i := range aa { + buf, ok := r.Next(4) + if !ok { + return errors.New("invalid length") + } + size := int64(binary.BigEndian.Uint32(buf)) + + buf, ok = r.Next(size) + if !ok { + return errors.New("invalid length") + } + aa[i] = Symbol(buf) + } + default: + return fmt.Errorf("invalid type for []Symbol %02x", type_) + } + + *a = aa + return nil +} + +type arrayBinary [][]byte + +func (a arrayBinary) Marshal(wr *buffer.Buffer) error { + var ( + elementType = TypeCodeVbin8 + elementsSizeTotal int + ) + for _, element := range a { + elementsSizeTotal += len(element) + + if len(element) > math.MaxUint8 { + elementType = TypeCodeVbin32 + } + } + + writeVariableArrayHeader(wr, len(a), elementsSizeTotal, elementType) + + if elementType == TypeCodeVbin32 { + for _, element := range a { + wr.AppendUint32(uint32(len(element))) + wr.Append(element) + } + } else { + for _, element := range a { + wr.AppendByte(byte(len(element))) + wr.Append(element) + } + } + + return nil +} + +func (a *arrayBinary) Unmarshal(r *buffer.Buffer) error { + length, err := readArrayHeader(r) + if err != nil { + return err + } + + const typeSize = 2 // assume all binary is at least 2 bytes + if length*typeSize > int64(r.Len()) { + return fmt.Errorf("invalid length %d", length) + } + + aa := (*a)[:0] + if int64(cap(aa)) < length { + aa = make([][]byte, length) + } else { + aa = aa[:length] + } + + type_, err := readType(r) + if err != nil { + return err + } + switch type_ { + case TypeCodeVbin8: + for i := range aa { + size, err := r.ReadByte() + if err != nil { + return err + } + + buf, ok := r.Next(int64(size)) + if !ok { + return fmt.Errorf("invalid length %d", length) + } + aa[i] = append([]byte(nil), buf...) + } + case TypeCodeVbin32: + for i := range aa { + buf, ok := r.Next(4) + if !ok { + return errors.New("invalid length") + } + size := binary.BigEndian.Uint32(buf) + + buf, ok = r.Next(int64(size)) + if !ok { + return errors.New("invalid length") + } + aa[i] = append([]byte(nil), buf...) + } + default: + return fmt.Errorf("invalid type for [][]byte %02x", type_) + } + + *a = aa + return nil +} + +type arrayTimestamp []time.Time + +func (a arrayTimestamp) Marshal(wr *buffer.Buffer) error { + const typeSize = 8 + + writeArrayHeader(wr, len(a), typeSize, TypeCodeTimestamp) + + for _, element := range a { + ms := element.UnixNano() / int64(time.Millisecond) + wr.AppendUint64(uint64(ms)) + } + + return nil +} + +func (a *arrayTimestamp) Unmarshal(r *buffer.Buffer) error { + length, err := readArrayHeader(r) + if err != nil { + return err + } + + type_, err := readType(r) + if err != nil { + return err + } + if type_ != TypeCodeTimestamp { + return fmt.Errorf("invalid type for []time.Time %02x", type_) + } + + const typeSize = 8 + buf, ok := r.Next(length * typeSize) + if !ok { + return fmt.Errorf("invalid length %d", length) + } + + aa := (*a)[:0] + if int64(cap(aa)) < length { + aa = make([]time.Time, length) + } else { + aa = aa[:length] + } + + var bufIdx int + for i := range aa { + ms := int64(binary.BigEndian.Uint64(buf[bufIdx:])) + bufIdx += typeSize + aa[i] = time.Unix(ms/1000, (ms%1000)*1000000).UTC() + } + + *a = aa + return nil +} + +type arrayUUID []UUID + +func (a arrayUUID) Marshal(wr *buffer.Buffer) error { + const typeSize = 16 + + writeArrayHeader(wr, len(a), typeSize, TypeCodeUUID) + + for _, element := range a { + wr.Append(element[:]) + } + + return nil +} + +func (a *arrayUUID) Unmarshal(r *buffer.Buffer) error { + length, err := readArrayHeader(r) + if err != nil { + return err + } + + type_, err := readType(r) + if err != nil { + return err + } + if type_ != TypeCodeUUID { + return fmt.Errorf("invalid type for []UUID %#02x", type_) + } + + const typeSize = 16 + buf, ok := r.Next(length * typeSize) + if !ok { + return fmt.Errorf("invalid length %d", length) + } + + aa := (*a)[:0] + if int64(cap(aa)) < length { + aa = make([]UUID, length) + } else { + aa = aa[:length] + } + + var bufIdx int + for i := range aa { + copy(aa[i][:], buf[bufIdx:bufIdx+16]) + bufIdx += 16 + } + + *a = aa + return nil +} + +// LIST + +type list []any + +func (l list) Marshal(wr *buffer.Buffer) error { + length := len(l) + + // type + if length == 0 { + wr.AppendByte(byte(TypeCodeList0)) + return nil + } + wr.AppendByte(byte(TypeCodeList32)) + + // size + sizeIdx := wr.Len() + wr.Append([]byte{0, 0, 0, 0}) + + // length + wr.AppendUint32(uint32(length)) + + for _, element := range l { + err := Marshal(wr, element) + if err != nil { + return err + } + } + + // overwrite size + binary.BigEndian.PutUint32(wr.Bytes()[sizeIdx:], uint32(wr.Len()-(sizeIdx+4))) + + return nil +} + +func (l *list) Unmarshal(r *buffer.Buffer) error { + length, err := readListHeader(r) + if err != nil { + return err + } + + // assume that all types are at least 1 byte + if length > int64(r.Len()) { + return fmt.Errorf("invalid length %d", length) + } + + ll := *l + if int64(cap(ll)) < length { + ll = make([]any, length) + } else { + ll = ll[:length] + } + + for i := range ll { + ll[i], err = ReadAny(r) + if err != nil { + return err + } + } + + *l = ll + return nil +} + +// multiSymbol can decode a single symbol or an array. +type MultiSymbol []Symbol + +func (ms MultiSymbol) Marshal(wr *buffer.Buffer) error { + return Marshal(wr, []Symbol(ms)) +} + +func (ms *MultiSymbol) Unmarshal(r *buffer.Buffer) error { + type_, err := peekType(r) + if err != nil { + return err + } + + if type_ == TypeCodeSym8 || type_ == TypeCodeSym32 { + s, err := ReadString(r) + if err != nil { + return err + } + + *ms = []Symbol{Symbol(s)} + return nil + } + + return Unmarshal(r, (*[]Symbol)(ms)) +} diff --git a/vendor/github.com/Azure/go-amqp/internal/frames/frames.go b/vendor/github.com/Azure/go-amqp/internal/frames/frames.go new file mode 100644 index 00000000000..7255af0c42c --- /dev/null +++ b/vendor/github.com/Azure/go-amqp/internal/frames/frames.go @@ -0,0 +1,1540 @@ +package frames + +import ( + "errors" + "fmt" + "strconv" + "time" + + "github.com/Azure/go-amqp/internal/buffer" + "github.com/Azure/go-amqp/internal/encoding" +) + +// Type contains the values for a frame's type. +type Type uint8 + +const ( + TypeAMQP Type = 0x0 + TypeSASL Type = 0x1 +) + +// String implements the fmt.Stringer interface for type Type. +func (t Type) String() string { + if t == 0 { + return "AMQP" + } + return "SASL" +} + +/* + + + + + + + + + + + + + + + + +*/ +type Source struct { + // the address of the source + // + // The address of the source MUST NOT be set when sent on a attach frame sent by + // the receiving link endpoint where the dynamic flag is set to true (that is where + // the receiver is requesting the sender to create an addressable node). + // + // The address of the source MUST be set when sent on a attach frame sent by the + // sending link endpoint where the dynamic flag is set to true (that is where the + // sender has created an addressable node at the request of the receiver and is now + // communicating the address of that created node). The generated name of the address + // SHOULD include the link name and the container-id of the remote container to allow + // for ease of identification. + Address string + + // indicates the durability of the terminus + // + // Indicates what state of the terminus will be retained durably: the state of durable + // messages, only existence and configuration of the terminus, or no state at all. + // + // 0: none + // 1: configuration + // 2: unsettled-state + Durable encoding.Durability + + // the expiry policy of the source + // + // link-detach: The expiry timer starts when terminus is detached. + // session-end: The expiry timer starts when the most recently associated session is + // ended. + // connection-close: The expiry timer starts when most recently associated connection + // is closed. + // never: The terminus never expires. + ExpiryPolicy encoding.ExpiryPolicy + + // duration that an expiring source will be retained + // + // The source starts expiring as indicated by the expiry-policy. + Timeout uint32 // seconds + + // request dynamic creation of a remote node + // + // When set to true by the receiving link endpoint, this field constitutes a request + // for the sending peer to dynamically create a node at the source. In this case the + // address field MUST NOT be set. + // + // When set to true by the sending link endpoint this field indicates creation of a + // dynamically created node. In this case the address field will contain the address + // of the created node. The generated address SHOULD include the link name and other + // available information on the initiator of the request (such as the remote + // container-id) in some recognizable form for ease of traceability. + Dynamic bool + + // properties of the dynamically created node + // + // If the dynamic field is not set to true this field MUST be left unset. + // + // When set by the receiving link endpoint, this field contains the desired + // properties of the node the receiver wishes to be created. When set by the + // sending link endpoint this field contains the actual properties of the + // dynamically created node. See subsection 3.5.9 for standard node properties. + // http://www.amqp.org/specification/1.0/node-properties + // + // lifetime-policy: The lifetime of a dynamically generated node. + // Definitionally, the lifetime will never be less than the lifetime + // of the link which caused its creation, however it is possible to + // extend the lifetime of dynamically created node using a lifetime + // policy. The value of this entry MUST be of a type which provides + // the lifetime-policy archetype. The following standard + // lifetime-policies are defined below: delete-on-close, + // delete-on-no-links, delete-on-no-messages or + // delete-on-no-links-or-messages. + // supported-dist-modes: The distribution modes that the node supports. + // The value of this entry MUST be one or more symbols which are valid + // distribution-modes. That is, the value MUST be of the same type as + // would be valid in a field defined with the following attributes: + // type="symbol" multiple="true" requires="distribution-mode" + DynamicNodeProperties map[encoding.Symbol]any // TODO: implement custom type with validation + + // the distribution mode of the link + // + // This field MUST be set by the sending end of the link if the endpoint supports more + // than one distribution-mode. This field MAY be set by the receiving end of the link + // to indicate a preference when a node supports multiple distribution modes. + DistributionMode encoding.Symbol + + // a set of predicates to filter the messages admitted onto the link + // + // The receiving endpoint sets its desired filter, the sending endpoint sets the filter + // actually in place (including any filters defaulted at the node). The receiving + // endpoint MUST check that the filter in place meets its needs and take responsibility + // for detaching if it does not. + Filter encoding.Filter + + // default outcome for unsettled transfers + // + // Indicates the outcome to be used for transfers that have not reached a terminal + // state at the receiver when the transfer is settled, including when the source + // is destroyed. The value MUST be a valid outcome (e.g., released or rejected). + DefaultOutcome any + + // descriptors for the outcomes that can be chosen on this link + // + // The values in this field are the symbolic descriptors of the outcomes that can + // be chosen on this link. This field MAY be empty, indicating that the default-outcome + // will be assumed for all message transfers (if the default-outcome is not set, and no + // outcomes are provided, then the accepted outcome MUST be supported by the source). + // + // When present, the values MUST be a symbolic descriptor of a valid outcome, + // e.g., "amqp:accepted:list". + Outcomes encoding.MultiSymbol + + // the extension capabilities the sender supports/desires + // + // http://www.amqp.org/specification/1.0/source-capabilities + Capabilities encoding.MultiSymbol +} + +func (s *Source) Marshal(wr *buffer.Buffer) error { + return encoding.MarshalComposite(wr, encoding.TypeCodeSource, []encoding.MarshalField{ + {Value: &s.Address, Omit: s.Address == ""}, + {Value: &s.Durable, Omit: s.Durable == encoding.DurabilityNone}, + {Value: &s.ExpiryPolicy, Omit: s.ExpiryPolicy == "" || s.ExpiryPolicy == encoding.ExpirySessionEnd}, + {Value: &s.Timeout, Omit: s.Timeout == 0}, + {Value: &s.Dynamic, Omit: !s.Dynamic}, + {Value: s.DynamicNodeProperties, Omit: len(s.DynamicNodeProperties) == 0}, + {Value: &s.DistributionMode, Omit: s.DistributionMode == ""}, + {Value: s.Filter, Omit: len(s.Filter) == 0}, + {Value: &s.DefaultOutcome, Omit: s.DefaultOutcome == nil}, + {Value: &s.Outcomes, Omit: len(s.Outcomes) == 0}, + {Value: &s.Capabilities, Omit: len(s.Capabilities) == 0}, + }) +} + +func (s *Source) Unmarshal(r *buffer.Buffer) error { + return encoding.UnmarshalComposite(r, encoding.TypeCodeSource, []encoding.UnmarshalField{ + {Field: &s.Address}, + {Field: &s.Durable}, + {Field: &s.ExpiryPolicy, HandleNull: func() error { s.ExpiryPolicy = encoding.ExpirySessionEnd; return nil }}, + {Field: &s.Timeout}, + {Field: &s.Dynamic}, + {Field: &s.DynamicNodeProperties}, + {Field: &s.DistributionMode}, + {Field: &s.Filter}, + {Field: &s.DefaultOutcome}, + {Field: &s.Outcomes}, + {Field: &s.Capabilities}, + }...) +} + +func (s Source) String() string { + return fmt.Sprintf("source{Address: %s, Durable: %d, ExpiryPolicy: %s, Timeout: %d, "+ + "Dynamic: %t, DynamicNodeProperties: %v, DistributionMode: %s, Filter: %v, DefaultOutcome: %v "+ + "Outcomes: %v, Capabilities: %v}", + s.Address, + s.Durable, + s.ExpiryPolicy, + s.Timeout, + s.Dynamic, + s.DynamicNodeProperties, + s.DistributionMode, + s.Filter, + s.DefaultOutcome, + s.Outcomes, + s.Capabilities, + ) +} + +/* + + + + + + + + + + + + +*/ +type Target struct { + // the address of the target + // + // The address of the target MUST NOT be set when sent on a attach frame sent by + // the sending link endpoint where the dynamic flag is set to true (that is where + // the sender is requesting the receiver to create an addressable node). + // + // The address of the source MUST be set when sent on a attach frame sent by the + // receiving link endpoint where the dynamic flag is set to true (that is where + // the receiver has created an addressable node at the request of the sender and + // is now communicating the address of that created node). The generated name of + // the address SHOULD include the link name and the container-id of the remote + // container to allow for ease of identification. + Address string + + // indicates the durability of the terminus + // + // Indicates what state of the terminus will be retained durably: the state of durable + // messages, only existence and configuration of the terminus, or no state at all. + // + // 0: none + // 1: configuration + // 2: unsettled-state + Durable encoding.Durability + + // the expiry policy of the target + // + // link-detach: The expiry timer starts when terminus is detached. + // session-end: The expiry timer starts when the most recently associated session is + // ended. + // connection-close: The expiry timer starts when most recently associated connection + // is closed. + // never: The terminus never expires. + ExpiryPolicy encoding.ExpiryPolicy + + // duration that an expiring target will be retained + // + // The target starts expiring as indicated by the expiry-policy. + Timeout uint32 // seconds + + // request dynamic creation of a remote node + // + // When set to true by the sending link endpoint, this field constitutes a request + // for the receiving peer to dynamically create a node at the target. In this case + // the address field MUST NOT be set. + // + // When set to true by the receiving link endpoint this field indicates creation of + // a dynamically created node. In this case the address field will contain the + // address of the created node. The generated address SHOULD include the link name + // and other available information on the initiator of the request (such as the + // remote container-id) in some recognizable form for ease of traceability. + Dynamic bool + + // properties of the dynamically created node + // + // If the dynamic field is not set to true this field MUST be left unset. + // + // When set by the sending link endpoint, this field contains the desired + // properties of the node the sender wishes to be created. When set by the + // receiving link endpoint this field contains the actual properties of the + // dynamically created node. See subsection 3.5.9 for standard node properties. + // http://www.amqp.org/specification/1.0/node-properties + // + // lifetime-policy: The lifetime of a dynamically generated node. + // Definitionally, the lifetime will never be less than the lifetime + // of the link which caused its creation, however it is possible to + // extend the lifetime of dynamically created node using a lifetime + // policy. The value of this entry MUST be of a type which provides + // the lifetime-policy archetype. The following standard + // lifetime-policies are defined below: delete-on-close, + // delete-on-no-links, delete-on-no-messages or + // delete-on-no-links-or-messages. + // supported-dist-modes: The distribution modes that the node supports. + // The value of this entry MUST be one or more symbols which are valid + // distribution-modes. That is, the value MUST be of the same type as + // would be valid in a field defined with the following attributes: + // type="symbol" multiple="true" requires="distribution-mode" + DynamicNodeProperties map[encoding.Symbol]any // TODO: implement custom type with validation + + // the extension capabilities the sender supports/desires + // + // http://www.amqp.org/specification/1.0/target-capabilities + Capabilities encoding.MultiSymbol +} + +func (t *Target) Marshal(wr *buffer.Buffer) error { + return encoding.MarshalComposite(wr, encoding.TypeCodeTarget, []encoding.MarshalField{ + {Value: &t.Address, Omit: t.Address == ""}, + {Value: &t.Durable, Omit: t.Durable == encoding.DurabilityNone}, + {Value: &t.ExpiryPolicy, Omit: t.ExpiryPolicy == "" || t.ExpiryPolicy == encoding.ExpirySessionEnd}, + {Value: &t.Timeout, Omit: t.Timeout == 0}, + {Value: &t.Dynamic, Omit: !t.Dynamic}, + {Value: t.DynamicNodeProperties, Omit: len(t.DynamicNodeProperties) == 0}, + {Value: &t.Capabilities, Omit: len(t.Capabilities) == 0}, + }) +} + +func (t *Target) Unmarshal(r *buffer.Buffer) error { + return encoding.UnmarshalComposite(r, encoding.TypeCodeTarget, []encoding.UnmarshalField{ + {Field: &t.Address}, + {Field: &t.Durable}, + {Field: &t.ExpiryPolicy, HandleNull: func() error { t.ExpiryPolicy = encoding.ExpirySessionEnd; return nil }}, + {Field: &t.Timeout}, + {Field: &t.Dynamic}, + {Field: &t.DynamicNodeProperties}, + {Field: &t.Capabilities}, + }...) +} + +func (t Target) String() string { + return fmt.Sprintf("source{Address: %s, Durable: %d, ExpiryPolicy: %s, Timeout: %d, "+ + "Dynamic: %t, DynamicNodeProperties: %v, Capabilities: %v}", + t.Address, + t.Durable, + t.ExpiryPolicy, + t.Timeout, + t.Dynamic, + t.DynamicNodeProperties, + t.Capabilities, + ) +} + +// frame is the decoded representation of a frame +type Frame struct { + Type Type // AMQP/SASL + Channel uint16 // channel this frame is for + Body FrameBody // body of the frame +} + +// String implements the fmt.Stringer interface for type Frame. +func (f Frame) String() string { + return fmt.Sprintf("Frame{Type: %s, Channel: %d, Body: %s}", f.Type, f.Channel, f.Body) +} + +// frameBody adds some type safety to frame encoding +type FrameBody interface { + frameBody() +} + +/* + + + + + + + + + + + + + +*/ + +type PerformOpen struct { + ContainerID string // required + Hostname string + MaxFrameSize uint32 // default: 4294967295 + ChannelMax uint16 // default: 65535 + IdleTimeout time.Duration // from milliseconds + OutgoingLocales encoding.MultiSymbol + IncomingLocales encoding.MultiSymbol + OfferedCapabilities encoding.MultiSymbol + DesiredCapabilities encoding.MultiSymbol + Properties map[encoding.Symbol]any +} + +func (o *PerformOpen) frameBody() {} + +func (o *PerformOpen) Marshal(wr *buffer.Buffer) error { + return encoding.MarshalComposite(wr, encoding.TypeCodeOpen, []encoding.MarshalField{ + {Value: &o.ContainerID, Omit: false}, + {Value: &o.Hostname, Omit: o.Hostname == ""}, + {Value: &o.MaxFrameSize, Omit: o.MaxFrameSize == 4294967295}, + {Value: &o.ChannelMax, Omit: o.ChannelMax == 65535}, + {Value: (*encoding.Milliseconds)(&o.IdleTimeout), Omit: o.IdleTimeout == 0}, + {Value: &o.OutgoingLocales, Omit: len(o.OutgoingLocales) == 0}, + {Value: &o.IncomingLocales, Omit: len(o.IncomingLocales) == 0}, + {Value: &o.OfferedCapabilities, Omit: len(o.OfferedCapabilities) == 0}, + {Value: &o.DesiredCapabilities, Omit: len(o.DesiredCapabilities) == 0}, + {Value: o.Properties, Omit: len(o.Properties) == 0}, + }) +} + +func (o *PerformOpen) Unmarshal(r *buffer.Buffer) error { + return encoding.UnmarshalComposite(r, encoding.TypeCodeOpen, []encoding.UnmarshalField{ + {Field: &o.ContainerID, HandleNull: func() error { return errors.New("Open.ContainerID is required") }}, + {Field: &o.Hostname}, + {Field: &o.MaxFrameSize, HandleNull: func() error { o.MaxFrameSize = 4294967295; return nil }}, + {Field: &o.ChannelMax, HandleNull: func() error { o.ChannelMax = 65535; return nil }}, + {Field: (*encoding.Milliseconds)(&o.IdleTimeout)}, + {Field: &o.OutgoingLocales}, + {Field: &o.IncomingLocales}, + {Field: &o.OfferedCapabilities}, + {Field: &o.DesiredCapabilities}, + {Field: &o.Properties}, + }...) +} + +func (o *PerformOpen) String() string { + return fmt.Sprintf("Open{ContainerID : %s, Hostname: %s, MaxFrameSize: %d, "+ + "ChannelMax: %d, IdleTimeout: %v, "+ + "OutgoingLocales: %v, IncomingLocales: %v, "+ + "OfferedCapabilities: %v, DesiredCapabilities: %v, "+ + "Properties: %v}", + o.ContainerID, + o.Hostname, + o.MaxFrameSize, + o.ChannelMax, + o.IdleTimeout, + o.OutgoingLocales, + o.IncomingLocales, + o.OfferedCapabilities, + o.DesiredCapabilities, + o.Properties, + ) +} + +/* + + + + + + + + + + + + + +*/ +type PerformBegin struct { + // the remote channel for this session + // If a session is locally initiated, the remote-channel MUST NOT be set. + // When an endpoint responds to a remotely initiated session, the remote-channel + // MUST be set to the channel on which the remote session sent the begin. + RemoteChannel *uint16 + + // the transfer-id of the first transfer id the sender will send + NextOutgoingID uint32 // required, sequence number http://www.ietf.org/rfc/rfc1982.txt + + // the initial incoming-window of the sender + IncomingWindow uint32 // required + + // the initial outgoing-window of the sender + OutgoingWindow uint32 // required + + // the maximum handle value that can be used on the session + // The handle-max value is the highest handle value that can be + // used on the session. A peer MUST NOT attempt to attach a link + // using a handle value outside the range that its partner can handle. + // A peer that receives a handle outside the supported range MUST + // close the connection with the framing-error error-code. + HandleMax uint32 // default 4294967295 + + // the extension capabilities the sender supports + // http://www.amqp.org/specification/1.0/session-capabilities + OfferedCapabilities encoding.MultiSymbol + + // the extension capabilities the sender can use if the receiver supports them + // The sender MUST NOT attempt to use any capability other than those it + // has declared in desired-capabilities field. + DesiredCapabilities encoding.MultiSymbol + + // session properties + // http://www.amqp.org/specification/1.0/session-properties + Properties map[encoding.Symbol]any +} + +func (b *PerformBegin) frameBody() {} + +func (b *PerformBegin) String() string { + return fmt.Sprintf("Begin{RemoteChannel: %v, NextOutgoingID: %d, IncomingWindow: %d, "+ + "OutgoingWindow: %d, HandleMax: %d, OfferedCapabilities: %v, DesiredCapabilities: %v, "+ + "Properties: %v}", + formatUint16Ptr(b.RemoteChannel), + b.NextOutgoingID, + b.IncomingWindow, + b.OutgoingWindow, + b.HandleMax, + b.OfferedCapabilities, + b.DesiredCapabilities, + b.Properties, + ) +} + +func formatUint16Ptr(p *uint16) string { + if p == nil { + return "" + } + return strconv.FormatUint(uint64(*p), 10) +} + +func (b *PerformBegin) Marshal(wr *buffer.Buffer) error { + return encoding.MarshalComposite(wr, encoding.TypeCodeBegin, []encoding.MarshalField{ + {Value: b.RemoteChannel, Omit: b.RemoteChannel == nil}, + {Value: &b.NextOutgoingID, Omit: false}, + {Value: &b.IncomingWindow, Omit: false}, + {Value: &b.OutgoingWindow, Omit: false}, + {Value: &b.HandleMax, Omit: b.HandleMax == 4294967295}, + {Value: &b.OfferedCapabilities, Omit: len(b.OfferedCapabilities) == 0}, + {Value: &b.DesiredCapabilities, Omit: len(b.DesiredCapabilities) == 0}, + {Value: b.Properties, Omit: b.Properties == nil}, + }) +} + +func (b *PerformBegin) Unmarshal(r *buffer.Buffer) error { + return encoding.UnmarshalComposite(r, encoding.TypeCodeBegin, []encoding.UnmarshalField{ + {Field: &b.RemoteChannel}, + {Field: &b.NextOutgoingID, HandleNull: func() error { return errors.New("Begin.NextOutgoingID is required") }}, + {Field: &b.IncomingWindow, HandleNull: func() error { return errors.New("Begin.IncomingWindow is required") }}, + {Field: &b.OutgoingWindow, HandleNull: func() error { return errors.New("Begin.OutgoingWindow is required") }}, + {Field: &b.HandleMax, HandleNull: func() error { b.HandleMax = 4294967295; return nil }}, + {Field: &b.OfferedCapabilities}, + {Field: &b.DesiredCapabilities}, + {Field: &b.Properties}, + }...) +} + +/* + + + + + + + + + + + + + + + + + + + +*/ +type PerformAttach struct { + // the name of the link + // + // This name uniquely identifies the link from the container of the source + // to the container of the target node, e.g., if the container of the source + // node is A, and the container of the target node is B, the link MAY be + // globally identified by the (ordered) tuple (A,B,). + Name string // required + + // the handle for the link while attached + // + // The numeric handle assigned by the the peer as a shorthand to refer to the + // link in all performatives that reference the link until the it is detached. + // + // The handle MUST NOT be used for other open links. An attempt to attach using + // a handle which is already associated with a link MUST be responded to with + // an immediate close carrying a handle-in-use session-error. + // + // To make it easier to monitor AMQP link attach frames, it is RECOMMENDED that + // implementations always assign the lowest available handle to this field. + // + // The two endpoints MAY potentially use different handles to refer to the same link. + // Link handles MAY be reused once a link is closed for both send and receive. + Handle uint32 // required + + // role of the link endpoint + // + // The role being played by the peer, i.e., whether the peer is the sender or the + // receiver of messages on the link. + Role encoding.Role + + // settlement policy for the sender + // + // The delivery settlement policy for the sender. When set at the receiver this + // indicates the desired value for the settlement mode at the sender. When set + // at the sender this indicates the actual settlement mode in use. The sender + // SHOULD respect the receiver's desired settlement mode if the receiver initiates + // the attach exchange and the sender supports the desired mode. + // + // 0: unsettled - The sender will send all deliveries initially unsettled to the receiver. + // 1: settled - The sender will send all deliveries settled to the receiver. + // 2: mixed - The sender MAY send a mixture of settled and unsettled deliveries to the receiver. + SenderSettleMode *encoding.SenderSettleMode + + // the settlement policy of the receiver + // + // The delivery settlement policy for the receiver. When set at the sender this + // indicates the desired value for the settlement mode at the receiver. + // When set at the receiver this indicates the actual settlement mode in use. + // The receiver SHOULD respect the sender's desired settlement mode if the sender + // initiates the attach exchange and the receiver supports the desired mode. + // + // 0: first - The receiver will spontaneously settle all incoming transfers. + // 1: second - The receiver will only settle after sending the disposition to + // the sender and receiving a disposition indicating settlement of + // the delivery from the sender. + ReceiverSettleMode *encoding.ReceiverSettleMode + + // the source for messages + // + // If no source is specified on an outgoing link, then there is no source currently + // attached to the link. A link with no source will never produce outgoing messages. + Source *Source + + // the target for messages + // + // If no target is specified on an incoming link, then there is no target currently + // attached to the link. A link with no target will never permit incoming messages. + Target *Target + + // unsettled delivery state + // + // This is used to indicate any unsettled delivery states when a suspended link is + // resumed. The map is keyed by delivery-tag with values indicating the delivery state. + // The local and remote delivery states for a given delivery-tag MUST be compared to + // resolve any in-doubt deliveries. If necessary, deliveries MAY be resent, or resumed + // based on the outcome of this comparison. See subsection 2.6.13. + // + // If the local unsettled map is too large to be encoded within a frame of the agreed + // maximum frame size then the session MAY be ended with the frame-size-too-small error. + // The endpoint SHOULD make use of the ability to send an incomplete unsettled map + // (see below) to avoid sending an error. + // + // The unsettled map MUST NOT contain null valued keys. + // + // When reattaching (as opposed to resuming), the unsettled map MUST be null. + Unsettled encoding.Unsettled + + // If set to true this field indicates that the unsettled map provided is not complete. + // When the map is incomplete the recipient of the map cannot take the absence of a + // delivery tag from the map as evidence of settlement. On receipt of an incomplete + // unsettled map a sending endpoint MUST NOT send any new deliveries (i.e. deliveries + // where resume is not set to true) to its partner (and a receiving endpoint which sent + // an incomplete unsettled map MUST detach with an error on receiving a transfer which + // does not have the resume flag set to true). + // + // Note that if this flag is set to true then the endpoints MUST detach and reattach at + // least once in order to send new deliveries. This flag can be useful when there are + // too many entries in the unsettled map to fit within a single frame. An endpoint can + // attach, resume, settle, and detach until enough unsettled state has been cleared for + // an attach where this flag is set to false. + IncompleteUnsettled bool // default: false + + // the sender's initial value for delivery-count + // + // This MUST NOT be null if role is sender, and it is ignored if the role is receiver. + InitialDeliveryCount uint32 // sequence number + + // the maximum message size supported by the link endpoint + // + // This field indicates the maximum message size supported by the link endpoint. + // Any attempt to deliver a message larger than this results in a message-size-exceeded + // link-error. If this field is zero or unset, there is no maximum size imposed by the + // link endpoint. + MaxMessageSize uint64 + + // the extension capabilities the sender supports + // http://www.amqp.org/specification/1.0/link-capabilities + OfferedCapabilities encoding.MultiSymbol + + // the extension capabilities the sender can use if the receiver supports them + // + // The sender MUST NOT attempt to use any capability other than those it + // has declared in desired-capabilities field. + DesiredCapabilities encoding.MultiSymbol + + // link properties + // http://www.amqp.org/specification/1.0/link-properties + Properties map[encoding.Symbol]any +} + +func (a *PerformAttach) frameBody() {} + +func (a PerformAttach) String() string { + return fmt.Sprintf("Attach{Name: %s, Handle: %d, Role: %s, SenderSettleMode: %s, ReceiverSettleMode: %s, "+ + "Source: %v, Target: %v, Unsettled: %v, IncompleteUnsettled: %t, InitialDeliveryCount: %d, MaxMessageSize: %d, "+ + "OfferedCapabilities: %v, DesiredCapabilities: %v, Properties: %v}", + a.Name, + a.Handle, + a.Role, + a.SenderSettleMode, + a.ReceiverSettleMode, + a.Source, + a.Target, + a.Unsettled, + a.IncompleteUnsettled, + a.InitialDeliveryCount, + a.MaxMessageSize, + a.OfferedCapabilities, + a.DesiredCapabilities, + a.Properties, + ) +} + +func (a *PerformAttach) Marshal(wr *buffer.Buffer) error { + return encoding.MarshalComposite(wr, encoding.TypeCodeAttach, []encoding.MarshalField{ + {Value: &a.Name, Omit: false}, + {Value: &a.Handle, Omit: false}, + {Value: &a.Role, Omit: false}, + {Value: a.SenderSettleMode, Omit: a.SenderSettleMode == nil}, + {Value: a.ReceiverSettleMode, Omit: a.ReceiverSettleMode == nil}, + {Value: a.Source, Omit: a.Source == nil}, + {Value: a.Target, Omit: a.Target == nil}, + {Value: a.Unsettled, Omit: len(a.Unsettled) == 0}, + {Value: &a.IncompleteUnsettled, Omit: !a.IncompleteUnsettled}, + {Value: &a.InitialDeliveryCount, Omit: a.Role == encoding.RoleReceiver}, + {Value: &a.MaxMessageSize, Omit: a.MaxMessageSize == 0}, + {Value: &a.OfferedCapabilities, Omit: len(a.OfferedCapabilities) == 0}, + {Value: &a.DesiredCapabilities, Omit: len(a.DesiredCapabilities) == 0}, + {Value: a.Properties, Omit: len(a.Properties) == 0}, + }) +} + +func (a *PerformAttach) Unmarshal(r *buffer.Buffer) error { + return encoding.UnmarshalComposite(r, encoding.TypeCodeAttach, []encoding.UnmarshalField{ + {Field: &a.Name, HandleNull: func() error { return errors.New("Attach.Name is required") }}, + {Field: &a.Handle, HandleNull: func() error { return errors.New("Attach.Handle is required") }}, + {Field: &a.Role, HandleNull: func() error { return errors.New("Attach.Role is required") }}, + {Field: &a.SenderSettleMode}, + {Field: &a.ReceiverSettleMode}, + {Field: &a.Source}, + {Field: &a.Target}, + {Field: &a.Unsettled}, + {Field: &a.IncompleteUnsettled}, + {Field: &a.InitialDeliveryCount}, + {Field: &a.MaxMessageSize}, + {Field: &a.OfferedCapabilities}, + {Field: &a.DesiredCapabilities}, + {Field: &a.Properties}, + }...) +} + +/* + + + + + + + + + + + + + + + + +*/ +type PerformFlow struct { + // Identifies the expected transfer-id of the next incoming transfer frame. + // This value MUST be set if the peer has received the begin frame for the + // session, and MUST NOT be set if it has not. See subsection 2.5.6 for more details. + NextIncomingID *uint32 // sequence number + + // Defines the maximum number of incoming transfer frames that the endpoint + // can currently receive. See subsection 2.5.6 for more details. + IncomingWindow uint32 // required + + // The transfer-id that will be assigned to the next outgoing transfer frame. + // See subsection 2.5.6 for more details. + NextOutgoingID uint32 // sequence number + + // Defines the maximum number of outgoing transfer frames that the endpoint + // could potentially currently send, if it was not constrained by restrictions + // imposed by its peer's incoming-window. See subsection 2.5.6 for more details. + OutgoingWindow uint32 + + // If set, indicates that the flow frame carries flow state information for the local + // link endpoint associated with the given handle. If not set, the flow frame is + // carrying only information pertaining to the session endpoint. + // + // If set to a handle that is not currently associated with an attached link, + // the recipient MUST respond by ending the session with an unattached-handle + // session error. + Handle *uint32 + + // The delivery-count is initialized by the sender when a link endpoint is created, + // and is incremented whenever a message is sent. Only the sender MAY independently + // modify this field. The receiver's value is calculated based on the last known + // value from the sender and any subsequent messages received on the link. Note that, + // despite its name, the delivery-count is not a count but a sequence number + // initialized at an arbitrary point by the sender. + // + // When the handle field is not set, this field MUST NOT be set. + // + // When the handle identifies that the flow state is being sent from the sender link + // endpoint to receiver link endpoint this field MUST be set to the current + // delivery-count of the link endpoint. + // + // When the flow state is being sent from the receiver endpoint to the sender endpoint + // this field MUST be set to the last known value of the corresponding sending endpoint. + // In the event that the receiving link endpoint has not yet seen the initial attach + // frame from the sender this field MUST NOT be set. + DeliveryCount *uint32 // sequence number + + // the current maximum number of messages that can be received + // + // The current maximum number of messages that can be handled at the receiver endpoint + // of the link. Only the receiver endpoint can independently set this value. The sender + // endpoint sets this to the last known value seen from the receiver. + // See subsection 2.6.7 for more details. + // + // When the handle field is not set, this field MUST NOT be set. + LinkCredit *uint32 + + // the number of available messages + // + // The number of messages awaiting credit at the link sender endpoint. Only the sender + // can independently set this value. The receiver sets this to the last known value seen + // from the sender. See subsection 2.6.7 for more details. + // + // When the handle field is not set, this field MUST NOT be set. + Available *uint32 + + // indicates drain mode + // + // When flow state is sent from the sender to the receiver, this field contains the + // actual drain mode of the sender. When flow state is sent from the receiver to the + // sender, this field contains the desired drain mode of the receiver. + // See subsection 2.6.7 for more details. + // + // When the handle field is not set, this field MUST NOT be set. + Drain bool + + // request state from partner + // + // If set to true then the receiver SHOULD send its state at the earliest convenient + // opportunity. + // + // If set to true, and the handle field is not set, then the sender only requires + // session endpoint state to be echoed, however, the receiver MAY fulfil this requirement + // by sending a flow performative carrying link-specific state (since any such flow also + // carries session state). + // + // If a sender makes multiple requests for the same state before the receiver can reply, + // the receiver MAY send only one flow in return. + // + // Note that if a peer responds to echo requests with flows which themselves have the + // echo field set to true, an infinite loop could result if its partner adopts the same + // policy (therefore such a policy SHOULD be avoided). + Echo bool + + // link state properties + // http://www.amqp.org/specification/1.0/link-state-properties + Properties map[encoding.Symbol]any +} + +func (f *PerformFlow) frameBody() {} + +func (f *PerformFlow) String() string { + return fmt.Sprintf("Flow{NextIncomingID: %s, IncomingWindow: %d, NextOutgoingID: %d, OutgoingWindow: %d, "+ + "Handle: %s, DeliveryCount: %s, LinkCredit: %s, Available: %s, Drain: %t, Echo: %t, Properties: %+v}", + formatUint32Ptr(f.NextIncomingID), + f.IncomingWindow, + f.NextOutgoingID, + f.OutgoingWindow, + formatUint32Ptr(f.Handle), + formatUint32Ptr(f.DeliveryCount), + formatUint32Ptr(f.LinkCredit), + formatUint32Ptr(f.Available), + f.Drain, + f.Echo, + f.Properties, + ) +} + +func formatUint32Ptr(p *uint32) string { + if p == nil { + return "" + } + return strconv.FormatUint(uint64(*p), 10) +} + +func (f *PerformFlow) Marshal(wr *buffer.Buffer) error { + return encoding.MarshalComposite(wr, encoding.TypeCodeFlow, []encoding.MarshalField{ + {Value: f.NextIncomingID, Omit: f.NextIncomingID == nil}, + {Value: &f.IncomingWindow, Omit: false}, + {Value: &f.NextOutgoingID, Omit: false}, + {Value: &f.OutgoingWindow, Omit: false}, + {Value: f.Handle, Omit: f.Handle == nil}, + {Value: f.DeliveryCount, Omit: f.DeliveryCount == nil}, + {Value: f.LinkCredit, Omit: f.LinkCredit == nil}, + {Value: f.Available, Omit: f.Available == nil}, + {Value: &f.Drain, Omit: !f.Drain}, + {Value: &f.Echo, Omit: !f.Echo}, + {Value: f.Properties, Omit: len(f.Properties) == 0}, + }) +} + +func (f *PerformFlow) Unmarshal(r *buffer.Buffer) error { + return encoding.UnmarshalComposite(r, encoding.TypeCodeFlow, []encoding.UnmarshalField{ + {Field: &f.NextIncomingID}, + {Field: &f.IncomingWindow, HandleNull: func() error { return errors.New("Flow.IncomingWindow is required") }}, + {Field: &f.NextOutgoingID, HandleNull: func() error { return errors.New("Flow.NextOutgoingID is required") }}, + {Field: &f.OutgoingWindow, HandleNull: func() error { return errors.New("Flow.OutgoingWindow is required") }}, + {Field: &f.Handle}, + {Field: &f.DeliveryCount}, + {Field: &f.LinkCredit}, + {Field: &f.Available}, + {Field: &f.Drain}, + {Field: &f.Echo}, + {Field: &f.Properties}, + }...) +} + +/* + + + + + + + + + + + + + + + + +*/ +type PerformTransfer struct { + // Specifies the link on which the message is transferred. + Handle uint32 // required + + // The delivery-id MUST be supplied on the first transfer of a multi-transfer + // delivery. On continuation transfers the delivery-id MAY be omitted. It is + // an error if the delivery-id on a continuation transfer differs from the + // delivery-id on the first transfer of a delivery. + DeliveryID *uint32 // sequence number + + // Uniquely identifies the delivery attempt for a given message on this link. + // This field MUST be specified for the first transfer of a multi-transfer + // message and can only be omitted for continuation transfers. It is an error + // if the delivery-tag on a continuation transfer differs from the delivery-tag + // on the first transfer of a delivery. + DeliveryTag []byte // up to 32 bytes + + // This field MUST be specified for the first transfer of a multi-transfer message + // and can only be omitted for continuation transfers. It is an error if the + // message-format on a continuation transfer differs from the message-format on + // the first transfer of a delivery. + // + // The upper three octets of a message format code identify a particular message + // format. The lowest octet indicates the version of said message format. Any given + // version of a format is forwards compatible with all higher versions. + MessageFormat *uint32 + + // If not set on the first (or only) transfer for a (multi-transfer) delivery, + // then the settled flag MUST be interpreted as being false. For subsequent + // transfers in a multi-transfer delivery if the settled flag is left unset then + // it MUST be interpreted as true if and only if the value of the settled flag on + // any of the preceding transfers was true; if no preceding transfer was sent with + // settled being true then the value when unset MUST be taken as false. + // + // If the negotiated value for snd-settle-mode at attachment is settled, then this + // field MUST be true on at least one transfer frame for a delivery (i.e., the + // delivery MUST be settled at the sender at the point the delivery has been + // completely transferred). + // + // If the negotiated value for snd-settle-mode at attachment is unsettled, then this + // field MUST be false (or unset) on every transfer frame for a delivery (unless the + // delivery is aborted). + Settled bool + + // indicates that the message has more content + // + // Note that if both the more and aborted fields are set to true, the aborted flag + // takes precedence. That is, a receiver SHOULD ignore the value of the more field + // if the transfer is marked as aborted. A sender SHOULD NOT set the more flag to + // true if it also sets the aborted flag to true. + More bool + + // If first, this indicates that the receiver MUST settle the delivery once it has + // arrived without waiting for the sender to settle first. + // + // If second, this indicates that the receiver MUST NOT settle until sending its + // disposition to the sender and receiving a settled disposition from the sender. + // + // If not set, this value is defaulted to the value negotiated on link attach. + // + // If the negotiated link value is first, then it is illegal to set this field + // to second. + // + // If the message is being sent settled by the sender, the value of this field + // is ignored. + // + // The (implicit or explicit) value of this field does not form part of the + // transfer state, and is not retained if a link is suspended and subsequently resumed. + // + // 0: first - The receiver will spontaneously settle all incoming transfers. + // 1: second - The receiver will only settle after sending the disposition to + // the sender and receiving a disposition indicating settlement of + // the delivery from the sender. + ReceiverSettleMode *encoding.ReceiverSettleMode + + // the state of the delivery at the sender + // + // When set this informs the receiver of the state of the delivery at the sender. + // This is particularly useful when transfers of unsettled deliveries are resumed + // after resuming a link. Setting the state on the transfer can be thought of as + // being equivalent to sending a disposition immediately before the transfer + // performative, i.e., it is the state of the delivery (not the transfer) that + // existed at the point the frame was sent. + // + // Note that if the transfer performative (or an earlier disposition performative + // referring to the delivery) indicates that the delivery has attained a terminal + // state, then no future transfer or disposition sent by the sender can alter that + // terminal state. + State encoding.DeliveryState + + // indicates a resumed delivery + // + // If true, the resume flag indicates that the transfer is being used to reassociate + // an unsettled delivery from a dissociated link endpoint. See subsection 2.6.13 + // for more details. + // + // The receiver MUST ignore resumed deliveries that are not in its local unsettled map. + // The sender MUST NOT send resumed transfers for deliveries not in its local + // unsettled map. + // + // If a resumed delivery spans more than one transfer performative, then the resume + // flag MUST be set to true on the first transfer of the resumed delivery. For + // subsequent transfers for the same delivery the resume flag MAY be set to true, + // or MAY be omitted. + // + // In the case where the exchange of unsettled maps makes clear that all message + // data has been successfully transferred to the receiver, and that only the final + // state (and potentially settlement) at the sender needs to be conveyed, then a + // resumed delivery MAY carry no payload and instead act solely as a vehicle for + // carrying the terminal state of the delivery at the sender. + Resume bool + + // indicates that the message is aborted + // + // Aborted messages SHOULD be discarded by the recipient (any payload within the + // frame carrying the performative MUST be ignored). An aborted message is + // implicitly settled. + Aborted bool + + // batchable hint + // + // If true, then the issuer is hinting that there is no need for the peer to urgently + // communicate updated delivery state. This hint MAY be used to artificially increase + // the amount of batching an implementation uses when communicating delivery states, + // and thereby save bandwidth. + // + // If the message being delivered is too large to fit within a single frame, then the + // setting of batchable to true on any of the transfer performatives for the delivery + // is equivalent to setting batchable to true for all the transfer performatives for + // the delivery. + // + // The batchable value does not form part of the transfer state, and is not retained + // if a link is suspended and subsequently resumed. + Batchable bool + + Payload []byte + + // optional channel to indicate to sender that transfer has completed + // + // Settled=true: closed when the transferred on network. + // Settled=false: closed when the receiver has confirmed settlement. + Done chan encoding.DeliveryState +} + +func (t *PerformTransfer) frameBody() {} + +func (t PerformTransfer) String() string { + deliveryTag := "" + if t.DeliveryTag != nil { + deliveryTag = fmt.Sprintf("%X", t.DeliveryTag) + } + + return fmt.Sprintf("Transfer{Handle: %d, DeliveryID: %s, DeliveryTag: %s, MessageFormat: %s, "+ + "Settled: %t, More: %t, ReceiverSettleMode: %s, State: %v, Resume: %t, Aborted: %t, "+ + "Batchable: %t, Payload [size]: %d}", + t.Handle, + formatUint32Ptr(t.DeliveryID), + deliveryTag, + formatUint32Ptr(t.MessageFormat), + t.Settled, + t.More, + t.ReceiverSettleMode, + t.State, + t.Resume, + t.Aborted, + t.Batchable, + len(t.Payload), + ) +} + +func (t *PerformTransfer) Marshal(wr *buffer.Buffer) error { + err := encoding.MarshalComposite(wr, encoding.TypeCodeTransfer, []encoding.MarshalField{ + {Value: &t.Handle}, + {Value: t.DeliveryID, Omit: t.DeliveryID == nil}, + {Value: &t.DeliveryTag, Omit: len(t.DeliveryTag) == 0}, + {Value: t.MessageFormat, Omit: t.MessageFormat == nil}, + {Value: &t.Settled, Omit: !t.Settled}, + {Value: &t.More, Omit: !t.More}, + {Value: t.ReceiverSettleMode, Omit: t.ReceiverSettleMode == nil}, + {Value: t.State, Omit: t.State == nil}, + {Value: &t.Resume, Omit: !t.Resume}, + {Value: &t.Aborted, Omit: !t.Aborted}, + {Value: &t.Batchable, Omit: !t.Batchable}, + }) + if err != nil { + return err + } + + wr.Append(t.Payload) + return nil +} + +func (t *PerformTransfer) Unmarshal(r *buffer.Buffer) error { + err := encoding.UnmarshalComposite(r, encoding.TypeCodeTransfer, []encoding.UnmarshalField{ + {Field: &t.Handle, HandleNull: func() error { return errors.New("Transfer.Handle is required") }}, + {Field: &t.DeliveryID}, + {Field: &t.DeliveryTag}, + {Field: &t.MessageFormat}, + {Field: &t.Settled}, + {Field: &t.More}, + {Field: &t.ReceiverSettleMode}, + {Field: &t.State}, + {Field: &t.Resume}, + {Field: &t.Aborted}, + {Field: &t.Batchable}, + }...) + if err != nil { + return err + } + + t.Payload = append([]byte(nil), r.Bytes()...) + + return err +} + +/* + + + + + + + + + + + +*/ +type PerformDisposition struct { + // directionality of disposition + // + // The role identifies whether the disposition frame contains information about + // sending link endpoints or receiving link endpoints. + Role encoding.Role + + // lower bound of deliveries + // + // Identifies the lower bound of delivery-ids for the deliveries in this set. + First uint32 // required, sequence number + + // upper bound of deliveries + // + // Identifies the upper bound of delivery-ids for the deliveries in this set. + // If not set, this is taken to be the same as first. + Last *uint32 // sequence number + + // indicates deliveries are settled + // + // If true, indicates that the referenced deliveries are considered settled by + // the issuing endpoint. + Settled bool + + // indicates state of deliveries + // + // Communicates the state of all the deliveries referenced by this disposition. + State encoding.DeliveryState + + // batchable hint + // + // If true, then the issuer is hinting that there is no need for the peer to + // urgently communicate the impact of the updated delivery states. This hint + // MAY be used to artificially increase the amount of batching an implementation + // uses when communicating delivery states, and thereby save bandwidth. + Batchable bool +} + +func (d *PerformDisposition) frameBody() {} + +func (d PerformDisposition) String() string { + return fmt.Sprintf("Disposition{Role: %s, First: %d, Last: %s, Settled: %t, State: %v, Batchable: %t}", + d.Role, + d.First, + formatUint32Ptr(d.Last), + d.Settled, + d.State, + d.Batchable, + ) +} + +func (d *PerformDisposition) Marshal(wr *buffer.Buffer) error { + return encoding.MarshalComposite(wr, encoding.TypeCodeDisposition, []encoding.MarshalField{ + {Value: &d.Role, Omit: false}, + {Value: &d.First, Omit: false}, + {Value: d.Last, Omit: d.Last == nil}, + {Value: &d.Settled, Omit: !d.Settled}, + {Value: d.State, Omit: d.State == nil}, + {Value: &d.Batchable, Omit: !d.Batchable}, + }) +} + +func (d *PerformDisposition) Unmarshal(r *buffer.Buffer) error { + return encoding.UnmarshalComposite(r, encoding.TypeCodeDisposition, []encoding.UnmarshalField{ + {Field: &d.Role, HandleNull: func() error { return errors.New("Disposition.Role is required") }}, + {Field: &d.First, HandleNull: func() error { return errors.New("Disposition.Handle is required") }}, + {Field: &d.Last}, + {Field: &d.Settled}, + {Field: &d.State}, + {Field: &d.Batchable}, + }...) +} + +/* + + + + + + + + +*/ +type PerformDetach struct { + // the local handle of the link to be detached + Handle uint32 //required + + // if true then the sender has closed the link + Closed bool + + // error causing the detach + // + // If set, this field indicates that the link is being detached due to an error + // condition. The value of the field SHOULD contain details on the cause of the error. + Error *encoding.Error +} + +func (d *PerformDetach) frameBody() {} + +func (d PerformDetach) String() string { + return fmt.Sprintf("Detach{Handle: %d, Closed: %t, Error: %v}", + d.Handle, + d.Closed, + d.Error, + ) +} + +func (d *PerformDetach) Marshal(wr *buffer.Buffer) error { + return encoding.MarshalComposite(wr, encoding.TypeCodeDetach, []encoding.MarshalField{ + {Value: &d.Handle, Omit: false}, + {Value: &d.Closed, Omit: !d.Closed}, + {Value: d.Error, Omit: d.Error == nil}, + }) +} + +func (d *PerformDetach) Unmarshal(r *buffer.Buffer) error { + return encoding.UnmarshalComposite(r, encoding.TypeCodeDetach, []encoding.UnmarshalField{ + {Field: &d.Handle, HandleNull: func() error { return errors.New("Detach.Handle is required") }}, + {Field: &d.Closed}, + {Field: &d.Error}, + }...) +} + +/* + + + + + + +*/ +type PerformEnd struct { + // error causing the end + // + // If set, this field indicates that the session is being ended due to an error + // condition. The value of the field SHOULD contain details on the cause of the error. + Error *encoding.Error +} + +func (e *PerformEnd) frameBody() {} + +func (d PerformEnd) String() string { + return fmt.Sprintf("End{Error: %v}", d.Error) +} + +func (e *PerformEnd) Marshal(wr *buffer.Buffer) error { + return encoding.MarshalComposite(wr, encoding.TypeCodeEnd, []encoding.MarshalField{ + {Value: e.Error, Omit: e.Error == nil}, + }) +} + +func (e *PerformEnd) Unmarshal(r *buffer.Buffer) error { + return encoding.UnmarshalComposite(r, encoding.TypeCodeEnd, + encoding.UnmarshalField{Field: &e.Error}, + ) +} + +/* + + + + + + +*/ +type PerformClose struct { + // error causing the close + // + // If set, this field indicates that the session is being closed due to an error + // condition. The value of the field SHOULD contain details on the cause of the error. + Error *encoding.Error +} + +func (c *PerformClose) frameBody() {} + +func (c *PerformClose) Marshal(wr *buffer.Buffer) error { + return encoding.MarshalComposite(wr, encoding.TypeCodeClose, []encoding.MarshalField{ + {Value: c.Error, Omit: c.Error == nil}, + }) +} + +func (c *PerformClose) Unmarshal(r *buffer.Buffer) error { + return encoding.UnmarshalComposite(r, encoding.TypeCodeClose, + encoding.UnmarshalField{Field: &c.Error}, + ) +} + +func (c *PerformClose) String() string { + return fmt.Sprintf("Close{Error: %s}", c.Error) +} + +/* + + + + + + +*/ + +type SASLInit struct { + Mechanism encoding.Symbol + InitialResponse []byte + Hostname string +} + +func (si *SASLInit) frameBody() {} + +func (si *SASLInit) Marshal(wr *buffer.Buffer) error { + return encoding.MarshalComposite(wr, encoding.TypeCodeSASLInit, []encoding.MarshalField{ + {Value: &si.Mechanism, Omit: false}, + {Value: &si.InitialResponse, Omit: false}, + {Value: &si.Hostname, Omit: len(si.Hostname) == 0}, + }) +} + +func (si *SASLInit) Unmarshal(r *buffer.Buffer) error { + return encoding.UnmarshalComposite(r, encoding.TypeCodeSASLInit, []encoding.UnmarshalField{ + {Field: &si.Mechanism, HandleNull: func() error { return errors.New("saslInit.Mechanism is required") }}, + {Field: &si.InitialResponse}, + {Field: &si.Hostname}, + }...) +} + +func (si *SASLInit) String() string { + // Elide the InitialResponse as it may contain a plain text secret. + return fmt.Sprintf("SaslInit{Mechanism : %s, InitialResponse: ********, Hostname: %s}", + si.Mechanism, + si.Hostname, + ) +} + +/* + + + + +*/ + +type SASLMechanisms struct { + Mechanisms encoding.MultiSymbol +} + +func (sm *SASLMechanisms) frameBody() {} + +func (sm *SASLMechanisms) Marshal(wr *buffer.Buffer) error { + return encoding.MarshalComposite(wr, encoding.TypeCodeSASLMechanism, []encoding.MarshalField{ + {Value: &sm.Mechanisms, Omit: false}, + }) +} + +func (sm *SASLMechanisms) Unmarshal(r *buffer.Buffer) error { + return encoding.UnmarshalComposite(r, encoding.TypeCodeSASLMechanism, + encoding.UnmarshalField{Field: &sm.Mechanisms, HandleNull: func() error { return errors.New("saslMechanisms.Mechanisms is required") }}, + ) +} + +func (sm *SASLMechanisms) String() string { + return fmt.Sprintf("SaslMechanisms{Mechanisms : %v}", + sm.Mechanisms, + ) +} + +/* + + + + +*/ + +type SASLChallenge struct { + Challenge []byte +} + +func (sc *SASLChallenge) String() string { + return "Challenge{Challenge: ********}" +} + +func (sc *SASLChallenge) frameBody() {} + +func (sc *SASLChallenge) Marshal(wr *buffer.Buffer) error { + return encoding.MarshalComposite(wr, encoding.TypeCodeSASLChallenge, []encoding.MarshalField{ + {Value: &sc.Challenge, Omit: false}, + }) +} + +func (sc *SASLChallenge) Unmarshal(r *buffer.Buffer) error { + return encoding.UnmarshalComposite(r, encoding.TypeCodeSASLChallenge, []encoding.UnmarshalField{ + {Field: &sc.Challenge, HandleNull: func() error { return errors.New("saslChallenge.Challenge is required") }}, + }...) +} + +/* + + + + +*/ + +type SASLResponse struct { + Response []byte +} + +func (sr *SASLResponse) String() string { + return "Response{Response: ********}" +} + +func (sr *SASLResponse) frameBody() {} + +func (sr *SASLResponse) Marshal(wr *buffer.Buffer) error { + return encoding.MarshalComposite(wr, encoding.TypeCodeSASLResponse, []encoding.MarshalField{ + {Value: &sr.Response, Omit: false}, + }) +} + +func (sr *SASLResponse) Unmarshal(r *buffer.Buffer) error { + return encoding.UnmarshalComposite(r, encoding.TypeCodeSASLResponse, []encoding.UnmarshalField{ + {Field: &sr.Response, HandleNull: func() error { return errors.New("saslResponse.Response is required") }}, + }...) +} + +/* + + + + + +*/ + +type SASLOutcome struct { + Code encoding.SASLCode + AdditionalData []byte +} + +func (so *SASLOutcome) frameBody() {} + +func (so *SASLOutcome) Marshal(wr *buffer.Buffer) error { + return encoding.MarshalComposite(wr, encoding.TypeCodeSASLOutcome, []encoding.MarshalField{ + {Value: &so.Code, Omit: false}, + {Value: &so.AdditionalData, Omit: len(so.AdditionalData) == 0}, + }) +} + +func (so *SASLOutcome) Unmarshal(r *buffer.Buffer) error { + return encoding.UnmarshalComposite(r, encoding.TypeCodeSASLOutcome, []encoding.UnmarshalField{ + {Field: &so.Code, HandleNull: func() error { return errors.New("saslOutcome.AdditionalData is required") }}, + {Field: &so.AdditionalData}, + }...) +} + +func (so *SASLOutcome) String() string { + return fmt.Sprintf("SaslOutcome{Code : %v, AdditionalData: %v}", + so.Code, + so.AdditionalData, + ) +} diff --git a/vendor/github.com/Azure/go-amqp/internal/frames/parsing.go b/vendor/github.com/Azure/go-amqp/internal/frames/parsing.go new file mode 100644 index 00000000000..c95a1aaa9f1 --- /dev/null +++ b/vendor/github.com/Azure/go-amqp/internal/frames/parsing.go @@ -0,0 +1,159 @@ +package frames + +import ( + "encoding/binary" + "errors" + "fmt" + "math" + + "github.com/Azure/go-amqp/internal/buffer" + "github.com/Azure/go-amqp/internal/encoding" +) + +const HeaderSize = 8 + +// Frame structure: +// +// header (8 bytes) +// 0-3: SIZE (total size, at least 8 bytes for header, uint32) +// 4: DOFF (data offset,at least 2, count of 4 bytes words, uint8) +// 5: TYPE (frame type) +// 0x0: AMQP +// 0x1: SASL +// 6-7: type dependent (channel for AMQP) +// extended header (opt) +// body (opt) + +// Header in a structure appropriate for use with binary.Read() +type Header struct { + // size: an unsigned 32-bit integer that MUST contain the total frame size of the frame header, + // extended header, and frame body. The frame is malformed if the size is less than the size of + // the frame header (8 bytes). + Size uint32 + // doff: gives the position of the body within the frame. The value of the data offset is an + // unsigned, 8-bit integer specifying a count of 4-byte words. Due to the mandatory 8-byte + // frame header, the frame is malformed if the value is less than 2. + DataOffset uint8 + FrameType uint8 + Channel uint16 +} + +// ParseHeader reads the header from r and returns the result. +// +// No validation is done. +func ParseHeader(r *buffer.Buffer) (Header, error) { + buf, ok := r.Next(8) + if !ok { + return Header{}, errors.New("invalid frameHeader") + } + _ = buf[7] + + fh := Header{ + Size: binary.BigEndian.Uint32(buf[0:4]), + DataOffset: buf[4], + FrameType: buf[5], + Channel: binary.BigEndian.Uint16(buf[6:8]), + } + + if fh.Size < HeaderSize { + return fh, fmt.Errorf("received frame header with invalid size %d", fh.Size) + } + + if fh.DataOffset < 2 { + return fh, fmt.Errorf("received frame header with invalid data offset %d", fh.DataOffset) + } + + return fh, nil +} + +// ParseBody reads and unmarshals an AMQP frame. +func ParseBody(r *buffer.Buffer) (FrameBody, error) { + payload := r.Bytes() + + if r.Len() < 3 || payload[0] != 0 || encoding.AMQPType(payload[1]) != encoding.TypeCodeSmallUlong { + return nil, errors.New("invalid frame body header") + } + + switch pType := encoding.AMQPType(payload[2]); pType { + case encoding.TypeCodeOpen: + t := new(PerformOpen) + err := t.Unmarshal(r) + return t, err + case encoding.TypeCodeBegin: + t := new(PerformBegin) + err := t.Unmarshal(r) + return t, err + case encoding.TypeCodeAttach: + t := new(PerformAttach) + err := t.Unmarshal(r) + return t, err + case encoding.TypeCodeFlow: + t := new(PerformFlow) + err := t.Unmarshal(r) + return t, err + case encoding.TypeCodeTransfer: + t := new(PerformTransfer) + err := t.Unmarshal(r) + return t, err + case encoding.TypeCodeDisposition: + t := new(PerformDisposition) + err := t.Unmarshal(r) + return t, err + case encoding.TypeCodeDetach: + t := new(PerformDetach) + err := t.Unmarshal(r) + return t, err + case encoding.TypeCodeEnd: + t := new(PerformEnd) + err := t.Unmarshal(r) + return t, err + case encoding.TypeCodeClose: + t := new(PerformClose) + err := t.Unmarshal(r) + return t, err + case encoding.TypeCodeSASLMechanism: + t := new(SASLMechanisms) + err := t.Unmarshal(r) + return t, err + case encoding.TypeCodeSASLChallenge: + t := new(SASLChallenge) + err := t.Unmarshal(r) + return t, err + case encoding.TypeCodeSASLOutcome: + t := new(SASLOutcome) + err := t.Unmarshal(r) + return t, err + default: + return nil, fmt.Errorf("unknown performative type %02x", pType) + } +} + +// Write encodes fr into buf. +// split out from conn.WriteFrame for testing purposes. +func Write(buf *buffer.Buffer, fr Frame) error { + // write header + buf.Append([]byte{ + 0, 0, 0, 0, // size, overwrite later + 2, // doff, see frameHeader.DataOffset comment + uint8(fr.Type), // frame type + }) + buf.AppendUint16(fr.Channel) // channel + + // write AMQP frame body + err := encoding.Marshal(buf, fr.Body) + if err != nil { + return err + } + + // validate size + if uint(buf.Len()) > math.MaxUint32 { + return errors.New("frame too large") + } + + // retrieve raw bytes + bufBytes := buf.Bytes() + + // write correct size + binary.BigEndian.PutUint32(bufBytes, uint32(len(bufBytes))) + return nil +} diff --git a/vendor/github.com/Azure/go-amqp/internal/queue/queue.go b/vendor/github.com/Azure/go-amqp/internal/queue/queue.go new file mode 100644 index 00000000000..8c45b5d9abf --- /dev/null +++ b/vendor/github.com/Azure/go-amqp/internal/queue/queue.go @@ -0,0 +1,162 @@ +package queue + +import ( + "container/ring" +) + +// Holder provides synchronized access to a *Queue[T]. +type Holder[T any] struct { + // these channels work in tandem to provide exclusive access to the underlying *Queue[T]. + // each channel is created with a buffer size of one. + // empty behaves like a mutex when there's one or more messages in the queue. + // populated is like a semaphore when the queue is empty. + // the *Queue[T] is only ever in one channel. which channel depends on if it contains any items. + // the initial state is for empty to contain an empty queue. + empty chan *Queue[T] + populated chan *Queue[T] +} + +// NewHolder creates a new Holder[T] that contains the provided *Queue[T]. +func NewHolder[T any](q *Queue[T]) *Holder[T] { + h := &Holder[T]{ + empty: make(chan *Queue[T], 1), + populated: make(chan *Queue[T], 1), + } + h.Release(q) + return h +} + +// Acquire attempts to acquire the *Queue[T]. If the *Queue[T] has already been acquired the call blocks. +// When the *Queue[T] is no longer required, you MUST call Release() to relinquish acquisition. +func (h *Holder[T]) Acquire() *Queue[T] { + // the queue will be in only one of the channels, it doesn't matter which one + var q *Queue[T] + select { + case q = <-h.empty: + // empty queue + case q = <-h.populated: + // populated queue + } + return q +} + +// Wait returns a channel that's signaled when the *Queue[T] contains at least one item. +// When the *Queue[T] is no longer required, you MUST call Release() to relinquish acquisition. +func (h *Holder[T]) Wait() <-chan *Queue[T] { + return h.populated +} + +// Release returns the *Queue[T] back to the Holder[T]. +// Once the *Queue[T] has been released, it is no longer safe to call its methods. +func (h *Holder[T]) Release(q *Queue[T]) { + if q.Len() == 0 { + h.empty <- q + } else { + h.populated <- q + } +} + +// Len returns the length of the *Queue[T]. +func (h *Holder[T]) Len() int { + msgLen := 0 + select { + case q := <-h.empty: + h.empty <- q + case q := <-h.populated: + msgLen = q.Len() + h.populated <- q + } + return msgLen +} + +// Queue[T] is a segmented FIFO queue of Ts. +type Queue[T any] struct { + head *ring.Ring + tail *ring.Ring + size int +} + +// New creates a new instance of Queue[T]. +// - size is the size of each Queue segment +func New[T any](size int) *Queue[T] { + r := &ring.Ring{ + Value: &segment[T]{ + items: make([]*T, size), + }, + } + return &Queue[T]{ + head: r, + tail: r, + } +} + +// Enqueue adds the specified item to the end of the queue. +// If the current segment is full, a new segment is created. +func (q *Queue[T]) Enqueue(item T) { + for { + r := q.tail + seg := r.Value.(*segment[T]) + + if seg.tail < len(seg.items) { + seg.items[seg.tail] = &item + seg.tail++ + q.size++ + return + } + + // segment is full, can we advance? + if next := r.Next(); next != q.head { + q.tail = next + continue + } + + // no, add a new ring + r.Link(&ring.Ring{ + Value: &segment[T]{ + items: make([]*T, len(seg.items)), + }, + }) + + q.tail = r.Next() + } +} + +// Dequeue removes and returns the item from the front of the queue. +func (q *Queue[T]) Dequeue() *T { + r := q.head + seg := r.Value.(*segment[T]) + + if seg.tail == 0 { + // queue is empty + return nil + } + + // remove first item + item := seg.items[seg.head] + seg.items[seg.head] = nil + seg.head++ + q.size-- + + if seg.head == seg.tail { + // segment is now empty, reset indices + seg.head, seg.tail = 0, 0 + + // if we're not at the last ring, advance head to the next one + if q.head != q.tail { + q.head = r.Next() + } + } + + return item +} + +// Len returns the total count of enqueued items. +func (q *Queue[T]) Len() int { + return q.size +} + +type segment[T any] struct { + items []*T + head int + tail int +} diff --git a/vendor/github.com/Azure/go-amqp/internal/shared/shared.go b/vendor/github.com/Azure/go-amqp/internal/shared/shared.go new file mode 100644 index 00000000000..efd859cfbb8 --- /dev/null +++ b/vendor/github.com/Azure/go-amqp/internal/shared/shared.go @@ -0,0 +1,34 @@ +package shared + +import ( + "encoding/base64" + "math/rand" + "sync" + "time" +) + +// lockedRand provides a rand source that is safe for concurrent use. +type lockedRand struct { + mu sync.Mutex + src *rand.Rand +} + +func (r *lockedRand) Read(p []byte) (int, error) { + r.mu.Lock() + defer r.mu.Unlock() + return r.src.Read(p) +} + +// package scoped rand source to avoid any issues with seeding +// of the global source. +var pkgRand = &lockedRand{ + src: rand.New(rand.NewSource(time.Now().UnixNano())), +} + +// RandString returns a base64 encoded string of n bytes. +func RandString(n int) string { + b := make([]byte, n) + // from math/rand, cannot fail + _, _ = pkgRand.Read(b) + return base64.RawURLEncoding.EncodeToString(b) +} diff --git a/vendor/github.com/Azure/go-amqp/link.go b/vendor/github.com/Azure/go-amqp/link.go new file mode 100644 index 00000000000..2a3721df754 --- /dev/null +++ b/vendor/github.com/Azure/go-amqp/link.go @@ -0,0 +1,393 @@ +package amqp + +import ( + "context" + "errors" + "fmt" + "sync" + + "github.com/Azure/go-amqp/internal/debug" + "github.com/Azure/go-amqp/internal/encoding" + "github.com/Azure/go-amqp/internal/frames" + "github.com/Azure/go-amqp/internal/queue" + "github.com/Azure/go-amqp/internal/shared" +) + +// linkKey uniquely identifies a link on a connection by name and direction. +// +// A link can be identified uniquely by the ordered tuple +// +// (source-container-id, target-container-id, name) +// +// On a single connection the container ID pairs can be abbreviated +// to a boolean flag indicating the direction of the link. +type linkKey struct { + name string + role encoding.Role // Local role: sender/receiver +} + +// link contains the common state and methods for sending and receiving links +type link struct { + key linkKey // Name and direction + + // NOTE: outputHandle and inputHandle might not have the same value + + // our handle + outputHandle uint32 + + // remote's handle + inputHandle uint32 + + // frames destined for this link are added to this queue by Session.muxFrameToLink + rxQ *queue.Holder[frames.FrameBody] + + // used for gracefully closing link + close chan struct{} // signals a link's mux to shut down; DO NOT use this to check if a link has terminated, use done instead + closeOnce *sync.Once // closeOnce protects close from being closed multiple times + + done chan struct{} // closed when the link has terminated (mux exited); DO NOT wait on this from within a link's mux() as it will never trigger! + doneErr error // contains the mux error state; ONLY written to by the mux and MUST only be read from after done is closed! + closeErr error // contains the error state returned from closeLink(); ONLY closeLink() reads/writes this! + + session *Session // parent session + source *frames.Source // used for Receiver links + target *frames.Target // used for Sender links + properties map[encoding.Symbol]any // additional properties sent upon link attach + + // "The delivery-count is initialized by the sender when a link endpoint is created, + // and is incremented whenever a message is sent. Only the sender MAY independently + // modify this field. The receiver's value is calculated based on the last known + // value from the sender and any subsequent messages received on the link. Note that, + // despite its name, the delivery-count is not a count but a sequence number + // initialized at an arbitrary point by the sender." + deliveryCount uint32 + + // The current maximum number of messages that can be handled at the receiver endpoint of the link. Only the receiver endpoint + // can independently set this value. The sender endpoint sets this to the last known value seen from the receiver. + linkCredit uint32 + + senderSettleMode *SenderSettleMode + receiverSettleMode *ReceiverSettleMode + maxMessageSize uint64 + + closeInProgress bool // indicates that the detach performative has been sent + dynamicAddr bool // request a dynamic link address from the server +} + +func newLink(s *Session, r encoding.Role) link { + l := link{ + key: linkKey{shared.RandString(40), r}, + session: s, + close: make(chan struct{}), + closeOnce: &sync.Once{}, + done: make(chan struct{}), + } + + // set the segment size relative to respective window + var segmentSize int + if r == encoding.RoleReceiver { + segmentSize = int(s.incomingWindow) + } else { + segmentSize = int(s.outgoingWindow) + } + + l.rxQ = queue.NewHolder(queue.New[frames.FrameBody](segmentSize)) + return l +} + +// waitForFrame waits for an incoming frame to be queued. +// it returns the next frame from the queue, or an error. +// the error is either from the context or session.doneErr. +// not meant for consumption outside of link.go. +func (l *link) waitForFrame(ctx context.Context) (frames.FrameBody, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-l.session.done: + // session has terminated, no need to deallocate in this case + return nil, l.session.doneErr + case q := <-l.rxQ.Wait(): + // frame received + fr := q.Dequeue() + l.rxQ.Release(q) + return *fr, nil + } +} + +// attach sends the Attach performative to establish the link with its parent session. +// this is automatically called by the new*Link constructors. +func (l *link) attach(ctx context.Context, beforeAttach func(*frames.PerformAttach), afterAttach func(*frames.PerformAttach)) error { + if err := l.session.freeAbandonedLinks(ctx); err != nil { + return err + } + + // once the abandoned links have been cleaned up we can create our link + if err := l.session.allocateHandle(ctx, l); err != nil { + return err + } + + attach := &frames.PerformAttach{ + Name: l.key.name, + Handle: l.outputHandle, + ReceiverSettleMode: l.receiverSettleMode, + SenderSettleMode: l.senderSettleMode, + MaxMessageSize: l.maxMessageSize, + Source: l.source, + Target: l.target, + Properties: l.properties, + } + + // link-specific configuration of the attach frame + beforeAttach(attach) + + if err := l.txFrameAndWait(ctx, attach); err != nil { + return err + } + + // wait for response + fr, err := l.waitForFrame(ctx) + if err != nil { + l.session.abandonLink(l) + return err + } + + resp, ok := fr.(*frames.PerformAttach) + if !ok { + debug.Log(1, "RX (link %p): unexpected attach response frame %T", l, fr) + if err := l.session.conn.Close(); err != nil { + return err + } + return &ConnError{inner: fmt.Errorf("unexpected attach response: %#v", fr)} + } + + // If the remote encounters an error during the attach it returns an Attach + // with no Source or Target. The remote then sends a Detach with an error. + // + // Note that if the application chooses not to create a terminus, the session + // endpoint will still create a link endpoint and issue an attach indicating + // that the link endpoint has no associated local terminus. In this case, the + // session endpoint MUST immediately detach the newly created link endpoint. + // + // http://docs.oasis-open.org/amqp/core/v1.0/csprd01/amqp-core-transport-v1.0-csprd01.html#doc-idp386144 + if resp.Source == nil && resp.Target == nil { + // wait for detach + fr, err := l.waitForFrame(ctx) + if err != nil { + // we timed out waiting for the peer to close the link, this really isn't an abandoned link. + // however, we still need to send the detach performative to ack the peer. + l.session.abandonLink(l) + return err + } + + detach, ok := fr.(*frames.PerformDetach) + if !ok { + if err := l.session.conn.Close(); err != nil { + return err + } + return &ConnError{inner: fmt.Errorf("unexpected frame while waiting for detach: %#v", fr)} + } + + // send return detach + fr = &frames.PerformDetach{ + Handle: l.outputHandle, + Closed: true, + } + if err := l.txFrameAndWait(ctx, fr); err != nil { + return err + } + + if detach.Error == nil { + return fmt.Errorf("received detach with no error specified") + } + return detach.Error + } + + if l.maxMessageSize == 0 || resp.MaxMessageSize < l.maxMessageSize { + l.maxMessageSize = resp.MaxMessageSize + } + + // link-specific configuration post attach + afterAttach(resp) + + if err := l.setSettleModes(resp); err != nil { + // close the link as there's a mismatch on requested/supported settlement modes + dr := &frames.PerformDetach{ + Handle: l.outputHandle, + Closed: true, + } + if err := l.txFrameAndWait(ctx, dr); err != nil { + return err + } + return err + } + + return nil +} + +// setSettleModes sets the settlement modes based on the resp frames.PerformAttach. +// +// If a settlement mode has been explicitly set locally and it was not honored by the +// server an error is returned. +func (l *link) setSettleModes(resp *frames.PerformAttach) error { + var ( + localRecvSettle = receiverSettleModeValue(l.receiverSettleMode) + respRecvSettle = receiverSettleModeValue(resp.ReceiverSettleMode) + ) + if l.receiverSettleMode != nil && localRecvSettle != respRecvSettle { + return fmt.Errorf("amqp: receiver settlement mode %q requested, received %q from server", l.receiverSettleMode, &respRecvSettle) + } + l.receiverSettleMode = &respRecvSettle + + var ( + localSendSettle = senderSettleModeValue(l.senderSettleMode) + respSendSettle = senderSettleModeValue(resp.SenderSettleMode) + ) + if l.senderSettleMode != nil && localSendSettle != respSendSettle { + return fmt.Errorf("amqp: sender settlement mode %q requested, received %q from server", l.senderSettleMode, &respSendSettle) + } + l.senderSettleMode = &respSendSettle + + return nil +} + +// muxHandleFrame processes fr based on type. +func (l *link) muxHandleFrame(fr frames.FrameBody) error { + switch fr := fr.(type) { + case *frames.PerformDetach: + if !fr.Closed { + l.closeWithError(ErrCondNotImplemented, fmt.Sprintf("non-closing detach not supported: %+v", fr)) + return nil + } + + // there are two possibilities: + // - this is the ack to a client-side Close() + // - the peer is closing the link so we must ack + + if l.closeInProgress { + // if the client-side close was initiated due to an error (l.closeWithError) + // then l.doneErr will already be set. in this case, return that error instead + // of an empty LinkError which indicates a clean client-side close. + if l.doneErr != nil { + return l.doneErr + } + return &LinkError{} + } + + dr := &frames.PerformDetach{ + Handle: l.outputHandle, + Closed: true, + } + l.txFrame(&frameContext{Ctx: context.Background()}, dr) + return &LinkError{RemoteErr: fr.Error} + + default: + debug.Log(1, "RX (link %p): unexpected frame: %s", l, fr) + l.closeWithError(ErrCondInternalError, fmt.Sprintf("link received unexpected frame %T", fr)) + return nil + } +} + +// Close closes the Sender and AMQP link. +func (l *link) closeLink(ctx context.Context) error { + var ctxErr error + l.closeOnce.Do(func() { + close(l.close) + + // once the mux has received the ack'ing detach performative, the mux will + // exit which deletes the link and closes l.done. + select { + case <-l.done: + l.closeErr = l.doneErr + case <-ctx.Done(): + // notify the caller that the close timed out/was cancelled. + // the mux will remain running and once the ack is received it will terminate. + ctxErr = ctx.Err() + + // record that the close timed out/was cancelled. + // subsequent calls to closeLink() will return this + debug.Log(1, "TX (link %p) closing %s: %v", l, l.key.name, ctxErr) + l.closeErr = &LinkError{inner: ctxErr} + } + }) + + if ctxErr != nil { + return ctxErr + } + + var linkErr *LinkError + if errors.As(l.closeErr, &linkErr) && linkErr.RemoteErr == nil && linkErr.inner == nil { + // an empty LinkError means the link was cleanly closed by the caller + return nil + } + return l.closeErr +} + +// closeWithError initiates closing the link with the specified AMQP error. +// the mux must continue to run until the ack'ing detach is received. +// l.doneErr is populated with a &LinkError{} containing an inner error constructed from the specified values +// - cnd is the AMQP error condition +// - desc is the error description +func (l *link) closeWithError(cnd ErrCond, desc string) { + amqpErr := &Error{Condition: cnd, Description: desc} + if l.closeInProgress { + debug.Log(3, "TX (link %p) close error already pending, discarding %v", l, amqpErr) + return + } + + dr := &frames.PerformDetach{ + Handle: l.outputHandle, + Closed: true, + Error: amqpErr, + } + l.closeInProgress = true + l.doneErr = &LinkError{inner: fmt.Errorf("%s: %s", cnd, desc)} + l.txFrame(&frameContext{Ctx: context.Background()}, dr) +} + +// txFrame sends the specified frame via the link's session. +// you MUST call this instead of session.txFrame() to ensure +// that frames are not sent during session shutdown. +func (l *link) txFrame(frameCtx *frameContext, fr frames.FrameBody) { + // NOTE: there is no need to select on l.done as this is either + // called from a link's mux or before the mux has even started. + select { + case <-l.session.done: + // the link's session has terminated, let that propagate to the link's mux + case <-l.session.endSent: + // we swallow this to prevent the link's mux from terminating. + // l.session.done will soon close so this is temporary. + case l.session.tx <- frameBodyEnvelope{FrameCtx: frameCtx, FrameBody: fr}: + debug.Log(2, "TX (link %p): mux frame to Session (%p): %s", l, l.session, fr) + } +} + +// txFrame sends the specified frame via the link's session. +// you MUST call this instead of session.txFrame() to ensure +// that frames are not sent during session shutdown. +func (l *link) txFrameAndWait(ctx context.Context, fr frames.FrameBody) error { + frameCtx := frameContext{ + Ctx: ctx, + Done: make(chan struct{}), + } + + // NOTE: there is no need to select on l.done as this is either + // called from a link's mux or before the mux has even started. + + select { + case <-l.session.done: + return l.session.doneErr + case <-l.session.endSent: + // we swallow this to prevent the link's mux from terminating. + // l.session.done will soon close so this is temporary. + return nil + case l.session.tx <- frameBodyEnvelope{FrameCtx: &frameCtx, FrameBody: fr}: + debug.Log(2, "TX (link %p): mux frame to Session (%p): %s", l, l.session, fr) + } + + select { + case <-frameCtx.Done: + return frameCtx.Err + case <-l.session.done: + return l.session.doneErr + } +} diff --git a/vendor/github.com/Azure/go-amqp/link_options.go b/vendor/github.com/Azure/go-amqp/link_options.go new file mode 100644 index 00000000000..e0447c5183a --- /dev/null +++ b/vendor/github.com/Azure/go-amqp/link_options.go @@ -0,0 +1,238 @@ +package amqp + +import ( + "github.com/Azure/go-amqp/internal/encoding" +) + +type SenderOptions struct { + // Capabilities is the list of extension capabilities the sender supports. + Capabilities []string + + // Durability indicates what state of the sender will be retained durably. + // + // Default: DurabilityNone. + Durability Durability + + // DynamicAddress indicates a dynamic address is to be used. + // Any specified address will be ignored. + // + // Default: false. + DynamicAddress bool + + // ExpiryPolicy determines when the expiry timer of the sender starts counting + // down from the timeout value. If the link is subsequently re-attached before + // the timeout is reached, the count down is aborted. + // + // Default: ExpirySessionEnd. + ExpiryPolicy ExpiryPolicy + + // ExpiryTimeout is the duration in seconds that the sender will be retained. + // + // Default: 0. + ExpiryTimeout uint32 + + // Name sets the name of the link. + // + // Link names must be unique per-connection and direction. + // + // Default: randomly generated. + Name string + + // Properties sets an entry in the link properties map sent to the server. + Properties map[string]any + + // RequestedReceiverSettleMode sets the requested receiver settlement mode. + // + // If a settlement mode is explicitly set and the server does not + // honor it an error will be returned during link attachment. + // + // Default: Accept the settlement mode set by the server, commonly ModeFirst. + RequestedReceiverSettleMode *ReceiverSettleMode + + // SettlementMode sets the settlement mode in use by this sender. + // + // Default: ModeMixed. + SettlementMode *SenderSettleMode + + // SourceAddress specifies the source address for this sender. + SourceAddress string + + // TargetCapabilities is the list of extension capabilities the sender desires. + TargetCapabilities []string + + // TargetDurability indicates what state of the peer will be retained durably. + // + // Default: DurabilityNone. + TargetDurability Durability + + // TargetExpiryPolicy determines when the expiry timer of the peer starts counting + // down from the timeout value. If the link is subsequently re-attached before + // the timeout is reached, the count down is aborted. + // + // Default: ExpirySessionEnd. + TargetExpiryPolicy ExpiryPolicy + + // TargetExpiryTimeout is the duration in seconds that the peer will be retained. + // + // Default: 0. + TargetExpiryTimeout uint32 +} + +type ReceiverOptions struct { + // Capabilities is the list of extension capabilities the receiver supports. + Capabilities []string + + // Credit specifies the maximum number of unacknowledged messages + // the sender can transmit. Once this limit is reached, no more messages + // will arrive until messages are acknowledged and settled. + // + // As messages are settled, any available credit will automatically be issued. + // + // Setting this to -1 requires manual management of link credit. + // Credits can be added with IssueCredit(), and links can also be + // drained with DrainCredit(). + // This should only be enabled when complete control of the link's + // flow control is required. + // + // Default: 1. + Credit int32 + + // Durability indicates what state of the receiver will be retained durably. + // + // Default: DurabilityNone. + Durability Durability + + // DynamicAddress indicates a dynamic address is to be used. + // Any specified address will be ignored. + // + // Default: false. + DynamicAddress bool + + // ExpiryPolicy determines when the expiry timer of the sender starts counting + // down from the timeout value. If the link is subsequently re-attached before + // the timeout is reached, the count down is aborted. + // + // Default: ExpirySessionEnd. + ExpiryPolicy ExpiryPolicy + + // ExpiryTimeout is the duration in seconds that the sender will be retained. + // + // Default: 0. + ExpiryTimeout uint32 + + // Filters contains the desired filters for this receiver. + // If the peer cannot fulfill the filters the link will be detached. + Filters []LinkFilter + + // MaxMessageSize sets the maximum message size that can + // be received on the link. + // + // A size of zero indicates no limit. + // + // Default: 0. + MaxMessageSize uint64 + + // Name sets the name of the link. + // + // Link names must be unique per-connection and direction. + // + // Default: randomly generated. + Name string + + // Properties sets an entry in the link properties map sent to the server. + Properties map[string]any + + // RequestedSenderSettleMode sets the requested sender settlement mode. + // + // If a settlement mode is explicitly set and the server does not + // honor it an error will be returned during link attachment. + // + // Default: Accept the settlement mode set by the server, commonly ModeMixed. + RequestedSenderSettleMode *SenderSettleMode + + // SettlementMode sets the settlement mode in use by this receiver. + // + // Default: ModeFirst. + SettlementMode *ReceiverSettleMode + + // TargetAddress specifies the target address for this receiver. + TargetAddress string + + // SourceCapabilities is the list of extension capabilities the receiver desires. + SourceCapabilities []string + + // SourceDurability indicates what state of the peer will be retained durably. + // + // Default: DurabilityNone. + SourceDurability Durability + + // SourceExpiryPolicy determines when the expiry timer of the peer starts counting + // down from the timeout value. If the link is subsequently re-attached before + // the timeout is reached, the count down is aborted. + // + // Default: ExpirySessionEnd. + SourceExpiryPolicy ExpiryPolicy + + // SourceExpiryTimeout is the duration in seconds that the peer will be retained. + // + // Default: 0. + SourceExpiryTimeout uint32 +} + +// LinkFilter is an advanced API for setting non-standard source filters. +// Please file an issue or open a PR if a standard filter is missing from this +// library. +// +// The name is the key for the filter map. It will be encoded as an AMQP symbol type. +// +// The code is the descriptor of the described type value. The domain-id and descriptor-id +// should be concatenated together. If 0 is passed as the code, the name will be used as +// the descriptor. +// +// The value is the value of the descriped types. Acceptable types for value are specific +// to the filter. +// +// Example: +// +// The standard selector-filter is defined as: +// +// +// +// In this case the name is "apache.org:selector-filter:string" and the code is +// 0x0000468C00000004. +// +// LinkSourceFilter("apache.org:selector-filter:string", 0x0000468C00000004, exampleValue) +// +// References: +// +// http://docs.oasis-open.org/amqp/core/v1.0/os/amqp-core-messaging-v1.0-os.html#type-filter-set +// http://docs.oasis-open.org/amqp/core/v1.0/os/amqp-core-types-v1.0-os.html#section-descriptor-values +type LinkFilter func(encoding.Filter) + +// NewLinkFilter creates a new LinkFilter with the specified values. +// Any preexisting link filter with the same name will be updated with the new code and value. +func NewLinkFilter(name string, code uint64, value any) LinkFilter { + return func(f encoding.Filter) { + var descriptor any + if code != 0 { + descriptor = code + } else { + descriptor = encoding.Symbol(name) + } + f[encoding.Symbol(name)] = &encoding.DescribedType{ + Descriptor: descriptor, + Value: value, + } + } +} + +// NewSelectorFilter creates a new selector filter (apache.org:selector-filter:string) with the specified filter value. +// Any preexisting selector filter will be updated with the new filter value. +func NewSelectorFilter(filter string) LinkFilter { + return NewLinkFilter(selectorFilter, selectorFilterCode, filter) +} + +const ( + selectorFilter = "apache.org:selector-filter:string" + selectorFilterCode = uint64(0x0000468C00000004) +) diff --git a/vendor/github.com/Azure/go-amqp/message.go b/vendor/github.com/Azure/go-amqp/message.go new file mode 100644 index 00000000000..2fcb9d635b8 --- /dev/null +++ b/vendor/github.com/Azure/go-amqp/message.go @@ -0,0 +1,500 @@ +package amqp + +import ( + "fmt" + "time" + + "github.com/Azure/go-amqp/internal/buffer" + "github.com/Azure/go-amqp/internal/encoding" +) + +// Message is an AMQP message. +type Message struct { + // Message format code. + // + // The upper three octets of a message format code identify a particular message + // format. The lowest octet indicates the version of said message format. Any + // given version of a format is forwards compatible with all higher versions. + Format uint32 + + // The DeliveryTag can be up to 32 octets of binary data. + // Note that when mode one is enabled there will be no delivery tag. + DeliveryTag []byte + + // The header section carries standard delivery details about the transfer + // of a message through the AMQP network. + Header *MessageHeader + // If the header section is omitted the receiver MUST assume the appropriate + // default values (or the meaning implied by no value being set) for the + // fields within the header unless other target or node specific defaults + // have otherwise been set. + + // The delivery-annotations section is used for delivery-specific non-standard + // properties at the head of the message. Delivery annotations convey information + // from the sending peer to the receiving peer. + DeliveryAnnotations Annotations + // If the recipient does not understand the annotation it cannot be acted upon + // and its effects (such as any implied propagation) cannot be acted upon. + // Annotations might be specific to one implementation, or common to multiple + // implementations. The capabilities negotiated on link attach and on the source + // and target SHOULD be used to establish which annotations a peer supports. A + // registry of defined annotations and their meanings is maintained [AMQPDELANN]. + // The symbolic key "rejected" is reserved for the use of communicating error + // information regarding rejected messages. Any values associated with the + // "rejected" key MUST be of type error. + // + // If the delivery-annotations section is omitted, it is equivalent to a + // delivery-annotations section containing an empty map of annotations. + + // The message-annotations section is used for properties of the message which + // are aimed at the infrastructure. + Annotations Annotations + // The message-annotations section is used for properties of the message which + // are aimed at the infrastructure and SHOULD be propagated across every + // delivery step. Message annotations convey information about the message. + // Intermediaries MUST propagate the annotations unless the annotations are + // explicitly augmented or modified (e.g., by the use of the modified outcome). + // + // The capabilities negotiated on link attach and on the source and target can + // be used to establish which annotations a peer understands; however, in a + // network of AMQP intermediaries it might not be possible to know if every + // intermediary will understand the annotation. Note that for some annotations + // it might not be necessary for the intermediary to understand their purpose, + // i.e., they could be used purely as an attribute which can be filtered on. + // + // A registry of defined annotations and their meanings is maintained [AMQPMESSANN]. + // + // If the message-annotations section is omitted, it is equivalent to a + // message-annotations section containing an empty map of annotations. + + // The properties section is used for a defined set of standard properties of + // the message. + Properties *MessageProperties + // The properties section is part of the bare message; therefore, + // if retransmitted by an intermediary, it MUST remain unaltered. + + // The application-properties section is a part of the bare message used for + // structured application data. Intermediaries can use the data within this + // structure for the purposes of filtering or routing. + ApplicationProperties map[string]any + // The keys of this map are restricted to be of type string (which excludes + // the possibility of a null key) and the values are restricted to be of + // simple types only, that is, excluding map, list, and array types. + + // NOTE: the Data, Value, and Sequence fields are mutually exclusive. + + // Data payloads. + // A data section contains opaque binary data. + Data [][]byte + + // Value payload. + // An amqp-value section contains a single AMQP value. + Value any + + // Sequence will contain AMQP sequence sections from the body of the message. + // An amqp-sequence section contains an AMQP sequence. + Sequence [][]any + + // The footer section is used for details about the message or delivery which + // can only be calculated or evaluated once the whole bare message has been + // constructed or seen (for example message hashes, HMACs, signatures and + // encryption details). + Footer Annotations + + deliveryID uint32 // used when sending disposition + settled bool // whether transfer was settled by sender + rcv *Receiver // used to settle message on the corresponding Receiver (nil if settled == true) +} + +// NewMessage returns a *Message with data as the first payload in the Data field. +// +// This constructor is intended as a helper for basic Messages with a +// single data payload. It is valid to construct a Message directly for +// more complex usages. +// +// To create a Message using the Value or Sequence fields, don't use this +// constructor, create a new Message instead. +func NewMessage(data []byte) *Message { + return &Message{ + Data: [][]byte{data}, + } +} + +// GetData returns the first []byte from the Data field +// or nil if Data is empty. +func (m *Message) GetData() []byte { + if len(m.Data) < 1 { + return nil + } + return m.Data[0] +} + +// MarshalBinary encodes the message into binary form. +func (m *Message) MarshalBinary() ([]byte, error) { + buf := &buffer.Buffer{} + err := m.Marshal(buf) + return buf.Detach(), err +} + +func (m *Message) Marshal(wr *buffer.Buffer) error { + if m.Header != nil { + err := m.Header.Marshal(wr) + if err != nil { + return err + } + } + + if m.DeliveryAnnotations != nil { + encoding.WriteDescriptor(wr, encoding.TypeCodeDeliveryAnnotations) + err := encoding.Marshal(wr, m.DeliveryAnnotations) + if err != nil { + return err + } + } + + if m.Annotations != nil { + encoding.WriteDescriptor(wr, encoding.TypeCodeMessageAnnotations) + err := encoding.Marshal(wr, m.Annotations) + if err != nil { + return err + } + } + + if m.Properties != nil { + err := encoding.Marshal(wr, m.Properties) + if err != nil { + return err + } + } + + if m.ApplicationProperties != nil { + encoding.WriteDescriptor(wr, encoding.TypeCodeApplicationProperties) + err := encoding.Marshal(wr, m.ApplicationProperties) + if err != nil { + return err + } + } + + for _, data := range m.Data { + encoding.WriteDescriptor(wr, encoding.TypeCodeApplicationData) + err := encoding.WriteBinary(wr, data) + if err != nil { + return err + } + } + + if m.Value != nil { + encoding.WriteDescriptor(wr, encoding.TypeCodeAMQPValue) + err := encoding.Marshal(wr, m.Value) + if err != nil { + return err + } + } + + if m.Sequence != nil { + // the body can basically be one of three different types (value, data or sequence). + // When it's sequence it's actually _several_ sequence sections, one for each sub-array. + for _, v := range m.Sequence { + encoding.WriteDescriptor(wr, encoding.TypeCodeAMQPSequence) + err := encoding.Marshal(wr, v) + if err != nil { + return err + } + } + } + + if m.Footer != nil { + encoding.WriteDescriptor(wr, encoding.TypeCodeFooter) + err := encoding.Marshal(wr, m.Footer) + if err != nil { + return err + } + } + + return nil +} + +// UnmarshalBinary decodes the message from binary form. +func (m *Message) UnmarshalBinary(data []byte) error { + buf := buffer.New(data) + return m.Unmarshal(buf) +} + +func (m *Message) Unmarshal(r *buffer.Buffer) error { + // loop, decoding sections until bytes have been consumed + for r.Len() > 0 { + // determine type + type_, headerLength, err := encoding.PeekMessageType(r.Bytes()) + if err != nil { + return err + } + + var ( + section any + // section header is read from r before + // unmarshaling section is set to true + discardHeader = true + ) + switch encoding.AMQPType(type_) { + + case encoding.TypeCodeMessageHeader: + discardHeader = false + section = &m.Header + + case encoding.TypeCodeDeliveryAnnotations: + section = &m.DeliveryAnnotations + + case encoding.TypeCodeMessageAnnotations: + section = &m.Annotations + + case encoding.TypeCodeMessageProperties: + discardHeader = false + section = &m.Properties + + case encoding.TypeCodeApplicationProperties: + section = &m.ApplicationProperties + + case encoding.TypeCodeApplicationData: + r.Skip(int(headerLength)) + + var data []byte + err = encoding.Unmarshal(r, &data) + if err != nil { + return err + } + + m.Data = append(m.Data, data) + continue + + case encoding.TypeCodeAMQPSequence: + r.Skip(int(headerLength)) + + var data []any + err = encoding.Unmarshal(r, &data) + if err != nil { + return err + } + + m.Sequence = append(m.Sequence, data) + continue + + case encoding.TypeCodeFooter: + section = &m.Footer + + case encoding.TypeCodeAMQPValue: + section = &m.Value + + default: + return fmt.Errorf("unknown message section %#02x", type_) + } + + if discardHeader { + r.Skip(int(headerLength)) + } + + err = encoding.Unmarshal(r, section) + if err != nil { + return err + } + } + return nil +} + +func (m *Message) onSettlement() { + m.settled = true + m.rcv = nil +} + +/* + + + + + + + + +*/ + +// MessageHeader carries standard delivery details about the transfer +// of a message. +type MessageHeader struct { + Durable bool + Priority uint8 + TTL time.Duration // from milliseconds + FirstAcquirer bool + DeliveryCount uint32 +} + +func (h *MessageHeader) Marshal(wr *buffer.Buffer) error { + return encoding.MarshalComposite(wr, encoding.TypeCodeMessageHeader, []encoding.MarshalField{ + {Value: &h.Durable, Omit: !h.Durable}, + {Value: &h.Priority, Omit: h.Priority == 4}, + {Value: (*encoding.Milliseconds)(&h.TTL), Omit: h.TTL == 0}, + {Value: &h.FirstAcquirer, Omit: !h.FirstAcquirer}, + {Value: &h.DeliveryCount, Omit: h.DeliveryCount == 0}, + }) +} + +func (h *MessageHeader) Unmarshal(r *buffer.Buffer) error { + return encoding.UnmarshalComposite(r, encoding.TypeCodeMessageHeader, []encoding.UnmarshalField{ + {Field: &h.Durable}, + {Field: &h.Priority, HandleNull: func() error { h.Priority = 4; return nil }}, + {Field: (*encoding.Milliseconds)(&h.TTL)}, + {Field: &h.FirstAcquirer}, + {Field: &h.DeliveryCount}, + }...) +} + +/* + + + + + + + + + + + + + + + + +*/ + +// MessageProperties is the defined set of properties for AMQP messages. +type MessageProperties struct { + // Message-id, if set, uniquely identifies a message within the message system. + // The message producer is usually responsible for setting the message-id in + // such a way that it is assured to be globally unique. A broker MAY discard a + // message as a duplicate if the value of the message-id matches that of a + // previously received message sent to the same node. + // + // The value is restricted to the following types + // - uint64, UUID, []byte, or string + MessageID any + + // The identity of the user responsible for producing the message. + // The client sets this value, and it MAY be authenticated by intermediaries. + UserID []byte + + // The to field identifies the node that is the intended destination of the message. + // On any given transfer this might not be the node at the receiving end of the link. + To *string + + // A common field for summary information about the message content and purpose. + Subject *string + + // The address of the node to send replies to. + ReplyTo *string + + // This is a client-specific id that can be used to mark or identify messages + // between clients. + // + // The value is restricted to the following types + // - uint64, UUID, []byte, or string + CorrelationID any + + // The RFC-2046 [RFC2046] MIME type for the message's application-data section + // (body). As per RFC-2046 [RFC2046] this can contain a charset parameter defining + // the character encoding used: e.g., 'text/plain; charset="utf-8"'. + // + // For clarity, as per section 7.2.1 of RFC-2616 [RFC2616], where the content type + // is unknown the content-type SHOULD NOT be set. This allows the recipient the + // opportunity to determine the actual type. Where the section is known to be truly + // opaque binary data, the content-type SHOULD be set to application/octet-stream. + // + // When using an application-data section with a section code other than data, + // content-type SHOULD NOT be set. + ContentType *string + + // The content-encoding property is used as a modifier to the content-type. + // When present, its value indicates what additional content encodings have been + // applied to the application-data, and thus what decoding mechanisms need to be + // applied in order to obtain the media-type referenced by the content-type header + // field. + // + // Content-encoding is primarily used to allow a document to be compressed without + // losing the identity of its underlying content type. + // + // Content-encodings are to be interpreted as per section 3.5 of RFC 2616 [RFC2616]. + // Valid content-encodings are registered at IANA [IANAHTTPPARAMS]. + // + // The content-encoding MUST NOT be set when the application-data section is other + // than data. The binary representation of all other application-data section types + // is defined completely in terms of the AMQP type system. + // + // Implementations MUST NOT use the identity encoding. Instead, implementations + // SHOULD NOT set this property. Implementations SHOULD NOT use the compress encoding, + // except as to remain compatible with messages originally sent with other protocols, + // e.g. HTTP or SMTP. + // + // Implementations SHOULD NOT specify multiple content-encoding values except as to + // be compatible with messages originally sent with other protocols, e.g. HTTP or SMTP. + ContentEncoding *string + + // An absolute time when this message is considered to be expired. + AbsoluteExpiryTime *time.Time + + // An absolute time when this message was created. + CreationTime *time.Time + + // Identifies the group the message belongs to. + GroupID *string + + // The relative position of this message within its group. + // + // The value is defined as a RFC-1982 sequence number + GroupSequence *uint32 + + // This is a client-specific id that is used so that client can send replies to this + // message to a specific group. + ReplyToGroupID *string +} + +func (p *MessageProperties) Marshal(wr *buffer.Buffer) error { + return encoding.MarshalComposite(wr, encoding.TypeCodeMessageProperties, []encoding.MarshalField{ + {Value: p.MessageID, Omit: p.MessageID == nil}, + {Value: &p.UserID, Omit: len(p.UserID) == 0}, + {Value: p.To, Omit: p.To == nil}, + {Value: p.Subject, Omit: p.Subject == nil}, + {Value: p.ReplyTo, Omit: p.ReplyTo == nil}, + {Value: p.CorrelationID, Omit: p.CorrelationID == nil}, + {Value: (*encoding.Symbol)(p.ContentType), Omit: p.ContentType == nil}, + {Value: (*encoding.Symbol)(p.ContentEncoding), Omit: p.ContentEncoding == nil}, + {Value: p.AbsoluteExpiryTime, Omit: p.AbsoluteExpiryTime == nil}, + {Value: p.CreationTime, Omit: p.CreationTime == nil}, + {Value: p.GroupID, Omit: p.GroupID == nil}, + {Value: p.GroupSequence, Omit: p.GroupSequence == nil}, + {Value: p.ReplyToGroupID, Omit: p.ReplyToGroupID == nil}, + }) +} + +func (p *MessageProperties) Unmarshal(r *buffer.Buffer) error { + return encoding.UnmarshalComposite(r, encoding.TypeCodeMessageProperties, []encoding.UnmarshalField{ + {Field: &p.MessageID}, + {Field: &p.UserID}, + {Field: &p.To}, + {Field: &p.Subject}, + {Field: &p.ReplyTo}, + {Field: &p.CorrelationID}, + {Field: &p.ContentType}, + {Field: &p.ContentEncoding}, + {Field: &p.AbsoluteExpiryTime}, + {Field: &p.CreationTime}, + {Field: &p.GroupID}, + {Field: &p.GroupSequence}, + {Field: &p.ReplyToGroupID}, + }...) +} + +// Annotations keys must be of type string, int, or int64. +// +// String keys are encoded as AMQP Symbols. +type Annotations = encoding.Annotations + +// UUID is a 128 bit identifier as defined in RFC 4122. +type UUID = encoding.UUID diff --git a/vendor/github.com/Azure/go-amqp/receiver.go b/vendor/github.com/Azure/go-amqp/receiver.go new file mode 100644 index 00000000000..d472614d4c3 --- /dev/null +++ b/vendor/github.com/Azure/go-amqp/receiver.go @@ -0,0 +1,909 @@ +package amqp + +import ( + "bytes" + "context" + "errors" + "fmt" + "sync" + "sync/atomic" + + "github.com/Azure/go-amqp/internal/buffer" + "github.com/Azure/go-amqp/internal/debug" + "github.com/Azure/go-amqp/internal/encoding" + "github.com/Azure/go-amqp/internal/frames" + "github.com/Azure/go-amqp/internal/queue" +) + +// Default link options +const ( + defaultLinkCredit = 1 +) + +// Receiver receives messages on a single AMQP link. +type Receiver struct { + l link + + // message receiving + receiverReady chan struct{} // receiver sends on this when mux is paused to indicate it can handle more messages + messagesQ *queue.Holder[Message] // used to send completed messages to receiver + txDisposition chan frameBodyEnvelope // used to funnel disposition frames through the mux + + // NOTE: this will need to be retooled if/when we need to support resuming links. + // at present, this is only used for debug tracing purposes so it's safe to change it to a count. + unsettledMessages int32 // count of unsettled messages for this receiver; MUST be atomically accessed + + msgBuf buffer.Buffer // buffered bytes for current message + more bool // if true, buf contains a partial message + msg Message // current message being decoded + + settlementCount uint32 // the count of settled messages + settlementCountMu sync.Mutex // must be held when accessing settlementCount + + autoSendFlow bool // automatically send flow frames as credit becomes available + inFlight inFlight // used to track message disposition when rcv-settle-mode == second + creditor creditor // manages credits via calls to IssueCredit/DrainCredit +} + +// IssueCredit adds credits to be requested in the next flow request. +// Attempting to issue more credit than the receiver's max credit as +// specified in ReceiverOptions.MaxCredit will result in an error. +func (r *Receiver) IssueCredit(credit uint32) error { + if r.autoSendFlow { + return errors.New("issueCredit can only be used with receiver links using manual credit management") + } + + if err := r.creditor.IssueCredit(credit); err != nil { + return err + } + + // cause mux() to check our flow conditions. + select { + case r.receiverReady <- struct{}{}: + default: + } + + return nil +} + +// Prefetched returns the next message that is stored in the Receiver's +// prefetch cache. It does NOT wait for the remote sender to send messages +// and returns immediately if the prefetch cache is empty. To receive from the +// prefetch and wait for messages from the remote Sender use `Receive`. +// +// Once a message is received, and if the sender is configured in any mode other +// than SenderSettleModeSettled, you *must* take an action on the message by calling +// one of the following: AcceptMessage, RejectMessage, ReleaseMessage, ModifyMessage. +func (r *Receiver) Prefetched() *Message { + select { + case r.receiverReady <- struct{}{}: + default: + } + + // non-blocking receive to ensure buffered messages are + // delivered regardless of whether the link has been closed. + q := r.messagesQ.Acquire() + msg := q.Dequeue() + r.messagesQ.Release(q) + + if msg == nil { + return nil + } + + debug.Log(3, "RX (Receiver %p): prefetched delivery ID %d", r, msg.deliveryID) + + if msg.settled { + r.onSettlement(1) + } + + return msg +} + +// ReceiveOptions contains any optional values for the Receiver.Receive method. +type ReceiveOptions struct { + // for future expansion +} + +// Receive returns the next message from the sender. +// Blocks until a message is received, ctx completes, or an error occurs. +// +// Once a message is received, and if the sender is configured in any mode other +// than SenderSettleModeSettled, you *must* take an action on the message by calling +// one of the following: AcceptMessage, RejectMessage, ReleaseMessage, ModifyMessage. +func (r *Receiver) Receive(ctx context.Context, opts *ReceiveOptions) (*Message, error) { + if msg := r.Prefetched(); msg != nil { + return msg, nil + } + + // wait for the next message + select { + case q := <-r.messagesQ.Wait(): + msg := q.Dequeue() + debug.Assert(msg != nil) + debug.Log(3, "RX (Receiver %p): received delivery ID %d", r, msg.deliveryID) + r.messagesQ.Release(q) + if msg.settled { + r.onSettlement(1) + } + return msg, nil + case <-r.l.done: + // if the link receives messages and is then closed between the above call to r.Prefetched() + // and this select statement, the order of selecting r.messages and r.l.done is undefined. + // however, once r.l.done is closed the link cannot receive any more messages. so be sure to + // drain any that might have trickled in within this window. + if msg := r.Prefetched(); msg != nil { + return msg, nil + } + return nil, r.l.doneErr + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +// Accept notifies the server that the message has been accepted and does not require redelivery. +// - ctx controls waiting for the peer to acknowledge the disposition +// - msg is the message to accept +// +// If the context's deadline expires or is cancelled before the operation +// completes, the message's disposition is in an unknown state. +func (r *Receiver) AcceptMessage(ctx context.Context, msg *Message) error { + return msg.rcv.messageDisposition(ctx, msg, &encoding.StateAccepted{}) +} + +// Reject notifies the server that the message is invalid. +// - ctx controls waiting for the peer to acknowledge the disposition +// - msg is the message to reject +// - e is an optional rejection error +// +// If the context's deadline expires or is cancelled before the operation +// completes, the message's disposition is in an unknown state. +func (r *Receiver) RejectMessage(ctx context.Context, msg *Message, e *Error) error { + return msg.rcv.messageDisposition(ctx, msg, &encoding.StateRejected{Error: e}) +} + +// Release releases the message back to the server. The message may be redelivered to this or another consumer. +// - ctx controls waiting for the peer to acknowledge the disposition +// - msg is the message to release +// +// If the context's deadline expires or is cancelled before the operation +// completes, the message's disposition is in an unknown state. +func (r *Receiver) ReleaseMessage(ctx context.Context, msg *Message) error { + return msg.rcv.messageDisposition(ctx, msg, &encoding.StateReleased{}) +} + +// Modify notifies the server that the message was not acted upon and should be modifed. +// - ctx controls waiting for the peer to acknowledge the disposition +// - msg is the message to modify +// - options contains the optional settings to modify +// +// If the context's deadline expires or is cancelled before the operation +// completes, the message's disposition is in an unknown state. +func (r *Receiver) ModifyMessage(ctx context.Context, msg *Message, options *ModifyMessageOptions) error { + if options == nil { + options = &ModifyMessageOptions{} + } + return msg.rcv.messageDisposition(ctx, + msg, &encoding.StateModified{ + DeliveryFailed: options.DeliveryFailed, + UndeliverableHere: options.UndeliverableHere, + MessageAnnotations: options.Annotations, + }) +} + +// ModifyMessageOptions contains the optional parameters to ModifyMessage. +type ModifyMessageOptions struct { + // DeliveryFailed indicates that the server must consider this an + // unsuccessful delivery attempt and increment the delivery count. + DeliveryFailed bool + + // UndeliverableHere indicates that the server must not redeliver + // the message to this link. + UndeliverableHere bool + + // Annotations is an optional annotation map to be merged + // with the existing message annotations, overwriting existing keys + // if necessary. + Annotations Annotations +} + +// Address returns the link's address. +func (r *Receiver) Address() string { + if r.l.source == nil { + return "" + } + return r.l.source.Address +} + +// LinkName returns associated link name or an empty string if link is not defined. +func (r *Receiver) LinkName() string { + return r.l.key.name +} + +// LinkSourceFilterValue retrieves the specified link source filter value or nil if it doesn't exist. +func (r *Receiver) LinkSourceFilterValue(name string) any { + if r.l.source == nil { + return nil + } + filter, ok := r.l.source.Filter[encoding.Symbol(name)] + if !ok { + return nil + } + return filter.Value +} + +// Close closes the Receiver and AMQP link. +// - ctx controls waiting for the peer to acknowledge the close +// +// If the context's deadline expires or is cancelled before the operation +// completes, an error is returned. However, the operation will continue to +// execute in the background. Subsequent calls will return a *LinkError +// that contains the context's error message. +func (r *Receiver) Close(ctx context.Context) error { + return r.l.closeLink(ctx) +} + +// sendDisposition sends a disposition frame to the peer +func (r *Receiver) sendDisposition(ctx context.Context, first uint32, last *uint32, state encoding.DeliveryState) error { + fr := &frames.PerformDisposition{ + Role: encoding.RoleReceiver, + First: first, + Last: last, + Settled: r.l.receiverSettleMode == nil || *r.l.receiverSettleMode == ReceiverSettleModeFirst, + State: state, + } + + frameCtx := frameContext{ + Ctx: ctx, + Done: make(chan struct{}), + } + + select { + case r.txDisposition <- frameBodyEnvelope{FrameCtx: &frameCtx, FrameBody: fr}: + debug.Log(2, "TX (Receiver %p): mux txDisposition %s", r, fr) + case <-r.l.done: + return r.l.doneErr + } + + select { + case <-frameCtx.Done: + return frameCtx.Err + case <-r.l.done: + return r.l.doneErr + } +} + +// messageDisposition is called via the *Receiver associated with a *Message. +// this allows messages to be settled across Receiver instances. +// note that only unsettled messsages will have their rcv field set. +func (r *Receiver) messageDisposition(ctx context.Context, msg *Message, state encoding.DeliveryState) error { + // settling a message that's already settled (sender-settled or otherwise) will have a nil rcv. + // which means that r will be nil. you MUST NOT dereference r if msg.settled == true + if msg.settled { + return nil + } + + debug.Assert(r != nil) + + // NOTE: we MUST add to the in-flight map before sending the disposition. if not, it's possible + // to receive the ack'ing disposition frame *before* the in-flight map has been updated which + // will cause the below <-wait to never trigger. + + var wait chan error + if r.l.receiverSettleMode != nil && *r.l.receiverSettleMode == ReceiverSettleModeSecond { + debug.Log(3, "TX (Receiver %p): delivery ID %d is in flight", r, msg.deliveryID) + wait = r.inFlight.add(msg) + } + + if err := r.sendDisposition(ctx, msg.deliveryID, nil, state); err != nil { + return err + } + + if wait == nil { + // mode first, there will be no settlement ack + msg.onSettlement() + r.deleteUnsettled() + r.onSettlement(1) + return nil + } + + select { + case err := <-wait: + // err has three possibilities + // - nil, meaning the peer acknowledged the settlement + // - an *Error, meaning the peer rejected the message with a provided error + // - a non-AMQP error. this comes from calls to inFlight.clear() during mux unwind. + // only for the first two cases is the message considered settled + + if amqpErr := (&Error{}); err == nil || errors.As(err, &amqpErr) { + debug.Log(3, "RX (Receiver %p): delivery ID %d has been settled", r, msg.deliveryID) + // we've received confirmation of disposition + return err + } + + debug.Log(3, "RX (Receiver %p): error settling delivery ID %d: %v", r, msg.deliveryID, err) + return err + + case <-ctx.Done(): + // didn't receive the ack in the time allotted, leave message as unsettled + // TODO: if the ack arrives later, we need to remove the message from the unsettled map and reclaim the credit + return ctx.Err() + } +} + +// onSettlement is to be called after message settlement. +// - count is the number of messages that were settled +func (r *Receiver) onSettlement(count uint32) { + if !r.autoSendFlow { + return + } + + r.settlementCountMu.Lock() + r.settlementCount += count + r.settlementCountMu.Unlock() + + select { + case r.receiverReady <- struct{}{}: + // woke up + default: + // wake pending + } +} + +// increments the count of unsettled messages. +// this is only called from our mux. +func (r *Receiver) addUnsettled() { + atomic.AddInt32(&r.unsettledMessages, 1) +} + +// decrements the count of unsettled messages. +// this is called inside _or_ outside the mux. +// it's called outside when RSM is mode first. +func (r *Receiver) deleteUnsettled() { + atomic.AddInt32(&r.unsettledMessages, -1) +} + +// returns the count of unsettled messages. +// this is only called from our mux for diagnostic purposes. +func (r *Receiver) countUnsettled() int32 { + return atomic.LoadInt32(&r.unsettledMessages) +} + +func newReceiver(source string, session *Session, opts *ReceiverOptions) (*Receiver, error) { + l := newLink(session, encoding.RoleReceiver) + l.source = &frames.Source{Address: source} + l.target = new(frames.Target) + l.linkCredit = defaultLinkCredit + r := &Receiver{ + l: l, + autoSendFlow: true, + receiverReady: make(chan struct{}, 1), + txDisposition: make(chan frameBodyEnvelope), + } + + r.messagesQ = queue.NewHolder(queue.New[Message](int(session.incomingWindow))) + + if opts == nil { + return r, nil + } + + for _, v := range opts.Capabilities { + r.l.target.Capabilities = append(r.l.target.Capabilities, encoding.Symbol(v)) + } + if opts.Credit > 0 { + r.l.linkCredit = uint32(opts.Credit) + } else if opts.Credit < 0 { + r.l.linkCredit = 0 + r.autoSendFlow = false + } + if opts.Durability > DurabilityUnsettledState { + return nil, fmt.Errorf("invalid Durability %d", opts.Durability) + } + r.l.target.Durable = opts.Durability + if opts.DynamicAddress { + r.l.source.Address = "" + r.l.dynamicAddr = opts.DynamicAddress + } + if opts.ExpiryPolicy != "" { + if err := encoding.ValidateExpiryPolicy(opts.ExpiryPolicy); err != nil { + return nil, err + } + r.l.target.ExpiryPolicy = opts.ExpiryPolicy + } + r.l.target.Timeout = opts.ExpiryTimeout + if opts.Filters != nil { + r.l.source.Filter = make(encoding.Filter) + for _, f := range opts.Filters { + f(r.l.source.Filter) + } + } + if opts.MaxMessageSize > 0 { + r.l.maxMessageSize = opts.MaxMessageSize + } + if opts.Name != "" { + r.l.key.name = opts.Name + } + if opts.Properties != nil { + r.l.properties = make(map[encoding.Symbol]any) + for k, v := range opts.Properties { + if k == "" { + return nil, errors.New("link property key must not be empty") + } + r.l.properties[encoding.Symbol(k)] = v + } + } + if opts.RequestedSenderSettleMode != nil { + if rsm := *opts.RequestedSenderSettleMode; rsm > SenderSettleModeMixed { + return nil, fmt.Errorf("invalid RequestedSenderSettleMode %d", rsm) + } + r.l.senderSettleMode = opts.RequestedSenderSettleMode + } + if opts.SettlementMode != nil { + if rsm := *opts.SettlementMode; rsm > ReceiverSettleModeSecond { + return nil, fmt.Errorf("invalid SettlementMode %d", rsm) + } + r.l.receiverSettleMode = opts.SettlementMode + } + r.l.target.Address = opts.TargetAddress + for _, v := range opts.SourceCapabilities { + r.l.source.Capabilities = append(r.l.source.Capabilities, encoding.Symbol(v)) + } + if opts.SourceDurability != DurabilityNone { + r.l.source.Durable = opts.SourceDurability + } + if opts.SourceExpiryPolicy != ExpiryPolicySessionEnd { + r.l.source.ExpiryPolicy = opts.SourceExpiryPolicy + } + if opts.SourceExpiryTimeout != 0 { + r.l.source.Timeout = opts.SourceExpiryTimeout + } + return r, nil +} + +// attach sends the Attach performative to establish the link with its parent session. +// this is automatically called by the new*Link constructors. +func (r *Receiver) attach(ctx context.Context) error { + if err := r.l.attach(ctx, func(pa *frames.PerformAttach) { + pa.Role = encoding.RoleReceiver + if pa.Source == nil { + pa.Source = new(frames.Source) + } + pa.Source.Dynamic = r.l.dynamicAddr + }, func(pa *frames.PerformAttach) { + if r.l.source == nil { + r.l.source = new(frames.Source) + } + // if dynamic address requested, copy assigned name to address + if r.l.dynamicAddr && pa.Source != nil { + r.l.source.Address = pa.Source.Address + } + // deliveryCount is a sequence number, must initialize to sender's initial sequence number + r.l.deliveryCount = pa.InitialDeliveryCount + // copy the received filter values + if pa.Source != nil { + r.l.source.Filter = pa.Source.Filter + } + }); err != nil { + return err + } + + return nil +} + +func nopHook() {} + +type receiverTestHooks struct { + MuxStart func() + MuxSelect func() +} + +func (r *Receiver) mux(hooks receiverTestHooks) { + if hooks.MuxSelect == nil { + hooks.MuxSelect = nopHook + } + if hooks.MuxStart == nil { + hooks.MuxStart = nopHook + } + + defer func() { + // unblock any in flight message dispositions + r.inFlight.clear(r.l.doneErr) + + if !r.autoSendFlow { + // unblock any pending drain requests + r.creditor.EndDrain() + } + + close(r.l.done) + }() + + hooks.MuxStart() + + if r.autoSendFlow { + r.l.doneErr = r.muxFlow(r.l.linkCredit, false) + } + + for { + msgLen := r.messagesQ.Len() + + r.settlementCountMu.Lock() + // counter that accumulates the settled delivery count. + // once the threshold has been reached, the counter is + // reset and a flow frame is sent. + previousSettlementCount := r.settlementCount + if previousSettlementCount >= r.l.linkCredit { + r.settlementCount = 0 + } + r.settlementCountMu.Unlock() + + // once we have pending credit equal to or greater than our available credit, reclaim it. + // we do this instead of settlementCount > 0 to prevent flow frames from being too chatty. + // NOTE: we compare the settlementCount against the current link credit instead of some + // fixed threshold to ensure credit is reclaimed in cases where the number of unsettled + // messages remains high for whatever reason. + if r.autoSendFlow && previousSettlementCount > 0 && previousSettlementCount >= r.l.linkCredit { + debug.Log(1, "RX (Receiver %p) (auto): source: %q, inflight: %d, linkCredit: %d, deliveryCount: %d, messages: %d, unsettled: %d, settlementCount: %d, settleMode: %s", + r, r.l.source.Address, r.inFlight.len(), r.l.linkCredit, r.l.deliveryCount, msgLen, r.countUnsettled(), previousSettlementCount, r.l.receiverSettleMode.String()) + r.l.doneErr = r.creditor.IssueCredit(previousSettlementCount) + } else if r.l.linkCredit == 0 { + debug.Log(1, "RX (Receiver %p) (pause): source: %q, inflight: %d, linkCredit: %d, deliveryCount: %d, messages: %d, unsettled: %d, settlementCount: %d, settleMode: %s", + r, r.l.source.Address, r.inFlight.len(), r.l.linkCredit, r.l.deliveryCount, msgLen, r.countUnsettled(), previousSettlementCount, r.l.receiverSettleMode.String()) + } + + if r.l.doneErr != nil { + return + } + + drain, credits := r.creditor.FlowBits(r.l.linkCredit) + if drain || credits > 0 { + debug.Log(1, "RX (Receiver %p) (flow): source: %q, inflight: %d, curLinkCredit: %d, newLinkCredit: %d, drain: %v, deliveryCount: %d, messages: %d, unsettled: %d, settlementCount: %d, settleMode: %s", + r, r.l.source.Address, r.inFlight.len(), r.l.linkCredit, credits, drain, r.l.deliveryCount, msgLen, r.countUnsettled(), previousSettlementCount, r.l.receiverSettleMode.String()) + + // send a flow frame. + r.l.doneErr = r.muxFlow(credits, drain) + } + + if r.l.doneErr != nil { + return + } + + txDisposition := r.txDisposition + closed := r.l.close + if r.l.closeInProgress { + // swap out channel so it no longer triggers + closed = nil + + // disable sending of disposition frames once closing is in progress. + // this is to prevent races between mux shutdown and clearing of + // any in-flight dispositions. + txDisposition = nil + } + + hooks.MuxSelect() + + select { + case q := <-r.l.rxQ.Wait(): + // populated queue + fr := *q.Dequeue() + r.l.rxQ.Release(q) + + // if muxHandleFrame returns an error it means the mux must terminate. + // note that in the case of a client-side close due to an error, nil + // is returned in order to keep the mux running to ack the detach frame. + if err := r.muxHandleFrame(fr); err != nil { + r.l.doneErr = err + return + } + + case env := <-txDisposition: + r.l.txFrame(env.FrameCtx, env.FrameBody) + + case <-r.receiverReady: + continue + + case <-closed: + if r.l.closeInProgress { + // a client-side close due to protocol error is in progress + continue + } + + // receiver is being closed by the client + r.l.closeInProgress = true + fr := &frames.PerformDetach{ + Handle: r.l.outputHandle, + Closed: true, + } + r.l.txFrame(&frameContext{Ctx: context.Background()}, fr) + + case <-r.l.session.done: + r.l.doneErr = r.l.session.doneErr + return + } + } +} + +// muxFlow sends tr to the session mux. +// l.linkCredit will also be updated to `linkCredit` +func (r *Receiver) muxFlow(linkCredit uint32, drain bool) error { + var ( + deliveryCount = r.l.deliveryCount + ) + + fr := &frames.PerformFlow{ + Handle: &r.l.outputHandle, + DeliveryCount: &deliveryCount, + LinkCredit: &linkCredit, // max number of messages, + Drain: drain, + } + + // Update credit. This must happen before entering loop below + // because incoming messages handled while waiting to transmit + // flow increment deliveryCount. This causes the credit to become + // out of sync with the server. + + if !drain { + // if we're draining we don't want to touch our internal credit - we're not changing it so any issued credits + // are still valid until drain completes, at which point they will be naturally zeroed. + r.l.linkCredit = linkCredit + } + + select { + case r.l.session.tx <- frameBodyEnvelope{FrameCtx: &frameContext{Ctx: context.Background()}, FrameBody: fr}: + debug.Log(2, "TX (Receiver %p): mux frame to Session (%p): %d, %s", r, r.l.session, r.l.session.channel, fr) + return nil + case <-r.l.close: + return nil + case <-r.l.session.done: + return r.l.session.doneErr + } +} + +// muxHandleFrame processes fr based on type. +func (r *Receiver) muxHandleFrame(fr frames.FrameBody) error { + debug.Log(2, "RX (Receiver %p): %s", r, fr) + switch fr := fr.(type) { + // message frame + case *frames.PerformTransfer: + r.muxReceive(*fr) + + // flow control frame + case *frames.PerformFlow: + if !fr.Echo { + // if the 'drain' flag has been set in the frame sent to the _receiver_ then + // we signal whomever is waiting (the service has seen and acknowledged our drain) + if fr.Drain && !r.autoSendFlow { + r.l.linkCredit = 0 // we have no active credits at this point. + r.creditor.EndDrain() + } + return nil + } + + var ( + // copy because sent by pointer below; prevent race + linkCredit = r.l.linkCredit + deliveryCount = r.l.deliveryCount + ) + + // send flow + resp := &frames.PerformFlow{ + Handle: &r.l.outputHandle, + DeliveryCount: &deliveryCount, + LinkCredit: &linkCredit, // max number of messages + } + + select { + case r.l.session.tx <- frameBodyEnvelope{FrameCtx: &frameContext{Ctx: context.Background()}, FrameBody: resp}: + debug.Log(2, "TX (Receiver %p): mux frame to Session (%p): %d, %s", r, r.l.session, r.l.session.channel, resp) + case <-r.l.close: + return nil + case <-r.l.session.done: + return r.l.session.doneErr + } + + case *frames.PerformDisposition: + // Unblock receivers waiting for message disposition + // bubble disposition error up to the receiver + var dispositionError error + if state, ok := fr.State.(*encoding.StateRejected); ok { + // state.Error isn't required to be filled out. For instance if you dead letter a message + // you will get a rejected response that doesn't contain an error. + if state.Error != nil { + dispositionError = state.Error + } + } + // removal from the in-flight map will also remove the message from the unsettled map + count := r.inFlight.remove(fr.First, fr.Last, dispositionError, func(msg *Message) { + r.deleteUnsettled() + msg.onSettlement() + }) + r.onSettlement(count) + + default: + return r.l.muxHandleFrame(fr) + } + + return nil +} + +func (r *Receiver) muxReceive(fr frames.PerformTransfer) { + if !r.more { + // this is the first transfer of a message, + // record the delivery ID, message format, + // and delivery Tag + if fr.DeliveryID != nil { + r.msg.deliveryID = *fr.DeliveryID + } + if fr.MessageFormat != nil { + r.msg.Format = *fr.MessageFormat + } + r.msg.DeliveryTag = fr.DeliveryTag + + // these fields are required on first transfer of a message + if fr.DeliveryID == nil { + r.l.closeWithError(ErrCondNotAllowed, "received message without a delivery-id") + return + } + if fr.MessageFormat == nil { + r.l.closeWithError(ErrCondNotAllowed, "received message without a message-format") + return + } + if fr.DeliveryTag == nil { + r.l.closeWithError(ErrCondNotAllowed, "received message without a delivery-tag") + return + } + } else { + // this is a continuation of a multipart message + // some fields may be omitted on continuation transfers, + // but if they are included they must be consistent + // with the first. + + if fr.DeliveryID != nil && *fr.DeliveryID != r.msg.deliveryID { + msg := fmt.Sprintf( + "received continuation transfer with inconsistent delivery-id: %d != %d", + *fr.DeliveryID, r.msg.deliveryID, + ) + r.l.closeWithError(ErrCondNotAllowed, msg) + return + } + if fr.MessageFormat != nil && *fr.MessageFormat != r.msg.Format { + msg := fmt.Sprintf( + "received continuation transfer with inconsistent message-format: %d != %d", + *fr.MessageFormat, r.msg.Format, + ) + r.l.closeWithError(ErrCondNotAllowed, msg) + return + } + if fr.DeliveryTag != nil && !bytes.Equal(fr.DeliveryTag, r.msg.DeliveryTag) { + msg := fmt.Sprintf( + "received continuation transfer with inconsistent delivery-tag: %q != %q", + fr.DeliveryTag, r.msg.DeliveryTag, + ) + r.l.closeWithError(ErrCondNotAllowed, msg) + return + } + } + + // discard message if it's been aborted + if fr.Aborted { + r.msgBuf.Reset() + r.msg = Message{} + r.more = false + return + } + + // ensure maxMessageSize will not be exceeded + if r.l.maxMessageSize != 0 && uint64(r.msgBuf.Len())+uint64(len(fr.Payload)) > r.l.maxMessageSize { + r.l.closeWithError(ErrCondMessageSizeExceeded, fmt.Sprintf("received message larger than max size of %d", r.l.maxMessageSize)) + return + } + + // add the payload the the buffer + r.msgBuf.Append(fr.Payload) + + // mark as settled if at least one frame is settled + r.msg.settled = r.msg.settled || fr.Settled + + // save in-progress status + r.more = fr.More + + if fr.More { + return + } + + // last frame in message + err := r.msg.Unmarshal(&r.msgBuf) + if err != nil { + r.l.closeWithError(ErrCondInternalError, err.Error()) + return + } + + // send to receiver + if !r.msg.settled { + r.addUnsettled() + r.msg.rcv = r + debug.Log(3, "RX (Receiver %p): add unsettled delivery ID %d", r, r.msg.deliveryID) + } + + q := r.messagesQ.Acquire() + q.Enqueue(r.msg) + msgLen := q.Len() + r.messagesQ.Release(q) + + // reset progress + r.msgBuf.Reset() + r.msg = Message{} + + // decrement link-credit after entire message received + r.l.deliveryCount++ + r.l.linkCredit-- + debug.Log(3, "RX (Receiver %p) link %s - deliveryCount: %d, linkCredit: %d, len(messages): %d", r, r.l.key.name, r.l.deliveryCount, r.l.linkCredit, msgLen) +} + +// inFlight tracks in-flight message dispositions allowing receivers +// to block waiting for the server to respond when an appropriate +// settlement mode is configured. +type inFlight struct { + mu sync.RWMutex + m map[uint32]inFlightInfo +} + +type inFlightInfo struct { + wait chan error + msg *Message +} + +func (f *inFlight) add(msg *Message) chan error { + wait := make(chan error, 1) + + f.mu.Lock() + if f.m == nil { + f.m = make(map[uint32]inFlightInfo) + } + + f.m[msg.deliveryID] = inFlightInfo{wait: wait, msg: msg} + f.mu.Unlock() + + return wait +} + +func (f *inFlight) remove(first uint32, last *uint32, err error, handler func(*Message)) uint32 { + f.mu.Lock() + + if f.m == nil { + f.mu.Unlock() + return 0 + } + + ll := first + if last != nil { + ll = *last + } + + count := uint32(0) + for i := first; i <= ll; i++ { + info, ok := f.m[i] + if ok { + handler(info.msg) + info.wait <- err + delete(f.m, i) + count++ + } + } + + f.mu.Unlock() + return count +} + +func (f *inFlight) clear(err error) { + f.mu.Lock() + for id, info := range f.m { + info.wait <- err + delete(f.m, id) + } + f.mu.Unlock() +} + +func (f *inFlight) len() int { + f.mu.RLock() + defer f.mu.RUnlock() + return len(f.m) +} diff --git a/vendor/github.com/Azure/go-amqp/sasl.go b/vendor/github.com/Azure/go-amqp/sasl.go new file mode 100644 index 00000000000..d7eea4add11 --- /dev/null +++ b/vendor/github.com/Azure/go-amqp/sasl.go @@ -0,0 +1,259 @@ +package amqp + +import ( + "context" + "fmt" + + "github.com/Azure/go-amqp/internal/debug" + "github.com/Azure/go-amqp/internal/encoding" + "github.com/Azure/go-amqp/internal/frames" +) + +// SASL Mechanisms +const ( + saslMechanismPLAIN encoding.Symbol = "PLAIN" + saslMechanismANONYMOUS encoding.Symbol = "ANONYMOUS" + saslMechanismEXTERNAL encoding.Symbol = "EXTERNAL" + saslMechanismXOAUTH2 encoding.Symbol = "XOAUTH2" +) + +// SASLType represents a SASL configuration to use during authentication. +type SASLType func(c *Conn) error + +// ConnSASLPlain enables SASL PLAIN authentication for the connection. +// +// SASL PLAIN transmits credentials in plain text and should only be used +// on TLS/SSL enabled connection. +func SASLTypePlain(username, password string) SASLType { + // TODO: how widely used is hostname? should it be supported + return func(c *Conn) error { + // make handlers map if no other mechanism has + if c.saslHandlers == nil { + c.saslHandlers = make(map[encoding.Symbol]stateFunc) + } + + // add the handler the the map + c.saslHandlers[saslMechanismPLAIN] = func(ctx context.Context) (stateFunc, error) { + // send saslInit with PLAIN payload + init := &frames.SASLInit{ + Mechanism: "PLAIN", + InitialResponse: []byte("\x00" + username + "\x00" + password), + Hostname: "", + } + fr := frames.Frame{ + Type: frames.TypeSASL, + Body: init, + } + debug.Log(1, "TX (ConnSASLPlain %p): %s", c, fr) + timeout, err := c.getWriteTimeout(ctx) + if err != nil { + return nil, err + } + if err = c.writeFrame(timeout, fr); err != nil { + return nil, err + } + + // go to c.saslOutcome to handle the server response + return c.saslOutcome, nil + } + return nil + } +} + +// ConnSASLAnonymous enables SASL ANONYMOUS authentication for the connection. +func SASLTypeAnonymous() SASLType { + return func(c *Conn) error { + // make handlers map if no other mechanism has + if c.saslHandlers == nil { + c.saslHandlers = make(map[encoding.Symbol]stateFunc) + } + + // add the handler the the map + c.saslHandlers[saslMechanismANONYMOUS] = func(ctx context.Context) (stateFunc, error) { + init := &frames.SASLInit{ + Mechanism: saslMechanismANONYMOUS, + InitialResponse: []byte("anonymous"), + } + fr := frames.Frame{ + Type: frames.TypeSASL, + Body: init, + } + debug.Log(1, "TX (ConnSASLAnonymous %p): %s", c, fr) + timeout, err := c.getWriteTimeout(ctx) + if err != nil { + return nil, err + } + if err = c.writeFrame(timeout, fr); err != nil { + return nil, err + } + + // go to c.saslOutcome to handle the server response + return c.saslOutcome, nil + } + return nil + } +} + +// ConnSASLExternal enables SASL EXTERNAL authentication for the connection. +// The value for resp is dependent on the type of authentication (empty string is common for TLS). +// See https://datatracker.ietf.org/doc/html/rfc4422#appendix-A for additional info. +func SASLTypeExternal(resp string) SASLType { + return func(c *Conn) error { + // make handlers map if no other mechanism has + if c.saslHandlers == nil { + c.saslHandlers = make(map[encoding.Symbol]stateFunc) + } + + // add the handler the the map + c.saslHandlers[saslMechanismEXTERNAL] = func(ctx context.Context) (stateFunc, error) { + init := &frames.SASLInit{ + Mechanism: saslMechanismEXTERNAL, + InitialResponse: []byte(resp), + } + fr := frames.Frame{ + Type: frames.TypeSASL, + Body: init, + } + debug.Log(1, "TX (ConnSASLExternal %p): %s", c, fr) + timeout, err := c.getWriteTimeout(ctx) + if err != nil { + return nil, err + } + if err = c.writeFrame(timeout, fr); err != nil { + return nil, err + } + + // go to c.saslOutcome to handle the server response + return c.saslOutcome, nil + } + return nil + } +} + +// ConnSASLXOAUTH2 enables SASL XOAUTH2 authentication for the connection. +// +// The saslMaxFrameSizeOverride parameter allows the limit that governs the maximum frame size this client will allow +// itself to generate to be raised for the sasl-init frame only. Set this when the size of the size of the SASL XOAUTH2 +// initial client response (which contains the username and bearer token) would otherwise breach the 512 byte min-max-frame-size +// (http://docs.oasis-open.org/amqp/core/v1.0/os/amqp-core-transport-v1.0-os.html#definition-MIN-MAX-FRAME-SIZE). Pass -1 +// to keep the default. +// +// SASL XOAUTH2 transmits the bearer in plain text and should only be used +// on TLS/SSL enabled connection. +func SASLTypeXOAUTH2(username, bearer string, saslMaxFrameSizeOverride uint32) SASLType { + return func(c *Conn) error { + // make handlers map if no other mechanism has + if c.saslHandlers == nil { + c.saslHandlers = make(map[encoding.Symbol]stateFunc) + } + + response, err := saslXOAUTH2InitialResponse(username, bearer) + if err != nil { + return err + } + + handler := saslXOAUTH2Handler{ + conn: c, + maxFrameSizeOverride: saslMaxFrameSizeOverride, + response: response, + } + // add the handler the the map + c.saslHandlers[saslMechanismXOAUTH2] = handler.init + return nil + } +} + +type saslXOAUTH2Handler struct { + conn *Conn + maxFrameSizeOverride uint32 + response []byte + errorResponse []byte // https://developers.google.com/gmail/imap/xoauth2-protocol#error_response +} + +func (s saslXOAUTH2Handler) init(ctx context.Context) (stateFunc, error) { + originalPeerMaxFrameSize := s.conn.peerMaxFrameSize + if s.maxFrameSizeOverride > s.conn.peerMaxFrameSize { + s.conn.peerMaxFrameSize = s.maxFrameSizeOverride + } + timeout, err := s.conn.getWriteTimeout(ctx) + if err != nil { + return nil, err + } + err = s.conn.writeFrame(timeout, frames.Frame{ + Type: frames.TypeSASL, + Body: &frames.SASLInit{ + Mechanism: saslMechanismXOAUTH2, + InitialResponse: s.response, + }, + }) + s.conn.peerMaxFrameSize = originalPeerMaxFrameSize + if err != nil { + return nil, err + } + + return s.step, nil +} + +func (s saslXOAUTH2Handler) step(ctx context.Context) (stateFunc, error) { + // read challenge or outcome frame + fr, err := s.conn.readFrame() + if err != nil { + return nil, err + } + + switch v := fr.Body.(type) { + case *frames.SASLOutcome: + // check if auth succeeded + if v.Code != encoding.CodeSASLOK { + return nil, fmt.Errorf("SASL XOAUTH2 auth failed with code %#00x: %s : %s", + v.Code, v.AdditionalData, s.errorResponse) + } + + // return to c.negotiateProto + s.conn.saslComplete = true + return s.conn.negotiateProto, nil + case *frames.SASLChallenge: + if s.errorResponse == nil { + s.errorResponse = v.Challenge + + timeout, err := s.conn.getWriteTimeout(ctx) + if err != nil { + return nil, err + } + + // The SASL protocol requires clients to send an empty response to this challenge. + err = s.conn.writeFrame(timeout, frames.Frame{ + Type: frames.TypeSASL, + Body: &frames.SASLResponse{ + Response: []byte{}, + }, + }) + if err != nil { + return nil, err + } + return s.step, nil + } else { + return nil, fmt.Errorf("SASL XOAUTH2 unexpected additional error response received during "+ + "exchange. Initial error response: %s, additional response: %s", s.errorResponse, v.Challenge) + } + default: + return nil, fmt.Errorf("sasl: unexpected frame type %T", fr.Body) + } +} + +func saslXOAUTH2InitialResponse(username string, bearer string) ([]byte, error) { + if len(bearer) == 0 { + return []byte{}, fmt.Errorf("unacceptable bearer token") + } + for _, char := range bearer { + if char < '\x20' || char > '\x7E' { + return []byte{}, fmt.Errorf("unacceptable bearer token") + } + } + for _, char := range username { + if char == '\x01' { + return []byte{}, fmt.Errorf("unacceptable username") + } + } + return []byte("user=" + username + "\x01auth=Bearer " + bearer + "\x01\x01"), nil +} diff --git a/vendor/github.com/Azure/go-amqp/sender.go b/vendor/github.com/Azure/go-amqp/sender.go new file mode 100644 index 00000000000..8287f5e9d85 --- /dev/null +++ b/vendor/github.com/Azure/go-amqp/sender.go @@ -0,0 +1,505 @@ +package amqp + +import ( + "context" + "encoding/binary" + "errors" + "fmt" + "sync" + + "github.com/Azure/go-amqp/internal/buffer" + "github.com/Azure/go-amqp/internal/debug" + "github.com/Azure/go-amqp/internal/encoding" + "github.com/Azure/go-amqp/internal/frames" +) + +// Sender sends messages on a single AMQP link. +type Sender struct { + l link + transfers chan transferEnvelope // sender uses to send transfer frames + + mu sync.Mutex // protects buf and nextDeliveryTag + buf buffer.Buffer + nextDeliveryTag uint64 + rollback chan struct{} +} + +// LinkName() is the name of the link used for this Sender. +func (s *Sender) LinkName() string { + return s.l.key.name +} + +// MaxMessageSize is the maximum size of a single message. +func (s *Sender) MaxMessageSize() uint64 { + return s.l.maxMessageSize +} + +// SendOptions contains any optional values for the Sender.Send method. +type SendOptions struct { + // Indicates the message is to be sent as settled when settlement mode is SenderSettleModeMixed. + // If the settlement mode is SenderSettleModeUnsettled and Settled is true, an error is returned. + Settled bool +} + +// Send sends a Message. +// +// Blocks until the message is sent or an error occurs. If the peer is +// configured for receiver settlement mode second, the call also blocks +// until the peer confirms message settlement. +// +// - ctx controls waiting for the message to be sent and possibly confirmed +// - msg is the message to send +// - opts contains optional values, pass nil to accept the defaults +// +// If the context's deadline expires or is cancelled before the operation +// completes, the message is in an unknown state of transmission. +// +// Send is safe for concurrent use. Since only a single message can be +// sent on a link at a time, this is most useful when settlement confirmation +// has been requested (receiver settle mode is second). In this case, +// additional messages can be sent while the current goroutine is waiting +// for the confirmation. +func (s *Sender) Send(ctx context.Context, msg *Message, opts *SendOptions) error { + // check if the link is dead. while it's safe to call s.send + // in this case, this will avoid some allocations etc. + select { + case <-s.l.done: + return s.l.doneErr + default: + // link is still active + } + done, err := s.send(ctx, msg, opts) + if err != nil { + return err + } + + // wait for transfer to be confirmed + select { + case state := <-done: + if state, ok := state.(*encoding.StateRejected); ok { + if state.Error != nil { + return state.Error + } + return errors.New("the peer rejected the message without specifying an error") + } + return nil + case <-s.l.done: + return s.l.doneErr + case <-ctx.Done(): + // TODO: if the message is not settled and we never received a disposition, how can we consider the message as sent? + return ctx.Err() + } +} + +// send is separated from Send so that the mutex unlock can be deferred without +// locking the transfer confirmation that happens in Send. +func (s *Sender) send(ctx context.Context, msg *Message, opts *SendOptions) (chan encoding.DeliveryState, error) { + const ( + maxDeliveryTagLength = 32 + maxTransferFrameHeader = 66 // determined by calcMaxTransferFrameHeader + ) + if len(msg.DeliveryTag) > maxDeliveryTagLength { + return nil, &Error{ + Condition: ErrCondMessageSizeExceeded, + Description: fmt.Sprintf("delivery tag is over the allowed %v bytes, len: %v", maxDeliveryTagLength, len(msg.DeliveryTag)), + } + } + + s.mu.Lock() + defer s.mu.Unlock() + + s.buf.Reset() + err := msg.Marshal(&s.buf) + if err != nil { + return nil, err + } + + if s.l.maxMessageSize != 0 && uint64(s.buf.Len()) > s.l.maxMessageSize { + return nil, &Error{ + Condition: ErrCondMessageSizeExceeded, + Description: fmt.Sprintf("encoded message size exceeds max of %d", s.l.maxMessageSize), + } + } + + senderSettled := senderSettleModeValue(s.l.senderSettleMode) == SenderSettleModeSettled + if opts != nil { + if opts.Settled && senderSettleModeValue(s.l.senderSettleMode) == SenderSettleModeUnsettled { + return nil, errors.New("can't send message as settled when sender settlement mode is unsettled") + } else if opts.Settled { + senderSettled = true + } + } + + var ( + maxPayloadSize = int64(s.l.session.conn.peerMaxFrameSize) - maxTransferFrameHeader + ) + + deliveryTag := msg.DeliveryTag + if len(deliveryTag) == 0 { + // use uint64 encoded as []byte as deliveryTag + deliveryTag = make([]byte, 8) + binary.BigEndian.PutUint64(deliveryTag, s.nextDeliveryTag) + s.nextDeliveryTag++ + } + + fr := frames.PerformTransfer{ + Handle: s.l.outputHandle, + DeliveryID: &needsDeliveryID, + DeliveryTag: deliveryTag, + MessageFormat: &msg.Format, + More: s.buf.Len() > 0, + } + + for fr.More { + buf, _ := s.buf.Next(maxPayloadSize) + fr.Payload = append([]byte(nil), buf...) + fr.More = s.buf.Len() > 0 + if !fr.More { + // SSM=settled: overrides RSM; no acks. + // SSM=unsettled: sender should wait for receiver to ack + // RSM=first: receiver considers it settled immediately, but must still send ack (SSM=unsettled only) + // RSM=second: receiver sends ack and waits for return ack from sender (SSM=unsettled only) + + // mark final transfer as settled when sender mode is settled + fr.Settled = senderSettled + + // set done on last frame + fr.Done = make(chan encoding.DeliveryState, 1) + } + + // NOTE: we MUST send a copy of fr here since we modify it post send + + frameCtx := frameContext{ + Ctx: ctx, + Done: make(chan struct{}), + } + + select { + case s.transfers <- transferEnvelope{FrameCtx: &frameCtx, InputHandle: s.l.inputHandle, Frame: fr}: + // frame was sent to our mux + case <-s.l.done: + return nil, s.l.doneErr + case <-ctx.Done(): + return nil, &Error{Condition: ErrCondTransferLimitExceeded, Description: fmt.Sprintf("credit limit exceeded for sending link %s", s.l.key.name)} + } + + select { + case <-frameCtx.Done: + if frameCtx.Err != nil { + if !fr.More { + select { + case s.rollback <- struct{}{}: + // the write never happened so signal the mux to roll back the delivery count and link credit + case <-s.l.close: + // the link is going down + } + } + return nil, frameCtx.Err + } + // frame was written to the network + case <-s.l.done: + return nil, s.l.doneErr + } + + // clear values that are only required on first message + fr.DeliveryID = nil + fr.DeliveryTag = nil + fr.MessageFormat = nil + } + + return fr.Done, nil +} + +// Address returns the link's address. +func (s *Sender) Address() string { + if s.l.target == nil { + return "" + } + return s.l.target.Address +} + +// Close closes the Sender and AMQP link. +// - ctx controls waiting for the peer to acknowledge the close +// +// If the context's deadline expires or is cancelled before the operation +// completes, an error is returned. However, the operation will continue to +// execute in the background. Subsequent calls will return a *LinkError +// that contains the context's error message. +func (s *Sender) Close(ctx context.Context) error { + return s.l.closeLink(ctx) +} + +// newSendingLink creates a new sending link and attaches it to the session +func newSender(target string, session *Session, opts *SenderOptions) (*Sender, error) { + l := newLink(session, encoding.RoleSender) + l.target = &frames.Target{Address: target} + l.source = new(frames.Source) + s := &Sender{ + l: l, + rollback: make(chan struct{}), + } + + if opts == nil { + return s, nil + } + + for _, v := range opts.Capabilities { + s.l.source.Capabilities = append(s.l.source.Capabilities, encoding.Symbol(v)) + } + if opts.Durability > DurabilityUnsettledState { + return nil, fmt.Errorf("invalid Durability %d", opts.Durability) + } + s.l.source.Durable = opts.Durability + if opts.DynamicAddress { + s.l.target.Address = "" + s.l.dynamicAddr = opts.DynamicAddress + } + if opts.ExpiryPolicy != "" { + if err := encoding.ValidateExpiryPolicy(opts.ExpiryPolicy); err != nil { + return nil, err + } + s.l.source.ExpiryPolicy = opts.ExpiryPolicy + } + s.l.source.Timeout = opts.ExpiryTimeout + if opts.Name != "" { + s.l.key.name = opts.Name + } + if opts.Properties != nil { + s.l.properties = make(map[encoding.Symbol]any) + for k, v := range opts.Properties { + if k == "" { + return nil, errors.New("link property key must not be empty") + } + s.l.properties[encoding.Symbol(k)] = v + } + } + if opts.RequestedReceiverSettleMode != nil { + if rsm := *opts.RequestedReceiverSettleMode; rsm > ReceiverSettleModeSecond { + return nil, fmt.Errorf("invalid RequestedReceiverSettleMode %d", rsm) + } + s.l.receiverSettleMode = opts.RequestedReceiverSettleMode + } + if opts.SettlementMode != nil { + if ssm := *opts.SettlementMode; ssm > SenderSettleModeMixed { + return nil, fmt.Errorf("invalid SettlementMode %d", ssm) + } + s.l.senderSettleMode = opts.SettlementMode + } + s.l.source.Address = opts.SourceAddress + for _, v := range opts.TargetCapabilities { + s.l.target.Capabilities = append(s.l.target.Capabilities, encoding.Symbol(v)) + } + if opts.TargetDurability != DurabilityNone { + s.l.target.Durable = opts.TargetDurability + } + if opts.TargetExpiryPolicy != ExpiryPolicySessionEnd { + s.l.target.ExpiryPolicy = opts.TargetExpiryPolicy + } + if opts.TargetExpiryTimeout != 0 { + s.l.target.Timeout = opts.TargetExpiryTimeout + } + return s, nil +} + +func (s *Sender) attach(ctx context.Context) error { + if err := s.l.attach(ctx, func(pa *frames.PerformAttach) { + pa.Role = encoding.RoleSender + if pa.Target == nil { + pa.Target = new(frames.Target) + } + pa.Target.Dynamic = s.l.dynamicAddr + }, func(pa *frames.PerformAttach) { + if s.l.target == nil { + s.l.target = new(frames.Target) + } + + // if dynamic address requested, copy assigned name to address + if s.l.dynamicAddr && pa.Target != nil { + s.l.target.Address = pa.Target.Address + } + }); err != nil { + return err + } + + s.transfers = make(chan transferEnvelope) + + return nil +} + +type senderTestHooks struct { + MuxSelect func() + MuxTransfer func() +} + +func (s *Sender) mux(hooks senderTestHooks) { + if hooks.MuxSelect == nil { + hooks.MuxSelect = nopHook + } + if hooks.MuxTransfer == nil { + hooks.MuxTransfer = nopHook + } + + defer func() { + close(s.l.done) + }() + +Loop: + for { + var outgoingTransfers chan transferEnvelope + if s.l.linkCredit > 0 { + debug.Log(1, "TX (Sender %p) (enable): target: %q, link credit: %d, deliveryCount: %d", s, s.l.target.Address, s.l.linkCredit, s.l.deliveryCount) + outgoingTransfers = s.transfers + } else { + debug.Log(1, "TX (Sender %p) (pause): target: %q, link credit: %d, deliveryCount: %d", s, s.l.target.Address, s.l.linkCredit, s.l.deliveryCount) + } + + closed := s.l.close + if s.l.closeInProgress { + // swap out channel so it no longer triggers + closed = nil + + // disable sending once closing is in progress. + // this prevents races with mux shutdown and + // the peer sending disposition frames. + outgoingTransfers = nil + } + + hooks.MuxSelect() + + select { + // received frame + case q := <-s.l.rxQ.Wait(): + // populated queue + fr := *q.Dequeue() + s.l.rxQ.Release(q) + + // if muxHandleFrame returns an error it means the mux must terminate. + // note that in the case of a client-side close due to an error, nil + // is returned in order to keep the mux running to ack the detach frame. + if err := s.muxHandleFrame(fr); err != nil { + s.l.doneErr = err + return + } + + // send data + case env := <-outgoingTransfers: + hooks.MuxTransfer() + select { + case s.l.session.txTransfer <- env: + debug.Log(2, "TX (Sender %p): mux transfer to Session: %d, %s", s, s.l.session.channel, env.Frame) + // decrement link-credit after entire message transferred + if !env.Frame.More { + s.l.deliveryCount++ + s.l.linkCredit-- + // we are the sender and we keep track of the peer's link credit + debug.Log(3, "TX (Sender %p): link: %s, link credit: %d", s, s.l.key.name, s.l.linkCredit) + } + continue Loop + case <-s.l.close: + continue Loop + case <-s.l.session.done: + continue Loop + } + + case <-closed: + if s.l.closeInProgress { + // a client-side close due to protocol error is in progress + continue + } + + // sender is being closed by the client + s.l.closeInProgress = true + fr := &frames.PerformDetach{ + Handle: s.l.outputHandle, + Closed: true, + } + s.l.txFrame(&frameContext{Ctx: context.Background()}, fr) + + case <-s.l.session.done: + s.l.doneErr = s.l.session.doneErr + return + + case <-s.rollback: + s.l.deliveryCount-- + s.l.linkCredit++ + debug.Log(3, "TX (Sender %p): rollback link: %s, link credit: %d", s, s.l.key.name, s.l.linkCredit) + } + } +} + +// muxHandleFrame processes fr based on type. +// depending on the peer's RSM, it might return a disposition frame for sending +func (s *Sender) muxHandleFrame(fr frames.FrameBody) error { + debug.Log(2, "RX (Sender %p): %s", s, fr) + switch fr := fr.(type) { + // flow control frame + case *frames.PerformFlow: + // the sender's link-credit variable MUST be set according to this formula when flow information is given by the receiver: + // link-credit(snd) := delivery-count(rcv) + link-credit(rcv) - delivery-count(snd) + linkCredit := *fr.LinkCredit - s.l.deliveryCount + if fr.DeliveryCount != nil { + // DeliveryCount can be nil if the receiver hasn't processed + // the attach. That shouldn't be the case here, but it's + // what ActiveMQ does. + linkCredit += *fr.DeliveryCount + } + + s.l.linkCredit = linkCredit + + if !fr.Echo { + return nil + } + + var ( + // copy because sent by pointer below; prevent race + deliveryCount = s.l.deliveryCount + ) + + // send flow + resp := &frames.PerformFlow{ + Handle: &s.l.outputHandle, + DeliveryCount: &deliveryCount, + LinkCredit: &linkCredit, // max number of messages + } + + select { + case s.l.session.tx <- frameBodyEnvelope{FrameCtx: &frameContext{Ctx: context.Background()}, FrameBody: resp}: + debug.Log(2, "TX (Sender %p): mux frame to Session (%p): %d, %s", s, s.l.session, s.l.session.channel, resp) + case <-s.l.close: + return nil + case <-s.l.session.done: + return s.l.session.doneErr + } + + case *frames.PerformDisposition: + if fr.Settled { + return nil + } + + // peer is in mode second, so we must send confirmation of disposition. + // NOTE: the ack must be sent through the session so it can close out + // the in-flight disposition. + dr := &frames.PerformDisposition{ + Role: encoding.RoleSender, + First: fr.First, + Last: fr.Last, + Settled: true, + } + + select { + case s.l.session.tx <- frameBodyEnvelope{FrameCtx: &frameContext{Ctx: context.Background()}, FrameBody: dr}: + debug.Log(2, "TX (Sender %p): mux frame to Session (%p): %d, %s", s, s.l.session, s.l.session.channel, dr) + case <-s.l.close: + return nil + case <-s.l.session.done: + return s.l.session.doneErr + } + + return nil + + default: + return s.l.muxHandleFrame(fr) + } + + return nil +} diff --git a/vendor/github.com/Azure/go-amqp/session.go b/vendor/github.com/Azure/go-amqp/session.go new file mode 100644 index 00000000000..4ca6c747e1f --- /dev/null +++ b/vendor/github.com/Azure/go-amqp/session.go @@ -0,0 +1,822 @@ +package amqp + +import ( + "context" + "errors" + "fmt" + "math" + "sync" + + "github.com/Azure/go-amqp/internal/bitmap" + "github.com/Azure/go-amqp/internal/debug" + "github.com/Azure/go-amqp/internal/encoding" + "github.com/Azure/go-amqp/internal/frames" + "github.com/Azure/go-amqp/internal/queue" +) + +// Default session options +const ( + defaultWindow = 5000 +) + +// SessionOptions contains the optional settings for configuring an AMQP session. +type SessionOptions struct { + // MaxLinks sets the maximum number of links (Senders/Receivers) + // allowed on the session. + // + // Minimum: 1. + // Default: 4294967295. + MaxLinks uint32 +} + +// Session is an AMQP session. +// +// A session multiplexes Receivers. +type Session struct { + channel uint16 // session's local channel + remoteChannel uint16 // session's remote channel, owned by conn.connReader + conn *Conn // underlying conn + tx chan frameBodyEnvelope // non-transfer frames to be sent; session must track disposition + txTransfer chan transferEnvelope // transfer frames to be sent; session must track disposition + + // frames destined for this session are added to this queue by conn.connReader + rxQ *queue.Holder[frames.FrameBody] + + // flow control + incomingWindow uint32 + outgoingWindow uint32 + needFlowCount uint32 + + handleMax uint32 + + // link management + linksMu sync.RWMutex // used to synchronize link handle allocation + linksByKey map[linkKey]*link // mapping of name+role link + outputHandles *bitmap.Bitmap // allocated output handles + + abandonedLinksMu sync.Mutex + abandonedLinks []*link + + // used for gracefully closing session + close chan struct{} // closed by calling Close(). it signals that the end performative should be sent + closeOnce sync.Once + + // part of internal public surface area + done chan struct{} // closed when the session has terminated (mux exited); DO NOT wait on this from within Session.mux() as it will never trigger! + endSent chan struct{} // closed when the end performative has been sent; once this is closed, links MUST NOT send any frames! + doneErr error // contains the mux error state; ONLY written to by the mux and MUST only be read from after done is closed! + closeErr error // contains the error state returned from Close(); ONLY Close() reads/writes this! +} + +func newSession(c *Conn, channel uint16, opts *SessionOptions) *Session { + s := &Session{ + conn: c, + channel: channel, + tx: make(chan frameBodyEnvelope), + txTransfer: make(chan transferEnvelope), + incomingWindow: defaultWindow, + outgoingWindow: defaultWindow, + handleMax: math.MaxUint32 - 1, + linksMu: sync.RWMutex{}, + linksByKey: make(map[linkKey]*link), + close: make(chan struct{}), + done: make(chan struct{}), + endSent: make(chan struct{}), + } + + if opts != nil { + if opts.MaxLinks != 0 { + // MaxLinks is the number of total links. + // handleMax is the max handle ID which starts + // at zero. so we decrement by one + s.handleMax = opts.MaxLinks - 1 + } + } + + // create output handle map after options have been applied + s.outputHandles = bitmap.New(s.handleMax) + + s.rxQ = queue.NewHolder(queue.New[frames.FrameBody](int(s.incomingWindow))) + + return s +} + +// waitForFrame waits for an incoming frame to be queued. +// it returns the next frame from the queue, or an error. +// the error is either from the context or conn.doneErr. +// not meant for consumption outside of session.go. +func (s *Session) waitForFrame(ctx context.Context) (frames.FrameBody, error) { + var q *queue.Queue[frames.FrameBody] + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-s.conn.done: + return nil, s.conn.doneErr + case q = <-s.rxQ.Wait(): + // populated queue + } + + fr := q.Dequeue() + s.rxQ.Release(q) + + return *fr, nil +} + +func (s *Session) begin(ctx context.Context) error { + // send Begin to server + begin := &frames.PerformBegin{ + NextOutgoingID: 0, + IncomingWindow: s.incomingWindow, + OutgoingWindow: s.outgoingWindow, + HandleMax: s.handleMax, + } + + if err := s.txFrameAndWait(ctx, begin); err != nil { + return err + } + + // wait for response + fr, err := s.waitForFrame(ctx) + if err != nil { + // if we exit before receiving the ack, our caller will clean up the channel. + // however, it does mean that the peer will now have assigned an outgoing + // channel ID that's not in use. + return err + } + + begin, ok := fr.(*frames.PerformBegin) + if !ok { + // this codepath is hard to hit (impossible?). if the response isn't a PerformBegin and we've not + // yet seen the remote channel number, the default clause in conn.connReader will protect us from that. + // if we have seen the remote channel number then it's likely the session.mux for that channel will + // either swallow the frame or blow up in some other way, both causing this call to hang. + // deallocate session on error. we can't call + // s.Close() as the session mux hasn't started yet. + debug.Log(1, "RX (Session %p): unexpected begin response frame %T", s, fr) + s.conn.deleteSession(s) + if err := s.conn.Close(); err != nil { + return err + } + return &ConnError{inner: fmt.Errorf("unexpected begin response: %#v", fr)} + } + + // start Session multiplexor + go s.mux(begin) + + return nil +} + +// Close closes the session. +// - ctx controls waiting for the peer to acknowledge the session is closed +// +// If the context's deadline expires or is cancelled before the operation +// completes, an error is returned. However, the operation will continue to +// execute in the background. Subsequent calls will return a *SessionError +// that contains the context's error message. +func (s *Session) Close(ctx context.Context) error { + var ctxErr error + s.closeOnce.Do(func() { + close(s.close) + + // once the mux has received the ack'ing end performative, the mux will + // exit which deletes the session and closes s.done. + select { + case <-s.done: + s.closeErr = s.doneErr + + case <-ctx.Done(): + // notify the caller that the close timed out/was cancelled. + // the mux will remain running and once the ack is received it will terminate. + ctxErr = ctx.Err() + + // record that the close timed out/was cancelled. + // subsequent calls to Close() will return this + debug.Log(1, "TX (Session %p) channel %d: %v", s, s.channel, ctxErr) + s.closeErr = &SessionError{inner: ctxErr} + } + }) + + if ctxErr != nil { + return ctxErr + } + + var sessionErr *SessionError + if errors.As(s.closeErr, &sessionErr) && sessionErr.RemoteErr == nil && sessionErr.inner == nil { + // an empty SessionError means the session was cleanly closed by the caller + return nil + } + return s.closeErr +} + +// txFrame sends a frame to the connWriter. +// - ctx is used to provide the write deadline +// - fr is the frame to write to net.Conn +func (s *Session) txFrame(frameCtx *frameContext, fr frames.FrameBody) { + debug.Log(2, "TX (Session %p) mux frame to Conn (%p): %s", s, s.conn, fr) + s.conn.sendFrame(frameEnvelope{ + FrameCtx: frameCtx, + Frame: frames.Frame{ + Type: frames.TypeAMQP, + Channel: s.channel, + Body: fr, + }, + }) +} + +// txFrameAndWait sends a frame to the connWriter and waits for the write to complete +// - ctx is used to provide the write deadline +// - fr is the frame to write to net.Conn +func (s *Session) txFrameAndWait(ctx context.Context, fr frames.FrameBody) error { + frameCtx := frameContext{ + Ctx: ctx, + Done: make(chan struct{}), + } + + s.txFrame(&frameCtx, fr) + + select { + case <-frameCtx.Done: + return frameCtx.Err + case <-s.conn.done: + return s.conn.doneErr + case <-s.done: + return s.doneErr + } +} + +// NewReceiver opens a new receiver link on the session. +// - ctx controls waiting for the peer to create a sending terminus +// - source is the name of the peer's sending terminus +// - opts contains optional values, pass nil to accept the defaults +// +// If the context's deadline expires or is cancelled before the operation +// completes, an error is returned. If the Receiver was successfully +// created, it will be cleaned up in future calls to NewReceiver. +func (s *Session) NewReceiver(ctx context.Context, source string, opts *ReceiverOptions) (*Receiver, error) { + return newReceiverForSession(ctx, s, source, opts, receiverTestHooks{}) +} + +// split out so tests can add hooks +func newReceiverForSession(ctx context.Context, s *Session, source string, opts *ReceiverOptions, hooks receiverTestHooks) (*Receiver, error) { + r, err := newReceiver(source, s, opts) + if err != nil { + return nil, err + } + if err = r.attach(ctx); err != nil { + return nil, err + } + + go r.mux(hooks) + + return r, nil +} + +// NewSender opens a new sender link on the session. +// - ctx controls waiting for the peer to create a receiver terminus +// - target is the name of the peer's receiver terminus +// - opts contains optional values, pass nil to accept the defaults +// +// If the context's deadline expires or is cancelled before the operation +// completes, an error is returned. If the Sender was successfully +// created, it will be cleaned up in future calls to NewSender. +func (s *Session) NewSender(ctx context.Context, target string, opts *SenderOptions) (*Sender, error) { + return newSenderForSession(ctx, s, target, opts, senderTestHooks{}) +} + +// split out so tests can add hooks +func newSenderForSession(ctx context.Context, s *Session, target string, opts *SenderOptions, hooks senderTestHooks) (*Sender, error) { + l, err := newSender(target, s, opts) + if err != nil { + return nil, err + } + if err = l.attach(ctx); err != nil { + return nil, err + } + + go l.mux(hooks) + + return l, nil +} + +func (s *Session) mux(remoteBegin *frames.PerformBegin) { + defer func() { + if s.doneErr == nil { + s.doneErr = &SessionError{} + } else if connErr := (&ConnError{}); !errors.As(s.doneErr, &connErr) { + // only wrap non-ConnError error types + var amqpErr *Error + if errors.As(s.doneErr, &amqpErr) { + s.doneErr = &SessionError{RemoteErr: amqpErr} + } else { + s.doneErr = &SessionError{inner: s.doneErr} + } + } + // Signal goroutines waiting on the session. + close(s.done) + }() + + var ( + // maps input (remote) handles to links + linkFromInputHandle = make(map[uint32]*link) + + // maps local delivery IDs (sending transfers) to input (remote) handles + inputHandleFromDeliveryID = make(map[uint32]uint32) + + // maps remote delivery IDs (receiving transfers) to input (remote) handles + inputHandleFromRemoteDeliveryID = make(map[uint32]uint32) + + // maps delivery IDs to output (our) handles. used for multi-frame transfers + deliveryIDFromOutputHandle = make(map[uint32]uint32) + + // maps delivery IDs to the settlement state channel + settlementFromDeliveryID = make(map[uint32]chan encoding.DeliveryState) + + // tracks the next delivery ID for outgoing transfers + nextDeliveryID uint32 + + // flow control values + nextOutgoingID uint32 + nextIncomingID = remoteBegin.NextOutgoingID + remoteIncomingWindow = remoteBegin.IncomingWindow + remoteOutgoingWindow = remoteBegin.OutgoingWindow + + closeInProgress bool // indicates the end performative has been sent + ) + + closeWithError := func(e1 *Error, e2 error) { + if closeInProgress { + debug.Log(3, "TX (Session %p): close already pending, discarding %v", s, e1) + return + } + + closeInProgress = true + s.doneErr = e2 + s.txFrame(&frameContext{Ctx: context.Background()}, &frames.PerformEnd{Error: e1}) + close(s.endSent) + } + + for { + txTransfer := s.txTransfer + // disable txTransfer if flow control windows have been exceeded + if remoteIncomingWindow == 0 || s.outgoingWindow == 0 { + debug.Log(1, "TX (Session %p): disabling txTransfer - window exceeded. remoteIncomingWindow: %d outgoingWindow: %d", + s, remoteIncomingWindow, s.outgoingWindow) + txTransfer = nil + } + + tx := s.tx + closed := s.close + if closeInProgress { + // swap out channel so it no longer triggers + closed = nil + + // once the end performative is sent, we're not allowed to send any frames + tx = nil + txTransfer = nil + } + + // notes on client-side closing session + // when session is closed, we must keep the mux running until the ack'ing end performative + // has been received. during this window, the session is allowed to receive frames but cannot + // send them. + // client-side close happens either by user calling Session.Close() or due to mux initiated + // close due to a violation of some invariant (see sending &Error{} to s.close). in the case + // that both code paths have been triggered, we must be careful to preserve the error that + // triggered the mux initiated close so it can be surfaced to the caller. + + select { + // conn has completed, exit + case <-s.conn.done: + s.doneErr = s.conn.doneErr + return + + case <-closed: + if closeInProgress { + // a client-side close due to protocol error is in progress + continue + } + // session is being closed by the client + closeInProgress = true + s.txFrame(&frameContext{Ctx: context.Background()}, &frames.PerformEnd{}) + close(s.endSent) + + // incoming frame + case q := <-s.rxQ.Wait(): + fr := *q.Dequeue() + s.rxQ.Release(q) + debug.Log(2, "RX (Session %p): %s", s, fr) + + switch body := fr.(type) { + // Disposition frames can reference transfers from more than one + // link. Send this frame to all of them. + case *frames.PerformDisposition: + start := body.First + end := start + if body.Last != nil { + end = *body.Last + } + for deliveryID := start; deliveryID <= end; deliveryID++ { + // find the input (remote) handle for this delivery ID. + // default to the map for local delivery IDs. + handles := inputHandleFromDeliveryID + if body.Role == encoding.RoleSender { + // the disposition frame is meant for a receiver + // so look in the map for remote delivery IDs. + handles = inputHandleFromRemoteDeliveryID + } + + inputHandle, ok := handles[deliveryID] + if !ok { + debug.Log(2, "RX (Session %p): role %s: didn't find deliveryID %d in inputHandlesByDeliveryID map", s, body.Role, deliveryID) + continue + } + delete(handles, deliveryID) + + if body.Settled && body.Role == encoding.RoleReceiver { + // check if settlement confirmation was requested, if so + // confirm by closing channel (RSM == ModeFirst) + if done, ok := settlementFromDeliveryID[deliveryID]; ok { + delete(settlementFromDeliveryID, deliveryID) + select { + case done <- body.State: + default: + } + close(done) + } + } + + // now find the *link for this input (remote) handle + link, ok := linkFromInputHandle[inputHandle] + if !ok { + closeWithError(&Error{ + Condition: ErrCondUnattachedHandle, + Description: "received disposition frame referencing a handle that's not in use", + }, fmt.Errorf("received disposition frame with unknown link input handle %d", inputHandle)) + continue + } + + s.muxFrameToLink(link, fr) + } + continue + case *frames.PerformFlow: + if body.NextIncomingID == nil { + // This is a protocol error: + // "[...] MUST be set if the peer has received + // the begin frame for the session" + closeWithError(&Error{ + Condition: ErrCondNotAllowed, + Description: "next-incoming-id not set after session established", + }, errors.New("protocol error: received flow without next-incoming-id after session established")) + continue + } + + // "When the endpoint receives a flow frame from its peer, + // it MUST update the next-incoming-id directly from the + // next-outgoing-id of the frame, and it MUST update the + // remote-outgoing-window directly from the outgoing-window + // of the frame." + nextIncomingID = body.NextOutgoingID + remoteOutgoingWindow = body.OutgoingWindow + + // "The remote-incoming-window is computed as follows: + // + // next-incoming-id(flow) + incoming-window(flow) - next-outgoing-id(endpoint) + // + // If the next-incoming-id field of the flow frame is not set, then remote-incoming-window is computed as follows: + // + // initial-outgoing-id(endpoint) + incoming-window(flow) - next-outgoing-id(endpoint)" + remoteIncomingWindow = body.IncomingWindow - nextOutgoingID + remoteIncomingWindow += *body.NextIncomingID + debug.Log(3, "RX (Session %p): flow - remoteOutgoingWindow: %d remoteIncomingWindow: %d nextOutgoingID: %d", s, remoteOutgoingWindow, remoteIncomingWindow, nextOutgoingID) + + // Send to link if handle is set + if body.Handle != nil { + link, ok := linkFromInputHandle[*body.Handle] + if !ok { + closeWithError(&Error{ + Condition: ErrCondUnattachedHandle, + Description: "received flow frame referencing a handle that's not in use", + }, fmt.Errorf("received flow frame with unknown link handle %d", body.Handle)) + continue + } + + s.muxFrameToLink(link, fr) + continue + } + + if body.Echo && !closeInProgress { + niID := nextIncomingID + resp := &frames.PerformFlow{ + NextIncomingID: &niID, + IncomingWindow: s.incomingWindow, + NextOutgoingID: nextOutgoingID, + OutgoingWindow: s.outgoingWindow, + } + s.txFrame(&frameContext{Ctx: context.Background()}, resp) + } + + case *frames.PerformAttach: + // On Attach response link should be looked up by name, then added + // to the links map with the remote's handle contained in this + // attach frame. + // + // Note body.Role is the remote peer's role, we reverse for the local key. + s.linksMu.RLock() + link, linkOk := s.linksByKey[linkKey{name: body.Name, role: !body.Role}] + s.linksMu.RUnlock() + if !linkOk { + closeWithError(&Error{ + Condition: ErrCondNotAllowed, + Description: "received mismatched attach frame", + }, fmt.Errorf("protocol error: received mismatched attach frame %+v", body)) + continue + } + + // track the input (remote) handle number for this link. + // note that it might be a different value than our output handle. + link.inputHandle = body.Handle + linkFromInputHandle[link.inputHandle] = link + + s.muxFrameToLink(link, fr) + + debug.Log(1, "RX (Session %p): link %s attached, input handle %d, output handle %d", s, link.key.name, link.inputHandle, link.outputHandle) + + case *frames.PerformTransfer: + s.needFlowCount++ + // "Upon receiving a transfer, the receiving endpoint will + // increment the next-incoming-id to match the implicit + // transfer-id of the incoming transfer plus one, as well + // as decrementing the remote-outgoing-window, and MAY + // (depending on policy) decrement its incoming-window." + nextIncomingID++ + // don't loop to intmax + if remoteOutgoingWindow > 0 { + remoteOutgoingWindow-- + } + link, ok := linkFromInputHandle[body.Handle] + if !ok { + closeWithError(&Error{ + Condition: ErrCondUnattachedHandle, + Description: "received transfer frame referencing a handle that's not in use", + }, fmt.Errorf("received transfer frame with unknown link handle %d", body.Handle)) + continue + } + + s.muxFrameToLink(link, fr) + + // if this message is received unsettled and link rcv-settle-mode == second, add to handlesByRemoteDeliveryID + if !body.Settled && body.DeliveryID != nil && link.receiverSettleMode != nil && *link.receiverSettleMode == ReceiverSettleModeSecond { + debug.Log(1, "RX (Session %p): adding handle %d to inputHandleFromRemoteDeliveryID. remote delivery ID: %d", s, body.Handle, *body.DeliveryID) + inputHandleFromRemoteDeliveryID[*body.DeliveryID] = body.Handle + } + + // Update peer's outgoing window if half has been consumed. + if s.needFlowCount >= s.incomingWindow/2 && !closeInProgress { + debug.Log(3, "RX (Session %p): channel %d: flow - s.needFlowCount(%d) >= s.incomingWindow(%d)/2\n", s, s.channel, s.needFlowCount, s.incomingWindow) + s.needFlowCount = 0 + nID := nextIncomingID + flow := &frames.PerformFlow{ + NextIncomingID: &nID, + IncomingWindow: s.incomingWindow, + NextOutgoingID: nextOutgoingID, + OutgoingWindow: s.outgoingWindow, + } + s.txFrame(&frameContext{Ctx: context.Background()}, flow) + } + + case *frames.PerformDetach: + link, ok := linkFromInputHandle[body.Handle] + if !ok { + closeWithError(&Error{ + Condition: ErrCondUnattachedHandle, + Description: "received detach frame referencing a handle that's not in use", + }, fmt.Errorf("received detach frame with unknown link handle %d", body.Handle)) + continue + } + s.muxFrameToLink(link, fr) + + // we received a detach frame and sent it to the link. + // this was either the response to a client-side initiated + // detach or our peer detached us. either way, now that + // the link has processed the frame it's detached so we + // are safe to clean up its state. + delete(linkFromInputHandle, link.inputHandle) + delete(deliveryIDFromOutputHandle, link.outputHandle) + s.deallocateHandle(link) + + case *frames.PerformEnd: + // there are two possibilities: + // - this is the ack to a client-side Close() + // - the peer is ending the session so we must ack + + if closeInProgress { + return + } + + // peer detached us with an error, save it and send the ack + if body.Error != nil { + s.doneErr = body.Error + } + + fr := frames.PerformEnd{} + s.txFrame(&frameContext{Ctx: context.Background()}, &fr) + + // per spec, when end is received, we're no longer allowed to receive frames + return + + default: + debug.Log(1, "RX (Session %p): unexpected frame: %s\n", s, body) + closeWithError(&Error{ + Condition: ErrCondInternalError, + Description: "session received unexpected frame", + }, fmt.Errorf("internal error: unexpected frame %T", body)) + } + + case env := <-txTransfer: + fr := &env.Frame + // record current delivery ID + var deliveryID uint32 + if fr.DeliveryID == &needsDeliveryID { + deliveryID = nextDeliveryID + fr.DeliveryID = &deliveryID + nextDeliveryID++ + deliveryIDFromOutputHandle[fr.Handle] = deliveryID + + if !fr.Settled { + inputHandleFromDeliveryID[deliveryID] = env.InputHandle + } + } else { + // if fr.DeliveryID is nil it must have been added + // to deliveryIDByHandle already (multi-frame transfer) + deliveryID = deliveryIDFromOutputHandle[fr.Handle] + } + + // log after the delivery ID has been assigned + debug.Log(2, "TX (Session %p): %d, %s", s, s.channel, fr) + + // frame has been sender-settled, remove from map. + // this should only come into play for multi-frame transfers. + if fr.Settled { + delete(inputHandleFromDeliveryID, deliveryID) + } + + s.txFrame(env.FrameCtx, fr) + + select { + case <-env.FrameCtx.Done: + if env.FrameCtx.Err != nil { + // transfer wasn't sent, don't update state + continue + } + // transfer was written to the network + case <-s.conn.done: + // the write failed, Conn is going down + continue + } + + // if not settled, add done chan to map + if !fr.Settled && fr.Done != nil { + settlementFromDeliveryID[deliveryID] = fr.Done + } else if fr.Done != nil { + // sender-settled, close done now that the transfer has been sent + close(fr.Done) + } + + // "Upon sending a transfer, the sending endpoint will increment + // its next-outgoing-id, decrement its remote-incoming-window, + // and MAY (depending on policy) decrement its outgoing-window." + nextOutgoingID++ + // don't decrement if we're at 0 or we could loop to int max + if remoteIncomingWindow != 0 { + remoteIncomingWindow-- + } + + case env := <-tx: + fr := env.FrameBody + debug.Log(2, "TX (Session %p): %d, %s", s, s.channel, fr) + switch fr := env.FrameBody.(type) { + case *frames.PerformDisposition: + if fr.Settled && fr.Role == encoding.RoleSender { + // sender with a peer that's in mode second; sending confirmation of disposition. + // disposition frames can reference a range of delivery IDs, although it's highly + // likely in this case there will only be one. + start := fr.First + end := start + if fr.Last != nil { + end = *fr.Last + } + for deliveryID := start; deliveryID <= end; deliveryID++ { + // send delivery state to the channel and close it to signal + // that the delivery has completed (RSM == ModeSecond) + if done, ok := settlementFromDeliveryID[deliveryID]; ok { + delete(settlementFromDeliveryID, deliveryID) + select { + case done <- fr.State: + default: + } + close(done) + } + } + } + s.txFrame(env.FrameCtx, fr) + case *frames.PerformFlow: + niID := nextIncomingID + fr.NextIncomingID = &niID + fr.IncomingWindow = s.incomingWindow + fr.NextOutgoingID = nextOutgoingID + fr.OutgoingWindow = s.outgoingWindow + s.txFrame(env.FrameCtx, fr) + case *frames.PerformTransfer: + panic("transfer frames must use txTransfer") + default: + s.txFrame(env.FrameCtx, fr) + } + } + } +} + +func (s *Session) allocateHandle(ctx context.Context, l *link) error { + s.linksMu.Lock() + defer s.linksMu.Unlock() + + // Check if link name already exists, if so then an error should be returned + existing := s.linksByKey[l.key] + if existing != nil { + return fmt.Errorf("link with name '%v' already exists", l.key.name) + } + + next, ok := s.outputHandles.Next() + if !ok { + if err := s.Close(ctx); err != nil { + return err + } + // handle numbers are zero-based, report the actual count + return &SessionError{inner: fmt.Errorf("reached session handle max (%d)", s.handleMax+1)} + } + + l.outputHandle = next // allocate handle to the link + s.linksByKey[l.key] = l // add to mapping + + return nil +} + +func (s *Session) deallocateHandle(l *link) { + s.linksMu.Lock() + defer s.linksMu.Unlock() + + delete(s.linksByKey, l.key) + s.outputHandles.Remove(l.outputHandle) +} + +func (s *Session) abandonLink(l *link) { + s.abandonedLinksMu.Lock() + defer s.abandonedLinksMu.Unlock() + s.abandonedLinks = append(s.abandonedLinks, l) +} + +func (s *Session) freeAbandonedLinks(ctx context.Context) error { + s.abandonedLinksMu.Lock() + defer s.abandonedLinksMu.Unlock() + + debug.Log(3, "TX (Session %p): cleaning up %d abandoned links", s, len(s.abandonedLinks)) + + for _, l := range s.abandonedLinks { + dr := &frames.PerformDetach{ + Handle: l.outputHandle, + Closed: true, + } + if err := s.txFrameAndWait(ctx, dr); err != nil { + return err + } + } + + s.abandonedLinks = nil + return nil +} + +func (s *Session) muxFrameToLink(l *link, fr frames.FrameBody) { + q := l.rxQ.Acquire() + q.Enqueue(fr) + l.rxQ.Release(q) + debug.Log(2, "RX (Session %p): mux frame to link (%p): %s, %s", s, l, l.key.name, fr) +} + +// transferEnvelope is used by senders to send transfer frames +type transferEnvelope struct { + FrameCtx *frameContext + + // the link's remote handle + InputHandle uint32 + + Frame frames.PerformTransfer +} + +// frameBodyEnvelope is used by senders and receivers to send frames. +type frameBodyEnvelope struct { + FrameCtx *frameContext + FrameBody frames.FrameBody +} + +// the address of this var is a sentinel value indicating +// that a transfer frame is in need of a delivery ID +var needsDeliveryID uint32 diff --git a/vendor/modules.txt b/vendor/modules.txt index 1304d4292c0..5c2a5acd41c 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -33,6 +33,56 @@ cloud.google.com/go/trace/internal # github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24 ## explicit; go 1.20 github.com/AdaLogics/go-fuzz-headers +# github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1 +## explicit; go 1.18 +github.com/Azure/azure-sdk-for-go/sdk/azcore +github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud +github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported +github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/log +github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers +github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/async +github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/body +github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/fake +github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/loc +github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/op +github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared +github.com/Azure/azure-sdk-for-go/sdk/azcore/log +github.com/Azure/azure-sdk-for-go/sdk/azcore/policy +github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime +github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming +github.com/Azure/azure-sdk-for-go/sdk/azcore/to +github.com/Azure/azure-sdk-for-go/sdk/azcore/tracing +# github.com/Azure/azure-sdk-for-go/sdk/internal v1.7.0 +## explicit; go 1.18 +github.com/Azure/azure-sdk-for-go/sdk/internal/diag +github.com/Azure/azure-sdk-for-go/sdk/internal/errorinfo +github.com/Azure/azure-sdk-for-go/sdk/internal/exported +github.com/Azure/azure-sdk-for-go/sdk/internal/log +github.com/Azure/azure-sdk-for-go/sdk/internal/poller +github.com/Azure/azure-sdk-for-go/sdk/internal/telemetry +github.com/Azure/azure-sdk-for-go/sdk/internal/temporal +github.com/Azure/azure-sdk-for-go/sdk/internal/uuid +# github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs v1.2.1 +## explicit; go 1.18 +github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs +github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal +github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/amqpwrap +github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/auth +github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/eh +github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/exported +github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/sas +github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/sbauth +github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/utils +# github.com/Azure/go-amqp v1.0.5 +## explicit; go 1.18 +github.com/Azure/go-amqp +github.com/Azure/go-amqp/internal/bitmap +github.com/Azure/go-amqp/internal/buffer +github.com/Azure/go-amqp/internal/debug +github.com/Azure/go-amqp/internal/encoding +github.com/Azure/go-amqp/internal/frames +github.com/Azure/go-amqp/internal/queue +github.com/Azure/go-amqp/internal/shared # github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.20.0 ## explicit; go 1.20 github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp