From 3b191f0e6f8b832f9888f00013fa7add2de39722 Mon Sep 17 00:00:00 2001 From: Noboru Saito Date: Thu, 13 Feb 2020 13:50:49 +0900 Subject: [PATCH] Extension to help use `LOAD DATA LOCAL INFILE` Allows data to be imported by Exec(parameter) after using special `Data::Data` as file path of `LOAD DATA LOCAL INFILE`. It is easy to use instead of `INSERT INTO`. This is easier to import than io.Reader when you want to import data generated from the program. Implementation method. 1. Add the `LOAD DATA` flag to Prepare() and Exec(). 2. Exec() with `LOAD DATA` flag set will send the parameter converted to TSV. --- AUTHORS | 1 + README.md | 26 +++++++ connection.go | 25 +++++++ driver_test.go | 173 ++++++++++++++++++++++++++++++++++++++++++++++ infile.go | 181 +++++++++++++++++++++++++++++++++++++++++++++++++ infile_test.go | 145 +++++++++++++++++++++++++++++++++++++++ statement.go | 6 ++ 7 files changed, 557 insertions(+) create mode 100644 infile_test.go diff --git a/AUTHORS b/AUTHORS index ad5989800..4bc4427b3 100644 --- a/AUTHORS +++ b/AUTHORS @@ -65,6 +65,7 @@ Maciej Zimnoch Michael Woolnough Nathanial Murphy Nicola Peduzzi +Noboru Saito Olivier Mengué oscarzhao Paul Bonser diff --git a/README.md b/README.md index d2627a41a..1c1d95f94 100644 --- a/README.md +++ b/README.md @@ -451,6 +451,32 @@ To use a `io.Reader` a handler function must be registered with `mysql.RegisterR See the [godoc of Go-MySQL-Driver](https://godoc.org/github.com/go-sql-driver/mysql "golang mysql driver documentation") for details. +### Execute `LOAD DATA LOCAL INFILE` instead of `INSERT INTO` + +Enables `LOAD DATA LOCAL INFILE` without the need to call a special function. +Using `LOAD DATA LOCAL INFILE` instead of `INSERT INTO` is available with the filepath `Data::Data`. +Create a statement by executing `LOAD DATA LOCAL INFILE 'Data::Data' INTO TABLE table name` as a query to the prepare function. +Execute the returned statement with a value in Exec. +Exec is imported as LOAD DATA until the statement is closed. + +```go +//stmt, _ = db.Prepare("INSERT INTO test (id, value1, value2 ) VALUES (?, ?, ?);") +stmt, _ = db.Prepare("LOAD DATA LOCAL INFILE 'Data::Data' INTO TABLE test (id, value1, value2);") +stmt.Exec(1, "test11", "test12") +stmt.Exec(2, "test21", "test22") +stmt.Close() +``` + +It is also possible to perform a `LOAD DATA LOCAL INFILE Data::Data' INTO TABLE table name` Query with Exec. +In that case, import the following Exec parameters as LOAD DATA. +To finish importing LOAD data, run the parameter with nil. + +```go +db.Exec("LOAD DATA LOCAL INFILE 'Data::Data' INTO TABLE test (id, value1, value2);") +db.Exec("", 1, "test11", "test12") +db.Exec("", 2, "test21", "test22") +db.Exec("") +``` ### `time.Time` support The default internal output type of MySQL `DATE` and `DATETIME` values is `[]byte` which allows you to scan the value into a `[]byte`, `string` or `sql.RawBytes` variable in your program. diff --git a/connection.go b/connection.go index e4bb59e67..46efbd53e 100644 --- a/connection.go +++ b/connection.go @@ -34,6 +34,9 @@ type mysqlConn struct { sequence uint8 parseTime bool reset bool // set when the Go SQL package calls ResetSession + inLoadData bool + loadData []byte + maxLoadDataSize int // for context support (Go 1.8+) watching bool @@ -151,6 +154,18 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } + if len(query) > 4 && strings.EqualFold(query[:4], "LOAD") { + err := mc.exec(query) + if err != nil { + return nil, err + } + mc.inLoadData = true + stmt := &mysqlStmt{ + mc: mc, + paramCount: -1, + } + return stmt, err + } // Send command err := mc.writeCommandPacketStr(comStmtPrepare, query) if err != nil { @@ -310,6 +325,9 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } + if mc.inLoadData { + return nil, mc.loadDataWrite(args) + } if len(args) != 0 { if !mc.cfg.InterpolateParams { return nil, driver.ErrSkip @@ -524,6 +542,10 @@ func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []drive return nil, err } + if mc.inLoadData { + return nil, mc.loadDataWrite(dargs) + } + if err := mc.watchCancel(ctx); err != nil { return nil, err } @@ -577,6 +599,9 @@ func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue return nil, err } + if stmt.mc.inLoadData { + return nil, stmt.mc.loadDataWrite(dargs) + } if err := stmt.mc.watchCancel(ctx); err != nil { return nil, err } diff --git a/driver_test.go b/driver_test.go index ace083dfc..14e6c6677 100644 --- a/driver_test.go +++ b/driver_test.go @@ -1275,6 +1275,179 @@ func TestLoadData(t *testing.T) { }) } +func TestExecLoadData(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + in := [][]driver.Value{{"tt1", "tt1"}, {"tt2", "tt2"}, {"tt3", nil}} + v := [][]string{{"tt1", "tt1"}, {"tt2", "tt2"}, {"tt3", ""}} + dbt.db.Exec("DROP TABLE IF EXISTS test") + + dbt.mustExec("CREATE TABLE test (value text, value2 text)") + dbt.mustExec("LOAD DATA LOCAL INFILE 'Data::Data' INTO TABLE test") + + for _, v := range in { + dbt.mustExec("", v[0], v[1]) + } + dbt.mustExec("") + + rows := dbt.mustQuery("SELECT * FROM test") + for n := 0; rows.Next(); n++ { + var out, out2 sql.RawBytes + rows.Scan(&out, &out2) + if !bytes.Equal([]byte(v[n][0]), out) { + dbt.Errorf("expected %v, got %v", v[n][0], out) + } + if !bytes.Equal([]byte(v[n][1]), out2) { + dbt.Errorf("expected %v, got %v", v[n][1], out2) + } + rows.Close() + } + + dbt.mustExec("DROP TABLE IF EXISTS test") + }) +} + +func TestPrepareLoadData(t *testing.T) { + count := 1000 + runTests(t, dsn, func(dbt *DBTest) { + in := []driver.Value{"test1", "test2"} + v := []string{"test1", "test2"} + dbt.db.Exec("DROP TABLE IF EXISTS test") + dbt.mustExec("CREATE TABLE test (id INT, value1 text, value2 text)") + + // Uncomment the next 'INSERT' line and comment out 'LOAD DATA' so it works. + //stmt, err := dbt.db.Prepare("INSERT INTO test (id, value1, value2) VALUES(?, ?, ?);") + stmt, err := dbt.db.Prepare("LOAD DATA LOCAL INFILE 'Data::Data' INTO TABLE test (id, value1, value2)") + if err != nil { + t.Fatalf("error preparing statement: %s", err.Error()) + } + + for i := 0; i < count; i++ { + _, err = stmt.Exec(i, in[0], in[1]) + if err != nil { + t.Fatalf("error executing statement: %s", err.Error()) + } + } + err = stmt.Close() + if err != nil { + t.Fatalf("error close statement: %s", err.Error()) + } + + rows := dbt.mustQuery("SELECT COUNT(*) FROM test") + if !rows.Next() { + dbt.Fatalf("error rows: %s", rows.Err()) + } + + c := 0 + if err := rows.Scan(&c); err != nil { + dbt.Fatal(err.Error()) + } + if count != c { + dbt.Errorf("Load Data Error:%v != %v", c, count) + } + rows = dbt.mustQuery("SELECT id, value1, value2 FROM test") + for rows.Next() { + var id int + var out, out2 sql.RawBytes + if err := rows.Scan(&id, &out, &out2); err != nil { + dbt.Fatal(err.Error()) + } + if !bytes.Equal([]byte(v[0]), out) { + dbt.Errorf("expected %v, got %v", v[0], out) + } + if !bytes.Equal([]byte(v[1]), out2) { + dbt.Errorf("expected %v, got %v", v[1], out2) + } + } + rows.Close() + + dbt.mustExec("DROP TABLE IF EXISTS test") + }) +} + +func TestExecLoadDataType(t *testing.T) { + type tType struct { + dbType string + value driver.Value + result sql.RawBytes + } + tTypes := []tType{ + { + dbType: "varchar(10)", + value: "ttt", + result: []byte("ttt"), + }, + { + dbType: "text", + value: "ttt", + result: []byte("ttt"), + }, + { + dbType: "text", + value: nil, + result: []byte(""), + }, + { + dbType: "text", + value: []byte(nil), + result: []byte(""), + }, + { + dbType: "int", + value: 42, + result: []byte("42"), + }, + { + dbType: "float", + value: 42.23, + result: []byte("42.23"), + }, + { + dbType: "datetime", + value: time.Date(2001, 5, 20, 23, 59, 59, 0, time.UTC), + result: []byte("2001-05-20 23:59:59"), + }, + { + dbType: "datetime", + value: time.Time{}, + result: []byte("0000-00-00 00:00:00"), + }, + { + dbType: "bool", + value: true, + result: []byte("1"), + }, + { + dbType: "bool", + value: false, + result: []byte("0"), + }, + } + runTests(t, dsn, func(dbt *DBTest) { + var rows *sql.Rows + for _, tType := range tTypes { + dbt.db.Exec("DROP TABLE IF EXISTS test") + dbt.mustExec("CREATE TABLE test (value " + tType.dbType + ")") + + dbt.mustExec("LOAD DATA LOCAL INFILE 'Data::Data' INTO TABLE test") + dbt.mustExec("", tType.value) + dbt.mustExec("") + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + var out sql.RawBytes + rows.Scan(&out) + if !bytes.Equal(tType.result, out) { + dbt.Errorf("%s: expected %v, got %v(%s)", tType.dbType, tType.result, out, out) + } + } else { + dbt.Errorf("%s: no data", tType.dbType) + } + rows.Close() + + dbt.mustExec("DROP TABLE IF EXISTS test") + } + }) +} + func TestFoundRows(t *testing.T) { runTests(t, dsn, func(dbt *DBTest) { dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)") diff --git a/infile.go b/infile.go index 273cb0ba5..3e3122501 100644 --- a/infile.go +++ b/infile.go @@ -9,11 +9,14 @@ package mysql import ( + "database/sql/driver" "fmt" "io" "os" + "strconv" "strings" "sync" + "time" ) var ( @@ -101,6 +104,10 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { packetSize = mc.maxWriteSize } + if name == "Data::Data" { + return mc.loadDataStart() + } + if idx := strings.Index(name, "Reader::"); idx == 0 || (idx > 0 && name[idx-1] == '/') { // io.Reader // The server might return an an absolute path. See issue #355. name = name[idx+8:] @@ -180,3 +187,177 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { mc.readPacket() return err } + +func (mc *mysqlConn) loadDataStart() (err error) { + mc.inLoadData = true + mc.maxLoadDataSize = 16 * 1024 // 16KB is small enough for disk readahead and large enough for TCP + if (mc.maxWriteSize / 2) < mc.maxLoadDataSize { + mc.maxLoadDataSize = mc.maxWriteSize / 2 + } + mc.loadData = []byte{0, 0, 0, 0} + return nil +} + +func (mc *mysqlConn) loadDataWrite(args []driver.Value) (err error) { + if len(args) == 0 { + return mc.loadDataTerminate() + } + + for n, column := range args { + if n > 0 { + mc.loadData = append(mc.loadData, '\t') + } + mc.loadData = mc.appendEncode(mc.loadData, column) + } + mc.loadData = append(mc.loadData, '\n') + if len(mc.loadData) > mc.maxLoadDataSize { + err = mc.loadDataWritePacket() + if err != nil { + return err + } + } + return nil +} + +func (mc *mysqlConn) loadDataWritePacket() (err error) { + if ioErr := mc.writePacket(mc.loadData); ioErr != nil { + return ioErr + } + mc.loadData = mc.loadData[:4] + return nil +} + +func (mc *mysqlConn) loadDataTerminate() (err error) { + defer func() { + mc.inLoadData = false + }() + if ioErr := mc.loadDataWritePacket(); ioErr != nil { + return ioErr + } + mc.loadData = mc.loadData[:4] + if ioErr := mc.writePacket(mc.loadData); ioErr != nil { + return ioErr + } + + // read OK packet + if err == nil { + return mc.readResultOK() + } + + mc.readPacket() + return err +} + +func (mc *mysqlConn) appendEncode(buf []byte, x driver.Value) []byte { + switch v := x.(type) { + case int64: + return strconv.AppendInt(buf, v, 10) + case uint64: + return strconv.AppendUint(buf, v, 10) + case float64: + return strconv.AppendFloat(buf, v, 'g', -1, 64) + case bool: + if v { + return append(buf, '1') + } else { + return append(buf, '0') + } + case time.Time: + if v.IsZero() { + return append(buf, "0000-00-00"...) + } else { + v := v.In(mc.cfg.Loc) + v = v.Add(time.Nanosecond * 500) // To round under microsecond + year := v.Year() + year100 := year / 100 + year1 := year % 100 + month := v.Month() + day := v.Day() + hour := v.Hour() + minute := v.Minute() + second := v.Second() + micro := v.Nanosecond() / 1000 + + buf := append(buf, []byte{ + digits10[year100], digits01[year100], + digits10[year1], digits01[year1], + '-', + digits10[month], digits01[month], + '-', + digits10[day], digits01[day], + ' ', + digits10[hour], digits01[hour], + ':', + digits10[minute], digits01[minute], + ':', + digits10[second], digits01[second], + }...) + if micro != 0 { + micro10000 := micro / 10000 + micro100 := micro / 100 % 100 + micro1 := micro % 100 + buf = append(buf, []byte{ + '.', + digits10[micro10000], digits01[micro10000], + digits10[micro100], digits01[micro100], + digits10[micro1], digits01[micro1], + }...) + } + return buf + } + case []byte: + if v == nil { + return append(buf, "\\N"...) + } else { + if mc.status&statusNoBackslashEscapes == 0 { + buf = escapeBytesBackslash(buf, v) + } else { + buf = escapeBytesQuotes(buf, v) + } + return buf + } + case string: + return appendEscaped(buf, v) + case nil: + return append(buf, "\\N"...) + default: + errLog.Print("unsupported type") + return buf + } +} + +func appendEscaped(buf []byte, v string) []byte { + escapeNeeded := false + startPos := 0 + var c byte + + for i := 0; i < len(v); i++ { + c = v[i] + if c == '\\' || c == '\n' || c == '\r' || c == '\t' { + escapeNeeded = true + startPos = i + break + } + } + if !escapeNeeded { + return append(buf, v...) + } + + result := append(buf, v[:startPos]...) + for i := startPos; i < len(v); i++ { + c = v[i] + switch c { + case '\\': + result = append(result, '\\', '\\') + case '\n': + result = append(result, '\\', 'n') + case '\r': + result = append(result, '\\', 'r') + case '\t': + result = append(result, '\\', 't') + default: + result = append(result, c) + } + } + return result +} diff --git a/infile_test.go b/infile_test.go new file mode 100644 index 000000000..9c236b8ba --- /dev/null +++ b/infile_test.go @@ -0,0 +1,145 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "database/sql/driver" + "reflect" + "testing" + "time" +) + +func Test_mysqlConn_appendEncode(t *testing.T) { + type args struct { + buf []byte + x driver.Value + } + tests := []struct { + name string + args args + want []byte + }{ + { + name: "test String", + args: args{[]byte{}, "test"}, + want: []byte("test"), + }, + { + name: "test Int64", + args: args{[]byte{}, driver.Value(int64(42))}, + want: []byte("42"), + }, + { + name: "test Uint64", + args: args{[]byte{}, driver.Value(uint64(42))}, + want: []byte("42"), + }, + { + name: "test Fload64", + args: args{[]byte{}, driver.Value(float64(42.23))}, + want: []byte("42.23"), + }, + { + name: "test Bool", + args: args{[]byte{}, driver.Value(bool(true))}, + want: []byte("1"), + }, + { + name: "test BoolFalse", + args: args{[]byte{}, driver.Value(bool(false))}, + want: []byte("0"), + }, + { + name: "test nil", + args: args{[]byte{}, driver.Value(nil)}, + want: []byte("\\N"), + }, + { + name: "test TimeNULL", + args: args{[]byte{}, driver.Value(time.Time{})}, + want: []byte("0000-00-00"), + }, + { + name: "test Time", + args: args{[]byte{}, driver.Value(time.Date(2014, time.December, 31, 12, 13, 24, 0, time.UTC))}, + want: []byte("2014-12-31 12:13:24"), + }, + { + name: "test byteNil", + args: args{[]byte{}, driver.Value([]byte(nil))}, + want: []byte("\\N"), + }, + { + name: "test byte", + args: args{[]byte{}, driver.Value([]byte("test"))}, + want: []byte("test"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mc := &mysqlConn{ + cfg: NewConfig(), + maxAllowedPacket: defaultMaxAllowedPacket, + } + if got := mc.appendEncode(tt.args.buf, tt.args.x); !reflect.DeepEqual(got, tt.want) { + t.Errorf("mysqlConn.appendEncode() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_appendEscaped(t *testing.T) { + type args struct { + buf []byte + v string + } + tests := []struct { + name string + args args + want []byte + }{ + { + name: "test1", + args: args{[]byte{}, "test"}, + want: []byte("test"), + }, + { + name: "test TAB", + args: args{[]byte{}, "t\test"}, + want: []byte("t\\test"), + }, + { + name: "test LF", + args: args{[]byte{}, "t\nest"}, + want: []byte("t\\nest"), + }, + { + name: "test CR", + args: args{[]byte{}, "t\rest"}, + want: []byte("t\\rest"), + }, + { + name: "test BackSlash", + args: args{[]byte{}, "t\\est"}, + want: []byte("t\\\\est"), + }, + { + name: "test All", + args: args{[]byte{}, "t\t\n\r\\est"}, + want: []byte("t\\t\\n\\r\\\\est"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := appendEscaped(tt.args.buf, tt.args.v); !reflect.DeepEqual(got, tt.want) { + t.Errorf("appendEscaped() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/statement.go b/statement.go index f7e370939..3b4ea3d98 100644 --- a/statement.go +++ b/statement.go @@ -29,6 +29,9 @@ func (stmt *mysqlStmt) Close() error { //errLog.Print(ErrInvalidConn) return driver.ErrBadConn } + if stmt.mc.inLoadData { + return stmt.mc.loadDataTerminate() + } err := stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id) stmt.mc = nil @@ -48,6 +51,9 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } + if stmt.mc.inLoadData { + return nil, stmt.mc.loadDataWrite(args) + } // Send command err := stmt.writeExecutePacket(args) if err != nil {