Skip to content

Commit

Permalink
!fix(orderby): use BuildOptions instead of allowedColumns (#46)
Browse files Browse the repository at this point in the history
  • Loading branch information
cnlangzi authored May 2, 2024
1 parent f7859c0 commit d72967c
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 33 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [1.5.0]
## [1.5.1]
- !fix(orderby): use BuildOption instead of allowedColumns (#46)

## [1.5.0] - 2024-04-30
### Changed
- !renamed `Context` with `Client` (#45)

Expand Down
2 changes: 1 addition & 1 deletion sqlbuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ func (b *Builder) Build() (string, []any, error) {

// quoteColumn escapes the given column name using the Builder's Quote character.
func (b *Builder) quoteColumn(c string) string {
if strings.ContainsAny(c, "(") || strings.ContainsAny(c, " ") || strings.ContainsAny(c, "as") {
if strings.ContainsAny(c, "(") || strings.ContainsAny(c, " ") {
return c
} else {
return b.Quote + c + b.Quote
Expand Down
63 changes: 42 additions & 21 deletions sqlbuilder_orderby.go
Original file line number Diff line number Diff line change
@@ -1,25 +1,30 @@
package sqle

import (
"slices"
"strings"
)

// OrderByBuilder represents a SQL ORDER BY clause builder.
// It is used to construct ORDER BY clauses for SQL queries.
type OrderByBuilder struct {
*Builder // The underlying SQL query builder.
written bool // Indicates if the ORDER BY clause has been written.
allowedColumns []string // The list of allowed columns for ordering.
*Builder // The underlying SQL query builder.
written bool // Indicates if the ORDER BY clause has been written.
options *BuilderOptions // The list of allowed columns for ordering.
}

// NewOrderBy creates a new instance of the OrderByBuilder.
// It takes a variadic parameter `allowedColumns` which specifies the columns that are allowed to be used in the ORDER BY clause.
func NewOrderBy(allowedColumns ...string) *OrderByBuilder {
return &OrderByBuilder{
Builder: New(),
allowedColumns: allowedColumns,
func NewOrderBy(opts ...BuilderOption) *OrderByBuilder {
ob := &OrderByBuilder{
Builder: New(),
options: &BuilderOptions{},
}

for _, o := range opts {
o(ob.options)
}

return ob
}

// WithOrderBy sets the order by clause for the SQL query.
Expand All @@ -31,32 +36,45 @@ func (b *Builder) WithOrderBy(ob *OrderByBuilder) *OrderByBuilder {
return nil
}

n := b.Order(ob.allowedColumns...)
n := b.Order()

b.SQL(ob.String())

return n
}

// Order create an OrderByBuilder with allowed columns to prevent sql injection. NB: any input is allowed if it is not provided
func (b *Builder) Order(allowedColumns ...string) *OrderByBuilder {
func (b *Builder) Order(opts ...BuilderOption) *OrderByBuilder {
ob := &OrderByBuilder{
Builder: b,
allowedColumns: allowedColumns,
Builder: b,
options: &BuilderOptions{},
}

for _, o := range opts {
o(ob.options)
}

return ob
}

// isAllowed check if column is included in allowed columns. It will remove any untrust input from client
func (ob *OrderByBuilder) isAllowed(col string) bool {
if ob.allowedColumns == nil {
return true
func (ob *OrderByBuilder) getColumn(col string) (string, bool) {
if ob.options.Columns == nil {
return col, true
}

return slices.ContainsFunc(ob.allowedColumns, func(c string) bool {
return strings.EqualFold(c, col)
})
if ob.options.ToName != nil {
col = ob.options.ToName(col)
}

for _, c := range ob.options.Columns {
if strings.EqualFold(c, col) {
return c, true
}
}

return "", false

}

// By order by raw sql. eg By("a asc, b desc")
Expand Down Expand Up @@ -107,16 +125,19 @@ func (ob *OrderByBuilder) ByDesc(columns ...string) *OrderByBuilder {
// If the column has already been written, it appends a comma before adding the column.
// If it's the first column being added, it appends "ORDER BY" before adding the column.
func (ob *OrderByBuilder) add(col, direction string) {
if ob.isAllowed(col) {
c, ok := ob.getColumn(col)

if ok {
if ob.written {
ob.Builder.SQL(", ").SQL(col).SQL(direction)

ob.Builder.SQL(", ").SQL(ob.quoteColumn(c)).SQL(direction)
} else {
// only write once
if !ob.written {
ob.Builder.SQL(" ORDER BY ")
}

ob.Builder.SQL(col).SQL(direction)
ob.Builder.SQL(ob.quoteColumn(c)).SQL(direction)

ob.written = true
}
Expand Down
37 changes: 27 additions & 10 deletions sqlbuilder_orderby_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package sqle
import (
"testing"

"github.com/iancoleman/strcase"
"github.com/stretchr/testify/require"
)

Expand All @@ -23,51 +24,51 @@ func TestOrderByBuilder(t *testing.T) {

return b
},
wanted: "SELECT * FROM users ORDER BY created_at DESC, id ASC, name ASC, updated_at ASC",
wanted: "SELECT * FROM users ORDER BY `created_at` DESC, `id` ASC, `name` ASC, `updated_at` ASC",
},
{
name: "safe_columns_should_work",
build: func() *Builder {
b := New("SELECT * FROM users")
b.Order("id", "created_at", "updated_at").
b.Order(WithAllow("id", "created_at", "updated_at")).
ByAsc("id", "name").
ByDesc("created_at", "unsafe_input").
ByAsc("updated_at")

return b
},
wanted: "SELECT * FROM users ORDER BY id ASC, created_at DESC, updated_at ASC",
wanted: "SELECT * FROM users ORDER BY `id` ASC, `created_at` DESC, `updated_at` ASC",
},
{
name: "order_by_raw_sql_should_work",
build: func() *Builder {
b := New("SELECT * FROM users")
b.Order("id", "created_at", "updated_at", "age").
b.Order(WithAllow("id", "created_at", "updated_at", "age")).
By("created_at desc, id, name asc, updated_at asc, age invalid_by, unsafe_asc, unsafe_desc desc")

return b
},
wanted: "SELECT * FROM users ORDER BY created_at DESC, id ASC, updated_at ASC",
wanted: "SELECT * FROM users ORDER BY `created_at` DESC, `id` ASC, `updated_at` ASC",
},
{
name: "with_order_by_should_work",
build: func() *Builder {
b := New("SELECT * FROM users")

ob := NewOrderBy("id", "created_at", "updated_at", "age")
ob := NewOrderBy(WithAllow("id", "created_at", "updated_at", "age"))
ob.By("created_at desc, id, name asc, updated_at asc, age invalid_by, unsafe_asc, unsafe_desc desc")

b.WithOrderBy(ob)

return b
},
wanted: "SELECT * FROM users ORDER BY created_at DESC, id ASC, updated_at ASC",
wanted: "SELECT * FROM users ORDER BY `created_at` DESC, `id` ASC, `updated_at` ASC",
},
{
name: "with_nil_order_by_should_work",
build: func() *Builder {
b := New("SELECT * FROM users")
b.Order("id", "created_at", "updated_at").
b.Order(WithAllow("id", "created_at", "updated_at")).
ByAsc("id", "name").
ByDesc("created_at", "unsafe_input").
ByAsc("updated_at")
Expand All @@ -76,14 +77,14 @@ func TestOrderByBuilder(t *testing.T) {

return b
},
wanted: "SELECT * FROM users ORDER BY id ASC, created_at DESC, updated_at ASC",
wanted: "SELECT * FROM users ORDER BY `id` ASC, `created_at` DESC, `updated_at` ASC",
},
{
name: "with_empty_order_by_should_work",
build: func() *Builder {
b := New("SELECT * FROM users")

ob := NewOrderBy("age").
ob := NewOrderBy(WithAllow("age")).
ByAsc("id", "name").
ByDesc("created_at", "unsafe_input").
ByAsc("updated_at")
Expand All @@ -94,6 +95,22 @@ func TestOrderByBuilder(t *testing.T) {
},
wanted: "SELECT * FROM users",
},
{
name: "with_to_name_order_by_should_work",
build: func() *Builder {
b := New("SELECT * FROM users")

ob := NewOrderBy(WithToName(strcase.ToSnake), WithAllow("created_at")).
ByAsc("id", "name").
ByDesc("createdAt", "unsafe_input").
ByAsc("updated_at")

b.WithOrderBy(ob)

return b
},
wanted: "SELECT * FROM users ORDER BY `created_at` DESC",
},
}

for _, test := range tests {
Expand Down

0 comments on commit d72967c

Please sign in to comment.