diff --git a/.github/workflows/nightly-merge.yml b/.github/workflows/nightly-merge.yml deleted file mode 100644 index ede08f6df31..00000000000 --- a/.github/workflows/nightly-merge.yml +++ /dev/null @@ -1,25 +0,0 @@ -name: 'Nightly Merge (master to iiot)' - -on: - schedule: - - cron: '0 0 * * *' - -jobs: - nightly-merge: - - runs-on: ubuntu-latest - - steps: - - name: Checkout - uses: actions/checkout@v1 - - - name: Nightly Merge - uses: robotology/gh-action-nightly-merge@v1.2.0 - with: - stable_branch: 'master' - development_branch: 'iiot' - allow_ff: false - push_token: 'IIOT_NIGHTLY_MERGE_PUSH_TOKEN' - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - IIOT_NIGHTLY_MERGE_PUSH_TOKEN: ${{ secrets.IIOT_NIGHTLY_MERGE_PUSH_TOKEN }} \ No newline at end of file diff --git a/builds/checkin/api-proxy.yaml b/builds/checkin/api-proxy.yaml index 1a7de11f12a..f7a79563744 100644 --- a/builds/checkin/api-proxy.yaml +++ b/builds/checkin/api-proxy.yaml @@ -55,7 +55,7 @@ jobs: displayName: Modify path - bash: scripts/linux/generic-rust/install.sh --project-root "edge-modules/api-proxy-module" displayName: Install Rust - - bash: scripts/linux/cross-platform-rust-build.sh --os ubuntu18.04 --arch arm32v7 --build-path edge-modules/api-proxy-module + - bash: scripts/linux/cross-platform-rust-build.sh --os alpine --arch arm32v7 --build-path edge-modules/api-proxy-module displayName: build # No arm platform specific test. Our cross build tool does not work for dependancies outside the compile repo so another script is used to compile. # That script can only build, but can't run tests for now. @@ -75,7 +75,7 @@ jobs: displayName: Modify path - bash: scripts/linux/generic-rust/install.sh --project-root "edge-modules/api-proxy-module" displayName: Install Rust - - bash: scripts/linux/cross-platform-rust-build.sh --os ubuntu18.04 --arch aarch64 --build-path edge-modules/api-proxy-module + - bash: scripts/linux/cross-platform-rust-build.sh --os alpine --arch aarch64 --build-path edge-modules/api-proxy-module displayName: build # No arm platform specific test. Our cross build tool does not work for dependancies outside the compile repo so another script is used to compile. # That script can only build, but can't run tests for now. \ No newline at end of file diff --git a/builds/checkin/mqtt.yaml b/builds/checkin/mqtt.yaml index cf1d28b18f9..78038352856 100644 --- a/builds/checkin/mqtt.yaml +++ b/builds/checkin/mqtt.yaml @@ -5,14 +5,16 @@ pr: - master - release/* jobs: - ################################################################################ - job: check_run_pipeline - ################################################################################ + ################################################################################ displayName: Check pipeline preconditions (changes ARE in builds or mqtt) pool: vmImage: "ubuntu-16.04" steps: + - checkout: self + submodules: false + fetchDepth: 3 - bash: | git log -m -1 --name-only --first-parent --pretty="" | egrep -i '^(builds|mqtt)' if [[ $? == 0 ]]; then @@ -24,13 +26,16 @@ jobs: ################################################################################ - job: linux_amd64 - ################################################################################ + ################################################################################ displayName: Linux amd64 dependsOn: check_run_pipeline condition: eq(dependencies.check_run_pipeline.outputs['check_files.RUN_PIPELINE'], 'true') pool: vmImage: "ubuntu-16.04" steps: + - checkout: self + submodules: false # mqtt broker does not use submodules + fetchDepth: 3 - script: echo "##vso[task.setvariable variable=RUST_BACKTRACE;]1" displayName: Set env variables - bash: scripts/linux/generic-rust/install.sh --project-root "mqtt" @@ -39,18 +44,28 @@ jobs: displayName: Build with no default features - bash: scripts/linux/generic-rust/build.sh --project-root "mqtt" --packages "mqttd/Cargo.toml" --manifest-path displayName: Build with default features - - bash: mqtt/build/linux/test.sh + - bash: mqtt/build/linux/test.sh --report test-results.xml displayName: Test + - task: PublishTestResults@2 + displayName: Publish test results + inputs: + testResultsFormat: "JUnit" + testResultsFiles: "**/test-results.xml" + failTaskOnFailedTests: true + condition: succeededOrFailed() ################################################################################ - job: style_check - ################################################################################ + ################################################################################ displayName: Style Check dependsOn: check_run_pipeline condition: eq(dependencies.check_run_pipeline.outputs['check_files.RUN_PIPELINE'], 'true') pool: vmImage: "ubuntu-16.04" steps: + - checkout: self + submodules: false # mqtt broker does not use submodules + fetchDepth: 3 - bash: scripts/linux/generic-rust/install.sh --project-root "mqtt" displayName: Install Rust - bash: scripts/linux/generic-rust/format.sh --project-root "mqtt" diff --git a/builds/ci/mqtt.yaml b/builds/ci/mqtt.yaml index 7976c761cc5..b85f201aae0 100644 --- a/builds/ci/mqtt.yaml +++ b/builds/ci/mqtt.yaml @@ -13,6 +13,9 @@ jobs: pool: vmImage: "ubuntu-16.04" steps: + - checkout: self + submodules: false + fetchDepth: 3 - bash: | git log -m -1 --name-only --first-parent --pretty="" | egrep -i '^(builds|mqtt)' if [[ $? == 0 ]]; then @@ -31,6 +34,9 @@ jobs: pool: vmImage: 'ubuntu-16.04' steps: + - checkout: self + submodules: false # mqtt broker does not use submodules + fetchDepth: 3 - task: Bash@3 displayName: Install Rust inputs: @@ -50,6 +56,14 @@ jobs: displayName: Test inputs: filePath: mqtt/build/linux/test.sh + arguments: --report test-results.xml + - task: PublishTestResults@2 + displayName: Publish test results + inputs: + testResultsFormat: "JUnit" + testResultsFiles: "**/test-results.xml" + failTaskOnFailedTests: true + condition: succeededOrFailed() ################################################################################ - job: linux_arm32v7 @@ -60,6 +74,9 @@ jobs: pool: vmImage: 'ubuntu-16.04' steps: + - checkout: self + submodules: false # mqtt broker does not use submodules + fetchDepth: 3 - script: | echo "##vso[task.setvariable variable=RUSTUP_HOME;]$(Agent.WorkFolder)/rustup" echo "##vso[task.setvariable variable=CARGO_HOME;]$(Agent.WorkFolder)/cargo" @@ -73,12 +90,12 @@ jobs: - script: cargo install cross --version 0.1.16 displayName: Install cross - task: Bash@3 - displayName: Build + displayName: Build with no default features inputs: filePath: scripts/linux/generic-rust/build.sh arguments: --project-root "mqtt" --packages "mqttd/Cargo.toml" --manifest-path --no-default-features --features "generic" --target armv7-unknown-linux-gnueabihf --cargo cross - task: Bash@3 - displayName: Build + displayName: Build with default features inputs: filePath: scripts/linux/generic-rust/build.sh arguments: --project-root "mqtt" --packages "mqttd/Cargo.toml" --manifest-path --target armv7-unknown-linux-gnueabihf --cargo cross @@ -86,28 +103,11 @@ jobs: displayName: Test inputs: filePath: mqtt/build/linux/test.sh - arguments: --target armv7-unknown-linux-gnueabihf --cargo cross - -################################################################################ - - job: style_check -################################################################################ - displayName: Style Check - dependsOn: check_run_pipeline - condition: eq(dependencies.check_run_pipeline.outputs['check_files.RUN_PIPELINE'], 'true') - pool: - vmImage: 'ubuntu-16.04' - steps: - - task: Bash@3 - displayName: Install Rust - inputs: - filePath: scripts/linux/generic-rust/install.sh - arguments: --project-root "mqtt" - - task: Bash@3 - displayName: Format Code - inputs: - filePath: scripts/linux/generic-rust/format.sh - arguments: --project-root "mqtt" - - task: Bash@3 - displayName: Clippy + arguments: --target armv7-unknown-linux-gnueabihf --cargo cross --report test-results.xml + - task: PublishTestResults@2 + displayName: Publish test results inputs: - filePath: mqtt/build/linux/clippy.sh \ No newline at end of file + testResultsFormat: "JUnit" + testResultsFiles: "**/test-results.xml" + failTaskOnFailedTests: true + condition: succeededOrFailed() \ No newline at end of file diff --git a/builds/e2e/Runner.ps1 b/builds/e2e/Runner.ps1 deleted file mode 100644 index 27b89cefc08..00000000000 --- a/builds/e2e/Runner.ps1 +++ /dev/null @@ -1,153 +0,0 @@ -New-Module -ScriptBlock { - - <# - # Completes initialization of a Windows VM (version 1809 or later) that is - # installed behind an HTTPS proxy server for end-to-end testing of Azure - # IoT Edge. It installs/enables sshd so that the test agent can SSH into - # this VM and Linux VMs using the same commands. It also configures system- - # wide proxy settings (so that sshd and related components can be - # downloaded) and sets the default SSH shell to a newer version of - # PowerShell (so that the agent can make use of the NoProxy parameter - # on Invoke-WebRequest rather than reverting system-wide proxy settings). - #> - - #requires -Version 5 - #requires -RunAsAdministrator - - Set-Variable OpenSshUtilsManifest -Option Constant -Value 'https://raw.githubusercontent.com/PowerShell/openssh-portable/68ad673db4bf971b5c087cef19bb32953fd9db75/contrib/win32/openssh/OpenSSHUtils.psd1' - Set-Variable OpenSshUtilsModule -Option Constant -Value 'https://raw.githubusercontent.com/PowerShell/openssh-portable/68ad673db4bf971b5c087cef19bb32953fd9db75/contrib/win32/openssh/OpenSSHUtils.psm1' - - function Get-WebResource { - param ( - [String] $proxyUri, - [String] $sourceUri, - [String] $destinationFile - ) - - if (-not (Test-Path $destinationFile -PathType Leaf)) { - Invoke-WebRequest -UseBasicParsing $sourceUri -Proxy $proxyUri -OutFile $destinationFile - } - } - - function Start-TestPath { - param ( - [String] $path, - [Int32] $timeoutSecs - ) - - $action = { - while (-not (Test-Path $args[0])) { - Start-Sleep 5 - } - - Test-Path $args[0] - } - - $job = Start-Job $action -ArgumentList $path - $result = $job | Wait-Job -Timeout $timeoutSecs | Receive-Job - $job | Stop-Job -PassThru | Remove-Job - return $result - } - - function Initialize-WindowsVM { - [CmdletBinding()] - param ( - [ValidateNotNullOrEmpty()] - [String] $ProxyHostname, - - [ValidateNotNullOrEmpty()] - [String] $SshPublicKeyBase64 - ) - - Set-StrictMode -Version "Latest" - $ErrorActionPreference = "Stop" - - $sshPublicKey = [System.Text.Encoding]::Utf8.GetString([System.Convert]::FromBase64String($SshPublicKeyBase64)) - $proxyUri = "http://${ProxyHostname}:3128" - - Write-Host 'Setting wininet proxy' - - Install-PackageProvider -Name NuGet -Force -Proxy $proxyUri - Register-PSRepository -Default -Proxy $proxyUri - Set-PSRepository PSGallery -InstallationPolicy Trusted -Proxy $proxyUri - Install-Module NetworkingDsc -MinimumVersion 6.3 -AllowClobber -Force -Proxy $proxyUri - Invoke-DscResource ProxySettings -Method Set -ModuleName NetworkingDsc -Property @{ - IsSingleInstance = "Yes" - EnableManualProxy = $true - ProxyServerBypassLocal = $true - ProxyServer = $proxyUri - } - # output settings to log - Invoke-DscResource ProxySettings -Method Get -ModuleName NetworkingDsc -Property @{ IsSingleInstance = "Yes" } - - Write-Host 'Setting winhttp proxy' - - netsh winhttp set proxy "${ProxyHostname}:3128" "" - - # iotedge-moby needs this variable for `docker pull` - Write-Host 'Setting HTTPS_PROXY in environment' - [Environment]::SetEnvironmentVariable("HTTPS_PROXY", $proxyUri, [EnvironmentVariableTarget]::Machine) - - # Add public key so agent can SSH into this runner - $authorizedKeys = Join-Path ${env:UserProfile} (Join-Path ".ssh" "authorized_keys") - Write-Host "Adding public key to $authorizedKeys" - - New-Item (Split-Path $authorizedKeys -Parent) -ItemType Directory -Force | Out-Null - Add-Content "$sshPublicKey" -Path $authorizedKeys - - # Fix up authorized_keys file permissions - $openSshUtils = ".\$(Split-Path "$OpenSshUtilsManifest" -Leaf)" - Get-WebResource $proxyUri $OpenSshUtilsManifest $openSshUtils - Get-WebResource $proxyUri $OpenSshUtilsModule ".\$(Split-Path "$OpenSshUtilsModule" -Leaf)" - Import-Module $openSshUtils -Force - Repair-AuthorizedKeyPermission $authorizedKeys - - Write-Host 'Installing sshd' - - Add-WindowsCapability -Online -Name OpenSSH.Client~~~~0.0.1.0 - Add-WindowsCapability -Online -Name OpenSSH.Server~~~~0.0.1.0 - - Set-Service -Name ssh-agent -StartupType Automatic - Set-Service -Name sshd -StartupType Automatic - - Write-Host 'Making PowerShell the default shell for ssh' - - New-Item 'HKLM:\SOFTWARE\OpenSSH' -Force | ` - New-ItemProperty -Name "DefaultShell" -Force -Value "$env:SystemRoot\system32\WindowsPowerShell\v1.0\powershell.exe" - - Write-Host 'Starting sshd' - - Start-Service ssh-agent - Start-Service sshd - - # Update sshd_config to look in the right place for authorized_keys - Write-Host 'Updating sshd_config' - - $sshdConfig = "$env:ProgramData\ssh\sshd_config" - $exists = Start-TestPath $sshdConfig -timeoutSecs 30 - if (-not $exists) { - Write-Error "Could not find $sshdConfig, exiting..." - return 1 - } - - $findLine = '^(\s*AuthorizedKeysFile\s+)\.ssh/authorized_keys$' - $replaceLine = "`$1$authorizedKeys" - (Get-Content "$sshdConfig") -replace "$findLine", "$replaceLine" | Out-File -Encoding Utf8 "$sshdConfig" - - $findLine = '^(\s*AuthorizedKeysFile\s+__PROGRAMDATA__/ssh/administrators_authorized_keys)$' - $replaceLine = '#$1' - (Get-Content "$sshdConfig") -replace "$findLine", "$replaceLine" | Out-File -Encoding Utf8 "$sshdConfig" - - Write-Host 'Restarting sshd' - - Restart-Service ssh-agent - Restart-Service sshd - - # Output the host key so it can be added to the agent's known_hosts file - Write-Host -NoNewline '#DATA#' - Get-Content -Encoding Utf8 "$env:ProgramData\ssh\ssh_host_rsa_key.pub" | ForEach-Object { Write-Host -NoNewline $_.Split()[0,1] } - Write-Host -NoNewline '#DATA#' - } - - Export-ModuleMember -Function Initialize-WindowsVM -} diff --git a/builds/e2e/agent_final.sh b/builds/e2e/agent_final.sh deleted file mode 100644 index 8c15b79888f..00000000000 --- a/builds/e2e/agent_final.sh +++ /dev/null @@ -1,75 +0,0 @@ -#!/bin/bash -# usage: ./agent_final.sh [ ]... - -set -euo pipefail - -# Update system-wide known_hosts file so agent can connect to runners - -user="$1" -hosts=( ) -host_key_pair=( ) -id_rsa="$(eval echo ~$user)/.ssh/id_rsa" -suffix="$(grep -Po '^search \K.*' /etc/resolv.conf)" - -touch /etc/ssh/ssh_known_hosts - -for val in "${@:2}"; do - host_key_pair=( "${host_key_pair[@]}" "$val" ) - if [ ${#host_key_pair[@]} -eq 2 ]; then - - set -- "${host_key_pair[@]}" - host_key_pair=( ) - hosts=( "${hosts[@]}" "$1" ) - ipaddr="$(getent hosts "$1" | awk '{ print $1 }')" - - # Remove pre-existing entries for this host - ssh-keygen -R "$1" -f /etc/ssh/ssh_known_hosts - ssh-keygen -R "$1.$suffix" -f /etc/ssh/ssh_known_hosts - - # Append host key to known_hosts - cat <<-EOF >> /etc/ssh/ssh_known_hosts -$1,$ipaddr $2 -$1.$suffix $2 -EOF - - fi -done - -# Test that we really can: -# (1) SSH into the runners, and -# (2) make HTTP/S requests through the proxy - -agent_name=$(hostname) - -for host in "${hosts[@]}"; do - # Linux or Windows - ssh -i "$id_rsa" "$user@$host" uname && os='linux' || os='windows' - if [ "$os" == 'linux' ]; then - echo "Testing Linux runner '$host'" - - # Verify runner can use the proxy - ssh -i "$id_rsa" "$user@$host" curl -x "http://$agent_name:3128" -L 'http://www.microsoft.com' - ssh -i "$id_rsa" "$user@$host" curl -x "http://$agent_name:3128" -L 'https://www.microsoft.com' - - # Verify runner can't skirt the proxy (should time out after 5s) - ssh -i "$id_rsa" "$user@$host" timeout 5 curl -L 'http://www.microsoft.com' && exit 1 || : - ssh -i "$id_rsa" "$user@$host" timeout 5 curl -L 'https://www.microsoft.com' && exit 1 || : - - echo "Linux runner verified." - else # windows - echo "Testing Windows runner '$host'" - - # Verify runner can use the proxy (should succeed) - # **SSH terminal doesn't like the progress bar that Invoke-WebRequest - # tries to display, so use $ProgressPreference='SilentlyContinue' to - # supress it. - ssh -i "$id_rsa" "$user@$host" "\$ProgressPreference='SilentlyContinue'; Invoke-WebRequest -UseBasicParsing -Proxy 'http://$agent_name:3128' 'http://www.microsoft.com'" - ssh -i "$id_rsa" "$user@$host" "\$ProgressPreference='SilentlyContinue'; Invoke-WebRequest -UseBasicParsing -Proxy 'http://$agent_name:3128' 'https://www.microsoft.com'" - - # Verify runner can't skirt the proxy (should time out after 5s) - ssh -i "$id_rsa" "$user@$host" "Invoke-WebRequest -UseBasicParsing -TimeoutSec 5 'http://www.microsoft.com'" && exit 1 || : - ssh -i "$id_rsa" "$user@$host" "Invoke-WebRequest -UseBasicParsing -TimeoutSec 5 'https://www.microsoft.com'" && exit 1 || : - - echo "Windows runner verified." - fi -done diff --git a/builds/e2e/create-windows-vm-template.json b/builds/e2e/create-windows-vm-template.json deleted file mode 100644 index 491ae3eeba5..00000000000 --- a/builds/e2e/create-windows-vm-template.json +++ /dev/null @@ -1,154 +0,0 @@ -{ - "$schema": "https://schema.management.azure.com/schemas/2015-01-01/deploymentTemplate.json#", - "contentVersion": "1.0.0.0", - "parameters": { - "admin_password": { - "type": "securestring" - }, - "admin_user": { - "type": "string" - }, - "extension_command": { - "type": "securestring" - }, - "nic_name": { - "type": "string" - }, - "nsg_id": { - "type": "string" - }, - "ssh_public_key": { - "type": "string" - }, - "vm_name": { - "type": "string" - }, - "vm_size": { - "type": "string" - }, - "vnet_subnet_id": { - "type": "string" - }, - "ip_addr_name": { - "defaultValue": "!", - "metadata": { - "description": "The name of the Public IP Address resource to create. The default value is the string '!', which is an invalid Public IP Address name and signals that a Public IP Address resource will NOT be created." - }, - "type": "string" - } - }, - "variables": { - "ip_addr_id": { - "id": "[resourceId('Microsoft.Network/publicIPAddresses', parameters('ip_addr_name'))]" - } - }, - "resources": [{ - "condition": "[not(equals(parameters('ip_addr_name'), '!'))]", - "type": "Microsoft.Network/publicIPAddresses", - "name": "[parameters('ip_addr_name')]", - "apiVersion": "2018-02-01", - "sku": { - "name": "Basic", - "tier": "Regional" - }, - "location": "[resourceGroup().location]", - "properties": { - "publicIPAddressVersion": "IPv4", - "publicIPAllocationMethod": "Dynamic" - }, - "dependsOn": [] - }, { - "type": "Microsoft.Network/networkInterfaces", - "name": "[parameters('nic_name')]", - "apiVersion": "2018-02-01", - "location": "[resourceGroup().location]", - "properties": { - "ipConfigurations": [{ - "name": "ipconfig1", - "properties": { - "subnet": { - "id": "[parameters('vnet_subnet_id')]" - }, - "privateIPAllocationMethod": "Dynamic", - "publicIPAddress": "[if(equals(parameters('ip_addr_name'), '!'), json('null'), variables('ip_addr_id'))]" - } - }], - "networkSecurityGroup": { - "id": "[parameters('nsg_id')]" - }, - "primary": true - }, - "dependsOn": [ - "[parameters('ip_addr_name')]" - ] - }, { - "type": "Microsoft.Compute/virtualMachines", - "name": "[parameters('vm_name')]", - "apiVersion": "2017-12-01", - "location": "[resourceGroup().location]", - "properties": { - "osProfile": { - "computerName": "[parameters('vm_name')]", - "adminUsername": "[parameters('admin_user')]", - "adminPassword": "[parameters('admin_password')]", - "windowsConfiguration": { - "enableAutomaticUpdates": true - } - }, - "hardwareProfile": { - "vmSize": "[parameters('vm_size')]" - }, - "storageProfile": { - "imageReference": { - "publisher": "MicrosoftWindowsServer", - "offer": "WindowsServer", - "sku": "2019-Datacenter-Core", - "version": "latest" - }, - "osDisk": { - "osType": "Windows", - "createOption": "FromImage", - "managedDisk": { - "storageAccountType": "Standard_LRS" - } - }, - "dataDisks": [] - }, - "networkProfile": { - "networkInterfaces": [{ - "id": "[resourceId('Microsoft.Network/networkInterfaces', parameters('nic_name'))]", - "properties": { - "primary": true - } - }] - } - }, - "dependsOn": [ - "[resourceId('Microsoft.Network/networkInterfaces', parameters('nic_name'))]" - ] - }, { - "type": "Microsoft.Compute/virtualMachines/extensions", - "name": "[concat(parameters('vm_name'), '/', 'setup')]", - "apiVersion": "2018-06-01", - "location": "[resourceGroup().location]", - "properties": { - "publisher": "Microsoft.Compute", - "type": "CustomScriptExtension", - "typeHandlerVersion": "1.9", - "autoUpgradeMinorVersion": true, - "settings": {}, - "protectedSettings": { - "commandToExecute": "[parameters('extension_command')]" - } - }, - "dependsOn": [ - "[resourceId('Microsoft.Compute/virtualMachines/', parameters('vm_name'))]" - ] - }], - "outputs": { - "hostkey": { - "type": "array", - "value": "[take(skip(split(string(reference(resourceId('Microsoft.Compute/virtualMachines/extensions', parameters('vm_name'), 'setup')).instanceView.substatuses[0].message), '#DATA#'), 1), 1)]" - } - } -} \ No newline at end of file diff --git a/builds/e2e/e2e.yaml b/builds/e2e/e2e.yaml index 8e4125ac270..55692e10eef 100644 --- a/builds/e2e/e2e.yaml +++ b/builds/e2e/e2e.yaml @@ -63,26 +63,50 @@ jobs: - template: templates/e2e-setup.yaml - template: templates/e2e-run.yaml -################################################################################ - - job: centos7_amd64 -################################################################################ - displayName: CentOs7 amd64 - - pool: - name: $(pool.name) - demands: - - Agent.OS -equals Linux - - Agent.OSArchitecture -equals X64 - - run-new-e2e-tests -equals true - - variables: - os: linux - arch: amd64 - artifactName: iotedged-centos7-amd64 - - steps: - - template: templates/e2e-clean-directory.yaml - - template: templates/e2e-setup.yaml - - template: templates/e2e-clear-docker-cached-images.yaml +################################################################################ + - job: centos7_amd64 +################################################################################ + displayName: CentOs7 amd64 + + pool: + name: $(pool.name) + demands: + - Agent.OS -equals Linux + - Agent.OSArchitecture -equals X64 + - run-new-e2e-tests -equals true + + variables: + os: linux + arch: amd64 + artifactName: iotedged-centos7-amd64 + + steps: + - template: templates/e2e-clean-directory.yaml + - template: templates/e2e-setup.yaml + - template: templates/e2e-clear-docker-cached-images.yaml - template: templates/e2e-run.yaml +################################################################################ + - job: linux_amd64_proxy +################################################################################ + displayName: Linux amd64 behind a proxy + + pool: + name: $(pool.name) + demands: new-e2e-proxy + + variables: + os: linux + arch: amd64 + artifactName: iotedged-ubuntu18.04-amd64 + # workaround, see https://github.com/Microsoft/azure-pipelines-agent/issues/2138#issuecomment-470166671 + 'agent.disablelogplugin.testfilepublisherplugin': true + 'agent.disablelogplugin.testresultlogplugin': true + + timeoutInMinutes: 120 + + steps: + - template: templates/e2e-clean-directory.yaml + - template: templates/e2e-setup.yaml + - template: templates/e2e-clear-docker-cached-images.yaml + - template: templates/e2e-run.yaml diff --git a/builds/e2e/finalize-agent-template.json b/builds/e2e/finalize-agent-template.json deleted file mode 100644 index 818c981c041..00000000000 --- a/builds/e2e/finalize-agent-template.json +++ /dev/null @@ -1,28 +0,0 @@ -{ - "$schema": "https://schema.management.azure.com/schemas/2015-01-01/deploymentTemplate.json#", - "contentVersion": "1.0.0.0", - "parameters": { - "extension_command": { - "type": "securestring" - }, - "vm_name": { - "type": "string" - } - }, - "resources": [{ - "type": "Microsoft.Compute/virtualMachines/extensions", - "name": "[concat(parameters('vm_name'), '/', 'setup')]", - "apiVersion": "2018-10-01", - "location": "[resourceGroup().location]", - "properties": { - "publisher": "Microsoft.Azure.Extensions", - "type": "CustomScript", - "typeHandlerVersion": "2.0", - "autoUpgradeMinorVersion": true, - "settings": {}, - "protectedSettings": { - "commandToExecute": "[parameters('extension_command')]" - } - } - }] -} \ No newline at end of file diff --git a/builds/e2e/proxy.md b/builds/e2e/proxy.md deleted file mode 100644 index e9c7354d717..00000000000 --- a/builds/e2e/proxy.md +++ /dev/null @@ -1,155 +0,0 @@ -This file documents how to set up a proxy environment in Azure for our E2E tests. - -The overall setup includes three VMs: - -- The "agent VM" has full network connectivity and runs the SSH tasks defined in the E2E proxy VSTS job. It also runs an HTTP proxy (squid). - -- A Linux "runner VM" runs the proxy tests themselves, on Linux. It has no network connectivity except to talk to the agent VM, thus all its interactions with Azure IoT Hub need to occur through the squid proxy on the agent VM. - -- A Windows "runner VM" serves the same purpose as the Linux runner, but on Windows. - -Follow the steps below to deploy the three VMs and set them up. The steps are in bash, but there are notes at the bottom about doing the same thing in PowerShell. In both cases, the Azure CLI `az` is required. If the deployment completes successfully, that means the environment is set up, the agent can reach the runners via SSH, and the runners can't reach the internet without the proxy. - -```sh -cd ./builds/e2e/ - -# ---------- -# Parameters - - -# Name of Azure subscription -subscription_name='<>' - -# Location of the resource group -location='<>' - -# Name of the resource group -resource_group_name='<>' - -# Name of the key vault to store secrets for this deployment -key_vault_name='<>' - -# AAD Object ID for a user or group who will be given access to the secrets in this key vault -key_vault_access_objectid='<>' - -# Name of the Azure virtual network to which all VMs will attach. -vms_vnet_name='<>' - -# The address prefix (in CIDR notation) of the virtual network/subnet -vms_vnet_address_prefix='<>' - -# Name of the user for the VMs -vms_username='vsts' - -# Name of the subnet within the virtual network -vms_vnet_subnet_name='default' - -# Names of the agent and runner VMs. Used to resolve them via DNS for the tests. -vsts_agent_vm_name='e2eproxyvstsagent' -vsts_runner1_vm_name='e2eproxyvstsrunner1' -# This will be a windows machine name so, e.g., must be <= 15 chars -vsts_runner2_vm_name='e2eproxyrunner2' - -# Name of the Windows VM admin password secret in key vault -key_vault_secret_name='windows-vm-admin-password' - -# Windows VM admin password -windows_vm_password="$(openssl rand -base64 32)" - -# ------- -# Execute - - -# Create SSH key for the VMs -keyfile="$(realpath ./id_rsa)" -ssh-keygen -t rsa -b 4096 -N '' -f "$keyfile" - - -# Create an SSH service connection in VSTS using $vsts_agent_vm_name and $keyfile - - -# Log in to Azure subscription -az login -az account set -s "$subscription_name" - - -# If the resource group doesn't already exist, create it -az group create -l "$location" -n "$resource_group_name" - - -# Deploy the VMs -az group deployment create --resource-group "$resource_group_name" --name 'e2e-proxy' --template-file ./proxy-deployment-template.json --parameters "$( - jq -n \ - --arg key_vault_access_objectid "$key_vault_access_objectid" \ - --arg key_vault_name "$key_vault_name" \ - --arg key_vault_secret_name "$key_vault_secret_name" \ - --arg vms_ssh_key_encoded "$(base64 -w 0 $keyfile)" \ - --arg vms_ssh_public_key "$(cat $keyfile.pub)" \ - --arg vms_username "$vms_username" \ - --arg vms_vnet_address_prefix "$vms_vnet_address_prefix" \ - --arg vms_vnet_name "$vms_vnet_name" \ - --arg vms_vnet_subnet_name "$vms_vnet_subnet_name" \ - --arg vsts_agent_vm_name "$vsts_agent_vm_name" \ - --arg vsts_runner1_vm_name "$vsts_runner1_vm_name" \ - --arg vsts_runner2_vm_name "$vsts_runner2_vm_name" \ - --arg windows_vm_password "$windows_vm_password" \ - '{ - "key_vault_access_objectid": { "value": $key_vault_access_objectid }, - "key_vault_name": { "value" : $key_vault_name }, - "key_vault_secret_name": { "value": $key_vault_secret_name }, - "vms_ssh_key_encoded": { "value": $vms_ssh_key_encoded }, - "vms_ssh_public_key": { "value": $vms_ssh_public_key }, - "vms_username": { "value": $vms_username }, - "vms_vnet_address_prefix": { "value": $vms_vnet_address_prefix }, - "vms_vnet_name": { "value": $vms_vnet_name }, - "vms_vnet_subnet_name": { "value": $vms_vnet_subnet_name }, - "vsts_agent_vm_name": { "value": $vsts_agent_vm_name }, - "vsts_runner1_vm_name": { "value": $vsts_runner1_vm_name }, - "vsts_runner2_vm_name": { "value": $vsts_runner2_vm_name }, - "windows_vm_password": { "value": $windows_vm_password } - }' -)" -``` - -## PowerShell notes: - -Variable assignments are the same, except that the variable names should be prefixed with '$', e.g.: - -```PowerShell -# Name of Azure subscription -$subscription_name='<>' -# ... -``` - -To create the Windows VM administrator password without openssl, use the following call: - -```PowerShell -Add-Type -AssemblyName System.Web -$windows_vm_password=$([Convert]::ToBase64String([System.Web.Security.Membership]::GeneratePassword(32, 3).ToCharArray(), 0)) -``` - -The commands to create an SSH key for the VMs are a little different. On Windows 1809 or later, install the ssh-agent feature first; more information [here](https://docs.microsoft.com/en-us/windows-server/administration/openssh/openssh_install_firstuse). - -```PowerShell -$keyfile=$(Join-Path (pwd).Path id_rsa) -ssh-keygen -t rsa -b 4096 -f "$keyfile" --% -N "" -``` - -The command to deploy the VMs is different. It doesn't use jq, and the base64-encoding command is PowerShell-specific: - -```PowerShell -az group deployment create --resource-group "$resource_group_name" --name 'e2e-proxy' --template-file ./proxy-deployment-template.json --parameters ` - key_vault_access_objectid="$key_vault_access_objectid" ` - key_vault_name="$key_vault_name" ` - key_vault_secret_name="$key_vault_secret_name" ` - vms_ssh_key_encoded="$([System.Convert]::ToBase64String([System.Text.Encoding]::Utf8.GetBytes($(Get-Content "$keyfile" -Raw))))" ` - vms_ssh_public_key="$(cat "$keyfile.pub")" ` - vms_username="$vms_username" ` - vms_vnet_address_prefix="$vms_vnet_address_prefix" ` - vms_vnet_name="$vms_vnet_name" ` - vms_vnet_subnet_name="$vms_vnet_subnet_name" ` - vsts_agent_vm_name="$vsts_agent_vm_name" ` - vsts_runner1_vm_name="$vsts_runner1_vm_name" ` - vsts_runner2_vm_name="$vsts_runner2_vm_name" ` - windows_vm_password=$windows_vm_password -``` diff --git a/builds/e2e/agent.sh b/builds/e2e/proxy/configure_proxy.sh old mode 100644 new mode 100755 similarity index 72% rename from builds/e2e/agent.sh rename to builds/e2e/proxy/configure_proxy.sh index 51b8acee24e..7444d7b75f1 --- a/builds/e2e/agent.sh +++ b/builds/e2e/proxy/configure_proxy.sh @@ -3,26 +3,18 @@ set -euo pipefail user="$1" -encoded_key="$2" -subnet_address_prefix="$3" +subnet_address_prefix="$2" -# set up SSH private key -echo "Creating SSH private key for user '$user'" +echo 'Installing squid' -home="$(eval echo ~$user)" -mkdir -p "$home/.ssh" -echo -e "$encoded_key" | base64 -d > "$home/.ssh/id_rsa" -chown -R "$user:$user" "$home/.ssh" -chmod 700 "$home/.ssh" -chmod 600 "$home/.ssh/id_rsa" +for i in `seq 3` +do + sleep 5 + apt-get update || continue + apt-get install -y jq squid && break +done -# install/configure squid -echo "Installing squid" - -apt-get update -apt-get install -y jq squid - -echo "Configuring squid" +echo 'Configuring squid' > ~/squid.conf cat <<-EOF acl localnet src $subnet_address_prefix diff --git a/builds/e2e/proxy/configure_runner.sh b/builds/e2e/proxy/configure_runner.sh new file mode 100755 index 00000000000..4890f4d6b99 --- /dev/null +++ b/builds/e2e/proxy/configure_runner.sh @@ -0,0 +1,51 @@ +#!/bin/bash + +set -euo pipefail + +proxy="http://${1}:3128" + +export http_proxy=$proxy +export https_proxy=$proxy + +echo 'Installing PowerShell Core and .NET Core 3.1' + +apt-get update +apt-get install -y curl git wget apt-transport-https +wget -q 'https://packages.microsoft.com/config/ubuntu/18.04/packages-microsoft-prod.deb' +dpkg -i packages-microsoft-prod.deb +apt-get update +add-apt-repository universe +apt-get install -y powershell dotnet-sdk-3.1 + +echo 'Installing Moby engine' + +curl -x $proxy 'https://packages.microsoft.com/config/ubuntu/18.04/multiarch/prod.list' > microsoft-prod.list +mv microsoft-prod.list /etc/apt/sources.list.d/ + +curl -x $proxy 'https://packages.microsoft.com/keys/microsoft.asc' | gpg --dearmor > microsoft.gpg +mv microsoft.gpg /etc/apt/trusted.gpg.d/ + +apt-get update +apt-get install -y moby-engine + +> ~/proxy-env.override.conf cat <<-EOF +[Service] +Environment="http_proxy=$proxy" +Environment="https_proxy=$proxy" +EOF +mkdir -p /etc/systemd/system/docker.service.d/ +cp ~/proxy-env.override.conf /etc/systemd/system/docker.service.d/ + +systemctl daemon-reload +systemctl restart docker + +# add iotedged's proxy settings (even though iotedged isn't installed--the tests do that later) +mkdir -p /etc/systemd/system/iotedge.service.d +cp ~/proxy-env.override.conf /etc/systemd/system/iotedge.service.d/ + +echo 'Verifying VM behavior behind proxy server' + +# Verify runner can't skirt the proxy (should time out after 5s) +unset http_proxy https_proxy +timeout 5s curl -L 'http://www.microsoft.com' && exit 1 || : +timeout 5s curl -L 'https://www.microsoft.com' && exit 1 || : diff --git a/builds/e2e/create-linux-vm-template.json b/builds/e2e/proxy/create-linux-vm-template.json similarity index 88% rename from builds/e2e/create-linux-vm-template.json rename to builds/e2e/proxy/create-linux-vm-template.json index 4c75a6cdc41..a6f6153317c 100644 --- a/builds/e2e/create-linux-vm-template.json +++ b/builds/e2e/proxy/create-linux-vm-template.json @@ -1,12 +1,12 @@ { - "$schema": "https://schema.management.azure.com/schemas/2015-01-01/deploymentTemplate.json#", + "$schema": "https://schema.management.azure.com/schemas/2019-04-01/deploymentTemplate.json#", "contentVersion": "1.0.0.0", "parameters": { "admin_user": { "type": "string" }, "extension_command": { - "type": "securestring" + "type": "string" }, "nic_name": { "type": "string" @@ -43,7 +43,7 @@ "condition": "[not(equals(parameters('ip_addr_name'), '!'))]", "type": "Microsoft.Network/publicIPAddresses", "name": "[parameters('ip_addr_name')]", - "apiVersion": "2018-02-01", + "apiVersion": "2020-05-01", "sku": { "name": "Basic", "tier": "Regional" @@ -57,7 +57,7 @@ }, { "type": "Microsoft.Network/networkInterfaces", "name": "[parameters('nic_name')]", - "apiVersion": "2018-02-01", + "apiVersion": "2020-05-01", "location": "[resourceGroup().location]", "properties": { "ipConfigurations": [{ @@ -81,7 +81,7 @@ }, { "type": "Microsoft.Compute/virtualMachines", "name": "[parameters('vm_name')]", - "apiVersion": "2017-12-01", + "apiVersion": "2020-06-01", "location": "[resourceGroup().location]", "properties": { "osProfile": { @@ -130,7 +130,7 @@ }, { "type": "Microsoft.Compute/virtualMachines/extensions", "name": "[concat(parameters('vm_name'), '/', 'setup')]", - "apiVersion": "2018-10-01", + "apiVersion": "2020-06-01", "location": "[resourceGroup().location]", "properties": { "publisher": "Microsoft.Azure.Extensions", @@ -145,11 +145,5 @@ "dependsOn": [ "[resourceId('Microsoft.Compute/virtualMachines/', parameters('vm_name'))]" ] - }], - "outputs": { - "hostkey": { - "type": "array", - "value": "[take(skip(split(string(reference(resourceId('Microsoft.Compute/virtualMachines/extensions', parameters('vm_name'), 'setup')).instanceView.statuses[0].message), '#DATA#'), 1), 1)]" - } - } + }] } \ No newline at end of file diff --git a/builds/e2e/proxy/create_ssh_keys.sh b/builds/e2e/proxy/create_ssh_keys.sh new file mode 100755 index 00000000000..de7fd1c7f4f --- /dev/null +++ b/builds/e2e/proxy/create_ssh_keys.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +[ -n "$1" ] || exit 1 +(($1 >= 1)) || exit 1 + +num_keypairs=$1 + +for i in `seq $num_keypairs` +do + ssh-keygen -t rsa -b 4096 -N '' -f "id_rsa$i" || exit 1 +done + +comma=',' +echo '{"keyinfo": [' > $AZ_SCRIPTS_OUTPUT_PATH +for i in `seq $num_keypairs` +do + if ((i == $num_keypairs)); then comma=''; fi + json="{\"privateKey\":\"$(cat id_rsa$i)\",\"publicKey\":\"$(cat id_rsa$i.pub)\"}$comma" + echo "$json" >> $AZ_SCRIPTS_OUTPUT_PATH +done +echo ']}' >> $AZ_SCRIPTS_OUTPUT_PATH diff --git a/builds/e2e/proxy-deployment-template.json b/builds/e2e/proxy/proxy-deployment-template.json similarity index 50% rename from builds/e2e/proxy-deployment-template.json rename to builds/e2e/proxy/proxy-deployment-template.json index 25b436522f2..e74c3452205 100644 --- a/builds/e2e/proxy-deployment-template.json +++ b/builds/e2e/proxy/proxy-deployment-template.json @@ -1,96 +1,66 @@ { - "$schema": "https://schema.management.azure.com/schemas/2015-01-01/deploymentTemplate.json#", + "$schema": "https://schema.management.azure.com/schemas/2019-04-01/deploymentTemplate.json#", "contentVersion": "1.0.0.0", "parameters": { - "key_vault_access_objectid": { - "type": "string" - }, - "key_vault_name": { - "type": "string" - }, - "key_vault_secret_name": { - "type": "string" - }, - "vms_ssh_key_encoded": { - "type": "securestring" - }, - "vms_ssh_public_key": { - "type": "string" - }, - "vms_username": { - "type": "string" - }, - "vms_vnet_address_prefix": { - "type": "string" - }, - "vms_vnet_name": { - "type": "string" - }, - "vms_vnet_subnet_name": { - "type": "string" - }, - "vsts_agent_vm_name": { - "type": "string" - }, - "vsts_runner1_vm_name": { - "type": "string" + "resource_prefix": { + "type": "string", + "defaultValue": "[concat('e2e-', uniqueString(resourceGroup().id), '-')]" }, - "vsts_runner2_vm_name": { - "type": "string" - }, - "windows_vm_password": { - "type": "securestring" + "runner_count": { + "type": "int", + "defaultValue": 1 }, - "vsts_agent_vm_network_interface_name": { - "defaultValue": "e2eproxyvstsagent", - "type": "string" - }, - "vsts_agent_vm_nsg_name": { - "defaultValue": "e2eproxyvstsagent", - "type": "string" - }, - "vsts_agent_vm_public_nsg_name": { - "defaultValue": "e2eproxyvstsagentpublic", + "key_vault_access_objectid": { "type": "string" }, - "vsts_agent_vm_size": { + "proxy_vm_size": { "defaultValue": "Standard_DS1_v2", "type": "string" }, - "vsts_runner1_vm_network_interface_name": { - "defaultValue": "e2eproxyvstsrunner1", + "runner_vm_size": { + "defaultValue": "Standard_D2s_v3", "type": "string" }, - "vsts_runner1_vm_size": { - "defaultValue": "Standard_D2s_v3", + "create_runner_public_ip": { + "defaultValue": "false", + "type": "bool" + }, + "linux_vm_creation_template_uri": { + "defaultValue": "https://raw.githubusercontent.com/Azure/iotedge/master/builds/e2e/proxy/create-linux-vm-template.json", "type": "string" }, - "vsts_runner2_vm_network_interface_name": { - "defaultValue": "e2eproxyvstsrunner2", + "proxy_config_script_uri": { + "defaultValue": "https://raw.githubusercontent.com/Azure/iotedge/master/builds/e2e/proxy/configure_proxy.sh", "type": "string" }, - "vsts_runner2_vm_size": { - "defaultValue": "Standard_D2s_v3", + "runner_config_script_uri": { + "defaultValue": "https://raw.githubusercontent.com/Azure/iotedge/master/builds/e2e/proxy/configure_runner.sh", "type": "string" }, - "vsts_runner_vms_nsg_name": { - "defaultValue": "e2eproxyvstsrunners", + "create_ssh_keys_script_uri": { + "defaultValue": "https://raw.githubusercontent.com/Azure/iotedge/master/builds/e2e/proxy/create_ssh_keys.sh", "type": "string" } }, "variables": { - "agent_prep1_script_uri": "https://raw.githubusercontent.com/Azure/iotedge/master/builds/e2e/agent.sh", - "agent_prep2_script_uri": "https://raw.githubusercontent.com/Azure/iotedge/master/builds/e2e/agent_final.sh", - "create_linux_vm_template_uri": "https://raw.githubusercontent.com/Azure/iotedge/master/builds/e2e/create-linux-vm-template.json", - "create_windows_vm_template_uri": "https://raw.githubusercontent.com/Azure/iotedge/master/builds/e2e/create-windows-vm-template.json", - "finalize_agent_template_uri": "https://raw.githubusercontent.com/Azure/iotedge/master/builds/e2e/finalize-agent-template.json", - "runner1_prep_script_uri": "https://raw.githubusercontent.com/Azure/iotedge/master/builds/e2e/runner.sh", - "runner2_prep_script_uri": "https://raw.githubusercontent.com/Azure/iotedge/master/builds/e2e/Runner.ps1" + "key_vault_name": "[concat(parameters('resource_prefix'), 'kv')]", + "contributor_id": "[subscriptionResourceId('Microsoft.Authorization/roleDefinitions', 'b24988ac-6180-42a0-ab88-20f7382dd24c')]", + "keygen_identity_name": "[concat(parameters('resource_prefix'), 'uaid')]", + "keygen_script_name": "create_ssh_keys", + "role_definition_name": "[guid(variables('keygen_identity_name'), variables('contributor_id'))]", + "proxy_vm_name": "[concat(parameters('resource_prefix'), 'proxy-vm')]", + "proxy_nsg_name": "[concat(parameters('resource_prefix'), 'proxy-nsg')]", + "runner_prefix": "[concat(parameters('resource_prefix'), 'runner')]", + "runner_nsg_name": "[concat(parameters('resource_prefix'), 'runner-nsg')]", + "vnet_name": "[concat(parameters('resource_prefix'), 'vnet')]", + "vnet_address_prefix": "10.0.0.0/24", + "subnet_name": "default", + "username": "azureuser" }, "resources": [{ "type": "Microsoft.KeyVault/vaults", - "name": "[parameters('key_vault_name')]", - "apiVersion": "2018-02-14", + "name": "[variables('key_vault_name')]", + "apiVersion": "2019-09-01", "location": "[resourceGroup().location]", "properties": { "sku": { @@ -115,31 +85,92 @@ }, "dependsOn": [] }, { - "type": "Microsoft.KeyVault/vaults/secrets", - "name": "[concat(parameters('key_vault_name'), '/', parameters('key_vault_secret_name'))]", - "apiVersion": "2018-02-14", + "type": "Microsoft.ManagedIdentity/userAssignedIdentities", + "name": "[variables('keygen_identity_name')]", + "apiVersion": "2018-11-30", + "location": "[resourceGroup().location]", + "dependsOn": [] + }, { + "type": "Microsoft.Authorization/roleAssignments", + "name": "[variables('role_definition_name')]", + "apiVersion": "2020-04-01-preview", + "properties": { + "roleDefinitionId": "[variables('contributor_id')]", + "principalId": "[reference(variables('keygen_identity_name')).principalId]", + "principalType": "ServicePrincipal", + "scope": "[resourceGroup().id]" + }, + "dependsOn": [ + "[variables('keygen_identity_name')]" + ] + }, { + "type": "Microsoft.Resources/deploymentScripts", + "name": "[variables('keygen_script_name')]", + "apiVersion": "2019-10-01-preview", + "location": "[resourceGroup().location]", + "identity": { + "type": "UserAssigned", + "userAssignedIdentities": { + "[resourceId('Microsoft.ManagedIdentity/userAssignedIdentities', variables('keygen_identity_name'))]": {} + } + }, + "kind": "AzureCLI", + "properties": { + "azCliVersion": "2.9.1", + "cleanupPreference": "OnSuccess", + "arguments": "[add(parameters('runner_count'), 1)]", + "primaryScriptUri": "[parameters('create_ssh_keys_script_uri')]", + "timeout": "PT30M", + "retentionInterval": "P1D" + }, + "dependsOn": [ + "[variables('role_definition_name')]" + ] + }, { + "type": "Microsoft.Resources/deployments", + "name": "[concat('store_ssh_keys', copyIndex())]", + "apiVersion": "2020-06-01", + "copy": { + "name": "store_ssh_keys_copy", + "count": "[add(parameters('runner_count'), 1)]" + }, "properties": { - "value": "[parameters('windows_vm_password')]" + "mode": "Incremental", + "template": { + "$schema": "https://schema.management.azure.com/schemas/2019-04-01/deploymentTemplate.json#", + "contentVersion": "1.0.0.0", + "resources": [ + { + "type": "Microsoft.KeyVault/vaults/secrets", + "name": "[if(equals(copyIndex(), 0), concat(variables('key_vault_name'), '/', variables('proxy_vm_name'), '-ssh-private-key'), concat(variables('key_vault_name'), '/', variables('runner_prefix'), copyIndex(), '-ssh-private-key'))]", + "apiVersion": "2019-09-01", + "properties": { + "value": "[reference(variables('keygen_script_name')).outputs.keyinfo[copyIndex()].privateKey]" + } + } + ] + } }, "dependsOn": [ - "[concat('Microsoft.KeyVault/vaults', '/', parameters('key_vault_name'))]" + "[variables('key_vault_name')]", + "[variables('keygen_script_name')]" ] - },{ + }, { "type": "Microsoft.Network/virtualNetworks", - "name": "[parameters('vms_vnet_name')]", - "apiVersion": "2018-10-01", + "name": "[variables('vnet_name')]", + "apiVersion": "2020-05-01", "location": "[resourceGroup().location]", "properties": { "addressSpace": { "addressPrefixes": [ - "[parameters('vms_vnet_address_prefix')]" + "[variables('vnet_address_prefix')]" ] }, "subnets": [ { - "name": "[parameters('vms_vnet_subnet_name')]", + "name": "[variables('subnet_name')]", "properties": { - "addressPrefix": "[parameters('vms_vnet_address_prefix')]" + "addressPrefix": "[variables('vnet_address_prefix')]" } } ] @@ -147,8 +178,8 @@ "dependsOn": [] }, { "type": "Microsoft.Network/networkSecurityGroups", - "name": "[parameters('vsts_agent_vm_nsg_name')]", - "apiVersion": "2018-02-01", + "name": "[variables('proxy_nsg_name')]", + "apiVersion": "2020-05-01", "location": "[resourceGroup().location]", "scale": null, "properties": { @@ -235,8 +266,8 @@ "dependsOn": [] }, { "type": "Microsoft.Network/networkSecurityGroups", - "name": "[parameters('vsts_runner_vms_nsg_name')]", - "apiVersion": "2018-02-01", + "name": "[variables('runner_nsg_name')]", + "apiVersion": "2020-05-01", "location": "[resourceGroup().location]", "scale": null, "properties": { @@ -308,151 +339,93 @@ "dependsOn": [] }, { "type": "Microsoft.Resources/deployments", - "name": "create_agent_vm", - "apiVersion": "2018-05-01", + "name": "create_proxy_vm", + "apiVersion": "2020-06-01", "properties": { "mode": "Incremental", "templateLink": { - "uri": "[variables('create_linux_vm_template_uri')]" + "uri": "[parameters('linux_vm_creation_template_uri')]" }, "parameters": { "admin_user": { - "value": "[parameters('vms_username')]" + "value": "[variables('username')]" }, "extension_command": { - "value": "[concat('/bin/bash -c \"set -euo pipefail && curl ', variables('agent_prep1_script_uri'), ' | sudo bash -s -- ', parameters('vms_username'), ' ', parameters('vms_ssh_key_encoded'), ' ', reference(resourceId('Microsoft.Network/virtualNetworks/subnets', parameters('vms_vnet_name'), parameters('vms_vnet_subnet_name')), '2018-08-01').addressPrefix, '\"')]" + "value": "[concat('/bin/bash -c \"set -euo pipefail && curl ', parameters('proxy_config_script_uri'), ' | sudo bash -s -- ', variables('username'), ' ', reference(resourceId('Microsoft.Network/virtualNetworks/subnets', variables('vnet_name'), variables('subnet_name')), '2018-08-01').addressPrefix, '\"')]" }, "nic_name": { - "value": "[parameters('vsts_agent_vm_network_interface_name')]" + "value": "[concat(parameters('resource_prefix'), 'proxy-nic')]" }, "nsg_id": { - "value": "[resourceId('Microsoft.Network/networkSecurityGroups', parameters('vsts_agent_vm_nsg_name'))]" + "value": "[resourceId('Microsoft.Network/networkSecurityGroups', variables('proxy_nsg_name'))]" }, "ssh_public_key": { - "value": "[parameters('vms_ssh_public_key')]" + "value": "[reference(variables('keygen_script_name')).outputs.keyinfo[0].publicKey]" }, "vm_name": { - "value": "[parameters('vsts_agent_vm_name')]" + "value": "[variables('proxy_vm_name')]" }, "vm_size": { - "value": "[parameters('vsts_agent_vm_size')]" + "value": "[parameters('proxy_vm_size')]" }, "vnet_subnet_id": { - "value": "[resourceId('Microsoft.Network/virtualNetworks/subnets', parameters('vms_vnet_name'), parameters('vms_vnet_subnet_name'))]" + "value": "[resourceId('Microsoft.Network/virtualNetworks/subnets', variables('vnet_name'), variables('subnet_name'))]" } } }, "dependsOn": [ - "[parameters('vsts_agent_vm_nsg_name')]", - "[parameters('vms_vnet_name')]" + "[variables('proxy_nsg_name')]", + "[variables('vnet_name')]", + "[variables('keygen_script_name')]" ] }, { "type": "Microsoft.Resources/deployments", - "name": "create_runner1_vm", - "apiVersion": "2018-05-01", + "name": "[concat('create_runner_vm', copyIndex(1))]", + "apiVersion": "2020-06-01", "properties": { "mode": "Incremental", "templateLink": { - "uri": "[variables('create_linux_vm_template_uri')]" + "uri": "[parameters('linux_vm_creation_template_uri')]" }, "parameters": { "admin_user": { - "value": "[parameters('vms_username')]" + "value": "[variables('username')]" }, "extension_command": { - "value": "[concat('/bin/bash -c \"set -euo pipefail && curl -x ', parameters('vsts_agent_vm_name'), ':3128 ', variables('runner1_prep_script_uri'), ' | sudo bash -s -- ', parameters('vsts_agent_vm_name'), ' ', parameters('vms_username'), '\"')]" + "value": "[concat('/bin/bash -c \"set -euo pipefail && curl -x ', variables('proxy_vm_name'), ':3128 ', parameters('runner_config_script_uri'), ' | sudo bash -s -- ', variables('proxy_vm_name'), '\"')]" }, "nic_name": { - "value": "[parameters('vsts_runner1_vm_network_interface_name')]" + "value": "[concat(variables('runner_prefix'), copyIndex(1), '-nic')]" }, "nsg_id": { - "value": "[resourceId('Microsoft.Network/networkSecurityGroups', parameters('vsts_runner_vms_nsg_name'))]" + "value": "[resourceId('Microsoft.Network/networkSecurityGroups', variables('runner_nsg_name'))]" }, "ssh_public_key": { - "value": "[parameters('vms_ssh_public_key')]" + "value": "[reference(variables('keygen_script_name')).outputs.keyinfo[copyIndex(1)].publicKey]" }, "vm_name": { - "value": "[parameters('vsts_runner1_vm_name')]" + "value": "[concat(variables('runner_prefix'), copyIndex(1), '-vm')]" }, "vm_size": { - "value": "[parameters('vsts_runner1_vm_size')]" + "value": "[parameters('runner_vm_size')]" }, "vnet_subnet_id": { - "value": "[resourceId('Microsoft.Network/virtualNetworks/subnets', parameters('vms_vnet_name'), parameters('vms_vnet_subnet_name'))]" - } - } - }, - "dependsOn": [ - "[parameters('vsts_runner_vms_nsg_name')]", - "[parameters('vms_vnet_name')]", - "create_agent_vm" - ] - }, { - "type": "Microsoft.Resources/deployments", - "name": "create_runner2_vm", - "apiVersion": "2018-05-01", - "properties": { - "mode": "Incremental", - "templateLink": { - "uri": "[variables('create_windows_vm_template_uri')]" - }, - "parameters": { - "admin_password": { - "value": "[parameters('windows_vm_password')]" - }, - "admin_user": { - "value": "[parameters('vms_username')]" - }, - "extension_command": { - "value": "[concat('powershell -ExecutionPolicy Unrestricted -Command \"& { . { Invoke-WebRequest -UseBasicParsing -Proxy http://', parameters('vsts_agent_vm_name'), ':3128 -Uri ', variables('runner2_prep_script_uri'), ' } | Invoke-Expression; Initialize-WindowsVM -ProxyHostname ', parameters('vsts_agent_vm_name'), ' -SshPublicKeyBase64 ', base64(parameters('vms_ssh_public_key')), ' }\"')]" + "value": "[resourceId('Microsoft.Network/virtualNetworks/subnets', variables('vnet_name'), variables('subnet_name'))]" }, - "nic_name": { - "value": "[parameters('vsts_runner2_vm_network_interface_name')]" - }, - "nsg_id": { - "value": "[resourceId('Microsoft.Network/networkSecurityGroups', parameters('vsts_runner_vms_nsg_name'))]" - }, - "ssh_public_key": { - "value": "[parameters('vms_ssh_public_key')]" - }, - "vm_name": { - "value": "[parameters('vsts_runner2_vm_name')]" - }, - "vm_size": { - "value": "[parameters('vsts_runner2_vm_size')]" - }, - "vnet_subnet_id": { - "value": "[resourceId('Microsoft.Network/virtualNetworks/subnets', parameters('vms_vnet_name'), parameters('vms_vnet_subnet_name'))]" + "ip_addr_name": { + "value": "[if(parameters('create_runner_public_ip'), concat(variables('runner_prefix'), copyIndex(1), '-ip'), '!')]" } } }, - "dependsOn": [ - "[parameters('vsts_runner_vms_nsg_name')]", - "[parameters('vms_vnet_name')]", - "create_agent_vm" - ] - }, { - "type": "Microsoft.Resources/deployments", - "name": "finalize_agent_deployment", - "apiVersion": "2018-05-01", - "properties": { - "mode": "Incremental", - "templateLink": { - "uri": "[variables('finalize_agent_template_uri')]" - }, - "parameters": { - "extension_command": { - "value": "[concat('/bin/bash -c \"set -euo pipefail && curl ', variables('agent_prep2_script_uri'), ' | sudo bash -s -- ', parameters('vms_username'), ' ', parameters('vsts_runner1_vm_name'), ' ''', trim(reference('create_runner1_vm').outputs.hostkey.value[0]), ''' ', parameters('vsts_runner2_vm_name'), ' ''', trim(reference('create_runner2_vm').outputs.hostkey.value[0]), '''\"')]" - }, - "vm_name": { - "value": "[parameters('vsts_agent_vm_name')]" - } - } + "copy": { + "name": "runner_vm_copy", + "count": "[parameters('runner_count')]" }, "dependsOn": [ - "create_agent_vm", - "create_runner1_vm", - "create_runner2_vm" + "[variables('runner_nsg_name')]", + "[variables('vnet_name')]", + "[variables('keygen_script_name')]", + "create_proxy_vm" ] }] } diff --git a/builds/e2e/proxy/proxy.md b/builds/e2e/proxy/proxy.md new file mode 100644 index 00000000000..cc81ffbff84 --- /dev/null +++ b/builds/e2e/proxy/proxy.md @@ -0,0 +1,84 @@ +This file documents how to set up a proxy environment in Azure for our E2E tests. + +The environment includes: +- A proxy server VM - full network connectivity, runs an HTTP proxy server (squid). +- One or more proxy client VMs (aka "runners") - no internet-bound network connectivity except through the proxy server. +- A Key Vault that contains the private keys used to SSH into the VMs. + +After installing the Azure CLI, enter the following commands to deploy and configure the VMs: + +```sh +cd builds/e2e/proxy/ + +# ---------- +# Parameters + +# Name of Azure subscription +subscription_name='<>' + +# Location of the resource group +location='<>' + +# Name of the resource group +resource_group_name='<>' + +# Prefix used when creating Azure resources. If not given, defaults to 'e2e-<13 char hash>-'. +resource_prefix='<>' + +# The number of runner VMs to create +runner_count=2 + +# AAD Object ID for a user or group who will be given access to the secrets in the key vault +key_vault_access_objectid='<>' + +# ------- +# Execute + +# Log in to Azure subscription +az login +az account set -s "$subscription_name" + +# If the resource group doesn't already exist, create it +az group create -l "$location" -n "$resource_group_name" + +# Deploy the VMs +az deployment group create --resource-group "$resource_group_name" --name 'e2e-proxy' --template-file ./proxy-deployment-template.json --parameters "$( + jq -n \ + --arg resource_prefix $resource_prefix \ + --argjson runner_count $runner_count \ + --arg key_vault_access_objectid "$key_vault_access_objectid" \ + '{ + "resource_prefix": { "value": $resource_prefix }, + "runner_count": { "value": $runner_count }, + "key_vault_access_objectid": { "value": $key_vault_access_objectid }, + "create_runner_public_ip": { "value": true } + }' +)" +``` + +Once the deployment has completed, SSH into each runner VM to install and configure the Azure Pipelines agent. To SSH into the runner VMs, you must first download their private keys from Key Vault. Find the name of the key vault from your deployment, then list the secret URLs for the private keys: + +```sh +az keyvault secret list --vault-name '<>' -o tsv --query "[].id|[?contains(@, 'runner')]" +``` + +With a secret URL and an IP address, you can SSH into a runner VM like this: + +```sh +az keyvault secret show --id '<>' -o tsv --query value > ~/.ssh/id_rsa.runner +chmod 600 ~/.ssh/id_rsa.runner +ssh -i ~/.ssh/id_rsa.runner azureuser@ +``` + +To install and configure Azure Pipelines agent, see [Self-hosted Linux Agents](https://docs.microsoft.com/en-us/azure/devops/pipelines/agents/v2-linux?view=azure-devops) and [Run a self-hosted agent behind a web proxy](https://docs.microsoft.com/en-us/azure/devops/pipelines/agents/proxy?view=azure-devops&tabs=unix). + +> Note that the proxy URL required for most operations on the runner VMs is simply the hostname of the proxy server VM, e.g. `http://e2e-piaj2z37enpb4-proxy-vm:3128`. However, operations inside Docker containers on the runner VMs need either: +> - The _fully-qualified_ name of the proxy VM, e.g. `http://e2e-piaj2z37enpb4-proxy-vm.e0gkjhpfr5quzatbjwfoss05vh.xx.internal.cloudapp.net:3128`, or +> - The private IP address of the proxy VM, e.g. `http://10.0.0.4:3128` +> +> The end-to-end tests get the proxy URL from the agent (via the predefined variable `$(Agent.ProxyUrl)`). Therefore, when you configure the agent you must give it one of the two proxy URLs described above (using either the fully-qualified name or the IP address). For example, To pass the fully-qualifed name during agent installation on a runner VM: +> ``` +> proxy_hostname='<>' +> proxy_fqdn="http://$proxy_hostname.$(grep -Po '^search \K.*' /etc/resolv.conf):3128" +> ./config.sh --proxyurl $proxy_fqdn +> ``` \ No newline at end of file diff --git a/builds/e2e/runner.sh b/builds/e2e/runner.sh deleted file mode 100644 index 86e08541ce4..00000000000 --- a/builds/e2e/runner.sh +++ /dev/null @@ -1,37 +0,0 @@ -#!/bin/bash - -set -euo pipefail - -proxy_hostname="$1" -user="$2" - -curl -x "http://$proxy_hostname:3128" 'https://packages.microsoft.com/config/ubuntu/18.04/prod.list' > ./microsoft-prod.list -mv ./microsoft-prod.list /etc/apt/sources.list.d/ - -curl -x "http://$proxy_hostname:3128" 'https://packages.microsoft.com/keys/microsoft.asc' | gpg --dearmor > microsoft.gpg -mv ./microsoft.gpg /etc/apt/trusted.gpg.d/ - -http_proxy="http://$proxy_hostname:3128" https_proxy="http://$proxy_hostname:3128" apt-get update -http_proxy="http://$proxy_hostname:3128" https_proxy="http://$proxy_hostname:3128" apt-get install -y moby-cli moby-engine - -> ~/proxy-env.override.conf cat <<-EOF -[Service] -Environment="http_proxy=http://$proxy_hostname:3128" -Environment="https_proxy=http://$proxy_hostname:3128" -EOF -mkdir -p /etc/systemd/system/docker.service.d/ -cp ~/proxy-env.override.conf /etc/systemd/system/docker.service.d/ - -# Make proxy-env.override.conf available in $user's home directory so tests can -# apply the same proxy settings to the iotedge service -home="$(eval echo ~$user)" -cp ~/proxy-env.override.conf "$home/" -chown -R "$user:$user" "$home/proxy-env.override.conf" - -systemctl daemon-reload -systemctl restart docker - -# Output the host key so it can be added to the agent's known_hosts file -echo -n '#DATA#' -cat /etc/ssh/ssh_host_rsa_key.pub | awk '{ printf "%s %s", $1, $2 }' -echo -n '#DATA#' diff --git a/builds/e2e/templates/e2e-clear-docker-cached-images.yaml b/builds/e2e/templates/e2e-clear-docker-cached-images.yaml index 2b7c02c50c0..4403cd70d80 100644 --- a/builds/e2e/templates/e2e-clear-docker-cached-images.yaml +++ b/builds/e2e/templates/e2e-clear-docker-cached-images.yaml @@ -25,7 +25,7 @@ steps: # Remove old images $remove = sudo docker images --format '{{.Repository}}:{{.Tag}}' ` | where { $images -notcontains $_ } - sudo docker rm -f $(docker ps -a -q) + sudo docker rm -f $(sudo docker ps -a -q) $remove | foreach { sudo docker rmi $_ } # Delete everything else diff --git a/builds/e2e/templates/e2e-run.yaml b/builds/e2e/templates/e2e-run.yaml index 18f57d98a89..e4eba89c8fa 100644 --- a/builds/e2e/templates/e2e-run.yaml +++ b/builds/e2e/templates/e2e-run.yaml @@ -5,7 +5,7 @@ steps: # Unfortunately CentOs has some failing test that need to be worked on. if ('$(Agent.Name)'.Contains('centos')) { - sudo --preserve-env dotnet test $testFile --logger:trx --filter "TestCategory=CentOsSafe" + sudo --preserve-env dotnet vstest $testFile --logger:trx --testcasefilter:'Category=CentOsSafe' } else { @@ -20,16 +20,21 @@ steps: E2E_PREVIEW_IOT_HUB_CONNECTION_STRING: $(TestPreviewIotHubConnectionString) E2E_REGISTRIES__0__PASSWORD: $(TestContainerRegistryPassword) E2E_ROOT_CA_PASSWORD: $(TestRootCaPassword) + E2E_BLOB_STORE_SAS: $(TestBlobStoreSas) + http_proxy: $(Agent.ProxyUrl) + https_proxy: $(Agent.ProxyUrl) - task: PublishTestResults@2 displayName: Publish test results inputs: - testResultsFormat: vstest + testRunner: vstest testResultsFiles: '**/*.trx' searchFolder: $(Build.SourcesDirectory)/TestResults - testRunTitle: End-to-end tests ($(Build.BuildNumber) $(os) $(arch)) + testRunTitle: End-to-end tests ($(Build.BuildNumber) $(System.JobDisplayName)) buildPlatform: $(arch) - condition: succeededOrFailed() + # This task takes 15 min when behind a proxy, so disable it + # see https://github.com/microsoft/azure-pipelines-tasks/issues/11831 + condition: and(succeededOrFailed(), not(variables['Agent.ProxyUrl'])) - pwsh: | $logDir = '$(Build.ArtifactStagingDirectory)/logs' @@ -48,5 +53,5 @@ steps: displayName: Publish logs inputs: PathtoPublish: $(Build.ArtifactStagingDirectory)/logs - ArtifactName: logs-end-to-end-$(Build.BuildNumber)-$(os)-$(arch)-$(artifactName) + ArtifactName: logs-end-to-end-$(Build.BuildNumber)-$(System.PhaseName) condition: succeededOrFailed() diff --git a/builds/e2e/templates/e2e-setup.yaml b/builds/e2e/templates/e2e-setup.yaml index 409a08a2cbf..5aa1ccc6df7 100644 --- a/builds/e2e/templates/e2e-setup.yaml +++ b/builds/e2e/templates/e2e-setup.yaml @@ -18,7 +18,8 @@ steps: TestPreviewIotHubConnectionString, TestRootCaCertificate, TestRootCaKey, - TestRootCaPassword + TestRootCaPassword, + TestBlobStoreSas - pwsh: | $imageBuildId = $(resources.pipeline.images.runID) @@ -81,6 +82,9 @@ steps: $binDir = Convert-Path "$testDir/bin/Debug/netcoreapp3.1" Write-Output "##vso[task.setvariable variable=binDir]$binDir" displayName: Build tests + env: + http_proxy: $(Agent.ProxyUrl) + https_proxy: $(Agent.ProxyUrl) - pwsh: | $imagePrefix = '$(cr.address)/$(cr.labelPrefix)azureiotedge' @@ -121,10 +125,15 @@ steps: if ('$(arch)' -eq 'arm32v7' -Or '$(arch)' -eq 'arm64v8') { $context['optimizeForPerformance'] = 'false' - $context['setupTimeoutMinutes'] = 6 + $context['setupTimeoutMinutes'] = 10 $context['teardownTimeoutMinutes'] = 5 $context['testTimeoutMinutes'] = 6 } + if ($env:AGENT_PROXYURL) + { + $context['proxy'] = $env:AGENT_PROXYURL + } + $context | ConvertTo-Json | Out-File -Encoding Utf8 '$(binDir)/context.json' displayName: Create test arguments file (context.json) diff --git a/builds/misc/images-mqtt.yaml b/builds/misc/images-mqtt.yaml index 9af6a4e7870..3a68689ba22 100644 --- a/builds/misc/images-mqtt.yaml +++ b/builds/misc/images-mqtt.yaml @@ -29,7 +29,7 @@ jobs: displayName: Build MQTT Broker - amd64 inputs: filePath: scripts/linux/cross-platform-rust-build.sh - arguments: --os ubuntu18.04 --arch amd64 --build-path mqtt/mqttd --cargo-flags '--no-default-features --features="generic"' + arguments: --os alpine --arch amd64 --build-path mqtt/mqttd --cargo-flags '--no-default-features --features="generic"' - template: templates/move-rust-artifacts.yaml parameters: diff --git a/builds/misc/templates/build-broker-watchdog.yaml b/builds/misc/templates/build-broker-watchdog.yaml index be1064b511a..2bb8363ca0f 100644 --- a/builds/misc/templates/build-broker-watchdog.yaml +++ b/builds/misc/templates/build-broker-watchdog.yaml @@ -4,7 +4,7 @@ steps: displayName: Build watchdog - amd64 inputs: filePath: scripts/linux/cross-platform-rust-build.sh - arguments: --os ubuntu18.04 --arch amd64 --build-path edge-hub/watchdog + arguments: --os alpine --arch amd64 --build-path edge-hub/watchdog - task: Bash@3 displayName: Build watchdog - arm32 inputs: @@ -20,7 +20,7 @@ steps: displayName: Build MQTT Broker - amd64 inputs: filePath: scripts/linux/cross-platform-rust-build.sh - arguments: --os ubuntu18.04 --arch amd64 --build-path mqtt/mqttd + arguments: --os alpine --arch amd64 --build-path mqtt/mqttd - task: Bash@3 displayName: Build MQTT Broker - arm32 inputs: diff --git a/builds/misc/templates/build-broker.yaml b/builds/misc/templates/build-broker.yaml index 20c40f42dcf..fed3872bd08 100644 --- a/builds/misc/templates/build-broker.yaml +++ b/builds/misc/templates/build-broker.yaml @@ -3,7 +3,7 @@ steps: displayName: Build MQTT Broker - amd64 inputs: filePath: scripts/linux/cross-platform-rust-build.sh - arguments: --os ubuntu18.04 --arch amd64 --build-path mqtt/mqttd + arguments: --os alpine --arch amd64 --build-path mqtt/mqttd - task: Bash@3 displayName: Build MQTT Broker - arm32 inputs: diff --git a/builds/misc/templates/build-watchdog.yaml b/builds/misc/templates/build-watchdog.yaml index 009f1d97575..0e4e9046a94 100644 --- a/builds/misc/templates/build-watchdog.yaml +++ b/builds/misc/templates/build-watchdog.yaml @@ -3,7 +3,7 @@ steps: displayName: Build watchdog - amd64 inputs: filePath: scripts/linux/cross-platform-rust-build.sh - arguments: --os ubuntu18.04 --arch amd64 --build-path edge-hub/watchdog + arguments: --os alpine --arch amd64 --build-path edge-hub/watchdog - task: Bash@3 displayName: Build watchdog - arm32 inputs: diff --git a/doc/ModuleStartupOrder.md b/doc/ModuleStartupOrder.md index 7fe4cb7e206..023ddf26430 100644 --- a/doc/ModuleStartupOrder.md +++ b/doc/ModuleStartupOrder.md @@ -1,128 +1,130 @@ -# How to configure module startup order - -By default, IoT Edge does not impose an ordering in the sequence in which modules are started, updated or stopped. Edge Agent by default is the first module that gets started and based on the edge deployment specification, it figures out which modules need to be started, updated or stopped and executes those operations in a non-deterministic order. - -The processing order of modules can be controlled by specifying the value of a module-specific property called `startupOrder` in the IoT Edge deployment. Modules that have been assigned a lower integer value as the startup order will be processed before modules that have been assigned a higher value. - -## __Use case__ - -Customers who have an array of modules of which some are 'critical' or 'foundation' modules that are required by other modules in the ecosystem might want these modules to be started before other modules. This is so as to achieve a better end user experience where other modules don't have to wait for these 'critical' or 'foundation' modules to be started, so as to initialize themselves. - -As an example, some customers want the Edge Hub module to be started before any other non-system modules in the ecosystem are started. This is so that other modules don't spend unnecessary cycles waiting for Edge Hub to come up before they can start sending messages to other modules or upstream to IoT Hub. - -**That being said, module owners should design their modules to withstand any failures of these 'critical' or 'foundation' modules, that they are dependent upon, as they could go down at any arbitrary time and an arbitrary number of times.** - -## __Configuration__ - -Customers can optionally specify a `startupOrder` value for each module in their IoT Edge deployment. This can be used to achieve module boot ordering. Modules with startup order of '1' are created and processed before those with a value greater than '1'. The maximum value of this property will be 4294967295. Only after an attempt has been made to start those with a lower value will those with a higher value be created and started. Startup order does not imply that a given module that starts before another will *complete* its startup before the other. Also, modules where the desired state is NOT configured to be 'Running' are skipped. - -The value of `startupOrder` must be positive and zero-based (i.e. a value of '0' means start this module first). Modules that possess the same startupOrder will be created at the same time and will have no deterministic startup order imposed amongst themselves. - -**It must be noted that the Edge Agent module does not support the `startupOrder` property. It always starts first.** - -Modules without a specified `startupOrder` value are started in a non-deterministic order. They are assigned the maximum startupOrder of 4294967295 indicating that they should be created and started after all other modules with specified values. - -**Please note that Kubernetes mode of IoT Edge does not support module startup ordering.** - -## __Example__ - -### __How to set startup order of Edge modules__ - -Here's an example of how to set the startupOrder of IoT Edge modules through Az CLI: - -Create a deployment manifest `deployment.json` JSON file that has your IoT Edge deployment specification. Please refer to [Learn how to deploy modules and establish routes in IoT Edge][1] for more information about the IoT Edge deployment manifest. - -The following sample deployment manifest illustrates how startupOrder values of modules can be set: - -```JSON -{ - "modulesContent": { - "$edgeAgent": { - "properties.desired": { - "schemaVersion": "1.0", - "runtime": { - "type": "docker", - "settings": { - "minDockerVersion": "v1.25", - "loggingOptions": "", - "registryCredentials": { - "ContosoRegistry": { - "username": "myacr", - "password": "", - "address": "myacr.azurecr.io" - } - } - } - }, - "systemModules": { - "edgeAgent": { - "type": "docker", - "settings": { - "image": "mcr.microsoft.com/azureiotedge-agent:1.0", - "createOptions": "" - } - }, - "edgeHub": { - "type": "docker", - "status": "running", - "restartPolicy": "always", - "settings": { - "image": "mcr.microsoft.com/azureiotedge-hub:1.0", - "createOptions": "" - }, - "startupOrder": 0 - } - }, - "modules": { - "SimulatedTemperatureSensor": { - "version": "1.0", - "type": "docker", - "status": "running", - "restartPolicy": "always", - "settings": { - "image": "mcr.microsoft.com/azureiotedge-simulated-temperature-sensor:1.0", - "createOptions": "{}" - }, - "startupOrder": 1 - }, - "filtermodule": { - "version": "1.0", - "type": "docker", - "status": "running", - "restartPolicy": "always", - "settings": { - "image": "myacr.azurecr.io/filtermodule:latest", - "createOptions": "{}" - } - } - } - } - }, - "$edgeHub": { - "properties.desired": { - "schemaVersion": "1.0", - "routes": { - "sensorToFilter": "FROM /messages/modules/SimulatedTemperatureSensor/outputs/temperatureOutput INTO BrokeredEndpoint(\"/modules/filtermodule/inputs/input1\")", - "filterToIoTHub": "FROM /messages/modules/filtermodule/outputs/output1 INTO $upstream" - }, - "storeAndForwardConfiguration": { - "timeToLiveSecs": 10 - } - } - } - } -} -``` - -In the sample deployment manifest shown above: - -* The `$edgeHub` module has been assigned a `startupOrder` value of 0. -* The `SimulatedTemperatureSensor` module has been assigned a `startupOrder` value of 1. -* The `filtermodule` module has not been assigned any `startupOrder` value which means that it will by default assume the value of 4294967295. It will be created and started after all others. - -When this deployment manifest is deployed to a device that does not have any modules running, `$edgeHub` is the first module that will be started followed by the `SimulatedTemperatureSensor` module and then the `filtermodule`. - -Please refer to [Deploy Azure IoT Edge modules with Azure CLI][2] for steps on how to deploy the deployment.json file to your device. - -[1]: https://docs.microsoft.com/azure/iot-edge/module-composition -[2]: https://docs.microsoft.com/en-us/azure/iot-edge/how-to-deploy-modules-cli +# How to configure module startup order + +By default, IoT Edge does not impose an ordering in the sequence in which modules are started, updated or stopped. Edge Agent by default is the first module that gets started and based on the edge deployment specification, it figures out which modules need to be started, updated or stopped and executes those operations in a non-deterministic order. + +The processing order of modules can be controlled by specifying the value of a module-specific property called `startupOrder` in the IoT Edge deployment. Modules that have been assigned a lower integer value as the startup order will be processed before modules that have been assigned a higher value. + +## __Use case__ + +Customers who have an array of modules of which some are 'critical' or 'foundation' modules that are required by other modules in the ecosystem might want these modules to be started before other modules. This is so as to achieve a better end user experience where other modules don't have to wait for these 'critical' or 'foundation' modules to be started, so as to initialize themselves. + +As an example, some customers want the Edge Hub module to be started before any other non-system modules in the ecosystem are started. This is so that other modules don't spend unnecessary cycles waiting for Edge Hub to come up before they can start sending messages to other modules or upstream to IoT Hub. + +**That being said, module owners should design their modules to withstand any failures of these 'critical' or 'foundation' modules, that they are dependent upon, as they could go down at any arbitrary time and an arbitrary number of times.** + +## __Configuration__ + +Customers can optionally specify a `startupOrder` value for each module in their IoT Edge deployment. This can be used to achieve module boot ordering. Modules with startup order of '1' are created and processed before those with a value greater than '1'. The maximum value of this property will be 4294967295. Only after an attempt has been made to start those with a lower value will those with a higher value be created and started. Startup order does not imply that a given module that starts before another will *complete* its startup before the other. Also, modules where the desired state is NOT configured to be 'Running' are skipped. + +The value of `startupOrder` must be positive and zero-based (i.e. a value of '0' means start this module first). Modules that possess the same startupOrder will be created at the same time and will have no deterministic startup order imposed amongst themselves. + +**It must be noted that the Edge Agent module does not support the `startupOrder` property. It always starts first.** + +Modules without a specified `startupOrder` value are started in a non-deterministic order. They are assigned the maximum startupOrder of 4294967295 indicating that they should be created and started after all other modules with specified values. + +**Please note that Kubernetes mode of IoT Edge does not support module startup ordering.** + +## __Example__ + +### __How to set startup order of Edge modules__ + +Here's an example of how to set the startupOrder of IoT Edge modules through Az CLI: + +Create a deployment manifest `deployment.json` JSON file that has your IoT Edge deployment specification. Please refer to [Learn how to deploy modules and establish routes in IoT Edge][1] for more information about the IoT Edge deployment manifest. + +The following sample deployment manifest illustrates how startupOrder values of modules can be set: + +```JSON +{ + "modulesContent": { + "$edgeAgent": { + "properties.desired": { + "schemaVersion": "1.1", + "runtime": { + "type": "docker", + "settings": { + "minDockerVersion": "v1.25", + "loggingOptions": "", + "registryCredentials": { + "ContosoRegistry": { + "username": "myacr", + "password": "", + "address": "myacr.azurecr.io" + } + } + } + }, + "systemModules": { + "edgeAgent": { + "type": "docker", + "settings": { + "image": "mcr.microsoft.com/azureiotedge-agent:1.0", + "createOptions": "" + } + }, + "edgeHub": { + "type": "docker", + "status": "running", + "restartPolicy": "always", + "settings": { + "image": "mcr.microsoft.com/azureiotedge-hub:1.0", + "createOptions": "" + }, + "startupOrder": 0 + } + }, + "modules": { + "SimulatedTemperatureSensor": { + "version": "1.0", + "type": "docker", + "status": "running", + "restartPolicy": "always", + "settings": { + "image": "mcr.microsoft.com/azureiotedge-simulated-temperature-sensor:1.0", + "createOptions": "{}" + }, + "startupOrder": 1 + }, + "filtermodule": { + "version": "1.0", + "type": "docker", + "status": "running", + "restartPolicy": "always", + "settings": { + "image": "myacr.azurecr.io/filtermodule:latest", + "createOptions": "{}" + } + } + } + } + }, + "$edgeHub": { + "properties.desired": { + "schemaVersion": "1.0", + "routes": { + "sensorToFilter": "FROM /messages/modules/SimulatedTemperatureSensor/outputs/temperatureOutput INTO BrokeredEndpoint(\"/modules/filtermodule/inputs/input1\")", + "filterToIoTHub": "FROM /messages/modules/filtermodule/outputs/output1 INTO $upstream" + }, + "storeAndForwardConfiguration": { + "timeToLiveSecs": 10 + } + } + } + } +} +``` + +In the sample deployment manifest shown above: + +* The `$edgeAgent` schemaVersion has been set to 1.1 (or later). +* The `edgeAgent` module always starts first. It does not support the `startupOrder` property. +* The `edgeHub` module has been assigned a `startupOrder` value of 0. +* The `SimulatedTemperatureSensor` module has been assigned a `startupOrder` value of 1. +* The `filtermodule` module has not been assigned any `startupOrder` value which means that it will by default assume the value of 4294967295. It will be created and started after all others. + +When this deployment manifest is deployed to a device that does not have any modules running, `$edgeHub` is the first module that will be started followed by the `SimulatedTemperatureSensor` module and then the `filtermodule`. + +Please refer to [Deploy Azure IoT Edge modules with Azure CLI][2] for steps on how to deploy the deployment.json file to your device. + +[1]: https://docs.microsoft.com/azure/iot-edge/module-composition +[2]: https://docs.microsoft.com/en-us/azure/iot-edge/how-to-deploy-modules-cli diff --git a/edge-agent/docker/linux/arm32v7/Dockerfile b/edge-agent/docker/linux/arm32v7/Dockerfile index d5eca1504b6..bb97476e21e 100644 --- a/edge-agent/docker/linux/arm32v7/Dockerfile +++ b/edge-agent/docker/linux/arm32v7/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=1.0.6.2-linux-arm32v7 +ARG base_tag=1.0.6.4-linux-arm32v7 FROM azureiotedge/azureiotedge-agent-base:${base_tag} ARG EXE_DIR=. diff --git a/edge-agent/docker/linux/arm32v7/base/Dockerfile b/edge-agent/docker/linux/arm32v7/base/Dockerfile index 40340fffef2..c1783725833 100644 --- a/edge-agent/docker/linux/arm32v7/base/Dockerfile +++ b/edge-agent/docker/linux/arm32v7/base/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=3.1.7-bionic-arm32v7 +ARG base_tag=3.1.10-bionic-arm32v7 FROM mcr.microsoft.com/dotnet/core/aspnet:${base_tag} RUN apt-get update && \ diff --git a/edge-agent/docker/linux/arm64v8/Dockerfile b/edge-agent/docker/linux/arm64v8/Dockerfile index 6f01328c41a..50a3c7626da 100644 --- a/edge-agent/docker/linux/arm64v8/Dockerfile +++ b/edge-agent/docker/linux/arm64v8/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=1.0.6.2-linux-arm64v8 +ARG base_tag=1.0.6.4-linux-arm64v8 FROM azureiotedge/azureiotedge-agent-base:${base_tag} diff --git a/edge-agent/docker/linux/arm64v8/base/Dockerfile b/edge-agent/docker/linux/arm64v8/base/Dockerfile index bcf981a21c5..5d42088b01b 100644 --- a/edge-agent/docker/linux/arm64v8/base/Dockerfile +++ b/edge-agent/docker/linux/arm64v8/base/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=3.1.7-bionic-arm64v8 +ARG base_tag=3.1.10-bionic-arm64v8 ARG num_procs=4 FROM mcr.microsoft.com/dotnet/core/aspnet:${base_tag} diff --git a/edge-agent/src/Microsoft.Azure.Devices.Edge.Agent.Edgelet/versioning/ModuleManagementHttpClientVersioned.cs b/edge-agent/src/Microsoft.Azure.Devices.Edge.Agent.Edgelet/versioning/ModuleManagementHttpClientVersioned.cs index 46869f4488c..14ed7e2fc00 100644 --- a/edge-agent/src/Microsoft.Azure.Devices.Edge.Agent.Edgelet/versioning/ModuleManagementHttpClientVersioned.cs +++ b/edge-agent/src/Microsoft.Azure.Devices.Edge.Agent.Edgelet/versioning/ModuleManagementHttpClientVersioned.cs @@ -88,7 +88,7 @@ public virtual async Task GetModuleLogs(string module, bool follow, Opti { using (HttpClient httpClient = HttpClientHelper.GetHttpClient(this.ManagementUri)) { - string baseUrl = HttpClientHelper.GetBaseUrl(this.ManagementUri); + string baseUrl = HttpClientHelper.GetBaseUrl(this.ManagementUri).TrimEnd('/'); var logsUrl = new StringBuilder(); logsUrl.AppendFormat(CultureInfo.InvariantCulture, LogsUrlTemplate, baseUrl, module, this.Version.Name, follow.ToString().ToLowerInvariant()); tail.ForEach(t => logsUrl.AppendFormat($"&{LogsUrlTailParameter}={t}")); diff --git a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/ClientConnectionHandler.cs b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/ClientConnectionHandler.cs index acfb79fbcc1..8db332dc782 100644 --- a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/ClientConnectionHandler.cs +++ b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/ClientConnectionHandler.cs @@ -158,6 +158,8 @@ public DeviceProxy(ClientConnectionHandler clientConnectionHandler, IIdentity id public bool IsActive => this.isActive; + public bool IsDirectClient => true; + public IIdentity Identity { get; } public Task CloseAsync(Exception ex) diff --git a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.CloudProxy/authenticators/DeviceScopeTokenAuthenticator.cs b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.CloudProxy/authenticators/DeviceScopeTokenAuthenticator.cs index 9134a27bca7..8b8e6eaf8a2 100644 --- a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.CloudProxy/authenticators/DeviceScopeTokenAuthenticator.cs +++ b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.CloudProxy/authenticators/DeviceScopeTokenAuthenticator.cs @@ -137,7 +137,7 @@ internal bool ValidateAudience(string audience, IIdentity identity) } if (string.IsNullOrWhiteSpace(hostName) || - !(this.iothubHostName.Equals(hostName) || this.edgeHubHostName.Equals(hostName))) + !(string.Equals(this.iothubHostName, hostName, StringComparison.OrdinalIgnoreCase) || string.Equals(this.edgeHubHostName, hostName, StringComparison.OrdinalIgnoreCase))) { Events.InvalidHostName(identity.Id, hostName, this.iothubHostName, this.edgeHubHostName); return false; diff --git a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/Authenticator.cs b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/Authenticator.cs index 856d7fe0679..cb162497ceb 100644 --- a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/Authenticator.cs +++ b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/Authenticator.cs @@ -35,7 +35,17 @@ async Task AuthenticateAsync(IClientCredentials clientCredentials, bool re Preconditions.CheckNotNull(clientCredentials); bool isAuthenticated; - if (clientCredentials.AuthenticationType == AuthenticationType.X509Cert) + if (clientCredentials.AuthenticationType == AuthenticationType.Implicit) + { + // Implicit authentication is executed when in a nested scenario a parent edge device captures a + // an IotHub message on an mqtt broker topic belonging to a device never seen before. In this case the + // child edge device has authenticated the connecting device, the authorization is continously monitoring + // if the device is publishing on allowed topics, so when a message arrives on a topic belonging to + // the device, it is sure that it has been authenticated/authorized before. Now just create an entry + // for it without further checks + isAuthenticated = true; + } + else if (clientCredentials.AuthenticationType == AuthenticationType.X509Cert) { isAuthenticated = await (reAuthenticating ? this.certificateAuthenticator.ReauthenticateAsync(clientCredentials) diff --git a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/DeviceScopeAuthenticator.cs b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/DeviceScopeAuthenticator.cs index 96c608cc4b0..c1955b0e188 100644 --- a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/DeviceScopeAuthenticator.cs +++ b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/DeviceScopeAuthenticator.cs @@ -134,8 +134,20 @@ public async Task ReauthenticateAsync(IClientCredentials clientCredentials Option authChain = await this.deviceScopeIdentitiesCache.GetAuthChain(authTarget); if (!authChain.HasValue) { - Events.NoAuthChain(authTarget); - return (false, false); + // The auth-target might be a new device that was recently added, and our + // cache might not have it yet. Try refreshing the target identity to see + // if we can get it from upstream. + Events.NoAuthChainResyncing(authTarget, actorDeviceId); + await this.deviceScopeIdentitiesCache.RefreshServiceIdentityOnBehalfOf(authTarget, actorDeviceId); + authChain = await this.deviceScopeIdentitiesCache.GetAuthChain(authTarget); + + if (!authChain.HasValue) + { + // Still don't have a valid auth-chain for the target, it must be + // out of scope, so we're done here + Events.NoAuthChain(authTarget); + return (false, false); + } } // Check that the actor is authorized to connect OnBehalfOf of the target @@ -148,7 +160,7 @@ public async Task ReauthenticateAsync(IClientCredentials clientCredentials // Check credentials against the acting EdgeHub string actorEdgeHubId = actorDeviceId + $"/{Constants.EdgeHubModuleId}"; - return await this.AuthenticateWithServiceIdentity(credentials, actorEdgeHubId, syncServiceIdentity); + return await this.AuthenticateWithServiceIdentity(credentials, actorEdgeHubId, true); } async Task<(bool isAuthenticated, bool serviceIdentityFound)> AuthenticateWithServiceIdentity(T credentials, string serviceIdentityId, bool syncServiceIdentity) @@ -158,7 +170,7 @@ public async Task ReauthenticateAsync(IClientCredentials clientCredentials if (!isAuthenticated && (!serviceIdentityFound || syncServiceIdentity)) { - Events.ResyncingServiceIdentity(credentials.Identity, serviceIdentityId); + Events.ResyncingServiceIdentity(credentials.Identity, serviceIdentityId, serviceIdentityFound); await this.deviceScopeIdentitiesCache.RefreshServiceIdentity(serviceIdentityId); serviceIdentity = await this.deviceScopeIdentitiesCache.GetServiceIdentity(serviceIdentityId); @@ -193,6 +205,7 @@ enum EventIds AuthenticatedInScope, InputCredentialsNotValid, ResyncingServiceIdentity, + NoAuthChainResyncing, AuthenticatingWithDeviceIdentity } @@ -234,9 +247,14 @@ public static void InputCredentialsNotValid(IIdentity identity) Log.LogInformation((int)EventIds.InputCredentialsNotValid, $"Credentials for client {identity.Id} are not valid."); } - public static void ResyncingServiceIdentity(IIdentity identity, string serviceIdentityId) + public static void ResyncingServiceIdentity(IIdentity identity, string serviceIdentityId, bool identityFound) + { + Log.LogInformation((int)EventIds.ResyncingServiceIdentity, $"Unable to authenticate client {identity.Id} with cached service identity {serviceIdentityId} (Found: {identityFound}). Resyncing service identity..."); + } + + public static void NoAuthChainResyncing(string authTarget, string actorDevice) { - Log.LogInformation((int)EventIds.ResyncingServiceIdentity, $"Unable to authenticate client {identity.Id} with cached service identity {serviceIdentityId}. Resyncing service identity..."); + Log.LogInformation((int)EventIds.NoAuthChainResyncing, $"No cached auth-chain when authenticating {actorDevice} OnBehalfOf {authTarget}. Resyncing service identity..."); } public static void AuthenticatingWithDeviceIdentity(IModuleIdentity moduleIdentity) diff --git a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/EdgeHubConnection.cs b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/EdgeHubConnection.cs index 6074c414fd8..2a9b839507d 100644 --- a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/EdgeHubConnection.cs +++ b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/EdgeHubConnection.cs @@ -299,6 +299,8 @@ public EdgeHubDeviceProxy(EdgeHubConnection edgeHubConnection) public bool IsActive => true; + public bool IsDirectClient => true; + public IIdentity Identity => this.edgeHubConnection.edgeHubIdentity; public Task CloseAsync(Exception ex) => Task.CompletedTask; diff --git a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/ServiceIdentityDictionary.cs b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/ServiceIdentityDictionary.cs index eccdd555f09..e3fbc947f93 100644 --- a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/ServiceIdentityDictionary.cs +++ b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/ServiceIdentityDictionary.cs @@ -25,7 +25,26 @@ public ServiceIdentityDictionary(string actorDeviceId) public string GetActorDeviceId() => this.actorDeviceId; - public Task> GetAuthChain(string id) => Task.FromResult(Option.None()); + public Task> GetAuthChain(string id) + { + Option authChain = Option.None(); + + if (this.identities.TryGetValue(id, out ServiceIdentity identity)) + { + if (identity.Id != this.actorDeviceId) + { + // All identities are immediate children of the actor Edge + authChain = Option.Some($"{identity.Id};{this.actorDeviceId}"); + } + else + { + // Special case for the self-identity of the Edge device + authChain = Option.Some(this.actorDeviceId); + } + } + + return Task.FromResult(authChain); + } public Task> GetEdgeAuthChain(string id) => throw new NotImplementedException("Nested Edge not enabled"); diff --git a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/ServiceIdentityTree.cs b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/ServiceIdentityTree.cs index 085379664d7..81388c6ffb3 100644 --- a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/ServiceIdentityTree.cs +++ b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/ServiceIdentityTree.cs @@ -36,11 +36,13 @@ public async Task InsertOrUpdate(ServiceIdentity identity) using (await this.nodesLock.LockAsync()) { + bool isUpdate = false; + if (this.nodes.ContainsKey(identity.Id)) { // Update case - this is just remove + re-insert + isUpdate = true; this.RemoveSingleNode(identity.Id); - Events.NodeRemoved(identity.Id); } // Insert case @@ -53,7 +55,14 @@ public async Task InsertOrUpdate(ServiceIdentity identity) this.InsertDeviceIdentity(identity); } - Events.NodeAdded(identity.Id); + if (isUpdate) + { + Events.NodeUpdated(identity.Id); + } + else + { + Events.NodeAdded(identity.Id); + } } } @@ -398,6 +407,7 @@ enum EventIds { NodeAdded = IdStart, NodeRemoved, + NodeUpdated, AuthChainAdded, AuthChainRemoved, MaxDepthExceeded, @@ -406,10 +416,13 @@ enum EventIds } public static void NodeAdded(string id) => - Log.LogDebug((int)EventIds.NodeAdded, $"Add node: {id}"); + Log.LogInformation((int)EventIds.NodeAdded, $"Add node: {id}"); public static void NodeRemoved(string id) => - Log.LogDebug((int)EventIds.NodeRemoved, $"Removed node: {id}"); + Log.LogInformation((int)EventIds.NodeRemoved, $"Removed node: {id}"); + + public static void NodeUpdated(string id) => + Log.LogInformation((int)EventIds.NodeUpdated, $"Updated node: {id}"); public static void AuthChainAdded(string id, string authChain, int depth) => Log.LogDebug((int)EventIds.AuthChainAdded, $"Auth-chain added for: {id}, at depth: {depth}, {authChain}"); diff --git a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/config/AuthorizationConfig.cs b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/config/AuthorizationConfig.cs index d3d155c2a58..3a63166c662 100644 --- a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/config/AuthorizationConfig.cs +++ b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/config/AuthorizationConfig.cs @@ -4,7 +4,10 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Core.Config using System; using System.Collections.Generic; using System.Linq; + using System.Runtime.Serialization; using Microsoft.Azure.Devices.Edge.Util; + using Newtonsoft.Json; + using Newtonsoft.Json.Converters; /// /// Domain object that represents Authorization configuration for Edge Hub Module (MQTT Broker). @@ -19,6 +22,7 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Core.Config /// public class AuthorizationConfig : IEquatable { + [JsonProperty("statements")] public IList Statements { get; } public AuthorizationConfig(IList statements) @@ -64,17 +68,21 @@ public Statement( IList resources) { this.Effect = effect; - this.Identities = identities; - this.Operations = operations; - this.Resources = resources; + this.Identities = Preconditions.CheckNotNull(identities, nameof(identities)); + this.Operations = Preconditions.CheckNotNull(operations, nameof(operations)); + this.Resources = Preconditions.CheckNotNull(resources, nameof(resources)); } + [JsonProperty("effect")] public Effect Effect { get; } + [JsonProperty("identities")] public IList Identities { get; } + [JsonProperty("operations")] public IList Operations { get; } + [JsonProperty("resources")] public IList Resources { get; } public bool Equals(Statement other) @@ -110,9 +118,13 @@ public override int GetHashCode() } } + [JsonConverter(typeof(StringEnumConverter))] public enum Effect { + [EnumMember(Value = "allow")] Allow = 0, + + [EnumMember(Value = "deny")] Deny = 1, } } diff --git a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/config/AuthorizationProperties.cs b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/config/AuthorizationProperties.cs index 5efe085267d..26bcde7e8d0 100644 --- a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/config/AuthorizationProperties.cs +++ b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/config/AuthorizationProperties.cs @@ -18,11 +18,11 @@ public class Statement public Statement(IList identities, IList allow, IList deny) { this.Identities = identities; - this.Allow = allow; - this.Deny = deny; + this.Allow = allow ?? new List(); + this.Deny = deny ?? new List(); } - [JsonProperty(PropertyName = "identities")] + [JsonProperty(PropertyName = "identities", Required = Required.Always)] public IList Identities { get; } [JsonProperty(PropertyName = "allow")] @@ -37,8 +37,8 @@ public class Rule [JsonConstructor] public Rule(IList operations, IList resources) { - this.Operations = operations; - this.Resources = resources; + this.Operations = operations ?? new List(); + this.Resources = resources ?? new List(); } [JsonProperty(PropertyName = "operations")] diff --git a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/config/BridgeConfig.cs b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/config/BridgeConfig.cs index 9de67938480..8fa04a38af4 100644 --- a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/config/BridgeConfig.cs +++ b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/config/BridgeConfig.cs @@ -1,7 +1,12 @@ // Copyright (c) Microsoft. All rights reserved. namespace Microsoft.Azure.Devices.Edge.Hub.Core.Config { + using System; using System.Collections.Generic; + using System.Linq; + using System.Runtime.Serialization; + using Newtonsoft.Json; + using Newtonsoft.Json.Converters; /// /// Domain object that represents Bridge configuration for MQTT Broker. @@ -9,7 +14,151 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Core.Config /// This object is being constructed from the EdgeHub twin's desired properties. /// See for DTO. /// - public class BridgeConfig : List + public class BridgeConfig : List, IEquatable { + public bool Equals(BridgeConfig other) + { + if (ReferenceEquals(null, other)) + { + return false; + } + + if (ReferenceEquals(this, other)) + { + return true; + } + + return Enumerable.SequenceEqual(this, other); + } + + public override bool Equals(object obj) + => this.Equals(obj as BridgeConfig); + + public override int GetHashCode() + { + unchecked + { + int hash = 17; + hash = this.Aggregate(hash, (acc, item) => (acc * 31 + item.GetHashCode())); + return hash; + } + } + } + + public class Bridge : IEquatable + { + [JsonConstructor] + public Bridge(string endpoint, IList settings) + { + this.Endpoint = endpoint; + this.Settings = settings; + } + + [JsonProperty("endpoint", Required = Required.Always)] + public string Endpoint { get; } + + [JsonProperty("settings", Required = Required.Always)] + public IList Settings { get; } + + public bool Equals(Bridge other) + { + if (ReferenceEquals(null, other)) + { + return false; + } + + if (ReferenceEquals(this, other)) + { + return true; + } + + return this.Endpoint.Equals(other.Endpoint) + && Enumerable.SequenceEqual(this.Settings, other.Settings); + } + + public override bool Equals(object obj) + => this.Equals(obj as Bridge); + + public override int GetHashCode() + { + unchecked + { + int hashCode = this.Endpoint?.GetHashCode() ?? 0; + hashCode = this.Settings.Aggregate(hashCode, (acc, item) => (acc * 31 + item.GetHashCode())); + return hashCode; + } + } + } + + public class Settings : IEquatable + { + [JsonConstructor] + public Settings( + Direction direction, + string topic, + string inPrefix, + string outPrefix) + { + this.Direction = direction; + this.Topic = topic; + this.InPrefix = inPrefix ?? string.Empty; + this.OutPrefix = outPrefix ?? string.Empty; + } + + [JsonProperty("direction", Required = Required.Always)] + public Direction Direction { get; } + + [JsonProperty("topic", Required = Required.Always)] + public string Topic { get; } + + [JsonProperty("inPrefix")] + public string InPrefix { get; } + + [JsonProperty("outPrefix")] + public string OutPrefix { get; } + + public bool Equals(Settings other) + { + if (ReferenceEquals(null, other)) + { + return false; + } + + if (ReferenceEquals(this, other)) + { + return true; + } + + return this.Direction.Equals(other.Direction) + && this.Topic.Equals(other.Topic) + && this.InPrefix.Equals(other.InPrefix) + && this.OutPrefix.Equals(other.OutPrefix); + } + + public override bool Equals(object obj) + => this.Equals(obj as Settings); + + public override int GetHashCode() + { + unchecked + { + int hashCode = this.Direction.GetHashCode(); + hashCode = (hashCode * 397) ^ (this.Topic?.GetHashCode() ?? 0); + hashCode = (hashCode * 397) ^ (this.InPrefix?.GetHashCode() ?? 0); + hashCode = (hashCode * 397) ^ (this.OutPrefix?.GetHashCode() ?? 0); + return hashCode; + } + } + } + + [JsonConverter(typeof(StringEnumConverter))] + public enum Direction + { + [EnumMember(Value = "in")] + In, + [EnumMember(Value = "out")] + Out, + [EnumMember(Value = "both")] + Both } } diff --git a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/config/BrokerConfigValidator.cs b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/config/BrokerConfigValidator.cs index ba932a11abc..e9e51944ff1 100644 --- a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/config/BrokerConfigValidator.cs +++ b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/config/BrokerConfigValidator.cs @@ -44,12 +44,12 @@ public virtual IList ValidateAuthorizationConfig(AuthorizationProperties foreach (var rule in statement.Allow) { - ValidateRule(rule, order, errors); + ValidateRule(rule, order, errors, "Allow"); } foreach (var rule in statement.Deny) { - ValidateRule(rule, order, errors); + ValidateRule(rule, order, errors, "Deny"); } order++; @@ -58,23 +58,77 @@ public virtual IList ValidateAuthorizationConfig(AuthorizationProperties return errors; } - private static void ValidateRule(AuthorizationProperties.Rule rule, int order, List errors) + /// + /// Important!: Validation logic should be in sync with validation logic in the Broker. + /// + /// Validates bridge config and returns a list of errors (if any). + /// + public virtual IList ValidateBridgeConfig(BridgeConfig properties) + { + Preconditions.CheckNotNull(properties, nameof(properties)); + + var order = 0; + var errors = new List(); + foreach (var bridge in properties) + { + ValidateBridge(bridge, order, errors); + order++; + } + + return errors; + } + + static void ValidateBridge(Bridge bridge, int order, List errors) + { + if (string.IsNullOrEmpty(bridge.Endpoint)) + { + errors.Add($"Bridge {order}: Endpoint must not be empty"); + } + + if (bridge.Settings.Count == 0) + { + errors.Add($"Bridge {order}: Settings must not be empty"); + } + + foreach (var setting in bridge.Settings) + { + if (setting.Topic != null + && !IsValidTopicFilter(setting.Topic)) + { + errors.Add($"Bridge {order}: Topic is invalid: {setting.Topic}"); + } + + if (setting.InPrefix.Contains("+") + || setting.InPrefix.Contains("#")) + { + errors.Add($"Bridge {order}: InPrefix must not contain wildcards (+, #)"); + } + + if (setting.OutPrefix.Contains("+") + || setting.OutPrefix.Contains("#")) + { + errors.Add($"Bridge {order}: OutPrefix must not contain wildcards (+, #)"); + } + } + } + + static void ValidateRule(AuthorizationProperties.Rule rule, int order, List errors, string source) { if (rule.Operations.Count == 0) { - errors.Add($"Statement {order}: Allow: Operations list must not be empty"); + errors.Add($"Statement {order}: {source}: Operations list must not be empty"); } - if (rule.Resources.Count == 0) + if (rule.Resources.Count == 0 && !IsConnectOperation(rule)) { - errors.Add($"Statement {order}: Allow: Resources list must not be empty"); + errors.Add($"Statement {order}: {source}: Resources list must not be empty"); } foreach (var operation in rule.Operations) { if (!validOperations.Contains(operation)) { - errors.Add($"Statement {order}: Unknown mqtt operation: {operation}. List of supported operations: mqtt:publish, mqtt:subscribe, mqtt:connect"); + errors.Add($"Statement {order}: {source}: Unknown mqtt operation: {operation}. List of supported operations: mqtt:publish, mqtt:subscribe, mqtt:connect"); } ValidateVariables(operation, order, errors); @@ -85,13 +139,18 @@ private static void ValidateRule(AuthorizationProperties.Rule rule, int order, L if (string.IsNullOrEmpty(resource) || !IsValidTopicFilter(resource)) { - errors.Add($"Statement {order}: Resource (topic filter) is invalid: {resource}"); + errors.Add($"Statement {order}: {source}: Resource (topic filter) is invalid: {resource}"); } ValidateVariables(resource, order, errors); } } + private static bool IsConnectOperation(AuthorizationProperties.Rule rule) + { + return rule.Operations.Count == 1 && rule.Operations[0] == "mqtt:connect"; + } + static void ValidateVariables(string value, int order, List errors) { foreach (var variable in ExtractVariable(value)) diff --git a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/config/BrokerProperties.cs b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/config/BrokerProperties.cs index 6a11dafafc8..341465899fc 100644 --- a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/config/BrokerProperties.cs +++ b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/config/BrokerProperties.cs @@ -1,10 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. namespace Microsoft.Azure.Devices.Edge.Hub.Core.Config { - using System; - using System.Linq; - using Microsoft.Azure.Devices.Edge.Util; - using Microsoft.Azure.Devices.Edge.Util.Json; using Newtonsoft.Json; /// @@ -16,8 +12,8 @@ public class BrokerProperties [JsonConstructor] public BrokerProperties(BridgeConfig bridges, AuthorizationProperties authorizations) { - this.Bridges = bridges; - this.Authorizations = authorizations; + this.Bridges = bridges ?? new BridgeConfig(); + this.Authorizations = authorizations ?? new AuthorizationProperties(); } [JsonProperty(PropertyName = "bridges")] diff --git a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/config/EdgeHubConfigParser.cs b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/config/EdgeHubConfigParser.cs index f4722f3b044..aef1aad9354 100644 --- a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/config/EdgeHubConfigParser.cs +++ b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/config/EdgeHubConfigParser.cs @@ -9,7 +9,8 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Core.Config using Microsoft.Extensions.Logging; /// - /// Creates EdgeHubConfig out of EdgeHubDesiredProperties. + /// Creates EdgeHubConfig out of EdgeHubDesiredProperties. Also validates the + /// desired properties. Throws an exception if validation failed. /// public class EdgeHubConfigParser { @@ -82,6 +83,11 @@ Option ParseBrokerConfig(BrokerProperties properties) /// Option ParseAuthorizationConfig(BrokerProperties properties) { + if (properties.Authorizations.Count == 0) + { + return Option.None(); + } + IList errors = this.validator.ValidateAuthorizationConfig(properties.Authorizations); if (errors.Count > 0) { @@ -118,7 +124,19 @@ Option ParseAuthorizationConfig(BrokerProperties properties Option ParseBridgeConfig(BrokerProperties properties) { - return Option.None(); + if (properties.Bridges.Count == 0) + { + return Option.None(); + } + + IList errors = this.validator.ValidateBridgeConfig(properties.Bridges); + if (errors.Count > 0) + { + string message = string.Join("; ", errors); + throw new InvalidOperationException($"Error validating bridge configuration: {message}"); + } + + return Option.Some(properties.Bridges); } } } diff --git a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/device/DeviceMessageHandler.cs b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/device/DeviceMessageHandler.cs index f765a192f0a..175d7990975 100644 --- a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/device/DeviceMessageHandler.cs +++ b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/device/DeviceMessageHandler.cs @@ -42,6 +42,8 @@ public DeviceMessageHandler(IIdentity identity, IEdgeHub edgeHub, IConnectionMan public IIdentity Identity { get; } + public bool IsDirectClient => this.underlyingProxy.IsDirectClient; + public Task ProcessMethodResponseAsync(IMessage message) { Preconditions.CheckNotNull(message, nameof(message)); diff --git a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/device/IDeviceProxy.cs b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/device/IDeviceProxy.cs index 96346b50519..9d8a216e2f1 100644 --- a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/device/IDeviceProxy.cs +++ b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/device/IDeviceProxy.cs @@ -38,5 +38,7 @@ public interface IDeviceProxy void SetInactive(); Task> GetUpdatedIdentity(); + + bool IsDirectClient { get; } } } diff --git a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/identity/AuthenticationType.cs b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/identity/AuthenticationType.cs index 41fe0b75aa9..8b4277f846f 100644 --- a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/identity/AuthenticationType.cs +++ b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/identity/AuthenticationType.cs @@ -7,6 +7,7 @@ public enum AuthenticationType SasKey, Token, X509Cert, - IoTEdged + IoTEdged, + Implicit } } diff --git a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/identity/ImplicitCredentials.cs b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/identity/ImplicitCredentials.cs new file mode 100644 index 00000000000..4abbe49ea86 --- /dev/null +++ b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Core/identity/ImplicitCredentials.cs @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft. All rights reserved. +namespace Microsoft.Azure.Devices.Edge.Hub.Core.Identity +{ + using Microsoft.Azure.Devices.Edge.Util; + + public class ImplicitCredentials : IClientCredentials + { + public ImplicitCredentials(IIdentity identity, string productInfo, Option modelId) + { + this.Identity = identity; + this.AuthenticationType = AuthenticationType.Implicit; + this.ProductInfo = productInfo; + this.ModelId = modelId; + this.AuthChain = Option.None(); + } + + public IIdentity Identity { get; } + + public AuthenticationType AuthenticationType { get; } + + public string ProductInfo { get; } + + public Option ModelId { get; } + + public Option AuthChain { get; } + } +} diff --git a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Http/controllers/RegistryController.cs b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Http/controllers/RegistryController.cs index 9221c36189c..fbd1aaf47ab 100644 --- a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Http/controllers/RegistryController.cs +++ b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Http/controllers/RegistryController.cs @@ -2,6 +2,7 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Http.Controllers { using System; + using System.Collections.Generic; using System.Net; using System.Text; using System.Threading.Tasks; @@ -148,6 +149,14 @@ public async Task ListModulesAsync( { Events.ReceivedRequest(nameof(this.ListModulesAsync), deviceId); + if (!this.HttpContext.Request.Query.ContainsKey("api-version")) + { + Dictionary headers = new Dictionary(); + headers.Add("iothub-errorcode", "InvalidProtocolVersion"); + await this.SendResponseAsync(HttpStatusCode.BadRequest, headers, string.Empty); + return; + } + try { deviceId = WebUtility.UrlDecode(Preconditions.CheckNonWhiteSpace(deviceId, nameof(deviceId))); @@ -514,9 +523,19 @@ async Task AuthenticateAsync(string deviceId, Option moduleId, Opt } async Task SendResponseAsync(HttpStatusCode status, string jsonContent = "") + { + await this.SendResponseAsync(status, new Dictionary(), jsonContent); + } + + async Task SendResponseAsync(HttpStatusCode status, Dictionary headers, string jsonContent = "") { this.Response.StatusCode = (int)status; + foreach (var header in headers) + { + this.Response.Headers.Add(header.Key, header.Value); + } + if (!string.IsNullOrEmpty(jsonContent)) { var resultUtf8Bytes = Encoding.UTF8.GetBytes(jsonContent); @@ -624,7 +643,7 @@ public static void CompleteRequest(string source, string deviceId, string authCh { Log.LogInformation( (int)EventIds.Authenticated, - $"CompleteRequest in {source}: deviceId={deviceId}, authChain={authChain} {Environment.NewLine} {result.StatusCode}:{result.JsonContent}"); + $"CompleteRequest in {source}: deviceId={deviceId}, authChain={authChain}, status={result.StatusCode}"); } } } diff --git a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Mqtt/DeviceProxy.cs b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Mqtt/DeviceProxy.cs index a8023de9a17..4a9b7bf5332 100644 --- a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Mqtt/DeviceProxy.cs +++ b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Mqtt/DeviceProxy.cs @@ -41,6 +41,8 @@ public DeviceProxy( public bool IsActive => this.isActive.Get(); + public bool IsDirectClient => true; + public Task CloseAsync(Exception ex) { if (this.isActive.GetAndSet(false)) diff --git a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/DeviceProxy.cs b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/DeviceProxy.cs index 1b81c020362..ff4191d1964 100644 --- a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/DeviceProxy.cs +++ b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/DeviceProxy.cs @@ -18,10 +18,11 @@ public class DeviceProxy : IDeviceProxy readonly IModuleToModuleMessageHandler moduleToModuleMessageHandler; readonly ICloud2DeviceMessageHandler cloud2DeviceMessageHandler; readonly IDirectMethodHandler directMethodHandler; - readonly bool isDirectClient; public delegate DeviceProxy Factory(IIdentity identity, bool isDirectClient); + public bool IsDirectClient { get; } + public DeviceProxy( IIdentity identity, bool isDirectClient, @@ -38,7 +39,17 @@ public DeviceProxy( this.cloud2DeviceMessageHandler = Preconditions.CheckNotNull(cloud2DeviceMessageHandler); this.directMethodHandler = Preconditions.CheckNotNull(directMethodHandler); this.isActive = new AtomicBoolean(true); - this.isDirectClient = isDirectClient; + this.IsDirectClient = isDirectClient; + + // when a child edge device connects, it uses $edgeHub identity. + // Although it is a direct client, it uses the indirect topics + if (identity is ModuleIdentity moduleIdentity) + { + if (string.Equals(moduleIdentity.ModuleId, Constants.EdgeHubModuleId)) + { + this.IsDirectClient = false; + } + } Events.Created(this.Identity); } @@ -66,31 +77,31 @@ public Task> GetUpdatedIdentity() public Task InvokeMethodAsync(DirectMethodRequest request) { Events.SendingDirectMethod(this.Identity); - return this.directMethodHandler.CallDirectMethodAsync(request, this.Identity, this.isDirectClient); + return this.directMethodHandler.CallDirectMethodAsync(request, this.Identity, this.IsDirectClient); } public Task OnDesiredPropertyUpdates(IMessage desiredProperties) { Events.SendingDesiredPropertyUpdate(this.Identity); - return this.twinHandler.SendDesiredPropertiesUpdate(desiredProperties, this.Identity, this.isDirectClient); + return this.twinHandler.SendDesiredPropertiesUpdate(desiredProperties, this.Identity, this.IsDirectClient); } public Task SendC2DMessageAsync(IMessage message) { Events.SendingC2DMessage(this.Identity); - return this.cloud2DeviceMessageHandler.SendC2DMessageAsync(message, this.Identity, this.isDirectClient); + return this.cloud2DeviceMessageHandler.SendC2DMessageAsync(message, this.Identity, this.IsDirectClient); } public Task SendMessageAsync(IMessage message, string input) { Events.SendingModuleToModuleMessage(this.Identity); - return this.moduleToModuleMessageHandler.SendModuleToModuleMessageAsync(message, input, this.Identity, this.isDirectClient); + return this.moduleToModuleMessageHandler.SendModuleToModuleMessageAsync(message, input, this.Identity, this.IsDirectClient); } public Task SendTwinUpdate(IMessage twin) { Events.SendingTwinUpdate(this.Identity); - return this.twinHandler.SendTwinUpdate(twin, this.Identity, this.isDirectClient); + return this.twinHandler.SendTwinUpdate(twin, this.Identity, this.IsDirectClient); } public void SetInactive() diff --git a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/ModuleToModuleResponseTimeout.cs b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/ModuleToModuleResponseTimeout.cs new file mode 100644 index 00000000000..0d27906422b --- /dev/null +++ b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/ModuleToModuleResponseTimeout.cs @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft. All rights reserved. +namespace Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter +{ + using System; + + // The role of this class is to help injecting a timeout value without activators + public class ModuleToModuleResponseTimeout + { + TimeSpan timeout; + + public ModuleToModuleResponseTimeout(TimeSpan timeout) + { + this.timeout = timeout; + } + + public static implicit operator TimeSpan(ModuleToModuleResponseTimeout t) => t.timeout; + } +} diff --git a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/authentication/AuthAgentController.cs b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/authentication/AuthAgentController.cs index a3cb05fd91e..b153609bb1d 100644 --- a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/authentication/AuthAgentController.cs +++ b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/authentication/AuthAgentController.cs @@ -149,7 +149,7 @@ async Task AuthenticateAsync(IClientCredentials credentials) static object GetAuthResult(bool isAuthenticated, Option credentials) { // note, that if authenticated, then these values are present, and defaults never apply - var id = credentials.Map(c => c.Identity.Id).GetOrElse("anonymous"); + var id = credentials.Map(c => $"{c.Identity.IotHubHostname}/{c.Identity.Id}").GetOrElse("anonymous"); if (isAuthenticated) { diff --git a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/brokerConnection/IMqttBrokerConnector.cs b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/brokerConnection/IMqttBrokerConnector.cs index bd470991714..5ddb682a427 100644 --- a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/brokerConnection/IMqttBrokerConnector.cs +++ b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/brokerConnection/IMqttBrokerConnector.cs @@ -9,7 +9,7 @@ public interface IMqttBrokerConnector Task ConnectAsync(string serverAddress, int port); Task DisconnectAsync(); - Task SendAsync(string topic, byte[] payload); + Task SendAsync(string topic, byte[] payload, bool retain = false); Task EnsureConnected { get; } } diff --git a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/brokerConnection/MqttBrokerConnector.cs b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/brokerConnection/MqttBrokerConnector.cs index 527b945fa25..b8071e06074 100644 --- a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/brokerConnection/MqttBrokerConnector.cs +++ b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/brokerConnection/MqttBrokerConnector.cs @@ -166,7 +166,7 @@ await clientToStop.ForEachAsync( } } - public async Task SendAsync(string topic, byte[] payload) + public async Task SendAsync(string topic, byte[] payload, bool retain = false) { var client = this.mqttClient.Expect(() => new Exception("No mqtt-bridge connector instance found to send messages.")); @@ -177,7 +177,7 @@ public async Task SendAsync(string topic, byte[] payload) // put into the dictionary next line, causing the ACK being unknown. lock (this.guard) { - var messageId = client.Publish(topic, payload, 1, false); + var messageId = client.Publish(topic, payload, 1, retain); added = this.pendingAcks.TryAdd(messageId, tcs); } diff --git a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/handlers/PolicyUpdateHandler.cs b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/handlers/BrokerConfigUpdateHandler.cs similarity index 65% rename from edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/handlers/PolicyUpdateHandler.cs rename to edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/handlers/BrokerConfigUpdateHandler.cs index 07d6e9ace7b..d8720d0501d 100644 --- a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/handlers/PolicyUpdateHandler.cs +++ b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/handlers/BrokerConfigUpdateHandler.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. namespace Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter { using System; @@ -12,18 +12,22 @@ namespace Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter /// /// Responsible for listening for EdgeHub config updates (twin updates) - /// and pushing new authorization policy definition from EdgeHub config to the Mqtt Broker. + /// and pushing broker config (authorization policy definition and bridge config) + /// from EdgeHub to the Mqtt Broker. /// - public class PolicyUpdateHandler : IMessageProducer + public class BrokerConfigUpdateHandler : IMessageProducer { // !Important: please keep in sync with mqtt-edgehub::command::POLICY_UPDATE_TOPIC - const string Topic = "$internal/authorization/policy"; + const string PolicyUpdateTopic = "$internal/authorization/policy"; + + // !Important: please keep in sync with mqtt-edgehub::command::BRIDGE_UPDATE_TOPIC + const string BridgeUpdateTopic = "$internal/bridge/settings"; readonly Task configSource; IMqttBrokerConnector connector; - public PolicyUpdateHandler(Task configSource) + public BrokerConfigUpdateHandler(Task configSource) { this.configSource = configSource; } @@ -51,12 +55,16 @@ async Task ConfigUpdateHandler(EdgeHubConfig config) { try { - PolicyUpdate update = ConfigToPolicyUpdate(config); + PolicyUpdate policyUpdate = ConfigToPolicyUpdate(config); + BridgeConfig bridgeUpdate = ConfigToBridgeUpdate(config); + + Events.PublishPolicyUpdate(policyUpdate); - Events.PublishPolicyUpdate(update); + var policyPayload = Encoding.UTF8.GetBytes(JsonConvert.SerializeObject(policyUpdate)); + await this.connector.SendAsync(PolicyUpdateTopic, policyPayload, retain: true); - var payload = Encoding.UTF8.GetBytes(JsonConvert.SerializeObject(update)); - await this.connector.SendAsync(Topic, payload); + var bridgePayload = Encoding.UTF8.GetBytes(JsonConvert.SerializeObject(bridgeUpdate)); + await this.connector.SendAsync(BridgeUpdateTopic, bridgePayload, retain: true); } catch (Exception ex) { @@ -74,18 +82,30 @@ static PolicyUpdate ConfigToPolicyUpdate(EdgeHubConfig config) () => GetEmptyPolicy()); } + static BridgeConfig ConfigToBridgeUpdate(EdgeHubConfig config) + { + Option maybeBridgeConfig = config.BrokerConfiguration.FlatMap( + config => config.Bridges); + + return maybeBridgeConfig.Match( + config => config, + () => GetEmptyBridgeConfig()); + } + static PolicyUpdate GetEmptyPolicy() { - return new PolicyUpdate(@" - { - 'statements': [ ] - }"); + return new PolicyUpdate(@"{""statements"": [ ] }"); + } + + static BridgeConfig GetEmptyBridgeConfig() + { + return new BridgeConfig(); } static class Events { const int IdStart = HubCoreEventIds.PolicyUpdateHandler; - static readonly ILogger Log = Logger.Factory.CreateLogger(); + static readonly ILogger Log = Logger.Factory.CreateLogger(); enum EventIds { @@ -95,7 +115,7 @@ enum EventIds internal static void PublishPolicyUpdate(PolicyUpdate update) { - Log.LogDebug((int)EventIds.PublishPolicy, $"Publishing ```{update.Definition}``` to mqtt broker on topic: {Topic}"); + Log.LogDebug((int)EventIds.PublishPolicy, $"Publishing ```{update.Definition}``` to mqtt broker on topic: {BridgeUpdateTopic}"); } internal static void ErrorUpdatingPolicy(Exception ex) diff --git a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/handlers/Cloud2DeviceMessageHandler.cs b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/handlers/Cloud2DeviceMessageHandler.cs index 88b9afeae30..d06315dccf6 100644 --- a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/handlers/Cloud2DeviceMessageHandler.cs +++ b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/handlers/Cloud2DeviceMessageHandler.cs @@ -29,6 +29,12 @@ public Cloud2DeviceMessageHandler(IConnectionRegistry connectionRegistry) public async Task SendC2DMessageAsync(IMessage message, IIdentity identity, bool isDirectClient) { + if (!message.SystemProperties.TryGetValue(SystemProperties.LockToken, out var lockToken)) + { + Events.NoLockToken(identity.Id); + throw new Exception("Cannot send C2D message without lock token"); + } + bool result; try { @@ -51,7 +57,7 @@ public async Task SendC2DMessageAsync(IMessage message, IIdentity identity, bool // TODO: confirming back the message based on the fact that the MQTT broker ACK-ed it. It doesn't mean that the // C2D message has been delivered. Going forward it is a broker responsibility to deliver the message, however if // it crashes, the message will be lost - await this.ConfirmMessageAsync(message, identity); + await this.ConfirmMessageAsync(lockToken, identity); } else { @@ -89,6 +95,7 @@ enum EventIds CouldToDeviceMessageFailed, BadIdentityFormat, CannotSendC2DToModule, + NoLockToken } public static void BadPayloadFormat(Exception e) => Log.LogError((int)EventIds.BadPayloadFormat, e, "Bad payload format: cannot deserialize subscription update"); @@ -97,6 +104,7 @@ enum EventIds public static void CouldToDeviceMessageFailed(string id, int messageLen) => Log.LogError((int)EventIds.CouldToDeviceMessageFailed, $"Failed to send Cloud to Device message to client: {id}, msg len: {messageLen}"); public static void BadIdentityFormat(string identity) => Log.LogError((int)EventIds.BadIdentityFormat, $"Bad identity format: {identity}"); public static void CannotSendC2DToModule(string id) => Log.LogError((int)EventIds.CannotSendC2DToModule, $"Cannot send C2D message to module {id}"); + public static void NoLockToken(string identity) => Log.LogError((int)EventIds.NoLockToken, $"Cannot send C2D message for {identity} because it does not have lock token in its system properties"); } } } diff --git a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/handlers/ConnectionHandler.cs b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/handlers/ConnectionHandler.cs index b295c502698..d1e2e363ce4 100644 --- a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/handlers/ConnectionHandler.cs +++ b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/handlers/ConnectionHandler.cs @@ -23,6 +23,7 @@ public class ConnectionHandler : IConnectionRegistry, IMessageConsumer, IMessage static readonly string[] subscriptions = new[] { TopicDeviceConnected }; readonly Task connectionProviderGetter; + readonly Task authenticatorGetter; readonly IIdentityProvider identityProvider; readonly ISystemComponentIdProvider systemComponentIdProvider; readonly DeviceProxy.Factory deviceProxyFactory; @@ -37,9 +38,10 @@ public class ConnectionHandler : IConnectionRegistry, IMessageConsumer, IMessage // this class is auto-registered so no way to implement an async activator. // hence this one needs to get a Task which is suboptimal, but that is the way // IConnectionProvider is registered - public ConnectionHandler(Task connectionProviderGetter, IIdentityProvider identityProvider, ISystemComponentIdProvider systemComponentIdProvider, DeviceProxy.Factory deviceProxyFactory) + public ConnectionHandler(Task connectionProviderGetter, Task authenticatorGetter, IIdentityProvider identityProvider, ISystemComponentIdProvider systemComponentIdProvider, DeviceProxy.Factory deviceProxyFactory) { this.connectionProviderGetter = Preconditions.CheckNotNull(connectionProviderGetter); + this.authenticatorGetter = Preconditions.CheckNotNull(authenticatorGetter); this.identityProvider = Preconditions.CheckNotNull(identityProvider); this.systemComponentIdProvider = Preconditions.CheckNotNull(systemComponentIdProvider); this.deviceProxyFactory = Preconditions.CheckNotNull(deviceProxyFactory); @@ -148,6 +150,20 @@ async Task RemoveConnectionsAsync(HashSet identitiesRemoved) { foreach (var identity in identitiesRemoved) { + if (this.knownConnections.TryGetValue(identity, out IDeviceListener container)) + { + if (container is IDeviceProxy proxy) + { + // Clients connected indirectly (through a child edge device) will not be reported + // by broker events and appear in the 'identitiesRemoved' list as missing identities. + // Ignore those: + if (!proxy.IsDirectClient) + { + continue; + } + } + } + if (this.knownConnections.TryRemove(identity, out var deviceListener)) { await deviceListener.CloseAsync(); @@ -263,6 +279,13 @@ async Task> CreateDeviceListenerAsync(IIdentity identity return Option.None(); } + if (!directOnCreation) + { + var clientCredentials = new ImplicitCredentials(identity, string.Empty, Option.None()); // TODO obtain prod info/model id + var authenticator = await this.authenticatorGetter; + await authenticator.AuthenticateAsync(clientCredentials); + } + var deviceListener = await this.AddConnectionAsync(identity, directOnCreation, connectionProvider); return Option.Some(deviceListener); } diff --git a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/handlers/MessageConfirmingHandler.cs b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/handlers/MessageConfirmingHandler.cs index 153925bd93b..1f2f38f114c 100644 --- a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/handlers/MessageConfirmingHandler.cs +++ b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/handlers/MessageConfirmingHandler.cs @@ -17,7 +17,7 @@ public class MessageConfirmingHandler public MessageConfirmingHandler(IConnectionRegistry connectionRegistry) => this.connectionRegistry = connectionRegistry; - protected async Task ConfirmMessageAsync(IMessage message, IIdentity identity) + protected async Task ConfirmMessageAsync(string lockToken, IIdentity identity) { var listener = default(IDeviceListener); try @@ -30,10 +30,8 @@ protected async Task ConfirmMessageAsync(IMessage message, IIdentity identity) return; } - var lockToken = "Unknown"; try { - lockToken = message.SystemProperties[SystemProperties.LockToken]; await listener.ProcessMessageFeedbackAsync(lockToken, FeedbackStatus.Complete); } catch (Exception ex) diff --git a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/handlers/ModuleToModuleMessageHandler.cs b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/handlers/ModuleToModuleMessageHandler.cs index 4e0194f507f..8213bc75ca7 100644 --- a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/handlers/ModuleToModuleMessageHandler.cs +++ b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/handlers/ModuleToModuleMessageHandler.cs @@ -2,7 +2,12 @@ namespace Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter { using System; + using System.Collections.Concurrent; using System.Collections.Generic; + using System.Linq; + using System.Text; + using System.Text.RegularExpressions; + using System.Threading; using System.Threading.Tasks; using Microsoft.Azure.Devices.Edge.Hub.Core; using Microsoft.Azure.Devices.Edge.Hub.Core.Device; @@ -10,32 +15,103 @@ namespace Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter using Microsoft.Azure.Devices.Edge.Util; using Microsoft.Extensions.Logging; - public class ModuleToModuleMessageHandler : MessageConfirmingHandler, IModuleToModuleMessageHandler, IMessageProducer + public class ModuleToModuleMessageHandler : MessageConfirmingHandler, IModuleToModuleMessageHandler, IMessageProducer, IMessageConsumer, IDisposable { - const string ModuleToModleSubscriptionPattern = @"^((?(\$edgehub)|(\$iothub)))/(?[^/\+\#]+)/(?[^/\+\#]+)/inputs/\#$"; - const string ModuleToModleTopicTemplate = @"{0}/{1}/{2}/inputs/{3}/{4}"; + const string MessageDelivered = "$edgehub/delivered"; + const string MessageDeliveredSubscription = MessageDelivered + "/#"; + const string ModuleToModleSubscriptionPattern = @"^((?(\$edgehub)|(\$iothub)))/(?[^/\+\#]+)/(?[^/\+\#]+)/\+/inputs/\#$"; + const string FeedbackMessagePattern = @"^\""\$edgehub/(?[^/\+\#]+)/(?[^/\+\#]+)/(?[^/\+\#]+)/inputs/"; + const string ModuleToModleTopicTemplate = @"{0}/{1}/{2}/{3}/inputs/{4}/{5}"; static readonly SubscriptionPattern[] subscriptionPatterns = new SubscriptionPattern[] { new SubscriptionPattern(ModuleToModleSubscriptionPattern, DeviceSubscription.ModuleMessages) }; + readonly Timer timer; + readonly TimeSpan tokenCleanupPeriod; + IMqttBrokerConnector connector; + IIdentityProvider identityProvider; + ConcurrentDictionary pendingMessages = new ConcurrentDictionary(); - public ModuleToModuleMessageHandler(IConnectionRegistry connectionRegistry) + public ModuleToModuleMessageHandler(IConnectionRegistry connectionRegistry, IIdentityProvider identityProvider, ModuleToModuleResponseTimeout responseTimeout) : base(connectionRegistry) { + this.identityProvider = Preconditions.CheckNotNull(identityProvider); + this.tokenCleanupPeriod = responseTimeout; + this.timer = new Timer(this.CleanTokens, null, responseTimeout, responseTimeout); } public void SetConnector(IMqttBrokerConnector connector) => this.connector = connector; public IReadOnlyCollection WatchedSubscriptions => subscriptionPatterns; + public IReadOnlyCollection Subscriptions => new[] { MessageDeliveredSubscription }; + + public async Task HandleAsync(MqttPublishInfo publishInfo) + { + if (publishInfo.Topic.Equals(MessageDelivered)) + { + try + { + var originalTopic = Encoding.UTF8.GetString(publishInfo.Payload); + var match = Regex.Match(originalTopic, FeedbackMessagePattern); + if (match.Success) + { + var id1 = match.Groups["id1"]; + var id2 = match.Groups["id2"]; + var lockToken = match.Groups["token"].Value; + + var identity = this.identityProvider.Create(id1.Value, id2.Value); + + if (this.pendingMessages.TryRemove(lockToken, out var _)) + { + await this.ConfirmMessageAsync(lockToken, identity); + } + else + { + Events.CannotFindMessageToConfirm(identity.Id); + } + } + } + catch (Exception ex) + { + Events.CannotDecodeConfirmation(ex); + } + + return true; + } + else + { + return false; + } + } + public async Task SendModuleToModuleMessageAsync(IMessage message, string input, IIdentity identity, bool isDirectClient) { + if (!message.SystemProperties.TryGetValue(SystemProperties.LockToken, out var currentLockToken)) + { + Events.NoLockToken(identity.Id); + throw new ArgumentException("Cannot send M2M message without lock token"); + } + bool result; try { + var currentTime = DateTime.UtcNow; + var overwrittenLockTokenDate = Option.None(); + this.pendingMessages.AddOrUpdate( + currentLockToken, + currentTime, + (i, t) => + { + overwrittenLockTokenDate = Option.Some(t); + return currentTime; + }); + + overwrittenLockTokenDate.ForEach(t => Events.OverwritingPendingMessage(identity.Id, currentLockToken, t)); + var topicPrefix = isDirectClient ? MqttBrokerAdapterConstants.DirectTopicPrefix : MqttBrokerAdapterConstants.IndirectTopicPrefix; var propertyBag = GetPropertyBag(message); result = await this.connector.SendAsync( - GetMessageToMessageTopic(identity, input, propertyBag, topicPrefix), + GetMessageToMessageTopic(identity, input, propertyBag, topicPrefix, currentLockToken), message.Body); } catch (Exception e) @@ -47,32 +123,51 @@ public async Task SendModuleToModuleMessageAsync(IMessage message, string input, if (result) { Events.ModuleToModuleMessage(identity.Id, message.Body.Length); - - // TODO: confirming back the message based on the fact that the MQTT broker ACK-ed it. It doesn't mean that the - // M2M message has been delivered. Going forward it is a broker responsibility to deliver the message, however if - // it crashes, the message will be lost - await this.ConfirmMessageAsync(message, identity); } else { + this.pendingMessages.TryRemove(currentLockToken, out var _); Events.ModuleToModuleMessageFailed(identity.Id, message.Body.Length); } } - static string GetMessageToMessageTopic(IIdentity identity, string input, string propertyBag, string topicPrefix) + public void Dispose() + { + this.timer.Dispose(); + } + + void CleanTokens(object _) + { + var now = DateTime.UtcNow; + var keys = this.pendingMessages.Keys.ToArray(); + + foreach (var key in keys) + { + if (this.pendingMessages.TryGetValue(key, out DateTime issued)) + { + if (now - issued > this.tokenCleanupPeriod) + { + Events.RemovingExpiredToken(key); + this.pendingMessages.TryRemove(key, out var _); + } + } + } + } + + static string GetMessageToMessageTopic(IIdentity identity, string input, string propertyBag, string topicPrefix, string lockToken) { switch (identity) { case IDeviceIdentity deviceIdentity: Events.CannotSendM2MToDevice(identity.Id); - throw new Exception($"Cannot send Module To Module message to {identity.Id}, because it is not a module but a device"); + throw new ArgumentException($"Cannot send Module To Module message to {identity.Id}, because it is not a module but a device"); case IModuleIdentity moduleIdentity: - return string.Format(ModuleToModleTopicTemplate, topicPrefix, moduleIdentity.DeviceId, moduleIdentity.ModuleId, input, propertyBag); + return string.Format(ModuleToModleTopicTemplate, topicPrefix, moduleIdentity.DeviceId, moduleIdentity.ModuleId, lockToken, input, propertyBag); default: Events.BadIdentityFormat(identity.Id); - throw new Exception($"cannot decode identity {identity.Id}"); + throw new ArgumentException($"cannot decode identity {identity.Id}"); } } @@ -89,6 +184,11 @@ enum EventIds ModuleToModuleMessageFailed, BadIdentityFormat, CannotSendM2MToDevice, + CannotDecodeConfirmation, + OverwritingPendingMessage, + CannotFindMessageToConfirm, + NoLockToken, + RemovingExpiredToken } public static void BadPayloadFormat(Exception e) => Log.LogError((int)EventIds.BadPayloadFormat, e, "Bad payload format: cannot deserialize subscription update"); @@ -97,6 +197,11 @@ enum EventIds public static void ModuleToModuleMessageFailed(string id, int messageLen) => Log.LogError((int)EventIds.ModuleToModuleMessageFailed, $"Failed to send Module to Module message to client: {id}, msg len: {messageLen}"); public static void BadIdentityFormat(string identity) => Log.LogError((int)EventIds.BadIdentityFormat, $"Bad identity format: {identity}"); public static void CannotSendM2MToDevice(string id) => Log.LogError((int)EventIds.CannotSendM2MToDevice, $"Cannot send Module to Module message to device {id}"); + public static void CannotDecodeConfirmation(Exception e) => Log.LogError((int)EventIds.CannotDecodeConfirmation, e, $"Cannot decode Module to Module message confirmation"); + public static void OverwritingPendingMessage(string identity, string messageId, DateTime time) => Log.LogWarning((int)EventIds.OverwritingPendingMessage, $"New M2M message is being sent for {identity} with msg id {messageId}, but it has been sent already with the same id at {time}"); + public static void CannotFindMessageToConfirm(string identity) => Log.LogWarning((int)EventIds.CannotFindMessageToConfirm, $"M2M confirmation has received for {identity} but no message can be found"); + public static void NoLockToken(string identity) => Log.LogError((int)EventIds.NoLockToken, $"Cannot send M2M message for {identity} because it does not have lock token in its system properties"); + public static void RemovingExpiredToken(string token) => Log.LogWarning((int)EventIds.RemovingExpiredToken, $"M2M confirmation has not been received for token {token}, removing"); } } } diff --git a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/handlers/ScopeIdentitiesHandler.cs b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/handlers/ScopeIdentitiesHandler.cs index 2fdc85768ef..45c7f50e5e2 100644 --- a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/handlers/ScopeIdentitiesHandler.cs +++ b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/handlers/ScopeIdentitiesHandler.cs @@ -88,7 +88,7 @@ async Task PublishBrokerServiceIdentities(IList brokerSer try { var payload = Encoding.UTF8.GetBytes(JsonConvert.SerializeObject(brokerServiceIdentities)); - await this.connector.SendAsync(Topic, payload); + await this.connector.SendAsync(Topic, payload, retain: true); } catch (Exception ex) { diff --git a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/handlers/SubscriptionChangeHandler.cs b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/handlers/SubscriptionChangeHandler.cs index f797b851a1f..29166eceec6 100644 --- a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/handlers/SubscriptionChangeHandler.cs +++ b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/handlers/SubscriptionChangeHandler.cs @@ -120,8 +120,6 @@ public async Task HandleAsync(MqttPublishInfo publishInfo) { Events.CouldNotObtainListener(subscriptionPattern.Subscrition.ToString(), identity.Id); } - - break; } } } diff --git a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/upstream/BrokeredCloudConnectionProvider.cs b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/upstream/BrokeredCloudConnectionProvider.cs index 779b25e465d..50ddcf43687 100644 --- a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/upstream/BrokeredCloudConnectionProvider.cs +++ b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/upstream/BrokeredCloudConnectionProvider.cs @@ -2,8 +2,6 @@ namespace Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter { using System; - using System.Collections.Generic; - using System.Text; using System.Threading.Tasks; using Microsoft.Azure.Devices.Edge.Hub.Core; using Microsoft.Azure.Devices.Edge.Hub.Core.Cloud; @@ -31,8 +29,7 @@ public Task> Connect(IClientCredentials clientCredentials, public Task> Connect(IIdentity identity, Action connectionStatusChangedHandler) { - // TODO: the connectionStatusChangeHandler is not wired - var cloudProxy = new BrokeredCloudProxy(identity, this.cloudProxyDispatcher); + var cloudProxy = new BrokeredCloudProxy(identity, this.cloudProxyDispatcher, connectionStatusChangedHandler); return Task.FromResult(new Try(new BrokeredCloudConnection(cloudProxy))); } } diff --git a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/upstream/BrokeredCloudProxy.cs b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/upstream/BrokeredCloudProxy.cs index 1718eb23856..ced11ca99c2 100644 --- a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/upstream/BrokeredCloudProxy.cs +++ b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/upstream/BrokeredCloudProxy.cs @@ -15,14 +15,18 @@ public class BrokeredCloudProxy : ICloudProxy { BrokeredCloudProxyDispatcher cloudProxyDispatcher; IIdentity identity; + Action connectionStatusChangedHandler; AtomicBoolean isActive = new AtomicBoolean(true); AtomicBoolean twinNeedsSubscribe = new AtomicBoolean(true); - public BrokeredCloudProxy(IIdentity identity, BrokeredCloudProxyDispatcher cloudProxyDispatcher) + public BrokeredCloudProxy(IIdentity identity, BrokeredCloudProxyDispatcher cloudProxyDispatcher, Action connectionStatusChangedHandler) { this.identity = Preconditions.CheckNotNull(identity); this.cloudProxyDispatcher = Preconditions.CheckNotNull(cloudProxyDispatcher); + + this.connectionStatusChangedHandler = connectionStatusChangedHandler; + this.cloudProxyDispatcher.ConnectionStatusChangedEvent += this.ConnectionChangedEventHandler; } public bool IsActive => this.isActive; @@ -30,6 +34,8 @@ public BrokeredCloudProxy(IIdentity identity, BrokeredCloudProxyDispatcher cloud public Task CloseAsync() { this.isActive.Set(false); + this.cloudProxyDispatcher.ConnectionStatusChangedEvent -= this.ConnectionChangedEventHandler; + return Task.FromResult(true); } @@ -52,5 +58,10 @@ public Task GetTwinAsync() { return this.cloudProxyDispatcher.GetTwinAsync(this.identity, this.twinNeedsSubscribe.GetAndSet(false)); } + + void ConnectionChangedEventHandler(CloudConnectionStatus cloudConnectionStatus) + { + this.connectionStatusChangedHandler(this.identity.Id, cloudConnectionStatus); + } } } diff --git a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/upstream/BrokeredCloudProxyDispatcher.cs b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/upstream/BrokeredCloudProxyDispatcher.cs index 531815ae0cb..a679d42b112 100644 --- a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/upstream/BrokeredCloudProxyDispatcher.cs +++ b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter/upstream/BrokeredCloudProxyDispatcher.cs @@ -4,12 +4,15 @@ namespace Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter using System; using System.Collections.Concurrent; using System.Collections.Generic; + using System.Dynamic; using System.IO; + using System.Text; using System.Text.RegularExpressions; using System.Threading; using System.Threading.Tasks; using Microsoft.Azure.Devices.Client.Common; using Microsoft.Azure.Devices.Edge.Hub.Core; + using Microsoft.Azure.Devices.Edge.Hub.Core.Cloud; using Microsoft.Azure.Devices.Edge.Hub.Core.Identity; using Microsoft.Azure.Devices.Edge.Util; using Microsoft.Azure.Devices.Edge.Util.Concurrency; @@ -26,17 +29,20 @@ public class BrokeredCloudProxyDispatcher : IMessageConsumer, IMessageProducer const string GetTwinTemplate = "$iothub/{0}/twin/get/?$rid={1}"; const string UpdateReportedTemplate = "$iothub/{0}/twin/reported/?$rid={1}"; const string DirectMethodResponseTemplate = "$iothub/{0}/methods/res/{1}/?$rid={2}"; - const string RpcTopicTemplate = "$edgehub/rpc/{0}"; + const string RpcTopicTemplate = "$upstream/rpc/{0}"; const string RpcVersion = "v1"; const string RpcCmdSub = "sub"; const string RpcCmdUnsub = "unsub"; const string RpcCmdPub = "pub"; - const string RpcAckPattern = @"^\$downstream/rpc/ack/(?[^/\+\#]+)/(?[^/\+\#]+)"; + const string RpcAckPattern = @"^\$downstream/rpc/ack/(?[^/\+\#]+)"; const string TwinGetResponsePattern = @"^\$downstream/(?[^/\+\#]+)(/(?[^/\+\#]+))?/twin/res/(?.+)/\?\$rid=(?.+)"; const string TwinSubscriptionForPatchPattern = @"^\$downstream/(?[^/\+\#]+)(/(?[^/\+\#]+))?/twin/desired/\?\$version=(?.+)"; - const string MethodCallPattern = @"^\$downstream/(?[^/\+\#]+)/methods/post/(?[^/\+\#]+)/\?\$rid=(?.+)"; + const string MethodCallPattern = @"^\$downstream/(?[^/\+\#]+)(/(?[^/\+\#]+))?/methods/post/(?[^/\+\#]+)/\?\$rid=(?.+)"; + + const string DownstreamTopic = "$downstream/#"; + const string ConnectivityTopic = "$internal/connectivity"; readonly TimeSpan responseTimeout = TimeSpan.FromSeconds(30); // TODO should come from configuration readonly byte[] emptyArray = new byte[0]; @@ -49,7 +55,9 @@ public class BrokeredCloudProxyDispatcher : IMessageConsumer, IMessageProducer ConcurrentDictionary> pendingRpcs = new ConcurrentDictionary>(); ConcurrentDictionary> pendingTwinRequests = new ConcurrentDictionary>(); - public IReadOnlyCollection Subscriptions => new string[] { "$downstream/#" }; + public event Action ConnectionStatusChangedEvent; + + public IReadOnlyCollection Subscriptions => new string[] { DownstreamTopic, ConnectivityTopic }; public void BindEdgeHub(IEdgeHub edgeHub) { @@ -65,7 +73,7 @@ public Task HandleAsync(MqttPublishInfo publishInfo) var match = Regex.Match(publishInfo.Topic, RpcAckPattern); if (match.Success) { - this.HandleRpcAck(match.Groups["guid"].Value, match.Groups["cmd"].Value); + this.HandleRpcAck(match.Groups["guid"].Value); return Task.FromResult(true); } @@ -89,6 +97,12 @@ public Task HandleAsync(MqttPublishInfo publishInfo) this.HandleDirectMethodCall(this.GetIdFromMatch(match), match.Groups["mname"].Value, match.Groups["rid"].Value, publishInfo.Payload); return Task.FromResult(true); } + + if (ConnectivityTopic.Equals(publishInfo.Topic)) + { + this.HandleConnectivityEvent(publishInfo.Payload); + return Task.FromResult(true); + } } catch (Exception e) { @@ -210,35 +224,21 @@ public async Task GetTwinAsync(IIdentity identity, bool needSubscribe) return taskCompletion.Task.Result; } - void HandleRpcAck(string ackedGuid, string cmd) + void HandleRpcAck(string ackedGuid) { - switch (cmd) + if (!Guid.TryParse(ackedGuid, out var guid)) { - case RpcCmdPub: - case RpcCmdSub: - case RpcCmdUnsub: - { - if (!Guid.TryParse(ackedGuid, out var guid)) - { - Events.CannotParseGuid(ackedGuid); - return; - } - - if (this.pendingRpcs.TryRemove(guid, out var tsc)) - { - tsc.SetResult(true); - } - else - { - Events.CannotFindGuid(ackedGuid); - } - } - - break; + Events.CannotParseGuid(ackedGuid); + return; + } - default: - Events.UnknownAckType(cmd, ackedGuid); - break; + if (this.pendingRpcs.TryRemove(guid, out var tsc)) + { + tsc.SetResult(true); + } + else + { + Events.CannotFindGuid(ackedGuid); } } @@ -403,6 +403,59 @@ byte[] GetRpcPayload(string command, string topic, byte[] payload) return stream.ToArray(); } + void HandleConnectivityEvent(byte[] payload) + { + try + { + var payloadAsString = Encoding.UTF8.GetString(payload); + var connectivityEvent = JsonConvert.DeserializeObject(payloadAsString) as IDictionary; + + var status = default(object); + if (connectivityEvent.TryGetValue("status", out status)) + { + var statusAsString = status as string; + if (statusAsString != null) + { + switch (statusAsString) + { + case "Connected": + this.CallConnectivityHandlers(true); + break; + + case "Disconnected": + this.CallConnectivityHandlers(false); + break; + + default: + break; + } + } + } + } + catch (Exception ex) + { + Events.ErrorParsingConnectivityEvent(ex); + } + } + + void CallConnectivityHandlers(bool isConnected) + { + var currentHandlers = this.ConnectionStatusChangedEvent.GetInvocationList(); + + foreach (var handler in currentHandlers) + { + try + { + handler.DynamicInvoke(isConnected ? CloudConnectionStatus.ConnectionEstablished : CloudConnectionStatus.Disconnected); + } + catch (Exception ex) + { + // ignore and go on + Events.ErrorDispatchingConnectivityEvent(ex); + } + } + } + long GetRid() => Interlocked.Increment(ref this.lastRid); static string GetPropertyBag(IMessage message) @@ -453,14 +506,15 @@ enum EventIds ErrorHandlingDownstreamMessage, CannotParseGuid, CannotFindGuid, - UnknownAckType, SendingOpenNotificationForClient, SentOpenNotificationForClient, ErrorSendingOpenNotificationForClient, SendingCloseNotificationForClient, SentCloseNotificationForClient, - ErrorSendingCloseNotificationForClient + ErrorSendingCloseNotificationForClient, + ErrorParsingConnectivityEvent, + ErrorDispatchingConnectivityEvent } public static void GettingTwin(string id, long rid) => Log.LogDebug((int)EventIds.GettingTwin, $"Getting twin for client: {id} with request id: {rid}"); @@ -492,6 +546,7 @@ enum EventIds public static void ErrorHandlingDownstreamMessage(string topic, Exception e) => Log.LogError((int)EventIds.ErrorHandlingDownstreamMessage, e, $"Error handling downstream message on topic: {topic}"); public static void CannotParseGuid(string guid) => Log.LogError((int)EventIds.CannotParseGuid, $"Cannot parse guid: {guid}"); public static void CannotFindGuid(string guid) => Log.LogError((int)EventIds.CannotFindGuid, $"Cannot find guid to ACK: {guid}"); - public static void UnknownAckType(string cmd, string guid) => Log.LogError((int)EventIds.UnknownAckType, $"Unknown ack type: {cmd} with guid: {guid}"); + public static void ErrorParsingConnectivityEvent(Exception ex) => Log.LogError((int)EventIds.ErrorParsingConnectivityEvent, ex, "Error parsing connectivity event"); + public static void ErrorDispatchingConnectivityEvent(Exception ex) => Log.LogError((int)EventIds.ErrorDispatchingConnectivityEvent, ex, "Error dispatching connectivity event"); } } diff --git a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Service/DependencyManager.cs b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Service/DependencyManager.cs index 18199557090..44b1cd7821d 100644 --- a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Service/DependencyManager.cs +++ b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Service/DependencyManager.cs @@ -209,7 +209,11 @@ void RegisterRoutingModule( TimeSpan configUpdateFrequency = TimeSpan.FromSeconds(configUpdateFrequencySecs); bool checkEntireQueueOnCleanup = this.configuration.GetValue("CheckEntireQueueOnCleanup", false); bool closeCloudConnectionOnDeviceDisconnect = this.configuration.GetValue("CloseCloudConnectionOnDeviceDisconnect", true); - bool isLegacyUpstream = this.configuration.GetValue("mqttBrokerSettings:legacyUpstream", true); + + bool isLegacyUpstream = !experimentalFeatures.Enabled + || !experimentalFeatures.EnableMqttBroker + || !experimentalFeatures.EnableNestedEdge + || !this.GetConfigurationValueIfExists(Constants.ConfigKey.GatewayHostname).HasValue; builder.RegisterModule( new RoutingModule( diff --git a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Service/Program.cs b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Service/Program.cs index 8735fae1e58..20b96ec6818 100644 --- a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Service/Program.cs +++ b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Service/Program.cs @@ -97,7 +97,10 @@ static async Task MainAsync(IConfigurationRoot configuration) logger.LogInformation("Initializing configuration"); IConfigSource configSource = await container.Resolve>(); ConfigUpdater configUpdater = await container.Resolve>(); - await configUpdater.Init(configSource); + var configUpdaterStartupFailed = new TaskCompletionSource(); + _ = configUpdater.Init(configSource).ContinueWith( + _ => configUpdaterStartupFailed.SetResult(false), + TaskContinuationOptions.OnlyOnFaulted); if (!Enum.TryParse(configuration.GetValue("AuthenticationMode", string.Empty), true, out AuthenticationMode authenticationMode) || authenticationMode != AuthenticationMode.Cloud) @@ -115,7 +118,7 @@ static async Task MainAsync(IConfigurationRoot configuration) try { await protocolHead.StartAsync(); - await Task.WhenAny(cts.Token.WhenCanceled(), renewal.Token.WhenCanceled()); + await Task.WhenAny(cts.Token.WhenCanceled(), renewal.Token.WhenCanceled(), configUpdaterStartupFailed.Task); } catch (Exception ex) { diff --git a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Service/appsettings_hub.json b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Service/appsettings_hub.json index 8529398fe17..4643745cb7a 100644 --- a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Service/appsettings_hub.json +++ b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Service/appsettings_hub.json @@ -21,8 +21,6 @@ "TableQos2StatePersistenceProvider.StorageTableName": "mqttqos2" }, "mqttBrokerSettings": { - "legacyUpstream": true, - "enabled": true, "port": 1882, "url": "127.0.0.1" }, diff --git a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Service/modules/MqttBrokerModule.cs b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Service/modules/MqttBrokerModule.cs index b905d7168d0..b358d4e8572 100644 --- a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Service/modules/MqttBrokerModule.cs +++ b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Service/modules/MqttBrokerModule.cs @@ -74,6 +74,13 @@ protected override void Load(ContainerBuilder builder) .As() .SingleInstance(); + // The purpose of this setting is to setup a cleanup timer throwing away unanswered message tokens to + // prevent memory leak. Giving a big enough multiplier to avoid deleting tokens in use, but also + // not to spin the clean cycle too much, even if the timeout value is short + var ackTimeout = Math.Max(this.config.GetValue("MessageAckTimeoutSecs", 30), 30); + builder.RegisterInstance(new ModuleToModuleResponseTimeout(TimeSpan.FromSeconds(ackTimeout * 10))) + .SingleInstance(); + base.Load(builder); } diff --git a/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.Core.Test/config/BrokerPropertiesValidatorTest.cs b/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.Core.Test/config/BrokerPropertiesValidatorTest.cs index 9a3acb0d346..87a90f35293 100644 --- a/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.Core.Test/config/BrokerPropertiesValidatorTest.cs +++ b/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.Core.Test/config/BrokerPropertiesValidatorTest.cs @@ -18,7 +18,7 @@ public void ValidateAuthorizationConfig_ValidInput() var authzProperties = properties.BrokerConfiguration.Authorizations; - IList errors = validator.ValidateAuthorizationConfig(properties.BrokerConfiguration.Authorizations); + IList errors = validator.ValidateAuthorizationConfig(authzProperties); Assert.Equal(0, errors.Count); } @@ -34,14 +34,17 @@ public void ValidateAuthorizationConfig_EmptyElements() // arrange some errors authzProperties[0].Identities[0] = string.Empty; - authzProperties[1].Allow[0].Operations.RemoveAt(0); - authzProperties[1].Allow[0].Operations.RemoveAt(0); + authzProperties[0].Deny[0].Resources.Clear(); + authzProperties[1].Identities.Clear(); + authzProperties[1].Allow[0].Operations.Clear(); - IList errors = validator.ValidateAuthorizationConfig(properties.BrokerConfiguration.Authorizations); + IList errors = validator.ValidateAuthorizationConfig(authzProperties); - Assert.Equal(2, errors.Count); + Assert.Equal(4, errors.Count); Assert.Equal("Statement 0: Identity name is invalid: ", errors[0]); - Assert.Equal("Statement 1: Allow: Operations list must not be empty", errors[1]); + Assert.Equal("Statement 0: Deny: Resources list must not be empty", errors[1]); + Assert.Equal("Statement 1: Identities list must not be empty", errors[2]); + Assert.Equal("Statement 1: Allow: Operations list must not be empty", errors[3]); } [Fact] @@ -56,11 +59,11 @@ public void ValidateAuthorizationConfig_InvalidOperation() // arrange some errors authzProperties[0].Deny[0].Operations[0] = "invalid"; - IList errors = validator.ValidateAuthorizationConfig(properties.BrokerConfiguration.Authorizations); + IList errors = validator.ValidateAuthorizationConfig(authzProperties); Assert.Equal(1, errors.Count); Assert.Equal( - "Statement 0: Unknown mqtt operation: invalid. " + "Statement 0: Deny: Unknown mqtt operation: invalid. " + "List of supported operations: mqtt:publish, mqtt:subscribe, mqtt:connect", errors[0]); } @@ -78,11 +81,11 @@ public void ValidateAuthorizationConfig_InvalidTopicFilters() authzProperties[0].Deny[0].Resources[0] = "topic/#/"; authzProperties[1].Allow[0].Resources[0] = "topic+"; - IList errors = validator.ValidateAuthorizationConfig(properties.BrokerConfiguration.Authorizations); + IList errors = validator.ValidateAuthorizationConfig(authzProperties); Assert.Equal(2, errors.Count); - Assert.Equal("Statement 0: Resource (topic filter) is invalid: topic/#/", errors[0]); - Assert.Equal("Statement 1: Resource (topic filter) is invalid: topic+", errors[1]); + Assert.Equal("Statement 0: Deny: Resource (topic filter) is invalid: topic/#/", errors[0]); + Assert.Equal("Statement 1: Allow: Resource (topic filter) is invalid: topic+", errors[1]); } [Fact] @@ -98,12 +101,85 @@ public void ValidateAuthorizationConfig_InvalidVariableNames() authzProperties[0].Identities[0] = "{{anywhat}}"; authzProperties[1].Allow[0].Resources[0] = "topic/{{invalid}}/{{myothervar}}"; - IList errors = validator.ValidateAuthorizationConfig(properties.BrokerConfiguration.Authorizations); + IList errors = validator.ValidateAuthorizationConfig(authzProperties); Assert.Equal(3, errors.Count); Assert.Equal("Statement 0: Invalid variable name: {{anywhat}}", errors[0]); Assert.Equal("Statement 1: Invalid variable name: {{invalid}}", errors[1]); Assert.Equal("Statement 1: Invalid variable name: {{myothervar}}", errors[2]); } + + [Fact] + public void ValidateAuthorizationConfig_EmptyResourceAllowedForConnectOperation() + { + var validator = new BrokerPropertiesValidator(); + + EdgeHubDesiredProperties properties = ConfigTestData.GetTestData(); + + var authzProperties = properties.BrokerConfiguration.Authorizations; + + // arrange connect op with no resources. + authzProperties[0].Deny[0].Operations.Clear(); + authzProperties[0].Deny[0].Operations.Insert(0, "mqtt:connect"); + authzProperties[0].Deny[0].Resources.Clear(); + + IList errors = validator.ValidateAuthorizationConfig(authzProperties); + + Assert.Equal(0, errors.Count); + } + + [Fact] + public void ValidateBridgeConfig_ValidInput() + { + var validator = new BrokerPropertiesValidator(); + + EdgeHubDesiredProperties properties = ConfigTestData.GetTestData(); + + IList errors = validator.ValidateBridgeConfig(properties.BrokerConfiguration.Bridges); + + Assert.Equal(0, errors.Count); + } + + [Fact] + public void ValidateBridgeConfig_EmptyElements() + { + var validator = new BrokerPropertiesValidator(); + + var bridgeConfig = new BridgeConfig + { + new Bridge(string.Empty, new List + { + new Settings(Direction.In, string.Empty, string.Empty, string.Empty) + }), + new Bridge("floor2", new List { }) + }; + + IList errors = validator.ValidateBridgeConfig(bridgeConfig); + + Assert.Equal(2, errors.Count); + Assert.Equal("Bridge 0: Endpoint must not be empty", errors[0]); + Assert.Equal("Bridge 1: Settings must not be empty", errors[1]); + } + + [Fact] + public void ValidateBridgeConfig_InvalidTopicOrPrefix() + { + var validator = new BrokerPropertiesValidator(); + + var bridgeConfig = new BridgeConfig + { + new Bridge("$upstream", new List + { + new Settings(Direction.In, "topic/#/a", "local/#", "remote/+/") + }) + }; + + IList errors = validator.ValidateBridgeConfig(bridgeConfig); + + Assert.Equal(3, errors.Count); + Assert.Equal("Bridge 0: Topic is invalid: topic/#/a", errors[0]); + Assert.Equal("Bridge 0: InPrefix must not contain wildcards (+, #)", errors[1]); + Assert.Equal("Bridge 0: OutPrefix must not contain wildcards (+, #)", errors[2]); + } } } diff --git a/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.Core.Test/config/ConfigTestData.cs b/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.Core.Test/config/ConfigTestData.cs index 7a7cdae552f..5d8c39f5507 100644 --- a/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.Core.Test/config/ConfigTestData.cs +++ b/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.Core.Test/config/ConfigTestData.cs @@ -73,7 +73,19 @@ public static EdgeHubDesiredProperties GetTestData() var authzProperties = new AuthorizationProperties { statement1, statement2 }; - var brokerProperties = new BrokerProperties(new BridgeConfig(), authzProperties); + var bridgeConfig = new BridgeConfig + { + new Bridge("$upstream", new List + { + new Settings(Direction.In, "topic/a", "local/", "remote/") + }), + new Bridge("floor2", new List + { + new Settings(Direction.Out, "/topic/b", "local", "remote") + }) + }; + + var brokerProperties = new BrokerProperties(bridgeConfig, authzProperties); var properties = new EdgeHubDesiredProperties( "1.2.0", new Dictionary(), diff --git a/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.Core.Test/config/EdgeHubConfigParserTest.cs b/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.Core.Test/config/EdgeHubConfigParserTest.cs index a51b3ea8b80..e3634dc3549 100644 --- a/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.Core.Test/config/EdgeHubConfigParserTest.cs +++ b/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.Core.Test/config/EdgeHubConfigParserTest.cs @@ -20,6 +20,9 @@ public void GetEdgeHubConfig_ValidInput_MappingIsCorrect() validator .Setup(v => v.ValidateAuthorizationConfig(It.IsAny())) .Returns(new List()); + validator + .Setup(v => v.ValidateBridgeConfig(It.IsAny())) + .Returns(new List()); var routeFactory = new EdgeRouteFactory(new Mock().Object); var configParser = new EdgeHubConfigParser(routeFactory, validator.Object); @@ -77,17 +80,59 @@ public void GetEdgeHubConfig_ValidInput_MappingIsCorrect() } [Fact] - public void GetEdgeHubConfig_ValidatorReturnsError_ExpectedException() + public void GetEdgeHubConfig_AuthorizationValidatorReturnsError_ExpectedException() + { + var validator = new Mock(); + validator + .Setup(v => v.ValidateAuthorizationConfig(It.IsAny())) + .Returns(new List { "Validation error has occurred" }); + + var routeFactory = new EdgeRouteFactory(new Mock().Object); + var configParser = new EdgeHubConfigParser(routeFactory, validator.Object); + + var authzProperties = new AuthorizationProperties + { + new AuthorizationProperties.Statement( + identities: new List + { + "device_1", + "device_3" + }, + allow: new List(), + deny: new List()) + }; + + var brokerProperties = new BrokerProperties(new BridgeConfig(), authzProperties); + var properties = new EdgeHubDesiredProperties( + "1.2.0", + new Dictionary(), + new StoreAndForwardConfiguration(100), + brokerProperties); + + // assert + Assert.Throws(() => configParser.GetEdgeHubConfig(properties)); + } + + [Fact] + public void GetEdgeHubConfig_BridgeValidatorReturnsError_ExpectedException() { var validator = new Mock(); validator .Setup(v => v.ValidateAuthorizationConfig(It.IsAny())) + .Returns(new List()); + validator + .Setup(v => v.ValidateBridgeConfig(It.IsAny())) .Returns(new List { "Validation error has occurred" }); var routeFactory = new EdgeRouteFactory(new Mock().Object); var configParser = new EdgeHubConfigParser(routeFactory, validator.Object); - var brokerProperties = new BrokerProperties(new BridgeConfig(), new AuthorizationProperties()); + var bridgeConfig = new BridgeConfig + { + new Bridge("floor2", new List { }) + }; + + var brokerProperties = new BrokerProperties(bridgeConfig, new AuthorizationProperties()); var properties = new EdgeHubDesiredProperties( "1.2.0", new Dictionary(), diff --git a/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.Core.Test/config/EdgeHubConfigTest.cs b/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.Core.Test/config/EdgeHubConfigTest.cs index 21ce1dc7e07..cd4d83e899d 100644 --- a/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.Core.Test/config/EdgeHubConfigTest.cs +++ b/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.Core.Test/config/EdgeHubConfigTest.cs @@ -59,6 +59,30 @@ public static IEnumerable GetEdgeHubConfigData() var storeAndForwardConfig5 = new StoreAndForwardConfiguration(3600, s2); var storeAndForwardConfig6 = new StoreAndForwardConfiguration(3600); + var bridge1 = new BridgeConfig + { + new Bridge("endpoint1", new List + { + new Settings(Direction.In, "topic/a", "local/", "remote/") + }) + }; + + var bridge2 = new BridgeConfig + { + new Bridge("endpoint1", new List + { + new Settings(Direction.Out, "/topic/b", "local", "remote") + }) + }; + + var bridge3 = new BridgeConfig + { + new Bridge("endpoint1", new List + { + new Settings(Direction.In, "topic/a", "local/", "remote/") + }) + }; + var statement1 = new Statement( effect: Effect.Allow, identities: new List @@ -110,13 +134,13 @@ public static IEnumerable GetEdgeHubConfigData() }); var brokerConfig1 = new BrokerConfig( - Option.None(), + Option.Some(bridge1), Option.Some(new AuthorizationConfig(new List { statement1 }))); var brokerConfig2 = new BrokerConfig( - Option.None(), + Option.Some(bridge2), Option.Some(new AuthorizationConfig(new List { statement2 }))); var brokerConfig3 = new BrokerConfig( - Option.None(), + Option.Some(bridge3), Option.Some(new AuthorizationConfig(new List { statement3 }))); string version = "1.0"; diff --git a/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.Core.Test/config/EdgeHubDesiredPropertiesTest.cs b/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.Core.Test/config/EdgeHubDesiredPropertiesTest.cs index c4114f6c24b..3baedde45bc 100644 --- a/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.Core.Test/config/EdgeHubDesiredPropertiesTest.cs +++ b/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.Core.Test/config/EdgeHubDesiredPropertiesTest.cs @@ -469,5 +469,52 @@ public void AuthorizationsTest() Assert.Single(authConfig[0].Deny[0].Resources); Assert.Equal("/alert/#", authConfig[0].Deny[0].Resources[0]); } + + [Fact] + public void BridgeTest() + { + string properties = + @"{ + 'schemaVersion': '1.2.0', + 'routes': {}, + 'storeAndForwardConfiguration': {}, + 'mqttBroker': { + 'bridges': [ + { + 'endpoint': '$upstream', + 'settings': [ + { + 'direction': 'in', + 'topic': 'telemetry/#', + 'outPrefix': '/local/topic', + 'inPrefix': '/remote/topic' + }, + { + 'direction': 'out', + 'topic': '', + 'inPrefix': '/local/telemetry', + 'outPrefix': '/remote/messages' + } + ] + } + ], + }, + '$version': 2 + }"; + + var props = JsonConvert.DeserializeObject(properties); + var bridges = props.BrokerConfiguration.Bridges; + Assert.Single(bridges); + Assert.Equal("$upstream", bridges[0].Endpoint); + Assert.Equal(2, bridges[0].Settings.Count); + Assert.Equal(Direction.In, bridges[0].Settings[0].Direction); + Assert.Equal("telemetry/#", bridges[0].Settings[0].Topic); + Assert.Equal("/local/topic", bridges[0].Settings[0].OutPrefix); + Assert.Equal("/remote/topic", bridges[0].Settings[0].InPrefix); + Assert.Equal(Direction.Out, bridges[0].Settings[1].Direction); + Assert.Equal(string.Empty, bridges[0].Settings[1].Topic); + Assert.Equal("/local/telemetry", bridges[0].Settings[1].InPrefix); + Assert.Equal("/remote/messages", bridges[0].Settings[1].OutPrefix); + } } } diff --git a/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.E2E.Test/DependencyManager.cs b/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.E2E.Test/DependencyManager.cs index ce2c2f3caa1..9c6e87cd1f8 100644 --- a/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.E2E.Test/DependencyManager.cs +++ b/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.E2E.Test/DependencyManager.cs @@ -206,7 +206,7 @@ public void Register(ContainerBuilder builder) experimentalFeatures, true, false, - false)); + true)); builder.RegisterModule(new HttpModule("Edge1")); builder.RegisterModule(new MqttModule(mqttSettingsConfiguration.Object, topics, this.serverCertificate, false, false, false, this.sslProtocols)); diff --git a/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.Http.Test/RegistryControllerTest.cs b/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.Http.Test/RegistryControllerTest.cs index fd99b210a1f..fa186fe26e0 100644 --- a/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.Http.Test/RegistryControllerTest.cs +++ b/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.Http.Test/RegistryControllerTest.cs @@ -395,6 +395,8 @@ public static IEnumerable GetControllerMethodDelegates() void SetupControllerContext(Controller controller) { var httpContext = new DefaultHttpContext(); + httpContext.Request.QueryString = new QueryString("?api-version=2017-10-20"); + var httpResponse = new DefaultHttpResponse(httpContext); httpResponse.Body = new MemoryStream(); var controllerContext = new ControllerContext(); diff --git a/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter.Test/Cloud2DeviceMessageHandlerTest.cs b/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter.Test/Cloud2DeviceMessageHandlerTest.cs index 4eae8ccb6a5..d4412c8c204 100644 --- a/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter.Test/Cloud2DeviceMessageHandlerTest.cs +++ b/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter.Test/Cloud2DeviceMessageHandlerTest.cs @@ -16,11 +16,12 @@ public class Cloud2DeviceMessageHandlerTest : MessageConfirmingTestBase public async Task EncodesDeviceNameInTopic() { var capture = new SendCapture(); - var connector = GetConnector(capture); + var connector = GetConnector(capture); var connectionRegistry = GetConnectionRegistry(); var identity = new DeviceIdentity("hub", "device_id"); var message = new EdgeMessage .Builder(new byte[] { 0x01, 0x02, 0x03 }) + .SetSystemProperties(new Dictionary() { [SystemProperties.LockToken] = "12345" }) .Build(); var sut = new Cloud2DeviceMessageHandler(connectionRegistry); @@ -41,7 +42,7 @@ public async Task EncodesPropertiesInTopic() var message = new EdgeMessage .Builder(new byte[] { 0x01, 0x02, 0x03 }) .SetProperties(new Dictionary() { ["prop1"] = "value1", ["prop2"] = "value2" }) - .SetSystemProperties(new Dictionary() { ["userId"] = "userid", ["cid"] = "corrid" }) + .SetSystemProperties(new Dictionary() { ["userId"] = "userid", ["cid"] = "corrid", [SystemProperties.LockToken] = "12345" }) .Build(); var sut = new Cloud2DeviceMessageHandler(connectionRegistry); @@ -65,6 +66,7 @@ public async Task SendsMessageDataAsPayload() var identity = new DeviceIdentity("hub", "device_id"); var message = new EdgeMessage .Builder(new byte[] { 0x01, 0x02, 0x03 }) + .SetSystemProperties(new Dictionary() { [SystemProperties.LockToken] = "12345" }) .Build(); var sut = new Cloud2DeviceMessageHandler(connectionRegistry); @@ -113,6 +115,7 @@ public async Task DoesNotSendToModule() var identity = new ModuleIdentity("hub", "device_id", "module_id"); var message = new EdgeMessage .Builder(new byte[] { 0x01, 0x02, 0x03 }) + .SetSystemProperties(new Dictionary() { [SystemProperties.LockToken] = "12345" }) .Build(); var sut = new Cloud2DeviceMessageHandler(connectionRegistry); @@ -121,7 +124,7 @@ public async Task DoesNotSendToModule() await sut.SendC2DMessageAsync(message, identity, true); Mock.Get(connector) - .Verify(c => c.SendAsync(It.IsAny(), It.IsAny()), Times.Never()); + .Verify(c => c.SendAsync(It.IsAny(), It.IsAny(), It.IsAny()), Times.Never()); } } } diff --git a/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter.Test/ConnectionHandlerTests.cs b/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter.Test/ConnectionHandlerTests.cs index 03abdccc40a..8da576049a1 100644 --- a/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter.Test/ConnectionHandlerTests.cs +++ b/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter.Test/ConnectionHandlerTests.cs @@ -25,7 +25,14 @@ public async Task CloseSendsDisconnectSignal() It.IsAny>())) .Returns(Task.FromResult(Mock.Of())); + var authenticator = Mock.Of(); + Mock.Get(authenticator) + .Setup(p => p.AuthenticateAsync(It.IsAny())) + .Returns(Task.FromResult(true)); + var connectionProviderGetter = Task.FromResult(connectionProvider); + var authenticatorGetter = Task.FromResult(authenticator); + var identityProvider = new IdentityProvider("hub"); var systemComponentIdProvider = new SystemComponentIdProvider( new TokenCredentials( @@ -35,14 +42,14 @@ public async Task CloseSendsDisconnectSignal() var brokerConnector = Mock.Of(); Mock.Get(brokerConnector) - .Setup(p => p.SendAsync(It.IsAny(), It.IsAny())) + .Setup(p => p.SendAsync(It.IsAny(), It.IsAny(), It.IsAny())) .Returns(() => Task.FromResult(true)); var sut = default(ConnectionHandler); DeviceProxy.Factory deviceProxyFactory = GetProxy; - sut = new ConnectionHandler(connectionProviderGetter, identityProvider, systemComponentIdProvider, deviceProxyFactory); + sut = new ConnectionHandler(connectionProviderGetter, authenticatorGetter, identityProvider, systemComponentIdProvider, deviceProxyFactory); sut.SetConnector(brokerConnector); await sut.HandleAsync(new MqttPublishInfo("$edgehub/connected", Encoding.UTF8.GetBytes("[\"device_test\"]"))); @@ -51,7 +58,8 @@ public async Task CloseSendsDisconnectSignal() Mock.Get(brokerConnector).Verify( c => c.SendAsync( It.Is(topic => topic.Equals("$edgehub/disconnect")), - It.Is(payload => Encoding.UTF8.GetString(payload).Equals($"\"device_test\""))), + It.Is(payload => Encoding.UTF8.GetString(payload).Equals($"\"device_test\"")), + false), Times.Once); DeviceProxy GetProxy(IIdentity identity, bool isDirectClient = true) diff --git a/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter.Test/DeviceProxyTest.cs b/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter.Test/DeviceProxyTest.cs index 3ceb2b3af26..21e459caf12 100644 --- a/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter.Test/DeviceProxyTest.cs +++ b/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter.Test/DeviceProxyTest.cs @@ -167,5 +167,28 @@ public void SetInactiveMakesInactive() Assert.False(sut.IsActive); } + + [Fact] + public async Task EdgeHubIsIndirect() + { + var connectionHandler = Mock.Of(); + var twinHandler = Mock.Of(); + var m2mHandler = Mock.Of(); + var c2dHandler = Mock.Of(); + var directMethodHandler = Mock.Of(); + var identity = new ModuleIdentity("hub", "device_id", "$edgeHub"); + + var twin = new EdgeMessage.Builder(new byte[0]).Build(); + + Mock.Get(twinHandler) + .Setup(h => h.SendTwinUpdate(It.IsAny(), It.Is(i => i == identity), It.Is(d => d == false))) + .Returns(Task.CompletedTask); + + var sut = new DeviceProxy(identity, true, connectionHandler, twinHandler, m2mHandler, c2dHandler, directMethodHandler); + + await sut.SendTwinUpdate(twin); + + Mock.Get(twinHandler).VerifyAll(); + } } } diff --git a/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter.Test/DirectMethodHandlerTest.cs b/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter.Test/DirectMethodHandlerTest.cs index 700c1f78d6a..9f1e31e1e97 100644 --- a/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter.Test/DirectMethodHandlerTest.cs +++ b/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter.Test/DirectMethodHandlerTest.cs @@ -202,8 +202,8 @@ protected static IMqttBrokerConnector GetConnector(SendCapture sendCapture = nul { var connector = Mock.Of(); Mock.Get(connector) - .Setup(c => c.SendAsync(It.IsAny(), It.IsAny())) - .Returns((string topic, byte[] content) => + .Setup(c => c.SendAsync(It.IsAny(), It.IsAny(), It.IsAny())) + .Returns((string topic, byte[] content, bool retain) => { sendCapture?.Caputre(topic, content); return Task.FromResult(true); @@ -243,14 +243,14 @@ public TestDeviceListener(IIdentity identity) public Task AddDesiredPropertyUpdatesSubscription(string correlationId) => Task.CompletedTask; public Task AddSubscription(DeviceSubscription subscription) => Task.CompletedTask; public Task CloseAsync() => Task.CompletedTask; - public Task ProcessDeviceMessageBatchAsync(IEnumerable message) => Task.CompletedTask; + public Task ProcessDeviceMessageBatchAsync(IEnumerable message) => Task.CompletedTask; public Task RemoveDesiredPropertyUpdatesSubscription(string correlationId) => Task.CompletedTask; public Task RemoveSubscription(DeviceSubscription subscription) => Task.CompletedTask; public Task SendGetTwinRequest(string correlationId) => Task.CompletedTask; public Task UpdateReportedPropertiesAsync(IMessage reportedPropertiesMessage, string correlationId) => Task.CompletedTask; public Task ProcessDeviceMessageAsync(IMessage message) => Task.CompletedTask; public Task ProcessMessageFeedbackAsync(string messageId, FeedbackStatus feedbackStatus) => Task.CompletedTask; - + public Task ProcessMethodResponseAsync(IMessage message) { this.CapturedMessage = message; diff --git a/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter.Test/MessageConfirmingBase.cs b/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter.Test/MessageConfirmingBase.cs index 033119073cf..8813b270c88 100644 --- a/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter.Test/MessageConfirmingBase.cs +++ b/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter.Test/MessageConfirmingBase.cs @@ -32,8 +32,8 @@ protected static IMqttBrokerConnector GetConnector(SendCapture sendCapture = nul { var connector = Mock.Of(); Mock.Get(connector) - .Setup(c => c.SendAsync(It.IsAny(), It.IsAny())) - .Returns((string topic, byte[] content) => + .Setup(c => c.SendAsync(It.IsAny(), It.IsAny(), It.IsAny())) + .Returns((string topic, byte[] content, bool retain) => { sendCapture?.Caputre(topic, content); return Task.FromResult(true); diff --git a/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter.Test/ModuleToModuleMessageHandlerTest.cs b/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter.Test/ModuleToModuleMessageHandlerTest.cs index 29cba564eed..588d166cdc3 100644 --- a/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter.Test/ModuleToModuleMessageHandlerTest.cs +++ b/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter.Test/ModuleToModuleMessageHandlerTest.cs @@ -1,7 +1,9 @@ // Copyright (c) Microsoft. All rights reserved. namespace Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter.Test { + using System; using System.Collections.Generic; + using System.Text; using System.Threading.Tasks; using Microsoft.Azure.Devices.Edge.Hub.Core; using Microsoft.Azure.Devices.Edge.Hub.Core.Identity; @@ -18,17 +20,19 @@ public async Task EncodesModuleNameInTopic() var capture = new SendCapture(); var connector = GetConnector(capture); var connectionRegistry = GetConnectionRegistry(); + var identityProvider = new IdentityProvider("hub"); var identity = new ModuleIdentity("hub", "device_id", "module_id"); var message = new EdgeMessage .Builder(new byte[] { 0x01, 0x02, 0x03 }) + .SetSystemProperties(new Dictionary() { [SystemProperties.LockToken] = "12345" }) .Build(); - var sut = new ModuleToModuleMessageHandler(connectionRegistry); + using var sut = new ModuleToModuleMessageHandler(connectionRegistry, identityProvider, GetAckTimeout()); sut.SetConnector(connector); await sut.SendModuleToModuleMessageAsync(message, "some_input", identity, true); - Assert.Equal("$edgehub/device_id/module_id/inputs/some_input/", capture.Topic); + Assert.Equal("$edgehub/device_id/module_id/12345/inputs/some_input/", capture.Topic); } [Fact] @@ -37,19 +41,20 @@ public async Task EncodesPropertiesInTopic() var capture = new SendCapture(); var connector = GetConnector(capture); var connectionRegistry = GetConnectionRegistry(); + var identityProvider = new IdentityProvider("hub"); var identity = new ModuleIdentity("hub", "device_id", "module_id"); var message = new EdgeMessage .Builder(new byte[] { 0x01, 0x02, 0x03 }) .SetProperties(new Dictionary() { ["prop1"] = "value1", ["prop2"] = "value2" }) - .SetSystemProperties(new Dictionary() { ["userId"] = "userid", ["cid"] = "corrid" }) + .SetSystemProperties(new Dictionary() { ["userId"] = "userid", ["cid"] = "corrid", [SystemProperties.LockToken] = "12345" }) .Build(); - var sut = new ModuleToModuleMessageHandler(connectionRegistry); + using var sut = new ModuleToModuleMessageHandler(connectionRegistry, identityProvider, GetAckTimeout()); sut.SetConnector(connector); await sut.SendModuleToModuleMessageAsync(message, "some_input", identity, true); - Assert.StartsWith("$edgehub/device_id/module_id/inputs/some_input/", capture.Topic); + Assert.StartsWith("$edgehub/device_id/module_id/12345/inputs/some_input/", capture.Topic); Assert.Contains("prop1=value1", capture.Topic); Assert.Contains("prop2=value2", capture.Topic); Assert.Contains("%24.uid=userid", capture.Topic); @@ -62,12 +67,14 @@ public async Task SendsMessageDataAsPayload() var capture = new SendCapture(); var connector = GetConnector(capture); var connectionRegistry = GetConnectionRegistry(); + var identityProvider = new IdentityProvider("hub"); var identity = new ModuleIdentity("hub", "device_id", "module_id"); var message = new EdgeMessage .Builder(new byte[] { 0x01, 0x02, 0x03 }) + .SetSystemProperties(new Dictionary() { [SystemProperties.LockToken] = "12345" }) .Build(); - var sut = new ModuleToModuleMessageHandler(connectionRegistry); + using var sut = new ModuleToModuleMessageHandler(connectionRegistry, identityProvider, GetAckTimeout()); sut.SetConnector(connector); await sut.SendModuleToModuleMessageAsync(message, "some_input", identity, true); @@ -77,7 +84,7 @@ public async Task SendsMessageDataAsPayload() [Fact] public async Task ConfirmsMessageAfterSent() - { + { var capturedLockId = default(string); var capturedStatus = (FeedbackStatus)(-1); @@ -88,6 +95,7 @@ public async Task ConfirmsMessageAfterSent() capturedStatus = status; }); + var identityProvider = new IdentityProvider("hub"); var connector = GetConnector(); var identity = new ModuleIdentity("hub", "device_id", "module_id"); @@ -96,10 +104,11 @@ public async Task ConfirmsMessageAfterSent() .SetSystemProperties(new Dictionary() { [SystemProperties.LockToken] = "12345" }) .Build(); - var sut = new ModuleToModuleMessageHandler(connectionRegistry); + using var sut = new ModuleToModuleMessageHandler(connectionRegistry, identityProvider, GetAckTimeout()); sut.SetConnector(connector); await sut.SendModuleToModuleMessageAsync(message, "some_input", identity, true); + await sut.HandleAsync(new MqttPublishInfo("$edgehub/delivered", Encoding.UTF8.GetBytes(@"""$edgehub/device_id/module_id/12345/inputs/"""))); Assert.Equal("12345", capturedLockId); Assert.Equal(FeedbackStatus.Complete, capturedStatus); @@ -113,15 +122,20 @@ public async Task DoesNotSendToDevice() var identity = new DeviceIdentity("hub", "device_id"); var message = new EdgeMessage .Builder(new byte[] { 0x01, 0x02, 0x03 }) + .SetSystemProperties(new Dictionary() { [SystemProperties.LockToken] = "12345" }) .Build(); - var sut = new ModuleToModuleMessageHandler(connectionRegistry); + var identityProvider = new IdentityProvider("hub"); + + using var sut = new ModuleToModuleMessageHandler(connectionRegistry, identityProvider, GetAckTimeout()); sut.SetConnector(connector); await sut.SendModuleToModuleMessageAsync(message, "some_input", identity, true); Mock.Get(connector) - .Verify(c => c.SendAsync(It.IsAny(), It.IsAny()), Times.Never()); + .Verify(c => c.SendAsync(It.IsAny(), It.IsAny(), It.IsAny()), Times.Never()); } + + ModuleToModuleResponseTimeout GetAckTimeout() => new ModuleToModuleResponseTimeout(TimeSpan.FromSeconds(30)); } } diff --git a/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter.Test/ScopeIdentitiesHandlerTest.cs b/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter.Test/ScopeIdentitiesHandlerTest.cs index 3db2657c6d8..300d633088b 100644 --- a/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter.Test/ScopeIdentitiesHandlerTest.cs +++ b/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter.Test/ScopeIdentitiesHandlerTest.cs @@ -110,8 +110,8 @@ protected static Mock GetConnector(SendCapture sendCapture { var connector = new Mock(); connector - .Setup(c => c.SendAsync(It.IsAny(), It.IsAny())) - .Returns((string topic, byte[] content) => + .Setup(c => c.SendAsync(It.IsAny(), It.IsAny(), It.IsAny())) + .Returns((string topic, byte[] content, bool retain) => { sendCapture?.Capture(topic, content); return Task.FromResult(true); diff --git a/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter.Test/TwinHandlerTest.cs b/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter.Test/TwinHandlerTest.cs index eff1f2320da..2c3f2ae2666 100644 --- a/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter.Test/TwinHandlerTest.cs +++ b/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter.Test/TwinHandlerTest.cs @@ -137,7 +137,7 @@ public async Task CapturesTwinReportedContentFromBody() _ = await sut.HandleAsync(publishInfo); await milestone.WaitAsync(); - Assert.Equal(new byte [] { 1, 2, 3 }, listenerCapture.Captured.CapturedMessage.Body); + Assert.Equal(new byte[] { 1, 2, 3 }, listenerCapture.Captured.CapturedMessage.Body); } [Fact] @@ -157,7 +157,7 @@ public async Task TwinUpdateRequiresMessageStatusCode() await sut.SendTwinUpdate(twin, identity, true); Mock.Get(connector) - .Verify(c => c.SendAsync(It.IsAny(), It.IsAny()), Times.Never()); + .Verify(c => c.SendAsync(It.IsAny(), It.IsAny(), It.IsAny()), Times.Never()); } [Fact] @@ -177,7 +177,7 @@ public async Task TwinUpdateRequiresMessageCorrelationId() await sut.SendTwinUpdate(twin, identity, true); Mock.Get(connector) - .Verify(c => c.SendAsync(It.IsAny(), It.IsAny()), Times.Never()); + .Verify(c => c.SendAsync(It.IsAny(), It.IsAny(), It.IsAny()), Times.Never()); } [Fact] @@ -243,7 +243,7 @@ public async Task DesiredUpdateRequiresVersion() await sut.SendDesiredPropertiesUpdate(twin, identity, true); Mock.Get(connector) - .Verify(c => c.SendAsync(It.IsAny(), It.IsAny()), Times.Never()); + .Verify(c => c.SendAsync(It.IsAny(), It.IsAny(), It.IsAny()), Times.Never()); } [Fact] @@ -257,7 +257,7 @@ public async Task DesiredUpdateEncodesVersionAndIdentityInTopic() var twin = new EdgeMessage.Builder(new byte[] { 1, 2, 3 }) .SetSystemProperties(new Dictionary() { - [SystemProperties.Version] = "123" + [SystemProperties.Version] = "123" }) .Build(); @@ -352,8 +352,8 @@ static IMqttBrokerConnector GetConnector(SendCapture sendCapture = null) { var connector = Mock.Of(); Mock.Get(connector) - .Setup(c => c.SendAsync(It.IsAny(), It.IsAny())) - .Returns((string topic, byte[] content) => + .Setup(c => c.SendAsync(It.IsAny(), It.IsAny(), It.IsAny())) + .Returns((string topic, byte[] content, bool retain) => { sendCapture?.Caputre(topic, content); return Task.FromResult(true); @@ -377,13 +377,13 @@ public TestDeviceListener(IIdentity identity, SemaphoreSlim milestone) } public IIdentity Identity { get; } - + public Task AddDesiredPropertyUpdatesSubscription(string correlationId) => Task.CompletedTask; public Task AddSubscription(DeviceSubscription subscription) => Task.CompletedTask; public Task CloseAsync() => Task.CompletedTask; public Task ProcessDeviceMessageBatchAsync(IEnumerable message) => Task.CompletedTask; public Task RemoveDesiredPropertyUpdatesSubscription(string correlationId) => Task.CompletedTask; - public Task RemoveSubscription(DeviceSubscription subscription) => Task.CompletedTask; + public Task RemoveSubscription(DeviceSubscription subscription) => Task.CompletedTask; public Task ProcessDeviceMessageAsync(IMessage message) => Task.CompletedTask; public Task ProcessMessageFeedbackAsync(string messageId, FeedbackStatus feedbackStatus) => Task.CompletedTask; public Task ProcessMethodResponseAsync(IMessage message) => Task.CompletedTask; diff --git a/edge-hub/docker/linux/arm32v7/Dockerfile b/edge-hub/docker/linux/arm32v7/Dockerfile index a2c1f04ab0f..925a6ec0e85 100644 --- a/edge-hub/docker/linux/arm32v7/Dockerfile +++ b/edge-hub/docker/linux/arm32v7/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=1.0.6.2-linux-arm32v7 +ARG base_tag=1.0.6.4-linux-arm32v7 FROM azureiotedge/azureiotedge-hub-base:${base_tag} ADD ./watchdog/target/armv7-unknown-linux-gnueabihf/release/watchdog /usr/local/bin/watchdog diff --git a/edge-hub/docker/linux/arm32v7/base/Dockerfile b/edge-hub/docker/linux/arm32v7/base/Dockerfile index 73e0b0fe38b..46aed4aa27a 100644 --- a/edge-hub/docker/linux/arm32v7/base/Dockerfile +++ b/edge-hub/docker/linux/arm32v7/base/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=3.1.7-bionic-arm32v7 +ARG base_tag=3.1.10-bionic-arm32v7 FROM mcr.microsoft.com/dotnet/core/aspnet:${base_tag} # Add an unprivileged user account for running Edge Hub diff --git a/edge-hub/docker/linux/arm64v8/Dockerfile b/edge-hub/docker/linux/arm64v8/Dockerfile index b283fca1218..bf93da16303 100644 --- a/edge-hub/docker/linux/arm64v8/Dockerfile +++ b/edge-hub/docker/linux/arm64v8/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=1.0.6.2-linux-arm64v8 +ARG base_tag=1.0.6.4-linux-arm64v8 FROM azureiotedge/azureiotedge-hub-base:${base_tag} ADD ./watchdog/target/aarch64-unknown-linux-gnu/release/watchdog /usr/local/bin/watchdog diff --git a/edge-hub/docker/linux/arm64v8/base/Dockerfile b/edge-hub/docker/linux/arm64v8/base/Dockerfile index 84cc861f9d2..d554bc4928a 100644 --- a/edge-hub/docker/linux/arm64v8/base/Dockerfile +++ b/edge-hub/docker/linux/arm64v8/base/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=3.1.7-bionic-arm64v8 +ARG base_tag=3.1.10-bionic-arm64v8 FROM mcr.microsoft.com/dotnet/core/aspnet:${base_tag} # Add an unprivileged user account for running Edge Hub diff --git a/edge-hub/watchdog/src/main.rs b/edge-hub/watchdog/src/main.rs index 368fe482ec1..86d146dbd69 100644 --- a/edge-hub/watchdog/src/main.rs +++ b/edge-hub/watchdog/src/main.rs @@ -30,11 +30,11 @@ fn main() -> Result<()> { init_logging(); info!("Starting Watchdog"); - let experimental_features_enabled = std::env::var("experimentalFeatures:enabled") + let experimental_features_enabled = std::env::var("experimentalFeatures__enabled") .unwrap_or_else(|_| "false".to_string()) == "true"; - let mqtt_broker_enabled = std::env::var("experimentalFeatures:mqttBrokerEnabled") + let mqtt_broker_enabled = std::env::var("experimentalFeatures__mqttBrokerEnabled") .unwrap_or_else(|_| "false".to_string()) == "true"; diff --git a/edge-modules/MetricsCollector/docker/linux/arm32v7/Dockerfile b/edge-modules/MetricsCollector/docker/linux/arm32v7/Dockerfile index b5c8fee5a9e..191bf47b27e 100644 --- a/edge-modules/MetricsCollector/docker/linux/arm32v7/Dockerfile +++ b/edge-modules/MetricsCollector/docker/linux/arm32v7/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=1.0.6.2-linux-arm32v7 +ARG base_tag=1.0.6.4-linux-arm32v7 FROM azureiotedge/azureiotedge-module-base:${base_tag} ARG EXE_DIR=. diff --git a/edge-modules/MetricsCollector/docker/linux/arm64v8/Dockerfile b/edge-modules/MetricsCollector/docker/linux/arm64v8/Dockerfile index 6fdd41de3ad..5a008e318a2 100644 --- a/edge-modules/MetricsCollector/docker/linux/arm64v8/Dockerfile +++ b/edge-modules/MetricsCollector/docker/linux/arm64v8/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=1.0.6.2-linux-arm64v8 +ARG base_tag=1.0.6.4-linux-arm64v8 FROM azureiotedge/azureiotedge-module-base:${base_tag} ARG EXE_DIR=. diff --git a/edge-modules/SimulatedTemperatureSensor/docker/linux/arm32v7/Dockerfile b/edge-modules/SimulatedTemperatureSensor/docker/linux/arm32v7/Dockerfile index e8d807f346e..b37affdf3c3 100644 --- a/edge-modules/SimulatedTemperatureSensor/docker/linux/arm32v7/Dockerfile +++ b/edge-modules/SimulatedTemperatureSensor/docker/linux/arm32v7/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=1.0.6.2-linux-arm32v7 +ARG base_tag=1.0.6.4-linux-arm32v7 FROM azureiotedge/azureiotedge-module-base:${base_tag} ARG EXE_DIR=. diff --git a/edge-modules/SimulatedTemperatureSensor/docker/linux/arm32v7/base/Dockerfile b/edge-modules/SimulatedTemperatureSensor/docker/linux/arm32v7/base/Dockerfile index 5a4db87ce88..17f5da2ae50 100644 --- a/edge-modules/SimulatedTemperatureSensor/docker/linux/arm32v7/base/Dockerfile +++ b/edge-modules/SimulatedTemperatureSensor/docker/linux/arm32v7/base/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=3.1.7-bionic-arm32v7 +ARG base_tag=3.1.10-bionic-arm32v7 FROM mcr.microsoft.com/dotnet/core/runtime:${base_tag} # Add an unprivileged user account for running the module diff --git a/edge-modules/SimulatedTemperatureSensor/docker/linux/arm64v8/Dockerfile b/edge-modules/SimulatedTemperatureSensor/docker/linux/arm64v8/Dockerfile index f4c5d75ed8f..25c12626e9a 100644 --- a/edge-modules/SimulatedTemperatureSensor/docker/linux/arm64v8/Dockerfile +++ b/edge-modules/SimulatedTemperatureSensor/docker/linux/arm64v8/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=1.0.6.2-linux-arm64v8 +ARG base_tag=1.0.6.4-linux-arm64v8 FROM azureiotedge/azureiotedge-module-base:${base_tag} ARG EXE_DIR=. diff --git a/edge-modules/SimulatedTemperatureSensor/docker/linux/arm64v8/base/Dockerfile b/edge-modules/SimulatedTemperatureSensor/docker/linux/arm64v8/base/Dockerfile index 991e2bc93c5..d715e7bb3d4 100644 --- a/edge-modules/SimulatedTemperatureSensor/docker/linux/arm64v8/base/Dockerfile +++ b/edge-modules/SimulatedTemperatureSensor/docker/linux/arm64v8/base/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=3.1.7-bionic-arm64v8 +ARG base_tag=3.1.10-bionic-arm64v8 FROM mcr.microsoft.com/dotnet/core/runtime:${base_tag} # Add an unprivileged user account for running the module diff --git a/edge-modules/api-proxy-module/build.sh b/edge-modules/api-proxy-module/build.sh index c3774cf3920..f4e8545b30f 100755 --- a/edge-modules/api-proxy-module/build.sh +++ b/edge-modules/api-proxy-module/build.sh @@ -88,7 +88,7 @@ echo ${PROJECT_ROOT} if [[ "$ARCH" == "amd64" ]]; then set +e -../../scripts/linux/cross-platform-rust-build.sh --os ubuntu18.04 --arch $ARCH --build-path edge-modules/api-proxy-module +../../scripts/linux/cross-platform-rust-build.sh --os alpine --arch $ARCH --build-path edge-modules/api-proxy-module set -e cp -r ./templates/ ./docker/linux/amd64 @@ -102,7 +102,7 @@ cp -r ./target/armv7-unknown-linux-musleabihf/release/api-proxy-module ./docker/ docker build . -t azureiotedge-api-proxy -f docker/linux/arm32v7/Dockerfile elif [[ "$ARCH" == "aarch64" ]]; then set +e -../../scripts/linux/cross-platform-rust-build.sh --os ubuntu18.04 --arch $ARCH --build-path edge-modules/api-proxy-module +../../scripts/linux/cross-platform-rust-build.sh --os alpine --arch $ARCH --build-path edge-modules/api-proxy-module set -e cp -r ./templates/ ./docker/linux/arm64v8 diff --git a/edge-modules/api-proxy-module/docker/linux/amd64/Dockerfile b/edge-modules/api-proxy-module/docker/linux/amd64/Dockerfile index 0f57086e4b3..abf09d7bc5e 100644 --- a/edge-modules/api-proxy-module/docker/linux/amd64/Dockerfile +++ b/edge-modules/api-proxy-module/docker/linux/amd64/Dockerfile @@ -12,6 +12,7 @@ COPY ./docker/linux/amd64/templates . RUN apk update && \ apk add nginx && \ mkdir /run/nginx + #expose ports EXPOSE 443/tcp EXPOSE 80/tcp @@ -21,5 +22,4 @@ EXPOSE 5000/tcp EXPOSE 11002/tcp #use for custom defining ports EXPOSE 7000-8000/tcp -ENTRYPOINT ["/app/api-proxy-module"] - \ No newline at end of file +ENTRYPOINT ["/app/api-proxy-module"] \ No newline at end of file diff --git a/edge-modules/api-proxy-module/readme.md b/edge-modules/api-proxy-module/readme.md index c36b98423fb..698efcbcd52 100644 --- a/edge-modules/api-proxy-module/readme.md +++ b/edge-modules/api-proxy-module/readme.md @@ -107,16 +107,17 @@ First, list all the environment variables that you want to update in the `PROXY_ | Environment variable | comments | | ------------- | ------------- | -| PROXY_CONFIG_ENV_VAR_LIST | List all the variable to be replaced. By default it contains: NGINX_DEFAULT_PORT,BLOB_UPLOAD_ROUTE_ADDRESS,DOCKER_REQUEST_ROUTE_ADDRESS,IOTEDGE_PARENTHOSTNAME | +| PROXY_CONFIG_ENV_VAR_LIST | List all the variable to be replaced. By default it contains: NGINX_DEFAULT_PORT,BLOB_UPLOAD_ROUTE_ADDRESS,DOCKER_REQUEST_ROUTE_ADDRESS,IOTEDGE_PARENTHOSTNAME, IOTEDGE_PARENTAPIPROXYNAME | Next, set each environment variable's value by listing them directly. | Environment variable | comments | | ------------- | ------------- | -| NGINX_DEFAULT_PORT | Changes the port Nginx listens too. If you change this option, make sure that the port you select is exposed in the dockerfile. Default is 443 | +| NGINX_DEFAULT_PORT | Changes the port Nginx listens on. If you update this environment variable, make sure the port you select is also exposed in the module dockerfile and the port binding. Default is 443. | | DOCKER_REQUEST_ROUTE_ADDRESS | Address to route docker requests. By default it points to the parent. | | BLOB_UPLOAD_ROUTE_ADDRESS| Address to route blob registry requests. By default it points to the parent. | -| IOTEDGE_PARENTHOSTNAME | Parent hostname | +| IOTEDGE_PARENTHOSTNAME | Read only variable. Do not assign, its value is automatically assigned to Parent hostname when container starts | +| IOTEDGE_PARENTAPIPROXYNAME | Set the name of the parent api proxy module, as specified in azure portal. This is used for certificate authentication. When omitted, the name of the parent is defaulted to the child api proxy name. | ### Update the proxy configuration dynamically @@ -136,7 +137,7 @@ To update the default configuration when the module starts, replace the configur Lastly, the configuration of the module should match the configuration of the proxy. For instance, the module image should have the same listening port open as defined in the proxy. -To minimize the number of open ports, the API Proxy should relay all HTTPS traffic (e.g. port 443), including traffic targeting the edgeHub. To avoid port binding conflicts, the edgeHub settings thus needs to be modified to not port bind on port 443. The API Proxy module should bind its port 443. The API Proxy should also be configured to route the edgeHub traffic by turning on the `ROUTE_EDGEHUB_TRAFFIC` to true (see example below). +To minimize the number of open ports, the API Proxy should relay all HTTPS traffic (e.g. port 443), including traffic targeting the edgeHub. To avoid port binding conflicts, the edgeHub settings thus needs to be modified to not port bind on port 443. The API Proxy module should bind its port 443. The API Proxy is configured by default to re-route all edgeHub traffic on port 443. If minimizing the numer of open ports is not a concerned, the API Proxy can listen on another port than 443 and let edgeHub use port 443. The API Proxy can for instance listen on port 8000 by setting the environment variable `NGINX_DEFAULT_PORT` to `8000` and bind port 8000 of the API Proxy module. This is the default configuration of the API Proxy module. diff --git a/edge-modules/api-proxy-module/rust-sdk/azure-iot-mqtt/src/device.rs b/edge-modules/api-proxy-module/rust-sdk/azure-iot-mqtt/src/device.rs index bb821668b2e..87860c6bac6 100644 --- a/edge-modules/api-proxy-module/rust-sdk/azure-iot-mqtt/src/device.rs +++ b/edge-modules/api-proxy-module/rust-sdk/azure-iot-mqtt/src/device.rs @@ -153,9 +153,7 @@ impl futures_core::Stream for Client { State::WaitingForSubscriptions { reset_session, acked } => if *reset_session { match std::pin::Pin::new(&mut this.inner).poll_next(cx) { - std::task::Poll::Ready(Some(Ok(mqtt3::Event::NewConnection { .. }))) => (), - - std::task::Poll::Ready(Some(Ok(mqtt3::Event::Disconnected(_)))) => (), + std::task::Poll::Ready(Some(Ok(mqtt3::Event::NewConnection { .. }))) | std::task::Poll::Ready(Some(Ok(mqtt3::Event::Disconnected(_)))) => (), std::task::Poll::Ready(Some(Ok(mqtt3::Event::Publication(publication)))) => match InternalMessage::parse(publication, &this.c2d_prefix) { Ok(InternalMessage::CloudToDevice(message)) => diff --git a/edge-modules/api-proxy-module/rust-sdk/azure-iot-mqtt/src/module.rs b/edge-modules/api-proxy-module/rust-sdk/azure-iot-mqtt/src/module.rs index 201c864f460..d9dcb4d96c6 100644 --- a/edge-modules/api-proxy-module/rust-sdk/azure-iot-mqtt/src/module.rs +++ b/edge-modules/api-proxy-module/rust-sdk/azure-iot-mqtt/src/module.rs @@ -215,7 +215,7 @@ impl futures_core::Stream for Client { State::WaitingForSubscriptions { reset_session, acked } => if *reset_session { match std::pin::Pin::new(&mut this.inner).poll_next(cx) { - std::task::Poll::Ready(Some(Ok(mqtt3::Event::NewConnection { .. }))) => (), + std::task::Poll::Ready(Some(Ok(mqtt3::Event::NewConnection { .. }))) | std::task::Poll::Ready(Some(Ok(mqtt3::Event::Disconnected(_)))) => (), std::task::Poll::Ready(Some(Ok(mqtt3::Event::Publication(publication)))) => match InternalMessage::parse(publication) { Ok(InternalMessage::DirectMethod { name, payload, request_id }) => @@ -237,7 +237,7 @@ impl futures_core::Stream for Client { } }, - std::task::Poll::Ready(Some(Ok(mqtt3::Event::Disconnected(_)))) => (), + std::task::Poll::Ready(Some(Err(err))) => return std::task::Poll::Ready(Some(Err(err))), diff --git a/edge-modules/api-proxy-module/src/main.rs b/edge-modules/api-proxy-module/src/main.rs index b4d338d35e1..79505e81a4a 100644 --- a/edge-modules/api-proxy-module/src/main.rs +++ b/edge-modules/api-proxy-module/src/main.rs @@ -26,7 +26,8 @@ use shutdown_handle::ShutdownHandle; async fn main() -> Result<()> { env_logger::builder().filter_level(LevelFilter::Info).init(); - let notify_need_reload_api_proxy = Arc::new(Notify::new()); + let notify_config_reload_api_proxy = Arc::new(Notify::new()); + let notify_cert_reload_api_proxy = Arc::new(Notify::new()); let client = config_monitor::get_sdk_client()?; let mut shutdown_sdk = client @@ -34,18 +35,14 @@ async fn main() -> Result<()> { .shutdown_handle() .context("Could not create Shutdown handle")?; - let report_twin_state_handle = client.report_twin_state_handle(); - - let (twin_state_poll_handle, twin_state_poll_shutdown_handle) = - config_monitor::report_twin_state(report_twin_state_handle); let (config_monitor_handle, config_monitor_shutdown_handle) = - config_monitor::start(client, notify_need_reload_api_proxy.clone()) + config_monitor::start(client, notify_config_reload_api_proxy.clone()) .context("Failed running config monitor")?; let (cert_monitor_handle, cert_monitor_shutdown_handle) = - certs_monitor::start(notify_need_reload_api_proxy.clone()) + certs_monitor::start(notify_cert_reload_api_proxy.clone()) .context("Failed running certificates monitor")?; let (nginx_controller_handle, nginx_controller_shutdown_handle) = - nginx_controller_start(notify_need_reload_api_proxy) + nginx_controller_start(notify_config_reload_api_proxy, notify_cert_reload_api_proxy) .context("Failed running nginx controller")?; //If one task closes, clean up everything @@ -60,7 +57,6 @@ async fn main() -> Result<()> { .context("Fatal, could not shut down SDK")?; cert_monitor_shutdown_handle.shutdown().await; - twin_state_poll_shutdown_handle.shutdown().await; config_monitor_shutdown_handle.shutdown().await; nginx_controller_shutdown_handle.shutdown().await; @@ -70,16 +66,14 @@ async fn main() -> Result<()> { if let Err(e) = config_monitor_handle.await { error!("error on finishing config monitor: {}", e); } - if let Err(e) = twin_state_poll_handle.await { - error!("error on finishing twin state monitor: {}", e); - } info!("Api proxy stopped"); Ok(()) } pub fn nginx_controller_start( - notify_need_reload_api_proxy: Arc, + notify_config_reload_api_proxy: Arc, + notify_cert_reload_api_proxy: Arc, ) -> Result<(JoinHandle>, ShutdownHandle), Error> { let program_path = "/usr/sbin/nginx"; let args = vec![ @@ -97,10 +91,10 @@ pub fn nginx_controller_start( let shutdown_handle = ShutdownHandle(shutdown_signal.clone()); let monitor_loop: JoinHandle> = tokio::spawn(async move { - //Wait for certificate rotation //This is just to avoid error at the beginning when nginx tries to start - //but configuration is not ready - notify_need_reload_api_proxy.notified().await; + //Wait for configuration and for certs to be ready. + notify_config_reload_api_proxy.notified().await; + notify_cert_reload_api_proxy.notified().await; loop { //Make sure proxy is stopped by sending stop command. Otherwise port will be blocked @@ -121,18 +115,19 @@ pub fn nginx_controller_start( .with_context(|| format!("Failed to start {} process.", name)) .context("Cannot start proxy!")?; - let signal_restart_nginx = notify_need_reload_api_proxy.notified(); + // Restart nginx on new config, new cert or crash. + let cert_reload = notify_cert_reload_api_proxy.notified(); + let config_reload = notify_config_reload_api_proxy.notified(); + futures::pin_mut!(cert_reload, config_reload); + let signal_restart_nginx = future::select(cert_reload, config_reload); futures::pin_mut!(child_nginx, signal_restart_nginx); - - //Wait for: either a signal to restart(cert rotation, new config) or the child to crash. let restart_proxy = future::select(child_nginx, signal_restart_nginx); - //Shutdown on ctrl+c or on signal + //Shutdown on ctrl+c or on signal let wait_shutdown_ctrl_c = shutdown::shutdown(); futures::pin_mut!(wait_shutdown_ctrl_c); let wait_shutdown_signal = shutdown_signal.notified(); futures::pin_mut!(wait_shutdown_signal); - let wait_shutdown = future::select(wait_shutdown_ctrl_c, wait_shutdown_signal); match future::select(wait_shutdown, restart_proxy).await { diff --git a/edge-modules/api-proxy-module/src/monitors/certs_monitor.rs b/edge-modules/api-proxy-module/src/monitors/certs_monitor.rs index c3fac5ca8dd..0a81440cbb6 100644 --- a/edge-modules/api-proxy-module/src/monitors/certs_monitor.rs +++ b/edge-modules/api-proxy-module/src/monitors/certs_monitor.rs @@ -14,8 +14,6 @@ use shutdown_handle::ShutdownHandle; const PROXY_SERVER_TRUSTED_CA_PATH: &str = "/app/trustedCA.crt"; const PROXY_SERVER_CERT_PATH: &str = "/app/server.crt"; const PROXY_SERVER_PRIVATE_KEY_PATH: &str = "/app/private_key_server.pem"; -const PROXY_IDENTITY_CERT_PATH: &str = "/app/identity.crt"; -const PROXY_IDENTITY_PRIVATE_KEY_PATH: &str = "/app/private_key_identity.pem"; const PROXY_SERVER_VALIDITY_DAYS: i64 = 90; @@ -46,22 +44,23 @@ pub fn start( .context("Could not create cert monitor client")?; let monitor_loop: JoinHandle> = tokio::spawn(async move { - loop { + let mut new_trust_bundle = false; + + //Loop until trust bundle is received. + while !new_trust_bundle { let wait_shutdown = shutdown_signal.notified(); futures::pin_mut!(wait_shutdown); - match futures::future::select(time::delay_for(CERTIFICATE_POLL_INTERVAL), wait_shutdown) - .await + if let Either::Right(_) = + futures::future::select(time::delay_for(CERTIFICATE_POLL_INTERVAL), wait_shutdown) + .await { - Either::Right(_) => { - warn!("Shutting down certs monitor!"); - return Ok(()); - } - Either::Left(_) => {} - }; + warn!("Shutting down certs monitor!"); + return Ok(()); + } //Check for rotation. If rotated then we notify. - let new_trust_bundle = match cert_monitor.get_new_trust_bundle().await { + new_trust_bundle = match cert_monitor.get_new_trust_bundle().await { Ok(Some(trust_bundle)) => { //If we have a new cert, we need to write it in file system file::write_binary_to_file( @@ -76,6 +75,25 @@ pub fn start( false } }; + } + + //Trust bundle just received. Request for a reset of the API proxy. + notify_certs_rotated.notify(); + + //Loop to check if server certificate expired. + //It is implemented as a polling instead of a delay until certificate expiry, because clocks are unreliable. + //If the system clock gets readjusted while the task is sleeping, the system might wake up after the certificate expiry. + loop { + let wait_shutdown = shutdown_signal.notified(); + futures::pin_mut!(wait_shutdown); + + if let Either::Right(_) = + futures::future::select(time::delay_for(CERTIFICATE_POLL_INTERVAL), wait_shutdown) + .await + { + warn!("Shutting down certs monitor!"); + return Ok(()); + } //Same thing as above but for private key and server cert let new_server_cert = match cert_monitor.need_to_rotate_server_cert(Utc::now()).await { @@ -98,30 +116,7 @@ pub fn start( } }; - //Same thing as above but for private key and identity cert - let new_identity_cert = match cert_monitor - .need_to_rotate_identity_cert(Utc::now()) - .await - { - Ok(Some((identity_cert, private_key))) => { - //If we have a new cert, we need to write it in file system - file::write_binary_to_file(identity_cert.as_bytes(), PROXY_IDENTITY_CERT_PATH)?; - - //If we have a new cert, we need to write it in file system - file::write_binary_to_file( - private_key.as_bytes(), - PROXY_IDENTITY_PRIVATE_KEY_PATH, - )?; - true - } - Ok(None) => false, - Err(err) => { - error!("Error while trying to get server cert {}", err); - false - } - }; - - if new_trust_bundle | new_identity_cert | new_server_cert { + if new_server_cert { notify_certs_rotated.notify(); } } @@ -137,7 +132,6 @@ struct CertificateMonitor { bundle_of_trust_hash: String, work_load_api_client: edgelet_client::WorkloadClient, server_cert_expiration_date: Option>, - identity_cert_expiration_date: Option>, validity_days: Duration, } @@ -151,7 +145,6 @@ impl CertificateMonitor { ) -> Result { //Create expiry date in the past so cert has to be rotated now. let server_cert_expiration_date = None; - let identity_cert_expiration_date = None; let work_load_api_client = edgelet_client::workload(&workload_url).context("Could not get workload client")?; @@ -163,7 +156,6 @@ impl CertificateMonitor { bundle_of_trust_hash: String::default(), work_load_api_client, server_cert_expiration_date, - identity_cert_expiration_date, validity_days, }) } @@ -200,32 +192,6 @@ impl CertificateMonitor { Ok(Some(certificates)) } - async fn need_to_rotate_identity_cert( - &mut self, - current_date: DateTime, - ) -> Result, anyhow::Error> { - //If certificates are not expired, we don't need to make a query - if let Some(expiration_date) = self.identity_cert_expiration_date { - if current_date < expiration_date { - return Ok(None); - } - } - let new_expiration_date = Utc::now() - .checked_add_signed(self.validity_days) - .context("Could not compute new expiration date for certificate")?; - - let resp = self - .work_load_api_client - .create_identity_cert(&self.module_id, new_expiration_date) - .await?; - - let (certificates, expiration_date) = - unwrap_certificate_response(&resp).context("could not extract server certificates")?; - self.identity_cert_expiration_date = Some(expiration_date); - - Ok(Some(certificates)) - } - async fn get_new_trust_bundle(&mut self) -> Result, anyhow::Error> { let resp = self.work_load_api_client.trust_bundle().await?; @@ -318,58 +284,6 @@ mod tests { assert!(result.is_none()); } - #[tokio::test] - async fn test_get_identity_certs() { - let expiration = Utc::now() + Duration::days(PROXY_SERVER_VALIDITY_DAYS); - let res = json!( - { - "privateKey": { "type": "key", "bytes": "PRIVATE KEY" }, - "certificate": "CERTIFICATE", - "expiration": expiration.to_rfc3339() - } - ); - - let module_id = String::from("api_proxy"); - let generation_id = String::from("0000"); - let gateway_hostname = String::from("dummy"); - let workload_url = mockito::server_url(); - - let mut client = CertificateMonitor::new( - module_id, - generation_id, - gateway_hostname, - &workload_url, - Duration::days(PROXY_SERVER_VALIDITY_DAYS), - ) - .unwrap(); - - let current_date = Utc::now(); - - let _m = mock( - "POST", - "/modules/api_proxy/certificate/identity?api-version=2019-01-30", - ) - .with_status(201) - .with_body(serde_json::to_string(&res).unwrap()) - .create(); - let (identity_cert, private_key) = client - .need_to_rotate_identity_cert(current_date) - .await - .unwrap() - .unwrap(); - - assert_eq!(identity_cert, "CERTIFICATE"); - assert_eq!(private_key, "PRIVATE KEY"); - - //Try again, certificate should be rotated in memory. - let result = client - .need_to_rotate_identity_cert(current_date) - .await - .unwrap(); - - assert!(result.is_none()); - } - #[tokio::test] async fn test_get_bundle_of_trust() { let res = json!( { "certificate": "CERTIFICATE" } ); diff --git a/edge-modules/api-proxy-module/src/monitors/config_monitor.rs b/edge-modules/api-proxy-module/src/monitors/config_monitor.rs index 9bd1230e367..2d89172e88b 100644 --- a/edge-modules/api-proxy-module/src/monitors/config_monitor.rs +++ b/edge-modules/api-proxy-module/src/monitors/config_monitor.rs @@ -1,7 +1,6 @@ -use std::{sync::Arc, time::Duration}; +use std::{env, sync::Arc, time::Duration}; use anyhow::{Context, Error, Result}; -use chrono::Utc; use futures_util::future::Either; use log::{error, warn}; use regex::Regex; @@ -9,22 +8,22 @@ use tokio::{sync::Notify, task::JoinHandle}; use super::file; use super::shutdown_handle; -use azure_iot_mqtt::{ - module::Client, ReportTwinStateHandle, ReportTwinStateRequest, Transport::Tcp, TwinProperties, -}; +use azure_iot_mqtt::{module::Client, Transport::Tcp, TwinProperties}; use shutdown_handle::ShutdownHandle; const PROXY_CONFIG_TAG: &str = "proxy_config"; const PROXY_CONFIG_PATH_RAW: &str = "/app/nginx_default_config.conf"; const PROXY_CONFIG_PATH_PARSED: &str = "/app/nginx_config.conf"; const PROXY_CONFIG_ENV_VAR_LIST: &str = "NGINX_CONFIG_ENV_VAR_LIST"; -const PROXY_CONFIG_DEFAULT_VARS_LIST:&str = "NGINX_DEFAULT_PORT,BLOB_UPLOAD_ROUTE_ADDRESS,DOCKER_REQUEST_ROUTE_ADDRESS,IOTEDGE_PARENTHOSTNAME"; +const PROXY_CONFIG_DEFAULT_VARS_LIST:&str = "NGINX_DEFAULT_PORT,BLOB_UPLOAD_ROUTE_ADDRESS,DOCKER_REQUEST_ROUTE_ADDRESS,IOTEDGE_PARENTHOSTNAME,IOTEDGE_PARENTAPIPROXYNAME"; -const PROXY_CONFIG_DEFAULT_VALUES: &[(&str, &str)] = &[("NGINX_DEFAULT_PORT", "443")]; +const PROXY_CONFIG_DEFAULT_VALUES: &[(&str, &str)] = &[ + ("NGINX_DEFAULT_PORT", "443"), + ("IOTEDGE_PARENTAPIPROXYNAME", "IOTEDGE_MODULEID"), +]; -const TWIN_STATE_POLL_INTERVAL: Duration = Duration::from_secs(5); const TWIN_CONFIG_MAX_BACK_OFF: Duration = Duration::from_secs(30); -const TWIN_CONFIG_KEEP_ALIVE: Duration = Duration::from_secs(5); +const TWIN_CONFIG_KEEP_ALIVE: Duration = Duration::from_secs(300); pub fn get_sdk_client() -> Result { let client = match Client::new_for_edge_module( @@ -54,6 +53,9 @@ pub fn start( //Allow on level of indirection, when one env var references another env var. dereference_env_variable(); + //Special handling of some of the environment variables + specific_handling_env_var(); + //Parse default config and notify to reboot nginx if it has already started //If the config is incorrect, return error because otherwise nginx doesn't have any config. @@ -143,6 +145,12 @@ fn dereference_env_variable() { } } +fn specific_handling_env_var() { + if let Ok(moduleid) = env::var("IOTEDGE_PARENTAPIPROXYNAME") { + std::env::set_var("IOTEDGE_PARENTAPIPROXYNAME", sanitize_dns_label(&moduleid)); + } +} + fn save_raw_config(twin: &TwinProperties) -> Result<()> { let json = twin.properties.get_key_value(PROXY_CONFIG_TAG); @@ -190,6 +198,24 @@ fn get_var_list() -> String { } } +const ALLOWED_CHAR_DNS: char = '-'; +const DNS_MAX_SIZE: usize = 63; + +// The name returned from here must conform to following rules (as per RFC 1035): +// - length must be <= 63 characters +// - must be all lower case alphanumeric characters or '-' +// - must start with an alphabet +// - must end with an alphanumeric character +pub fn sanitize_dns_label(name: &str) -> String { + name.trim_start_matches(|c: char| !c.is_ascii_alphabetic()) + .trim_end_matches(|c: char| !c.is_ascii_alphanumeric()) + .to_lowercase() + .chars() + .filter(|c| c.is_ascii_alphanumeric() || c == &ALLOWED_CHAR_DNS) + .take(DNS_MAX_SIZE) + .collect::() +} + //Check readme for details of how parsing is done. //First all the environment variables are replaced by their value. //Only environment variables in the list NGINX_CONFIG_ENV_VAR_LIST are replaced. @@ -225,55 +251,6 @@ fn get_parsed_config(str: &str) -> Result { Ok(str) } -pub fn report_twin_state( - mut report_twin_state_handle: ReportTwinStateHandle, -) -> (JoinHandle>, ShutdownHandle) { - use futures_util::StreamExt; - - let shutdown_signal = Arc::new(tokio::sync::Notify::new()); - let shutdown_handle = ShutdownHandle(shutdown_signal.clone()); - - let mut interval = tokio::time::interval(TWIN_STATE_POLL_INTERVAL); - let monitor_loop: JoinHandle> = tokio::spawn(async move { - report_twin_state_handle - .report_twin_state(ReportTwinStateRequest::Replace( - vec![("start-time".to_string(), Utc::now().to_string().into())] - .into_iter() - .collect(), - )) - .await - .context("couldn't report initial twin state")?; - - loop { - let wait_shutdown = shutdown_signal.notified(); - futures::pin_mut!(wait_shutdown); - match futures::future::select(wait_shutdown, interval.next()).await { - Either::Left(_) => { - warn!("Shutting down twin state polling!"); - return Ok(()); - } - Either::Right((result, _)) => { - if result.is_some() { - report_twin_state_handle - .report_twin_state(ReportTwinStateRequest::Patch( - vec![("current-time".to_string(), Utc::now().to_string().into())] - .into_iter() - .collect(), - )) - .await - .context("couldn't report twin state patch")?; - } else { - warn!("Shutting down twin state polling!"); - //Should send a ctrl c event here? - return Ok(()); - } - } - }; - } - }); - (monitor_loop, shutdown_handle) -} - #[cfg(test)] mod tests { const RAW_CONFIG_BASE64:&str = "ZXZlbnRzIHsgfQ0KDQoNCmh0dHAgew0KICAgIHByb3h5X2J1ZmZlcnMgMzIgMTYwazsgIA0KICAgIHByb3h5X2J1ZmZlcl9zaXplIDE2MGs7DQogICAgcHJveHlfcmVhZF90aW1lb3V0IDM2MDA7DQogICAgZXJyb3JfbG9nIC9kZXYvc3Rkb3V0IGluZm87DQogICAgYWNjZXNzX2xvZyAvZGV2L3N0ZG91dDsNCg0KICAgIHNlcnZlciB7DQogICAgICAgIGxpc3RlbiAke05HSU5YX0RFRkFVTFRfUE9SVH0gc3NsIGRlZmF1bHRfc2VydmVyOw0KDQogICAgICAgIGNodW5rZWRfdHJhbnNmZXJfZW5jb2Rpbmcgb247DQoNCiAgICAgICAgc3NsX2NlcnRpZmljYXRlICAgICAgICBzZXJ2ZXIuY3J0Ow0KICAgICAgICBzc2xfY2VydGlmaWNhdGVfa2V5ICAgIHByaXZhdGVfa2V5LnBlbTsgDQogICAgICAgIHNzbF9jbGllbnRfY2VydGlmaWNhdGUgdHJ1c3RlZENBLmNydDsNCiAgICAgICAgc3NsX3ZlcmlmeV9jbGllbnQgb247DQoNCg0KICAgICAgICAjaWZfdGFnICR7TkdJTlhfSEFTX0JMT0JfTU9EVUxFfQ0KICAgICAgICBpZiAoJGh0dHBfeF9tc19ibG9iX3R5cGUgPSBCbG9ja0Jsb2IpDQogICAgICAgIHsNCiAgICAgICAgICAgIHJld3JpdGUgXiguKikkIC9zdG9yYWdlJDEgbGFzdDsNCiAgICAgICAgfSANCiAgICAgICAgI2VuZGlmX3RhZyAke05HSU5YX0hBU19CTE9CX01PRFVMRX0NCg0KICAgICAgICAjaWZfdGFnICR7RE9DS0VSX1JFUVVFU1RfUk9VVEVfQUREUkVTU30NCiAgICAgICAgbG9jYXRpb24gL3YyIHsNCiAgICAgICAgICAgIHByb3h5X2h0dHBfdmVyc2lvbiAxLjE7DQogICAgICAgICAgICByZXNvbHZlciAxMjcuMC4wLjExOw0KICAgICAgICAgICAgc2V0ICRiYWNrZW5kICJodHRwOi8vJHtET0NLRVJfUkVRVUVTVF9ST1VURV9BRERSRVNTfSI7DQogICAgICAgICAgICBwcm94eV9wYXNzICAgICAgICAgICRiYWNrZW5kOw0KICAgICAgICB9DQogICAgICAgI2VuZGlmX3RhZyAke0RPQ0tFUl9SRVFVRVNUX1JPVVRFX0FERFJFU1N9DQoNCiAgICAgICAgI2lmX3RhZyAke05HSU5YX0hBU19CTE9CX01PRFVMRX0NCiAgICAgICAgbG9jYXRpb24gfl4vc3RvcmFnZS8oLiopew0KICAgICAgICAgICAgcHJveHlfaHR0cF92ZXJzaW9uIDEuMTsNCiAgICAgICAgICAgIHJlc29sdmVyIDEyNy4wLjAuMTE7DQogICAgICAgICAgICBzZXQgJGJhY2tlbmQgImh0dHA6Ly8ke05HSU5YX0JMT0JfTU9EVUxFX05BTUVfQUREUkVTU30iOw0KICAgICAgICAgICAgcHJveHlfcGFzcyAgICAgICAgICAkYmFja2VuZC8kMSRpc19hcmdzJGFyZ3M7DQogICAgICAgIH0NCiAgICAgICAgI2VuZGlmX3RhZyAke05HSU5YX0hBU19CTE9CX01PRFVMRX0NCg0KICAgICAgICAjaWZfdGFnICR7TkdJTlhfTk9UX1JPT1R9ICAgICAgDQogICAgICAgIGxvY2F0aW9uIC97DQogICAgICAgICAgICBwcm94eV9odHRwX3ZlcnNpb24gMS4xOw0KICAgICAgICAgICAgcmVzb2x2ZXIgMTI3LjAuMC4xMTsNCiAgICAgICAgICAgIHNldCAkYmFja2VuZCAiaHR0cHM6Ly8ke0dBVEVXQVlfSE9TVE5BTUV9OjQ0MyI7DQogICAgICAgICAgICBwcm94eV9wYXNzICAgICAgICAgICRiYWNrZW5kLyQxJGlzX2FyZ3MkYXJnczsNCiAgICAgICAgfQ0KICAgICAgICAjZW5kaWZfdGFnICR7TkdJTlhfTk9UX1JPT1R9DQogICAgfQ0KfQ=="; @@ -284,7 +261,7 @@ mod tests { #[test] fn env_var_tests() { //unset all variables - std::env::set_var(PROXY_CONFIG_ENV_VAR_LIST, "NGINX_DEFAULT_PORT,DOCKER_REQUEST_ROUTE_ADDRESS,NGINX_HAS_BLOB_MODULE,GATEWAY_HOSTNAME,NGINX_NOT_ROOT"); + std::env::set_var(PROXY_CONFIG_ENV_VAR_LIST, "NGINX_DEFAULT_PORT,DOCKER_REQUEST_ROUTE_ADDRESS,NGINX_HAS_BLOB_MODULE,GATEWAY_HOSTNAME,NGINX_NOT_ROOT,IOTEDGE_PARENTAPIPROXYNAME"); let vars_list = PROXY_CONFIG_DEFAULT_VARS_LIST.split(','); for key in vars_list { std::env::remove_var(key); @@ -355,5 +332,42 @@ mod tests { let config = get_parsed_config(dummy_config).unwrap(); assert_eq!("\r\n#if_tag IOTEDGE_PARENTHOSTNAME\r\nshould not be removed\r\n#endif_tag IOTEDGE_PARENTHOSTNAME", config); + + //*************************** Check IOTEDGE_PARENTAPIPROXYNAME defaults to module id if omitted ******************* + let vars_list = PROXY_CONFIG_DEFAULT_VARS_LIST.split(','); + for key in vars_list { + std::env::remove_var(key); + } + std::env::set_var("IOTEDGE_MODULEID", "apiproxy"); + + set_default_env_vars(); + //Check variable has been assigned the module id env var + let var = std::env::var("IOTEDGE_PARENTAPIPROXYNAME").unwrap(); + assert_eq!("IOTEDGE_MODULEID", var); + + dereference_env_variable(); + + specific_handling_env_var(); + + let dummy_config = "${IOTEDGE_PARENTAPIPROXYNAME}"; + + let config = get_parsed_config(dummy_config).unwrap(); + + assert_eq!("apiproxy", config); + + //*************************** Check IOTEDGE_PARENTAPIPROXYNAME get sanitized ******************* + let vars_list = PROXY_CONFIG_DEFAULT_VARS_LIST.split(','); + for key in vars_list { + std::env::remove_var(key); + } + std::env::set_var("IOTEDGE_PARENTAPIPROXYNAME", "iotedge_api_proxy"); + set_default_env_vars(); + dereference_env_variable(); + specific_handling_env_var(); + let dummy_config = "${IOTEDGE_PARENTAPIPROXYNAME}"; + + let config = get_parsed_config(dummy_config).unwrap(); + + assert_eq!("iotedgeapiproxy", config); } } diff --git a/edge-modules/api-proxy-module/templates/nginx_default_config.conf b/edge-modules/api-proxy-module/templates/nginx_default_config.conf index a057c65eaac..695f9e3f3f2 100644 --- a/edge-modules/api-proxy-module/templates/nginx_default_config.conf +++ b/edge-modules/api-proxy-module/templates/nginx_default_config.conf @@ -49,7 +49,8 @@ http { location ~^/registry/(.*) { proxy_http_version 1.1; resolver 127.0.0.11; - proxy_pass http://${DOCKER_REQUEST_ROUTE_ADDRESS}/$1$is_args$args; + set $upstream_endpoint http://${DOCKER_REQUEST_ROUTE_ADDRESS}/$1$is_args$args; + proxy_pass $upstream_endpoint; } #endif_tag ${DOCKER_REQUEST_ROUTE_ADDRESS} @@ -57,7 +58,8 @@ http { location ~^/storage/(.*){ resolver 127.0.0.11; proxy_http_version 1.1; - proxy_pass http://${BLOB_UPLOAD_ROUTE_ADDRESS}/$1$is_args$args; + set $upstream_endpoint http://${BLOB_UPLOAD_ROUTE_ADDRESS}/$1$is_args$args; + proxy_pass $upstream_endpoint; } #endif_tag ${BLOB_UPLOAD_ROUTE_ADDRESS} @@ -67,6 +69,8 @@ http { resolver 127.0.0.11; #proxy_ssl_certificate identity.crt; #proxy_ssl_certificate_key private_key_identity.pem; + proxy_ssl_server_name on; + proxy_ssl_name ${IOTEDGE_PARENTAPIPROXYNAME}; proxy_ssl_trusted_certificate trustedCA.crt; proxy_ssl_verify_depth 7; proxy_ssl_verify on; @@ -75,10 +79,12 @@ http { #endif_tag ${IOTEDGE_PARENTHOSTNAME} location ~^/devices|twins/ { + resolver 127.0.0.11; proxy_http_version 1.1; proxy_ssl_verify off; - proxy_set_header x-ms-edge-clientcert $ssl_client_escaped_cert; - proxy_pass https://edgeHub; + proxy_set_header x-ms-edge-clientcert $ssl_client_escaped_cert; + set $upstream_endpoint https://edgeHub; + proxy_pass $upstream_endpoint; } } } \ No newline at end of file diff --git a/edge-modules/iotedge-diagnostics-dotnet/docker/linux/arm32v7/Dockerfile b/edge-modules/iotedge-diagnostics-dotnet/docker/linux/arm32v7/Dockerfile index 1e94f407286..d3d793075e8 100644 --- a/edge-modules/iotedge-diagnostics-dotnet/docker/linux/arm32v7/Dockerfile +++ b/edge-modules/iotedge-diagnostics-dotnet/docker/linux/arm32v7/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=1.0.6.2-linux-arm32v7 +ARG base_tag=1.0.6.4-linux-arm32v7 FROM azureiotedge/azureiotedge-module-base:${base_tag} ARG EXE_DIR=. diff --git a/edge-modules/iotedge-diagnostics-dotnet/docker/linux/arm64v8/Dockerfile b/edge-modules/iotedge-diagnostics-dotnet/docker/linux/arm64v8/Dockerfile index a0922082a43..7423308fd8c 100644 --- a/edge-modules/iotedge-diagnostics-dotnet/docker/linux/arm64v8/Dockerfile +++ b/edge-modules/iotedge-diagnostics-dotnet/docker/linux/arm64v8/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=1.0.6.2-linux-arm64v8 +ARG base_tag=1.0.6.4-linux-arm64v8 FROM azureiotedge/azureiotedge-module-base:${base_tag} ARG EXE_DIR=. diff --git a/edge-modules/iotedge-diagnostics-dotnet/src/Program.cs b/edge-modules/iotedge-diagnostics-dotnet/src/Program.cs index 851619fb322..fba86d6a62b 100644 --- a/edge-modules/iotedge-diagnostics-dotnet/src/Program.cs +++ b/edge-modules/iotedge-diagnostics-dotnet/src/Program.cs @@ -2,24 +2,14 @@ namespace Diagnostics { using System; - using System.IO; using System.Linq; using System.Net; using System.Net.Http; using System.Net.Sockets; - using System.Text; - using System.Threading; using System.Threading.Tasks; - using Microsoft.Azure.Devices.Client; - using Microsoft.Azure.Devices.Client.Transport.Mqtt; using Microsoft.Azure.Devices.Edge.Util; - using Microsoft.Azure.Devices.Edge.Util.Concurrency; - using Microsoft.Azure.Devices.Edge.Util.TransientFaultHandling; - using Microsoft.Azure.Devices.Shared; using Microsoft.Extensions.Configuration; - using Newtonsoft.Json; using ProxyLib.Proxy; - using ProxyLib.Proxy.Exceptions; class Program { @@ -62,47 +52,75 @@ static async Task MainAsync(string[] args) static async Task EdgeAgent(string managementUri) { - string modules; if (managementUri.EndsWith(".sock")) { - modules = GetSocket.GetSocketResponse(managementUri, "/modules/?api-version=2018-06-28"); + string response = GetSocket.GetSocketResponse(managementUri.TrimEnd('/'), "/modules/?api-version=2018-06-28"); + + if (!response.StartsWith("HTTP/1.1 200 OK")) + { + throw new Exception($"Got bad response: {response}"); + } } else { using (var http = new HttpClient()) - using (var response = await http.GetAsync(managementUri + "/modules/?api-version=2018-06-28")) + using (var response = await http.GetAsync(managementUri.TrimEnd('/') + "/modules/?api-version=2018-06-28")) { response.EnsureSuccessStatusCode(); - modules = await response.Content.ReadAsStringAsync(); } } - - if (!modules.StartsWith("HTTP/1.1 200 OK")) - { - throw new Exception($"Got bad response: {modules}"); - } } static async Task Upstream(string hostname, string port, string proxy) { - if (proxy != null) + if (port == "443") { - Uri proxyUri = new Uri(proxy); - IProxyClient proxyClient = MakeProxy(proxyUri); + var httpClientHandler = new HttpClientHandler(); + httpClientHandler.ServerCertificateCustomValidationCallback = (message, cert, chain, sslPolicyErrors) => + { + return true; // Is valid + }; + + if (proxy != null) + { + Environment.SetEnvironmentVariable("https_proxy", proxy); + } - // Setup timeouts - proxyClient.ReceiveTimeout = (int)TimeSpan.FromSeconds(60).TotalMilliseconds; - proxyClient.SendTimeout = (int)TimeSpan.FromSeconds(60).TotalMilliseconds; + var httpClient = new HttpClient(httpClientHandler); + var logsUrl = string.Format("https://{0}/devices/0000/modules", hostname); + var httpRequest = new HttpRequestMessage(HttpMethod.Get, logsUrl); + HttpResponseMessage httpResponseMessage = await httpClient.SendAsync(httpRequest, HttpCompletionOption.ResponseHeadersRead); - // Get TcpClient to futher work - var client = proxyClient.CreateConnection(hostname, int.Parse(port)); - client.GetStream(); + var keys = httpResponseMessage.Headers.GetValues("iothub-errorcode"); + if (!keys.Contains("InvalidProtocolVersion")) + { + throw new Exception($"Wrong value for iothub-errorcode header"); + } } else { - TcpClient client = new TcpClient(); - await client.ConnectAsync(hostname, int.Parse(port)); - client.GetStream(); + // The current rust code never put proxy parameter when port is != than 443. + // So the code below is never exercised. It was put there to avoid silently ignoring the proxy + // if the rust code is changed. + if (proxy != null) + { + Uri proxyUri = new Uri(proxy); + IProxyClient proxyClient = MakeProxy(proxyUri); + + // Setup timeouts + proxyClient.ReceiveTimeout = (int)TimeSpan.FromSeconds(60).TotalMilliseconds; + proxyClient.SendTimeout = (int)TimeSpan.FromSeconds(60).TotalMilliseconds; + + // Get TcpClient to futher work + var client = proxyClient.CreateConnection(hostname, int.Parse(port)); + client.GetStream(); + } + else + { + TcpClient client = new TcpClient(); + await client.ConnectAsync(hostname, int.Parse(port)); + client.GetStream(); + } } } diff --git a/edgelet/iotedge/src/check/checks/connect_management_uri.rs b/edgelet/iotedge/src/check/checks/connect_management_uri.rs index 7ef043a4e38..3ba5faa09a2 100644 --- a/edgelet/iotedge/src/check/checks/connect_management_uri.rs +++ b/edgelet/iotedge/src/check/checks/connect_management_uri.rs @@ -52,7 +52,7 @@ impl ConnectManagementUri { |upstream_hostname| upstream_hostname.to_string() + &check.diagnostics_image_name, ) } else { - return Ok(CheckResult::Skipped); + check.diagnostics_image_name.clone() }; let connect_management_uri = settings.connect().management_uri(); diff --git a/edgelet/iotedge/src/check/checks/container_connect_upstream.rs b/edgelet/iotedge/src/check/checks/container_connect_upstream.rs index a5413bf201c..971e4596001 100644 --- a/edgelet/iotedge/src/check/checks/container_connect_upstream.rs +++ b/edgelet/iotedge/src/check/checks/container_connect_upstream.rs @@ -101,7 +101,7 @@ impl ContainerConnectUpstream { |upstream_hostname| upstream_hostname.to_string() + &check.diagnostics_image_name, ) } else { - return Ok(CheckResult::Skipped); + check.diagnostics_image_name.clone() }; let parent_hostname: String; @@ -139,14 +139,16 @@ impl ContainerConnectUpstream { &port, ]); - let proxy = settings - .agent() - .env() - .get("https_proxy") - .map(std::string::String::as_str); - self.proxy = proxy.map(ToOwned::to_owned); - if let Some(proxy) = proxy { - args.extend(&["--proxy", proxy]); + if &port == "443" { + let proxy = settings + .agent() + .env() + .get("https_proxy") + .map(std::string::String::as_str); + self.proxy = proxy.map(ToOwned::to_owned); + if let Some(proxy) = proxy { + args.extend(&["--proxy", proxy]); + } } if let Err((_, err)) = super::docker(docker_host_arg, args) { diff --git a/edgelet/iotedge/src/check/checks/container_local_time.rs b/edgelet/iotedge/src/check/checks/container_local_time.rs index 089b04f7f0a..fafd0796f60 100644 --- a/edgelet/iotedge/src/check/checks/container_local_time.rs +++ b/edgelet/iotedge/src/check/checks/container_local_time.rs @@ -50,7 +50,7 @@ impl ContainerLocalTime { |upstream_hostname| upstream_hostname.to_string() + &check.diagnostics_image_name, ) } else { - return Ok(CheckResult::Skipped); + check.diagnostics_image_name.clone() }; let output = super::docker( diff --git a/edgelet/iotedge/src/check/checks/container_resolve_parent_hostname.rs b/edgelet/iotedge/src/check/checks/container_resolve_parent_hostname.rs index 9e8fe84daaa..8cc44814dea 100644 --- a/edgelet/iotedge/src/check/checks/container_resolve_parent_hostname.rs +++ b/edgelet/iotedge/src/check/checks/container_resolve_parent_hostname.rs @@ -52,7 +52,7 @@ impl ContainerResolveParentHostname { |upstream_hostname| upstream_hostname.to_string() + &check.diagnostics_image_name, ) } else { - return Ok(CheckResult::Skipped); + check.diagnostics_image_name.clone() }; let docker_host_arg = if let Some(docker_host_arg) = &check.docker_host_arg { diff --git a/edgelet/iotedge/src/check/checks/host_connect_upstream.rs b/edgelet/iotedge/src/check/checks/host_connect_upstream.rs index f58ac40e781..8b1b14bd6e7 100644 --- a/edgelet/iotedge/src/check/checks/host_connect_upstream.rs +++ b/edgelet/iotedge/src/check/checks/host_connect_upstream.rs @@ -75,10 +75,14 @@ impl HostConnectUpstream { self.upstream_hostname = Some(upstream_hostname.clone()); - self.proxy = check - .settings - .as_ref() - .and_then(|settings| settings.agent().env().get("https_proxy").cloned()); + self.proxy = if self.port_number == 443 { + check + .settings + .as_ref() + .and_then(|settings| settings.agent().env().get("https_proxy").cloned()) + } else { + None + }; if let Some(proxy) = &self.proxy { runtime.block_on( diff --git a/edgelet/iotedge/src/check/checks/pull_agent_from_upstream.rs b/edgelet/iotedge/src/check/checks/pull_agent_from_upstream.rs index a87120debff..6f0af1e30f8 100644 --- a/edgelet/iotedge/src/check/checks/pull_agent_from_upstream.rs +++ b/edgelet/iotedge/src/check/checks/pull_agent_from_upstream.rs @@ -36,6 +36,38 @@ impl PullAgentFromUpstream { return Ok(CheckResult::Skipped); }; + if let (Some(username), Some(password), Some(server_address)) = ( + &settings + .agent() + .config() + .auth() + .and_then(docker::models::AuthConfig::username), + &settings + .agent() + .config() + .auth() + .and_then(docker::models::AuthConfig::password), + &settings + .agent() + .config() + .auth() + .and_then(docker::models::AuthConfig::serveraddress), + ) { + super::docker( + docker_host_arg, + vec![ + "login", + server_address, + "-p", + password, + "--username", + username, + ], + ) + .map_err(|(_, err)| err) + .context(format!("Failed to login to {}", server_address))?; + } + super::docker( docker_host_arg, vec!["pull", &settings.agent().config().image()], diff --git a/mqtt/Cargo.lock b/mqtt/Cargo.lock index dd9e92d2d69..130d2bdaf99 100644 --- a/mqtt/Cargo.lock +++ b/mqtt/Cargo.lock @@ -8,9 +8,9 @@ checksum = "ee2a4ec343196209d6594e19543ae87a39f96d5534d7174822a3ad825dd6ed7e" [[package]] name = "aho-corasick" -version = "0.7.13" +version = "0.7.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "043164d8ba5c4c3035fec9bbee8647c0261d788f3474306f93bb65901cae0e86" +checksum = "b476ce7103678b0c6d3d395dbbae31d48ff910bd28be979ba5d48c6351131d0d" dependencies = [ "memchr", ] @@ -26,9 +26,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.32" +version = "1.0.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b602bfe940d21c130f3895acd65221e8a61270debe89d628b9cb4e3ccb8569b" +checksum = "a1fd36ffbb1fb7c834eac128ea8d0e310c5aeb635548f9d58861e1308d46e71c" [[package]] name = "arc-swap" @@ -55,15 +55,15 @@ dependencies = [ [[package]] name = "assert_matches" -version = "1.3.0" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7deb0a829ca7bcfaf5da70b073a8d128619259a7be8216a355e23f00763059e5" +checksum = "695579f0f2520f3774bb40461e5adb066459d4e0af4d59d20175484fb8e9edf1" [[package]] name = "async-trait" -version = "0.1.40" +version = "0.1.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "687c230d85c0a52504709705fc8a53e4a692b83a2184f03dae73e38e1e93a783" +checksum = "b246867b8b3b6ae56035f1eb1ed557c1d8eae97f0d53696138a50fa0e3a3b8c0" dependencies = [ "proc-macro2", "quote", @@ -161,9 +161,9 @@ dependencies = [ [[package]] name = "bstr" -version = "0.2.13" +version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "31accafdb70df7871592c058eca3985b71104e15ac32f64706022c58867da931" +checksum = "473fc6b38233f9af7baa94fb5852dca389e3d95b8e21c8e3719301462c5d9faf" dependencies = [ "lazy_static", "memchr", @@ -200,9 +200,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.0.60" +version = "1.0.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef611cc68ff783f18535d77ddd080185275713d852c4f5cbb6122c462a7a825c" +checksum = "ed67cbde08356238e75fc4656be4749481eeffb09e19f320a25237d5221c985d" [[package]] name = "cfg-if" @@ -210,15 +210,23 @@ version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822" +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + [[package]] name = "chrono" -version = "0.4.15" +version = "0.4.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "942f72db697d8767c22d46a598e01f2d3b475501ea43d0db4f16d90259182d0b" +checksum = "670ad68c9088c2a963aaa298cb369688cf3f9465ce5e2d4ca10e6e0098a1ce73" dependencies = [ + "libc", "num-integer", "num-traits", "time", + "winapi 0.3.9", ] [[package]] @@ -277,6 +285,12 @@ dependencies = [ "serde_json", ] +[[package]] +name = "const_fn" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce90df4c658c62f12d78f7508cf92f9173e5184a539c10bfe54a3107b3ffd0f2" + [[package]] name = "core-foundation" version = "0.7.0" @@ -295,11 +309,11 @@ checksum = "b3a71ab494c0b5b860bdc8407ae08978052417070c2ced38573a9157ad75b8ac" [[package]] name = "crc32fast" -version = "1.2.0" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba125de2af0df55319f41944744ad91c71113bf74a4646efff39afe1f6842db1" +checksum = "81156fece84ab6a9f2afdb109ce3ae577e42b1228441eded99bd77f627953b1a" dependencies = [ - "cfg-if", + "cfg-if 1.0.0", ] [[package]] @@ -340,48 +354,48 @@ dependencies = [ [[package]] name = "crossbeam-channel" -version = "0.4.4" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b153fe7cbef478c567df0f972e02e6d736db11affe43dfc9c56a9374d1adfb87" +checksum = "dca26ee1f8d361640700bde38b2c37d8c22b3ce2d360e1fc1c74ea4b0aa7d775" dependencies = [ + "cfg-if 1.0.0", "crossbeam-utils", - "maybe-uninit", ] [[package]] name = "crossbeam-deque" -version = "0.7.3" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f02af974daeee82218205558e51ec8768b48cf524bd01d550abe5573a608285" +checksum = "94af6efb46fef72616855b036a624cf27ba656ffc9be1b9a3c931cfc7749a9a9" dependencies = [ + "cfg-if 1.0.0", "crossbeam-epoch", "crossbeam-utils", - "maybe-uninit", ] [[package]] name = "crossbeam-epoch" -version = "0.8.2" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "058ed274caafc1f60c4997b5fc07bf7dc7cca454af7c6e81edffe5f33f70dace" +checksum = "ec0f606a85340376eef0d6d8fec399e6d4a544d648386c6645eb6d0653b27d9f" dependencies = [ - "autocfg 1.0.1", - "cfg-if", + "cfg-if 1.0.0", + "const_fn", "crossbeam-utils", "lazy_static", - "maybe-uninit", "memoffset", "scopeguard", ] [[package]] name = "crossbeam-utils" -version = "0.7.2" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3c7c73a2d1e9fc0886a08b93e98eb643461230d5f1925e4036204d5f2e261a8" +checksum = "ec91540d98355f690a86367e566ecad2e9e579f230230eb7c21398372be73ea5" dependencies = [ "autocfg 1.0.1", - "cfg-if", + "cfg-if 1.0.0", + "const_fn", "lazy_static", ] @@ -495,11 +509,11 @@ dependencies = [ [[package]] name = "flate2" -version = "1.0.17" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "766d0e77a2c1502169d4a93ff3b8c15a71fd946cd0126309752104e5f3c46d94" +checksum = "da80be589a72651dcda34d8b35bcdc9b7254ad06325611074d9cc0fbb19f60ee" dependencies = [ - "cfg-if", + "cfg-if 0.1.10", "crc32fast", "libc", "miniz_oxide", @@ -565,9 +579,9 @@ checksum = "3dcaa9ae7725d12cdb85b3ad99a434db70b468c09ded17e012d86b5c1010f7a7" [[package]] name = "futures" -version = "0.3.5" +version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e05b85ec287aac0dc34db7d4a569323df697f9c55b99b15d6b4ef8cde49f613" +checksum = "5d8e3078b7b2a8a671cb7a3d17b4760e4181ea243227776ba83fd043b4ca034e" dependencies = [ "futures-channel", "futures-core", @@ -580,9 +594,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.5" +version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f366ad74c28cca6ba456d95e6422883cfb4b252a83bed929c83abfdbbf2967d5" +checksum = "a7a4d35f7401e948629c9c3d6638fb9bf94e0b2121e96c3b428cc4e631f3eb74" dependencies = [ "futures-core", "futures-sink", @@ -590,15 +604,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.5" +version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59f5fff90fd5d971f936ad674802482ba441b6f09ba5e15fd8b39145582ca399" +checksum = "d674eaa0056896d5ada519900dbf97ead2e46a7b6621e8160d79e2f2e1e2784b" [[package]] name = "futures-executor" -version = "0.3.5" +version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10d6bb888be1153d3abeb9006b11b02cf5e9b209fda28693c31ae1e4e012e314" +checksum = "cc709ca1da6f66143b8c9bec8e6260181869893714e9b5a490b169b0414144ab" dependencies = [ "futures-core", "futures-task", @@ -607,15 +621,15 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.5" +version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de27142b013a8e869c14957e6d2edeef89e97c289e69d042ee3a49acd8b51789" +checksum = "5fc94b64bb39543b4e432f1790b6bf18e3ee3b74653c5449f63310e9a74b123c" [[package]] name = "futures-macro" -version = "0.3.5" +version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d0b5a30a4328ab5473878237c447333c093297bded83a4983d10f4deea240d39" +checksum = "f57ed14da4603b2554682e9f2ff3c65d7567b53188db96cb71538217fc64581b" dependencies = [ "proc-macro-hack", "proc-macro2", @@ -625,24 +639,24 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.5" +version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f2032893cb734c7a05d85ce0cc8b8c4075278e93b24b66f9de99d6eb0fa8acc" +checksum = "0d8764258ed64ebc5d9ed185cf86a95db5cac810269c5d20ececb32e0088abbd" [[package]] name = "futures-task" -version = "0.3.5" +version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bdb66b5f09e22019b1ab0830f7785bcea8e7a42148683f99214f73f8ec21a626" +checksum = "4dd26820a9f3637f1302da8bceba3ff33adbe53464b54ca24d4e2d4f1db30f94" dependencies = [ "once_cell", ] [[package]] name = "futures-util" -version = "0.3.5" +version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8764574ff08b701a084482c3c7031349104b07ac897393010494beaa18ce32c6" +checksum = "8a894a0acddba51a2d49a6f4263b1e64b8c579ece8af50fa86503d52cd1eea34" dependencies = [ "futures-channel", "futures-core", @@ -664,7 +678,7 @@ version = "0.1.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc587bc0ec293155d5bfa6b9891ec18a1e330c234f896ea47fbada4cadbe47e6" dependencies = [ - "cfg-if", + "cfg-if 0.1.10", "libc", "wasi 0.9.0+wasi-snapshot-preview1", ] @@ -696,9 +710,9 @@ checksum = "d36fab90f82edc3c747f9d438e06cf0a491055896f2a279638bb5beed6c40177" [[package]] name = "hashbrown" -version = "0.9.0" +version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00d63df3d41950fb462ed38308eea019113ad1508da725bbedcd0fa5a85ef5f7" +checksum = "d7afe4a420e3fe79967a00898cc1f4db7c8a49a9333a29f8a4bd76a253d5cd04" [[package]] name = "heck" @@ -711,9 +725,9 @@ dependencies = [ [[package]] name = "hermit-abi" -version = "0.1.16" +version = "0.1.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c30f6d0bc6b00693347368a67d41b58f2fb851215ff1da49e90fe2c5c667151" +checksum = "5aca5565f760fb5b220e499d72710ed156fdb74e631659e99377d9ebfbd13ae8" dependencies = [ "libc", ] @@ -846,7 +860,7 @@ version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "63312a18f7ea8760cdd0a7c5aac1a619752a246b833545e3e36d1f81f7cd9e66" dependencies = [ - "cfg-if", + "cfg-if 0.1.10", ] [[package]] @@ -906,16 +920,16 @@ checksum = "db65c6da02e61f55dae90a0ae427b2a5f6b3e8db09f58d10efab23af92592616" dependencies = [ "arrayvec", "bitflags", - "cfg-if", + "cfg-if 0.1.10", "ryu", "static_assertions", ] [[package]] name = "libc" -version = "0.2.77" +version = "0.2.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2f96b10ec2560088a8e76961b00d47107b3a625fecb76dedb29ee7ccbf98235" +checksum = "2448f6066e80e3bfc792e9c98bf705b4b0fc6e8ef5b43e5889aff0eaa9c58743" [[package]] name = "linked-hash-map" @@ -947,7 +961,7 @@ version = "0.4.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4fabed175da42fed1fa0746b0ea71f412aa9d35e76e95e59b192c64b9dc2bf8b" dependencies = [ - "cfg-if", + "cfg-if 0.1.10", ] [[package]] @@ -988,9 +1002,9 @@ dependencies = [ [[package]] name = "miniz_oxide" -version = "0.4.2" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c60c0dfe32c10b43a144bad8fc83538c52f58302c92300ea7ec7bf7b38d5a7b9" +checksum = "0f2d26ec3309788e423cfbf68ad1800f061638098d76a83681af979dc4eda19d" dependencies = [ "adler", "autocfg 1.0.1", @@ -1002,7 +1016,7 @@ version = "0.6.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fce347092656428bc8eaf6201042cb551b8d67855af7374542a92a0fbfcac430" dependencies = [ - "cfg-if", + "cfg-if 0.1.10", "fuchsia-zircon", "fuchsia-zircon-sys", "iovec", @@ -1040,11 +1054,11 @@ dependencies = [ [[package]] name = "mockall" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf52bd480d59ec342893c9c64ace644b4bbb1e184f8217312f0282107a372e4d" +checksum = "41cabea45a7fc0e37093f4f30a5e2b62602253f91791c057d5f0470c63260c3d" dependencies = [ - "cfg-if", + "cfg-if 0.1.10", "downcast", "fragile", "lazy_static", @@ -1055,11 +1069,11 @@ dependencies = [ [[package]] name = "mockall_derive" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa4f060c7e8d81fa8cf7bfd4a2cc183402d1066c9cba56998e2807b109d8c0ec" +checksum = "7c461918bf7f59eefb1459252756bf2351a995d6bd510d0b2061bd86bcdabfa6" dependencies = [ - "cfg-if", + "cfg-if 0.1.10", "proc-macro2", "quote", "syn", @@ -1100,6 +1114,7 @@ dependencies = [ "humantime-serde", "lazy_static", "matches", + "mockall", "mqtt-broker", "mqtt-broker-tests-util", "mqtt3", @@ -1110,6 +1125,7 @@ dependencies = [ "serde", "serde_bytes", "serde_derive", + "serde_json", "serial_test", "tempfile", "test-case", @@ -1238,6 +1254,7 @@ dependencies = [ "mqtt-broker", "mqtt3", "policy", + "proptest", "test-case", "thiserror", "tracing", @@ -1254,7 +1271,6 @@ dependencies = [ "futures-sink", "futures-util", "log", - "mockall", "serde", "structopt", "tokio", @@ -1266,7 +1282,8 @@ name = "mqttd" version = "0.1.0" dependencies = [ "anyhow", - "atty", + "async-trait", + "cfg-if 1.0.0", "chrono", "clap", "edgelet-client", @@ -1280,6 +1297,7 @@ dependencies = [ "thiserror", "tokio", "tracing", + "tracing-log", "tracing-subscriber", ] @@ -1307,7 +1325,7 @@ version = "0.2.35" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3ebc3ec692ed7c9a255596c67808dee269f64655d8baf7b4f0638e51ba1d6853" dependencies = [ - "cfg-if", + "cfg-if 0.1.10", "libc", "winapi 0.3.9", ] @@ -1377,7 +1395,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8d575eff3665419f9b83678ff2815858ad9d11567e082f5ac1814baba4e2bcb4" dependencies = [ "bitflags", - "cfg-if", + "cfg-if 0.1.10", "foreign-types", "lazy_static", "libc", @@ -1439,7 +1457,7 @@ version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d58c7c768d4ba344e3e8d72518ac13e259d7c7ade24167003b8488e10b6740a3" dependencies = [ - "cfg-if", + "cfg-if 0.1.10", "cloudabi 0.0.3", "libc", "redox_syscall", @@ -1453,7 +1471,7 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c361aa727dd08437f2f1447be8b59a33b0edd15e0fcee698f935613d9efbca9b" dependencies = [ - "cfg-if", + "cfg-if 0.1.10", "cloudabi 0.1.0", "instant", "libc", @@ -1476,18 +1494,18 @@ checksum = "d4fd5641d01c8f18a23da7b6fe29298ff4b55afcccdf78973b24cf3175fee32e" [[package]] name = "pin-project" -version = "0.4.23" +version = "0.4.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca4433fff2ae79342e497d9f8ee990d174071408f28f726d6d83af93e58e48aa" +checksum = "2ffbc8e94b38ea3d2d8ba92aea2983b503cd75d0888d75b86bb37970b5698e15" dependencies = [ "pin-project-internal", ] [[package]] name = "pin-project-internal" -version = "0.4.23" +version = "0.4.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c0e815c3ee9a031fdf5af21c10aa17c573c9c6a566328d99e3936c34e36461f" +checksum = "65ad2ae56b6abe3a1ee25f15ee605bacadb9a764edaba9c2bf4103800d4a1895" dependencies = [ "proc-macro2", "quote", @@ -1496,9 +1514,9 @@ dependencies = [ [[package]] name = "pin-project-lite" -version = "0.1.7" +version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "282adbf10f2698a7a77f8e983a74b2d18176c19a7fd32a45446139ae7b02b715" +checksum = "c917123afa01924fc84bb20c4c03f004d9c38e5127e3c039bbf7f4b9c76a2f6b" [[package]] name = "pin-utils" @@ -1508,9 +1526,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "pkg-config" -version = "0.3.18" +version = "0.3.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d36492546b6af1463394d46f0c834346f31548646f6ba10849802c9c9a27ac33" +checksum = "3831453b3449ceb48b6d9c7ad7c96d5ea673e9b470a1dc578c2ce6521230884c" [[package]] name = "plotters" @@ -1528,8 +1546,10 @@ dependencies = [ name = "policy" version = "0.1.0" dependencies = [ + "itertools", "lazy_static", "matches", + "proptest", "regex", "serde", "serde_json", @@ -1609,9 +1629,9 @@ checksum = "eba180dafb9038b050a4c280019bbedf9f2467b61e5d892dcad585bb57aadc5a" [[package]] name = "proc-macro2" -version = "1.0.21" +version = "1.0.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36e28516df94f3dd551a587da5357459d9b36d945a7c37c3557928c1c2ff2a2c" +checksum = "1e0704ee1a7e00d7bb417d0770ea303c1bccbabf0ef1667dae92b5967f5f8a71" dependencies = [ "unicode-xid", ] @@ -1800,9 +1820,9 @@ dependencies = [ [[package]] name = "rayon" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cfd016f0c045ad38b5251be2c9c0ab806917f82da4d36b2a327e5166adad9270" +checksum = "8b0d8e0819fadc20c74ea8373106ead0600e3a67ef1fe8da56e39b9ae7275674" dependencies = [ "autocfg 1.0.1", "crossbeam-deque", @@ -1812,9 +1832,9 @@ dependencies = [ [[package]] name = "rayon-core" -version = "1.8.1" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8c4fec834fb6e6d2dd5eece3c7b432a52f0ba887cf40e595190c4107edc08bf" +checksum = "9ab346ac5921dc62ffa9f89b7a773907511cdfa5490c572ae9be1be33e8afa4a" dependencies = [ "crossbeam-channel", "crossbeam-deque", @@ -1840,9 +1860,9 @@ checksum = "41cc0f7e4d5d4544e8861606a285bb08d3e70712ccc7d2b84d7c0ccfaf4b05ce" [[package]] name = "regex" -version = "1.3.9" +version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c3780fcf44b193bc4d09f36d2a3c87b251da4a046c87795a0d35f4f927ad8e6" +checksum = "8963b85b8ce3074fecffde43b4b0dded83ce2f367dc8d363afc56679f3ee820b" dependencies = [ "aho-corasick", "memchr", @@ -1862,9 +1882,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.6.18" +version = "0.6.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26412eb97c6b088a6997e05f69403a802a92d520de2f8e63c2b65f9e0f47c4e8" +checksum = "8cab7a364d15cde1e505267766a2d3c4e22a843e1a601f0fa7564c0f82ced11c" [[package]] name = "remove_dir_all" @@ -1967,9 +1987,9 @@ checksum = "388a1df253eca08550bef6c72392cfe7c30914bf41df5269b68cbd6ff8f570a3" [[package]] name = "serde" -version = "1.0.116" +version = "1.0.117" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96fe57af81d28386a513cbc6858332abc6117cfdb5999647c6444b8f43a370a5" +checksum = "b88fa983de7720629c9387e9f517353ed404164b1e482c970a90c1a4aaf7dc1a" dependencies = [ "serde_derive", ] @@ -1995,9 +2015,9 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.116" +version = "1.0.117" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f630a6370fd8e457873b4bd2ffdae75408bc291ba72be773772a4c2a065d9ae8" +checksum = "cbd1ae72adb44aab48f325a02444a5fc079349a8d804c1fc922aed3f7454c74e" dependencies = [ "proc-macro2", "quote", @@ -2006,9 +2026,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.57" +version = "1.0.59" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "164eacbdb13512ec2745fb09d51fd5b22b0d65ed294a1dcf7285a360c80a675c" +checksum = "dcac07dbffa1c65e7f816ab9eba78eb142c6d44410f4eeba1e26e4f5dfa56b95" dependencies = [ "indexmap", "itoa", @@ -2098,7 +2118,7 @@ version = "0.3.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b1fa70dc5c8104ec096f4fe7ede7a221d35ae13dcd19ba1ad9a81d2cab9a1c44" dependencies = [ - "cfg-if", + "cfg-if 0.1.10", "libc", "redox_syscall", "winapi 0.3.9", @@ -2124,9 +2144,9 @@ checksum = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a" [[package]] name = "structopt" -version = "0.3.17" +version = "0.3.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6cc388d94ffabf39b5ed5fadddc40147cb21e605f53db6f8f36a625d27489ac5" +checksum = "126d630294ec449fae0b16f964e35bf3c74f940da9dca17ee9b905f7b3112eb8" dependencies = [ "clap", "lazy_static", @@ -2135,9 +2155,9 @@ dependencies = [ [[package]] name = "structopt-derive" -version = "0.4.10" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e2513111825077552a6751dfad9e11ce0fba07d7276a3943a037d7e93e64c5f" +checksum = "65e51c492f9e23a220534971ff5afc14037289de430e3c83f9daf6a1b6ae91e8" dependencies = [ "heck", "proc-macro-error", @@ -2148,9 +2168,9 @@ dependencies = [ [[package]] name = "syn" -version = "1.0.41" +version = "1.0.46" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6690e3e9f692504b941dc6c3b188fd28df054f7fb8469ab40680df52fdcc842b" +checksum = "5ad5de3220ea04da322618ded2c42233d02baca219d6f160a3e9c87cda16c942" dependencies = [ "proc-macro2", "quote", @@ -2163,7 +2183,7 @@ version = "3.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a6e24d9338a0a5be79593e2fa15a648add6138caa803e2d5bc782c371732ca9" dependencies = [ - "cfg-if", + "cfg-if 0.1.10", "libc", "rand 0.7.3", "redox_syscall", @@ -2203,18 +2223,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.20" +version = "1.0.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7dfdd070ccd8ccb78f4ad66bf1982dc37f620ef696c6b5028fe2ed83dd3d0d08" +checksum = "318234ffa22e0920fe9a40d7b8369b5f649d490980cf7aadcf1eb91594869b42" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.20" +version = "1.0.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd80fc12f73063ac132ac92aceea36734f04a1d93c1240c6944e23a3b8841793" +checksum = "cae2447b6282786c3493999f40a9be2a6ad20cb8bd268b0a0dbf5a065535c0ab" dependencies = [ "proc-macro2", "quote", @@ -2357,12 +2377,13 @@ checksum = "e987b6bf443f4b5b3b6f38704195592cca41c5bb7aedd3c3693c7081f8289860" [[package]] name = "tracing" -version = "0.1.19" +version = "0.1.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d79ca061b032d6ce30c660fded31189ca0b9922bf483cd70759f13a2d86786c" +checksum = "b0987850db3733619253fe60e17cb59b82d37c7e6c0236bb81e4d6b87c879f27" dependencies = [ - "cfg-if", + "cfg-if 0.1.10", "log", + "pin-project-lite", "tracing-attributes", "tracing-core", ] @@ -2380,9 +2401,9 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.16" +version = "0.1.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5bcf46c1f1f06aeea2d6b81f3c863d0930a596c86ad1920d4e5bad6dd1d7119a" +checksum = "f50de3927f93d202783f4513cda820ab47ef17f624b03c096e86ef00c67e6b5f" dependencies = [ "lazy_static", ] @@ -2559,7 +2580,7 @@ version = "0.2.68" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1ac64ead5ea5f05873d7c12b545865ca2b8d28adfc50a49b84770a3a97265d42" dependencies = [ - "cfg-if", + "cfg-if 0.1.10", "wasm-bindgen-macro", ] diff --git a/mqtt/build/linux/test.sh b/mqtt/build/linux/test.sh index e708309a408..bcc3bed71d6 100755 --- a/mqtt/build/linux/test.sh +++ b/mqtt/build/linux/test.sh @@ -28,7 +28,10 @@ usage() echo "" echo "options" echo " -h, --help Print this help and exit." + echo " -t, --target Target architecture." echo " -r, --release Release build? (flag, default: false)" + echo " -c, --cargo Path of cargo installation." + echo " --report Optional. Generates the xml test report with specified name." exit 1; } @@ -44,17 +47,18 @@ process_args() TARGET="$arg" save_next_arg=0 elif [ $save_next_arg -eq 2 ]; then - RELEASE="true" + CARGO="$arg" save_next_arg=0 elif [ $save_next_arg -eq 3 ]; then - CARGO="$arg" + REPORT="$arg" save_next_arg=0 else case "$arg" in "-h" | "--help" ) usage;; "-t" | "--target" ) save_next_arg=1;; - "-r" | "--release" ) save_next_arg=2;; - "-c" | "--cargo" ) save_next_arg=3;; + "-r" | "--release" ) RELEASE="--release";; + "-c" | "--cargo" ) save_next_arg=2;; + "--report" ) save_next_arg=3;; * ) usage;; esac fi @@ -63,8 +67,18 @@ process_args() process_args "$@" -if [[ -z ${RELEASE} ]]; then - cd "$PROJECT_ROOT" && $CARGO test --workspace --all-features --target "$TARGET" +if [[ -z ${REPORT} ]]; then + echo $CARGO test --no-fail-fast --workspace --all-features --target "$TARGET" "$RELEASE" + cd "$PROJECT_ROOT" && $CARGO test --no-fail-fast --workspace --all-features \ + --target "$TARGET" "$RELEASE" else - cd "$PROJECT_ROOT" && $CARGO test --workspace --all-features --release --target "$TARGET" + # Get cargo2junit to report test results to Azure Pipelines + $CARGO install cargo2junit + + cd "$PROJECT_ROOT" && $CARGO test --no-fail-fast --workspace --all-features \ + --target "$TARGET" "$RELEASE" \ + -- -Z unstable-options --format json | tee test-result.json + + # Convert test results to junit format. + cat test-result.json | cargo2junit > $REPORT fi diff --git a/mqtt/edgelet-client/src/lib.rs b/mqtt/edgelet-client/src/lib.rs index be78531e71a..1daef5721e0 100644 --- a/mqtt/edgelet-client/src/lib.rs +++ b/mqtt/edgelet-client/src/lib.rs @@ -69,13 +69,13 @@ pub(crate) enum Scheme { #[derive(Debug, thiserror::Error)] pub enum ApiError { - #[error("could not construct request URL")] + #[error("could not construct URL")] ConstructRequestUrl(#[source] Box), #[error("could not construct request")] ConstructRequest(#[source] http::Error), - #[error("could not construct request")] + #[error("could not make HTTP request")] ExecuteRequest(#[source] hyper::Error), #[error("response has status code {0} and body {1}")] diff --git a/mqtt/mqtt-bridge/Cargo.toml b/mqtt/mqtt-bridge/Cargo.toml index 0b6407c2f07..aff3e1d9906 100644 --- a/mqtt/mqtt-bridge/Cargo.toml +++ b/mqtt/mqtt-bridge/Cargo.toml @@ -16,6 +16,7 @@ futures-util = "0.3" humantime = "2.0" humantime-serde = "1.0" lazy_static = "1.4" +mockall = "0.8" openssl = "0.10" parking_lot = "0.11" percent-encoding = "1.0" @@ -23,6 +24,7 @@ regex = "1" serde = { version = "1.0", features = ["derive", "rc"] } serde_bytes = "0.11" serde_derive = "1.0" +serde_json = "1.0" serial_test = "0.4" thiserror = "1.0" tokio = { version = "0.2", features = ["sync", "rt-util"] } diff --git a/mqtt/mqtt-bridge/src/bridge.rs b/mqtt/mqtt-bridge/src/bridge.rs index 6e901695ba0..1b65d12ad15 100644 --- a/mqtt/mqtt-bridge/src/bridge.rs +++ b/mqtt/mqtt-bridge/src/bridge.rs @@ -1,111 +1,183 @@ -use futures_util::{future::select, future::Either, pin_mut}; +use futures_util::{ + future::{self, Either}, + pin_mut, +}; use mqtt3::ShutdownError; -use tokio::sync::{mpsc::error::SendError, oneshot, oneshot::Sender}; use tracing::{debug, error, info, info_span}; use tracing_futures::Instrument; use crate::{ - client::ClientError, - persist::PersistError, - pump::{PumpMessage, PumpPair}, - rpc::RpcError, - settings::ConnectionSettings, + client::{ClientError, MqttClientConfig}, + config_update::BridgeDiff, + persist::{PersistError, PublicationStore, StreamWakeableState, WakingMemoryStore}, + pump::{Builder, Pump, PumpError, PumpHandle, PumpMessage}, + settings::{ConnectionSettings, Credentials}, + upstream::{ + ConnectivityError, LocalUpstreamMqttEventHandler, LocalUpstreamPumpEvent, + LocalUpstreamPumpEventHandler, RemoteUpstreamMqttEventHandler, RemoteUpstreamPumpEvent, + RemoteUpstreamPumpEventHandler, RpcError, + }, }; -#[derive(Debug)] -pub struct BridgeShutdownHandle { - local_shutdown: Sender<()>, - remote_shutdown: Sender<()>, +pub struct BridgeHandle { + local_pump_handle: PumpHandle, + remote_pump_handle: PumpHandle, } -impl BridgeShutdownHandle { - // TODO: Remove when we implement bridge controller shutdown - #![allow(dead_code)] - pub async fn shutdown(self) -> Result<(), BridgeError> { - self.local_shutdown - .send(()) - .map_err(BridgeError::ShutdownBridge)?; - self.remote_shutdown - .send(()) - .map_err(BridgeError::ShutdownBridge)?; +impl BridgeHandle { + pub fn new( + local_pump_handle: PumpHandle, + remote_pump_handle: PumpHandle, + ) -> Self { + Self { + local_pump_handle, + remote_pump_handle, + } + } + + pub async fn send_update(&mut self, message: BridgeDiff) -> Result<(), BridgeError> { + let (local_updates, remote_updates) = message.into_parts(); + + if local_updates.has_updates() { + debug!("sending update to local pump {:?}", local_updates); + self.local_pump_handle + .send(PumpMessage::ConfigurationUpdate(local_updates)) + .await?; + } + + if remote_updates.has_updates() { + debug!("sending update to remote pump {:?}", remote_updates); + self.remote_pump_handle + .send(PumpMessage::ConfigurationUpdate(remote_updates)) + .await?; + } + Ok(()) } + + pub async fn shutdown(mut self) { + if let Err(e) = self.local_pump_handle.send(PumpMessage::Shutdown).await { + error!(error = %e, "unable to request shutdown for local pump"); + } + + if let Err(e) = self.remote_pump_handle.send(PumpMessage::Shutdown).await { + error!(error = %e, "unable to request shutdown for remote pump"); + } + } } /// Bridge implementation that connects to local broker and remote broker and handles messages flow -pub struct Bridge { - pumps: PumpPair, - connection_settings: ConnectionSettings, +pub struct Bridge { + local_pump: Pump, LocalUpstreamPumpEventHandler>, + remote_pump: Pump, RemoteUpstreamPumpEventHandler>, } -impl Bridge { - pub async fn new( - system_address: String, - device_id: String, - connection_settings: ConnectionSettings, +impl Bridge { + pub fn new_upstream( + system_address: &str, + device_id: &str, + settings: &ConnectionSettings, ) -> Result { - debug!("creating bridge {}...", connection_settings.name()); - - let mut pumps = PumpPair::new(&connection_settings, &system_address, &device_id)?; + const BATCH_SIZE: usize = 10; + + debug!("creating bridge {}...", settings.name()); + + let (local_pump, remote_pump) = Builder::default() + .with_local(|pump| { + pump.with_config(MqttClientConfig::new( + system_address, + settings.keep_alive(), + settings.clean_session(), + Credentials::Anonymous(format!("{}/{}/$bridge", device_id, settings.name())), + )) + .with_rules(settings.forwards()); + }) + .with_remote(|pump| { + pump.with_config(MqttClientConfig::new( + settings.address(), + settings.keep_alive(), + settings.clean_session(), + settings.credentials().clone(), + )) + .with_rules(settings.subscriptions()); + }) + .with_store(|| PublicationStore::new_memory(BATCH_SIZE)) + .build()?; + + debug!("created bridge {}...", settings.name()); - pumps - .local_pump - .subscribe() - .instrument(info_span!("pump", name = "local")) - .await?; - - pumps - .remote_pump - .subscribe() - .instrument(info_span!("pump", name = "remote")) - .await?; - - debug!("created {} bridge...", connection_settings.name()); Ok(Bridge { - pumps, - connection_settings, + local_pump, + remote_pump, }) } +} - pub async fn run(mut self) -> Result<(), BridgeError> { - info!("starting {} bridge...", self.connection_settings.name()); - - let (local_shutdown, local_shutdown_listener) = oneshot::channel::<()>(); - let (remote_shutdown, remote_shutdown_listener) = oneshot::channel::<()>(); - let shutdown_handle = BridgeShutdownHandle { - local_shutdown, - remote_shutdown, - }; +impl Bridge +where + S: StreamWakeableState + Send, +{ + pub async fn run(self) -> Result<(), BridgeError> { + info!("starting bridge..."); + let shutdown_local_pump = self.local_pump.handle(); let local_pump = self - .pumps .local_pump - .run(local_shutdown_listener) + .run() .instrument(info_span!("pump", name = "local")); + let shutdown_remote_pump = self.remote_pump.handle(); let remote_pump = self - .pumps .remote_pump - .run(remote_shutdown_listener) + .run() .instrument(info_span!("pump", name = "remote")); + + debug!("starting pumps ...",); + pin_mut!(local_pump, remote_pump); - debug!( - "starting pumps for {} bridge...", - self.connection_settings.name() - ); - match select(local_pump, remote_pump).await { - Either::Left(_) => { - shutdown_handle.shutdown().await?; + match future::select(local_pump, remote_pump).await { + Either::Left((local_pump, remote_pump)) => { + if let Err(e) = local_pump { + error!(error = %e, "local pump exited with error"); + } else { + info!("local pump exited"); + } + + debug!("shutting down remote pump..."); + shutdown_remote_pump.shutdown().await; + + if let Err(e) = remote_pump.await { + error!(error = %e, "remote pump exited with error"); + } else { + info!("remote pump exited"); + } } - Either::Right(_) => { - shutdown_handle.shutdown().await?; + Either::Right((remote_pump, local_pump)) => { + if let Err(e) = remote_pump { + error!(error = %e, "remote pump exited with error"); + } else { + info!("remote pump exited"); + } + + debug!("shutting down local pump..."); + shutdown_local_pump.shutdown().await; + + if let Err(e) = local_pump.await { + error!(error = %e, "local pump exited with error"); + } else { + info!("local pump exited"); + } } } - debug!("bridge {} stopped...", self.connection_settings.name()); + info!("bridge stopped"); Ok(()) } + + pub fn handle(&self) -> BridgeHandle { + BridgeHandle::new(self.local_pump.handle(), self.remote_pump.handle()) + } } /// Bridge error. @@ -123,18 +195,30 @@ pub enum BridgeError { #[error("failed to load settings.")] LoadingSettings(#[from] config::ConfigError), - #[error("failed to get send pump message.")] - SenderToPump(#[from] SendError), + #[error("Failed to get send pump message.")] + SendToPump, + + #[error("Failed to send message to pump: {0}")] + SendBridgeUpdate(#[from] PumpError), #[error("failed to execute RPC command")] Rpc(#[from] RpcError), + #[error("failed to execute connectivity event")] + Connectivity(#[from] ConnectivityError), + #[error("failed to signal bridge shutdown.")] ShutdownBridge(()), #[error("failed to get publish handle from client.")] PublishHandle(#[source] ClientError), + #[error("failed to validate client settings: {0}")] + ValidationError(#[source] ClientError), + + #[error("failed to get subscribe handle from client.")] + UpdateSubscriptionHandle(#[source] ClientError), + #[error("failed to get publish handle from client.")] ClientShutdown(#[from] ShutdownError), } diff --git a/mqtt/mqtt-bridge/src/client.rs b/mqtt/mqtt-bridge/src/client.rs index 37ab3b9153b..715eeaca3c6 100644 --- a/mqtt/mqtt-bridge/src/client.rs +++ b/mqtt/mqtt-bridge/src/client.rs @@ -1,18 +1,18 @@ -#![allow(dead_code)] // TODO remove when ready -use std::{ - collections::HashSet, fmt::Display, io::Error, io::ErrorKind, pin::Pin, str, time::Duration, -}; +#![allow(clippy::default_trait_access)] // Needed because mock! macro violates +use core::{convert::TryInto, num::TryFromIntError}; +use std::{fmt::Display, io::Error, io::ErrorKind, pin::Pin, str, time::Duration}; use async_trait::async_trait; use chrono::Utc; use futures_util::future::{self, BoxFuture}; +use mockall::automock; use openssl::{ssl::SslConnector, ssl::SslMethod, x509::X509}; use tokio::{io::AsyncRead, io::AsyncWrite, net::TcpStream, stream::StreamExt}; -use tracing::{debug, error, info, warn}; +use tracing::{debug, error, info}; use mqtt3::{ - proto, Client, Event, IoSource, PublishHandle, ShutdownError, SubscriptionUpdateEvent, - UpdateSubscriptionError, + proto::{self, Publication, SubscribeTo}, + Client, Event, IoSource, ShutdownError, UpdateSubscriptionError, }; use crate::{ @@ -21,24 +21,12 @@ use crate::{ }; const DEFAULT_TOKEN_DURATION_MINS: i64 = 60; -const DEFAULT_MAX_RECONNECT: Duration = Duration::from_secs(5); +const DEFAULT_MAX_RECONNECT: Duration = Duration::from_secs(60); // TODO: get QOS from topic settings const DEFAULT_QOS: proto::QoS = proto::QoS::AtLeastOnce; +// TODO: read from env var const API_VERSION: &str = "2010-01-01"; -#[derive(Debug, Clone)] -pub struct ClientShutdownHandle(mqtt3::ShutdownHandle); - -impl ClientShutdownHandle { - pub async fn shutdown(&mut self) -> Result<(), ClientError> { - self.0 - .shutdown() - .await - .map_err(ClientError::ShutdownClient)?; - Ok(()) - } -} - #[derive(Clone)] enum BridgeIoSource { Tcp(TcpConnection), @@ -67,12 +55,12 @@ where T: TokenSource + Clone + Send + Sync + 'static, { pub fn new( - address: String, + address: impl Into, token_source: Option, trust_bundle_source: Option, ) -> Self { Self { - address, + address: address.into(), token_source, trust_bundle_source, } @@ -116,6 +104,13 @@ impl BridgeIoSource { Error::new(ErrorKind::Other, format!("failed to connect: {}", err)) })?; + if let Some(pass) = password.as_ref() { + validate_length(pass).map_err(|_| { + error!("password too long"); + ErrorKind::InvalidInput + })?; + } + let stream: Pin> = Box::pin(io); Ok((stream, password)) }) @@ -152,6 +147,13 @@ impl BridgeIoSource { Error::new(ErrorKind::Other, format!("failed to connect: {}", err)) })?; + if let Some(pass) = password.as_ref() { + validate_length(pass).map_err(|_| { + error!("password too long"); + ErrorKind::InvalidInput + })?; + } + let config = SslConnector::builder(SslMethod::tls()) .map(|mut builder| { if let Some(trust_bundle) = server_root_certificate { @@ -184,58 +186,65 @@ impl BridgeIoSource { } } -/// This is a wrapper over mqtt3 client -pub struct MqttClient -where - H: EventHandler, -{ - client_id: Option, - username: Option, - io_source: BridgeIoSource, +pub struct MqttClientConfig { + addr: String, keep_alive: Duration, - client: Client, - event_handler: H, + clean_session: bool, + credentials: Credentials, } -impl MqttClient { - pub fn tcp( - address: &str, +impl MqttClientConfig { + pub fn new( + addr: impl Into, keep_alive: Duration, clean_session: bool, - event_handler: H, - connection_credentials: &Credentials, + credentials: Credentials, ) -> Self { - let token_source = Self::token_source(&connection_credentials); - let tcp_connection = TcpConnection::new(address.to_owned(), token_source, None); + Self { + addr: addr.into(), + keep_alive, + clean_session, + credentials, + } + } +} + +/// This is a wrapper over mqtt3 client +pub struct MqttClient { + client: Client, + event_handler: H, +} + +impl MqttClient +where + H: MqttEventHandler, +{ + pub fn tcp(config: MqttClientConfig, event_handler: H) -> Result { + let token_source = Self::token_source(&config.credentials); + let tcp_connection = TcpConnection::new(config.addr, token_source, None); let io_source = BridgeIoSource::Tcp(tcp_connection); Self::new( - keep_alive, - clean_session, + config.keep_alive, + config.clean_session, event_handler, - connection_credentials, + &config.credentials, io_source, ) } - pub fn tls( - address: &str, - keep_alive: Duration, - clean_session: bool, - event_handler: H, - connection_credentials: &Credentials, - ) -> Self { - let trust_bundle = Some(TrustBundleSource::new(connection_credentials.clone())); + pub fn tls(config: MqttClientConfig, event_handler: H) -> Result { + let trust_bundle = Some(TrustBundleSource::new(config.credentials.clone())); - let token_source = Self::token_source(&connection_credentials); - let tcp_connection = TcpConnection::new(address.to_owned(), token_source, trust_bundle); + let token_source = Self::token_source(&config.credentials); + let tcp_connection = TcpConnection::new(config.addr, token_source, trust_bundle); let io_source = BridgeIoSource::Tls(tcp_connection); Self::new( - keep_alive, - clean_session, + config.keep_alive, + config.clean_session, event_handler, - connection_credentials, + &config.credentials, io_source, ) } @@ -246,7 +255,7 @@ impl MqttClient { event_handler: H, connection_credentials: &Credentials, io_source: BridgeIoSource, - ) -> Self { + ) -> Result { let (client_id, username) = match connection_credentials { Credentials::Provider(provider_settings) => ( format!( @@ -272,23 +281,42 @@ impl MqttClient { let client_id = if clean_session { None } else { Some(client_id) }; + Self::validate(client_id.as_ref(), username.as_ref(), &keep_alive)?; + let client = Client::new( - client_id.clone(), - username.clone(), + client_id, + username, None, - io_source.clone(), + io_source, DEFAULT_MAX_RECONNECT, keep_alive, ); - Self { - client_id, - username, - io_source, - keep_alive, + Ok(Self { client, event_handler, + }) + } + + fn validate( + client_id: Option<&String>, + username: Option<&String>, + keep_alive: &Duration, + ) -> Result<(), ClientError> { + if let Some(id) = client_id { + validate_length(id)?; + }; + + if let Some(name) = username { + validate_length(name)?; } + + let _: u16 = keep_alive + .as_secs() + .try_into() + .map_err(ClientError::StringTooLarge)?; + + Ok(()) } fn token_source(connection_credentials: &Credentials) -> Option { @@ -300,116 +328,206 @@ impl MqttClient { } } - pub fn shutdown_handle(&self) -> Result { - self.client - .shutdown_handle() - .map_or(Err(ShutdownError::ClientDoesNotExist), |shutdown_handle| { - Ok(ClientShutdownHandle(shutdown_handle)) - }) - } - - pub fn publish_handle(&self) -> Result { - let publish_handle = self - .client - .publish_handle() - .map_err(ClientError::PublishHandle)?; + pub async fn run(&mut self) -> Result<(), ClientError> { + let default_topics = self.event_handler.subscriptions(); + self.subscribe(default_topics).await?; - Ok(publish_handle) + self.handle_events().await; + Ok(()) } - pub async fn handle_events(&mut self) { - debug!("polling bridge client"); + async fn handle_events(&mut self) { + debug!("polling bridge client..."); while let Some(event) = self.client.try_next().await.unwrap_or_else(|e| { // TODO: handle the error by recreating the connection - error!(error=%e, "failed to poll events"); + error!(error = %e, "failed to poll events"); None }) { debug!("handling event {:?}", event); - if let Err(e) = self.event_handler.handle(&event).await { - error!(err = %e, "error processing event {:?}", event); + if let Err(e) = self.event_handler.handle(event).await { + error!(error = %e, "error processing event"); } } } - pub async fn subscribe(&mut self, topics: &[String]) -> Result<(), ClientError> { - info!("subscribing to topics"); - let subscriptions = topics.iter().map(|topic| proto::SubscribeTo { - topic_filter: topic.to_string(), + async fn subscribe(&mut self, topics: Vec) -> Result<(), ClientError> { + info!("subscribing to topics {:?}...", topics); + let subscriptions = topics.into_iter().map(|topic| proto::SubscribeTo { + topic_filter: topic, qos: DEFAULT_QOS, }); for subscription in subscriptions { - debug!("subscribing to topic {}", subscription.topic_filter); self.client .subscribe(subscription) .map_err(ClientError::Subscribe)?; } - let mut subacks: HashSet<_> = topics.iter().collect(); - if subacks.is_empty() { - info!("has no topics to subscribe to"); - return Ok(()); - } + Ok(()) + } +} + +fn validate_length(id: &str) -> Result<(), ClientError> { + let _: u16 = id.len().try_into().map_err(ClientError::StringTooLarge)?; + + Ok(()) +} + +/// A trait extending `MqttClient` with additional handles functionality. +pub trait MqttClientExt { + /// Publish handle type. + type PublishHandle; + + /// Returns an instance of publish handle. + fn publish_handle(&self) -> Result; + + /// Update subscription handle type. + type UpdateSubscriptionHandle; - // TODO: Don't wait for subscription updates before starting the bridge. - // We should move this logic to the handle events. - // - // This is fine for now when dealing with only the upstream edge device. - // But when remote brokers are introduced this will be an issue. - while let Some(event) = self + /// Returns an instance of update subscription handle. + fn update_subscription_handle(&self) -> Result; + + /// Client shutdown handle type. + type ShutdownHandle; + + /// Returns an instance of shutdown handle. + fn shutdown_handle(&self) -> Result; +} + +#[cfg(not(test))] +/// Implements `MqttClientExt` for production code. +impl MqttClientExt for MqttClient { + type PublishHandle = PublishHandle; + + fn publish_handle(&self) -> Result { + let publish_handle = self .client - .try_next() - .await - .map_err(ClientError::PollClient)? - { - if let Event::SubscriptionUpdates(subscriptions) = event { - for subscription in subscriptions { - match subscription { - SubscriptionUpdateEvent::Subscribe(sub) => { - subacks.remove(&sub.topic_filter); - debug!("successfully subscribed to topic {}", &sub.topic_filter); - } - SubscriptionUpdateEvent::RejectedByServer(topic_filter) => { - subacks.remove(&topic_filter); - error!("subscription rejected by server {}", topic_filter); - } - SubscriptionUpdateEvent::Unsubscribe(topic_filter) => { - warn!("unsubscribed to {}", topic_filter); - } - } - } + .publish_handle() + .map_err(ClientError::PublishHandle)?; + Ok(PublishHandle(publish_handle)) + } - info!("stopped waiting for subscriptions"); - break; - } - } + type UpdateSubscriptionHandle = UpdateSubscriptionHandle; - if subacks.is_empty() { - info!("successfully subscribed to topics"); - } else { - error!( - "failed to receive expected subacks for topics: {:?}", - subacks.iter().map(ToString::to_string).collect::(), - ); - } + fn update_subscription_handle(&self) -> Result { + let update_subscription_handle = self.client.update_subscription_handle()?; + Ok(UpdateSubscriptionHandle(update_subscription_handle)) + } + + type ShutdownHandle = ShutdownHandle; + + fn shutdown_handle(&self) -> Result { + let shutdown_handle = self.client.shutdown_handle()?; + Ok(ShutdownHandle(shutdown_handle)) + } +} + +#[cfg(test)] +/// Implements `MqttClientExt` for tests. +impl MqttClientExt for MqttClient { + type PublishHandle = MockPublishHandle; + + fn publish_handle(&self) -> Result { + Ok(MockPublishHandle::new()) + } + type UpdateSubscriptionHandle = MockUpdateSubscriptionHandle; + + fn update_subscription_handle(&self) -> Result { + Ok(MockUpdateSubscriptionHandle::new()) + } + + type ShutdownHandle = MockShutdownHandle; + + fn shutdown_handle(&self) -> Result { + Ok(MockShutdownHandle::new()) + } +} + +/// A client shutdown handle. +#[derive(Debug, Clone)] +pub struct ShutdownHandle(mqtt3::ShutdownHandle); + +#[automock] +impl ShutdownHandle { + pub async fn shutdown(&mut self) -> Result<(), ClientError> { + self.0.shutdown().await?; Ok(()) } } +/// A client publish handle. +#[derive(Debug, Clone)] +pub struct PublishHandle(mqtt3::PublishHandle); + +impl PublishHandle { + pub async fn publish(&mut self, publication: Publication) -> Result<(), ClientError> { + self.0 + .publish(publication) + .await + .map_err(ClientError::PublishError) + } +} + +mockall::mock! { + pub PublishHandle { + async fn publish(&mut self, publication: Publication) -> Result<(), ClientError>; + } + + pub trait Clone { + fn clone(&self) -> Self; + } +} + +/// A client subscription update handle. +pub struct UpdateSubscriptionHandle(mqtt3::UpdateSubscriptionHandle); + +#[automock] +impl UpdateSubscriptionHandle { + pub async fn subscribe(&mut self, subscribe_to: SubscribeTo) -> Result<(), ClientError> { + self.0 + .subscribe(subscribe_to) + .await + .map_err(ClientError::UpdateSubscriptionError) + } + + pub async fn unsubscribe(&mut self, unsubscribe_from: String) -> Result<(), ClientError> { + self.0 + .unsubscribe(unsubscribe_from) + .await + .map_err(ClientError::UpdateSubscriptionError) + } +} + +/// A trait which every MQTT client event handler implements. #[async_trait] -pub trait EventHandler { +pub trait MqttEventHandler { type Error: Display; - async fn handle(&mut self, event: &Event) -> Result; + /// Returns a list of subscriptions for a MQTT client to subscribe to + /// when client starts. + fn subscriptions(&self) -> Vec { + vec![] + } + + /// Handles MQTT client event and returns marker which determines whether + /// an event handler fully handled an event. + async fn handle(&mut self, event: Event) -> Result; } +/// An `MqttEventHandler::handle` method result. #[derive(Debug, PartialEq)] pub enum Handled { + /// MQTT client event is fully handled. Fully, - Partially, - Skipped, + + /// MQTT client event is partially handled. It contains modified event. + Partially(Event), + + /// Unknown MQTT client event so event handler skipped the event. + /// It contains not modified event. + Skipped(Event), } #[derive(Debug, thiserror::Error)] @@ -423,9 +541,130 @@ pub enum ClientError { #[error("failed to shutdown custom mqtt client: {0}")] ShutdownClient(#[from] mqtt3::ShutdownError), - #[error("failed to shutdown custom mqtt client: {0}")] - PublishHandle(#[from] mqtt3::PublishError), + #[error("failed to obtain publish handle: {0}")] + PublishHandle(#[source] mqtt3::PublishError), + + #[error("failed to publish event: {0}")] + PublishError(#[source] mqtt3::PublishError), + + #[error("failed to obtain subscribe handle: {0}")] + UpdateSubscriptionHandle(#[source] mqtt3::UpdateSubscriptionError), + + #[error("failed to send update subscription: {0}")] + UpdateSubscriptionError(#[source] mqtt3::UpdateSubscriptionError), #[error("failed to connect")] SslHandshake, + + #[error("string too large: {0}")] + StringTooLarge(#[from] TryFromIntError), +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::settings::AuthenticationSettings; + + #[derive(Default)] + struct EventHandler {} + + #[async_trait] + impl MqttEventHandler for EventHandler { + type Error = ClientError; + + async fn handle(&mut self, _event: Event) -> Result { + Ok(Handled::Fully) + } + } + + #[test] + fn new_validates_valid_input() { + let addr = "addr".to_owned(); + let secs: u64 = (u16::MAX).into(); + let keep_alive = Duration::from_secs(secs); + let clean_session = true; + let connection_credentials = Credentials::Anonymous("user".to_owned()); + + let event_handler = EventHandler::default(); + + let client = MqttClient::new( + keep_alive, + clean_session, + event_handler, + &connection_credentials, + BridgeIoSource::Tcp(TcpConnection::new(addr, None, None)), + ); + + assert_eq!(client.is_ok(), true); + } + + #[test] + fn new_validates_invalid_keep_alive() { + let addr = "addr".to_owned(); + let secs: u64 = (u16::MAX).into(); + let keep_alive = Duration::from_secs(secs + 1); + let clean_session = true; + let connection_credentials = Credentials::Anonymous("user".to_owned()); + + let event_handler = EventHandler::default(); + let client = MqttClient::new( + keep_alive, + clean_session, + event_handler, + &connection_credentials, + BridgeIoSource::Tcp(TcpConnection::new(addr, None, None)), + ); + + assert_eq!(client.is_err(), true); + } + + #[test] + fn new_validates_invalid_client_id() { + let addr = "addr".to_owned(); + let keep_alive = Duration::from_secs(120); + let clean_session = false; + let repeat: usize = (u16::MAX).into(); + let connection_credentials = Credentials::PlainText(AuthenticationSettings::new( + "c".repeat(repeat + 1), + "username".into(), + "pass".into(), + )); + + let event_handler = EventHandler::default(); + + let client = MqttClient::new( + keep_alive, + clean_session, + event_handler, + &connection_credentials, + BridgeIoSource::Tcp(TcpConnection::new(addr, None, None)), + ); + + assert_eq!(client.is_err(), true); + } + + #[test] + fn new_validates_invalid_username() { + let addr = "addr".to_owned(); + let keep_alive = Duration::from_secs(120); + let clean_session = false; + let repeat: usize = (u16::MAX).into(); + let connection_credentials = Credentials::PlainText(AuthenticationSettings::new( + "user".into(), + "u".repeat(repeat + 1), + "pass".into(), + )); + + let event_handler = EventHandler::default(); + + let client = MqttClient::new( + keep_alive, + clean_session, + event_handler, + &connection_credentials, + BridgeIoSource::Tcp(TcpConnection::new(addr, None, None)), + ); + + assert_eq!(client.is_err(), true); + } } diff --git a/mqtt/mqtt-bridge/src/config_update.rs b/mqtt/mqtt-bridge/src/config_update.rs new file mode 100644 index 00000000000..80ad9fb912f --- /dev/null +++ b/mqtt/mqtt-bridge/src/config_update.rs @@ -0,0 +1,989 @@ +use std::collections::HashMap; + +use serde::Deserialize; +use tracing::debug; + +use crate::{bridge::BridgeHandle, controller::Error, settings::Direction, settings::TopicRule}; + +/// Keeps the current subscriptions and forwards and calculates the diff with an `BridgeUpdate` +/// It is used to send a diff to a `BridgeHandle` and update itself with latest configuration +pub struct ConfigUpdater { + bridge_handle: BridgeHandle, + current_subscriptions: HashMap, + current_forwards: HashMap, +} + +impl ConfigUpdater { + pub fn new(bridge_handle: BridgeHandle) -> Self { + Self { + bridge_handle, + current_subscriptions: HashMap::new(), + current_forwards: HashMap::new(), + } + } + + pub async fn send_update(&mut self, bridge_update: BridgeUpdate) -> Result<(), Error> { + let diff = self.diff(bridge_update); + + debug!("sending diff {:?}", diff); + + self.bridge_handle.send_update(diff.clone()).await?; + + self.update(diff); + + Ok(()) + } + + fn diff(&self, bridge_update: BridgeUpdate) -> BridgeDiff { + let (forwards, subscriptions) = bridge_update.into_parts(); + + let local_diff = diff_topic_rules(forwards, &self.current_forwards); + + let remote_diff = diff_topic_rules(subscriptions, &self.current_subscriptions); + + BridgeDiff::default() + .with_local_diff(local_diff) + .with_remote_diff(remote_diff) + } + + fn update(&mut self, bridge_diff: BridgeDiff) { + let (local_updates, remote_updates) = bridge_diff.into_parts(); + + update_pump(local_updates, &mut self.current_forwards); + + update_pump(remote_updates, &mut self.current_subscriptions) + } +} + +fn diff_topic_rules(updated: Vec, current: &HashMap) -> PumpDiff { + let mut added = vec![]; + let mut removed = vec![]; + + let subs_map = updated + .iter() + .map(|sub| (sub.subscribe_to(), sub.clone())) + .collect::>(); + + for sub in updated { + if !current.contains_key(&sub.subscribe_to()) + || current + .get(&sub.subscribe_to()) + .filter(|curr| curr.to_owned().eq(&sub)) + == None + { + added.push(sub); + } + } + + for sub in current.keys() { + if !subs_map.contains_key(sub) { + if let Some(curr) = current.get(sub) { + removed.push(curr.to_owned()) + } + } + } + + PumpDiff::default().with_added(added).with_removed(removed) +} + +fn update_pump(pump_diff: PumpDiff, current: &mut HashMap) { + let (added, removed) = pump_diff.into_parts(); + + added.into_iter().for_each(|added| { + current.insert(added.subscribe_to(), added); + }); + + removed.iter().for_each(|updated| { + current.remove(&updated.subscribe_to()); + }); +} + +#[derive(Debug, Deserialize)] +pub struct BridgeControllerUpdate(Vec); + +impl BridgeControllerUpdate { + pub fn from_bridge_topic_rules(name: &str, subs: &[TopicRule], forwards: &[TopicRule]) -> Self { + let subscriptions = subs + .iter() + .map(|s| Direction::In(s.to_owned())) + .chain(forwards.iter().map(|s| Direction::Out(s.to_owned()))) + .collect(); + + let bridge_update = BridgeUpdate { + endpoint: name.to_owned(), + subscriptions, + }; + Self(vec![bridge_update]) + } + + pub fn into_inner(self) -> Vec { + self.0 + } +} + +#[derive(Clone, Debug, PartialEq, Deserialize)] +pub struct BridgeUpdate { + endpoint: String, + #[serde(rename = "settings")] + subscriptions: Vec, +} + +impl BridgeUpdate { + pub fn new(name: impl Into, subs: Vec, forwards: Vec) -> Self { + let subscriptions = subs + .into_iter() + .map(Direction::In) + .chain(forwards.into_iter().map(Direction::Out)) + .collect(); + + Self { + endpoint: name.into(), + subscriptions, + } + } + + // TODO update should have name + pub fn name(&self) -> &str { + &self.endpoint + } + + pub fn endpoint(&self) -> &str { + &self.endpoint + } + + pub fn into_parts(self) -> (Vec, Vec) { + let forwards = self + .subscriptions + .iter() + .filter_map(|sub| match sub { + Direction::Out(topic) | Direction::Both(topic) => Some(topic.clone()), + _ => None, + }) + .collect(); + let subscriptions = self + .subscriptions + .iter() + .filter_map(|sub| match sub { + Direction::In(topic) | Direction::Both(topic) => Some(topic.clone()), + _ => None, + }) + .collect(); + + (forwards, subscriptions) + } +} + +#[derive(Clone, Debug, Default, PartialEq)] +pub struct BridgeDiff { + local_pump_diff: PumpDiff, + remote_pump_diff: PumpDiff, +} + +impl BridgeDiff { + pub fn with_local_diff(mut self, diff: PumpDiff) -> Self { + self.local_pump_diff = diff; + self + } + + pub fn with_remote_diff(mut self, diff: PumpDiff) -> Self { + self.remote_pump_diff = diff; + self + } + + pub fn into_parts(self) -> (PumpDiff, PumpDiff) { + (self.local_pump_diff, self.remote_pump_diff) + } +} + +#[derive(Clone, Debug, Default, PartialEq)] +pub struct PumpDiff { + added: Vec, + removed: Vec, +} + +impl PumpDiff { + pub fn with_added(mut self, added: Vec) -> Self { + self.added = added; + self + } + + pub fn with_removed(mut self, removed: Vec) -> Self { + self.removed = removed; + self + } + + pub fn added(&self) -> Vec<&TopicRule> { + self.added.iter().collect() + } + + pub fn removed(&self) -> Vec<&TopicRule> { + self.removed.iter().collect() + } + + pub fn has_updates(&self) -> bool { + !(self.added.is_empty() && self.removed.is_empty()) + } + + pub fn into_parts(self) -> (Vec, Vec) { + (self.added, self.removed) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn diff_with_empty_current_topics_and_empty_update() { + let (local_handle, _) = crate::pump::channel(); + let (remote_handle, _) = crate::pump::channel(); + + let handler = BridgeHandle::new(local_handle, remote_handle); + + let config_updater = ConfigUpdater::new(handler); + + let bridge_update = BridgeUpdate { + endpoint: "$upstream".to_owned(), + subscriptions: Vec::new(), + }; + let diff = config_updater.diff(bridge_update); + + let (local_updates, remote_updates) = diff.into_parts(); + assert_eq!(local_updates, PumpDiff::default()); + assert_eq!(remote_updates, PumpDiff::default()); + } + + #[test] + fn diff_with_empty_current_topics_and_remote_pump_update() { + let (local_handle, _) = crate::pump::channel(); + let (remote_handle, _) = crate::pump::channel(); + + let handler = BridgeHandle::new(local_handle, remote_handle); + + let config_updater = ConfigUpdater::new(handler); + + let update = r#" + { + "endpoint": "$upstream", + "settings": [ + { + "direction": "in", + "topic": "test/#", + "inPrefix": "/local", + "outPrefix": "/remote" + } + ] + }"#; + + let topic_rule = r#"{ + "topic": "test/#", + "inPrefix": "/local", + "outPrefix": "/remote" + }"#; + + let expected = + PumpDiff::default().with_added(vec![serde_json::from_str(topic_rule).unwrap()]); + + let bridge_update: BridgeUpdate = serde_json::from_str(update).unwrap(); + + let diff = config_updater.diff(bridge_update); + + let (local_updates, remote_updates) = diff.into_parts(); + assert_eq!(local_updates, PumpDiff::default()); + assert_eq!(remote_updates, expected); + } + + #[test] + fn diff_with_empty_current_topics_and_local_pump_update() { + let (local_handle, _) = crate::pump::channel(); + let (remote_handle, _) = crate::pump::channel(); + + let handler = BridgeHandle::new(local_handle, remote_handle); + + let config_updater = ConfigUpdater::new(handler); + + let update = r#" + { + "endpoint": "$upstream", + "settings": [ + { + "direction": "out", + "topic": "test/#", + "inPrefix": "/local", + "outPrefix": "/remote" + } + ] + }"#; + + let topic_rule = r#"{ + "topic": "test/#", + "inPrefix": "/local", + "outPrefix": "/remote" + }"#; + + let expected_topic_rule: TopicRule = serde_json::from_str(topic_rule).unwrap(); + let expected = PumpDiff::default().with_added(vec![expected_topic_rule]); + + let bridge_update: BridgeUpdate = serde_json::from_str(update).unwrap(); + + let diff = config_updater.diff(bridge_update); + + let (local_updates, remote_updates) = diff.into_parts(); + assert_eq!(local_updates, expected); + assert_eq!(remote_updates, PumpDiff::default()); + } + + #[test] + fn diff_with_empty_current_topics_and_both_pump_update() { + let (local_handle, _) = crate::pump::channel(); + let (remote_handle, _) = crate::pump::channel(); + + let handler = BridgeHandle::new(local_handle, remote_handle); + + let config_updater = ConfigUpdater::new(handler); + + let update = r#" + { + "endpoint": "$upstream", + "settings": [ + { + "direction": "in", + "topic": "test/#", + "inPrefix": "/local", + "outPrefix": "/remote" + }, + { + "direction": "out", + "topic": "test/#", + "inPrefix": "/local", + "outPrefix": "/remote" + }, + { + "direction": "both", + "topic": "test2/#", + "inPrefix": "/local", + "outPrefix": "/remote" + } + ] + }"#; + + let topic_rule1 = r#"{ + "topic": "test/#", + "inPrefix": "/local", + "outPrefix": "/remote" + }"#; + let topic_rule2 = r#"{ + "topic": "test2/#", + "inPrefix": "/local", + "outPrefix": "/remote" + }"#; + + let expected = PumpDiff::default().with_added(vec![ + serde_json::from_str(topic_rule1).unwrap(), + serde_json::from_str(topic_rule2).unwrap(), + ]); + + let bridge_update: BridgeUpdate = serde_json::from_str(update).unwrap(); + + let diff = config_updater.diff(bridge_update); + + let (local_updates, remote_updates) = diff.into_parts(); + assert_eq!(local_updates, expected); + assert_eq!(remote_updates, expected); + } + + #[test] + fn diff_with_current_topics_and_both_pump_update_added() { + let (local_handle, _) = crate::pump::channel(); + let (remote_handle, _) = crate::pump::channel(); + + let handler = BridgeHandle::new(local_handle, remote_handle); + + let mut config_updater = ConfigUpdater::new(handler); + let existing_rule: TopicRule = serde_json::from_str( + r#"{ + "topic": "existing/#", + "inPrefix": "/local", + "outPrefix": "/remote" + }"#, + ) + .unwrap(); + + config_updater + .current_subscriptions + .insert("/local/existing/#".to_owned(), existing_rule.clone()); + config_updater + .current_forwards + .insert("/local/existing/#".to_owned(), existing_rule); + + let update = r#" + { + "endpoint": "$upstream", + "settings": [ + { + "direction": "in", + "topic": "test/#", + "inPrefix": "/local", + "outPrefix": "/remote" + }, + { + "direction": "out", + "topic": "test/#", + "inPrefix": "/local", + "outPrefix": "/remote" + }, + { + "direction": "both", + "topic": "existing/#", + "inPrefix": "/local", + "outPrefix": "/remote" + } + ] + }"#; + + let topic_rule1 = r#"{ + "topic": "test/#", + "inPrefix": "/local", + "outPrefix": "/remote" + }"#; + + let expected = + PumpDiff::default().with_added(vec![serde_json::from_str(topic_rule1).unwrap()]); + + let bridge_update: BridgeUpdate = serde_json::from_str(update).unwrap(); + + let diff = config_updater.diff(bridge_update); + + let (local_updates, remote_updates) = diff.into_parts(); + assert_eq!(local_updates, expected); + assert_eq!(remote_updates, expected); + } + + #[test] + fn diff_with_current_topics_and_both_pump_update_outprefix_updated() { + let (local_handle, _) = crate::pump::channel(); + let (remote_handle, _) = crate::pump::channel(); + + let handler = BridgeHandle::new(local_handle, remote_handle); + + let mut config_updater = ConfigUpdater::new(handler); + let existing_rule: TopicRule = serde_json::from_str( + r#"{ + "topic": "test/#", + "inPrefix": "/local", + "outPrefix": "/remote" + }"#, + ) + .unwrap(); + + config_updater + .current_subscriptions + .insert("/local/test/#".to_owned(), existing_rule); + + let update = r#" + { + "endpoint": "$upstream", + "settings": [ + { + "direction": "in", + "topic": "test/#", + "inPrefix": "/local", + "outPrefix": "/updated" + } + ] + }"#; + + let topic_rule1 = r#"{ + "topic": "test/#", + "inPrefix": "/local", + "outPrefix": "/updated" + }"#; + + let expected = + PumpDiff::default().with_added(vec![serde_json::from_str(topic_rule1).unwrap()]); + + let bridge_update: BridgeUpdate = serde_json::from_str(update).unwrap(); + + let diff = config_updater.diff(bridge_update); + + let (local_updates, remote_updates) = diff.into_parts(); + assert_eq!(local_updates, PumpDiff::default()); + assert_eq!(remote_updates, expected); + } + + #[test] + fn diff_with_current_topics_and_both_pump_update_added_and_removed() { + let (local_handle, _) = crate::pump::channel(); + let (remote_handle, _) = crate::pump::channel(); + + let handler = BridgeHandle::new(local_handle, remote_handle); + + let mut config_updater = ConfigUpdater::new(handler); + let existing_rule: TopicRule = serde_json::from_str( + r#"{ + "topic": "existing/#", + "inPrefix": "/local", + "outPrefix": "/remote" + }"#, + ) + .unwrap(); + + config_updater + .current_subscriptions + .insert("/local/existing/#".to_owned(), existing_rule.clone()); + config_updater + .current_forwards + .insert("/local/existing/#".to_owned(), existing_rule.clone()); + + let update = r#" + { + "endpoint": "$upstream", + "settings": [ + { + "direction": "in", + "topic": "test/#", + "inPrefix": "/local", + "outPrefix": "/remote" + }, + { + "direction": "out", + "topic": "test/#", + "inPrefix": "/local", + "outPrefix": "/remote" + }, + { + "direction": "both", + "topic": "test2/#", + "inPrefix": "/local", + "outPrefix": "/remote" + } + ] + }"#; + + let topic_rule1 = r#"{ + "topic": "test/#", + "inPrefix": "/local", + "outPrefix": "/remote" + }"#; + let topic_rule2 = r#"{ + "topic": "test2/#", + "inPrefix": "/local", + "outPrefix": "/remote" + }"#; + + let expected = PumpDiff::default() + .with_added(vec![ + serde_json::from_str(topic_rule1).unwrap(), + serde_json::from_str(topic_rule2).unwrap(), + ]) + .with_removed(vec![existing_rule]); + + let bridge_update: BridgeUpdate = serde_json::from_str(update).unwrap(); + + let diff = config_updater.diff(bridge_update); + + let (local_updates, remote_updates) = diff.into_parts(); + assert_eq!(local_updates, expected); + assert_eq!(remote_updates, expected); + } + + #[test] + fn update_config_from_diff_added() { + let (local_handle, _) = crate::pump::channel(); + let (remote_handle, _) = crate::pump::channel(); + + let handler = BridgeHandle::new(local_handle, remote_handle); + + let mut config_updater = ConfigUpdater::new(handler); + let existing_rule: TopicRule = serde_json::from_str( + r#"{ + "topic": "existing/#", + "inPrefix": "/local", + "outPrefix": "/remote" + }"#, + ) + .unwrap(); + + config_updater + .current_subscriptions + .insert("/local/existing/#".to_owned(), existing_rule.clone()); + config_updater + .current_forwards + .insert("/local/existing/#".to_owned(), existing_rule.clone()); + + let topic_rule1 = r#"{ + "topic": "forward/#", + "inPrefix": "/local", + "outPrefix": "/remote" + }"#; + + let topic_rule2 = r#"{ + "topic": "sub/#", + "inPrefix": "/local", + "outPrefix": "/remote" + }"#; + + let forwards_diff = + PumpDiff::default().with_added(vec![serde_json::from_str(topic_rule1).unwrap()]); + + let subs_diff = + PumpDiff::default().with_added(vec![serde_json::from_str(topic_rule2).unwrap()]); + + config_updater.update( + BridgeDiff::default() + .with_local_diff(forwards_diff) + .with_remote_diff(subs_diff), + ); + + let expected_forward_rule = serde_json::from_str(topic_rule1).unwrap(); + let expected_subs_rule = serde_json::from_str(topic_rule2).unwrap(); + assert_eq!( + config_updater + .current_forwards + .get("/local/existing/#") + .unwrap(), + &existing_rule + ); + assert_eq!( + config_updater + .current_forwards + .get("/local/forward/#") + .unwrap(), + &expected_forward_rule + ); + assert_eq!( + config_updater + .current_forwards + .get("/local/subs/#") + .is_none(), + true + ); + assert_eq!( + config_updater + .current_subscriptions + .get("/local/existing/#") + .unwrap(), + &existing_rule + ); + assert_eq!( + config_updater + .current_subscriptions + .get("/local/sub/#") + .unwrap(), + &expected_subs_rule + ); + assert_eq!( + config_updater + .current_subscriptions + .get("/local/forward/#") + .is_none(), + true + ); + } + + #[test] + fn update_config_from_diff_updated() { + let (local_handle, _) = crate::pump::channel(); + let (remote_handle, _) = crate::pump::channel(); + + let handler = BridgeHandle::new(local_handle, remote_handle); + + let mut config_updater = ConfigUpdater::new(handler); + let existing_rule: TopicRule = serde_json::from_str( + r#"{ + "topic": "existing/#", + "inPrefix": "/local", + "outPrefix": "/remote" + }"#, + ) + .unwrap(); + + config_updater + .current_subscriptions + .insert("/local/existing/#".to_owned(), existing_rule.clone()); + config_updater + .current_forwards + .insert("/local/existing/#".to_owned(), existing_rule); + + let topic_rule1 = r#"{ + "topic": "existing/#", + "inPrefix": "/local", + "outPrefix": "/forward-remote" + }"#; + + let topic_rule2 = r#"{ + "topic": "existing/#", + "inPrefix": "/local", + "outPrefix": "/sub-remote" + }"#; + + let forwards_diff = + PumpDiff::default().with_added(vec![serde_json::from_str(topic_rule1).unwrap()]); + + let subs_diff = + PumpDiff::default().with_added(vec![serde_json::from_str(topic_rule2).unwrap()]); + + config_updater.update( + BridgeDiff::default() + .with_local_diff(forwards_diff) + .with_remote_diff(subs_diff), + ); + + let expected_forward_rule = serde_json::from_str(topic_rule1).unwrap(); + let expected_subs_rule = serde_json::from_str(topic_rule2).unwrap(); + assert_eq!( + config_updater + .current_forwards + .get("/local/existing/#") + .unwrap(), + &expected_forward_rule + ); + assert_eq!( + config_updater + .current_forwards + .get("/local/subs/#") + .is_none(), + true + ); + assert_eq!( + config_updater + .current_subscriptions + .get("/local/existing/#") + .unwrap(), + &expected_subs_rule + ); + assert_eq!( + config_updater + .current_subscriptions + .get("/local/forward/#") + .is_none(), + true + ); + } + + #[test] + fn update_config_from_diff_removed_forward() { + let (local_handle, _) = crate::pump::channel(); + let (remote_handle, _) = crate::pump::channel(); + + let handler = BridgeHandle::new(local_handle, remote_handle); + + let mut config_updater = ConfigUpdater::new(handler); + let existing_rule: TopicRule = serde_json::from_str( + r#"{ + "topic": "existing/#", + "inPrefix": "/local", + "outPrefix": "/remote" + }"#, + ) + .unwrap(); + + config_updater + .current_subscriptions + .insert("/local/existing/#".to_owned(), existing_rule.clone()); + config_updater + .current_forwards + .insert("/local/existing/#".to_owned(), existing_rule); + + let topic_rule1 = r#"{ + "topic": "existing/#", + "inPrefix": "/local", + "outPrefix": "/remote" + }"#; + + let forwards_diff = + PumpDiff::default().with_removed(vec![serde_json::from_str(topic_rule1).unwrap()]); + + config_updater.update(BridgeDiff::default().with_local_diff(forwards_diff)); + + assert_eq!( + config_updater + .current_forwards + .get("/local/existing/#") + .is_none(), + true + ); + + assert_eq!( + config_updater + .current_subscriptions + .get("/local/existing/#") + .is_some(), + true + ); + } + + #[test] + fn update_config_from_diff_removed_sub() { + let (local_handle, _) = crate::pump::channel(); + let (remote_handle, _) = crate::pump::channel(); + + let handler = BridgeHandle::new(local_handle, remote_handle); + + let mut config_updater = ConfigUpdater::new(handler); + let existing_rule: TopicRule = serde_json::from_str( + r#"{ + "topic": "existing/#", + "inPrefix": "/local", + "outPrefix": "/remote" + }"#, + ) + .unwrap(); + + config_updater + .current_subscriptions + .insert("/local/existing/#".to_owned(), existing_rule.clone()); + config_updater + .current_forwards + .insert("/local/existing/#".to_owned(), existing_rule); + + let topic_rule1 = r#"{ + "topic": "existing/#", + "inPrefix": "/local", + "outPrefix": "/remote" + }"#; + + let forwards_diff = + PumpDiff::default().with_removed(vec![serde_json::from_str(topic_rule1).unwrap()]); + + config_updater.update(BridgeDiff::default().with_remote_diff(forwards_diff)); + + assert_eq!( + config_updater + .current_forwards + .get("/local/existing/#") + .is_some(), + true + ); + + assert_eq!( + config_updater + .current_subscriptions + .get("/local/existing/#") + .is_none(), + true + ); + } + + #[test] + fn update_config_from_diff_removed_sub_when_not_in_current() { + let (local_handle, _) = crate::pump::channel(); + let (remote_handle, _) = crate::pump::channel(); + + let handler = BridgeHandle::new(local_handle, remote_handle); + + let mut config_updater = ConfigUpdater::new(handler); + + let topic_rule1 = r#"{ + "topic": "existing/#", + "inPrefix": "/local", + "outPrefix": "/remote" + }"#; + + let forwards_diff = + PumpDiff::default().with_removed(vec![serde_json::from_str(topic_rule1).unwrap()]); + + config_updater.update(BridgeDiff::default().with_remote_diff(forwards_diff)); + + assert_eq!( + config_updater + .current_forwards + .get("/local/existing/#") + .is_none(), + true + ); + + assert_eq!( + config_updater + .current_subscriptions + .get("/local/existing/#") + .is_none(), + true + ); + } + + #[test] + fn deserialize_bridge_controller_update() { + let update = r#"[{ + "endpoint": "$upstream", + "settings": [ + { + "direction": "in", + "topic": "test/#", + "inPrefix": "/local", + "outPrefix": "/remote" + }, + { + "direction": "out", + "topic": "test2/#", + "inPrefix": "/local", + "outPrefix": "/remote" + } + ] + }]"#; + + let bridge_controller_update: BridgeControllerUpdate = + serde_json::from_str(update).unwrap(); + + let updates = bridge_controller_update.into_inner(); + let bridge_update = updates.first().take().unwrap(); + + let sub_rule: TopicRule = serde_json::from_str( + r#"{ + "topic": "test/#", + "inPrefix": "/local", + "outPrefix": "/remote" + }"#, + ) + .unwrap(); + + let forward_rule: TopicRule = serde_json::from_str( + r#"{ + "topic": "test2/#", + "inPrefix": "/local", + "outPrefix": "/remote" + }"#, + ) + .unwrap(); + + assert_eq!(bridge_update.clone().endpoint(), "$upstream"); + let (forwards, subscriptions) = bridge_update.to_owned().into_parts(); + assert_eq!(subscriptions, vec![sub_rule]); + assert_eq!(forwards, vec![forward_rule]); + } + + #[test] + fn bridge_controller_from_bridge() { + let sub_rule: TopicRule = serde_json::from_str( + r#"{ + "topic": "sub/#", + "inPrefix": "/local", + "outPrefix": "/remote" + }"#, + ) + .unwrap(); + + let forward_rule: TopicRule = serde_json::from_str( + r#"{ + "topic": "forward/#", + "inPrefix": "/local", + "outPrefix": "/remote" + }"#, + ) + .unwrap(); + + let bridge_controller_update = BridgeControllerUpdate::from_bridge_topic_rules( + "$upstream", + vec![sub_rule.clone()].as_slice(), + vec![forward_rule.clone()].as_slice(), + ); + + let updates = bridge_controller_update.into_inner(); + let bridge_update = updates.first().take().unwrap(); + + assert_eq!(bridge_update.clone().endpoint(), "$upstream"); + let (forwards, subscriptions) = bridge_update.to_owned().into_parts(); + assert_eq!(subscriptions, vec![sub_rule]); + assert_eq!(forwards, vec![forward_rule]); + } +} diff --git a/mqtt/mqtt-bridge/src/controller.rs b/mqtt/mqtt-bridge/src/controller.rs deleted file mode 100644 index e45c742e35e..00000000000 --- a/mqtt/mqtt-bridge/src/controller.rs +++ /dev/null @@ -1,96 +0,0 @@ -use futures_util::future::{self, join_all}; -use serde::{Deserialize, Serialize}; -use thiserror::Error; -use tokio::sync::mpsc::{self, UnboundedSender}; -use tracing::{error, info, info_span}; -use tracing_futures::Instrument; - -use crate::{bridge::Bridge, settings::BridgeSettings}; - -/// Controller that handles the settings and monitors changes, spawns new Bridges and monitors shutdown signal. -pub struct BridgeController { - handle: BridgeControllerHandle, -} - -impl BridgeController { - pub fn new() -> Self { - let (sender, _updates) = mpsc::unbounded_channel(); - let handle = BridgeControllerHandle { sender }; - - Self { handle } - } - - pub fn handle(&self) -> BridgeControllerHandle { - self.handle.clone() - } - - pub async fn run(self, system_address: String, device_id: String, settings: BridgeSettings) { - info!("starting bridge controller..."); - - let mut bridge_handles = vec![]; - if let Some(upstream_settings) = settings.upstream() { - let upstream_settings = upstream_settings.clone(); - - let upstream_bridge = async move { - let bridge = - Bridge::new(system_address, device_id, upstream_settings.clone()).await; - - match bridge { - Ok(bridge) => { - if let Err(e) = bridge.run().await { - error!(err = %e, "failed running {} bridge", upstream_settings.name()); - } - } - Err(e) => { - error!(err = %e, "failed to create {} bridge", upstream_settings.name()); - } - }; - } - .instrument(info_span!("bridge", name = "upstream")); - - bridge_handles.push(upstream_bridge); - } else { - info!("No upstream settings detected. Not starting bridge controller.") - }; - - // join_all is fine because the bridge shouldn't fail and exit - // if a pump in the bridge fails, it should internally recreate it - // this means that if a bridge stops, then shutdown was triggered - join_all(bridge_handles).await; - - // TODO: bridge controller will eventually listen for updates via the twin - // until this is complete we need to wait here indefinitely - // if we stop the bridge controller, our startup/shutdown logic will shut eveything down - future::pending::<()>().await; - } -} - -impl Default for BridgeController { - fn default() -> Self { - Self::new() - } -} - -#[derive(Clone, Debug)] -pub struct BridgeControllerHandle { - sender: UnboundedSender, -} - -impl BridgeControllerHandle { - pub fn send(&mut self, message: BridgeControllerUpdate) -> Result<(), Error> { - self.sender - .send(message) - .map_err(Error::SendControllerMessage) - } -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct BridgeControllerUpdate { - // TODO: add settings -} - -#[derive(Debug, Error)] -pub enum Error { - #[error("An error occurred sending a message to the controller.")] - SendControllerMessage(#[source] tokio::sync::mpsc::error::SendError), -} diff --git a/mqtt/mqtt-bridge/src/controller/bridges.rs b/mqtt/mqtt-bridge/src/controller/bridges.rs new file mode 100644 index 00000000000..a55a6446bef --- /dev/null +++ b/mqtt/mqtt-bridge/src/controller/bridges.rs @@ -0,0 +1,116 @@ +use std::{ + collections::HashMap, + pin::Pin, + task::{Context, Poll}, +}; + +use futures_util::{ + future::{self, BoxFuture}, + stream::{FuturesUnordered, Stream}, + FusedStream, FutureExt, StreamExt, +}; +use tokio::task::JoinError; +use tracing::{debug, error, info_span, warn}; +use tracing_futures::Instrument; + +use crate::{ + bridge::{Bridge, BridgeError, BridgeHandle}, + config_update::{BridgeUpdate, ConfigUpdater}, + persist::StreamWakeableState, + settings::ConnectionSettings, +}; + +/// A type for a future that will be resolved to when `Bridge` exits. +type BridgeFuture = BoxFuture<'static, (String, Result, JoinError>)>; + +/// Encapsulates logic from `BridgeController` on how it manages with +/// `BridgeFuture`s. +/// +/// It represents a `FusedStream` of `Bridge` futures which resolves to a pair +/// bridge name and exit result. It stores shutdown handles for each `Bridge` +/// internally to request a stop when needed. +#[derive(Default)] +pub(crate) struct Bridges { + bridge_handles: HashMap, + config_updaters: HashMap, + bridges: FuturesUnordered, +} + +impl Bridges { + pub(crate) async fn start_bridge(&mut self, bridge: Bridge, settings: &ConnectionSettings) + where + S: StreamWakeableState + Send + 'static, + { + let name = settings.name().to_owned(); + + // save bridge handle + let bridge_handle = bridge.handle(); + self.bridge_handles.insert(name.clone(), bridge_handle); + + // save config updater + let config_updater = ConfigUpdater::new(bridge.handle()); + self.config_updaters.insert(name.clone(), config_updater); + + // start bridge + let upstream_bridge = bridge.run().instrument(info_span!("bridge", name = %name)); + let task = tokio::spawn(upstream_bridge).map(|res| (name, res)); + self.bridges.push(Box::pin(task)); + } + + pub(crate) async fn send_update(&mut self, update: BridgeUpdate) { + if let Some(config) = self.config_updaters.get_mut(update.name()) { + if let Err(e) = config.send_update(update).await { + error!("error sending bridge update {:?}", e); + } + } + } + + pub(crate) async fn shutdown_all(&mut self) { + debug!("sending shutdown request to all bridges..."); + + // sending shutdown signal to each bridge + let shutdowns = self + .bridge_handles + .drain() + .map(|(_, handle)| handle.shutdown()); + future::join_all(shutdowns).await; + + debug!("waiting for all bridges to exit..."); + + // wait until all bridges finish + while let Some((name, bridge)) = self.bridges.next().await { + match bridge { + Ok(Ok(_)) => debug!("bridge {} exited", name), + Ok(Err(e)) => warn!(error = %e, "bridge {} exited with error", name), + Err(e) => warn!(error = %e, "bridge {} panicked ", name), + } + } + + debug!("all bridges exited"); + } +} + +impl Stream for Bridges { + type Item = ( + String, + Result, tokio::task::JoinError>, + ); + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let poll = self.bridges.poll_next_unpin(cx); + + // remove redundant handlers when bridge exits + if let Poll::Ready(Some((name, _))) = &poll { + self.bridge_handles.remove(name); + self.config_updaters.remove(name); + } + + poll + } +} + +impl FusedStream for Bridges { + fn is_terminated(&self) -> bool { + self.bridges.is_terminated() + } +} diff --git a/mqtt/mqtt-bridge/src/controller/mod.rs b/mqtt/mqtt-bridge/src/controller/mod.rs new file mode 100644 index 00000000000..d012b0992c7 --- /dev/null +++ b/mqtt/mqtt-bridge/src/controller/mod.rs @@ -0,0 +1,191 @@ +mod bridges; + +use bridges::Bridges; + +use async_trait::async_trait; +use futures_util::{ + future::{self, Either}, + stream::Fuse, + FusedStream, StreamExt, +}; +use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender}; +use tracing::{debug, error, info, warn}; + +use mqtt_broker::sidecar::{Sidecar, SidecarShutdownHandle, SidecarShutdownHandleError}; + +use crate::{ + bridge::{Bridge, BridgeError}, + config_update::BridgeControllerUpdate, + settings::BridgeSettings, +}; + +const UPSTREAM: &str = "$upstream"; + +/// `BridgeController` controls lifetime of bridges: start/stop and update +/// forwarding rules. +/// +/// Controller handles monitors settings updates and starts a new `Bridge` or +/// stops running `Bridge` if the number of bridges changes. In addition it +/// prepares changes in forwarding rules and applies them to `Bridge` if required. +pub struct BridgeController { + system_address: String, + device_id: String, + settings: BridgeSettings, + handle: BridgeControllerHandle, + messages: Fuse>, +} + +impl BridgeController { + pub fn new(system_address: String, device_id: String, settings: BridgeSettings) -> Self { + let (sender, updates_receiver) = mpsc::unbounded_channel(); + let handle = BridgeControllerHandle { sender }; + + Self { + system_address, + device_id, + settings, + handle, + messages: updates_receiver.fuse(), + } + } + + pub fn handle(&self) -> BridgeControllerHandle { + self.handle.clone() + } +} + +#[async_trait] +impl Sidecar for BridgeController { + fn shutdown_handle(&self) -> Result { + let handle = self.handle.clone(); + Ok(SidecarShutdownHandle::new(async { handle.shutdown() })) + } + + async fn run(mut self: Box) { + info!("starting bridge controller..."); + + let mut bridges = Bridges::default(); + + if let Some(upstream_settings) = self.settings.upstream() { + match Bridge::new_upstream(&self.system_address, &self.device_id, upstream_settings) { + Ok(bridge) => { + bridges.start_bridge(bridge, upstream_settings).await; + } + Err(e) => { + error!(err = %e, "failed to create {} bridge", UPSTREAM); + } + } + } else { + info!("no upstream settings detected") + } + + loop { + let wait_bridge_or_pending = if bridges.is_terminated() { + // if no active bridges available, wait only for a new messages arrival + Either::Left(future::pending()) + } else { + // otherwise try to await both a new message arrival or any bridge exit + Either::Right(bridges.next()) + }; + + match future::select(self.messages.select_next_some(), wait_bridge_or_pending).await { + Either::Left((BridgeControllerMessage::BridgeControllerUpdate(update), _)) => { + process_update(update, &mut bridges).await + } + Either::Left((BridgeControllerMessage::Shutdown, _)) => { + info!("bridge controller shutdown requested"); + bridges.shutdown_all().await; + break; + } + Either::Right((Some((name, bridge)), _)) => { + match bridge { + Ok(Ok(_)) => debug!("bridge {} exited", name), + Ok(Err(e)) => warn!(error = %e, "bridge {} exited with error", name), + Err(e) => warn!(error = %e, "bridge {} panicked ", name), + } + + // always restart upstream bridge + if name == UPSTREAM { + info!("restarting bridge..."); + if let Some(upstream_settings) = self.settings.upstream() { + match Bridge::new_upstream( + &self.system_address, + &self.device_id, + upstream_settings, + ) { + Ok(bridge) => { + bridges.start_bridge(bridge, upstream_settings).await; + } + Err(e) => { + error!(err = %e, "failed to create {} bridge", name); + } + } + } + } + } + Either::Right((None, _)) => { + // first time we resolve bridge future it returns None + } + } + } + + info!("bridge controller stopped"); + } +} + +async fn process_update(update: BridgeControllerUpdate, bridges: &mut Bridges) { + debug!("received updated config: {:?}", update); + + for bridge_update in update.into_inner() { + // for now only supports upstream bridge. + if bridge_update.name() != UPSTREAM { + warn!( + "updates for {} bridge is not supported", + bridge_update.name() + ); + continue; + } + + bridges.send_update(bridge_update).await; + } +} + +#[derive(Clone, Debug)] +pub struct BridgeControllerHandle { + sender: UnboundedSender, +} + +impl BridgeControllerHandle { + pub fn send_update(&mut self, update: BridgeControllerUpdate) -> Result<(), Error> { + self.send_message(BridgeControllerMessage::BridgeControllerUpdate(update)) + } + + pub fn shutdown(mut self) { + if let Err(e) = self.send_message(BridgeControllerMessage::Shutdown) { + error!(error = %e, "unable to request shutdown for bridge controller"); + } + } + + fn send_message(&mut self, message: BridgeControllerMessage) -> Result<(), Error> { + self.sender + .send(message) + .map_err(Error::SendControllerMessage) + } +} + +/// Control message for `BridgeController`. +#[derive(Debug)] +pub enum BridgeControllerMessage { + BridgeControllerUpdate(BridgeControllerUpdate), + Shutdown, +} + +/// Error for `BridgeController`. +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("An error occurred sending a message to the controller.")] + SendControllerMessage(#[source] tokio::sync::mpsc::error::SendError), + + #[error("An error occurred sending a message to the bridge.")] + SendBridgeMessage(#[from] BridgeError), +} diff --git a/mqtt/mqtt-bridge/src/lib.rs b/mqtt/mqtt-bridge/src/lib.rs index 1bacaf7ed11..3ddd619af03 100644 --- a/mqtt/mqtt-bridge/src/lib.rs +++ b/mqtt/mqtt-bridge/src/lib.rs @@ -13,15 +13,15 @@ mod bridge; pub mod client; -mod connectivity; +mod config_update; pub mod controller; mod messages; mod persist; -mod pump; -mod rpc; +pub mod pump; pub mod settings; mod token_source; +pub mod upstream; -pub use crate::controller::{ - BridgeController, BridgeControllerHandle, BridgeControllerUpdate, Error, -}; +pub use crate::controller::{BridgeController, BridgeControllerHandle, Error}; + +pub use crate::config_update::BridgeControllerUpdate; diff --git a/mqtt/mqtt-bridge/src/messages.rs b/mqtt/mqtt-bridge/src/messages.rs index b208587a19f..5a11bdc1134 100644 --- a/mqtt/mqtt-bridge/src/messages.rs +++ b/mqtt/mqtt-bridge/src/messages.rs @@ -1,24 +1,30 @@ -use std::convert::TryFrom; +use std::{collections::HashMap, convert::TryFrom}; use async_trait::async_trait; -use mqtt3::{proto::Publication, Event}; +use mqtt3::{proto::Publication, Event, SubscriptionUpdateEvent}; use mqtt_broker::TopicFilter; use tracing::{debug, warn}; use crate::{ bridge::BridgeError, - client::{EventHandler, Handled}, + client::{Handled, MqttEventHandler}, persist::{PublicationStore, StreamWakeableState}, - rpc::RpcHandler, + pump::TopicMapperUpdates, settings::TopicRule, }; -#[derive(Clone)] +#[derive(Default, Clone)] pub struct TopicMapper { topic_settings: TopicRule, topic_filter: TopicFilter, } +impl TopicMapper { + pub fn subscribe_to(&self) -> String { + self.topic_settings.subscribe_to() + } +} + impl TryFrom for TopicMapper { type Error = BridgeError; @@ -36,140 +42,227 @@ impl TryFrom for TopicMapper { } /// Handle events from client and saves them with the forward topic -pub struct MessageHandler { - topic_mappers: Vec, +pub struct StoreMqttEventHandler { + topic_mappers: HashMap, + topic_mappers_updates: TopicMapperUpdates, store: PublicationStore, } -impl MessageHandler { - pub fn new(store: PublicationStore, topic_mappers: Vec) -> Self { +impl StoreMqttEventHandler { + pub fn new(store: PublicationStore, topic_mappers_updates: TopicMapperUpdates) -> Self { Self { - topic_mappers, + topic_mappers: HashMap::new(), + topic_mappers_updates, store, } } fn transform(&self, topic_name: &str) -> Option { - self.topic_mappers.iter().find_map(|mapper| { + self.topic_mappers.values().find_map(|mapper| { mapper .topic_settings .in_prefix() // maps if local does not have a value it uses the topic that was received, // else it checks that the received topic starts with local prefix and removes the local prefix .map_or(Some(topic_name), |local_prefix| { - let prefix = format!("{}/", local_prefix); - topic_name.strip_prefix(&prefix) + if local_prefix.is_empty() { + topic_name.strip_prefix(local_prefix) + } else { + topic_name.strip_prefix(format!("{}/", local_prefix).as_str()) + } }) // match topic without local prefix with the topic filter pattern .filter(|stripped_topic| mapper.topic_filter.matches(stripped_topic)) - .map(|stripped_topic| { - if let Some(remote_prefix) = mapper.topic_settings.out_prefix() { - format!("{}/{}", remote_prefix, stripped_topic) - } else { - stripped_topic.to_string() + .map(|stripped_topic| match mapper.topic_settings.out_prefix() { + Some(remote_prefix) => { + if remote_prefix.is_empty() { + stripped_topic.to_string() + } else { + format!("{}/{}", remote_prefix, stripped_topic) + } } + None => stripped_topic.to_string(), }) }) } -} - -#[async_trait] -impl EventHandler for MessageHandler -where - S: StreamWakeableState + Send, -{ - type Error = BridgeError; - - async fn handle(&mut self, event: &Event) -> Result { - if let Event::Publication(publication) = event { - let forward_publication = - self.transform(&publication.topic_name) - .map(|topic_name| Publication { - topic_name, - qos: publication.qos, - retain: publication.retain, - payload: publication.payload.clone(), - }); - - if let Some(publication) = forward_publication { - debug!("saving message to store"); - self.store.push(publication).map_err(BridgeError::Store)?; - - return Ok(Handled::Fully); - } else { - warn!("no topic matched"); - } - } - Ok(Handled::Skipped) + fn update_subscribed(&mut self, sub: &str) { + if let Some(mapper) = self.topic_mappers_updates.get(sub) { + self.topic_mappers.insert(sub.to_owned(), mapper); + } else { + warn!("unexpected subscription ack for {}", sub); + }; } -} -pub struct UpstreamHandler { - messages: MessageHandler, - rpc: RpcHandler, + fn update_unsubscribed(&mut self, sub: &str) { + if self.topic_mappers.remove(sub).is_none() { + warn!("unexpected subscription/rejected ack for {}", sub); + }; + } } #[async_trait] -impl EventHandler for UpstreamHandler +impl MqttEventHandler for StoreMqttEventHandler where S: StreamWakeableState + Send, { type Error = BridgeError; - async fn handle(&mut self, event: &Event) -> Result { - // try to handle as RPC command first - if self.rpc.handle(&event).await? == Handled::Fully { - return Ok(Handled::Fully); + fn subscriptions(&self) -> Vec { + self.topic_mappers_updates.subscriptions() + } + + async fn handle(&mut self, event: Event) -> Result { + match &event { + Event::Publication(publication) => { + let forward_publication = + self.transform(&publication.topic_name) + .map(|topic_name| Publication { + topic_name, + qos: publication.qos, + retain: publication.retain, + payload: publication.payload.clone(), + }); + + if let Some(publication) = forward_publication { + debug!("saving message to store"); + self.store.push(publication).map_err(BridgeError::Store)?; + + return Ok(Handled::Fully); + } else { + warn!("no topic matched"); + } + } + Event::SubscriptionUpdates(sub_updates) => { + for update in sub_updates { + match update { + SubscriptionUpdateEvent::Subscribe(subscribe_to) => { + debug!("received subscribe: {:?}", subscribe_to); + self.update_subscribed(&subscribe_to.topic_filter); + } + SubscriptionUpdateEvent::Unsubscribe(unsubcribed_from) => { + debug!("received unsubscribe: {}", unsubcribed_from); + self.update_unsubscribed(&unsubcribed_from); + } + SubscriptionUpdateEvent::RejectedByServer(rejected) => { + debug!("received subscription rejected: {}", rejected); + self.update_unsubscribed(&rejected); + } + } + } + + return Ok(Handled::Fully); + } + Event::NewConnection { reset_session: _ } | Event::Disconnected(_) => {} } - // handle as an event for regular message handler - self.messages.handle(&event).await + Ok(Handled::Skipped(event)) } } #[cfg(test)] mod tests { use bytes::Bytes; - use futures_util::{stream::StreamExt, TryStreamExt}; - use std::str::FromStr; + use futures_util::{ + future::{self, Either}, + stream::StreamExt, + TryStreamExt, + }; + use std::{collections::HashMap, str::FromStr}; use mqtt3::{ - proto::{Publication, QoS}, - Event, ReceivedPublication, + proto::{Publication, QoS, SubscribeTo}, + Event, ReceivedPublication, SubscriptionUpdateEvent, }; use mqtt_broker::TopicFilter; - use super::{MessageHandler, TopicMapper}; + use super::{StoreMqttEventHandler, TopicMapper}; + use crate::{ - client::EventHandler, - persist::PublicationStore, - settings::{BridgeSettings, Direction}, + client::MqttEventHandler, persist::PublicationStore, pump::TopicMapperUpdates, + settings::BridgeSettings, }; + #[tokio::test] + async fn message_handler_updates_topic() { + let batch_size: usize = 5; + let settings = BridgeSettings::from_file("tests/config.json").unwrap(); + let connection_settings = settings.upstream().unwrap(); + + let topics: HashMap = connection_settings + .forwards() + .iter() + .map(|sub| { + ( + sub.subscribe_to(), + TopicMapper { + topic_settings: sub.clone(), + topic_filter: TopicFilter::from_str(sub.topic()).unwrap(), + }, + ) + }) + .collect(); + + let store = PublicationStore::new_memory(batch_size); + let mut handler = StoreMqttEventHandler::new(store, TopicMapperUpdates::new(topics)); + + handler + .handle(Event::SubscriptionUpdates(vec![ + SubscriptionUpdateEvent::Subscribe(SubscribeTo { + topic_filter: "local/floor/#".to_string(), + qos: QoS::AtLeastOnce, + }), + ])) + .await + .unwrap(); + + let _topic_mapper = handler.topic_mappers.get("local/floor/#").unwrap(); + } + + #[tokio::test] + async fn message_handler_updates_topic_without_pending_update() { + let batch_size: usize = 5; + + let topics = HashMap::new(); + + let store = PublicationStore::new_memory(batch_size); + let mut handler = StoreMqttEventHandler::new(store, TopicMapperUpdates::new(topics)); + + handler + .handle(Event::SubscriptionUpdates(vec![ + SubscriptionUpdateEvent::Subscribe(SubscribeTo { + topic_filter: "local/floor/#".to_string(), + qos: QoS::AtLeastOnce, + }), + ])) + .await + .unwrap(); + + assert_eq!(handler.topic_mappers.get("local/floor/#").is_none(), true); + } + #[tokio::test] async fn message_handler_saves_message_with_local_and_forward_topic() { let batch_size: usize = 5; let settings = BridgeSettings::from_file("tests/config.json").unwrap(); let connection_settings = settings.upstream().unwrap(); - let topics: Vec = connection_settings - .subscriptions() + let topics = connection_settings + .forwards() .iter() - .filter_map(|sub| { - if *sub.direction() == Direction::Out { - Some(TopicMapper { + .map(|sub| { + ( + sub.subscribe_to(), + TopicMapper { topic_settings: sub.clone(), topic_filter: TopicFilter::from_str(sub.topic()).unwrap(), - }) - } else { - None - } + }, + ) }) .collect(); let store = PublicationStore::new_memory(batch_size); - let mut handler = MessageHandler::new(store, topics); + let mut handler = StoreMqttEventHandler::new(store, TopicMapperUpdates::new(topics)); let pub1 = ReceivedPublication { topic_name: "local/floor/1".to_string(), @@ -186,7 +279,72 @@ mod tests { payload: Bytes::new(), }; - handler.handle(&Event::Publication(pub1)).await.unwrap(); + handler + .handle(Event::SubscriptionUpdates(vec![ + SubscriptionUpdateEvent::Subscribe(SubscribeTo { + topic_filter: "local/floor/#".to_string(), + qos: QoS::AtLeastOnce, + }), + ])) + .await + .unwrap(); + handler.handle(Event::Publication(pub1)).await.unwrap(); + + let mut loader = handler.store.loader(); + + let extracted1 = loader.try_next().await.unwrap().unwrap(); + assert_eq!(extracted1.1, expected); + } + + #[tokio::test] + async fn message_handler_saves_message_with_empty_local_and_forward_topic() { + let batch_size: usize = 5; + let settings = BridgeSettings::from_file("tests/config.json").unwrap(); + let connection_settings = settings.upstream().unwrap(); + + let topics = connection_settings + .forwards() + .iter() + .map(|sub| { + ( + sub.subscribe_to(), + TopicMapper { + topic_settings: sub.clone(), + topic_filter: TopicFilter::from_str(sub.topic()).unwrap(), + }, + ) + }) + .collect(); + + let store = PublicationStore::new_memory(batch_size); + let mut handler = StoreMqttEventHandler::new(store, TopicMapperUpdates::new(topics)); + + let pub1 = ReceivedPublication { + topic_name: "floor2/1".to_string(), + qos: QoS::AtLeastOnce, + retain: true, + payload: Bytes::new(), + dup: false, + }; + + let expected = Publication { + topic_name: "floor2/1".to_string(), + qos: QoS::AtLeastOnce, + retain: true, + payload: Bytes::new(), + }; + + handler + .handle(Event::SubscriptionUpdates(vec![ + SubscriptionUpdateEvent::Subscribe(SubscribeTo { + topic_filter: "floor2/#".to_string(), + qos: QoS::AtLeastOnce, + }), + ])) + .await + .unwrap(); + + handler.handle(Event::Publication(pub1)).await.unwrap(); let mut loader = handler.store.loader(); @@ -200,23 +358,22 @@ mod tests { let settings = BridgeSettings::from_file("tests/config.json").unwrap(); let connection_settings = settings.upstream().unwrap(); - let topics: Vec = connection_settings - .subscriptions() + let topics = connection_settings + .forwards() .iter() - .filter_map(|sub| { - if *sub.direction() == Direction::Out { - Some(TopicMapper { + .map(|sub| { + ( + sub.subscribe_to(), + TopicMapper { topic_settings: sub.clone(), topic_filter: TopicFilter::from_str(sub.topic()).unwrap(), - }) - } else { - None - } + }, + ) }) .collect(); let store = PublicationStore::new_memory(batch_size); - let mut handler = MessageHandler::new(store, topics); + let mut handler = StoreMqttEventHandler::new(store, TopicMapperUpdates::new(topics)); let pub1 = ReceivedPublication { topic_name: "temp/1".to_string(), @@ -233,7 +390,16 @@ mod tests { payload: Bytes::new(), }; - handler.handle(&Event::Publication(pub1)).await.unwrap(); + handler + .handle(Event::SubscriptionUpdates(vec![ + SubscriptionUpdateEvent::Subscribe(SubscribeTo { + topic_filter: "temp/#".to_string(), + qos: QoS::AtLeastOnce, + }), + ])) + .await + .unwrap(); + handler.handle(Event::Publication(pub1)).await.unwrap(); let mut loader = handler.store.loader(); @@ -247,23 +413,22 @@ mod tests { let settings = BridgeSettings::from_file("tests/config.json").unwrap(); let connection_settings = settings.upstream().unwrap(); - let topics: Vec = connection_settings - .subscriptions() + let topics = connection_settings + .forwards() .iter() - .filter_map(|sub| { - if *sub.direction() == Direction::Out { - Some(TopicMapper { + .map(|sub| { + ( + sub.subscribe_to(), + TopicMapper { topic_settings: sub.clone(), topic_filter: TopicFilter::from_str(sub.topic()).unwrap(), - }) - } else { - None - } + }, + ) }) .collect(); let store = PublicationStore::new_memory(batch_size); - let mut handler = MessageHandler::new(store, topics); + let mut handler = StoreMqttEventHandler::new(store, TopicMapperUpdates::new(topics)); let pub1 = ReceivedPublication { topic_name: "pattern/p1".to_string(), @@ -280,7 +445,17 @@ mod tests { payload: Bytes::new(), }; - handler.handle(&Event::Publication(pub1)).await.unwrap(); + handler + .handle(Event::SubscriptionUpdates(vec![ + SubscriptionUpdateEvent::Subscribe(SubscribeTo { + topic_filter: "pattern/#".to_string(), + qos: QoS::AtLeastOnce, + }), + ])) + .await + .unwrap(); + + handler.handle(Event::Publication(pub1)).await.unwrap(); let mut loader = handler.store.loader(); @@ -294,23 +469,22 @@ mod tests { let settings = BridgeSettings::from_file("tests/config.json").unwrap(); let connection_settings = settings.upstream().unwrap(); - let topics: Vec = connection_settings - .subscriptions() + let topics = connection_settings + .forwards() .iter() - .filter_map(|sub| { - if *sub.direction() == Direction::Out { - Some(TopicMapper { + .map(|sub| { + ( + sub.subscribe_to(), + TopicMapper { topic_settings: sub.clone(), topic_filter: TopicFilter::from_str(sub.topic()).unwrap(), - }) - } else { - None - } + }, + ) }) .collect(); let store = PublicationStore::new_memory(batch_size); - let mut handler = MessageHandler::new(store, topics); + let mut handler = StoreMqttEventHandler::new(store, TopicMapperUpdates::new(topics)); let pub1 = ReceivedPublication { topic_name: "local/temp/1".to_string(), @@ -320,11 +494,123 @@ mod tests { dup: false, }; - handler.handle(&Event::Publication(pub1)).await.unwrap(); + handler + .handle(Event::SubscriptionUpdates(vec![ + SubscriptionUpdateEvent::Subscribe(SubscribeTo { + topic_filter: "local/temp/#".to_string(), + qos: QoS::AtLeastOnce, + }), + ])) + .await + .unwrap(); + handler.handle(Event::Publication(pub1)).await.unwrap(); + + let mut loader = handler.store.loader(); + + let mut interval = tokio::time::interval(std::time::Duration::from_secs(1)); + if let Either::Right(_) = future::select(interval.next(), loader.next()).await { + panic!("Should not reach here"); + } + } + + #[tokio::test] + async fn message_handler_with_local_and_forward_not_ack_topic() { + let batch_size: usize = 5; + let settings = BridgeSettings::from_file("tests/config.json").unwrap(); + let connection_settings = settings.upstream().unwrap(); + + let topics = connection_settings + .forwards() + .iter() + .map(|sub| { + ( + sub.subscribe_to(), + TopicMapper { + topic_settings: sub.clone(), + topic_filter: TopicFilter::from_str(sub.topic()).unwrap(), + }, + ) + }) + .collect(); + + let store = PublicationStore::new_memory(batch_size); + let mut handler = StoreMqttEventHandler::new(store, TopicMapperUpdates::new(topics)); + + let pub1 = ReceivedPublication { + topic_name: "pattern/p1".to_string(), + qos: QoS::AtLeastOnce, + retain: true, + payload: Bytes::new(), + dup: false, + }; + + handler.handle(Event::Publication(pub1)).await.unwrap(); let mut loader = handler.store.loader(); let mut interval = tokio::time::interval(std::time::Duration::from_secs(1)); - futures_util::future::select(interval.next(), loader.next()).await; + + if let Either::Right(_) = future::select(interval.next(), loader.next()).await { + panic!("Should not reach here"); + } + } + + #[tokio::test] + async fn message_handler_with_local_and_forward_unsubscribed_topic() { + let batch_size: usize = 5; + let settings = BridgeSettings::from_file("tests/config.json").unwrap(); + let connection_settings = settings.upstream().unwrap(); + + let topics = connection_settings + .forwards() + .iter() + .map(|sub| { + ( + sub.subscribe_to(), + TopicMapper { + topic_settings: sub.clone(), + topic_filter: TopicFilter::from_str(sub.topic()).unwrap(), + }, + ) + }) + .collect(); + + let store = PublicationStore::new_memory(batch_size); + let mut handler = StoreMqttEventHandler::new(store, TopicMapperUpdates::new(topics)); + + let pub1 = ReceivedPublication { + topic_name: "pattern/p1".to_string(), + qos: QoS::AtLeastOnce, + retain: true, + payload: Bytes::new(), + dup: false, + }; + + handler + .handle(Event::SubscriptionUpdates(vec![ + SubscriptionUpdateEvent::Subscribe(SubscribeTo { + topic_filter: "pattern/#".into(), + qos: QoS::AtLeastOnce, + }), + ])) + .await + .unwrap(); + + handler + .handle(Event::SubscriptionUpdates(vec![ + SubscriptionUpdateEvent::Unsubscribe("pattern/#".into()), + ])) + .await + .unwrap(); + + handler.handle(Event::Publication(pub1)).await.unwrap(); + + let mut loader = handler.store.loader(); + + let mut interval = tokio::time::interval(std::time::Duration::from_secs(1)); + + if let Either::Right(_) = future::select(interval.next(), loader.next()).await { + panic!("Should not reach here"); + } } } diff --git a/mqtt/mqtt-bridge/src/persist/loader.rs b/mqtt/mqtt-bridge/src/persist/loader.rs index b0f79127b9d..990a8ae96e3 100644 --- a/mqtt/mqtt-bridge/src/persist/loader.rs +++ b/mqtt/mqtt-bridge/src/persist/loader.rs @@ -114,7 +114,7 @@ mod tests { #[test] fn smaller_batch_size_respected() { // setup state - let state = WakingMemoryStore::new(); + let state = WakingMemoryStore::default(); let state = Arc::new(Mutex::new(state)); // setup data @@ -153,7 +153,7 @@ mod tests { #[test] fn larger_batch_size_respected() { // setup state - let state = WakingMemoryStore::new(); + let state = WakingMemoryStore::default(); let state = Arc::new(Mutex::new(state)); // setup data @@ -194,7 +194,7 @@ mod tests { #[test] fn ordering_maintained_across_inserts() { // setup state - let state = WakingMemoryStore::new(); + let state = WakingMemoryStore::default(); let state = Arc::new(Mutex::new(state)); // add many elements @@ -229,7 +229,7 @@ mod tests { #[tokio::test] async fn retrieve_elements() { // setup state - let state = WakingMemoryStore::new(); + let state = WakingMemoryStore::default(); let state = Arc::new(Mutex::new(state)); // setup data @@ -270,7 +270,7 @@ mod tests { #[tokio::test] async fn delete_and_retrieve_new_elements() { // setup state - let state = WakingMemoryStore::new(); + let state = WakingMemoryStore::default(); let state = Arc::new(Mutex::new(state)); // setup data @@ -330,7 +330,7 @@ mod tests { #[tokio::test] async fn poll_stream_does_not_block_when_map_empty() { // setup state - let state = WakingMemoryStore::new(); + let state = WakingMemoryStore::default(); let state = Arc::new(Mutex::new(state)); // setup data diff --git a/mqtt/mqtt-bridge/src/persist/publication_store.rs b/mqtt/mqtt-bridge/src/persist/publication_store.rs index 3d8467faa21..72b1137b77f 100644 --- a/mqtt/mqtt-bridge/src/persist/publication_store.rs +++ b/mqtt/mqtt-bridge/src/persist/publication_store.rs @@ -1,4 +1,3 @@ -#![allow(dead_code)] // TODO remove when ready use std::sync::Arc; use anyhow::Result; @@ -22,7 +21,7 @@ pub struct PublicationStore(Arc>>); impl PublicationStore { pub fn new_memory(batch_size: usize) -> PublicationStore { - Self::new(WakingMemoryStore::new(), batch_size) + Self::new(WakingMemoryStore::default(), batch_size) } } @@ -99,7 +98,7 @@ mod tests { #[tokio::test] async fn insert() { // setup state - let state = WakingMemoryStore::new(); + let state = WakingMemoryStore::default(); let batch_size: usize = 5; let persistence = PublicationStore::new(state, batch_size); @@ -138,7 +137,7 @@ mod tests { #[tokio::test] async fn remove() { // setup state - let state = WakingMemoryStore::new(); + let state = WakingMemoryStore::default(); let batch_size: usize = 1; let persistence = PublicationStore::new(state, batch_size); @@ -176,7 +175,7 @@ mod tests { #[tokio::test] async fn remove_key_inserted_but_not_retrieved() { // setup state - let state = WakingMemoryStore::new(); + let state = WakingMemoryStore::default(); let batch_size: usize = 1; let persistence = PublicationStore::new(state, batch_size); @@ -198,7 +197,7 @@ mod tests { #[tokio::test] async fn remove_key_dne() { // setup state - let state = WakingMemoryStore::new(); + let state = WakingMemoryStore::default(); let batch_size: usize = 1; let persistence = PublicationStore::new(state, batch_size); @@ -213,7 +212,7 @@ mod tests { #[tokio::test] async fn get_loader() { // setup state - let state = WakingMemoryStore::new(); + let state = WakingMemoryStore::default(); let batch_size: usize = 1; let persistence = PublicationStore::new(state, batch_size); diff --git a/mqtt/mqtt-bridge/src/persist/waking_state/memory.rs b/mqtt/mqtt-bridge/src/persist/waking_state/memory.rs index 0164ea41ccb..a41bb913a7b 100644 --- a/mqtt/mqtt-bridge/src/persist/waking_state/memory.rs +++ b/mqtt/mqtt-bridge/src/persist/waking_state/memory.rs @@ -18,9 +18,9 @@ pub struct WakingMemoryStore { waker: Option, } -impl WakingMemoryStore { - pub fn new() -> Self { - WakingMemoryStore { +impl Default for WakingMemoryStore { + fn default() -> Self { + Self { queue: VecDeque::new(), loaded: HashSet::new(), waker: None, diff --git a/mqtt/mqtt-bridge/src/persist/waking_state/mod.rs b/mqtt/mqtt-bridge/src/persist/waking_state/mod.rs index 364f0b014f1..598946fb5f6 100644 --- a/mqtt/mqtt-bridge/src/persist/waking_state/mod.rs +++ b/mqtt/mqtt-bridge/src/persist/waking_state/mod.rs @@ -48,7 +48,7 @@ mod tests { loader::MessageLoader, waking_state::StreamWakeableState, Key, WakingMemoryStore, }; - #[test_case(WakingMemoryStore::new())] + #[test_case(WakingMemoryStore::default())] fn insert(mut state: impl StreamWakeableState) { let key1 = Key { offset: 0 }; let pub1 = Publication { @@ -65,7 +65,7 @@ mod tests { assert_eq!(pub1, extracted_message); } - #[test_case(WakingMemoryStore::new())] + #[test_case(WakingMemoryStore::default())] fn ordering_maintained_across_insert(mut state: impl StreamWakeableState) { // insert a bunch of elements let num_elements = 10 as usize; @@ -91,7 +91,7 @@ mod tests { } } - #[test_case(WakingMemoryStore::new())] + #[test_case(WakingMemoryStore::default())] async fn ordering_maintained_across_removal(mut state: impl StreamWakeableState) { // insert a bunch of elements let num_elements = 10 as usize; @@ -130,7 +130,7 @@ mod tests { } } - #[test_case(WakingMemoryStore::new())] + #[test_case(WakingMemoryStore::default())] fn larger_batch_size_respected(mut state: impl StreamWakeableState) { let key1 = Key { offset: 0 }; let pub1 = Publication { @@ -150,7 +150,7 @@ mod tests { assert_eq!(pub1, extracted_message); } - #[test_case(WakingMemoryStore::new())] + #[test_case(WakingMemoryStore::default())] fn smaller_batch_size_respected(mut state: impl StreamWakeableState) { let key1 = Key { offset: 0 }; let pub1 = Publication { @@ -179,7 +179,7 @@ mod tests { assert_eq!(pub1, extracted_message); } - #[test_case(WakingMemoryStore::new())] + #[test_case(WakingMemoryStore::default())] async fn remove_loaded(mut state: impl StreamWakeableState) { let key1 = Key { offset: 0 }; let pub1 = Publication { @@ -197,14 +197,14 @@ mod tests { assert_eq!(empty_batch.len(), 0); } - #[test_case(WakingMemoryStore::new())] + #[test_case(WakingMemoryStore::default())] fn remove_loaded_dne(mut state: impl StreamWakeableState) { let key1 = Key { offset: 0 }; let bad_removal = state.remove(key1); assert_matches!(bad_removal, Err(_)); } - #[test_case(WakingMemoryStore::new())] + #[test_case(WakingMemoryStore::default())] fn remove_loaded_inserted_but_not_yet_retrieved(mut state: impl StreamWakeableState) { let key1 = Key { offset: 0 }; let pub1 = Publication { @@ -219,7 +219,7 @@ mod tests { assert_matches!(bad_removal, Err(_)); } - #[test_case(WakingMemoryStore::new())] + #[test_case(WakingMemoryStore::default())] async fn remove_loaded_out_of_order(mut state: impl StreamWakeableState) { // setup data let key1 = Key { offset: 0 }; @@ -246,7 +246,7 @@ mod tests { assert_matches!(state.remove(key2), Ok(_)) } - #[test_case(WakingMemoryStore::new())] + #[test_case(WakingMemoryStore::default())] async fn insert_wakes_stream(state: impl StreamWakeableState + Send + 'static) { // setup data let state = Rc::new(RefCell::new(state)); diff --git a/mqtt/mqtt-bridge/src/pump.rs b/mqtt/mqtt-bridge/src/pump.rs deleted file mode 100644 index 8c611c79934..00000000000 --- a/mqtt/mqtt-bridge/src/pump.rs +++ /dev/null @@ -1,359 +0,0 @@ -#![allow(dead_code)] -use std::{ - collections::HashMap, - convert::TryInto, - fmt::{Display, Formatter, Result as FmtResult}, -}; - -use futures_util::{ - future::{select, Either, FutureExt}, - pin_mut, - stream::{StreamExt, TryStreamExt}, -}; -use tokio::sync::{mpsc::Sender, oneshot, oneshot::Receiver}; -use tracing::{debug, error, info}; - -use mqtt3::PublishHandle; - -use crate::{ - bridge::BridgeError, - client::{ClientShutdownHandle, MqttClient}, - messages::MessageHandler, - persist::{MessageLoader, PublicationStore, WakingMemoryStore}, - settings::{ConnectionSettings, Credentials, Direction, TopicRule}, -}; - -const BATCH_SIZE: usize = 10; - -#[derive(Debug, PartialEq)] -pub enum PumpMessage { - ConnectivityUpdate(ConnectivityState), - ConfigurationUpdate(ConnectionSettings), -} - -pub struct PumpHandle { - sender: Sender, -} - -impl PumpHandle { - pub fn new(sender: Sender) -> Self { - Self { sender } - } - - pub async fn send(&mut self, message: PumpMessage) -> Result<(), BridgeError> { - self.sender - .send(message) - .await - .map_err(BridgeError::SenderToPump) - } -} - -#[derive(Clone, Copy, Debug, PartialEq)] -pub enum ConnectivityState { - Connected, - Disconnected, -} - -impl Display for ConnectivityState { - fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { - match self { - Self::Connected => write!(f, "Connected"), - Self::Disconnected => write!(f, "Disconnected"), - } - } -} - -/// Specifies if a pump is local or remote -#[derive(Debug, Clone)] -pub enum PumpType { - Local, - Remote, -} - -/// Pumps always exists in a pair -/// This abstraction provides a convenience function to create two pumps at once -pub struct PumpPair { - pub local_pump: Pump, - pub remote_pump: Pump, -} - -impl PumpPair { - pub fn new( - connection_settings: &ConnectionSettings, - system_address: &str, - device_id: &str, - ) -> Result { - let local_client_id = format!( - "{}/$edgeHub/$bridge/{}", - device_id, - connection_settings.name() - ); - - let forwards: HashMap = connection_settings - .subscriptions() - .iter() - .filter_map(|s| { - if *s.direction() == Direction::Out { - Some(Self::format_key_value(s)) - } else { - None - } - }) - .collect(); - - let subscriptions: HashMap = connection_settings - .subscriptions() - .iter() - .filter_map(|s| { - if *s.direction() == Direction::In { - Some(Self::format_key_value(s)) - } else { - None - } - }) - .collect(); - - let outgoing_persist = PublicationStore::new_memory(BATCH_SIZE); - let incoming_persist = PublicationStore::new_memory(BATCH_SIZE); - let outgoing_loader = outgoing_persist.loader(); - let incoming_loader = incoming_persist.loader(); - - let remote_pump = Self::prepare_pump( - outgoing_loader, - &incoming_persist, - &outgoing_persist, - connection_settings.address(), - connection_settings.credentials(), - PumpType::Remote, - subscriptions, - connection_settings, - )?; - - let local_pump = Self::prepare_pump( - incoming_loader, - &outgoing_persist, - &incoming_persist, - system_address, - &Credentials::Anonymous(local_client_id), - PumpType::Local, - forwards, - connection_settings, - )?; - - Ok(Self { - local_pump, - remote_pump, - }) - } - - // If this grows by even more we can find a workaround - #[allow(clippy::too_many_arguments)] - fn prepare_pump( - loader: MessageLoader, - ingress_store: &PublicationStore, - egress_store: &PublicationStore, - address: &str, - credentials: &Credentials, - pump_type: PumpType, - mut topic_mappings: HashMap, - connection_settings: &ConnectionSettings, - ) -> Result { - let (subscriptions, topic_rules): (Vec<_>, Vec<_>) = topic_mappings.drain().unzip(); - let topic_filters = topic_rules - .into_iter() - .map(|topic| topic.try_into()) - .collect::, _>>()?; - - let client = match pump_type { - PumpType::Local => MqttClient::tcp( - address, - connection_settings.keep_alive(), - connection_settings.clean_session(), - MessageHandler::new(ingress_store.clone(), topic_filters), - credentials, - ), - PumpType::Remote => MqttClient::tls( - address, - connection_settings.keep_alive(), - connection_settings.clean_session(), - MessageHandler::new(ingress_store.clone(), topic_filters), - credentials, - ), - }; - - Ok(Pump::new( - client, - subscriptions, - loader, - egress_store.clone(), - pump_type, - )?) - } - - fn format_key_value(subscription: &TopicRule) -> (String, TopicRule) { - let key = if let Some(local) = subscription.in_prefix() { - format!("{}/{}", local, subscription.topic().to_string()) - } else { - subscription.topic().into() - }; - (key, subscription.clone()) - } -} - -/// Pump used to connect to either local broker or remote brokers (including the upstream edge device) -/// It contains an mqtt client that connects to a local/remote broker -/// After connection there are two simultaneous processes: -/// 1) persist incoming messages into an ingress store to be used by another pump -/// 2) publish outgoing messages from an egress store to the local/remote broker -pub struct Pump { - client: MqttClient>, - client_shutdown: ClientShutdownHandle, - publish_handle: PublishHandle, - subscriptions: Vec, - loader: MessageLoader, - persist: PublicationStore, - pump_type: PumpType, -} - -impl Pump { - fn new( - client: MqttClient>, - subscriptions: Vec, - loader: MessageLoader, - persist: PublicationStore, - pump_type: PumpType, - ) -> Result { - let publish_handle = client - .publish_handle() - .map_err(BridgeError::PublishHandle)?; - let client_shutdown = client.shutdown_handle()?; - - Ok(Self { - client, - client_shutdown, - publish_handle, - subscriptions, - loader, - persist, - pump_type, - }) - } - - pub async fn subscribe(&mut self) -> Result<(), BridgeError> { - self.client - .subscribe(&self.subscriptions) - .await - .map_err(BridgeError::Subscribe)?; - - Ok(()) - } - - #[allow(clippy::too_many_lines)] - pub async fn run(&mut self, shutdown: Receiver<()>) { - debug!("starting pump"); - - let (loader_shutdown, loader_shutdown_rx) = oneshot::channel::<()>(); - let publish_handle = self.publish_handle.clone(); - let persist = self.persist.clone(); - let mut loader = self.loader.clone(); - let mut client_shutdown = self.client_shutdown.clone(); - - // egress loop - let egress_loop = async { - let mut receive_fut = loader_shutdown_rx.into_stream(); - - info!("starting egress publication processing"); - - loop { - let mut publish_handle = publish_handle.clone(); - match select(receive_fut.next(), loader.try_next()).await { - Either::Left((shutdown, _)) => { - info!("received shutdown signal for egress messages",); - if shutdown.is_none() { - error!("has unexpected behavior from shutdown signal while signaling bridge pump shutdown"); - } - - break; - } - Either::Right((loaded_element, _)) => { - debug!("extracted publication from store"); - - if let Ok(Some((key, publication))) = loaded_element { - debug!("publishing {:?}", key); - if let Err(e) = publish_handle.publish(publication).await { - error!(err = %e, "failed publish"); - } - - if let Err(e) = persist.remove(key) { - error!(err = %e, "failed removing publication from store"); - } - } - } - } - } - - info!("stopped sending egress messages"); - }; - - // ingress loop - let ingress_loop = async { - debug!("starting ingress publication processing"); - self.client.handle_events().await; - }; - - // run pumps - let egress_loop = egress_loop.fuse(); - let ingress_loop = ingress_loop.fuse(); - let shutdown = shutdown.fuse(); - pin_mut!(egress_loop, ingress_loop); - let pump_processes = select(egress_loop, ingress_loop); - - // wait for shutdown - match select(pump_processes, shutdown).await { - // early-stop error - Either::Left((pump_processes, _)) => { - error!("stopped early so will shut down"); - - match pump_processes { - Either::Left((_, ingress_loop)) => { - if let Err(e) = client_shutdown.shutdown().await { - error!(err = %e, "failed to shutdown ingress publication loop"); - } - - ingress_loop.await; - } - Either::Right((_, egress_loop)) => { - if let Err(e) = loader_shutdown.send(()) { - error!("failed to shutdown egress publication loop {:?}", e); - } - - egress_loop.await; - } - } - } - // shutdown was signaled - Either::Right((shutdown, pump_processes)) => { - if let Err(e) = shutdown { - error!(err = %e, "failed listening for shutdown"); - } - - if let Err(e) = client_shutdown.shutdown().await { - error!(err = %e, "failed to shutdown ingress publication loop"); - } - - if let Err(e) = loader_shutdown.send(()) { - error!("failed to shutdown egress publication loop {:?}", e); - } - - match pump_processes.await { - Either::Left((_, ingress_loop)) => { - ingress_loop.await; - } - Either::Right((_, egress_loop)) => { - egress_loop.await; - } - } - } - } - } -} diff --git a/mqtt/mqtt-bridge/src/pump/builder.rs b/mqtt/mqtt-bridge/src/pump/builder.rs new file mode 100644 index 00000000000..b8df529b0ca --- /dev/null +++ b/mqtt/mqtt-bridge/src/pump/builder.rs @@ -0,0 +1,205 @@ +use std::{collections::HashMap, convert::TryInto}; + +use tokio::sync::mpsc; + +use crate::{ + bridge::BridgeError, + client::{MqttClient, MqttClientConfig, MqttClientExt}, + messages::{StoreMqttEventHandler, TopicMapper}, + persist::{PublicationStore, StreamWakeableState, WakingMemoryStore}, + settings::TopicRule, + upstream::{ + ConnectivityMqttEventHandler, LocalRpcMqttEventHandler, LocalUpstreamMqttEventHandler, + LocalUpstreamPumpEventHandler, RemoteRpcMqttEventHandler, RemoteUpstreamMqttEventHandler, + RemoteUpstreamPumpEventHandler, RpcSubscriptions, + }, +}; + +use super::{MessagesProcessor, Pump, PumpHandle, TopicMapperUpdates}; + +pub type PumpPair = ( + Pump, LocalUpstreamPumpEventHandler>, + Pump, RemoteUpstreamPumpEventHandler>, +); + +/// Constructs a pair of bridge pumps: local and remote. +/// +/// Local pump connects to a local broker, subscribes to topics to receive +/// messages from local broker and put it in the store of the remote pump. +/// Also reads messages from a local store and publishes them to local broker. +/// +/// Remote pump connects to a remote broker, subscribes to topics to receive +/// messages from remote broker and put it in the store of the local pump. +/// Also reads messages from a remote store and publishes them to local broker. +pub struct Builder { + local: PumpBuilder, + remote: PumpBuilder, + store: Box PublicationStore>, +} + +impl Default for Builder { + fn default() -> Self { + Self { + local: PumpBuilder::default(), + remote: PumpBuilder::default(), + store: Box::new(|| PublicationStore::new_memory(0)), + } + } +} + +impl Builder +where + S: StreamWakeableState + Send, +{ + /// Apples parameters to create local pump. + pub fn with_local(mut self, mut apply: F) -> Self + where + F: FnMut(&mut PumpBuilder), + { + apply(&mut self.local); + self + } + + /// Applies parameters to create remote pump. + pub fn with_remote(mut self, mut apply: F) -> Self + where + F: FnMut(&mut PumpBuilder), + { + apply(&mut self.remote); + self + } + + /// Setups a factory to create publication store. + pub fn with_store(self, store: F) -> Builder + where + F: Fn() -> PublicationStore + 'static, + { + Builder { + local: self.local, + remote: self.remote, + store: Box::new(store), + } + } + + /// Creates a pair of local and remote pump. + pub fn build(&mut self) -> Result, BridgeError> { + let remote_store = (self.store)(); + let local_store = (self.store)(); + + let (remote_messages_send, remote_messages_recv) = mpsc::channel(100); + let (local_messages_send, local_messages_recv) = mpsc::channel(100); + + // prepare local pump + let topic_filters = make_topics(&self.local.rules)?; + let local_topic_mappers_updates = TopicMapperUpdates::new(topic_filters); + + let rpc = LocalRpcMqttEventHandler::new(PumpHandle::new(remote_messages_send.clone())); + let messages = + StoreMqttEventHandler::new(remote_store.clone(), local_topic_mappers_updates.clone()); + let handler = LocalUpstreamMqttEventHandler::new(messages, rpc); + + let config = self.local.client.take().expect("local client config"); + let client = MqttClient::tcp(config, handler).map_err(BridgeError::ValidationError)?; + let local_pub_handle = client + .publish_handle() + .map_err(BridgeError::PublishHandle)?; + let subscription_handle = client + .update_subscription_handle() + .map_err(BridgeError::UpdateSubscriptionHandle)?; + + let handler = LocalUpstreamPumpEventHandler::new(local_pub_handle); + let pump_handle = PumpHandle::new(local_messages_send.clone()); + let messages = MessagesProcessor::new( + handler, + local_messages_recv, + pump_handle, + subscription_handle, + local_topic_mappers_updates, + ); + + let local_pump = Pump::new( + local_messages_send.clone(), + client, + local_store.clone(), + messages, + )?; + + // prepare remote pump + let topic_filters = make_topics(&self.remote.rules)?; + let remote_topic_mappers_updates = TopicMapperUpdates::new(topic_filters); + + let rpc_subscriptions = RpcSubscriptions::default(); + let rpc = RemoteRpcMqttEventHandler::new(rpc_subscriptions.clone(), local_pump.handle()); + let messages = + StoreMqttEventHandler::new(local_store, remote_topic_mappers_updates.clone()); + let connectivity = ConnectivityMqttEventHandler::new(PumpHandle::new(local_messages_send)); + let handler = RemoteUpstreamMqttEventHandler::new(messages, rpc, connectivity); + + let config = self.remote.client.take().expect("remote client config"); + let client = MqttClient::tls(config, handler).map_err(BridgeError::ValidationError)?; + let remote_pub_handle = client + .publish_handle() + .map_err(BridgeError::PublishHandle)?; + let remote_sub_handle = client + .update_subscription_handle() + .map_err(BridgeError::UpdateSubscriptionHandle)?; + + let handler = RemoteUpstreamPumpEventHandler::new( + remote_sub_handle, + remote_pub_handle, + local_pump.handle(), + rpc_subscriptions, + ); + let pump_handle = PumpHandle::new(remote_messages_send.clone()); + + let remote_sub_handle = client + .update_subscription_handle() + .map_err(BridgeError::UpdateSubscriptionHandle)?; + let messages = MessagesProcessor::new( + handler, + remote_messages_recv, + pump_handle, + remote_sub_handle, + remote_topic_mappers_updates, + ); + + let remote_pump = Pump::new(remote_messages_send, client, remote_store, messages)?; + + Ok((local_pump, remote_pump)) + } +} + +/// Collects parameters to construct `Pump`. +#[derive(Default)] +pub struct PumpBuilder { + client: Option, + rules: Vec, +} + +impl PumpBuilder { + /// Applies default topic translation rules. + pub fn with_rules(&mut self, rules: Vec) -> &mut Self { + self.rules = rules; + self + } + + /// Applies MQTT client settings. + pub fn with_config(&mut self, config: MqttClientConfig) -> &mut Self { + self.client = Some(config); + self + } +} + +fn make_topics(rules: &[TopicRule]) -> Result, BridgeError> { + let topic_filters: Vec = rules + .iter() + .map(|topic| topic.to_owned().try_into()) + .collect::, _>>()?; + + let topic_filters = topic_filters + .iter() + .map(|topic| (topic.subscribe_to(), topic.clone())) + .collect::>(); + + Ok(topic_filters) +} diff --git a/mqtt/mqtt-bridge/src/pump/egress.rs b/mqtt/mqtt-bridge/src/pump/egress.rs new file mode 100644 index 00000000000..6c215fb42eb --- /dev/null +++ b/mqtt/mqtt-bridge/src/pump/egress.rs @@ -0,0 +1,128 @@ +use futures_util::{pin_mut, stream::StreamExt}; +use tokio::{select, sync::oneshot}; +use tracing::{debug, error, info}; + +use crate::persist::{Key, PublicationStore, StreamWakeableState}; + +// Import and use mocks when run tests, real implementation when otherwise +#[cfg(test)] +pub use crate::client::MockPublishHandle as PublishHandle; + +#[cfg(not(test))] +use crate::client::PublishHandle; + +use mqtt3::proto::Publication; + +const MAX_IN_FLIGHT: usize = 16; + +/// Handles the egress of received publications. +/// +/// It loads messages from the local store and publishes them as MQTT messages +/// to the broker. After acknowledgement is received from the broker it +/// deletes publication from the store. +pub(crate) struct Egress { + publish_handle: PublishHandle, + store: PublicationStore, + shutdown_send: Option>, + shutdown_recv: oneshot::Receiver<()>, +} + +impl Egress +where + S: StreamWakeableState, +{ + /// Creates a new instance of egress. + pub(crate) fn new(publish_handle: PublishHandle, store: PublicationStore) -> Egress { + let (shutdown_send, shutdown_recv) = oneshot::channel(); + + Self { + publish_handle, + store, + shutdown_send: Some(shutdown_send), + shutdown_recv, + } + } + + /// Returns a shutdown handle of egress. + pub(crate) fn handle(&mut self) -> EgressShutdownHandle { + EgressShutdownHandle(self.shutdown_send.take()) + } + + /// Runs egress processing. + pub(crate) async fn run(self) -> Result<(), EgressError> { + let Egress { + publish_handle, + store, + mut shutdown_recv, + .. + } = self; + + info!("starting egress publication processing..."); + + // Take the stream of loaded messages and convert to a stream of futures + // which publish. Then convert to buffered stream so that we can have + // multiple in-flight and also limit number of publications. + let publications = store + .loader() + .filter_map(|loaded| { + let publish_handle = publish_handle.clone(); + async { + match loaded { + Ok((key, publication)) => { + Some(try_publish(key, publication, publish_handle)) + } + Err(e) => { + error!(error = %e, "failed loading publication from store"); + None + } + } + } + }) + .buffer_unordered(MAX_IN_FLIGHT); + pin_mut!(publications); + + loop { + select! { + _ = &mut shutdown_recv => { + debug!("received shutdown signal for egress messages"); + break; + } + key = publications.select_next_some() => { + if let Err(e) = store.remove(key) { + error!(error = %e, "failed removing publication from store"); + } + } + } + } + + info!("egress publication processing stopped"); + Ok(()) + } +} + +async fn try_publish(key: Key, publication: Publication, mut publish_handle: PublishHandle) -> Key { + debug!("publishing {:?}", key); + if let Err(e) = publish_handle.publish(publication).await { + error!(error = %e, "failed publish"); + } + + key +} + +/// Egress shutdown handle. +pub(crate) struct EgressShutdownHandle(Option>); + +impl EgressShutdownHandle { + /// Sends a signal to shutdown egress. + pub(crate) async fn shutdown(mut self) { + if let Some(sender) = self.0.take() { + if sender.send(()).is_err() { + error!("unable to request shutdown for egress."); + } + } + } +} + +#[derive(Debug, thiserror::Error)] +#[error("ingress error")] +pub(crate) struct EgressError; diff --git a/mqtt/mqtt-bridge/src/pump/ingress.rs b/mqtt/mqtt-bridge/src/pump/ingress.rs new file mode 100644 index 00000000000..caa6a584e64 --- /dev/null +++ b/mqtt/mqtt-bridge/src/pump/ingress.rs @@ -0,0 +1,63 @@ +use tracing::{error, info}; + +use crate::client::{MqttClient, MqttEventHandler}; + +// Import and use mocks when run tests, real implementation when otherwise +#[cfg(test)] +pub use crate::client::MockShutdownHandle as ShutdownHandle; + +#[cfg(not(test))] +use crate::client::ShutdownHandle; + +/// Handles incoming MQTT publications and puts them into the store. +pub(crate) struct Ingress { + client: MqttClient, + shutdown_client: Option, +} + +impl Ingress +where + H: MqttEventHandler, +{ + /// Creates a new instance of ingress. + pub(crate) fn new(client: MqttClient, shutdown_client: ShutdownHandle) -> Self { + Self { + client, + shutdown_client: Some(shutdown_client), + } + } + + /// Returns a shutdown handle of ingress. + pub(crate) fn handle(&mut self) -> IngressShutdownHandle { + IngressShutdownHandle(self.shutdown_client.take()) + } + + /// Runs ingress processing. + pub(crate) async fn run(mut self) -> Result<(), IngressError> { + info!("starting ingress publication processing..."); + self.client.run().await?; + info!("ingress publication processing stopped"); + + Ok(()) + } +} + +/// Ingress shutdown handle. +pub(crate) struct IngressShutdownHandle(Option); + +impl IngressShutdownHandle { + /// Sends a signal to shutdown ingress. + pub(crate) async fn shutdown(mut self) { + if let Some(mut sender) = self.0.take() { + if let Err(e) = sender.shutdown().await { + error!("unable to request shutdown for ingress. {}", e); + } + } + } +} + +#[derive(Debug, thiserror::Error)] +pub(crate) enum IngressError { + #[error("mqtt client error. {0}")] + MqttClient(#[from] crate::client::ClientError), +} diff --git a/mqtt/mqtt-bridge/src/pump/messages.rs b/mqtt/mqtt-bridge/src/pump/messages.rs new file mode 100644 index 00000000000..2c7881908bd --- /dev/null +++ b/mqtt/mqtt-bridge/src/pump/messages.rs @@ -0,0 +1,151 @@ +use std::convert::TryInto; + +use async_trait::async_trait; +use futures_util::stream::StreamExt; +use mqtt3::{proto::QoS, proto::SubscribeTo}; +use tokio::sync::mpsc; +use tracing::{debug, error, info}; + +use super::{PumpHandle, PumpMessage, TopicMapperUpdates}; + +// Import and use mocks when run tests, real implementation when otherwise +#[cfg(test)] +pub use crate::client::MockUpdateSubscriptionHandle as UpdateSubscriptionHandle; + +#[cfg(not(test))] +use crate::client::UpdateSubscriptionHandle; + +/// A trait for all custom pump event handlers. +#[async_trait] +pub trait PumpMessageHandler { + /// A custom pump message event type. + type Message; + + /// Handles custom pump message event. + async fn handle(&mut self, message: Self::Message); +} + +/// Handles incoming control messsages for a pump. +pub(crate) struct MessagesProcessor +where + M: PumpMessageHandler, +{ + messages: mpsc::Receiver>, + pump_handle: Option>, + handler: M, + subscription_handle: UpdateSubscriptionHandle, + topic_mappers_updates: TopicMapperUpdates, +} + +impl MessagesProcessor +where + M: PumpMessageHandler, +{ + /// Creates a new instance of message processor. + pub(crate) fn new( + handler: M, + messages: mpsc::Receiver>, + pump_handle: PumpHandle, + subscription_handle: UpdateSubscriptionHandle, + topic_mappers_updates: TopicMapperUpdates, + ) -> Self { + Self { + messages, + pump_handle: Some(pump_handle), + handler, + subscription_handle, + topic_mappers_updates, + } + } + + /// Returns a shutdown handle of message processor. + pub(crate) fn handle(&mut self) -> MessagesProcessorShutdownHandle { + MessagesProcessorShutdownHandle(self.pump_handle.take()) + } + + /// Runs control messages processing. + pub(crate) async fn run(mut self) -> Result<(), MessageProcessorError> { + info!("starting pump messages processor..."); + while let Some(message) = self.messages.next().await { + match message { + PumpMessage::Event(event) => self.handler.handle(event).await, + PumpMessage::ConfigurationUpdate(update) => { + let (added, removed) = update.into_parts(); + debug!( + "received updates added: {:?}, removed: {:?}", + added, removed + ); + + for sub in removed { + let subscribe_to = sub.subscribe_to(); + let unsubscribe_result = self + .subscription_handle + .unsubscribe(subscribe_to.clone()) + .await; + + match unsubscribe_result { + Ok(_) => { + self.topic_mappers_updates.remove(&subscribe_to); + } + Err(e) => { + error!( + "Failed to send unsubscribe update for {}. {}", + subscribe_to, e + ); + } + } + } + + for sub in added { + let subscribe_to = sub.subscribe_to(); + + match sub.to_owned().try_into() { + Ok(mapper) => { + self.topic_mappers_updates.insert(&subscribe_to, mapper); + + if let Err(e) = self + .subscription_handle + .subscribe(SubscribeTo { + topic_filter: subscribe_to, + qos: QoS::AtLeastOnce, // TODO: get from config + }) + .await + { + error!("failed to send subscribe {}", e); + } + } + Err(e) => { + error!("topic rule could not be parsed {}. {}", subscribe_to, e) + } + } + } + } + PumpMessage::Shutdown => { + info!("stop requested"); + break; + } + } + } + + info!("pump messages processor stopped"); + Ok(()) + } +} + +/// Messages processor shutdown handle. +pub(crate) struct MessagesProcessorShutdownHandle(Option>); + +impl MessagesProcessorShutdownHandle { + /// Sends a signal to shutdown message processor. + pub(crate) async fn shutdown(mut self) { + if let Some(mut sender) = self.0.take() { + if let Err(e) = sender.send(PumpMessage::Shutdown).await { + error!("unable to request shutdown for message processor. {}", e); + } + } + } +} + +#[derive(Debug, thiserror::Error)] +#[error("pump messages processor error")] +pub(crate) struct MessageProcessorError; diff --git a/mqtt/mqtt-bridge/src/pump/mod.rs b/mqtt/mqtt-bridge/src/pump/mod.rs new file mode 100644 index 00000000000..ab05f805b20 --- /dev/null +++ b/mqtt/mqtt-bridge/src/pump/mod.rs @@ -0,0 +1,283 @@ +mod builder; +mod egress; +mod ingress; +mod messages; + +use std::{collections::HashMap, error::Error as StdError, future::Future, sync::Arc}; + +pub use builder::Builder; +use egress::{Egress, EgressError, EgressShutdownHandle}; +use ingress::{Ingress, IngressError, IngressShutdownHandle}; +pub use messages::PumpMessageHandler; +use messages::{MessageProcessorError, MessagesProcessor, MessagesProcessorShutdownHandle}; + +use futures_util::{ + future::{self, Either}, + join, pin_mut, +}; +use mockall::automock; +use parking_lot::Mutex; +use tokio::sync::mpsc; +use tracing::{debug, error, info, warn}; + +use crate::{ + bridge::BridgeError, + client::{MqttClient, MqttClientExt, MqttEventHandler}, + config_update::PumpDiff, + messages::TopicMapper, + persist::{PublicationStore, StreamWakeableState}, +}; + +#[cfg(test)] +pub fn channel() -> (PumpHandle, mpsc::Receiver>) { + let (tx, rx) = tokio::sync::mpsc::channel(10); + (PumpHandle::new(tx), rx) +} + +#[derive(Debug, thiserror::Error)] +pub enum PumpError { + #[error("unable to send command to pump. channel closed")] + Send, + + #[error("error ocurred when running pump. {0}")] + Run(Box), +} + +/// Pump is used to connect to either local broker or remote brokers +/// (including the upstream edge device) +/// +/// It contains several tasks running in parallel: ingress, egress and events processing. +/// +/// During `ingress` pump handles incoming MQTT publications and puts them +/// into the store. The opposite pump will read publications from a store +/// and forwards them to the corresponding broker. +/// +/// During `egress` pump reads pulications from its own store and sends them +/// to the broker MQTT client connected to. +/// +/// Messages processing is intended to control pump behavior: initiate pump +/// shutdown, handle configuration update or another specific event. +pub struct Pump +where + M: PumpMessageHandler, +{ + messages_send: mpsc::Sender>, + messages: MessagesProcessor, + egress: Egress, + ingress: Ingress, +} + +impl Pump +where + H: MqttEventHandler, + M: PumpMessageHandler, + M::Message: 'static, + S: StreamWakeableState, +{ + /// Creates a new instance of pump. + fn new( + messages_send: mpsc::Sender>, + client: MqttClient, + store: PublicationStore, + messages: MessagesProcessor, + ) -> Result { + let client_shutdown = client.shutdown_handle()?; + let publish_handle = client + .publish_handle() + .map_err(BridgeError::PublishHandle)?; + + let egress = Egress::new(publish_handle, store); + let ingress = Ingress::new(client, client_shutdown); + + Ok(Self { + messages_send, + messages, + egress, + ingress, + }) + } + + /// Returns a handle to send control messages to a pump. + pub fn handle(&self) -> PumpHandle { + PumpHandle::new(self.messages_send.clone()) + } + + /// Orchestrates starting of egress, ingress and controll messages + /// processing and waits for all of them to finish. + /// + /// Attempts to start all routines in the same task in parallel and + /// waits for any of them to finish. It sends shutdown to other ones + /// and waits until all of them stopped. + pub async fn run(mut self) -> Result<(), PumpError> { + info!("starting pump..."); + + let shutdown_egress = self.egress.handle(); + let egress = self.egress.run(); + + let shutdown_ingress = self.ingress.handle(); + let ingress = self.ingress.run(); + + let shutdown_messages = self.messages.handle(); + let messages = self.messages.run(); + + pin_mut!(egress, ingress, messages); + + match future::select(&mut messages, future::select(&mut egress, &mut ingress)).await { + Either::Left((messages, _)) => { + if let Err(e) = &messages { + error!(error = %e, "pump messages processor exited with error"); + } else { + debug!("pump messages processor exited"); + } + + debug!("shutting down both ingress and egress..."); + join!( + stop_ingress(ingress, shutdown_ingress), + stop_egress(egress, shutdown_egress) + ); + + messages.map_err(|e| PumpError::Run(e.into())) + } + Either::Right((Either::Left((egress, ingress)), messages)) => { + if let Err(e) = &egress { + error!(error = %e, "egress processing exited with error"); + } else { + debug!("egress processing exited"); + } + + debug!("shutting down both ingress and messages processor..."); + join!( + stop_ingress(ingress, shutdown_ingress), + stop_messages(messages, shutdown_messages), + ); + + egress.map_err(|e| PumpError::Run(e.into())) + } + Either::Right((Either::Right((ingress, egress)), messages)) => { + if let Err(e) = &ingress { + error!(error = %e, "ingress processing exited with error"); + } else { + debug!("ingress processing exited"); + } + + debug!("shutting down both egress and messages processor..."); + join!( + stop_egress(egress, shutdown_egress), + stop_messages(messages, shutdown_messages) + ); + + ingress.map_err(|e| PumpError::Run(e.into())) + } + }?; + + info!("pump stopped"); + Ok(()) + } +} + +async fn stop_ingress(ingress: F, shutdown_handle: IngressShutdownHandle) +where + F: Future>, +{ + let (_, ingress) = join!(shutdown_handle.shutdown(), ingress); + + if let Err(e) = ingress { + error!(error = %e, "ingress processing exited with error"); + } else { + debug!("ingress processing exited"); + } +} + +async fn stop_egress(egress: F, shutdown_handle: EgressShutdownHandle) +where + F: Future>, +{ + let (_, egress) = join!(shutdown_handle.shutdown(), egress); + + if let Err(e) = egress { + error!(error = %e, "egress processing exited with error"); + } else { + debug!("egress processing exited"); + } +} + +async fn stop_messages(messages: F, shutdown_handle: MessagesProcessorShutdownHandle) +where + F: Future>, + M: 'static, +{ + let (_, messages) = join!(shutdown_handle.shutdown(), messages); + + if let Err(e) = messages { + error!(error = %e, "pump messages processor exited with error"); + } else { + debug!("pump messages processor exited"); + } +} + +/// A message to control pump behavior. +#[derive(Debug, PartialEq)] +pub enum PumpMessage { + Event(E), + ConfigurationUpdate(PumpDiff), + Shutdown, +} + +/// A handle to send control messages to the pump. +pub struct PumpHandle { + sender: mpsc::Sender>, +} + +#[automock] +impl PumpHandle { + /// Creates a new instance of pump handle. + fn new(sender: mpsc::Sender>) -> Self { + Self { sender } + } + + /// Sends a control message to a pump. + pub async fn send(&mut self, message: PumpMessage) -> Result<(), PumpError> { + self.sender.send(message).await.map_err(|_| PumpError::Send) + } + + /// Sends a shutdown control message to a pump. + pub async fn shutdown(mut self) { + if let Err(e) = self.send(PumpMessage::Shutdown).await { + warn!(error = %e, "unable to request shutdown for pump. probably the pump is about to exit"); + } + } +} + +/// Topic settings received as updates from twin or from initial configuration in the default config file +#[derive(Clone)] +pub struct TopicMapperUpdates(Arc>>); + +impl TopicMapperUpdates { + pub fn new(mappings: HashMap) -> Self { + Self(Arc::new(Mutex::new(mappings))) + } + + pub fn insert(&self, topic_filter: &str, mapper: TopicMapper) -> Option { + self.0.lock().insert(topic_filter.into(), mapper) + } + + pub fn remove(&self, topic_filter: &str) -> Option { + self.0.lock().remove(topic_filter) + } + + pub fn get(&self, topic_filter: &str) -> Option { + self.0.lock().get(topic_filter).cloned() + } + + pub fn contains_key(&self, topic_filter: &str) -> bool { + self.0.lock().contains_key(topic_filter) + } + + pub fn subscriptions(&self) -> Vec { + self.0 + .lock() + .values() + .map(TopicMapper::subscribe_to) + .collect() + } +} diff --git a/mqtt/mqtt-bridge/src/rpc.rs b/mqtt/mqtt-bridge/src/rpc.rs deleted file mode 100644 index 1f9fcc12444..00000000000 --- a/mqtt/mqtt-bridge/src/rpc.rs +++ /dev/null @@ -1,518 +0,0 @@ -use std::collections::HashMap; - -use async_trait::async_trait; -use bson::{doc, Document}; -use bytes::{buf::BufExt, Bytes}; -use lazy_static::lazy_static; -use regex::Regex; -use serde::{Deserialize, Serialize}; -use tracing::{error, warn}; - -// Import and use mocks when run tests, real implementation when otherwise -#[cfg(test)] -pub use mqtt3::{ - MockPublishHandle as PublishHandle, MockUpdateSubscriptionHandle as UpdateSubscriptionHandle, -}; -#[cfg(not(test))] -pub use mqtt3::{PublishHandle, UpdateSubscriptionHandle}; - -use mqtt3::{ - proto::Publication, proto::QoS, proto::SubscribeTo, Event, PublishError, - SubscriptionUpdateEvent, -}; - -use crate::client::{EventHandler, Handled}; - -/// MQTT client event handler to react on RPC commands that `EdgeHub` sends -/// to execute. -/// -/// The main purpose of this handler is to establish a communication channel -/// between `EdgeHub` and the upstream bridge. -/// `EdgeHub` will use low level commands SUB, UNSUB, PUB. In turn the bridge -/// sends corresponding MQTT packet to upstream broker and waits for an ack -/// from the upstream. After ack is received it sends a special publish to -/// downstream broker. -pub struct RpcHandler { - upstream_subs: UpdateSubscriptionHandle, - upstream_pubs: PublishHandle, - downstream_pubs: PublishHandle, - subscriptions: HashMap, -} - -impl RpcHandler { - #[allow(dead_code)] // TODO remove when used in the code - pub fn new( - upstream_subs: UpdateSubscriptionHandle, - upstream_pubs: PublishHandle, - downstream_pubs: PublishHandle, - ) -> Self { - Self { - upstream_subs, - upstream_pubs, - downstream_pubs, - subscriptions: HashMap::default(), - } - } - - async fn handle_command( - &mut self, - command_id: String, - command: RpcCommand, - ) -> Result<(), RpcError> { - match command { - RpcCommand::Subscribe { topic_filter } => { - self.handle_subscribe(command_id, topic_filter).await - } - RpcCommand::Unsubscribe { topic_filter } => { - self.handle_unsubscribe(command_id, topic_filter).await - } - RpcCommand::Publish { topic, payload } => { - self.handle_publish(command_id, topic, payload).await - } - } - } - - async fn handle_subscribe( - &mut self, - command_id: String, - topic_filter: String, - ) -> Result<(), RpcError> { - let subscribe_to = SubscribeTo { - topic_filter: topic_filter.clone(), - qos: QoS::AtLeastOnce, - }; - - match self.upstream_subs.subscribe(subscribe_to).await { - Ok(_) => { - if let Some(existing) = self.subscriptions.insert(topic_filter, command_id) { - warn!("duplicating sub request found for {}", existing); - } - } - Err(e) => { - let reason = format!("unable to subscribe to upstream {}. {}", topic_filter, e); - self.publish_nack(command_id, reason).await?; - } - } - - Ok(()) - } - - async fn handle_unsubscribe( - &mut self, - command_id: String, - topic_filter: String, - ) -> Result<(), RpcError> { - match self.upstream_subs.unsubscribe(topic_filter.clone()).await { - Ok(_) => { - if let Some(existing) = self.subscriptions.insert(topic_filter, command_id) { - warn!("duplicating unsub request found for {}", existing); - } - } - Err(e) => { - let reason = format!( - "unable to unsubscribe from upstream {}. {}", - topic_filter, e - ); - self.publish_nack(command_id, reason).await?; - } - } - - Ok(()) - } - - async fn handle_publish( - &mut self, - command_id: String, - topic_name: String, - payload: Vec, - ) -> Result<(), RpcError> { - let publication = Publication { - topic_name: topic_name.clone(), - qos: QoS::AtLeastOnce, - retain: false, - payload: payload.into(), - }; - - match self.upstream_pubs.publish(publication).await { - Ok(_) => self.publish_ack(command_id).await, - Err(e) => { - let reason = format!("unable to publish to upstream {}. {}", topic_name, e); - self.publish_nack(command_id, reason).await - } - } - } - - async fn handle_subcription_update( - &mut self, - subscription: &SubscriptionUpdateEvent, - ) -> Result { - match subscription { - SubscriptionUpdateEvent::Subscribe(sub) => { - if let Some(command_id) = self.subscriptions.remove(&sub.topic_filter) { - self.publish_ack(command_id).await?; - return Ok(true); - } - } - SubscriptionUpdateEvent::RejectedByServer(topic_filter) => { - if let Some(command_id) = self.subscriptions.remove(topic_filter) { - let reason = format!("subscription rejected by server {}", topic_filter); - self.publish_nack(command_id, reason).await?; - return Ok(true); - } - } - SubscriptionUpdateEvent::Unsubscribe(topic_filter) => { - if let Some(command_id) = self.subscriptions.remove(topic_filter) { - self.publish_ack(command_id).await?; - return Ok(true); - } - } - } - - Ok(false) - } - - async fn publish_nack(&mut self, command_id: String, reason: String) -> Result<(), RpcError> { - let mut payload = Vec::new(); - let doc = doc! { "reason": reason }; - doc.to_writer(&mut payload)?; - - let publication = Publication { - topic_name: format!("$edgehub/rpc/nack/{}", &command_id), - qos: QoS::AtLeastOnce, - retain: false, - payload: payload.into(), - }; - - self.downstream_pubs - .publish(publication) - .await - .map_err(|e| RpcError::SendNack(command_id, e)) - } - - async fn publish_ack(&mut self, command_id: String) -> Result<(), RpcError> { - let publication = Publication { - topic_name: format!("$edgehub/rpc/ack/{}", &command_id), - qos: QoS::AtLeastOnce, - retain: false, - payload: Bytes::default(), - }; - - self.downstream_pubs - .publish(publication) - .await - .map_err(|e| RpcError::SendAck(command_id, e)) - } -} - -#[async_trait] -impl EventHandler for RpcHandler { - type Error = RpcError; - - async fn handle(&mut self, event: &Event) -> Result { - match event { - Event::Publication(publication) => { - if let Some(command_id) = capture_command_id(&publication.topic_name) { - let doc = Document::from_reader(&mut publication.payload.clone().reader())?; - match bson::from_document(doc)? { - VersionedRpcCommand::V1(command) => { - self.handle_command(command_id, command).await?; - - return Ok(Handled::Fully); - } - } - } - } - Event::SubscriptionUpdates(subscriptions) => { - let mut handled = 0; - for subscription in subscriptions { - if self.handle_subcription_update(subscription).await? { - handled += 1; - } - } - - if handled == subscriptions.len() { - return Ok(Handled::Fully); - } else { - return Ok(Handled::Partially); - } - } - _ => {} - } - - Ok(Handled::Skipped) - } -} - -#[derive(Debug, thiserror::Error)] -pub enum RpcError { - #[error("failed to deserialize command from received publication")] - DeserializeCommand(#[from] bson::de::Error), - - #[error("failed to serialize ack")] - SerializeAck(#[from] bson::ser::Error), - - #[error("unable to send nack for {0}. {1}")] - SendNack(String, #[source] PublishError), - - #[error("unable to send ack for {0}. {1}")] - SendAck(String, #[source] PublishError), -} - -fn capture_command_id(topic_name: &str) -> Option { - lazy_static! { - static ref RPC_TOPIC_PATTERN: Regex = Regex::new("\\$edgehub/rpc/(?P[^/ ]+)$") - .expect("failed to create new Regex from pattern"); - } - - RPC_TOPIC_PATTERN - .captures(topic_name) - .and_then(|captures| captures.name("command_id")) - .map(|command_id| command_id.as_str().into()) -} - -#[derive(Debug, PartialEq, Serialize, Deserialize)] -#[serde(rename_all = "camelCase", tag = "cmd")] -enum RpcCommand { - #[serde(rename = "sub")] - Subscribe { - #[serde(rename = "topic")] - topic_filter: String, - }, - - #[serde(rename = "unsub")] - Unsubscribe { - #[serde(rename = "topic")] - topic_filter: String, - }, - - #[serde(rename = "pub")] - Publish { - topic: String, - - #[serde(with = "serde_bytes")] - payload: Vec, - }, -} - -#[derive(Debug, PartialEq, Serialize, Deserialize)] -#[serde(rename_all = "camelCase", tag = "version")] -enum VersionedRpcCommand { - V1(RpcCommand), -} - -#[cfg(test)] -mod tests { - use bson::{bson, doc, spec::BinarySubtype}; - use matches::assert_matches; - use mqtt3::{ - proto::QoS, Event, MockPublishHandle, MockUpdateSubscriptionHandle, ReceivedPublication, - }; - use test_case::test_case; - - use super::*; - - #[test] - fn it_deserializes_from_bson() { - let commands = vec![ - ( - bson!({ - "version": "v1", - "cmd": "sub", - "topic": "/foo", - }), - VersionedRpcCommand::V1(RpcCommand::Subscribe { - topic_filter: "/foo".into(), - }), - ), - ( - bson!({ - "version": "v1", - "cmd": "unsub", - "topic": "/foo", - }), - VersionedRpcCommand::V1(RpcCommand::Unsubscribe { - topic_filter: "/foo".into(), - }), - ), - ( - bson!({ - "version": "v1", - "cmd": "pub", - "topic": "/foo", - "payload": vec![100, 97, 116, 97] - }), - VersionedRpcCommand::V1(RpcCommand::Publish { - topic: "/foo".into(), - payload: b"data".to_vec(), - }), - ), - ]; - - for (command, expected) in commands { - let rpc: VersionedRpcCommand = bson::from_bson(command).unwrap(); - assert_eq!(rpc, expected); - } - } - - #[test_case(r"$edgehub/rpc/foo", Some("foo"); "when word")] - #[test_case(r"$edgehub/rpc/CA761232-ED42-11CE-BACD-00AA0057B223", Some("CA761232-ED42-11CE-BACD-00AA0057B223"); "when uuid")] - #[test_case(r"$edgehub/rpc/ack/CA761232-ED42-11CE-BACD-00AA0057B223", None; "when ack")] - #[test_case(r"$iothub/rpc/ack/CA761232-ED42-11CE-BACD-00AA0057B223", None; "when wrong topic")] - #[test_case(r"$iothub/rpc/ack/some id", None; "when spaces")] - fn it_captures_command_id(topic: &str, expected: Option<&str>) { - assert_eq!(capture_command_id(topic).as_deref(), expected) - } - - #[tokio::test] - async fn it_handles_sub_command() { - let upstream_pubs = MockPublishHandle::new(); - - let mut upstream_subs = MockUpdateSubscriptionHandle::new(); - upstream_subs - .expect_subscribe() - .withf(|subscribe_to| subscribe_to.topic_filter == "/foo") - .returning(|_| Ok(())); - - let mut downstream_pubs = MockPublishHandle::new(); - downstream_pubs - .expect_publish() - .once() - .withf(|publication| publication.topic_name == "$edgehub/rpc/ack/1") - .returning(|_| Ok(())); - - let mut handler = RpcHandler::new(upstream_subs, upstream_pubs, downstream_pubs); - - // send a command to subscribe to topic /foo - let event = command("1", "sub", "/foo", None); - let res = handler.handle(&event).await; - assert_matches!(res, Ok(Handled::Fully)); - - // emulate server response - let event = - Event::SubscriptionUpdates(vec![SubscriptionUpdateEvent::Subscribe(SubscribeTo { - topic_filter: "/foo".into(), - qos: QoS::AtLeastOnce, - })]); - let res = handler.handle(&event).await; - assert_matches!(res, Ok(Handled::Fully)); - } - - #[tokio::test] - async fn it_send_nack_when_sub_to_upstream_rejected() { - let upstream_pubs = MockPublishHandle::new(); - - let mut upstream_subs = MockUpdateSubscriptionHandle::new(); - upstream_subs - .expect_subscribe() - .withf(|subscribe_to| subscribe_to.topic_filter == "/foo") - .returning(|_| Ok(())); - - let mut downstream_pubs = MockPublishHandle::new(); - downstream_pubs - .expect_publish() - .once() - .withf(|publication| publication.topic_name == "$edgehub/rpc/nack/1") - .returning(|_| Ok(())); - - let mut handler = RpcHandler::new(upstream_subs, upstream_pubs, downstream_pubs); - - // send a command to subscribe to topic /foo - let event = command("1", "sub", "/foo", None); - let res = handler.handle(&event).await; - assert_matches!(res, Ok(Handled::Fully)); - - // emulate server response - let event = Event::SubscriptionUpdates(vec![SubscriptionUpdateEvent::RejectedByServer( - "/foo".into(), - )]); - let res = handler.handle(&event).await; - assert_matches!(res, Ok(Handled::Fully)); - } - - #[tokio::test] - async fn it_handles_unsub_command() { - let upstream_pubs = MockPublishHandle::new(); - - let mut upstream_subs = MockUpdateSubscriptionHandle::new(); - upstream_subs - .expect_unsubscribe() - .withf(|unsubscribe_from| unsubscribe_from == "/foo") - .returning(|_| Ok(())); - - let mut downstream_pubs = MockPublishHandle::new(); - downstream_pubs - .expect_publish() - .once() - .withf(|publication| publication.topic_name == "$edgehub/rpc/ack/1") - .returning(|_| Ok(())); - - let mut handler = RpcHandler::new(upstream_subs, upstream_pubs, downstream_pubs); - - // send a command to unsubscribe from topic /foo - let event = command("1", "unsub", "/foo", None); - let res = handler.handle(&event).await; - assert_matches!(res, Ok(Handled::Fully)); - - // emulate server response - let event = - Event::SubscriptionUpdates(vec![SubscriptionUpdateEvent::Unsubscribe("/foo".into())]); - let res = handler.handle(&event).await; - assert_matches!(res, Ok(Handled::Fully)); - } - - #[tokio::test] - async fn it_handles_pub_command() { - let mut upstream_pubs = MockPublishHandle::new(); - upstream_pubs - .expect_publish() - .once() - .withf(|publication| { - publication.topic_name == "/foo" && publication.payload == Bytes::from("hello") - }) - .returning(|_| Ok(())); - - let upstream_subs = MockUpdateSubscriptionHandle::new(); - - let mut downstream_pubs = MockPublishHandle::new(); - downstream_pubs - .expect_publish() - .once() - .withf(|publication| publication.topic_name == "$edgehub/rpc/ack/1") - .returning(|_| Ok(())); - - let mut handler = RpcHandler::new(upstream_subs, upstream_pubs, downstream_pubs); - - let event = command("1", "pub", "/foo", Some(b"hello".to_vec())); - let res = handler.handle(&event).await; - - assert_matches!(res, Ok(Handled::Fully)); - } - - fn command(id: &str, cmd: &str, topic: &str, payload: Option>) -> Event { - let mut command = doc! { - "version": "v1", - "cmd": cmd, - "topic": topic - }; - if let Some(payload) = payload { - command.insert( - "payload", - bson::Binary { - subtype: BinarySubtype::Generic, - bytes: payload, - }, - ); - } - - let mut payload = Vec::new(); - command.to_writer(&mut payload).unwrap(); - - Event::Publication(ReceivedPublication { - topic_name: format!("$edgehub/rpc/{}", id), - dup: false, - qos: QoS::AtLeastOnce, - retain: false, - payload: payload.into(), - }) - } -} diff --git a/mqtt/mqtt-bridge/src/settings.rs b/mqtt/mqtt-bridge/src/settings.rs index da97950a4a5..2f074fbc4a2 100644 --- a/mqtt/mqtt-bridge/src/settings.rs +++ b/mqtt/mqtt-bridge/src/settings.rs @@ -77,25 +77,17 @@ impl<'de> serde::Deserialize<'de> for BridgeSettings { messages, } = serde::Deserialize::deserialize(deserializer)?; - let upstream_connection_settings = nested_bridge - .filter(|nested_bridge| { - nested_bridge - .enable_upstream_bridge() - .unwrap_or("false") - .to_lowercase() - == "true" - }) - .map(|nested_bridge| ConnectionSettings { - name: "upstream".into(), - address: format!( - "{}:{}", - nested_bridge.gateway_hostname, DEFAULT_UPSTREAM_PORT - ), - subscriptions: upstream.subscriptions, - credentials: Credentials::Provider(nested_bridge), - clean_session: upstream.clean_session, - keep_alive: upstream.keep_alive, - }); + let upstream_connection_settings = nested_bridge.map(|nested_bridge| ConnectionSettings { + name: "$upstream".into(), + address: format!( + "{}:{}", + nested_bridge.gateway_hostname, DEFAULT_UPSTREAM_PORT + ), + subscriptions: upstream.subscriptions, + credentials: Credentials::Provider(nested_bridge), + clean_session: upstream.clean_session, + keep_alive: upstream.keep_alive, + }); Ok(BridgeSettings { upstream: upstream_connection_settings, @@ -114,7 +106,7 @@ pub struct ConnectionSettings { #[serde(flatten)] credentials: Credentials, - subscriptions: Vec, + subscriptions: Vec, #[serde(with = "humantime_serde")] keep_alive: Duration, @@ -135,8 +127,24 @@ impl ConnectionSettings { &self.credentials } - pub fn subscriptions(&self) -> &Vec { - &self.subscriptions + pub fn subscriptions(&self) -> Vec { + self.subscriptions + .iter() + .filter_map(|sub| match sub { + Direction::In(topic) | Direction::Both(topic) => Some(topic.clone()), + _ => None, + }) + .collect() + } + + pub fn forwards(&self) -> Vec { + self.subscriptions + .iter() + .filter_map(|sub| match sub { + Direction::Out(topic) | Direction::Both(topic) => Some(topic.clone()), + _ => None, + }) + .collect() } pub fn keep_alive(&self) -> Duration { @@ -166,6 +174,14 @@ pub struct AuthenticationSettings { } impl AuthenticationSettings { + pub fn new(client_id: String, username: String, password: String) -> Self { + Self { + client_id, + username, + password, + } + } + pub fn client_id(&self) -> &str { &self.client_id } @@ -181,9 +197,6 @@ impl AuthenticationSettings { #[derive(Debug, Clone, PartialEq, Deserialize)] pub struct CredentialProviderSettings { - #[serde(rename = "enableupstreambridge")] - enable_upstream_bridge: Option, - #[serde(rename = "iotedge_iothubhostname")] iothub_hostname: String, @@ -204,10 +217,6 @@ pub struct CredentialProviderSettings { } impl CredentialProviderSettings { - pub fn enable_upstream_bridge(&self) -> Option<&str> { - self.enable_upstream_bridge.as_deref() - } - pub fn iothub_hostname(&self) -> &str { &self.iothub_hostname } @@ -233,11 +242,8 @@ impl CredentialProviderSettings { } } -#[derive(Debug, Clone, PartialEq, Deserialize)] +#[derive(Debug, Default, Clone, PartialEq, Deserialize)] pub struct TopicRule { - #[serde(flatten)] - direction: Direction, - topic: String, #[serde(rename = "outPrefix")] @@ -248,20 +254,29 @@ pub struct TopicRule { } impl TopicRule { - pub fn direction(&self) -> &Direction { - &self.direction - } - pub fn topic(&self) -> &str { &self.topic } pub fn out_prefix(&self) -> Option<&str> { - self.out_prefix.as_ref().map(AsRef::as_ref) + self.out_prefix.as_deref() } pub fn in_prefix(&self) -> Option<&str> { - self.in_prefix.as_ref().map(AsRef::as_ref) + self.in_prefix.as_deref() + } + + pub fn subscribe_to(&self) -> String { + match &self.in_prefix { + Some(local) => { + if local.is_empty() { + self.topic.clone() + } else { + format!("{}/{}", local, self.topic) + } + } + None => self.topic.clone(), + } } } @@ -269,9 +284,13 @@ impl TopicRule { #[serde(tag = "direction")] pub enum Direction { #[serde(rename = "in")] - In, + In(TopicRule), + #[serde(rename = "out")] - Out, + Out(TopicRule), + + #[serde(rename = "both")] + Both(TopicRule), } #[derive(Debug, Clone, PartialEq, Deserialize)] @@ -284,7 +303,7 @@ struct UpstreamSettings { clean_session: bool, - subscriptions: Vec, + subscriptions: Vec, } #[cfg(test)] @@ -317,7 +336,7 @@ mod tests { let settings = BridgeSettings::from_file("tests/config.json").unwrap(); let upstream = settings.upstream().unwrap(); - assert_eq!(upstream.name(), "upstream"); + assert_eq!(upstream.name(), "$upstream"); assert_eq!(upstream.address(), "edge1:8883"); match upstream.credentials() { @@ -371,8 +390,7 @@ mod tests { #[test] #[serial(env_settings)] - fn from_env_no_upstream_protcol() { - let _gateway_hostname = env::set_var("IOTEDGE_GATEWAYHOSTNAME", "upstream"); + fn from_env_no_gateway_hostname() { let _device_id = env::set_var("IOTEDGE_DEVICEID", "device1"); let _module_id = env::set_var("IOTEDGE_MODULEID", "m1"); let _generation_id = env::set_var("IOTEDGE_MODULEGENERATIONID", "123"); @@ -394,12 +412,11 @@ mod tests { let _generation_id = env::set_var("IOTEDGE_MODULEGENERATIONID", "123"); let _workload_uri = env::set_var("IOTEDGE_WORKLOADURI", "workload"); let _iothub_hostname = env::set_var("IOTEDGE_IOTHUBHOSTNAME", "iothub"); - let _enable_bridge = env::set_var("enableupstreambridge", "true"); let settings = make_settings().unwrap(); let upstream = settings.upstream().unwrap(); - assert_eq!(upstream.name(), "upstream"); + assert_eq!(upstream.name(), "$upstream"); assert_eq!(upstream.address(), "upstream:8883"); match upstream.credentials() { diff --git a/mqtt/mqtt-bridge/src/token_source.rs b/mqtt/mqtt-bridge/src/token_source.rs index 47855a25838..d71bae326d0 100644 --- a/mqtt/mqtt-bridge/src/token_source.rs +++ b/mqtt/mqtt-bridge/src/token_source.rs @@ -1,5 +1,3 @@ -#![allow(dead_code)] // TODO remove when ready - use std::io::{Error, ErrorKind}; use async_trait::async_trait; diff --git a/mqtt/mqtt-bridge/src/connectivity.rs b/mqtt/mqtt-bridge/src/upstream/connectivity.rs similarity index 51% rename from mqtt/mqtt-bridge/src/connectivity.rs rename to mqtt/mqtt-bridge/src/upstream/connectivity.rs index 6815b8c5eae..6e3cb040912 100644 --- a/mqtt/mqtt-bridge/src/connectivity.rs +++ b/mqtt/mqtt-bridge/src/upstream/connectivity.rs @@ -1,25 +1,42 @@ -#![allow(dead_code)] // TODO remove when ready +use std::fmt::{Display, Formatter, Result as FmtResult}; use async_trait::async_trait; +use serde::Serialize; use tracing::{debug, info}; use mqtt3::Event; use crate::{ - bridge::BridgeError, - client::{EventHandler, Handled}, - pump::{ConnectivityState, PumpHandle, PumpMessage}, + client::{Handled, MqttEventHandler}, + pump::{PumpError, PumpHandle, PumpMessage}, }; +use super::LocalUpstreamPumpEvent; + +#[derive(Clone, Copy, Debug, PartialEq, Serialize)] +pub enum ConnectivityState { + Connected, + Disconnected, +} + +impl Display for ConnectivityState { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + match self { + Self::Connected => write!(f, "Connected"), + Self::Disconnected => write!(f, "Disconnected"), + } + } +} + /// Handles connection and disconnection events and sends a notification when status changes -pub struct ConnectivityHandler { +pub struct ConnectivityMqttEventHandler { state: ConnectivityState, - sender: PumpHandle, + sender: PumpHandle, } -impl ConnectivityHandler { - pub fn new(sender: PumpHandle) -> Self { - ConnectivityHandler { +impl ConnectivityMqttEventHandler { + pub fn new(sender: PumpHandle) -> Self { + ConnectivityMqttEventHandler { state: ConnectivityState::Disconnected, sender, } @@ -27,25 +44,27 @@ impl ConnectivityHandler { } #[async_trait] -impl EventHandler for ConnectivityHandler { - type Error = BridgeError; +impl MqttEventHandler for ConnectivityMqttEventHandler { + type Error = ConnectivityError; - async fn handle(&mut self, event: &Event) -> Result { - match event { + async fn handle(&mut self, event: Event) -> Result { + let event = match event { Event::Disconnected(reason) => { - debug!("Received disconnected state {}", reason); + debug!("received disconnected state {}", reason); match self.state { ConnectivityState::Connected => { self.state = ConnectivityState::Disconnected; - self.sender - .send(PumpMessage::ConnectivityUpdate( - ConnectivityState::Disconnected, - )) - .await?; - info!("Sent disconnected state"); + + let event = LocalUpstreamPumpEvent::ConnectivityUpdate( + ConnectivityState::Disconnected, + ); + let msg = PumpMessage::Event(event); + self.sender.send(msg).await?; + + info!("sent disconnected state"); } ConnectivityState::Disconnected => { - debug!("Already disconnected"); + debug!("already disconnected"); } } @@ -55,52 +74,63 @@ impl EventHandler for ConnectivityHandler { Event::NewConnection { reset_session: _ } => { match self.state { ConnectivityState::Connected => { - debug!("Already connected"); + debug!("already connected"); } ConnectivityState::Disconnected => { self.state = ConnectivityState::Connected; - self.sender - .send(PumpMessage::ConnectivityUpdate( - ConnectivityState::Connected, - )) - .await?; - info!("Sent connected state") + + let event = LocalUpstreamPumpEvent::ConnectivityUpdate( + ConnectivityState::Connected, + ); + let msg = PumpMessage::Event(event); + self.sender.send(msg).await?; + + info!("sent connected state") } } return Ok(Handled::Fully); } - _ => {} - } + event => event, + }; - Ok(Handled::Skipped) + Ok(Handled::Skipped(event)) } } +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub struct ConnectivityError(#[from] PumpError); + #[cfg(test)] mod tests { + use matches::assert_matches; use mqtt3::{proto::QoS, proto::SubscribeTo, ConnectionError, Event, SubscriptionUpdateEvent}; - use tokio::sync::{mpsc, mpsc::error::TryRecvError}; + use tokio::sync::mpsc::error::TryRecvError; - use crate::client::Handled; - use crate::pump::{ConnectivityState, PumpMessage}; + use crate::{ + client::Handled, + pump::{self, PumpMessage}, + }; use super::*; #[tokio::test] async fn sends_connected_state() { - let (sender, mut connectivity_receiver) = mpsc::channel::(1); + let (handle, mut connectivity_receiver) = pump::channel(); - let mut ch = ConnectivityHandler::new(PumpHandle::new(sender)); + let mut ch = ConnectivityMqttEventHandler::new(handle); let event = Event::NewConnection { reset_session: true, }; - let res = ch.handle(&event).await.unwrap(); + let res = ch.handle(event).await.unwrap(); let msg = connectivity_receiver.try_recv().unwrap(); assert_eq!( msg, - PumpMessage::ConnectivityUpdate(ConnectivityState::Connected) + PumpMessage::Event(LocalUpstreamPumpEvent::ConnectivityUpdate( + ConnectivityState::Connected + )) ); assert_eq!(ch.state, ConnectivityState::Connected); assert_eq!(res, Handled::Fully); @@ -108,12 +138,12 @@ mod tests { #[tokio::test] async fn sends_disconnected_state() { - let (sender, mut connectivity_receiver) = mpsc::channel::(1); + let (handle, mut connectivity_receiver) = pump::channel(); - let mut ch = ConnectivityHandler::new(PumpHandle::new(sender)); + let mut ch = ConnectivityMqttEventHandler::new(handle); let res_connected = ch - .handle(&Event::NewConnection { + .handle(Event::NewConnection { reset_session: true, }) .await @@ -121,9 +151,7 @@ mod tests { let _msg = connectivity_receiver.try_recv().unwrap(); let res_disconnected = ch - .handle(&Event::Disconnected( - ConnectionError::ServerClosedConnection, - )) + .handle(Event::Disconnected(ConnectionError::ServerClosedConnection)) .await .unwrap(); @@ -131,7 +159,9 @@ mod tests { assert_eq!( msg, - PumpMessage::ConnectivityUpdate(ConnectivityState::Disconnected) + PumpMessage::Event(LocalUpstreamPumpEvent::ConnectivityUpdate( + ConnectivityState::Disconnected + )) ); assert_eq!(ch.state, ConnectivityState::Disconnected); assert_eq!(res_connected, Handled::Fully); @@ -140,19 +170,19 @@ mod tests { #[tokio::test] async fn not_sends_connected_state_when_already_connected() { - let (sender, mut connectivity_receiver) = mpsc::channel::(1); + let (handle, mut connectivity_receiver) = pump::channel(); - let mut ch = ConnectivityHandler::new(PumpHandle::new(sender)); + let mut ch = ConnectivityMqttEventHandler::new(handle); let res_connected1 = ch - .handle(&Event::NewConnection { + .handle(Event::NewConnection { reset_session: true, }) .await .unwrap(); let res_connected2 = ch - .handle(&Event::NewConnection { + .handle(Event::NewConnection { reset_session: true, }) .await @@ -168,14 +198,12 @@ mod tests { #[tokio::test] async fn not_sends_disconnected_state_when_already_disconnected() { - let (sender, mut connectivity_receiver) = mpsc::channel::(1); + let (handle, mut connectivity_receiver) = pump::channel(); - let mut ch = ConnectivityHandler::new(PumpHandle::new(sender)); + let mut ch = ConnectivityMqttEventHandler::new(handle); let res_disconnected = ch - .handle(&Event::Disconnected( - ConnectionError::ServerClosedConnection, - )) + .handle(Event::Disconnected(ConnectionError::ServerClosedConnection)) .await .unwrap(); @@ -187,9 +215,9 @@ mod tests { #[tokio::test] async fn not_handles_other_events() { - let (sender, _) = mpsc::channel::(1); + let (handle, _) = pump::channel(); - let mut ch = ConnectivityHandler::new(PumpHandle::new(sender)); + let mut ch = ConnectivityMqttEventHandler::new(handle); let event = Event::SubscriptionUpdates(vec![SubscriptionUpdateEvent::Subscribe(SubscribeTo { @@ -197,16 +225,16 @@ mod tests { qos: QoS::AtLeastOnce, })]); - let res = ch.handle(&event).await.unwrap(); + let res = ch.handle(event).await.unwrap(); - assert_eq!(res, Handled::Skipped) + assert_matches!(res, Handled::Skipped(_)) } #[tokio::test] async fn default_disconnected_state() { - let (sender, _) = mpsc::channel::(1); + let (handle, _) = pump::channel(); - let ch = ConnectivityHandler::new(PumpHandle::new(sender)); + let ch = ConnectivityMqttEventHandler::new(handle); assert_eq!(ch.state, ConnectivityState::Disconnected); } diff --git a/mqtt/mqtt-bridge/src/upstream/events/local.rs b/mqtt/mqtt-bridge/src/upstream/events/local.rs new file mode 100644 index 00000000000..ec4036fdbee --- /dev/null +++ b/mqtt/mqtt-bridge/src/upstream/events/local.rs @@ -0,0 +1,218 @@ +use async_trait::async_trait; +use bson::doc; +use bytes::Bytes; +use mqtt3::{proto::Publication, proto::QoS}; +use serde_json::json; +use tracing::{debug, error}; + +// Import and use mocks when run tests, real implementation when otherwise +#[cfg(test)] +pub use crate::client::MockPublishHandle as PublishHandle; + +#[cfg(not(test))] +use crate::client::PublishHandle; + +use crate::{ + pump::PumpMessageHandler, + upstream::{CommandId, ConnectivityState}, +}; + +const CONNECTIVITY_TOPIC: &str = "$internal/connectivity"; + +/// Pump control event for a local upstream bridge pump. +#[derive(Debug, PartialEq)] +pub enum LocalUpstreamPumpEvent { + /// Connectivity update event. + ConnectivityUpdate(ConnectivityState), + + /// RPC command acknowledgement event. + RpcAck(CommandId), + + /// RPC command negative acknowledgement event. + RpcNack(CommandId, String), + + /// Forward incoming upstream publication event. + Publication(Publication), +} + +/// Handles control event received by a local upstream bridge pump. +/// +/// It handles following events: +/// * connectivity update - emitted when the connection to remote broker changed +/// (connected/disconnected). It should publish corresponding MQTT message to the +/// local broker. +/// * RPC command acknowledgement - emitted when the RPC command executed with +/// success result. +/// * RPC command negative acknowledgement - emitted when the RPC command failed +/// to execute. +pub struct LocalUpstreamPumpEventHandler { + publish_handle: PublishHandle, +} + +impl LocalUpstreamPumpEventHandler { + pub fn new(publish_handle: PublishHandle) -> Self { + Self { publish_handle } + } +} + +#[async_trait] +impl PumpMessageHandler for LocalUpstreamPumpEventHandler { + type Message = LocalUpstreamPumpEvent; + + async fn handle(&mut self, message: Self::Message) { + let maybe_publication = match message { + LocalUpstreamPumpEvent::ConnectivityUpdate(status) => { + debug!("changed connectivity status to {}", status); + + let payload = json!({ "status": status }); + match serde_json::to_string(&payload) { + Ok(payload) => Some(Publication { + topic_name: CONNECTIVITY_TOPIC.to_owned(), + qos: QoS::AtLeastOnce, + retain: true, + payload: payload.into(), + }), + Err(e) => { + error!("unable to convert to JSON. {}", e); + None + } + } + } + LocalUpstreamPumpEvent::RpcAck(command_id) => { + debug!("sending rpc command ack {}", command_id); + + Some(Publication { + topic_name: format!("$downstream/rpc/ack/{}", command_id), + qos: QoS::AtLeastOnce, + retain: false, + payload: Bytes::default(), + }) + } + LocalUpstreamPumpEvent::RpcNack(command_id, reason) => { + debug!("sending rpc command nack {}", command_id); + + let mut payload = Vec::new(); + let doc = doc! { "reason": reason }; + match doc.to_writer(&mut payload) { + Ok(_) => Some(Publication { + topic_name: format!("$downstream/rpc/nack/{}", command_id), + qos: QoS::AtLeastOnce, + retain: false, + payload: payload.into(), + }), + Err(e) => { + error!("unable to convert to BSON. {}", e); + None + } + } + } + LocalUpstreamPumpEvent::Publication(publication) => { + debug!("sending incoming message on {}", publication.topic_name); + Some(publication) + } + }; + + if let Some(publication) = maybe_publication { + let topic = publication.topic_name.clone(); + if let Err(e) = self.publish_handle.publish(publication).await { + error!(error = %e, "failed to publish on topic {}", topic); + } + } + } +} + +#[cfg(test)] +mod tests { + use crate::client::MockPublishHandle; + + use super::*; + + #[tokio::test] + async fn it_sends_connectivity_update_when_connected() { + it_sends_connectivity_update_when_changed(ConnectivityState::Connected).await; + } + + #[tokio::test] + async fn it_sends_connectivity_update_when_disconnected() { + it_sends_connectivity_update_when_changed(ConnectivityState::Disconnected).await; + } + + async fn it_sends_connectivity_update_when_changed(state: ConnectivityState) { + let payload = json!({ "status": state }); + let payload = serde_json::to_vec(&payload).unwrap(); + + let mut pub_handle = MockPublishHandle::new(); + pub_handle + .expect_publish() + .once() + .withf(move |publication| { + publication.topic_name == "$internal/connectivity" && publication.payload == payload + }) + .returning(|_| Ok(())); + + let mut handler = LocalUpstreamPumpEventHandler::new(pub_handle); + + let event = LocalUpstreamPumpEvent::ConnectivityUpdate(state); + handler.handle(event).await; + } + + #[tokio::test] + async fn it_sends_rpc_ack() { + let mut pub_handle = MockPublishHandle::new(); + pub_handle + .expect_publish() + .once() + .withf(move |publication| { + publication.topic_name == "$downstream/rpc/ack/1" && publication.payload.is_empty() + }) + .returning(|_| Ok(())); + + let mut handler = LocalUpstreamPumpEventHandler::new(pub_handle); + + let event = LocalUpstreamPumpEvent::RpcAck("1".into()); + handler.handle(event).await; + } + + #[tokio::test] + async fn it_sends_rpc_nack() { + let mut payload = Vec::new(); + let doc = doc! { "reason": "error" }; + doc.to_writer(&mut payload).unwrap(); + + let mut pub_handle = MockPublishHandle::new(); + pub_handle + .expect_publish() + .once() + .withf(move |publication| { + publication.topic_name == "$downstream/rpc/nack/1" && publication.payload == payload + }) + .returning(|_| Ok(())); + + let mut handler = LocalUpstreamPumpEventHandler::new(pub_handle); + + let event = LocalUpstreamPumpEvent::RpcNack("1".into(), "error".into()); + handler.handle(event).await; + } + + #[tokio::test] + async fn it_sends_incoming_publication() { + let mut pub_handle = MockPublishHandle::new(); + pub_handle + .expect_publish() + .once() + .withf(move |publication| { + publication.topic_name == "$downstream/device_1/module_a/twin/res/200" + }) + .returning(|_| Ok(())); + + let mut handler = LocalUpstreamPumpEventHandler::new(pub_handle); + + let event = LocalUpstreamPumpEvent::Publication(Publication { + topic_name: "$downstream/device_1/module_a/twin/res/200".into(), + qos: QoS::AtLeastOnce, + retain: false, + payload: "hello".into(), + }); + handler.handle(event).await; + } +} diff --git a/mqtt/mqtt-bridge/src/upstream/events/mod.rs b/mqtt/mqtt-bridge/src/upstream/events/mod.rs new file mode 100644 index 00000000000..037553580f1 --- /dev/null +++ b/mqtt/mqtt-bridge/src/upstream/events/mod.rs @@ -0,0 +1,5 @@ +mod local; +mod remote; + +pub use local::{LocalUpstreamPumpEvent, LocalUpstreamPumpEventHandler}; +pub use remote::{RemoteUpstreamPumpEvent, RemoteUpstreamPumpEventHandler}; diff --git a/mqtt/mqtt-bridge/src/upstream/events/remote.rs b/mqtt/mqtt-bridge/src/upstream/events/remote.rs new file mode 100644 index 00000000000..d7dd83000b5 --- /dev/null +++ b/mqtt/mqtt-bridge/src/upstream/events/remote.rs @@ -0,0 +1,284 @@ +use async_trait::async_trait; +use mqtt3::proto::{Publication, QoS, SubscribeTo}; +use tracing::{error, warn}; + +use crate::{ + pump::{PumpHandle, PumpMessageHandler}, + upstream::{ + CommandId, LocalUpstreamPumpEvent, RpcCommand, RpcError, RpcPumpHandle, RpcSubscriptions, + }, +}; + +// Import and use mocks when run tests, real implementation when otherwise +#[cfg(test)] +use crate::client::{ + MockPublishHandle as PublishHandle, MockUpdateSubscriptionHandle as UpdateSubscriptionHandle, +}; +#[cfg(not(test))] +use crate::client::{PublishHandle, UpdateSubscriptionHandle}; + +/// Pump control event for a remote upstream bridge pump. +#[derive(Debug, PartialEq)] +pub enum RemoteUpstreamPumpEvent { + RpcCommand(CommandId, RpcCommand), +} + +/// Handles control event received by a remote upstream bridge pump. +/// +/// It handles following events: +/// * RPC command - emitted when `EdgeHub` requested RPC command to be executed +/// against remote broker. +pub struct RemoteUpstreamPumpEventHandler { + remote_sub_handle: UpdateSubscriptionHandle, + remote_pub_handle: PublishHandle, + local_pump: RpcPumpHandle, + subscriptions: RpcSubscriptions, +} + +impl RemoteUpstreamPumpEventHandler { + pub fn new( + remote_sub_handle: UpdateSubscriptionHandle, + remote_pub_handle: PublishHandle, + local_pump_handle: PumpHandle, + subscriptions: RpcSubscriptions, + ) -> Self { + Self { + remote_sub_handle, + remote_pub_handle, + local_pump: RpcPumpHandle::new(local_pump_handle), + subscriptions, + } + } + + async fn handle_command( + &mut self, + command_id: CommandId, + command: RpcCommand, + ) -> Result<(), RpcError> { + match command { + RpcCommand::Subscribe { topic_filter } => { + self.handle_subscribe(command_id, topic_filter).await + } + RpcCommand::Unsubscribe { topic_filter } => { + self.handle_unsubscribe(command_id, topic_filter).await + } + RpcCommand::Publish { topic, payload } => { + self.handle_publish(command_id, topic, payload).await + } + } + } + + async fn handle_subscribe( + &mut self, + command_id: CommandId, + topic_filter: String, + ) -> Result<(), RpcError> { + let subscribe_to = SubscribeTo { + topic_filter: topic_filter.clone(), + qos: QoS::AtLeastOnce, + }; + + match self.remote_sub_handle.subscribe(subscribe_to).await { + Ok(_) => { + if let Some(existing) = self.subscriptions.insert(&topic_filter, command_id) { + warn!("duplicating sub request found for {}", existing); + } + } + Err(e) => { + let reason = format!("unable to subscribe to upstream {}. {}", topic_filter, e); + self.local_pump.send_nack(command_id, reason).await?; + } + } + + Ok(()) + } + + async fn handle_unsubscribe( + &mut self, + command_id: CommandId, + topic_filter: String, + ) -> Result<(), RpcError> { + match self + .remote_sub_handle + .unsubscribe(topic_filter.clone()) + .await + { + Ok(_) => { + if let Some(existing) = self.subscriptions.insert(&topic_filter, command_id) { + warn!("duplicating unsub request found for {}", existing); + } + } + Err(e) => { + let reason = format!( + "unable to unsubscribe from upstream {}. {}", + topic_filter, e + ); + self.local_pump.send_nack(command_id, reason).await?; + } + } + + Ok(()) + } + + async fn handle_publish( + &mut self, + command_id: CommandId, + topic_name: String, + payload: Vec, + ) -> Result<(), RpcError> { + let publication = Publication { + topic_name: topic_name.clone(), + qos: QoS::AtLeastOnce, + retain: false, + payload: payload.into(), + }; + + match self.remote_pub_handle.publish(publication).await { + Ok(_) => self.local_pump.send_ack(command_id).await, + Err(e) => { + let reason = format!("unable to publish to upstream {}. {}", topic_name, e); + self.local_pump.send_nack(command_id, reason).await + } + } + } +} + +#[async_trait] +impl PumpMessageHandler for RemoteUpstreamPumpEventHandler { + type Message = RemoteUpstreamPumpEvent; + + async fn handle(&mut self, message: Self::Message) { + match message { + RemoteUpstreamPumpEvent::RpcCommand(command_id, command) => { + let cmd_string = command.to_string(); + if let Err(e) = self.handle_command(command_id.clone(), command).await { + error!( + "unable to handle rpc command {} {}. {}", + command_id, cmd_string, e + ); + } + } + } + } +} + +#[cfg(test)] +mod tests { + use tokio::sync::mpsc::error::TryRecvError; + + use bytes::Bytes; + use matches::assert_matches; + + use crate::{ + client::{MockPublishHandle, MockUpdateSubscriptionHandle}, + pump::{self, PumpMessage}, + }; + + use super::*; + + #[tokio::test] + async fn it_handles_sub_command() { + let remote_pub_handle = MockPublishHandle::new(); + + let mut remote_sub_handle = MockUpdateSubscriptionHandle::new(); + remote_sub_handle + .expect_subscribe() + .withf(|subscribe_to| subscribe_to.topic_filter == "/foo") + .returning(|_| Ok(())); + + let (local_pump, mut rx) = pump::channel(); + + let rpc_subscriptions = RpcSubscriptions::default(); + let mut handler = RemoteUpstreamPumpEventHandler::new( + remote_sub_handle, + remote_pub_handle, + local_pump, + rpc_subscriptions.clone(), + ); + + // handle a command to subscribe to topic /foo + let command = RpcCommand::Subscribe { + topic_filter: "/foo".into(), + }; + let event = RemoteUpstreamPumpEvent::RpcCommand("1".into(), command); + handler.handle(event).await; + + // check no message which was sent to local pump + assert_matches!(rx.try_recv(), Err(TryRecvError::Empty)); + + // check subscriptions has requested topic + assert_matches!(rpc_subscriptions.remove("/foo"), Some(id) if id == "1".into()); + } + + #[tokio::test] + async fn it_handles_unsub_command() { + let remote_pub_handle = MockPublishHandle::new(); + + let mut remote_sub_handle = MockUpdateSubscriptionHandle::new(); + remote_sub_handle + .expect_unsubscribe() + .withf(|subscribe_from| subscribe_from == "/foo") + .returning(|_| Ok(())); + + let (local_pump, mut rx) = pump::channel(); + + let rpc_subscriptions = RpcSubscriptions::default(); + let mut handler = RemoteUpstreamPumpEventHandler::new( + remote_sub_handle, + remote_pub_handle, + local_pump, + rpc_subscriptions.clone(), + ); + + // handle a command to unsubscribe from topic /foo + let command = RpcCommand::Unsubscribe { + topic_filter: "/foo".into(), + }; + let event = RemoteUpstreamPumpEvent::RpcCommand("1".into(), command); + handler.handle(event).await; + + // check no message which was sent to local pump + assert_matches!(rx.try_recv(), Err(TryRecvError::Empty)); + + // check subscriptions has requested topic + assert_matches!(rpc_subscriptions.remove("/foo"), Some(id) if id == "1".into()); + } + + #[tokio::test] + async fn it_handles_pub_command() { + let mut remote_pub_handle = MockPublishHandle::new(); + remote_pub_handle + .expect_publish() + .once() + .withf(|publication| { + publication.topic_name == "/foo" && publication.payload == Bytes::from("hello") + }) + .returning(|_| Ok(())); + + let remote_sub_handle = MockUpdateSubscriptionHandle::new(); + + let (local_pump, mut rx) = pump::channel(); + + let rpc_subscriptions = RpcSubscriptions::default(); + let mut handler = RemoteUpstreamPumpEventHandler::new( + remote_sub_handle, + remote_pub_handle, + local_pump, + rpc_subscriptions, + ); + + // handle a command to publish on topic /foo + let command = RpcCommand::Publish { + topic: "/foo".into(), + payload: b"hello".to_vec(), + }; + let event = RemoteUpstreamPumpEvent::RpcCommand("1".into(), command); + handler.handle(event).await; + + // check message which was sent to local pump + assert_matches!( + rx.recv().await, + Some(PumpMessage::Event(LocalUpstreamPumpEvent::RpcAck(id))) if id == "1".into() + ); + } +} diff --git a/mqtt/mqtt-bridge/src/upstream/mod.rs b/mqtt/mqtt-bridge/src/upstream/mod.rs new file mode 100644 index 00000000000..0bd69c13e6f --- /dev/null +++ b/mqtt/mqtt-bridge/src/upstream/mod.rs @@ -0,0 +1,121 @@ +//! Module contains code related to upstream bridge. + +mod connectivity; +mod events; +mod rpc; + +pub use connectivity::{ConnectivityError, ConnectivityMqttEventHandler, ConnectivityState}; +pub use events::{ + LocalUpstreamPumpEvent, LocalUpstreamPumpEventHandler, RemoteUpstreamPumpEvent, + RemoteUpstreamPumpEventHandler, +}; +pub use rpc::{ + CommandId, LocalRpcMqttEventHandler, RemoteRpcMqttEventHandler, RpcCommand, RpcError, + RpcPumpHandle, RpcSubscriptions, +}; + +use async_trait::async_trait; +use mqtt3::Event; + +use crate::{ + bridge::BridgeError, + client::{Handled, MqttEventHandler}, + messages::StoreMqttEventHandler, + persist::StreamWakeableState, +}; + +/// Handles all events that local clients received for upstream bridge. +/// +/// Contains several event handlers to process RPC and regular MQTT events +/// in a chain. +pub struct LocalUpstreamMqttEventHandler { + messages: StoreMqttEventHandler, + rpc: LocalRpcMqttEventHandler, +} + +impl LocalUpstreamMqttEventHandler { + pub fn new(messages: StoreMqttEventHandler, rpc: LocalRpcMqttEventHandler) -> Self { + Self { messages, rpc } + } +} + +#[async_trait] +impl MqttEventHandler for LocalUpstreamMqttEventHandler +where + S: StreamWakeableState + Send, +{ + type Error = BridgeError; + + fn subscriptions(&self) -> Vec { + let mut subscriptions = self.rpc.subscriptions(); + subscriptions.extend(self.messages.subscriptions()); + subscriptions + } + + async fn handle(&mut self, event: Event) -> Result { + // try to handle as RPC command first + match self.rpc.handle(event).await? { + Handled::Fully => Ok(Handled::Fully), + Handled::Partially(event) | Handled::Skipped(event) => { + // handle as an event for regular message handler + self.messages.handle(event).await + } + } + } +} + +/// Handles all events that comes to remote clients received for upstream bridge. +/// +/// Contains several event handlers to process Connectivity, RPC and regular +/// MQTT events in a chain. +pub struct RemoteUpstreamMqttEventHandler { + messages: StoreMqttEventHandler, + rpc: RemoteRpcMqttEventHandler, + connectivity: ConnectivityMqttEventHandler, +} + +impl RemoteUpstreamMqttEventHandler { + pub fn new( + messages: StoreMqttEventHandler, + rpc: RemoteRpcMqttEventHandler, + connectivity: ConnectivityMqttEventHandler, + ) -> Self { + Self { + messages, + rpc, + connectivity, + } + } +} + +#[async_trait] +impl MqttEventHandler for RemoteUpstreamMqttEventHandler +where + S: StreamWakeableState + Send, +{ + type Error = BridgeError; + + fn subscriptions(&self) -> Vec { + let mut subscriptions = self.messages.subscriptions(); + subscriptions.extend(self.rpc.subscriptions()); + subscriptions.extend(self.connectivity.subscriptions()); + subscriptions + } + + async fn handle(&mut self, event: Event) -> Result { + // try to handle incoming connectivity event + let event = match self.connectivity.handle(event).await? { + Handled::Fully => return Ok(Handled::Fully), + Handled::Partially(event) | Handled::Skipped(event) => event, + }; + + // try to handle incoming messages as RPC command + let event = match self.rpc.handle(event).await? { + Handled::Fully => return Ok(Handled::Fully), + Handled::Partially(event) | Handled::Skipped(event) => event, + }; + + // handle as an event for regular message handler + self.messages.handle(event).await + } +} diff --git a/mqtt/mqtt-bridge/src/upstream/rpc/local.rs b/mqtt/mqtt-bridge/src/upstream/rpc/local.rs new file mode 100644 index 00000000000..570be589253 --- /dev/null +++ b/mqtt/mqtt-bridge/src/upstream/rpc/local.rs @@ -0,0 +1,347 @@ +use std::collections::HashSet; + +use async_trait::async_trait; +use bson::{doc, Document}; +use bytes::buf::BufExt; +use lazy_static::lazy_static; +use mqtt3::{Event, ReceivedPublication, SubscriptionUpdateEvent}; +use regex::Regex; +use serde::{Deserialize, Serialize}; + +use crate::{ + client::{Handled, MqttEventHandler}, + pump::{PumpHandle, PumpMessage}, + upstream::{CommandId, RemoteUpstreamPumpEvent, RpcCommand}, +}; + +use super::RpcError; + +/// An RPC handlers that responsible to connect to part of the bridge which +/// connects to local broker. +/// +/// It receives RPC commands on a special topic, converts it to a `RpcCommand` +/// and sends to remote pump as a `PumpMessage`. +pub struct LocalRpcMqttEventHandler { + remote_pump: PumpHandle, + subscriptions: HashSet, +} + +impl LocalRpcMqttEventHandler { + /// Creates a new instance of local part of RPC handler. + pub fn new(remote_pump: PumpHandle) -> Self { + let mut subscriptions = HashSet::new(); + subscriptions.insert("$upstream/rpc/+".into()); + + Self { + remote_pump, + subscriptions, + } + } + + async fn handle_publication( + &mut self, + command_id: CommandId, + publication: &ReceivedPublication, + ) -> Result { + let doc = Document::from_reader(&mut publication.payload.clone().reader())?; + match bson::from_document(doc)? { + VersionedRpcCommand::V1(command) => { + let event = RemoteUpstreamPumpEvent::RpcCommand(command_id.clone(), command); + let msg = PumpMessage::Event(event); + self.remote_pump + .send(msg) + .await + .map_err(|e| RpcError::SendToRemotePump(command_id, e))?; + + Ok(true) + } + } + } + + fn handle_subscriptions( + &mut self, + subscriptions: impl IntoIterator, + ) -> Vec { + let mut skipped = vec![]; + for subscription in subscriptions { + if !self.handle_subscription_update(&subscription) { + skipped.push(subscription); + } + } + + skipped + } + + fn handle_subscription_update(&mut self, subscription: &SubscriptionUpdateEvent) -> bool { + let topic_filter = match subscription { + SubscriptionUpdateEvent::Subscribe(sub) => &sub.topic_filter, + SubscriptionUpdateEvent::RejectedByServer(topic_filter) => topic_filter, + SubscriptionUpdateEvent::Unsubscribe(topic_filter) => topic_filter, + }; + + self.subscriptions.contains(topic_filter) + } +} + +#[async_trait] +impl MqttEventHandler for LocalRpcMqttEventHandler { + type Error = RpcError; + + fn subscriptions(&self) -> Vec { + self.subscriptions.iter().cloned().collect() + } + + async fn handle(&mut self, event: Event) -> Result { + match event { + Event::Publication(publication) => { + if let Some(command_id) = capture_command_id(&publication.topic_name) { + if self.handle_publication(command_id, &publication).await? { + return Ok(Handled::Fully); + } + } + + Ok(Handled::Skipped(Event::Publication(publication))) + } + Event::SubscriptionUpdates(subscriptions) => { + let subscriptions_len = subscriptions.len(); + + let skipped = self.handle_subscriptions(subscriptions); + if skipped.is_empty() { + Ok(Handled::Fully) + } else if skipped.len() == subscriptions_len { + let event = Event::SubscriptionUpdates(skipped); + Ok(Handled::Skipped(event)) + } else { + let event = Event::SubscriptionUpdates(skipped); + Ok(Handled::Partially(event)) + } + } + _ => Ok(Handled::Skipped(event)), + } + } +} + +fn capture_command_id(topic_name: &str) -> Option { + lazy_static! { + static ref RPC_TOPIC_PATTERN: Regex = Regex::new("\\$upstream/rpc/(?P[^/ ]+)$") + .expect("failed to create new Regex from pattern"); + } + + RPC_TOPIC_PATTERN + .captures(topic_name) + .and_then(|captures| captures.name("command_id")) + .map(|command_id| command_id.as_str().into()) +} + +#[derive(Debug, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase", tag = "version")] +enum VersionedRpcCommand { + V1(RpcCommand), +} + +#[cfg(test)] +mod tests { + use bson::{bson, spec::BinarySubtype}; + use bytes::Bytes; + use matches::assert_matches; + use mqtt3::{ + proto::{QoS, SubscribeTo}, + ReceivedPublication, + }; + use test_case::test_case; + + use super::*; + + #[test] + fn it_deserializes_from_bson() { + let commands = vec![ + ( + bson!({ + "version": "v1", + "cmd": "sub", + "topic": "/foo", + }), + VersionedRpcCommand::V1(RpcCommand::Subscribe { + topic_filter: "/foo".into(), + }), + ), + ( + bson!({ + "version": "v1", + "cmd": "unsub", + "topic": "/foo", + }), + VersionedRpcCommand::V1(RpcCommand::Unsubscribe { + topic_filter: "/foo".into(), + }), + ), + ( + bson!({ + "version": "v1", + "cmd": "pub", + "topic": "/foo", + "payload": vec![100, 97, 116, 97] + }), + VersionedRpcCommand::V1(RpcCommand::Publish { + topic: "/foo".into(), + payload: b"data".to_vec(), + }), + ), + ]; + + for (command, expected) in commands { + let rpc: VersionedRpcCommand = bson::from_bson(command).unwrap(); + assert_eq!(rpc, expected); + } + } + + #[test_case(r"$upstream/rpc/foo", Some("foo".into()); "when word")] + #[test_case(r"$upstream/rpc/CA761232-ED42-11CE-BACD-00AA0057B223", Some("CA761232-ED42-11CE-BACD-00AA0057B223".into()); "when uuid")] + #[test_case(r"$downstream/rpc/ack/CA761232-ED42-11CE-BACD-00AA0057B223", None; "when ack")] + #[test_case(r"$iothub/rpc/ack/CA761232-ED42-11CE-BACD-00AA0057B223", None; "when wrong topic")] + #[test_case(r"$iothub/rpc/ack/some id", None; "when spaces")] + #[allow(clippy::needless_pass_by_value)] + fn it_captures_command_id(topic: &str, expected: Option) { + assert_eq!(capture_command_id(topic), expected) + } + + #[tokio::test] + async fn it_handles_rpc_commands() { + let (pump_handle, mut rx) = crate::pump::channel(); + let mut handler = LocalRpcMqttEventHandler::new(pump_handle); + + let event = command("1", "sub", "/foo", None); + let res = handler.handle(event).await; + assert_matches!(res, Ok(Handled::Fully)); + assert_matches!(rx.recv().await, Some(PumpMessage::Event(RemoteUpstreamPumpEvent::RpcCommand(id, RpcCommand::Subscribe{topic_filter}))) if topic_filter == "/foo" && id == "1".into()); + + let event = command("2", "unsub", "/foo", None); + let res = handler.handle(event).await; + assert_matches!(res, Ok(Handled::Fully)); + assert_matches!(rx.recv().await, Some(PumpMessage::Event(RemoteUpstreamPumpEvent::RpcCommand(id, RpcCommand::Unsubscribe{topic_filter}))) if topic_filter == "/foo" && id == "2".into()); + + let event = command("3", "pub", "/foo", Some(b"hello".to_vec())); + let res = handler.handle(event).await; + assert_matches!(res, Ok(Handled::Fully)); + assert_matches!(rx.recv().await, Some(PumpMessage::Event(RemoteUpstreamPumpEvent::RpcCommand(id, RpcCommand::Publish{topic, payload}))) if topic == "/foo" && payload == b"hello" && id == "3".into()); + } + + #[tokio::test] + async fn it_skips_when_not_rpc_command_pub() { + let (pump_handle, _) = crate::pump::channel(); + let mut handler = LocalRpcMqttEventHandler::new(pump_handle); + + let event = Event::Publication(ReceivedPublication { + topic_name: "$edgehub/twin/$edgeHub".into(), + dup: false, + qos: QoS::AtLeastOnce, + retain: false, + payload: Bytes::default(), + }); + let res = handler.handle(event).await; + assert_matches!(res, Ok(Handled::Skipped(_))); + } + + #[tokio::test] + async fn it_skips_when_not_rpc_topic_sub() { + let update_events = vec![ + SubscriptionUpdateEvent::Subscribe(SubscribeTo { + qos: QoS::AtLeastOnce, + topic_filter: "/foo".into(), + }), + SubscriptionUpdateEvent::RejectedByServer("/foo".into()), + SubscriptionUpdateEvent::Unsubscribe("/foo".into()), + ]; + + let (pump_handle, _) = crate::pump::channel(); + let mut handler = LocalRpcMqttEventHandler::new(pump_handle); + + for update_event in update_events { + let event = Event::SubscriptionUpdates(vec![update_event]); + let res = handler.handle(event).await; + assert_matches!(res, Ok(Handled::Skipped(_))); + } + } + + #[tokio::test] + async fn it_handles_fully_when_rpc_topic_sub() { + let update_events = vec![ + SubscriptionUpdateEvent::Subscribe(SubscribeTo { + qos: QoS::AtLeastOnce, + topic_filter: "$upstream/rpc/+".into(), + }), + SubscriptionUpdateEvent::RejectedByServer("$upstream/rpc/+".into()), + SubscriptionUpdateEvent::Unsubscribe("$upstream/rpc/+".into()), + ]; + + let (pump_handle, _) = crate::pump::channel(); + let mut handler = LocalRpcMqttEventHandler::new(pump_handle); + + for update_event in update_events { + let event = Event::SubscriptionUpdates(vec![update_event]); + let res = handler.handle(event).await; + assert_matches!(res, Ok(Handled::Fully)); + } + } + + #[tokio::test] + async fn it_handles_partially_when_mixed_topics_sub() { + let expected_events = vec![ + SubscriptionUpdateEvent::Subscribe(SubscribeTo { + qos: QoS::AtLeastOnce, + topic_filter: "/foo/bar".into(), + }), + SubscriptionUpdateEvent::RejectedByServer("/foo".into()), + SubscriptionUpdateEvent::Unsubscribe("/bar".into()), + ]; + + let rpc_events = vec![ + SubscriptionUpdateEvent::Subscribe(SubscribeTo { + qos: QoS::AtLeastOnce, + topic_filter: "$upstream/rpc/+".into(), + }), + SubscriptionUpdateEvent::RejectedByServer("$upstream/rpc/+".into()), + SubscriptionUpdateEvent::Unsubscribe("$upstream/rpc/+".into()), + ]; + + let (pump_handle, _) = crate::pump::channel(); + let mut handler = LocalRpcMqttEventHandler::new(pump_handle); + + for rpc_event in rpc_events { + let mut update_events = expected_events.clone(); + update_events.push(rpc_event); + + let event = Event::SubscriptionUpdates(update_events); + let res = handler.handle(event).await; + assert_matches!(res, Ok(Handled::Partially(events)) if events == Event::SubscriptionUpdates(expected_events.clone())); + } + } + + fn command(id: &str, cmd: &str, topic: &str, payload: Option>) -> Event { + let mut command = doc! { + "version": "v1", + "cmd": cmd, + "topic": topic + }; + if let Some(payload) = payload { + command.insert( + "payload", + bson::Binary { + subtype: BinarySubtype::Generic, + bytes: payload, + }, + ); + } + + let mut payload = Vec::new(); + command.to_writer(&mut payload).unwrap(); + + Event::Publication(ReceivedPublication { + topic_name: format!("$upstream/rpc/{}", id), + dup: false, + qos: QoS::AtLeastOnce, + retain: false, + payload: payload.into(), + }) + } +} diff --git a/mqtt/mqtt-bridge/src/upstream/rpc/mod.rs b/mqtt/mqtt-bridge/src/upstream/rpc/mod.rs new file mode 100644 index 00000000000..84aaa58646b --- /dev/null +++ b/mqtt/mqtt-bridge/src/upstream/rpc/mod.rs @@ -0,0 +1,126 @@ +//! Downstream MQTT client event handler to react on RPC commands that +//! `EdgeHub` sends to execute. +//! +//! The main purpose of this handler is to establish a communication channel +//! between `EdgeHub` and the upstream bridge. +//! `EdgeHub` will use low level commands SUB, UNSUB, PUB. In turn the bridge +//! sends corresponding MQTT packet to upstream broker and waits for an ack +//! from the upstream. After ack is received it sends a special publish to +//! downstream broker. + +mod local; +mod remote; + +pub use local::LocalRpcMqttEventHandler; +use parking_lot::Mutex; +pub use remote::{RemoteRpcMqttEventHandler, RpcPumpHandle}; + +use std::{ + collections::HashMap, fmt::Display, fmt::Formatter, fmt::Result as FmtResult, sync::Arc, +}; + +use bson::doc; +use serde::{Deserialize, Serialize}; +use tracing::error; + +use crate::pump::PumpError; + +/// RPC command unique identificator. +#[derive(Debug, Clone, PartialEq)] +pub struct CommandId(Arc); + +impl From for CommandId +where + C: Into, +{ + fn from(command_id: C) -> Self { + Self(Arc::new(command_id.into())) + } +} + +impl Display for CommandId { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + write!(f, "{}", self.0) + } +} + +/// RPC command execution error. +#[derive(Debug, thiserror::Error)] +pub enum RpcError { + #[error("failed to deserialize command from received publication")] + DeserializeCommand(#[from] bson::de::Error), + + #[error("unable to send nack for {0}. {1}")] + SendNack(CommandId, #[source] PumpError), + + #[error("unable to send ack for {0}. {1}")] + SendAck(CommandId, #[source] PumpError), + + #[error("unable to send command for {0} to remote pump. {1}")] + SendToRemotePump(CommandId, #[source] PumpError), + + #[error("unable to send publication on {0} to remote pump. {1}")] + SendPublicationToLocalPump(String, #[source] PumpError), +} + +/// RPC command to be executed against upstream broker. +#[derive(Debug, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase", tag = "cmd")] +pub enum RpcCommand { + /// A RPC command to subscribe to the topic. + #[serde(rename = "sub")] + Subscribe { + #[serde(rename = "topic")] + topic_filter: String, + }, + + /// A RPC command to unsubscribe from the topic. + #[serde(rename = "unsub")] + Unsubscribe { + #[serde(rename = "topic")] + topic_filter: String, + }, + + /// A RPC command to publish a message to a given topic. + #[serde(rename = "pub")] + Publish { + topic: String, + + #[serde(with = "serde_bytes")] + payload: Vec, + }, +} + +impl Display for RpcCommand { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + match self { + Self::Subscribe { topic_filter } => write!(f, "SUB {}", topic_filter), + Self::Unsubscribe { topic_filter } => write!(f, "UNSUB {}", topic_filter), + Self::Publish { topic, .. } => write!(f, "PUB {}", topic), + } + } +} + +/// Represents a mapping of RPC subscription topic filter to unique command +/// identifier. It is shared between remote pump event handler and remote +/// pump message processor of upstream bridge. +/// +/// It is shared due to `mqtt3::Client` implementation details when +/// subscription has been made with `UpdateSubscriptionHandle` but the server +/// response comes back as `mqtt3::Event` type which handled with the event +/// handler. +#[derive(Debug, Clone, Default)] +pub struct RpcSubscriptions(Arc>>); + +impl RpcSubscriptions { + /// Stores topic filter to command identifier mapping. + pub fn insert(&self, topic_filter: &str, id: CommandId) -> Option { + self.0.lock().insert(topic_filter.into(), id) + } + + /// Removes topic filter to command identifier mapping and returns + /// `CommandId` if exists. + pub fn remove(&self, topic_filter: &str) -> Option { + self.0.lock().remove(topic_filter) + } +} diff --git a/mqtt/mqtt-bridge/src/upstream/rpc/remote.rs b/mqtt/mqtt-bridge/src/upstream/rpc/remote.rs new file mode 100644 index 00000000000..d391e808fbe --- /dev/null +++ b/mqtt/mqtt-bridge/src/upstream/rpc/remote.rs @@ -0,0 +1,367 @@ +use async_trait::async_trait; +use lazy_static::lazy_static; +use regex::RegexSet; + +use mqtt3::{proto::Publication, Event, ReceivedPublication, SubscriptionUpdateEvent}; +use tracing::debug; + +use crate::{ + client::{Handled, MqttEventHandler}, + pump::PumpHandle, + pump::PumpMessage, + upstream::LocalUpstreamPumpEvent, +}; + +use super::{CommandId, RpcError, RpcSubscriptions}; + +/// An RPC handlers that responsible to connect to part of the bridge which +/// connects to upstream broker. +/// +/// 1. It receives a subscription update, identifies those which related to +/// requested RPC commands and sends ACK or NACK to local pump as a +/// `PumpMessage` for each update. +/// +/// 2. It receives a publication, identifies those which are for `IoTHub` +/// topics, translates topic and sends a special `PumpMessage` event to +/// local pump. +pub struct RemoteRpcMqttEventHandler { + subscriptions: RpcSubscriptions, + local_pump: RpcPumpHandle, +} + +impl RemoteRpcMqttEventHandler { + pub fn new( + subscriptions: RpcSubscriptions, + local_pump: PumpHandle, + ) -> Self { + Self { + subscriptions, + local_pump: RpcPumpHandle::new(local_pump), + } + } + + async fn handle_publication( + &mut self, + publication: &ReceivedPublication, + ) -> Result { + if let Some(topic_name) = translate(&publication.topic_name) { + debug!("forwarding incoming upstream publication to {}", topic_name); + + let publication = Publication { + topic_name, + qos: publication.qos, + retain: publication.retain, + payload: publication.payload.clone(), + }; + self.local_pump.send_pub(publication).await?; + Ok(true) + } else { + Ok(false) + } + } + + async fn handle_subscriptions( + &mut self, + subscriptions: impl IntoIterator, + ) -> Result>, RpcError> { + let mut skipped = vec![]; + for subscription in subscriptions { + if !self.handle_subscription_update(&subscription).await? { + skipped.push(subscription); + } + } + + if skipped.is_empty() { + Ok(None) + } else { + Ok(Some(skipped)) + } + } + + async fn handle_subscription_update( + &mut self, + subscription: &SubscriptionUpdateEvent, + ) -> Result { + match subscription { + SubscriptionUpdateEvent::Subscribe(sub) => { + if let Some(command_id) = self.subscriptions.remove(&sub.topic_filter) { + self.local_pump.send_ack(command_id).await?; + return Ok(true); + } + } + SubscriptionUpdateEvent::RejectedByServer(topic_filter) => { + if let Some(command_id) = self.subscriptions.remove(topic_filter) { + let reason = format!("subscription rejected by server {}", topic_filter); + self.local_pump.send_nack(command_id, reason).await?; + return Ok(true); + } + } + SubscriptionUpdateEvent::Unsubscribe(topic_filter) => { + if let Some(command_id) = self.subscriptions.remove(topic_filter) { + self.local_pump.send_ack(command_id).await?; + return Ok(true); + } + } + } + + Ok(false) + } +} + +#[async_trait] +impl MqttEventHandler for RemoteRpcMqttEventHandler { + type Error = RpcError; + + async fn handle(&mut self, event: Event) -> Result { + let event = match event { + Event::Publication(publication) if self.handle_publication(&publication).await? => { + return Ok(Handled::Fully); + } + Event::SubscriptionUpdates(subscriptions) => { + let len = subscriptions.len(); + match self.handle_subscriptions(subscriptions).await? { + Some(skipped) if skipped.len() == len => { + let event = Event::SubscriptionUpdates(skipped); + return Ok(Handled::Skipped(event)); + } + Some(skipped) => { + let event = Event::SubscriptionUpdates(skipped); + return Ok(Handled::Partially(event)); + } + None => return Ok(Handled::Fully), + }; + } + event => event, + }; + + Ok(Handled::Skipped(event)) + } +} + +fn translate(topic_name: &str) -> Option { + const DEVICE_OR_MODULE_ID: &str = r"(?P[^/]+)(/(?P[^/]+))?"; + + lazy_static! { + static ref UPSTREAM_TOPIC_PATTERNS: RegexSet = RegexSet::new(&[ + format!("\\$iothub/{}/twin/res/(?P.*)", DEVICE_OR_MODULE_ID), + format!( + "\\$iothub/{}/twin/desired/(?P.*)", + DEVICE_OR_MODULE_ID + ), + format!( + "\\$iothub/{}/methods/post/(?P.*)", + DEVICE_OR_MODULE_ID + ) + ]) + .expect("upstream topic patterns"); + }; + + if UPSTREAM_TOPIC_PATTERNS.is_match(topic_name) { + Some(topic_name.replace("$iothub", "$downstream")) + } else { + None + } +} + +/// Convenient wrapper around `PumpHandle` for local pump that encapsulates +/// sending RPC command Ack, Nack or Publish. +pub struct RpcPumpHandle(PumpHandle); + +impl RpcPumpHandle { + pub fn new(handle: PumpHandle) -> Self { + Self(handle) + } + + pub async fn send_ack(&mut self, command_id: CommandId) -> Result<(), RpcError> { + let event = LocalUpstreamPumpEvent::RpcAck(command_id.clone()); + self.0 + .send(PumpMessage::Event(event)) + .await + .map_err(|e| RpcError::SendAck(command_id, e)) + } + + pub async fn send_nack( + &mut self, + command_id: CommandId, + reason: String, + ) -> Result<(), RpcError> { + let event = LocalUpstreamPumpEvent::RpcNack(command_id.clone(), reason); + self.0 + .send(PumpMessage::Event(event)) + .await + .map_err(|e| RpcError::SendNack(command_id, e)) + } + + pub async fn send_pub(&mut self, publication: Publication) -> Result<(), RpcError> { + let topic_name = publication.topic_name.clone(); + let event = LocalUpstreamPumpEvent::Publication(publication); + self.0 + .send(PumpMessage::Event(event)) + .await + .map_err(|e| RpcError::SendPublicationToLocalPump(topic_name, e)) + } +} + +#[cfg(test)] +mod tests { + use matches::assert_matches; + use test_case::test_case; + + use mqtt3::{ + proto::{QoS, SubscribeTo}, + SubscriptionUpdateEvent, + }; + + use crate::pump::{self, PumpMessage}; + + use super::*; + + #[tokio::test] + async fn it_send_event_when_subscription_update_received() { + let subscriptions = RpcSubscriptions::default(); + subscriptions.insert("/foo/subscribed", "1".into()); + subscriptions.insert("/foo/rejected", "2".into()); + subscriptions.insert("/foo/unsubscribed", "3".into()); + + let (local_pump, mut rx) = pump::channel(); + let mut handler = RemoteRpcMqttEventHandler::new(subscriptions, local_pump); + + let event = Event::SubscriptionUpdates(vec![ + SubscriptionUpdateEvent::Subscribe(SubscribeTo { + topic_filter: "/foo/subscribed".into(), + qos: QoS::AtLeastOnce, + }), + SubscriptionUpdateEvent::RejectedByServer("/foo/rejected".into()), + SubscriptionUpdateEvent::Unsubscribe("/foo/unsubscribed".into()), + ]); + + let res = handler.handle(event).await; + assert_matches!(res, Ok(Handled::Fully)); + + assert_matches!( + rx.recv().await, + Some(PumpMessage::Event(LocalUpstreamPumpEvent::RpcAck(id))) if id == "1".into() + ); + assert_matches!( + rx.recv().await, + Some(PumpMessage::Event(LocalUpstreamPumpEvent::RpcNack(id, _))) if id == "2".into() + ); + assert_matches!( + rx.recv().await, + Some(PumpMessage::Event(LocalUpstreamPumpEvent::RpcAck(id))) if id == "3".into() + ); + } + + #[tokio::test] + async fn it_returns_partially_handled_when_has_non_rpc() { + let subscriptions = RpcSubscriptions::default(); + subscriptions.insert("/foo/rpc", "1".into()); + + let (local_pump, mut rx) = pump::channel(); + let mut handler = RemoteRpcMqttEventHandler::new(subscriptions, local_pump); + + let event = Event::SubscriptionUpdates(vec![ + SubscriptionUpdateEvent::Subscribe(SubscribeTo { + topic_filter: "/foo/rpc".into(), + qos: QoS::AtLeastOnce, + }), + SubscriptionUpdateEvent::Subscribe(SubscribeTo { + topic_filter: "/bar".into(), + qos: QoS::AtLeastOnce, + }), + ]); + + let res = handler.handle(event).await; + let expected = + Event::SubscriptionUpdates(vec![SubscriptionUpdateEvent::Subscribe(SubscribeTo { + topic_filter: "/bar".into(), + qos: QoS::AtLeastOnce, + })]); + assert_matches!(res, Ok(Handled::Partially(event)) if event == expected); + + assert_matches!( + rx.recv().await, + Some(PumpMessage::Event(LocalUpstreamPumpEvent::RpcAck(id))) if id == "1".into() + ); + } + + #[tokio::test] + async fn it_sends_publications_twin_response() { + it_sends_publication_when_known_topic( + "$iothub/device_1/module_a/twin/res/?rid=1", + "$downstream/device_1/module_a/twin/res/?rid=1", + ) + .await; + } + + #[tokio::test] + async fn it_sends_publications_twin_desired() { + it_sends_publication_when_known_topic( + "$iothub/device_1/module_a/twin/desired/?rid=1", + "$downstream/device_1/module_a/twin/desired/?rid=1", + ) + .await; + } + + #[tokio::test] + async fn it_sends_publications_direct_method() { + it_sends_publication_when_known_topic( + "$iothub/device_1/module_a/methods/post/?rid=1", + "$downstream/device_1/module_a/methods/post/?rid=1", + ) + .await; + } + + async fn it_sends_publication_when_known_topic(topic_name: &str, translated_topic: &str) { + let subscriptions = RpcSubscriptions::default(); + + let (local_pump, mut rx) = pump::channel(); + let mut handler = RemoteRpcMqttEventHandler::new(subscriptions, local_pump); + + let event = Event::Publication(ReceivedPublication { + topic_name: topic_name.into(), + dup: false, + qos: QoS::AtLeastOnce, + retain: false, + payload: "hello".into(), + }); + + let res = handler.handle(event).await; + assert_matches!(res, Ok(Handled::Fully)); + + assert_matches!( + rx.recv().await, + Some(PumpMessage::Event(LocalUpstreamPumpEvent::Publication(publication))) if publication.topic_name == translated_topic + ); + } + + #[tokio::test] + async fn it_skips_publication_when_unknown_topic() { + let subscriptions = RpcSubscriptions::default(); + + let (local_pump, _rx) = pump::channel(); + let mut handler = RemoteRpcMqttEventHandler::new(subscriptions, local_pump); + + let event = Event::Publication(ReceivedPublication { + topic_name: "/foo".into(), + dup: false, + qos: QoS::AtLeastOnce, + retain: false, + payload: "hello".into(), + }); + + let res = handler.handle(event).await; + assert_matches!(res, Ok(Handled::Skipped(_))); + } + + #[test_case("$iothub/device_1/module_a/twin/res/?rid=1", Some("$downstream/device_1/module_a/twin/res/?rid=1"); "twin module")] + #[test_case("$iothub/device_1/twin/res/?rid=1", Some("$downstream/device_1/twin/res/?rid=1"); "twin device")] + #[test_case("$iothub/device_1/module_a/twin/desired/?rid=1", Some("$downstream/device_1/module_a/twin/desired/?rid=1"); "desired twin module")] + #[test_case("$iothub/device_1/twin/desired/?rid=1", Some("$downstream/device_1/twin/desired/?rid=1"); "desired twin device")] + #[test_case("$iothub/device_1/module_a/methods/post/?rid=1", Some("$downstream/device_1/module_a/methods/post/?rid=1"); "direct method module")] + #[test_case("$iothub/device_1/methods/post/?rid=1", Some("$downstream/device_1/methods/post/?rid=1"); "direct method device")] + #[test_case("$edgehub/device_1/module_a/twin/res/?rid=1", None; "wrong prefix")] + fn it_translates_upstream_topic(topic_name: &str, expected: Option<&str>) { + assert_eq!(translate(topic_name).as_deref(), expected); + } +} diff --git a/mqtt/mqtt-bridge/tests/config.json b/mqtt/mqtt-bridge/tests/config.json index 6cd4ce28ac9..dbde80ac774 100644 --- a/mqtt/mqtt-bridge/tests/config.json +++ b/mqtt/mqtt-bridge/tests/config.json @@ -5,11 +5,10 @@ "iotedge_modulegenerationid": "321", "iotedge_workloaduri": "uri", "iotedge_iothubhostname": "iothub", - "enableupstreambridge": "true", "upstream": { "subscriptions": [ { - "direction": "in", + "direction": "both", "topic": "temp/#", "outPrefix": "floor/kitchen" }, @@ -21,12 +20,13 @@ }, { "direction": "out", - "topic": "temp/#", - "outPrefix": "floor/kitchen" + "topic": "pattern/#" }, { "direction": "out", - "topic": "pattern/#" + "topic": "floor2/#", + "inPrefix": "", + "outPrefix": "" } ] }, diff --git a/mqtt/mqtt-broker-tests-util/src/server.rs b/mqtt/mqtt-broker-tests-util/src/server.rs index dd688271c4b..e4fab1bdaa5 100644 --- a/mqtt/mqtt-broker-tests-util/src/server.rs +++ b/mqtt/mqtt-broker-tests-util/src/server.rs @@ -77,7 +77,7 @@ where P: MakePacketProcessor + Clone + Send + Sync + 'static, { lazy_static! { - static ref PORT: AtomicU32 = AtomicU32::new(5555); + static ref PORT: AtomicU32 = AtomicU32::new(8889); } let port = PORT.fetch_add(1, Ordering::SeqCst); diff --git a/mqtt/mqtt-broker/benches/persist_broker_state.rs b/mqtt/mqtt-broker/benches/persist_broker_state.rs index 9cc831214e3..374a2a1349c 100644 --- a/mqtt/mqtt-broker/benches/persist_broker_state.rs +++ b/mqtt/mqtt-broker/benches/persist_broker_state.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, iter::FromIterator}; +use std::{collections::HashMap, iter::FromIterator, net::SocketAddr}; use bytes::Bytes; use criterion::{ @@ -9,8 +9,8 @@ use tokio::runtime::Runtime; use mqtt3::proto::{Publication, QoS}; use mqtt_broker::{ - BrokerSnapshot, ClientId, FileFormat, FilePersistor, Persist, PersistError, SessionSnapshot, - VersionedFileFormat, + AuthId, BrokerSnapshot, ClientId, ClientInfo, FileFormat, FilePersistor, Persist, PersistError, + SessionSnapshot, VersionedFileFormat, }; fn test_write( @@ -119,8 +119,12 @@ fn make_fake_state( .chain(shared_messages.clone()) .collect(); + let id = format!("Session {}", i); + let client_id = ClientId::from(&id); + let auth_id = AuthId::from_identity(id); + SessionSnapshot::from_parts( - ClientId::from(format!("Session {}", i)), + ClientInfo::new(client_id, peer_addr(), auth_id), HashMap::new(), waiting_to_be_sent, ) @@ -130,6 +134,10 @@ fn make_fake_state( BrokerSnapshot::new(retained, sessions) } +fn peer_addr() -> SocketAddr { + "127.0.0.1:12345".parse().unwrap() +} + fn make_fake_publish(topic_name: String) -> Publication { Publication { topic_name, diff --git a/mqtt/mqtt-broker/src/auth/authorization.rs b/mqtt/mqtt-broker/src/auth/authorization.rs index 31d218170dd..7c02225c81c 100644 --- a/mqtt/mqtt-broker/src/auth/authorization.rs +++ b/mqtt/mqtt-broker/src/auth/authorization.rs @@ -1,4 +1,9 @@ -use std::{any::Any, convert::Infallible, error::Error as StdError}; +use std::{ + any::Any, + convert::Infallible, + error::Error as StdError, + fmt::{Display, Formatter, Result as FmtResult}, +}; use mqtt3::proto; @@ -103,18 +108,29 @@ impl Activity { } } +impl Display for Activity { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + write!( + f, + "client: {} operation: {}", + self.client_id(), + self.operation() + ) + } +} + /// Describes a client operation to be authorized. #[derive(Clone, Debug)] pub enum Operation { - Connect(Connect), + Connect, Publish(Publish), Subscribe(Subscribe), } impl Operation { /// Creates a new operation context for CONNECT request. - pub fn new_connect(connect: proto::Connect) -> Self { - Self::Connect(connect.into()) + pub fn new_connect() -> Self { + Self::Connect } /// Creates a new operation context for PUBLISH request. @@ -128,22 +144,17 @@ impl Operation { } } -/// Represents a client attempt to connect to the broker. -#[derive(Clone, Debug)] -pub struct Connect { - will: Option, -} - -impl Connect { - pub fn will(&self) -> Option<&Publication> { - self.will.as_ref() - } -} - -impl From for Connect { - fn from(connect: proto::Connect) -> Self { - Self { - will: connect.will.map(Into::into), +impl Display for Operation { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + match self { + Self::Connect => write!(f, "CONNECT"), + Self::Publish(publish) => write!(f, "PUBLISH {}", publish.publication.topic_name), + Self::Subscribe(subscribe) => write!( + f, + "SUBSCRIBE {}; qos: {}", + subscribe.topic_filter, + u8::from(subscribe.qos) + ), } } } @@ -236,33 +247,19 @@ impl From for Subscribe { #[cfg(test)] mod tests { - use std::{net::SocketAddr, time::Duration}; + use std::net::SocketAddr; use matches::assert_matches; - use mqtt3::{proto, PROTOCOL_LEVEL, PROTOCOL_NAME}; - use super::{Activity, AllowAll, Authorization, Authorizer, DenyAll, Operation}; use crate::ClientInfo; - fn connect() -> proto::Connect { - proto::Connect { - username: None, - password: None, - will: None, - client_id: proto::ClientId::ServerGenerated, - keep_alive: Duration::from_secs(1), - protocol_name: PROTOCOL_NAME.to_string(), - protocol_level: PROTOCOL_LEVEL, - } - } - #[test] fn default_auth_always_deny_any_action() { let auth = DenyAll; let activity = Activity::new( ClientInfo::new("client-auth-id", peer_addr(), "client-id"), - Operation::new_connect(connect()), + Operation::new_connect(), ); let res = auth.authorize(&activity); @@ -275,7 +272,7 @@ mod tests { let auth = AllowAll; let activity = Activity::new( ClientInfo::new("client-auth-id", peer_addr(), "client-id"), - Operation::new_connect(connect()), + Operation::new_connect(), ); let res = auth.authorize(&activity); diff --git a/mqtt/mqtt-broker/src/auth/mod.rs b/mqtt/mqtt-broker/src/auth/mod.rs index 97540a44267..9177f436b69 100644 --- a/mqtt/mqtt-broker/src/auth/mod.rs +++ b/mqtt/mqtt-broker/src/auth/mod.rs @@ -6,7 +6,7 @@ pub use authentication::{ DynAuthenticator, }; pub use authorization::{ - authorize_fn_ok, Activity, AllowAll, Authorization, Authorizer, Connect, DenyAll, Operation, + authorize_fn_ok, Activity, AllowAll, Authorization, Authorizer, DenyAll, Operation, Publication, Publish, Subscribe, }; @@ -17,8 +17,6 @@ use std::{ use serde::{Deserialize, Serialize}; -use crate::ClientId; - /// Authenticated MQTT client identity. #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub enum AuthId { @@ -79,8 +77,8 @@ impl Display for Identity { } } -impl PartialEq for Identity { - fn eq(&self, other: &ClientId) -> bool { - self.as_str() == other.as_str() +impl> PartialEq for Identity { + fn eq(&self, other: &T) -> bool { + self.as_str() == other.as_ref() } } diff --git a/mqtt/mqtt-broker/src/broker.rs b/mqtt/mqtt-broker/src/broker.rs index 3aa376c2791..34d8d1d2941 100644 --- a/mqtt/mqtt-broker/src/broker.rs +++ b/mqtt/mqtt-broker/src/broker.rs @@ -58,6 +58,7 @@ where } } Message::System(event) => { + debug!("incoming system event: {:?}", event); let span = info_span!("broker", event = "system"); let _enter = span.enter(); match event { @@ -85,7 +86,7 @@ where // TODO return an error instead? break; } else { - self.reevaluate_subscriptions(); + self.reauthorize(); debug!("successfully updated authorization info"); } } @@ -105,53 +106,51 @@ where Ok(self.snapshot()) } - fn prepare_activities(client_id: &ClientId, session: &Session) -> Vec<(ClientId, Activity)> { - session + fn prepare_activities(session: &Session) -> Vec { + let client_info = session.client_info().clone(); + let connect = std::iter::once(Activity::new(client_info, Operation::new_connect())); + + let sessions = session .subscriptions() .into_iter() .flat_map(HashMap::values) .map(|sub| { - let operation = Operation::new_subscribe(proto::SubscribeTo { + let subscribe = Operation::new_subscribe(proto::SubscribeTo { topic_filter: sub.filter().to_string(), qos: *sub.max_qos(), }); let client_info = session.client_info().clone(); - let activity = Activity::new(client_info, operation); + Activity::new(client_info, subscribe) + }); - (client_id.clone(), activity) - }) - .collect() + connect.chain(sessions).collect() } - fn reevaluate_subscriptions(&mut self) { - let disconnecting: Vec = self + fn reauthorize(&mut self) { + let disconnecting = self .sessions .iter() - .flat_map(|(client_id, session)| Self::prepare_activities(client_id, session)) - .filter_map( - |(client_id, activity)| match self.authorizer.authorize(&activity) { - Ok(Authorization::Allowed) => None, - Ok(Authorization::Forbidden(reason)) => { - debug!( - "client {} not allowed to subscribe to topic. {}", - client_id, reason - ); - Some(client_id) - } - Err(e) => { - warn!(message="error authorizing client subscription: {}", error = %e); - Some(client_id) - } - }, - ) - .collect(); + .flat_map(|(_, session)| Self::prepare_activities(session)) + .filter(|activity| match self.authorizer.authorize(&activity) { + Ok(Authorization::Allowed) => false, + Ok(Authorization::Forbidden(reason)) => { + warn!("not authorized: {}; reason: {}", &activity, reason); + true + } + Err(e) => { + warn!(message="error authorizing client: {}", error = %e); + true + } + }) + .collect::>(); - for client_id in disconnecting { - if let Err(x) = self.process_drop_connection(&client_id) { + for activity in disconnecting { + if let Err(reason) = self.process_drop_connection(activity.client_id()) { warn!( "error dropping connection for client {}. Reason {}", - client_id, x + activity.client_id(), + reason ); } } @@ -194,7 +193,7 @@ where client_id: ClientId, event: ClientEvent, ) -> Result<(), Error> { - debug!("incoming: {:?}", event); + debug!("incoming client event: {:?}", event); let result = match event { ClientEvent::ConnReq(connreq) => self.process_connect(client_id, connreq), ClientEvent::ConnAck(_) => { @@ -341,14 +340,14 @@ where connreq.peer_addr(), auth_id.clone(), ); - let operation = Operation::new_connect(connreq.connect().clone()); + let operation = Operation::new_connect(); let activity = Activity::new(client_info, operation); match self.authorizer.authorize(&activity) { Ok(Authorization::Allowed) => { - debug!("client {} successfully authorized", client_id); + debug!("successfully authorized: {}", &activity); } Ok(Authorization::Forbidden(reason)) => { - warn!("client {} not allowed to connect. {}", client_id, reason); + warn!("not authorized: {}; reason: {}", &activity, reason); refuse_connection!(proto::ConnectionRefusedReason::NotAuthorized); return Ok(()); } @@ -565,7 +564,7 @@ where let activity = Activity::new(client_info, operation); match self.authorizer.authorize(&activity) { Ok(Authorization::Allowed) => { - debug!("client {} successfully authorized", client_id); + debug!("successfully authorized: {}", &activity); let (maybe_publication, maybe_event) = session.handle_publish(publish)?; if let Some(event) = maybe_event { @@ -577,10 +576,7 @@ where } } Ok(Authorization::Forbidden(reason)) => { - warn!( - "client {} not allowed to publish to topic {}. {}", - client_id, publish.topic_name, reason - ); + warn!("not authorized: {}; reason: {}", &activity, reason); self.drop_connection(&client_id)?; } Err(e) => { @@ -941,7 +937,6 @@ fn subscribe( where Z: Authorizer, { - let client_id = session.client_id().clone(); let client_info = session.client_info().clone(); let mut subscriptions = Vec::with_capacity(subscribe.subscribe_to.len()); @@ -951,12 +946,12 @@ where let operation = Operation::new_subscribe(subscribe_to.clone()); let activity = Activity::new(client_info.clone(), operation); let auth = authorizer.authorize(&activity); - auth.map(|auth| (auth, subscribe_to)) + auth.map(|auth| (auth, subscribe_to, activity)) }); for auth in auth_results { let ack_qos = match auth { - Ok((Authorization::Allowed, subscribe_to)) => { + Ok((Authorization::Allowed, subscribe_to, _)) => { match session.subscribe_to(subscribe_to) { Ok((qos, subscription)) => { if let Some(subscription) = subscription { @@ -970,14 +965,8 @@ where } } } - Ok((Authorization::Forbidden(reason), subscribe_to)) => { - warn!( - "client {} not allowed to subscribe to topic {} qos {}. {}", - client_id, - subscribe_to.topic_filter, - u8::from(subscribe_to.qos), - reason - ); + Ok((Authorization::Forbidden(reason), _, activity)) => { + warn!("not authorized: {}; reason: {}", &activity, reason); proto::SubAckQos::Failure } Err(e) => { @@ -2074,7 +2063,7 @@ pub(crate) mod tests { async fn test_publish_client_has_no_permissions() { let broker = BrokerBuilder::default() .with_authorizer(authorize_fn_ok(|activity| { - if matches!(activity.operation(), Operation::Connect(_)) { + if matches!(activity.operation(), Operation::Connect) { Authorization::Allowed } else { Authorization::Forbidden("not allowed".to_string()) @@ -2111,7 +2100,7 @@ pub(crate) mod tests { async fn test_subscribe_client_has_no_permissions() { let broker = BrokerBuilder::default() .with_authorizer(authorize_fn_ok(|activity| match activity.operation() { - Operation::Connect(_) => Authorization::Allowed, + Operation::Connect => Authorization::Allowed, Operation::Subscribe(subscribe) => match subscribe.topic_filter() { "/topic/denied" => Authorization::Forbidden("denied".to_string()), _ => Authorization::Allowed, diff --git a/mqtt/mqtt-broker/src/connection/mod.rs b/mqtt/mqtt-broker/src/connection/mod.rs index 893323f0ec2..64ad67e3760 100644 --- a/mqtt/mqtt-broker/src/connection/mod.rs +++ b/mqtt/mqtt-broker/src/connection/mod.rs @@ -197,7 +197,7 @@ where // incoming packet stream completed with an error // send a DropConnection request to the broker and wait for the outgoing // task to drain - debug!(message = "incoming_task finished with an error. sending drop connection request to broker", error=%e); + debug!(message = "incoming_task finished with an error. sending drop connection request to broker", error = %e); let msg = Message::Client(client_id.clone(), ClientEvent::DropConnection); broker_handle.send(msg)?; @@ -270,7 +270,7 @@ where } }, Err(e) => { - warn!(message="error occurred while reading from connection", error=%e); + warn!(message="error occurred while reading from connection", error = %e); return Err(e.into()); } } @@ -300,14 +300,14 @@ where PacketAction::Continue(Some((packet, message))) => { // send a packet to a client if let Err(e) = outgoing.send(packet).await { - warn!(message = "error occurred while writing to connection", error=%e); + warn!(message = "error occurred while writing to connection", error = %e); return Err((messages, e.into())); } // send a message back to broker if let Some(message) = message { if let Err(e) = broker.send(message) { - warn!(message = "error occurred while sending QoS ack to broker", error=%e); + warn!(message = "error occurred while sending QoS ack to broker", error = %e); return Err((messages, e)); } } diff --git a/mqtt/mqtt-broker/src/lib.rs b/mqtt/mqtt-broker/src/lib.rs index 8e9b00ff677..9dd3f058707 100644 --- a/mqtt/mqtt-broker/src/lib.rs +++ b/mqtt/mqtt-broker/src/lib.rs @@ -20,6 +20,7 @@ mod ready; mod server; mod session; pub mod settings; +pub mod sidecar; mod snapshot; mod state_change; mod stream; @@ -32,7 +33,7 @@ pub mod proptest; use std::{ any::Any, - fmt::{Display, Formatter, Result as FmtResult}, + fmt::{Debug, Display, Formatter, Result as FmtResult}, net::SocketAddr, sync::Arc, }; @@ -180,20 +181,28 @@ impl ConnReq { } } -#[derive(Debug)] pub enum Auth { Identity(AuthId), Unknown, Failure, } +impl Debug for Auth { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + match self { + Auth::Identity(id) => f.write_fmt(format_args!("\"{}\"", id)), + Auth::Unknown => f.write_str("Unknown"), + Auth::Failure => f.write_str("Failure"), + } + } +} + #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub enum Publish { QoS0(proto::PacketIdentifier, proto::Publish), QoS12(proto::PacketIdentifier, proto::Publish), } -#[derive(Debug)] pub enum ClientEvent { /// Connect request ConnReq(ConnReq), @@ -252,7 +261,115 @@ pub enum ClientEvent { PubComp(proto::PubComp), } -#[derive(Debug)] +impl Debug for ClientEvent { + #[allow(clippy::too_many_lines)] + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + match self { + ClientEvent::ConnReq(connreq) => f + .debug_struct("ConnReq") + .field("client_id", &connreq.client_id().as_str()) + .field("connect", &connreq.connect()) + .field("auth", &connreq.auth()) + .finish(), + ClientEvent::ConnAck(connack) => f + .debug_struct("ConnAck") + .field("session_present", &connack.session_present) + .field("return_code", &connack.return_code) + .finish(), + ClientEvent::Disconnect(_) => f.write_str("Disconnect"), + ClientEvent::DropConnection => f.write_str("DropConnection"), + ClientEvent::CloseSession => f.write_str("CloseSession"), + ClientEvent::PingReq(_) => f.write_str("PingReq"), + ClientEvent::PingResp(_) => f.write_str("PingResp"), + ClientEvent::Subscribe(sub) => f + .debug_struct("Subscribe") + .field("id", &sub.packet_identifier.get()) + .field("qos", &sub.subscribe_to) + .finish(), + ClientEvent::SubAck(suback) => f + .debug_struct("SubAck") + .field("id", &suback.packet_identifier.get()) + .field("qos", &suback.qos) + .finish(), + ClientEvent::Unsubscribe(unsub) => f + .debug_struct("Unsubscribe") + .field("id", &unsub.packet_identifier.get()) + .field("topic", &unsub.unsubscribe_from) + .finish(), + ClientEvent::UnsubAck(unsuback) => f + .debug_struct("UnsubAck") + .field("id", &unsuback.packet_identifier.get()) + .finish(), + ClientEvent::PublishFrom(publish, _) => { + let (qos, id, dup) = match publish.packet_identifier_dup_qos { + proto::PacketIdentifierDupQoS::AtMostOnce => { + (proto::QoS::AtMostOnce, None, false) + } + proto::PacketIdentifierDupQoS::AtLeastOnce(id, dup) => { + (proto::QoS::AtLeastOnce, Some(id.get()), dup) + } + proto::PacketIdentifierDupQoS::ExactlyOnce(id, dup) => { + (proto::QoS::ExactlyOnce, Some(id.get()), dup) + } + }; + f.debug_struct("PublishFrom") + .field("qos", &qos) + .field("id", &id) + .field("dup", &dup) + .field("retain", &publish.retain) + .field("topic_name", &publish.topic_name) + .field("payload", &publish.payload) + .finish() + } + ClientEvent::PublishTo(publish) => { + let publish = match publish { + Publish::QoS0(_, publish) => publish, + Publish::QoS12(_, publish) => publish, + }; + let (qos, id, dup) = match publish.packet_identifier_dup_qos { + proto::PacketIdentifierDupQoS::AtMostOnce => { + (proto::QoS::AtMostOnce, None, false) + } + proto::PacketIdentifierDupQoS::AtLeastOnce(id, dup) => { + (proto::QoS::AtLeastOnce, Some(id.get()), dup) + } + proto::PacketIdentifierDupQoS::ExactlyOnce(id, dup) => { + (proto::QoS::ExactlyOnce, Some(id.get()), dup) + } + }; + f.debug_struct("PublishTo") + .field("qos", &qos) + .field("id", &id) + .field("dup", &dup) + .field("retain", &publish.retain) + .field("topic_name", &publish.topic_name) + .field("payload", &publish.payload) + .finish() + } + ClientEvent::PubAck0(packet_identifier) => f + .debug_struct("PubAck0") + .field("id", &packet_identifier.get()) + .finish(), + ClientEvent::PubAck(puback) => f + .debug_struct("PubAck") + .field("id", &puback.packet_identifier.get()) + .finish(), + ClientEvent::PubRec(pubrec) => f + .debug_struct("PubRec") + .field("id", &pubrec.packet_identifier.get()) + .finish(), + ClientEvent::PubRel(pubrel) => f + .debug_struct("PubRel") + .field("id", &pubrel.packet_identifier.get()) + .finish(), + ClientEvent::PubComp(pubcomp) => f + .debug_struct("PubComp") + .field("id", &pubcomp.packet_identifier.get()) + .finish(), + } + } +} + pub enum SystemEvent { /// An event for a broker to stop processing incoming event and exit. Shutdown, @@ -270,6 +387,21 @@ pub enum SystemEvent { Publish(Publication), } +impl Debug for SystemEvent { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + match self { + SystemEvent::Shutdown => f.write_str("Shutdown"), + SystemEvent::StateSnapshot(_) => f.write_str("StateSnapshot"), + SystemEvent::AuthorizationUpdate(update) => { + f.debug_tuple("AuthorizationUpdate").field(&update).finish() + } + SystemEvent::Publish(publication) => { + f.debug_tuple("Publish").field(&publication).finish() + } + } + } +} + #[derive(Debug)] pub enum Message { Client(ClientId, ClientEvent), diff --git a/mqtt/mqtt-broker/src/persist.rs b/mqtt/mqtt-broker/src/persist.rs index 08598c45f9b..b2500ed377f 100644 --- a/mqtt/mqtt-broker/src/persist.rs +++ b/mqtt/mqtt-broker/src/persist.rs @@ -544,32 +544,14 @@ mod tests { use tempfile::TempDir; use crate::{ - persist::{ConsolidatedState, FileFormat, FilePersistor, Persist, VersionedFileFormat}, + persist::{FileFormat, FilePersistor, Persist, VersionedFileFormat}, proptest::arb_broker_snapshot, BrokerSnapshot, }; proptest! { #[test] - fn consolidate_simple(state in arb_broker_snapshot()) { - let (expected_retained, expected_sessions) = state.clone().into_parts(); - - let consolidated: ConsolidatedState = state.into(); - prop_assert_eq!(expected_retained.len(), consolidated.retained.len()); - prop_assert_eq!(expected_sessions.len(), consolidated.sessions.len()); - - let state: BrokerSnapshot = consolidated.into(); - let (result_retained, result_sessions) = state.into_parts(); - - prop_assert_eq!(expected_retained, result_retained); - prop_assert_eq!(expected_sessions.len(), result_sessions.len()); - for i in 0..expected_sessions.len(){ - prop_assert_eq!(expected_sessions[i].clone().into_parts(), result_sessions[i].clone().into_parts()); - } - } - - #[test] - fn consolidate_roundtrip(state in arb_broker_snapshot()) { + fn broker_state_versioned_file_format_persistance_test(state in arb_broker_snapshot()) { let (expected_retained, expected_sessions) = state.clone().into_parts(); let format = VersionedFileFormat; let mut buffer = vec![0_u8; 10 * 1024 * 1024]; @@ -581,10 +563,7 @@ mod tests { let (result_retained, result_sessions) = state.into_parts(); prop_assert_eq!(expected_retained, result_retained); - prop_assert_eq!(expected_sessions.len(), result_sessions.len()); - for i in 0..expected_sessions.len(){ - prop_assert_eq!(expected_sessions[i].clone().into_parts(), result_sessions[i].clone().into_parts()); - } + prop_assert_eq!(expected_sessions, result_sessions); } } diff --git a/mqtt/mqtt-broker/src/proptest.rs b/mqtt/mqtt-broker/src/proptest.rs index 484af621eb9..ec61b434c55 100644 --- a/mqtt/mqtt-broker/src/proptest.rs +++ b/mqtt/mqtt-broker/src/proptest.rs @@ -1,57 +1,34 @@ -#[cfg(any(test, feature = "proptest"))] +#![cfg(any(test, feature = "proptest"))] use std::{net::IpAddr, net::SocketAddr, time::Duration}; use bytes::Bytes; use mqtt3::proto; use proptest::{ bool, - collection::{hash_map, hash_set, vec, vec_deque}, + collection::{hash_map, vec, vec_deque}, num, prelude::*, }; use crate::{ - session::identifiers::{IdentifiersInUse, PacketIdentifiers}, AuthId, BrokerSnapshot, ClientId, ClientInfo, Publish, Segment, SessionSnapshot, Subscription, TopicFilter, }; prop_compose! { pub fn arb_broker_snapshot()( - retained in hash_map(arb_topic(), arb_publication(), 0..20), - sessions in vec(arb_session_snapshot(), 0..10), + retained in hash_map(arb_topic(), arb_publication(), 0..5), + sessions in vec(arb_session_snapshot(), 0..5), ) -> BrokerSnapshot { BrokerSnapshot::new(retained, sessions) } } -prop_compose! { - pub(crate) fn arb_packet_identifiers()( - in_use in arb_identifiers_in_use(), - previous in arb_packet_identifier(), - ) -> PacketIdentifiers { - PacketIdentifiers::new(in_use, previous) - } -} - -pub(crate) fn arb_identifiers_in_use() -> impl Strategy { - vec(num::usize::ANY, PacketIdentifiers::SIZE).prop_map(|v| { - let mut array = [0; PacketIdentifiers::SIZE]; - let nums = &v[..array.len()]; - array.copy_from_slice(nums); - IdentifiersInUse(Box::new(array)) - }) -} - prop_compose! { pub fn arb_session_snapshot()( client_info in arb_client_info(), - subscriptions in hash_map(arb_topic(), arb_subscription(), 0..10), - _packet_identifiers in arb_packet_identifiers(), - waiting_to_be_sent in vec_deque(arb_publication(), 0..10), - _waiting_to_be_released in hash_map(arb_packet_identifier(), arb_proto_publish(), 0..10), - _waiting_to_be_acked in hash_map(arb_packet_identifier(), arb_publish(), 0..10), - _waiting_to_be_completed in hash_set(arb_packet_identifier(), 0..10), + subscriptions in hash_map(arb_topic(), arb_subscription(), 0..5), + waiting_to_be_sent in vec_deque(arb_publication(), 0..5), ) -> SessionSnapshot { SessionSnapshot::from_parts( client_info, @@ -81,7 +58,7 @@ prop_compose! { prop_compose! { pub fn arb_subscribe()( packet_identifier in arb_packet_identifier(), - subscribe_to in proptest::collection::vec(arb_subscribe_to(), 1..10) + subscribe_to in proptest::collection::vec(arb_subscribe_to(), 1..5) ) -> proto::Subscribe { proto::Subscribe { packet_identifier, @@ -105,7 +82,7 @@ prop_compose! { prop_compose! { pub fn arb_unsubscribe()( packet_identifier in arb_packet_identifier(), - unsubscribe_from in proptest::collection::vec(arb_topic_filter_weighted(), 1..10) + unsubscribe_from in proptest::collection::vec(arb_topic_filter_weighted(), 1..5) ) -> proto::Unsubscribe { proto::Unsubscribe { packet_identifier, @@ -194,7 +171,7 @@ pub fn arb_topic() -> impl Strategy { } pub fn arb_payload() -> impl Strategy { - vec(num::u8::ANY, 0..1024).prop_map(Bytes::from) + vec(num::u8::ANY, 0..128).prop_map(Bytes::from) } prop_compose! { diff --git a/mqtt/mqtt-broker/src/ready.rs b/mqtt/mqtt-broker/src/ready.rs index 5011cbef31b..dbbcd9e0bd3 100644 --- a/mqtt/mqtt-broker/src/ready.rs +++ b/mqtt/mqtt-broker/src/ready.rs @@ -107,6 +107,7 @@ impl AwaitingEvents for BrokerReadyEvent { fn awaiting() -> HashSet { let mut awaiting = HashSet::new(); awaiting.insert(Self::AuthorizerReady); + awaiting.insert(Self::PolicyReady); awaiting } diff --git a/mqtt/mqtt-broker/src/server.rs b/mqtt/mqtt-broker/src/server.rs index c9ca0bdbbe4..d3567879c16 100644 --- a/mqtt/mqtt-broker/src/server.rs +++ b/mqtt/mqtt-broker/src/server.rs @@ -37,6 +37,12 @@ where } } +impl Server { + pub fn listeners(&self) -> &Vec { + &self.listeners + } +} + impl Server where Z: Authorizer + Send + 'static, @@ -97,7 +103,7 @@ where pub async fn serve(self, shutdown_signal: F) -> Result where - F: Future + Unpin, + F: Future, { let Server { broker, @@ -108,7 +114,7 @@ where // prepare dispatcher in a separate task let broker_task = tokio::spawn(broker.run()); - pin_mut!(broker_task); + pin_mut!(broker_task, shutdown_signal); // prepare each transport listener let mut incoming_tasks = Vec::new(); @@ -229,7 +235,7 @@ where } } -struct Listener { +pub struct Listener { transport: Transport, authenticator: Arc<(dyn Authenticator> + Send + Sync)>, ready: Option, @@ -256,6 +262,10 @@ impl Listener { } } + pub fn transport(&self) -> &Transport { + &self.transport + } + async fn run(self, shutdown_signal: F, make_processor: P) -> Result<(), Error> where F: Future + Unpin, @@ -277,7 +287,7 @@ impl Listener { let ready = async { match ready { Some(ready) => { - info!("Waiting for broker to be ready to serve requests"); + info!("waiting for broker to be ready to serve requests"); ready.wait().await } None => Ok(()), @@ -292,7 +302,7 @@ impl Listener { let mut incoming = transport.incoming().await?; let addr = incoming.local_addr()?; - info!("Listening on address {}", addr); + info!("listening on address {}", addr); loop { match future::select(&mut shutdown_signal, incoming.next()).await { @@ -323,8 +333,8 @@ impl Listener { } Either::Right((Some(Err(e)), _)) => { warn!( - message = "accept loop exiting due to an error", - error =% DetailedErrorValue(&e) + error =% DetailedErrorValue(&e), + message = "accept loop exiting due to an error" ); break; } @@ -337,10 +347,13 @@ impl Listener { Ok(()) } Either::Left((Err(e), _)) => { - error!("error occurred when waiting for broker readiness. {}", e); + error!(error = %DetailedErrorValue(&e), "error occurred when waiting for broker readiness."); + Ok(()) + } + Either::Right((_, _)) => { + info!("shutdown signalled while waiting for broker to be ready"); Ok(()) } - Either::Right((_, _)) => Ok(()), } } .instrument(span) diff --git a/mqtt/mqtt-broker/src/session/identifiers.rs b/mqtt/mqtt-broker/src/session/identifiers.rs index 4380b42035e..b9cf574a1cf 100644 --- a/mqtt/mqtt-broker/src/session/identifiers.rs +++ b/mqtt/mqtt-broker/src/session/identifiers.rs @@ -90,11 +90,6 @@ impl PacketIdentifiers { /// We use a bitshift instead of `usize::pow` because the latter is not a const fn pub(crate) const SIZE: usize = (1 << 16) / (mem::size_of::() * 8); - #[cfg(any(test, feature = "proptest"))] - pub(crate) fn new(in_use: IdentifiersInUse, previous: proto::PacketIdentifier) -> Self { - Self { in_use, previous } - } - pub(crate) fn reserve(&mut self) -> Result { let start = self.previous; let mut current = start; diff --git a/mqtt/mqtt-broker/src/sidecar.rs b/mqtt/mqtt-broker/src/sidecar.rs new file mode 100644 index 00000000000..21f27e72ba2 --- /dev/null +++ b/mqtt/mqtt-broker/src/sidecar.rs @@ -0,0 +1,34 @@ +use std::{error::Error as StdError, future::Future, pin::Pin}; + +use async_trait::async_trait; + +/// A common trait for any additional routine that enriches MQTT broker behavior. +#[async_trait] +pub trait Sidecar { + /// Returns a new instance of a shutdown handle to be used to stop sidecar. + fn shutdown_handle(&self) -> Result; + + /// Starts a routine. + async fn run(self: Box); +} + +/// Shutdown handle to request a sidecar to stop. +pub struct SidecarShutdownHandle(Pin + Send>>); + +impl SidecarShutdownHandle { + pub fn new(shutdown: F) -> Self + where + F: Future + Send + 'static, + { + Self(Box::pin(shutdown)) + } + + pub async fn shutdown(self) { + self.0.await + } +} + +/// This error returned when there is impossible to obtain a shutdown handle. +#[derive(Debug, thiserror::Error)] +#[error("unable to obtain shutdown handler for sidecar. {0}")] +pub struct SidecarShutdownHandleError(#[source] pub Box); diff --git a/mqtt/mqtt-broker/src/snapshot.rs b/mqtt/mqtt-broker/src/snapshot.rs index df2a941c924..68d564722c8 100644 --- a/mqtt/mqtt-broker/src/snapshot.rs +++ b/mqtt/mqtt-broker/src/snapshot.rs @@ -124,7 +124,7 @@ where match event { Event::State(state) => { if let Err(e) = self.persistor.store(state).await { - warn!(message = "an error occurred persisting state snapshot.", error=%e); + warn!(message = "an error occurred persisting state snapshot.", error = %e); } } Event::Shutdown => { diff --git a/mqtt/mqtt-broker/src/subscription.rs b/mqtt/mqtt-broker/src/subscription.rs index a709bf57e35..8f2ecbb554c 100644 --- a/mqtt/mqtt-broker/src/subscription.rs +++ b/mqtt/mqtt-broker/src/subscription.rs @@ -33,7 +33,7 @@ impl Subscription { } } -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct TopicFilter { segments: Vec, multilevel: bool, diff --git a/mqtt/mqtt-broker/src/transport.rs b/mqtt/mqtt-broker/src/transport.rs index 8e298e5bf2c..c6f657954f3 100644 --- a/mqtt/mqtt-broker/src/transport.rs +++ b/mqtt/mqtt-broker/src/transport.rs @@ -92,6 +92,14 @@ impl Transport { Protocol::Tls(addr, _) => addr, } } + + /// Returns a server certificate if any. + pub fn identity(&self) -> Option<&ServerCertificate> { + match &self.protocol { + Protocol::Tcp(_) => None, + Protocol::Tls(_, identity) => Some(identity), + } + } } enum Protocol { @@ -179,12 +187,12 @@ impl Stream for IncomingTcp { match self.listener.poll_accept(cx) { Poll::Ready(Ok((tcp, _))) => match tcp.set_nodelay(true) { Ok(()) => { - debug!("TCP: Accepted connection from client"); + debug!("accepted connection from client"); Poll::Ready(Some(Ok(StreamSelector::Tcp(tcp)))) } Err(err) => { warn!( - "TCP: Dropping client because failed to setup TCP properties: {}", + "dropping client because failed to setup TCP properties: {}", err ); Poll::Ready(Some(Err(err))) @@ -192,7 +200,7 @@ impl Stream for IncomingTcp { }, Poll::Ready(Err(err)) => { error!( - "TCP: Dropping client that failed to completely establish a TCP connection: {}", + "dropping client that failed to completely establish a TCP connection: {}", err ); Poll::Ready(Some(Err(err))) @@ -231,12 +239,12 @@ impl Stream for IncomingTls { .push(Box::pin(async move { accept(&acceptor, stream).await })); } Err(err) => warn!( - "TCP: Dropping client because failed to setup TCP properties: {}", + "dropping client because failed to setup TCP properties: {}", err ), }, Poll::Ready(Err(err)) => warn!( - "TCP: Dropping client that failed to completely establish a TCP connection: {}", + "dropping client that failed to completely establish a TCP connection: {}", err ), Poll::Pending => break, @@ -250,17 +258,17 @@ impl Stream for IncomingTls { match Pin::new(&mut self.connections).poll_next(cx) { Poll::Ready(Some(Ok(stream))) => { - debug!("TLS: Accepted connection from client"); + debug!("accepted connection from client"); return Poll::Ready(Some(Ok(StreamSelector::Tls(stream)))); } Poll::Ready(Some(Err(err))) => warn!( - "TLS: Dropping client that failed to complete a TLS handshake: {}", + "dropping client that failed to complete a TLS handshake: {}", err ), Poll::Ready(None) => { - debug!("TLS: Shutting down web server"); + debug!("shutting down web server"); return Poll::Ready(None); } diff --git a/mqtt/mqtt-edgehub/Cargo.toml b/mqtt/mqtt-edgehub/Cargo.toml index ee0f8af2685..9973a503f3e 100644 --- a/mqtt/mqtt-edgehub/Cargo.toml +++ b/mqtt/mqtt-edgehub/Cargo.toml @@ -16,6 +16,7 @@ futures-util = "0.3" lazy_static = "1.4" openssl = "0.10" parking_lot = "0.11" +proptest = { version = "0.9", optional = true } regex = "1" serde = { version = "1.0", features = ["derive", "rc"] } serde_json = "1.0" @@ -35,12 +36,12 @@ assert_matches = "1.3" lazy_static = "1.4" matches = "0.1" mockito = "0.25" -proptest = "0.9" serial_test = "0.4" test-case = "1.0" tokio = { version = "0.2", features = ["macros"] } mqtt-broker-tests-util = { path = "../mqtt-broker-tests-util" } +mqtt-broker = { path = "../mqtt-broker", features = ["proptest"] } [[test]] name = "translation" diff --git a/mqtt/mqtt-edgehub/src/auth/authorization/edgehub.rs b/mqtt/mqtt-edgehub/src/auth/authorization/edgehub.rs index 60638671333..2234fc03e77 100644 --- a/mqtt/mqtt-edgehub/src/auth/authorization/edgehub.rs +++ b/mqtt/mqtt-edgehub/src/auth/authorization/edgehub.rs @@ -1,10 +1,12 @@ -use std::{any::Any, cell::RefCell, collections::HashMap, error::Error as StdError, fmt}; +use std::{any::Any, collections::HashMap, error::Error as StdError, fmt}; +use lazy_static::lazy_static; +use regex::Regex; use serde::{Deserialize, Serialize}; use tracing::info; use mqtt_broker::{ - auth::{Activity, Authorization, Authorizer, Connect, Operation, Publish, Subscribe}, + auth::{Activity, Authorization, Authorizer, Operation}, AuthId, BrokerReadyEvent, BrokerReadyHandle, ClientId, }; @@ -15,10 +17,11 @@ use mqtt_broker::{ /// /// For non-iothub-specific primitives it delegates the request to an inner authorizer (`PolicyAuthorizer`). pub struct EdgeHubAuthorizer { - iothub_allowed_topics: RefCell>>, identities_cache: HashMap, inner: Z, broker_ready: Option, + device_id: String, + iothub_id: String, } impl EdgeHubAuthorizer @@ -26,37 +29,49 @@ where Z: Authorizer, E: StdError, { - pub fn new(authorizer: Z, broker_ready: BrokerReadyHandle) -> Self { - Self::create(authorizer, Some(broker_ready)) + pub fn new( + authorizer: Z, + device_id: impl Into, + iothub_id: impl Into, + broker_ready: BrokerReadyHandle, + ) -> Self { + Self::create(authorizer, device_id, iothub_id, Some(broker_ready)) } - pub fn without_ready_handle(authorizer: Z) -> Self { - Self::create(authorizer, None) + pub fn without_ready_handle( + authorizer: Z, + device_id: impl Into, + iothub_id: impl Into, + ) -> Self { + Self::create(authorizer, device_id, iothub_id, None) } - fn create(authorizer: Z, broker_ready: Option) -> Self { + fn create( + authorizer: Z, + device_id: impl Into, + iothub_id: impl Into, + broker_ready: Option, + ) -> Self { Self { - iothub_allowed_topics: RefCell::default(), identities_cache: HashMap::default(), inner: authorizer, broker_ready, + device_id: device_id.into(), + iothub_id: iothub_id.into(), } } #[allow(clippy::unused_self)] - fn authorize_connect( - &self, - activity: &Activity, - _connect: &Connect, - ) -> Result { + fn authorize_connect(&self, activity: &Activity) -> Result { match activity.client_info().auth_id() { // forbid anonymous clients to connect to the broker AuthId::Anonymous => Ok(Authorization::Forbidden( - "Anonymous clients cannot connect to broker".to_string(), + "Anonymous clients cannot connect to the broker".to_string(), )), // allow only those clients whose auth_id and client_id identical AuthId::Identity(identity) => { - if identity == activity.client_id() { + let actor_id = format!("{}/{}", self.iothub_id, activity.client_id()); + if *identity == actor_id { // delegate to inner authorizer. self.inner.authorize(activity) } else { @@ -70,44 +85,10 @@ where } } - fn authorize_publish( - &self, - activity: &Activity, - publish: &Publish, - ) -> Result { - let topic = publish.publication().topic_name(); - + fn authorize_topic(&self, activity: &Activity, topic: &str) -> Result { if is_iothub_topic(topic) { // run authorization rules for publication to IoTHub topic self.authorize_iothub_topic(activity, topic) - } else if is_forbidden_topic(topic) { - // forbid any clients to access restricted topics - Ok(Authorization::Forbidden(format!( - "{} is forbidden topic filter", - topic - ))) - } else { - // delegate to inner authorizer for to any non-iothub topics. - self.inner.authorize(activity) - } - } - - fn authorize_subscribe( - &self, - activity: &Activity, - subscribe: &Subscribe, - ) -> Result { - let topic = subscribe.topic_filter(); - - if is_iothub_topic(topic) { - // run authorization rules for subscription to IoTHub topic - self.authorize_iothub_topic(activity, topic) - } else if is_forbidden_topic_filter(topic) { - // forbid any clients to access restricted topics - Ok(Authorization::Forbidden(format!( - "{} is forbidden topic filter", - topic - ))) } else { // delegate to inner authorizer for to any non-iothub topics. self.inner.authorize(activity) @@ -121,10 +102,8 @@ where "Anonymous clients do not have access to IoTHub topics".to_string(), ), // allow authenticated clients with client_id == auth_id and accessing its own IoTHub topic - AuthId::Identity(identity) if identity == activity.client_id() => { - if self.is_iothub_topic_allowed(activity.client_id(), topic) - && self.check_authorized_cache(activity.client_id(), topic) - { + AuthId::Identity(_) => { + if self.is_iothub_operation_authorized(topic, activity.client_id()) { Authorization::Allowed } else { // check if iothub policy is overridden by a custom policy. @@ -137,75 +116,99 @@ where } } } - // forbid access otherwise - AuthId::Identity(_) => Authorization::Forbidden(format!( - "client_id {} must match registered iothub identity id to access IoTHub topic", - activity.client_id() - )), }) } - fn is_iothub_topic_allowed(&self, client_id: &ClientId, topic: &str) -> bool { - let mut iothub_allowed_topics = self.iothub_allowed_topics.borrow_mut(); - let allowed_topics = iothub_allowed_topics - .entry(client_id.clone()) - .or_insert_with(|| allowed_iothub_topic(&client_id)); - - allowed_topics - .iter() - .any(|allowed_topic| topic.starts_with(allowed_topic)) - } - - fn check_authorized_cache(&self, client_id: &ClientId, topic: &str) -> bool { - match get_on_behalf_of_id(topic) { - Some(on_behalf_of_id) if client_id == &on_behalf_of_id => { - self.identities_cache.contains_key(client_id) + fn is_iothub_operation_authorized(&self, topic: &str, client_id: &ClientId) -> bool { + // actor id is either id of a leaf/edge device or on-behalf-of id (when + // child edge acting on behalf of it's own children). + match get_actor_id(topic) { + Some(actor_id) if actor_id == *client_id => { + // if actor = client, it means it is a regular leaf/edge device request. + // check that it is in the current edgehub auth chain. + // + // [edgehub] <- [actor = client (leaf or child edge)] + match self.identities_cache.get(&actor_id) { + Some(identity) => identity + .auth_chain() + .map_or(false, |auth_chain| auth_chain.contains(&self.device_id)), + None => false, + } } - Some(on_behalf_of_id) => self - .identities_cache - .get(&on_behalf_of_id) - .and_then(IdentityUpdate::auth_chain) - .map_or(false, |auth_chain| auth_chain.contains(client_id.as_str())), - None => { - // If there is no on_behalf_of_id, we are dealing with a legacy topic - // The client_id must still be in the identities cache - self.identities_cache.contains_key(client_id) + Some(actor_id) => { + // if actor != client, it means it is an on-behalf-of request. + // check that: + // - actor_id is in the auth chain for client_id (that client + // making a request can actually do it on behalf of the actor) + // - check that actor_id is in the auth chain for current edgehub. + // - check that client_id is in the auth chain for current edgehub. + // + // [edgehub] <- [client (child edgehub)] <- [actor (grandchild)] + + let parent_ok = match self.identities_cache.get(&actor_id) { + Some(identity) => identity.auth_chain().map_or(false, |auth_chain| { + auth_chain.contains(&client_id.as_str().replace("/$edgeHub", "")) + }), + None => false, + }; + + let actor_ok = match self.identities_cache.get(&actor_id) { + Some(identity) => identity + .auth_chain() + .map_or(false, |auth_chain| auth_chain.contains(&self.device_id)), + None => false, + }; + + let client_ok = match self.identities_cache.get(client_id) { + Some(identity) => identity + .auth_chain() + .map_or(false, |auth_chain| auth_chain.contains(&self.device_id)), + None => false, + }; + + parent_ok && actor_ok && client_ok } + // If there is no actor_id, we are dealing with a legacy topic/unknown format. + // Delegated to inner authorizer. + None => false, } } } -fn get_on_behalf_of_id(topic: &str) -> Option { - // topics without the new topic format cannot have on_behalf_of_ids - if !topic.starts_with("$iothub/clients") { - return None; +fn get_actor_id(topic: &str) -> Option { + lazy_static! { + static ref TOPIC_PATTERN: Regex = Regex::new( + // this regex tries to capture all possible iothub/edgehub specific topic format. + // we need this + // - to validate that this is correct iothub/edgehub topic. + // - to extract device_id and module_id. + // + // format! is for ease of reading only. + &format!(r"^(\$edgehub|\$iothub)/(?P[^/\+\#]+)(/(?P[^/\+\#]+))?/({}|{}|{}|{}|{}|{}|{}|{}|{})", + "messages/events", + "messages/c2d/post", + "twin/desired", + "twin/reported", + "twin/get", + "twin/res", + "methods/post", + "methods/res", + "\\+/inputs") + ).expect("failed to create new Regex from pattern"); } - let topic_parts = topic.split('/').collect::>(); - let device_id = topic_parts.get(2); - let module_id = match topic_parts.get(3) { - Some(s) if *s == "modules" => topic_parts.get(4), - _ => None, - }; - match (device_id, module_id) { - (Some(device_id), Some(module_id)) => Some(format!("{}/{}", device_id, module_id).into()), - (Some(device_id), None) => Some((*device_id).into()), - _ => None, + match TOPIC_PATTERN.captures(topic) { + Some(captures) => match (captures.name("device_id"), captures.name("module_id")) { + (Some(device_id), None) => Some(device_id.as_str().into()), + (Some(device_id), Some(module_id)) => { + Some(format!("{}/{}", device_id.as_str(), module_id.as_str()).into()) + } + (_, _) => None, + }, + None => None, } } -const FORBIDDEN_TOPIC_FILTER_PREFIXES: [&str; 2] = ["#", "$"]; - -fn is_forbidden_topic_filter(topic_filter: &str) -> bool { - FORBIDDEN_TOPIC_FILTER_PREFIXES - .iter() - .any(|prefix| topic_filter.starts_with(prefix)) -} - -fn is_forbidden_topic(topic_filter: &str) -> bool { - topic_filter.starts_with('$') -} - const IOTHUB_TOPICS_PREFIX: [&str; 2] = ["$edgehub/", "$iothub/"]; fn is_iothub_topic(topic: &str) -> bool { @@ -214,37 +217,6 @@ fn is_iothub_topic(topic: &str) -> bool { .any(|prefix| topic.starts_with(prefix)) } -fn allowed_iothub_topic(client_id: &ClientId) -> Vec { - let client_id_parts = client_id.as_str().split('/').collect::>(); - let x = match client_id_parts.len() { - 1 => client_id_parts[0].to_string(), - 2 => format!("{}/modules/{}", client_id_parts[0], client_id_parts[1]), - _ => { - panic!("ClientId cannot have more than deviceId and moduleId"); - } - }; - vec![ - format!("$edgehub/{}/messages/events", client_id), - format!("$edgehub/{}/messages/c2d/post", client_id), - format!("$edgehub/{}/twin/desired", client_id), - format!("$edgehub/{}/twin/reported", client_id), - format!("$edgehub/{}/twin/get", client_id), - format!("$edgehub/{}/twin/res", client_id), - format!("$edgehub/{}/methods/post", client_id), - format!("$edgehub/{}/methods/res", client_id), - format!("$edgehub/{}/inputs", client_id), - format!("$edgehub/{}/outputs", client_id), - format!("$iothub/clients/{}/messages/events", x), - format!("$iothub/clients/{}/messages/c2d/post", x), - format!("$iothub/clients/{}/twin/patch/properties/desired", x), - format!("$iothub/clients/{}/twin/patch/properties/reported", x), - format!("$iothub/clients/{}/twin/get", x), - format!("$iothub/clients/{}/twin/res", x), - format!("$iothub/clients/{}/methods/post", x), - format!("$iothub/clients/{}/methods/res", x), - ] -} - impl Authorizer for EdgeHubAuthorizer where Z: Authorizer, @@ -254,9 +226,13 @@ where fn authorize(&self, activity: &Activity) -> Result { match activity.operation() { - Operation::Connect(connect) => self.authorize_connect(activity, &connect), - Operation::Publish(publish) => self.authorize_publish(activity, &publish), - Operation::Subscribe(subscribe) => self.authorize_subscribe(activity, &subscribe), + Operation::Connect => self.authorize_connect(activity), + Operation::Publish(publish) => { + self.authorize_topic(activity, &publish.publication().topic_name()) + } + Operation::Subscribe(subscribe) => { + self.authorize_topic(activity, &subscribe.topic_filter()) + } } } @@ -344,74 +320,153 @@ mod tests { use super::{AuthorizerUpdate, EdgeHubAuthorizer, IdentityUpdate}; - #[test_case(&tests::connect_activity("device-1", AuthId::Anonymous); "anonymous clients")] - #[test_case(&tests::connect_activity("device-1", "device-2"); "different auth_id and client_id")] + #[test_case(&tests::connect_activity("leaf-1", AuthId::Anonymous); "anonymous clients")] + #[test_case(&tests::connect_activity("leaf-1", "leaf-2"); "different auth_id and client_id")] fn it_forbids_to_connect(activity: &Activity) { - let authorizer = authorizer(AllowAll); + let authorizer = authorizer(AllowAll, vec![]); let auth = authorizer.authorize(&activity); assert_matches!(auth, Ok(Authorization::Forbidden(_))); } - #[test_case(&tests::subscribe_activity("device-1", "device-1", "$edgehub/device-1/messages/events"); "device events")] - #[test_case(&tests::subscribe_activity("device-1/module-a", "device-1/module-a", "$edgehub/device-1/module-a/messages/events"); "edge module events")] - #[test_case(&tests::subscribe_activity("device-1", "device-1", "$edgehub/device-1/messages/c2d/post"); "device C2D messages")] - #[test_case(&tests::subscribe_activity("device-1/module-a", "device-1/module-a", "$edgehub/device-1/module-a/messages/c2d/post"); "edge module C2D messages")] - #[test_case(&tests::subscribe_activity("device-1", "device-1", "$edgehub/device-1/twin/desired"); "device update desired properties")] - #[test_case(&tests::subscribe_activity("device-1/module-a", "device-1/module-a", "$edgehub/device-1/module-a/twin/desired"); "edge module update desired properties")] - #[test_case(&tests::subscribe_activity("device-1", "device-1", "$edgehub/device-1/twin/reported"); "device update reported properties")] - #[test_case(&tests::subscribe_activity("device-1/module-a", "device-1/module-a", "$edgehub/device-1/module-a/twin/reported"); "edge module update reported properties")] - #[test_case(&tests::subscribe_activity("device-1", "device-1", "$edgehub/device-1/twin/get"); "device twin request")] - #[test_case(&tests::subscribe_activity("device-1/module-a", "device-1/module-a", "$edgehub/device-1/module-a/twin/get"); "edge module twin request")] - #[test_case(&tests::subscribe_activity("device-1", "device-1", "$edgehub/device-1/twin/res"); "device twin response")] - #[test_case(&tests::subscribe_activity("device-1/module-a", "device-1/module-a", "$edgehub/device-1/module-a/inputs/route1"); "edge module access M2M inputs")] - #[test_case(&tests::subscribe_activity("device-1/module-a", "device-1/module-a", "$edgehub/device-1/module-a/outputs/route1"); "edge module access M2M outputs")] - #[test_case(&tests::subscribe_activity("device-1", "device-1", "$iothub/clients/device-1/messages/events"); "iothub telemetry")] - #[test_case(&tests::subscribe_activity("device-1/module-a", "device-1/module-a", "$iothub/clients/device-1/modules/module-a/messages/events"); "iothub telemetry with moduleId")] - #[test_case(&tests::subscribe_activity("device-1", "device-1", "$iothub/clients/device-1/messages/c2d/post"); "iothub c2d messages")] - #[test_case(&tests::subscribe_activity("device-1/module-a", "device-1/module-a", "$iothub/clients/device-1/modules/module-a/messages/c2d/post"); "iothub c2d messages with moduleId")] - #[test_case(&tests::subscribe_activity("device-1", "device-1", "$iothub/clients/device-1/twin/patch/properties/desired"); "iothub update desired properties")] - #[test_case(&tests::subscribe_activity("device-1/module-a", "device-1/module-a", "$iothub/clients/device-1/modules/module-a/twin/patch/properties/desired"); "iothub update desired properties with moduleId")] - #[test_case(&tests::subscribe_activity("device-1", "device-1", "$iothub/clients/device-1/twin/patch/properties/reported"); "iothub update reported properties")] - #[test_case(&tests::subscribe_activity("device-1/module-a", "device-1/module-a", "$iothub/clients/device-1/modules/module-a/twin/patch/properties/reported"); "iothub update reported properties with moduleId")] - #[test_case(&tests::subscribe_activity("device-1", "device-1", "$iothub/clients/device-1/twin/get"); "iothub device twin request")] - #[test_case(&tests::subscribe_activity("device-1/module-a", "device-1/module-a", "$iothub/clients/device-1/modules/module-a/twin/get"); "iothub module twin request")] - #[test_case(&tests::subscribe_activity("device-1", "device-1", "$iothub/clients/device-1/twin/res"); "iothub device twin response")] - #[test_case(&tests::subscribe_activity("device-1/module-a", "device-1/module-a", "$iothub/clients/device-1/modules/module-a/twin/res"); "iothub module twin response")] - #[test_case(&tests::subscribe_activity("device-1", "device-1", "$iothub/clients/device-1/methods/post"); "iothub device DM request")] - #[test_case(&tests::subscribe_activity("device-1/module-a", "device-1/module-a", "$iothub/clients/device-1/modules/module-a/methods/post"); "iothub module DM request")] - #[test_case(&tests::subscribe_activity("device-1", "device-1", "$iothub/clients/device-1/methods/res"); "iothub device DM response")] - #[test_case(&tests::subscribe_activity("device-1/module-a", "device-1/module-a", "$iothub/clients/device-1/modules/module-a/methods/res"); "iothub module DM response")] - fn it_allows_to_subscribe_to(activity: &Activity) { - let authorizer = authorizer(DenyAll); + #[test_case(&tests::subscribe_activity("leaf-1", "leaf-1", "$edgehub/leaf-1/messages/events"); "device events")] + #[test_case(&tests::subscribe_activity("leaf-1", "leaf-1", "$edgehub/leaf-1/messages/c2d/post"); "device C2D messages")] + #[test_case(&tests::subscribe_activity("leaf-1", "leaf-1", "$edgehub/leaf-1/twin/desired"); "device update desired properties")] + #[test_case(&tests::subscribe_activity("leaf-1", "leaf-1", "$edgehub/leaf-1/twin/reported"); "device update reported properties")] + #[test_case(&tests::subscribe_activity("leaf-1", "leaf-1", "$edgehub/leaf-1/twin/get"); "device twin request")] + #[test_case(&tests::subscribe_activity("leaf-1", "leaf-1", "$edgehub/leaf-1/twin/res"); "device twin response")] + #[test_case(&tests::subscribe_activity("leaf-1", "leaf-1", "$edgehub/leaf-1/methods/post"); "device DM request")] + #[test_case(&tests::subscribe_activity("leaf-1", "leaf-1", "$edgehub/leaf-1/methods/res"); "device DM response")] + #[test_case(&tests::subscribe_activity("leaf-1", "leaf-1", "$iothub/leaf-1/messages/events"); "iothub telemetry")] + #[test_case(&tests::subscribe_activity("leaf-1", "leaf-1", "$iothub/leaf-1/messages/c2d/post"); "iothub c2d messages")] + #[test_case(&tests::subscribe_activity("leaf-1", "leaf-1", "$iothub/leaf-1/twin/desired"); "iothub update desired properties")] + #[test_case(&tests::subscribe_activity("leaf-1", "leaf-1", "$iothub/leaf-1/twin/reported"); "iothub update reported properties")] + #[test_case(&tests::subscribe_activity("leaf-1", "leaf-1", "$iothub/leaf-1/twin/get"); "iothub device twin request")] + #[test_case(&tests::subscribe_activity("leaf-1", "leaf-1", "$iothub/leaf-1/twin/res"); "iothub device twin response")] + #[test_case(&tests::subscribe_activity("leaf-1", "leaf-1", "$iothub/leaf-1/methods/post"); "iothub device DM request")] + #[test_case(&tests::subscribe_activity("leaf-1", "leaf-1", "$iothub/leaf-1/methods/res"); "iothub device DM response")] + fn it_allows_to_subscribe_for_leaf(activity: &Activity) { + let identities = vec![IdentityUpdate { + identity: "leaf-1".to_string(), + auth_chain: Some("leaf-1;this_edge".to_string()), + }]; + let authorizer = authorizer(DenyAll, identities); let auth = authorizer.authorize(&activity); assert_matches!(auth, Ok(Authorization::Allowed)); } - #[test_case(&tests::subscribe_activity("device-1", "device-1", "#"); "everything")] - #[test_case(&tests::subscribe_activity("device-1", "device-1", "$SYS/connected"); "SYS topics")] - #[test_case(&tests::subscribe_activity("device-1", "device-1", "$CUSTOM/topic"); "any special topics")] - #[test_case(&tests::subscribe_activity("device-1", "device-1", "$upstream/#"); "everything with upstream prefixed")] - #[test_case(&tests::subscribe_activity("device-1", "device-1", "$downstream/#"); "everything with downstream prefixed")] - #[test_case(&tests::subscribe_activity("device-1", "device-2", "$edgehub/device-1/twin/get"); "twin request for another client")] - #[test_case(&tests::subscribe_activity("device-1", AuthId::Anonymous, "$edgehub/device-1/twin/get"); "twin request by anonymous client")] - fn it_forbids_to_subscribe_to(activity: &Activity) { - let authorizer = authorizer(AllowAll); + #[test_case(&tests::subscribe_activity("this_edge/module-a", "this_edge/module-a", "$edgehub/this_edge/module-a/messages/events"); "device events")] + #[test_case(&tests::subscribe_activity("this_edge/module-a", "this_edge/module-a", "$edgehub/this_edge/module-a/messages/c2d/post"); "device C2D messages")] + #[test_case(&tests::subscribe_activity("this_edge/module-a", "this_edge/module-a", "$edgehub/this_edge/module-a/twin/desired"); "device update desired properties")] + #[test_case(&tests::subscribe_activity("this_edge/module-a", "this_edge/module-a", "$edgehub/this_edge/module-a/twin/reported"); "device update reported properties")] + #[test_case(&tests::subscribe_activity("this_edge/module-a", "this_edge/module-a", "$edgehub/this_edge/module-a/twin/get"); "device twin request")] + #[test_case(&tests::subscribe_activity("this_edge/module-a", "this_edge/module-a", "$edgehub/this_edge/module-a/twin/res"); "device twin response")] + #[test_case(&tests::subscribe_activity("this_edge/module-a", "this_edge/module-a", "$edgehub/this_edge/module-a/methods/post"); "device DM request")] + #[test_case(&tests::subscribe_activity("this_edge/module-a", "this_edge/module-a", "$edgehub/this_edge/module-a/methods/res"); "device DM response")] + #[test_case(&tests::subscribe_activity("this_edge/module-a", "this_edge/module-a", "$edgehub/this_edge/module-a/+/inputs/route1"); "edge module access M2M inputs")] + #[test_case(&tests::subscribe_activity("this_edge/module-a", "this_edge/module-a", "$iothub/this_edge/module-a/messages/events"); "iothub telemetry")] + #[test_case(&tests::subscribe_activity("this_edge/module-a", "this_edge/module-a", "$iothub/this_edge/module-a/messages/c2d/post"); "iothub c2d messages")] + #[test_case(&tests::subscribe_activity("this_edge/module-a", "this_edge/module-a", "$iothub/this_edge/module-a/twin/desired"); "iothub update desired properties")] + #[test_case(&tests::subscribe_activity("this_edge/module-a", "this_edge/module-a", "$iothub/this_edge/module-a/twin/reported"); "iothub update reported properties")] + #[test_case(&tests::subscribe_activity("this_edge/module-a", "this_edge/module-a", "$iothub/this_edge/module-a/twin/get"); "iothub device twin request")] + #[test_case(&tests::subscribe_activity("this_edge/module-a", "this_edge/module-a", "$iothub/this_edge/module-a/twin/res"); "iothub device twin response")] + #[test_case(&tests::subscribe_activity("this_edge/module-a", "this_edge/module-a", "$iothub/this_edge/module-a/methods/post"); "iothub device DM request")] + #[test_case(&tests::subscribe_activity("this_edge/module-a", "this_edge/module-a", "$iothub/this_edge/module-a/methods/res"); "iothub device DM response")] + fn it_allows_to_subscribe_for_module(activity: &Activity) { + let identities = vec![IdentityUpdate { + identity: "this_edge/module-a".to_string(), + auth_chain: Some("this_edge/module-a;this_edge".to_string()), + }]; + let authorizer = authorizer(DenyAll, identities); let auth = authorizer.authorize(&activity); - assert_matches!(auth, Ok(Authorization::Forbidden(_))); + assert_matches!(auth, Ok(Authorization::Allowed)); } - #[test_case(&tests::connect_activity("device-1", "device-1"); "identical auth_id and client_id")] - #[test_case(&tests::publish_activity("device-1", "device-1", "topic"); "generic MQTT topic publish")] - #[test_case(&tests::subscribe_activity("device-1", "device-1", "topic"); "generic MQTT topic subscribe")] + #[test_case(&tests::subscribe_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$edgehub/edge-1/module-a/messages/events"); "edge module events")] + #[test_case(&tests::subscribe_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$edgehub/edge-1/module-a/messages/c2d/post"); "edge module C2D messages")] + #[test_case(&tests::subscribe_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$edgehub/edge-1/module-a/twin/desired"); "edge module update desired properties")] + #[test_case(&tests::subscribe_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$edgehub/edge-1/module-a/twin/reported"); "edge module update reported properties")] + #[test_case(&tests::subscribe_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$edgehub/edge-1/module-a/twin/get"); "edge module twin request")] + #[test_case(&tests::subscribe_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$edgehub/edge-1/module-a/methods/post"); "module DM request")] + #[test_case(&tests::subscribe_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$edgehub/edge-1/module-a/methods/res"); "module DM response")] + #[test_case(&tests::subscribe_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$iothub/edge-1/module-a/messages/events"); "iothub telemetry with moduleId")] + #[test_case(&tests::subscribe_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$iothub/edge-1/module-a/messages/c2d/post"); "iothub c2d messages with moduleId")] + #[test_case(&tests::subscribe_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$iothub/edge-1/module-a/twin/desired"); "iothub update desired properties with moduleId")] + #[test_case(&tests::subscribe_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$iothub/edge-1/module-a/twin/reported"); "iothub update reported properties with moduleId")] + #[test_case(&tests::subscribe_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$iothub/edge-1/module-a/twin/get"); "iothub module twin request")] + #[test_case(&tests::subscribe_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$iothub/edge-1/module-a/twin/res"); "iothub module twin response")] + #[test_case(&tests::subscribe_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$iothub/edge-1/module-a/methods/post"); "iothub module DM request")] + #[test_case(&tests::subscribe_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$iothub/edge-1/module-a/methods/res"); "iothub module DM response")] + fn it_allows_to_subscribe_for_on_behalf_of_module(activity: &Activity) { + let identities = vec![ + // child edgehub + IdentityUpdate { + identity: "edge-1/$edgeHub".to_string(), + auth_chain: Some("edge-1/$edgeHub;this_edge".to_string()), + }, + // grandchild module + IdentityUpdate { + identity: "edge-1/module-a".to_string(), + auth_chain: Some("edge-1/module-a;edge-1;this_edge".to_string()), + }, + ]; + let authorizer = authorizer(DenyAll, identities); + + let auth = authorizer.authorize(&activity); + + assert_matches!(auth, Ok(Authorization::Allowed)); + } + + #[test_case(&tests::subscribe_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$edgehub/leaf-2/messages/events"); "edge module events")] + #[test_case(&tests::subscribe_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$edgehub/leaf-2/messages/c2d/post"); "edge module C2D messages")] + #[test_case(&tests::subscribe_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$edgehub/leaf-2/twin/desired"); "edge module update desired properties")] + #[test_case(&tests::subscribe_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$edgehub/leaf-2/twin/reported"); "edge module update reported properties")] + #[test_case(&tests::subscribe_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$edgehub/leaf-2/twin/get"); "edge module twin request")] + #[test_case(&tests::subscribe_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$edgehub/leaf-2/methods/post"); "module DM request")] + #[test_case(&tests::subscribe_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$edgehub/leaf-2/methods/res"); "module DM response")] + #[test_case(&tests::subscribe_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$iothub/leaf-2/messages/events"); "iothub telemetry with moduleId")] + #[test_case(&tests::subscribe_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$iothub/leaf-2/messages/c2d/post"); "iothub c2d messages with moduleId")] + #[test_case(&tests::subscribe_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$iothub/leaf-2/twin/desired"); "iothub update desired properties with moduleId")] + #[test_case(&tests::subscribe_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$iothub/leaf-2/twin/reported"); "iothub update reported properties with moduleId")] + #[test_case(&tests::subscribe_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$iothub/leaf-2/twin/get"); "iothub module twin request")] + #[test_case(&tests::subscribe_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$iothub/leaf-2/twin/res"); "iothub module twin response")] + #[test_case(&tests::subscribe_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$iothub/leaf-2/methods/post"); "iothub module DM request")] + #[test_case(&tests::subscribe_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$iothub/leaf-2/methods/res"); "iothub module DM response")] + fn it_allows_to_subscribe_for_on_behalf_of_leaf(activity: &Activity) { + let identities = vec![ + // child edgehub + IdentityUpdate { + identity: "edge-1/$edgeHub".to_string(), + auth_chain: Some("edge-1/$edgeHub;this_edge".to_string()), + }, + // grandchild leaf + IdentityUpdate { + identity: "leaf-2".to_string(), + auth_chain: Some("leaf-2;edge-1;this_edge".to_string()), + }, + ]; + let authorizer = authorizer(DenyAll, identities); + + let auth = authorizer.authorize(&activity); + + assert_matches!(auth, Ok(Authorization::Allowed)); + } + + #[test_case(&tests::connect_activity("edge-1/$edgeHub", "myhub.azure-devices.net/edge-1/$edgeHub"); "module identical auth_id and client_id")] + #[test_case(&tests::connect_activity("edge-1", "myhub.azure-devices.net/edge-1"); "leaf identical auth_id and client_id")] + #[test_case(&tests::publish_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "topic"); "module generic MQTT topic publish")] + #[test_case(&tests::publish_activity("edge-1", "edge-1", "topic"); "leaf generic MQTT topic publish")] + #[test_case(&tests::subscribe_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "topic"); "module generic MQTT topic subscribe")] + #[test_case(&tests::subscribe_activity("edge-1", "edge-1", "topic"); "leaf generic MQTT topic subscribe")] fn it_delegates_to_inner(activity: &Activity) { let inner = authorize_fn_ok(|_| Authorization::Forbidden("not allowed inner".to_string())); - let authorizer = EdgeHubAuthorizer::without_ready_handle(inner); + let authorizer = + EdgeHubAuthorizer::without_ready_handle(inner, "edgehub_id", "myhub.azure-devices.net"); let auth = authorizer.authorize(&activity); @@ -419,92 +474,185 @@ mod tests { assert_matches!(auth, Ok(auth) if auth == Authorization::Forbidden("not allowed inner".to_string())); } - #[test_case(&tests::publish_activity("device-1", "device-1", "$edgehub/some/topic"); "arbitrary edgehub prefixed topic")] - #[test_case(&tests::publish_activity("device-1", "device-1", "$iothub/some/topic"); "arbitrary iothub prefixed topic")] - #[test_case(&tests::subscribe_activity("device-1", "device-1", "$edgehub/#"); "everything with edgehub prefixed")] - #[test_case(&tests::subscribe_activity("device-1", "device-1", "$iothub/#"); "everything with iothub prefixed")] - #[test_case(&tests::subscribe_activity("device-1", "device-1", "$edgehub/device-2/twin/get"); "twin request for another device")] - #[test_case(&tests::subscribe_activity("device-1", "device-1", "$edgehub/+/twin/get"); "twin request for arbitrary device")] - #[test_case(&tests::subscribe_activity("device-1", "device-1", "$edgehub/device-1/twin/+"); "both twin operations")] + #[test_case(&tests::publish_activity("leaf-1", "leaf-1", "$edgehub/some/topic"); "arbitrary edgehub prefixed topic")] + #[test_case(&tests::publish_activity("leaf-1", "leaf-1", "$iothub/some/topic"); "arbitrary iothub prefixed topic")] + #[test_case(&tests::subscribe_activity("leaf-1", "leaf-1", "$edgehub/#"); "everything with edgehub prefixed")] + #[test_case(&tests::subscribe_activity("leaf-1", "leaf-1", "$iothub/#"); "everything with iothub prefixed")] + #[test_case(&tests::subscribe_activity("leaf-1", "leaf-1", "$edgehub/leaf-2/twin/get"); "twin request for another device")] + #[test_case(&tests::subscribe_activity("leaf-1", "leaf-1", "$edgehub/+/twin/get"); "twin request for arbitrary device")] + #[test_case(&tests::subscribe_activity("leaf-1", "leaf-1", "$edgehub/edge-1/twin/+"); "both twin operations")] + #[test_case(&tests::subscribe_activity("leaf-1", "leaf-1", "$iothub/leaf-2/twin/get"); "iothub twin request for another device")] + #[test_case(&tests::subscribe_activity("leaf-1", "leaf-1", "$iothub/+/twin/get"); "iothub twin request for arbitrary device")] + #[test_case(&tests::subscribe_activity("leaf-1", "leaf-1", "$iothub/edge-1/twin/+"); "iothub both twin operations")] fn iothub_primitives_overridden_by_inner(activity: &Activity) { // these primitives must be denied, but overridden by AllowAll - let authorizer = authorizer(AllowAll); + let authorizer = authorizer(AllowAll, vec![]); let auth = authorizer.authorize(&activity); assert_matches!(auth, Ok(Authorization::Allowed)); } - #[test_case(&tests::publish_activity("device-1", "device-1", "$edgehub/device-1/messages/events"); "device events")] - #[test_case(&tests::publish_activity("device-1/module-a", "device-1/module-a", "$edgehub/device-1/module-a/messages/events"); "edge module events")] - #[test_case(&tests::publish_activity("device-1", "device-1", "$edgehub/device-1/messages/c2d/post"); "device C2D messages")] - #[test_case(&tests::publish_activity("device-1/module-a", "device-1/module-a", "$edgehub/device-1/module-a/messages/c2d/post"); "edge module C2D messages")] - #[test_case(&tests::publish_activity("device-1", "device-1", "$edgehub/device-1/twin/desired"); "device update desired properties")] - #[test_case(&tests::publish_activity("device-1/module-a", "device-1/module-a", "$edgehub/device-1/module-a/twin/desired"); "edge module update desired properties")] - #[test_case(&tests::publish_activity("device-1", "device-1", "$edgehub/device-1/twin/reported"); "device update reported properties")] - #[test_case(&tests::publish_activity("device-1/module-a", "device-1/module-a", "$edgehub/device-1/module-a/twin/reported"); "edge module update reported properties")] - #[test_case(&tests::publish_activity("device-1", "device-1", "$edgehub/device-1/twin/get"); "device twin request")] - #[test_case(&tests::publish_activity("device-1/module-a", "device-1/module-a", "$edgehub/device-1/module-a/twin/get"); "edge module twin request")] - #[test_case(&tests::publish_activity("device-1", "device-1", "$edgehub/device-1/twin/res"); "device twin response")] - #[test_case(&tests::publish_activity("device-1/module-a", "device-1/module-a", "$edgehub/device-1/module-a/twin/res"); "edge module twin response")] - #[test_case(&tests::publish_activity("device-1/module-a", "device-1/module-a", "$edgehub/device-1/module-a/outputs/route1"); "edge module access M2M outputs")] - #[test_case(&tests::publish_activity("device-1/module-a", "device-1/module-a", "$edgehub/device-1/module-a/inputs/route1"); "edge module access M2M inputs")] - #[test_case(&tests::publish_activity("device-1", "device-1", "$iothub/clients/device-1/messages/events"); "iothub telemetry")] - #[test_case(&tests::publish_activity("device-1/module-a", "device-1/module-a", "$iothub/clients/device-1/modules/module-a/messages/events"); "iothub telemetry with moduleId")] - #[test_case(&tests::publish_activity("device-1", "device-1", "$iothub/clients/device-1/messages/c2d/post"); "iothub c2d messages")] - #[test_case(&tests::publish_activity("device-1/module-a", "device-1/module-a", "$iothub/clients/device-1/modules/module-a/messages/c2d/post"); "iothub c2d messages with moduleId")] - #[test_case(&tests::publish_activity("device-1", "device-1", "$iothub/clients/device-1/twin/patch/properties/desired"); "iothub update desired properties")] - #[test_case(&tests::publish_activity("device-1/module-a", "device-1/module-a", "$iothub/clients/device-1/modules/module-a/twin/patch/properties/desired"); "iothub update desired properties with moduleId")] - #[test_case(&tests::publish_activity("device-1", "device-1", "$iothub/clients/device-1/twin/patch/properties/reported"); "iothub update reported properties")] - #[test_case(&tests::publish_activity("device-1/module-a", "device-1/module-a", "$iothub/clients/device-1/modules/module-a/twin/patch/properties/reported"); "iothub update reported properties with moduleId")] - #[test_case(&tests::publish_activity("device-1", "device-1", "$iothub/clients/device-1/twin/get"); "iothub device twin request")] - #[test_case(&tests::publish_activity("device-1/module-a", "device-1/module-a", "$iothub/clients/device-1/modules/module-a/twin/get"); "iothub module twin request")] - #[test_case(&tests::publish_activity("device-1", "device-1", "$iothub/clients/device-1/twin/res"); "iothub device twin response")] - #[test_case(&tests::publish_activity("device-1/module-a", "device-1/module-a", "$iothub/clients/device-1/modules/module-a/twin/res"); "iothub module twin response")] - #[test_case(&tests::publish_activity("device-1", "device-1", "$iothub/clients/device-1/methods/post"); "iothub device DM request")] - #[test_case(&tests::publish_activity("device-1/module-a", "device-1/module-a", "$iothub/clients/device-1/modules/module-a/methods/post"); "iothub module DM request")] - #[test_case(&tests::publish_activity("device-1", "device-1", "$iothub/clients/device-1/methods/res"); "iothub device DM response")] - #[test_case(&tests::publish_activity("device-1/module-a", "device-1/module-a", "$iothub/clients/device-1/modules/module-a/methods/res"); "iothub module DM response")] - fn it_allows_to_publish_to(activity: &Activity) { - let authorizer = authorizer(DenyAll); + #[test_case(&tests::publish_activity("leaf-1", "leaf-1", "$edgehub/leaf-1/messages/events"); "device events")] + #[test_case(&tests::publish_activity("leaf-1", "leaf-1", "$edgehub/leaf-1/messages/c2d/post"); "device C2D messages")] + #[test_case(&tests::publish_activity("leaf-1", "leaf-1", "$edgehub/leaf-1/twin/desired"); "device update desired properties")] + #[test_case(&tests::publish_activity("leaf-1", "leaf-1", "$edgehub/leaf-1/twin/reported"); "device update reported properties")] + #[test_case(&tests::publish_activity("leaf-1", "leaf-1", "$edgehub/leaf-1/twin/get"); "device twin request")] + #[test_case(&tests::publish_activity("leaf-1", "leaf-1", "$edgehub/leaf-1/twin/res"); "device twin response")] + #[test_case(&tests::publish_activity("leaf-1", "leaf-1", "$edgehub/leaf-1/methods/post"); "device DM request")] + #[test_case(&tests::publish_activity("leaf-1", "leaf-1", "$edgehub/leaf-1/methods/res"); "device DM response")] + #[test_case(&tests::publish_activity("leaf-1", "leaf-1", "$iothub/leaf-1/messages/events"); "iothub telemetry")] + #[test_case(&tests::publish_activity("leaf-1", "leaf-1", "$iothub/leaf-1/messages/c2d/post"); "iothub c2d messages")] + #[test_case(&tests::publish_activity("leaf-1", "leaf-1", "$iothub/leaf-1/twin/desired"); "iothub update desired properties")] + #[test_case(&tests::publish_activity("leaf-1", "leaf-1", "$iothub/leaf-1/twin/reported"); "iothub update reported properties")] + #[test_case(&tests::publish_activity("leaf-1", "leaf-1", "$iothub/leaf-1/twin/get"); "iothub device twin request")] + #[test_case(&tests::publish_activity("leaf-1", "leaf-1", "$iothub/leaf-1/twin/res"); "iothub device twin response")] + #[test_case(&tests::publish_activity("leaf-1", "leaf-1", "$iothub/leaf-1/methods/post"); "iothub device DM request")] + #[test_case(&tests::publish_activity("leaf-1", "leaf-1", "$iothub/leaf-1/methods/res"); "iothub device DM response")] + fn it_allows_to_publish_for_leaf(activity: &Activity) { + let identities = vec![IdentityUpdate { + identity: "leaf-1".to_string(), + auth_chain: Some("leaf-1;this_edge".to_string()), + }]; + let authorizer = authorizer(DenyAll, identities); let auth = authorizer.authorize(&activity); assert_matches!(auth, Ok(Authorization::Allowed)); } - #[test_case(&tests::publish_activity("device-1", "device-1", "$downstream/some/topic"); "any downstream prefixed topics")] - #[test_case(&tests::publish_activity("device-1", "device-1", "$upstream/some/topic"); "any upstream prefixed topics")] - #[test_case(&tests::publish_activity("device-1", "device-2", "$edgehub/device-1/twin/get"); "twin request for another client")] - #[test_case(&tests::publish_activity("device-1", AuthId::Anonymous, "$edgehub/device-1/twin/get"); "twin request by anonymous client")] - #[test_case(&tests::publish_activity("device-1", "device-1", "$SYS/foo"); "any system topic")] - #[test_case(&tests::publish_activity("device-1", "device-1", "$CUSTOM/foo"); "any special topic")] - fn it_forbids_to_publish_to(activity: &Activity) { - let authorizer = authorizer(AllowAll); + #[test_case(&tests::publish_activity("this_edge/module-a", "this_edge/module-a", "$edgehub/this_edge/module-a/messages/events"); "device events")] + #[test_case(&tests::publish_activity("this_edge/module-a", "this_edge/module-a", "$edgehub/this_edge/module-a/messages/c2d/post"); "device C2D messages")] + #[test_case(&tests::publish_activity("this_edge/module-a", "this_edge/module-a", "$edgehub/this_edge/module-a/twin/desired"); "device update desired properties")] + #[test_case(&tests::publish_activity("this_edge/module-a", "this_edge/module-a", "$edgehub/this_edge/module-a/twin/reported"); "device update reported properties")] + #[test_case(&tests::publish_activity("this_edge/module-a", "this_edge/module-a", "$edgehub/this_edge/module-a/twin/get"); "device twin request")] + #[test_case(&tests::publish_activity("this_edge/module-a", "this_edge/module-a", "$edgehub/this_edge/module-a/twin/res"); "device twin response")] + #[test_case(&tests::publish_activity("this_edge/module-a", "this_edge/module-a", "$edgehub/this_edge/module-a/methods/post"); "device DM request")] + #[test_case(&tests::publish_activity("this_edge/module-a", "this_edge/module-a", "$edgehub/this_edge/module-a/methods/res"); "device DM response")] + #[test_case(&tests::publish_activity("this_edge/module-a", "this_edge/module-a", "$iothub/this_edge/module-a/messages/events"); "iothub telemetry")] + #[test_case(&tests::publish_activity("this_edge/module-a", "this_edge/module-a", "$iothub/this_edge/module-a/messages/c2d/post"); "iothub c2d messages")] + #[test_case(&tests::publish_activity("this_edge/module-a", "this_edge/module-a", "$iothub/this_edge/module-a/twin/desired"); "iothub update desired properties")] + #[test_case(&tests::publish_activity("this_edge/module-a", "this_edge/module-a", "$iothub/this_edge/module-a/twin/reported"); "iothub update reported properties")] + #[test_case(&tests::publish_activity("this_edge/module-a", "this_edge/module-a", "$iothub/this_edge/module-a/twin/get"); "iothub device twin request")] + #[test_case(&tests::publish_activity("this_edge/module-a", "this_edge/module-a", "$iothub/this_edge/module-a/twin/res"); "iothub device twin response")] + #[test_case(&tests::publish_activity("this_edge/module-a", "this_edge/module-a", "$iothub/this_edge/module-a/methods/post"); "iothub device DM request")] + #[test_case(&tests::publish_activity("this_edge/module-a", "this_edge/module-a", "$iothub/this_edge/module-a/methods/res"); "iothub device DM response")] + fn it_allows_to_publish_for_module(activity: &Activity) { + let identities = vec![IdentityUpdate { + identity: "this_edge/module-a".to_string(), + auth_chain: Some("this_edge/module-a;this_edge".to_string()), + }]; + let authorizer = authorizer(DenyAll, identities); + + let auth = authorizer.authorize(&activity); + + assert_matches!(auth, Ok(Authorization::Allowed)); + } + + #[test_case(&tests::publish_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$edgehub/leaf-2/messages/events"); "edge module events")] + #[test_case(&tests::publish_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$edgehub/leaf-2/messages/c2d/post"); "edge module C2D messages")] + #[test_case(&tests::publish_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$edgehub/leaf-2/twin/desired"); "edge module update desired properties")] + #[test_case(&tests::publish_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$edgehub/leaf-2/twin/reported"); "edge module update reported properties")] + #[test_case(&tests::publish_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$edgehub/leaf-2/twin/get"); "edge module twin request")] + #[test_case(&tests::publish_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$edgehub/leaf-2/twin/res"); "edge module twin response")] + #[test_case(&tests::publish_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$edgehub/leaf-2/methods/post"); "module DM request")] + #[test_case(&tests::publish_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$edgehub/leaf-2/methods/res"); "module DM response")] + #[test_case(&tests::publish_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$iothub/leaf-2/messages/events"); "iothub telemetry with moduleId")] + #[test_case(&tests::publish_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$iothub/leaf-2/messages/c2d/post"); "iothub c2d messages with moduleId")] + #[test_case(&tests::publish_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$iothub/leaf-2/twin/desired"); "iothub update desired properties with moduleId")] + #[test_case(&tests::publish_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$iothub/leaf-2/twin/reported"); "iothub update reported properties with moduleId")] + #[test_case(&tests::publish_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$iothub/leaf-2/twin/get"); "iothub module twin request")] + #[test_case(&tests::publish_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$iothub/leaf-2/twin/res"); "iothub module twin response")] + #[test_case(&tests::publish_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$iothub/leaf-2/methods/post"); "iothub module DM request")] + #[test_case(&tests::publish_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$iothub/leaf-2/methods/res"); "iothub module DM response")] + fn it_allows_to_publish_for_on_behalf_of_leaf(activity: &Activity) { + let identities = vec![ + // child edgehub + IdentityUpdate { + identity: "edge-1/$edgeHub".to_string(), + auth_chain: Some("edge-1/$edgeHub;this_edge".to_string()), + }, + // grandchild leaf + IdentityUpdate { + identity: "leaf-2".to_string(), + auth_chain: Some("leaf-2;edge-1;this_edge".to_string()), + }, + ]; + let authorizer = authorizer(DenyAll, identities); + + let auth = authorizer.authorize(&activity); + + assert_matches!(auth, Ok(Authorization::Allowed)); + } + + #[test_case(&tests::publish_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$edgehub/edge-1/module-a/messages/events"); "edge module events")] + #[test_case(&tests::publish_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$edgehub/edge-1/module-a/messages/c2d/post"); "edge module C2D messages")] + #[test_case(&tests::publish_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$edgehub/edge-1/module-a/twin/desired"); "edge module update desired properties")] + #[test_case(&tests::publish_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$edgehub/edge-1/module-a/twin/reported"); "edge module update reported properties")] + #[test_case(&tests::publish_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$edgehub/edge-1/module-a/twin/get"); "edge module twin request")] + #[test_case(&tests::publish_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$edgehub/edge-1/module-a/twin/res"); "edge module twin response")] + #[test_case(&tests::publish_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$edgehub/edge-1/module-a/methods/post"); "module DM request")] + #[test_case(&tests::publish_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$edgehub/edge-1/module-a/methods/res"); "module DM response")] + #[test_case(&tests::publish_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$iothub/edge-1/module-a/messages/events"); "iothub telemetry with moduleId")] + #[test_case(&tests::publish_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$iothub/edge-1/module-a/messages/c2d/post"); "iothub c2d messages with moduleId")] + #[test_case(&tests::publish_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$iothub/edge-1/module-a/twin/desired"); "iothub update desired properties with moduleId")] + #[test_case(&tests::publish_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$iothub/edge-1/module-a/twin/reported"); "iothub update reported properties with moduleId")] + #[test_case(&tests::publish_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$iothub/edge-1/module-a/twin/get"); "iothub module twin request")] + #[test_case(&tests::publish_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$iothub/edge-1/module-a/twin/res"); "iothub module twin response")] + #[test_case(&tests::publish_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$iothub/edge-1/module-a/methods/post"); "iothub module DM request")] + #[test_case(&tests::publish_activity("edge-1/$edgeHub", "edge-1/$edgeHub", "$iothub/edge-1/module-a/methods/res"); "iothub module DM response")] + fn it_allows_to_publish_for_on_behalf_of_module(activity: &Activity) { + let identities = vec![ + // child edgehub + IdentityUpdate { + identity: "edge-1/$edgeHub".to_string(), + auth_chain: Some("edge-1/$edgeHub;this_edge".to_string()), + }, + // grandchild module + IdentityUpdate { + identity: "edge-1/module-a".to_string(), + auth_chain: Some("edge-1/module-a;edge-1;this_edge".to_string()), + }, + ]; + let authorizer = authorizer(DenyAll, identities); + + let auth = authorizer.authorize(&activity); + + assert_matches!(auth, Ok(Authorization::Allowed)); + } + + #[test_case(&tests::subscribe_activity("edge-1/module-a", "edge-1/module-a", "$edgehub/edge-1/module-a/messages/events"); "module events sub")] + #[test_case(&tests::publish_activity("edge-1/module-a", "edge-1/module-a", "$edgehub/edge-1/module-a/messages/events"); "module events pub")] + #[test_case(&tests::subscribe_activity("leaf-1", "leaf-1", "$edgehub/leaf-1/messages/events"); "leaf events sub")] + #[test_case(&tests::publish_activity("leaf-1", "leaf-1", "$edgehub/leaf-1/messages/events"); "leaf events pub")] + fn it_forbids_operations_for_not_in_scope_identities(activity: &Activity) { + let identities = vec![ + // leaf + IdentityUpdate { + identity: "another-leaf".to_string(), + auth_chain: Some("another-leaf;this_edge".to_string()), + }, + // module + IdentityUpdate { + identity: "this_edge/another-module".to_string(), + auth_chain: Some("edge-1/another-module;this_edge".to_string()), + }, + ]; + let authorizer = authorizer(DenyAll, identities); let auth = authorizer.authorize(&activity); assert_matches!(auth, Ok(Authorization::Forbidden(_))); } - fn authorizer(inner: Z) -> EdgeHubAuthorizer + fn authorizer(inner: Z, identities: Vec) -> EdgeHubAuthorizer where Z: Authorizer, { - let mut authorizer = EdgeHubAuthorizer::without_ready_handle(inner); + let mut authorizer = + EdgeHubAuthorizer::without_ready_handle(inner, "this_edge", "myhub.azure-devices.net"); - let service_identity = IdentityUpdate { - identity: "device-1".to_string(), - auth_chain: Some("edgeB;device-1;".to_string()), - }; - let service_identity2 = IdentityUpdate { - identity: "device-1/module-a".to_string(), - auth_chain: Some("edgeB;device-1/module-a;".to_string()), - }; - let _ = authorizer.update(Box::new(AuthorizerUpdate(vec![ - service_identity, - service_identity2, - ]))); + let _ = authorizer.update(Box::new(AuthorizerUpdate(identities))); authorizer } } diff --git a/mqtt/mqtt-edgehub/src/auth/authorization/local.rs b/mqtt/mqtt-edgehub/src/auth/authorization/local.rs index 75dad02b21f..e9993623665 100644 --- a/mqtt/mqtt-edgehub/src/auth/authorization/local.rs +++ b/mqtt/mqtt-edgehub/src/auth/authorization/local.rs @@ -44,8 +44,6 @@ where #[cfg(test)] mod tests { - use std::time::Duration; - use matches::assert_matches; use test_case::test_case; @@ -81,18 +79,7 @@ mod tests { } fn connect_activity(peer_addr: &str) -> Activity { - let client_id = proto::ClientId::IdWithCleanSession("local-client".into()); - let connect = proto::Connect { - username: None, - password: None, - will: None, - client_id, - keep_alive: Duration::from_secs(1), - protocol_name: mqtt3::PROTOCOL_NAME.to_string(), - protocol_level: mqtt3::PROTOCOL_LEVEL, - }; - - let operation = Operation::new_connect(connect); + let operation = Operation::new_connect(); activity("client_id".into(), operation, peer_addr) } diff --git a/mqtt/mqtt-edgehub/src/auth/authorization/mod.rs b/mqtt/mqtt-edgehub/src/auth/authorization/mod.rs index 77c0bdfe20d..4f11333401a 100644 --- a/mqtt/mqtt-edgehub/src/auth/authorization/mod.rs +++ b/mqtt/mqtt-edgehub/src/auth/authorization/mod.rs @@ -8,23 +8,11 @@ pub use local::LocalAuthorizer; #[cfg(test)] mod tests { - use std::time::Duration; - use mqtt3::proto; use mqtt_broker::{auth::Activity, auth::Operation, AuthId, ClientInfo}; pub(crate) fn connect_activity(client_id: &str, auth_id: impl Into) -> Activity { - let connect = proto::Connect { - username: None, - password: None, - will: None, - client_id: proto::ClientId::IdWithCleanSession(client_id.into()), - keep_alive: Duration::from_secs(1), - protocol_name: mqtt3::PROTOCOL_NAME.to_string(), - protocol_level: mqtt3::PROTOCOL_LEVEL, - }; - - let operation = Operation::new_connect(connect); + let operation = Operation::new_connect(); activity(operation, client_id, auth_id) } diff --git a/mqtt/mqtt-edgehub/src/auth/authorization/policy.rs b/mqtt/mqtt-edgehub/src/auth/authorization/policy.rs index a3c3bca8ea8..9544490ddb7 100644 --- a/mqtt/mqtt-edgehub/src/auth/authorization/policy.rs +++ b/mqtt/mqtt-edgehub/src/auth/authorization/policy.rs @@ -86,6 +86,14 @@ pub struct PolicyUpdate { definition: String, } +impl PolicyUpdate { + pub fn new(definition: impl Into) -> Self { + Self { + definition: definition.into(), + } + } +} + #[derive(Debug, Error)] pub enum Error { #[error("An error occurred authorizing the request: {0}")] @@ -117,7 +125,7 @@ fn identity(activity: &Activity) -> &str { fn operation(activity: &Activity) -> &str { match activity.operation() { - Operation::Connect(_) => "mqtt:connect", + Operation::Connect => "mqtt:connect", Operation::Publish(_) => "mqtt:publish", Operation::Subscribe(_) => "mqtt:subscribe", } @@ -126,7 +134,7 @@ fn operation(activity: &Activity) -> &str { fn resource(activity: &Activity) -> &str { match activity.operation() { // this is intentional. mqtt:connect should have empty resource. - Operation::Connect(_) => "", + Operation::Connect => "", Operation::Publish(publish) => publish.publication().topic_name(), Operation::Subscribe(subscribe) => subscribe.topic_filter(), } diff --git a/mqtt/mqtt-edgehub/src/command/bridge_update.rs b/mqtt/mqtt-edgehub/src/command/bridge_update.rs index 0be2fadf4c4..c40ee663c5a 100644 --- a/mqtt/mqtt-edgehub/src/command/bridge_update.rs +++ b/mqtt/mqtt-edgehub/src/command/bridge_update.rs @@ -15,10 +15,8 @@ pub struct BridgeUpdateCommand { } impl BridgeUpdateCommand { - pub fn new(controller_handle: &BridgeControllerHandle) -> Self { - Self { - controller_handle: controller_handle.clone(), - } + pub fn new(controller_handle: BridgeControllerHandle) -> Self { + Self { controller_handle } } } @@ -35,7 +33,7 @@ impl Command for BridgeUpdateCommand { serde_json::from_slice(&publication.payload).map_err(Error::ParseBridgeUpdate)?; self.controller_handle - .send(update) + .send_update(update) .map_err(Error::SendBridgeUpdate)?; Ok(()) } diff --git a/mqtt/mqtt-edgehub/src/command/handler.rs b/mqtt/mqtt-edgehub/src/command/handler.rs index 7c3f9af637d..4395368de93 100644 --- a/mqtt/mqtt-edgehub/src/command/handler.rs +++ b/mqtt/mqtt-edgehub/src/command/handler.rs @@ -1,5 +1,6 @@ use std::{collections::HashMap, collections::HashSet, error::Error as StdError, time::Duration}; +use async_trait::async_trait; use futures_util::future::BoxFuture; use tokio::{net::TcpStream, stream::StreamExt}; use tracing::{debug, error, info}; @@ -7,6 +8,7 @@ use tracing::{debug, error, info}; use mqtt3::{ proto, Client, Event, IoSource, ShutdownError, SubscriptionUpdateEvent, UpdateSubscriptionError, }; +use mqtt_broker::sidecar::{Sidecar, SidecarShutdownHandle, SidecarShutdownHandleError}; use crate::command::{Command, DynCommand}; @@ -86,21 +88,46 @@ impl CommandHandler { } } - // TODO refactor and move it inside the [`run`] method - pub async fn init(&mut self) -> Result<(), CommandHandlerError> { - info!("initializing command handler..."); - let topics: Vec<_> = self.commands.keys().map(String::as_str).collect(); - subscribe(&mut self.client, &topics).await?; - Ok(()) - } - pub fn shutdown_handle(&self) -> Result { Ok(ShutdownHandle { client_shutdown: self.client.shutdown_handle()?, }) } - pub async fn run(mut self) { + async fn handle_event(&mut self, event: Event) -> Result<(), Box> { + if let Event::Publication(publication) = event { + if let Some(command) = self.commands.get_mut(&publication.topic_name) { + command.handle(&publication)?; + } + } + Ok(()) + } +} + +#[async_trait] +impl Sidecar for CommandHandler { + fn shutdown_handle(&self) -> Result { + let mut handle = self + .client + .shutdown_handle() + .map_err(|e| SidecarShutdownHandleError(Box::new(e)))?; + + let shutdown = async move { + if let Err(e) = handle.shutdown().await { + error!(error = %e, "unable to request shutdown for command handler"); + } + }; + + Ok(SidecarShutdownHandle::new(shutdown)) + } + + async fn run(mut self: Box) { + let topics = self.commands.keys().map(Clone::clone).collect(); + // TODO percolate error instead + if let Err(e) = subscribe(&mut self.client, topics).await { + error!(error = %e, "unable to subscribe to all required topics"); + } + info!("starting command handler..."); loop { @@ -122,29 +149,19 @@ impl CommandHandler { debug!("command handler stopped"); } - - async fn handle_event(&mut self, event: Event) -> Result<(), Box> { - if let Event::Publication(publication) = event { - if let Some(command) = self.commands.get_mut(&publication.topic_name) { - command.handle(&publication)?; - } - } - Ok(()) - } } async fn subscribe( - client: &mut mqtt3::Client, - topics: &[&str], + client: &mut Client, + topics: Vec, ) -> Result<(), CommandHandlerError> { - debug!( - "command handler subscribing to topics: {}", - topics.join(", ") - ); + debug!("command handler subscribing to topics: {:?}", topics); + + let mut subacks: HashSet<_> = topics.iter().map(ToString::to_string).collect(); for topic in topics { let subscription = proto::SubscribeTo { - topic_filter: (*topic).to_string(), + topic_filter: topic, qos: proto::QoS::AtLeastOnce, }; @@ -153,8 +170,6 @@ async fn subscribe( .map_err(CommandHandlerError::SubscribeFailure)?; } - let mut subacks: HashSet<_> = topics.iter().map(ToString::to_string).collect(); - while let Some(event) = client .try_next() .await @@ -180,7 +195,10 @@ async fn subscribe( } } - error!("command handler failed to subscribe to disconnect topic"); + error!( + "command handler failed to subscribe to the following topics {:?}", + subacks + ); Err(CommandHandlerError::MissingSubacks( subacks.into_iter().collect::>(), )) diff --git a/mqtt/mqtt-edgehub/src/command/mod.rs b/mqtt/mqtt-edgehub/src/command/mod.rs index a5002e0c007..f1090e2fb40 100644 --- a/mqtt/mqtt-edgehub/src/command/mod.rs +++ b/mqtt/mqtt-edgehub/src/command/mod.rs @@ -15,6 +15,7 @@ use std::error::Error as StdError; use mqtt3::ReceivedPublication; pub const AUTHORIZED_IDENTITIES_TOPIC: &str = "$internal/identities"; +pub const POLICY_UPDATE_TOPIC: &str = "$internal/authorization/policy"; pub const DISCONNECT_TOPIC: &str = "$edgehub/disconnect"; /// A command trait to be implemented and used with `CommandHandler`. diff --git a/mqtt/mqtt-edgehub/src/command/policy_update.rs b/mqtt/mqtt-edgehub/src/command/policy_update.rs index 6c772b0ec07..fd609e512c1 100644 --- a/mqtt/mqtt-edgehub/src/command/policy_update.rs +++ b/mqtt/mqtt-edgehub/src/command/policy_update.rs @@ -5,8 +5,6 @@ use mqtt_broker::{BrokerHandle, Message, SystemEvent}; use crate::{auth::PolicyUpdate, command::Command}; -const POLICY_UPDATE_TOPIC: &str = "$internal/authorization/policy"; - /// `PolicyUpdateCommand` is executed when `EdgeHub` sends a special packet /// to notify the broker that the customer authorization policy has changed, /// and that we need to update the authorizer in the broker. @@ -26,7 +24,7 @@ impl Command for PolicyUpdateCommand { type Error = Error; fn topic(&self) -> &str { - POLICY_UPDATE_TOPIC + super::POLICY_UPDATE_TOPIC } fn handle(&mut self, publication: &ReceivedPublication) -> Result<(), Self::Error> { diff --git a/mqtt/mqtt-edgehub/src/connection/delivery.rs b/mqtt/mqtt-edgehub/src/connection/delivery.rs index b8aaa4afe09..9182cc603ea 100644 --- a/mqtt/mqtt-edgehub/src/connection/delivery.rs +++ b/mqtt/mqtt-edgehub/src/connection/delivery.rs @@ -69,11 +69,11 @@ impl

PublicationDelivery

{ } fn match_m2m_publish(packet: &Packet) -> Option<(proto::PacketIdentifier, String)> { - const ANYTHING_BUT_NOT_SLASH: &str = r"[^/]+"; + const ANYTHING_BUT_SLASH: &str = r"[^/]+"; lazy_static! { static ref M2M_PUBLISH_PATTERN: Regex = Regex::new(&format!( - "\\$edgehub/{}/{}/inputs/.+", - ANYTHING_BUT_NOT_SLASH, ANYTHING_BUT_NOT_SLASH + "\\$edgehub/{}/{}/{}/inputs/.+", + ANYTHING_BUT_SLASH, ANYTHING_BUT_SLASH, ANYTHING_BUT_SLASH )) .expect("failed to create new Regex from pattern"); } diff --git a/mqtt/mqtt-edgehub/src/settings.rs b/mqtt/mqtt-edgehub/src/settings.rs index b60d9bf55d3..1a8e2198f44 100644 --- a/mqtt/mqtt-edgehub/src/settings.rs +++ b/mqtt/mqtt-edgehub/src/settings.rs @@ -1,4 +1,7 @@ -use std::path::{Path, PathBuf}; +use std::{ + env, + path::{Path, PathBuf}, +}; use config::{Config, ConfigError, Environment, File, FileFormat}; use lazy_static::lazy_static; @@ -33,6 +36,8 @@ pub struct Settings { impl Settings { pub fn new() -> Result { + convert_to_old_env_variable(); + let mut config = Config::new(); config.merge(File::from_str(DEFAULTS, FileFormat::Json))?; config.merge(Environment::new().separator("__"))?; @@ -44,6 +49,8 @@ impl Settings { where P: AsRef, { + convert_to_old_env_variable(); + let mut config = Config::new(); config.merge(File::from_str(DEFAULTS, FileFormat::Json))?; config.merge(File::from(path.as_ref()))?; @@ -71,6 +78,24 @@ impl Default for Settings { } } +fn convert_to_old_env_variable() { + if let Ok(val) = env::var("mqttBroker__max_queued_messages") { + env::set_var("BROKER__SESSION__max_queued_messages", val); + } + + if let Ok(val) = env::var("mqttBroker__max_queued_bytes") { + env::set_var("BROKER__SESSION__max_queued_size", val); + } + + if let Ok(val) = env::var("mqttBroker__max_inflight_messages") { + env::set_var("BROKER__SESSION__max_inflight_messages", val); + } + + if let Ok(val) = env::var("mqttBroker__when_full") { + env::set_var("BROKER__SESSION__when_full", val); + } +} + #[derive(Debug, Clone, PartialEq, Deserialize)] #[serde(rename_all = "snake_case")] pub struct ListenerConfig { @@ -216,6 +241,29 @@ mod tests { const DAYS: u64 = 24 * 60 * 60; + #[test] + #[serial(env_settings)] + fn check_env_var_name_override() { + let _max_inflight_messages = env::set_var("mqttBroker__max_inflight_messages", "17"); + let _max_queued_messages = env::set_var("mqttBroker__max_queued_messages", "1001"); + let _max_queued_bytes = env::set_var("mqttBroker__max_queued_bytes", "1"); + let _when_full = env::set_var("mqttBroker__when_full", "drop_old"); + + let settings = Settings::new().unwrap(); + + assert_eq!( + settings.broker().session(), + &SessionConfig::new( + Duration::from_secs(60 * DAYS), + Some(HumanSize::new_kilobytes(256).expect("256kb")), + 17, + 1001, + Some(HumanSize::new_bytes(1)), + QueueFullAction::DropOld, + ) + ); + } + #[test] fn it_loads_defaults() { let settings = Settings::default(); diff --git a/mqtt/mqtt-edgehub/src/topic/translation.rs b/mqtt/mqtt-edgehub/src/topic/translation.rs index 1980179dcda..2d64bd5a49c 100644 --- a/mqtt/mqtt-edgehub/src/topic/translation.rs +++ b/mqtt/mqtt-edgehub/src/topic/translation.rs @@ -262,11 +262,11 @@ translate_c2d! { // Module-to-Module inputs module_to_module_inputs { to_internal { - format!("devices/{}/modules/{}/(?P.+)", DEVICE_ID, MODULE_ID), - {|captures: regex::Captures<'_>, _| format!("$edgehub/{}/{}/inputs/{}", &captures["device_id"], &captures["module_id"], &captures["path"])} + format!("devices/{}/modules/{}/#", DEVICE_ID, MODULE_ID), + {|captures: regex::Captures<'_>, _| format!("$edgehub/{}/{}/+/inputs/#", &captures["device_id"], &captures["module_id"])} }, to_external { - format!("\\$edgehub/{}/{}/inputs/(?P.+)", DEVICE_ID, MODULE_ID), + format!("\\$edgehub/{}/{}/[^/]+/inputs/(?P.+)", DEVICE_ID, MODULE_ID), {|captures: regex::Captures<'_>| format!("devices/{}/modules/{}/inputs/{}", &captures["device_id"], &captures["module_id"], &captures["path"])} } } @@ -508,20 +508,13 @@ mod tests { // M2M subscription assert_eq!( c2d.to_internal("devices/device_1/modules/module_a/#", &client_id), - Some("$edgehub/device_1/module_a/inputs/#".to_owned()) - ); - assert_eq!( - c2d.to_internal( - "devices/device_1/modules/module_a/telemetry/?rid=1", - &client_id - ), - Some("$edgehub/device_1/module_a/inputs/telemetry/?rid=1".to_owned()) + Some("$edgehub/device_1/module_a/+/inputs/#".to_owned()) ); // M2M incoming assert_eq!( c2d.to_external( - "$edgehub/device_1/module_a/inputs/route_1/%24.cdid=device_1&%24.cmid=module_a" + "$edgehub/device_1/module_a/b9aa0940-dcf2-457f-83a4-45f4c7ceecf9/inputs/route_1/%24.cdid=device_1&%24.cmid=module_a" ), Some( "devices/device_1/modules/module_a/inputs/route_1/%24.cdid=device_1&%24.cmid=module_a" diff --git a/mqtt/mqtt-edgehub/tests/common/mod.rs b/mqtt/mqtt-edgehub/tests/common/mod.rs index edfed786a0e..b5c2c7b7677 100644 --- a/mqtt/mqtt-edgehub/tests/common/mod.rs +++ b/mqtt/mqtt-edgehub/tests/common/mod.rs @@ -1,25 +1,50 @@ +#![allow(dead_code)] use std::{any::Any, error::Error as StdError}; -use tokio::task::JoinHandle; +use tokio::{ + sync::mpsc::{self, UnboundedReceiver, UnboundedSender}, + task::JoinHandle, +}; use mqtt3::ShutdownError; -use mqtt_broker::auth::{Activity, Authorization, Authorizer}; +use mqtt_broker::{ + auth::{Activity, Authorization, Authorizer}, + sidecar::Sidecar, +}; use mqtt_edgehub::command::{Command, CommandHandler, ShutdownHandle}; pub const LOCAL_BROKER_SUFFIX: &str = "$edgeHub/$broker"; -// We need a Dummy Authorizer to authorize the command handler and $edgehub -// LocalAuthorizer currently wraps EdgeHubAuthorizer in production code, +// We need a `DummyAuthorizer` to authorize the command handler and $edgehub. +// +// LocalAuthorizer currently wraps `EdgeHubAuthorizer` and `PolicyAuthorizer` in production code, // but LocalAuthorizer would authorize everything in the case of an integ test. -pub struct DummyAuthorizer(Z); +// +// In addition, `DummyAuthorizer` provides a way to signal when authorizer receives updates. +pub struct DummyAuthorizer { + inner: Z, + receiver: Option>, + sender: UnboundedSender<()>, +} impl DummyAuthorizer where Z: Authorizer, { - #![allow(dead_code)] - pub fn new(authorizer: Z) -> Self { - Self(authorizer) + pub fn new(inner: Z) -> Self { + let (sender, receiver) = mpsc::unbounded_channel(); + Self { + inner, + receiver: Some(receiver), + sender, + } + } + + /// A receiver that signals when authorizer update has happened. + pub fn update_signal(&mut self) -> UnboundedReceiver<()> { + self.receiver + .take() + .expect("You can get only one receiver instance") } } @@ -39,12 +64,14 @@ where { Ok(Authorization::Allowed) } else { - self.0.authorize(activity) + self.inner.authorize(activity) } } fn update(&mut self, update: Box) -> Result<(), Self::Error> { - self.0.update(update) + self.inner.update(update)?; + self.sender.send(()).expect("unable to send update signal"); + Ok(()) } } @@ -56,11 +83,9 @@ where C: Command + Send + 'static, E: StdError + 'static, { - let mut command_handler = CommandHandler::new(system_address, "test-device"); + let mut command_handler = Box::new(CommandHandler::new(system_address, "test-device")); command_handler.add_command(command); - command_handler.init().await.unwrap(); - let shutdown_handle: ShutdownHandle = command_handler.shutdown_handle().unwrap(); let join_handle = tokio::spawn(command_handler.run()); diff --git a/mqtt/mqtt-edgehub/tests/delivery.rs b/mqtt/mqtt-edgehub/tests/delivery.rs index 272286429ac..a38d88907ec 100644 --- a/mqtt/mqtt-edgehub/tests/delivery.rs +++ b/mqtt/mqtt-edgehub/tests/delivery.rs @@ -38,7 +38,7 @@ async fn it_sends_delivery_confirmation_for_m2m_messages() { .build(); // subscribe to module inputs - let inputs = "devices/device-1/modules/module-1/telemetry/#"; + let inputs = "devices/device-1/modules/module-1/#"; module.subscribe(inputs, QoS::AtLeastOnce).await; let mut edgehub = TestClientBuilder::new(server_handle.address()) @@ -50,13 +50,14 @@ async fn it_sends_delivery_confirmation_for_m2m_messages() { edgehub.subscribe(confirmation, QoS::AtLeastOnce).await; // publish a message to module-1 - let inputs = "$edgehub/device-1/module-1/inputs/telemetry/?rid=1"; + let inputs = "$edgehub/device-1/module-1/c1906616-e64f-4cf0-96eb-33a40a2535c3/inputs/telemetry/%24.uid=something"; edgehub.publish_qos1(inputs, "message", false).await; assert_eq!( module.publications().next().await, Some(ReceivedPublication { - topic_name: "devices/device-1/modules/module-1/inputs/telemetry/?rid=1".into(), + topic_name: "devices/device-1/modules/module-1/inputs/telemetry/%24.uid=something" + .into(), dup: false, qos: QoS::AtLeastOnce, retain: false, @@ -71,7 +72,7 @@ async fn it_sends_delivery_confirmation_for_m2m_messages() { dup: false, qos: QoS::AtLeastOnce, retain: false, - payload: "\"$edgehub/device-1/module-1/inputs/telemetry/?rid=1\"".into() + payload: "\"$edgehub/device-1/module-1/c1906616-e64f-4cf0-96eb-33a40a2535c3/inputs/telemetry/%24.uid=something\"".into() }) ); diff --git a/mqtt/mqtt-edgehub/tests/authorization.rs b/mqtt/mqtt-edgehub/tests/edgehub_authorizer.rs similarity index 63% rename from mqtt/mqtt-edgehub/tests/authorization.rs rename to mqtt/mqtt-edgehub/tests/edgehub_authorizer.rs index 09961848f0b..d2b1a006a7b 100644 --- a/mqtt/mqtt-edgehub/tests/authorization.rs +++ b/mqtt/mqtt-edgehub/tests/edgehub_authorizer.rs @@ -1,12 +1,15 @@ use assert_matches::assert_matches; +use bytes::Bytes; use futures_util::StreamExt; -use mqtt_broker::{auth::authorize_fn_ok, BrokerReady}; -use mqtt3::{ - proto::ClientId, proto::ConnectReturnCode, proto::Packet, proto::PacketIdentifier, proto::QoS, - proto::SubAckQos, proto::Subscribe, proto::SubscribeTo, +use mqtt3::proto::{ + ClientId, ConnectReturnCode, Packet, PacketIdentifier, PacketIdentifierDupQoS, Publish, QoS, + SubAckQos, Subscribe, SubscribeTo, +}; +use mqtt_broker::{ + auth::{authorize_fn_ok, Authorization, Authorizer, Operation}, + BrokerBuilder, }; -use mqtt_broker::{auth::Authorization, auth::Operation, BrokerBuilder}; use mqtt_broker_tests_util::{ client::TestClientBuilder, packet_stream::PacketStream, @@ -24,33 +27,18 @@ use common::DummyAuthorizer; /// create broker /// create command handler /// connect authorized client -/// verify client has not been connected, since authorized identities haven't been sent +/// verify client can't publish or subscribe, since identities haven't been sent #[tokio::test] -async fn publish_not_allowed_identity_not_in_cache() { +async fn pub_sub_not_allowed_identity_not_in_cache() { // Start broker with DummyAuthorizer that allows everything from CommandHandler and $edgeHub, // but otherwise passes authorization along to EdgeHubAuthorizer let broker = BrokerBuilder::default() - .with_authorizer(DummyAuthorizer::new(EdgeHubAuthorizer::new( - authorize_fn_ok(|activity| { - if matches!(activity.operation(), Operation::Connect(_)) { - Authorization::Allowed - } else { - Authorization::Forbidden("not allowed".to_string()) - } - }), - BrokerReady::new().handle(), - ))) + .with_authorizer(authorizer()) .build(); - let broker_handle = broker.handle(); - - let server_handle = start_server(broker, DummyAuthenticator::with_id("device-1")); - - // start command handler with AuthorizedIdentitiesCommand - let command = AuthorizedIdentitiesCommand::new(&broker_handle); - let (command_handler_shutdown_handle, join_handle) = - common::start_command_handler(server_handle.address(), command) - .await - .expect("could not start command handler"); + let server_handle = start_server( + broker, + DummyAuthenticator::with_id("myhub.azure-devices.net/device-1"), + ); let mut device_client = PacketStream::connect( ClientId::IdWithCleanSession("device-1".into()), @@ -63,29 +51,39 @@ async fn publish_not_allowed_identity_not_in_cache() { // We should be able to connect because inner authorizer allows connects assert_matches!(device_client.next().await, Some(Packet::ConnAck(c)) if c.return_code == ConnectReturnCode::Accepted); - let s = Subscribe { - packet_identifier: PacketIdentifier::new(1).unwrap(), - subscribe_to: vec![SubscribeTo { - // We need to use a post-translation topic here - topic_filter: "$edgehub/device-1/inputs/telemetry/#".into(), - qos: QoS::AtLeastOnce, - }], - }; - - device_client.send_subscribe(s).await; // client subscribes to topic + // client subscribes to topic + device_client + .send_subscribe(Subscribe { + packet_identifier: PacketIdentifier::new(1).unwrap(), + subscribe_to: vec![SubscribeTo { + // We need to use a post-translation topic here + topic_filter: "$edgehub/device-1/twin/res/#".into(), + qos: QoS::AtLeastOnce, + }], + }) + .await; - // assert device_client couldn't subscribe because it was refused + // assert device_client couldn't subscribe because it is not in the list of allowed identities. assert_matches!( device_client.next().await, Some(Packet::SubAck(x)) if matches!(x.qos.get(0), Some(SubAckQos::Failure)) ); - command_handler_shutdown_handle - .shutdown() - .await - .expect("failed to stop command handler client"); + // client publishes to a topic. + device_client + .send_publish(Publish { + packet_identifier_dup_qos: PacketIdentifierDupQoS::AtLeastOnce( + PacketIdentifier::new(1).unwrap(), + false, + ), + retain: false, + topic_name: "$edgehub/device-1/twin/get?rid=42".into(), + payload: Bytes::from("qos 1"), + }) + .await; - join_handle.await.unwrap(); + // Verify client has been disconnected after unauthorized pub. + assert_matches!(device_client.next().await, None); } /// Scenario: @@ -99,21 +97,15 @@ async fn publish_not_allowed_identity_not_in_cache() { async fn auth_update_happy_case() { // Start broker with DummyAuthorizer that allows everything from CommandHandler and $edgeHub, // but otherwise passes authorization along to EdgeHubAuthorizer - let broker = BrokerBuilder::default() - .with_authorizer(DummyAuthorizer::new(EdgeHubAuthorizer::new( - authorize_fn_ok(|activity| { - if matches!(activity.operation(), Operation::Connect(_)) { - Authorization::Allowed - } else { - Authorization::Forbidden("not allowed".to_string()) - } - }), - BrokerReady::new().handle(), - ))) - .build(); + let mut authorizer = authorizer(); + let mut identities_ready = authorizer.update_signal(); + let broker = BrokerBuilder::default().with_authorizer(authorizer).build(); let broker_handle = broker.handle(); - let server_handle = start_server(broker, DummyAuthenticator::with_id("device-1")); + let server_handle = start_server( + broker, + DummyAuthenticator::with_id("myhub.azure-devices.net/device-1"), + ); // start command handler with AuthorizedIdentitiesCommand let command = AuthorizedIdentitiesCommand::new(&broker_handle); @@ -127,7 +119,7 @@ async fn auth_update_happy_case() { .build(); let service_identity1 = - IdentityUpdate::new("device-1".into(), Some("device-1;$edgehub".into())); + IdentityUpdate::new("device-1".into(), Some("device-1;this_edgehub_id".into())); let identities = vec![service_identity1]; // EdgeHub sends authorized identities + auth chains to broker @@ -135,15 +127,18 @@ async fn auth_update_happy_case() { .publish_qos1( AUTHORIZED_IDENTITIES_TOPIC, serde_json::to_string(&identities).expect("unable to serialize identities"), - false, + true, ) .await; + // let authorizer update sink in... + identities_ready.recv().await; + let s = Subscribe { packet_identifier: PacketIdentifier::new(1).unwrap(), subscribe_to: vec![SubscribeTo { // We need to use a post-translation topic here - topic_filter: "$edgehub/device-1/inputs/telemetry/#".into(), + topic_filter: "$edgehub/device-1/twin/res/#".into(), qos: QoS::AtLeastOnce, }], }; @@ -164,14 +159,10 @@ async fn auth_update_happy_case() { assert_matches!(device_client.next().await, Some(Packet::SubAck(_))); edgehub_client - .publish_qos1( - "$edgehub/device-1/inputs/telemetry/#", - "test_payload", - false, - ) + .publish_qos1("$edgehub/device-1/twin/res/#", "test_payload", true) .await; - assert_matches!(device_client.next().await, Some(Packet::Publish(_))); + assert_matches!(device_client.next().await, Some(Packet::Publish(p)) if p.payload == Bytes::from("test_payload")); command_handler_shutdown_handle .shutdown() @@ -191,23 +182,18 @@ async fn auth_update_happy_case() { /// publish authorization update with client removed /// verify client has disconnected #[tokio::test] -async fn disconnect_client_on_auth_update_reevaluates_subscriptions() { +async fn authorization_update_reevaluates_sessions() { // Start broker with DummyAuthorizer that allows everything from CommandHandler and $edgeHub, // but otherwise passes authorization along to EdgeHubAuthorizer - let broker = BrokerBuilder::default() - .with_authorizer(DummyAuthorizer::new( - EdgeHubAuthorizer::without_ready_handle(authorize_fn_ok(|activity| { - if matches!(activity.operation(), Operation::Connect(_)) { - Authorization::Allowed - } else { - Authorization::Forbidden("not allowed".to_string()) - } - })), - )) - .build(); + let mut authorizer = authorizer(); + let mut identities_ready = authorizer.update_signal(); + let broker = BrokerBuilder::default().with_authorizer(authorizer).build(); let broker_handle = broker.handle(); - let server_handle = start_server(broker, DummyAuthenticator::with_id("device-1")); + let server_handle = start_server( + broker, + DummyAuthenticator::with_id("myhub.azure-devices.net/device-1"), + ); // start command handler with AuthorizedIdentitiesCommand let command = AuthorizedIdentitiesCommand::new(&broker_handle); @@ -221,7 +207,7 @@ async fn disconnect_client_on_auth_update_reevaluates_subscriptions() { .build(); let service_identity1 = - IdentityUpdate::new("device-1".into(), Some("device-1;$edgehub".into())); + IdentityUpdate::new("device-1".into(), Some("device-1;this_edgehub_id".into())); let identities = vec![service_identity1]; // EdgeHub sends authorized identities + auth chains to broker @@ -229,15 +215,18 @@ async fn disconnect_client_on_auth_update_reevaluates_subscriptions() { .publish_qos1( AUTHORIZED_IDENTITIES_TOPIC, serde_json::to_string(&identities).expect("unable to serialize identities"), - false, + true, ) .await; + // let authorizer update sink in... + identities_ready.recv().await; + let s = Subscribe { packet_identifier: PacketIdentifier::new(1).unwrap(), subscribe_to: vec![SubscribeTo { // We need to use a post-translation topic here - topic_filter: "$edgehub/device-1/inputs/telemetry/#".into(), + topic_filter: "$edgehub/device-1/+/inputs/#".into(), qos: QoS::AtLeastOnce, }], }; @@ -264,10 +253,13 @@ async fn disconnect_client_on_auth_update_reevaluates_subscriptions() { .publish_qos1( AUTHORIZED_IDENTITIES_TOPIC, serde_json::to_string(&identities).expect("unable to serialize identities"), - false, + true, ) .await; + // let authorizer update sink in... + identities_ready.recv().await; + // next() will return None only if the client is disconnected, so this // asserts that the subscription has been re-evaluated and disconnected by broker. assert_eq!(device_client.next().await, None); @@ -281,3 +273,17 @@ async fn disconnect_client_on_auth_update_reevaluates_subscriptions() { edgehub_client.shutdown().await; } + +fn authorizer() -> DummyAuthorizer { + DummyAuthorizer::new(EdgeHubAuthorizer::without_ready_handle( + authorize_fn_ok(|activity| { + if matches!(activity.operation(), Operation::Connect) { + Authorization::Allowed + } else { + Authorization::Forbidden("not allowed".to_string()) + } + }), + "this_edgehub_id".to_string(), + "myhub.azure-devices.net".to_string(), + )) +} diff --git a/mqtt/mqtt-edgehub/tests/policy_authorizer.rs b/mqtt/mqtt-edgehub/tests/policy_authorizer.rs new file mode 100644 index 00000000000..1f74dba9d49 --- /dev/null +++ b/mqtt/mqtt-edgehub/tests/policy_authorizer.rs @@ -0,0 +1,297 @@ +use assert_matches::assert_matches; +use bytes::Bytes; +use futures_util::StreamExt; + +use mqtt3::proto::{ + ClientId, ConnectReturnCode, ConnectionRefusedReason, Packet, PacketIdentifier, + PacketIdentifierDupQoS, Publish, QoS, SubAckQos, Subscribe, SubscribeTo, +}; +use mqtt_broker::BrokerBuilder; +use mqtt_broker_tests_util::{ + client::TestClientBuilder, + packet_stream::PacketStream, + server::{start_server, DummyAuthenticator}, +}; +use mqtt_edgehub::{ + auth::{PolicyAuthorizer, PolicyUpdate}, + command::{PolicyUpdateCommand, POLICY_UPDATE_TOPIC}, +}; + +mod common; +use common::DummyAuthorizer; + +/// Scenario: +/// create broker +/// create command handler +/// connect a client +/// verify client can't connect, since policy haven't been sent. +#[tokio::test] +async fn connect_not_allowed_policy_not_set() { + // Start broker with DummyAuthorizer that allows everything from CommandHandler and $edgeHub, + // but otherwise passes authorization along to PolicyAuthorizer + let broker = BrokerBuilder::default() + .with_authorizer(DummyAuthorizer::new( + PolicyAuthorizer::without_ready_handle("this_edgehub_id".to_string()), + )) + .build(); + let server_handle = start_server( + broker, + DummyAuthenticator::with_id("myhub.azure-devices.net/device-1"), + ); + + let mut device_client = PacketStream::connect( + ClientId::IdWithCleanSession("device-1".into()), + server_handle.address(), + None, + None, + ) + .await; + + // Verify client cannot connect because policy is not set. + assert_matches!( + device_client.next().await, + Some(Packet::ConnAck(ack)) if ack.return_code == ConnectReturnCode::Refused(ConnectionRefusedReason::ServerUnavailable) + ); +} + +/// Scenario: +/// create broker +/// create command handler +/// publish policy update from edgehub +/// connect authorized client +/// verify client can connect, subscribe and publish +#[tokio::test] +async fn auth_policy_happy_case() { + // Start broker with DummyAuthorizer that allows everything from CommandHandler and $edgeHub, + // but otherwise passes authorization along to PolicyAuthorizer + let mut authorizer = DummyAuthorizer::new(PolicyAuthorizer::without_ready_handle( + "this_edgehub_id".to_string(), + )); + let mut policy_ready = authorizer.update_signal(); + let broker = BrokerBuilder::default().with_authorizer(authorizer).build(); + let broker_handle = broker.handle(); + + let server_handle = start_server( + broker, + DummyAuthenticator::with_id("myhub.azure-devices.net/device-1"), + ); + + // start command handler with PolicyUpdateCommand + let command = PolicyUpdateCommand::new(&broker_handle); + let (command_handler_shutdown_handle, join_handle) = + common::start_command_handler(server_handle.address(), command) + .await + .expect("could not start command handler"); + + let mut edgehub_client = TestClientBuilder::new(server_handle.address()) + .with_client_id(ClientId::IdWithCleanSession("$edgehub".into())) + .build(); + + let policy = PolicyUpdate::new( + r###"{ + "statements": [ + { + "effect": "allow", + "identities": [ + "myhub.azure-devices.net/device-1" + ], + "operations": [ + "mqtt:connect", + "mqtt:publish", + "mqtt:subscribe" + ], + "resources": [ + "#" + ] + } + ] + }"###, + ); + + // EdgeHub sends authorization policy to the broker + edgehub_client + .publish_qos1( + POLICY_UPDATE_TOPIC, + serde_json::to_string(&policy).expect("unable to serialize policy"), + true, + ) + .await; + + // let policy update sink in... + policy_ready.recv().await; + + let mut device_client = PacketStream::connect( + ClientId::IdWithCleanSession("device-1".into()), + server_handle.address(), + None, + None, + ) + .await; + + // assert connack + assert_matches!( + device_client.next().await, + Some(Packet::ConnAck(ack)) if ack.return_code == ConnectReturnCode::Accepted + ); + + // client subscribes to a topic + device_client + .send_subscribe(Subscribe { + packet_identifier: PacketIdentifier::new(1).unwrap(), + subscribe_to: vec![SubscribeTo { + topic_filter: "custom/topic".into(), + qos: QoS::AtLeastOnce, + }], + }) + .await; + + // assert suback + assert_matches!( + device_client.next().await, + Some(Packet::SubAck(ack)) if ack.qos[0] == SubAckQos::Success(QoS::AtLeastOnce) + ); + + // client publishes to a topic + device_client + .send_publish(Publish { + packet_identifier_dup_qos: PacketIdentifierDupQoS::AtLeastOnce( + PacketIdentifier::new(1).unwrap(), + false, + ), + retain: false, + topic_name: "custom/topic".into(), + payload: Bytes::from("qos 1"), + }) + .await; + + // assert puback + assert_matches!(device_client.next().await, Some(Packet::PubAck(_))); + + command_handler_shutdown_handle + .shutdown() + .await + .expect("failed to stop command handler client"); + + join_handle.await.unwrap(); + + edgehub_client.shutdown().await; +} + +/// Scenario: +/// create broker +/// create command handler +/// publish policy update from edgehub +/// connect authorized client +/// publish policy update with client access removed +/// verify client has disconnected +#[tokio::test] +async fn policy_update_reevaluates_sessions() { + mqtt_broker_tests_util::init_logging(); + // Start broker with DummyAuthorizer that allows everything from CommandHandler and $edgeHub, + // but otherwise passes authorization along to PolicyAuthorizer + let mut authorizer = DummyAuthorizer::new(PolicyAuthorizer::without_ready_handle( + "this_edgehub_id".to_string(), + )); + let mut policy_update_signal = authorizer.update_signal(); + let broker = BrokerBuilder::default().with_authorizer(authorizer).build(); + let broker_handle = broker.handle(); + + let server_handle = start_server( + broker, + DummyAuthenticator::with_id("myhub.azure-devices.net/device-1"), + ); + + // start command handler with PolicyUpdateCommand + let command = PolicyUpdateCommand::new(&broker_handle); + let (command_handler_shutdown_handle, join_handle) = + common::start_command_handler(server_handle.address(), command) + .await + .expect("could not start command handler"); + + let mut edgehub_client = TestClientBuilder::new(server_handle.address()) + .with_client_id(ClientId::IdWithCleanSession("$edgehub".into())) + .build(); + + // EdgeHub sends authorization policy to the broker + let policy = PolicyUpdate::new( + r###"{ + "statements": [ + { + "effect": "allow", + "identities": [ + "myhub.azure-devices.net/device-1" + ], + "operations": [ + "mqtt:connect" + ] + } + ] + }"###, + ); + + edgehub_client + .publish_qos1( + POLICY_UPDATE_TOPIC, + serde_json::to_string(&policy).expect("unable to serialize policy"), + true, + ) + .await; + + // let policy update sink in... + policy_update_signal.recv().await; + + let mut device_client = PacketStream::connect( + ClientId::IdWithCleanSession("device-1".into()), + server_handle.address(), + None, + None, + ) + .await; + + // assert connack + assert_matches!( + device_client.next().await, + Some(Packet::ConnAck(ack)) if ack.return_code == ConnectReturnCode::Accepted + ); + + // EdgeHub sends updated authorization policy to the broker + // where client no longer allowed to connect + let policy = PolicyUpdate::new( + r###"{ + "statements": [ + { + "effect": "deny", + "identities": [ + "myhub.azure-devices.net/device-1" + ], + "operations": [ + "mqtt:connect" + ] + } + ] + }"###, + ); + + edgehub_client + .publish_qos1( + POLICY_UPDATE_TOPIC, + serde_json::to_string(&policy).expect("unable to serialize policy"), + true, + ) + .await; + + // let policy update sink in... + policy_update_signal.recv().await; + + // assert client disconnected + assert_matches!(device_client.next().await, None); + + command_handler_shutdown_handle + .shutdown() + .await + .expect("failed to stop command handler client"); + + join_handle.await.unwrap(); + + edgehub_client.shutdown().await; +} diff --git a/mqtt/mqtt-edgehub/tests/translation.rs b/mqtt/mqtt-edgehub/tests/translation.rs index c3b6d110863..ded661e2575 100644 --- a/mqtt/mqtt-edgehub/tests/translation.rs +++ b/mqtt/mqtt-edgehub/tests/translation.rs @@ -1,32 +1,26 @@ -use std::{ - error::Error as StdError, - sync::atomic::{AtomicU32, Ordering}, -}; - -use lazy_static::lazy_static; use matches::assert_matches; +use mqtt_edgehub::connection::MakeEdgeHubPacketProcessor; use proptest::prelude::*; -use tokio::{runtime::Runtime, sync::oneshot}; +use tokio::runtime::Runtime; -use futures_util::FutureExt; use mqtt3::{ proto::{ClientId, QoS}, ReceivedPublication, }; use mqtt_broker::{ - auth::{AllowAll, Authenticator, Authorizer}, - proptest::arb_clientid, - Broker, BrokerBuilder, Server, + auth::AllowAll, auth::Authorizer, Broker, BrokerBuilder, MakeMqttPacketProcessor, Server, +}; +use mqtt_broker_tests_util::{ + client::{TestClient, TestClientBuilder}, + server::{self, DummyAuthenticator, ServerHandle}, }; -use mqtt_broker_tests_util::{DummyAuthenticator, ServerHandle, TestClient, TestClientBuilder}; -use mqtt_edgehub::connection::MakeEdgeHubPacketProcessor; // https://docs.microsoft.com/en-us/azure/iot-hub/iot-hub-mqtt-support#retrieving-a-device-twins-properties #[tokio::test] async fn translation_twin_retrieve() { let broker = BrokerBuilder::default().with_authorizer(AllowAll).build(); - let server_handle = start_server(broker, DummyAuthenticator::anonymous()); + let server_handle = start_server(broker); let mut edge_hub_core = TestClientBuilder::new(server_handle.address()) .with_client_id(ClientId::IdWithCleanSession("edge_hub_core".into())) @@ -71,7 +65,7 @@ async fn translation_twin_retrieve() { async fn translation_twin_update() { let broker = BrokerBuilder::default().with_authorizer(AllowAll).build(); - let server_handle = start_server(broker, DummyAuthenticator::anonymous()); + let server_handle = start_server(broker); let mut edge_hub_core = TestClientBuilder::new(server_handle.address()) .with_client_id(ClientId::IdWithCleanSession("edge_hub_core".into())) @@ -120,7 +114,7 @@ async fn translation_twin_update() { async fn translation_twin_receive() { let broker = BrokerBuilder::default().with_authorizer(AllowAll).build(); - let server_handle = start_server(broker, DummyAuthenticator::anonymous()); + let server_handle = start_server(broker); let mut edge_hub_core = TestClientBuilder::new(server_handle.address()) .with_client_id(ClientId::IdWithCleanSession("edge_hub_core".into())) @@ -160,7 +154,7 @@ async fn translation_twin_receive() { async fn translation_direct_method_response() { let broker = BrokerBuilder::default().with_authorizer(AllowAll).build(); - let server_handle = start_server(broker, DummyAuthenticator::anonymous()); + let server_handle = start_server(broker); let mut edge_hub_core = TestClientBuilder::new(server_handle.address()) .with_client_id(ClientId::IdWithCleanSession("edge_hub_core".into())) @@ -225,7 +219,7 @@ async fn translation_twin_notify_with_wildcards() { proptest! { #[test] - fn translate_clientid_proptest(client_id in arb_clientid()) { + fn translate_clientid_proptest(client_id in mqtt_broker::proptest::arb_clientid()) { let mut rt = Runtime::new().unwrap(); rt.block_on(test_twin_with_client_id(client_id.as_str())); } @@ -234,7 +228,7 @@ proptest! { async fn test_twin_with_client_id(client_id: &str) { let broker = BrokerBuilder::default().with_authorizer(AllowAll).build(); - let server_handle = start_server(broker, DummyAuthenticator::anonymous()); + let server_handle = start_server(broker); let mut edge_hub_core = TestClientBuilder::new(server_handle.address()) .with_client_id(ClientId::IdWithCleanSession("edge_hub_core".into())) @@ -331,28 +325,21 @@ async fn ensure_subscribed(client: &mut TestClient) { client.subscriptions().recv().await; } -fn start_server(broker: Broker, authenticator: N) -> ServerHandle +fn start_server(broker: Broker) -> ServerHandle where - N: Authenticator> + Send + Sync + 'static, Z: Authorizer + Send + Sync + 'static, { - lazy_static! { - static ref PORT: AtomicU32 = AtomicU32::new(5555); - } + let make_server = |addr| { + let broker_handle = broker.handle(); - let port = PORT.fetch_add(1, Ordering::SeqCst); - let address = format!("localhost:{}", port); + let mut server = Server::from_broker(broker).with_packet_processor( + MakeEdgeHubPacketProcessor::new(broker_handle, MakeMqttPacketProcessor), + ); - let mut server = Server::from_broker(broker) - .with_packet_processor(MakeEdgeHubPacketProcessor::default()) - .with_tcp(&address, authenticator, None); + let authenticator = DummyAuthenticator::anonymous(); + server.with_tcp(&addr, authenticator, None).unwrap(); - let (shutdown, rx) = oneshot::channel::<()>(); - let task = tokio::spawn(server.serve(rx.map(drop))); - - ServerHandle { - address, - shutdown: Some(shutdown), - task: Some(task), - } + server + }; + server::run(make_server) } diff --git a/mqtt/mqtt-policy/Cargo.toml b/mqtt/mqtt-policy/Cargo.toml index 89174fb1f94..990fad1ed84 100644 --- a/mqtt/mqtt-policy/Cargo.toml +++ b/mqtt/mqtt-policy/Cargo.toml @@ -17,4 +17,5 @@ policy = { path = "../policy" } [dev-dependencies] assert_matches = "1.3" bytes = "0.5" +proptest = "0.9" test-case = "1.0" \ No newline at end of file diff --git a/mqtt/mqtt-policy/src/lib.rs b/mqtt/mqtt-policy/src/lib.rs index 028429e2197..5f055000449 100644 --- a/mqtt/mqtt-policy/src/lib.rs +++ b/mqtt/mqtt-policy/src/lib.rs @@ -28,8 +28,6 @@ pub(crate) const EDGEHUB_ID_VAR: &str = "{{iot:this_device_id}}"; #[cfg(test)] mod tests { - use std::time::Duration; - use bytes::Bytes; use mqtt3::proto; use mqtt_broker::{auth::Activity, auth::Operation, AuthId, ClientId, ClientInfo}; @@ -40,16 +38,8 @@ mod tests { ) -> Activity { let client_id = client_id.into(); Activity::new( - ClientInfo::new(client_id.clone(), "127.0.0.1:80".parse().unwrap(), auth_id), - Operation::new_connect(proto::Connect { - username: None, - password: None, - will: None, - client_id: proto::ClientId::IdWithExistingSession(client_id.to_string()), - keep_alive: Duration::default(), - protocol_name: mqtt3::PROTOCOL_NAME.into(), - protocol_level: mqtt3::PROTOCOL_LEVEL, - }), + ClientInfo::new(client_id, "127.0.0.1:80".parse().unwrap(), auth_id), + Operation::new_connect(), ) } diff --git a/mqtt/mqtt-policy/src/matcher.rs b/mqtt/mqtt-policy/src/matcher.rs index 7270cbce1df..7386927649e 100644 --- a/mqtt/mqtt-policy/src/matcher.rs +++ b/mqtt/mqtt-policy/src/matcher.rs @@ -31,7 +31,7 @@ impl ResourceMatcher for MqttTopicFilterMatcher { Some(context) => { match context.operation() { // special case for Connect operation, since it doesn't really have a "resource". - Operation::Connect(_) => true, + Operation::Connect => true, // for pub or sub just match the topic filter. _ => { if let Ok(filter) = TopicFilter::from_str(policy) { diff --git a/mqtt/mqtt-policy/src/substituter.rs b/mqtt/mqtt-policy/src/substituter.rs index 777e4e7fbdf..aa86dfefc72 100644 --- a/mqtt/mqtt-policy/src/substituter.rs +++ b/mqtt/mqtt-policy/src/substituter.rs @@ -25,23 +25,33 @@ impl MqttSubstituter { } fn replace_variable(&self, value: &str, context: &Request) -> String { - if let Some(context) = context.context() { - if let Some(variable) = extract_variable(value) { - return match variable { - crate::CLIENT_ID_VAR => { - replace(value, variable, context.client_info().client_id().as_str()) - } - crate::IDENTITY_VAR => { - replace(value, variable, context.client_info().auth_id().as_str()) - } - crate::DEVICE_ID_VAR => replace(value, variable, extract_device_id(&context)), - crate::MODULE_ID_VAR => replace(value, variable, extract_module_id(&context)), - crate::EDGEHUB_ID_VAR => replace(value, variable, self.device_id()), - _ => value.to_string(), - }; + match context.context() { + Some(context) => { + let mut result = value.to_owned(); + for variable in VariableIter::new(value) { + result = match variable { + crate::CLIENT_ID_VAR => replace( + &result, + variable, + context.client_info().client_id().as_str(), + ), + crate::IDENTITY_VAR => { + replace(&result, variable, context.client_info().auth_id().as_str()) + } + crate::DEVICE_ID_VAR => { + replace(&result, variable, extract_device_id(&context)) + } + crate::MODULE_ID_VAR => { + replace(&result, variable, extract_module_id(&context)) + } + crate::EDGEHUB_ID_VAR => replace(&result, variable, self.device_id()), + _ => result, + }; + } + result } + None => value.to_owned(), } - value.to_string() } } @@ -57,13 +67,36 @@ impl Substituter for MqttSubstituter { } } -pub(super) fn extract_variable(value: &str) -> Option<&str> { - if let Some(start) = value.find("{{") { - if let Some(end) = value.find("}}") { - return Some(&value[start..end + 2]); +/// A simple iterator that returns all occurrences +/// of variable substrings like `{{var_name}}` in the +/// provided string value. +#[derive(Debug)] +pub(super) struct VariableIter<'a> { + value: &'a str, + index: usize, +} + +impl<'a> VariableIter<'a> { + pub fn new(value: &'a str) -> Self { + Self { value, index: 0 } + } +} + +impl<'a> Iterator for VariableIter<'a> { + type Item = &'a str; + + fn next(&mut self) -> Option { + let value = &self.value[self.index..]; + if let Some(start) = value.find("{{") { + if let Some(end) = value.find("}}") { + if start < end { + self.index = self.index + end + 2; + return Some(&value[start..end + 2]); + } + } } + None } - None } fn replace(value: &str, variable: &str, substitution: &str) -> String { @@ -82,6 +115,7 @@ fn extract_module_id(activity: &Activity) -> &str { #[cfg(test)] mod tests { + use proptest::prelude::*; use test_case::test_case; use crate::tests; @@ -148,6 +182,26 @@ mod tests { "test_device_client_id", "{{{}bad}}}"; "bad variable")] + #[test_case("{{{}bad}", + "test_device_auth_id", + "test_device_client_id", + "{{{}bad}"; + "bad variable 2")] + #[test_case("{}bad}}", + "test_device_auth_id", + "test_device_client_id", + "{}bad}}"; + "bad variable 3")] + #[test_case("{{iot:this_device_id}}{{iot:module_id}}", + "test_device_auth_id/test_module", + "test_device_client_id", + "edge_devicetest_module"; + "multiple variable")] + #[test_case("namespace-{{iot:this_device_id}}/{{iot:module_id}}-suffix", + "test_device_auth_id/test_module", + "test_device_client_id", + "namespace-edge_device/test_module-suffix"; + "multiple variable substring")] fn visit_identity_test(input: &str, auth_id: &str, client_id: &str, expected: &str) { let request = Request::with_context( "some_identity", @@ -225,6 +279,26 @@ mod tests { "test_device_client_id", "{{{}bad}}}"; "bad variable")] + #[test_case("{{{}bad}", + "test_device_auth_id", + "test_device_client_id", + "{{{}bad}"; + "bad variable 2")] + #[test_case("{}bad}}", + "test_device_auth_id", + "test_device_client_id", + "{}bad}}"; + "bad variable 3")] + #[test_case("{{iot:this_device_id}}{{iot:module_id}}", + "test_device_auth_id/test_module", + "test_device_client_id", + "edge_devicetest_module"; + "multiple variable")] + #[test_case("namespace-{{iot:this_device_id}}/{{iot:module_id}}-suffix", + "test_device_auth_id/test_module", + "test_device_client_id", + "namespace-edge_device/test_module-suffix"; + "multiple variable substring")] fn visit_resource_test(input: &str, auth_id: &str, client_id: &str, expected: &str) { let request = Request::with_context( "some_identity", @@ -241,4 +315,11 @@ mod tests { .unwrap() ); } + + proptest! { + #[test] + fn iterator_does_not_crash(value in "[a-z\\{\\}]+") { + drop(VariableIter::new(&value).collect::>()); + } + } } diff --git a/mqtt/mqtt-policy/src/validator.rs b/mqtt/mqtt-policy/src/validator.rs index 2783fda8680..60049314fa9 100644 --- a/mqtt/mqtt-policy/src/validator.rs +++ b/mqtt/mqtt-policy/src/validator.rs @@ -6,7 +6,7 @@ use lazy_static::lazy_static; use mqtt_broker::TopicFilter; use policy::{PolicyDefinition, PolicyValidator, Statement}; -use crate::{errors::Error, substituter}; +use crate::{errors::Error, substituter::VariableIter}; /// MQTT-specific implementation of `PolicyValidator`. It checks the following rules: /// * Valid schema version. @@ -78,7 +78,7 @@ fn visit_identity(value: &str) -> Result<(), Error> { if value.is_empty() { return Err(Error::InvalidIdentity(value.into())); } - if let Some(variable) = substituter::extract_variable(value) { + for variable in VariableIter::new(value) { if VALID_VARIABLES.get(variable).is_none() { return Err(Error::InvalidIdentityVariable(variable.into())); } @@ -97,7 +97,7 @@ fn visit_resource(value: &str) -> Result<(), Error> { if value.is_empty() { return Err(Error::InvalidResource(value.into())); } - if let Some(variable) = substituter::extract_variable(value) { + for variable in VariableIter::new(value) { if VALID_VARIABLES.get(variable).is_none() { return Err(Error::InvalidResourceVariable(variable.into())); } diff --git a/mqtt/mqtt3/Cargo.toml b/mqtt/mqtt3/Cargo.toml index 0cfc0a33724..c6fc9b5fd13 100644 --- a/mqtt/mqtt3/Cargo.toml +++ b/mqtt/mqtt3/Cargo.toml @@ -12,7 +12,6 @@ futures-channel = { version = "0.3", features = ["sink"] } futures-sink = "0.3" futures-util = { version = "0.3", features = ["sink"] } log = "0.4" -mockall = "0.8" serde = { version = "1.0", optional = true, features = ["derive"] } tokio = { version = "0.2", features = ["time"] } tokio-util = { version = "0.2", features = ["codec"] } @@ -23,5 +22,4 @@ structopt = "0.3" tokio = { version = "0.2", features = ["rt-core", "signal", "stream", "tcp"] } [features] -serde1 = ["serde"] - +serde1 = ["serde"] \ No newline at end of file diff --git a/mqtt/mqtt3/src/client/mod.rs b/mqtt/mqtt3/src/client/mod.rs index db243eaa692..711a015df64 100644 --- a/mqtt/mqtt3/src/client/mod.rs +++ b/mqtt/mqtt3/src/client/mod.rs @@ -5,12 +5,10 @@ mod connect; mod ping; mod publish; -pub use publish::{MockPublishHandle, PublishError, PublishHandle}; +pub use publish::{PublishError, PublishHandle}; mod subscriptions; -pub use subscriptions::{ - MockUpdateSubscriptionHandle, UpdateSubscriptionError, UpdateSubscriptionHandle, -}; +pub use subscriptions::{UpdateSubscriptionError, UpdateSubscriptionHandle}; /// An MQTT v3.1.1 client. /// diff --git a/mqtt/mqtt3/src/client/publish.rs b/mqtt/mqtt3/src/client/publish.rs index ce7e489a4ed..491a681b2fc 100644 --- a/mqtt/mqtt3/src/client/publish.rs +++ b/mqtt/mqtt3/src/client/publish.rs @@ -1,7 +1,5 @@ use std::future::Future; -use mockall::automock; - #[derive(Debug)] pub(super) struct State { publish_request_send: futures_channel::mpsc::Sender, @@ -366,7 +364,6 @@ impl Default for State { #[derive(Clone, Debug)] pub struct PublishHandle(futures_channel::mpsc::Sender); -#[automock] impl PublishHandle { /// Publish the given message to the server pub async fn publish( diff --git a/mqtt/mqtt3/src/client/subscriptions.rs b/mqtt/mqtt3/src/client/subscriptions.rs index 414795af178..ee1389fad9a 100644 --- a/mqtt/mqtt3/src/client/subscriptions.rs +++ b/mqtt/mqtt3/src/client/subscriptions.rs @@ -1,5 +1,3 @@ -use mockall::automock; - #[derive(Debug)] pub(super) struct State { subscriptions: std::collections::BTreeMap, @@ -86,7 +84,7 @@ impl State { crate::proto::SubAckQos::Success(actual_qos) => { if actual_qos >= expected_qos { log::debug!( - "Subscribed to {} with {:?}", + "Subscribed to {} with qos {:?}", topic_filter, actual_qos ); @@ -225,66 +223,42 @@ impl State { // // So we cannot just make a group of all Subscribes, send that packet, then make a group of all Unsubscribes, then send that packet. // Instead, we have to respect the ordering of Subscribes with Unsubscribes. - // So we make an intermediate set of all subscriptions based on the updates waiting to be sent, compute the diff from the current subscriptions, + // So we make an intermediate set of all subscriptions and unsubscriptions and if for same topic an Unsubscribe is before a Subscribe, only Subscribe remains in the intermediate set + // and if a Unsubscribe is set after Subscribe they are both removed // then send a SUBSCRIBE packet for any net new subscriptions and an UNSUBSCRIBE packet for any net new unsubscriptions. - - let mut current_subscriptions: std::collections::BTreeMap<_, _> = self - .subscriptions - .iter() - .map(|(topic_filter, qos)| (std::borrow::Cow::Borrowed(&**topic_filter), *qos)) - .collect(); - - for (_, subscription_update) in &self.subscription_updates_waiting_to_be_acked { - match subscription_update { - BatchedSubscriptionUpdate::Subscribe(subscribe_to) => { - for subscribe_to in subscribe_to { - current_subscriptions.insert( - std::borrow::Cow::Borrowed(&*subscribe_to.topic_filter), - subscribe_to.qos, - ); - } - } - - BatchedSubscriptionUpdate::Unsubscribe(unsubscribe_from) => { - for unsubscribe_from in unsubscribe_from { - current_subscriptions.remove(&**unsubscribe_from); - } - } - } - } - - let mut target_subscriptions = current_subscriptions.clone(); + let mut target_subscriptions = std::collections::BTreeMap::new(); + let mut target_unsubscriptions = std::collections::BTreeMap::new(); while let Some(subscription_update) = self.subscription_updates_waiting_to_be_sent.pop_front() { match subscription_update { - SubscriptionUpdate::Subscribe(subscribe_to) => target_subscriptions.insert( - std::borrow::Cow::Owned(subscribe_to.topic_filter), - subscribe_to.qos, - ), + SubscriptionUpdate::Subscribe(subscribe_to) => { + target_unsubscriptions.remove(&subscribe_to.topic_filter); + target_subscriptions.insert( + std::borrow::Cow::Owned(subscribe_to.topic_filter), + subscribe_to.qos, + ); + } SubscriptionUpdate::Unsubscribe(unsubscribe_from) => { - target_subscriptions.remove(&*unsubscribe_from) + if target_subscriptions.remove(&*unsubscribe_from).is_none() { + target_unsubscriptions.insert(unsubscribe_from, true); + } } }; } let mut pending_subscriptions: std::collections::VecDeque<_> = Default::default(); for (topic_filter, &qos) in &target_subscriptions { - if current_subscriptions.get(topic_filter) != Some(&qos) { - // Current subscription doesn't exist, or exists but has different QoS - pending_subscriptions.push_back(crate::proto::SubscribeTo { - topic_filter: topic_filter.clone().into_owned(), - qos, - }); - } + pending_subscriptions.push_back(crate::proto::SubscribeTo { + topic_filter: topic_filter.clone().into_owned(), + qos, + }); } let mut pending_unsubscriptions: std::collections::VecDeque<_> = Default::default(); - for topic_filter in current_subscriptions.keys() { - if !target_subscriptions.contains_key(topic_filter) { - pending_unsubscriptions.push_back(topic_filter.clone().into_owned()); - } + for topic_filter in target_unsubscriptions.keys() { + pending_unsubscriptions.push_back(topic_filter.clone()); } // Save the error, if any, from reserving a packet identifier @@ -608,7 +582,6 @@ impl Iterator for NewConnectionIter { #[derive(Clone, Debug)] pub struct UpdateSubscriptionHandle(futures_channel::mpsc::Sender); -#[automock] impl UpdateSubscriptionHandle { #[allow(clippy::doc_markdown)] /// Subscribe to a topic with the given parameters. diff --git a/mqtt/mqtt3/src/lib.rs b/mqtt/mqtt3/src/lib.rs index e99c140fdad..0bacb6434cd 100644 --- a/mqtt/mqtt3/src/lib.rs +++ b/mqtt/mqtt3/src/lib.rs @@ -27,9 +27,9 @@ pub const PROTOCOL_LEVEL: u8 = 0x04; mod client; pub use client::{ - Client, ConnectionError, Error, Event, IoSource, MockPublishHandle, - MockUpdateSubscriptionHandle, PublishError, PublishHandle, ReceivedPublication, ShutdownError, - ShutdownHandle, SubscriptionUpdateEvent, UpdateSubscriptionError, UpdateSubscriptionHandle, + Client, ConnectionError, Error, Event, IoSource, PublishError, PublishHandle, + ReceivedPublication, ShutdownError, ShutdownHandle, SubscriptionUpdateEvent, + UpdateSubscriptionError, UpdateSubscriptionHandle, }; mod logging_framed; diff --git a/mqtt/mqtt3/src/proto/packet.rs b/mqtt/mqtt3/src/proto/packet.rs index 79e584fd514..23ee44308e1 100644 --- a/mqtt/mqtt3/src/proto/packet.rs +++ b/mqtt/mqtt3/src/proto/packet.rs @@ -144,8 +144,6 @@ impl std::fmt::Debug for Connect { .field("will", &self.will) .field("client_id", &self.client_id) .field("keep_alive", &self.keep_alive) - .field("protocol_name", &self.protocol_name) - .field("protocol_level", &self.protocol_level) .finish() } } @@ -874,7 +872,7 @@ pub struct SubscribeTo { /// The level of reliability for a publication /// /// Ref: 4.3 Quality of Service levels and protocol flows -#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)] +#[derive(Clone, Copy, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serde1", derive(Deserialize, Serialize))] pub enum QoS { AtMostOnce, @@ -892,6 +890,12 @@ impl From for u8 { } } +impl std::fmt::Debug for QoS { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_fmt(format_args!("{}", u8::from(*self))) + } +} + #[allow(clippy::doc_markdown)] /// QoS returned in a SUBACK packet. Either one of the [`QoS`] values, or an error code. #[derive(Clone, Copy, Debug, Eq, PartialEq)] diff --git a/mqtt/mqtt3/tests/subscriptions.rs b/mqtt/mqtt3/tests/subscriptions.rs index b1a2c935717..82a76625e05 100644 --- a/mqtt/mqtt3/tests/subscriptions.rs +++ b/mqtt/mqtt3/tests/subscriptions.rs @@ -474,16 +474,30 @@ fn should_combine_pending_subscription_updates() { topic_filter: "topic3".to_string(), qos: mqtt3::proto::QoS::ExactlyOnce, }, + mqtt3::proto::SubscribeTo { + topic_filter: "topic4".to_string(), + qos: mqtt3::proto::QoS::ExactlyOnce, + }, ], }, )), + common::TestConnectionStep::Receives(mqtt3::proto::Packet::Unsubscribe( + mqtt3::proto::Unsubscribe { + packet_identifier: mqtt3::proto::PacketIdentifier::new(2).unwrap(), + unsubscribe_from: vec!["topic5".to_string()], + }, + )), common::TestConnectionStep::Sends(mqtt3::proto::Packet::SubAck(mqtt3::proto::SubAck { packet_identifier: mqtt3::proto::PacketIdentifier::new(1).unwrap(), qos: vec![ mqtt3::proto::SubAckQos::Success(mqtt3::proto::QoS::AtLeastOnce), mqtt3::proto::SubAckQos::Success(mqtt3::proto::QoS::ExactlyOnce), + mqtt3::proto::SubAckQos::Success(mqtt3::proto::QoS::ExactlyOnce), ], })), + common::TestConnectionStep::Sends(mqtt3::proto::Packet::UnsubAck(mqtt3::proto::UnsubAck { + packet_identifier: mqtt3::proto::PacketIdentifier::new(2).unwrap(), + })), common::TestConnectionStep::Receives(mqtt3::proto::Packet::PingReq(mqtt3::proto::PingReq)), common::TestConnectionStep::Sends(mqtt3::proto::Packet::PingResp(mqtt3::proto::PingResp)), ]]); @@ -508,12 +522,20 @@ fn should_combine_pending_subscription_updates() { qos: mqtt3::proto::QoS::AtLeastOnce, }) .unwrap(); + client .subscribe(mqtt3::proto::SubscribeTo { topic_filter: "topic3".to_string(), qos: mqtt3::proto::QoS::ExactlyOnce, }) .unwrap(); + client.unsubscribe("topic4".to_string()).unwrap(); + client + .subscribe(mqtt3::proto::SubscribeTo { + topic_filter: "topic4".to_string(), + qos: mqtt3::proto::QoS::ExactlyOnce, + }) + .unwrap(); client .subscribe(mqtt3::proto::SubscribeTo { topic_filter: "topic1".to_string(), @@ -521,6 +543,7 @@ fn should_combine_pending_subscription_updates() { }) .unwrap(); client.unsubscribe("topic2".to_string()).unwrap(); + client.unsubscribe("topic5".to_string()).unwrap(); common::verify_client_events( &mut runtime, @@ -538,7 +561,15 @@ fn should_combine_pending_subscription_updates() { topic_filter: "topic3".to_string(), qos: mqtt3::proto::QoS::ExactlyOnce, }), + mqtt3::SubscriptionUpdateEvent::Subscribe(mqtt3::proto::SubscribeTo { + topic_filter: "topic4".to_string(), + qos: mqtt3::proto::QoS::ExactlyOnce, + }), ]), + mqtt3::Event::SubscriptionUpdates(vec![mqtt3::SubscriptionUpdateEvent::Unsubscribe( + "topic5".to_string(), + )]), + mqtt3::Event::Disconnected(mqtt3::ConnectionError::ServerClosedConnection), ], ); diff --git a/mqtt/mqttd/Cargo.toml b/mqtt/mqttd/Cargo.toml index 55a99a9f716..3df59e93c06 100644 --- a/mqtt/mqttd/Cargo.toml +++ b/mqtt/mqttd/Cargo.toml @@ -7,13 +7,15 @@ edition = "2018" [dependencies] anyhow = "1.0" -atty = "0.2" +async-trait = "0.1" +cfg-if = "1.0" chrono = "0.4" clap = "2.33" futures-util = { version = "0.3", features = ["sink"] } tokio = { version = "0.2", features = ["dns", "macros", "rt-threaded", "signal", "stream", "tcp", "time"] } thiserror = "1.0" tracing = "0.1" +tracing-log = "0.1" tracing-subscriber = "0.1" edgelet-client = { path = "../edgelet-client", optional = true } diff --git a/mqtt/mqttd/src/app/edgehub.rs b/mqtt/mqttd/src/app/edgehub.rs new file mode 100644 index 00000000000..fb28855b1bc --- /dev/null +++ b/mqtt/mqttd/src/app/edgehub.rs @@ -0,0 +1,308 @@ +use std::{ + env, fs, + future::Future, + path::{Path, PathBuf}, +}; + +use anyhow::{bail, Context, Result}; +use async_trait::async_trait; +use chrono::{DateTime, Duration, Utc}; +use futures_util::{ + future::{self, Either}, + FutureExt, +}; +use tracing::{debug, error, info}; + +use mqtt_bridge::{settings::BridgeSettings, BridgeController}; +use mqtt_broker::{ + auth::Authorizer, + sidecar::{Sidecar, SidecarShutdownHandle}, + Broker, BrokerBuilder, BrokerHandle, BrokerReady, BrokerSnapshot, FilePersistor, + MakeMqttPacketProcessor, Message, Persist, Server, ServerCertificate, SystemEvent, + VersionedFileFormat, +}; +use mqtt_edgehub::{ + auth::{ + EdgeHubAuthenticator, EdgeHubAuthorizer, LocalAuthenticator, LocalAuthorizer, + PolicyAuthorizer, + }, + command::{ + AuthorizedIdentitiesCommand, BridgeUpdateCommand, CommandHandler, DisconnectCommand, + PolicyUpdateCommand, + }, + connection::MakeEdgeHubPacketProcessor, + settings::Settings, +}; + +use super::{shutdown, Bootstrap}; + +const DEVICE_ID_ENV: &str = "IOTEDGE_DEVICEID"; +const IOTHUB_HOSTNAME_ENV: &str = "IOTEDGE_IOTHUBHOSTNAME"; + +#[derive(Default)] +pub struct EdgeHubBootstrap { + broker_ready: BrokerReady, +} + +#[async_trait] +impl Bootstrap for EdgeHubBootstrap { + type Settings = Settings; + + fn load_config>(&self, path: P) -> Result { + info!("loading settings from a file {}", path.as_ref().display()); + Ok(Self::Settings::from_file(path)?) + } + + type Authorizer = LocalAuthorizer>; + + async fn make_broker( + &self, + settings: &Self::Settings, + ) -> Result<(Broker, FilePersistor)> { + info!("loading state..."); + let persistence_config = settings.broker().persistence(); + let state_dir = persistence_config.file_path(); + + fs::create_dir_all(state_dir.clone())?; + let mut persistor = FilePersistor::new(state_dir, VersionedFileFormat::default()); + let state = persistor.load().await?; + info!("state loaded."); + + let device_id = env::var(DEVICE_ID_ENV).context(DEVICE_ID_ENV)?; + let iothub_id = env::var(IOTHUB_HOSTNAME_ENV).context(IOTHUB_HOSTNAME_ENV)?; + + let authorizer = LocalAuthorizer::new(EdgeHubAuthorizer::new( + PolicyAuthorizer::new(device_id.clone(), self.broker_ready.handle()), + device_id, + iothub_id, + self.broker_ready.handle(), + )); + + let broker = BrokerBuilder::default() + .with_authorizer(authorizer) + .with_state(state.unwrap_or_default()) + .with_config(settings.broker().clone()) + .build(); + + Ok((broker, persistor)) + } + + fn snapshot_interval(&self, settings: &Self::Settings) -> std::time::Duration { + settings.broker().persistence().time_interval() + } + + async fn run( + self, + config: Self::Settings, + broker: Broker, + ) -> Result { + let mut broker_handle = broker.handle(); + let sidecars = make_sidecars(&broker_handle, &config)?; + + info!("starting server..."); + let server = make_server(config, broker, self.broker_ready).await?; + + let shutdown_signal = shutdown_signal(&server); + let server = tokio::spawn(server.serve(shutdown_signal)); + + info!("starting sidecars..."); + + let mut shutdowns = Vec::new(); + let mut sidecar_joins = Vec::new(); + + for sidecar in sidecars { + shutdowns.push(sidecar.shutdown_handle()?); + sidecar_joins.push(tokio::spawn(sidecar.run())); + } + + let state = match future::select(server, future::select_all(sidecar_joins)).await { + // server exited first + Either::Left((snapshot, sidecars)) => { + // send shutdown event to each sidecar + let shutdowns = shutdowns.into_iter().map(SidecarShutdownHandle::shutdown); + future::join_all(shutdowns).await; + + // awaits for at least one to finish + let (_res, _stopped, sidecars) = sidecars.await; + + // wait for the rest to exit + future::join_all(sidecars).await; + + snapshot?? + } + // one of sidecars exited first + Either::Right(((res, stopped, sidecars), server)) => { + debug!("a sidecar has stopped. shutting down all sidecars..."); + if let Err(e) = res { + error!(message = "failed waiting for sidecar shutdown", error = %e); + } + + // send shutdown event to each of the rest sidecars + shutdowns.remove(stopped); + let shutdowns = shutdowns.into_iter().map(SidecarShutdownHandle::shutdown); + future::join_all(shutdowns).await; + + // wait for the rest to exit + future::join_all(sidecars).await; + + // signal server + broker_handle.send(Message::System(SystemEvent::Shutdown))?; + server.await?? + } + }; + + Ok(state) + } +} + +async fn make_server( + config: Settings, + broker: Broker, + broker_ready: BrokerReady, +) -> Result>> +where + Z: Authorizer + Send + 'static, +{ + let broker_handle = broker.handle(); + + let make_processor = MakeEdgeHubPacketProcessor::new_default(broker_handle.clone()); + let mut server = Server::from_broker(broker).with_packet_processor(make_processor); + + // Add system transport to allow communication between edgehub components + let authenticator = LocalAuthenticator::new(); + server.with_tcp(config.listener().system().addr(), authenticator, None)?; + + // Add regular MQTT over TCP transport + let authenticator = EdgeHubAuthenticator::new(config.auth().url()); + + if let Some(tcp) = config.listener().tcp() { + let broker_ready = Some(broker_ready.signal()); + server.with_tcp(tcp.addr(), authenticator.clone(), broker_ready)?; + } + + // Add regular MQTT over TLS transport + if let Some(tls) = config.listener().tls() { + let identity = if let Some(config) = tls.certificate() { + info!("loading identity from {}", config.cert_path().display()); + ServerCertificate::from_pem(config.cert_path(), config.private_key_path()) + .with_context(|| { + ServerCertificateLoadError::File( + config.cert_path().to_path_buf(), + config.private_key_path().to_path_buf(), + ) + })? + } else { + info!("downloading identity from edgelet"); + download_server_certificate() + .await + .with_context(|| ServerCertificateLoadError::Edgelet)? + }; + + let broker_ready = Some(broker_ready.signal()); + server.with_tls(tls.addr(), identity, authenticator.clone(), broker_ready)?; + }; + + Ok(server) +} + +fn make_sidecars( + broker_handle: &BrokerHandle, + config: &Settings, +) -> Result>> { + let mut sidecars: Vec> = Vec::new(); + + let system_address = config.listener().system().addr().to_string(); + let device_id = env::var(DEVICE_ID_ENV).context(DEVICE_ID_ENV)?; + + let settings = BridgeSettings::new()?; + let bridge_controller = + BridgeController::new(system_address.clone(), device_id.to_owned(), settings); + let bridge_controller_handle = bridge_controller.handle(); + + sidecars.push(Box::new(bridge_controller)); + + let mut command_handler = CommandHandler::new(system_address, &device_id); + command_handler.add_command(DisconnectCommand::new(&broker_handle)); + command_handler.add_command(AuthorizedIdentitiesCommand::new(&broker_handle)); + command_handler.add_command(PolicyUpdateCommand::new(broker_handle)); + command_handler.add_command(BridgeUpdateCommand::new(bridge_controller_handle)); + sidecars.push(Box::new(command_handler)); + + Ok(sidecars) +} + +pub const WORKLOAD_URI: &str = "IOTEDGE_WORKLOADURI"; +pub const EDGE_DEVICE_HOST_NAME: &str = "EdgeDeviceHostName"; +pub const MODULE_ID: &str = "IOTEDGE_MODULEID"; +pub const MODULE_GENERATION_ID: &str = "IOTEDGE_MODULEGENERATIONID"; + +pub const CERTIFICATE_VALIDITY_DAYS: i64 = 90; + +async fn download_server_certificate() -> Result { + let uri = env::var(WORKLOAD_URI).context(WORKLOAD_URI)?; + let hostname = env::var(EDGE_DEVICE_HOST_NAME).context(EDGE_DEVICE_HOST_NAME)?; + let module_id = env::var(MODULE_ID).context(MODULE_ID)?; + let generation_id = env::var(MODULE_GENERATION_ID).context(MODULE_GENERATION_ID)?; + let expiration = Utc::now() + Duration::days(CERTIFICATE_VALIDITY_DAYS); + + let client = edgelet_client::workload(&uri)?; + let cert = client + .create_server_cert(&module_id, &generation_id, &hostname, expiration) + .await?; + + if cert.private_key().type_() != "key" { + bail!( + "unknown type of private key: {}", + cert.private_key().type_() + ); + } + + if let Some(private_key) = cert.private_key().bytes() { + let identity = ServerCertificate::from_pem_pair(cert.certificate(), private_key)?; + Ok(identity) + } else { + bail!("missing private key"); + } +} + +fn shutdown_signal(server: &Server) -> impl Future { + server + .listeners() + .iter() + .find_map(|listener| listener.transport().identity()) + .map_or_else( + || Either::Left(shutdown::shutdown()), + |identity| { + let system_or_cert_expired = future::select( + Box::pin(server_certificate_renewal(identity.not_after())), + Box::pin(shutdown::shutdown()), + ); + Either::Right(system_or_cert_expired.map(drop)) + }, + ) +} + +async fn server_certificate_renewal(renew_at: DateTime) { + let delay = renew_at - Utc::now(); + if delay > Duration::zero() { + info!( + "scheduled server certificate renewal timer for {}", + renew_at + ); + let delay = delay.to_std().expect("duration must not be negative"); + crate::time::sleep(delay).await; + + info!("restarting the broker to perform certificate renewal"); + } else { + error!("server certificate expired at {}", renew_at); + } +} + +#[derive(Debug, thiserror::Error)] +pub enum ServerCertificateLoadError { + #[error("unable to load server certificate from file {0} and private key {1}")] + File(PathBuf, PathBuf), + + #[error("unable to download certificate from edgelet")] + Edgelet, +} diff --git a/mqtt/mqttd/src/app/generic.rs b/mqtt/mqttd/src/app/generic.rs new file mode 100644 index 00000000000..7644cbc5b1b --- /dev/null +++ b/mqtt/mqttd/src/app/generic.rs @@ -0,0 +1,115 @@ +use std::{ + fs, + path::{Path, PathBuf}, +}; + +use anyhow::{Context, Result}; +use async_trait::async_trait; +use futures_util::pin_mut; +use tracing::{error, info}; + +use mqtt_broker::{ + auth::{AllowAll, Authorizer, authenticate_fn_ok}, + AuthId, Broker, BrokerBuilder, BrokerSnapshot, FilePersistor, + MakeMqttPacketProcessor, Persist, Server, ServerCertificate, VersionedFileFormat, +}; +use mqtt_generic::settings::{CertificateConfig, Settings}; + +use super::{shutdown, Bootstrap}; + +#[derive(Default)] +pub struct GenericBootstrap; + +#[async_trait] +impl Bootstrap for GenericBootstrap { + type Settings = Settings; + + fn load_config>(&self, path: P) -> Result { + info!("loading settings from a file {}", path.as_ref().display()); + Ok(Self::Settings::from_file(path)?) + } + + type Authorizer = AllowAll; + + async fn make_broker( + &self, + settings: &Self::Settings, + ) -> Result<(Broker, FilePersistor)> { + info!("loading state..."); + let persistence_config = settings.broker().persistence(); + let state_dir = persistence_config.file_path(); + + fs::create_dir_all(state_dir.clone())?; + let mut persistor = FilePersistor::new(state_dir, VersionedFileFormat::default()); + let state = persistor.load().await?; + info!("state loaded."); + + let broker = BrokerBuilder::default() + .with_authorizer(AllowAll) + .with_state(state.unwrap_or_default()) + .with_config(settings.broker().clone()) + .build(); + + Ok((broker, persistor)) + } + + fn snapshot_interval(&self, settings: &Self::Settings) -> std::time::Duration { + settings.broker().persistence().time_interval() + } + + async fn run( + self, + config: Self::Settings, + broker: Broker, + ) -> Result { + let shutdown_signal = shutdown::shutdown(); + pin_mut!(shutdown_signal); + + info!("starting server..."); + let server = make_server(config, broker).await?; + let state = server.serve(shutdown_signal).await?; + + Ok(state) + } +} + +async fn make_server( + config: Settings, + broker: Broker, +) -> Result> +where + Z: Authorizer + Send + 'static, +{ + let mut server = Server::from_broker(broker); + + if let Some(tcp) = config.listener().tcp() { + let authenticator = authenticate_fn_ok(|_| Some(AuthId::Anonymous)); + server.with_tcp(tcp.addr(), authenticator, None)?; + } + + if let Some(tls) = config.listener().tls() { + let authenticator = authenticate_fn_ok(|_| Some(AuthId::Anonymous)); + let identity = load_server_certificate(tls.certificate())?; + server.with_tls(tls.addr(), identity, authenticator, None)?; + } + + Ok(server) +} + +fn load_server_certificate(config: &CertificateConfig) -> Result { + let identity = ServerCertificate::from_pem(config.cert_path(), config.private_key_path()) + .with_context(|| { + ServerCertificateLoadError::ParseCertificate( + config.cert_path().to_path_buf(), + config.private_key_path().to_path_buf(), + ) + })?; + + Ok(identity) +} + +#[derive(Debug, thiserror::Error)] +pub enum ServerCertificateLoadError { + #[error("unable to decode server certificate {0} and private key {1}")] + ParseCertificate(PathBuf, PathBuf), +} diff --git a/mqtt/mqttd/src/app/mod.rs b/mqtt/mqttd/src/app/mod.rs new file mode 100644 index 00000000000..589c588f672 --- /dev/null +++ b/mqtt/mqttd/src/app/mod.rs @@ -0,0 +1,118 @@ +mod shutdown; +mod snapshot; + +cfg_if! { + if #[cfg(feature = "edgehub")] { + mod edgehub; + + pub fn new() -> App { + App::new(edgehub::EdgeHubBootstrap::default()) + } + } else { + mod generic; + + pub fn new() -> App { + App::new(generic::GenericBootstrap::default()) + } + } +} + +use std::path::Path; + +use anyhow::{Context, Result}; +use async_trait::async_trait; +use cfg_if::cfg_if; +use tracing::{error, info}; + +use mqtt_broker::{ + auth::Authorizer, Broker, BrokerSnapshot, FilePersistor, Persist, VersionedFileFormat, +}; + +/// Main entrypoint to the app. +pub struct App +where + B: Bootstrap, +{ + bootstrap: B, + settings: B::Settings, +} + +impl App +where + B: Bootstrap, +{ + /// Returns a new instance of the app. + pub fn new(bootstrap: B) -> Self { + Self { + bootstrap, + settings: B::Settings::default(), + } + } + + /// Configures app with settings. + pub fn setup

(&mut self, config_path: P) -> Result<()> + where + P: AsRef, + { + self.settings = self + .bootstrap + .load_config(config_path) + .context("An error occurred loading configuration.")?; + Ok(()) + } + + /// Starts up all routines and runs MQTT server. + pub async fn run(self) -> Result<()> { + let (broker, persistor) = self.bootstrap.make_broker(&self.settings).await?; + + let snapshot_interval = self.bootstrap.snapshot_interval(&self.settings); + let (mut snapshotter_shutdown_handle, snapshotter_join_handle) = + snapshot::start_snapshotter(broker.handle(), persistor, snapshot_interval).await; + + let state = self.bootstrap.run(self.settings, broker).await?; + + snapshotter_shutdown_handle.shutdown().await?; + let mut persistor = snapshotter_join_handle.await?; + info!("state snapshotter shutdown."); + + info!("persisting state before exiting..."); + persistor.store(state).await?; + info!("state persisted."); + + info!("exiting... goodbye"); + Ok(()) + } +} + +#[derive(Debug, thiserror::Error)] +#[error("An error occurred loading configuration.")] +pub struct LoadConfigurationError; + +/// Defines a common steps for an app to start. +#[async_trait] +pub trait Bootstrap { + /// A type describing app configuration. + type Settings: Default; + + /// Reads configuration from the file on disk and returns settings. + fn load_config>(&self, path: P) -> Result; + + /// An `Authorizer` type. + type Authorizer: Authorizer + Send + 'static; + + /// Creates a new instance of the `Broker` configured for the app. + async fn make_broker( + &self, + settings: &Self::Settings, + ) -> Result<(Broker, FilePersistor)>; + + /// Returns update interval for snapshotter. + fn snapshot_interval(&self, settings: &Self::Settings) -> std::time::Duration; + + /// Runs all configured routines: MQTT server, sidecars, etc.. + async fn run( + self, + config: Self::Settings, + broker: Broker, + ) -> Result; +} diff --git a/mqtt/mqttd/src/broker/shutdown.rs b/mqtt/mqttd/src/app/shutdown.rs similarity index 100% rename from mqtt/mqttd/src/broker/shutdown.rs rename to mqtt/mqttd/src/app/shutdown.rs diff --git a/mqtt/mqttd/src/broker/snapshot.rs b/mqtt/mqttd/src/app/snapshot.rs similarity index 92% rename from mqtt/mqttd/src/broker/snapshot.rs rename to mqtt/mqttd/src/app/snapshot.rs index 0546d055373..7cc7e4ed62c 100644 --- a/mqtt/mqttd/src/broker/snapshot.rs +++ b/mqtt/mqttd/src/app/snapshot.rs @@ -44,7 +44,7 @@ async fn tick_snapshot( mut broker_handle: BrokerHandle, snapshot_handle: StateSnapshotHandle, ) { - info!("Persisting state every {:?}", period); + info!("persisting state every {:?}", period); let start = Instant::now() + period; let mut interval = tokio::time::interval_at(start, period); loop { @@ -52,7 +52,7 @@ async fn tick_snapshot( if let Err(e) = broker_handle.send(Message::System(SystemEvent::StateSnapshot( snapshot_handle.clone(), ))) { - warn!(message = "failed to tick the snapshotter", error=%e); + warn!(message = "failed to tick the snapshotter", error = %e); } } } @@ -72,19 +72,19 @@ mod imp { let mut stream = match signal(SignalKind::user_defined1()) { Ok(stream) => stream, Err(e) => { - warn!(message = "an error occurred setting up the signal handler", error=%e); + warn!(message = "an error occurred setting up the signal handler", error = %e); return; } }; - info!("Setup to persist state on USR1 signal"); + info!("setup to persist state on USR1 signal"); loop { stream.recv().await; - info!("Received signal USR1"); + info!("received signal USR1"); if let Err(e) = broker_handle.send(Message::System(SystemEvent::StateSnapshot( snapshot_handle.clone(), ))) { - warn!(message = "failed to signal the snapshotter", error=%e); + warn!(message = "failed to signal the snapshotter", error = %e); } } } diff --git a/mqtt/mqttd/src/broker/bootstrap/edgehub.rs b/mqtt/mqttd/src/broker/bootstrap/edgehub.rs deleted file mode 100644 index 3d2b9a9c30d..00000000000 --- a/mqtt/mqttd/src/broker/bootstrap/edgehub.rs +++ /dev/null @@ -1,320 +0,0 @@ -use std::{ - env, - future::Future, - path::{Path, PathBuf}, -}; - -use anyhow::{bail, Context, Result}; -use chrono::{DateTime, Duration, Utc}; -use futures_util::{ - future::{self, Either}, - pin_mut, FutureExt, -}; -use tokio::time; -use tracing::{error, info, warn}; - -use super::SidecarManager; -use mqtt_bridge::{settings::BridgeSettings, BridgeController}; -use mqtt_broker::{ - auth::{AllowAll, Authorizer}, - Broker, BrokerBuilder, BrokerConfig, BrokerHandle, BrokerReady, BrokerSnapshot, Server, - ServerCertificate, -}; -use mqtt_edgehub::{ - auth::{EdgeHubAuthenticator, EdgeHubAuthorizer, LocalAuthenticator, LocalAuthorizer}, - command::{ - AuthorizedIdentitiesCommand, BridgeUpdateCommand, CommandHandler, CommandHandlerError, - DisconnectCommand, PolicyUpdateCommand, ShutdownHandle, - }, - connection::MakeEdgeHubPacketProcessor, - settings::{ListenerConfig, Settings}, -}; - -const DEVICE_ID_ENV: &str = "IOTEDGE_DEVICEID"; - -pub fn config

(config_path: Option

) -> Result -where - P: AsRef, -{ - let config = if let Some(path) = config_path { - info!("loading settings from a file {}", path.as_ref().display()); - Settings::from_file(path)? - } else { - info!("using default settings"); - Settings::new()? - }; - - Ok(config) -} - -pub async fn broker( - config: &BrokerConfig, - state: Option, - broker_ready: &BrokerReady, -) -> Result> { - // TODO: Use AllowAll as bottom level authorizer until Policies are sent over from EdgeHub - let authorizer = LocalAuthorizer::new(EdgeHubAuthorizer::new(AllowAll, broker_ready.handle())); - - let broker = BrokerBuilder::default() - .with_authorizer(authorizer) - .with_state(state.unwrap_or_default()) - .with_config(config.clone()) - .build(); - - Ok(broker) -} - -pub async fn start_server( - config: Settings, - broker: Broker, - shutdown_signal: F, - broker_ready: BrokerReady, -) -> Result -where - Z: Authorizer + Send + 'static, - F: Future, -{ - info!("starting server..."); - - let broker_handle = broker.handle(); - - let make_processor = MakeEdgeHubPacketProcessor::new_default(broker_handle.clone()); - let mut server = Server::from_broker(broker).with_packet_processor(make_processor); - - // Add system transport to allow communication between edgehub components - let authenticator = LocalAuthenticator::new(); - server.with_tcp(config.listener().system().addr(), authenticator, None)?; - - // Add regular MQTT over TCP transport - let authenticator = EdgeHubAuthenticator::new(config.auth().url()); - - if let Some(tcp) = config.listener().tcp() { - let broker_ready = Some(broker_ready.signal()); - server.with_tcp(tcp.addr(), authenticator.clone(), broker_ready)?; - } - - // Add regular MQTT over TLS transport - let renewal_signal = match config.listener().tls() { - Some(tls) => { - let identity = if let Some(config) = tls.certificate() { - info!("loading identity from {}", config.cert_path().display()); - ServerCertificate::from_pem(config.cert_path(), config.private_key_path()) - .with_context(|| { - ServerCertificateLoadError::File( - config.cert_path().to_path_buf(), - config.private_key_path().to_path_buf(), - ) - })? - } else { - info!("downloading identity from edgelet"); - download_server_certificate() - .await - .with_context(|| ServerCertificateLoadError::Edgelet)? - }; - let renew_at = identity.not_after(); - - let broker_ready = Some(broker_ready.signal()); - server.with_tls(tls.addr(), identity, authenticator.clone(), broker_ready)?; - - let renewal_signal = server_certificate_renewal(renew_at); - Either::Left(renewal_signal) - } - None => Either::Right(future::pending()), - }; - - // Prepare shutdown signal which is either SYSTEM shutdown signal or cert renewal timout - pin_mut!(shutdown_signal); - pin_mut!(renewal_signal); - let shutdown = future::select(shutdown_signal, renewal_signal).map(drop); - - // Start serving new connections - let state = server.serve(shutdown).await?; - - Ok(state) -} - -#[derive(Debug, thiserror::Error)] -pub enum SidecarError { - #[error("Failed to shutdown command handler")] - CommandHandlerShutdown(#[from] CommandHandlerError), -} - -#[derive(Clone, Debug)] -pub struct SidecarShutdownHandle { - command_handler_shutdown: ShutdownHandle, -} - -impl SidecarShutdownHandle { - pub fn new(command_handler_shutdown: ShutdownHandle) -> Self { - Self { - command_handler_shutdown, - } - } - - pub async fn shutdown(self) -> Result<(), SidecarError> { - self.command_handler_shutdown - .shutdown() - .await - .map_err(SidecarError::CommandHandlerShutdown) - } -} - -pub async fn start_sidecars( - broker_handle: BrokerHandle, - listener_settings: ListenerConfig, -) -> Result> { - info!("starting sidecars..."); - - let system_address = listener_settings.system().addr().to_string(); - - let bridge_controller = BridgeController::new(); - - let device_id = env::var(DEVICE_ID_ENV)?; - let mut command_handler = CommandHandler::new(system_address.clone(), &device_id); - command_handler.add_command(DisconnectCommand::new(&broker_handle)); - command_handler.add_command(AuthorizedIdentitiesCommand::new(&broker_handle)); - command_handler.add_command(PolicyUpdateCommand::new(&broker_handle)); - command_handler.add_command(BridgeUpdateCommand::new(&bridge_controller.handle())); - command_handler.init().await?; - let command_handler_shutdown = command_handler.shutdown_handle()?; - let command_handler_join_handle = tokio::spawn(command_handler.run()); - - let settings = BridgeSettings::new()?; - let bridge_controller_join_handle = - tokio::spawn(bridge_controller.run(system_address, device_id, settings)); - - let join_handles = vec![command_handler_join_handle, bridge_controller_join_handle]; - let shutdown_handle = SidecarShutdownHandle::new(command_handler_shutdown); - - Ok(Some(SidecarManager::new(join_handles, shutdown_handle))) -} - -async fn server_certificate_renewal(renew_at: DateTime) { - let delay = renew_at - Utc::now(); - if delay > Duration::zero() { - info!( - "scheduled server certificate renewal timer for {}", - renew_at - ); - let delay = delay.to_std().expect("duration must not be negative"); - time::delay_for(delay).await; - - info!("restarting the broker to perform certificate renewal"); - } else { - warn!("server certificate expired at {}", renew_at); - } -} - -#[derive(Debug, thiserror::Error)] -pub enum ServerCertificateLoadError { - #[error("unable to load server certificate from file {0} and private key {1}")] - File(PathBuf, PathBuf), - - #[error("unable to download certificate from edgelet")] - Edgelet, -} - -pub const WORKLOAD_URI: &str = "IOTEDGE_WORKLOADURI"; -pub const EDGE_DEVICE_HOST_NAME: &str = "EdgeDeviceHostName"; -pub const MODULE_ID: &str = "IOTEDGE_MODULEID"; -pub const MODULE_GENERATION_ID: &str = "IOTEDGE_MODULEGENERATIONID"; - -pub const CERTIFICATE_VALIDITY_DAYS: i64 = 90; - -async fn download_server_certificate() -> Result { - let uri = env::var(WORKLOAD_URI).context(WORKLOAD_URI)?; - let hostname = env::var(EDGE_DEVICE_HOST_NAME).context(EDGE_DEVICE_HOST_NAME)?; - let module_id = env::var(MODULE_ID).context(MODULE_GENERATION_ID)?; - let generation_id = env::var(MODULE_GENERATION_ID).context(MODULE_GENERATION_ID)?; - let expiration = Utc::now() + Duration::days(CERTIFICATE_VALIDITY_DAYS); - - let client = edgelet_client::workload(&uri)?; - let cert = client - .create_server_cert(&module_id, &generation_id, &hostname, expiration) - .await?; - - if cert.private_key().type_() != "key" { - bail!( - "unknown type of private key: {}", - cert.private_key().type_() - ); - } - - if let Some(private_key) = cert.private_key().bytes() { - let identity = ServerCertificate::from_pem_pair(cert.certificate(), private_key)?; - Ok(identity) - } else { - bail!("missing private key"); - } -} - -#[cfg(test)] -mod tests { - use std::{ - env, - time::{Duration as StdDuration, Instant}, - }; - - use chrono::{Duration, Utc}; - use mockito::mock; - use serde_json::json; - - use super::{ - download_server_certificate, server_certificate_renewal, EDGE_DEVICE_HOST_NAME, - MODULE_GENERATION_ID, MODULE_ID, WORKLOAD_URI, - }; - - const PRIVATE_KEY: &str = include_str!("../../../../mqtt-broker/test/tls/pkey.pem"); - - const CERTIFICATE: &str = include_str!("../../../../mqtt-broker/test/tls/cert.pem"); - - #[tokio::test] - async fn it_downloads_server_cert() { - let expiration = Utc::now() + Duration::days(90); - let res = json!( - { - "privateKey": { "type": "key", "bytes": PRIVATE_KEY }, - "certificate": CERTIFICATE, - "expiration": expiration.to_rfc3339() - } - ); - - let _m = mock( - "POST", - "/modules/$edgeHub/genid/12345678/certificate/server?api-version=2019-01-30", - ) - .with_status(201) - .with_body(serde_json::to_string(&res).unwrap()) - .create(); - - env::set_var(WORKLOAD_URI, mockito::server_url()); - env::set_var(EDGE_DEVICE_HOST_NAME, "localhost"); - env::set_var(MODULE_ID, "$edgeHub"); - env::set_var(MODULE_GENERATION_ID, "12345678"); - - let res = download_server_certificate().await; - assert!(res.is_ok()); - } - - #[tokio::test] - async fn it_schedules_cert_renewal_in_future() { - let now = Instant::now(); - - let renew_at = Utc::now() + Duration::milliseconds(100); - server_certificate_renewal(renew_at).await; - - let elapsed = now.elapsed(); - assert!(elapsed > StdDuration::from_millis(100)); - assert!(elapsed < StdDuration::from_millis(500)); - } - - #[tokio::test] - async fn it_does_not_schedule_cert_renewal_in_past() { - let now = Instant::now(); - - let renew_at = Utc::now(); - server_certificate_renewal(renew_at).await; - - assert!(now.elapsed() < StdDuration::from_millis(100)); - } -} diff --git a/mqtt/mqttd/src/broker/bootstrap/generic.rs b/mqtt/mqttd/src/broker/bootstrap/generic.rs deleted file mode 100644 index fe2151b4cee..00000000000 --- a/mqtt/mqttd/src/broker/bootstrap/generic.rs +++ /dev/null @@ -1,114 +0,0 @@ -use std::{ - future::Future, - path::{Path, PathBuf}, -}; - -use anyhow::{Context, Result}; -use futures_util::pin_mut; -use thiserror::Error; -use tracing::info; - -use super::SidecarManager; -use mqtt_broker::{ - auth::{authenticate_fn_ok, AllowAll, Authorizer}, - settings::BrokerConfig, - AuthId, Broker, BrokerBuilder, BrokerHandle, BrokerReady, BrokerSnapshot, Error, Server, - ServerCertificate, -}; -use mqtt_generic::settings::{CertificateConfig, ListenerConfig, Settings}; - -pub fn config

(config_path: Option

) -> Result -where - P: AsRef, -{ - let config = if let Some(path) = config_path { - info!("loading settings from a file {}", path.as_ref().display()); - Settings::from_file(path)? - } else { - info!("using default settings"); - Settings::default() - }; - - Ok(config) -} - -pub async fn broker( - config: &BrokerConfig, - state: Option, - _: &BrokerReady, -) -> Result, Error> { - let broker = BrokerBuilder::default() - .with_authorizer(AllowAll) - .with_state(state.unwrap_or_default()) - .with_config(config.clone()) - .build(); - - Ok(broker) -} - -pub async fn start_server( - config: Settings, - broker: Broker, - shutdown_signal: F, - _: BrokerReady, -) -> Result -where - Z: Authorizer + Send + 'static, - F: Future, -{ - info!("starting server..."); - let mut server = Server::from_broker(broker); - - if let Some(tcp) = config.listener().tcp() { - let authenticator = authenticate_fn_ok(|_| Some(AuthId::Anonymous)); - server.with_tcp(tcp.addr(), authenticator, None)?; - } - - if let Some(tls) = config.listener().tls() { - let authenticator = authenticate_fn_ok(|_| Some(AuthId::Anonymous)); - let identity = load_server_certificate(tls.certificate())?; - server.with_tls(tls.addr(), identity, authenticator, None)?; - } - - pin_mut!(shutdown_signal); - let state = server.serve(shutdown_signal).await?; - Ok(state) -} - -// There are currently no sidecars for the generic feature flag, so this is empty -#[derive(Debug, Error)] -pub enum SidecarError {} - -#[derive(Clone, Debug)] -pub struct SidecarShutdownHandle; - -// There are currently no sidecars for the generic feature flag, so this is a no-op -impl SidecarShutdownHandle { - pub async fn shutdown(self) -> Result<(), SidecarError> { - info!("no sidecars to stop"); - Ok(()) - } -} - -pub async fn start_sidecars(_: BrokerHandle, _: ListenerConfig) -> Result> { - info!("no sidecars to start"); - Ok(None) -} - -fn load_server_certificate(config: &CertificateConfig) -> Result { - let identity = ServerCertificate::from_pem(config.cert_path(), config.private_key_path()) - .with_context(|| { - ServerCertificateLoadError::ParseCertificate( - config.cert_path().to_path_buf(), - config.private_key_path().to_path_buf(), - ) - })?; - - Ok(identity) -} - -#[derive(Debug, thiserror::Error)] -pub enum ServerCertificateLoadError { - #[error("unable to decode server certificate {0} and private key {1}")] - ParseCertificate(PathBuf, PathBuf), -} diff --git a/mqtt/mqttd/src/broker/bootstrap/mod.rs b/mqtt/mqttd/src/broker/bootstrap/mod.rs deleted file mode 100644 index 1b42775cdee..00000000000 --- a/mqtt/mqttd/src/broker/bootstrap/mod.rs +++ /dev/null @@ -1,61 +0,0 @@ -use futures_util::future::select_all; -use tokio::task::JoinHandle; - -use tracing::{error, info}; - -#[cfg(feature = "edgehub")] -mod edgehub; - -#[cfg(feature = "edgehub")] -pub use edgehub::{ - broker, config, start_server, start_sidecars, SidecarError, SidecarShutdownHandle, -}; - -#[cfg(all(not(feature = "edgehub"), feature = "generic"))] -mod generic; - -#[cfg(all(not(feature = "edgehub"), feature = "generic"))] -pub use generic::{ - broker, config, start_server, start_sidecars, SidecarError, SidecarShutdownHandle, -}; - -/// Wraps join handles for sidecar processes and exposes single future -/// Exposed future will wait for any sidecar to complete, then shut down the rest -/// Also exposes shutdown handle used to shut down all the sidecars -pub struct SidecarManager { - join_handles: Vec>, - shutdown_handle: SidecarShutdownHandle, -} - -impl SidecarManager { - #![allow(dead_code)] // needed because we have no sidecars for the generic feature - pub fn new(join_handles: Vec>, shutdown_handle: SidecarShutdownHandle) -> Self { - Self { - join_handles, - shutdown_handle, - } - } - - pub async fn wait_for_shutdown(self) -> Result<(), SidecarError> { - let (sidecar_output, _, other_handles) = select_all(self.join_handles).await; - - info!("A sidecar has stopped. Shutting down sidecars..."); - if let Err(e) = sidecar_output { - error!(message = "failed waiting for sidecar shutdown", err = %e); - } - - self.shutdown_handle.shutdown().await?; - - for handle in other_handles { - if let Err(e) = handle.await { - error!(message = "failed waiting for sidecar shutdown", err = %e); - } - } - - Ok(()) - } - - pub fn shutdown_handle(&self) -> SidecarShutdownHandle { - self.shutdown_handle.clone() - } -} diff --git a/mqtt/mqttd/src/broker/bootstrap/server.rs b/mqtt/mqttd/src/broker/bootstrap/server.rs deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/mqtt/mqttd/src/broker/mod.rs b/mqtt/mqttd/src/broker/mod.rs deleted file mode 100644 index 873831879eb..00000000000 --- a/mqtt/mqttd/src/broker/mod.rs +++ /dev/null @@ -1,104 +0,0 @@ -mod bootstrap; -mod shutdown; -mod snapshot; - -use std::{fs, path::Path}; - -use anyhow::{Context, Result}; -use futures_util::{ - future::{select, Either}, - pin_mut, -}; -use tracing::{error, info}; - -use mqtt_broker::{BrokerReady, FilePersistor, Message, Persist, SystemEvent, VersionedFileFormat}; - -use crate::broker::snapshot::start_snapshotter; - -pub async fn run

(config_path: Option

) -> Result<()> -where - P: AsRef, -{ - let settings = bootstrap::config(config_path).context(LoadConfigurationError)?; - let listener_settings = settings.listener().clone(); - - info!("loading state..."); - let persistence_config = settings.broker().persistence(); - let state_dir = persistence_config.file_path(); - - fs::create_dir_all(state_dir.clone())?; - let mut persistor = FilePersistor::new(state_dir, VersionedFileFormat::default()); - let state = persistor.load().await?; - info!("state loaded."); - - let broker_ready = BrokerReady::new(); - - let broker = bootstrap::broker(settings.broker(), state, &broker_ready).await?; - let mut broker_handle = broker.handle(); - - let snapshot_interval = persistence_config.time_interval(); - let (mut snapshotter_shutdown_handle, snapshotter_join_handle) = - start_snapshotter(broker.handle(), persistor, snapshot_interval).await; - - let shutdown_signal = shutdown::shutdown(); - - // start broker - let server_join_handle = tokio::spawn(bootstrap::start_server( - settings, - broker, - shutdown_signal, - broker_ready, - )); - - // start sidecars if they should run - // if not wait for server shutdown - let state = if let Some(sidecar_manager) = - bootstrap::start_sidecars(broker_handle.clone(), listener_settings).await? - { - // wait on future for sidecars or broker - // if one of them exits then shut the other down - let sidecar_shutdown_handle = sidecar_manager.shutdown_handle(); - let sidecars_fut = sidecar_manager.wait_for_shutdown(); - pin_mut!(sidecars_fut); - match select(server_join_handle, sidecars_fut).await { - // server finished first - Either::Left((server_output, sidecars_fut)) => { - // shutdown sidecars - sidecar_shutdown_handle.shutdown().await?; - - // wait for sidecars to finish - if let Err(e) = sidecars_fut.await { - error!(message = "failed running sidecars", err = %e) - } - - // extract state from server - server_output - } - // sidecars finished first - Either::Right((_, server_join_handle)) => { - // signal server and sidecars shutdown - broker_handle.send(Message::System(SystemEvent::Shutdown))?; - - // extract state from server - server_join_handle.await - } - } - } else { - server_join_handle.await - }??; - - snapshotter_shutdown_handle.shutdown().await?; - let mut persistor = snapshotter_join_handle.await?; - info!("state snapshotter shutdown."); - - info!("persisting state before exiting..."); - persistor.store(state).await?; - info!("state persisted."); - - info!("exiting... goodbye"); - Ok(()) -} - -#[derive(Debug, thiserror::Error)] -#[error("An error occurred loading configuration.")] -pub struct LoadConfigurationError; diff --git a/mqtt/mqttd/src/lib.rs b/mqtt/mqttd/src/lib.rs index 710f3c15945..0d8456982c1 100644 --- a/mqtt/mqttd/src/lib.rs +++ b/mqtt/mqttd/src/lib.rs @@ -11,5 +11,6 @@ clippy::missing_errors_doc )] -pub mod broker; +pub mod app; +pub mod time; pub mod tracing; diff --git a/mqtt/mqttd/src/main.rs b/mqtt/mqttd/src/main.rs index f9c05b385cf..581b82c8017 100644 --- a/mqtt/mqttd/src/main.rs +++ b/mqtt/mqttd/src/main.rs @@ -4,7 +4,7 @@ use std::{env, path::PathBuf}; use anyhow::Result; use clap::{crate_description, crate_name, crate_version, App, Arg}; -use mqttd::{broker, tracing}; +use mqttd::{app, tracing}; #[tokio::main] async fn main() -> Result<()> { @@ -15,7 +15,13 @@ async fn main() -> Result<()> { .value_of("config") .map(PathBuf::from); - broker::run(config_path).await?; + let mut app = app::new(); + if let Some(config_path) = config_path { + app.setup(config_path)?; + } + + app.run().await?; + Ok(()) } diff --git a/mqtt/mqttd/src/time.rs b/mqtt/mqttd/src/time.rs new file mode 100644 index 00000000000..47383481f08 --- /dev/null +++ b/mqtt/mqttd/src/time.rs @@ -0,0 +1,57 @@ +//! `tokio::time::delay_for` and `tokio::time::delay_until` has a bug that +//! prevents to schedule a very long timeout (more than 2 years). +//! To mitigate this issue we split a long interval by the number +//! of small intervals and schedule small intervals and waits a small one until +//! we reach a target. + +use std::{ + cmp, + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use futures_util::{ready, FutureExt}; +use tokio::time::{self, Delay, Duration, Instant}; + +const DEFAULT_DURATION: Duration = Duration::from_secs(30 * 24 * 60 * 60); // 30 days + +/// Waits until `duration` has elapsed. +pub fn sleep(duration: Duration) -> Sleep { + sleep_until(Instant::now() + duration) +} + +/// Waits until `deadline` is reached. +pub fn sleep_until(deadline: Instant) -> Sleep { + Sleep { + deadline, + delay: next_delay(deadline, DEFAULT_DURATION), + } +} + +/// A future returned by `sleep` and `sleep_until` +pub struct Sleep { + deadline: Instant, + delay: Delay, +} + +impl Future for Sleep { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + ready!(self.delay.poll_unpin(cx)); + + if Instant::now() >= self.deadline { + Poll::Ready(()) + } else { + self.delay = next_delay(self.deadline, DEFAULT_DURATION); + cx.waker().wake_by_ref(); + Poll::Pending + } + } +} + +fn next_delay(deadline: Instant, duration: Duration) -> Delay { + let delay = cmp::min(deadline, Instant::now() + duration); + time::delay_until(delay) +} diff --git a/mqtt/mqttd/src/tracing/edgehub.rs b/mqtt/mqttd/src/tracing/edgehub.rs index f8d8b7eafeb..c653db837fa 100644 --- a/mqtt/mqttd/src/tracing/edgehub.rs +++ b/mqtt/mqttd/src/tracing/edgehub.rs @@ -1,23 +1,28 @@ -use std::io; +use std::env; -use tracing::Level; +use tracing::{log::LevelFilter, Level}; +use tracing_log::LogTracer; use tracing_subscriber::{fmt, EnvFilter}; +use super::Format; + const BROKER_LOG_LEVEL_ENV: &str = "BROKER_LOG"; const EDGE_HUB_LOG_LEVEL_ENV: &str = "RuntimeLogLevel"; pub fn init() { - let log_level = EnvFilter::try_from_env(BROKER_LOG_LEVEL_ENV) - .or_else(|_| EnvFilter::try_from_env(EDGE_HUB_LOG_LEVEL_ENV)) - .or_else(|_| EnvFilter::try_from_default_env()) - .unwrap_or_else(|_| EnvFilter::new("info")); + let log_level = env::var(BROKER_LOG_LEVEL_ENV) + .or_else(|_| env::var(EDGE_HUB_LOG_LEVEL_ENV)) + .or_else(|_| env::var(EnvFilter::DEFAULT_ENV)) + .map_or_else(|_| "info".into(), |level| level.to_lowercase()); let subscriber = fmt::Subscriber::builder() - .with_ansi(atty::is(atty::Stream::Stderr)) .with_max_level(Level::TRACE) - .with_writer(io::stderr) - .with_env_filter(log_level) + .on_event(Format::default()) + .with_env_filter(EnvFilter::new(log_level.clone())) .finish(); let _ = tracing::subscriber::set_global_default(subscriber); + + let filter = log_level.parse().unwrap_or(LevelFilter::Info); + let _ = LogTracer::builder().with_max_level(filter).init(); } diff --git a/mqtt/mqttd/src/tracing/format.rs b/mqtt/mqttd/src/tracing/format.rs new file mode 100644 index 00000000000..cd80380b3fd --- /dev/null +++ b/mqtt/mqttd/src/tracing/format.rs @@ -0,0 +1,117 @@ +use std::marker::PhantomData; + +use tracing::{Event, Level}; +use tracing_log::NormalizeEvent; +use tracing_subscriber::fmt::{ + time::ChronoLocal, time::FormatTime, Context, FormatEvent, NewVisitor, +}; + +/// Marker for `Format` that indicates that the syslog format should be used. +pub(crate) struct EdgeHub; + +/// Custom event formatter. +pub(crate) struct Format { + format: PhantomData, + timer: T, +} + +impl Default for Format { + fn default() -> Self { + Format { + format: PhantomData, + timer: ChronoLocal::with_format("%F %T.%3f %:z".into()), + } + } +} + +impl FormatEvent for Format +where + N: for<'a> NewVisitor<'a>, + T: FormatTime, +{ + fn format_event( + &self, + ctx: &Context<'_, N>, + writer: &mut dyn std::fmt::Write, + event: &Event<'_>, + ) -> std::fmt::Result { + let normalized_meta = event.normalized_metadata(); + let meta = normalized_meta.as_ref().unwrap_or_else(|| event.metadata()); + + let (fmt_level, fmt_ctx) = (FmtLevel::new(meta.level()), FullCtx::new(&ctx)); + + write!(writer, "<{}> ", fmt_level.syslog_level())?; + self.timer.format_time(writer)?; + write!(writer, "[{}] [{}{}] - ", fmt_level, fmt_ctx, meta.target(),)?; + + { + let mut recorder = ctx.new_visitor(writer, true); + event.record(&mut recorder); + } + + writeln!(writer) + } +} + +/// Wrapper around `Level` to format it accordingly to `syslog` rules. +struct FmtLevel<'a> { + level: &'a Level, +} + +impl<'a> FmtLevel<'a> { + fn new(level: &'a Level) -> Self { + Self { level } + } + + fn syslog_level(&self) -> i8 { + match *self.level { + Level::ERROR => 3, + Level::WARN => 4, + Level::INFO => 6, + Level::DEBUG | Level::TRACE => 7, + } + } +} + +impl<'a> std::fmt::Display for FmtLevel<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match *self.level { + Level::TRACE => f.pad("TRC"), + Level::DEBUG => f.pad("DBG"), + Level::INFO => f.pad("INF"), + Level::WARN => f.pad("WRN"), + Level::ERROR => f.pad("ERR"), + } + } +} + +/// Wrapper around log entry context to format entry. +struct FullCtx<'a, N> { + ctx: &'a Context<'a, N>, +} + +impl<'a, N: 'a> FullCtx<'a, N> { + fn new(ctx: &'a Context<'a, N>) -> Self { + Self { ctx } + } +} + +impl<'a, N> std::fmt::Display for FullCtx<'a, N> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut seen = false; + self.ctx.visit_spans(|_, span| { + write!(f, "{}", span.name())?; + seen = true; + + let fields = span.fields(); + if !fields.is_empty() { + write!(f, "{{{}}}", fields)?; + } + ":".fmt(f) + })?; + if seen { + f.pad(" ")?; + } + Ok(()) + } +} diff --git a/mqtt/mqttd/src/tracing/generic.rs b/mqtt/mqttd/src/tracing/generic.rs index 5d1d50378df..393d8ac8170 100644 --- a/mqtt/mqttd/src/tracing/generic.rs +++ b/mqtt/mqttd/src/tracing/generic.rs @@ -1,7 +1,7 @@ -use std::io; - use tracing::Level; -use tracing_subscriber::{fmt, EnvFilter}; +use tracing_subscriber::{fmt::Subscriber, EnvFilter}; + +use super::Format; const BROKER_LOG_LEVEL_ENV: &str = "BROKER_LOG"; @@ -10,11 +10,10 @@ pub fn init() { .or_else(|_| EnvFilter::try_from_default_env()) .unwrap_or_else(|_| EnvFilter::new("info")); - let subscriber = fmt::Subscriber::builder() - .with_ansi(atty::is(atty::Stream::Stderr)) + let subscriber = Subscriber::builder() .with_max_level(Level::TRACE) - .with_writer(io::stderr) .with_env_filter(log_level) + .on_event(Format::default()) .finish(); let _ = tracing::subscriber::set_global_default(subscriber); } diff --git a/mqtt/mqttd/src/tracing/mod.rs b/mqtt/mqttd/src/tracing/mod.rs index 46b65944b91..07457d391a5 100644 --- a/mqtt/mqttd/src/tracing/mod.rs +++ b/mqtt/mqttd/src/tracing/mod.rs @@ -9,3 +9,6 @@ mod generic; #[cfg(all(not(feature = "edgehub"), feature = "generic"))] pub use generic::init; + +mod format; +use format::Format; diff --git a/mqtt/policy/Cargo.toml b/mqtt/policy/Cargo.toml index c08edff3a58..74c3b67ca13 100644 --- a/mqtt/policy/Cargo.toml +++ b/mqtt/policy/Cargo.toml @@ -7,6 +7,7 @@ version = "0.1.0" [dependencies] lazy_static = "1" +proptest = { version = "0.9", optional = true } regex = "1" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" @@ -14,3 +15,4 @@ thiserror = "1.0" [dev-dependencies] matches = "0.1" +itertools = "0.9" \ No newline at end of file diff --git a/mqtt/policy/src/core/builder.rs b/mqtt/policy/src/core/builder.rs index 39f22078923..a9f67a6d370 100644 --- a/mqtt/policy/src/core/builder.rs +++ b/mqtt/policy/src/core/builder.rs @@ -19,12 +19,12 @@ pub struct PolicyBuilder { validator: V, matcher: M, substituter: S, - json: String, + source: Source, default_decision: Decision, } impl PolicyBuilder { - /// Constructs a `PolicyBuilder` from provided json policy definition and + /// Constructs a `PolicyBuilder` from provided json policy definition, with /// default configuration. /// /// Call to this method does not parse or validate the json, all heavy work @@ -33,7 +33,24 @@ impl PolicyBuilder json: impl Into, ) -> PolicyBuilder { PolicyBuilder { - json: json.into(), + source: Source::Json(json.into()), + validator: DefaultValidator, + matcher: DefaultResourceMatcher, + substituter: DefaultSubstituter, + default_decision: Decision::Denied, + } + } + + /// Constructs a `PolicyBuilder` from provided policy definition struct, with + /// default configuration. + /// + /// Call to this method does not validate the definition, all heavy work + /// is done in `build` method. + pub fn from_definition( + definition: PolicyDefinition, + ) -> PolicyBuilder { + PolicyBuilder { + source: Source::Definition(definition), validator: DefaultValidator, matcher: DefaultResourceMatcher, substituter: DefaultSubstituter, @@ -52,7 +69,7 @@ where /// Specifies the `PolicyValidator` to validate the policy definition. pub fn with_validator(self, validator: V1) -> PolicyBuilder { PolicyBuilder { - json: self.json, + source: self.source, validator, matcher: self.matcher, substituter: self.substituter, @@ -63,7 +80,7 @@ where /// Specifies the `ResourceMatcher` to use with `Policy`. pub fn with_matcher(self, matcher: M1) -> PolicyBuilder { PolicyBuilder { - json: self.json, + source: self.source, validator: self.validator, matcher, substituter: self.substituter, @@ -74,7 +91,7 @@ where /// Specifies the `Substituter` to use with `Policy`. pub fn with_substituter(self, substituter: S1) -> PolicyBuilder { PolicyBuilder { - json: self.json, + source: self.source, validator: self.validator, matcher: self.matcher, substituter, @@ -96,13 +113,26 @@ where /// /// Any validation errors are collected and returned as `Error::ValidationSummary`. pub fn build(self) -> Result> { - let mut definition: PolicyDefinition = PolicyDefinition::from_json(&self.json)?; + let PolicyBuilder { + validator, + matcher, + substituter, + source, + default_decision, + } = self; + + let mut definition: PolicyDefinition = match source { + Source::Json(json) => PolicyDefinition::from_json(&json)?, + Source::Definition(definition) => definition, + }; for (order, mut statement) in definition.statements.iter_mut().enumerate() { statement.order = order; } - self.validate(&definition)?; + validator + .validate(&definition) + .map_err(|e| Error::Validation(e.into()))?; let mut static_rules = Identities::new(); let mut variable_rules = Identities::new(); @@ -112,19 +142,13 @@ where } Ok(Policy { - default_decision: self.default_decision, - resource_matcher: self.matcher, - substituter: self.substituter, + default_decision, + resource_matcher: matcher, + substituter, static_rules: static_rules.0, variable_rules: variable_rules.0, }) } - - fn validate(&self, definition: &PolicyDefinition) -> Result<()> { - self.validator - .validate(definition) - .map_err(|e| Error::Validation(e.into())) - } } fn process_statement( @@ -215,11 +239,16 @@ fn is_variable_rule(value: &str) -> bool { VAR_PATTERN.is_match(value) } +enum Source { + Json(String), + Definition(PolicyDefinition), +} + /// Represents a deserialized policy definition. -#[derive(Deserialize)] +#[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] pub struct PolicyDefinition { - statements: Vec, + pub(super) statements: Vec, } impl PolicyDefinition { @@ -236,18 +265,18 @@ impl PolicyDefinition { } /// Represents a statement in a policy definition. -#[derive(Deserialize)] +#[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] pub struct Statement { #[serde(default)] - order: usize, + pub(super) order: usize, #[serde(default)] - description: String, - effect: Effect, - identities: Vec, - operations: Vec, + pub(super) description: String, + pub(super) effect: Effect, + pub(super) identities: Vec, + pub(super) operations: Vec, #[serde(default)] - resources: Vec, + pub(super) resources: Vec, } impl Statement { @@ -277,7 +306,7 @@ impl Statement { } /// Represents an effect on a statement. -#[derive(Deserialize, Copy, Clone)] +#[derive(Debug, Deserialize, Copy, Clone)] #[serde(rename_all = "camelCase")] pub enum Effect { Allow, @@ -291,7 +320,7 @@ mod tests { use matches::assert_matches; use crate::{ - core::{tests::build_policy, Effect as CoreEffect, EffectOrd}, + core::{tests::build_policy, EffectImpl, EffectOrd}, validator::ValidatorError, }; @@ -544,7 +573,7 @@ mod tests { assert_eq!( EffectOrd { order: 0, - effect: CoreEffect::Allow + effect: EffectImpl::Allow }, policy.static_rules["actor_a"].0["write"].0["events/telemetry"] ); @@ -553,7 +582,7 @@ mod tests { assert_eq!( EffectOrd { order: 2, - effect: CoreEffect::Allow + effect: EffectImpl::Allow }, policy.variable_rules["actor_a"].0["read"].0["{{variable}}/#"] ); @@ -591,28 +620,28 @@ mod tests { assert_eq!( policy.static_rules["actor_a"].0["write"].0["events/telemetry"], EffectOrd { - effect: CoreEffect::Allow, + effect: EffectImpl::Allow, order: 0 } ); assert_eq!( policy.static_rules["actor_a"].0["read"].0["events/telemetry"], EffectOrd { - effect: CoreEffect::Allow, + effect: EffectImpl::Allow, order: 0 } ); assert_eq!( policy.static_rules["actor_b"].0["write"].0["events/telemetry"], EffectOrd { - effect: CoreEffect::Allow, + effect: EffectImpl::Allow, order: 0 } ); assert_eq!( policy.static_rules["actor_b"].0["read"].0["events/telemetry"], EffectOrd { - effect: CoreEffect::Allow, + effect: EffectImpl::Allow, order: 0 } ); @@ -622,42 +651,42 @@ mod tests { assert_eq!( policy.variable_rules["actor_a"].0["write"].0["devices/{{variable}}/#"], EffectOrd { - effect: CoreEffect::Allow, + effect: EffectImpl::Allow, order: 0 } ); assert_eq!( policy.variable_rules["actor_a"].0["read"].0["devices/{{variable}}/#"], EffectOrd { - effect: CoreEffect::Allow, + effect: EffectImpl::Allow, order: 0 } ); assert_eq!( policy.variable_rules["actor_b"].0["write"].0["devices/{{variable}}/#"], EffectOrd { - effect: CoreEffect::Allow, + effect: EffectImpl::Allow, order: 0 } ); assert_eq!( policy.variable_rules["actor_b"].0["read"].0["devices/{{variable}}/#"], EffectOrd { - effect: CoreEffect::Allow, + effect: EffectImpl::Allow, order: 0 } ); assert_eq!( policy.variable_rules["{{var_actor}}"].0["write"].0["devices/{{variable}}/#"], EffectOrd { - effect: CoreEffect::Allow, + effect: EffectImpl::Allow, order: 0 } ); assert_eq!( policy.variable_rules["{{var_actor}}"].0["read"].0["devices/{{variable}}/#"], EffectOrd { - effect: CoreEffect::Allow, + effect: EffectImpl::Allow, order: 0 } ); diff --git a/mqtt/policy/src/core/mod.rs b/mqtt/policy/src/core/mod.rs index a90d4f41140..3f753a787ae 100644 --- a/mqtt/policy/src/core/mod.rs +++ b/mqtt/policy/src/core/mod.rs @@ -7,7 +7,7 @@ use crate::errors::Result; use crate::{substituter::Substituter, Error, ResourceMatcher}; mod builder; -pub use builder::{PolicyBuilder, PolicyDefinition, Statement}; +pub use builder::{Effect, PolicyBuilder, PolicyDefinition, Statement}; /// Policy engine. Represents a read-only set of rules and can /// evaluate `Request` based on those rules. @@ -38,12 +38,12 @@ where match self.eval_static_rules(request) { // static rules not defined. Need to check variable rules. Ok(EffectOrd { - effect: Effect::Undefined, + effect: EffectImpl::Undefined, .. }) => match self.eval_variable_rules(request) { // variable rules undefined as well. Return default decision. Ok(EffectOrd { - effect: Effect::Undefined, + effect: EffectImpl::Undefined, .. }) => Ok(self.default_decision), // variable rules defined. Return the decision. @@ -55,7 +55,7 @@ where match self.eval_variable_rules(request) { // variable rules undefined. Proceed with static rule decision. Ok(EffectOrd { - effect: Effect::Undefined, + effect: EffectImpl::Undefined, .. }) => Ok(static_effect.into()), // variable rules defined. Compare priority and return. @@ -83,8 +83,8 @@ where Some(operations) => match operations.0.get(&request.operation) { // operation exists. Some(resources) => { - // Iterate over and match resources. - // We need to go through all resources and find one with highest priority (smallest order). + // iterate over and match resources. + // we need to go through all resources and find one with highest priority (smallest order). let mut result = &EffectOrd::undefined(); for (resource, effect) in &resources.0 { if effect.order < result.order // check the order @@ -115,8 +115,8 @@ where return match operations.0.get(&request.operation) { // operation exists. Some(resources) => { - // Iterate over and match resources. - // We need to go through all resources and find one with highest priority (smallest order). + // iterate over and match resources. + // we need to go through all resources and find one with highest priority (smallest order). let mut result = &EffectOrd::undefined(); for (resource, effect) in &resources.0 { let resource = self.substituter.visit_resource(resource, request)?; @@ -130,9 +130,16 @@ where result = effect; } } + // continue to look for other identity variable rules + // if no resources matched the current one. + if result == &EffectOrd::undefined() { + continue; + } Ok(*result) } - None => Ok(EffectOrd::undefined()), + // continue to look for other identity variable rules + // if no operation found for the current one. + None => continue, }; } } @@ -297,18 +304,8 @@ pub enum Decision { Denied, } -impl From for Decision { - fn from(effect: Effect) -> Self { - match effect { - Effect::Allow => Decision::Allowed, - Effect::Deny => Decision::Denied, - Effect::Undefined => Decision::Denied, - } - } -} - #[derive(Debug, Copy, Clone, PartialOrd, PartialEq)] -enum Effect { +enum EffectImpl { Allow, Deny, Undefined, @@ -317,18 +314,18 @@ enum Effect { #[derive(Debug, Copy, Clone, PartialEq)] struct EffectOrd { order: usize, - effect: Effect, + effect: EffectImpl, } impl EffectOrd { - pub fn new(effect: Effect, order: usize) -> Self { + pub fn new(effect: EffectImpl, order: usize) -> Self { Self { order, effect } } pub fn undefined() -> Self { Self { order: usize::MAX, - effect: Effect::Undefined, + effect: EffectImpl::Undefined, } } @@ -351,9 +348,9 @@ impl PartialOrd for EffectOrd { impl From for Decision { fn from(effect: EffectOrd) -> Self { match effect.effect { - Effect::Allow => Decision::Allowed, - Effect::Deny => Decision::Denied, - Effect::Undefined => Decision::Denied, + EffectImpl::Allow => Decision::Allowed, + EffectImpl::Deny => Decision::Denied, + EffectImpl::Undefined => Decision::Denied, } } } @@ -361,8 +358,8 @@ impl From for Decision { impl From<&Statement> for EffectOrd { fn from(statement: &Statement) -> Self { match statement.effect() { - builder::Effect::Allow => EffectOrd::new(Effect::Allow, statement.order()), - builder::Effect::Deny => EffectOrd::new(Effect::Deny, statement.order()), + builder::Effect::Allow => EffectOrd::new(EffectImpl::Allow, statement.order()), + builder::Effect::Deny => EffectOrd::new(EffectImpl::Deny, statement.order()), } } } @@ -433,7 +430,7 @@ pub(crate) mod tests { { "effect": "allow", "identities": [ - "contoso.azure-devices.net/some_device" + "actor_a" ], "operations": [ "write" @@ -525,7 +522,7 @@ pub(crate) mod tests { let policy = PolicyBuilder::from_json(json) .with_default_decision(Decision::Denied) - .with_substituter(TestSubstituter) + .with_substituter(TestIdentitySubstituter) .build() .expect("Unable to build policy from json."); @@ -592,7 +589,7 @@ pub(crate) mod tests { let policy = PolicyBuilder::from_json(json) .with_default_decision(Decision::Denied) - .with_substituter(TestSubstituter) + .with_substituter(TestIdentitySubstituter) .build() .expect("Unable to build policy from json."); @@ -647,7 +644,7 @@ pub(crate) mod tests { let policy = PolicyBuilder::from_json(json) .with_default_decision(Decision::Denied) - .with_substituter(TestSubstituter) + .with_substituter(TestIdentitySubstituter) .with_matcher(StartWithMatcher) .build() .expect("Unable to build policy from json."); @@ -692,7 +689,7 @@ pub(crate) mod tests { let policy = PolicyBuilder::from_json(json) .with_default_decision(Decision::Denied) - .with_substituter(TestSubstituter) + .with_substituter(TestIdentitySubstituter) .with_matcher(StartWithMatcher) .build() .expect("Unable to build policy from json."); @@ -702,25 +699,131 @@ pub(crate) mod tests { assert_matches!(policy.evaluate(&request), Ok(Decision::Allowed)); } - /// `TestSubstituter` replaces any value with the corresponding identity or resource + /// Scenario: + /// - Have a policy with a custom identity matcher + /// - Have two variable rules (deny and allow) for an identity, such that + /// both rules match a given request identity. + /// - But the two rules must be different in resources. + /// - Make a request to the allowed resource. + /// - The deny rule resources do not match the request. + /// - The allow rule resources do match the request. + /// - Expected: request allowed. + /// + /// This case is created as a result of a discovered bug. + #[test] + fn all_identity_variable_rules_must_be_evaluated_resources_do_not_match() { + let json = r###"{ + "schemaVersion": "2020-10-30", + "statements": [ + { + "effect": "deny", + "identities": [ + "{{any}}" + ], + "operations": [ + "write" + ], + "resources": [ + "hello/b" + ] + }, + { + "effect": "allow", + "identities": [ + "{{identity}}" + ], + "operations": [ + "write" + ], + "resources": [ + "hello/a" + ] + } + ] + }"###; + + let policy = PolicyBuilder::from_json(json) + .with_default_decision(Decision::Denied) + .with_substituter(TestIdentitySubstituter) + .with_matcher(DefaultResourceMatcher) + .build() + .expect("Unable to build policy from json."); + + let request = Request::new("actor_a", "write", "hello/a").unwrap(); + + assert_matches!(policy.evaluate(&request), Ok(Decision::Allowed)); + } + + /// Scenario: + /// - The same as test case above, + /// but statement operations are different. + /// + /// This case is created as a result of a discovered bug. + #[test] + fn all_identity_variable_rules_must_be_evaluated_operations_do_not_match() { + let json = r###"{ + "schemaVersion": "2020-10-30", + "statements": [ + { + "effect": "deny", + "identities": [ + "{{any}}" + ], + "operations": [ + "read" + ], + "resources": [ + "hello/b" + ] + }, + { + "effect": "allow", + "identities": [ + "{{identity}}" + ], + "operations": [ + "write" + ], + "resources": [ + "hello/a" + ] + } + ] + }"###; + + let policy = PolicyBuilder::from_json(json) + .with_default_decision(Decision::Denied) + .with_substituter(TestIdentitySubstituter) + .with_matcher(DefaultResourceMatcher) + .build() + .expect("Unable to build policy from json."); + + let request = Request::new("actor_a", "write", "hello/a").unwrap(); + + assert_matches!(policy.evaluate(&request), Ok(Decision::Allowed)); + } + + /// `TestSubstituter` replaces any value with the corresponding identity /// from the request, thus making the variable rule to always match the request. - struct TestSubstituter; + #[derive(Debug)] + struct TestIdentitySubstituter; - impl Substituter for TestSubstituter { + impl Substituter for TestIdentitySubstituter { type Context = (); fn visit_identity(&self, _value: &str, context: &Request) -> Result { Ok(context.identity.clone()) } - fn visit_resource(&self, _value: &str, context: &Request) -> Result { - Ok(context.resource.clone()) + fn visit_resource(&self, value: &str, _context: &Request) -> Result { + Ok(value.into()) } } /// `StartWithMatcher` matches resources that start with requested value. For /// example, if a policy defines a resource "hello/world", then request for "hello/" /// will match. + #[derive(Debug)] struct StartWithMatcher; impl ResourceMatcher for StartWithMatcher { @@ -730,4 +833,98 @@ pub(crate) mod tests { policy.starts_with(input) } } + + #[cfg(feature = "proptest")] + mod proptests { + use crate::{Decision, Effect, PolicyBuilder, PolicyDefinition, Request, Statement}; + use proptest::{collection::vec, prelude::*}; + + proptest! { + /// The goal of this test is to verify the following scenarios: + /// - PolicyBuilder does not crash. + /// - All combinations of identity/operation/resource in a statement in the definition + /// should produce expected result. + /// Since some statements can be overridden by the previous ones, + /// we can only safely verify the very first statement. + #[test] + fn policy_engine_proptest(definition in arb_policy_definition()){ + use itertools::iproduct; + + // take very first statement, which should have top priority. + let statement = &definition.statements()[0]; + let expected = match statement.effect() { + Effect::Allow => Decision::Allowed, + Effect::Deny => Decision::Denied, + }; + + // collect all combos of identity/operation/resource + // in the statement. + let requests = iproduct!( + statement.identities(), + statement.operations(), + statement.resources() + ) + .map(|item| Request::new(item.0, item.1, item.2).expect("unable to create a request")) + .collect::>(); + + let policy = PolicyBuilder::from_definition(definition) + .build() + .expect("unable to build policy from definition"); + + // evaluate and assert. + for request in requests { + assert_eq!(policy.evaluate(&request).unwrap(), expected); + } + } + } + + prop_compose! { + pub fn arb_policy_definition()( + statements in vec(arb_statement(), 1..5) + ) -> PolicyDefinition { + PolicyDefinition { + statements + } + } + } + + prop_compose! { + pub fn arb_statement()( + description in arb_description(), + effect in arb_effect(), + identities in vec(arb_identity(), 1..5), + operations in vec(arb_operation(), 1..5), + resources in vec(arb_resource(), 1..5), + ) -> Statement { + Statement{ + order: 0, + description, + effect, + identities, + operations, + resources, + } + } + } + + pub fn arb_effect() -> impl Strategy { + prop_oneof![Just(Effect::Allow), Just(Effect::Deny)] + } + + pub fn arb_description() -> impl Strategy { + "\\PC+" + } + + pub fn arb_identity() -> impl Strategy { + "(\\PC+)|(\\{\\{\\PC+\\}\\})" + } + + pub fn arb_operation() -> impl Strategy { + "\\PC+" + } + + pub fn arb_resource() -> impl Strategy { + "\\PC+(/(\\PC+|\\{\\{\\PC+\\}\\}))*" + } + } } diff --git a/mqtt/policy/src/lib.rs b/mqtt/policy/src/lib.rs index a43907427ae..40058f69826 100644 --- a/mqtt/policy/src/lib.rs +++ b/mqtt/policy/src/lib.rs @@ -17,7 +17,7 @@ mod matcher; mod substituter; mod validator; -pub use crate::core::{Decision, Policy, Request}; +pub use crate::core::{Decision, Effect, Policy, Request}; pub use crate::core::{PolicyBuilder, PolicyDefinition, Statement}; pub use crate::errors::{Error, Result}; pub use crate::matcher::{DefaultResourceMatcher, ResourceMatcher}; diff --git a/scripts/linux/buildAPIProxy.sh b/scripts/linux/buildAPIProxy.sh index 90b86a4ed7d..282a398b433 100644 --- a/scripts/linux/buildAPIProxy.sh +++ b/scripts/linux/buildAPIProxy.sh @@ -151,11 +151,11 @@ build_project() { # build project with cross if [[ "$ARCH" == "amd64" ]]; then - execute scripts/linux/cross-platform-rust-build.sh --os ubuntu18.04 --arch "amd64" --build-path edge-modules/api-proxy-module + execute scripts/linux/cross-platform-rust-build.sh --os alpine --arch "amd64" --build-path edge-modules/api-proxy-module elif [[ "$ARCH" == "arm32v7" ]]; then - docker run --rm -t -v "${API_PROXY_DIR}"/../..:/home/rust/src messense/rust-musl-cross:armv7-musleabihf /bin/bash -c " rm -frv ~/.rustup/toolchains/* &&curl -sSLf https://sh.rustup.rs | sh -s -- -y && rustup target add armv7-unknown-linux-musleabihf && cargo build --target=armv7-unknown-linux-musleabihf --release --manifest-path /home/rust/src/edge-modules/api-proxy-module/Cargo.toml" + execute scripts/linux/cross-platform-rust-build.sh --os alpine --arch "arm32v7" --build-path edge-modules/api-proxy-module elif [[ "$ARCH" == "arm64v8" ]]; then - execute scripts/linux/cross-platform-rust-build.sh --os ubuntu18.04 --arch "aarch64" --build-path edge-modules/api-proxy-module + execute scripts/linux/cross-platform-rust-build.sh --os alpine --arch "aarch64" --build-path edge-modules/api-proxy-module else echo "Cannot run script Unsupported architecture $ARCH" exit 1 diff --git a/scripts/linux/createArmBase.sh b/scripts/linux/createArmBase.sh index 9bf13f4183e..64153fb7567 100755 --- a/scripts/linux/createArmBase.sh +++ b/scripts/linux/createArmBase.sh @@ -46,8 +46,12 @@ usage() echo "Note: You might have to run this as root or sudo." echo "Note: This script is only applicable on ARM architectures." echo "" + echo "Note: When pushing base images, please familiarize yourself with the versioning semantics." + echo "Note: All base images in a branch should have the same version." + echo "Note: This means that if one base-image gets updated, the others should get retagged/pushed also." + echo "" echo " -a, --arch Architecture of the Docker image; either 'armv7l' or 'aarch64'. Defaults to 'uname -m'" - echo " -i, --image-name Image name (azureiotedge-module-base, azureiotedge-agent-base, or azureiotedge-hub-base)" + echo " -i, --image-name Image name (azureiotedge-module-base, azureiotedge-module-base-full, azureiotedge-hub-base, azureiotedge-agent-base, azureiotedge-iotedged-base, or azureiotedge-proxy-base)" echo " -d, --project-dir Project directory (required)." echo " Directory which contains docker/linux/arm32v7/base/Dockerfile or docker/linux/arm64v8/base/Dockerfile" echo " -n, --namespace Docker namespace (default: $DEFAULT_DOCKER_NAMESPACE)" @@ -114,7 +118,7 @@ process_args() fi if [[ "azureiotedge-module-base" != ${DOCKER_IMAGENAME} ]] && [[ "azureiotedge-hub-base" != ${DOCKER_IMAGENAME} ]] && [[ "azureiotedge-agent-base" != ${DOCKER_IMAGENAME} ]] && [[ "azureiotedge-iotedged-base" != ${DOCKER_IMAGENAME} ]] && [[ "azureiotedge-proxy-base" != ${DOCKER_IMAGENAME} ]] && [[ "azureiotedge-module-base-full" != ${DOCKER_IMAGENAME} ]]; then - echo "Docker image name must be one of azureiotedge-module-base, azureiotedge-module-base-full, azureiotedge-hub-base, azureiotedge-agent-base, azureiotedge-iotedged-base or azureiotedge-proxy-base" + echo "Docker image name must be one of azureiotedge-module-base, azureiotedge-module-base-full, azureiotedge-hub-base, azureiotedge-agent-base, azureiotedge-iotedged-base, or azureiotedge-proxy-base" print_help_and_exit fi diff --git a/scripts/linux/cross-platform-rust-build.sh b/scripts/linux/cross-platform-rust-build.sh index c5c7d71947f..ccc553c51d3 100755 --- a/scripts/linux/cross-platform-rust-build.sh +++ b/scripts/linux/cross-platform-rust-build.sh @@ -70,6 +70,10 @@ case "$PACKAGE_OS" in 'ubuntu18.04') DOCKER_IMAGE='ubuntu:18.04' ;; + + 'alpine') + DOCKER_IMAGE='ubuntu:18.04' + ;; esac if [ -z "$DOCKER_IMAGE" ]; then @@ -77,27 +81,9 @@ if [ -z "$DOCKER_IMAGE" ]; then exit 1 fi -case "$PACKAGE_ARCH" in - 'amd64') - RUST_TARGET='x86_64-unknown-linux-musl' - ;; - - 'arm32v7') - RUST_TARGET='armv7-unknown-linux-gnueabihf' - ;; - - 'aarch64') - RUST_TARGET='aarch64-unknown-linux-gnu' - ;; -esac - -if [ -n "$RUST_TARGET" ]; then - RUST_TARGET_COMMAND="rustup target add $RUST_TARGET &&" -fi - - case "$PACKAGE_OS.$PACKAGE_ARCH" in - ubuntu18.04.amd64) + alpine.amd64) + RUST_TARGET='x86_64-unknown-linux-musl' # The below SETUP was copied from https://github.com/emk/rust-musl-builder/blob/master/Dockerfile. SETUP_COMMAND=$' OPENSSL_VERSION=1.1.1g @@ -149,10 +135,19 @@ case "$PACKAGE_OS.$PACKAGE_ARCH" in export PKG_CONFIG_ALL_STATIC=true export LIBZ_SYS_STATIC=1 export TARGET=musl + cd /project/$BUILD_PATH && + echo \'Installing rustup\' && + curl -sSLf https://sh.rustup.rs | sh -s -- -y && + . ~/.cargo/env && ' + MAKE_FLAGS="'CARGOFLAGS=$CARGOFLAGS --target x86_64-unknown-linux-musl'" + MAKE_FLAGS="$MAKE_FLAGS 'TARGET=target/x86_64-unknown-linux-musl/release'" + MAKE_FLAGS="$MAKE_FLAGS 'STRIP_COMMAND=strip'" ;; ubuntu18.04.arm32v7) + RUST_TARGET='armv7-unknown-linux-gnueabihf' + SETUP_COMMAND=$' sources="$(cat /etc/apt/sources.list | grep -E \'^[^#]\')" && # Update existing repos to be specifically for amd64 @@ -178,10 +173,105 @@ case "$PACKAGE_OS.$PACKAGE_ARCH" in echo \'linker = \"arm-linux-gnueabihf-gcc\"\' >> ~/.cargo/config && export ARMV7_UNKNOWN_LINUX_GNUEABIHF_OPENSSL_LIB_DIR=/usr/lib/arm-linux-gnueabihf && export ARMV7_UNKNOWN_LINUX_GNUEABIHF_OPENSSL_INCLUDE_DIR=/usr/include && + cd /project/$BUILD_PATH && + echo \'Installing rustup\' && + curl -sSLf https://sh.rustup.rs | sh -s -- -y && + . ~/.cargo/env && ' + MAKE_FLAGS="'CARGOFLAGS=$CARGOFLAGS --target armv7-unknown-linux-gnueabihf'" + MAKE_FLAGS="$MAKE_FLAGS 'TARGET=target/armv7-unknown-linux-gnueabihf/release'" + MAKE_FLAGS="$MAKE_FLAGS 'STRIP_COMMAND=/usr/arm-linux-gnueabihf/bin/strip'" ;; - ubuntu18.04.aarch64) + alpine.arm32v7) + RUST_TARGET='armv7-unknown-linux-musleabihf' + + SETUP_COMMAND=$' + TOOLCHAIN=stable && \ + TARGET=armv7-unknown-linux-musleabihf && \ + OPENSSL_ARCH=linux-generic32 && \ + RUST_MUSL_CROSS_TARGET=$TARGET && \ + + apt-get update && \ + apt-get install -y \ + build-essential \ + cmake \ + curl \ + file \ + git \ + sudo \ + xutils-dev \ + unzip \ + && \ + apt-get clean && rm -rf /var/lib/apt/lists/* && \ + +echo \'OUTPUT = /usr/local/musl\r\nGCC_VER = 7.2.0\r\nDL_CMD = curl -C - -L -o\r\nCOMMON_CONFIG += CFLAGS=\"-g0 -Os\" CXXFLAGS=\"-g0 -Os\" LDFLAGS=\"-s\"\r\nCOMMON_CONFIG += --disable-nls\r\nGCC_CONFIG += --enable-languages=c,c++\r\nGCC_CONFIG += --disable-libquadmath --disable-decimal-float\r\nGCC_CONFIG += --disable-multilib\r\nCOMMON_CONFIG += --with-debug-prefix-map=$(CURDIR)=\r\n\' > /tmp/config.mak && +less /tmp/config.mak && +cd /tmp && \ + curl -Lsq -o musl-cross-make.zip https://github.com/richfelker/musl-cross-make/archive/v0.9.8.zip && \ + unzip -q musl-cross-make.zip && \ + rm musl-cross-make.zip && \ + mv musl-cross-make-0.9.8 musl-cross-make && \ + cp /tmp/config.mak /tmp/musl-cross-make/config.mak && \ + cd /tmp/musl-cross-make && \ + TARGET=$TARGET make install > /tmp/musl-cross-make.log && \ + ln -s /usr/local/musl/bin/$TARGET-strip /usr/local/musl/bin/musl-strip && \ + cd /tmp && \ + rm -rf /tmp/musl-cross-make /tmp/musl-cross-make.log && + + mkdir -p /home/rust/libs /home/rust/src && + + export PATH=/root/.cargo/bin:/usr/local/musl/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin && + export TARGET_CC=$TARGET-gcc && + export TARGET_CXX=$TARGET-g++ && + export TARGET_C_INCLUDE_PATH=/usr/local/musl/$TARGET/include/ && + + chmod 755 /root/ && \ + curl https://sh.rustup.rs -sqSf | \ + sh -s -- -y --default-toolchain $TOOLCHAIN && \ + rustup target add $TARGET && + echo \'[build]\ntarget = \"armv7-unknown-linux-musleabihf\"\n\n[target.armv7-unknown-linux-musleabihf]\nlinker = \"armv7-unknown-linux-musleabihf-gcc\"\n\' > /root/.cargo/config && \ + + cd /home/rust/libs && \ + export CC=$TARGET_CC && \ + export C_INCLUDE_PATH=$TARGET_C_INCLUDE_PATH && \ + echo "Building zlib" && \ + VERS=1.2.11 && \ + CHECKSUM=c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1 && \ + cd /home/rust/libs && \ + curl -sqLO https://zlib.net/zlib-$VERS.tar.gz && \ + echo "$CHECKSUM zlib-$VERS.tar.gz" > checksums.txt && \ + sha256sum -c checksums.txt && \ + tar xzf zlib-$VERS.tar.gz && cd zlib-$VERS && \ + ./configure --static --archs="-fPIC" --prefix=/usr/local/musl/$TARGET && \ + make && sudo make install && \ + cd .. && rm -rf zlib-$VERS.tar.gz zlib-$VERS checksums.txt && \ + echo "Building OpenSSL" && \ + VERS=1.0.2q && \ + CHECKSUM=5744cfcbcec2b1b48629f7354203bc1e5e9b5466998bbccc5b5fcde3b18eb684 && \ + curl -sqO https://www.openssl.org/source/openssl-$VERS.tar.gz && \ + echo "$CHECKSUM openssl-$VERS.tar.gz" > checksums.txt && \ + sha256sum -c checksums.txt && \ + tar xzf openssl-$VERS.tar.gz && cd openssl-$VERS && \ + ./Configure $OPENSSL_ARCH -fPIC --prefix=/usr/local/musl/$TARGET && \ + make depend && \ + make && sudo make install && \ + cd .. && rm -rf openssl-$VERS.tar.gz openssl-$VERS checksums.txt && \ + export OPENSSL_DIR=/usr/local/musl/$TARGET/ && \ + export OPENSSL_INCLUDE_DIR=/usr/local/musl/$TARGET/include/ && \ + export DEP_OPENSSL_INCLUDE=/usr/local/musl/$TARGET/include/ && \ + export OPENSSL_LIB_DIR=/usr/local/musl/$TARGET/lib/ && \ + export OPENSSL_STATIC=1 && \ + ' + + MAKE_FLAGS="'CARGOFLAGS=$CARGOFLAGS --target armv7-unknown-linux-musleabihf'" + MAKE_FLAGS="$MAKE_FLAGS 'TARGET=target/armv7-unknown-linux-musleabihf/release'" + MAKE_FLAGS="$MAKE_FLAGS 'STRIP_COMMAND=musl-strip'" + ;; + + ubuntu18.04.aarch64| alpine.aarch64) + RUST_TARGET='aarch64-unknown-linux-gnu' + SETUP_COMMAND=$' sources="$(cat /etc/apt/sources.list | grep -E \'^[^#]\')" && # Update existing repos to be specifically for amd64 @@ -206,39 +296,29 @@ case "$PACKAGE_OS.$PACKAGE_ARCH" in echo \'[target.aarch64-unknown-linux-gnu]\' > ~/.cargo/config && echo \'linker = \"aarch64-linux-gnu-gcc\"\' >> ~/.cargo/config && export AARCH64_UNKNOWN_LINUX_GNU_OPENSSL_LIB_DIR=/usr/lib/aarch64-linux-gnu && - export AARCH64_UNKNOWN_LINUX_GNU_OPENSSL_INCLUDE_DIR=/usr/include && + export AARCH64_UNKNOWN_LINUX_GNU_OPENSSL_INCLUDE_DIR=/usr/include && + cd /project/$BUILD_PATH && + echo \'Installing rustup\' && + curl -sSLf https://sh.rustup.rs | sh -s -- -y && + . ~/.cargo/env && ' + MAKE_FLAGS="'CARGOFLAGS=$CARGOFLAGS --target aarch64-unknown-linux-gnu'" + MAKE_FLAGS="$MAKE_FLAGS 'TARGET=target/aarch64-unknown-linux-gnu/release'" + MAKE_FLAGS="$MAKE_FLAGS 'STRIP_COMMAND=/usr/aarch64-linux-gnu/bin/strip'" ;; esac +if [ -n "$RUST_TARGET" ]; then + RUST_TARGET_COMMAND="rustup target add $RUST_TARGET &&" +fi + if [ -z "$SETUP_COMMAND" ]; then echo "Unrecognized target [$PACKAGE_OS.$PACKAGE_ARCH]" >&2 exit 1 fi -case "$PACKAGE_OS" in - *) - case "$PACKAGE_ARCH" in - amd64) - MAKE_FLAGS="'CARGOFLAGS=$CARGOFLAGS --target x86_64-unknown-linux-musl'" - MAKE_FLAGS="$MAKE_FLAGS 'TARGET=target/x86_64-unknown-linux-musl/release'" - MAKE_FLAGS="$MAKE_FLAGS 'STRIP_COMMAND=strip'" - ;; - arm32v7) - MAKE_FLAGS="'CARGOFLAGS=$CARGOFLAGS --target armv7-unknown-linux-gnueabihf'" - MAKE_FLAGS="$MAKE_FLAGS 'TARGET=target/armv7-unknown-linux-gnueabihf/release'" - MAKE_FLAGS="$MAKE_FLAGS 'STRIP_COMMAND=/usr/arm-linux-gnueabihf/bin/strip'" - ;; - aarch64) - MAKE_FLAGS="'CARGOFLAGS=$CARGOFLAGS --target aarch64-unknown-linux-gnu'" - MAKE_FLAGS="$MAKE_FLAGS 'TARGET=target/aarch64-unknown-linux-gnu/release'" - MAKE_FLAGS="$MAKE_FLAGS 'STRIP_COMMAND=/usr/aarch64-linux-gnu/bin/strip'" - ;; - esac +MAKE_COMMAND="make release $MAKE_FLAGS" - MAKE_COMMAND="make release $MAKE_FLAGS" - ;; -esac docker run --rm \ --user root \ @@ -253,12 +333,7 @@ docker run --rm \ cat /etc/os-release && $SETUP_COMMAND - - cd /project/$BUILD_PATH && - echo 'Installing rustup' && - curl -sSLf https://sh.rustup.rs | sh -s -- -y && - . ~/.cargo/env && - + cd /project/$BUILD_PATH && # build artifacts $RUST_TARGET_COMMAND $MAKE_COMMAND diff --git a/test/Microsoft.Azure.Devices.Edge.Test.Common/config/EdgeConfiguration.cs b/test/Microsoft.Azure.Devices.Edge.Test.Common/config/EdgeConfiguration.cs index 0af9a0a7dcb..c6ae6d6dd7d 100644 --- a/test/Microsoft.Azure.Devices.Edge.Test.Common/config/EdgeConfiguration.cs +++ b/test/Microsoft.Azure.Devices.Edge.Test.Common/config/EdgeConfiguration.cs @@ -29,7 +29,7 @@ public EdgeConfiguration( this.expectedConfig = expectedConfig; this.moduleImages = moduleImages; this.ModuleNames = moduleNames - .Select(id => id.StartsWith('$') ? id.Substring(1) : id) + .Select(id => id) .ToArray(); } diff --git a/test/Microsoft.Azure.Devices.Edge.Test.Common/linux/EdgeDaemon.cs b/test/Microsoft.Azure.Devices.Edge.Test.Common/linux/EdgeDaemon.cs index 96b46dbe7d7..faf0b7a758c 100644 --- a/test/Microsoft.Azure.Devices.Edge.Test.Common/linux/EdgeDaemon.cs +++ b/test/Microsoft.Azure.Devices.Edge.Test.Common/linux/EdgeDaemon.cs @@ -84,7 +84,7 @@ public async Task InstallAsync(Option packagesPath, Option proxy, C string[] commands = packagesPath.Match( p => this.packageManagement.GetInstallCommandsFromLocal(p), - () => this.packageManagement.GetInstallCommandsFromMicrosoftProd()); + () => this.packageManagement.GetInstallCommandsFromMicrosoftProd(proxy)); await Profiler.Run( async () => diff --git a/test/Microsoft.Azure.Devices.Edge.Test.Common/linux/PackageManagement.cs b/test/Microsoft.Azure.Devices.Edge.Test.Common/linux/PackageManagement.cs index b84dbf613fb..fe2975bb6f0 100644 --- a/test/Microsoft.Azure.Devices.Edge.Test.Common/linux/PackageManagement.cs +++ b/test/Microsoft.Azure.Devices.Edge.Test.Common/linux/PackageManagement.cs @@ -4,6 +4,7 @@ namespace Microsoft.Azure.Devices.Edge.Test.Common.Linux using System; using System.IO; using System.Linq; + using Microsoft.Azure.Devices.Edge.Util; public enum SupportedPackageExtension { @@ -59,30 +60,41 @@ public string[] GetInstallCommandsFromLocal(string path) }; } - public string[] GetInstallCommandsFromMicrosoftProd() => this.packageExtension switch + public string[] GetInstallCommandsFromMicrosoftProd(Option proxy) { - SupportedPackageExtension.Deb => new[] + var curl = "curl"; + var prefix = string.Empty; + proxy.ForEach(url => { - // Based on instructions at: - // https://github.com/MicrosoftDocs/azure-docs/blob/058084949656b7df518b64bfc5728402c730536a/articles/iot-edge/how-to-install-iot-edge-linux.md - // TODO: 8/30/2019 support curl behind a proxy - $"curl https://packages.microsoft.com/config/{this.os}/{this.version}/multiarch/prod.list > /etc/apt/sources.list.d/microsoft-prod.list", - "curl https://packages.microsoft.com/keys/microsoft.asc | gpg --dearmor > /etc/apt/trusted.gpg.d/microsoft.gpg", - $"apt-get update", - $"apt-get install --yes iotedge" - }, - SupportedPackageExtension.Rpm => new[] + curl += $" -x {url}"; + prefix = $"http_proxy={url} https_proxy={url} "; + }); + + return this.packageExtension switch { - $"rpm -iv --replacepkgs https://packages.microsoft.com/config/{this.os}/{this.version}/packages-microsoft-prod.rpm", - $"yum updateinfo", - $"yum install --yes iotedge", - "pathToSystemdConfig=$(systemctl cat iotedge | head -n 1)", - "sed 's/=on-failure/=no/g' ${pathToSystemdConfig#?} > ~/override.conf", - "sudo mv -f ~/override.conf ${pathToSystemdConfig#?}", - "sudo systemctl daemon-reload" - }, - _ => throw new NotImplementedException($"Don't know how to install daemon on for '.{this.packageExtension}'"), - }; + SupportedPackageExtension.Deb => new[] + { + // Based on instructions at: + // https://github.com/MicrosoftDocs/azure-docs/blob/058084949656b7df518b64bfc5728402c730536a/articles/iot-edge/how-to-install-iot-edge-linux.md + $"{curl} https://packages.microsoft.com/config/{this.os}/{this.version}/multiarch/prod.list > /etc/apt/sources.list.d/microsoft-prod.list", + $"{curl} https://packages.microsoft.com/keys/microsoft.asc | gpg --dearmor > /etc/apt/trusted.gpg.d/microsoft.gpg", + $"{prefix}apt-get update", + $"{prefix}apt-get install --yes iotedge" + }, + SupportedPackageExtension.Rpm => new[] + { + // No proxy support here because our proxy test environment uses Ubuntu. + $"rpm -iv --replacepkgs https://packages.microsoft.com/config/{this.os}/{this.version}/packages-microsoft-prod.rpm", + $"yum updateinfo", + $"yum install --yes iotedge", + "pathToSystemdConfig=$(systemctl cat iotedge | head -n 1)", + "sed 's/=on-failure/=no/g' ${pathToSystemdConfig#?} > ~/override.conf", + "sudo mv -f ~/override.conf ${pathToSystemdConfig#?}", + "sudo systemctl daemon-reload" + }, + _ => throw new NotImplementedException($"Don't know how to install daemon on for '.{this.packageExtension}'"), + }; + } public string[] GetUninstallCommands() => this.packageExtension switch { diff --git a/test/Microsoft.Azure.Devices.Edge.Test/AuthorizationPolicy.cs b/test/Microsoft.Azure.Devices.Edge.Test/AuthorizationPolicy.cs new file mode 100644 index 00000000000..1eedf09d3f6 --- /dev/null +++ b/test/Microsoft.Azure.Devices.Edge.Test/AuthorizationPolicy.cs @@ -0,0 +1,298 @@ +// Copyright (c) Microsoft. All rights reserved. +namespace Microsoft.Azure.Devices.Edge.Test +{ + using System; + using System.Collections.Generic; + using System.Threading; + using System.Threading.Tasks; + using Microsoft.Azure.Devices.Client.Exceptions; + using Microsoft.Azure.Devices.Edge.Test.Common; + using Microsoft.Azure.Devices.Edge.Test.Common.Certs; + using Microsoft.Azure.Devices.Edge.Test.Common.Config; + using Microsoft.Azure.Devices.Edge.Test.Helpers; + using Microsoft.Azure.Devices.Edge.Util; + using Microsoft.Azure.Devices.Edge.Util.Test.Common.NUnit; + using NUnit.Framework; + + [EndToEnd] + public class AuthorizationPolicy : SasManualProvisioningFixture + { + ///

+ /// Scenario: + /// - Create a deployment with broker and a policy that denies the connection. + /// - Create a device and validate that it cannot connect. + /// - Update deployment with new policy that allows the connection. + /// - Validate that new device can connect. + /// + [Test] + public async Task AuthorizationPolicyUpdateTest() + { + CancellationToken token = this.TestToken; + + string deviceId1 = DeviceId.Current.Generate(); + string deviceId2 = DeviceId.Current.Generate(); + + EdgeDeployment deployment = await this.runtime.DeployConfigurationAsync( + builder => + { + builder.GetModule(ModuleName.EdgeHub) + .WithEnvironment(new[] + { + ("experimentalFeatures__enabled", "true"), + ("experimentalFeatures__mqttBrokerEnabled", "true"), + }) + // deploy with deny policy + .WithDesiredProperties(new Dictionary + { + ["mqttBroker"] = new + { + authorizations = new[] + { + new + { + identities = new[] { $"{this.iotHub.Hostname}/{deviceId1}" }, + deny = new[] + { + new + { + operations = new[] { "mqtt:connect" } + } + } + } + } + } + }); + }, + token); + + EdgeModule edgeHub = deployment.Modules[ModuleName.EdgeHub]; + await edgeHub.WaitForReportedPropertyUpdatesAsync( + new + { + properties = new + { + reported = new + { + lastDesiredStatus = new + { + code = 200, + description = string.Empty + } + } + } + }, + token); + + // verify devices are not authorized after policy update. + Assert.ThrowsAsync(async () => + { + var leaf = await LeafDevice.CreateAsync( + deviceId1, + Protocol.Mqtt, + AuthenticationType.Sas, + Option.Some(this.runtime.DeviceId), + false, + CertificateAuthority.GetQuickstart(), + this.iotHub, + token, + Option.None()); + DateTime seekTime = DateTime.Now; + await leaf.SendEventAsync(token); + await leaf.WaitForEventsReceivedAsync(seekTime, token); + }); + + // deploy new allow policy + EdgeDeployment deployment2 = await this.runtime.DeployConfigurationAsync( + builder => + { + builder.GetModule(ModuleName.EdgeHub) + .WithEnvironment(new[] + { + ("experimentalFeatures__enabled", "true"), + ("experimentalFeatures__mqttBrokerEnabled", "true"), + }) + .WithDesiredProperties(new Dictionary + { + ["mqttBroker"] = new + { + authorizations = new[] + { + new + { + identities = new[] { $"{this.iotHub.Hostname}/{deviceId2}" }, + allow = new[] + { + new + { + operations = new[] { "mqtt:connect" } + } + } + } + } + } + }); + }, + token); + + EdgeModule edgeHub2 = deployment2.Modules[ModuleName.EdgeHub]; + await edgeHub2.WaitForReportedPropertyUpdatesAsync( + new + { + properties = new + { + reported = new + { + lastDesiredStatus = new + { + code = 200, + description = string.Empty + } + } + } + }, + token); + + var leaf = await LeafDevice.CreateAsync( + deviceId2, + Protocol.Mqtt, + AuthenticationType.Sas, + Option.Some(this.runtime.DeviceId), + false, + CertificateAuthority.GetQuickstart(), + this.iotHub, + token, + Option.None()); + + // verify device is authorized after policy update. + await TryFinally.DoAsync( + async () => + { + DateTime seekTime = DateTime.Now; + await leaf.SendEventAsync(token); + await leaf.WaitForEventsReceivedAsync(seekTime, token); + }, + async () => + { + await leaf.DeleteIdentityAsync(token); + }); + } + + /// + /// Scenario: + /// - Create a deployment with broker and two authorization rules: + /// allow device1 connect, deny device2 connect. + /// - Create devices and validate that they can/cannot connect. + /// + [Test] + public async Task AuthorizationPolicyExplicitPolicyTest() + { + CancellationToken token = this.TestToken; + + string deviceId1 = DeviceId.Current.Generate(); + string deviceId2 = DeviceId.Current.Generate(); + + EdgeDeployment deployment = await this.runtime.DeployConfigurationAsync( + builder => + { + builder.GetModule(ModuleName.EdgeHub) + .WithEnvironment(new[] + { + ("experimentalFeatures__enabled", "true"), + ("experimentalFeatures__mqttBrokerEnabled", "true"), + }) + .WithDesiredProperties(new Dictionary + { + ["mqttBroker"] = new + { + authorizations = new dynamic[] + { + new + { + identities = new[] { $"{this.iotHub.Hostname}/{deviceId1}" }, + allow = new[] + { + new + { + operations = new[] { "mqtt:connect" } + } + } + }, + new + { + identities = new[] { $"{this.iotHub.Hostname}/{deviceId2}" }, + deny = new[] + { + new + { + operations = new[] { "mqtt:connect" } + } + } + } + } + } + }); + }, + token); + + EdgeModule edgeHub = deployment.Modules[ModuleName.EdgeHub]; + await edgeHub.WaitForReportedPropertyUpdatesAsync( + new + { + properties = new + { + reported = new + { + lastDesiredStatus = new + { + code = 200, + description = string.Empty + } + } + } + }, + token); + + // verify device1 is authorized + var leaf = await LeafDevice.CreateAsync( + deviceId1, + Protocol.Mqtt, + AuthenticationType.Sas, + Option.Some(this.runtime.DeviceId), + false, + CertificateAuthority.GetQuickstart(), + this.iotHub, + token, + Option.None()); + + await TryFinally.DoAsync( + async () => + { + DateTime seekTime = DateTime.Now; + await leaf.SendEventAsync(token); + await leaf.WaitForEventsReceivedAsync(seekTime, token); + }, + async () => + { + await leaf.DeleteIdentityAsync(token); + }); + + // verify device2 is not authorized + Assert.ThrowsAsync(async () => + { + var leaf = await LeafDevice.CreateAsync( + deviceId2, + Protocol.Mqtt, + AuthenticationType.Sas, + Option.Some(this.runtime.DeviceId), + false, + CertificateAuthority.GetQuickstart(), + this.iotHub, + token, + Option.None()); + DateTime seekTime = DateTime.Now; + await leaf.SendEventAsync(token); + await leaf.WaitForEventsReceivedAsync(seekTime, token); + }); + } + } +} diff --git a/test/Microsoft.Azure.Devices.Edge.Test/EdgeAgentDirectMethods.cs b/test/Microsoft.Azure.Devices.Edge.Test/EdgeAgentDirectMethods.cs index 6f0e2b11f3a..0336adcce35 100644 --- a/test/Microsoft.Azure.Devices.Edge.Test/EdgeAgentDirectMethods.cs +++ b/test/Microsoft.Azure.Devices.Edge.Test/EdgeAgentDirectMethods.cs @@ -130,7 +130,7 @@ public async Task TestUploadModuleLogs() { string moduleName = "NumberLogger"; int count = 10; - string sasUrl = "https://lefitcheblobtest1.blob.core.windows.net/upload-test?sv=2019-02-02&st=2020-08-03T17%3A14%3A16Z&se=2020-11-04T18%3A14%3A00Z&sr=c&sp=racwdl&sig=phKgqaaxSJTcZzUcggE%2FnhDljs4%2BhvCg7IOKk8iWTcY%3D"; + string sasUrl = Context.Current.BlobSasUrl.Expect(() => new InvalidOperationException("Missing Blob SAS url")); CancellationToken token = this.TestToken; @@ -143,27 +143,82 @@ await this.runtime.DeployConfigurationAsync( }, token); await Task.Delay(10000); - var request = new ModuleLogsUploadRequest("1.0", new List { new LogRequestItem(moduleName, new ModuleLogFilter(Option.None(), Option.None(), Option.None(), Option.None(), Option.None())) }, LogsContentEncoding.None, LogsContentType.Text, sasUrl); + var request = new + { + schemaVersion = "1.0", + items = new + { + id = "NumberLogger", + filter = new { }, + }, + encoding = 0, + contentYtpe = 1, + sasUrl, + }; + + var payload = JsonConvert.SerializeObject(request); + + CloudToDeviceMethodResult result = await this.iotHub.InvokeMethodAsync(this.runtime.DeviceId, ConfigModuleName.EdgeAgent, new CloudToDeviceMethod("UploadModuleLogs", TimeSpan.FromSeconds(300), TimeSpan.FromSeconds(300)).SetPayloadJson(payload), token); + + var response = JsonConvert.DeserializeObject(result.GetPayloadAsJson()); + await this.WaitForTaskCompletion(response.CorrelationId, token); + } + + [Test] + public async Task TestUploadSupportBundle() + { + string moduleName = "NumberLogger"; + int count = 10; + string sasUrl = Context.Current.BlobSasUrl.Expect(() => new InvalidOperationException("Missing Blob SAS url")); + + CancellationToken token = this.TestToken; + + string numberLoggerImage = Context.Current.NumberLoggerImage.Expect(() => new InvalidOperationException("Missing Number Logger image")); + await this.runtime.DeployConfigurationAsync( + builder => + { + builder.AddModule(moduleName, numberLoggerImage) + .WithEnvironment(new[] { ("Count", count.ToString()) }); + }, token); + await Task.Delay(10000); + + var request = new + { + schemaVersion = "1.0", + sasUrl, + }; - CloudToDeviceMethodResult result = await this.iotHub.InvokeMethodAsync(this.runtime.DeviceId, ConfigModuleName.EdgeAgent, new CloudToDeviceMethod("UploadModuleLogs", TimeSpan.FromSeconds(300), TimeSpan.FromSeconds(300)).SetPayloadJson(JsonConvert.SerializeObject(request)), token); + var payload = JsonConvert.SerializeObject(request); - TaskStatusResponse response = null; - for (int i = 0; i < 10; i++) + CloudToDeviceMethodResult result = await this.iotHub.InvokeMethodAsync(this.runtime.DeviceId, ConfigModuleName.EdgeAgent, new CloudToDeviceMethod("UploadSupportBundle", TimeSpan.FromSeconds(300), TimeSpan.FromSeconds(300)).SetPayloadJson(payload), token); + + var response = JsonConvert.DeserializeObject(result.GetPayloadAsJson()); + await this.WaitForTaskCompletion(response.CorrelationId, token); + } + + async Task WaitForTaskCompletion(string correlationId, CancellationToken token) + { + while (true) { + var request = new + { + schemaVersion = "1.0", + correlationId + }; + + var result = await this.iotHub.InvokeMethodAsync(this.runtime.DeviceId, ConfigModuleName.EdgeAgent, new CloudToDeviceMethod("GetTaskStatus", TimeSpan.FromSeconds(300), TimeSpan.FromSeconds(300)).SetPayloadJson(JsonConvert.SerializeObject(request)), token); + Assert.AreEqual((int)HttpStatusCode.OK, result.Status); - response = JsonConvert.DeserializeObject(result.GetPayloadAsJson()); + var response = JsonConvert.DeserializeObject(result.GetPayloadAsJson()); if (response.Status != BackgroundTaskRunStatus.NotStarted && response.Status != BackgroundTaskRunStatus.Running) { - break; + Assert.AreEqual(BackgroundTaskRunStatus.Completed, response.Status, response.Message); + return; } await Task.Delay(5000); - var correlation = new TaskStatusRequest("1.0", response.CorrelationId); - result = await this.iotHub.InvokeMethodAsync(this.runtime.DeviceId, ConfigModuleName.EdgeAgent, new CloudToDeviceMethod("GetTaskStatus", TimeSpan.FromSeconds(300), TimeSpan.FromSeconds(300)).SetPayloadJson(JsonConvert.SerializeObject(correlation)), token); } - - Assert.AreEqual(BackgroundTaskRunStatus.Completed, response.Status, response.Message); } class LogResponse diff --git a/test/Microsoft.Azure.Devices.Edge.Test/Metrics.cs b/test/Microsoft.Azure.Devices.Edge.Test/Metrics.cs index 75feb69d09b..83bdb8b2122 100644 --- a/test/Microsoft.Azure.Devices.Edge.Test/Metrics.cs +++ b/test/Microsoft.Azure.Devices.Edge.Test/Metrics.cs @@ -7,6 +7,7 @@ namespace Microsoft.Azure.Devices.Edge.Test using System.Net; using System.Threading; using System.Threading.Tasks; + using Microsoft.Azure.Devices.Edge.Test.Common; using Microsoft.Azure.Devices.Edge.Test.Common.Config; using Microsoft.Azure.Devices.Edge.Test.Helpers; using Microsoft.Azure.Devices.Edge.Util.Test.Common.NUnit; @@ -26,6 +27,9 @@ public async Task ValidateMetrics() CancellationToken token = this.TestToken; await this.DeployAsync(token); + var agent = new EdgeAgent(this.runtime.DeviceId, this.iotHub); + await agent.PingAsync(token); + var result = await this.iotHub.InvokeMethodAsync(this.runtime.DeviceId, ModuleName, new CloudToDeviceMethod("ValidateMetrics", TimeSpan.FromSeconds(300), TimeSpan.FromSeconds(300)), token); Assert.AreEqual(result.Status, (int)HttpStatusCode.OK); diff --git a/test/Microsoft.Azure.Devices.Edge.Test/Module.cs b/test/Microsoft.Azure.Devices.Edge.Test/Module.cs index e7cc6d3bd22..655943589f9 100644 --- a/test/Microsoft.Azure.Devices.Edge.Test/Module.cs +++ b/test/Microsoft.Azure.Devices.Edge.Test/Module.cs @@ -21,6 +21,7 @@ public class Module : SasManualProvisioningFixture [Category("CentOsSafe")] public async Task TempSensor() { + Assert.Ignore("Temporarily disabling flaky test while we figure out what is wrong"); string sensorImage = Context.Current.TempSensorImage.GetOrElse(DefaultSensorImage); CancellationToken token = this.TestToken; @@ -103,6 +104,7 @@ public async Task TempFilter() // Test Temperature Filter Function: https://docs.microsoft.com/en-us/azure/iot-edge/tutorial-deploy-function public async Task TempFilterFunc() { + Assert.Ignore("Temporarily disabling flaky test while we figure out what is wrong"); if (OsPlatform.IsArm() && OsPlatform.Is64Bit()) { Assert.Ignore("TempFilterFunc is disabled for arm64 because azureiotedge-functions-filter does not exist for arm64"); diff --git a/test/Microsoft.Azure.Devices.Edge.Test/PlugAndPlay.cs b/test/Microsoft.Azure.Devices.Edge.Test/PlugAndPlay.cs index 9c291220144..c894708bcab 100644 --- a/test/Microsoft.Azure.Devices.Edge.Test/PlugAndPlay.cs +++ b/test/Microsoft.Azure.Devices.Edge.Test/PlugAndPlay.cs @@ -39,6 +39,7 @@ public PlugAndPlay() [Test] public async Task PlugAndPlayDeviceClient() { + Assert.Ignore("Temporarily disabling flaky test while we figure out what is wrong"); CancellationToken token = this.TestToken; EdgeDeployment deployment = await this.runtime.DeployConfigurationAsync( builder => diff --git a/test/Microsoft.Azure.Devices.Edge.Test/helpers/BaseFixture.cs b/test/Microsoft.Azure.Devices.Edge.Test/helpers/BaseFixture.cs index 3b83b7e089d..2213c982860 100644 --- a/test/Microsoft.Azure.Devices.Edge.Test/helpers/BaseFixture.cs +++ b/test/Microsoft.Azure.Devices.Edge.Test/helpers/BaseFixture.cs @@ -17,9 +17,12 @@ public class BaseFixture protected CancellationToken TestToken => this.cts.Token; + protected virtual Task BeforeTestTimerStarts() => Task.CompletedTask; + [SetUp] - protected void BeforeEachTest() + protected async Task BeforeEachTestAsync() { + await this.BeforeTestTimerStarts(); this.cts = new CancellationTokenSource(Context.Current.TestTimeout); this.testStartTime = DateTime.Now; this.profiler = Profiler.Start(); diff --git a/test/Microsoft.Azure.Devices.Edge.Test/helpers/Context.cs b/test/Microsoft.Azure.Devices.Edge.Test/helpers/Context.cs index d95dc18a00a..7bacfddf05c 100644 --- a/test/Microsoft.Azure.Devices.Edge.Test/helpers/Context.cs +++ b/test/Microsoft.Azure.Devices.Edge.Test/helpers/Context.cs @@ -109,6 +109,7 @@ IEnumerable GetAndValidateRegistries() this.TestTimeout = TimeSpan.FromMinutes(context.GetValue("testTimeoutMinutes", 5)); this.Verbose = context.GetValue("verbose"); this.ParentHostname = Option.Maybe(Get("parentHostname")); + this.BlobSasUrl = Option.Maybe(Get("BLOB_STORE_SAS")); } static readonly Lazy Default = new Lazy(() => new Context()); @@ -182,5 +183,7 @@ IEnumerable GetAndValidateRegistries() public bool Verbose { get; } public Option ParentHostname { get; } + + public Option BlobSasUrl { get; } } } diff --git a/test/Microsoft.Azure.Devices.Edge.Test/helpers/CustomCertificatesFixture.cs b/test/Microsoft.Azure.Devices.Edge.Test/helpers/CustomCertificatesFixture.cs index 985ccd7e7d9..f2bfb2bb1a1 100644 --- a/test/Microsoft.Azure.Devices.Edge.Test/helpers/CustomCertificatesFixture.cs +++ b/test/Microsoft.Azure.Devices.Edge.Test/helpers/CustomCertificatesFixture.cs @@ -12,19 +12,15 @@ public class CustomCertificatesFixture : SasManualProvisioningFixture { protected CertificateAuthority ca; - [SetUp] - public override Task SasProvisionEdgeAsync() - { - // Do nothing; everything happens at [OneTimeSetUp] instead. We do this to avoid - // creating a new device for every permutation of the Transparent Gateway tests. - return Task.CompletedTask; - } + // Do nothing; everything happens at [OneTimeSetUp] instead. We do this to avoid + // creating a new device for every permutation of the Transparent Gateway tests. + protected override Task BeforeTestTimerStarts() => Task.CompletedTask; [OneTimeSetUp] public async Task SetUpCertificatesAsync() { await Profiler.Run( - () => base.SasProvisionEdgeAsync(), + () => this.SasProvisionEdgeAsync(), "Completed edge manual provisioning with SAS token"); await Profiler.Run( diff --git a/test/Microsoft.Azure.Devices.Edge.Test/helpers/SasManualProvisioningFixture.cs b/test/Microsoft.Azure.Devices.Edge.Test/helpers/SasManualProvisioningFixture.cs index 4bb3dafbe6a..2f1d90845f7 100644 --- a/test/Microsoft.Azure.Devices.Edge.Test/helpers/SasManualProvisioningFixture.cs +++ b/test/Microsoft.Azure.Devices.Edge.Test/helpers/SasManualProvisioningFixture.cs @@ -21,8 +21,9 @@ public SasManualProvisioningFixture(string connectionString, string eventHubEndp { } - [SetUp] - public virtual async Task SasProvisionEdgeAsync() + protected override Task BeforeTestTimerStarts() => this.SasProvisionEdgeAsync(); + + protected virtual async Task SasProvisionEdgeAsync() { using (var cts = new CancellationTokenSource(Context.Current.SetupTimeout)) { diff --git a/test/Microsoft.Azure.Devices.Edge.Test/helpers/X509ManualProvisioningFixture.cs b/test/Microsoft.Azure.Devices.Edge.Test/helpers/X509ManualProvisioningFixture.cs index 13194453c8a..7ff53a1b0be 100644 --- a/test/Microsoft.Azure.Devices.Edge.Test/helpers/X509ManualProvisioningFixture.cs +++ b/test/Microsoft.Azure.Devices.Edge.Test/helpers/X509ManualProvisioningFixture.cs @@ -88,7 +88,8 @@ await this.ConfigureDaemonAsync( return (new X509Thumbprint() { - PrimaryThumbprint = deviceCert.Thumbprint + PrimaryThumbprint = deviceCert.Thumbprint, + SecondaryThumbprint = deviceCert.Thumbprint }, identityCerts); } diff --git a/test/README.md b/test/README.md index b84a92b6c6c..e33ab5a18b6 100644 --- a/test/README.md +++ b/test/README.md @@ -64,6 +64,7 @@ The tests also expect to find several _secret_ parameters. While these can techn | `[E2E_]PREVIEW_EVENT_HUB_ENDPOINT` | * | Alternate Event Hub that is in a region that has preview bits deployed for preview testing - required only for PlugAndPlay tests. | | `[E2E_]REGISTRIES__{n}__PASSWORD` || Password associated with a container registry entry in the `registries` array of `context.json`. `{n}` is the number corresponding to the (zero-based) array entry. For example, if you specified a single container registry in the `registries` array, the corresponding parameter would be `[E2E_]REGISTRIES__0__PASSWORD`. | | `[E2E_]ROOT_CA_PASSWORD` || The password associated with the root certificate specified in `rootCaCertificatePath`. | +| `[E2E_]BLOB_STORE_SAS` || The sas token used to upload module logs and support bundle in the tests. | _Note: the definitive source for information about test parameters is `test/Microsoft.Azure.Devices.Edge.Test/helpers/Context.cs`._ diff --git a/test/connectivity/modules/NetworkController/docker/linux/arm32v7/Dockerfile b/test/connectivity/modules/NetworkController/docker/linux/arm32v7/Dockerfile index fdb6d122b55..7ff3492eb4c 100644 --- a/test/connectivity/modules/NetworkController/docker/linux/arm32v7/Dockerfile +++ b/test/connectivity/modules/NetworkController/docker/linux/arm32v7/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=1.0.6.2-linux-arm32v7 +ARG base_tag=1.0.6.4-linux-arm32v7 FROM azureiotedge/azureiotedge-module-base-full:${base_tag} ARG EXE_DIR=. diff --git a/test/connectivity/modules/NetworkController/docker/linux/arm64v8/Dockerfile b/test/connectivity/modules/NetworkController/docker/linux/arm64v8/Dockerfile index deb39e3216a..2d7b9ab6df0 100644 --- a/test/connectivity/modules/NetworkController/docker/linux/arm64v8/Dockerfile +++ b/test/connectivity/modules/NetworkController/docker/linux/arm64v8/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=1.0.6.2-linux-arm64v8 +ARG base_tag=1.0.6.4-linux-arm64v8 FROM azureiotedge/azureiotedge-module-base:${base_tag} ARG EXE_DIR=. diff --git a/test/modules/CloudToDeviceMessageTester/docker/linux/arm32v7/Dockerfile b/test/modules/CloudToDeviceMessageTester/docker/linux/arm32v7/Dockerfile index a204d7a4604..869959e818b 100644 --- a/test/modules/CloudToDeviceMessageTester/docker/linux/arm32v7/Dockerfile +++ b/test/modules/CloudToDeviceMessageTester/docker/linux/arm32v7/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=1.0.6.2-linux-arm32v7 +ARG base_tag=1.0.6.4-linux-arm32v7 FROM azureiotedge/azureiotedge-module-base:${base_tag} ARG EXE_DIR=. diff --git a/test/modules/CloudToDeviceMessageTester/docker/linux/arm64v8/Dockerfile b/test/modules/CloudToDeviceMessageTester/docker/linux/arm64v8/Dockerfile index 601c8adeebf..2b77fac1972 100644 --- a/test/modules/CloudToDeviceMessageTester/docker/linux/arm64v8/Dockerfile +++ b/test/modules/CloudToDeviceMessageTester/docker/linux/arm64v8/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=1.0.6.2-linux-arm64v8 +ARG base_tag=1.0.6.4-linux-arm64v8 FROM azureiotedge/azureiotedge-module-base:${base_tag} ARG EXE_DIR=. diff --git a/test/modules/DeploymentTester/docker/linux/arm32v7/Dockerfile b/test/modules/DeploymentTester/docker/linux/arm32v7/Dockerfile index 1a369fbf2d6..678c53c2800 100644 --- a/test/modules/DeploymentTester/docker/linux/arm32v7/Dockerfile +++ b/test/modules/DeploymentTester/docker/linux/arm32v7/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=1.0.6.2-linux-arm32v7 +ARG base_tag=1.0.6.4-linux-arm32v7 FROM azureiotedge/azureiotedge-module-base:${base_tag} ARG EXE_DIR=. diff --git a/test/modules/DeploymentTester/docker/linux/arm64v8/Dockerfile b/test/modules/DeploymentTester/docker/linux/arm64v8/Dockerfile index 52a6538d22f..b8a3f8642e9 100644 --- a/test/modules/DeploymentTester/docker/linux/arm64v8/Dockerfile +++ b/test/modules/DeploymentTester/docker/linux/arm64v8/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=1.0.6.2-linux-arm64v8 +ARG base_tag=1.0.6.4-linux-arm64v8 FROM azureiotedge/azureiotedge-module-base:${base_tag} ARG EXE_DIR=. diff --git a/test/modules/DirectMethodReceiver/docker/linux/arm32v7/Dockerfile b/test/modules/DirectMethodReceiver/docker/linux/arm32v7/Dockerfile index b16979b0ecf..21e9b309752 100644 --- a/test/modules/DirectMethodReceiver/docker/linux/arm32v7/Dockerfile +++ b/test/modules/DirectMethodReceiver/docker/linux/arm32v7/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=1.0.6.2-linux-arm32v7 +ARG base_tag=1.0.6.4-linux-arm32v7 FROM azureiotedge/azureiotedge-module-base:${base_tag} ARG EXE_DIR=. diff --git a/test/modules/DirectMethodReceiver/docker/linux/arm64v8/Dockerfile b/test/modules/DirectMethodReceiver/docker/linux/arm64v8/Dockerfile index 18321fe93a2..22e8f57b87d 100644 --- a/test/modules/DirectMethodReceiver/docker/linux/arm64v8/Dockerfile +++ b/test/modules/DirectMethodReceiver/docker/linux/arm64v8/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=1.0.6.2-linux-arm64v8 +ARG base_tag=1.0.6.4-linux-arm64v8 FROM azureiotedge/azureiotedge-module-base:${base_tag} ARG EXE_DIR=. diff --git a/test/modules/DirectMethodSender/docker/linux/arm32v7/Dockerfile b/test/modules/DirectMethodSender/docker/linux/arm32v7/Dockerfile index c0b1b7b6626..d7f2677f034 100644 --- a/test/modules/DirectMethodSender/docker/linux/arm32v7/Dockerfile +++ b/test/modules/DirectMethodSender/docker/linux/arm32v7/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=1.0.6.2-linux-arm32v7 +ARG base_tag=1.0.6.4-linux-arm32v7 FROM azureiotedge/azureiotedge-module-base:${base_tag} ARG EXE_DIR=. diff --git a/test/modules/DirectMethodSender/docker/linux/arm64v8/Dockerfile b/test/modules/DirectMethodSender/docker/linux/arm64v8/Dockerfile index ff112740ca8..36d34854725 100644 --- a/test/modules/DirectMethodSender/docker/linux/arm64v8/Dockerfile +++ b/test/modules/DirectMethodSender/docker/linux/arm64v8/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=1.0.6.2-linux-arm64v8 +ARG base_tag=1.0.6.4-linux-arm64v8 FROM azureiotedge/azureiotedge-module-base:${base_tag} ARG EXE_DIR=. diff --git a/test/modules/EdgeHubRestartTester/docker/linux/arm32v7/Dockerfile b/test/modules/EdgeHubRestartTester/docker/linux/arm32v7/Dockerfile index da7787e04f5..2a2c82888a7 100644 --- a/test/modules/EdgeHubRestartTester/docker/linux/arm32v7/Dockerfile +++ b/test/modules/EdgeHubRestartTester/docker/linux/arm32v7/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=1.0.6.2-linux-arm32v7 +ARG base_tag=1.0.6.4-linux-arm32v7 FROM azureiotedge/azureiotedge-module-base:${base_tag} ARG EXE_DIR=. diff --git a/test/modules/EdgeHubRestartTester/docker/linux/arm64v8/Dockerfile b/test/modules/EdgeHubRestartTester/docker/linux/arm64v8/Dockerfile index 6cb12d1c5eb..2c60cab9943 100644 --- a/test/modules/EdgeHubRestartTester/docker/linux/arm64v8/Dockerfile +++ b/test/modules/EdgeHubRestartTester/docker/linux/arm64v8/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=1.0.6.2-linux-arm64v8 +ARG base_tag=1.0.6.4-linux-arm64v8 FROM azureiotedge/azureiotedge-module-base:${base_tag} ARG EXE_DIR=. diff --git a/test/modules/MetricsValidator/docker/linux/arm32v7/Dockerfile b/test/modules/MetricsValidator/docker/linux/arm32v7/Dockerfile index fcf7e2e8437..a7e252ebcce 100644 --- a/test/modules/MetricsValidator/docker/linux/arm32v7/Dockerfile +++ b/test/modules/MetricsValidator/docker/linux/arm32v7/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=1.0.6.2-linux-arm32v7 +ARG base_tag=1.0.6.4-linux-arm32v7 FROM azureiotedge/azureiotedge-module-base:${base_tag} ARG EXE_DIR=. diff --git a/test/modules/MetricsValidator/docker/linux/arm64v8/Dockerfile b/test/modules/MetricsValidator/docker/linux/arm64v8/Dockerfile index b10cf1885b8..47c948b7deb 100644 --- a/test/modules/MetricsValidator/docker/linux/arm64v8/Dockerfile +++ b/test/modules/MetricsValidator/docker/linux/arm64v8/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=1.0.6.2-linux-arm64v8 +ARG base_tag=1.0.6.4-linux-arm64v8 FROM azureiotedge/azureiotedge-module-base:${base_tag} ARG EXE_DIR=. diff --git a/test/modules/MetricsValidator/src/tests/ValidateDocumentedMetrics.cs b/test/modules/MetricsValidator/src/tests/ValidateDocumentedMetrics.cs index 0ea02d10949..948a227a20b 100644 --- a/test/modules/MetricsValidator/src/tests/ValidateDocumentedMetrics.cs +++ b/test/modules/MetricsValidator/src/tests/ValidateDocumentedMetrics.cs @@ -36,11 +36,6 @@ protected override async Task Test(CancellationToken cancellationToken) var metrics = await this.scraper.ScrapeEndpointsAsync(cancellationToken); var expected = this.GetExpectedMetrics(); - if (RuntimeInformation.OSArchitecture == Architecture.Arm || RuntimeInformation.OSArchitecture == Architecture.Arm64) - { - // Docker doesn't return this on arm - expected.Remove("edgeAgent_created_pids_total"); - } if (OsPlatform.IsWindows()) { @@ -76,6 +71,26 @@ protected override async Task Test(CancellationToken cancellationToken) } } + // The following metric should not be populated in a happy E2E path. + // We are going to make a list and remove them here to not consider them as a failure. + IEnumerable skippingMetrics = new HashSet + { + "edgeAgent_unsuccessful_iothub_syncs_total", + "edgehub_client_connect_failed_total", + "edgehub_messages_dropped_total", + "edgehub_messages_unack_total", + "edgehub_offline_count_total", + "edgehub_operation_retry_total" + }; + + foreach (string skippingMetric in skippingMetrics) + { + if (unreturnedMetrics.Remove(skippingMetric)) + { + log.LogInformation($"\"{skippingMetric}\" was depopulated"); + } + } + foreach (string unreturnedMetric in unreturnedMetrics) { this.testReporter.Assert(unreturnedMetric, false, $"Metric did not exist in scrape."); @@ -106,11 +121,13 @@ Dictionary GetExpectedMetrics() async Task SeedMetrics(CancellationToken cancellationToken) { + string deviceId = Environment.GetEnvironmentVariable("IOTEDGE_DEVICEID"); + await this.moduleClient.SendEventAsync(new Message(Encoding.UTF8.GetBytes("Test message to seed metrics")), cancellationToken); const string methodName = "FakeDirectMethod"; await this.moduleClient.SetMethodHandlerAsync(methodName, (_, __) => Task.FromResult(new MethodResponse(200)), null); - await this.moduleClient.InvokeMethodAsync(Environment.GetEnvironmentVariable("IOTEDGE_DEVICEID"), Environment.GetEnvironmentVariable("IOTEDGE_MODULEID"), new MethodRequest(methodName), cancellationToken); + await this.moduleClient.InvokeMethodAsync(deviceId, Environment.GetEnvironmentVariable("IOTEDGE_MODULEID"), new MethodRequest(methodName), cancellationToken); await this.moduleClient.UpdateReportedPropertiesAsync(new TwinCollection(), cancellationToken); await this.moduleClient.GetTwinAsync(cancellationToken); diff --git a/test/modules/ModuleRestarter/docker/linux/arm32v7/Dockerfile b/test/modules/ModuleRestarter/docker/linux/arm32v7/Dockerfile index bb9d06e7d3d..9af4da5c80c 100644 --- a/test/modules/ModuleRestarter/docker/linux/arm32v7/Dockerfile +++ b/test/modules/ModuleRestarter/docker/linux/arm32v7/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=1.0.6.2-linux-arm32v7 +ARG base_tag=1.0.6.4-linux-arm32v7 FROM azureiotedge/azureiotedge-module-base:${base_tag} ARG EXE_DIR=. diff --git a/test/modules/ModuleRestarter/docker/linux/arm64v8/Dockerfile b/test/modules/ModuleRestarter/docker/linux/arm64v8/Dockerfile index 7a540b68cde..915759edc22 100644 --- a/test/modules/ModuleRestarter/docker/linux/arm64v8/Dockerfile +++ b/test/modules/ModuleRestarter/docker/linux/arm64v8/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=1.0.6.2-linux-arm64v8 +ARG base_tag=1.0.6.4-linux-arm64v8 FROM azureiotedge/azureiotedge-module-base:${base_tag} ARG EXE_DIR=. diff --git a/test/modules/NumberLogger/docker/linux/arm32v7/Dockerfile b/test/modules/NumberLogger/docker/linux/arm32v7/Dockerfile index 1cd3ee51f52..82d29e6ba64 100644 --- a/test/modules/NumberLogger/docker/linux/arm32v7/Dockerfile +++ b/test/modules/NumberLogger/docker/linux/arm32v7/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=1.0.6.2-linux-arm32v7 +ARG base_tag=1.0.6.4-linux-arm32v7 FROM azureiotedge/azureiotedge-module-base:${base_tag} ARG EXE_DIR=. diff --git a/test/modules/NumberLogger/docker/linux/arm64v8/Dockerfile b/test/modules/NumberLogger/docker/linux/arm64v8/Dockerfile index a6ba865a5b0..f9ccebaa306 100644 --- a/test/modules/NumberLogger/docker/linux/arm64v8/Dockerfile +++ b/test/modules/NumberLogger/docker/linux/arm64v8/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=1.0.6.2-linux-arm64v8 +ARG base_tag=1.0.6.4-linux-arm64v8 FROM azureiotedge/azureiotedge-module-base:${base_tag} ARG EXE_DIR=. diff --git a/test/modules/Relayer/docker/linux/arm32v7/Dockerfile b/test/modules/Relayer/docker/linux/arm32v7/Dockerfile index d597fd58e4e..5383197dc33 100644 --- a/test/modules/Relayer/docker/linux/arm32v7/Dockerfile +++ b/test/modules/Relayer/docker/linux/arm32v7/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=1.0.6.2-linux-arm32v7 +ARG base_tag=1.0.6.4-linux-arm32v7 FROM azureiotedge/azureiotedge-module-base:${base_tag} ARG EXE_DIR=. diff --git a/test/modules/Relayer/docker/linux/arm64v8/Dockerfile b/test/modules/Relayer/docker/linux/arm64v8/Dockerfile index e75ef36fa17..78fbbf4e180 100644 --- a/test/modules/Relayer/docker/linux/arm64v8/Dockerfile +++ b/test/modules/Relayer/docker/linux/arm64v8/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=1.0.6.2-linux-arm64v8 +ARG base_tag=1.0.6.4-linux-arm64v8 FROM azureiotedge/azureiotedge-module-base:${base_tag} ARG EXE_DIR=. diff --git a/test/modules/TemperatureFilter/docker/linux/arm32v7/Dockerfile b/test/modules/TemperatureFilter/docker/linux/arm32v7/Dockerfile index c7243c55956..9af5f38a24a 100644 --- a/test/modules/TemperatureFilter/docker/linux/arm32v7/Dockerfile +++ b/test/modules/TemperatureFilter/docker/linux/arm32v7/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=1.0.6.2-linux-arm32v7 +ARG base_tag=1.0.6.4-linux-arm32v7 FROM azureiotedge/azureiotedge-module-base:${base_tag} ARG EXE_DIR=. diff --git a/test/modules/TemperatureFilter/docker/linux/arm64v8/Dockerfile b/test/modules/TemperatureFilter/docker/linux/arm64v8/Dockerfile index 76820faf802..6dc39cf0380 100644 --- a/test/modules/TemperatureFilter/docker/linux/arm64v8/Dockerfile +++ b/test/modules/TemperatureFilter/docker/linux/arm64v8/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=1.0.6.2-linux-arm64v8 +ARG base_tag=1.0.6.4-linux-arm64v8 FROM azureiotedge/azureiotedge-module-base:${base_tag} ARG EXE_DIR=. diff --git a/test/modules/TestAnalyzer/docker/linux/arm32v7/Dockerfile b/test/modules/TestAnalyzer/docker/linux/arm32v7/Dockerfile index 14375ff628b..0476fe9eb61 100644 --- a/test/modules/TestAnalyzer/docker/linux/arm32v7/Dockerfile +++ b/test/modules/TestAnalyzer/docker/linux/arm32v7/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=1.0.6.2-linux-arm32v7 +ARG base_tag=1.0.6.4-linux-arm32v7 FROM azureiotedge/azureiotedge-module-base-full:${base_tag} ARG EXE_DIR=. diff --git a/test/modules/TestAnalyzer/docker/linux/arm32v7/base/Dockerfile b/test/modules/TestAnalyzer/docker/linux/arm32v7/base/Dockerfile index fefa5e36214..7106e0fa126 100644 --- a/test/modules/TestAnalyzer/docker/linux/arm32v7/base/Dockerfile +++ b/test/modules/TestAnalyzer/docker/linux/arm32v7/base/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=3.1.7-bionic-arm32v7 +ARG base_tag=3.1.10-bionic-arm32v7 FROM mcr.microsoft.com/dotnet/core/aspnet:${base_tag} RUN apt-get update && apt-get install -y libcap2-bin libsnappy1v5 && \ diff --git a/test/modules/TestAnalyzer/docker/linux/arm64v8/Dockerfile b/test/modules/TestAnalyzer/docker/linux/arm64v8/Dockerfile index a37dd5119cb..539d3c2d485 100644 --- a/test/modules/TestAnalyzer/docker/linux/arm64v8/Dockerfile +++ b/test/modules/TestAnalyzer/docker/linux/arm64v8/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=1.0.6.2-linux-arm64v8 +ARG base_tag=1.0.6.4-linux-arm64v8 FROM azureiotedge/azureiotedge-module-base-full:${base_tag} ARG EXE_DIR=. diff --git a/test/modules/TestAnalyzer/docker/linux/arm64v8/base/Dockerfile b/test/modules/TestAnalyzer/docker/linux/arm64v8/base/Dockerfile index 13b5a370d80..76738ad7f5f 100644 --- a/test/modules/TestAnalyzer/docker/linux/arm64v8/base/Dockerfile +++ b/test/modules/TestAnalyzer/docker/linux/arm64v8/base/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=3.1.7-bionic-arm64v8 +ARG base_tag=3.1.10-bionic-arm64v8 FROM mcr.microsoft.com/dotnet/core/aspnet:${base_tag} RUN apt-get update && \ diff --git a/test/modules/TestResultCoordinator/docker/linux/arm32v7/Dockerfile b/test/modules/TestResultCoordinator/docker/linux/arm32v7/Dockerfile index 02c4d7e49ab..18dacc71802 100644 --- a/test/modules/TestResultCoordinator/docker/linux/arm32v7/Dockerfile +++ b/test/modules/TestResultCoordinator/docker/linux/arm32v7/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=1.0.6.2-linux-arm32v7 +ARG base_tag=1.0.6.4-linux-arm32v7 FROM azureiotedge/azureiotedge-module-base-full:${base_tag} ARG EXE_DIR=. diff --git a/test/modules/TestResultCoordinator/docker/linux/arm64v8/Dockerfile b/test/modules/TestResultCoordinator/docker/linux/arm64v8/Dockerfile index f53e4d72718..ec67e4717d6 100644 --- a/test/modules/TestResultCoordinator/docker/linux/arm64v8/Dockerfile +++ b/test/modules/TestResultCoordinator/docker/linux/arm64v8/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=1.0.6.2-linux-arm64v8 +ARG base_tag=1.0.6.4-linux-arm64v8 FROM azureiotedge/azureiotedge-module-base-full:${base_tag} ARG EXE_DIR=. diff --git a/test/modules/TwinTester/docker/linux/arm32v7/Dockerfile b/test/modules/TwinTester/docker/linux/arm32v7/Dockerfile index 5b5e34cf263..c13ef185a73 100644 --- a/test/modules/TwinTester/docker/linux/arm32v7/Dockerfile +++ b/test/modules/TwinTester/docker/linux/arm32v7/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=1.0.6.2-linux-arm32v7 +ARG base_tag=1.0.6.4-linux-arm32v7 FROM azureiotedge/azureiotedge-module-base-full:${base_tag} ARG EXE_DIR=. diff --git a/test/modules/TwinTester/docker/linux/arm64v8/Dockerfile b/test/modules/TwinTester/docker/linux/arm64v8/Dockerfile index 88ba95f4a4b..54b0670827d 100644 --- a/test/modules/TwinTester/docker/linux/arm64v8/Dockerfile +++ b/test/modules/TwinTester/docker/linux/arm64v8/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=1.0.6.2-linux-arm64v8 +ARG base_tag=1.0.6.4-linux-arm64v8 FROM azureiotedge/azureiotedge-module-base-full:${base_tag} ARG EXE_DIR=. diff --git a/test/modules/load-gen/docker/linux/arm32v7/Dockerfile b/test/modules/load-gen/docker/linux/arm32v7/Dockerfile index 0e6d79106a7..bd07b5cf9f0 100644 --- a/test/modules/load-gen/docker/linux/arm32v7/Dockerfile +++ b/test/modules/load-gen/docker/linux/arm32v7/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=1.0.6.2-linux-arm32v7 +ARG base_tag=1.0.6.4-linux-arm32v7 FROM azureiotedge/azureiotedge-module-base:${base_tag} ARG EXE_DIR=. diff --git a/test/modules/load-gen/docker/linux/arm64v8/Dockerfile b/test/modules/load-gen/docker/linux/arm64v8/Dockerfile index ae92f69e7cc..50c1c1bfaed 100644 --- a/test/modules/load-gen/docker/linux/arm64v8/Dockerfile +++ b/test/modules/load-gen/docker/linux/arm64v8/Dockerfile @@ -1,4 +1,4 @@ -ARG base_tag=1.0.6.2-linux-arm64v8 +ARG base_tag=1.0.6.4-linux-arm64v8 FROM azureiotedge/azureiotedge-module-base:${base_tag} ARG EXE_DIR=.