Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the ripoff-export command #6

Merged
merged 3 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
tmp
/ripoff
.DS_Store
/export
14 changes: 13 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ rows:

For more (sometimes wildly complex) examples, see `./testdata`.

## More on valueFuncs and row keys
## More on valueFuncs

valueFuncs allow you to generate random data that's seeded with a static string. This ensures that repeat runs of ripoff are deterministic, which enables upserts (consistent primary keys).

Expand Down Expand Up @@ -90,6 +90,18 @@ rows:
- `rowId` - The map key of the row using this template, ex `users:uuid(fooBar)`. Useful for allowing the "caller" to provide their own ID for the "main" row being created, if there is one. Optional to use if you find it awkward.
- `enums` - A map of SQL enums names to an array of enum values. Useful for creating one row for each value of an enum (ex: each user role).

# Export from your database to ripoff files

An experimental command has been added to generate ripoff files from your database. This may be useful to users just starting to use ripoff who don't have so much fake data that templating is required yet.

Currently, it attempts to export all data from all tables into a single ripoff file. In the future flags may be added to allow you to include/exclude tables, add arbitrary `WHERE` conditions, modify the row id/key, export multiple files, or use existing templates.

## Installation

1. Run `go install github.com/mortenson/ripoff/cmd/ripoff-export@latest`
2. Set the `DATABASE_URL` env variable to your local PostgreSQL database
3. Run `ripoff-export <directory to be deleted and exported to>`

# Security

This project explicitly allows SQL injection due to the way queries are constructed. Do not run `ripoff` on directories you do not trust.
Expand Down
111 changes: 111 additions & 0 deletions cmd/ripoff-export/ripoff_export.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package main

import (
"bytes"
"context"
"fmt"
"log/slog"
"os"
"path"
"path/filepath"

"github.com/jackc/pgx/v5"
"gopkg.in/yaml.v3"

"github.com/mortenson/ripoff"
)

func errAttr(err error) slog.Attr {
return slog.Any("error", err)
}

func main() {
dburl := os.Getenv("DATABASE_URL")
if dburl == "" {
slog.Error("DATABASE_URL env variable is required")
os.Exit(1)
}

if len(os.Args) != 2 {
slog.Error("Path to export directory is required")
os.Exit(1)
}

// Connect to database.
ctx := context.Background()
conn, err := pgx.Connect(ctx, dburl)
if err != nil {
slog.Error("Could not connect to database", errAttr(err))
os.Exit(1)
}
defer conn.Close(ctx)

exportDirectory := path.Clean(os.Args[1])
dirInfo, err := os.Stat(exportDirectory)
if err == nil && !dirInfo.IsDir() {
slog.Error("Export directory is not a directory")
os.Exit(1)
}

// Directory exists, delete it after verifying that it's safe to do so.
if err == nil && !os.IsNotExist(err) {
err = filepath.WalkDir(exportDirectory, func(path string, entry os.DirEntry, err error) error {
if err != nil {
return err
}
if !entry.IsDir() && filepath.Ext(path) != ".yaml" && filepath.Ext(path) != ".yml" {
return fmt.Errorf("ripoff-export can only safely delete directories that only contain YAML files, found: %s", path)
}
return nil
})
if err != nil {
slog.Error("Error verifying test directory", errAttr(err))
os.Exit(1)
}
err = os.RemoveAll(exportDirectory)
if err != nil {
slog.Error("Could not read from export directory", errAttr(err))
os.Exit(1)
}
}

err = os.MkdirAll(exportDirectory, 0755)
if err != nil {
slog.Error("Could not re-create export directory", errAttr(err))
os.Exit(1)
}

tx, err := conn.Begin(ctx)
if err != nil {
slog.Error("Could not create transaction", errAttr(err))
os.Exit(1)
}
defer func() {
err = tx.Rollback(ctx)
if err != nil && err != pgx.ErrTxClosed {
slog.Error("Could not rollback transaction", errAttr(err))
os.Exit(1)
}
}()

ripoffFile, err := ripoff.ExportToRipoff(ctx, tx)
if err != nil {
slog.Error("Could not assemble ripoff file from database", errAttr(err))
os.Exit(1)
}

var ripoffFileBuf bytes.Buffer
yamlEncoder := yaml.NewEncoder(&ripoffFileBuf)
yamlEncoder.SetIndent(2)
err = yamlEncoder.Encode(ripoffFile)
if err != nil {
slog.Error("Could not marshal yaml from ripoff file", errAttr(err))
os.Exit(1)
}

err = os.WriteFile(path.Join(exportDirectory, "ripoff.yml"), ripoffFileBuf.Bytes(), 0644)
if err != nil {
slog.Error("Could not write ripoff file", errAttr(err))
os.Exit(1)
}
}
170 changes: 150 additions & 20 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"crypto/sha256"
"encoding/binary"
"errors"
"fmt"
"log/slog"
"math/rand"
Expand Down Expand Up @@ -107,7 +108,7 @@ func GetEnumValues(ctx context.Context, tx pgx.Tx) (EnumValuesResult, error) {
}

