-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathfilter.go
181 lines (176 loc) · 5.4 KB
/
filter.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
package WebGuard
import (
"fmt"
"github.com/spf13/viper"
"io/ioutil"
"net/http"
"path/filepath"
"strings"
"time"
)
// WebGuardFilter 过滤器,判断规则是否可用
// req 请求接收 rules 指定的规则配置路径
func WebGuardFilter(cfg *viper.Viper, req *http.Request) (bool, error) {
if req == nil {
return false, fmt.Errorf("request is nil pointer")
}
// 配置读取完毕,我们接下来验证规则
var (
allowHost = cfg.GetString("proxy-rules.allow-host")
allowLocation = cfg.GetString("proxy-rules.allow-location")
allowTime = cfg.GetString("proxy-rules.allow-time")
allowPath = cfg.GetString("proxy-rules.allow-path")
allowIpList = cfg.GetString("proxy-rules.allow-ip-list")
allowTokenHeader = cfg.GetStringMapString("proxy-rules.allow-token-header")
allowUserAgent = cfg.GetStringSlice("proxy-rules.allow-user-agent")
refuseIpList = cfg.GetString("proxy-rules.refuse-ip-list")
allowMaxLength = cfg.GetInt("proxy-rules.allow-body-max-length")
allowMinLength = cfg.GetInt("proxy-rules.allow-body-min-length")
allowFileExt = cfg.GetString("proxy-rules.allow-file-ext")
)
if allowHost != "" && allowHost != "*" {
hosts := strings.Split(allowHost, ",")
var isHost = false
for _, h := range hosts {
if req.Host == h {
isHost = true
break
}
}
if isHost == false {
return false, fmt.Errorf("allow-host rules trigger")
}
}
address := strings.Split(req.RemoteAddr, ":")[0]
if req.Header.Get("X-Forwarded-For") != "" {
address = strings.Split(strings.TrimSpace(strings.Split(req.Header.Get("X-Forwarded-For"), ",")[0]), ":")[0]
}
// 判断是否在允许的allow-location中
if allowLocation != "" && allowLocation != "*" {
if !LoopUpLocation(strings.Split(allowLocation, ","), address) {
return false, fmt.Errorf("allow-location rules trigger")
}
}
// 判断是否在允许的时间段内
if allowTime != "" && allowTime != "*" {
num := strings.Split(allowTime, "-")
afterTime, _ := time.Parse("15:04", strings.TrimSpace(num[0]))
beforeTime, _ := time.Parse("15:04", strings.TrimSpace(num[1]))
nowTime, _ := time.Parse("15:04", strings.TrimSpace(fmt.Sprintf("%d:%d", time.Now().Hour(), time.Now().Minute())))
if nowTime.After(afterTime) && nowTime.Before(beforeTime) {
} else {
return false, fmt.Errorf("allow-time rules trigger")
}
}
// 检查白名单ip列表
if allowIpList != "" && allowIpList != "*" {
ipList := strings.Split(allowIpList, ",")
var newList []string
for _, ip := range ipList {
if strings.Contains(ip, "/") || strings.Contains(ip, "-") {
if ipSlices, err := IPIntoSlices(ip); err != nil {
return false, fmt.Errorf("allow-ip-list format is error")
} else {
newList = append(newList, ipSlices...)
}
} else {
newList = append(newList, ip)
}
}
if _, status := _find(newList, address); status == false {
return false, fmt.Errorf("allow-ip-list rules trigger")
}
}
// 检查header头中是否包含指定的头
if len(allowTokenHeader) != 0 {
for k, v := range allowTokenHeader {
// 获取到值,但是值不为所设置到的值
if req.Header.Get(k) != "" {
if req.Header.Get(k) != v {
return false, fmt.Errorf("not meeting allow-token-header")
}
} else {
//没有找到这个key的话,我们也返回
return false, fmt.Errorf("allow-token-header rules trigger")
}
}
}
// 检查是否包含指定的user-agent
if !(len(allowUserAgent) == 1 && allowUserAgent[0] == "*") {
var isUserAgent = false
for _, ua := range allowUserAgent {
if req.UserAgent() == ua {
isUserAgent = true
break
}
}
if isUserAgent == false {
return false, fmt.Errorf("allow-user-agent rules trigger")
}
}
// 判断uri是否指定
if allowPath != "" && allowPath != "*" {
path := strings.Split(allowPath, ",")
var isPath = false
for _, p := range path {
if strings.Contains(req.URL.Path, p) {
isPath = true
break
}
}
if isPath == false {
return false, fmt.Errorf("allow-path rules trigger")
}
}
// 判断ext文件是否正确
if allowFileExt != "" && allowFileExt != ".*" {
fileExt := strings.Split(allowFileExt, ",")
var isExt = false
for _, p := range fileExt {
if filepath.Ext(req.URL.Path) == p {
isExt = true
break
}
}
if isExt == false {
return false, fmt.Errorf("allow-file-ext rules trigger")
}
}
// 黑名单判断
if refuseIpList != "" && refuseIpList != "-" {
ipList := strings.Split(refuseIpList, ",") // 获取拒绝的
var newList []string
for _, ip := range ipList {
if strings.Contains(ip, "/") || strings.Contains(ip, "-") {
if ipSlices, err := IPIntoSlices(ip); err != nil {
return false, fmt.Errorf("refuse-ip-list format is error")
} else {
newList = append(newList, ipSlices...)
}
} else {
newList = append(newList, ip)
}
}
if _, status := _find(newList, address); status == true {
return false, fmt.Errorf("refuse-ip-list rules trigger")
}
}
if allowMaxLength != 0 && allowMinLength != 0 {
body, err := ioutil.ReadAll(req.Body)
if err != nil {
return false, fmt.Errorf("parse request body length is error")
}
if len(body) < allowMinLength || len(body) > allowMaxLength {
return false, fmt.Errorf("allow-body-length rules trigger")
}
}
return true, nil
}
func _find(slice []string, val string) (int, bool) {
for i, item := range slice {
if item == val {
return i, true
}
}
return -1, false
}