diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml new file mode 100644 index 00000000..48b62928 --- /dev/null +++ b/.github/workflows/golangci-lint.yml @@ -0,0 +1,25 @@ +name: golangci-lint +on: + push: + branches: + - master + pull_request: + +permissions: + contents: read + # Optional: allow read access to pull request. Use with `only-new-issues` option. + # pull-requests: read + +jobs: + golangci: + name: lint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version: stable + - name: golangci-lint + uses: golangci/golangci-lint-action@v6 + with: + version: v1.62.2 diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 00000000..48cbd710 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,225 @@ +# This configuration file is not a recommendation. +# +# We intentionally use a limited set of linters. +# This configuration file is used with different version of golangci-lint to avoid regressions: +# the linters can change between version, +# their configuration may be not compatible or their reports can be different, +# and this can break some of our tests. +# Also, some linters are not relevant for the project (e.g. linters related to SQL). +# +# We have specific constraints, so we use a specific configuration. +# +# See the file `.golangci.reference.yml` to have a list of all available configuration options. + +linters: + disable-all: true + # This list of linters is not a recommendation (same thing for all this configuration file). + # We intentionally use a limited set of linters. + # See the comment on top of this file. + enable: + - bodyclose + - copyloopvar + - depguard + - dogsled + - dupl + - errcheck + - errorlint + - funlen + - gocheckcompilerdirectives + #- gochecknoinits + - goconst + - gocritic + - gocyclo + - godox + - gofmt + - goimports + #- mnd + - goprintffuncname + - gosec + - gosimple + - govet + - intrange + - ineffassign + - lll + - misspell + - nakedret + - noctx + - nolintlint + - revive + - staticcheck + - stylecheck + - testifylint + - unconvert + - unparam + - unused + - whitespace + +linters-settings: + depguard: + rules: + logger: + deny: + - pkg: "github.com/pkg/errors" + desc: Should be replaced by standard lib errors package. + - pkg: "github.com/instana/testify" + desc: It's a fork of github.com/stretchr/testify. + files: + # logrus is allowed to use only in logutils package. + - "!**/pkg/logutils/**.go" + dupl: + threshold: 200 + funlen: + lines: -1 # the number of lines (code + empty lines) is not a right metric and leads to code without empty line or one-liner. + statements: 9999999 + goconst: + min-len: 2 + min-occurrences: 3 + gocritic: + enabled-tags: + - diagnostic + - experimental + - opinionated + - performance + - style + disabled-checks: + - dupImport # https://github.com/go-critic/go-critic/issues/845 + - ifElseChain + - octalLiteral + - whyNoLint + gocyclo: + min-complexity: 9999999 + godox: + keywords: + - FIXME + gofmt: + rewrite-rules: + - pattern: "interface{}" + replacement: "any" + goimports: + local-prefixes: github.com/golangci/golangci-lint + mnd: + # don't include the "operation" and "assign" + checks: + - argument + - case + - condition + - return + ignored-numbers: + - "0" + - "1" + - "2" + - "3" + ignored-functions: + - strings.SplitN + govet: + settings: + printf: + funcs: + - (github.com/golangci/golangci-lint/pkg/logutils.Log).Infof + - (github.com/golangci/golangci-lint/pkg/logutils.Log).Warnf + - (github.com/golangci/golangci-lint/pkg/logutils.Log).Errorf + - (github.com/golangci/golangci-lint/pkg/logutils.Log).Fatalf + enable: + - nilness + - shadow + errorlint: + asserts: false + lll: + line-length: 140 + misspell: + locale: US + ignore-words: + - "importas" # linter name + nolintlint: + allow-unused: false # report any unused nolint directives + require-explanation: true # require an explanation for nolint directives + require-specific: true # require nolint directives to be specific about which linter is being skipped + revive: + rules: + - name: indent-error-flow + - name: unexported-return + disabled: true + - name: unused-parameter + - name: unused-receiver + +issues: + exclude-rules: + - path: (.+)_test\.go + linters: + - dupl + - mnd + - lll + + # The logic of creating a linter is similar between linters, it's not duplication. + - path: pkg/golinters + linters: + - dupl + + # Deprecated configuration options. + - path: pkg/commands/run.go + linters: [staticcheck] + text: "SA1019: c.cfg.Run.ShowStats is deprecated: use Output.ShowStats instead." + + # Deprecated linter options. + - path: pkg/golinters/errcheck/errcheck.go + linters: [staticcheck] + text: "SA1019: errCfg.Exclude is deprecated: use ExcludeFunctions instead" + - path: pkg/golinters/errcheck/errcheck.go + linters: [staticcheck] + text: "SA1019: errCfg.Ignore is deprecated: use ExcludeFunctions instead" + - path: pkg/golinters/govet/govet.go + linters: [staticcheck] + text: "SA1019: cfg.CheckShadowing is deprecated: the linter should be enabled inside Enable." + - path: pkg/golinters/godot/godot.go + linters: [staticcheck] + text: "SA1019: settings.CheckAll is deprecated: use Scope instead" + - path: pkg/golinters/gci/gci.go + linters: [staticcheck] + text: "SA1019: settings.LocalPrefixes is deprecated: use Sections instead." + - path: pkg/golinters/mnd/mnd.go + linters: [staticcheck] + text: "SA1019: settings.Settings is deprecated: use root level settings instead." + - path: pkg/golinters/mnd/mnd.go + linters: [staticcheck] + text: "SA1019: config.GoMndSettings is deprecated: use MndSettings." + + # Related to `run.go`, it cannot be removed. + - path: pkg/golinters/gofumpt/gofumpt.go + linters: [staticcheck] + text: "SA1019: settings.LangVersion is deprecated: use the global `run.go` instead." + - path: pkg/golinters/internal/staticcheck_common.go + linters: [staticcheck] + text: "SA1019: settings.GoVersion is deprecated: use the global `run.go` instead." + - path: pkg/lint/lintersdb/manager.go + linters: [staticcheck] + text: "SA1019: (.+).(GoVersion|LangVersion) is deprecated: use the global `run.go` instead." + + # Based on existing code, the modifications should be limited to make maintenance easier. + - path: pkg/golinters/unused/unused.go + linters: [gocritic] + text: "rangeValCopy: each iteration copies 160 bytes \\(consider pointers or indexing\\)" + + # Related to file sizes. + - path: pkg/goanalysis/runner_loadingpackage.go + linters: [gosec] + text: "G115: integer overflow conversion uintptr -> int" + + # Related to PID. + - path: test/bench/bench_test.go + linters: [gosec] + text: "G115: integer overflow conversion int -> int32" + + # Related to the result of computation but divided multiple times by 1024. + - path: test/bench/bench_test.go + linters: [gosec] + text: "G115: integer overflow conversion uint64 -> int" + + exclude-dirs: + - test/testdata_etc # test files + - internal/go # extracted from Go code + - internal/x # extracted from x/tools code + exclude-files: + - pkg/goanalysis/runner_checker.go # extracted from x/tools code + +run: + timeout: 5m diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..1a653cd9 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,6 @@ +{ + "go.lintTool": "golangci-lint", + "go.lintFlags": [ + "--fast" + ] +} \ No newline at end of file diff --git a/api/apimodel.go b/api/apimodel.go index 38c3be67..2f7891b2 100644 --- a/api/apimodel.go +++ b/api/apimodel.go @@ -13,6 +13,26 @@ const ( RuleNotModified = "rules not modified" ) +const ( + TransportProtocolTCP = "tcp" + TransportProtocolWS = "ws" + TransportProtocolGRPC = "grpc" +) + +const ( + NodeTypeV2ray = "V2ray" + NodeTypeTrojan = "Trojan" + NodeTypeShadowsocks = "Shadowsocks" + NodeTypeShadowsocksPlugin = "Shadowsocks-Plugin" + NodeTypeVLess = "Vless" + NodeTypeVMESS = "Vmess" + NodeTypeDokodemo = "dokodemo-door" +) + +const ( + SecurityTypeTLS = "tls" +) + // Config API config type Config struct { APIHost string `mapstructure:"ApiHost"` @@ -58,7 +78,7 @@ type NodeInfo struct { ServiceName string Method string Header json.RawMessage - HttpHeaders map[string]*conf.StringList + HTTPHeaders map[string]*conf.StringList Headers map[string]string NameServerConfig []*conf.NameServerConfig EnableREALITY bool diff --git a/api/fac/fac.go b/api/fac/fac.go index 6637441a..1cec0c1a 100644 --- a/api/fac/fac.go +++ b/api/fac/fac.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "math" "os" "reflect" "regexp" @@ -55,7 +56,7 @@ func New(apiConfig *api.Config) *APIClient { } else { client.SetTimeout(5 * time.Second) } - client.OnError(func(req *resty.Request, err error) { + client.OnError(func(_ *resty.Request, err error) { var v *resty.ResponseError if errors.As(err, &v) { // v.Response contains the last response from the server @@ -90,14 +91,14 @@ func New(apiConfig *api.Config) *APIClient { } // readLocalRuleList reads the local rule list file -func readLocalRuleList(path string) (LocalRuleList []api.DetectRule) { - LocalRuleList = make([]api.DetectRule, 0) +func readLocalRuleList(path string) (localRuleList []api.DetectRule) { + localRuleList = make([]api.DetectRule, 0) if path != "" { // open the file file, err := os.Open(path) defer func(file *os.File) { - err := file.Close() + err = file.Close() if err != nil { log.Printf("Error when closing file: %s", err) } @@ -105,26 +106,26 @@ func readLocalRuleList(path string) (LocalRuleList []api.DetectRule) { // handle errors while opening if err != nil { log.Printf("Error when opening file: %s", err) - return LocalRuleList + return localRuleList } fileScanner := bufio.NewScanner(file) // read line by line for fileScanner.Scan() { - LocalRuleList = append(LocalRuleList, api.DetectRule{ + localRuleList = append(localRuleList, api.DetectRule{ ID: -1, Pattern: regexp.MustCompile(fileScanner.Text()), }) } // handle first encountered error while reading if err := fileScanner.Err(); err != nil { - log.Fatalf("Error while reading file: %s", err) - return + log.Errorf("Error while reading file: %s", err) + return localRuleList } } - return LocalRuleList + return localRuleList } // Describe return a description of the client @@ -143,12 +144,12 @@ func (c *APIClient) assembleURL(path string) string { func (c *APIClient) parseResponse(res *resty.Response, path string, err error) (*Response, error) { if err != nil { - return nil, fmt.Errorf("request %s failed: %s", c.assembleURL(path), err) + return nil, fmt.Errorf("request %s failed: %w", c.assembleURL(path), err) } if res.StatusCode() > 400 { body := res.Body() - return nil, fmt.Errorf("request %s failed: %s, %v", c.assembleURL(path), string(body), err) + return nil, fmt.Errorf("request %s failed: %s, %w", c.assembleURL(path), string(body), err) } response := res.Result().(*Response) @@ -161,7 +162,7 @@ func (c *APIClient) parseResponse(res *resty.Response, path string, err error) ( // GetNodeInfo will pull NodeInfo Config from ssPanel func (c *APIClient) GetNodeInfo() (nodeInfo *api.NodeInfo, err error) { - path := fmt.Sprintf("/mod_mu/nodes/%d/info", c.NodeID) + path := fmt.Sprintf("/mod_mu/nodes/%s/info", c.NodeID) res, err := c.client.R(). SetResult(&Response{}). SetHeader("If-None-Match", c.eTags["node"]). @@ -183,8 +184,8 @@ func (c *APIClient) GetNodeInfo() (nodeInfo *api.NodeInfo, err error) { nodeInfoResponse := new(NodeInfoResponse) - if err := json.Unmarshal(response.Data, nodeInfoResponse); err != nil { - return nil, fmt.Errorf("unmarshal %s failed: %s", reflect.TypeOf(nodeInfoResponse), err) + if err = json.Unmarshal(response.Data, nodeInfoResponse); err != nil { + return nil, fmt.Errorf("unmarshal %s failed: %w", reflect.TypeOf(nodeInfoResponse), err) } // determine ssPanel version, if disable custom config or version < 2021.11, then use old api @@ -215,20 +216,21 @@ func (c *APIClient) GetNodeInfo() (nodeInfo *api.NodeInfo, err error) { nodeInfo, err = c.ParseSSPanelNodeInfo(nodeInfoResponse) if err != nil { res, _ := json.Marshal(nodeInfoResponse) - return nil, fmt.Errorf("parse node info failed: %s, \nError: %s, \nPlease check the doc of custom_config for help: https://xrayr-project.github.io/XrayR-doc/dui-jie-sspanel/sspanel/sspanel_custom_config", string(res), err) + return nil, fmt.Errorf("parse node info failed: %s, \nError: %s, \nPlease check the doc of custom_config for help: "+ + "https://xrayr-project.github.io/XrayR-doc/dui-jie-sspanel/sspanel/sspanel_custom_config", string(res), err) } } if err != nil { res, _ := json.Marshal(nodeInfoResponse) - return nil, fmt.Errorf("parse node info failed: %s, \nError: %s", string(res), err) + return nil, fmt.Errorf("parse node info failed: %s, \nError: %w", string(res), err) } return nodeInfo, nil } // GetUserList will pull user form ssPanel -func (c *APIClient) GetUserList() (UserList *[]api.UserInfo, err error) { +func (c *APIClient) GetUserList() (userList *[]api.UserInfo, err error) { path := "/mod_mu/users" res, err := c.client.R(). SetQueryParam("node_id", c.NodeID). @@ -252,10 +254,10 @@ func (c *APIClient) GetUserList() (UserList *[]api.UserInfo, err error) { userListResponse := new([]UserResponse) - if err := json.Unmarshal(response.Data, userListResponse); err != nil { - return nil, fmt.Errorf("unmarshal %s failed: %s", reflect.TypeOf(userListResponse), err) + if err = json.Unmarshal(response.Data, userListResponse); err != nil { + return nil, fmt.Errorf("unmarshal %s failed: %w", reflect.TypeOf(userListResponse), err) } - userList, err := c.ParseUserListResponse(userListResponse) + userList, err = c.ParseUserListResponse(userListResponse) if err != nil { res, _ := json.Marshal(userListResponse) return nil, fmt.Errorf("parse user list failed: %s", string(res)) @@ -267,7 +269,7 @@ func (c *APIClient) GetUserList() (UserList *[]api.UserInfo, err error) { func (c *APIClient) ReportNodeStatus(nodeStatus *api.NodeStatus) (err error) { // Determine whether a status report is in need if compareVersion(c.version, "2023.2") == -1 { - path := fmt.Sprintf("/mod_mu/nodes/%d/info", c.NodeID) + path := fmt.Sprintf("/mod_mu/nodes/%s/info", c.NodeID) systemLoad := SystemLoad{ Uptime: strconv.FormatUint(nodeStatus.Uptime, 10), Load: fmt.Sprintf("%.2f %.2f %.2f", nodeStatus.CPU/100, nodeStatus.Mem/100, nodeStatus.Disk/100), @@ -319,7 +321,6 @@ func (c *APIClient) ReportNodeOnlineUsers(onlineUserList *[]api.OnlineUser) erro // ReportUserTraffic reports the user traffic func (c *APIClient) ReportUserTraffic(userTraffic *[]api.UserTraffic) error { - data := make([]UserTraffic, len(*userTraffic)) for i, traffic := range *userTraffic { data[i] = UserTraffic{ @@ -370,7 +371,7 @@ func (c *APIClient) GetNodeRule() (*[]api.DetectRule, error) { ruleListResponse := new([]RuleItem) if err := json.Unmarshal(response.Data, ruleListResponse); err != nil { - return nil, fmt.Errorf("unmarshal %s failed: %s", reflect.TypeOf(ruleListResponse), err) + return nil, fmt.Errorf("unmarshal %s failed: %w", reflect.TypeOf(ruleListResponse), err) } for _, r := range *ruleListResponse { @@ -384,7 +385,6 @@ func (c *APIClient) GetNodeRule() (*[]api.DetectRule, error) { // ReportIllegal reports the user illegal behaviors func (c *APIClient) ReportIllegal(detectResultList *[]api.DetectResult) error { - data := make([]IllegalItem, len(*detectResultList)) for i, r := range *detectResultList { data[i] = IllegalItem{ @@ -412,29 +412,35 @@ func (c *APIClient) ParseV2rayNodeResponse(nodeInfoResponse *NodeInfoResponse) ( var enableTLS bool var path, host, transportProtocol, serviceName, HeaderType string var header json.RawMessage - var speedLimit uint64 = 0 + var speedLimit uint64 + if nodeInfoResponse.RawServerString == "" { return nil, fmt.Errorf("no server info in response") } - // nodeInfo.RawServerString = strings.ToLower(nodeInfo.RawServerString) serverConf := strings.Split(nodeInfoResponse.RawServerString, ";") parsedPort, err := strconv.ParseInt(serverConf[1], 10, 32) if err != nil { return nil, err } + if parsedPort < 0 || parsedPort > math.MaxUint32 { + return nil, fmt.Errorf("parsed port %d is out of range for uint32", parsedPort) + } port := uint32(parsedPort) parsedAlterID, err := strconv.ParseInt(serverConf[2], 10, 16) if err != nil { return nil, err } + if parsedAlterID < 0 || parsedAlterID > math.MaxUint16 { + return nil, fmt.Errorf("parsed alterID %d is out of range for uint16", parsedAlterID) + } alterID := uint16(parsedAlterID) // Compatible with more node types config for _, value := range serverConf[3:5] { switch value { - case "tls": + case api.SecurityTypeTLS: enableTLS = true default: if value != "" { @@ -463,6 +469,7 @@ func (c *APIClient) ParseV2rayNodeResponse(nodeInfoResponse *NodeInfoResponse) ( HeaderType = value } } + if c.SpeedLimit > 0 { speedLimit = uint64((c.SpeedLimit * 1000000) / 8) } else { @@ -475,7 +482,7 @@ func (c *APIClient) ParseV2rayNodeResponse(nodeInfoResponse *NodeInfoResponse) ( } if err != nil { - return nil, fmt.Errorf("marshal Header Type %s into config failed: %s", header, err) + return nil, fmt.Errorf("marshal Header Type %s into config failed: %w", header, err) } // Create GeneralNodeInfo @@ -501,7 +508,7 @@ func (c *APIClient) ParseV2rayNodeResponse(nodeInfoResponse *NodeInfoResponse) ( // ParseSSNodeResponse parse the response for the given node info format func (c *APIClient) ParseSSNodeResponse(nodeInfoResponse *NodeInfoResponse) (*api.NodeInfo, error) { var port uint32 = 0 - var speedLimit uint64 = 0 + var speedLimit uint64 var method string path := "/mod_mu/users" res, err := c.client.R(). @@ -518,7 +525,7 @@ func (c *APIClient) ParseSSNodeResponse(nodeInfoResponse *NodeInfoResponse) (*ap userListResponse := new([]UserResponse) if err := json.Unmarshal(response.Data, userListResponse); err != nil { - return nil, fmt.Errorf("unmarshal %s failed: %s", reflect.TypeOf(userListResponse), err) + return nil, fmt.Errorf("unmarshal %s failed: %w", reflect.TypeOf(userListResponse), err) } // init server port @@ -548,15 +555,18 @@ func (c *APIClient) ParseSSNodeResponse(nodeInfoResponse *NodeInfoResponse) (*ap func (c *APIClient) ParseSSPluginNodeResponse(nodeInfoResponse *NodeInfoResponse) (*api.NodeInfo, error) { var enableTLS bool var path, host, transportProtocol string - var speedLimit uint64 = 0 + var speedLimit uint64 serverConf := strings.Split(nodeInfoResponse.RawServerString, ";") parsedPort, err := strconv.ParseInt(serverConf[1], 10, 32) if err != nil { return nil, err } + if parsedPort < 0 || parsedPort > math.MaxUint32 { + return nil, fmt.Errorf("parsed port %d is out of range for uint32", parsedPort) + } port := uint32(parsedPort) - port = port - 1 // Shadowsocks-Plugin requires two ports, one for ss the other for other stream protocol + port-- // Shadowsocks-Plugin requires two ports, one for ss the other for other stream protocol if port <= 0 { return nil, fmt.Errorf("Shadowsocks-Plugin listen port must bigger than 1") } @@ -566,9 +576,9 @@ func (c *APIClient) ParseSSPluginNodeResponse(nodeInfoResponse *NodeInfoResponse case "tls": enableTLS = true case "ws": - transportProtocol = "ws" + transportProtocol = api.TransportProtocolWS case "obfs": - transportProtocol = "tcp" + transportProtocol = api.TransportProtocolTCP } } @@ -614,7 +624,7 @@ func (c *APIClient) ParseTrojanNodeResponse(nodeInfoResponse *NodeInfoResponse) // 域名或IP;port=连接端口#偏移端口|host=xx // gz.aaa.com;port=443#12345|host=hk.aaa.com var p, host, outsidePort, insidePort, transportProtocol, serviceName string - var speedLimit uint64 = 0 + var speedLimit uint64 if nodeInfoResponse.RawServerString == "" { return nil, fmt.Errorf("no server info in response") @@ -639,11 +649,14 @@ func (c *APIClient) ParseTrojanNodeResponse(nodeInfoResponse *NodeInfoResponse) if err != nil { return nil, err } + if parsedPort < 0 || parsedPort > math.MaxUint32 { + return nil, fmt.Errorf("parsed port %d is out of range for uint32", parsedPort) + } port := uint32(parsedPort) serverConf := strings.Split(nodeInfoResponse.RawServerString, ";") extraServerConf := strings.Split(serverConf[1], "|") - transportProtocol = "tcp" + transportProtocol = api.TransportProtocolTCP serviceName = "" for _, item := range extraServerConf { conf := strings.Split(item, "=") @@ -654,7 +667,7 @@ func (c *APIClient) ParseTrojanNodeResponse(nodeInfoResponse *NodeInfoResponse) value := conf[1] switch key { case "grpc": - transportProtocol = "grpc" + transportProtocol = api.TransportProtocolGRPC case "servicename": serviceName = value } @@ -689,8 +702,8 @@ func (c *APIClient) ParseUserListResponse(userInfoResponse *[]UserResponse) (*[] c.access.Unlock() }() - var deviceLimit, localDeviceLimit = 0, 0 - var speedLimit uint64 = 0 + var deviceLimit, localDeviceLimit int + var speedLimit uint64 var userList []api.UserInfo for _, user := range *userInfoResponse { if c.DeviceLimit > 0 { @@ -741,7 +754,7 @@ func (c *APIClient) ParseUserListResponse(userInfoResponse *[]UserResponse) (*[] // Only available for SSPanel version >= 2021.11 func (c *APIClient) ParseSSPanelNodeInfo(nodeInfoResponse *NodeInfoResponse) (*api.NodeInfo, error) { var ( - speedLimit uint64 = 0 + speedLimit uint64 enableTLS, enableVless bool alterID uint16 = 0 transportProtocol string @@ -755,7 +768,7 @@ func (c *APIClient) ParseSSPanelNodeInfo(nodeInfoResponse *NodeInfoResponse) (*a nodeConfig := new(CustomConfig) err := json.Unmarshal(nodeInfoResponse.CustomConfig, nodeConfig) if err != nil { - return nil, fmt.Errorf("custom_config format error: %v", err) + return nil, fmt.Errorf("custom_config format error: %w", err) } if c.SpeedLimit > 0 { @@ -768,7 +781,9 @@ func (c *APIClient) ParseSSPanelNodeInfo(nodeInfoResponse *NodeInfoResponse) (*a if err != nil { return nil, err } - + if parsedPort < 0 || parsedPort > math.MaxUint32 { + return nil, fmt.Errorf("parsed port %d is out of range for uint32", parsedPort) + } port := uint32(parsedPort) switch c.NodeType { diff --git a/api/fac/fac_test.go b/api/fac/fac_test.go index 77de309c..11994685 100644 --- a/api/fac/fac_test.go +++ b/api/fac/fac_test.go @@ -105,7 +105,6 @@ func TestReportReportNodeOnlineUsers(t *testing.T) { IP: fmt.Sprintf("1.1.1.%d", i), } } - // client.Debug() err = client.ReportNodeOnlineUsers(&onlineUserList) if err != nil { t.Error(err) @@ -126,7 +125,6 @@ func TestReportReportUserTraffic(t *testing.T) { Download: 114514, } } - // client.Debug() err = client.ReportUserTraffic(&generalUserTraffic) if err != nil { t.Error(err) diff --git a/api/fac/model.go b/api/fac/model.go index 626e20c3..9e12d645 100644 --- a/api/fac/model.go +++ b/api/fac/model.go @@ -57,7 +57,7 @@ type Response struct { // PostData is the data structure of post data type PostData struct { - Data interface{} `json:"data"` + Data any `json:"data"` } // SystemLoad is the data structure of system load diff --git a/app/mydispatcher/default.go b/app/mydispatcher/default.go index 7a510706..58b54f43 100644 --- a/app/mydispatcher/default.go +++ b/app/mydispatcher/default.go @@ -4,7 +4,9 @@ package mydispatcher import ( "context" + errorsType "errors" "fmt" + "math" "strings" "sync" "time" @@ -32,6 +34,10 @@ import ( var errSniffingTimeout = newError("timeout on sniffing") +const ( + ProtocolFakeDNS = "fakedns" +) + type cachedReader struct { sync.Mutex reader *pipe.Reader @@ -47,7 +53,11 @@ func (r *cachedReader) Cache(b *buf.Buffer) { b.Clear() rawBytes := b.Extend(buf.Size) n := r.cache.Copy(rawBytes) - b.Resize(0, int32(n)) + if n > math.MaxInt32 { + b.Resize(0, math.MaxInt32) + } else { + b.Resize(0, int32(n)) //nolint:gosec //n is not user input + } r.Unlock() } @@ -103,15 +113,19 @@ type DefaultDispatcher struct { RuleManager *rule.Manager } -func init() { - common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { +func init() { //nolint:gochecknoinits // bypass + common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config any) (any, error) { d := new(DefaultDispatcher) - if err := core.RequireFeatures(ctx, func(om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager, dc dns.Client) error { - core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) { - d.fdns = fdns - }) - return d.Init(config.(*Config), om, router, pm, sm, dc) - }); err != nil { + if err := core.RequireFeatures( + ctx, func(om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager, dc dns.Client) error { + err := core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) { + d.fdns = fdns + }) + if err != nil { + return err + } + return d.Init(config.(*Config), om, router, pm, sm, dc) + }); err != nil { return nil, err } return d, nil @@ -119,19 +133,25 @@ func init() { } // Init initializes DefaultDispatcher. -func (d *DefaultDispatcher) Init(config *Config, om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager, dns dns.Client) error { +func (d *DefaultDispatcher) Init( + _ *Config, + om outbound.Manager, + router routing.Router, + pm policy.Manager, + sm stats.Manager, + dnsClient dns.Client) error { d.ohm = om d.router = router d.policy = pm d.stats = sm d.Limiter = limiter.New() d.RuleManager = rule.New() - d.dns = dns + d.dns = dnsClient return nil } // Type implements common.HasType. -func (*DefaultDispatcher) Type() interface{} { +func (*DefaultDispatcher) Type() any { return routing.DispatcherType() } @@ -145,17 +165,20 @@ func (*DefaultDispatcher) Close() error { return nil } -func (d *DefaultDispatcher) getLink(ctx context.Context, network net.Network, sniffing session.SniffingRequest) (*transport.Link, *transport.Link, error) { +func (d *DefaultDispatcher) getLink( + ctx context.Context, + _ net.Network, + _ session.SniffingRequest) (inboundLink, outboundLink *transport.Link, err error) { opt := pipe.OptionsFromContext(ctx) uplinkReader, uplinkWriter := pipe.New(opt...) downlinkReader, downlinkWriter := pipe.New(opt...) - inboundLink := &transport.Link{ + inboundLink = &transport.Link{ Reader: downlinkReader, Writer: uplinkWriter, } - outboundLink := &transport.Link{ + outboundLink = &transport.Link{ Reader: uplinkReader, Writer: downlinkWriter, } @@ -166,15 +189,21 @@ func (d *DefaultDispatcher) getLink(ctx context.Context, network net.Network, sn user = sessionInbound.User } - if user != nil && len(user.Email) > 0 { + if user != nil && user.Email != "" { // Speed Limit and Device Limit bucket, ok, reject := d.Limiter.GetUserBucket(sessionInbound.Tag, user.Email, sessionInbound.Source.Address.IP().String()) if reject { errors.LogWarning(ctx, "Devices reach the limit: ", user.Email) common.Close(outboundLink.Writer) common.Close(inboundLink.Writer) - common.Interrupt(outboundLink.Reader) - common.Interrupt(inboundLink.Reader) + err = common.Interrupt(outboundLink.Reader) + if err != nil { + return nil, nil, err + } + err = common.Interrupt(inboundLink.Reader) + if err != nil { + return nil, nil, err + } return nil, nil, newError("Devices reach the limit: ", user.Email) } if ok { @@ -206,10 +235,14 @@ func (d *DefaultDispatcher) getLink(ctx context.Context, network net.Network, sn return inboundLink, outboundLink, nil } -func (d *DefaultDispatcher) shouldOverride(ctx context.Context, result SniffResult, request session.SniffingRequest, destination net.Destination) bool { +func (d *DefaultDispatcher) shouldOverride( + ctx context.Context, + result SniffResult, + request session.SniffingRequest, + destination net.Destination) bool { domain := result.Domain() for _, d := range request.ExcludeForDomain { - if strings.ToLower(domain) == d { + if strings.EqualFold(domain, d) { return false } } @@ -221,7 +254,7 @@ func (d *DefaultDispatcher) shouldOverride(ctx context.Context, result SniffResu if strings.HasPrefix(protocolString, p) || strings.HasPrefix(p, protocolString) { return true } - if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && protocolString != "bittorrent" && p == "fakedns" && + if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && protocolString != "bittorrent" && p == ProtocolFakeDNS && destination.Address.Family().IsIP() && fkr0.IsIPInIPPool(destination.Address) { errors.LogInfo(ctx, "Using sniffer ", protocolString, " since the fake DNS missed") return true @@ -256,18 +289,18 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin } sniffingRequest := content.SniffingRequest - inbound, outbound, err := d.getLink(ctx, destination.Network, sniffingRequest) + inbound, outboundLink, err := d.getLink(ctx, destination.Network, sniffingRequest) if err != nil { return nil, err } if !sniffingRequest.Enabled { - go d.routedDispatch(ctx, outbound, destination) + go d.routedDispatch(ctx, outboundLink, destination) } else { go func() { cReader := &cachedReader{ - reader: outbound.Reader.(*pipe.Reader), + reader: outboundLink.Reader.(*pipe.Reader), } - outbound.Reader = cReader + outboundLink.Reader = cReader result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network) if err == nil { content.Protocol = result.Protocol() @@ -276,20 +309,20 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin domain := result.Domain() errors.LogInfo(ctx, "sniffed domain: ", domain) destination.Address = net.ParseAddress(domain) - if sniffingRequest.RouteOnly && result.Protocol() != "fakedns" { + if sniffingRequest.RouteOnly && result.Protocol() != ProtocolFakeDNS { ob.RouteTarget = destination } else { ob.Target = destination } } - d.routedDispatch(ctx, outbound, destination) + d.routedDispatch(ctx, outboundLink, destination) }() } return inbound, nil } // DispatchLink implements routing.Dispatcher. -func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.Destination, outbound *transport.Link) error { +func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.Destination, outboundLink *transport.Link) error { if !destination.IsValid() { return newError("Dispatcher: Invalid destination.") } @@ -308,13 +341,13 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De } sniffingRequest := content.SniffingRequest if !sniffingRequest.Enabled { - go d.routedDispatch(ctx, outbound, destination) + go d.routedDispatch(ctx, outboundLink, destination) } else { go func() { cReader := &cachedReader{ - reader: outbound.Reader.(*pipe.Reader), + reader: outboundLink.Reader.(*pipe.Reader), } - outbound.Reader = cReader + outboundLink.Reader = cReader result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network) if err == nil { content.Protocol = result.Protocol() @@ -323,13 +356,13 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De domain := result.Domain() errors.LogInfo(ctx, "sniffed domain: ", domain) destination.Address = net.ParseAddress(domain) - if sniffingRequest.RouteOnly && result.Protocol() != "fakedns" { + if sniffingRequest.RouteOnly && result.Protocol() != ProtocolFakeDNS { ob.RouteTarget = destination } else { ob.Target = destination } } - d.routedDispatch(ctx, outbound, destination) + d.routedDispatch(ctx, outboundLink, destination) }() } @@ -363,7 +396,7 @@ func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool, netw cReader.Cache(payload) if !payload.IsEmpty() { result, err := sniffer.Sniff(ctx, payload.Bytes(), network) - if err != common.ErrNoClue { + if !errorsType.Is(err, common.ErrNoClue) { return result, err } } @@ -406,9 +439,11 @@ func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport. if sessionInbound.User != nil { if d.RuleManager.Detect(sessionInbound.Tag, destination.String(), sessionInbound.User.Email) { errors.LogError(ctx, fmt.Sprintf("User %s access %s reject by rule", sessionInbound.User.Email, destination.String())) - newError("destination is reject by rule") common.Close(link.Writer) - common.Interrupt(link.Reader) + err := common.Interrupt(link.Reader) + if err != nil { + errors.LogError(ctx, "interrupt link reader failed: ", err) + } return } } @@ -425,7 +460,10 @@ func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport. } else { errors.LogError(ctx, "non existing tag for platform initialized detour: ", forcedOutboundTag) common.Close(link.Writer) - common.Interrupt(link.Reader) + err := common.Interrupt(link.Reader) + if err != nil { + errors.LogError(ctx, "interrupt link reader failed: ", err) + } return } } else if d.router != nil { @@ -455,7 +493,10 @@ func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport. if handler == nil { errors.LogInfo(ctx, "default outbound handler not exist") common.Close(link.Writer) - common.Interrupt(link.Reader) + err := common.Interrupt(link.Reader) + if err != nil { + errors.LogError(ctx, "interrupt link reader failed: ", err) + } return } diff --git a/app/mydispatcher/errors.generated.go b/app/mydispatcher/errors.generated.go index 516bc40f..761d0b31 100644 --- a/app/mydispatcher/errors.generated.go +++ b/app/mydispatcher/errors.generated.go @@ -2,6 +2,6 @@ package mydispatcher import "github.com/xtls/xray-core/common/errors" -func newError(values ...interface{}) *errors.Error { +func newError(values ...any) *errors.Error { return errors.New(values...) } diff --git a/app/mydispatcher/fakednssniffer.go b/app/mydispatcher/fakednssniffer.go index b2683220..2471e0e9 100644 --- a/app/mydispatcher/fakednssniffer.go +++ b/app/mydispatcher/fakednssniffer.go @@ -26,7 +26,7 @@ func newFakeDNSSniffer(ctx context.Context) (protocolSnifferWithMetadata, error) errNotInit := newError("FakeDNSEngine is not initialized, but such a sniffer is used").AtError() return protocolSnifferWithMetadata{}, errNotInit } - return protocolSnifferWithMetadata{protocolSniffer: func(ctx context.Context, bytes []byte) (SniffResult, error) { + return protocolSnifferWithMetadata{protocolSniffer: func(ctx context.Context, _ []byte) (SniffResult, error) { outbounds := session.OutboundsFromContext(ctx) ob := outbounds[len(outbounds)-1] Target := ob.Target @@ -87,8 +87,10 @@ func (f DNSThenOthersSniffResult) Domain() string { return f.domainName } -func newFakeDNSThenOthers(ctx context.Context, fakeDNSSniffer protocolSnifferWithMetadata, others []protocolSnifferWithMetadata) ( - protocolSnifferWithMetadata, error) { // nolint: unparam +func newFakeDNSThenOthers( + ctx context.Context, + fakeDNSSniffer protocolSnifferWithMetadata, + others []protocolSnifferWithMetadata) (protocolSnifferWithMetadata, error) { //nolint: unparam // non failable // ctx may be used in the future _ = ctx return protocolSnifferWithMetadata{ diff --git a/app/mydispatcher/sniffer.go b/app/mydispatcher/sniffer.go index f808ce77..5b9fc586 100644 --- a/app/mydispatcher/sniffer.go +++ b/app/mydispatcher/sniffer.go @@ -2,6 +2,7 @@ package mydispatcher import ( "context" + "errors" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/net" @@ -34,11 +35,11 @@ type Sniffer struct { func NewSniffer(ctx context.Context) *Sniffer { ret := &Sniffer{ sniffer: []protocolSnifferWithMetadata{ - {func(c context.Context, b []byte) (SniffResult, error) { return http.SniffHTTP(b) }, false, net.Network_TCP}, - {func(c context.Context, b []byte) (SniffResult, error) { return tls.SniffTLS(b) }, false, net.Network_TCP}, - {func(c context.Context, b []byte) (SniffResult, error) { return bittorrent.SniffBittorrent(b) }, false, net.Network_TCP}, - {func(c context.Context, b []byte) (SniffResult, error) { return quic.SniffQUIC(b) }, false, net.Network_UDP}, - {func(c context.Context, b []byte) (SniffResult, error) { return bittorrent.SniffUTP(b) }, false, net.Network_UDP}, + {func(_ context.Context, b []byte) (SniffResult, error) { return http.SniffHTTP(b) }, false, net.Network_TCP}, + {func(_ context.Context, b []byte) (SniffResult, error) { return tls.SniffTLS(b) }, false, net.Network_TCP}, + {func(_ context.Context, b []byte) (SniffResult, error) { return bittorrent.SniffBittorrent(b) }, false, net.Network_TCP}, + {func(_ context.Context, b []byte) (SniffResult, error) { return quic.SniffQUIC(b) }, false, net.Network_UDP}, + {func(_ context.Context, b []byte) (SniffResult, error) { return bittorrent.SniffUTP(b) }, false, net.Network_UDP}, }, } if sniffer, err := newFakeDNSSniffer(ctx); err == nil { @@ -62,7 +63,7 @@ func (s *Sniffer) Sniff(c context.Context, payload []byte, network net.Network) continue } result, err := s(c, payload) - if err == common.ErrNoClue { + if errors.Is(err, common.ErrNoClue) { pendingSniffer = append(pendingSniffer, si) continue } @@ -89,7 +90,7 @@ func (s *Sniffer) SniffMetadata(c context.Context) (SniffResult, error) { continue } result, err := s(c, nil) - if err == common.ErrNoClue { + if errors.Is(err, common.ErrNoClue) { pendingSniffer = append(pendingSniffer, si) continue } @@ -107,7 +108,7 @@ func (s *Sniffer) SniffMetadata(c context.Context) (SniffResult, error) { return nil, errUnknownContent } -func CompositeResult(domainResult SniffResult, protocolResult SniffResult) SniffResult { +func CompositeResult(domainResult, protocolResult SniffResult) SniffResult { return &compositeResult{domainResult: domainResult, protocolResult: protocolResult} } diff --git a/app/mydispatcher/stats.go b/app/mydispatcher/stats.go index 5296ba66..a896da6a 100644 --- a/app/mydispatcher/stats.go +++ b/app/mydispatcher/stats.go @@ -1,6 +1,7 @@ package mydispatcher import ( + "github.com/sirupsen/logrus" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/features/stats" @@ -21,5 +22,8 @@ func (w *SizeStatWriter) Close() error { } func (w *SizeStatWriter) Interrupt() { - common.Interrupt(w.Writer) + err := common.Interrupt(w.Writer) + if err != nil { + logrus.Error(newError("failed to interrupt writer").Base(err)) + } } diff --git a/cmd/root.go b/cmd/root.go index 3fb87c4b..cf04e6c4 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -23,7 +23,7 @@ var ( cfgFile string rootCmd = &cobra.Command{ Use: "XrayR", - Run: func(cmd *cobra.Command, args []string) { + Run: func(_ *cobra.Command, _ []string) { if err := run(); err != nil { log.Fatal(err) } @@ -55,7 +55,6 @@ func getConfig() *viper.Viper { config.SetConfigName("config") config.SetConfigType("yml") config.AddConfigPath(".") - } if err := config.ReadInConfig(); err != nil { @@ -73,7 +72,7 @@ func run() error { config := getConfig() panelConfig := &panel.Config{} if err := config.Unmarshal(panelConfig); err != nil { - return fmt.Errorf("Parse config file %v failed: %s \n", cfgFile, err) + return fmt.Errorf("parse config file %v failed: %w", cfgFile, err) } if panelConfig.LogConfig.Level == "debug" { @@ -110,7 +109,7 @@ func run() error { runtime.GC() // Running backend osSignals := make(chan os.Signal, 1) - signal.Notify(osSignals, os.Interrupt, os.Kill, syscall.SIGTERM) + signal.Notify(osSignals, os.Interrupt, syscall.SIGTERM) <-osSignals return nil diff --git a/cmd/version.go b/cmd/version.go index 457eb361..e53e0b22 100644 --- a/cmd/version.go +++ b/cmd/version.go @@ -16,7 +16,7 @@ func init() { rootCmd.AddCommand(&cobra.Command{ Use: "version", Short: "Print current version of XrayR", - Run: func(cmd *cobra.Command, args []string) { + Run: func(_ *cobra.Command, _ []string) { showVersion() }, }) diff --git a/cmd/x25519.go b/cmd/x25519.go index 0a54bb27..ab003da4 100644 --- a/cmd/x25519.go +++ b/cmd/x25519.go @@ -15,7 +15,7 @@ var ( x25519Cmd = &cobra.Command{ Use: "x25519", Short: "Generate key pair for x25519 key exchange", - Run: func(cmd *cobra.Command, args []string) { + Run: func(_ *cobra.Command, _ []string) { if err := x25519(); err != nil { fmt.Println(err) } diff --git a/common/limiter/limiter.go b/common/limiter/limiter.go index f919adbb..d32bac5f 100644 --- a/common/limiter/limiter.go +++ b/common/limiter/limiter.go @@ -50,7 +50,11 @@ func New() *Limiter { } } -func (l *Limiter) AddInboundLimiter(tag string, nodeSpeedLimit uint64, userList *[]api.UserInfo, globalLimit *GlobalDeviceLimitConfig) error { +func (l *Limiter) AddInboundLimiter( + tag string, + nodeSpeedLimit uint64, + userList *[]api.UserInfo, + globalLimit *GlobalDeviceLimitConfig) error { inboundInfo := &InboundInfo{ Tag: tag, NodeSpeedLimit: nodeSpeedLimit, @@ -135,17 +139,17 @@ func (l *Limiter) GetOnlineDevice(tag string) (*[]api.OnlineUser, error) { if value, ok := l.InboundInfo.Load(tag); ok { inboundInfo := value.(*InboundInfo) // Clear Speed Limiter bucket for users who are not online - inboundInfo.BucketHub.Range(func(key, value interface{}) bool { + inboundInfo.BucketHub.Range(func(key, _ any) bool { email := key.(string) if _, exists := inboundInfo.UserOnlineIP.Load(email); !exists { inboundInfo.BucketHub.Delete(email) } return true }) - inboundInfo.UserOnlineIP.Range(func(key, value interface{}) bool { + inboundInfo.UserOnlineIP.Range(func(key, value any) bool { email := key.(string) ipMap := value.(*sync.Map) - ipMap.Range(func(key, value interface{}) bool { + ipMap.Range(func(key, value any) bool { uid := value.(int) ip := key.(string) onlineUser = append(onlineUser, api.OnlineUser{UID: uid, IP: ip}) @@ -161,7 +165,7 @@ func (l *Limiter) GetOnlineDevice(tag string) (*[]api.OnlineUser, error) { return &onlineUser, nil } -func (l *Limiter) GetUserBucket(tag string, email string, ip string) (limiter *rate.Limiter, SpeedLimit bool, Reject bool) { +func (l *Limiter) GetUserBucket(tag, email, ip string) (limiter *rate.Limiter, speedLimit, reject bool) { if value, ok := l.InboundInfo.Load(tag); ok { var ( userLimit uint64 = 0 @@ -187,7 +191,7 @@ func (l *Limiter) GetUserBucket(tag string, email string, ip string) (limiter *r // If this is a new ip if _, ok := ipMap.LoadOrStore(ip, uid); !ok { counter := 0 - ipMap.Range(func(key, value interface{}) bool { + ipMap.Range(func(_, _ any) bool { counter++ return true }) @@ -208,25 +212,21 @@ func (l *Limiter) GetUserBucket(tag string, email string, ip string) (limiter *r // Speed limit limit := determineRate(nodeLimit, userLimit) // Determine the speed limit rate if limit > 0 { - limiter := rate.NewLimiter(rate.Limit(limit), int(limit)) // Byte/s + limiter := rate.NewLimiter(rate.Limit(limit), int(limit)) //nolint:gosec // Limit is not from user input if v, ok := inboundInfo.BucketHub.LoadOrStore(email, limiter); ok { bucket := v.(*rate.Limiter) return bucket, true, false - } else { - return limiter, true, false } - } else { - return nil, false, false + return limiter, true, false } - } else { - errors.LogDebug(context.Background(), "Get Inbound Limiter information failed") return nil, false, false } + errors.LogDebug(context.Background(), "Get Inbound Limiter information failed") + return nil, false, false } // Global device limit func globalLimit(inboundInfo *InboundInfo, email string, uid int, ip string, deviceLimit int) bool { - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(inboundInfo.GlobalLimit.config.Timeout)*time.Second) defer cancel() @@ -237,22 +237,22 @@ func globalLimit(inboundInfo *InboundInfo, email string, uid int, ip string, dev if err != nil { if _, ok := err.(*store.NotFound); ok { // If the email is a new device - go pushIP(inboundInfo, uniqueKey, &map[string]int{ip: uid}) + go pushIP(inboundInfo, uniqueKey, map[string]int{ip: uid}) } else { errors.LogErrorInner(context.Background(), err, "cache service") } return false } - ipMap := v.(*map[string]int) + ipMap := v.(map[string]int) // Reject device reach limit directly - if deviceLimit > 0 && len(*ipMap) > deviceLimit { + if deviceLimit > 0 && len(ipMap) > deviceLimit { return true } // If the ip is not in cache - if _, ok := (*ipMap)[ip]; !ok { - (*ipMap)[ip] = uid + if _, ok := (ipMap)[ip]; !ok { + (ipMap)[ip] = uid go pushIP(inboundInfo, uniqueKey, ipMap) } @@ -260,7 +260,7 @@ func globalLimit(inboundInfo *InboundInfo, email string, uid int, ip string, dev } // push the ip to cache -func pushIP(inboundInfo *InboundInfo, uniqueKey string, ipMap *map[string]int) { +func pushIP(inboundInfo *InboundInfo, uniqueKey string, ipMap map[string]int) { ctx, cancel := context.WithTimeout(context.Background(), time.Duration(inboundInfo.GlobalLimit.config.Timeout)*time.Second) defer cancel() @@ -276,16 +276,11 @@ func determineRate(nodeLimit, userLimit uint64) (limit uint64) { return nodeLimit } else if nodeLimit < userLimit { return userLimit - } else { - return 0 - } - } else { - if nodeLimit > userLimit { - return userLimit - } else if nodeLimit < userLimit { - return nodeLimit - } else { - return nodeLimit } + return 0 + } + if nodeLimit > userLimit { + return userLimit } + return nodeLimit } diff --git a/common/limiter/rate.go b/common/limiter/rate.go index e8911217..7e774627 100644 --- a/common/limiter/rate.go +++ b/common/limiter/rate.go @@ -2,7 +2,6 @@ package limiter import ( "context" - "io" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" @@ -12,10 +11,9 @@ import ( type Writer struct { writer buf.Writer limiter *rate.Limiter - w io.Writer } -func (l *Limiter) RateWriter(writer buf.Writer, limiter *rate.Limiter) buf.Writer { +func (*Limiter) RateWriter(writer buf.Writer, limiter *rate.Limiter) buf.Writer { return &Writer{ writer: writer, limiter: limiter, @@ -28,6 +26,9 @@ func (w *Writer) Close() error { func (w *Writer) WriteMultiBuffer(mb buf.MultiBuffer) error { ctx := context.Background() - w.limiter.WaitN(ctx, int(mb.Len())) + err := w.limiter.WaitN(ctx, int(mb.Len())) + if err != nil { + return err + } return w.writer.WriteMultiBuffer(mb) } diff --git a/common/mylego/model.go b/common/mylego/model.go index 270907bf..3085db85 100644 --- a/common/mylego/model.go +++ b/common/mylego/model.go @@ -11,6 +11,12 @@ type CertConfig struct { RejectUnknownSni bool `mapstructure:"RejectUnknownSni"` } +const ( + CertModeHTTP = "http" + CertModeDNS = "dns" + CertModeTLS = "tls" +) + type LegoCMD struct { C *CertConfig path string diff --git a/common/mylego/mylego.go b/common/mylego/mylego.go index c326d7ab..afd83688 100644 --- a/common/mylego/mylego.go +++ b/common/mylego/mylego.go @@ -13,7 +13,7 @@ var defaultPath string func New(certConf *CertConfig) (*LegoCMD, error) { // Set default path to configPath/cert - var p = "" + var p string configPath := os.Getenv("XRAY_LOCATION_CONFIG") if configPath != "" { p = configPath @@ -32,16 +32,16 @@ func New(certConf *CertConfig) (*LegoCMD, error) { return lego, nil } -func (l *LegoCMD) getPath() string { +func (l *LegoCMD) getPath() string { //nolint:unused // leave it here return l.path } -func (l *LegoCMD) getCertConfig() *CertConfig { +func (l *LegoCMD) getCertConfig() *CertConfig { //nolint:unused // leave it here return l.C } // DNSCert cert a domain using DNS API -func (l *LegoCMD) DNSCert() (CertPath string, KeyPath string, err error) { +func (l *LegoCMD) DNSCert() (certPath, keyPath string, err error) { defer func() (string, string, error) { // Handle any error if r := recover(); r != nil { @@ -55,8 +55,8 @@ func (l *LegoCMD) DNSCert() (CertPath string, KeyPath string, err error) { } return "", "", err } - return CertPath, KeyPath, nil - }() + return certPath, keyPath, nil + }() //nolint:errcheck // leave it here // Set Env for DNS configuration for key, value := range l.C.DNSEnv { @@ -64,24 +64,24 @@ func (l *LegoCMD) DNSCert() (CertPath string, KeyPath string, err error) { } // First check if the certificate exists - CertPath, KeyPath, err = checkCertFile(l.C.CertDomain) + certPath, keyPath, err = checkCertFile(l.C.CertDomain) if err == nil { - return CertPath, KeyPath, err + return certPath, keyPath, nil } err = l.Run() if err != nil { return "", "", err } - CertPath, KeyPath, err = checkCertFile(l.C.CertDomain) + certPath, keyPath, err = checkCertFile(l.C.CertDomain) if err != nil { return "", "", err } - return CertPath, KeyPath, nil + return certPath, keyPath, nil } // HTTPCert cert a domain using http methods -func (l *LegoCMD) HTTPCert() (CertPath string, KeyPath string, err error) { +func (l *LegoCMD) HTTPCert() (certPath, keyPath string, err error) { defer func() (string, string, error) { // Handle any error if r := recover(); r != nil { @@ -95,13 +95,13 @@ func (l *LegoCMD) HTTPCert() (CertPath string, KeyPath string, err error) { } return "", "", err } - return CertPath, KeyPath, nil - }() + return certPath, keyPath, nil + }() //nolint:errcheck // leave it here // First check if the certificate exists - CertPath, KeyPath, err = checkCertFile(l.C.CertDomain) + certPath, keyPath, err = checkCertFile(l.C.CertDomain) if err == nil { - return CertPath, KeyPath, err + return certPath, keyPath, nil } err = l.Run() @@ -109,16 +109,16 @@ func (l *LegoCMD) HTTPCert() (CertPath string, KeyPath string, err error) { return "", "", err } - CertPath, KeyPath, err = checkCertFile(l.C.CertDomain) + certPath, keyPath, err = checkCertFile(l.C.CertDomain) if err != nil { return "", "", err } - return CertPath, KeyPath, nil + return certPath, keyPath, nil } // RenewCert renew a domain cert -func (l *LegoCMD) RenewCert() (CertPath string, KeyPath string, ok bool, err error) { +func (l *LegoCMD) RenewCert() (certPath, keyPath string, ok bool, err error) { defer func() (string, string, bool, error) { // Handle any error if r := recover(); r != nil { @@ -132,15 +132,15 @@ func (l *LegoCMD) RenewCert() (CertPath string, KeyPath string, ok bool, err err } return "", "", false, err } - return CertPath, KeyPath, ok, nil - }() + return certPath, keyPath, ok, nil + }() //nolint:errcheck // leave it here ok, err = l.Renew() if err != nil { return } - CertPath, KeyPath, err = checkCertFile(l.C.CertDomain) + certPath, keyPath, err = checkCertFile(l.C.CertDomain) if err != nil { return } @@ -148,7 +148,7 @@ func (l *LegoCMD) RenewCert() (CertPath string, KeyPath string, ok bool, err err return } -func checkCertFile(domain string) (string, string, error) { +func checkCertFile(domain string) (absCertPath, absKeyPath string, err error) { keyPath := path.Join(defaultPath, "certificates", fmt.Sprintf("%s.key", sanitizedDomain(domain))) certPath := path.Join(defaultPath, "certificates", fmt.Sprintf("%s.crt", sanitizedDomain(domain))) if _, err := os.Stat(keyPath); os.IsNotExist(err) { @@ -157,7 +157,7 @@ func checkCertFile(domain string) (string, string, error) { if _, err := os.Stat(certPath); os.IsNotExist(err) { return "", "", fmt.Errorf("cert cert failed: %s", domain) } - absKeyPath, _ := filepath.Abs(keyPath) - absCertPath, _ := filepath.Abs(certPath) + absKeyPath, _ = filepath.Abs(keyPath) + absCertPath, _ = filepath.Abs(certPath) return absCertPath, absKeyPath, nil } diff --git a/common/mylego/renew_test.go b/common/mylego/renew_test.go index fc37fe67..c6a584cc 100644 --- a/common/mylego/renew_test.go +++ b/common/mylego/renew_test.go @@ -48,7 +48,6 @@ func Test_merge(t *testing.T) { } for _, test := range testCases { - test := test t.Run(test.desc, func(t *testing.T) { t.Parallel() @@ -108,7 +107,6 @@ func Test_needRenewal(t *testing.T) { } for _, test := range testCases { - test := test t.Run(test.desc, func(t *testing.T) { actual := needRenewal(test.x509Cert, "foo.com", test.days) diff --git a/common/rule/rule.go b/common/rule/rule.go index 40348fb4..8a5b1852 100644 --- a/common/rule/rule.go +++ b/common/rule/rule.go @@ -49,14 +49,14 @@ func (r *Manager) GetDetectResult(tag string) (*[]api.DetectResult, error) { return &detectResult, nil } -func (r *Manager) Detect(tag string, destination string, email string) (reject bool) { +func (r *Manager) Detect(tag, destination, email string) (reject bool) { reject = false var hitRuleID = -1 // If we have some rule for this inbound if value, ok := r.InboundRule.Load(tag); ok { ruleList := value.([]api.DetectRule) for _, r := range ruleList { - if r.Pattern.Match([]byte(destination)) { + if r.Pattern.MatchString(destination) { hitRuleID = r.ID reject = true break diff --git a/common/serverstatus/serverstatus.go b/common/serverstatus/serverstatus.go index aa8f36bc..a3647606 100644 --- a/common/serverstatus/serverstatus.go +++ b/common/serverstatus/serverstatus.go @@ -2,6 +2,7 @@ package serverstatus import ( + "errors" "fmt" "github.com/shirou/gopsutil/v3/cpu" @@ -11,16 +12,15 @@ import ( ) // GetSystemInfo get the system info of a given periodic -func GetSystemInfo() (Cpu float64, Mem float64, Disk float64, Uptime uint64, err error) { - +func GetSystemInfo() (c, m, d float64, up uint64, err error) { errorString := "" cpuPercent, err := cpu.Percent(0, false) // Check if cpuPercent is empty if len(cpuPercent) > 0 && err == nil { - Cpu = cpuPercent[0] + c = cpuPercent[0] } else { - Cpu = 0 + c = 0 errorString += fmt.Sprintf("get cpu usage failed: %s ", err) } @@ -28,26 +28,26 @@ func GetSystemInfo() (Cpu float64, Mem float64, Disk float64, Uptime uint64, err if err != nil { errorString += fmt.Sprintf("get mem usage failed: %s ", err) } else { - Mem = memUsage.UsedPercent + m = memUsage.UsedPercent } diskUsage, err := disk.Usage("/") if err != nil { errorString += fmt.Sprintf("get disk usage failed: %s ", err) } else { - Disk = diskUsage.UsedPercent + d = diskUsage.UsedPercent } uptime, err := host.Uptime() if err != nil { errorString += fmt.Sprintf("get uptime failed: %s ", err) } else { - Uptime = uptime + up = uptime } if errorString != "" { - err = fmt.Errorf(errorString) + err = errors.New(errorString) } - return Cpu, Mem, Disk, Uptime, err + return c, m, d, up, err } diff --git a/panel/config.go b/panel/config.go index a9aa9643..f447db2e 100644 --- a/panel/config.go +++ b/panel/config.go @@ -7,7 +7,7 @@ import ( type Config struct { LogConfig *LogConfig `mapstructure:"Log"` - DnsConfigPath string `mapstructure:"DnsConfigPath"` + DNSConfigPath string `mapstructure:"DnsConfigPath"` InboundConfigPath string `mapstructure:"InboundConfigPath"` OutboundConfigPath string `mapstructure:"OutboundConfigPath"` RouteConfigPath string `mapstructure:"RouteConfigPath"` @@ -17,7 +17,7 @@ type Config struct { type NodesConfig struct { PanelType string `mapstructure:"PanelType"` - ApiConfig *api.Config `mapstructure:"ApiConfig"` + APIConfig *api.Config `mapstructure:"ApiConfig"` ControllerConfig *controller.Config `mapstructure:"ControllerConfig"` } diff --git a/panel/panel.go b/panel/panel.go index b8100393..56732bfb 100644 --- a/panel/panel.go +++ b/panel/panel.go @@ -35,7 +35,7 @@ func New(panelConfig *Config) *Panel { return p } -func (p *Panel) loadCore(panelConfig *Config) *core.Instance { +func (*Panel) loadCore(panelConfig *Config) *core.Instance { // Log Config coreLogConfig := &conf.LogConfig{} logConfig := getDefaultLogConfig() @@ -49,13 +49,13 @@ func (p *Panel) loadCore(panelConfig *Config) *core.Instance { coreLogConfig.ErrorLog = logConfig.ErrorPath // DNS config - coreDnsConfig := &conf.DNSConfig{} - if panelConfig.DnsConfigPath != "" { - if data, err := os.ReadFile(panelConfig.DnsConfigPath); err != nil { - log.Panicf("Failed to read DNS config file at: %s", panelConfig.DnsConfigPath) + coreDNSConfig := &conf.DNSConfig{} + if panelConfig.DNSConfigPath != "" { + if data, err := os.ReadFile(panelConfig.DNSConfigPath); err != nil { + log.Panicf("Failed to read DNS config file at: %s", panelConfig.DNSConfigPath) } else { - if err = json.Unmarshal(data, coreDnsConfig); err != nil { - log.Panicf("Failed to unmarshal DNS config: %s", panelConfig.DnsConfigPath) + if err = json.Unmarshal(data, coreDNSConfig); err != nil { + log.Panicf("Failed to unmarshal DNS config: %s", panelConfig.DNSConfigPath) } } } @@ -65,7 +65,7 @@ func (p *Panel) loadCore(panelConfig *Config) *core.Instance { // config.ControllerConfig.DNSConfig = coreDnsConfig // } - dnsConfig, err := coreDnsConfig.Build() + dnsConfig, err := coreDNSConfig.Build() if err != nil { log.Panicf("Failed to understand DNS config, Please check: https://xtls.github.io/config/dns.html for help: %s", err) } @@ -73,7 +73,8 @@ func (p *Panel) loadCore(panelConfig *Config) *core.Instance { // Routing config coreRouterConfig := &conf.RouterConfig{} if panelConfig.RouteConfigPath != "" { - if data, err := os.ReadFile(panelConfig.RouteConfigPath); err != nil { + var data []byte + if data, err = os.ReadFile(panelConfig.RouteConfigPath); err != nil { log.Panicf("Failed to read Routing config file at: %s", panelConfig.RouteConfigPath) } else { if err = json.Unmarshal(data, coreRouterConfig); err != nil { @@ -88,7 +89,8 @@ func (p *Panel) loadCore(panelConfig *Config) *core.Instance { // Custom Inbound config var coreCustomInboundConfig []conf.InboundDetourConfig if panelConfig.InboundConfigPath != "" { - if data, err := os.ReadFile(panelConfig.InboundConfigPath); err != nil { + var data []byte + if data, err = os.ReadFile(panelConfig.InboundConfigPath); err != nil { log.Panicf("Failed to read Custom Inbound config file at: %s", panelConfig.OutboundConfigPath) } else { if err = json.Unmarshal(data, &coreCustomInboundConfig); err != nil { @@ -98,7 +100,8 @@ func (p *Panel) loadCore(panelConfig *Config) *core.Instance { } var inBoundConfig []*core.InboundHandlerConfig for _, config := range coreCustomInboundConfig { - oc, err := config.Build() + var oc *core.InboundHandlerConfig + oc, err = config.Build() if err != nil { log.Panicf("Failed to understand Inbound config, Please check: https://xtls.github.io/config/inbound.html for help: %s", err) } @@ -107,7 +110,8 @@ func (p *Panel) loadCore(panelConfig *Config) *core.Instance { // Custom Outbound config var coreCustomOutboundConfig []conf.OutboundDetourConfig if panelConfig.OutboundConfigPath != "" { - if data, err := os.ReadFile(panelConfig.OutboundConfigPath); err != nil { + var data []byte + if data, err = os.ReadFile(panelConfig.OutboundConfigPath); err != nil { log.Panicf("Failed to read Custom Outbound config file at: %s", panelConfig.OutboundConfigPath) } else { if err = json.Unmarshal(data, &coreCustomOutboundConfig); err != nil { @@ -117,7 +121,8 @@ func (p *Panel) loadCore(panelConfig *Config) *core.Instance { } var outBoundConfig []*core.OutboundHandlerConfig for _, config := range coreCustomOutboundConfig { - oc, err := config.Build() + var oc *core.OutboundHandlerConfig + oc, err = config.Build() if err != nil { log.Panicf("Failed to understand Outbound config, Please check: https://xtls.github.io/config/outbound.html for help: %s", err) } @@ -169,7 +174,7 @@ func (p *Panel) Start() { switch nodeConfig.PanelType { case "SSpanel": case "FAC": - apiClient = fac.New(nodeConfig.ApiConfig) + apiClient = fac.New(nodeConfig.APIConfig) default: log.Panicf("Unsupport panel type: %s", nodeConfig.PanelType) } @@ -183,7 +188,6 @@ func (p *Panel) Start() { } controllerService = controller.New(server, apiClient, controllerConfig, nodeConfig.PanelType) p.Service = append(p.Service, controllerService) - } // Start all the service diff --git a/service/controller/control.go b/service/controller/control.go index e06f4fe3..19f18b5c 100644 --- a/service/controller/control.go +++ b/service/controller/control.go @@ -32,7 +32,7 @@ func (c *Controller) addInbound(config *core.InboundHandlerConfig) error { } handler, ok := rawHandler.(inbound.Handler) if !ok { - return fmt.Errorf("not an InboundHandler: %s", err) + return fmt.Errorf("not an InboundHandler: %w", err) } if err := c.ibm.AddHandler(context.Background(), handler); err != nil { return err @@ -47,7 +47,7 @@ func (c *Controller) addOutbound(config *core.OutboundHandlerConfig) error { } handler, ok := rawHandler.(outbound.Handler) if !ok { - return fmt.Errorf("not an InboundHandler: %s", err) + return fmt.Errorf("not an InboundHandler: %w", err) } if err := c.obm.AddHandler(context.Background(), handler); err != nil { return err @@ -58,7 +58,7 @@ func (c *Controller) addOutbound(config *core.OutboundHandlerConfig) error { func (c *Controller) addUsers(users []*protocol.User, tag string) error { handler, err := c.ibm.GetHandler(context.Background(), tag) if err != nil { - return fmt.Errorf("no such inbound tag: %s", err) + return fmt.Errorf("no such inbound tag: %w", err) } inboundInstance, ok := handler.(proxy.GetInbound) if !ok { @@ -85,7 +85,7 @@ func (c *Controller) addUsers(users []*protocol.User, tag string) error { func (c *Controller) removeUsers(users []string, tag string) error { handler, err := c.ibm.GetHandler(context.Background(), tag) if err != nil { - return fmt.Errorf("no such inbound tag: %s", err) + return fmt.Errorf("no such inbound tag: %w", err) } inboundInstance, ok := handler.(proxy.GetInbound) if !ok { @@ -94,7 +94,7 @@ func (c *Controller) removeUsers(users []string, tag string) error { userManager, ok := inboundInstance.GetInbound().(proxy.UserManager) if !ok { - return fmt.Errorf("handler %s is not implement proxy.UserManager", err) + return fmt.Errorf("handler is not implement proxy.UserManager %w", err) } for _, email := range users { err = userManager.RemoveUser(context.Background(), email) @@ -105,7 +105,7 @@ func (c *Controller) removeUsers(users []string, tag string) error { return nil } -func (c *Controller) getTraffic(email string) (up int64, down int64, upCounter stats.Counter, downCounter stats.Counter) { +func (c *Controller) getTraffic(email string) (up, down int64, upCounter, downCounter stats.Counter) { upName := "user>>>" + email + ">>>traffic>>>uplink" downName := "user>>>" + email + ">>>traffic>>>downlink" upCounter = c.stm.GetCounter(upName) @@ -123,7 +123,7 @@ func (c *Controller) getTraffic(email string) (up int64, down int64, upCounter s return up, down, upCounter, downCounter } -func (c *Controller) resetTraffic(upCounterList *[]stats.Counter, downCounterList *[]stats.Counter) { +func (*Controller) resetTraffic(upCounterList, downCounterList *[]stats.Counter) { for _, upCounter := range *upCounterList { upCounter.Set(0) } @@ -132,7 +132,11 @@ func (c *Controller) resetTraffic(upCounterList *[]stats.Counter, downCounterLis } } -func (c *Controller) AddInboundLimiter(tag string, nodeSpeedLimit uint64, userList *[]api.UserInfo, globalDeviceLimitConfig *limiter.GlobalDeviceLimitConfig) error { +func (c *Controller) AddInboundLimiter( + tag string, + nodeSpeedLimit uint64, + userList *[]api.UserInfo, + globalDeviceLimitConfig *limiter.GlobalDeviceLimitConfig) error { err := c.dispatcher.Limiter.AddInboundLimiter(tag, nodeSpeedLimit, userList, globalDeviceLimitConfig) return err } diff --git a/service/controller/controller.go b/service/controller/controller.go index de049468..f7f36025 100644 --- a/service/controller/controller.go +++ b/service/controller/controller.go @@ -53,16 +53,16 @@ type periodicTask struct { } // New return a Controller service with default parameters. -func New(server *core.Instance, api api.API, config *Config, panelType string) *Controller { +func New(server *core.Instance, apiClient api.API, config *Config, panelType string) *Controller { logger := log.NewEntry(log.StandardLogger()).WithFields(log.Fields{ - "Host": api.Describe().APIHost, - "Type": api.Describe().NodeType, - "ID": api.Describe().NodeID, + "Host": apiClient.Describe().APIHost, + "Type": apiClient.Describe().NodeType, + "ID": apiClient.Describe().NodeID, }) controller := &Controller{ server: server, config: config, - apiClient: api, + apiClient: apiClient, panelType: panelType, ibm: server.GetFeature(inbound.ManagerType()).(inbound.Manager), obm: server.GetFeature(outbound.ManagerType()).(outbound.Manager), @@ -151,7 +151,7 @@ func (c *Controller) Start() error { ) // Check cert service in need - if c.nodeInfo.EnableTLS && c.config.EnableREALITY == false { + if c.nodeInfo.EnableTLS && !c.config.EnableREALITY { c.tasks = append(c.tasks, periodicTask{ tag: "cert monitor", Periodic: &task.Periodic{ @@ -163,7 +163,11 @@ func (c *Controller) Start() error { // Start periodic tasks for i := range c.tasks { c.logger.Printf("Start %s periodic task", c.tasks[i].tag) - go c.tasks[i].Start() + go func(task periodicTask) { + if err := task.Start(); err != nil { + c.logger.Printf("Error starting %s periodic task: %v", task.tag, err) + } + }(c.tasks[i]) } return nil @@ -223,12 +227,12 @@ func (c *Controller) nodeInfoMonitor() (err error) { if !reflect.DeepEqual(c.nodeInfo, newNodeInfo) { // Remove old tag oldTag := c.Tag - err := c.removeOldTag(oldTag) + err = c.removeOldTag(oldTag) if err != nil { c.logger.Print(err) return nil } - if c.nodeInfo.NodeType == "Shadowsocks-Plugin" { + if c.nodeInfo.NodeType == api.NodeTypeShadowsocksPlugin { err = c.removeOldTag(fmt.Sprintf("dokodemo-door_%s+1", c.Tag)) } if err != nil { @@ -256,12 +260,13 @@ func (c *Controller) nodeInfoMonitor() (err error) { // Check Rule if !c.config.DisableGetRule { - if ruleList, err := c.apiClient.GetNodeRule(); err != nil { + var ruleList *[]api.DetectRule + if ruleList, err = c.apiClient.GetNodeRule(); err != nil { if err.Error() != api.RuleNotModified { c.logger.Printf("Get rule list filed: %s", err) } } else if len(*ruleList) > 0 { - if err := c.UpdateRule(c.Tag, *ruleList); err != nil { + if err = c.UpdateRule(c.Tag, *ruleList); err != nil { c.logger.Print(err) } } @@ -275,11 +280,10 @@ func (c *Controller) nodeInfoMonitor() (err error) { } // Add Limiter - if err := c.AddInboundLimiter(c.Tag, newNodeInfo.SpeedLimit, newUserInfo, c.config.GlobalDeviceLimitConfig); err != nil { + if err = c.AddInboundLimiter(c.Tag, newNodeInfo.SpeedLimit, newUserInfo, c.config.GlobalDeviceLimitConfig); err != nil { c.logger.Print(err) return nil } - } else { var deleted, added []api.UserInfo if usersChanged { @@ -290,7 +294,7 @@ func (c *Controller) nodeInfoMonitor() (err error) { for i, u := range deleted { deletedEmail[i] = fmt.Sprintf("%s|%s|%d", c.Tag, u.Email, u.UID) } - err := c.removeUsers(deletedEmail, c.Tag) + err = c.removeUsers(deletedEmail, c.Tag) if err != nil { c.logger.Print(err) } @@ -325,37 +329,33 @@ func (c *Controller) removeOldTag(oldTag string) (err error) { } func (c *Controller) addNewTag(newNodeInfo *api.NodeInfo) (err error) { - if newNodeInfo.NodeType != "Shadowsocks-Plugin" { + if newNodeInfo.NodeType != api.NodeTypeShadowsocksPlugin { inboundConfig, err := InboundBuilder(c.config, newNodeInfo, c.Tag) if err != nil { return err } err = c.addInbound(inboundConfig) if err != nil { - return err } outBoundConfig, err := OutboundBuilder(c.config, newNodeInfo, c.Tag) if err != nil { - return err } err = c.addOutbound(outBoundConfig) if err != nil { - return err } - } else { return c.addInboundForSSPlugin(*newNodeInfo) } return nil } -func (c *Controller) addInboundForSSPlugin(newNodeInfo api.NodeInfo) (err error) { +func (c *Controller) addInboundForSSPlugin(newNodeInfo api.NodeInfo) (err error) { //nolint:gocritic // ignore // Shadowsocks-Plugin require a separate inbound for other TransportProtocol likes: ws, grpc fakeNodeInfo := newNodeInfo - fakeNodeInfo.TransportProtocol = "tcp" + fakeNodeInfo.TransportProtocol = api.TransportProtocolTCP fakeNodeInfo.EnableTLS = false // Add a regular Shadowsocks inbound and outbound inboundConfig, err := InboundBuilder(c.config, &fakeNodeInfo, c.Tag) @@ -364,23 +364,20 @@ func (c *Controller) addInboundForSSPlugin(newNodeInfo api.NodeInfo) (err error) } err = c.addInbound(inboundConfig) if err != nil { - return err } outBoundConfig, err := OutboundBuilder(c.config, &fakeNodeInfo, c.Tag) if err != nil { - return err } err = c.addOutbound(outBoundConfig) if err != nil { - return err } // Add an inbound for upper streaming protocol fakeNodeInfo = newNodeInfo fakeNodeInfo.Port++ - fakeNodeInfo.NodeType = "dokodemo-door" + fakeNodeInfo.NodeType = api.NodeTypeDokodemo dokodemoTag := fmt.Sprintf("dokodemo-door_%s+1", c.Tag) inboundConfig, err = InboundBuilder(c.config, &fakeNodeInfo, dokodemoTag) if err != nil { @@ -388,36 +385,33 @@ func (c *Controller) addInboundForSSPlugin(newNodeInfo api.NodeInfo) (err error) } err = c.addInbound(inboundConfig) if err != nil { - return err } outBoundConfig, err = OutboundBuilder(c.config, &fakeNodeInfo, dokodemoTag) if err != nil { - return err } err = c.addOutbound(outBoundConfig) if err != nil { - return err } return nil } func (c *Controller) addNewUser(userInfo *[]api.UserInfo, nodeInfo *api.NodeInfo) (err error) { - users := make([]*protocol.User, 0) + var users []*protocol.User switch nodeInfo.NodeType { - case "V2ray", "Vmess", "Vless": - if nodeInfo.EnableVless || (nodeInfo.NodeType == "Vless" && nodeInfo.NodeType != "Vmess") { + case api.NodeTypeV2ray, api.NodeTypeVMESS, api.NodeTypeVLess: + if nodeInfo.EnableVless || (nodeInfo.NodeType == api.NodeTypeVLess && nodeInfo.NodeType != api.NodeTypeVMESS) { users = c.buildVlessUser(userInfo) } else { users = c.buildVmessUser(userInfo) } - case "Trojan": + case api.NodeTypeTrojan: users = c.buildTrojanUser(userInfo) - case "Shadowsocks": + case api.NodeTypeShadowsocks: users = c.buildSSUser(userInfo, nodeInfo.CypherMethod) - case "Shadowsocks-Plugin": + case api.NodeTypeShadowsocksPlugin: users = c.buildSSPluginUser(userInfo) default: return fmt.Errorf("unsupported node type: %s", nodeInfo.NodeType) @@ -431,7 +425,7 @@ func (c *Controller) addNewUser(userInfo *[]api.UserInfo, nodeInfo *api.NodeInfo return nil } -func compareUserList(old, new *[]api.UserInfo) (deleted, added []api.UserInfo) { +func compareUserList(old, new *[]api.UserInfo) (deleted, added []api.UserInfo) { //nolint:gocritic // ignore mSrc := make(map[api.UserInfo]byte) // 按源数组建索引 mAll := make(map[api.UserInfo]byte) // 源+目所有元素建索引 @@ -447,7 +441,7 @@ func compareUserList(old, new *[]api.UserInfo) (deleted, added []api.UserInfo) { l := len(mAll) mAll[v] = 1 if l != len(mAll) { // 长度变化,即可以存 - l = len(mAll) + l = len(mAll) //nolint: ineffassign,staticcheck // ignore } else { // 存不了,进并集 set = append(set, v) } @@ -469,14 +463,17 @@ func compareUserList(old, new *[]api.UserInfo) (deleted, added []api.UserInfo) { return deleted, added } -func limitUser(c *Controller, user api.UserInfo, silentUsers *[]api.UserInfo) { +func limitUser(c *Controller, user api.UserInfo, silentUsers *[]api.UserInfo) { //nolint:gocritic // ignore c.limitedUsers[user] = LimitInfo{ end: time.Now().Unix() + int64(c.config.AutoSpeedLimitConfig.LimitDuration*60), currentSpeedLimit: c.config.AutoSpeedLimitConfig.LimitSpeed, originSpeedLimit: user.SpeedLimit, } - c.logger.Printf("Limit User: %s Speed: %d End: %s", c.buildUserTag(&user), c.config.AutoSpeedLimitConfig.LimitSpeed, time.Unix(c.limitedUsers[user].end, 0).Format("01-02 15:04:05")) - user.SpeedLimit = uint64((c.config.AutoSpeedLimitConfig.LimitSpeed * 1000000) / 8) + c.logger.Printf("Limit User: %s Speed: %d End: %s", + c.buildUserTag(&user), + c.config.AutoSpeedLimitConfig.LimitSpeed, + time.Unix(c.limitedUsers[user].end, 0).Format("01-02 15:04:05")) + user.SpeedLimit = uint64((c.config.AutoSpeedLimitConfig.LimitSpeed * 1000000) / 8) //nolint:gosec // ignore *silentUsers = append(*silentUsers, user) } @@ -512,7 +509,10 @@ func (c *Controller) userInfoMonitor() (err error) { c.logger.Printf("User: %s Speed: %d End: nil (Unlimit)", c.buildUserTag(&user), user.SpeedLimit) delete(c.limitedUsers, user) } else { - c.logger.Printf("User: %s Speed: %d End: %s", c.buildUserTag(&user), limitInfo.currentSpeedLimit, time.Unix(c.limitedUsers[user].end, 0).Format("01-02 15:04:05")) + c.logger.Printf("User: %s Speed: %d End: %s", + c.buildUserTag(&user), + limitInfo.currentSpeedLimit, + time.Unix(c.limitedUsers[user].end, 0).Format("01-02 15:04:05")) } } if len(toReleaseUsers) > 0 { @@ -604,7 +604,6 @@ func (c *Controller) userInfoMonitor() (err error) { } else { c.logger.Printf("Report %d illegal behaviors", len(*detectResult)) } - } return nil } @@ -619,9 +618,9 @@ func (c *Controller) buildNodeTag() string { // Check Cert func (c *Controller) certMonitor() error { - if c.nodeInfo.EnableTLS && c.config.EnableREALITY == false { + if c.nodeInfo.EnableTLS && !c.config.EnableREALITY { switch c.config.CertConfig.CertMode { - case "dns", "http", "tls": + case mylego.CertModeHTTP, mylego.CertModeDNS, mylego.CertModeTLS: lego, err := mylego.New(c.config.CertConfig) if err != nil { c.logger.Print(err) diff --git a/service/controller/controller_test.go b/service/controller/controller_test.go index 73621dea..57c4f6f8 100644 --- a/service/controller/controller_test.go +++ b/service/controller/controller_test.go @@ -31,14 +31,6 @@ func TestController(t *testing.T) { serverConfig.Policy = policyConfig config, _ := serverConfig.Build() - // config := &core.Config{ - // App: []*serial.TypedMessage{ - // serial.ToTypedMessage(&dispatcher.Config{}), - // serial.ToTypedMessage(&proxyman.InboundConfig{}), - // serial.ToTypedMessage(&proxyman.OutboundConfig{}), - // serial.ToTypedMessage(&stats.Config{}), - // }} - server, err := core.New(config) if err != nil { t.Errorf("failed to create instance: %s", err) @@ -76,7 +68,7 @@ func TestController(t *testing.T) { { osSignals := make(chan os.Signal, 1) - signal.Notify(osSignals, os.Interrupt, os.Kill, syscall.SIGTERM) + signal.Notify(osSignals, os.Interrupt, syscall.SIGTERM) <-osSignals } } diff --git a/service/controller/errors.generated.go b/service/controller/errors.generated.go index 18d55320..e7b0ca81 100644 --- a/service/controller/errors.generated.go +++ b/service/controller/errors.generated.go @@ -2,6 +2,6 @@ package controller import "github.com/xtls/xray-core/common/errors" -func newError(values ...interface{}) *errors.Error { +func newError(values ...any) *errors.Error { return errors.New(values...) } diff --git a/service/controller/inboundbuilder.go b/service/controller/inboundbuilder.go index c9f05879..1e4da400 100644 --- a/service/controller/inboundbuilder.go +++ b/service/controller/inboundbuilder.go @@ -80,7 +80,7 @@ func InboundBuilder(config *Config, nodeInfo *api.NodeInfo, tag string) (*core.I protocol = "vmess" proxySetting = &conf.VMessInboundConfig{} } - case "Trojan": + case api.NodeTypeTrojan: protocol = "trojan" // Enable fallback if config.EnableFallback { @@ -95,7 +95,7 @@ func InboundBuilder(config *Config, nodeInfo *api.NodeInfo, tag string) (*core.I } else { proxySetting = &conf.TrojanServerConfig{} } - case "Shadowsocks", "Shadowsocks-Plugin": + case api.NodeTypeShadowsocks, api.NodeTypeShadowsocksPlugin: protocol = "shadowsocks" cipher := strings.ToLower(nodeInfo.CypherMethod) @@ -108,7 +108,10 @@ func InboundBuilder(config *Config, nodeInfo *api.NodeInfo, tag string) (*core.I // shadowsocks must have a random password // shadowsocks2022's password == user PSK, thus should a length of string >= 32 and base64 encoder b := make([]byte, 32) - rand.Read(b) + _, err := rand.Read(b) + if err != nil { + return nil, fmt.Errorf("generate random password failed: %w", err) + } randPasswd := hex.EncodeToString(b) if C.Contains(shadowaead_2022.List, cipher) { proxySetting.Users = append(proxySetting.Users, &conf.ShadowsocksUserConfig{ @@ -124,7 +127,7 @@ func InboundBuilder(config *Config, nodeInfo *api.NodeInfo, tag string) (*core.I proxySetting.IVCheck = false } - case "dokodemo-door": + case api.NodeTypeDokodemo: protocol = "dokodemo-door" proxySetting = struct { Host string `json:"address"` @@ -139,7 +142,7 @@ func InboundBuilder(config *Config, nodeInfo *api.NodeInfo, tag string) (*core.I setting, err := json.Marshal(proxySetting) if err != nil { - return nil, fmt.Errorf("marshal proxy %s config failed: %s", nodeInfo.NodeType, err) + return nil, fmt.Errorf("marshal proxy %s config failed: %w", nodeInfo.NodeType, err) } inboundDetourConfig.Protocol = protocol inboundDetourConfig.Settings = &setting @@ -149,7 +152,7 @@ func InboundBuilder(config *Config, nodeInfo *api.NodeInfo, tag string) (*core.I transportProtocol := conf.TransportProtocol(nodeInfo.TransportProtocol) networkType, err := transportProtocol.Build() if err != nil { - return nil, fmt.Errorf("convert TransportProtocol failed: %s", err) + return nil, fmt.Errorf("convert TransportProtocol failed: %w", err) } switch networkType { @@ -175,7 +178,7 @@ func InboundBuilder(config *Config, nodeInfo *api.NodeInfo, tag string) (*core.I Host: &hosts, Path: nodeInfo.Path, Method: nodeInfo.Method, - Headers: nodeInfo.HttpHeaders, + Headers: nodeInfo.HTTPHeaders, } streamSetting.HTTPSettings = httpSettings case "grpc": @@ -270,7 +273,7 @@ func InboundBuilder(config *Config, nodeInfo *api.NodeInfo, tag string) (*core.I return inboundDetourConfig.Build() } -func getCertFile(certConfig *mylego.CertConfig) (certFile string, keyFile string, err error) { +func getCertFile(certConfig *mylego.CertConfig) (certFile, keyFile string, err error) { switch certConfig.CertMode { case "file": if certConfig.CertFile == "" || certConfig.KeyFile == "" { @@ -309,7 +312,6 @@ func buildVlessFallbacks(fallbackConfigs []*FallBackConfig) ([]*conf.VLessInboun vlessFallBacks := make([]*conf.VLessInboundFallback, len(fallbackConfigs)) for i, c := range fallbackConfigs { - if c.Dest == "" { return nil, fmt.Errorf("dest is required for fallback failed") } @@ -317,7 +319,7 @@ func buildVlessFallbacks(fallbackConfigs []*FallBackConfig) ([]*conf.VLessInboun var dest json.RawMessage dest, err := json.Marshal(c.Dest) if err != nil { - return nil, fmt.Errorf("marshal dest %s config failed: %s", dest, err) + return nil, fmt.Errorf("marshal dest %s config failed: %w", dest, err) } vlessFallBacks[i] = &conf.VLessInboundFallback{ Name: c.SNI, @@ -337,7 +339,6 @@ func buildTrojanFallbacks(fallbackConfigs []*FallBackConfig) ([]*conf.TrojanInbo trojanFallBacks := make([]*conf.TrojanInboundFallback, len(fallbackConfigs)) for i, c := range fallbackConfigs { - if c.Dest == "" { return nil, fmt.Errorf("dest is required for fallback failed") } @@ -345,7 +346,7 @@ func buildTrojanFallbacks(fallbackConfigs []*FallBackConfig) ([]*conf.TrojanInbo var dest json.RawMessage dest, err := json.Marshal(c.Dest) if err != nil { - return nil, fmt.Errorf("marshal dest %s config failed: %s", dest, err) + return nil, fmt.Errorf("marshal dest %s config failed: %w", dest, err) } trojanFallBacks[i] = &conf.TrojanInboundFallback{ Name: c.SNI, diff --git a/service/controller/outboundbuilder.go b/service/controller/outboundbuilder.go index 6aaad4b0..8423e1b6 100644 --- a/service/controller/outboundbuilder.go +++ b/service/controller/outboundbuilder.go @@ -38,7 +38,7 @@ func OutboundBuilder(config *Config, nodeInfo *api.NodeInfo, tag string) (*core. var setting json.RawMessage setting, err := json.Marshal(proxySetting) if err != nil { - return nil, fmt.Errorf("marshal proxy %s config failed: %s", nodeInfo.NodeType, err) + return nil, fmt.Errorf("marshal proxy %s config failed: %w", nodeInfo.NodeType, err) } outboundDetourConfig.Settings = &setting return outboundDetourConfig.Build() diff --git a/service/controller/userbuilder.go b/service/controller/userbuilder.go index 60b2dc64..375080b9 100644 --- a/service/controller/userbuilder.go +++ b/service/controller/userbuilder.go @@ -167,7 +167,7 @@ func (c *Controller) buildUserTag(user *api.UserInfo) string { return fmt.Sprintf("%s|%s|%d", c.Tag, user.Email, user.UID) } -func (c *Controller) checkShadowsocksPassword(password string, method string) (string, error) { +func (c *Controller) checkShadowsocksPassword(password, method string) (string, error) { if strings.Contains(c.panelType, "V2board") { var userKey string if len(password) < 16 { @@ -182,7 +182,6 @@ func (c *Controller) checkShadowsocksPassword(password string, method string) (s userKey = password[:32] } return base64.StdEncoding.EncodeToString([]byte(userKey)), nil - } else { - return password, nil } + return password, nil }