diff --git a/handlers/deploy.go b/handlers/deploy.go index 55f254fd6..c188637a9 100644 --- a/handlers/deploy.go +++ b/handlers/deploy.go @@ -160,7 +160,10 @@ func makeDeploymentSpec(request requests.CreateFunctionRequest, existingSecrets FailureThreshold: 3, } - initialReplicas := getMinReplicaCount(request.Labels) + initialReplicas, replicaErr := getMinReplicaCount(request.Labels) + if replicaErr != nil { + return nil, replicaErr + } labels := buildLabels(request.Service, request.Labels) nodeSelector := createSelector(request.Constraints) resources, resourceErr := createResources(request) diff --git a/handlers/labels.go b/handlers/labels.go index 58ac97f43..736cd3bcf 100644 --- a/handlers/labels.go +++ b/handlers/labels.go @@ -1,8 +1,8 @@ package handlers import ( + "errors" "fmt" - "log" "strconv" "time" ) @@ -42,20 +42,24 @@ func buildLabels(functionName string, requestLables *map[string]string) map[stri // getMinReplicaCount extracts the functions minimum replica count from the user's // request labels. If the value is not found, this will return the default value, 1. -func getMinReplicaCount(labels *map[string]string) *int32 { +func getMinReplicaCount(labels *map[string]string) (*int32, error) { if labels == nil { - return int32p(initialReplicasCount) + return int32p(initialReplicasCount), nil } l := *labels if value, exists := l[FunctionMinReplicaCount]; exists { minReplicas, err := strconv.Atoi(value) - if err == nil && minReplicas > 0 { - return int32p(int32(minReplicas)) + if err != nil { + return nil, errors.New("could not parse the minimum replica value") } - log.Println(err) + if minReplicas > 0 { + return int32p(int32(minReplicas)), nil + } + + return nil, errors.New("replica count must be a positive integer") } - return int32p(initialReplicasCount) + return int32p(initialReplicasCount), nil } diff --git a/handlers/lables_test.go b/handlers/lables_test.go index 5ca411f9a..1f1371d8d 100644 --- a/handlers/lables_test.go +++ b/handlers/lables_test.go @@ -1,6 +1,9 @@ package handlers -import "testing" +import ( + "strings" + "testing" +) func Test_getMinReplicaCount(t *testing.T) { scenarios := []struct { @@ -27,7 +30,10 @@ func Test_getMinReplicaCount(t *testing.T) { for _, s := range scenarios { t.Run(s.name, func(t *testing.T) { - output := getMinReplicaCount(s.labels) + output, err := getMinReplicaCount(s.labels) + if err != nil { + t.Errorf("getMinReplicaCount should not error on an empty or valid label") + } if output == nil { t.Errorf("getMinReplicaCount should not return nil pointer") } @@ -40,6 +46,47 @@ func Test_getMinReplicaCount(t *testing.T) { } } +func Test_getMinReplicaCount_Error(t *testing.T) { + scenarios := []struct { + name string + labels *map[string]string + msg string + }{ + { + name: "negative values should return an error", + labels: &map[string]string{FunctionMinReplicaCount: "-2"}, + msg: "replica count must be a positive integer", + }, + { + name: "zero values should return an error", + labels: &map[string]string{FunctionMinReplicaCount: "0"}, + msg: "replica count must be a positive integer", + }, + { + name: "decimal values should return an error", + labels: &map[string]string{FunctionMinReplicaCount: "10.5"}, + msg: "could not parse the minimum replica value", + }, + { + name: "non-integer values should return an error", + labels: &map[string]string{FunctionMinReplicaCount: "test"}, + msg: "could not parse the minimum replica value", + }, + } + + for _, s := range scenarios { + t.Run(s.name, func(t *testing.T) { + output, err := getMinReplicaCount(s.labels) + if output != nil { + t.Errorf("getMinReplicaCount should return nil value on invalid input") + } + if !strings.Contains(err.Error(), s.msg) { + t.Errorf("unexpected error: expected '%s', got '%s'", s.msg, err.Error()) + } + }) + } +} + func Test_parseLabels(t *testing.T) { scenarios := []struct { name string diff --git a/handlers/update.go b/handlers/update.go index b3d5ed71d..2aab246a5 100644 --- a/handlers/update.go +++ b/handlers/update.go @@ -75,7 +75,12 @@ func updateDeploymentSpec( deployment.Labels = labels deployment.Spec.Template.ObjectMeta.Labels = labels - deployment.Spec.Replicas = getMinReplicaCount(request.Labels) + + replicaCount, replicaErr := getMinReplicaCount(request.Labels) + if replicaErr != nil { + return replicaErr, http.StatusBadRequest + } + deployment.Spec.Replicas = replicaCount deployment.Annotations = annotations deployment.Spec.Template.Annotations = annotations