Skip to content

Commit

Permalink
Issue #515: Add ColumnStats Schema for JSON parsing (#522)
Browse files Browse the repository at this point in the history
Add schema to columnStats

---------

Co-authored-by: jiawei <jiaweihu08@gmail.com>
  • Loading branch information
osopardo1 and Jiaweihu08 authored Jan 16, 2025
1 parent b2e2f85 commit 6cfe206
Show file tree
Hide file tree
Showing 12 changed files with 524 additions and 82 deletions.
6 changes: 3 additions & 3 deletions core/src/main/scala/io/qbeast/core/model/ColumnToIndex.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ import org.apache.spark.sql.types.StructType
case class ColumnToIndex(columnName: String, transformerType: Option[String]) {

def toTransformer(schema: StructType): Transformer = {
val getColumnQType = ColumnToIndexUtils.getColumnQType(columnName, schema)
val qDataType = ColumnToIndexUtils.getColumnQType(columnName, schema)
transformerType match {
case Some(tt) => Transformer(tt, columnName, getColumnQType)
case None => Transformer(columnName, getColumnQType)
case Some(tt) => Transformer(tt, columnName, qDataType)
case None => Transformer(columnName, qDataType)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ object ColumnToIndexUtils {
val SpecExtractor: Regex = "([^:]+):([A-z]+)".r

def getColumnQType(columnName: String, schema: StructType): QDataType = {
SparkToQTypesUtils.convertDataTypes(schema(columnName).dataType)
SparkToQTypesUtils.convertToQDataType(schema(columnName).dataType)
}

}
149 changes: 149 additions & 0 deletions core/src/main/scala/io/qbeast/core/model/QbeastColumnStats.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
/*
* Copyright 2021 Qbeast Analytics, S.L.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.qbeast.core.model

import io.qbeast.core.transform._
import io.qbeast.spark.utils.SparkToQTypesUtils
import org.apache.spark.internal.Logging
import org.apache.spark.sql.types._
import org.apache.spark.sql.AnalysisExceptionFactory
import org.apache.spark.sql.Row
import org.apache.spark.sql.SparkSession

/**
* Container for Qbeast Column Stats
*
* @param schema
* the column stats schema
* @param rowOption
* the column stats row
*/
case class QbeastColumnStats(schema: StructType, rowOption: Option[Row]) extends Logging {

/**
* Creates a transformation for the given transformer. If the stats are not available, it
* returns None. An IllegalArgumentException is thrown if the output is an
* IdentityTransformation. Otherwise, the provided stats will be used to create and returns the
* transformation.
*
* @param transformer
* the transformer
* @return
*/
def createTransformation(transformer: Transformer): Option[Transformation] = rowOption match {
case Some(row) =>
val hasStats = transformer.stats.statsNames.exists(row.getAs[Object](_) != null)
if (hasStats) {
val transformation = transformer.makeTransformation(row.getAs[Object])
(transformer, transformation) match {
case (_: LinearTransformer, _: IdentityTransformation) =>
// If the transformation is IdentityTransformation, it means:
// 1. Either columnName_min or columnName_max is missing or
// 2. columnName_min and columnName_max are the same
throw new IllegalArgumentException(
s"The provided columnStats for column ${transformer.columnName} are not valid. " +
s"Make sure both min and max values are provided and satisfy the condition: min < max.")
case _ => Some(transformation)
}
} else None
case None => None
}

}

/**
* Companion object for QbeastColumnStats
*/
object QbeastColumnStats {

/**
* Builds the column stats schema
*
* For each column transformer, create the sequence StructField for its column stats
* @param columnTransformers
* the column transformers
* @return
*/
private[model] def buildColumnStatsSchema(columnTransformers: Seq[Transformer]): StructType = {
val builder = Seq.newBuilder[StructField]
columnTransformers.foreach { t =>
val fields = t match {
case lt: LinearTransformer =>
val sparkDataType = SparkToQTypesUtils.convertToSparkDataType(lt.dataType)
lt.stats.statsNames.map(StructField(_, sparkDataType, nullable = true))
case nq: CDFNumericQuantilesTransformer =>
nq.stats.statsNames.map(StructField(_, ArrayType(DoubleType), nullable = true))
case sq: CDFStringQuantilesTransformer =>
sq.stats.statsNames.map(StructField(_, ArrayType(StringType), nullable = true))
case sh: StringHistogramTransformer =>
sh.stats.statsNames.map(StructField(_, ArrayType(StringType), nullable = true))
case _ => Seq.empty
// TODO: Add support for other transformers
}
builder ++= fields
}
StructType(builder.result())
}

/**
* Builds the column stats row
*
* @param stats
* the stats in a JSON string
* @param columnStatsSchema
* the column stats schema
* @return
*/
private[model] def buildColumnStatsRow(
stats: String,
columnStatsSchema: StructType): Option[Row] = {
if (stats.isEmpty) None // No stats are provided
else {
val spark = SparkSession.active
import spark.implicits._
val columnStatsJSON = Seq(stats).toDS()
val row = spark.read
.option("inferTimestamp", "true")
.option("timestampFormat", "yyyy-MM-dd HH:mm:ss.SSSSSS'Z'")
.schema(columnStatsSchema)
.json(columnStatsJSON)
.first()
// All values will be Null is the input JSON is invalid
val isInvalidJSON = row.toSeq.forall(_ == null)
if (isInvalidJSON) {
throw AnalysisExceptionFactory.create(
s"The columnStats provided is not a valid JSON: $stats")
}
Some(row)
}
}

/**
* Builds the QbeastColumnStats
*
* @param statsString
* the stats in a JSON string
* @param columnTransformers
* the set of columnTransformers to build the Stats from
* @return
*/
def apply(statsString: String, columnTransformers: Seq[Transformer]): QbeastColumnStats = {
val statsSchema = buildColumnStatsSchema(columnTransformers)
val statsRowOption = buildColumnStatsRow(statsString, statsSchema)
QbeastColumnStats(statsSchema, statsRowOption)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ case class CDFNumericQuantilesTransformer(columnName: String, orderedDataType: O
override def makeTransformation(row: String => Any): Transformation = {
row(columnTransformerName) match {
case null => EmptyTransformation()
case q: Seq[_] if q.nonEmpty =>
case q: Seq[_] =>
try {
val quantiles = q.map(_.asInstanceOf[Double]).toIndexedSeq
CDFNumericQuantilesTransformation(quantiles, orderedDataType)
Expand All @@ -50,9 +50,6 @@ case class CDFNumericQuantilesTransformer(columnName: String, orderedDataType: O
throw AnalysisExceptionFactory.create(
"Quantiles should be of type Double, but found another type")
}
case q: Seq[_] if q.isEmpty =>
throw AnalysisExceptionFactory.create(
s"Quantiles for column $columnName size should be greater than 1")
case _ =>
throw AnalysisExceptionFactory.create(
s"Quantiles for column $columnName should be of type Array[Double]")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,9 @@ case class CDFStringQuantilesTransformer(columnName: String) extends CDFQuantile
override def makeTransformation(row: String => Any): Transformation = {
row(columnTransformerName) match {
case null => EmptyTransformation()
case q: Seq[_] if q.nonEmpty =>
case q: Seq[_] =>
val quantiles = q.map(_.toString).toIndexedSeq
CDFStringQuantilesTransformation(quantiles)
case q: Seq[_] if q.isEmpty =>
throw AnalysisExceptionFactory.create(
s"Quantiles for column $columnName size should be greater than 1")
case _ =>
throw AnalysisExceptionFactory.create(
s"Quantiles for column $columnName should be of type Array[String]")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ case class LinearTransformer(columnName: String, dataType: QDataType) extends Tr

override def stats: ColumnStats =
ColumnStats(
names = Seq(colMax, colMin),
predicates = Seq(s"max($columnName) AS $colMax", s"min($columnName) AS $colMin"))
names = Seq(colMin, colMax),
predicates = Seq(s"min($columnName) AS $colMin", s"max($columnName) AS $colMax"))

override def makeTransformation(row: String => Any): Transformation = dataType match {
case ordered: OrderedDataType =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,14 @@
*/
package io.qbeast.spark.index

import io.qbeast.core.model.ColumnToIndex
import io.qbeast.core.model.QTableID
import io.qbeast.core.model.QbeastOptions
import io.qbeast.core.model.Revision
import io.qbeast.core.model.RevisionChange
import io.qbeast.core.model.RevisionFactory
import io.qbeast.core.model.StagingUtils
import io.qbeast.core.transform.CDFQuantilesTransformer
import io.qbeast.core.transform.EmptyTransformation
import io.qbeast.core.transform.EmptyTransformer
import io.qbeast.core.transform.Transformation
import io.qbeast.core.transform.Transformer
import io.qbeast.core.model._
import io.qbeast.core.transform._
import io.qbeast.IISeq
import org.apache.spark.internal.Logging
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.AnalysisExceptionFactory
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Row
import org.apache.spark.sql.SparkSession

/**
* Spark implementation of RevisionBuilder
Expand Down Expand Up @@ -86,19 +75,24 @@ trait SparkRevisionChangesUtils extends StagingUtils with Logging {
options: QbeastOptions,
data: DataFrame): (Option[RevisionChange], Long) = {
checkColumnChanges(revision, options)
// 1. Compute transformer changes
val transformerChanges =
computeTransformerChanges(revision.columnTransformers, options, data.schema)
// 2. Update transformers if necessary
val updatedTransformers =
computeUpdatedTransformers(revision.columnTransformers, transformerChanges)
val statsRow = getDataFrameStats(data, updatedTransformers)
val numElements = statsRow.getAs[Long]("count")
// 3. Get the stats from the DataFrame
val dataFrameStats = getDataFrameStats(data, updatedTransformers)
val numElements = dataFrameStats.getAs[Long]("count")
// 4. Compute the cube size changes
val cubeSizeChanges = computeCubeSizeChanges(revision, options)
// 5. Compute the Transformation changes given the input data and the user input
val transformationChanges =
computeTransformationChanges(
updatedTransformers,
revision.transformations,
options,
statsRow)
dataFrameStats)
val hasRevisionChanges =
cubeSizeChanges.isDefined ||
transformerChanges.flatten.nonEmpty ||
Expand Down Expand Up @@ -238,11 +232,13 @@ trait SparkRevisionChangesUtils extends StagingUtils with Logging {
transformations: IISeq[Transformation],
options: QbeastOptions,
row: Row): IISeq[Option[Transformation]] = {
// Compute transformations from columnStats and DataFrame stats, and merge them
// Compute transformations from dataFrameStats
val transformationsFromDataFrameStats =
computeTransformationsFromDataFrameStats(transformers, row)
// Compute transformations from columnStats
val transformationsFromColumnsStats =
computeTransformationsFromColumnStats(transformers, options)
// Merge transformations from DataFrame and columnStats
val newTransformations = transformationsFromDataFrameStats
.zip(transformationsFromColumnsStats)
.map {
Expand Down Expand Up @@ -284,16 +280,12 @@ trait SparkRevisionChangesUtils extends StagingUtils with Logging {
private[index] def computeTransformationsFromColumnStats(
transformers: IISeq[Transformer],
options: QbeastOptions): IISeq[Option[Transformation]] = {
val (columnStats, availableColumnStats) = parseColumnStats(options)
transformers.map { t =>
if (t.stats.statsNames.forall(availableColumnStats.contains)) {
// Create transformation with columnStats
Some(t.makeTransformation(columnStats.getAs[Object]))
} else {
// Ignore the transformation if the stats are not available
None
}
}
// 1. Get the columnStats from the options
val columnStatsString = options.columnStats.getOrElse("")
// 2. Build the QbeastColumnStats
val columnStats = QbeastColumnStats(columnStatsString, transformers)
// 3. Compute transformations from the columnStats
transformers.map(columnStats.createTransformation)
}

/**
Expand All @@ -310,22 +302,4 @@ trait SparkRevisionChangesUtils extends StagingUtils with Logging {
transformers.map(_.makeTransformation(row.getAs[Object]))
}

private[index] def parseColumnStats(options: QbeastOptions): (Row, Set[String]) = {
val (row, statsNames) = if (options.columnStats.isDefined) {
val spark = SparkSession.active
import spark.implicits._
val stats = spark.read
.option("inferTimestamp", "true")
.option("timestampFormat", "yyyy-MM-dd HH:mm:ss.SSSSSS'Z'")
.json(Seq(options.columnStats.get).toDS())
.first()
(stats, stats.schema.fieldNames.toSet)
} else (Row.empty, Set.empty[String])
if (statsNames.contains("_corrupt_record")) {
throw AnalysisExceptionFactory.create(
"The columnStats provided is not a valid JSON: " + row.getAs[String]("_corrupt_record"))
}
(row, statsNames)
}

}
44 changes: 33 additions & 11 deletions core/src/main/scala/io/qbeast/spark/utils/SparkToQTypesUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,43 @@
*/
package io.qbeast.spark.utils

import io.qbeast.core.{model => qmodel}
import io.qbeast.core.model.DateDataType
import io.qbeast.core.model.DecimalDataType
import io.qbeast.core.model.DoubleDataType
import io.qbeast.core.model.FloatDataType
import io.qbeast.core.model.IntegerDataType
import io.qbeast.core.model.LongDataType
import io.qbeast.core.model.QDataType
import io.qbeast.core.model.StringDataType
import io.qbeast.core.model.TimestampDataType
import org.apache.spark.sql.types._

object SparkToQTypesUtils {

def convertDataTypes(sparkType: DataType): io.qbeast.core.model.QDataType = sparkType match {
case _: DoubleType => qmodel.DoubleDataType
case _: IntegerType => qmodel.IntegerDataType
case _: FloatType => qmodel.FloatDataType
case _: LongType => qmodel.LongDataType
case _: StringType => qmodel.StringDataType
case _: DecimalType => qmodel.DecimalDataType
case _: TimestampType => qmodel.TimestampDataType
case _: DateType => qmodel.DateDataType
case _ => throw new RuntimeException(s"${sparkType.typeName} is not supported yet")
def convertToQDataType(sparkDataType: DataType): QDataType = sparkDataType match {
case _: DoubleType => DoubleDataType
case _: IntegerType => IntegerDataType
case _: FloatType => FloatDataType
case _: LongType => LongDataType
case _: StringType => StringDataType
case _: DecimalType => DecimalDataType
case _: TimestampType => TimestampDataType
case _: DateType => DateDataType
case _ => throw new RuntimeException(s"${sparkDataType.typeName} is not supported.")
// TODO add more types
}

def convertToSparkDataType(qDataType: QDataType): DataType = qDataType match {
case DoubleDataType => DoubleType
case IntegerDataType => IntegerType
case FloatDataType => FloatType
case LongDataType => LongType
case StringDataType => StringType
case DecimalDataType => DecimalType.SYSTEM_DEFAULT
case TimestampDataType => TimestampType
case DateDataType => DateType
case _ =>
throw new RuntimeException(s"No corresponding spark type is found for ${qDataType.name}.")
// TODO add more types
}

Expand Down
Loading

0 comments on commit 6cfe206

Please sign in to comment.