From 5f1a516e0b5c542affb8f2b6463a7a813de5809f Mon Sep 17 00:00:00 2001 From: tsosunchia <59512455+tsosunchia@users.noreply.github.com> Date: Sun, 8 Oct 2023 11:11:39 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E9=94=99=E8=AF=AF=E5=A4=84?= =?UTF-8?q?=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pow_client.go | 20 +++++++++++++++----- pow_client_test.go | 2 +- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/pow_client.go b/pow_client.go index c155584..9089335 100644 --- a/pow_client.go +++ b/pow_client.go @@ -81,6 +81,9 @@ func RetToken(getTokenParams *GetTokenParams) (string, error) { } } if getTokenParams.Proxy != nil { + if client.Transport == nil { + client.Transport = &http.Transport{} + } client.Transport.(*http.Transport).Proxy = http.ProxyURL(getTokenParams.Proxy) } challengeParams := &ChallengeParams{ @@ -116,11 +119,7 @@ func requestChallenge(challengeParams *ChallengeParams) (*RequestResponse, error //req.Header.Add("Host", getTokenParams.Host) req.Host = challengeParams.Host resp, err := challengeParams.Client.Do(req) - if err != nil || resp.StatusCode != http.StatusOK { - // 如果http_code为429 - if resp.StatusCode == http.StatusTooManyRequests { - log.Fatalln("请求次数超限,请稍后再试") - } + if err != nil { return nil, err } defer func(Body io.ReadCloser) { @@ -129,6 +128,14 @@ func requestChallenge(challengeParams *ChallengeParams) (*RequestResponse, error fmt.Println(err) } }(resp.Body) + + if resp.StatusCode != http.StatusOK { + if resp.StatusCode == http.StatusTooManyRequests { + log.Fatalln("请求次数超限,请稍后再试") + } + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + var challengeResponse RequestResponse err = json.NewDecoder(resp.Body).Decode(&challengeResponse) if err != nil { @@ -184,6 +191,9 @@ func submitAnswer(challengeParams *ChallengeParams, challengeResponse *RequestRe }(resp.Body) if resp.StatusCode != http.StatusOK { + if resp.StatusCode == http.StatusTooManyRequests { + return "", errors.New("请求次数超限") + } bodyBytes, _ := io.ReadAll(resp.Body) return "", errors.New(string(bodyBytes)) } diff --git a/pow_client_test.go b/pow_client_test.go index bf3ce0d..1bc8da4 100644 --- a/pow_client_test.go +++ b/pow_client_test.go @@ -9,7 +9,7 @@ import ( ) func TestGetToken(t *testing.T) { - token, err := getToken("103.120.18.35", "api.leo.moe", "443") + token, err := getToken("api.leo.moe", "api.leo.moe", "443") fmt.Println(token, err) assert.NoError(t, err, "GetToken() returned an error") }