From 8b88f63ced31bd71b5d751383850a2d8e864e3bf Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Wed, 11 Dec 2024 03:46:06 +0000 Subject: [PATCH] test(cardinal): improve component type safety and message testing Co-Authored-By: Scott Sunarto --- cardinal/testsuite/components.go | 10 ++++++++++ cardinal/testsuite/utils_test.go | 22 ++++++++++++---------- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/cardinal/testsuite/components.go b/cardinal/testsuite/components.go index 0f941de1f..d9c153a75 100644 --- a/cardinal/testsuite/components.go +++ b/cardinal/testsuite/components.go @@ -1,10 +1,16 @@ package testsuite +import ( + "pkg.world.dev/world-engine/cardinal/types" +) + // LocationComponent is a test component for location-based tests type LocationComponent struct { X, Y uint64 } +var _ types.Component = (*LocationComponent)(nil) + func (l LocationComponent) Name() string { return "location" } @@ -14,6 +20,8 @@ type ValueComponent struct { Value int64 } +var _ types.Component = (*ValueComponent)(nil) + func (v ValueComponent) Name() string { return "value" } @@ -23,6 +31,8 @@ type PowerComponent struct { Power int64 } +var _ types.Component = (*PowerComponent)(nil) + func (p PowerComponent) Name() string { return "power" } diff --git a/cardinal/testsuite/utils_test.go b/cardinal/testsuite/utils_test.go index 8662483f7..5ec7e08e1 100644 --- a/cardinal/testsuite/utils_test.go +++ b/cardinal/testsuite/utils_test.go @@ -2,13 +2,13 @@ package testsuite import ( "errors" + "reflect" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "pkg.world.dev/world-engine/cardinal" "pkg.world.dev/world-engine/cardinal/types" ) @@ -148,17 +148,17 @@ func (t *testOutputMsg) GetInFieldInformation() map[string]any { return map[stri func TestGetMessage(t *testing.T) { tests := []struct { name string - msgType string + msgID types.MessageID shouldError bool }{ { name: "get registered message", - msgType: "test.test_input_msg", + msgID: 1, shouldError: false, }, { name: "get unregistered message", - msgType: "test.unregistered_msg", + msgID: 999, shouldError: true, }, } @@ -166,20 +166,22 @@ func TestGetMessage(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Create a test world - world := cardinal.NewTestWorld(t) + world := NewTestWorld(t) // Register test messages - err := world.RegisterMessage(&testInputMsg{}) + err := world.RegisterMessage(&testInputMsg{}, reflect.TypeOf(testInputMsg{})) require.NoError(t, err) - err = world.RegisterMessage(&testOutputMsg{}) + err = world.RegisterMessage(&testOutputMsg{}, reflect.TypeOf(testOutputMsg{})) require.NoError(t, err) // Test message retrieval - _, err = world.GetMessage(tt.msgType) + msg, found := world.GetMessageByID(tt.msgID) if tt.shouldError { - require.Error(t, err) + assert.False(t, found) + assert.Nil(t, msg) } else { - require.NoError(t, err) + assert.True(t, found) + assert.NotNil(t, msg) } }) }