Skip to content

Commit

Permalink
refactor filtering to happen per model
Browse files Browse the repository at this point in the history
  • Loading branch information
bparees committed Aug 31, 2023
1 parent 71f52c9 commit 03047fc
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 56 deletions.
8 changes: 3 additions & 5 deletions cmd/wisdom/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import (
"gopkg.in/yaml.v2"

"github.com/openshift/wisdom/pkg/api"
"github.com/openshift/wisdom/pkg/filters"
"github.com/openshift/wisdom/pkg/model"
"github.com/openshift/wisdom/pkg/model/ibm"
"github.com/openshift/wisdom/pkg/model/openai"
Expand Down Expand Up @@ -90,7 +89,6 @@ func newStartServerCommand() *cobra.Command {
models = initModels(config)

h := server.Handler{
Filter: filters.NewFilter(),
DefaultProvider: config.DefaultProvider,
DefaultModel: config.DefaultModelId,
Models: models,
Expand Down Expand Up @@ -195,7 +193,6 @@ func newInferCommand() *cobra.Command {
if o.prompt == "" {
return fmt.Errorf("model prompt is required")
}
filter := filters.NewFilter()

// If the user didn't specify a provider or model, use the defaults from the config file
if o.provider == "" {
Expand All @@ -213,8 +210,8 @@ func newInferCommand() *cobra.Command {
input := api.ModelInput{
Prompt: o.prompt,
}
log.Debugf("invoking model %s/%s", o.provider, o.modelId)
response, err := model.InvokeModel(input, m, filter)
log.Debugf("Using provider/model %s/%s for prompt:\n%s\n", o.provider, o.modelId, o.prompt)
response, err := model.InvokeModel(input, m)
if err != nil {
if response != nil && response.Output != "" {
log.Debugf("Response(Error):\n%s", response.Output)
Expand Down Expand Up @@ -246,6 +243,7 @@ func initModels(config api.Config) map[string]api.Model {
switch m.Provider {
case "ibm":
models[m.Provider+"/"+m.ModelId] = ibm.NewIBMModel(m.ModelId, m.URL, m.UserId, m.APIKey)

case "openai":
models[m.Provider+"/"+m.ModelId] = openai.NewOpenAIModel(m.ModelId, m.URL, m.APIKey)
default:
Expand Down
43 changes: 43 additions & 0 deletions pkg/api/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,51 @@ import (
"github.com/golang-jwt/jwt/v4"
)

type Filter struct {
InputFilterChain []InputFilter
ResponseFilterChain []ResponseFilter
}

type InputFilter func(input *ModelInput) (*ModelInput, error)
type ResponseFilter func(response *ModelResponse) (*ModelResponse, error)

func NewFilter(inputFilters []InputFilter, responseFilters []ResponseFilter) Filter {
filter := Filter{
InputFilterChain: inputFilters,
ResponseFilterChain: responseFilters,
}
return filter
}

func (f *Filter) FilterInput(input *ModelInput) (*ModelInput, error) {
output := input
var err error
for _, filter := range f.InputFilterChain {
output, err = filter(output)
if err != nil {
return output, err
}
}
return output, err
}

func (f Filter) FilterResponse(response *ModelResponse) (*ModelResponse, error) {
output := response
var err error
for _, filter := range f.ResponseFilterChain {
output, err = filter(output)
if err != nil {
return output, err
}
}
return output, err
}

type Model interface {
Invoke(ModelInput) (*ModelResponse, error)
GetFilter() Filter
//FilterInput(*ModelInput) (*ModelInput, error)
//FilterResponse(*ModelResponse) (*ModelResponse, error)
}

// ModelInput represents the payload for the prompt_request endpoint.
Expand Down
44 changes: 0 additions & 44 deletions pkg/filters/filters.go
Original file line number Diff line number Diff line change
@@ -1,45 +1 @@
package filters

import (
"github.com/openshift/wisdom/pkg/filters/yaml"

"github.com/openshift/wisdom/pkg/api"
)

type Filter struct {
inputFilterChain []InputFilter
responseFilterChain []ResponseFilter
}
type InputFilter func(input *api.ModelInput) (*api.ModelInput, error)
type ResponseFilter func(response *api.ModelResponse) (*api.ModelResponse, error)

func NewFilter() Filter {
filter := Filter{}
//filter.responseFilterChain = append(filter.responseFilterChain, markdown.MarkdownStripper, yaml.YamlLinter)
filter.responseFilterChain = append(filter.responseFilterChain, yaml.YamlLinter)
return filter
}

func (f *Filter) FilterInput(input *api.ModelInput) (*api.ModelInput, error) {
output := input
var err error
for _, filter := range f.inputFilterChain {
output, err = filter(output)
if err != nil {
return output, err
}
}
return output, err
}

func (f *Filter) FilterResponse(response *api.ModelResponse) (*api.ModelResponse, error) {
output := response
var err error
for _, filter := range f.responseFilterChain {
output, err = filter(output)
if err != nil {
return output, err
}
}
return output, err
}
13 changes: 13 additions & 0 deletions pkg/model/ibm/model_ibm.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"net/http"

"github.com/openshift/wisdom/pkg/api"
"github.com/openshift/wisdom/pkg/filters/markdown"
"github.com/openshift/wisdom/pkg/filters/yaml"
)

const (
Expand Down Expand Up @@ -36,17 +38,24 @@ type IBMModel struct {
url string
apiKey string
userId string
filter api.Filter
}

func NewIBMModel(modelId, url, userId, apiKey string) *IBMModel {
filter := api.NewFilter(nil, []api.ResponseFilter{markdown.MarkdownStripper, yaml.YamlLinter})
return &IBMModel{
modelId: modelId,
url: url,
apiKey: apiKey,
userId: userId,
filter: filter,
}
}

func (m *IBMModel) GetFilter() api.Filter {
return m.filter
}

func (m *IBMModel) Invoke(input api.ModelInput) (*api.ModelResponse, error) {

if input.UserId == "" && m.userId == "" {
Expand Down Expand Up @@ -122,3 +131,7 @@ func (m *IBMModel) Invoke(input api.ModelInput) (*api.ModelResponse, error) {

return &response, err
}

func (m *IBMModel) FilterInput(input *api.ModelInput) (*api.ModelInput, error) {
return m.filter.FilterInput(input)
}
7 changes: 3 additions & 4 deletions pkg/model/invoker.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@ import (
log "github.com/sirupsen/logrus"

"github.com/openshift/wisdom/pkg/api"
"github.com/openshift/wisdom/pkg/filters"
)

func InvokeModel(input api.ModelInput, model api.Model, filter filters.Filter) (*api.ModelResponse, error) {

func InvokeModel(input api.ModelInput, model api.Model) (*api.ModelResponse, error) {
response, err := model.Invoke(input)
if response == nil {
response = &api.ModelResponse{}
Expand All @@ -18,7 +16,8 @@ func InvokeModel(input api.ModelInput, model api.Model, filter filters.Filter) (
response.Error = err.Error()
return response, err
}
output, err := filter.FilterResponse(response)

output, err := model.GetFilter().FilterResponse(response)
if err != nil {
response.Error = err.Error()
}
Expand Down
8 changes: 8 additions & 0 deletions pkg/model/openai/model_openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,24 @@ type OpenAIModel struct {
modelId string
url string
apiKey string
filter api.Filter
}

func NewOpenAIModel(modelId, url, apiKey string) *OpenAIModel {
filter := api.NewFilter(nil, nil)

return &OpenAIModel{
modelId: modelId,
url: url,
apiKey: apiKey,
filter: filter,
}
}

func (m *OpenAIModel) GetFilter() api.Filter {
return m.filter
}

func (m *OpenAIModel) Invoke(input api.ModelInput) (*api.ModelResponse, error) {

if input.APIKey == "" && m.apiKey == "" {
Expand Down
2 changes: 0 additions & 2 deletions pkg/server/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@ package server
import (
"github.com/gorilla/sessions"
"github.com/openshift/wisdom/pkg/api"
"github.com/openshift/wisdom/pkg/filters"
"golang.org/x/oauth2"
)

type Handler struct {
Filter filters.Filter
DefaultModel string
DefaultProvider string
Models map[string]api.Model
Expand Down
2 changes: 1 addition & 1 deletion pkg/server/inference_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func (h *Handler) InferHandler(w http.ResponseWriter, r *http.Request) {

log.Debugf("Using provider/model %s/%s for prompt:\n%s\n", payload.Provider, payload.ModelId, payload.Prompt)

response, err := model.InvokeModel(payload, m, h.Filter)
response, err := model.InvokeModel(payload, m)

buf := bytes.Buffer{}
if response != nil {
Expand Down

0 comments on commit 03047fc

Please sign in to comment.