diff --git a/internal/wrench/http_pkg.go b/internal/wrench/http_pkg.go index 5c0e0f0..dc20556 100644 --- a/internal/wrench/http_pkg.go +++ b/internal/wrench/http_pkg.go @@ -1,6 +1,7 @@ package wrench import ( + "context" "encoding/json" "fmt" "io" @@ -355,6 +356,7 @@ func (b *Bot) httpPkgEnsureZigVersionCached(version, versionKind string) { var ( cachedResponsesMu sync.Mutex cachedResponses = map[string]error{} + fsParallelismLock sync.Mutex ) func (b *Bot) httpPkgEnsureZigDownloadCached(version, versionKind, fname string) error { @@ -406,15 +408,15 @@ func (b *Bot) httpPkgEnsureZigDownloadCached(version, versionKind, fname string) logWriter := b.idWriter("zig") cachedResponsesMu.Lock() - defer cachedResponsesMu.Unlock() - - if cachedError, ok := cachedResponses[url]; ok { + cachedError, isCachedError := cachedResponses[url] + cachedResponsesMu.Unlock() + if isCachedError { fmt.Fprintf(logWriter, "not fetching: %s (cached error %s)\n", url, cachedError) return cachedError } fmt.Fprintf(logWriter, "fetch: %s > %s\n", url, filePath) - resp, err := http.Get(url) + resp, err := httpGet(url, 60*time.Second) if err != nil { return errors.Wrap(err, "Get") } @@ -423,13 +425,20 @@ func (b *Bot) httpPkgEnsureZigDownloadCached(version, versionKind, fname string) if resp.StatusCode >= 400 && resp.StatusCode <= 500 { // 404 not found, 403 forbidden, etc. err := fmt.Errorf("bad response status: %s", resp.Status) + cachedResponsesMu.Lock() cachedResponses[url] = err + cachedResponsesMu.Unlock() return err } if resp.StatusCode != http.StatusOK { return fmt.Errorf("bad response status: %s", resp.Status) } + // Two goroutines may have fetched this file at the same time (assuming neither hit the cache, + // new file download) - which is fine - but we can't have them write to disk at the same time. + fsParallelismLock.Lock() + defer fsParallelismLock.Unlock() + if err := os.MkdirAll(dirPath, os.ModePerm); err != nil { return errors.Wrap(err, "MkdirAll "+dirPath) } @@ -625,7 +634,7 @@ func (b *Bot) httpPkgZigIndexCached() ([]byte, error) { } // Fetch the latest upstream Zig index.json - resp, err := http.Get("https://ziglang.org/download/index.json") + resp, err := httpGet("https://ziglang.org/download/index.json", 60*time.Second) if err != nil { return nil, errors.Wrap(err, "fetching upstream https://ziglang.org/download/index.json") } @@ -637,7 +646,7 @@ func (b *Bot) httpPkgZigIndexCached() ([]byte, error) { // Fetch the Mach index.json which contains Mach nominated versions, but is otherwise not as // up-to-date as ziglang.org's version. - resp, err = http.Get("https://machengine.org/zig/index.json") + resp, err = httpGet("https://machengine.org/zig/index.json", 60*time.Second) if err != nil { return nil, errors.Wrap(err, "fetching mach https://machengine.org/zig/index.json") } @@ -703,3 +712,14 @@ func (b *Bot) httpPkgZigIndexCached() ([]byte, error) { return httpPkgZigIndexCached, nil } + +// Like http.Get, but actually respects a timeout instead of leaking a goroutine to forever run. +func httpGet(url string, timeout time.Duration) (*http.Response, error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, err + } + return http.DefaultClient.Do(req) +}