Skip to content

Commit 6ced69c

Browse files
committed
fix
1 parent a74d50b commit 6ced69c

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed

sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -962,6 +962,60 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {
962962
!aggregateExpression.containsPattern(TreePattern.UNRESOLVED_ORDINAL)))
963963
}
964964

965+
test("Window function with orderBy literal") {
966+
// Create a local relation with test data
967+
val schema = StructType(Seq(StructField("col1", IntegerType)))
968+
val data = Seq(InternalRow(1))
969+
val inputRows = data.map { row =>
970+
val proj = UnsafeProjection.create(schema)
971+
proj(row).copy()
972+
}
973+
val localRelation = createLocalRelationProto(schema, inputRows)
974+
975+
// Create the sum(col1) function
976+
val sumFunction = proto.Expression.newBuilder()
977+
.setUnresolvedFunction(
978+
proto.Expression.UnresolvedFunction.newBuilder()
979+
.setFunctionName("sum")
980+
.addArguments(proto.Expression.newBuilder()
981+
.setUnresolvedAttribute(proto.Expression.UnresolvedAttribute.newBuilder()
982+
.setUnparsedIdentifier("col1"))))
983+
.build()
984+
985+
// Create window expression: sum(col1).over(Window.orderBy(lit(4)))
986+
val windowExpression = proto.Expression.newBuilder()
987+
.setWindow(proto.Expression.Window.newBuilder()
988+
.setWindowFunction(sumFunction)
989+
.addOrderSpec(proto.Expression.SortOrder.newBuilder()
990+
.setChild(proto.Expression.newBuilder()
991+
.setLiteral(proto.Expression.Literal.newBuilder().setInteger(4)))
992+
.setDirection(proto.Expression.SortOrder.SortDirection.SORT_DIRECTION_ASCENDING)
993+
.setNullOrdering(proto.Expression.SortOrder.NullOrdering.SORT_NULLS_FIRST)))
994+
.build()
995+
996+
// Create alias for the result column
997+
val aliasedWindowExpression = proto.Expression.newBuilder()
998+
.setAlias(proto.Expression.Alias.newBuilder()
999+
.setExpr(windowExpression)
1000+
.addName("sum_over"))
1001+
.build()
1002+
1003+
// Build the project relation
1004+
val project = proto.Project.newBuilder()
1005+
.setInput(localRelation)
1006+
.addExpressions(aliasedWindowExpression)
1007+
.build()
1008+
1009+
val result = transform(proto.Relation.newBuilder().setProject(project).build())
1010+
val df = Dataset.ofRows(spark, result)
1011+
1012+
// Verify the result
1013+
val collected = df.collect()
1014+
assert(collected.length == 1)
1015+
assert(df.schema.fields.head.name == "sum_over")
1016+
assert(collected(0).getAs[Long]("sum_over") == 1L)
1017+
}
1018+
9651019
test("Time literal") {
9661020
val project = proto.Project.newBuilder
9671021
.addExpressions(

0 commit comments

Comments
 (0)