diff --git a/driver/driver.go b/driver/driver.go index 477e9a94..b9bb03bb 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -379,7 +379,7 @@ func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e if err != nil { return nil, err } - if tail != "" { + if notWhitespace(tail) { s.Close() return nil, util.TailErr } diff --git a/driver/driver_test.go b/driver/driver_test.go index bef4bf62..871b70e4 100644 --- a/driver/driver_test.go +++ b/driver/driver_test.go @@ -225,8 +225,8 @@ func Test_Prepare(t *testing.T) { } _, err = db.Prepare(`SELECT 1; `) - if err.Error() != string(util.TailErr) { - t.Error("want tailErr") + if err != nil { + t.Error(err) } _, err = db.Prepare(`SELECT 1; SELECT`) diff --git a/driver/time_test.go b/driver/time_test.go index a19380d4..0b56ba8b 100644 --- a/driver/time_test.go +++ b/driver/time_test.go @@ -27,12 +27,12 @@ func Fuzz_stringOrTime_1(f *testing.F) { // Make sure times round-trip to the same string: // https://pkg.go.dev/database/sql#Rows.Scan if v.Format(time.RFC3339Nano) != str { - t.Fatalf("did not round-trip: %q", str) + t.Errorf("did not round-trip: %q", str) } } else { date, err := time.Parse(time.RFC3339Nano, str) if err == nil && date.Format(time.RFC3339Nano) == str { - t.Fatalf("would round-trip: %q", str) + t.Errorf("would round-trip: %q", str) } } }) diff --git a/driver/whitespace.go b/driver/whitespace.go new file mode 100644 index 00000000..7e7e0038 --- /dev/null +++ b/driver/whitespace.go @@ -0,0 +1,67 @@ +package driver + +func notWhitespace(sql string) bool { + const ( + code = iota + minus + slash + ccomment + endcomment + sqlcomment + ) + + state := code + for _, b := range ([]byte)(sql) { + if b == 0 { + break + } + + switch state { + case code: + switch b { + case '-': + state = minus + case '/': + state = slash + case ' ', ';', '\t', '\n', '\v', '\f', '\r': + continue + default: + return true + } + case minus: + if b != '-' { + return true + } + state = sqlcomment + case slash: + if b != '*' { + return true + } + state = ccomment + case ccomment: + if b == '*' { + state = endcomment + } + case endcomment: + switch b { + case '/': + state = code + case '*': + state = endcomment + default: + state = ccomment + } + case sqlcomment: + if b == '\n' { + state = code + } + } + } + + switch state { + case code, ccomment, endcomment, sqlcomment: + return false + default: + return true + } +} diff --git a/driver/whitespace_test.go b/driver/whitespace_test.go new file mode 100644 index 00000000..b5ef4b79 --- /dev/null +++ b/driver/whitespace_test.go @@ -0,0 +1,56 @@ +package driver + +import ( + "context" + "testing" +) + +func Fuzz_isWhitespace(f *testing.F) { + f.Add("") + f.Add(" ") + f.Add(";") + f.Add("0") + f.Add("-") + f.Add("--") + f.Add("/*") + f.Add("/*/") + f.Add("/**") + f.Add("/**0/") + f.Add("\v") + f.Add(" \v") + f.Add("\xf0") + + db, err := Open(":memory:") + if err != nil { + f.Fatal(err) + } + defer db.Close() + + f.Fuzz(func(t *testing.T, str string) { + c, err := db.Conn(context.Background()) + if err != nil { + t.Fatal(err) + } + + c.Raw(func(driverConn any) error { + conn := driverConn.(*conn).Conn + stmt, tail, err := conn.Prepare(str) + stmt.Close() + + // It's hard to be bug for bug compatible with SQLite. + // We settle for somewhat less: + // - if SQLite reports whitespace, we must too + // - if we report whitespace, SQLite must not parse a statement + if notWhitespace(str) { + if stmt == nil && tail == "" && err == nil { + t.Errorf("was whitespace: %q", str) + } + } else { + if stmt != nil { + t.Errorf("was not whitespace: %q (%v)", str, err) + } + } + return nil + }) + }) +} diff --git a/vfs/os_windows.go b/vfs/os_windows.go index a606f9c2..0398f476 100644 --- a/vfs/os_windows.go +++ b/vfs/os_windows.go @@ -45,6 +45,7 @@ func osGetExclusiveLock(file *os.File, state *LockLevel) _ErrorCode { osUnlock(file, _SHARED_FIRST, _SHARED_SIZE) // Acquire the EXCLUSIVE lock. + // Can't wait here, because the file is not OVERLAPPED. rc := osWriteLock(file, _SHARED_FIRST, _SHARED_SIZE, 0) if rc != _OK {