Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 99 additions & 41 deletions core/src/test/scala/org/apache/spark/SparkTestSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.nio.charset.StandardCharsets.UTF_8
import java.nio.file.{Files, Path}
import java.util.{Locale, TimeZone}

import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.{ArrayBuffer, ListBuffer}
import scala.jdk.CollectionConverters._

import org.apache.logging.log4j._
Expand Down Expand Up @@ -301,59 +301,117 @@ trait SparkTestSuite
parameters: Map[String, String] = Map.empty,
matchPVals: Boolean = false,
queryContext: Array[ExpectedContext] = Array.empty): Unit = {
assert(exception.getCondition === condition)
sqlState.foreach(state => assert(exception.getSqlState === state))
val mismatches = new ListBuffer[String]

if (exception.getCondition != condition) {
mismatches += s"condition: expected '$condition' but got '${exception.getCondition}'"
}
sqlState.foreach { state =>
if (exception.getSqlState != state) {
mismatches += s"sqlState: expected '$state' but got '${exception.getSqlState}'"
}
}

val actualParameters = exception.getMessageParameters.asScala
val ignorable = checkErrorIgnorableParameters.getOrElse(condition, Set.empty[String])
val actualParametersToCompare = actualParameters.filter { case (k, _) =>
!ignorable.contains(k) || parameters.contains(k)
}
if (matchPVals) {
assert(actualParametersToCompare.size === parameters.size)
actualParametersToCompare.foreach(
exp => {
val parm = parameters.getOrElse(exp._1,
throw new IllegalArgumentException("Missing parameter" + exp._1))
if (!exp._2.matches(parm)) {
throw new IllegalArgumentException("For parameter '" + exp._1 + "' value '" + exp._2 +
"' does not match: " + parm)
}
if (actualParametersToCompare.size != parameters.size) {
mismatches += s"parameters size: expected ${parameters.size} but got" +
s" ${actualParametersToCompare.size}"
}
actualParametersToCompare.foreach { case (key, actualVal) =>
parameters.get(key) match {
case None =>
mismatches += s"parameters: unexpected key '$key' with value '$actualVal'"
case Some(pattern) if !actualVal.matches(pattern) =>
mismatches += s"parameters['$key']: value '$actualVal' does not match pattern" +
s" '$pattern'"
case _ =>
}
)
} else {
assert(actualParametersToCompare === parameters)
}
parameters.keys.filterNot(actualParametersToCompare.contains).foreach { key =>
mismatches += s"parameters: missing expected key '$key'"
}
} else if (actualParametersToCompare != parameters) {
mismatches += s"parameters: expected $parameters but got $actualParametersToCompare"
}

val actualQueryContext = exception.getQueryContext()
assert(actualQueryContext.length === queryContext.length, "Invalid length of the query context")
actualQueryContext.zip(queryContext).foreach { case (actual, expected) =>
assert(actual.contextType() === expected.contextType,
"Invalid contextType of a query context Actual:" + actual.toString)
if (actual.contextType() == QueryContextType.SQL) {
assert(actual.objectType() === expected.objectType,
"Invalid objectType of a query context Actual:" + actual.toString)
assert(actual.objectName() === expected.objectName,
"Invalid objectName of a query context. Actual:" + actual.toString)
// If startIndex and stopIndex are -1, it means we simply want to check the
// fragment of the query context. This should be the case when the fragment is
// distinguished within the query text.
if (expected.startIndex != -1) {
assert(actual.startIndex() === expected.startIndex,
"Invalid startIndex of a query context. Actual:" + actual.toString)
if (actualQueryContext.length != queryContext.length) {
mismatches += s"queryContext.length: expected ${queryContext.length}" +
s" but got ${actualQueryContext.length}"
}
actualQueryContext.zip(queryContext).zipWithIndex.foreach {
case ((actual, expected), idx) =>
if (actual.contextType() != expected.contextType) {
mismatches += s"queryContext[$idx].contextType: expected ${expected.contextType}" +
s" but got ${actual.contextType()}"
}
if (expected.stopIndex != -1) {
assert(actual.stopIndex() === expected.stopIndex,
"Invalid stopIndex of a query context. Actual:" + actual.toString)
if (actual.contextType() == QueryContextType.SQL) {
if (actual.objectType() != expected.objectType) {
mismatches += s"queryContext[$idx].objectType: expected '${expected.objectType}'" +
s" but got '${actual.objectType()}'"
}
if (actual.objectName() != expected.objectName) {
mismatches += s"queryContext[$idx].objectName: expected '${expected.objectName}'" +
s" but got '${actual.objectName()}'"
}
if (expected.startIndex != -1 && actual.startIndex() != expected.startIndex) {
mismatches += s"queryContext[$idx].startIndex: expected ${expected.startIndex}" +
s" but got ${actual.startIndex()}"
}
if (expected.stopIndex != -1 && actual.stopIndex() != expected.stopIndex) {
mismatches += s"queryContext[$idx].stopIndex: expected ${expected.stopIndex}" +
s" but got ${actual.stopIndex()}"
}
if (actual.fragment() != expected.fragment) {
mismatches += s"queryContext[$idx].fragment: expected '${expected.fragment}'" +
s" but got '${actual.fragment()}'"
}
} else if (actual.contextType() == QueryContextType.DataFrame) {
if (actual.fragment() != expected.fragment) {
mismatches += s"queryContext[$idx].fragment: expected '${expected.fragment}'" +
s" but got '${actual.fragment()}'"
}
if (expected.callSitePattern.nonEmpty &&
!actual.callSite().matches(expected.callSitePattern)) {
mismatches += s"queryContext[$idx].callSite: '${actual.callSite()}'" +
s" does not match pattern '${expected.callSitePattern}'"
}
}
assert(actual.fragment() === expected.fragment,
"Invalid fragment of a query context. Actual:" + actual.toString)
} else if (actual.contextType() == QueryContextType.DataFrame) {
assert(actual.fragment() === expected.fragment,
"Invalid code fragment of a query context. Actual:" + actual.toString)
if (expected.callSitePattern.nonEmpty) {
assert(actual.callSite().matches(expected.callSitePattern),
"Invalid callSite of a query context. Actual:" + actual.toString)
}

if (mismatches.nonEmpty) {
val sb = new StringBuilder
sb.append(s"checkError found ${mismatches.size} mismatch(es).\n\n")
sb.append("=== Actual Exception State ===\n")
sb.append(s" condition: ${exception.getCondition}\n")
sb.append(s" sqlState: ${exception.getSqlState}\n")
sb.append(s" parameters:\n")
if (actualParameters.isEmpty) {
sb.append(" (empty)\n")
} else {
actualParameters.foreach { case (k, v) => sb.append(s" $k -> $v\n") }
}
actualQueryContext.zipWithIndex.foreach { case (ctx, idx) =>
sb.append(s" queryContext[$idx] (${ctx.contextType()}):\n")
if (ctx.contextType() == QueryContextType.SQL) {
sb.append(s" objectType: ${ctx.objectType()}\n")
sb.append(s" objectName: ${ctx.objectName()}\n")
sb.append(s" startIndex: ${ctx.startIndex()}\n")
sb.append(s" stopIndex: ${ctx.stopIndex()}\n")
sb.append(s" fragment: ${ctx.fragment()}\n")
} else if (ctx.contextType() == QueryContextType.DataFrame) {
sb.append(s" fragment: ${ctx.fragment()}\n")
sb.append(s" callSite: ${ctx.callSite()}\n")
}
}
sb.append("\n=== Mismatches ===\n")
mismatches.foreach(m => sb.append(s" $m\n"))
fail(sb.toString())
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.connect.test

import java.util.TimeZone

import scala.collection.mutable.ListBuffer
import scala.jdk.CollectionConverters._

import org.scalatest.Assertions
Expand Down Expand Up @@ -112,58 +113,113 @@ abstract class QueryTest extends ConnectFunSuite with SQLHelper {
parameters: Map[String, String] = Map.empty,
matchPVals: Boolean = false,
queryContext: Array[ExpectedContext] = Array.empty): Unit = {
assert(exception.getCondition === condition)
sqlState.foreach(state => assert(exception.getSqlState === state))
val expectedParameters = exception.getMessageParameters.asScala
val mismatches = new ListBuffer[String]

if (exception.getCondition != condition) {
mismatches += s"condition: expected '$condition' but got '${exception.getCondition}'"
}
sqlState.foreach { state =>
if (exception.getSqlState != state) {
mismatches += s"sqlState: expected '$state' but got '${exception.getSqlState}'"
}
}

val actualParameters = exception.getMessageParameters.asScala
if (matchPVals) {
assert(expectedParameters.size === parameters.size)
expectedParameters.foreach(exp => {
val parm = parameters.getOrElse(
exp._1,
throw new IllegalArgumentException("Missing parameter" + exp._1))
if (!exp._2.matches(parm)) {
throw new IllegalArgumentException(
"For parameter '" + exp._1 + "' value '" + exp._2 +
"' does not match: " + parm)
if (actualParameters.size != parameters.size) {
mismatches += s"parameters size: expected ${parameters.size} but got" +
s" ${actualParameters.size}"
}
actualParameters.foreach { case (key, actualVal) =>
parameters.get(key) match {
case None =>
mismatches += s"parameters: unexpected key '$key' with value '$actualVal'"
case Some(pattern) if !actualVal.matches(pattern) =>
mismatches += s"parameters['$key']: value '$actualVal' does not match pattern" +
s" '$pattern'"
case _ =>
}
})
} else {
assert(expectedParameters === parameters)
}
parameters.keys.filterNot(actualParameters.contains).foreach { key =>
mismatches += s"parameters: missing expected key '$key'"
}
} else if (actualParameters != parameters) {
mismatches += s"parameters: expected $parameters but got $actualParameters"
}

val actualQueryContext = exception.getQueryContext()
assert(
actualQueryContext.length === queryContext.length,
"Invalid length of the query context")
actualQueryContext.zip(queryContext).foreach { case (actual, expected) =>
assert(
actual.contextType() === expected.contextType,
"Invalid contextType of a query context Actual:" + actual.toString)
if (actual.contextType() == QueryContextType.SQL) {
assert(
actual.objectType() === expected.objectType,
"Invalid objectType of a query context Actual:" + actual.toString)
assert(
actual.objectName() === expected.objectName,
"Invalid objectName of a query context. Actual:" + actual.toString)
assert(
actual.startIndex() === expected.startIndex,
"Invalid startIndex of a query context. Actual:" + actual.toString)
assert(
actual.stopIndex() === expected.stopIndex,
"Invalid stopIndex of a query context. Actual:" + actual.toString)
assert(
actual.fragment() === expected.fragment,
"Invalid fragment of a query context. Actual:" + actual.toString)
} else if (actual.contextType() == QueryContextType.DataFrame) {
assert(
actual.fragment() === expected.fragment,
"Invalid code fragment of a query context. Actual:" + actual.toString)
if (expected.callSitePattern.nonEmpty) {
assert(
actual.callSite().matches(expected.callSitePattern),
"Invalid callSite of a query context. Actual:" + actual.toString)
if (actualQueryContext.length != queryContext.length) {
mismatches += s"queryContext.length: expected ${queryContext.length}" +
s" but got ${actualQueryContext.length}"
}
actualQueryContext.zip(queryContext).zipWithIndex.foreach {
case ((actual, expected), idx) =>
if (actual.contextType() != expected.contextType) {
mismatches += s"queryContext[$idx].contextType: expected ${expected.contextType}" +
s" but got ${actual.contextType()}"
}
if (actual.contextType() == QueryContextType.SQL) {
if (actual.objectType() != expected.objectType) {
mismatches += s"queryContext[$idx].objectType: expected '${expected.objectType}'" +
s" but got '${actual.objectType()}'"
}
if (actual.objectName() != expected.objectName) {
mismatches += s"queryContext[$idx].objectName: expected '${expected.objectName}'" +
s" but got '${actual.objectName()}'"
}
if (actual.startIndex() != expected.startIndex) {
mismatches += s"queryContext[$idx].startIndex: expected ${expected.startIndex}" +
s" but got ${actual.startIndex()}"
}
if (actual.stopIndex() != expected.stopIndex) {
mismatches += s"queryContext[$idx].stopIndex: expected ${expected.stopIndex}" +
s" but got ${actual.stopIndex()}"
}
if (actual.fragment() != expected.fragment) {
mismatches += s"queryContext[$idx].fragment: expected '${expected.fragment}'" +
s" but got '${actual.fragment()}'"
}
} else if (actual.contextType() == QueryContextType.DataFrame) {
if (actual.fragment() != expected.fragment) {
mismatches += s"queryContext[$idx].fragment: expected '${expected.fragment}'" +
s" but got '${actual.fragment()}'"
}
if (expected.callSitePattern.nonEmpty &&
!actual.callSite().matches(expected.callSitePattern)) {
mismatches += s"queryContext[$idx].callSite: '${actual.callSite()}'" +
s" does not match pattern '${expected.callSitePattern}'"
}
}
}

if (mismatches.nonEmpty) {
val sb = new StringBuilder
sb.append(s"checkError found ${mismatches.size} mismatch(es).\n\n")
sb.append("=== Actual Exception State ===\n")
sb.append(s" condition: ${exception.getCondition}\n")
sb.append(s" sqlState: ${exception.getSqlState}\n")
sb.append(s" parameters:\n")
if (actualParameters.isEmpty) {
sb.append(" (empty)\n")
} else {
actualParameters.foreach { case (k, v) => sb.append(s" $k -> $v\n") }
}
actualQueryContext.zipWithIndex.foreach { case (ctx, idx) =>
sb.append(s" queryContext[$idx] (${ctx.contextType()}):\n")
if (ctx.contextType() == QueryContextType.SQL) {
sb.append(s" objectType: ${ctx.objectType()}\n")
sb.append(s" objectName: ${ctx.objectName()}\n")
sb.append(s" startIndex: ${ctx.startIndex()}\n")
sb.append(s" stopIndex: ${ctx.stopIndex()}\n")
sb.append(s" fragment: ${ctx.fragment()}\n")
} else if (ctx.contextType() == QueryContextType.DataFrame) {
sb.append(s" fragment: ${ctx.fragment()}\n")
sb.append(s" callSite: ${ctx.callSite()}\n")
}
}
sb.append("\n=== Mismatches ===\n")
mismatches.foreach(m => sb.append(s" $m\n"))
fail(sb.toString())
}
}

Expand Down