mirror of
https://github.com/harness/drone.git
synced 2025-05-06 16:42:31 +08:00
global mutex for the dbtx package (#118)
This commit is contained in:
parent
204548169d
commit
a734db6f0c
@ -32,7 +32,7 @@ func GetAccessor(ctx context.Context, db *sqlx.DB) Accessor {
|
|||||||
if a, ok := ctx.Value(ctxKeyTx{}).(Accessor); ok {
|
if a, ok := ctx.Value(ctxKeyTx{}).(Accessor); ok {
|
||||||
return a
|
return a
|
||||||
}
|
}
|
||||||
return db
|
return New(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetTransaction returns Transaction interface from the context if it exists or return nil.
|
// GetTransaction returns Transaction interface from the context if it exists or return nil.
|
||||||
|
@ -13,7 +13,11 @@ import (
|
|||||||
|
|
||||||
// New returns new database Runner interface.
|
// New returns new database Runner interface.
|
||||||
func New(db *sqlx.DB) Transactor {
|
func New(db *sqlx.DB) Transactor {
|
||||||
run := &runnerDB{sqlDB{db}}
|
mx := getLocker(db)
|
||||||
|
run := &runnerDB{
|
||||||
|
db: sqlDB{db},
|
||||||
|
mx: mx,
|
||||||
|
}
|
||||||
return run
|
return run
|
||||||
}
|
}
|
||||||
|
|
||||||
|
42
internal/store/database/dbtx/locker.go
Normal file
42
internal/store/database/dbtx/locker.go
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
// Copyright 2022 Harness Inc. All rights reserved.
|
||||||
|
// Use of this source code is governed by the Polyform Free Trial License
|
||||||
|
// that can be found in the LICENSE.md file for this repository.
|
||||||
|
|
||||||
|
package dbtx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/jmoiron/sqlx"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
postgres = "postgres"
|
||||||
|
)
|
||||||
|
|
||||||
|
type locker interface {
|
||||||
|
Lock()
|
||||||
|
Unlock()
|
||||||
|
RLock()
|
||||||
|
RUnlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
var globalMx sync.RWMutex
|
||||||
|
|
||||||
|
func needsLocking(driver string) bool {
|
||||||
|
return driver != postgres
|
||||||
|
}
|
||||||
|
|
||||||
|
func getLocker(db *sqlx.DB) locker {
|
||||||
|
if needsLocking(db.DriverName()) {
|
||||||
|
return &globalMx
|
||||||
|
}
|
||||||
|
return lockerNop{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type lockerNop struct{}
|
||||||
|
|
||||||
|
func (lockerNop) RLock() {}
|
||||||
|
func (lockerNop) RUnlock() {}
|
||||||
|
func (lockerNop) Lock() {}
|
||||||
|
func (lockerNop) Unlock() {}
|
@ -8,10 +8,15 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jmoiron/sqlx"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// runnerDB executes individual sqlx database calls wrapped with the locker calls (Lock/Unlock)
|
||||||
|
// or a transaction wrapped with the locker calls (RLock/RUnlock or Lock/Unlock depending on the transaction type).
|
||||||
type runnerDB struct {
|
type runnerDB struct {
|
||||||
transactor
|
db transactor
|
||||||
|
mx locker
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Transactor = runnerDB{}
|
var _ Transactor = runnerDB{}
|
||||||
@ -28,7 +33,15 @@ func (r runnerDB) WithTx(ctx context.Context, txFn func(context.Context) error,
|
|||||||
txOpts = TxDefault
|
txOpts = TxDefault
|
||||||
}
|
}
|
||||||
|
|
||||||
tx, err := r.startTx(ctx, txOpts)
|
if txOpts.ReadOnly {
|
||||||
|
r.mx.RLock()
|
||||||
|
defer r.mx.RUnlock()
|
||||||
|
} else {
|
||||||
|
r.mx.Lock()
|
||||||
|
defer r.mx.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
tx, err := r.db.startTx(ctx, txOpts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -64,6 +77,80 @@ func (r runnerDB) WithTx(ctx context.Context, txFn func(context.Context) error,
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r runnerDB) DriverName() string {
|
||||||
|
return r.db.DriverName()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r runnerDB) Rebind(query string) string {
|
||||||
|
return r.db.Rebind(query)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r runnerDB) BindNamed(query string, arg interface{}) (string, []interface{}, error) {
|
||||||
|
return r.db.BindNamed(query, arg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r runnerDB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
|
||||||
|
r.mx.Lock()
|
||||||
|
defer r.mx.Unlock()
|
||||||
|
return r.db.QueryContext(ctx, query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r runnerDB) QueryxContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) {
|
||||||
|
r.mx.Lock()
|
||||||
|
defer r.mx.Unlock()
|
||||||
|
return r.db.QueryxContext(ctx, query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r runnerDB) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *sqlx.Row {
|
||||||
|
r.mx.Lock()
|
||||||
|
defer r.mx.Unlock()
|
||||||
|
return r.db.QueryRowxContext(ctx, query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r runnerDB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
||||||
|
r.mx.Lock()
|
||||||
|
defer r.mx.Unlock()
|
||||||
|
return r.db.ExecContext(ctx, query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r runnerDB) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
|
||||||
|
r.mx.Lock()
|
||||||
|
defer r.mx.Unlock()
|
||||||
|
return r.db.QueryRowContext(ctx, query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r runnerDB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
|
||||||
|
r.mx.Lock()
|
||||||
|
defer r.mx.Unlock()
|
||||||
|
return r.db.PrepareContext(ctx, query)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r runnerDB) PreparexContext(ctx context.Context, query string) (*sqlx.Stmt, error) {
|
||||||
|
r.mx.Lock()
|
||||||
|
defer r.mx.Unlock()
|
||||||
|
return r.db.PreparexContext(ctx, query)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r runnerDB) PrepareNamedContext(ctx context.Context, query string) (*sqlx.NamedStmt, error) {
|
||||||
|
r.mx.Lock()
|
||||||
|
defer r.mx.Unlock()
|
||||||
|
return r.db.PrepareNamedContext(ctx, query)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r runnerDB) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
|
||||||
|
r.mx.Lock()
|
||||||
|
defer r.mx.Unlock()
|
||||||
|
return r.db.GetContext(ctx, dest, query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r runnerDB) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
|
||||||
|
r.mx.Lock()
|
||||||
|
defer r.mx.Unlock()
|
||||||
|
return r.db.SelectContext(ctx, dest, query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// runnerTx executes sqlx database transaction calls.
|
||||||
|
// Locking is not used because runnerDB locks the entire transaction.
|
||||||
type runnerTx struct {
|
type runnerTx struct {
|
||||||
Tx
|
Tx
|
||||||
commit bool
|
commit bool
|
||||||
|
@ -11,6 +11,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/jmoiron/sqlx"
|
"github.com/jmoiron/sqlx"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
//nolint:gocognit
|
//nolint:gocognit
|
||||||
@ -106,7 +107,10 @@ func TestWithTx(t *testing.T) {
|
|||||||
t: t,
|
t: t,
|
||||||
errCommit: test.errCommit,
|
errCommit: test.errCommit,
|
||||||
}
|
}
|
||||||
run := &runnerDB{mock}
|
run := &runnerDB{
|
||||||
|
db: mock,
|
||||||
|
mx: lockerNop{},
|
||||||
|
}
|
||||||
|
|
||||||
ctx, cancelFn := context.WithCancel(context.Background())
|
ctx, cancelFn := context.WithCancel(context.Background())
|
||||||
defer cancelFn()
|
defer cancelFn()
|
||||||
@ -203,3 +207,129 @@ func (tx *txMock) Rollback() error {
|
|||||||
tx.rollback = true
|
tx.rollback = true
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLocking(t *testing.T) {
|
||||||
|
const dummyQuery = ""
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fn func(db Transactor, l *lockerCounter)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "exec-lock",
|
||||||
|
fn: func(db Transactor, l *lockerCounter) {
|
||||||
|
ctx := context.Background()
|
||||||
|
_, _ = db.ExecContext(ctx, dummyQuery)
|
||||||
|
_, _ = db.ExecContext(ctx, dummyQuery)
|
||||||
|
_, _ = db.ExecContext(ctx, dummyQuery)
|
||||||
|
|
||||||
|
assert.Zero(t, l.RLocks)
|
||||||
|
assert.Zero(t, l.RUnlocks)
|
||||||
|
assert.Equal(t, 3, l.Locks)
|
||||||
|
assert.Equal(t, 3, l.Unlocks)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tx-lock",
|
||||||
|
fn: func(db Transactor, l *lockerCounter) {
|
||||||
|
ctx := context.Background()
|
||||||
|
_ = db.WithTx(ctx, func(ctx context.Context) error {
|
||||||
|
_, _ = GetAccessor(ctx, nil).ExecContext(ctx, dummyQuery)
|
||||||
|
_, _ = GetAccessor(ctx, nil).ExecContext(ctx, dummyQuery)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Zero(t, l.RLocks)
|
||||||
|
assert.Zero(t, l.RUnlocks)
|
||||||
|
assert.Equal(t, 1, l.Locks)
|
||||||
|
assert.Equal(t, 1, l.Unlocks)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tx-read-lock",
|
||||||
|
fn: func(db Transactor, l *lockerCounter) {
|
||||||
|
ctx := context.Background()
|
||||||
|
_ = db.WithTx(ctx, func(ctx context.Context) error {
|
||||||
|
_, _ = GetAccessor(ctx, nil).QueryContext(ctx, dummyQuery)
|
||||||
|
_, _ = GetAccessor(ctx, nil).QueryContext(ctx, dummyQuery)
|
||||||
|
return nil
|
||||||
|
}, TxDefaultReadOnly)
|
||||||
|
|
||||||
|
assert.Equal(t, 1, l.RLocks)
|
||||||
|
assert.Equal(t, 1, l.RUnlocks)
|
||||||
|
assert.Zero(t, l.Locks)
|
||||||
|
assert.Zero(t, l.Unlocks)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
l := &lockerCounter{}
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
test.fn(runnerDB{
|
||||||
|
db: dbMockNop{},
|
||||||
|
mx: l,
|
||||||
|
}, l)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type lockerCounter struct {
|
||||||
|
Locks int
|
||||||
|
Unlocks int
|
||||||
|
RLocks int
|
||||||
|
RUnlocks int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *lockerCounter) Lock() { l.Locks++ }
|
||||||
|
func (l *lockerCounter) Unlock() { l.Unlocks++ }
|
||||||
|
func (l *lockerCounter) RLock() { l.RLocks++ }
|
||||||
|
func (l *lockerCounter) RUnlock() { l.RUnlocks++ }
|
||||||
|
|
||||||
|
type dbMockNop struct{}
|
||||||
|
|
||||||
|
func (dbMockNop) DriverName() string { return "" }
|
||||||
|
func (dbMockNop) Rebind(string) string { return "" }
|
||||||
|
func (dbMockNop) BindNamed(string, interface{}) (string, []interface{}, error) { return "", nil, nil }
|
||||||
|
|
||||||
|
//nolint:nilnil // it's a mock
|
||||||
|
func (dbMockNop) QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
//nolint:nilnil // it's a mock
|
||||||
|
func (dbMockNop) QueryxContext(context.Context, string, ...interface{}) (*sqlx.Rows, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (dbMockNop) QueryRowxContext(context.Context, string, ...interface{}) *sqlx.Row { return nil }
|
||||||
|
func (dbMockNop) ExecContext(context.Context, string, ...interface{}) (sql.Result, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (dbMockNop) QueryRowContext(context.Context, string, ...any) *sql.Row {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
//nolint:nilnil // it's a mock
|
||||||
|
func (dbMockNop) PrepareContext(context.Context, string) (*sql.Stmt, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
//nolint:nilnil // it's a mock
|
||||||
|
func (dbMockNop) PreparexContext(context.Context, string) (*sqlx.Stmt, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
//nolint:nilnil // it's a mock
|
||||||
|
func (dbMockNop) PrepareNamedContext(context.Context, string) (*sqlx.NamedStmt, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (dbMockNop) GetContext(context.Context, interface{}, string, ...interface{}) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (dbMockNop) SelectContext(context.Context, interface{}, string, ...interface{}) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dbMockNop) Commit() error { return nil }
|
||||||
|
func (dbMockNop) Rollback() error { return nil }
|
||||||
|
|
||||||
|
func (d dbMockNop) startTx(context.Context, *sql.TxOptions) (Tx, error) { return d, nil }
|
||||||
|
Loading…
Reference in New Issue
Block a user