Skip to content

Commit

Permalink
Merge pull request #4 from noborus/guessmode
Browse files Browse the repository at this point in the history
Guessmode
  • Loading branch information
Noboru Saito authored Jul 28, 2017
2 parents 57fe342 + 794874b commit 667a488
Show file tree
Hide file tree
Showing 9 changed files with 114 additions and 79 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ Options:
database connection option.
-id string
Field delimiter for input. (default ",")
-ig
Guess format from extension.
-ih
The first line is interpreted as column names.
-iltsv
Expand All @@ -43,7 +45,7 @@ Options:
-oh
Output column name as header.
-ojson
Json format for output.
JSON format for output.
-oltsv
LTSV format for output.
-omd
Expand Down
40 changes: 19 additions & 21 deletions csv.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,29 +35,27 @@ func csvheader(reader *csv.Reader) ([]string, error) {
return header, err
}

func (trdsql TRDSQL) csvReader(db *DDB, sqlstr string, tablenames []string) (string, int) {
func (trdsql TRDSQL) csvReader(db *DDB, sqlstr string, tablename string) (string, int) {
var header []string
for _, tablename := range tablenames {
reader, err := csvOpen(tablename, trdsql.inSep, trdsql.iskip)
if err != nil {
// no file
continue
}
rtable := db.escapetable(tablename)
sqlstr = rewrite(sqlstr, tablename, rtable)
header, err = csvheader(reader)
if err != nil {
log.Println(err)
return sqlstr, 1
}
db.Create(rtable, header, trdsql.ihead)
err = db.ImportPrepare(rtable, header, trdsql.ihead)
if err != nil {
log.Println(err)
return sqlstr, 1
}
db.csvImport(reader, header, trdsql.ihead)
reader, err := csvOpen(tablename, trdsql.inSep, trdsql.iskip)
if err != nil {
// no file
return sqlstr, 0
}
rtable := db.escapetable(tablename)
sqlstr = db.rewrite(sqlstr, tablename, rtable)
header, err = csvheader(reader)
if err != nil {
log.Println(err)
return sqlstr, 1
}
db.Create(rtable, header, trdsql.ihead)
err = db.ImportPrepare(rtable, header, trdsql.ihead)
if err != nil {
log.Println(err)
return sqlstr, 1
}
db.csvImport(reader, header, trdsql.ihead)
return sqlstr, 0
}

Expand Down
16 changes: 12 additions & 4 deletions database.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ import (

// DDB is *sql.DB wrapper.
type DDB struct {
driver string
dsn string
escape string
driver string
dsn string
escape string
rewritten []string
*sql.DB
stmt *sql.Stmt
}
Expand Down Expand Up @@ -170,8 +171,15 @@ func (db *DDB) escapetable(oldname string) string {
return newname
}

func rewrite(sqlstr string, oldname string, newname string) (rewrite string) {
func (db *DDB) rewrite(sqlstr string, oldname string, newname string) (rewrite string) {
for _, rewritten := range db.rewritten {
if rewritten == newname {
// Rewritten
return sqlstr
}
}
rewrite = strings.Replace(sqlstr, oldname, newname, -1)
db.rewritten = append(db.rewritten, newname)
return rewrite
}

Expand Down
42 changes: 20 additions & 22 deletions ltsv.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,29 +28,27 @@ func ltsvOpen(filename string, delimiter string, skip int) (*ltsv.Reader, error)
return reader, nil
}

func (trdsql TRDSQL) ltsvReader(db *DDB, sqlstr string, tablenames []string) (string, int) {
for _, tablename := range tablenames {
reader, err := ltsvOpen(tablename, trdsql.inSep, trdsql.iskip)
if err != nil {
// no file
continue
}
rtable := db.escapetable(tablename)
sqlstr = rewrite(sqlstr, tablename, rtable)
first, err := reader.Read()
if err != nil {
return sqlstr, 1
}
header := keys(first)
db.Create(rtable, header, true)
err = db.ImportPrepare(rtable, header, true)
if err != nil {
log.Println(err)
return sqlstr, 1
}

db.ltsvImport(reader, first, header)
func (trdsql TRDSQL) ltsvReader(db *DDB, sqlstr string, tablename string) (string, int) {
reader, err := ltsvOpen(tablename, "\t", trdsql.iskip)
if err != nil {
// no file
return sqlstr, 0
}
rtable := db.escapetable(tablename)
sqlstr = db.rewrite(sqlstr, tablename, rtable)
first, err := reader.Read()
if err != nil {
return sqlstr, 1
}
header := keys(first)
db.Create(rtable, header, true)
err = db.ImportPrepare(rtable, header, true)
if err != nil {
log.Println(err)
return sqlstr, 1
}

db.ltsvImport(reader, first, header)
return sqlstr, 0
}

Expand Down
2 changes: 0 additions & 2 deletions ltsv_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package main

import (
"fmt"
"strings"
"testing"

Expand Down Expand Up @@ -31,7 +30,6 @@ ID:2 name:testb
reader := ltsv.NewReader(s)
r, _ := reader.Read()
if r["ID"] != "1" || r["name"] != "testa" {
fmt.Printf("[%s]\n", r["ID"])
t.Error("invalid value", r["ID"])
}
}
2 changes: 2 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
type TRDSQL struct {
outStream io.Writer
errStream io.Writer
iguess bool
iltsv bool
inSep string
ihead bool
iskip int
Expand Down
48 changes: 32 additions & 16 deletions trdsql.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ import (
"strings"
)

const VERSION = `0.3.0`

var debug = debugT(false)

type debugT bool
Expand All @@ -29,9 +27,6 @@ func (trdsql TRDSQL) Run(args []string) int {
version bool
odriver string
odsn string
iltsv bool
inSep string
ihead bool
iskip int
query string
driver string
Expand Down Expand Up @@ -63,10 +58,11 @@ Options:
flags.StringVar(&cfg.Db, "db", cfg.Db, "Specify db name of the setting.")
flags.StringVar(&odriver, "driver", "", "database driver. [ "+strings.Join(sql.Drivers(), " | ")+" ]")
flags.StringVar(&odsn, "dsn", "", "database connection option.")
flags.BoolVar(&iltsv, "iltsv", false, "LTSV format for input.")
flags.StringVar(&inSep, "id", ",", "Field delimiter for input.")
flags.BoolVar(&trdsql.iguess, "ig", false, "Guess format from extension.")
flags.BoolVar(&trdsql.iltsv, "iltsv", false, "LTSV format for input.")
flags.StringVar(&trdsql.inSep, "id", ",", "Field delimiter for input.")
flags.StringVar(&trdsql.outSep, "od", ",", "Field delimiter for output.")
flags.BoolVar(&ihead, "ih", false, "The first line is interpreted as column names.")
flags.BoolVar(&trdsql.ihead, "ih", false, "The first line is interpreted as column names.")
flags.BoolVar(&oltsv, "oltsv", false, "LTSV format for output.")
flags.BoolVar(&oat, "oat", false, "ASCII Table format for output.")
flags.BoolVar(&omd, "omd", false, "Mark Down format for output.")
Expand Down Expand Up @@ -134,17 +130,11 @@ Options:
}
trdsql.iskip = iskip
var r int
if iltsv {
trdsql.inSep = "\t"
sqlstr, r = trdsql.ltsvReader(db, sqlstr, tablenames)
} else {
trdsql.inSep = inSep
trdsql.ihead = ihead
sqlstr, r = trdsql.csvReader(db, sqlstr, tablenames)
}
sqlstr, r = trdsql.tableReader(db, sqlstr, tablenames)
if r != 0 {
return r
}

switch {
case oltsv:
r = trdsql.ltsvWrite(db, sqlstr)
Expand All @@ -163,6 +153,32 @@ Options:
return r
}

func (trdsql TRDSQL) tableReader(db *DDB, sqlstr string, tablenames []string) (string, int) {
var r int
for _, tablename := range tablenames {
ltsv := false
if trdsql.iltsv {
ltsv = true
} else if trdsql.iguess {
ltsv = guessExtension(tablename)
}
if ltsv {
sqlstr, r = trdsql.ltsvReader(db, sqlstr, tablename)
} else {
sqlstr, r = trdsql.csvReader(db, sqlstr, tablename)
}
}
return sqlstr, r
}

func guessExtension(tablename string) bool {
pos := strings.LastIndex(tablename, ".")
if pos > 0 && tablename[pos:] == ".ltsv" {
return true
}
return false
}

func getSeparator(sepString string) (rune, error) {
if sepString == "" {
return 0, nil
Expand Down
35 changes: 22 additions & 13 deletions trdsql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ var outformat = []string{
"-oltsv",
"-oat",
"-omd",
"-ojson",
"-oraw",
}

func TestRun(t *testing.T) {
Expand Down Expand Up @@ -61,15 +63,12 @@ func TestLtsvRun(t *testing.T) {
}
}

var tsql = []string{
"test.sql",
}

func TestQueryfileRun(t *testing.T) {
func TestGuessRun(t *testing.T) {
outStream, errStream := new(bytes.Buffer), new(bytes.Buffer)
trdsql := &TRDSQL{outStream: outStream, errStream: errStream}
for _, c := range tsql {
args := []string{"trdsql", "-q", "testdata/" + c}
for _, c := range append(tcsv, tltsv...) {
sql := "SELECT * FROM testdata/" + c
args := []string{"trdsql", "-ig", sql}
if trdsql.Run(args) != 0 {
t.Errorf("trdsql error.")
}
Expand All @@ -79,13 +78,15 @@ func TestQueryfileRun(t *testing.T) {
}
}

/*
func TestPgRun(t *testing.T) {
var tsql = []string{
"test.sql",
}

func TestQueryfileRun(t *testing.T) {
outStream, errStream := new(bytes.Buffer), new(bytes.Buffer)
trdsql := &TRDSQL{outStream: outStream, errStream: errStream}
for _, c := range tcsv {
sql := "SELECT * FROM testdata/" + c
args := []string{"trdsql", "-driver", "postgres", sql}
for _, c := range tsql {
args := []string{"trdsql", "-q", "testdata/" + c}
if trdsql.Run(args) != 0 {
t.Errorf("trdsql error.")
}
Expand All @@ -94,4 +95,12 @@ func TestPgRun(t *testing.T) {
}
}
}
*/

func TestGuessExtension(t *testing.T) {
if guessExtension("test.ltsv") != true {
t.Errorf("guessExtension error.")
}
if guessExtension("test.csv") != false {
t.Errorf("guessExtension error.")
}
}
4 changes: 4 additions & 0 deletions version.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
package main

// VERSION is trdsql version
const VERSION = `0.3.1`

0 comments on commit 667a488

Please sign in to comment.