From 363b12ee4c908756ad5ccad6f1d427b257dcf6c0 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Sun, 3 Nov 2024 13:13:11 +0000 Subject: [PATCH] Fix #178. --- context.go | 1 + ext/regexp/regexp.go | 126 ++++++++++++++++++++++++++++++++++---- ext/regexp/regexp_test.go | 18 +++++- stmt.go | 1 + 4 files changed, 131 insertions(+), 15 deletions(-) diff --git a/context.go b/context.go index be5dd92c..86be214e 100644 --- a/context.go +++ b/context.go @@ -89,6 +89,7 @@ func (ctx Context) ResultText(value string) { } // ResultRawText sets the text result of the function to a []byte. +// Returning a nil slice is the same as calling [Context.ResultNull]. // // https://sqlite.org/c3ref/result_blob.html func (ctx Context) ResultRawText(value []byte) { diff --git a/ext/regexp/regexp.go b/ext/regexp/regexp.go index 8253c532..6210147f 100644 --- a/ext/regexp/regexp.go +++ b/ext/regexp/regexp.go @@ -2,6 +2,8 @@ // // It provides the following Unicode aware functions: // - regexp_like(), +// - regexp_count(), +// - regexp_instr(), // - regexp_substr(), // - regexp_replace(), // - and a REGEXP operator. @@ -24,8 +26,17 @@ func Register(db *sqlite3.Conn) error { return errors.Join( db.CreateFunction("regexp", 2, flags, regex), db.CreateFunction("regexp_like", 2, flags, regexLike), + db.CreateFunction("regexp_count", 2, flags, regexCount), + db.CreateFunction("regexp_count", 3, flags, regexCount), + db.CreateFunction("regexp_instr", 2, flags, regexInstr), + db.CreateFunction("regexp_instr", 3, flags, regexInstr), + db.CreateFunction("regexp_instr", 4, flags, regexInstr), + db.CreateFunction("regexp_instr", 5, flags, regexInstr), db.CreateFunction("regexp_substr", 2, flags, regexSubstr), - db.CreateFunction("regexp_replace", 3, flags, regexReplace)) + db.CreateFunction("regexp_substr", 3, flags, regexSubstr), + db.CreateFunction("regexp_substr", 4, flags, regexSubstr), + db.CreateFunction("regexp_replace", 3, flags, regexReplace), + db.CreateFunction("regexp_replace", 4, flags, regexReplace)) } func load(ctx sqlite3.Context, i int, expr string) (*regexp.Regexp, error) { @@ -44,35 +55,126 @@ func load(ctx sqlite3.Context, i int, expr string) (*regexp.Regexp, error) { func regex(ctx sqlite3.Context, arg ...sqlite3.Value) { re, err := load(ctx, 0, arg[0].Text()) if err != nil { - ctx.ResultError(err) // notest - } else { - ctx.ResultBool(re.Match(arg[1].RawText())) + ctx.ResultError(err) + return // notest } + text := arg[1].RawText() + ctx.ResultBool(re.Match(text)) } func regexLike(ctx sqlite3.Context, arg ...sqlite3.Value) { re, err := load(ctx, 1, arg[1].Text()) if err != nil { - ctx.ResultError(err) // notest - } else { - ctx.ResultBool(re.Match(arg[0].RawText())) + ctx.ResultError(err) + return // notest } + text := arg[0].RawText() + ctx.ResultBool(re.Match(text)) +} + +func regexCount(ctx sqlite3.Context, arg ...sqlite3.Value) { + re, err := load(ctx, 1, arg[1].Text()) + if err != nil { + ctx.ResultError(err) + return // notest + } + text := arg[0].RawText() + if len(arg) > 2 { + pos := arg[2].Int() + _, text = split(text, pos) + } + ctx.ResultInt(len(re.FindAll(text, -1))) } func regexSubstr(ctx sqlite3.Context, arg ...sqlite3.Value) { re, err := load(ctx, 1, arg[1].Text()) if err != nil { - ctx.ResultError(err) // notest + ctx.ResultError(err) + return // notest + } + text := arg[0].RawText() + if len(arg) > 2 { + pos := arg[2].Int() + _, text = split(text, pos) + } + n := 0 + if len(arg) > 3 { + n = arg[3].Int() + } + + var res []byte + if n <= 1 { + res = re.Find(text) } else { - ctx.ResultRawText(re.Find(arg[0].RawText())) + all := re.FindAll(text, n) + if n <= len(all) { + res = all[n-1] + } } + ctx.ResultRawText(res) } -func regexReplace(ctx sqlite3.Context, arg ...sqlite3.Value) { +func regexInstr(ctx sqlite3.Context, arg ...sqlite3.Value) { re, err := load(ctx, 1, arg[1].Text()) if err != nil { - ctx.ResultError(err) // notest + ctx.ResultError(err) + return // notest + } + pos := 1 + text := arg[0].RawText() + if len(arg) > 2 { + pos = arg[2].Int() + _, text = split(text, pos) + } + n := 0 + if len(arg) > 3 { + n = arg[3].Int() + } + + var loc []int + if n <= 1 { + loc = re.FindIndex(text) } else { - ctx.ResultRawText(re.ReplaceAll(arg[0].RawText(), arg[2].RawText())) + all := re.FindAllIndex(text, n) + if n <= len(all) { + loc = all[n-1] + } + } + if loc == nil { + return + } + + end := 0 + if len(arg) > 4 && arg[4].Bool() { + end = 1 + } + ctx.ResultInt(pos + loc[end]) +} + +func regexReplace(ctx sqlite3.Context, arg ...sqlite3.Value) { + re, err := load(ctx, 1, arg[1].Text()) + if err != nil { + ctx.ResultError(err) + return // notest + } + var head, tail []byte + tail = arg[0].RawText() + if len(arg) > 3 { + pos := arg[3].Int() + head, tail = split(tail, pos) + } + tail = re.ReplaceAll(tail, arg[2].RawText()) + if head != nil { + tail = append(head, tail...) + } + ctx.ResultRawText(tail) +} + +func split(s []byte, i int) (head, tail []byte) { + for pos := range string(s) { + if i--; i <= 0 { + return s[:pos:pos], s[pos:] + } } + return s, nil } diff --git a/ext/regexp/regexp_test.go b/ext/regexp/regexp_test.go index 13232057..e9d86a79 100644 --- a/ext/regexp/regexp_test.go +++ b/ext/regexp/regexp_test.go @@ -1,6 +1,7 @@ package regexp import ( + "database/sql" "testing" "github.com/ncruces/go-sqlite3/driver" @@ -29,18 +30,27 @@ func TestRegister(t *testing.T) { {`regexp_like('Hello', 'elo')`, "0"}, {`regexp_like('Hello', 'ell')`, "1"}, {`regexp_like('Hello', 'el.')`, "1"}, + {`regexp_count('Hello', 'l')`, "2"}, + {`regexp_instr('Hello', 'el.')`, "2"}, + {`regexp_instr('Hello', '.', 6)`, ""}, {`regexp_substr('Hello', 'el.')`, "ell"}, + {`regexp_substr('Hello', 'l', 2, 2)`, "l"}, {`regexp_replace('Hello', 'llo', 'll')`, "Hell"}, + + {`regexp_count('123123123123123', '(12)3', 1)`, "5"}, + {`regexp_instr('500 Oracle Parkway, Redwood Shores, CA', '(?i)[s|r|p][[:alpha:]]{6}', 3, 2, 1)`, "28"}, + {`regexp_substr('500 Oracle Parkway, Redwood Shores, CA', ',[^,]+,', 3, 1)`, ", Redwood Shores,"}, + {`regexp_replace('500 Oracle Parkway, Redwood Shores, CA', '( ){2,}', ' ', 3)`, "500 Oracle Parkway, Redwood Shores, CA"}, } for _, tt := range tests { - var got string + var got sql.NullString err := db.QueryRow(`SELECT ` + tt.test).Scan(&got) if err != nil { t.Fatal(err) } - if got != tt.want { - t.Errorf("got %q, want %q", got, tt.want) + if got.String != tt.want { + t.Errorf("got %q, want %q", got.String, tt.want) } } } @@ -58,6 +68,8 @@ func TestRegister_errors(t *testing.T) { tests := []string{ `'' REGEXP ?`, `regexp_like('', ?)`, + `regexp_count('', ?)`, + `regexp_instr('', ?)`, `regexp_substr('', ?)`, `regexp_replace('', ?, '')`, } diff --git a/stmt.go b/stmt.go index 9da2a2ea..139dd352 100644 --- a/stmt.go +++ b/stmt.go @@ -255,6 +255,7 @@ func (s *Stmt) BindText(param int, value string) error { // BindRawText binds a []byte to the prepared statement as text. // The leftmost SQL parameter has an index of 1. +// Binding a nil slice is the same as calling [Stmt.BindNull]. // // https://sqlite.org/c3ref/bind_blob.html func (s *Stmt) BindRawText(param int, value []byte) error {