diff --git a/core/src/test/scala/org/apache/spark/SparkTestSuite.scala b/core/src/test/scala/org/apache/spark/SparkTestSuite.scala index 3d5fc6aaa18c8..10504684be9fd 100644 --- a/core/src/test/scala/org/apache/spark/SparkTestSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkTestSuite.scala @@ -22,7 +22,7 @@ import java.nio.charset.StandardCharsets.UTF_8 import java.nio.file.{Files, Path} import java.util.{Locale, TimeZone} -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.{ArrayBuffer, ListBuffer} import scala.jdk.CollectionConverters._ import org.apache.logging.log4j._ @@ -301,59 +301,120 @@ trait SparkTestSuite parameters: Map[String, String] = Map.empty, matchPVals: Boolean = false, queryContext: Array[ExpectedContext] = Array.empty): Unit = { - assert(exception.getCondition === condition) - sqlState.foreach(state => assert(exception.getSqlState === state)) + val mismatches = new ListBuffer[String] + + if (exception.getCondition != condition) { + mismatches += s"condition: expected '$condition' but got '${exception.getCondition}'" + } + sqlState.foreach { state => + if (exception.getSqlState != state) { + mismatches += s"sqlState: expected '$state' but got '${exception.getSqlState}'" + } + } + val actualParameters = exception.getMessageParameters.asScala val ignorable = checkErrorIgnorableParameters.getOrElse(condition, Set.empty[String]) val actualParametersToCompare = actualParameters.filter { case (k, _) => !ignorable.contains(k) || parameters.contains(k) } if (matchPVals) { - assert(actualParametersToCompare.size === parameters.size) - actualParametersToCompare.foreach( - exp => { - val parm = parameters.getOrElse(exp._1, - throw new IllegalArgumentException("Missing parameter" + exp._1)) - if (!exp._2.matches(parm)) { - throw new IllegalArgumentException("For parameter '" + exp._1 + "' value '" + exp._2 + - "' does not match: " + parm) - } + if (actualParametersToCompare.size != parameters.size) { + mismatches += s"parameters size: expected ${parameters.size} but got" + + s" ${actualParametersToCompare.size}" + } + actualParametersToCompare.foreach { case (key, actualVal) => + parameters.get(key) match { + case None => + mismatches += s"parameters: unexpected key '$key' with value '$actualVal'" + case Some(pattern) if !actualVal.matches(pattern) => + mismatches += s"parameters['$key']: value '$actualVal' does not match pattern" + + s" '$pattern'" + case _ => } - ) - } else { - assert(actualParametersToCompare === parameters) + } + parameters.keys.filterNot(actualParametersToCompare.contains).foreach { key => + mismatches += s"parameters: missing expected key '$key'" + } + } else if (actualParametersToCompare != parameters) { + mismatches += s"parameters: expected $parameters but got $actualParametersToCompare" } + val actualQueryContext = exception.getQueryContext() - assert(actualQueryContext.length === queryContext.length, "Invalid length of the query context") - actualQueryContext.zip(queryContext).foreach { case (actual, expected) => - assert(actual.contextType() === expected.contextType, - "Invalid contextType of a query context Actual:" + actual.toString) - if (actual.contextType() == QueryContextType.SQL) { - assert(actual.objectType() === expected.objectType, - "Invalid objectType of a query context Actual:" + actual.toString) - assert(actual.objectName() === expected.objectName, - "Invalid objectName of a query context. Actual:" + actual.toString) - // If startIndex and stopIndex are -1, it means we simply want to check the - // fragment of the query context. This should be the case when the fragment is - // distinguished within the query text. - if (expected.startIndex != -1) { - assert(actual.startIndex() === expected.startIndex, - "Invalid startIndex of a query context. Actual:" + actual.toString) + if (actualQueryContext.length != queryContext.length) { + mismatches += s"queryContext.length: expected ${queryContext.length}" + + s" but got ${actualQueryContext.length}" + } + actualQueryContext.zip(queryContext).zipWithIndex.foreach { + case ((actual, expected), idx) => + if (actual.contextType() != expected.contextType) { + mismatches += s"queryContext[$idx].contextType: expected ${expected.contextType}" + + s" but got ${actual.contextType()}" } - if (expected.stopIndex != -1) { - assert(actual.stopIndex() === expected.stopIndex, - "Invalid stopIndex of a query context. Actual:" + actual.toString) + if (actual.contextType() == QueryContextType.SQL) { + if (actual.objectType() != expected.objectType) { + mismatches += s"queryContext[$idx].objectType: expected '${expected.objectType}'" + + s" but got '${actual.objectType()}'" + } + if (actual.objectName() != expected.objectName) { + mismatches += s"queryContext[$idx].objectName: expected '${expected.objectName}'" + + s" but got '${actual.objectName()}'" + } + // If startIndex and stopIndex are -1, it means we simply want to check the + // fragment of the query context. This should be the case when the fragment is + // distinguished within the query text. + if (expected.startIndex != -1 && actual.startIndex() != expected.startIndex) { + mismatches += s"queryContext[$idx].startIndex: expected ${expected.startIndex}" + + s" but got ${actual.startIndex()}" + } + if (expected.stopIndex != -1 && actual.stopIndex() != expected.stopIndex) { + mismatches += s"queryContext[$idx].stopIndex: expected ${expected.stopIndex}" + + s" but got ${actual.stopIndex()}" + } + if (actual.fragment() != expected.fragment) { + mismatches += s"queryContext[$idx].fragment: expected '${expected.fragment}'" + + s" but got '${actual.fragment()}'" + } + } else if (actual.contextType() == QueryContextType.DataFrame) { + if (actual.fragment() != expected.fragment) { + mismatches += s"queryContext[$idx].fragment: expected '${expected.fragment}'" + + s" but got '${actual.fragment()}'" + } + if (expected.callSitePattern.nonEmpty && + !actual.callSite().matches(expected.callSitePattern)) { + mismatches += s"queryContext[$idx].callSite: '${actual.callSite()}'" + + s" does not match pattern '${expected.callSitePattern}'" + } } - assert(actual.fragment() === expected.fragment, - "Invalid fragment of a query context. Actual:" + actual.toString) - } else if (actual.contextType() == QueryContextType.DataFrame) { - assert(actual.fragment() === expected.fragment, - "Invalid code fragment of a query context. Actual:" + actual.toString) - if (expected.callSitePattern.nonEmpty) { - assert(actual.callSite().matches(expected.callSitePattern), - "Invalid callSite of a query context. Actual:" + actual.toString) + } + + if (mismatches.nonEmpty) { + val sb = new StringBuilder + sb.append(s"checkError found ${mismatches.size} mismatch(es).\n\n") + sb.append("=== Actual Exception State ===\n") + sb.append(s" condition: ${exception.getCondition}\n") + sb.append(s" sqlState: ${exception.getSqlState}\n") + sb.append(s" parameters:\n") + if (actualParameters.isEmpty) { + sb.append(" (empty)\n") + } else { + actualParameters.foreach { case (k, v) => sb.append(s" $k -> $v\n") } + } + actualQueryContext.zipWithIndex.foreach { case (ctx, idx) => + sb.append(s" queryContext[$idx] (${ctx.contextType()}):\n") + if (ctx.contextType() == QueryContextType.SQL) { + sb.append(s" objectType: ${ctx.objectType()}\n") + sb.append(s" objectName: ${ctx.objectName()}\n") + sb.append(s" startIndex: ${ctx.startIndex()}\n") + sb.append(s" stopIndex: ${ctx.stopIndex()}\n") + sb.append(s" fragment: ${ctx.fragment()}\n") + } else if (ctx.contextType() == QueryContextType.DataFrame) { + sb.append(s" fragment: ${ctx.fragment()}\n") + sb.append(s" callSite: ${ctx.callSite()}\n") } } + sb.append("\n=== Mismatches ===\n") + mismatches.foreach(m => sb.append(s" $m\n")) + fail(sb.toString()) } } diff --git a/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala b/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala index 9612c7b3eb17b..b04e735d6bfbc 100644 --- a/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala @@ -31,6 +31,7 @@ import com.fasterxml.jackson.core.util.{DefaultIndenter, DefaultPrettyPrinter} import com.fasterxml.jackson.databind.SerializationFeature import com.fasterxml.jackson.databind.json.JsonMapper import com.fasterxml.jackson.module.scala.DefaultScalaModule +import org.scalatest.exceptions.TestFailedException import org.apache.spark.SparkThrowableHelper._ import org.apache.spark.util.Utils @@ -800,4 +801,62 @@ class SparkThrowableSuite extends SparkFunSuite { assert(errorWithNull.getSqlState == "22018", "Should fall back to error class reader SQL state when custom is null") } + + test("checkError reports all mismatches in a single failure message") { + class TestQueryContext extends QueryContext { + override val contextType = QueryContextType.SQL + override val objectName = "v1" + override val objectType = "VIEW" + override val startIndex = 2 + override val stopIndex = 10 + override val fragment = "1 / 0" + override val callSite = "" + override val summary = "" + } + + val exception = new SparkArithmeticException( + errorClass = "DIVIDE_BY_ZERO", + messageParameters = Map("config" -> "CONFIG"), + context = Array(new TestQueryContext), + summary = "") + + val error = intercept[TestFailedException] { + checkError( + exception = exception, + condition = "WRONG_CONDITION", + sqlState = Some("99999"), + parameters = Map("config" -> "WRONG_VALUE"), + queryContext = Array(ExpectedContext( + objectType = "TABLE", + objectName = "t1", + startIndex = 0, + stopIndex = 5, + fragment = "wrong fragment"))) + } + val msg = error.getMessage + assert(msg.contains("=== Actual Exception State ==="), "Should contain actual state header") + assert(msg.contains("=== Mismatches ==="), "Should contain mismatches header") + assert(msg.contains("condition:"), "Should report condition mismatch") + assert(msg.contains("sqlState:"), "Should report sqlState mismatch") + assert(msg.contains("parameters:"), "Should report parameters mismatch") + assert(msg.contains("queryContext[0].objectType:"), "Should report objectType mismatch") + assert(msg.contains("queryContext[0].objectName:"), "Should report objectName mismatch") + assert(msg.contains("queryContext[0].startIndex:"), "Should report startIndex mismatch") + assert(msg.contains("queryContext[0].stopIndex:"), "Should report stopIndex mismatch") + assert(msg.contains("queryContext[0].fragment:"), "Should report fragment mismatch") + } + + test("checkError succeeds when all fields match") { + val exception = new SparkArithmeticException( + errorClass = "DIVIDE_BY_ZERO", + messageParameters = Map("config" -> "CONFIG"), + context = Array.empty, + summary = "") + + checkError( + exception = exception, + condition = "DIVIDE_BY_ZERO", + sqlState = Some("22012"), + parameters = Map("config" -> "CONFIG")) + } } diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/QueryTest.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/QueryTest.scala index da3b32b408f58..51e13c240e8f0 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/QueryTest.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/QueryTest.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.connect.test import java.util.TimeZone +import scala.collection.mutable.ListBuffer import scala.jdk.CollectionConverters._ import org.scalatest.Assertions @@ -112,58 +113,115 @@ abstract class QueryTest extends ConnectFunSuite with SQLHelper { parameters: Map[String, String] = Map.empty, matchPVals: Boolean = false, queryContext: Array[ExpectedContext] = Array.empty): Unit = { - assert(exception.getCondition === condition) - sqlState.foreach(state => assert(exception.getSqlState === state)) - val expectedParameters = exception.getMessageParameters.asScala + val mismatches = new ListBuffer[String] + + if (exception.getCondition != condition) { + mismatches += s"condition: expected '$condition' but got '${exception.getCondition}'" + } + sqlState.foreach { state => + if (exception.getSqlState != state) { + mismatches += s"sqlState: expected '$state' but got '${exception.getSqlState}'" + } + } + + val actualParameters = exception.getMessageParameters.asScala if (matchPVals) { - assert(expectedParameters.size === parameters.size) - expectedParameters.foreach(exp => { - val parm = parameters.getOrElse( - exp._1, - throw new IllegalArgumentException("Missing parameter" + exp._1)) - if (!exp._2.matches(parm)) { - throw new IllegalArgumentException( - "For parameter '" + exp._1 + "' value '" + exp._2 + - "' does not match: " + parm) + if (actualParameters.size != parameters.size) { + mismatches += s"parameters size: expected ${parameters.size} but got" + + s" ${actualParameters.size}" + } + actualParameters.foreach { case (key, actualVal) => + parameters.get(key) match { + case None => + mismatches += s"parameters: unexpected key '$key' with value '$actualVal'" + case Some(pattern) if !actualVal.matches(pattern) => + mismatches += s"parameters['$key']: value '$actualVal' does not match pattern" + + s" '$pattern'" + case _ => } - }) - } else { - assert(expectedParameters === parameters) + } + parameters.keys.filterNot(actualParameters.contains).foreach { key => + mismatches += s"parameters: missing expected key '$key'" + } + } else if (actualParameters != parameters) { + mismatches += s"parameters: expected $parameters but got $actualParameters" } + val actualQueryContext = exception.getQueryContext() - assert( - actualQueryContext.length === queryContext.length, - "Invalid length of the query context") - actualQueryContext.zip(queryContext).foreach { case (actual, expected) => - assert( - actual.contextType() === expected.contextType, - "Invalid contextType of a query context Actual:" + actual.toString) + if (actualQueryContext.length != queryContext.length) { + mismatches += s"queryContext.length: expected ${queryContext.length}" + + s" but got ${actualQueryContext.length}" + } + actualQueryContext.zip(queryContext).zipWithIndex.foreach { case ((actual, expected), idx) => + if (actual.contextType() != expected.contextType) { + mismatches += s"queryContext[$idx].contextType: expected ${expected.contextType}" + + s" but got ${actual.contextType()}" + } if (actual.contextType() == QueryContextType.SQL) { - assert( - actual.objectType() === expected.objectType, - "Invalid objectType of a query context Actual:" + actual.toString) - assert( - actual.objectName() === expected.objectName, - "Invalid objectName of a query context. Actual:" + actual.toString) - assert( - actual.startIndex() === expected.startIndex, - "Invalid startIndex of a query context. Actual:" + actual.toString) - assert( - actual.stopIndex() === expected.stopIndex, - "Invalid stopIndex of a query context. Actual:" + actual.toString) - assert( - actual.fragment() === expected.fragment, - "Invalid fragment of a query context. Actual:" + actual.toString) + if (actual.objectType() != expected.objectType) { + mismatches += s"queryContext[$idx].objectType: expected '${expected.objectType}'" + + s" but got '${actual.objectType()}'" + } + if (actual.objectName() != expected.objectName) { + mismatches += s"queryContext[$idx].objectName: expected '${expected.objectName}'" + + s" but got '${actual.objectName()}'" + } + // If startIndex and stopIndex are -1, it means we simply want to check the + // fragment of the query context. This should be the case when the fragment is + // distinguished within the query text. + if (expected.startIndex != -1 && actual.startIndex() != expected.startIndex) { + mismatches += s"queryContext[$idx].startIndex: expected ${expected.startIndex}" + + s" but got ${actual.startIndex()}" + } + if (expected.stopIndex != -1 && actual.stopIndex() != expected.stopIndex) { + mismatches += s"queryContext[$idx].stopIndex: expected ${expected.stopIndex}" + + s" but got ${actual.stopIndex()}" + } + if (actual.fragment() != expected.fragment) { + mismatches += s"queryContext[$idx].fragment: expected '${expected.fragment}'" + + s" but got '${actual.fragment()}'" + } } else if (actual.contextType() == QueryContextType.DataFrame) { - assert( - actual.fragment() === expected.fragment, - "Invalid code fragment of a query context. Actual:" + actual.toString) - if (expected.callSitePattern.nonEmpty) { - assert( - actual.callSite().matches(expected.callSitePattern), - "Invalid callSite of a query context. Actual:" + actual.toString) + if (actual.fragment() != expected.fragment) { + mismatches += s"queryContext[$idx].fragment: expected '${expected.fragment}'" + + s" but got '${actual.fragment()}'" + } + if (expected.callSitePattern.nonEmpty && + !actual.callSite().matches(expected.callSitePattern)) { + mismatches += s"queryContext[$idx].callSite: '${actual.callSite()}'" + + s" does not match pattern '${expected.callSitePattern}'" + } + } + } + + if (mismatches.nonEmpty) { + val sb = new StringBuilder + sb.append(s"checkError found ${mismatches.size} mismatch(es).\n\n") + sb.append("=== Actual Exception State ===\n") + sb.append(s" condition: ${exception.getCondition}\n") + sb.append(s" sqlState: ${exception.getSqlState}\n") + sb.append(s" parameters:\n") + if (actualParameters.isEmpty) { + sb.append(" (empty)\n") + } else { + actualParameters.foreach { case (k, v) => sb.append(s" $k -> $v\n") } + } + actualQueryContext.zipWithIndex.foreach { case (ctx, idx) => + sb.append(s" queryContext[$idx] (${ctx.contextType()}):\n") + if (ctx.contextType() == QueryContextType.SQL) { + sb.append(s" objectType: ${ctx.objectType()}\n") + sb.append(s" objectName: ${ctx.objectName()}\n") + sb.append(s" startIndex: ${ctx.startIndex()}\n") + sb.append(s" stopIndex: ${ctx.stopIndex()}\n") + sb.append(s" fragment: ${ctx.fragment()}\n") + } else if (ctx.contextType() == QueryContextType.DataFrame) { + sb.append(s" fragment: ${ctx.fragment()}\n") + sb.append(s" callSite: ${ctx.callSite()}\n") } } + sb.append("\n=== Mismatches ===\n") + mismatches.foreach(m => sb.append(s" $m\n")) + fail(sb.toString()) } }