81 lines
1.6 KiB
Go
81 lines
1.6 KiB
Go
|
package migrate
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"database/sql"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"io/fs"
|
||
|
|
||
|
"code.icb4dc0.de/prskr/searcherside/core/ports"
|
||
|
|
||
|
"ariga.io/atlas/sql/migrate"
|
||
|
"ariga.io/atlas/sql/postgres"
|
||
|
"ariga.io/atlas/sql/sqlite"
|
||
|
)
|
||
|
|
||
|
var _ ports.Migrator = (*AtlasMigrator)(nil)
|
||
|
|
||
|
type AtlasMigrator struct {
|
||
|
MigrationsFS fs.FS
|
||
|
RevisionRW ports.RevisionReadWriter
|
||
|
}
|
||
|
|
||
|
func (a AtlasMigrator) Migrate(ctx context.Context, req ports.MigrationRequest) (err error) {
|
||
|
dialectFS, err := fs.Sub(a.MigrationsFS, req.Driver.String())
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("no migrations sub-directory found for dialect %s: %w", req.Driver, err)
|
||
|
}
|
||
|
|
||
|
migrateDriver, conn, err := migrationDriverFor(req.Driver, req.URL)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
defer func() {
|
||
|
err = errors.Join(err, conn.Close())
|
||
|
}()
|
||
|
|
||
|
executor, err := migrate.NewExecutor(migrateDriver, readOnlyFSDir{FS: dialectFS}, a.RevisionRW)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
pendingFiles, err := executor.Pending(ctx)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
for idx := range pendingFiles {
|
||
|
if err = executor.Execute(ctx, pendingFiles[idx]); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func migrationDriverFor(driverName ports.Driver, url string) (drv migrate.Driver, db *sql.DB, err error) {
|
||
|
conn, err := sql.Open(driverName.String(), url)
|
||
|
if err != nil {
|
||
|
return nil, nil, err
|
||
|
}
|
||
|
|
||
|
defer func() {
|
||
|
if err != nil {
|
||
|
err = errors.Join(err, conn.Close())
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
switch driverName {
|
||
|
case ports.DriverSQLite:
|
||
|
drv, err = sqlite.Open(conn)
|
||
|
return drv, conn, err
|
||
|
case ports.DriverPostgres:
|
||
|
drv, err = postgres.Open(conn)
|
||
|
return drv, conn, err
|
||
|
default:
|
||
|
return nil, nil, fmt.Errorf("unknown driver: %s", driverName)
|
||
|
}
|
||
|
}
|