mirror of
https://github.com/harness/drone.git
synced 2025-05-06 12:41:08 +08:00
global mutex for the dbtx package (#118)
This commit is contained in:
parent
204548169d
commit
a734db6f0c
@ -32,7 +32,7 @@ func GetAccessor(ctx context.Context, db *sqlx.DB) Accessor {
|
||||
if a, ok := ctx.Value(ctxKeyTx{}).(Accessor); ok {
|
||||
return a
|
||||
}
|
||||
return db
|
||||
return New(db)
|
||||
}
|
||||
|
||||
// GetTransaction returns Transaction interface from the context if it exists or return nil.
|
||||
|
@ -13,7 +13,11 @@ import (
|
||||
|
||||
// New returns new database Runner interface.
|
||||
func New(db *sqlx.DB) Transactor {
|
||||
run := &runnerDB{sqlDB{db}}
|
||||
mx := getLocker(db)
|
||||
run := &runnerDB{
|
||||
db: sqlDB{db},
|
||||
mx: mx,
|
||||
}
|
||||
return run
|
||||
}
|
||||
|
||||
|
42
internal/store/database/dbtx/locker.go
Normal file
42
internal/store/database/dbtx/locker.go
Normal file
@ -0,0 +1,42 @@
|
||||
// Copyright 2022 Harness Inc. All rights reserved.
|
||||
// Use of this source code is governed by the Polyform Free Trial License
|
||||
// that can be found in the LICENSE.md file for this repository.
|
||||
|
||||
package dbtx
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
const (
|
||||
postgres = "postgres"
|
||||
)
|
||||
|
||||
type locker interface {
|
||||
Lock()
|
||||
Unlock()
|
||||
RLock()
|
||||
RUnlock()
|
||||
}
|
||||
|
||||
var globalMx sync.RWMutex
|
||||
|
||||
func needsLocking(driver string) bool {
|
||||
return driver != postgres
|
||||
}
|
||||
|
||||
func getLocker(db *sqlx.DB) locker {
|
||||
if needsLocking(db.DriverName()) {
|
||||
return &globalMx
|
||||
}
|
||||
return lockerNop{}
|
||||
}
|
||||
|
||||
type lockerNop struct{}
|
||||
|
||||
func (lockerNop) RLock() {}
|
||||
func (lockerNop) RUnlock() {}
|
||||
func (lockerNop) Lock() {}
|
||||
func (lockerNop) Unlock() {}
|
@ -8,10 +8,15 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
// runnerDB executes individual sqlx database calls wrapped with the locker calls (Lock/Unlock)
|
||||
// or a transaction wrapped with the locker calls (RLock/RUnlock or Lock/Unlock depending on the transaction type).
|
||||
type runnerDB struct {
|
||||
transactor
|
||||
db transactor
|
||||
mx locker
|
||||
}
|
||||
|
||||
var _ Transactor = runnerDB{}
|
||||
@ -28,7 +33,15 @@ func (r runnerDB) WithTx(ctx context.Context, txFn func(context.Context) error,
|
||||
txOpts = TxDefault
|
||||
}
|
||||
|
||||
tx, err := r.startTx(ctx, txOpts)
|
||||
if txOpts.ReadOnly {
|
||||
r.mx.RLock()
|
||||
defer r.mx.RUnlock()
|
||||
} else {
|
||||
r.mx.Lock()
|
||||
defer r.mx.Unlock()
|
||||
}
|
||||
|
||||
tx, err := r.db.startTx(ctx, txOpts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -64,6 +77,80 @@ func (r runnerDB) WithTx(ctx context.Context, txFn func(context.Context) error,
|
||||
return err
|
||||
}
|
||||
|
||||
func (r runnerDB) DriverName() string {
|
||||
return r.db.DriverName()
|
||||
}
|
||||
|
||||
func (r runnerDB) Rebind(query string) string {
|
||||
return r.db.Rebind(query)
|
||||
}
|
||||
|
||||
func (r runnerDB) BindNamed(query string, arg interface{}) (string, []interface{}, error) {
|
||||
return r.db.BindNamed(query, arg)
|
||||
}
|
||||
|
||||
func (r runnerDB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
|
||||
r.mx.Lock()
|
||||
defer r.mx.Unlock()
|
||||
return r.db.QueryContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
func (r runnerDB) QueryxContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) {
|
||||
r.mx.Lock()
|
||||
defer r.mx.Unlock()
|
||||
return r.db.QueryxContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
func (r runnerDB) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *sqlx.Row {
|
||||
r.mx.Lock()
|
||||
defer r.mx.Unlock()
|
||||
return r.db.QueryRowxContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
func (r runnerDB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
||||
r.mx.Lock()
|
||||
defer r.mx.Unlock()
|
||||
return r.db.ExecContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
func (r runnerDB) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
|
||||
r.mx.Lock()
|
||||
defer r.mx.Unlock()
|
||||
return r.db.QueryRowContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
func (r runnerDB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
|
||||
r.mx.Lock()
|
||||
defer r.mx.Unlock()
|
||||
return r.db.PrepareContext(ctx, query)
|
||||
}
|
||||
|
||||
func (r runnerDB) PreparexContext(ctx context.Context, query string) (*sqlx.Stmt, error) {
|
||||
r.mx.Lock()
|
||||
defer r.mx.Unlock()
|
||||
return r.db.PreparexContext(ctx, query)
|
||||
}
|
||||
|
||||
func (r runnerDB) PrepareNamedContext(ctx context.Context, query string) (*sqlx.NamedStmt, error) {
|
||||
r.mx.Lock()
|
||||
defer r.mx.Unlock()
|
||||
return r.db.PrepareNamedContext(ctx, query)
|
||||
}
|
||||
|
||||
func (r runnerDB) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
|
||||
r.mx.Lock()
|
||||
defer r.mx.Unlock()
|
||||
return r.db.GetContext(ctx, dest, query, args...)
|
||||
}
|
||||
|
||||
func (r runnerDB) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
|
||||
r.mx.Lock()
|
||||
defer r.mx.Unlock()
|
||||
return r.db.SelectContext(ctx, dest, query, args...)
|
||||
}
|
||||
|
||||
// runnerTx executes sqlx database transaction calls.
|
||||
// Locking is not used because runnerDB locks the entire transaction.
|
||||
type runnerTx struct {
|
||||
Tx
|
||||
commit bool
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
//nolint:gocognit
|
||||
@ -106,7 +107,10 @@ func TestWithTx(t *testing.T) {
|
||||
t: t,
|
||||
errCommit: test.errCommit,
|
||||
}
|
||||
run := &runnerDB{mock}
|
||||
run := &runnerDB{
|
||||
db: mock,
|
||||
mx: lockerNop{},
|
||||
}
|
||||
|
||||
ctx, cancelFn := context.WithCancel(context.Background())
|
||||
defer cancelFn()
|
||||
@ -203,3 +207,129 @@ func (tx *txMock) Rollback() error {
|
||||
tx.rollback = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestLocking(t *testing.T) {
|
||||
const dummyQuery = ""
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func(db Transactor, l *lockerCounter)
|
||||
}{
|
||||
{
|
||||
name: "exec-lock",
|
||||
fn: func(db Transactor, l *lockerCounter) {
|
||||
ctx := context.Background()
|
||||
_, _ = db.ExecContext(ctx, dummyQuery)
|
||||
_, _ = db.ExecContext(ctx, dummyQuery)
|
||||
_, _ = db.ExecContext(ctx, dummyQuery)
|
||||
|
||||
assert.Zero(t, l.RLocks)
|
||||
assert.Zero(t, l.RUnlocks)
|
||||
assert.Equal(t, 3, l.Locks)
|
||||
assert.Equal(t, 3, l.Unlocks)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tx-lock",
|
||||
fn: func(db Transactor, l *lockerCounter) {
|
||||
ctx := context.Background()
|
||||
_ = db.WithTx(ctx, func(ctx context.Context) error {
|
||||
_, _ = GetAccessor(ctx, nil).ExecContext(ctx, dummyQuery)
|
||||
_, _ = GetAccessor(ctx, nil).ExecContext(ctx, dummyQuery)
|
||||
return nil
|
||||
})
|
||||
|
||||
assert.Zero(t, l.RLocks)
|
||||
assert.Zero(t, l.RUnlocks)
|
||||
assert.Equal(t, 1, l.Locks)
|
||||
assert.Equal(t, 1, l.Unlocks)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tx-read-lock",
|
||||
fn: func(db Transactor, l *lockerCounter) {
|
||||
ctx := context.Background()
|
||||
_ = db.WithTx(ctx, func(ctx context.Context) error {
|
||||
_, _ = GetAccessor(ctx, nil).QueryContext(ctx, dummyQuery)
|
||||
_, _ = GetAccessor(ctx, nil).QueryContext(ctx, dummyQuery)
|
||||
return nil
|
||||
}, TxDefaultReadOnly)
|
||||
|
||||
assert.Equal(t, 1, l.RLocks)
|
||||
assert.Equal(t, 1, l.RUnlocks)
|
||||
assert.Zero(t, l.Locks)
|
||||
assert.Zero(t, l.Unlocks)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
l := &lockerCounter{}
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
test.fn(runnerDB{
|
||||
db: dbMockNop{},
|
||||
mx: l,
|
||||
}, l)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type lockerCounter struct {
|
||||
Locks int
|
||||
Unlocks int
|
||||
RLocks int
|
||||
RUnlocks int
|
||||
}
|
||||
|
||||
func (l *lockerCounter) Lock() { l.Locks++ }
|
||||
func (l *lockerCounter) Unlock() { l.Unlocks++ }
|
||||
func (l *lockerCounter) RLock() { l.RLocks++ }
|
||||
func (l *lockerCounter) RUnlock() { l.RUnlocks++ }
|
||||
|
||||
type dbMockNop struct{}
|
||||
|
||||
func (dbMockNop) DriverName() string { return "" }
|
||||
func (dbMockNop) Rebind(string) string { return "" }
|
||||
func (dbMockNop) BindNamed(string, interface{}) (string, []interface{}, error) { return "", nil, nil }
|
||||
|
||||
//nolint:nilnil // it's a mock
|
||||
func (dbMockNop) QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
//nolint:nilnil // it's a mock
|
||||
func (dbMockNop) QueryxContext(context.Context, string, ...interface{}) (*sqlx.Rows, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (dbMockNop) QueryRowxContext(context.Context, string, ...interface{}) *sqlx.Row { return nil }
|
||||
func (dbMockNop) ExecContext(context.Context, string, ...interface{}) (sql.Result, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (dbMockNop) QueryRowContext(context.Context, string, ...any) *sql.Row {
|
||||
return nil
|
||||
}
|
||||
|
||||
//nolint:nilnil // it's a mock
|
||||
func (dbMockNop) PrepareContext(context.Context, string) (*sql.Stmt, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
//nolint:nilnil // it's a mock
|
||||
func (dbMockNop) PreparexContext(context.Context, string) (*sqlx.Stmt, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
//nolint:nilnil // it's a mock
|
||||
func (dbMockNop) PrepareNamedContext(context.Context, string) (*sqlx.NamedStmt, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (dbMockNop) GetContext(context.Context, interface{}, string, ...interface{}) error {
|
||||
return nil
|
||||
}
|
||||
func (dbMockNop) SelectContext(context.Context, interface{}, string, ...interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dbMockNop) Commit() error { return nil }
|
||||
func (dbMockNop) Rollback() error { return nil }
|
||||
|
||||
func (d dbMockNop) startTx(context.Context, *sql.TxOptions) (Tx, error) { return d, nil }
|
||||
|
Loading…
Reference in New Issue
Block a user