Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add generated columns support #20

Merged
merged 4 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 38 additions & 31 deletions modules/core/src/main/scala/com/anymindgroup/PgCodeGen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ class PgCodeGen(
val filterFragment: Fragment[Void] =
sql" AND table_name NOT IN (#${(schemaHistoryTableName :: excludeTables).mkString("'", "','", "'")})"

val q: Query[Void, String ~ String ~ String ~ Option[Int] ~ Option[Int] ~ Option[Int] ~ String ~ Option[String]] =
sql"""SELECT table_name,column_name,udt_name,character_maximum_length,numeric_precision,numeric_scale,is_nullable,column_default
val q =
sql"""SELECT table_name,column_name,udt_name,character_maximum_length,numeric_precision,numeric_scale,is_nullable,column_default,is_generated
FROM information_schema.COLUMNS WHERE table_schema = 'public'$filterFragment UNION
(SELECT
cls.relname AS table_name,
Expand All @@ -114,24 +114,27 @@ class PgCodeGen(
WHEN attr.attnotnull OR tp.typtype = 'd'::"char" AND tp.typnotnull THEN 'NO'::text
ELSE 'YES'::text
END::information_schema.yes_or_no AS is_nullable,
NULL AS column_default
NULL AS column_default,
'NEVER' AS is_generated
FROM pg_catalog.pg_attribute as attr
JOIN pg_catalog.pg_class as cls on cls.oid = attr.attrelid
JOIN pg_catalog.pg_namespace as ns on ns.oid = cls.relnamespace
JOIN pg_catalog.pg_type as tp on tp.oid = attr.atttypid
WHERE cls.relkind = 'm' and attr.attnum >= 1 AND ns.nspname = 'public'
ORDER by attr.attnum)
""".query(name ~ name ~ name ~ int4.opt ~ int4.opt ~ int4.opt ~ varchar(3) ~ varchar.opt)
""".query(name ~ name ~ name ~ int4.opt ~ int4.opt ~ int4.opt ~ varchar(3) ~ varchar.opt ~ varchar)

s.execute(q.map { case tName ~ colName ~ udt ~ maxCharLength ~ numPrecision ~ numScale ~ nullable ~ default =>
(
tName,
colName,
toType(udt, maxCharLength, numPrecision, numScale),
nullable == "YES",
default.flatMap(ColumnDefault.fromString),
)
}).map(_.map { case (tName, colName, udt, isNullable, default) =>
s.execute(q.map {
case tName ~ colName ~ udt ~ maxCharLength ~ numPrecision ~ numScale ~ nullable ~ default ~ is_generated =>
(
tName,
colName,
toType(udt, maxCharLength, numPrecision, numScale),
nullable == "YES",
default.flatMap(ColumnDefault.fromString),
is_generated == "ALWAYS",
)
}).map(_.map { case (tName, colName, udt, isNullable, default, isAlwaysGenerated) =>
toScalaType(udt, isNullable, enums).map { st =>
(
tName,
Expand All @@ -142,6 +145,7 @@ class PgCodeGen(
scalaType = st,
isNullable = isNullable,
default = default,
isAlwaysGenerated = isAlwaysGenerated,
),
)
}.leftMap(new Exception(_))
Expand Down Expand Up @@ -316,7 +320,7 @@ class PgCodeGen(

columns.toList.map { case (tname, tableCols) =>
val tableConstraints = constraints.getOrElse(tname, Nil)
val autoIncColumns = findAutoIncColumns(tname)
val generatedCols = findAutoIncColumns(tname) ::: tableCols.filter(_.isAlwaysGenerated)
val autoIncFk = tableConstraints.collect { case c: Constraint.ForeignKey => c }.flatMap {
_.refs.flatMap { ref =>
tableCols.find(c => c.columnName == ref.fromColName).filter { _ =>
Expand All @@ -327,8 +331,8 @@ class PgCodeGen(

Table(
name = tname,
columns = tableCols.filterNot((autoIncColumns ::: autoIncFk).contains),
autoIncColumns = autoIncColumns,
columns = tableCols.filterNot((generatedCols ::: autoIncFk).contains),
generatedColumns = generatedCols,
constraints = tableConstraints,
indexes = indexes.getOrElse(tname, Nil),
autoIncFk = autoIncFk,
Expand Down Expand Up @@ -625,7 +629,7 @@ class PgCodeGen(
if (autoIncFk.isEmpty) {
(rowClassName, s"${rowClassName}.codec")
} else {
val autoIncFkCodecs = autoIncFk.map(col => s"skunk.codec.all.${col.pgType.name}").mkString(" *: ")
val autoIncFkCodecs = autoIncFk.map(_.codecName).mkString(" *: ")
val autoIncFkScalaTypes = autoIncFk.map(_.scalaType).mkString(" *: ")
(s"($autoIncFkScalaTypes ~ $rowClassName)", s"$autoIncFkCodecs ~ ${rowClassName}.codec")
}
Expand All @@ -640,18 +644,20 @@ class PgCodeGen(
val allColNames = allCols.map(_.columnName).mkString(",")
val (insertScalaType, insertCodec) = queryTypesStr(table)

val returningStatement = autoIncColumns match {
val returningStatement = generatedColumns match {
case Nil => ""
case _ => autoIncColumns.map(_.columnName).mkString(" RETURNING ", ",", "")
case _ => generatedColumns.map(_.columnName).mkString(" RETURNING ", ",", "")
}
val returningType = autoIncColumns.map(_.scalaType).mkString(" *: ")
val fragmentType = autoIncColumns match {
val returningType = generatedColumns
.map(_.scalaType)
.mkString("", " *: ", if (generatedColumns.length > 1) " *: EmptyTuple" else "")
val fragmentType = generatedColumns match {
case Nil => "command"
case _ => s"query(${autoIncColumns.map(col => s"skunk.codec.all.${col.pgType.name}").mkString(" *: ")})"
case _ => s"query(${generatedColumns.map(_.codecName).mkString(" *: ")})"
}

val upsertQ = primaryUniqueConstraint.map { cstr =>
val queryType = autoIncColumns match {
val queryType = generatedColumns match {
case Nil => s"Command[$insertScalaType *: updateFr.A *: EmptyTuple]"
case _ => s"Query[$insertScalaType *: updateFr.A *: EmptyTuple, $returningType]"
}
Expand All @@ -662,7 +668,7 @@ class PgCodeGen(
| DO UPDATE SET $${updateFr.fragment}$returningStatement\"\"\".$fragmentType""".stripMargin
}

val queryType = autoIncColumns match {
val queryType = generatedColumns match {
case Nil => s"Command[$insertScalaType]"
case _ => s"Query[$insertScalaType, $returningType]"
}
Expand Down Expand Up @@ -693,7 +699,7 @@ class PgCodeGen(
}

private def tableColumns(table: Table): (Option[String], String) = {
val allCols = table.autoIncColumns ::: table.autoIncFk ::: table.columns
val allCols = table.generatedColumns ::: table.autoIncFk ::: table.columns
val cols =
allCols.map(column =>
s""" val ${column.snakeCaseScalaName} = Cols(NonEmptyList.of("${column.columnName}"), ${column.codecName}, tableName)"""
Expand All @@ -714,13 +720,13 @@ class PgCodeGen(
private def selectAllStatement(table: Table): String = {
import table.*

val autoIncStm = if (autoIncColumns.nonEmpty) {
val types = autoIncColumns.map(_.codecName).mkString(" *: ")
val sTypes = autoIncColumns.map(_.scalaType).mkString(" *: ")
val colNamesStr = (autoIncColumns ::: columns).map(_.columnName).mkString(", ")
val autoIncStm = if (generatedColumns.nonEmpty) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think u forgot to rename autoIncStm

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right, updated!

val types = generatedColumns.map(_.codecName).mkString(" *: ")
val sTypes = generatedColumns.map(_.scalaType).mkString(" *: ")
val colNamesStr = (generatedColumns ::: columns).map(_.columnName).mkString(", ")

s"""
| def selectAllWithId[A](addClause: Fragment[A] = Fragment.empty): Query[A, $sTypes *: $rowClassName *: EmptyTuple] =
| def selectAllWithGenerated[A](addClause: Fragment[A] = Fragment.empty): Query[A, $sTypes *: $rowClassName *: EmptyTuple] =
| sql"SELECT $colNamesStr FROM #$$tableName $$addClause".query($types *: ${rowClassName}.codec)
|
""".stripMargin
Expand Down Expand Up @@ -806,6 +812,7 @@ object PgCodeGen {
isEnum: Boolean,
isNullable: Boolean,
default: Option[ColumnDefault],
isAlwaysGenerated: Boolean,
) {
val scalaName: String = toScalaName(columnName)
val snakeCaseScalaName: String = escapeScalaKeywords(columnName)
Expand Down Expand Up @@ -849,7 +856,7 @@ object PgCodeGen {
final case class Table(
name: String,
columns: List[Column],
autoIncColumns: List[Column],
generatedColumns: List[Column],
constraints: List[Constraint],
indexes: List[Index],
autoIncFk: List[Column],
Expand Down
4 changes: 3 additions & 1 deletion modules/core/src/test/resources/db/migration/V1__test.sql
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ CREATE TABLE test (
tla_var varchar(3) NOT NULL,
numeric_default numeric NOT NULL,
numeric_24p numeric(24) NOT NULL,
numeric_16p_2s numeric(16, 2) NOT NULL
numeric_16p_2s numeric(16, 2) NOT NULL,
gen INT NOT NULL GENERATED ALWAYS AS (1 + 1) STORED,
gen_opt INT GENERATED ALWAYS AS (1 + 1) STORED
);

CREATE TABLE test_ref_only (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import better.files.File
import skunk.*
import skunk.codec.all.*
import skunk.util.Origin
import skunk.{Command as SqlCommand}
import skunk.Command as SqlCommand
import cats.implicits.*
import java.time.temporal.ChronoUnit
import java.time.ZoneOffset
Expand Down Expand Up @@ -44,7 +44,7 @@ object GeneratedCodeTest extends IOApp {
tlaVar = "abc",
numericDefault = BigDecimal(1),
numeric24p = BigDecimal(2),
numeric16p2s = BigDecimal(3)
numeric16p2s = BigDecimal(3),
).withUpdateAll

val testBRow = TestBRow(
Expand Down Expand Up @@ -83,20 +83,24 @@ object GeneratedCodeTest extends IOApp {
// Test table
p <- s.prepare(TestTable.upsertQuery(testUpdateFr))
_ <- s.prepare(TestTable.insertQuery(ignoreConflict = true))
id <- p.option((testRow, testUpdateFr.argument))
_ <- IO.raiseWhen(id.isEmpty)(new Throwable("test A did not return a generated id"))
res <- p.option((testRow, testUpdateFr.argument))
_ <- IO.raiseWhen(res.isEmpty)(new Throwable("test A did not return generated columns"))
id = res.get._1
_ <- IO.raiseWhen(res.get._2 != 2 && res.get._3 != Some(2))(
new Throwable("unexpected result for generated columns")
)
all <- s.execute(TestTable.selectAll())
allWithId <- s.execute(TestTable.selectAllWithId())
allWithGen <- s.execute(TestTable.selectAllWithGenerated())
_ <- IO.raiseWhen(all != List(testRow))(new Throwable("test A result not equal"))
_ <- IO.raiseWhen(allWithId.map(_._2) != List(testRow))(new Throwable("test A result with id not equal"))
_ <- IO.raiseWhen(allWithGen.map(_._4) != List(testRow))(new Throwable("test A result with id not equal"))
aliasedTestTable = TestTable.withAlias("t")
idAndName2 = aliasedTestTable.column.id ~ aliasedTestTable.column.name_2
xs <-
s.execute(
sql"""SELECT #${idAndName2.aliasedName},#${aliasedTestTable.column.name.fullName} FROM #${TestTable.tableName} #${aliasedTestTable.tableName}"""
.query(idAndName2.codec ~ TestTable.column.name.codec)
)
_ <- IO.raiseWhen(xs != List((id.get, testRow.name2) -> testRow.name))(
_ <- IO.raiseWhen(xs != List((id, testRow.name2) -> testRow.name))(
new Throwable("test A select fields not equal")
)
all2 <- s.execute(TestTable.select(TestTable.all))
Expand Down Expand Up @@ -179,10 +183,10 @@ object GeneratedCodeTest extends IOApp {
_ <- IO.raiseWhen(
Some(testBRow.copy(val27 = None, val2 = "updated_val_2", val14 = "updated_val_14")) != loadedById
)(new Throwable("test B result missing update"))
_ <- s.execute(sql"REFRESH MATERIALIZED VIEW test_materialized_view".command)
_ <- s.execute(sql"REFRESH MATERIALIZED VIEW test_materialized_view".command)
result <- s.execute(TestMaterializedViewTable.selectAll())
_ <- IO.raiseWhen(result.isEmpty)(new Throwable(s"materialized view doesn't have correct value: ${result}"))
_ <- IO.println("Test successful!")
_ <- IO.raiseWhen(result.isEmpty)(new Throwable(s"materialized view doesn't have correct value: ${result}"))
_ <- IO.println("Test successful!")
} yield ()

}
Expand Down