searcherside/infrastructure/db/migrate/atlas_migrator.go

81 lines
1.6 KiB
Go
Raw Permalink Normal View History

package migrate
import (
"context"
"database/sql"
"errors"
"fmt"
"io/fs"
"code.icb4dc0.de/prskr/searcherside/core/ports"
"ariga.io/atlas/sql/migrate"
"ariga.io/atlas/sql/postgres"
"ariga.io/atlas/sql/sqlite"
)
var _ ports.Migrator = (*AtlasMigrator)(nil)
type AtlasMigrator struct {
MigrationsFS fs.FS
RevisionRW ports.RevisionReadWriter
}
func (a AtlasMigrator) Migrate(ctx context.Context, req ports.MigrationRequest) (err error) {
dialectFS, err := fs.Sub(a.MigrationsFS, req.Driver.String())
if err != nil {
return fmt.Errorf("no migrations sub-directory found for dialect %s: %w", req.Driver, err)
}
migrateDriver, conn, err := migrationDriverFor(req.Driver, req.URL)
if err != nil {
return err
}
defer func() {
err = errors.Join(err, conn.Close())
}()
executor, err := migrate.NewExecutor(migrateDriver, readOnlyFSDir{FS: dialectFS}, a.RevisionRW)
if err != nil {
return err
}
pendingFiles, err := executor.Pending(ctx)
if err != nil {
return err
}
for idx := range pendingFiles {
if err = executor.Execute(ctx, pendingFiles[idx]); err != nil {
return err
}
}
return nil
}
func migrationDriverFor(driverName ports.Driver, url string) (drv migrate.Driver, db *sql.DB, err error) {
conn, err := sql.Open(driverName.String(), url)
if err != nil {
return nil, nil, err
}
defer func() {
if err != nil {
err = errors.Join(err, conn.Close())
}
}()
switch driverName {
case ports.DriverSQLite:
drv, err = sqlite.Open(conn)
return drv, conn, err
case ports.DriverPostgres:
drv, err = postgres.Open(conn)
return drv, conn, err
default:
return nil, nil, fmt.Errorf("unknown driver: %s", driverName)
}
}