Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: stream http request instead of consuming all #174

Merged
merged 6 commits into from
Dec 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 40 additions & 21 deletions pkg/file/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,44 +73,63 @@ func NewFileFromReader(r io.Reader) (File, error) {
return nil, errors.New("invalid file reader")
}

dummy := &dummyFile{}

decoder := json.NewDecoder(r)
jsonDecodeErr := decoder.Decode(dummy)
// Take a peek and see if we encounter '{' (would imply the contents is JSON)
preview := make([]byte, 1024)
n, err := io.ReadFull(r, preview)
switch {
case err == io.ErrUnexpectedEOF:
preview = preview[:n]
r = bytes.NewReader(preview)
case err != nil:
return nil, err
default:
r = io.MultiReader(bytes.NewReader(preview), r)
}

// reset file reader
// need to read first block to detect json or metro format
// after that, need to reset seek point of reader
if sk, ok := r.(io.Seeker); ok {
sk.Seek(0, io.SeekStart)
// Look for the start of JSON
var isJSON bool
for i := range preview {
if preview[i] == '{' {
isJSON = true
break
}
}

if jsonDecodeErr != nil {
// Parse metro file
// Decode contents as Metro2 formatting when it's not JSON
if !isJSON {
return NewReader(r).Read()
}

// Parse json file
if dummy.Header == nil {
return nil, errors.New("invalid json file")
// Determine the file format
var buf bytes.Buffer
r = io.TeeReader(r, &buf)

var dummy dummyFile
err = json.NewDecoder(r).Decode(&dummy)
if err != nil {
return nil, fmt.Errorf("reading header: %w", err)
}

fileFormat := utils.CharacterFileFormat
if dummy.Header.RecordDescriptorWord == lib.UnpackedRecordLength {
fileFormat = utils.CharacterFileFormat
} else if dummy.Header.BlockDescriptorWord > 0 {
fileFormat = utils.PackedFileFormat
if dummy.Header != nil {
if dummy.Header.RecordDescriptorWord == lib.UnpackedRecordLength {
fileFormat = utils.CharacterFileFormat
} else if dummy.Header.BlockDescriptorWord > 0 {
fileFormat = utils.PackedFileFormat
}
}

// Decode the file as JSON now
f, err := NewFile(fileFormat)
if err != nil {
return nil, err
}

if err = decoder.Decode(f); err != nil {
return nil, err
r = io.MultiReader(&buf, r)
err = json.NewDecoder(r).Decode(f)
if err != nil {
return f, fmt.Errorf("reading file: %w", err)
}

return f, nil
}

Expand Down
116 changes: 116 additions & 0 deletions pkg/server/large_requests_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package server

import (
"bytes"
"compress/gzip"
"context"
"io"
"mime/multipart"
"net/http"
"os"
"path/filepath"
"testing"
"time"

"github.com/stretchr/testify/require"
)

func TestServer__LargeRequests(t *testing.T) {
timeout, _ := time.ParseDuration("30s")
handler, _ := ConfigureHandlers()

svr := &http.Server{
Addr: "0.0.0.0:15551",
Handler: handler,
ReadTimeout: timeout,
ReadHeaderTimeout: timeout,
WriteTimeout: timeout,
IdleTimeout: timeout,
}
go svr.ListenAndServe()
t.Cleanup(func() {
svr.Shutdown(context.Background())
})

// Read each file
files := []string{
filepath.Join("testdata", "10k_record.json.gz"),
}
if os.Getenv("GITHUB_ACTIONS") == "" {
// The 25k file is having issues on Github Actions right now
files = append(files, filepath.Join("testdata", "25k_record.json.gz"))
}
if false {
files = append(files, filepath.Join("testdata", "50k_record.json.gz")) // Currently fails on CI and locally
}

for i := range files {
path := files[i]

t.Run("validate "+path, func(t *testing.T) {
var body bytes.Buffer
w := multipart.NewWriter(&body)
part, err := w.CreateFormFile("file", filepath.Base(path))
require.NoError(t, err)
_, err = io.Copy(part, open(t, path))
require.NoError(t, err)
require.NoError(t, w.Close()) // flush

req, err := http.NewRequest("POST", "http://localhost:15551/validator", &body)
require.NoError(t, err)
req.Header.Set("Content-Type", w.FormDataContentType())

resp, err := http.DefaultClient.Do(req)
if resp != nil && resp.StatusCode != http.StatusOK {
if resp != nil && resp.Body != nil {
t.Cleanup(func() { resp.Body.Close() })

bs, _ := io.ReadAll(resp.Body)
t.Logf("Response: %v", string(bs))
}
require.Equal(t, http.StatusOK, resp.StatusCode)
}
require.NoError(t, err)
})

t.Run("convert "+path, func(t *testing.T) {
var body bytes.Buffer
w := multipart.NewWriter(&body)
part, err := w.CreateFormFile("file", filepath.Base(path))
require.NoError(t, err)
_, err = io.Copy(part, open(t, path))
require.NoError(t, err)
w.WriteField("format", "metro")
require.NoError(t, w.Close()) // flush

req, err := http.NewRequest("POST", "http://localhost:15551/convert", &body)
require.NoError(t, err)
req.Header.Set("Content-Type", w.FormDataContentType())

resp, err := http.DefaultClient.Do(req)
if resp != nil && resp.StatusCode != http.StatusOK {
if resp != nil && resp.Body != nil {
t.Cleanup(func() { resp.Body.Close() })

bs, _ := io.ReadAll(resp.Body)
t.Logf("Response: %v", string(bs))
}
require.Equal(t, http.StatusOK, resp.StatusCode)
}
require.NoError(t, err)
})
}
}

func open(t *testing.T, path string) io.Reader {
t.Helper()

fd, err := os.Open(path)
require.NoError(t, err)
t.Cleanup(func() { fd.Close() })

r, err := gzip.NewReader(fd)
require.NoError(t, err)

return r
}
43 changes: 24 additions & 19 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@
package server

import (
"bytes"
"encoding/json"
"errors"
"io"
"fmt"
"net/http"
"strings"

Expand All @@ -18,29 +16,24 @@ import (
)

func parseInputFromRequest(r *http.Request) (file.File, error) {
src, _, err := r.FormFile("file")
if err != nil {

buf, err := io.ReadAll(r.Body)
contentType := strings.ToLower(r.Header.Get("Content-Type"))
if strings.HasPrefix(contentType, "multipart/") {
src, _, err := r.FormFile("file")
if err != nil {
return nil, errors.New("unable to read request body")
return nil, fmt.Errorf("reading multipart request: %w", err)
}
defer src.Close()

defer r.Body.Close()

mf, err := file.NewFileFromReader(bytes.NewReader(buf))
mf, err := file.NewFileFromReader(src)
if err != nil {
return nil, err
return nil, fmt.Errorf("parsing file as multipart: %w", err)
}

return mf, nil
}

defer src.Close()

mf, err := file.NewFileFromReader(src)
mf, err := file.NewFileFromReader(r.Body)
if err != nil {
return nil, err
return nil, fmt.Errorf("parsing request body as reader: %w", err)
}
return mf, nil
}
Expand Down Expand Up @@ -70,7 +63,7 @@ func messageToBuf(format string, metroFile file.File, newline bool) ([]byte, err
case utils.MessageMetroFormat:
output = []byte(metroFile.String(newline))
default:
return nil, errors.New("invalid format")
return nil, fmt.Errorf("invalid format: %v", format)
}
return output, err
}
Expand Down Expand Up @@ -99,7 +92,7 @@ func getFormat(r *http.Request) (string, error) {
format = utils.MessageJsonFormat
}
if format != utils.MessageMetroFormat && format != utils.MessageJsonFormat {
return format, errors.New("invalid format")
return format, fmt.Errorf("invalid format: %v", format)
}
return format, nil
}
Expand All @@ -123,6 +116,10 @@ func getIsNewLine(r *http.Request) bool {
// 400: Bad Request
// 501: Not Implemented
func validator(w http.ResponseWriter, r *http.Request) {
if r != nil && r.Body != nil {
defer r.Body.Close()
}

metroFile, err := parseInputFromRequest(r)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
Expand All @@ -148,6 +145,10 @@ func validator(w http.ResponseWriter, r *http.Request) {
// 400: Bad Request
// 501: Not Implemented
func print(w http.ResponseWriter, r *http.Request) {
if r != nil && r.Body != nil {
defer r.Body.Close()
}

metroFile, err := parseInputFromRequest(r)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
Expand Down Expand Up @@ -180,6 +181,10 @@ func print(w http.ResponseWriter, r *http.Request) {
// 400: Bad Request
// 501: Not Implemented
func convert(w http.ResponseWriter, r *http.Request) {
if r != nil && r.Body != nil {
defer r.Body.Close()
}

metroFile, err := parseInputFromRequest(r)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
Expand Down
8 changes: 4 additions & 4 deletions pkg/server/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ func (t *ServerTest) TestWithInvalidForm(c *check.C) {
c.Assert(recorder.Code, check.Equals, http.StatusBadRequest)
}

func (t *ServerTest) TestPrintWithInvalidData(c *check.C) {
func (t *ServerTest) TestPrintWithoutContentType(c *check.C) {
writer, body := t.getWriter("base_segment.json", c)
err := writer.WriteField("format", "json")
c.Assert(err, check.IsNil)
Expand All @@ -227,10 +227,10 @@ func (t *ServerTest) TestPrintWithInvalidData(c *check.C) {
recorder, request := t.makeRequest(http.MethodPost, "/print", body.String(), c)
request.Header.Set("Content-Type", writer.FormDataContentType())
t.testServer.ServeHTTP(recorder, request)
c.Assert(recorder.Code, check.Equals, http.StatusBadRequest)
c.Assert(recorder.Code, check.Equals, http.StatusOK)
}

func (t *ServerTest) TestConvertWithInvalidData(c *check.C) {
func (t *ServerTest) TestConvertWithoutContentType(c *check.C) {
writer, body := t.getWriter("base_segment.json", c)
err := writer.WriteField("format", "json")
c.Assert(err, check.IsNil)
Expand All @@ -239,7 +239,7 @@ func (t *ServerTest) TestConvertWithInvalidData(c *check.C) {
recorder, request := t.makeRequest(http.MethodPost, "/convert", body.String(), c)
request.Header.Set("Content-Type", writer.FormDataContentType())
t.testServer.ServeHTTP(recorder, request)
c.Assert(recorder.Code, check.Equals, http.StatusBadRequest)
c.Assert(recorder.Code, check.Equals, http.StatusOK)
}

func (t *ServerTest) TestConvertWithValidJsonRequest(c *check.C) {
Expand Down
Binary file added pkg/server/testdata/10k_record.json.gz
Binary file not shown.
Binary file added pkg/server/testdata/25k_record.json.gz
Binary file not shown.
Binary file added pkg/server/testdata/50k_record.json.gz
Binary file not shown.
Loading