diff --git a/docs/references/cli.md b/docs/references/cli.md index bd6fd496eaeab0cae479315baf113980dd366ad3..72aac7c8517c1dbcf8ff74231bfed12b307e9614 100644 --- a/docs/references/cli.md +++ b/docs/references/cli.md @@ -87,7 +87,12 @@ See pg_migrate --help for more informations. $ ``` -`pg_migrate init` refuses to reset an existing project. +`pg_migrate init` accepts to re-initialize an existing project +by updating DSN in `.env` and metadata in `.pg_migrate/Info.json`. +PostgreSQL Migrator does not modify existing TOML, report template or jq files. + +PostgreSQL Migrator refuses to initialize a non-empty directory +without existing `pg_migrate.toml` or `.pg_migrate`. ## pg_migrate **inspect** diff --git a/internal/catalog/metadata.go b/internal/catalog/metadata.go index c128098aa88d4c3f1b140220fd2af33270c2ae68..7a84f8406265495b645812c5146dd991c0641ff0 100644 --- a/internal/catalog/metadata.go +++ b/internal/catalog/metadata.go @@ -19,6 +19,9 @@ type Metadata struct { // Version queries PostgreSQL server name and version func Version(ctx context.Context) (software, version string, err error) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + ctx, err = database.WithConnection(ctx, database.Target) if err != nil { return diff --git a/internal/cmd/init/cmd.go b/internal/cmd/init/cmd.go index bea18c4beae0e57d0d55f0acf3bd5b3004cb8708..f380186ff955a6f8495662ff2595ae9ec73d124d 100644 --- a/internal/cmd/init/cmd.go +++ b/internal/cmd/init/cmd.go @@ -1,20 +1,29 @@ package initc import ( + "errors" "fmt" "log/slog" "os" + "path/filepath" + "strings" + "github.com/knadh/koanf/parsers/dotenv" + "github.com/knadh/koanf/providers/file" "github.com/knadh/koanf/v2" "github.com/lithammer/dedent" "github.com/spf13/pflag" "gitlab.com/dalibo/pg_migrate/internal/cmd" + "gitlab.com/dalibo/pg_migrate/internal/database" "gitlab.com/dalibo/pg_migrate/internal/project" ) type Cmd struct { - path string - k *koanf.Koanf + k *koanf.Koanf + path string + hasDir bool + hasToml bool + empty bool } func (Cmd) String() string { @@ -27,43 +36,50 @@ func (c *Cmd) Run(args ...string) error { return err } - err = c.initDir() + c.checkDir() + + if !c.empty && !c.hasDir && !c.hasToml { + slog.Error("Directory is not empty and not a PostgreSQL Migrator project.", "path", c.path) + return cmd.Exit(1) + } + + defer database.Close() + err = saveSource(c.k.String("source")) if err != nil { - return fmt.Errorf("dir: %w", err) + return fmt.Errorf("source: %w", err) } - err = c.initGitignore() + err = saveTarget(c.k.String("target")) if err != nil { - return fmt.Errorf(".gitignore: %w", err) + return fmt.Errorf("target: %w", err) } - err = c.initDotEnv() + err = ensureDir(c.path) // moves inside project if err != nil { - rmerr := c.removeDir() - if rmerr != nil { - slog.Warn("Failed to clean .pg_migrate directory.", "path", c.path, "err", rmerr) - } - return err + return fmt.Errorf("dir: %w", err) } - err = c.initInfo() + err = saveInfo() if err != nil { return fmt.Errorf("info: %w", err) } - err = c.initConfig() - if err != nil { - return fmt.Errorf("config: %w", err) + if c.hasToml { + slog.Info("Existing configuration preserved.") + } else { + err = writeNewConfig() + if err != nil { + return fmt.Errorf("config: %w", err) + } } - err = c.writeReports() + err = gitIgnoreDotEnv(!c.hasToml) if err != nil { - return fmt.Errorf("report: %w", err) + return fmt.Errorf(".gitignore: %w", err) } - project.Current.Path = c.path - err = project.Current.Read() + err = writeSecrets() if err != nil { - panic(fmt.Errorf("unreadable project: %w", err)) + return err } slog.Info( @@ -118,7 +134,46 @@ func (c *Cmd) parseFlags(args ...string) error { flags.Usage() return cmd.Exit(0) } + c.path, _ = filepath.Abs(flags.Arg(0)) c.k = cmd.Koanf(&flags) - c.path = flags.Arg(0) - return nil + + // Load .env from parameterized project path. + dotEnv := filepath.Join(c.path, ".env") + err = c.k.Load(file.Provider(dotEnv), dotenv.ParserEnv("PGM", c.k.Delim(), func(n string) string { + n = strings.TrimPrefix(n, "PGM") + n = strings.ToLower(n) + // Don't overwrite flags. + if c.k.String(n) != "" { + return "" + } + return n + })) + + if errors.Is(err, os.ErrNotExist) { + return nil // ignore missing .env + } + + return err +} + +func (c *Cmd) checkDir() { + _, err := os.Stat(filepath.Join(c.path, ".pg_migrate")) + c.hasDir = err == nil + + _, err = os.Stat(filepath.Join(c.path, "pg_migrate.toml")) + c.hasToml = err == nil + + c.empty = !c.hasDir && !c.hasToml + if !c.empty { + return + } + + f, err := os.Open(c.path) + if err != nil { + return + } + + defer f.Close() //nolint:errcheck + names, _ := f.Readdirnames(1) + c.empty = len(names) == 0 } diff --git a/internal/cmd/init/write.go b/internal/cmd/init/write.go index eef4fd87dab950a46d70828f4b15e3fd6bbdf696..dd01878bd07ec4aef519def457f0922cf6a6448f 100644 --- a/internal/cmd/init/write.go +++ b/internal/cmd/init/write.go @@ -1,14 +1,17 @@ package initc import ( + "bufio" "context" "embed" "encoding/json" "errors" "fmt" + "io" "log/slog" "os" "path/filepath" + "strings" "text/template" "time" @@ -19,49 +22,14 @@ import ( "gitlab.com/dalibo/pg_migrate/internal/project" ) -func (c *Cmd) initDir() error { - var err error - if c.path == "" { - c.path = "." - } - c.path, err = filepath.Abs(c.path) - if err != nil { - return err - } - d := filepath.Join(c.path, ".pg_migrate") - _, err = os.Stat(d) - if err == nil { - return fmt.Errorf("project already exists") - } - slog.Debug("Creating directory.", "path", d) - err = os.MkdirAll(d, 0700) - if err != nil { - return fmt.Errorf("create: %w", err) - } - return nil -} - -func (c *Cmd) removeDir() error { - var err error - c.path, err = filepath.Abs(c.path) - if err != nil { - return err - } - d := filepath.Join(c.path, ".pg_migrate") - slog.Debug("Cleaning directory.", "path", d) - return os.RemoveAll(d) -} - -// initDotEnv checks DSN and stores them in .env -func (c *Cmd) initDotEnv() error { - dsn := c.k.String("source") +func saveSource(dsn string) error { if dsn == "" { - return errors.New("empty source") + return errors.New("empty DSN") } u, err := dburl.Parse(dsn) if err != nil { - return fmt.Errorf("source: %w", err) + return err } project.Current.Source = u project.Current.Info.Driver = u.Driver @@ -75,49 +43,54 @@ func (c *Cmd) initDotEnv() error { return fmt.Errorf("version: %w", err) } - dsn = c.k.String("target") - if dsn != "" { - u, err = dburl.Parse(dsn) - if err != nil { - return fmt.Errorf("target: %w", err) - } - u.Driver = "pgx" // Ensure we use pgx for PostgreSQL targets. - project.Current.Target = u - project.Current.Info.Target.InitialDSN = u.Short() - project.Current.Info.Target.Software, project.Current.Info.Target.Version, err = catalog.Version(context.Background()) - if err != nil { - return fmt.Errorf("version: %w", err) - } - } else { + return nil +} + +func saveTarget(dsn string) error { + if dsn == "" { project.Current.Info.Target.Software = "PostgreSQL" project.Current.Info.Target.Version = "unknown" + return nil } - slog.Debug("Writing secrets.", "path", ".env") - f, err := os.Create(filepath.Join(c.path, ".env")) + u, err := dburl.Parse(dsn) if err != nil { return err } - defer f.Close() //nolint:errcheck - err = os.Chmod(f.Name(), 0600) + u.Driver = "pgx" // Ensure we use pgx for PostgreSQL targets. + project.Current.Target = u + project.Current.Info.Target.InitialDSN = u.Short() + project.Current.Info.Target.Software, project.Current.Info.Target.Version, err = catalog.Version(context.Background()) if err != nil { - return err - } - - _, err = fmt.Fprintf(f, "PGMSOURCE=%s\n", project.Current.Source) - - if project.Current.Target != nil { - _, err = fmt.Fprintf(f, "PGMTARGET=%s\n", project.Current.Target) + return fmt.Errorf("version: %w", err) } - - return err + return nil } type versionner interface { Version(context.Context) (string, string, error) } -func (c *Cmd) initInfo() error { +func ensureDir(path string) error { + var err error + if path == "" { + path = "." + } + d := filepath.Join(path, ".pg_migrate") + _, err = os.Stat(d) + if err == nil { + slog.Debug("Reusing directory.", "path", d) + return os.Chdir(path) + } + slog.Debug("Creating directory.", "path", d) + err = os.MkdirAll(d, 0700) + if err != nil { + return fmt.Errorf("create: %w", err) + } + return os.Chdir(path) +} + +func saveInfo() error { project.Current.Info.Created = time.Now() project.Current.Info.Creator = fmt.Sprintf("pg_migrate (%s)", cmd.Version) jsonData, err := json.Marshal(project.Current.Info) @@ -125,109 +98,143 @@ func (c *Cmd) initInfo() error { panic(err) } - return os.WriteFile(filepath.Join(c.path, ".pg_migrate/Info.json"), jsonData, 0644) - + p, _ := filepath.Abs(".pg_migrate/Info.json") + return os.WriteFile(p, jsonData, 0644) } -func (c *Cmd) initGitignore() error { - if _, err := os.Stat(filepath.Join(c.path, ".gitignore")); err == nil { - return fmt.Errorf(".gitignore already exists") - } - slog.Debug("Writing gitignore.", "path", ".gitignore") - f, err := os.Create(filepath.Join(c.path, ".gitignore")) +func writeSecrets() error { + slog.Debug("Writing secrets.", "path", ".env") + f, err := os.OpenFile(".env", os.O_CREATE|os.O_RDWR, 0600) if err != nil { return err } defer f.Close() //nolint:errcheck - _, err = f.WriteString(dedent.Dedent(` - # .env holds secrets - .env - # Generated files - *.json - report.md - `)[1:]) + sourceDef := fmt.Sprintf("PGMSOURCE=%s", project.Current.Source) + targetDef := fmt.Sprintf("PGMTARGET=%s", project.Current.Target) + + // Walk lines and replace declaration. + var lines []string + var hasSource, hasTarget bool + scanner := bufio.NewScanner(f) + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "PGMSOURCE=") { + slog.Debug("Updating source DSN definition.") + line = sourceDef + hasSource = true + } + if project.Current.Target != nil && strings.HasPrefix(line, "PGMTARGET=") { + slog.Debug("Updating target DSN definition.") + line = targetDef + hasTarget = true + } + lines = append(lines, line) + } + // Append missing declaration. + if !hasSource { + lines = append(lines, sourceDef) + } + if project.Current.Target != nil && !hasTarget { + lines = append(lines, targetDef) + } + + // Overwrite back file. + f.Truncate(0) //nolint:errcheck + f.Seek(0, io.SeekStart) //nolint:errcheck + for _, line := range lines { + fmt.Fprintln(f, line) //nolint:errcheck + } + return err } -//go:embed pg_migrate.toml -var defaultConfig string +func gitIgnoreDotEnv(ignoreReports bool) error { + if _, err := os.Stat(".gitignore"); err == nil { + slog.Debug("Preserve existing .gitignore.") + return nil + } -func (c *Cmd) initConfig() error { - slog.Debug("Writing config.", "path", "pg_migrate.toml") - f, err := os.Create(filepath.Join(c.path, "pg_migrate.toml")) + slog.Debug("Writing gitignore.", "path", ".gitignore") + f, err := os.Create(".gitignore") if err != nil { return err } defer f.Close() //nolint:errcheck - t, err := template.New("config").Parse(defaultConfig) + _, err = f.WriteString(dedent.Dedent(` + # .env holds secrets + .env + `)) if err != nil { - panic(err) // config is builtin. + return err } - return t.Execute(f, map[string]any{ - "Driver": project.Current.Info.Driver, - }) + if ignoreReports { + _, err = f.WriteString(dedent.Dedent(` + # Generated files + *.json + report.md + `)[1:]) + return err + } + return nil } +//go:embed pg_migrate.toml +var defaultToml string + //go:embed report.md.tmpl var defaultReportTemplate string //go:embed *.jq var jqFiles embed.FS -func (c *Cmd) writeReports() error { - slog.Debug("Writing report template.", "path", "report.md.tmpl") - f, err := os.Create(filepath.Join(c.path, "report.md.tmpl")) +func writeNewConfig() error { + slog.Debug("Writing config.", "path", "pg_migrate.toml") + err := render("pg_migrate.toml", defaultToml) if err != nil { - return err + return fmt.Errorf("pg_migrate.toml: %w", err) } - defer f.Close() //nolint:errcheck - _, err = f.WriteString(defaultReportTemplate) + slog.Debug("Writing report template.", "path", "report.md.tmpl") + err = os.WriteFile("report.md.tmpl", []byte(defaultReportTemplate), 0600) if err != nil { - return err + return fmt.Errorf("report.md.tml: %w", err) } files, err := jqFiles.ReadDir(".") if err != nil { - return err + panic(err) } for _, file := range files { - err = c.writeJq(file.Name()) + name := file.Name() + slog.Debug("Writing jq file.", "path", file.Name()) + content, err := jqFiles.ReadFile(name) if err != nil { - return err + panic(err) + } + + err = render(name, string(content)) + if err != nil { + return fmt.Errorf("%s: %w", name, err) } } return nil } -func (c *Cmd) writeJq(name string) error { - slog.Debug("Writing jq files.", "path", name) - f, err := os.Create(filepath.Join(c.path, name)) +func render(name, content string) error { + f, err := os.Create(name) if err != nil { return err } defer f.Close() //nolint:errcheck - content, err := jqFiles.ReadFile(name) - if err != nil { - return err - } - - t, err := template.New("jq").Parse(string(content)) - if err != nil { - return err - } - - err = t.Execute(f, map[string]any{ - "Driver": project.Current.Info.Driver, - }) + t, err := template.New("root").Parse(content) if err != nil { return err } - return err + return t.Execute(f, project.Current.Info) } diff --git a/test/cli/mysql.bats b/test/cli/mysql.bats index 9aff3e121e1235037ec7c242728e44b9296e1a4a..d6b552b6facc0f2bac89d4926eed078826312bde 100644 --- a/test/cli/mysql.bats +++ b/test/cli/mysql.bats @@ -25,6 +25,8 @@ setup_file() { @test "init" { pg_migrate --verbose init --source "mysql://sakila:N0tSecret@${MYSQL_HOST-localhost}:3306/sakila" + # Reinit + pg_migrate --verbose init --source "mysql://sakila:N0tSecret@${MYSQL_HOST-localhost}/sakila" } @test "git" {