From f1195b71fb1fb71caba070b6b80847e6a7443184 Mon Sep 17 00:00:00 2001
From: Nuno Cruces <ncruces@users.noreply.github.com>
Date: Sun, 3 Nov 2024 13:13:11 +0000
Subject: [PATCH] Fix #178.

---
 context.go                |   1 +
 ext/regexp/regexp.go      | 124 ++++++++++++++++++++++++++++++++++----
 ext/regexp/regexp_test.go |  18 +++++-
 stmt.go                   |   1 +
 4 files changed, 129 insertions(+), 15 deletions(-)

diff --git a/context.go b/context.go
index be5dd92c..86be214e 100644
--- a/context.go
+++ b/context.go
@@ -89,6 +89,7 @@ func (ctx Context) ResultText(value string) {
 }
 
 // ResultRawText sets the text result of the function to a []byte.
+// Returning a nil slice is the same as calling [Context.ResultNull].
 //
 // https://sqlite.org/c3ref/result_blob.html
 func (ctx Context) ResultRawText(value []byte) {
diff --git a/ext/regexp/regexp.go b/ext/regexp/regexp.go
index 8253c532..55b42123 100644
--- a/ext/regexp/regexp.go
+++ b/ext/regexp/regexp.go
@@ -24,8 +24,17 @@ func Register(db *sqlite3.Conn) error {
 	return errors.Join(
 		db.CreateFunction("regexp", 2, flags, regex),
 		db.CreateFunction("regexp_like", 2, flags, regexLike),
+		db.CreateFunction("regexp_count", 2, flags, regexCount),
+		db.CreateFunction("regexp_count", 3, flags, regexCount),
+		db.CreateFunction("regexp_instr", 2, flags, regexInstr),
+		db.CreateFunction("regexp_instr", 3, flags, regexInstr),
+		db.CreateFunction("regexp_instr", 4, flags, regexInstr),
+		db.CreateFunction("regexp_instr", 5, flags, regexInstr),
 		db.CreateFunction("regexp_substr", 2, flags, regexSubstr),
-		db.CreateFunction("regexp_replace", 3, flags, regexReplace))
+		db.CreateFunction("regexp_substr", 3, flags, regexSubstr),
+		db.CreateFunction("regexp_substr", 4, flags, regexSubstr),
+		db.CreateFunction("regexp_replace", 3, flags, regexReplace),
+		db.CreateFunction("regexp_replace", 4, flags, regexReplace))
 }
 
 func load(ctx sqlite3.Context, i int, expr string) (*regexp.Regexp, error) {
@@ -44,35 +53,126 @@ func load(ctx sqlite3.Context, i int, expr string) (*regexp.Regexp, error) {
 func regex(ctx sqlite3.Context, arg ...sqlite3.Value) {
 	re, err := load(ctx, 0, arg[0].Text())
 	if err != nil {
-		ctx.ResultError(err) // notest
-	} else {
-		ctx.ResultBool(re.Match(arg[1].RawText()))
+		ctx.ResultError(err)
+		return // notest
 	}
+	text := arg[1].RawText()
+	ctx.ResultBool(re.Match(text))
 }
 
 func regexLike(ctx sqlite3.Context, arg ...sqlite3.Value) {
 	re, err := load(ctx, 1, arg[1].Text())
 	if err != nil {
-		ctx.ResultError(err) // notest
-	} else {
-		ctx.ResultBool(re.Match(arg[0].RawText()))
+		ctx.ResultError(err)
+		return // notest
 	}
+	text := arg[0].RawText()
+	ctx.ResultBool(re.Match(text))
+}
+
+func regexCount(ctx sqlite3.Context, arg ...sqlite3.Value) {
+	re, err := load(ctx, 1, arg[1].Text())
+	if err != nil {
+		ctx.ResultError(err)
+		return // notest
+	}
+	text := arg[0].RawText()
+	if len(arg) > 2 {
+		pos := arg[2].Int()
+		_, text = split(text, pos)
+	}
+	ctx.ResultInt(len(re.FindAll(text, -1)))
 }
 
 func regexSubstr(ctx sqlite3.Context, arg ...sqlite3.Value) {
 	re, err := load(ctx, 1, arg[1].Text())
 	if err != nil {
-		ctx.ResultError(err) // notest
+		ctx.ResultError(err)
+		return // notest
+	}
+	text := arg[0].RawText()
+	if len(arg) > 2 {
+		pos := arg[2].Int()
+		_, text = split(text, pos)
+	}
+	n := 0
+	if len(arg) > 3 {
+		n = arg[3].Int()
+	}
+
+	var res []byte
+	if n <= 1 {
+		res = re.Find(text)
 	} else {
-		ctx.ResultRawText(re.Find(arg[0].RawText()))
+		all := re.FindAll(text, n)
+		if n <= len(all) {
+			res = all[n-1]
+		}
 	}
+	ctx.ResultRawText(res)
 }
 
-func regexReplace(ctx sqlite3.Context, arg ...sqlite3.Value) {
+func regexInstr(ctx sqlite3.Context, arg ...sqlite3.Value) {
 	re, err := load(ctx, 1, arg[1].Text())
 	if err != nil {
-		ctx.ResultError(err) // notest
+		ctx.ResultError(err)
+		return // notest
+	}
+	pos := 1
+	text := arg[0].RawText()
+	if len(arg) > 2 {
+		pos = arg[2].Int()
+		_, text = split(text, pos)
+	}
+	n := 0
+	if len(arg) > 3 {
+		n = arg[3].Int()
+	}
+
+	var loc []int
+	if n <= 1 {
+		loc = re.FindIndex(text)
 	} else {
-		ctx.ResultRawText(re.ReplaceAll(arg[0].RawText(), arg[2].RawText()))
+		all := re.FindAllIndex(text, n)
+		if n <= len(all) {
+			loc = all[n-1]
+		}
+	}
+	if loc == nil {
+		return
+	}
+
+	end := 0
+	if len(arg) > 4 && arg[4].Bool() {
+		end = 1
+	}
+	ctx.ResultInt(pos + loc[end])
+}
+
+func regexReplace(ctx sqlite3.Context, arg ...sqlite3.Value) {
+	re, err := load(ctx, 1, arg[1].Text())
+	if err != nil {
+		ctx.ResultError(err)
+		return // notest
+	}
+	var head, tail []byte
+	tail = arg[0].RawText()
+	if len(arg) > 3 {
+		pos := arg[3].Int()
+		head, tail = split(tail, pos)
+	}
+	tail = re.ReplaceAll(tail, arg[2].RawText())
+	if head != nil {
+		tail = append(head, tail...)
+	}
+	ctx.ResultRawText(tail)
+}
+
+func split(s []byte, i int) (head, tail []byte) {
+	for pos := range string(s) {
+		if i--; i <= 0 {
+			return s[:pos:pos], s[pos:]
+		}
 	}
+	return s, nil
 }
diff --git a/ext/regexp/regexp_test.go b/ext/regexp/regexp_test.go
index 13232057..e9d86a79 100644
--- a/ext/regexp/regexp_test.go
+++ b/ext/regexp/regexp_test.go
@@ -1,6 +1,7 @@
 package regexp
 
 import (
+	"database/sql"
 	"testing"
 
 	"github.com/ncruces/go-sqlite3/driver"
@@ -29,18 +30,27 @@ func TestRegister(t *testing.T) {
 		{`regexp_like('Hello', 'elo')`, "0"},
 		{`regexp_like('Hello', 'ell')`, "1"},
 		{`regexp_like('Hello', 'el.')`, "1"},
+		{`regexp_count('Hello', 'l')`, "2"},
+		{`regexp_instr('Hello', 'el.')`, "2"},
+		{`regexp_instr('Hello', '.', 6)`, ""},
 		{`regexp_substr('Hello', 'el.')`, "ell"},
+		{`regexp_substr('Hello', 'l', 2, 2)`, "l"},
 		{`regexp_replace('Hello', 'llo', 'll')`, "Hell"},
+
+		{`regexp_count('123123123123123', '(12)3', 1)`, "5"},
+		{`regexp_instr('500 Oracle Parkway, Redwood Shores, CA', '(?i)[s|r|p][[:alpha:]]{6}', 3, 2, 1)`, "28"},
+		{`regexp_substr('500 Oracle Parkway, Redwood Shores, CA', ',[^,]+,', 3, 1)`, ", Redwood Shores,"},
+		{`regexp_replace('500   Oracle     Parkway,    Redwood  Shores, CA', '( ){2,}', ' ', 3)`, "500 Oracle Parkway, Redwood Shores, CA"},
 	}
 
 	for _, tt := range tests {
-		var got string
+		var got sql.NullString
 		err := db.QueryRow(`SELECT ` + tt.test).Scan(&got)
 		if err != nil {
 			t.Fatal(err)
 		}
-		if got != tt.want {
-			t.Errorf("got %q, want %q", got, tt.want)
+		if got.String != tt.want {
+			t.Errorf("got %q, want %q", got.String, tt.want)
 		}
 	}
 }
@@ -58,6 +68,8 @@ func TestRegister_errors(t *testing.T) {
 	tests := []string{
 		`'' REGEXP ?`,
 		`regexp_like('', ?)`,
+		`regexp_count('', ?)`,
+		`regexp_instr('', ?)`,
 		`regexp_substr('', ?)`,
 		`regexp_replace('', ?, '')`,
 	}
diff --git a/stmt.go b/stmt.go
index 9da2a2ea..139dd352 100644
--- a/stmt.go
+++ b/stmt.go
@@ -255,6 +255,7 @@ func (s *Stmt) BindText(param int, value string) error {
 
 // BindRawText binds a []byte to the prepared statement as text.
 // The leftmost SQL parameter has an index of 1.
+// Binding a nil slice is the same as calling [Stmt.BindNull].
 //
 // https://sqlite.org/c3ref/bind_blob.html
 func (s *Stmt) BindRawText(param int, value []byte) error {