Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 102 additions & 41 deletions core/src/test/scala/org/apache/spark/SparkTestSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.nio.charset.StandardCharsets.UTF_8
import java.nio.file.{Files, Path}
import java.util.{Locale, TimeZone}

import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.{ArrayBuffer, ListBuffer}
import scala.jdk.CollectionConverters._

import org.apache.logging.log4j._
Expand Down Expand Up @@ -301,59 +301,120 @@ trait SparkTestSuite
parameters: Map[String, String] = Map.empty,
matchPVals: Boolean = false,
queryContext: Array[ExpectedContext] = Array.empty): Unit = {
assert(exception.getCondition === condition)
sqlState.foreach(state => assert(exception.getSqlState === state))
val mismatches = new ListBuffer[String]

if (exception.getCondition != condition) {
mismatches += s"condition: expected '$condition' but got '${exception.getCondition}'"
}
sqlState.foreach { state =>
if (exception.getSqlState != state) {
mismatches += s"sqlState: expected '$state' but got '${exception.getSqlState}'"
}
}

val actualParameters = exception.getMessageParameters.asScala
val ignorable = checkErrorIgnorableParameters.getOrElse(condition, Set.empty[String])
val actualParametersToCompare = actualParameters.filter { case (k, _) =>
!ignorable.contains(k) || parameters.contains(k)
}
if (matchPVals) {
assert(actualParametersToCompare.size === parameters.size)
actualParametersToCompare.foreach(
exp => {
val parm = parameters.getOrElse(exp._1,
throw new IllegalArgumentException("Missing parameter" + exp._1))
if (!exp._2.matches(parm)) {
throw new IllegalArgumentException("For parameter '" + exp._1 + "' value '" + exp._2 +
"' does not match: " + parm)
}
if (actualParametersToCompare.size != parameters.size) {
mismatches += s"parameters size: expected ${parameters.size} but got" +
s" ${actualParametersToCompare.size}"
}
actualParametersToCompare.foreach { case (key, actualVal) =>
parameters.get(key) match {
case None =>
mismatches += s"parameters: unexpected key '$key' with value '$actualVal'"
case Some(pattern) if !actualVal.matches(pattern) =>
mismatches += s"parameters['$key']: value '$actualVal' does not match pattern" +
s" '$pattern'"
case _ =>
}
)
} else {
assert(actualParametersToCompare === parameters)
}
parameters.keys.filterNot(actualParametersToCompare.contains).foreach { key =>
mismatches += s"parameters: missing expected key '$key'"
}
} else if (actualParametersToCompare != parameters) {
mismatches += s"parameters: expected $parameters but got $actualParametersToCompare"
}

val actualQueryContext = exception.getQueryContext()
assert(actualQueryContext.length === queryContext.length, "Invalid length of the query context")
actualQueryContext.zip(queryContext).foreach { case (actual, expected) =>
assert(actual.contextType() === expected.contextType,
"Invalid contextType of a query context Actual:" + actual.toString)
if (actual.contextType() == QueryContextType.SQL) {
assert(actual.objectType() === expected.objectType,
"Invalid objectType of a query context Actual:" + actual.toString)
assert(actual.objectName() === expected.objectName,
"Invalid objectName of a query context. Actual:" + actual.toString)
// If startIndex and stopIndex are -1, it means we simply want to check the
// fragment of the query context. This should be the case when the fragment is
// distinguished within the query text.
if (expected.startIndex != -1) {
assert(actual.startIndex() === expected.startIndex,
"Invalid startIndex of a query context. Actual:" + actual.toString)
if (actualQueryContext.length != queryContext.length) {
mismatches += s"queryContext.length: expected ${queryContext.length}" +
s" but got ${actualQueryContext.length}"
}
actualQueryContext.zip(queryContext).zipWithIndex.foreach {
case ((actual, expected), idx) =>
if (actual.contextType() != expected.contextType) {
mismatches += s"queryContext[$idx].contextType: expected ${expected.contextType}" +
s" but got ${actual.contextType()}"
}
if (expected.stopIndex != -1) {
assert(actual.stopIndex() === expected.stopIndex,
"Invalid stopIndex of a query context. Actual:" + actual.toString)
if (actual.contextType() == QueryContextType.SQL) {
if (actual.objectType() != expected.objectType) {
mismatches += s"queryContext[$idx].objectType: expected '${expected.objectType}'" +
s" but got '${actual.objectType()}'"
}
if (actual.objectName() != expected.objectName) {
mismatches += s"queryContext[$idx].objectName: expected '${expected.objectName}'" +
s" but got '${actual.objectName()}'"
}
// If startIndex and stopIndex are -1, it means we simply want to check the
// fragment of the query context. This should be the case when the fragment is
// distinguished within the query text.
if (expected.startIndex != -1 && actual.startIndex() != expected.startIndex) {
mismatches += s"queryContext[$idx].startIndex: expected ${expected.startIndex}" +
s" but got ${actual.startIndex()}"
}
if (expected.stopIndex != -1 && actual.stopIndex() != expected.stopIndex) {
mismatches += s"queryContext[$idx].stopIndex: expected ${expected.stopIndex}" +
s" but got ${actual.stopIndex()}"
}
if (actual.fragment() != expected.fragment) {
mismatches += s"queryContext[$idx].fragment: expected '${expected.fragment}'" +
s" but got '${actual.fragment()}'"
}
} else if (actual.contextType() == QueryContextType.DataFrame) {
if (actual.fragment() != expected.fragment) {
mismatches += s"queryContext[$idx].fragment: expected '${expected.fragment}'" +
s" but got '${actual.fragment()}'"
}
if (expected.callSitePattern.nonEmpty &&
!actual.callSite().matches(expected.callSitePattern)) {
mismatches += s"queryContext[$idx].callSite: '${actual.callSite()}'" +
s" does not match pattern '${expected.callSitePattern}'"
}
}
assert(actual.fragment() === expected.fragment,
"Invalid fragment of a query context. Actual:" + actual.toString)
} else if (actual.contextType() == QueryContextType.DataFrame) {
assert(actual.fragment() === expected.fragment,
"Invalid code fragment of a query context. Actual:" + actual.toString)
if (expected.callSitePattern.nonEmpty) {
assert(actual.callSite().matches(expected.callSitePattern),
"Invalid callSite of a query context. Actual:" + actual.toString)
}

if (mismatches.nonEmpty) {
val sb = new StringBuilder
sb.append(s"checkError found ${mismatches.size} mismatch(es).\n\n")
sb.append("=== Actual Exception State ===\n")
sb.append(s" condition: ${exception.getCondition}\n")
sb.append(s" sqlState: ${exception.getSqlState}\n")
sb.append(s" parameters:\n")
if (actualParameters.isEmpty) {
sb.append(" (empty)\n")
} else {
actualParameters.foreach { case (k, v) => sb.append(s" $k -> $v\n") }
}
actualQueryContext.zipWithIndex.foreach { case (ctx, idx) =>
sb.append(s" queryContext[$idx] (${ctx.contextType()}):\n")
if (ctx.contextType() == QueryContextType.SQL) {
sb.append(s" objectType: ${ctx.objectType()}\n")
sb.append(s" objectName: ${ctx.objectName()}\n")
sb.append(s" startIndex: ${ctx.startIndex()}\n")
sb.append(s" stopIndex: ${ctx.stopIndex()}\n")
sb.append(s" fragment: ${ctx.fragment()}\n")
} else if (ctx.contextType() == QueryContextType.DataFrame) {
sb.append(s" fragment: ${ctx.fragment()}\n")
sb.append(s" callSite: ${ctx.callSite()}\n")
}
}
sb.append("\n=== Mismatches ===\n")
mismatches.foreach(m => sb.append(s" $m\n"))
fail(sb.toString())
}
}

Expand Down
59 changes: 59 additions & 0 deletions core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import com.fasterxml.jackson.core.util.{DefaultIndenter, DefaultPrettyPrinter}
import com.fasterxml.jackson.databind.SerializationFeature
import com.fasterxml.jackson.databind.json.JsonMapper
import com.fasterxml.jackson.module.scala.DefaultScalaModule
import org.scalatest.exceptions.TestFailedException

import org.apache.spark.SparkThrowableHelper._
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -800,4 +801,62 @@ class SparkThrowableSuite extends SparkFunSuite {
assert(errorWithNull.getSqlState == "22018",
"Should fall back to error class reader SQL state when custom is null")
}

test("checkError reports all mismatches in a single failure message") {
class TestQueryContext extends QueryContext {
override val contextType = QueryContextType.SQL
override val objectName = "v1"
override val objectType = "VIEW"
override val startIndex = 2
override val stopIndex = 10
override val fragment = "1 / 0"
override val callSite = ""
override val summary = ""
}

val exception = new SparkArithmeticException(
errorClass = "DIVIDE_BY_ZERO",
messageParameters = Map("config" -> "CONFIG"),
context = Array(new TestQueryContext),
summary = "")

val error = intercept[TestFailedException] {
checkError(
exception = exception,
condition = "WRONG_CONDITION",
sqlState = Some("99999"),
parameters = Map("config" -> "WRONG_VALUE"),
queryContext = Array(ExpectedContext(
objectType = "TABLE",
objectName = "t1",
startIndex = 0,
stopIndex = 5,
fragment = "wrong fragment")))
}
val msg = error.getMessage
assert(msg.contains("=== Actual Exception State ==="), "Should contain actual state header")
assert(msg.contains("=== Mismatches ==="), "Should contain mismatches header")
assert(msg.contains("condition:"), "Should report condition mismatch")
assert(msg.contains("sqlState:"), "Should report sqlState mismatch")
assert(msg.contains("parameters:"), "Should report parameters mismatch")
assert(msg.contains("queryContext[0].objectType:"), "Should report objectType mismatch")
assert(msg.contains("queryContext[0].objectName:"), "Should report objectName mismatch")
assert(msg.contains("queryContext[0].startIndex:"), "Should report startIndex mismatch")
assert(msg.contains("queryContext[0].stopIndex:"), "Should report stopIndex mismatch")
assert(msg.contains("queryContext[0].fragment:"), "Should report fragment mismatch")
}

test("checkError succeeds when all fields match") {
val exception = new SparkArithmeticException(
errorClass = "DIVIDE_BY_ZERO",
messageParameters = Map("config" -> "CONFIG"),
context = Array.empty,
summary = "")

checkError(
exception = exception,
condition = "DIVIDE_BY_ZERO",
sqlState = Some("22012"),
parameters = Map("config" -> "CONFIG"))
}
}
Loading