From c59a74bcfc5bd60abd5b2d02718434472a73a252 Mon Sep 17 00:00:00 2001 From: Roman Langolf Date: Tue, 19 Nov 2024 14:35:14 +0700 Subject: [PATCH 1/4] add generated columns support --- .../scala/com/anymindgroup/PgCodeGen.scala | 65 ++++++++++--------- .../test/resources/db/migration/V1__test.sql | 3 +- .../com/anymindgroup/GeneratedCodeTest._scala | 18 ++--- 3 files changed, 48 insertions(+), 38 deletions(-) diff --git a/modules/core/src/main/scala/com/anymindgroup/PgCodeGen.scala b/modules/core/src/main/scala/com/anymindgroup/PgCodeGen.scala index 1dc804e..1086ffd 100644 --- a/modules/core/src/main/scala/com/anymindgroup/PgCodeGen.scala +++ b/modules/core/src/main/scala/com/anymindgroup/PgCodeGen.scala @@ -97,8 +97,8 @@ class PgCodeGen( val filterFragment: Fragment[Void] = sql" AND table_name NOT IN (#${(schemaHistoryTableName :: excludeTables).mkString("'", "','", "'")})" - val q: Query[Void, String ~ String ~ String ~ Option[Int] ~ Option[Int] ~ Option[Int] ~ String ~ Option[String]] = - sql"""SELECT table_name,column_name,udt_name,character_maximum_length,numeric_precision,numeric_scale,is_nullable,column_default + val q = + sql"""SELECT table_name,column_name,udt_name,character_maximum_length,numeric_precision,numeric_scale,is_nullable,column_default,is_generated FROM information_schema.COLUMNS WHERE table_schema = 'public'$filterFragment UNION (SELECT cls.relname AS table_name, @@ -114,24 +114,27 @@ class PgCodeGen( WHEN attr.attnotnull OR tp.typtype = 'd'::"char" AND tp.typnotnull THEN 'NO'::text ELSE 'YES'::text END::information_schema.yes_or_no AS is_nullable, - NULL AS column_default + NULL AS column_default, + 'NEVER' AS is_generated FROM pg_catalog.pg_attribute as attr JOIN pg_catalog.pg_class as cls on cls.oid = attr.attrelid JOIN pg_catalog.pg_namespace as ns on ns.oid = cls.relnamespace JOIN pg_catalog.pg_type as tp on tp.oid = attr.atttypid WHERE cls.relkind = 'm' and attr.attnum >= 1 AND ns.nspname = 'public' ORDER by attr.attnum) - """.query(name ~ name ~ name ~ int4.opt ~ int4.opt ~ int4.opt ~ varchar(3) ~ varchar.opt) + """.query(name ~ name ~ name ~ int4.opt ~ int4.opt ~ int4.opt ~ varchar(3) ~ varchar.opt ~ varchar) - s.execute(q.map { case tName ~ colName ~ udt ~ maxCharLength ~ numPrecision ~ numScale ~ nullable ~ default => - ( - tName, - colName, - toType(udt, maxCharLength, numPrecision, numScale), - nullable == "YES", - default.flatMap(ColumnDefault.fromString), - ) - }).map(_.map { case (tName, colName, udt, isNullable, default) => + s.execute(q.map { + case tName ~ colName ~ udt ~ maxCharLength ~ numPrecision ~ numScale ~ nullable ~ default ~ is_generated => + ( + tName, + colName, + toType(udt, maxCharLength, numPrecision, numScale), + nullable == "YES", + default.flatMap(ColumnDefault.fromString), + is_generated == "ALWAYS", + ) + }).map(_.map { case (tName, colName, udt, isNullable, default, isAlwaysGenerated) => toScalaType(udt, isNullable, enums).map { st => ( tName, @@ -142,6 +145,7 @@ class PgCodeGen( scalaType = st, isNullable = isNullable, default = default, + isAlwaysGenerated = isAlwaysGenerated, ), ) }.leftMap(new Exception(_)) @@ -316,7 +320,7 @@ class PgCodeGen( columns.toList.map { case (tname, tableCols) => val tableConstraints = constraints.getOrElse(tname, Nil) - val autoIncColumns = findAutoIncColumns(tname) + val generatedCols = findAutoIncColumns(tname) ::: tableCols.filter(_.isAlwaysGenerated) val autoIncFk = tableConstraints.collect { case c: Constraint.ForeignKey => c }.flatMap { _.refs.flatMap { ref => tableCols.find(c => c.columnName == ref.fromColName).filter { _ => @@ -327,8 +331,8 @@ class PgCodeGen( Table( name = tname, - columns = tableCols.filterNot((autoIncColumns ::: autoIncFk).contains), - autoIncColumns = autoIncColumns, + columns = tableCols.filterNot((generatedCols ::: autoIncFk).contains), + generatedColumns = generatedCols, constraints = tableConstraints, indexes = indexes.getOrElse(tname, Nil), autoIncFk = autoIncFk, @@ -640,18 +644,20 @@ class PgCodeGen( val allColNames = allCols.map(_.columnName).mkString(",") val (insertScalaType, insertCodec) = queryTypesStr(table) - val returningStatement = autoIncColumns match { + val returningStatement = generatedColumns match { case Nil => "" - case _ => autoIncColumns.map(_.columnName).mkString(" RETURNING ", ",", "") + case _ => generatedColumns.map(_.columnName).mkString(" RETURNING ", ",", "") } - val returningType = autoIncColumns.map(_.scalaType).mkString(" *: ") - val fragmentType = autoIncColumns match { + val returningType = generatedColumns + .map(_.scalaType) + .mkString("", " *: ", if (generatedColumns.length > 1) " *: EmptyTuple" else "") + val fragmentType = generatedColumns match { case Nil => "command" - case _ => s"query(${autoIncColumns.map(col => s"skunk.codec.all.${col.pgType.name}").mkString(" *: ")})" + case _ => s"query(${generatedColumns.map(col => s"skunk.codec.all.${col.pgType.name}").mkString(" *: ")})" } val upsertQ = primaryUniqueConstraint.map { cstr => - val queryType = autoIncColumns match { + val queryType = generatedColumns match { case Nil => s"Command[$insertScalaType *: updateFr.A *: EmptyTuple]" case _ => s"Query[$insertScalaType *: updateFr.A *: EmptyTuple, $returningType]" } @@ -662,7 +668,7 @@ class PgCodeGen( | DO UPDATE SET $${updateFr.fragment}$returningStatement\"\"\".$fragmentType""".stripMargin } - val queryType = autoIncColumns match { + val queryType = generatedColumns match { case Nil => s"Command[$insertScalaType]" case _ => s"Query[$insertScalaType, $returningType]" } @@ -693,7 +699,7 @@ class PgCodeGen( } private def tableColumns(table: Table): (Option[String], String) = { - val allCols = table.autoIncColumns ::: table.autoIncFk ::: table.columns + val allCols = table.generatedColumns ::: table.autoIncFk ::: table.columns val cols = allCols.map(column => s""" val ${column.snakeCaseScalaName} = Cols(NonEmptyList.of("${column.columnName}"), ${column.codecName}, tableName)""" @@ -714,10 +720,10 @@ class PgCodeGen( private def selectAllStatement(table: Table): String = { import table.* - val autoIncStm = if (autoIncColumns.nonEmpty) { - val types = autoIncColumns.map(_.codecName).mkString(" *: ") - val sTypes = autoIncColumns.map(_.scalaType).mkString(" *: ") - val colNamesStr = (autoIncColumns ::: columns).map(_.columnName).mkString(", ") + val autoIncStm = if (generatedColumns.nonEmpty) { + val types = generatedColumns.map(_.codecName).mkString(" *: ") + val sTypes = generatedColumns.map(_.scalaType).mkString(" *: ") + val colNamesStr = (generatedColumns ::: columns).map(_.columnName).mkString(", ") s""" | def selectAllWithId[A](addClause: Fragment[A] = Fragment.empty): Query[A, $sTypes *: $rowClassName *: EmptyTuple] = @@ -806,6 +812,7 @@ object PgCodeGen { isEnum: Boolean, isNullable: Boolean, default: Option[ColumnDefault], + isAlwaysGenerated: Boolean, ) { val scalaName: String = toScalaName(columnName) val snakeCaseScalaName: String = escapeScalaKeywords(columnName) @@ -849,7 +856,7 @@ object PgCodeGen { final case class Table( name: String, columns: List[Column], - autoIncColumns: List[Column], + generatedColumns: List[Column], constraints: List[Constraint], indexes: List[Index], autoIncFk: List[Column], diff --git a/modules/core/src/test/resources/db/migration/V1__test.sql b/modules/core/src/test/resources/db/migration/V1__test.sql index c14fa79..6707756 100644 --- a/modules/core/src/test/resources/db/migration/V1__test.sql +++ b/modules/core/src/test/resources/db/migration/V1__test.sql @@ -14,7 +14,8 @@ CREATE TABLE test ( tla_var varchar(3) NOT NULL, numeric_default numeric NOT NULL, numeric_24p numeric(24) NOT NULL, - numeric_16p_2s numeric(16, 2) NOT NULL + numeric_16p_2s numeric(16, 2) NOT NULL, + gen INT NOT NULL GENERATED ALWAYS AS (1 + 1) STORED ); CREATE TABLE test_ref_only ( diff --git a/modules/core/src/test/scala/com/anymindgroup/GeneratedCodeTest._scala b/modules/core/src/test/scala/com/anymindgroup/GeneratedCodeTest._scala index 6752886..fb91c16 100644 --- a/modules/core/src/test/scala/com/anymindgroup/GeneratedCodeTest._scala +++ b/modules/core/src/test/scala/com/anymindgroup/GeneratedCodeTest._scala @@ -16,7 +16,7 @@ import better.files.File import skunk.* import skunk.codec.all.* import skunk.util.Origin -import skunk.{Command as SqlCommand} +import skunk.Command as SqlCommand import cats.implicits.* import java.time.temporal.ChronoUnit import java.time.ZoneOffset @@ -44,7 +44,7 @@ object GeneratedCodeTest extends IOApp { tlaVar = "abc", numericDefault = BigDecimal(1), numeric24p = BigDecimal(2), - numeric16p2s = BigDecimal(3) + numeric16p2s = BigDecimal(3), ).withUpdateAll val testBRow = TestBRow( @@ -83,8 +83,10 @@ object GeneratedCodeTest extends IOApp { // Test table p <- s.prepare(TestTable.upsertQuery(testUpdateFr)) _ <- s.prepare(TestTable.insertQuery(ignoreConflict = true)) - id <- p.option((testRow, testUpdateFr.argument)) - _ <- IO.raiseWhen(id.isEmpty)(new Throwable("test A did not return a generated id")) + res <- p.option((testRow, testUpdateFr.argument)) + _ <- IO.raiseWhen(res.isEmpty)(new Throwable("test A did not return generated columns")) + id = res.get.head + _ <- IO.raiseWhen(res.get.tail.head != 2)(new Throwable("unexpected result for generated column")) all <- s.execute(TestTable.selectAll()) allWithId <- s.execute(TestTable.selectAllWithId()) _ <- IO.raiseWhen(all != List(testRow))(new Throwable("test A result not equal")) @@ -96,7 +98,7 @@ object GeneratedCodeTest extends IOApp { sql"""SELECT #${idAndName2.aliasedName},#${aliasedTestTable.column.name.fullName} FROM #${TestTable.tableName} #${aliasedTestTable.tableName}""" .query(idAndName2.codec ~ TestTable.column.name.codec) ) - _ <- IO.raiseWhen(xs != List((id.get, testRow.name2) -> testRow.name))( + _ <- IO.raiseWhen(xs != List((id, testRow.name2) -> testRow.name))( new Throwable("test A select fields not equal") ) all2 <- s.execute(TestTable.select(TestTable.all)) @@ -179,10 +181,10 @@ object GeneratedCodeTest extends IOApp { _ <- IO.raiseWhen( Some(testBRow.copy(val27 = None, val2 = "updated_val_2", val14 = "updated_val_14")) != loadedById )(new Throwable("test B result missing update")) - _ <- s.execute(sql"REFRESH MATERIALIZED VIEW test_materialized_view".command) + _ <- s.execute(sql"REFRESH MATERIALIZED VIEW test_materialized_view".command) result <- s.execute(TestMaterializedViewTable.selectAll()) - _ <- IO.raiseWhen(result.isEmpty)(new Throwable(s"materialized view doesn't have correct value: ${result}")) - _ <- IO.println("Test successful!") + _ <- IO.raiseWhen(result.isEmpty)(new Throwable(s"materialized view doesn't have correct value: ${result}")) + _ <- IO.println("Test successful!") } yield () } From 2c65aed71efc0eea317403ac4d0b4b1634a696bb Mon Sep 17 00:00:00 2001 From: Roman Langolf Date: Tue, 19 Nov 2024 14:51:38 +0700 Subject: [PATCH 2/4] update method name and test --- .../core/src/main/scala/com/anymindgroup/PgCodeGen.scala | 2 +- .../test/scala/com/anymindgroup/GeneratedCodeTest._scala | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/modules/core/src/main/scala/com/anymindgroup/PgCodeGen.scala b/modules/core/src/main/scala/com/anymindgroup/PgCodeGen.scala index 1086ffd..104d28e 100644 --- a/modules/core/src/main/scala/com/anymindgroup/PgCodeGen.scala +++ b/modules/core/src/main/scala/com/anymindgroup/PgCodeGen.scala @@ -726,7 +726,7 @@ class PgCodeGen( val colNamesStr = (generatedColumns ::: columns).map(_.columnName).mkString(", ") s""" - | def selectAllWithId[A](addClause: Fragment[A] = Fragment.empty): Query[A, $sTypes *: $rowClassName *: EmptyTuple] = + | def selectAllWithGenerated[A](addClause: Fragment[A] = Fragment.empty): Query[A, $sTypes *: $rowClassName *: EmptyTuple] = | sql"SELECT $colNamesStr FROM #$$tableName $$addClause".query($types *: ${rowClassName}.codec) | """.stripMargin diff --git a/modules/core/src/test/scala/com/anymindgroup/GeneratedCodeTest._scala b/modules/core/src/test/scala/com/anymindgroup/GeneratedCodeTest._scala index fb91c16..1fc0063 100644 --- a/modules/core/src/test/scala/com/anymindgroup/GeneratedCodeTest._scala +++ b/modules/core/src/test/scala/com/anymindgroup/GeneratedCodeTest._scala @@ -85,12 +85,12 @@ object GeneratedCodeTest extends IOApp { _ <- s.prepare(TestTable.insertQuery(ignoreConflict = true)) res <- p.option((testRow, testUpdateFr.argument)) _ <- IO.raiseWhen(res.isEmpty)(new Throwable("test A did not return generated columns")) - id = res.get.head - _ <- IO.raiseWhen(res.get.tail.head != 2)(new Throwable("unexpected result for generated column")) + id = res.get._1 + _ <- IO.raiseWhen(res.get._2 != 2)(new Throwable("unexpected result for generated column")) all <- s.execute(TestTable.selectAll()) - allWithId <- s.execute(TestTable.selectAllWithId()) + allWithGen <- s.execute(TestTable.selectAllWithGenerated()) _ <- IO.raiseWhen(all != List(testRow))(new Throwable("test A result not equal")) - _ <- IO.raiseWhen(allWithId.map(_._2) != List(testRow))(new Throwable("test A result with id not equal")) + _ <- IO.raiseWhen(allWithGen.map(_._3) != List(testRow))(new Throwable("test A result with id not equal")) aliasedTestTable = TestTable.withAlias("t") idAndName2 = aliasedTestTable.column.id ~ aliasedTestTable.column.name_2 xs <- From e6c6de51638ae933222bf869f3a747b5442ee69e Mon Sep 17 00:00:00 2001 From: Roman Langolf Date: Tue, 19 Nov 2024 18:30:05 +0700 Subject: [PATCH 3/4] fix codec type for optional values --- .../core/src/main/scala/com/anymindgroup/PgCodeGen.scala | 4 ++-- modules/core/src/test/resources/db/migration/V1__test.sql | 3 ++- .../test/scala/com/anymindgroup/GeneratedCodeTest._scala | 8 +++++--- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/modules/core/src/main/scala/com/anymindgroup/PgCodeGen.scala b/modules/core/src/main/scala/com/anymindgroup/PgCodeGen.scala index 104d28e..6fb4171 100644 --- a/modules/core/src/main/scala/com/anymindgroup/PgCodeGen.scala +++ b/modules/core/src/main/scala/com/anymindgroup/PgCodeGen.scala @@ -629,7 +629,7 @@ class PgCodeGen( if (autoIncFk.isEmpty) { (rowClassName, s"${rowClassName}.codec") } else { - val autoIncFkCodecs = autoIncFk.map(col => s"skunk.codec.all.${col.pgType.name}").mkString(" *: ") + val autoIncFkCodecs = autoIncFk.map(_.codecName).mkString(" *: ") val autoIncFkScalaTypes = autoIncFk.map(_.scalaType).mkString(" *: ") (s"($autoIncFkScalaTypes ~ $rowClassName)", s"$autoIncFkCodecs ~ ${rowClassName}.codec") } @@ -653,7 +653,7 @@ class PgCodeGen( .mkString("", " *: ", if (generatedColumns.length > 1) " *: EmptyTuple" else "") val fragmentType = generatedColumns match { case Nil => "command" - case _ => s"query(${generatedColumns.map(col => s"skunk.codec.all.${col.pgType.name}").mkString(" *: ")})" + case _ => s"query(${generatedColumns.map(_.codecName).mkString(" *: ")})" } val upsertQ = primaryUniqueConstraint.map { cstr => diff --git a/modules/core/src/test/resources/db/migration/V1__test.sql b/modules/core/src/test/resources/db/migration/V1__test.sql index 6707756..72c2bfd 100644 --- a/modules/core/src/test/resources/db/migration/V1__test.sql +++ b/modules/core/src/test/resources/db/migration/V1__test.sql @@ -15,7 +15,8 @@ CREATE TABLE test ( numeric_default numeric NOT NULL, numeric_24p numeric(24) NOT NULL, numeric_16p_2s numeric(16, 2) NOT NULL, - gen INT NOT NULL GENERATED ALWAYS AS (1 + 1) STORED + gen INT NOT NULL GENERATED ALWAYS AS (1 + 1) STORED, + gen_opt INT GENERATED ALWAYS AS (1 + 1) STORED ); CREATE TABLE test_ref_only ( diff --git a/modules/core/src/test/scala/com/anymindgroup/GeneratedCodeTest._scala b/modules/core/src/test/scala/com/anymindgroup/GeneratedCodeTest._scala index 1fc0063..f5ad61a 100644 --- a/modules/core/src/test/scala/com/anymindgroup/GeneratedCodeTest._scala +++ b/modules/core/src/test/scala/com/anymindgroup/GeneratedCodeTest._scala @@ -85,12 +85,14 @@ object GeneratedCodeTest extends IOApp { _ <- s.prepare(TestTable.insertQuery(ignoreConflict = true)) res <- p.option((testRow, testUpdateFr.argument)) _ <- IO.raiseWhen(res.isEmpty)(new Throwable("test A did not return generated columns")) - id = res.get._1 - _ <- IO.raiseWhen(res.get._2 != 2)(new Throwable("unexpected result for generated column")) + id = res.get._1 + _ <- IO.raiseWhen(res.get._2 != 2 && res.get._3 != Some(2))( + new Throwable("unexpected result for generated columns") + ) all <- s.execute(TestTable.selectAll()) allWithGen <- s.execute(TestTable.selectAllWithGenerated()) _ <- IO.raiseWhen(all != List(testRow))(new Throwable("test A result not equal")) - _ <- IO.raiseWhen(allWithGen.map(_._3) != List(testRow))(new Throwable("test A result with id not equal")) + _ <- IO.raiseWhen(allWithGen.map(_._4) != List(testRow))(new Throwable("test A result with id not equal")) aliasedTestTable = TestTable.withAlias("t") idAndName2 = aliasedTestTable.column.id ~ aliasedTestTable.column.name_2 xs <- From 05ef5c8dc845ca6d9775b03518493de58f0222eb Mon Sep 17 00:00:00 2001 From: Roman Langolf Date: Wed, 20 Nov 2024 17:28:48 +0700 Subject: [PATCH 4/4] update naming --- modules/core/src/main/scala/com/anymindgroup/PgCodeGen.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/core/src/main/scala/com/anymindgroup/PgCodeGen.scala b/modules/core/src/main/scala/com/anymindgroup/PgCodeGen.scala index 6fb4171..a101e18 100644 --- a/modules/core/src/main/scala/com/anymindgroup/PgCodeGen.scala +++ b/modules/core/src/main/scala/com/anymindgroup/PgCodeGen.scala @@ -720,7 +720,7 @@ class PgCodeGen( private def selectAllStatement(table: Table): String = { import table.* - val autoIncStm = if (generatedColumns.nonEmpty) { + val generatedColStm = if (generatedColumns.nonEmpty) { val types = generatedColumns.map(_.codecName).mkString(" *: ") val sTypes = generatedColumns.map(_.scalaType).mkString(" *: ") val colNamesStr = (generatedColumns ::: columns).map(_.columnName).mkString(", ") @@ -746,7 +746,7 @@ class PgCodeGen( val selectCol = s"""| def select[A, B](cols: Cols[A], rest: Fragment[B] = Fragment.empty): Query[B, A] = | sql"SELECT #$${cols.name} FROM #$$tableName $$rest".query(cols.codec) |""".stripMargin - autoIncStm ++ defaultStm ++ selectCol + generatedColStm ++ defaultStm ++ selectCol } private def lastModified(modified: List[Long]): Option[Long] =