From 03047fcf6e5c3ee193b8a6ab030fdcd7dcd038d5 Mon Sep 17 00:00:00 2001 From: bparees Date: Thu, 31 Aug 2023 17:16:59 -0400 Subject: [PATCH] refactor filtering to happen per model --- cmd/wisdom/main.go | 8 +++--- pkg/api/types.go | 43 +++++++++++++++++++++++++++++++ pkg/filters/filters.go | 44 -------------------------------- pkg/model/ibm/model_ibm.go | 13 ++++++++++ pkg/model/invoker.go | 7 +++-- pkg/model/openai/model_openai.go | 8 ++++++ pkg/server/handler.go | 2 -- pkg/server/inference_handler.go | 2 +- 8 files changed, 71 insertions(+), 56 deletions(-) diff --git a/cmd/wisdom/main.go b/cmd/wisdom/main.go index c813ee8..6dfae5e 100644 --- a/cmd/wisdom/main.go +++ b/cmd/wisdom/main.go @@ -15,7 +15,6 @@ import ( "gopkg.in/yaml.v2" "github.com/openshift/wisdom/pkg/api" - "github.com/openshift/wisdom/pkg/filters" "github.com/openshift/wisdom/pkg/model" "github.com/openshift/wisdom/pkg/model/ibm" "github.com/openshift/wisdom/pkg/model/openai" @@ -90,7 +89,6 @@ func newStartServerCommand() *cobra.Command { models = initModels(config) h := server.Handler{ - Filter: filters.NewFilter(), DefaultProvider: config.DefaultProvider, DefaultModel: config.DefaultModelId, Models: models, @@ -195,7 +193,6 @@ func newInferCommand() *cobra.Command { if o.prompt == "" { return fmt.Errorf("model prompt is required") } - filter := filters.NewFilter() // If the user didn't specify a provider or model, use the defaults from the config file if o.provider == "" { @@ -213,8 +210,8 @@ func newInferCommand() *cobra.Command { input := api.ModelInput{ Prompt: o.prompt, } - log.Debugf("invoking model %s/%s", o.provider, o.modelId) - response, err := model.InvokeModel(input, m, filter) + log.Debugf("Using provider/model %s/%s for prompt:\n%s\n", o.provider, o.modelId, o.prompt) + response, err := model.InvokeModel(input, m) if err != nil { if response != nil && response.Output != "" { log.Debugf("Response(Error):\n%s", response.Output) @@ -246,6 +243,7 @@ func initModels(config api.Config) map[string]api.Model { switch m.Provider { case "ibm": models[m.Provider+"/"+m.ModelId] = ibm.NewIBMModel(m.ModelId, m.URL, m.UserId, m.APIKey) + case "openai": models[m.Provider+"/"+m.ModelId] = openai.NewOpenAIModel(m.ModelId, m.URL, m.APIKey) default: diff --git a/pkg/api/types.go b/pkg/api/types.go index eb4c0f2..19d0098 100644 --- a/pkg/api/types.go +++ b/pkg/api/types.go @@ -4,8 +4,51 @@ import ( "github.com/golang-jwt/jwt/v4" ) +type Filter struct { + InputFilterChain []InputFilter + ResponseFilterChain []ResponseFilter +} + +type InputFilter func(input *ModelInput) (*ModelInput, error) +type ResponseFilter func(response *ModelResponse) (*ModelResponse, error) + +func NewFilter(inputFilters []InputFilter, responseFilters []ResponseFilter) Filter { + filter := Filter{ + InputFilterChain: inputFilters, + ResponseFilterChain: responseFilters, + } + return filter +} + +func (f *Filter) FilterInput(input *ModelInput) (*ModelInput, error) { + output := input + var err error + for _, filter := range f.InputFilterChain { + output, err = filter(output) + if err != nil { + return output, err + } + } + return output, err +} + +func (f Filter) FilterResponse(response *ModelResponse) (*ModelResponse, error) { + output := response + var err error + for _, filter := range f.ResponseFilterChain { + output, err = filter(output) + if err != nil { + return output, err + } + } + return output, err +} + type Model interface { Invoke(ModelInput) (*ModelResponse, error) + GetFilter() Filter + //FilterInput(*ModelInput) (*ModelInput, error) + //FilterResponse(*ModelResponse) (*ModelResponse, error) } // ModelInput represents the payload for the prompt_request endpoint. diff --git a/pkg/filters/filters.go b/pkg/filters/filters.go index 20019d3..66bd55a 100644 --- a/pkg/filters/filters.go +++ b/pkg/filters/filters.go @@ -1,45 +1 @@ package filters - -import ( - "github.com/openshift/wisdom/pkg/filters/yaml" - - "github.com/openshift/wisdom/pkg/api" -) - -type Filter struct { - inputFilterChain []InputFilter - responseFilterChain []ResponseFilter -} -type InputFilter func(input *api.ModelInput) (*api.ModelInput, error) -type ResponseFilter func(response *api.ModelResponse) (*api.ModelResponse, error) - -func NewFilter() Filter { - filter := Filter{} - //filter.responseFilterChain = append(filter.responseFilterChain, markdown.MarkdownStripper, yaml.YamlLinter) - filter.responseFilterChain = append(filter.responseFilterChain, yaml.YamlLinter) - return filter -} - -func (f *Filter) FilterInput(input *api.ModelInput) (*api.ModelInput, error) { - output := input - var err error - for _, filter := range f.inputFilterChain { - output, err = filter(output) - if err != nil { - return output, err - } - } - return output, err -} - -func (f *Filter) FilterResponse(response *api.ModelResponse) (*api.ModelResponse, error) { - output := response - var err error - for _, filter := range f.responseFilterChain { - output, err = filter(output) - if err != nil { - return output, err - } - } - return output, err -} diff --git a/pkg/model/ibm/model_ibm.go b/pkg/model/ibm/model_ibm.go index f59d509..b791f2c 100644 --- a/pkg/model/ibm/model_ibm.go +++ b/pkg/model/ibm/model_ibm.go @@ -7,6 +7,8 @@ import ( "net/http" "github.com/openshift/wisdom/pkg/api" + "github.com/openshift/wisdom/pkg/filters/markdown" + "github.com/openshift/wisdom/pkg/filters/yaml" ) const ( @@ -36,17 +38,24 @@ type IBMModel struct { url string apiKey string userId string + filter api.Filter } func NewIBMModel(modelId, url, userId, apiKey string) *IBMModel { + filter := api.NewFilter(nil, []api.ResponseFilter{markdown.MarkdownStripper, yaml.YamlLinter}) return &IBMModel{ modelId: modelId, url: url, apiKey: apiKey, userId: userId, + filter: filter, } } +func (m *IBMModel) GetFilter() api.Filter { + return m.filter +} + func (m *IBMModel) Invoke(input api.ModelInput) (*api.ModelResponse, error) { if input.UserId == "" && m.userId == "" { @@ -122,3 +131,7 @@ func (m *IBMModel) Invoke(input api.ModelInput) (*api.ModelResponse, error) { return &response, err } + +func (m *IBMModel) FilterInput(input *api.ModelInput) (*api.ModelInput, error) { + return m.filter.FilterInput(input) +} diff --git a/pkg/model/invoker.go b/pkg/model/invoker.go index 2bdccc8..a2a59a9 100644 --- a/pkg/model/invoker.go +++ b/pkg/model/invoker.go @@ -4,11 +4,9 @@ import ( log "github.com/sirupsen/logrus" "github.com/openshift/wisdom/pkg/api" - "github.com/openshift/wisdom/pkg/filters" ) -func InvokeModel(input api.ModelInput, model api.Model, filter filters.Filter) (*api.ModelResponse, error) { - +func InvokeModel(input api.ModelInput, model api.Model) (*api.ModelResponse, error) { response, err := model.Invoke(input) if response == nil { response = &api.ModelResponse{} @@ -18,7 +16,8 @@ func InvokeModel(input api.ModelInput, model api.Model, filter filters.Filter) ( response.Error = err.Error() return response, err } - output, err := filter.FilterResponse(response) + + output, err := model.GetFilter().FilterResponse(response) if err != nil { response.Error = err.Error() } diff --git a/pkg/model/openai/model_openai.go b/pkg/model/openai/model_openai.go index 8324af3..2e1930f 100644 --- a/pkg/model/openai/model_openai.go +++ b/pkg/model/openai/model_openai.go @@ -37,16 +37,24 @@ type OpenAIModel struct { modelId string url string apiKey string + filter api.Filter } func NewOpenAIModel(modelId, url, apiKey string) *OpenAIModel { + filter := api.NewFilter(nil, nil) + return &OpenAIModel{ modelId: modelId, url: url, apiKey: apiKey, + filter: filter, } } +func (m *OpenAIModel) GetFilter() api.Filter { + return m.filter +} + func (m *OpenAIModel) Invoke(input api.ModelInput) (*api.ModelResponse, error) { if input.APIKey == "" && m.apiKey == "" { diff --git a/pkg/server/handler.go b/pkg/server/handler.go index 2fd4e4f..afa2efb 100644 --- a/pkg/server/handler.go +++ b/pkg/server/handler.go @@ -3,12 +3,10 @@ package server import ( "github.com/gorilla/sessions" "github.com/openshift/wisdom/pkg/api" - "github.com/openshift/wisdom/pkg/filters" "golang.org/x/oauth2" ) type Handler struct { - Filter filters.Filter DefaultModel string DefaultProvider string Models map[string]api.Model diff --git a/pkg/server/inference_handler.go b/pkg/server/inference_handler.go index 6bea772..5245d6c 100644 --- a/pkg/server/inference_handler.go +++ b/pkg/server/inference_handler.go @@ -48,7 +48,7 @@ func (h *Handler) InferHandler(w http.ResponseWriter, r *http.Request) { log.Debugf("Using provider/model %s/%s for prompt:\n%s\n", payload.Provider, payload.ModelId, payload.Prompt) - response, err := model.InvokeModel(payload, m, h.Filter) + response, err := model.InvokeModel(payload, m) buf := bytes.Buffer{} if response != nil {