From 00aa3a2d5ad1914288c43ec9a1a4083ed1adf1e5 Mon Sep 17 00:00:00 2001 From: nathannaveen <42319948+nathannaveen@users.noreply.github.com> Date: Sun, 8 Jan 2023 14:06:16 -0600 Subject: [PATCH 1/2] Included tests for internal/signalio/csv - Included tests for the functions `maybeWriteHeader()` and `marshalValue` in `internal/signalio/csv`. - Included additional comments for `maybeWriteHeader()`. Signed-off-by: nathannaveen <42319948+nathannaveen@users.noreply.github.com> --- internal/signalio/csv.go | 35 ++++++++++++-- internal/signalio/csv_test.go | 86 +++++++++++++++++++++++++++++++++++ 2 files changed, 117 insertions(+), 4 deletions(-) create mode 100644 internal/signalio/csv_test.go diff --git a/internal/signalio/csv.go b/internal/signalio/csv.go index 4a5b3dd05..eac3f2879 100644 --- a/internal/signalio/csv.go +++ b/internal/signalio/csv.go @@ -50,13 +50,40 @@ func (w *csvWriter) WriteSignals(signals []signal.Set, extra ...Field) error { } func (w *csvWriter) maybeWriteHeader() error { - // Check headerWritten without the lock to avoid holding the lock if the - // header has already been written. + /* + The variable w.headerWritten is checked twice to avoid what is known as a "race condition". + A race condition can occur when two or more goroutines try to access a shared resource + (in this case, the csvWriter instance) concurrently, and the outcome of the program depends on + the interleaving of their execution. + + Imagine the following scenario: + + 1. Goroutine A reads the value of w.headerWritten as false. + 2. Goroutine B reads the value of w.headerWritten as false. + 3. Goroutine A acquires the mutex lock and sets w.headerWritten to true. + 4. Goroutine B acquires the mutex lock and sets w.headerWritten to true. + + If this happens, the header will be written twice, which is not the desired behavior. + By checking w.headerWritten twice, once before acquiring the mutex lock and once after acquiring the lock, + the function can ensure that only one goroutine enters the critical section and writes the header. + + Here's how the function works: + + 1. Goroutine A reads the value of w.headerWritten as false. + 2. Goroutine A acquires the mutex lock. + 3. Goroutine A re-checks the value of w.headerWritten and finds it to be false. + 4. Goroutine A sets w.headerWritten to true and writes the header. + 5. Goroutine A releases the mutex lock. + + If Goroutine B tries to enter the critical section at any point after step 2, + it will have to wait until Goroutine A releases the lock in step 5. Once the lock is released, + Goroutine B will re-check the value of w.headerWritten and find it to be true, + so it will not write the header again. + */ + if w.headerWritten { return nil } - // Grab the lock and re-check headerWritten just in case another goroutine - // entered the same critical section. Also prevent concurrent writes to w. w.mu.Lock() defer w.mu.Unlock() if w.headerWritten { diff --git a/internal/signalio/csv_test.go b/internal/signalio/csv_test.go new file mode 100644 index 000000000..f138f3e01 --- /dev/null +++ b/internal/signalio/csv_test.go @@ -0,0 +1,86 @@ +package signalio + +import ( + "encoding/csv" + "sync" + "testing" + "time" +) + +func TestMarshalValue(t *testing.T) { + tests := []struct { + value any + expected string + wantErr bool + }{ + {value: true, expected: "true", wantErr: false}, + {value: 1, expected: "1", wantErr: false}, + {value: int16(2), expected: "2", wantErr: false}, + {value: int32(3), expected: "3", wantErr: false}, + {value: int64(4), expected: "4", wantErr: false}, + {value: uint(5), expected: "5", wantErr: false}, + {value: uint16(6), expected: "6", wantErr: false}, + {value: uint32(7), expected: "7", wantErr: false}, + {value: uint64(8), expected: "8", wantErr: false}, + {value: byte(9), expected: "9", wantErr: false}, + {value: float32(10.1), expected: "10.1", wantErr: false}, + {value: 11.1, expected: "11.1", wantErr: false}, // float64 + {value: "test", expected: "test", wantErr: false}, + {value: time.Now(), expected: time.Now().Format(time.RFC3339), wantErr: false}, + {value: nil, expected: "", wantErr: false}, + {value: []int{1, 2, 3}, expected: "", wantErr: true}, + {value: map[string]string{"key": "value"}, expected: "", wantErr: true}, + {value: struct{}{}, expected: "", wantErr: true}, + } + for _, test := range tests { + res, err := marshalValue(test.value) + if (err != nil) != test.wantErr { + t.Errorf("Unexpected error for value %v: wantErr %v, got %v", test.value, test.wantErr, err) + } + if res != test.expected { + t.Errorf("Unexpected result for value %v: expected %v, got %v", test.value, test.expected, res) + } + } +} + +func Test_csvWriter_maybeWriteHeader(t *testing.T) { + type fields struct { + w *csv.Writer + header []string + headerWritten bool + } + tests := []struct { + name string + fields fields + }{ + { + name: "write header", + fields: fields{ + w: csv.NewWriter(nil), + header: []string{}, + headerWritten: false, + }, + }, + { + name: "header already written", + fields: fields{ + w: csv.NewWriter(nil), + header: []string{"a", "b", "c"}, + headerWritten: true, + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + w := &csvWriter{ + w: test.fields.w, + header: test.fields.header, + headerWritten: test.fields.headerWritten, + mu: sync.Mutex{}, + } + if err := w.maybeWriteHeader(); err != nil { // never want an error + t.Errorf("maybeWriteHeader() error = %v", err) + } + }) + } +} From 53a5ee148929d2fdb87340c4f58577e4dd2d4a3f Mon Sep 17 00:00:00 2001 From: nathannaveen <42319948+nathannaveen@users.noreply.github.com> Date: Mon, 9 Jan 2023 12:07:06 -0600 Subject: [PATCH 2/2] Included additional tests for csv - Inlcuded additional tests and refactored some variable names Signed-off-by: nathannaveen <42319948+nathannaveen@users.noreply.github.com> --- internal/signalio/csv.go | 60 +++++++-------- internal/signalio/csv_test.go | 135 +++++++++++++++++++++++++++++++++- 2 files changed, 164 insertions(+), 31 deletions(-) diff --git a/internal/signalio/csv.go b/internal/signalio/csv.go index eac3f2879..53bc347c5 100644 --- a/internal/signalio/csv.go +++ b/internal/signalio/csv.go @@ -33,83 +33,83 @@ type csvWriter struct { mu sync.Mutex } -func CSVWriter(w io.Writer, emptySets []signal.Set, extra ...string) Writer { +func CSVWriter(writer io.Writer, emptySets []signal.Set, extra ...string) Writer { return &csvWriter{ header: fieldsFromSignalSets(emptySets, extra), - w: csv.NewWriter(w), + w: csv.NewWriter(writer), } } // WriteSignals implements the Writer interface. -func (w *csvWriter) WriteSignals(signals []signal.Set, extra ...Field) error { +func (writer *csvWriter) WriteSignals(signals []signal.Set, extra ...Field) error { values, err := marshalToMap(signals, extra...) if err != nil { return err } - return w.writeRecord(values) + return writer.writeRecord(values) } -func (w *csvWriter) maybeWriteHeader() error { +func (writer *csvWriter) maybeWriteHeader() error { /* - The variable w.headerWritten is checked twice to avoid what is known as a "race condition". + The variable writer.headerWritten is checked twice to avoid what is known as a "race condition". A race condition can occur when two or more goroutines try to access a shared resource (in this case, the csvWriter instance) concurrently, and the outcome of the program depends on the interleaving of their execution. Imagine the following scenario: - 1. Goroutine A reads the value of w.headerWritten as false. - 2. Goroutine B reads the value of w.headerWritten as false. - 3. Goroutine A acquires the mutex lock and sets w.headerWritten to true. - 4. Goroutine B acquires the mutex lock and sets w.headerWritten to true. + 1. Goroutine A reads the value of writer.headerWritten as false. + 2. Goroutine B reads the value of writer.headerWritten as false. + 3. Goroutine A acquires the mutex lock and sets writer.headerWritten to true. + 4. Goroutine B acquires the mutex lock and sets writer.headerWritten to true. If this happens, the header will be written twice, which is not the desired behavior. - By checking w.headerWritten twice, once before acquiring the mutex lock and once after acquiring the lock, + By checking writer.headerWritten twice, once before acquiring the mutex lock and once after acquiring the lock, the function can ensure that only one goroutine enters the critical section and writes the header. Here's how the function works: - 1. Goroutine A reads the value of w.headerWritten as false. + 1. Goroutine A reads the value of writer.headerWritten as false. 2. Goroutine A acquires the mutex lock. - 3. Goroutine A re-checks the value of w.headerWritten and finds it to be false. - 4. Goroutine A sets w.headerWritten to true and writes the header. + 3. Goroutine A re-checks the value of writer.headerWritten and finds it to be false. + 4. Goroutine A sets writer.headerWritten to true and writes the header. 5. Goroutine A releases the mutex lock. If Goroutine B tries to enter the critical section at any point after step 2, it will have to wait until Goroutine A releases the lock in step 5. Once the lock is released, - Goroutine B will re-check the value of w.headerWritten and find it to be true, + Goroutine B will re-check the value of writer.headerWritten and find it to be true, so it will not write the header again. */ - if w.headerWritten { + if writer.headerWritten { return nil } - w.mu.Lock() - defer w.mu.Unlock() - if w.headerWritten { + writer.mu.Lock() + defer writer.mu.Unlock() + if writer.headerWritten { return nil } - w.headerWritten = true - return w.w.Write(w.header) + writer.headerWritten = true + return writer.w.Write(writer.header) } -func (w *csvWriter) writeRecord(values map[string]string) error { - if err := w.maybeWriteHeader(); err != nil { +func (writer *csvWriter) writeRecord(values map[string]string) error { + if err := writer.maybeWriteHeader(); err != nil { return err } var rec []string - for _, k := range w.header { + for _, k := range writer.header { rec = append(rec, values[k]) } // Grab the lock when we're ready to write the record to prevent - // concurrent writes to w. - w.mu.Lock() - defer w.mu.Unlock() - if err := w.w.Write(rec); err != nil { + // concurrent writes to writer. + writer.mu.Lock() + defer writer.mu.Unlock() + if err := writer.w.Write(rec); err != nil { return err } - w.w.Flush() - return w.w.Error() + writer.w.Flush() + return writer.w.Error() } func marshalValue(value any) (string, error) { diff --git a/internal/signalio/csv_test.go b/internal/signalio/csv_test.go index f138f3e01..5cd0fb66f 100644 --- a/internal/signalio/csv_test.go +++ b/internal/signalio/csv_test.go @@ -5,6 +5,8 @@ import ( "sync" "testing" "time" + + "github.com/ossf/criticality_score/internal/collector/signal" ) func TestMarshalValue(t *testing.T) { @@ -78,9 +80,140 @@ func Test_csvWriter_maybeWriteHeader(t *testing.T) { headerWritten: test.fields.headerWritten, mu: sync.Mutex{}, } - if err := w.maybeWriteHeader(); err != nil { // never want an error + if err := w.maybeWriteHeader(); err != nil { // never want an error with these test cases t.Errorf("maybeWriteHeader() error = %v", err) } }) } } + +func Test_csvWriter_writeRecord(t *testing.T) { + type fields struct { + w *csv.Writer + header []string + headerWritten bool + } + tests := []struct { //nolint:govet + name string + fields fields + values map[string]string + wantErr bool + }{ + { + name: "write record with regular error", + fields: fields{ + w: csv.NewWriter(&mockWriter{ + written: []byte{'a', 'b', 'c'}, + error: nil, + }), + header: []string{"a", "b", "c"}, + headerWritten: true, + }, + wantErr: true, + }, + { + name: "write record with write error", + fields: fields{ + w: &csv.Writer{}, + header: []string{"a", "b", "c"}, + headerWritten: true, + }, + wantErr: true, + }, + { + name: "write record with maybeWriteHeader error", + fields: fields{ + w: &csv.Writer{}, + header: []string{"a", "b", "c"}, + headerWritten: false, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := &csvWriter{ + w: tt.fields.w, + header: tt.fields.header, + headerWritten: tt.fields.headerWritten, + mu: sync.Mutex{}, + } + if err := w.writeRecord(tt.values); (err != nil) != tt.wantErr { + t.Errorf("writeRecord() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +type mockWriter struct { //nolint:govet + written []byte + error error +} + +func (m *mockWriter) Write(p []byte) (n int, err error) { + return 0, m.error +} + +func Test_csvWriter_WriteSignals(t *testing.T) { + type args struct { + signals []signal.Set + extra []Field + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "write signals with marshal error", + args: args{ + signals: []signal.Set{ + &testSet{ + UpdatedCount: signal.Val(1), + }, + }, + extra: []Field{ + { + Key: "a", + Value: []int{1, 2, 3}, + }, + { + Key: "b", + Value: map[string]string{"key": "value"}, + }, + }, + }, + wantErr: true, + }, + { + name: "write signals with write error", + args: args{ + extra: []Field{ + { + Key: "a", + Value: "1", + }, + }, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + writer := CSVWriter(&mockWriter{}, []signal.Set{}, "a", "b") + + if err := writer.WriteSignals(tt.args.signals, tt.args.extra...); (err != nil) != tt.wantErr { + t.Errorf("WriteSignals() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +type testSet struct { //nolint:govet + UpdatedCount signal.Field[int] + Field string +} + +func (t testSet) Namespace() signal.Namespace { + return "test" +}