From 36eaa924e78b88f429a374549cb6e2d2f0490ff0 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Tue, 26 Aug 2025 20:37:24 +0200 Subject: [PATCH 1/2] [SPARK-53399][SQL] Improve CollapseProject --- .../sql/catalyst/optimizer/Optimizer.scala | 106 ++++++++++++------ .../optimizer/CollapseProjectSuite.scala | 16 +++ 2 files changed, 86 insertions(+), 36 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index ef505a0144113..ba6ccd0670a9d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1169,9 +1169,12 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper { def apply(plan: LogicalPlan, alwaysInline: Boolean): LogicalPlan = { plan.transformUpWithPruning(_.containsPattern(PROJECT), ruleId) { - case p1 @ Project(_, p2: Project) - if canCollapseExpressions(p1.projectList, p2.projectList, alwaysInline) => - p2.copy(projectList = buildCleanedProjectList(p1.projectList, p2.projectList)) + case p1 @ Project(_, p2: Project) => + mergeProjectExpressions(p1.projectList, p2.projectList, alwaysInline) match { + case (Seq(), merged) => p2.copy(projectList = merged) + case (newUpper, newLower) => + p1.copy(projectList = newUpper, child = p2.copy(projectList = newLower)) + } case p @ Project(_, agg: Aggregate) if canCollapseExpressions(p.projectList, agg.aggregateExpressions, alwaysInline) && canCollapseAggregate(p, agg) => @@ -1191,6 +1194,67 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper { } } + private def cheapToInlineProducer( + producer: NamedExpression, + relatedConsumers: Iterable[Expression]) = trimAliases(producer) match { + // These collection creation functions are not cheap as a producer, but we have + // optimizer rules that can optimize them out if they are only consumed by + // ExtractValue (See SimplifyExtractValueOps), so we need to allow to inline them to + // avoid perf regression. As an example: + // Project(s.a, s.b, Project(create_struct(a, b, c) as s, child)) + // We should collapse these two projects and eventually get Project(a, b, child) + case e @ (_: CreateNamedStruct | _: UpdateFields | _: CreateMap | _: CreateArray) => + // We can inline the collection creation producer if at most one of its access + // is non-cheap. Cheap access here means the access can be optimized by + // `SimplifyExtractValueOps` and become a cheap expression. For example, + // `create_struct(a, b, c).a` is a cheap access as it can be optimized to `a`. + // For a query: + // Project(s.a, s, Project(create_struct(a, b, c) as s, child)) + // We should collapse these two projects and eventually get + // Project(a, create_struct(a, b, c) as s, child) + var nonCheapAccessSeen = false + def nonCheapAccessVisitor(): Boolean = { + // Returns true for all calls after the first. + try { + nonCheapAccessSeen + } finally { + nonCheapAccessSeen = true + } + } + + !relatedConsumers + .exists(findNonCheapAccesses(_, producer.toAttribute, e, nonCheapAccessVisitor)) + + case other => isCheap(other) + } + + private def mergeProjectExpressions( + consumers: Seq[NamedExpression], + producers: Seq[NamedExpression], + alwaysInline: Boolean): (Seq[NamedExpression], Seq[NamedExpression]) = { + lazy val producerAttributes = AttributeSet(producers.collect { case a: Alias => a.toAttribute }) + lazy val producerReferences = AttributeMap(consumers + .flatMap(e => collectReferences(e).filter(producerAttributes.contains).map(_ -> e)) + .groupMap(_._1)(_._2) + .transform((_, v) => (v.size, ExpressionSet(v)))) + + val (substitute, keep) = producers.partition { + case a: Alias if producerReferences.contains(a.toAttribute) => + val (count, relatedConsumers) = producerReferences(a.toAttribute) + a.deterministic && + (alwaysInline || count == 1 || cheapToInlineProducer(a, relatedConsumers)) + + case _ => true + } + + val substituted = buildCleanedProjectList(consumers, substitute) + if (keep.isEmpty) { + (Seq.empty, substituted) + } else { + (substituted, keep ++ AttributeSet(substitute.flatMap(_.references))) + } + } + /** * Check if we can collapse expressions safely. */ @@ -1206,7 +1270,7 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper { */ def canCollapseExpressions( consumers: Seq[Expression], - producerMap: Map[Attribute, Expression], + producerMap: Map[Attribute, NamedExpression], alwaysInline: Boolean = false): Boolean = { // We can only collapse expressions if all input expressions meet the following criteria: // - The input is deterministic. @@ -1221,38 +1285,8 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper { val producer = producerMap.getOrElse(reference, reference) val relatedConsumers = consumers.filter(_.references.contains(reference)) - def cheapToInlineProducer: Boolean = trimAliases(producer) match { - // These collection creation functions are not cheap as a producer, but we have - // optimizer rules that can optimize them out if they are only consumed by - // ExtractValue (See SimplifyExtractValueOps), so we need to allow to inline them to - // avoid perf regression. As an example: - // Project(s.a, s.b, Project(create_struct(a, b, c) as s, child)) - // We should collapse these two projects and eventually get Project(a, b, child) - case e @ (_: CreateNamedStruct | _: UpdateFields | _: CreateMap | _: CreateArray) => - // We can inline the collection creation producer if at most one of its access - // is non-cheap. Cheap access here means the access can be optimized by - // `SimplifyExtractValueOps` and become a cheap expression. For example, - // `create_struct(a, b, c).a` is a cheap access as it can be optimized to `a`. - // For a query: - // Project(s.a, s, Project(create_struct(a, b, c) as s, child)) - // We should collapse these two projects and eventually get - // Project(a, create_struct(a, b, c) as s, child) - var nonCheapAccessSeen = false - def nonCheapAccessVisitor(): Boolean = { - // Returns true for all calls after the first. - try { - nonCheapAccessSeen - } finally { - nonCheapAccessSeen = true - } - } - - !relatedConsumers.exists(findNonCheapAccesses(_, reference, e, nonCheapAccessVisitor)) - - case other => isCheap(other) - } - - producer.deterministic && (count == 1 || alwaysInline || cheapToInlineProducer) + producer.deterministic && + (count == 1 || alwaysInline || cheapToInlineProducer(producer, relatedConsumers)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala index e83f231c188e7..6e97231e41b1c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala @@ -298,4 +298,20 @@ class CollapseProjectSuite extends PlanTest { comparePlans(optimized, expected) } } + + test("SPARK-53399: Merge expressions") { + val query = testRelation + .select($"a" + 1 as "a_plus_1", $"b" + 1 as "b_plus_1") + .select($"a_plus_1" + $"a_plus_1", $"b_plus_1") + .analyze + + val optimized = Optimize.execute(query) + + val expected = testRelation + .select($"a" + 1 as "a_plus_1", $"b") + .select($"a_plus_1" + $"a_plus_1", $"b" + 1 as "b_plus_1") + .analyze + + comparePlans(optimized, expected) + } } From 0abacd79b4a5a2c135e3fb5df6c9ecc85761e78d Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Wed, 3 Sep 2025 18:39:44 +0200 Subject: [PATCH 2/2] comment to `producerReferences` map --- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index ba6ccd0670a9d..025ecc98b868a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1233,6 +1233,10 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper { producers: Seq[NamedExpression], alwaysInline: Boolean): (Seq[NamedExpression], Seq[NamedExpression]) = { lazy val producerAttributes = AttributeSet(producers.collect { case a: Alias => a.toAttribute }) + + // A map from producer attributes to tuples of: + // - how many times the producer is referenced from consumers and + // - the set of consumers that reference the producer. lazy val producerReferences = AttributeMap(consumers .flatMap(e => collectReferences(e).filter(producerAttributes.contains).map(_ -> e)) .groupMap(_._1)(_._2)