var valueFuncRegex = regexp.MustCompile(`([a-zA-Z]+)\((.*)\)$`)
var referenceRegex = regexp.MustCompile(`^[a-zA-Z0-9_]+:`)
var referenceRegex = regexp.MustCompile(`^[a-zA-Z0-9_]+:[a-zA-Z]+\(`)

func prepareValue(rawValue string) (string, error) {
valueFuncMatches := valueFuncRegex.FindStringSubmatch(rawValue)
Expand Down Expand Up @@ -181,32 +182,60 @@ func buildQueryForRow(primaryKeys PrimaryKeysResult, rowId string, row Row, depe
if column == "~conflict" {
continue
}
// Explicit dependencies, for foreign keys to non-primary keys.
if column == "~dependencies" {
dependencies := []string{}
switch v := valueRaw.(type) {
// Coming from yaml
case []interface{}:
for _, curr := range v {
dependencies = append(dependencies, curr.(string))
}
// Coming from Go, probably a test
case []string:
dependencies = v
default:
return "", fmt.Errorf("cannot parse ~dependencies value in row %s", rowId)
}
for _, dependency := range dependencies {
err := dependencyGraph.AddEdge(rowId, dependency)
if isRealGraphError(err) {
return "", err
}
}
continue
}

// Technically we allow more than strings in ripoff files for templating purposes,
// Technically we allow more than null strings in ripoff files for templating purposes,
// but full support (ex: escaping arrays, what to do with maps, etc.) is quite hard so tabling that for now.
value := fmt.Sprint(valueRaw)
if valueRaw == nil {
values = append(values, "NULL")
setStatements = append(setStatements, fmt.Sprintf("%s = %s", pq.QuoteIdentifier(column), "NULL"))
} else {
value := fmt.Sprint(valueRaw)

// Assume that if a valueFunc is prefixed with a table name, it's a primary/foreign key.
addEdge := referenceRegex.MatchString(value)
// Don't add edges to and from the same row.
if addEdge && rowId != value {
err := dependencyGraph.AddEdge(rowId, value)
if isRealGraphError(err) {
return "", err
}
}

// Assume that if a valueFunc is prefixed with a table name, it's a primary/foreign key.
addEdge := referenceRegex.MatchString(value)
// Don't add edges to and from the same row.
if addEdge && rowId != value {
err := dependencyGraph.AddEdge(rowId, value)
columns = append(columns, pq.QuoteIdentifier(column))
valuePrepared, err := prepareValue(value)
if err != nil {
return "", err
}
// Assume this column is the primary key.
if rowId == value && onConflictColumn == "" {
onConflictColumn = pq.QuoteIdentifier(column)
}
values = append(values, pq.QuoteLiteral(valuePrepared))
setStatements = append(setStatements, fmt.Sprintf("%s = %s", pq.QuoteIdentifier(column), pq.QuoteLiteral(valuePrepared)))
}

columns = append(columns, pq.QuoteIdentifier(column))
valuePrepared, err := prepareValue(value)
if err != nil {
return "", err
}
// Assume this column is the primary key.
if rowId == value && onConflictColumn == "" {
onConflictColumn = pq.QuoteIdentifier(column)
}
values = append(values, pq.QuoteLiteral(valuePrepared))
setStatements = append(setStatements, fmt.Sprintf("%s = %s", pq.QuoteIdentifier(column), pq.QuoteLiteral(valuePrepared)))
}

if onConflictColumn == "" {
Expand Down Expand Up @@ -261,3 +290,104 @@ func buildQueriesForRipoff(primaryKeys PrimaryKeysResult, totalRipoff RipoffFile
}
return sortedQueries, nil
}

const columnsWithForeignKeysQuery = `
select col.table_name as table,
col.column_name,
COALESCE(rel.table_name, '') as primary_table,
COALESCE(rel.column_name, '') as primary_column,
COALESCE(kcu.constraint_name, '')
from information_schema.columns col
left join (select kcu.constraint_schema,
kcu.constraint_name,
kcu.table_schema,
kcu.table_name,
kcu.column_name,
kcu.ordinal_position,
kcu.position_in_unique_constraint
from information_schema.key_column_usage kcu
join information_schema.table_constraints tco
on kcu.constraint_schema = tco.constraint_schema
and kcu.constraint_name = tco.constraint_name
and tco.constraint_type = 'FOREIGN KEY'
) as kcu
on col.table_schema = kcu.table_schema
and col.table_name = kcu.table_name
and col.column_name = kcu.column_name
left join information_schema.referential_constraints rco
on rco.constraint_name = kcu.constraint_name
and rco.constraint_schema = kcu.table_schema
left join information_schema.key_column_usage rel
on rco.unique_constraint_name = rel.constraint_name
and rco.unique_constraint_schema = rel.constraint_schema
and rel.ordinal_position = kcu.position_in_unique_constraint
where col.table_schema = 'public';
`

type ForeignKey struct {
ToTable string
ColumnConditions [][2]string
}

type ForeignKeyResultTable struct {
Columns []string
// Constraint -> Fkey
ForeignKeys map[string]*ForeignKey
}

// Map of table name to foreign keys.
type ForeignKeysResult map[string]*ForeignKeyResultTable

func getForeignKeysResult(ctx context.Context, conn pgx.Tx) (ForeignKeysResult, error) {
rows, err := conn.Query(ctx, columnsWithForeignKeysQuery)
if err != nil {
return ForeignKeysResult{}, err
}
defer rows.Close()

result := ForeignKeysResult{}

for rows.Next() {
var fromTableName string
var fromColumnName string
var toTableName string
var toColumnName string // Unused
var constaintName string
err = rows.Scan(&fromTableName, &fromColumnName, &toTableName, &toColumnName, &constaintName)
if err != nil {
return ForeignKeysResult{}, err
}
_, tableExists := result[fromTableName]
if !tableExists {
result[fromTableName] = &ForeignKeyResultTable{
Columns: []string{},
ForeignKeys: map[string]*ForeignKey{},
}
}
result[fromTableName].Columns = append(result[fromTableName].Columns, fromColumnName)
if constaintName != "" {
_, fkeyExists := result[fromTableName].ForeignKeys[constaintName]
if !fkeyExists {
result[fromTableName].ForeignKeys[constaintName] = &ForeignKey{
ToTable: toTableName,
ColumnConditions: [][2]string{},
}
}
if fromColumnName != "" && toColumnName != "" {
result[fromTableName].ForeignKeys[constaintName].ColumnConditions = append(
result[fromTableName].ForeignKeys[constaintName].ColumnConditions,
[2]string{fromColumnName, toColumnName},
)
}
}
}

return result, nil
}

func isRealGraphError(err error) bool {
if err == nil || errors.Is(err, graph.ErrEdgeAlreadyExists) {
return false
}
return true
}
2 changes: 1 addition & 1 deletion db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func TestRipoff(t *testing.T) {
defer conn.Close(ctx)

_, filename, _, _ := runtime.Caller(0)
dir := path.Join(path.Dir(filename), "testdata")
dir := path.Join(path.Dir(filename), "testdata", "import")
dirEntry, err := os.ReadDir(dir)
require.NoError(t, err)

Expand Down
Loading