Skip to content

Commit

Permalink
feat(sqlbuilder): added OrderBy (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
cnlangzi authored Apr 10, 2024
1 parent 65ab358 commit 91ead50
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 0 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [1.4.2] - 2014-04-10
### Added
- added `OrderByBuilder` to prevent sql injection (#32)

## [1.4.1] - 2014-04-09
### Added
- added multi-dht support on `DB` (#31)
Expand Down
65 changes: 65 additions & 0 deletions sqlbuilder_orderby.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package sqle

import (
"slices"
"strings"
)

type OrderByBuilder struct {
*Builder
isWritten bool
allowedColumns []string
}

// OrderBy create an OrderByBuilder with allowed columns to prevent sql injection. NB: any input is allowed if it is not provided
func (b *Builder) OrderBy(allowedColumns ...string) *OrderByBuilder {
ob := &OrderByBuilder{
Builder: b,
allowedColumns: allowedColumns,
}

b.SQL(" ORDER BY ")

return ob
}

// isAllowed check if column is included in allowed columns. It will remove any untrust input from client
func (ob *OrderByBuilder) isAllowed(col string) bool {
if ob.allowedColumns == nil {
return true
}

return slices.ContainsFunc(ob.allowedColumns, func(c string) bool {
return strings.EqualFold(c, col)
})
}

// Asc order by ASC with columns
func (ob *OrderByBuilder) Asc(columns ...string) *OrderByBuilder {
for _, c := range columns {
if ob.isAllowed(c) {
if ob.isWritten {
ob.Builder.SQL(", ").SQL(c).SQL(" ASC")
} else {
ob.Builder.SQL(c).SQL(" ASC")
ob.isWritten = true
}
}
}
return ob
}

// Desc order by desc with columns
func (ob *OrderByBuilder) Desc(columns ...string) *OrderByBuilder {
for _, c := range columns {
if ob.isAllowed(c) {
if ob.isWritten {
ob.Builder.SQL(", ").SQL(c).SQL(" DESC")
} else {
ob.Builder.SQL(c).SQL(" DESC")
ob.isWritten = true
}
}
}
return ob
}
50 changes: 50 additions & 0 deletions sqlbuilder_orderby_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package sqle

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestOrderByBuilder(t *testing.T) {
tests := []struct {
name string
build func() *Builder
wanted string
}{
{
name: "no_safe_columns_should_work",
build: func() *Builder {
b := New("SELECT * FROM users")
b.OrderBy().
Desc("created_at").
Asc("id", "name").
Asc("updated_at")

return b
},
wanted: "SELECT * FROM users ORDER BY created_at DESC, id ASC, name ASC, updated_at ASC",
},
{
name: "safe_columns_should_work",
build: func() *Builder {
b := New("SELECT * FROM users")
b.OrderBy("id", "created_at", "updated_at").
Asc("id", "name").
Desc("created_at", "unsafe_input").
Asc("updated_at")

return b
},
wanted: "SELECT * FROM users ORDER BY id ASC, created_at DESC, updated_at ASC",
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
actual := test.build().String()

require.Equal(t, test.wanted, actual)
})
}
}

0 comments on commit 91ead50

Please sign in to comment.