global mutex for the dbtx package (#118)

This commit is contained in:
Marko Gaćeša 2022-12-13 17:52:16 +01:00 committed by GitHub
parent 204548169d
commit a734db6f0c
5 changed files with 268 additions and 5 deletions

View File

@ -32,7 +32,7 @@ func GetAccessor(ctx context.Context, db *sqlx.DB) Accessor {
if a, ok := ctx.Value(ctxKeyTx{}).(Accessor); ok { if a, ok := ctx.Value(ctxKeyTx{}).(Accessor); ok {
return a return a
} }
return db return New(db)
} }
// GetTransaction returns Transaction interface from the context if it exists or return nil. // GetTransaction returns Transaction interface from the context if it exists or return nil.

View File

@ -13,7 +13,11 @@ import (
// New returns new database Runner interface. // New returns new database Runner interface.
func New(db *sqlx.DB) Transactor { func New(db *sqlx.DB) Transactor {
run := &runnerDB{sqlDB{db}} mx := getLocker(db)
run := &runnerDB{
db: sqlDB{db},
mx: mx,
}
return run return run
} }

View File

@ -0,0 +1,42 @@
// Copyright 2022 Harness Inc. All rights reserved.
// Use of this source code is governed by the Polyform Free Trial License
// that can be found in the LICENSE.md file for this repository.
package dbtx
import (
"sync"
"github.com/jmoiron/sqlx"
)
const (
postgres = "postgres"
)
type locker interface {
Lock()
Unlock()
RLock()
RUnlock()
}
var globalMx sync.RWMutex
func needsLocking(driver string) bool {
return driver != postgres
}
func getLocker(db *sqlx.DB) locker {
if needsLocking(db.DriverName()) {
return &globalMx
}
return lockerNop{}
}
type lockerNop struct{}
func (lockerNop) RLock() {}
func (lockerNop) RUnlock() {}
func (lockerNop) Lock() {}
func (lockerNop) Unlock() {}

View File

@ -8,10 +8,15 @@ import (
"context" "context"
"database/sql" "database/sql"
"errors" "errors"
"github.com/jmoiron/sqlx"
) )
// runnerDB executes individual sqlx database calls wrapped with the locker calls (Lock/Unlock)
// or a transaction wrapped with the locker calls (RLock/RUnlock or Lock/Unlock depending on the transaction type).
type runnerDB struct { type runnerDB struct {
transactor db transactor
mx locker
} }
var _ Transactor = runnerDB{} var _ Transactor = runnerDB{}
@ -28,7 +33,15 @@ func (r runnerDB) WithTx(ctx context.Context, txFn func(context.Context) error,
txOpts = TxDefault txOpts = TxDefault
} }
tx, err := r.startTx(ctx, txOpts) if txOpts.ReadOnly {
r.mx.RLock()
defer r.mx.RUnlock()
} else {
r.mx.Lock()
defer r.mx.Unlock()
}
tx, err := r.db.startTx(ctx, txOpts)
if err != nil { if err != nil {
return err return err
} }
@ -64,6 +77,80 @@ func (r runnerDB) WithTx(ctx context.Context, txFn func(context.Context) error,
return err return err
} }
func (r runnerDB) DriverName() string {
return r.db.DriverName()
}
func (r runnerDB) Rebind(query string) string {
return r.db.Rebind(query)
}
func (r runnerDB) BindNamed(query string, arg interface{}) (string, []interface{}, error) {
return r.db.BindNamed(query, arg)
}
func (r runnerDB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
r.mx.Lock()
defer r.mx.Unlock()
return r.db.QueryContext(ctx, query, args...)
}
func (r runnerDB) QueryxContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) {
r.mx.Lock()
defer r.mx.Unlock()
return r.db.QueryxContext(ctx, query, args...)
}
func (r runnerDB) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *sqlx.Row {
r.mx.Lock()
defer r.mx.Unlock()
return r.db.QueryRowxContext(ctx, query, args...)
}
func (r runnerDB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
r.mx.Lock()
defer r.mx.Unlock()
return r.db.ExecContext(ctx, query, args...)
}
func (r runnerDB) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
r.mx.Lock()
defer r.mx.Unlock()
return r.db.QueryRowContext(ctx, query, args...)
}
func (r runnerDB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
r.mx.Lock()
defer r.mx.Unlock()
return r.db.PrepareContext(ctx, query)
}
func (r runnerDB) PreparexContext(ctx context.Context, query string) (*sqlx.Stmt, error) {
r.mx.Lock()
defer r.mx.Unlock()
return r.db.PreparexContext(ctx, query)
}
func (r runnerDB) PrepareNamedContext(ctx context.Context, query string) (*sqlx.NamedStmt, error) {
r.mx.Lock()
defer r.mx.Unlock()
return r.db.PrepareNamedContext(ctx, query)
}
func (r runnerDB) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
r.mx.Lock()
defer r.mx.Unlock()
return r.db.GetContext(ctx, dest, query, args...)
}
func (r runnerDB) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
r.mx.Lock()
defer r.mx.Unlock()
return r.db.SelectContext(ctx, dest, query, args...)
}
// runnerTx executes sqlx database transaction calls.
// Locking is not used because runnerDB locks the entire transaction.
type runnerTx struct { type runnerTx struct {
Tx Tx
commit bool commit bool

View File

@ -11,6 +11,7 @@ import (
"testing" "testing"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
"github.com/stretchr/testify/assert"
) )
//nolint:gocognit //nolint:gocognit
@ -106,7 +107,10 @@ func TestWithTx(t *testing.T) {
t: t, t: t,
errCommit: test.errCommit, errCommit: test.errCommit,
} }
run := &runnerDB{mock} run := &runnerDB{
db: mock,
mx: lockerNop{},
}
ctx, cancelFn := context.WithCancel(context.Background()) ctx, cancelFn := context.WithCancel(context.Background())
defer cancelFn() defer cancelFn()
@ -203,3 +207,129 @@ func (tx *txMock) Rollback() error {
tx.rollback = true tx.rollback = true
return nil return nil
} }
func TestLocking(t *testing.T) {
const dummyQuery = ""
tests := []struct {
name string
fn func(db Transactor, l *lockerCounter)
}{
{
name: "exec-lock",
fn: func(db Transactor, l *lockerCounter) {
ctx := context.Background()
_, _ = db.ExecContext(ctx, dummyQuery)
_, _ = db.ExecContext(ctx, dummyQuery)
_, _ = db.ExecContext(ctx, dummyQuery)
assert.Zero(t, l.RLocks)
assert.Zero(t, l.RUnlocks)
assert.Equal(t, 3, l.Locks)
assert.Equal(t, 3, l.Unlocks)
},
},
{
name: "tx-lock",
fn: func(db Transactor, l *lockerCounter) {
ctx := context.Background()
_ = db.WithTx(ctx, func(ctx context.Context) error {
_, _ = GetAccessor(ctx, nil).ExecContext(ctx, dummyQuery)
_, _ = GetAccessor(ctx, nil).ExecContext(ctx, dummyQuery)
return nil
})
assert.Zero(t, l.RLocks)
assert.Zero(t, l.RUnlocks)
assert.Equal(t, 1, l.Locks)
assert.Equal(t, 1, l.Unlocks)
},
},
{
name: "tx-read-lock",
fn: func(db Transactor, l *lockerCounter) {
ctx := context.Background()
_ = db.WithTx(ctx, func(ctx context.Context) error {
_, _ = GetAccessor(ctx, nil).QueryContext(ctx, dummyQuery)
_, _ = GetAccessor(ctx, nil).QueryContext(ctx, dummyQuery)
return nil
}, TxDefaultReadOnly)
assert.Equal(t, 1, l.RLocks)
assert.Equal(t, 1, l.RUnlocks)
assert.Zero(t, l.Locks)
assert.Zero(t, l.Unlocks)
},
},
}
for _, test := range tests {
l := &lockerCounter{}
t.Run(test.name, func(t *testing.T) {
test.fn(runnerDB{
db: dbMockNop{},
mx: l,
}, l)
})
}
}
type lockerCounter struct {
Locks int
Unlocks int
RLocks int
RUnlocks int
}
func (l *lockerCounter) Lock() { l.Locks++ }
func (l *lockerCounter) Unlock() { l.Unlocks++ }
func (l *lockerCounter) RLock() { l.RLocks++ }
func (l *lockerCounter) RUnlock() { l.RUnlocks++ }
type dbMockNop struct{}
func (dbMockNop) DriverName() string { return "" }
func (dbMockNop) Rebind(string) string { return "" }
func (dbMockNop) BindNamed(string, interface{}) (string, []interface{}, error) { return "", nil, nil }
//nolint:nilnil // it's a mock
func (dbMockNop) QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) {
return nil, nil
}
//nolint:nilnil // it's a mock
func (dbMockNop) QueryxContext(context.Context, string, ...interface{}) (*sqlx.Rows, error) {
return nil, nil
}
func (dbMockNop) QueryRowxContext(context.Context, string, ...interface{}) *sqlx.Row { return nil }
func (dbMockNop) ExecContext(context.Context, string, ...interface{}) (sql.Result, error) {
return nil, nil
}
func (dbMockNop) QueryRowContext(context.Context, string, ...any) *sql.Row {
return nil
}
//nolint:nilnil // it's a mock
func (dbMockNop) PrepareContext(context.Context, string) (*sql.Stmt, error) {
return nil, nil
}
//nolint:nilnil // it's a mock
func (dbMockNop) PreparexContext(context.Context, string) (*sqlx.Stmt, error) {
return nil, nil
}
//nolint:nilnil // it's a mock
func (dbMockNop) PrepareNamedContext(context.Context, string) (*sqlx.NamedStmt, error) {
return nil, nil
}
func (dbMockNop) GetContext(context.Context, interface{}, string, ...interface{}) error {
return nil
}
func (dbMockNop) SelectContext(context.Context, interface{}, string, ...interface{}) error {
return nil
}
func (dbMockNop) Commit() error { return nil }
func (dbMockNop) Rollback() error { return nil }
func (d dbMockNop) startTx(context.Context, *sql.TxOptions) (Tx, error) { return d, nil }