diff --git a/v2/internal/runtime/store.go b/v2/internal/runtime/store.go index 9d7822747..4199154c6 100644 --- a/v2/internal/runtime/store.go +++ b/v2/internal/runtime/store.go @@ -71,19 +71,16 @@ func (p *StoreProvider) New(name string, defaultValue interface{}) *Store { golog.Fatal("Cannot initialise a store with nil") } - dataCopy := deepcopy.Copy(defaultValue) - dataType := reflect.TypeOf(dataCopy) - result := Store{ - name: name, - runtime: p.runtime, - data: reflect.ValueOf(dataCopy), - dataType: dataType, + name: name, + runtime: p.runtime, } // Setup the sync listener result.setupListener() + result.Set(defaultValue) + return &result } @@ -215,13 +212,20 @@ func (s *Store) notify() { } } -// set will update the data held by the store +// Set will update the data held by the store // and notify listeners of the change -func (s *Store) set(data interface{}) error { +func (s *Store) Set(data interface{}) error { + + if data == nil { + return fmt.Errorf("cannot set store to nil") + } + s.lock() - if data != nil { - inType := reflect.TypeOf(data) + dataCopy := deepcopy.Copy(data) + + if dataCopy != nil { + inType := reflect.TypeOf(dataCopy) if inType != s.dataType && s.data.IsValid() { s.unlock() @@ -230,11 +234,11 @@ func (s *Store) set(data interface{}) error { } if s.dataType == nil { - s.dataType = reflect.TypeOf(data) + s.dataType = reflect.TypeOf(dataCopy) } // Save data - s.data = reflect.ValueOf(data) + s.data = reflect.ValueOf(dataCopy) s.unlock() @@ -347,7 +351,7 @@ func (s *Store) Update(updater interface{}) { results := reflect.ValueOf(updater).Call(args) // We will only have 1 result. Set the store to it - s.set(results[0].Interface()) + s.Set(results[0].Interface()) } // Get returns the value of the data that's kept in the current state / Store diff --git a/v2/internal/runtime/store_test.go b/v2/internal/runtime/store_test.go index 93d866214..62327c578 100644 --- a/v2/internal/runtime/store_test.go +++ b/v2/internal/runtime/store_test.go @@ -14,6 +14,33 @@ import ( is2 "github.com/matryer/is" ) +func TestStoreProvider_NewWithNilDefault(t *testing.T) { + is := is2.New(t) + + defaultLogger := logger.NewDefaultLogger() + testLogger := internallogger.New(defaultLogger) + //testLogger.SetLogLevel(logger.TRACE) + serviceBus := servicebus.New(testLogger) + err := serviceBus.Start() + is.NoErr(err) + defer serviceBus.Stop() + + testRuntime := New(serviceBus) + storeProvider := newStore(testRuntime) + + testStore := storeProvider.New("test", 0) + + // You should be able to write a new value into a + // store initialised with nil + err = testStore.Set(100) + is.NoErr(err) + + // You shouldn't be able to write different types to the + // store + err = testStore.Set(false) + is.True(err != nil) +} + func TestStoreProvider_NewWithScalarDefault(t *testing.T) { is := is2.New(t) @@ -58,9 +85,8 @@ func TestStoreProvider_NewWithStructDefault(t *testing.T) { testStore := storeProvider.New("test", testValue) - testStore.Update(func(current *TestValue) *TestValue { - return testValue - }) + err = testStore.Set(testValue) + is.NoErr(err) testStore.resync() value := testStore.Get() is.Equal(value, testValue) @@ -69,9 +95,8 @@ func TestStoreProvider_NewWithStructDefault(t *testing.T) { testValue = &TestValue{ Name: "there", } - testStore.Update(func(current *TestValue) *TestValue { - return testValue - }) + err = testStore.Set(testValue) + is.NoErr(err) testStore.resync() value = testStore.Get() is.Equal(value, testValue) @@ -128,9 +153,8 @@ func TestStoreProvider_RapidReadWrite(t *testing.T) { wg.Done() return default: - store.Update(func(current int) int { - return rand.Int() - }) + err := store.Set(rand.Int()) + is.NoErr(err) } } }(testStore, ctx, writerCount)