Skip to content

Commit

Permalink
add unit test for FlowProtocolDispatcher and FunctionProtocolDispatcher
Browse files Browse the repository at this point in the history
  • Loading branch information
zeyu10 committed Dec 20, 2024
1 parent 2870421 commit 083708f
Show file tree
Hide file tree
Showing 2 changed files with 349 additions and 17 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
/*
* Copyright 2021-2023 Weibo, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.weibo.rill.flow.service.dispatcher

import com.alibaba.fastjson.JSON
import com.weibo.rill.flow.interfaces.model.resource.Resource
import com.weibo.rill.flow.interfaces.model.strategy.DispatchInfo
import com.weibo.rill.flow.interfaces.model.task.FunctionPattern
import com.weibo.rill.flow.interfaces.model.task.FunctionTask
import com.weibo.rill.flow.interfaces.model.task.TaskInfo
import com.weibo.rill.flow.olympicene.core.model.dag.DAG
import com.weibo.rill.flow.olympicene.traversal.Olympicene
import com.weibo.rill.flow.service.dconfs.BizDConfs
import com.weibo.rill.flow.service.service.DAGDescriptorService
import com.weibo.rill.flow.service.statistic.DAGResourceStatistic
import spock.lang.Specification

class FlowProtocolDispatcherTest extends Specification {
FlowProtocolDispatcher dispatcher
DAGDescriptorService dagDescriptorService
BizDConfs bizDConfs
Olympicene olympicene
DAGResourceStatistic dagResourceStatistic

def setup() {
dispatcher = new FlowProtocolDispatcher()
dagDescriptorService = Mock(DAGDescriptorService)
bizDConfs = Mock(BizDConfs)
olympicene = Mock(Olympicene)
dagResourceStatistic = Mock(DAGResourceStatistic)

dispatcher.dagDescriptorService = dagDescriptorService
dispatcher.bizDConfs = bizDConfs
dispatcher.olympicene = olympicene
dispatcher.dagResourceStatistic = dagResourceStatistic
}

def "test handle method with valid input"() {
given:
def resource = Mock(Resource) {
getSchemeValue() >> "test-scheme"
getResourceName() >> "test-resource"
}
def taskInfo = Mock(TaskInfo) {
getName() >> "test-task"
getTask() >> Mock(FunctionTask) {
getPattern() >> FunctionPattern.FLOW_ASYNC
}
}
def dispatchInfo = Mock(DispatchInfo) {
getExecutionId() >> "parent-execution-id"
getTaskInfo() >> taskInfo
getInput() >> ["uid": "123", "key": "value"]
}
def dag = Mock(DAG)

when:
def result = dispatcher.handle(resource, dispatchInfo)

then:
1 * bizDConfs.getFlowDAGMaxDepth() >> 10
1 * dagDescriptorService.getDAG(123L, _, "test-scheme") >> dag
1 * olympicene.submit(_, dag, _, _, _)
1 * dagResourceStatistic.updateFlowTypeResourceStatus("parent-execution-id", "test-task", "test-resource", dag)

and:
def jsonResult = JSON.parseObject(result)
jsonResult.containsKey("execution_id")
jsonResult.get("execution_id") != null
}

def "test handle method with null input map"() {
given:
def resource = Mock(Resource) {
getSchemeValue() >> "test-scheme"
getResourceName() >> "test-resource"
}
def taskInfo = Mock(TaskInfo) {
getName() >> "test-task"
getTask() >> Mock(FunctionTask) {
getPattern() >> FunctionPattern.FLOW_ASYNC
}
}
def dispatchInfo = Mock(DispatchInfo) {
getExecutionId() >> "parent-execution-id"
getTaskInfo() >> taskInfo
getInput() >> null
}
def dag = Mock(DAG)

when:
def result = dispatcher.handle(resource, dispatchInfo)

then:
1 * bizDConfs.getFlowDAGMaxDepth() >> 10
1 * dagDescriptorService.getDAG(0L, _, "test-scheme") >> dag
1 * olympicene.submit(_, dag, _, _, _)
1 * dagResourceStatistic.updateFlowTypeResourceStatus("parent-execution-id", "test-task", "test-resource", dag)

and:
def jsonResult = JSON.parseObject(result)
jsonResult.containsKey("execution_id")
jsonResult.get("execution_id") != null
}

def "test handle method with invalid uid"() {
given:
def resource = Mock(Resource) {
getSchemeValue() >> "test-scheme"
getResourceName() >> "test-resource"
}
def taskInfo = Mock(TaskInfo) {
getName() >> "test-task"
getTask() >> Mock(FunctionTask) {
getPattern() >> FunctionPattern.FLOW_ASYNC
}
}
def dispatchInfo = Mock(DispatchInfo) {
getExecutionId() >> "parent-execution-id"
getTaskInfo() >> taskInfo
getInput() >> ["uid": null, "key": "value"]
}
def dag = Mock(DAG)

when:
def result = dispatcher.handle(resource, dispatchInfo)

then:
1 * bizDConfs.getFlowDAGMaxDepth() >> 10
1 * dagDescriptorService.getDAG(0L, _, "test-scheme") >> dag
1 * olympicene.submit(_, dag, _, _, _)
1 * dagResourceStatistic.updateFlowTypeResourceStatus("parent-execution-id", "test-task", "test-resource", dag)

and:
def jsonResult = JSON.parseObject(result)
jsonResult.containsKey("execution_id")
jsonResult.get("execution_id") != null
}

def "test handle method with non-numeric uid"() {
given:
def resource = Mock(Resource) {
getSchemeValue() >> "test-scheme"
getResourceName() >> "test-resource"
}
def taskInfo = Mock(TaskInfo) {
getName() >> "test-task"
getTask() >> Mock(FunctionTask) {
getPattern() >> FunctionPattern.FLOW_ASYNC
}
}
def dispatchInfo = Mock(DispatchInfo) {
getExecutionId() >> "parent-execution-id"
getTaskInfo() >> taskInfo
getInput() >> ["uid": "not-a-number", "key": "value"]
}
def dag = Mock(DAG)

when:
def result = dispatcher.handle(resource, dispatchInfo)

then:
thrown(NumberFormatException)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,40 +16,193 @@

package com.weibo.rill.flow.service.dispatcher

import com.weibo.rill.flow.common.exception.TaskException
import com.weibo.rill.flow.interfaces.model.http.HttpParameter
import com.weibo.rill.flow.interfaces.model.resource.Resource
import com.weibo.rill.flow.interfaces.model.strategy.DispatchInfo
import com.weibo.rill.flow.interfaces.model.task.FunctionTask
import com.weibo.rill.flow.interfaces.model.task.TaskInfo
import com.weibo.rill.flow.olympicene.core.switcher.SwitcherManager
import com.weibo.rill.flow.service.invoke.HttpInvokeHelper
import com.weibo.rill.flow.service.statistic.DAGResourceStatistic
import org.springframework.http.HttpHeaders
import org.springframework.http.HttpMethod
import org.springframework.http.MediaType
import org.springframework.util.LinkedMultiValueMap
import org.springframework.util.MultiValueMap
import org.springframework.web.client.RestClientResponseException
import spock.lang.Specification
import spock.lang.Unroll

class FunctionProtocolDispatcherTest extends Specification {
FunctionProtocolDispatcher dispatcher = new FunctionProtocolDispatcher();
FunctionProtocolDispatcher dispatcher
HttpInvokeHelper httpInvokeHelper
DAGResourceStatistic dagResourceStatistic
SwitcherManager switcherManager

def "buildHttpEntity test"() {
def setup() {
dispatcher = new FunctionProtocolDispatcher()
httpInvokeHelper = Mock(HttpInvokeHelper)
dagResourceStatistic = Mock(DAGResourceStatistic)
switcherManager = Mock(SwitcherManager)

dispatcher.httpInvokeHelper = httpInvokeHelper
dispatcher.dagResourceStatistic = dagResourceStatistic
dispatcher.switcherManagerImpl = switcherManager
}

def "test handle method with successful HTTP request"() {
given:
def resource = Mock(Resource)
def taskInfo = Mock(TaskInfo) {
getName() >> "test-task"
getTask() >> Mock(FunctionTask) {
getRequestType() >> "POST"
}
}
def dispatchInfo = Mock(DispatchInfo) {
getExecutionId() >> "test-execution-id"
getTaskInfo() >> taskInfo
getInput() >> ["key": "value"]
getHeaders() >> new LinkedMultiValueMap<String, String>()
}
def httpParameter = HttpParameter.builder()
.header(inputHeader)
.body(inputBody)
.queryParams([:])
.body([:])
.callback([:])
.header([:])
.build()
MultiValueMap<String, String> header = new LinkedMultiValueMap<>()
Optional.ofNullable(httpParameter.getHeader())
.ifPresent { it -> it.forEach { key, value -> header.add(key, value) } }
def expectedResponse = '{"status": "success"}'

when:
def httpEntity = dispatcher.buildHttpEntity(method, header, httpParameter)
def result = dispatcher.handle(resource, dispatchInfo)

then:
httpEntity.body == body
1 * switcherManager.getSwitcherState("ENABLE_FUNCTION_DISPATCH_RET_CHECK") >> false
1 * httpInvokeHelper.functionRequestParams(_, _, _, _) >> httpParameter
1 * httpInvokeHelper.buildUrl(_, _) >> "http://test.url"
1 * httpInvokeHelper.invokeRequest(_, _, _, _, _, _) >> expectedResponse
1 * dagResourceStatistic.updateUrlTypeResourceStatus(_, _, _, expectedResponse)
result == expectedResponse
}

def "test handle method with RestClientResponseException"() {
given:
def resource = Mock(Resource) {
getResourceName() >> "test-resource"
}
def taskInfo = Mock(TaskInfo) {
getName() >> "test-task"
getTask() >> Mock(FunctionTask) {
getRequestType() >> "POST"
}
}
def dispatchInfo = Mock(DispatchInfo) {
getExecutionId() >> "test-execution-id"
getTaskInfo() >> taskInfo
getInput() >> ["key": "value"]
getHeaders() >> new LinkedMultiValueMap<String, String>()
}
def httpParameter = HttpParameter.builder()
.queryParams([:])
.body([:])
.callback([:])
.header([:])
.build()
def errorResponse = "Error response"
def exception = Mock(RestClientResponseException) {
getRawStatusCode() >> 500
getResponseBodyAsString() >> errorResponse
}

when:
dispatcher.handle(resource, dispatchInfo)

then:
1 * switcherManager.getSwitcherState("ENABLE_FUNCTION_DISPATCH_RET_CHECK") >> false
1 * httpInvokeHelper.functionRequestParams(_, _, _, _) >> httpParameter
1 * httpInvokeHelper.buildUrl(_, _) >> "http://test.url"
1 * httpInvokeHelper.invokeRequest(_, _, _, _, _, _) >> { throw exception }
1 * dagResourceStatistic.updateUrlTypeResourceStatus(_, _, _, errorResponse)
thrown(TaskException)
}

@Unroll
def "test handle method with different HTTP methods: #requestType"() {
given:
def resource = Mock(Resource)
def taskInfo = Mock(TaskInfo) {
getName() >> "test-task"
getTask() >> Mock(FunctionTask) {
getRequestType() >> requestType
}
}
def dispatchInfo = Mock(DispatchInfo) {
getExecutionId() >> "test-execution-id"
getTaskInfo() >> taskInfo
getInput() >> ["key": "value"]
getHeaders() >> new LinkedMultiValueMap<String, String>()
}
def httpParameter = HttpParameter.builder()
.queryParams([:])
.body([:])
.callback([:])
.header([:])
.build()
def expectedResponse = '{"status": "success"}'

when:
def result = dispatcher.handle(resource, dispatchInfo)

then:
1 * switcherManager.getSwitcherState("ENABLE_FUNCTION_DISPATCH_RET_CHECK") >> false
1 * httpInvokeHelper.functionRequestParams(_, _, _, _) >> httpParameter
1 * httpInvokeHelper.buildUrl(_, _) >> "http://test.url"
1 * httpInvokeHelper.invokeRequest(_, _, _, _, expectedMethod, _) >> expectedResponse
1 * dagResourceStatistic.updateUrlTypeResourceStatus(_, _, _, expectedResponse)
result == expectedResponse

where:
method | inputHeader | inputBody | body
null | [:] | [:] | null
HttpMethod.GET | [:] | [:] | null
HttpMethod.POST | [:] | [:] | [:]
HttpMethod.POST | [:] | [k: "v", user: [name: "Bob"]] | [k: "v", user: [name: "Bob"]]
HttpMethod.POST | ["Content-Type": MediaType.APPLICATION_JSON_VALUE] | [k: "v", user: [name: "Bob"]] | [k: "v", user: [name: "Bob"]]
HttpMethod.POST | ["Content-Type": MediaType.APPLICATION_FORM_URLENCODED_VALUE] | [k: "v", name: "Bob"] | [k: ["v"], name: ["Bob"]]
requestType | expectedMethod
"POST" | HttpMethod.POST
"GET" | HttpMethod.GET
"PUT" | HttpMethod.PUT
null | HttpMethod.POST // default method
}

def "test handle method with form-urlencoded content type"() {
given:
def resource = Mock(Resource)
def taskInfo = Mock(TaskInfo) {
getName() >> "test-task"
getTask() >> Mock(FunctionTask) {
getRequestType() >> "POST"
}
}
def headers = new LinkedMultiValueMap<String, String>()
headers.add(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED_VALUE)
def dispatchInfo = Mock(DispatchInfo) {
getExecutionId() >> "test-execution-id"
getTaskInfo() >> taskInfo
getInput() >> ["key": "value"]
getHeaders() >> headers
}
def httpParameter = HttpParameter.builder()
.queryParams([:])
.body(["formKey": "formValue"])
.callback([:])
.header([:])
.build()
def expectedResponse = '{"status": "success"}'

when:
def result = dispatcher.handle(resource, dispatchInfo)

then:
1 * switcherManager.getSwitcherState("ENABLE_FUNCTION_DISPATCH_RET_CHECK") >> false
1 * httpInvokeHelper.functionRequestParams(_, _, _, _) >> httpParameter
1 * httpInvokeHelper.buildUrl(_, _) >> "http://test.url"
1 * httpInvokeHelper.invokeRequest(_, _, _, _, _, _) >> expectedResponse
1 * dagResourceStatistic.updateUrlTypeResourceStatus(_, _, _, expectedResponse)
result == expectedResponse
}
}

0 comments on commit 083708f

Please sign in to comment.