Fix migration step not checked against current version.

Add tests for DropColumns.
This commit is contained in:
Nurahmadie 2014-02-15 22:17:22 +07:00
parent 54a9544044
commit 4465b2654d
4 changed files with 79 additions and 11 deletions

View File

@ -53,7 +53,6 @@ DELETE FROM migration where revision = ?
// Implementation details is specific for each database, // Implementation details is specific for each database,
// see migrate/sqlite.go for implementation reference. // see migrate/sqlite.go for implementation reference.
type Operation interface { type Operation interface {
CreateTable(tableName string, args []string) (sql.Result, error) CreateTable(tableName string, args []string) (sql.Result, error)
RenameTable(tableName, newName string) (sql.Result, error) RenameTable(tableName, newName string) (sql.Result, error)
@ -147,7 +146,7 @@ func (m *Migration) up(target, current int64) error {
// loop through and execute revisions // loop through and execute revisions
for _, rev := range m.revs { for _, rev := range m.revs {
if rev.Revision() >= target { if rev.Revision() > current {
current = rev.Revision() current = rev.Revision()
// execute the revision Upgrade. // execute the revision Upgrade.
if err := rev.Up(op); err != nil { if err := rev.Up(op); err != nil {
@ -191,7 +190,7 @@ func (m *Migration) down(target, current int64) error {
current = rev.Revision() current = rev.Revision()
// execute the revision Upgrade. // execute the revision Upgrade.
if err := rev.Down(op); err != nil { if err := rev.Down(op); err != nil {
log.Printf("Failed to downgrade to Revision Number %v\n", current) log.Printf("Failed to downgrade from Revision Number %v\n", current)
log.Println(err) log.Println(err)
return tx.Rollback() return tx.Rollback()
} }
@ -202,7 +201,7 @@ func (m *Migration) down(target, current int64) error {
return tx.Rollback() return tx.Rollback()
} }
log.Printf("Successfully downgraded to Revision %v\n", current) log.Printf("Successfully downgraded from Revision %v\n", current)
} }
} }

View File

@ -5,8 +5,8 @@ import (
"fmt" "fmt"
"strings" "strings"
_ "github.com/mattn/go-sqlite3"
"github.com/dchest/uniuri" "github.com/dchest/uniuri"
_ "github.com/mattn/go-sqlite3"
) )
type SQLiteDriver MigrationDriver type SQLiteDriver MigrationDriver
@ -48,7 +48,8 @@ func (s *SQLiteDriver) DropColumns(tableName string, columnsToDrop []string) (sq
} }
columnNames := selectName(columns) columnNames := selectName(columns)
preparedColumns := make([]string, len(columnNames)-len(columnsToDrop))
var preparedColumns []string
for k, column := range columnNames { for k, column := range columnNames {
listed := false listed := false
for _, dropped := range columnsToDrop { for _, dropped := range columnsToDrop {
@ -98,8 +99,8 @@ func (s *SQLiteDriver) RenameColumns(tableName string, columnChanges map[string]
return nil, err return nil, err
} }
oldColumns := make([]string, len(columnChanges)) var oldColumns []string
newColumns := make([]string, len(columnChanges)) var newColumns []string
for k, column := range selectName(columns) { for k, column := range selectName(columns) {
for Old, New := range columnChanges { for Old, New := range columnChanges {
if column == Old { if column == Old {
@ -126,7 +127,7 @@ func (s *SQLiteDriver) RenameColumns(tableName string, columnChanges map[string]
func (s *SQLiteDriver) getDDLFromTable(tableName string) (string, error) { func (s *SQLiteDriver) getDDLFromTable(tableName string) (string, error) {
var sql string var sql string
query := `SELECT sql FROM sqlite_master WHERE type='table' and name='?';` query := `SELECT sql FROM sqlite_master WHERE type='table' and name=?;`
err := s.Tx.QueryRow(query, tableName).Scan(&sql) err := s.Tx.QueryRow(query, tableName).Scan(&sql)
if err != nil { if err != nil {
return "", err return "", err

View File

@ -76,6 +76,29 @@ func (r *revision2) Revision() int64 {
// ---------- end of revision 2 // ---------- end of revision 2
// ---------- revision 3
type revision3 struct{}
func (r *revision3) Up(op Operation) error {
if _, err := op.AddColumn("samples", "url VARCHAR(255)"); err != nil {
return err
}
_, err := op.AddColumn("samples", "likes INTEGER")
return err
}
func (r *revision3) Down(op Operation) error {
_, err := op.DropColumns("samples", []string{"likes", "url"})
return err
}
func (r *revision3) Revision() int64 {
return 3
}
// ---------- end of revision 3
var db *sql.DB var db *sql.DB
var testSchema = ` var testSchema = `
@ -144,6 +167,51 @@ func TestMigrateRenameTable(t *testing.T) {
} }
} }
type TableInfo struct {
CID int64 `meddler:"cid,pk"`
Name string `meddler:"name"`
Type string `meddler:"type"`
Notnull bool `meddler:"notnull"`
DfltValue interface{} `meddler:"dflt_value"`
PK bool `meddler:"pk"`
}
func TestMigrateAddRemoveColumns(t *testing.T) {
defer tearDown()
if err := setUp(); err != nil {
t.Fatalf("Error preparing database: %q", err)
}
Driver = SQLite
mgr := New(db)
if err := mgr.Add(&revision1{}).Add(&revision3{}).Migrate(); err != nil {
t.Errorf("Can not migrate: %q", err)
}
var columns []*TableInfo
if err := meddler.QueryAll(db, &columns, `PRAGMA table_info(samples);`); err != nil {
t.Errorf("Can not access table info: %q", err)
}
if len(columns) < 5 {
t.Errorf("Expect length columns: %d\nGot: %d", 5, len(columns))
}
if err := mgr.MigrateTo(1); err != nil {
t.Errorf("Can not migrate: %q", err)
}
var another_columns []*TableInfo
if err := meddler.QueryAll(db, &another_columns, `PRAGMA table_info(samples);`); err != nil {
t.Errorf("Can not access table info: %q", err)
}
if len(another_columns) != 3 {
t.Errorf("Expect length columns: %d\nGot: %d", 3, len(columns))
}
}
func setUp() error { func setUp() error {
var err error var err error
db, err = sql.Open("sqlite3", "migration_tests.sqlite") db, err = sql.Open("sqlite3", "migration_tests.sqlite")

View File

@ -15,7 +15,7 @@ func fetchColumns(sql string) ([]string, error) {
} }
func selectName(columns []string) []string { func selectName(columns []string) []string {
results := make([]string, len(columns)) var results []string
for _, column := range columns { for _, column := range columns {
col := strings.SplitN(strings.Trim(column, " \n\t"), " ", 2) col := strings.SplitN(strings.Trim(column, " \n\t"), " ", 2)
results = append(results, col[0]) results = append(results, col[0])
@ -24,7 +24,7 @@ func selectName(columns []string) []string {
} }
func setForUpdate(left []string, right []string) string { func setForUpdate(left []string, right []string) string {
results := make([]string, len(left)) var results []string
for k, str := range left { for k, str := range left {
results = append(results, fmt.Sprintf("%s = %s", str, right[k])) results = append(results, fmt.Sprintf("%s = %s", str, right[k]))
} }