Skip to content

Commit

Permalink
add generated columns support (#20)
Browse files Browse the repository at this point in the history
* add generated columns support
  • Loading branch information
rolang authored Nov 20, 2024
1 parent c2f11e5 commit 83b7b0b
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 43 deletions.
71 changes: 39 additions & 32 deletions modules/core/src/main/scala/com/anymindgroup/PgCodeGen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ class PgCodeGen(
val filterFragment: Fragment[Void] =
sql" AND table_name NOT IN (#${(schemaHistoryTableName :: excludeTables).mkString("'", "','", "'")})"

val q: Query[Void, String ~ String ~ String ~ Option[Int] ~ Option[Int] ~ Option[Int] ~ String ~ Option[String]] =
sql"""SELECT table_name,column_name,udt_name,character_maximum_length,numeric_precision,numeric_scale,is_nullable,column_default
val q =
sql"""SELECT table_name,column_name,udt_name,character_maximum_length,numeric_precision,numeric_scale,is_nullable,column_default,is_generated
FROM information_schema.COLUMNS WHERE table_schema = 'public'$filterFragment UNION
(SELECT
cls.relname AS table_name,
Expand All @@ -114,24 +114,27 @@ class PgCodeGen(
WHEN attr.attnotnull OR tp.typtype = 'd'::"char" AND tp.typnotnull THEN 'NO'::text
ELSE 'YES'::text
END::information_schema.yes_or_no AS is_nullable,
NULL AS column_default
NULL AS column_default,
'NEVER' AS is_generated
FROM pg_catalog.pg_attribute as attr
JOIN pg_catalog.pg_class as cls on cls.oid = attr.attrelid
JOIN pg_catalog.pg_namespace as ns on ns.oid = cls.relnamespace
JOIN pg_catalog.pg_type as tp on tp.oid = attr.atttypid
WHERE cls.relkind = 'm' and attr.attnum >= 1 AND ns.nspname = 'public'
ORDER by attr.attnum)
""".query(name ~ name ~ name ~ int4.opt ~ int4.opt ~ int4.opt ~ varchar(3) ~ varchar.opt)
""".query(name ~ name ~ name ~ int4.opt ~ int4.opt ~ int4.opt ~ varchar(3) ~ varchar.opt ~ varchar)

s.execute(q.map { case tName ~ colName ~ udt ~ maxCharLength ~ numPrecision ~ numScale ~ nullable ~ default =>
(
tName,
colName,
toType(udt, maxCharLength, numPrecision, numScale),
nullable == "YES",
default.flatMap(ColumnDefault.fromString),
)
}).map(_.map { case (tName, colName, udt, isNullable, default) =>
s.execute(q.map {
case tName ~ colName ~ udt ~ maxCharLength ~ numPrecision ~ numScale ~ nullable ~ default ~ is_generated =>
(
tName,
colName,
toType(udt, maxCharLength, numPrecision, numScale),
nullable == "YES",
default.flatMap(ColumnDefault.fromString),
is_generated == "ALWAYS",
)
}).map(_.map { case (tName, colName, udt, isNullable, default, isAlwaysGenerated) =>
toScalaType(udt, isNullable, enums).map { st =>
(
tName,
Expand All @@ -142,6 +145,7 @@ class PgCodeGen(
scalaType = st,
isNullable = isNullable,
default = default,
isAlwaysGenerated = isAlwaysGenerated,
),
)
}.leftMap(new Exception(_))
Expand Down Expand Up @@ -316,7 +320,7 @@ class PgCodeGen(

columns.toList.map { case (tname, tableCols) =>
val tableConstraints = constraints.getOrElse(tname, Nil)
val autoIncColumns = findAutoIncColumns(tname)
val generatedCols = findAutoIncColumns(tname) ::: tableCols.filter(_.isAlwaysGenerated)
val autoIncFk = tableConstraints.collect { case c: Constraint.ForeignKey => c }.flatMap {
_.refs.flatMap { ref =>
tableCols.find(c => c.columnName == ref.fromColName).filter { _ =>
Expand All @@ -327,8 +331,8 @@ class PgCodeGen(

Table(
name = tname,
columns = tableCols.filterNot((autoIncColumns ::: autoIncFk).contains),
autoIncColumns = autoIncColumns,
columns = tableCols.filterNot((generatedCols ::: autoIncFk).contains),
generatedColumns = generatedCols,
constraints = tableConstraints,
indexes = indexes.getOrElse(tname, Nil),
autoIncFk = autoIncFk,
Expand Down Expand Up @@ -625,7 +629,7 @@ class PgCodeGen(
if (autoIncFk.isEmpty) {
(rowClassName, s"${rowClassName}.codec")
} else {
val autoIncFkCodecs = autoIncFk.map(col => s"skunk.codec.all.${col.pgType.name}").mkString(" *: ")
val autoIncFkCodecs = autoIncFk.map(_.codecName).mkString(" *: ")
val autoIncFkScalaTypes = autoIncFk.map(_.scalaType).mkString(" *: ")
(s"($autoIncFkScalaTypes ~ $rowClassName)", s"$autoIncFkCodecs ~ ${rowClassName}.codec")
}
Expand All @@ -640,18 +644,20 @@ class PgCodeGen(
val allColNames = allCols.map(_.columnName).mkString(",")
val (insertScalaType, insertCodec) = queryTypesStr(table)

val returningStatement = autoIncColumns match {
val returningStatement = generatedColumns match {
case Nil => ""
case _ => autoIncColumns.map(_.columnName).mkString(" RETURNING ", ",", "")
case _ => generatedColumns.map(_.columnName).mkString(" RETURNING ", ",", "")
}
val returningType = autoIncColumns.map(_.scalaType).mkString(" *: ")
val fragmentType = autoIncColumns match {
val returningType = generatedColumns
.map(_.scalaType)
.mkString("", " *: ", if (generatedColumns.length > 1) " *: EmptyTuple" else "")
val fragmentType = generatedColumns match {
case Nil => "command"
case _ => s"query(${autoIncColumns.map(col => s"skunk.codec.all.${col.pgType.name}").mkString(" *: ")})"
case _ => s"query(${generatedColumns.map(_.codecName).mkString(" *: ")})"
}

val upsertQ = primaryUniqueConstraint.map { cstr =>
val queryType = autoIncColumns match {
val queryType = generatedColumns match {
case Nil => s"Command[$insertScalaType *: updateFr.A *: EmptyTuple]"
case _ => s"Query[$insertScalaType *: updateFr.A *: EmptyTuple, $returningType]"
}
Expand All @@ -662,7 +668,7 @@ class PgCodeGen(
| DO UPDATE SET $${updateFr.fragment}$returningStatement\"\"\".$fragmentType""".stripMargin
}

val queryType = autoIncColumns match {
val queryType = generatedColumns match {
case Nil => s"Command[$insertScalaType]"
case _ => s"Query[$insertScalaType, $returningType]"
}
Expand Down Expand Up @@ -693,7 +699,7 @@ class PgCodeGen(
}

private def tableColumns(table: Table): (Option[String], String) = {
val allCols = table.autoIncColumns ::: table.autoIncFk ::: table.columns
val allCols = table.generatedColumns ::: table.autoIncFk ::: table.columns
val cols =
allCols.map(column =>
s""" val ${column.snakeCaseScalaName} = Cols(NonEmptyList.of("${column.columnName}"), ${column.codecName}, tableName)"""
Expand All @@ -714,13 +720,13 @@ class PgCodeGen(
private def selectAllStatement(table: Table): String = {
import table.*

val autoIncStm = if (autoIncColumns.nonEmpty) {
val types = autoIncColumns.map(_.codecName).mkString(" *: ")
val sTypes = autoIncColumns.map(_.scalaType).mkString(" *: ")
val colNamesStr = (autoIncColumns ::: columns).map(_.columnName).mkString(", ")
val generatedColStm = if (generatedColumns.nonEmpty) {
val types = generatedColumns.map(_.codecName).mkString(" *: ")
val sTypes = generatedColumns.map(_.scalaType).mkString(" *: ")
val colNamesStr = (generatedColumns ::: columns).map(_.columnName).mkString(", ")

s"""
| def selectAllWithId[A](addClause: Fragment[A] = Fragment.empty): Query[A, $sTypes *: $rowClassName *: EmptyTuple] =
| def selectAllWithGenerated[A](addClause: Fragment[A] = Fragment.empty): Query[A, $sTypes *: $rowClassName *: EmptyTuple] =
| sql"SELECT $colNamesStr FROM #$$tableName $$addClause".query($types *: ${rowClassName}.codec)
|
""".stripMargin
Expand All @@ -740,7 +746,7 @@ class PgCodeGen(
val selectCol = s"""| def select[A, B](cols: Cols[A], rest: Fragment[B] = Fragment.empty): Query[B, A] =
| sql"SELECT #$${cols.name} FROM #$$tableName $$rest".query(cols.codec)
|""".stripMargin
autoIncStm ++ defaultStm ++ selectCol
generatedColStm ++ defaultStm ++ selectCol
}

private def lastModified(modified: List[Long]): Option[Long] =
Expand Down Expand Up @@ -806,6 +812,7 @@ object PgCodeGen {
isEnum: Boolean,
isNullable: Boolean,
default: Option[ColumnDefault],
isAlwaysGenerated: Boolean,
) {
val scalaName: String = toScalaName(columnName)
val snakeCaseScalaName: String = escapeScalaKeywords(columnName)
Expand Down Expand Up @@ -849,7 +856,7 @@ object PgCodeGen {
final case class Table(
name: String,
columns: List[Column],
autoIncColumns: List[Column],
generatedColumns: List[Column],
constraints: List[Constraint],
indexes: List[Index],
autoIncFk: List[Column],
Expand Down
4 changes: 3 additions & 1 deletion modules/core/src/test/resources/db/migration/V1__test.sql
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ CREATE TABLE test (
tla_var varchar(3) NOT NULL,
numeric_default numeric NOT NULL,
numeric_24p numeric(24) NOT NULL,
numeric_16p_2s numeric(16, 2) NOT NULL
numeric_16p_2s numeric(16, 2) NOT NULL,
gen INT NOT NULL GENERATED ALWAYS AS (1 + 1) STORED,
gen_opt INT GENERATED ALWAYS AS (1 + 1) STORED
);

CREATE TABLE test_ref_only (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import better.files.File
import skunk.*
import skunk.codec.all.*
import skunk.util.Origin
import skunk.{Command as SqlCommand}
import skunk.Command as SqlCommand
import cats.implicits.*
import java.time.temporal.ChronoUnit
import java.time.ZoneOffset
Expand Down Expand Up @@ -44,7 +44,7 @@ object GeneratedCodeTest extends IOApp {
tlaVar = "abc",
numericDefault = BigDecimal(1),
numeric24p = BigDecimal(2),
numeric16p2s = BigDecimal(3)
numeric16p2s = BigDecimal(3),
).withUpdateAll

val testBRow = TestBRow(
Expand Down Expand Up @@ -83,20 +83,24 @@ object GeneratedCodeTest extends IOApp {
// Test table
p <- s.prepare(TestTable.upsertQuery(testUpdateFr))
_ <- s.prepare(TestTable.insertQuery(ignoreConflict = true))
id <- p.option((testRow, testUpdateFr.argument))
_ <- IO.raiseWhen(id.isEmpty)(new Throwable("test A did not return a generated id"))
res <- p.option((testRow, testUpdateFr.argument))
_ <- IO.raiseWhen(res.isEmpty)(new Throwable("test A did not return generated columns"))
id = res.get._1
_ <- IO.raiseWhen(res.get._2 != 2 && res.get._3 != Some(2))(
new Throwable("unexpected result for generated columns")
)
all <- s.execute(TestTable.selectAll())
allWithId <- s.execute(TestTable.selectAllWithId())
allWithGen <- s.execute(TestTable.selectAllWithGenerated())
_ <- IO.raiseWhen(all != List(testRow))(new Throwable("test A result not equal"))
_ <- IO.raiseWhen(allWithId.map(_._2) != List(testRow))(new Throwable("test A result with id not equal"))
_ <- IO.raiseWhen(allWithGen.map(_._4) != List(testRow))(new Throwable("test A result with id not equal"))
aliasedTestTable = TestTable.withAlias("t")
idAndName2 = aliasedTestTable.column.id ~ aliasedTestTable.column.name_2
xs <-
s.execute(
sql"""SELECT #${idAndName2.aliasedName},#${aliasedTestTable.column.name.fullName} FROM #${TestTable.tableName} #${aliasedTestTable.tableName}"""
.query(idAndName2.codec ~ TestTable.column.name.codec)
)
_ <- IO.raiseWhen(xs != List((id.get, testRow.name2) -> testRow.name))(
_ <- IO.raiseWhen(xs != List((id, testRow.name2) -> testRow.name))(
new Throwable("test A select fields not equal")
)
all2 <- s.execute(TestTable.select(TestTable.all))
Expand Down Expand Up @@ -179,10 +183,10 @@ object GeneratedCodeTest extends IOApp {
_ <- IO.raiseWhen(
Some(testBRow.copy(val27 = None, val2 = "updated_val_2", val14 = "updated_val_14")) != loadedById
)(new Throwable("test B result missing update"))
_ <- s.execute(sql"REFRESH MATERIALIZED VIEW test_materialized_view".command)
_ <- s.execute(sql"REFRESH MATERIALIZED VIEW test_materialized_view".command)
result <- s.execute(TestMaterializedViewTable.selectAll())
_ <- IO.raiseWhen(result.isEmpty)(new Throwable(s"materialized view doesn't have correct value: ${result}"))
_ <- IO.println("Test successful!")
_ <- IO.raiseWhen(result.isEmpty)(new Throwable(s"materialized view doesn't have correct value: ${result}"))
_ <- IO.println("Test successful!")
} yield ()

}
Expand Down

0 comments on commit 83b7b0b

Please sign in to comment.