diff --git a/pkg/diff/diff.go b/pkg/diff/diff.go index 8730e82..3456946 100644 --- a/pkg/diff/diff.go +++ b/pkg/diff/diff.go @@ -185,6 +185,8 @@ func (sc *Syncer) init() error { types.RBACRole, types.RBACEndpointPermission, types.ServicePackage, types.ServiceVersion, types.Document, + + types.FilterChain, } sc.entityDiffers = map[types.EntityType]types.Differ{} diff --git a/pkg/diff/order.go b/pkg/diff/order.go index 1364e98..bf1ae3b 100644 --- a/pkg/diff/order.go +++ b/pkg/diff/order.go @@ -22,7 +22,8 @@ L3 +---------------------------> Service <---+ +-> Route | | Version | | | | | | | | | | | v | -L4 +----------> Document <---------+ +-> Plugins <----------+ +L4 +----------> Document <---------+ +-> Plugins / <---------+ + FilterChains */ // dependencyOrder defines the order in which entities will be synced by decK. @@ -61,6 +62,7 @@ var dependencyOrder = [][]types.EntityType{ }, { types.Plugin, + types.FilterChain, types.Document, }, } diff --git a/pkg/dump/dump.go b/pkg/dump/dump.go index 61a6b84..8cb0a21 100644 --- a/pkg/dump/dump.go +++ b/pkg/dump/dump.go @@ -42,6 +42,9 @@ type Config struct { // IsConsumerGroupScopedPluginSupported IsConsumerGroupScopedPluginSupported bool + + // IsFilterChainsSupported + IsFilterChainsSupported bool } func deduplicate(stringSlice []string) []string { @@ -252,6 +255,19 @@ func getProxyConfiguration(ctx context.Context, group *errgroup.Group, return nil }) + if config.IsFilterChainsSupported { + group.Go(func() error { + filterChains, err := GetAllFilterChains(ctx, client, config.SelectorTags) + if err != nil { + return fmt.Errorf("filter chains: %w", err) + } + state.FilterChains = filterChains + return nil + }) + } else { + state.FilterChains = make([]*kong.FilterChain, 0) + } + group.Go(func() error { certificates, err := GetAllCertificates(ctx, client, config.SelectorTags) if err != nil { @@ -441,6 +457,30 @@ func GetAllPlugins(ctx context.Context, return plugins, nil } +// GetAllFilterChains queries Kong for all the filter chains using client. +func GetAllFilterChains(ctx context.Context, + client *kong.Client, tags []string, +) ([]*kong.FilterChain, error) { + var filterChains []*kong.FilterChain + opt := newOpt(tags) + + for { + s, nextopt, err := client.FilterChains.List(ctx, opt) + if err != nil { + return nil, err + } + if err := ctx.Err(); err != nil { + return nil, err + } + filterChains = append(filterChains, s...) + if nextopt == nil { + break + } + opt = nextopt + } + return filterChains, nil +} + // GetAllCertificates queries Kong for all the certificates using client. func GetAllCertificates(ctx context.Context, client *kong.Client, tags []string, diff --git a/pkg/file/builder.go b/pkg/file/builder.go index 822252d..70bed9c 100644 --- a/pkg/file/builder.go +++ b/pkg/file/builder.go @@ -93,6 +93,7 @@ func (b *stateBuilder) build() (*utils.KongRawState, *utils.KonnectRawState, err b.consumerGroups() b.consumers() b.plugins() + b.filterChains() b.enterprise() // konnect @@ -894,6 +895,16 @@ func (b *stateBuilder) ingestService(s *FService) error { return err } + // filter chains for the service + var filterChains []FFilterChain + for _, f := range s.FilterChains { + f.Service = utils.GetServiceReference(s.Service) + filterChains = append(filterChains, *f) + } + if err := b.ingestFilterChains(filterChains); err != nil { + return err + } + // routes for the service for _, r := range s.Routes { r := r @@ -1153,6 +1164,48 @@ func (b *stateBuilder) plugins() { } } +func (b *stateBuilder) filterChains() { + if b.err != nil { + return + } + + var filterChains []FFilterChain + for _, f := range b.targetContent.FilterChains { + f := f + if f.Service != nil && !utils.Empty(f.Service.ID) { + s, err := b.intermediate.Services.Get(*f.Service.ID) + if errors.Is(err, state.ErrNotFound) { + b.err = fmt.Errorf("service %v for filterChain %v: %w", + f.Service.FriendlyName(), *f.Name, err) + + return + } else if err != nil { + b.err = err + return + } + f.Service = utils.GetServiceReference(s.Service) + } + if f.Route != nil && !utils.Empty(f.Route.ID) { + r, err := b.intermediate.Routes.Get(*f.Route.ID) + if errors.Is(err, state.ErrNotFound) { + b.err = fmt.Errorf("route %v for filterChain %v: %w", + f.Route.FriendlyName(), *f.Name, err) + + return + } else if err != nil { + b.err = err + return + } + f.Route = utils.GetRouteReference(r.Route) + } + filterChains = append(filterChains, f) + } + if err := b.ingestFilterChains(filterChains); err != nil { + b.err = err + return + } +} + func (b *stateBuilder) validatePlugin(p FPlugin) error { if b.isConsumerGroupScopedPluginSupported && *p.Name == ratelimitingAdvancedPluginName { // check if deprecated consumer-groups configuration is present in the config @@ -1270,6 +1323,16 @@ func (b *stateBuilder) ingestRoute(r FRoute) error { return err } + // filter chains for the route + var filterChains []FFilterChain + for _, f := range r.FilterChains { + f.Route = utils.GetRouteReference(r.Route) + filterChains = append(filterChains, *f) + } + if err := b.ingestFilterChains(filterChains); err != nil { + return err + } + // plugins for the route var plugins []FPlugin for _, p := range r.Plugins { @@ -1398,6 +1461,36 @@ func pluginRelations(plugin *kong.Plugin) (cID, rID, sID, cgID string) { return } +func (b *stateBuilder) ingestFilterChains(filterChains []FFilterChain) error { + for _, f := range filterChains { + f := f + if utils.Empty(f.ID) { + rID, sID := filterChainRelations(&f.FilterChain) + filterChain, err := b.currentState.FilterChains.GetByProp(sID, rID) + if errors.Is(err, state.ErrNotFound) { + f.ID = uuid() + } else if err != nil { + return err + } else { + f.ID = kong.String(*filterChain.ID) + } + } + utils.MustMergeTags(&f, b.selectTags) + b.rawState.FilterChains = append(b.rawState.FilterChains, &f.FilterChain) + } + return nil +} + +func filterChainRelations(filterChain *kong.FilterChain) (rID, sID string) { + if filterChain.Route != nil && !utils.Empty(filterChain.Route.ID) { + rID = *filterChain.Route.ID + } + if filterChain.Service != nil && !utils.Empty(filterChain.Service.ID) { + sID = *filterChain.Service.ID + } + return +} + func defaulter( ctx context.Context, client *kong.Client, fileContent *Content, disableDynamicDefaults, isKonnect bool, ) (*utils.Defaulter, error) { diff --git a/pkg/file/codegen/main.go b/pkg/file/codegen/main.go index 4506d50..22928e7 100644 --- a/pkg/file/codegen/main.go +++ b/pkg/file/codegen/main.go @@ -76,6 +76,49 @@ func main() { }, } + schema.Definitions["FFilterChain"].AnyOf = []*jsonschema.Type{ + { + Required: []string{"filters", "name"}, + }, + { + Required: []string{"filters", "id"}, + }, + } + schema.Definitions["FFilterChain"].Properties["enabled"] = &jsonschema.Type{ + Type: "boolean", + } + schema.Definitions["FFilterChain"].Properties["filters"] = &jsonschema.Type{ + Type: "array", + Items: &jsonschema.Type{ + Ref: "#/definitions/FFilter", + }, + } + + schema.Definitions["FFilter"] = &jsonschema.Type{ + Type: "object", + Required: []string{"name"}, + AdditionalProperties: json.RawMessage(`false`), + Properties: map[string]*jsonschema.Type{ + "name": { + Type: "string", + }, + "config": { + OneOf: []*jsonschema.Type{ + {Type: "array"}, + {Type: "boolean"}, + {Type: "integer"}, + {Type: "number"}, + {Type: "null"}, + {Type: "object"}, + {Type: "string"}, + }, + }, + "enabled": { + Type: "boolean", + }, + }, + } + // creds schema.Definitions["ACLGroup"].Required = []string{"group"} schema.Definitions["BasicAuth"].Required = []string{"username", "password"} @@ -102,6 +145,9 @@ func main() { schema.Definitions["FPlugin"].Properties["route"] = stringType schema.Definitions["FPlugin"].Properties["consumer_group"] = stringType + schema.Definitions["FFilterChain"].Properties["service"] = stringType + schema.Definitions["FFilterChain"].Properties["route"] = stringType + schema.Definitions["FService"].Properties["client_certificate"] = stringType // konnect resources diff --git a/pkg/file/kong_json_schema.json b/pkg/file/kong_json_schema.json index 3e1bdca..d5cda3c 100644 --- a/pkg/file/kong_json_schema.json +++ b/pkg/file/kong_json_schema.json @@ -55,6 +55,12 @@ }, "type": "array" }, + "filter_chains": { + "items": { + "$ref": "#/definitions/FFilterChain" + }, + "type": "array" + }, "licenses": { "items": { "$schema": "http://json-schema.org/draft-04/schema#", @@ -571,6 +577,99 @@ "additionalProperties": false, "type": "object" }, + "FFilter": { + "required": [ + "name" + ], + "properties": { + "config": { + "oneOf": [ + { + "type": "array" + }, + { + "type": "boolean" + }, + { + "type": "integer" + }, + { + "type": "number" + }, + { + "type": "null" + }, + { + "type": "object" + }, + { + "type": "string" + } + ] + }, + "enabled": { + "type": "boolean" + }, + "name": { + "type": "string" + } + }, + "additionalProperties": false, + "type": "object" + }, + "FFilterChain": { + "properties": { + "created_at": { + "type": "integer" + }, + "enabled": { + "type": "boolean" + }, + "filters": { + "items": { + "$ref": "#/definitions/FFilter" + }, + "type": "array" + }, + "id": { + "type": "string" + }, + "name": { + "type": "string" + }, + "route": { + "type": "string" + }, + "service": { + "type": "string" + }, + "tags": { + "items": { + "type": "string" + }, + "type": "array" + }, + "updated_at": { + "type": "integer" + } + }, + "additionalProperties": false, + "type": "object", + "anyOf": [ + { + "required": [ + "filters", + "name" + ] + }, + { + "required": [ + "filters", + "id" + ] + } + ] + }, "FLicense": { "properties": { "created_at": { @@ -731,6 +830,13 @@ "expression": { "type": "string" }, + "filter_chains": { + "items": { + "$schema": "http://json-schema.org/draft-04/schema#", + "$ref": "#/definitions/FFilterChain" + }, + "type": "array" + }, "headers": { "patternProperties": { ".*": { @@ -863,6 +969,12 @@ "enabled": { "type": "boolean" }, + "filter_chains": { + "items": { + "$ref": "#/definitions/FFilterChain" + }, + "type": "array" + }, "host": { "type": "string" }, @@ -1131,6 +1243,24 @@ "additionalProperties": false, "type": "object" }, + "Filter": { + "properties": { + "config": { + "items": { + "type": "integer" + }, + "type": "array" + }, + "enabled": { + "type": "boolean" + }, + "name": { + "type": "string" + } + }, + "additionalProperties": false, + "type": "object" + }, "HMACAuth": { "required": [ "username", diff --git a/pkg/file/types.go b/pkg/file/types.go index 09fa016..f682fbb 100644 --- a/pkg/file/types.go +++ b/pkg/file/types.go @@ -30,12 +30,180 @@ const ( httpsPort = 443 ) +// FFilterChain represents a Kong FilterChain. +// +k8s:deepcopy-gen=true +type FFilterChain struct { + kong.FilterChain `yaml:",inline,omitempty"` +} + +// SerializableFilter is a shadow type +// used for custom marshalling of filters in a FilterChain. +type SerializableFilter struct { + Name *string `json:"name,omitempty" yaml:"name,omitempty"` + Config *json.RawMessage `json:"config,omitempty" yaml:"name,omitempty"` + Enabled *bool `json:"enabled,omitempty" yaml:"enabled,omitempty"` +} + +// SerializableFilterChain is a shadow type +// used for custom marshalling of FilterChain. +type SerializableFilterChain struct { + CreatedAt *int `json:"created_at,omitempty" yaml:"created_at,omitempty"` + UpdatedAt *int `json:"updated_at,omitempty" yaml:"updated_at,omitempty"` + ID *string `json:"id,omitempty" yaml:"id,omitempty"` + Name *string `json:"name,omitempty" yaml:"name,omitempty"` + Enabled *bool `json:"enabled,omitempty" yaml:"enabled,omitempty"` + Route string `json:"route,omitempty" yaml:",omitempty"` + Service string `json:"service,omitempty" yaml:",omitempty"` + Filters []*SerializableFilter `json:"filters,omitempty" yaml:",omitempty"` + Tags []*string `json:"tags,omitempty" yaml:"tags,omitempty"` +} + +func copyToSerializableFilterChain(f FFilterChain) SerializableFilterChain { + sf := SerializableFilterChain{} + if f.CreatedAt != nil { + sf.CreatedAt = f.CreatedAt + } + if f.UpdatedAt != nil { + sf.UpdatedAt = f.UpdatedAt + } + if f.ID != nil { + sf.ID = f.ID + } + if f.Name != nil { + sf.Name = f.Name + } + if f.Enabled != nil { + sf.Enabled = f.Enabled + } + if f.FilterChain.Route != nil { + sf.Route = *f.FilterChain.Route.ID + } + if f.FilterChain.Service != nil { + sf.Service = *f.FilterChain.Service.ID + } + if f.Filters != nil { + sf.Filters = []*SerializableFilter{} + for _, filter := range f.Filters { + sfilter := &SerializableFilter{} + if filter.Name != nil { + sfilter.Name = filter.Name + } + if filter.Config != nil { + sfilter.Config = filter.Config + } + if filter.Enabled != nil { + sfilter.Enabled = filter.Enabled + } + sf.Filters = append(sf.Filters, sfilter) + } + } + if f.Tags != nil { + sf.Tags = f.Tags + } + return sf +} + +func copyFromSerializableFilterChain(sf SerializableFilterChain, f *FFilterChain) { + if sf.CreatedAt != nil { + f.CreatedAt = sf.CreatedAt + } + if sf.UpdatedAt != nil { + f.UpdatedAt = sf.UpdatedAt + } + if sf.ID != nil { + f.ID = sf.ID + } + if sf.Name != nil { + f.Name = sf.Name + } + if sf.Enabled != nil { + f.Enabled = sf.Enabled + } + if sf.Filters != nil { + f.Filters = []*kong.Filter{} + for _, sfilter := range sf.Filters { + filter := &kong.Filter{} + if sfilter.Name != nil { + filter.Name = sfilter.Name + } + if sfilter.Config != nil { + filter.Config = sfilter.Config + } + if sfilter.Enabled != nil { + filter.Enabled = sfilter.Enabled + } + f.Filters = append(f.Filters, filter) + } + } + if sf.Tags != nil { + f.Tags = sf.Tags + } + if sf.Route != "" { + f.Route = &kong.Route{ + ID: kong.String(sf.Route), + } + } + if sf.Service != "" { + f.Service = &kong.Service{ + ID: kong.String(sf.Service), + } + } +} + +// MarshalYAML is a custom marshal method to handle +// foreign references. +func (f FFilterChain) MarshalYAML() (interface{}, error) { + return copyToSerializableFilterChain(f), nil +} + +// UnmarshalYAML is a custom marshal method to handle +// foreign references. +func (f *FFilterChain) UnmarshalYAML(unmarshal func(interface{}) error) error { + var sf SerializableFilterChain + if err := unmarshal(&sf); err != nil { + return err + } + copyFromSerializableFilterChain(sf, f) + return nil +} + +// MarshalJSON is a custom marshal method to handle +// foreign references. +func (f FFilterChain) MarshalJSON() ([]byte, error) { + sf := copyToSerializableFilterChain(f) + return json.Marshal(sf) +} + +// UnmarshalJSON is a custom marshal method to handle +// foreign references. +func (f *FFilterChain) UnmarshalJSON(b []byte) error { + var sf SerializableFilterChain + err := json.Unmarshal(b, &sf) + if err != nil { + return err + } + copyFromSerializableFilterChain(sf, f) + return nil +} + +// sortKey is used for sorting. +func (f FFilterChain) sortKey() string { + if f.Name != nil { + return *f.Name + } + if f.ID != nil { + return *f.ID + } + return "" +} + // FService represents a Kong Service and it's associated routes and plugins. // +k8s:deepcopy-gen=true type FService struct { kong.Service - Routes []*FRoute `json:"routes,omitempty" yaml:",omitempty"` - Plugins []*FPlugin `json:"plugins,omitempty" yaml:",omitempty"` + Routes []*FRoute `json:"routes,omitempty" yaml:",omitempty"` + Plugins []*FPlugin `json:"plugins,omitempty" yaml:",omitempty"` + FilterChains []*FFilterChain `json:"filter_chains,omitempty" yaml:",omitempty"` // sugar property URL *string `json:"url,omitempty" yaml:",omitempty"` @@ -53,26 +221,27 @@ func (s FService) sortKey() string { } type service struct { - ClientCertificate *string `json:"client_certificate,omitempty" yaml:"client_certificate,omitempty"` - ConnectTimeout *int `json:"connect_timeout,omitempty" yaml:"connect_timeout,omitempty"` - CreatedAt *int `json:"created_at,omitempty" yaml:"created_at,omitempty"` - Host *string `json:"host,omitempty" yaml:"host,omitempty"` - ID *string `json:"id,omitempty" yaml:"id,omitempty"` - Name *string `json:"name,omitempty" yaml:"name,omitempty"` - Path *string `json:"path,omitempty" yaml:"path,omitempty"` - Port *int `json:"port,omitempty" yaml:"port,omitempty"` - Protocol *string `json:"protocol,omitempty" yaml:"protocol,omitempty"` - ReadTimeout *int `json:"read_timeout,omitempty" yaml:"read_timeout,omitempty"` - Retries *int `json:"retries,omitempty" yaml:"retries,omitempty"` - UpdatedAt *int `json:"updated_at,omitempty" yaml:"updated_at,omitempty"` - WriteTimeout *int `json:"write_timeout,omitempty" yaml:"write_timeout,omitempty"` - Tags []*string `json:"tags,omitempty" yaml:"tags,omitempty"` - TLSVerify *bool `json:"tls_verify,omitempty" yaml:"tls_verify,omitempty"` - TLSVerifyDepth *int `json:"tls_verify_depth,omitempty" yaml:"tls_verify_depth,omitempty"` - CACertificates []*string `json:"ca_certificates,omitempty" yaml:"ca_certificates,omitempty"` - Enabled *bool `json:"enabled,omitempty" yaml:"enabled,omitempty"` - Routes []*FRoute `json:"routes,omitempty" yaml:",omitempty"` - Plugins []*FPlugin `json:"plugins,omitempty" yaml:",omitempty"` + ClientCertificate *string `json:"client_certificate,omitempty" yaml:"client_certificate,omitempty"` + ConnectTimeout *int `json:"connect_timeout,omitempty" yaml:"connect_timeout,omitempty"` + CreatedAt *int `json:"created_at,omitempty" yaml:"created_at,omitempty"` + Host *string `json:"host,omitempty" yaml:"host,omitempty"` + ID *string `json:"id,omitempty" yaml:"id,omitempty"` + Name *string `json:"name,omitempty" yaml:"name,omitempty"` + Path *string `json:"path,omitempty" yaml:"path,omitempty"` + Port *int `json:"port,omitempty" yaml:"port,omitempty"` + Protocol *string `json:"protocol,omitempty" yaml:"protocol,omitempty"` + ReadTimeout *int `json:"read_timeout,omitempty" yaml:"read_timeout,omitempty"` + Retries *int `json:"retries,omitempty" yaml:"retries,omitempty"` + UpdatedAt *int `json:"updated_at,omitempty" yaml:"updated_at,omitempty"` + WriteTimeout *int `json:"write_timeout,omitempty" yaml:"write_timeout,omitempty"` + Tags []*string `json:"tags,omitempty" yaml:"tags,omitempty"` + TLSVerify *bool `json:"tls_verify,omitempty" yaml:"tls_verify,omitempty"` + TLSVerifyDepth *int `json:"tls_verify_depth,omitempty" yaml:"tls_verify_depth,omitempty"` + CACertificates []*string `json:"ca_certificates,omitempty" yaml:"ca_certificates,omitempty"` + Enabled *bool `json:"enabled,omitempty" yaml:"enabled,omitempty"` + Routes []*FRoute `json:"routes,omitempty" yaml:",omitempty"` + Plugins []*FPlugin `json:"plugins,omitempty" yaml:",omitempty"` + FilterChains []*FFilterChain `json:"filter_chains,omitempty" yaml:",omitempty"` // sugar property URL *string `json:"url,omitempty" yaml:",omitempty"` @@ -102,6 +271,7 @@ func copyToService(fService FService) service { s.Tags = fService.Tags s.Routes = fService.Routes s.Plugins = fService.Plugins + s.FilterChains = fService.FilterChains s.Enabled = fService.Enabled return s @@ -180,6 +350,7 @@ func copyFromService(service service, fService *FService) error { fService.TLSVerifyDepth = service.TLSVerifyDepth fService.Routes = service.Routes fService.Plugins = service.Plugins + fService.FilterChains = service.FilterChains fService.Enabled = service.Enabled return nil } @@ -218,11 +389,12 @@ func (s *FService) UnmarshalJSON(b []byte) error { return copyFromService(service, s) } -// FRoute represents a Kong Route and it's associated plugins. +// FRoute represents a Kong Route and it's associated plugins and filter chains. // +k8s:deepcopy-gen=true type FRoute struct { - kong.Route `yaml:",inline,omitempty"` - Plugins []*FPlugin `json:"plugins,omitempty" yaml:",omitempty"` + kong.Route `yaml:",inline,omitempty"` + Plugins []*FPlugin `json:"plugins,omitempty" yaml:",omitempty"` + FilterChains []*FFilterChain `json:"filter_chains,omitempty" yaml:",omitempty"` } // sortKey is used for sorting. @@ -736,6 +908,7 @@ type Content struct { Consumers []FConsumer `json:"consumers,omitempty" yaml:",omitempty"` ConsumerGroups []FConsumerGroupObject `json:"consumer_groups,omitempty" yaml:",omitempty"` Plugins []FPlugin `json:"plugins,omitempty" yaml:",omitempty"` + FilterChains []FFilterChain `json:"filter_chains,omitempty" yaml:",omitempty"` Upstreams []FUpstream `json:"upstreams,omitempty" yaml:",omitempty"` Certificates []FCertificate `json:"certificates,omitempty" yaml:",omitempty"` CACertificates []FCACertificate `json:"ca_certificates,omitempty" yaml:"ca_certificates,omitempty"` diff --git a/pkg/file/writer.go b/pkg/file/writer.go index eb439bb..ca87c05 100644 --- a/pkg/file/writer.go +++ b/pkg/file/writer.go @@ -83,6 +83,11 @@ func KongStateToContent(kongState *state.KongState, config WriteConfig) (*Conten return nil, err } + err = populateFilterChains(kongState, file, config) + if err != nil { + return nil, err + } + err = populateUpstreams(kongState, file, config) if err != nil { return nil, err @@ -149,6 +154,11 @@ func KonnectStateToFile(kongState *state.KongState, config WriteConfig) error { return err } + err = populateFilterChains(kongState, file, config) + if err != nil { + return err + } + err = populateUpstreams(kongState, file, config) if err != nil { return err @@ -284,6 +294,49 @@ func populateServices(kongState *state.KongState, file *Content, return nil } +func getFRouteFromRoute(r *state.Route, kongState *state.KongState, config WriteConfig) (*FRoute, error) { + plugins, err := kongState.Plugins.GetAllByRouteID(*r.ID) + if err != nil { + return nil, err + } + filterChains, err := kongState.FilterChains.GetAllByRouteID(*r.ID) + if err != nil { + return nil, err + } + utils.ZeroOutID(r, r.Name, config.WithID) + utils.ZeroOutTimestamps(r) + utils.MustRemoveTags(&r.Route, config.SelectTags) + + route := &FRoute{Route: r.Route} + + for _, p := range plugins { + if p.Service != nil || p.Consumer != nil { + continue + } + p.Route = nil + utils.ZeroOutID(p, p.Name, config.WithID) + utils.ZeroOutTimestamps(p) + utils.MustRemoveTags(&p.Plugin, config.SelectTags) + route.Plugins = append(route.Plugins, &FPlugin{Plugin: p.Plugin}) + } + sort.SliceStable(route.Plugins, func(i, j int) bool { + return compareOrder(route.Plugins[i], route.Plugins[j]) + }) + + for _, f := range filterChains { + f.Route = nil + utils.ZeroOutID(f, f.Name, config.WithID) + utils.ZeroOutTimestamps(f) + utils.MustRemoveTags(&f.FilterChain, config.SelectTags) + route.FilterChains = append(route.FilterChains, &FFilterChain{FilterChain: f.FilterChain}) + } + sort.SliceStable(route.FilterChains, func(i, j int) bool { + return compareOrder(route.FilterChains[i], route.FilterChains[j]) + }) + + return route, nil +} + func fetchService(id string, kongState *state.KongState, config WriteConfig) (*FService, error) { kongService, err := kongState.Services.Get(id) if err != nil { @@ -298,6 +351,10 @@ func fetchService(id string, kongState *state.KongState, config WriteConfig) (*F if err != nil { return nil, err } + filterChains, err := kongState.FilterChains.GetAllByServiceID(*s.ID) + if err != nil { + return nil, err + } for _, p := range plugins { p := p if p.Route != nil || p.Consumer != nil || p.ConsumerGroup != nil { @@ -314,34 +371,26 @@ func fetchService(id string, kongState *state.KongState, config WriteConfig) (*F }) for _, r := range routes { r := r - plugins, err := kongState.Plugins.GetAllByRouteID(*r.ID) + r.Service = nil + route, err := getFRouteFromRoute(r, kongState, config) if err != nil { return nil, err } - r.Service = nil - utils.ZeroOutID(r, r.Name, config.WithID) - utils.ZeroOutTimestamps(r) - utils.MustRemoveTags(&r.Route, config.SelectTags) - route := &FRoute{Route: r.Route} - for _, p := range plugins { - p := p - if p.Service != nil || p.Consumer != nil || p.ConsumerGroup != nil { - continue - } - p.Route = nil - utils.ZeroOutID(p, p.Name, config.WithID) - utils.ZeroOutTimestamps(p) - utils.MustRemoveTags(&p.Plugin, config.SelectTags) - route.Plugins = append(route.Plugins, &FPlugin{Plugin: p.Plugin}) - } - sort.SliceStable(route.Plugins, func(i, j int) bool { - return compareOrder(route.Plugins[i], route.Plugins[j]) - }) s.Routes = append(s.Routes, route) } sort.SliceStable(s.Routes, func(i, j int) bool { return compareOrder(s.Routes[i], s.Routes[j]) }) + for _, f := range filterChains { + f.Service = nil + utils.ZeroOutID(f, f.Name, config.WithID) + utils.ZeroOutTimestamps(f) + utils.MustRemoveTags(&f.FilterChain, config.SelectTags) + s.FilterChains = append(s.FilterChains, &FFilterChain{FilterChain: f.FilterChain}) + } + sort.SliceStable(s.FilterChains, func(i, j int) bool { + return compareOrder(s.FilterChains[i], s.FilterChains[j]) + }) utils.ZeroOutID(&s, s.Name, config.WithID) utils.ZeroOutTimestamps(&s) utils.MustRemoveTags(&s, config.SelectTags) @@ -360,28 +409,10 @@ func populateServicelessRoutes(kongState *state.KongState, file *Content, if r.Service != nil { continue } - plugins, err := kongState.Plugins.GetAllByRouteID(*r.ID) + route, err := getFRouteFromRoute(r, kongState, config) if err != nil { return err } - utils.ZeroOutID(r, r.Name, config.WithID) - utils.ZeroOutTimestamps(r) - utils.MustRemoveTags(&r.Route, config.SelectTags) - route := &FRoute{Route: r.Route} - for _, p := range plugins { - p := p - if p.Service != nil || p.Consumer != nil || p.ConsumerGroup != nil { - continue - } - p.Route = nil - utils.ZeroOutID(p, p.Name, config.WithID) - utils.ZeroOutTimestamps(p) - utils.MustRemoveTags(&p.Plugin, config.SelectTags) - route.Plugins = append(route.Plugins, &FPlugin{Plugin: p.Plugin}) - } - sort.SliceStable(route.Plugins, func(i, j int) bool { - return compareOrder(route.Plugins[i], route.Plugins[j]) - }) file.Routes = append(file.Routes, *route) } sort.SliceStable(file.Routes, func(i, j int) bool { @@ -462,6 +493,50 @@ func populatePlugins(kongState *state.KongState, file *Content, return nil } +func populateFilterChains(kongState *state.KongState, file *Content, + _ WriteConfig, +) error { + filterChains, err := kongState.FilterChains.GetAll() + if err != nil { + return err + } + + for _, f := range filterChains { + associations := 0 + if f.Service != nil { + associations++ + sID := *f.Service.ID + service, err := kongState.Services.Get(sID) + if err != nil { + return fmt.Errorf("unable to get service %s for filter chain %s [%s]: %w", sID, *f.Name, *f.ID, err) + } + if !utils.Empty(service.Name) { + sID = *service.Name + } + f.Service.ID = &sID + } + if f.Route != nil { + associations++ + rID := *f.Route.ID + route, err := kongState.Routes.Get(rID) + if err != nil { + return fmt.Errorf("unable to get route %s for filter chain %s [%s]: %w", rID, *f.Name, *f.ID, err) + } + if !utils.Empty(route.Name) { + rID = *route.Name + } + f.Route.ID = &rID + } + if associations != 1 { + return fmt.Errorf("unable to determine route or service entity associated with filter chain %s [%s]", *f.Name, *f.ID) + } + } + sort.SliceStable(file.FilterChains, func(i, j int) bool { + return compareOrder(file.FilterChains[i], file.FilterChains[j]) + }) + return nil +} + func populateUpstreams(kongState *state.KongState, file *Content, config WriteConfig, ) error { diff --git a/pkg/file/writer_test.go b/pkg/file/writer_test.go index 5fed7a1..991a284 100644 --- a/pkg/file/writer_test.go +++ b/pkg/file/writer_test.go @@ -195,6 +195,31 @@ func Test_compareOrder(t *testing.T) { }, expected: true, }, + { + sortable1: FFilterChain{ + FilterChain: kong.FilterChain{ + Name: kong.String("my-filter-chain-1"), + ID: kong.String("my-id-1"), + Filters: []*kong.Filter{ + { + Name: kong.String("example-filter"), + }, + }, + }, + }, + sortable2: FFilterChain{ + FilterChain: kong.FilterChain{ + Name: kong.String("my-filter-chain-2"), + ID: kong.String("my-id-2"), + Filters: []*kong.Filter{ + { + Name: kong.String("example-filter"), + }, + }, + }, + }, + expected: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/pkg/file/zz_generated.deepcopy.go b/pkg/file/zz_generated.deepcopy.go index a7ab0b1..29b74ab 100644 --- a/pkg/file/zz_generated.deepcopy.go +++ b/pkg/file/zz_generated.deepcopy.go @@ -78,6 +78,13 @@ func (in *Content) DeepCopyInto(out *Content) { (*in)[i].DeepCopyInto(&(*out)[i]) } } + if in.FilterChains != nil { + in, out := &in.FilterChains, &out.FilterChains + *out = make([]FFilterChain, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } if in.Upstreams != nil { in, out := &in.Upstreams, &out.Upstreams *out = make([]FUpstream, len(*in)) @@ -409,6 +416,23 @@ func (in *FDocument) DeepCopy() *FDocument { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *FFilterChain) DeepCopyInto(out *FFilterChain) { + *out = *in + in.FilterChain.DeepCopyInto(&out.FilterChain) + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new FFilterChain. +func (in *FFilterChain) DeepCopy() *FFilterChain { + if in == nil { + return nil + } + out := new(FFilterChain) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *FLicense) DeepCopyInto(out *FLicense) { *out = *in @@ -508,6 +532,17 @@ func (in *FRoute) DeepCopyInto(out *FRoute) { } } } + if in.FilterChains != nil { + in, out := &in.FilterChains, &out.FilterChains + *out = make([]*FFilterChain, len(*in)) + for i := range *in { + if (*in)[i] != nil { + in, out := &(*in)[i], &(*out)[i] + *out = new(FFilterChain) + (*in).DeepCopyInto(*out) + } + } + } return } @@ -547,6 +582,17 @@ func (in *FService) DeepCopyInto(out *FService) { } } } + if in.FilterChains != nil { + in, out := &in.FilterChains, &out.FilterChains + *out = make([]*FFilterChain, len(*in)) + for i := range *in { + if (*in)[i] != nil { + in, out := &(*in)[i], &(*out)[i] + *out = new(FFilterChain) + (*in).DeepCopyInto(*out) + } + } + } if in.URL != nil { in, out := &in.URL, &out.URL *out = new(string) diff --git a/pkg/state/builder.go b/pkg/state/builder.go index 93d2281..59f43ba 100644 --- a/pkg/state/builder.go +++ b/pkg/state/builder.go @@ -310,6 +310,31 @@ func buildKong(kongState *KongState, raw *utils.KongRawState) error { } } + for _, f := range raw.FilterChains { + if f.Service != nil && !utils.Empty(f.Service.ID) { + ok, s, err := ensureService(kongState, *f.Service.ID) + if err != nil { + return err + } + if ok { + f.Service = s + } + } + if f.Route != nil && !utils.Empty(f.Route.ID) { + ok, r, err := ensureRoute(kongState, *f.Route.ID) + if err != nil { + return err + } + if ok { + f.Route = r + } + } + err := kongState.FilterChains.Add(FilterChain{FilterChain: *f}) + if err != nil { + return fmt.Errorf("inserting filter chains into state: %w", err) + } + } + for _, r := range raw.RBACRoles { err := kongState.RBACRoles.Add(RBACRole{RBACRole: *r}) if err != nil { diff --git a/pkg/state/filter_chain.go b/pkg/state/filter_chain.go new file mode 100644 index 0000000..3ce99b4 --- /dev/null +++ b/pkg/state/filter_chain.go @@ -0,0 +1,281 @@ +package state + +import ( + "errors" + "fmt" + + memdb "github.com/hashicorp/go-memdb" + "github.com/kong/go-database-reconciler/pkg/state/indexers" + "github.com/kong/go-database-reconciler/pkg/utils" +) + +const ( + filterChainTableName = "filterChain" + filterChainsByForeign = "filterChainsByForeign" +) + +var filterChainTableSchema = &memdb.TableSchema{ + Name: filterChainTableName, + Indexes: map[string]*memdb.IndexSchema{ + "id": { + Name: "id", + Unique: true, + Indexer: &memdb.StringFieldIndex{Field: "ID"}, + }, + all: allIndex, + // combined foreign fields + // FIXME bug: collision if svc/route has the same ID + // and same type of filter chain is created. Consider the case when only + // of the association is present + filterChainsByForeign: { + Name: filterChainsByForeign, + Indexer: &indexers.SubFieldIndexer{ + Fields: []indexers.Field{ + { + Struct: "Service", + Sub: "ID", + }, + { + Struct: "Route", + Sub: "ID", + }, + }, + }, + }, + }, +} + +// FilterChainsCollection stores and indexes Kong Services. +type FilterChainsCollection collection + +// Add adds a filter chain to FilterChainsCollection +func (k *FilterChainsCollection) Add(filterChain FilterChain) error { + txn := k.db.Txn(true) + defer txn.Abort() + + err := insertFilterChain(txn, filterChain) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +func insertFilterChain(txn *memdb.Txn, filterChain FilterChain) error { + // TODO abstract this check in the go-memdb library itself + if utils.Empty(filterChain.ID) { + return errIDRequired + } + + // err out if filter chain with same ID is present + _, err := getFilterChainByID(txn, *filterChain.ID) + if err == nil { + return fmt.Errorf("inserting filter chain %v: %w", filterChain.Console(), ErrAlreadyExists) + } else if !errors.Is(err, ErrNotFound) { + return err + } + + // err out if another filter chain with exact same combination is present + sID, rID := "", "" + if filterChain.Service != nil && !utils.Empty(filterChain.Service.ID) { + sID = *filterChain.Service.ID + } + if filterChain.Route != nil && !utils.Empty(filterChain.Route.ID) { + rID = *filterChain.Route.ID + } + + _, err = getFilterChainBy(txn, sID, rID) + if err == nil { + return fmt.Errorf("inserting filter chain %v: %w", filterChain.Console(), ErrAlreadyExists) + } else if !errors.Is(err, ErrNotFound) { + return err + } + + // all good + err = txn.Insert(filterChainTableName, &filterChain) + if err != nil { + return err + } + return nil +} + +func getFilterChainByID(txn *memdb.Txn, id string) (*FilterChain, error) { + res, err := multiIndexLookupUsingTxn(txn, filterChainTableName, + []string{"id"}, id) + if err != nil { + return nil, err + } + + filterChain, ok := res.(*FilterChain) + if !ok { + panic(unexpectedType) + } + return &FilterChain{FilterChain: *filterChain.DeepCopy()}, nil +} + +// Get gets a filter chain by id. +func (k *FilterChainsCollection) Get(id string) (*FilterChain, error) { + if id == "" { + return nil, errIDRequired + } + + txn := k.db.Txn(false) + defer txn.Abort() + + filterChain, err := getFilterChainByID(txn, id) + if err != nil { + return nil, err + } + return filterChain, nil +} + +func getFilterChainBy(txn *memdb.Txn, svcID, routeID string) ( + *FilterChain, error, +) { + res, err := txn.First(filterChainTableName, filterChainsByForeign, + svcID, routeID) + if err != nil { + return nil, err + } + if res == nil { + return nil, ErrNotFound + } + f, ok := res.(*FilterChain) + if !ok { + panic(unexpectedType) + } + return &FilterChain{FilterChain: *f.DeepCopy()}, nil +} + +// GetByProp returns a filter chain which matches all the properties passed in +// the arguments. If serviceID and routeID, +// are empty strings, then a global filter chain is searched. +// Otherwise, a filter chain with name and the supplied foreign references is +// searched. +// name is required. +func (k *FilterChainsCollection) GetByProp(serviceID, routeID string) (*FilterChain, error) { + txn := k.db.Txn(false) + defer txn.Abort() + + return getFilterChainBy(txn, serviceID, routeID) +} + +func (k *FilterChainsCollection) getAllFilterChainsBy(index string, identifier ...string) ( + []*FilterChain, error, +) { + haveID := false + args := make([]interface{}, len(identifier)) + for i, v := range identifier { + haveID = haveID || v != "" + args[i] = v + } + + if !haveID { + return nil, errIDRequired + } + + txn := k.db.Txn(false) + defer txn.Abort() + + iter, err := txn.Get(filterChainTableName, index, args...) + if err != nil { + return nil, err + } + var res []*FilterChain + for el := iter.Next(); el != nil; el = iter.Next() { + f, ok := el.(*FilterChain) + if !ok { + panic(unexpectedType) + } + res = append(res, &FilterChain{FilterChain: *f.DeepCopy()}) + } + return res, nil +} + +// GetAllByServiceID returns all filter chains referencing a service +// by its id. +func (k *FilterChainsCollection) GetAllByServiceID(id string) ([]*FilterChain, + error, +) { + return k.getAllFilterChainsBy(filterChainsByForeign, id, "") +} + +// GetAllByRouteID returns all filter chains referencing a route +// by its id. +func (k *FilterChainsCollection) GetAllByRouteID(id string) ([]*FilterChain, + error, +) { + return k.getAllFilterChainsBy(filterChainsByForeign, "", id) +} + +// Update updates a filter chain +func (k *FilterChainsCollection) Update(filterChain FilterChain) error { + // TODO abstract this check in the go-memdb library itself + if utils.Empty(filterChain.ID) { + return errIDRequired + } + + txn := k.db.Txn(true) + defer txn.Abort() + + err := deleteFilterChain(txn, *filterChain.ID) + if err != nil { + return err + } + + err = insertFilterChain(txn, filterChain) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +func deleteFilterChain(txn *memdb.Txn, id string) error { + filterChain, err := getFilterChainByID(txn, id) + if err != nil { + return err + } + return txn.Delete(filterChainTableName, filterChain) +} + +// Delete deletes a filter chain by ID. +func (k *FilterChainsCollection) Delete(id string) error { + if id == "" { + return errIDRequired + } + + txn := k.db.Txn(true) + defer txn.Abort() + + err := deleteFilterChain(txn, id) + if err != nil { + return err + } + + txn.Commit() + return nil +} + +// GetAll gets a filter chain by name or ID. +func (k *FilterChainsCollection) GetAll() ([]*FilterChain, error) { + txn := k.db.Txn(false) + defer txn.Abort() + + iter, err := txn.Get(filterChainTableName, all, true) + if err != nil { + return nil, err + } + + var res []*FilterChain + for el := iter.Next(); el != nil; el = iter.Next() { + f, ok := el.(*FilterChain) + if !ok { + panic(unexpectedType) + } + res = append(res, &FilterChain{FilterChain: *f.DeepCopy()}) + } + return res, nil +} diff --git a/pkg/state/state.go b/pkg/state/state.go index 1c5e255..d19d06c 100644 --- a/pkg/state/state.go +++ b/pkg/state/state.go @@ -22,6 +22,7 @@ type KongState struct { SNIs *SNIsCollection CACertificates *CACertificatesCollection Plugins *PluginsCollection + FilterChains *FilterChainsCollection Consumers *ConsumersCollection Vaults *VaultsCollection Licenses *LicensesCollection @@ -65,6 +66,7 @@ func NewKongState() (*KongState, error) { sniTableName: sniTableSchema, caCertTableName: caCertTableSchema, pluginTableName: pluginTableSchema, + filterChainTableName: filterChainTableSchema, consumerTableName: consumerTableSchema, consumerGroupTableName: consumerGroupTableSchema, consumerGroupConsumerTableName: consumerGroupConsumerTableSchema, @@ -107,6 +109,7 @@ func NewKongState() (*KongState, error) { state.SNIs = (*SNIsCollection)(&state.common) state.CACertificates = (*CACertificatesCollection)(&state.common) state.Plugins = (*PluginsCollection)(&state.common) + state.FilterChains = (*FilterChainsCollection)(&state.common) state.Consumers = (*ConsumersCollection)(&state.common) state.ConsumerGroups = (*ConsumerGroupsCollection)(&state.common) state.ConsumerGroupConsumers = (*ConsumerGroupConsumersCollection)(&state.common) diff --git a/pkg/state/types.go b/pkg/state/types.go index 42db24c..ca9ebb2 100644 --- a/pkg/state/types.go +++ b/pkg/state/types.go @@ -434,6 +434,75 @@ func (s1 *SNI) EqualWithOpts(s2 *SNI, ignoreID, return reflect.DeepEqual(s1Copy, s2Copy) } +// FilterChain represents a filter chain in Kong. +type FilterChain struct { + kong.FilterChain `yaml:",inline"` + Meta +} + +// Console returns an entity's identity in a human +// readable string. +func (p1 *FilterChain) Console() string { + res := "" + if p1.Name != nil { + res += *p1.Name + " " + } else if p1.ID != nil { + res += *p1.ID + " " + } + + if p1.Service != nil { + res += "for service " + p1.Service.FriendlyName() + } else if p1.Route != nil { + res += "for route " + p1.Route.FriendlyName() + } + + return res +} + +// EqualWithOpts returns true if p1 and p2 are equal. +// If ignoreID is set to true, IDs will be ignored while comparison. +// If ignoreTS is set to true, timestamp fields will be ignored. +func (p1 *FilterChain) EqualWithOpts(p2 *FilterChain, ignoreID, + ignoreTS, ignoreForeign bool, +) bool { + p1Copy := p1.FilterChain.DeepCopy() + p2Copy := p2.FilterChain.DeepCopy() + + sort.Slice(p1Copy.Tags, func(i, j int) bool { return *(p1Copy.Tags[i]) < *(p1Copy.Tags[j]) }) + sort.Slice(p2Copy.Tags, func(i, j int) bool { return *(p2Copy.Tags[i]) < *(p2Copy.Tags[j]) }) + + if ignoreID { + p1Copy.ID = nil + p2Copy.ID = nil + } + if ignoreTS { + p1Copy.CreatedAt = nil + p2Copy.CreatedAt = nil + p1Copy.UpdatedAt = nil + p2Copy.UpdatedAt = nil + } + if ignoreForeign { + p1Copy.Service = nil + p1Copy.Route = nil + p2Copy.Service = nil + p2Copy.Route = nil + } + + if p1Copy.Service != nil { + p1Copy.Service.Name = nil + } + if p2Copy.Service != nil { + p2Copy.Service.Name = nil + } + if p1Copy.Route != nil { + p1Copy.Route.Name = nil + } + if p2Copy.Route != nil { + p2Copy.Route.Name = nil + } + return reflect.DeepEqual(p1Copy, p2Copy) +} + // Plugin represents a route in Kong. // It adds some helper methods along with Meta to the original Plugin object. type Plugin struct { diff --git a/pkg/types/core.go b/pkg/types/core.go index 1fdc1d7..05d39f9 100644 --- a/pkg/types/core.go +++ b/pkg/types/core.go @@ -124,6 +124,9 @@ const ( Vault EntityType = "vault" // License identifies a License in Kong Enterprise. License EntityType = "license" + + // FilterChain identifies a FilterChain in Kong. + FilterChain EntityType = "filter-chain" ) // AllTypes represents all types defined in the @@ -146,6 +149,8 @@ var AllTypes = []EntityType{ ServicePackage, ServiceVersion, Document, Vault, License, + + FilterChain, } func entityTypeToKind(t EntityType) crud.Kind { @@ -545,7 +550,22 @@ func NewEntity(t EntityType, opts EntityOpts) (Entity, error) { currentState: opts.CurrentState, }, differ: &licenseDiffer{ - kind: entityTypeToKind(License), + kind: entityTypeToKind(License), + currentState: opts.CurrentState, + targetState: opts.TargetState, + }, + }, nil + case FilterChain: + return entityImpl{ + typ: FilterChain, + crudActions: &filterChainCRUD{ + client: opts.KongClient, + }, + postProcessActions: &filterChainPostAction{ + currentState: opts.CurrentState, + }, + differ: &filterChainDiffer{ + kind: entityTypeToKind(FilterChain), currentState: opts.CurrentState, targetState: opts.TargetState, }, diff --git a/pkg/types/filter_chain.go b/pkg/types/filter_chain.go new file mode 100644 index 0000000..d6c57df --- /dev/null +++ b/pkg/types/filter_chain.go @@ -0,0 +1,199 @@ +package types + +import ( + "context" + "errors" + "fmt" + + "github.com/kong/go-database-reconciler/pkg/crud" + "github.com/kong/go-database-reconciler/pkg/state" + "github.com/kong/go-kong/kong" +) + +// filterChainCRUD implements crud.Actions interface. +type filterChainCRUD struct { + client *kong.Client +} + +// kong and konnect APIs only require IDs for referenced entities. +func stripFilterChainReferencesName(filterChain *state.FilterChain) { + if filterChain.FilterChain.Service != nil && filterChain.FilterChain.Service.Name != nil { + filterChain.FilterChain.Service.Name = nil + } + if filterChain.FilterChain.Route != nil && filterChain.FilterChain.Route.Name != nil { + filterChain.FilterChain.Route.Name = nil + } +} + +func filterChainFromStruct(arg crud.Event) *state.FilterChain { + filterChain, ok := arg.Obj.(*state.FilterChain) + if !ok { + panic("unexpected type, expected *state.FilterChain") + } + stripFilterChainReferencesName(filterChain) + return filterChain +} + +// Create creates a FilterChain in Kong. +// The arg should be of type crud.Event, containing the filter chain to be created, +// else the function will panic. +// It returns a the created *state.FilterChain. +func (s *filterChainCRUD) Create(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + filterChain := filterChainFromStruct(event) + + createdFilterChain, err := s.client.FilterChains.Create(ctx, &filterChain.FilterChain) + if err != nil { + return nil, err + } + return &state.FilterChain{FilterChain: *createdFilterChain}, nil +} + +// Delete deletes a FilterChain in Kong. +// The arg should be of type crud.Event, containing the filter chain to be deleted, +// else the function will panic. +// It returns a the deleted *state.FilterChain. +func (s *filterChainCRUD) Delete(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + filterChain := filterChainFromStruct(event) + err := s.client.FilterChains.Delete(ctx, filterChain.ID) + if err != nil { + return nil, err + } + return filterChain, nil +} + +// Update updates a FilterChain in Kong. +// The arg should be of type crud.Event, containing the filter chain to be updated, +// else the function will panic. +// It returns a the updated *state.FilterChain. +func (s *filterChainCRUD) Update(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + filterChain := filterChainFromStruct(event) + + updatedFilterChain, err := s.client.FilterChains.Create(ctx, &filterChain.FilterChain) + if err != nil { + return nil, err + } + return &state.FilterChain{FilterChain: *updatedFilterChain}, nil +} + +type filterChainDiffer struct { + kind crud.Kind + + currentState, targetState *state.KongState +} + +func (d *filterChainDiffer) Deletes(handler func(crud.Event) error) error { + currentFilterChains, err := d.currentState.FilterChains.GetAll() + if err != nil { + return fmt.Errorf("error fetching filter chains from state: %w", err) + } + + for _, filterChain := range currentFilterChains { + n, err := d.deleteFilterChain(filterChain) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + } + return nil +} + +func (d *filterChainDiffer) deleteFilterChain(filterChain *state.FilterChain) (*crud.Event, error) { + filterChain = &state.FilterChain{FilterChain: *filterChain.DeepCopy()} + + serviceID, routeID := filterChainForeignNames(filterChain) + _, err := d.targetState.FilterChains.GetByProp( + serviceID, routeID, + ) + if errors.Is(err, state.ErrNotFound) { + return &crud.Event{ + Op: crud.Delete, + Kind: d.kind, + Obj: filterChain, + }, nil + } + if err != nil { + return nil, fmt.Errorf("looking up filter chain %q: %w", *filterChain.ID, err) + } + return nil, nil +} + +func (d *filterChainDiffer) CreateAndUpdates(handler func(crud.Event) error) error { + targetFilterChains, err := d.targetState.FilterChains.GetAll() + if err != nil { + return fmt.Errorf("error fetching filter chains from state: %w", err) + } + + for _, filterChain := range targetFilterChains { + n, err := d.createUpdateFilterChain(filterChain) + if err != nil { + return err + } + if n != nil { + err = handler(*n) + if err != nil { + return err + } + } + } + return nil +} + +func (d *filterChainDiffer) createUpdateFilterChain(filterChain *state.FilterChain) (*crud.Event, error) { + filterChain = &state.FilterChain{FilterChain: *filterChain.DeepCopy()} + + name := "" + if filterChain.Name != nil { + name = *filterChain.Name + } + + serviceID, routeID := filterChainForeignNames(filterChain) + currentFilterChain, err := d.currentState.FilterChains.GetByProp( + serviceID, routeID, + ) + if errors.Is(err, state.ErrNotFound) { + // filter chain not present, create it + + return &crud.Event{ + Op: crud.Create, + Kind: d.kind, + Obj: filterChain, + }, nil + } + if err != nil { + return nil, fmt.Errorf("error looking up filter chain %q: %w", + name, err) + } + currentFilterChain = &state.FilterChain{FilterChain: *currentFilterChain.DeepCopy()} + // found, check if update needed + + if !currentFilterChain.EqualWithOpts(filterChain, false, true, false) { + return &crud.Event{ + Op: crud.Update, + Kind: d.kind, + Obj: filterChain, + OldObj: currentFilterChain, + }, nil + } + return nil, nil +} + +func filterChainForeignNames(p *state.FilterChain) (serviceID, routeID string) { + if p == nil { + return + } + if p.Service != nil && p.Service.ID != nil { + serviceID = *p.Service.ID + } + if p.Route != nil && p.Route.ID != nil { + routeID = *p.Route.ID + } + return +} diff --git a/pkg/types/postProcess.go b/pkg/types/postProcess.go index 56ea79a..9df0ddf 100644 --- a/pkg/types/postProcess.go +++ b/pkg/types/postProcess.go @@ -30,6 +30,18 @@ func (crud *servicePostAction) Delete(_ context.Context, args ...crud.Arg) (crud return nil, fmt.Errorf("error deleting plugin '%v' for service '%v': %w", *plugin.ID, serviceID, err) } } + + // Delete all filterChains associated with this service as that's the implicit behavior of Kong (cascade delete). + filterChains, err := crud.currentState.FilterChains.GetAllByServiceID(serviceID) + if err != nil { + return nil, fmt.Errorf("error looking up filterChains for service '%v': %w", serviceID, err) + } + for _, filterChain := range filterChains { + err = crud.currentState.FilterChains.Delete(*filterChain.ID) + if err != nil { + return nil, fmt.Errorf("error deleting filterChain '%v' for service '%v': %w", *filterChain.ID, serviceID, err) + } + } return nil, crud.currentState.Services.Delete(serviceID) } @@ -59,6 +71,19 @@ func (crud *routePostAction) Delete(_ context.Context, args ...crud.Arg) (crud.A return nil, fmt.Errorf("error deleting plugin '%v' for route '%v': %w", *plugin.ID, routeID, err) } } + + // Delete all filterChains associated with this route as that's the implicit behavior of Kong (cascade delete). + filterChains, err := crud.currentState.FilterChains.GetAllByRouteID(routeID) + if err != nil { + return nil, fmt.Errorf("error looking up filterChains for route '%v': %w", routeID, err) + } + for _, filterChain := range filterChains { + err = crud.currentState.FilterChains.Delete(*filterChain.ID) + if err != nil { + return nil, fmt.Errorf("error deleting filterChain '%v' for route '%v': %w", *filterChain.ID, routeID, err) + } + } + return nil, crud.currentState.Routes.Delete(routeID) } @@ -471,3 +496,19 @@ func (crud licensePostAction) Delete(_ context.Context, args ...crud.Arg) (crud. func (crud licensePostAction) Update(_ context.Context, args ...crud.Arg) (crud.Arg, error) { return nil, crud.currentState.Licenses.Update(*args[0].(*state.License)) } + +type filterChainPostAction struct { + currentState *state.KongState +} + +func (crud *filterChainPostAction) Create(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.FilterChains.Add(*args[0].(*state.FilterChain)) +} + +func (crud *filterChainPostAction) Delete(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.FilterChains.Delete(*((args[0].(*state.FilterChain)).ID)) +} + +func (crud *filterChainPostAction) Update(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.FilterChains.Update(*args[0].(*state.FilterChain)) +} diff --git a/pkg/utils/types.go b/pkg/utils/types.go index ed48896..6846f1e 100644 --- a/pkg/utils/types.go +++ b/pkg/utils/types.go @@ -29,7 +29,8 @@ type KongRawState struct { Services []*kong.Service Routes []*kong.Route - Plugins []*kong.Plugin + Plugins []*kong.Plugin + FilterChains []*kong.FilterChain Upstreams []*kong.Upstream Targets []*kong.Target