Skip to content

Commit 9ada1e7

Browse files
mihailotim-dbcloud-fan
authored andcommitted
[SPARK-51820][FOLLOWUP][CONNECT] Replace literal in SortOrder only under Sort operator
### What changes were proposed in this pull request? Replace literal in `SortOrder` only under `Sort` operator ### Why are the changes needed? SPARK-51820 introduced a bug where literal under all `SortOrder` expressions were treated as ordinals, breaking Windows in Spark Connect. This PR fixes that. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added a test case ### Was this patch authored or co-authored using generative AI tooling? No Closes #52189 from mihailotim-db/mihailotim-db/fix_window_ordinal. Authored-by: Mihailo Timotic <mihailo.timotic@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 166895a commit 9ada1e7

File tree

2 files changed

+80
-6
lines changed

2 files changed

+80
-6
lines changed

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1752,7 +1752,8 @@ class SparkConnectPlanner(
17521752
transformUnresolvedExtractValue(exp.getUnresolvedExtractValue)
17531753
case proto.Expression.ExprTypeCase.UPDATE_FIELDS =>
17541754
transformUpdateFields(exp.getUpdateFields)
1755-
case proto.Expression.ExprTypeCase.SORT_ORDER => transformSortOrder(exp.getSortOrder)
1755+
case proto.Expression.ExprTypeCase.SORT_ORDER =>
1756+
transformSortOrder(order = exp.getSortOrder, shouldReplaceOrdinals = false)
17561757
case proto.Expression.ExprTypeCase.LAMBDA_FUNCTION =>
17571758
transformLambdaFunction(exp.getLambdaFunction)
17581759
case proto.Expression.ExprTypeCase.UNRESOLVED_NAMED_LAMBDA_VARIABLE =>
@@ -2230,7 +2231,8 @@ class SparkConnectPlanner(
22302231

22312232
val windowSpec = WindowSpecDefinition(
22322233
partitionSpec = window.getPartitionSpecList.asScala.toSeq.map(transformExpression),
2233-
orderSpec = window.getOrderSpecList.asScala.toSeq.map(transformSortOrder),
2234+
orderSpec = window.getOrderSpecList.asScala.toSeq.map(orderSpec =>
2235+
transformSortOrder(order = orderSpec, shouldReplaceOrdinals = false)),
22342236
frameSpecification = frameSpec)
22352237

22362238
WindowExpression(
@@ -2382,12 +2384,20 @@ class SparkConnectPlanner(
23822384
logical.Sort(
23832385
child = transformRelation(sort.getInput),
23842386
global = sort.getIsGlobal,
2385-
order = sort.getOrderList.asScala.toSeq.map(transformSortOrder))
2387+
order = sort.getOrderList.asScala.toSeq.map(order =>
2388+
transformSortOrder(order = order, shouldReplaceOrdinals = true)))
23862389
}
23872390

2388-
private def transformSortOrder(order: proto.Expression.SortOrder) = {
2391+
private def transformSortOrder(
2392+
order: proto.Expression.SortOrder,
2393+
shouldReplaceOrdinals: Boolean = false) = {
2394+
val childWithReplacedOrdinals = if (shouldReplaceOrdinals) {
2395+
transformSortOrderAndReplaceOrdinals(order.getChild)
2396+
} else {
2397+
transformExpression(order.getChild)
2398+
}
23892399
expressions.SortOrder(
2390-
child = transformSortOrderAndReplaceOrdinals(order.getChild),
2400+
child = childWithReplacedOrdinals,
23912401
direction = order.getDirection match {
23922402
case proto.Expression.SortOrder.SortDirection.SORT_DIRECTION_ASCENDING =>
23932403
expressions.Ascending
@@ -4081,7 +4091,8 @@ class SparkConnectPlanner(
40814091
.map(transformExpression)
40824092
.toSeq,
40834093
orderSpec = options.getOrderSpecList.asScala
4084-
.map(transformSortOrder)
4094+
.map(orderSpec =>
4095+
transformSortOrder(order = orderSpec, shouldReplaceOrdinals = false))
40854096
.toSeq,
40864097
withSinglePartition =
40874098
options.hasWithSinglePartition && options.getWithSinglePartition)

sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -962,6 +962,69 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {
962962
!aggregateExpression.containsPattern(TreePattern.UNRESOLVED_ORDINAL)))
963963
}
964964

965+
test("SPARK-51820 Literals in SortOrder should only be replaced under Sort node") {
966+
val schema = StructType(Seq(StructField("col1", IntegerType)))
967+
val data = Seq(InternalRow(1))
968+
val inputRows = data.map { row =>
969+
val proj = UnsafeProjection.create(schema)
970+
proj(row).copy()
971+
}
972+
val localRelation = createLocalRelationProto(schema, inputRows)
973+
974+
val sumFunction = proto.Expression
975+
.newBuilder()
976+
.setUnresolvedFunction(
977+
proto.Expression.UnresolvedFunction
978+
.newBuilder()
979+
.setFunctionName("sum")
980+
.addArguments(
981+
proto.Expression
982+
.newBuilder()
983+
.setUnresolvedAttribute(proto.Expression.UnresolvedAttribute
984+
.newBuilder()
985+
.setUnparsedIdentifier("col1"))))
986+
.build()
987+
988+
val windowExpression = proto.Expression
989+
.newBuilder()
990+
.setWindow(
991+
proto.Expression.Window
992+
.newBuilder()
993+
.setWindowFunction(sumFunction)
994+
.addOrderSpec(
995+
proto.Expression.SortOrder
996+
.newBuilder()
997+
.setChild(proto.Expression
998+
.newBuilder()
999+
.setLiteral(proto.Expression.Literal.newBuilder().setInteger(4)))
1000+
.setDirection(proto.Expression.SortOrder.SortDirection.SORT_DIRECTION_ASCENDING)
1001+
.setNullOrdering(proto.Expression.SortOrder.NullOrdering.SORT_NULLS_FIRST)))
1002+
.build()
1003+
1004+
val aliasedWindowExpression = proto.Expression
1005+
.newBuilder()
1006+
.setAlias(
1007+
proto.Expression.Alias
1008+
.newBuilder()
1009+
.setExpr(windowExpression)
1010+
.addName("sum_over"))
1011+
.build()
1012+
1013+
val project = proto.Project
1014+
.newBuilder()
1015+
.setInput(localRelation)
1016+
.addExpressions(aliasedWindowExpression)
1017+
.build()
1018+
1019+
val result = transform(proto.Relation.newBuilder().setProject(project).build())
1020+
val df = Dataset.ofRows(spark, result)
1021+
1022+
val collected = df.collect()
1023+
assert(collected.length == 1)
1024+
assert(df.schema.fields.head.name == "sum_over")
1025+
assert(collected(0).getAs[Long]("sum_over") == 1L)
1026+
}
1027+
9651028
test("Time literal") {
9661029
val project = proto.Project.newBuilder
9671030
.addExpressions(

0 commit comments

Comments
 (0)