mirror of
https://github.com/harness/drone.git
synced 2025-05-12 06:59:54 +08:00
Fix migration step not checked against current version.
Add tests for DropColumns.
This commit is contained in:
parent
54a9544044
commit
4465b2654d
@ -53,7 +53,6 @@ DELETE FROM migration where revision = ?
|
|||||||
// Implementation details is specific for each database,
|
// Implementation details is specific for each database,
|
||||||
// see migrate/sqlite.go for implementation reference.
|
// see migrate/sqlite.go for implementation reference.
|
||||||
type Operation interface {
|
type Operation interface {
|
||||||
|
|
||||||
CreateTable(tableName string, args []string) (sql.Result, error)
|
CreateTable(tableName string, args []string) (sql.Result, error)
|
||||||
|
|
||||||
RenameTable(tableName, newName string) (sql.Result, error)
|
RenameTable(tableName, newName string) (sql.Result, error)
|
||||||
@ -147,7 +146,7 @@ func (m *Migration) up(target, current int64) error {
|
|||||||
|
|
||||||
// loop through and execute revisions
|
// loop through and execute revisions
|
||||||
for _, rev := range m.revs {
|
for _, rev := range m.revs {
|
||||||
if rev.Revision() >= target {
|
if rev.Revision() > current {
|
||||||
current = rev.Revision()
|
current = rev.Revision()
|
||||||
// execute the revision Upgrade.
|
// execute the revision Upgrade.
|
||||||
if err := rev.Up(op); err != nil {
|
if err := rev.Up(op); err != nil {
|
||||||
@ -191,7 +190,7 @@ func (m *Migration) down(target, current int64) error {
|
|||||||
current = rev.Revision()
|
current = rev.Revision()
|
||||||
// execute the revision Upgrade.
|
// execute the revision Upgrade.
|
||||||
if err := rev.Down(op); err != nil {
|
if err := rev.Down(op); err != nil {
|
||||||
log.Printf("Failed to downgrade to Revision Number %v\n", current)
|
log.Printf("Failed to downgrade from Revision Number %v\n", current)
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
return tx.Rollback()
|
return tx.Rollback()
|
||||||
}
|
}
|
||||||
@ -202,7 +201,7 @@ func (m *Migration) down(target, current int64) error {
|
|||||||
return tx.Rollback()
|
return tx.Rollback()
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("Successfully downgraded to Revision %v\n", current)
|
log.Printf("Successfully downgraded from Revision %v\n", current)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5,8 +5,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
_ "github.com/mattn/go-sqlite3"
|
|
||||||
"github.com/dchest/uniuri"
|
"github.com/dchest/uniuri"
|
||||||
|
_ "github.com/mattn/go-sqlite3"
|
||||||
)
|
)
|
||||||
|
|
||||||
type SQLiteDriver MigrationDriver
|
type SQLiteDriver MigrationDriver
|
||||||
@ -48,7 +48,8 @@ func (s *SQLiteDriver) DropColumns(tableName string, columnsToDrop []string) (sq
|
|||||||
}
|
}
|
||||||
|
|
||||||
columnNames := selectName(columns)
|
columnNames := selectName(columns)
|
||||||
preparedColumns := make([]string, len(columnNames)-len(columnsToDrop))
|
|
||||||
|
var preparedColumns []string
|
||||||
for k, column := range columnNames {
|
for k, column := range columnNames {
|
||||||
listed := false
|
listed := false
|
||||||
for _, dropped := range columnsToDrop {
|
for _, dropped := range columnsToDrop {
|
||||||
@ -98,8 +99,8 @@ func (s *SQLiteDriver) RenameColumns(tableName string, columnChanges map[string]
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
oldColumns := make([]string, len(columnChanges))
|
var oldColumns []string
|
||||||
newColumns := make([]string, len(columnChanges))
|
var newColumns []string
|
||||||
for k, column := range selectName(columns) {
|
for k, column := range selectName(columns) {
|
||||||
for Old, New := range columnChanges {
|
for Old, New := range columnChanges {
|
||||||
if column == Old {
|
if column == Old {
|
||||||
@ -126,7 +127,7 @@ func (s *SQLiteDriver) RenameColumns(tableName string, columnChanges map[string]
|
|||||||
|
|
||||||
func (s *SQLiteDriver) getDDLFromTable(tableName string) (string, error) {
|
func (s *SQLiteDriver) getDDLFromTable(tableName string) (string, error) {
|
||||||
var sql string
|
var sql string
|
||||||
query := `SELECT sql FROM sqlite_master WHERE type='table' and name='?';`
|
query := `SELECT sql FROM sqlite_master WHERE type='table' and name=?;`
|
||||||
err := s.Tx.QueryRow(query, tableName).Scan(&sql)
|
err := s.Tx.QueryRow(query, tableName).Scan(&sql)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
|
@ -76,6 +76,29 @@ func (r *revision2) Revision() int64 {
|
|||||||
|
|
||||||
// ---------- end of revision 2
|
// ---------- end of revision 2
|
||||||
|
|
||||||
|
// ---------- revision 3
|
||||||
|
|
||||||
|
type revision3 struct{}
|
||||||
|
|
||||||
|
func (r *revision3) Up(op Operation) error {
|
||||||
|
if _, err := op.AddColumn("samples", "url VARCHAR(255)"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err := op.AddColumn("samples", "likes INTEGER")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *revision3) Down(op Operation) error {
|
||||||
|
_, err := op.DropColumns("samples", []string{"likes", "url"})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *revision3) Revision() int64 {
|
||||||
|
return 3
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------- end of revision 3
|
||||||
|
|
||||||
var db *sql.DB
|
var db *sql.DB
|
||||||
|
|
||||||
var testSchema = `
|
var testSchema = `
|
||||||
@ -144,6 +167,51 @@ func TestMigrateRenameTable(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type TableInfo struct {
|
||||||
|
CID int64 `meddler:"cid,pk"`
|
||||||
|
Name string `meddler:"name"`
|
||||||
|
Type string `meddler:"type"`
|
||||||
|
Notnull bool `meddler:"notnull"`
|
||||||
|
DfltValue interface{} `meddler:"dflt_value"`
|
||||||
|
PK bool `meddler:"pk"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMigrateAddRemoveColumns(t *testing.T) {
|
||||||
|
defer tearDown()
|
||||||
|
if err := setUp(); err != nil {
|
||||||
|
t.Fatalf("Error preparing database: %q", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
Driver = SQLite
|
||||||
|
|
||||||
|
mgr := New(db)
|
||||||
|
if err := mgr.Add(&revision1{}).Add(&revision3{}).Migrate(); err != nil {
|
||||||
|
t.Errorf("Can not migrate: %q", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var columns []*TableInfo
|
||||||
|
if err := meddler.QueryAll(db, &columns, `PRAGMA table_info(samples);`); err != nil {
|
||||||
|
t.Errorf("Can not access table info: %q", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(columns) < 5 {
|
||||||
|
t.Errorf("Expect length columns: %d\nGot: %d", 5, len(columns))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mgr.MigrateTo(1); err != nil {
|
||||||
|
t.Errorf("Can not migrate: %q", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var another_columns []*TableInfo
|
||||||
|
if err := meddler.QueryAll(db, &another_columns, `PRAGMA table_info(samples);`); err != nil {
|
||||||
|
t.Errorf("Can not access table info: %q", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(another_columns) != 3 {
|
||||||
|
t.Errorf("Expect length columns: %d\nGot: %d", 3, len(columns))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func setUp() error {
|
func setUp() error {
|
||||||
var err error
|
var err error
|
||||||
db, err = sql.Open("sqlite3", "migration_tests.sqlite")
|
db, err = sql.Open("sqlite3", "migration_tests.sqlite")
|
||||||
|
@ -15,7 +15,7 @@ func fetchColumns(sql string) ([]string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func selectName(columns []string) []string {
|
func selectName(columns []string) []string {
|
||||||
results := make([]string, len(columns))
|
var results []string
|
||||||
for _, column := range columns {
|
for _, column := range columns {
|
||||||
col := strings.SplitN(strings.Trim(column, " \n\t"), " ", 2)
|
col := strings.SplitN(strings.Trim(column, " \n\t"), " ", 2)
|
||||||
results = append(results, col[0])
|
results = append(results, col[0])
|
||||||
@ -24,7 +24,7 @@ func selectName(columns []string) []string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func setForUpdate(left []string, right []string) string {
|
func setForUpdate(left []string, right []string) string {
|
||||||
results := make([]string, len(left))
|
var results []string
|
||||||
for k, str := range left {
|
for k, str := range left {
|
||||||
results = append(results, fmt.Sprintf("%s = %s", str, right[k]))
|
results = append(results, fmt.Sprintf("%s = %s", str, right[k]))
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user