diff --git a/cloudflare/d1/rows.go b/cloudflare/d1/rows.go index 47079e8..326b995 100644 --- a/cloudflare/d1/rows.go +++ b/cloudflare/d1/rows.go @@ -7,8 +7,6 @@ import ( "math" "sync" "syscall/js" - - "github.com/syumai/workers/internal/jsutil" ) type rows struct { @@ -16,8 +14,7 @@ type rows struct { currentRow int // columns is cached value of Columns method. // do not use this directly. - _columns []string - onceColumns sync.Once + columns []string // _rowsLen is cached value of rowsLen method. // do not use this directly. _rowsLen int @@ -27,23 +24,10 @@ type rows struct { var _ driver.Rows = (*rows)(nil) -// Columns returns column names retrieved from query result object's keys. +// Columns returns column names retrieved from query result. // If rows are empty, this returns nil. func (r *rows) Columns() []string { - r.onceColumns.Do(func() { - if r.rowsObj.Length() == 0 { - // return nothing when row count is zero. - return - } - colsArray := jsutil.ObjectClass.Call("keys", r.rowsObj.Index(0)) - colsLen := colsArray.Length() - cols := make([]string, colsLen) - for i := 0; i < colsLen; i++ { - cols[i] = colsArray.Index(i).String() - } - r._columns = cols - }) - return r._columns + return r.columns } func (r *rows) Close() error { @@ -91,9 +75,9 @@ func (r *rows) Next(dest []driver.Value) error { return io.EOF } rowObj := r.rowsObj.Index(r.currentRow) - cols := r.Columns() - for i, col := range cols { - v, err := convertRowColumnValueToAny(rowObj.Get(col)) + rowObjLen := rowObj.Length() + for i := 0; i < rowObjLen; i++ { + v, err := convertRowColumnValueToAny(rowObj.Index(i)) if err != nil { return err } diff --git a/cloudflare/d1/stmt.go b/cloudflare/d1/stmt.go index eec4262..c1c92e1 100644 --- a/cloudflare/d1/stmt.go +++ b/cloudflare/d1/stmt.go @@ -59,15 +59,30 @@ func (s *stmt) QueryContext(_ context.Context, args []driver.NamedValue) (driver for i, arg := range args { argValues[i] = arg.Value } - resultPromise := s.stmtObj.Call("bind", argValues...).Call("all") + resultPromise := s.stmtObj.Call("bind", argValues...).Call("raw", map[string]any{"columnNames": true}) rowsObj, err := jsutil.AwaitPromise(resultPromise) if err != nil { return nil, err } - if !rowsObj.Get("success").Bool() { - return nil, errors.New("d1: failed to query") + // If there are no rows to retrieve, length is 0. + if rowsObj.Length() == 0 { + return &rows{ + columns: nil, + rowsObj: rowsObj, + }, nil } + + // The first result array includes the column names. + colsArray := rowsObj.Index(0) + colsLen := colsArray.Length() + cols := make([]string, colsLen) + for i := 0; i < colsLen; i++ { + cols[i] = colsArray.Index(i).String() + } + // Remove the first result array from the rowsObj. + rowsObj.Call("shift") return &rows{ - rowsObj: rowsObj.Get("results"), + columns: cols, + rowsObj: rowsObj, }, nil }