@@ -962,6 +962,60 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {
962
962
! aggregateExpression.containsPattern(TreePattern .UNRESOLVED_ORDINAL )))
963
963
}
964
964
965
+ test(" Window function with orderBy literal" ) {
966
+ // Create a local relation with test data
967
+ val schema = StructType (Seq (StructField (" col1" , IntegerType )))
968
+ val data = Seq (InternalRow (1 ))
969
+ val inputRows = data.map { row =>
970
+ val proj = UnsafeProjection .create(schema)
971
+ proj(row).copy()
972
+ }
973
+ val localRelation = createLocalRelationProto(schema, inputRows)
974
+
975
+ // Create the sum(col1) function
976
+ val sumFunction = proto.Expression .newBuilder()
977
+ .setUnresolvedFunction(
978
+ proto.Expression .UnresolvedFunction .newBuilder()
979
+ .setFunctionName(" sum" )
980
+ .addArguments(proto.Expression .newBuilder()
981
+ .setUnresolvedAttribute(proto.Expression .UnresolvedAttribute .newBuilder()
982
+ .setUnparsedIdentifier(" col1" ))))
983
+ .build()
984
+
985
+ // Create window expression: sum(col1).over(Window.orderBy(lit(4)))
986
+ val windowExpression = proto.Expression .newBuilder()
987
+ .setWindow(proto.Expression .Window .newBuilder()
988
+ .setWindowFunction(sumFunction)
989
+ .addOrderSpec(proto.Expression .SortOrder .newBuilder()
990
+ .setChild(proto.Expression .newBuilder()
991
+ .setLiteral(proto.Expression .Literal .newBuilder().setInteger(4 )))
992
+ .setDirection(proto.Expression .SortOrder .SortDirection .SORT_DIRECTION_ASCENDING )
993
+ .setNullOrdering(proto.Expression .SortOrder .NullOrdering .SORT_NULLS_FIRST )))
994
+ .build()
995
+
996
+ // Create alias for the result column
997
+ val aliasedWindowExpression = proto.Expression .newBuilder()
998
+ .setAlias(proto.Expression .Alias .newBuilder()
999
+ .setExpr(windowExpression)
1000
+ .addName(" sum_over" ))
1001
+ .build()
1002
+
1003
+ // Build the project relation
1004
+ val project = proto.Project .newBuilder()
1005
+ .setInput(localRelation)
1006
+ .addExpressions(aliasedWindowExpression)
1007
+ .build()
1008
+
1009
+ val result = transform(proto.Relation .newBuilder().setProject(project).build())
1010
+ val df = Dataset .ofRows(spark, result)
1011
+
1012
+ // Verify the result
1013
+ val collected = df.collect()
1014
+ assert(collected.length == 1 )
1015
+ assert(df.schema.fields.head.name == " sum_over" )
1016
+ assert(collected(0 ).getAs[Long ](" sum_over" ) == 1L )
1017
+ }
1018
+
965
1019
test(" Time literal" ) {
966
1020
val project = proto.Project .newBuilder
967
1021
.addExpressions(
0 commit comments