-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(sqlbuilder): added OrderBy (#32)
- Loading branch information
Showing
3 changed files
with
119 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
package sqle | ||
|
||
import ( | ||
"slices" | ||
"strings" | ||
) | ||
|
||
type OrderByBuilder struct { | ||
*Builder | ||
isWritten bool | ||
allowedColumns []string | ||
} | ||
|
||
// OrderBy create an OrderByBuilder with allowed columns to prevent sql injection. NB: any input is allowed if it is not provided | ||
func (b *Builder) OrderBy(allowedColumns ...string) *OrderByBuilder { | ||
ob := &OrderByBuilder{ | ||
Builder: b, | ||
allowedColumns: allowedColumns, | ||
} | ||
|
||
b.SQL(" ORDER BY ") | ||
|
||
return ob | ||
} | ||
|
||
// isAllowed check if column is included in allowed columns. It will remove any untrust input from client | ||
func (ob *OrderByBuilder) isAllowed(col string) bool { | ||
if ob.allowedColumns == nil { | ||
return true | ||
} | ||
|
||
return slices.ContainsFunc(ob.allowedColumns, func(c string) bool { | ||
return strings.EqualFold(c, col) | ||
}) | ||
} | ||
|
||
// Asc order by ASC with columns | ||
func (ob *OrderByBuilder) Asc(columns ...string) *OrderByBuilder { | ||
for _, c := range columns { | ||
if ob.isAllowed(c) { | ||
if ob.isWritten { | ||
ob.Builder.SQL(", ").SQL(c).SQL(" ASC") | ||
} else { | ||
ob.Builder.SQL(c).SQL(" ASC") | ||
ob.isWritten = true | ||
} | ||
} | ||
} | ||
return ob | ||
} | ||
|
||
// Desc order by desc with columns | ||
func (ob *OrderByBuilder) Desc(columns ...string) *OrderByBuilder { | ||
for _, c := range columns { | ||
if ob.isAllowed(c) { | ||
if ob.isWritten { | ||
ob.Builder.SQL(", ").SQL(c).SQL(" DESC") | ||
} else { | ||
ob.Builder.SQL(c).SQL(" DESC") | ||
ob.isWritten = true | ||
} | ||
} | ||
} | ||
return ob | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
package sqle | ||
|
||
import ( | ||
"testing" | ||
|
||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestOrderByBuilder(t *testing.T) { | ||
tests := []struct { | ||
name string | ||
build func() *Builder | ||
wanted string | ||
}{ | ||
{ | ||
name: "no_safe_columns_should_work", | ||
build: func() *Builder { | ||
b := New("SELECT * FROM users") | ||
b.OrderBy(). | ||
Desc("created_at"). | ||
Asc("id", "name"). | ||
Asc("updated_at") | ||
|
||
return b | ||
}, | ||
wanted: "SELECT * FROM users ORDER BY created_at DESC, id ASC, name ASC, updated_at ASC", | ||
}, | ||
{ | ||
name: "safe_columns_should_work", | ||
build: func() *Builder { | ||
b := New("SELECT * FROM users") | ||
b.OrderBy("id", "created_at", "updated_at"). | ||
Asc("id", "name"). | ||
Desc("created_at", "unsafe_input"). | ||
Asc("updated_at") | ||
|
||
return b | ||
}, | ||
wanted: "SELECT * FROM users ORDER BY id ASC, created_at DESC, updated_at ASC", | ||
}, | ||
} | ||
|
||
for _, test := range tests { | ||
t.Run(test.name, func(t *testing.T) { | ||
actual := test.build().String() | ||
|
||
require.Equal(t, test.wanted, actual) | ||
}) | ||
} | ||
} |