diff --git a/internal/store/database/dbtx/ctx.go b/internal/store/database/dbtx/ctx.go index 93ef09f34..d80a51c40 100644 --- a/internal/store/database/dbtx/ctx.go +++ b/internal/store/database/dbtx/ctx.go @@ -32,7 +32,7 @@ func GetAccessor(ctx context.Context, db *sqlx.DB) Accessor { if a, ok := ctx.Value(ctxKeyTx{}).(Accessor); ok { return a } - return db + return New(db) } // GetTransaction returns Transaction interface from the context if it exists or return nil. diff --git a/internal/store/database/dbtx/db.go b/internal/store/database/dbtx/db.go index 1d1853a79..9c7cbd8f4 100644 --- a/internal/store/database/dbtx/db.go +++ b/internal/store/database/dbtx/db.go @@ -13,7 +13,11 @@ import ( // New returns new database Runner interface. func New(db *sqlx.DB) Transactor { - run := &runnerDB{sqlDB{db}} + mx := getLocker(db) + run := &runnerDB{ + db: sqlDB{db}, + mx: mx, + } return run } diff --git a/internal/store/database/dbtx/locker.go b/internal/store/database/dbtx/locker.go new file mode 100644 index 000000000..a575f30f1 --- /dev/null +++ b/internal/store/database/dbtx/locker.go @@ -0,0 +1,42 @@ +// Copyright 2022 Harness Inc. All rights reserved. +// Use of this source code is governed by the Polyform Free Trial License +// that can be found in the LICENSE.md file for this repository. + +package dbtx + +import ( + "sync" + + "github.com/jmoiron/sqlx" +) + +const ( + postgres = "postgres" +) + +type locker interface { + Lock() + Unlock() + RLock() + RUnlock() +} + +var globalMx sync.RWMutex + +func needsLocking(driver string) bool { + return driver != postgres +} + +func getLocker(db *sqlx.DB) locker { + if needsLocking(db.DriverName()) { + return &globalMx + } + return lockerNop{} +} + +type lockerNop struct{} + +func (lockerNop) RLock() {} +func (lockerNop) RUnlock() {} +func (lockerNop) Lock() {} +func (lockerNop) Unlock() {} diff --git a/internal/store/database/dbtx/runner.go b/internal/store/database/dbtx/runner.go index e403e7662..2bf8c417c 100644 --- a/internal/store/database/dbtx/runner.go +++ b/internal/store/database/dbtx/runner.go @@ -8,10 +8,15 @@ import ( "context" "database/sql" "errors" + + "github.com/jmoiron/sqlx" ) +// runnerDB executes individual sqlx database calls wrapped with the locker calls (Lock/Unlock) +// or a transaction wrapped with the locker calls (RLock/RUnlock or Lock/Unlock depending on the transaction type). type runnerDB struct { - transactor + db transactor + mx locker } var _ Transactor = runnerDB{} @@ -28,7 +33,15 @@ func (r runnerDB) WithTx(ctx context.Context, txFn func(context.Context) error, txOpts = TxDefault } - tx, err := r.startTx(ctx, txOpts) + if txOpts.ReadOnly { + r.mx.RLock() + defer r.mx.RUnlock() + } else { + r.mx.Lock() + defer r.mx.Unlock() + } + + tx, err := r.db.startTx(ctx, txOpts) if err != nil { return err } @@ -64,6 +77,80 @@ func (r runnerDB) WithTx(ctx context.Context, txFn func(context.Context) error, return err } +func (r runnerDB) DriverName() string { + return r.db.DriverName() +} + +func (r runnerDB) Rebind(query string) string { + return r.db.Rebind(query) +} + +func (r runnerDB) BindNamed(query string, arg interface{}) (string, []interface{}, error) { + return r.db.BindNamed(query, arg) +} + +func (r runnerDB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + r.mx.Lock() + defer r.mx.Unlock() + return r.db.QueryContext(ctx, query, args...) +} + +func (r runnerDB) QueryxContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) { + r.mx.Lock() + defer r.mx.Unlock() + return r.db.QueryxContext(ctx, query, args...) +} + +func (r runnerDB) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *sqlx.Row { + r.mx.Lock() + defer r.mx.Unlock() + return r.db.QueryRowxContext(ctx, query, args...) +} + +func (r runnerDB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + r.mx.Lock() + defer r.mx.Unlock() + return r.db.ExecContext(ctx, query, args...) +} + +func (r runnerDB) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row { + r.mx.Lock() + defer r.mx.Unlock() + return r.db.QueryRowContext(ctx, query, args...) +} + +func (r runnerDB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { + r.mx.Lock() + defer r.mx.Unlock() + return r.db.PrepareContext(ctx, query) +} + +func (r runnerDB) PreparexContext(ctx context.Context, query string) (*sqlx.Stmt, error) { + r.mx.Lock() + defer r.mx.Unlock() + return r.db.PreparexContext(ctx, query) +} + +func (r runnerDB) PrepareNamedContext(ctx context.Context, query string) (*sqlx.NamedStmt, error) { + r.mx.Lock() + defer r.mx.Unlock() + return r.db.PrepareNamedContext(ctx, query) +} + +func (r runnerDB) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { + r.mx.Lock() + defer r.mx.Unlock() + return r.db.GetContext(ctx, dest, query, args...) +} + +func (r runnerDB) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { + r.mx.Lock() + defer r.mx.Unlock() + return r.db.SelectContext(ctx, dest, query, args...) +} + +// runnerTx executes sqlx database transaction calls. +// Locking is not used because runnerDB locks the entire transaction. type runnerTx struct { Tx commit bool diff --git a/internal/store/database/dbtx/runner_test.go b/internal/store/database/dbtx/runner_test.go index c94050db4..d964e060e 100644 --- a/internal/store/database/dbtx/runner_test.go +++ b/internal/store/database/dbtx/runner_test.go @@ -11,6 +11,7 @@ import ( "testing" "github.com/jmoiron/sqlx" + "github.com/stretchr/testify/assert" ) //nolint:gocognit @@ -106,7 +107,10 @@ func TestWithTx(t *testing.T) { t: t, errCommit: test.errCommit, } - run := &runnerDB{mock} + run := &runnerDB{ + db: mock, + mx: lockerNop{}, + } ctx, cancelFn := context.WithCancel(context.Background()) defer cancelFn() @@ -203,3 +207,129 @@ func (tx *txMock) Rollback() error { tx.rollback = true return nil } + +func TestLocking(t *testing.T) { + const dummyQuery = "" + tests := []struct { + name string + fn func(db Transactor, l *lockerCounter) + }{ + { + name: "exec-lock", + fn: func(db Transactor, l *lockerCounter) { + ctx := context.Background() + _, _ = db.ExecContext(ctx, dummyQuery) + _, _ = db.ExecContext(ctx, dummyQuery) + _, _ = db.ExecContext(ctx, dummyQuery) + + assert.Zero(t, l.RLocks) + assert.Zero(t, l.RUnlocks) + assert.Equal(t, 3, l.Locks) + assert.Equal(t, 3, l.Unlocks) + }, + }, + { + name: "tx-lock", + fn: func(db Transactor, l *lockerCounter) { + ctx := context.Background() + _ = db.WithTx(ctx, func(ctx context.Context) error { + _, _ = GetAccessor(ctx, nil).ExecContext(ctx, dummyQuery) + _, _ = GetAccessor(ctx, nil).ExecContext(ctx, dummyQuery) + return nil + }) + + assert.Zero(t, l.RLocks) + assert.Zero(t, l.RUnlocks) + assert.Equal(t, 1, l.Locks) + assert.Equal(t, 1, l.Unlocks) + }, + }, + { + name: "tx-read-lock", + fn: func(db Transactor, l *lockerCounter) { + ctx := context.Background() + _ = db.WithTx(ctx, func(ctx context.Context) error { + _, _ = GetAccessor(ctx, nil).QueryContext(ctx, dummyQuery) + _, _ = GetAccessor(ctx, nil).QueryContext(ctx, dummyQuery) + return nil + }, TxDefaultReadOnly) + + assert.Equal(t, 1, l.RLocks) + assert.Equal(t, 1, l.RUnlocks) + assert.Zero(t, l.Locks) + assert.Zero(t, l.Unlocks) + }, + }, + } + + for _, test := range tests { + l := &lockerCounter{} + t.Run(test.name, func(t *testing.T) { + test.fn(runnerDB{ + db: dbMockNop{}, + mx: l, + }, l) + }) + } +} + +type lockerCounter struct { + Locks int + Unlocks int + RLocks int + RUnlocks int +} + +func (l *lockerCounter) Lock() { l.Locks++ } +func (l *lockerCounter) Unlock() { l.Unlocks++ } +func (l *lockerCounter) RLock() { l.RLocks++ } +func (l *lockerCounter) RUnlock() { l.RUnlocks++ } + +type dbMockNop struct{} + +func (dbMockNop) DriverName() string { return "" } +func (dbMockNop) Rebind(string) string { return "" } +func (dbMockNop) BindNamed(string, interface{}) (string, []interface{}, error) { return "", nil, nil } + +//nolint:nilnil // it's a mock +func (dbMockNop) QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) { + return nil, nil +} + +//nolint:nilnil // it's a mock +func (dbMockNop) QueryxContext(context.Context, string, ...interface{}) (*sqlx.Rows, error) { + return nil, nil +} +func (dbMockNop) QueryRowxContext(context.Context, string, ...interface{}) *sqlx.Row { return nil } +func (dbMockNop) ExecContext(context.Context, string, ...interface{}) (sql.Result, error) { + return nil, nil +} +func (dbMockNop) QueryRowContext(context.Context, string, ...any) *sql.Row { + return nil +} + +//nolint:nilnil // it's a mock +func (dbMockNop) PrepareContext(context.Context, string) (*sql.Stmt, error) { + return nil, nil +} + +//nolint:nilnil // it's a mock +func (dbMockNop) PreparexContext(context.Context, string) (*sqlx.Stmt, error) { + return nil, nil +} + +//nolint:nilnil // it's a mock +func (dbMockNop) PrepareNamedContext(context.Context, string) (*sqlx.NamedStmt, error) { + return nil, nil +} +func (dbMockNop) GetContext(context.Context, interface{}, string, ...interface{}) error { + return nil +} +func (dbMockNop) SelectContext(context.Context, interface{}, string, ...interface{}) error { + return nil +} + +func (dbMockNop) Commit() error { return nil } +func (dbMockNop) Rollback() error { return nil } + +func (d dbMockNop) startTx(context.Context, *sql.TxOptions) (Tx, error) { return d, nil }