From dfd25c20df1b5f9b0ddd8ae2a80c659f30d5656f Mon Sep 17 00:00:00 2001 From: eyes_on_me Date: Fri, 29 Aug 2025 11:19:32 +0800 Subject: [PATCH] [Enhancement] support expr reuse in outer join where predicates (#62139) Signed-off-by: silverbullet233 <3675229+silverbullet233@users.noreply.github.com> (cherry picked from commit 6c220d0dac21ab7545393ea17c82eed0c91f8fea) --- be/src/exec/cross_join_node.cpp | 15 +- be/src/exec/cross_join_node.h | 2 + be/src/exec/hash_join_node.cpp | 10 +- be/src/exec/hash_join_node.h | 2 + be/src/exec/hash_joiner.cpp | 12 ++ be/src/exec/hash_joiner.h | 8 +- .../pipeline/hashjoin/hash_joiner_factory.cpp | 3 + .../pipeline/nljoin/nljoin_probe_operator.cpp | 34 ++++- .../pipeline/nljoin/nljoin_probe_operator.h | 5 + .../spillable_nljoin_probe_operator.cpp | 17 ++- .../nljoin/spillable_nljoin_probe_operator.h | 9 +- be/src/exprs/expr.cpp | 22 +++ be/src/exprs/expr.h | 3 + be/src/storage/chunk_helper.cpp | 19 +++ be/src/storage/chunk_helper.h | 25 ++++ .../com/starrocks/planner/HashJoinNode.java | 3 + .../java/com/starrocks/planner/JoinNode.java | 15 ++ .../starrocks/planner/NestLoopJoinNode.java | 4 + .../operator/scalar/DictMappingOperator.java | 5 + .../exprreuse/ScalarOperatorsReuseRule.java | 13 +- .../validate/InputDependenciesChecker.java | 3 + .../sql/plan/PlanFragmentBuilder.java | 56 +++++--- .../planner/PushDownSubfieldHashJoinTest.java | 16 ++- .../optimizer/JoinPredicateExprReuseTest.java | 130 +++++++++++++++++ .../java/com/starrocks/sql/plan/JoinTest.java | 11 +- .../sql/plan/LowCardinalityTest.java | 14 +- .../sql/plan/LowCardinalityTest2.java | 5 +- gensrc/thrift/PlanNodes.thrift | 3 + .../sql/test_join/R/test_predicate_expr_reuse | 135 ++++++++++++++++++ .../sql/test_join/T/test_predicate_expr_reuse | 106 ++++++++++++++ 30 files changed, 643 insertions(+), 62 deletions(-) create mode 100644 fe/fe-core/src/test/java/com/starrocks/sql/optimizer/JoinPredicateExprReuseTest.java create mode 100644 test/sql/test_join/R/test_predicate_expr_reuse create mode 100644 test/sql/test_join/T/test_predicate_expr_reuse diff --git a/be/src/exec/cross_join_node.cpp b/be/src/exec/cross_join_node.cpp index b97e6035a7cc62..16be4bfe1674f2 100644 --- a/be/src/exec/cross_join_node.cpp +++ b/be/src/exec/cross_join_node.cpp @@ -88,6 +88,13 @@ Status CrossJoinNode::init(const TPlanNode& tnode, RuntimeState* state) { _build_runtime_filters.emplace_back(rf_desc); } } + if (tnode.nestloop_join_node.__isset.common_slot_map) { + for (const auto& [key, val] : tnode.nestloop_join_node.common_slot_map) { + ExprContext* context; + RETURN_IF_ERROR(Expr::create_expr_tree(_pool, val, &context, state, true)); + _common_expr_ctxs.insert({key, context}); + } + } return Status::OK(); } @@ -608,10 +615,10 @@ std::vector> CrossJoinNode::_decompos OpFactories left_ops = _children[0]->decompose_to_pipeline(context); // communication with CrossJoinRight through shared_data. - auto left_factory = - std::make_shared(context->next_operator_id(), id(), _row_descriptor, child(0)->row_desc(), - child(1)->row_desc(), _sql_join_conjuncts, std::move(_join_conjuncts), - std::move(_conjunct_ctxs), std::move(cross_join_context), _join_op); + auto left_factory = std::make_shared( + context->next_operator_id(), id(), _row_descriptor, child(0)->row_desc(), child(1)->row_desc(), + _sql_join_conjuncts, std::move(_join_conjuncts), std::move(_conjunct_ctxs), std::move(_common_expr_ctxs), + std::move(cross_join_context), _join_op); // Initialize OperatorFactory's fields involving runtime filters. this->init_runtime_filter_for_operator(left_factory.get(), context, rc_rf_probe_collector); if (!context->is_colocate_group()) { diff --git a/be/src/exec/cross_join_node.h b/be/src/exec/cross_join_node.h index 01f7e7e6c6dd27..58873d723c53fd 100644 --- a/be/src/exec/cross_join_node.h +++ b/be/src/exec/cross_join_node.h @@ -128,6 +128,8 @@ class CrossJoinNode final : public ExecNode { std::vector _build_runtime_filters; bool _interpolate_passthrough = false; + + std::map _common_expr_ctxs; }; } // namespace starrocks diff --git a/be/src/exec/hash_join_node.cpp b/be/src/exec/hash_join_node.cpp index 5532839e3a15ed..d96d137b1a7e57 100644 --- a/be/src/exec/hash_join_node.cpp +++ b/be/src/exec/hash_join_node.cpp @@ -126,6 +126,14 @@ Status HashJoinNode::init(const TPlanNode& tnode, RuntimeState* state) { _build_equivalence_partition_expr_ctxs = _build_expr_ctxs; } + if (tnode.__isset.hash_join_node && tnode.hash_join_node.__isset.common_slot_map) { + for (const auto& [key, val] : tnode.hash_join_node.common_slot_map) { + ExprContext* context; + RETURN_IF_ERROR(Expr::create_expr_tree(_pool, val, &context, state, true)); + _common_expr_ctxs.insert({key, context}); + } + } + RETURN_IF_ERROR(Expr::create_expr_trees(_pool, tnode.hash_join_node.other_join_conjuncts, &_other_join_conjunct_ctxs, state)); @@ -484,7 +492,7 @@ pipeline::OpFactories HashJoinNode::_decompose_to_pipeline(pipeline::PipelineBui _other_join_conjunct_ctxs, _conjunct_ctxs, child(1)->row_desc(), child(0)->row_desc(), child(1)->type(), child(0)->type(), child(1)->conjunct_ctxs().empty(), _build_runtime_filters, _output_slots, _output_slots, context->degree_of_parallelism(), _distribution_mode, - _enable_late_materialization, _enable_partition_hash_join, _is_skew_join); + _enable_late_materialization, _enable_partition_hash_join, _is_skew_join, _common_expr_ctxs); auto hash_joiner_factory = std::make_shared(param); // Create a shared RefCountedRuntimeFilterCollector diff --git a/be/src/exec/hash_join_node.h b/be/src/exec/hash_join_node.h index 2e6966a7c66283..10006135d4913e 100644 --- a/be/src/exec/hash_join_node.h +++ b/be/src/exec/hash_join_node.h @@ -140,6 +140,8 @@ class HashJoinNode final : public ExecNode { bool _probe_eos = false; // probe table scan finished; size_t _runtime_join_filter_pushdown_limit = 1024000; + std::map _common_expr_ctxs; + RuntimeProfile::Counter* _build_timer = nullptr; RuntimeProfile::Counter* _build_ht_timer = nullptr; RuntimeProfile::Counter* _copy_right_table_chunk_timer = nullptr; diff --git a/be/src/exec/hash_joiner.cpp b/be/src/exec/hash_joiner.cpp index 08fc2f3ea3d316..c8955a9360ef25 100644 --- a/be/src/exec/hash_joiner.cpp +++ b/be/src/exec/hash_joiner.cpp @@ -33,6 +33,7 @@ #include "pipeline/hashjoin/hash_joiner_fwd.h" #include "runtime/current_thread.h" #include "simd/simd.h" +#include "storage/chunk_helper.h" #include "util/runtime_profile.h" namespace starrocks { @@ -73,6 +74,7 @@ HashJoiner::HashJoiner(const HashJoinerParam& param) _probe_expr_ctxs(param._probe_expr_ctxs), _other_join_conjunct_ctxs(param._other_join_conjunct_ctxs), _conjunct_ctxs(param._conjunct_ctxs), + _common_expr_ctxs(param._common_expr_ctxs), _build_row_descriptor(param._build_row_descriptor), _probe_row_descriptor(param._probe_row_descriptor), _build_node_type(param._build_node_type), @@ -158,6 +160,11 @@ void HashJoiner::_init_hash_table_param(HashTableParam* param, RuntimeState* sta param->column_view_concat_rows_limit = state->column_view_concat_rows_limit(); param->column_view_concat_bytes_limit = state->column_view_concat_bytes_limit(); std::set predicate_slots; + for (const auto& [slot_id, ctx] : _common_expr_ctxs) { + std::vector expr_slots; + ctx->root()->get_slot_ids(&expr_slots); + predicate_slots.insert(expr_slots.begin(), expr_slots.end()); + } for (ExprContext* expr_context : _conjunct_ctxs) { std::vector expr_slots; expr_context->root()->get_slot_ids(&expr_slots); @@ -388,6 +395,9 @@ Status HashJoiner::_calc_filter_for_other_conjunct(ChunkPtr* chunk, Filter& filt hit_all = false; filter.assign((*chunk)->num_rows(), 1); + CommonExprEvalScopeGuard guard(*chunk, _common_expr_ctxs); + RETURN_IF_ERROR(guard.evaluate()); + for (auto* ctx : _other_join_conjunct_ctxs) { ASSIGN_OR_RETURN(ColumnPtr column, ctx->evaluate((*chunk).get())) size_t true_count = ColumnHelper::count_true_with_notnull(column); @@ -516,6 +526,8 @@ Status HashJoiner::_process_other_conjunct(ChunkPtr* chunk, JoinHashTable& hash_ Status HashJoiner::_process_where_conjunct(ChunkPtr* chunk) { SCOPED_TIMER(probe_metrics().where_conjunct_evaluate_timer); + CommonExprEvalScopeGuard guard(*chunk, _common_expr_ctxs); + RETURN_IF_ERROR(guard.evaluate()); return ExecNode::eval_conjuncts(_conjunct_ctxs, (*chunk).get()); } diff --git a/be/src/exec/hash_joiner.h b/be/src/exec/hash_joiner.h index 458d2bfb47dc6e..10d41313664748 100644 --- a/be/src/exec/hash_joiner.h +++ b/be/src/exec/hash_joiner.h @@ -72,7 +72,8 @@ struct HashJoinerParam { bool build_conjunct_ctxs_is_empty, std::list build_runtime_filters, std::set build_output_slots, std::set probe_output_slots, size_t max_dop, const TJoinDistributionMode::type distribution_mode, bool enable_late_materialization, - bool enable_partition_hash_join, bool is_skew_join) + bool enable_partition_hash_join, bool is_skew_join, + const std::map& common_expr_ctxs) : _pool(pool), _hash_join_node(hash_join_node), _is_null_safes(std::move(is_null_safes)), @@ -92,7 +93,8 @@ struct HashJoinerParam { _distribution_mode(distribution_mode), _enable_late_materialization(enable_late_materialization), _enable_partition_hash_join(enable_partition_hash_join), - _is_skew_join(is_skew_join) {} + _is_skew_join(is_skew_join), + _common_expr_ctxs(common_expr_ctxs) {} HashJoinerParam(HashJoinerParam&&) = default; HashJoinerParam(HashJoinerParam&) = default; @@ -120,6 +122,7 @@ struct HashJoinerParam { const bool _enable_late_materialization; const bool _enable_partition_hash_join; const bool _is_skew_join; + const std::map _common_expr_ctxs; }; inline bool could_short_circuit(TJoinOp::type join_type) { @@ -439,6 +442,7 @@ class HashJoiner final : public pipeline::ContextWithDependency { const std::vector& _other_join_conjunct_ctxs; // Conjuncts in Join followed by a filter predicate, usually in Where and Having. const std::vector& _conjunct_ctxs; + const std::map& _common_expr_ctxs; const RowDescriptor& _build_row_descriptor; const RowDescriptor& _probe_row_descriptor; const TPlanNodeType::type _build_node_type; diff --git a/be/src/exec/pipeline/hashjoin/hash_joiner_factory.cpp b/be/src/exec/pipeline/hashjoin/hash_joiner_factory.cpp index 925ec4002ef0b9..cacfb4e71b1d95 100644 --- a/be/src/exec/pipeline/hashjoin/hash_joiner_factory.cpp +++ b/be/src/exec/pipeline/hashjoin/hash_joiner_factory.cpp @@ -19,16 +19,19 @@ namespace starrocks::pipeline { Status HashJoinerFactory::prepare(RuntimeState* state) { RETURN_IF_ERROR(Expr::prepare(_param._build_expr_ctxs, state)); RETURN_IF_ERROR(Expr::prepare(_param._probe_expr_ctxs, state)); + RETURN_IF_ERROR(Expr::prepare(_param._common_expr_ctxs, state)); RETURN_IF_ERROR(Expr::prepare(_param._other_join_conjunct_ctxs, state)); RETURN_IF_ERROR(Expr::prepare(_param._conjunct_ctxs, state)); RETURN_IF_ERROR(Expr::open(_param._build_expr_ctxs, state)); RETURN_IF_ERROR(Expr::open(_param._probe_expr_ctxs, state)); + RETURN_IF_ERROR(Expr::open(_param._common_expr_ctxs, state)); RETURN_IF_ERROR(Expr::open(_param._other_join_conjunct_ctxs, state)); RETURN_IF_ERROR(Expr::open(_param._conjunct_ctxs, state)); return Status::OK(); } void HashJoinerFactory::close(RuntimeState* state) { + Expr::close(_param._common_expr_ctxs, state); Expr::close(_param._conjunct_ctxs, state); Expr::close(_param._other_join_conjunct_ctxs, state); Expr::close(_param._probe_expr_ctxs, state); diff --git a/be/src/exec/pipeline/nljoin/nljoin_probe_operator.cpp b/be/src/exec/pipeline/nljoin/nljoin_probe_operator.cpp index 9aa2a170f30b69..ba9851b9d2f8f5 100644 --- a/be/src/exec/pipeline/nljoin/nljoin_probe_operator.cpp +++ b/be/src/exec/pipeline/nljoin/nljoin_probe_operator.cpp @@ -22,6 +22,7 @@ #include "runtime/current_thread.h" #include "runtime/descriptors.h" #include "simd/simd.h" +#include "storage/chunk_helper.h" namespace starrocks::pipeline { @@ -30,6 +31,7 @@ NLJoinProbeOperator::NLJoinProbeOperator(OperatorFactory* factory, int32_t id, i const std::string& sql_join_conjuncts, const std::vector& join_conjuncts, const std::vector& conjunct_ctxs, + const std::map& common_expr_ctxs, const std::vector& col_types, size_t probe_column_count, const std::shared_ptr& cross_join_context) : OperatorWithDependency(factory, id, "nestloop_join_probe", plan_node_id, false, driver_sequence), @@ -39,6 +41,7 @@ NLJoinProbeOperator::NLJoinProbeOperator(OperatorFactory* factory, int32_t id, i _sql_join_conjuncts(sql_join_conjuncts), _join_conjuncts(join_conjuncts), _conjunct_ctxs(conjunct_ctxs), + _common_expr_ctxs(common_expr_ctxs), _cross_join_context(cross_join_context) {} Status NLJoinProbeOperator::prepare(RuntimeState* state) { @@ -309,6 +312,9 @@ Status NLJoinProbeOperator::_eval_nullaware_anti_conjuncts(const ChunkPtr& chunk // for null-aware left anti join, join_conjunct[0] is on-predicate // others are other-conjuncts // process on conjuncts + CommonExprEvalScopeGuard guard(chunk, _common_expr_ctxs); + RETURN_IF_ERROR(guard.evaluate()); + { ASSIGN_OR_RETURN(ColumnPtr column, _join_conjuncts[0]->evaluate(chunk.get())); size_t num_false = ColumnHelper::count_false_with_notnull(column); @@ -354,6 +360,8 @@ Status NLJoinProbeOperator::_eval_nullaware_anti_conjuncts(const ChunkPtr& chunk Status NLJoinProbeOperator::_probe_for_inner_join(const ChunkPtr& chunk) { if (!_join_conjuncts.empty() && chunk && !chunk->is_empty()) { + CommonExprEvalScopeGuard guard(chunk, _common_expr_ctxs); + RETURN_IF_ERROR(guard.evaluate()); RETURN_IF_ERROR(eval_conjuncts_and_in_filters(_join_conjuncts, chunk.get(), nullptr, true)); } return Status::OK(); @@ -374,7 +382,10 @@ Status NLJoinProbeOperator::_probe_for_other_join(const ChunkPtr& chunk) { if (_is_null_aware_left_anti_join()) { RETURN_IF_ERROR(_eval_nullaware_anti_conjuncts(chunk, &filter)); } else { + CommonExprEvalScopeGuard guard(chunk, _common_expr_ctxs); + RETURN_IF_ERROR(guard.evaluate()); RETURN_IF_ERROR(eval_conjuncts_and_in_filters(_join_conjuncts, chunk.get(), &filter, apply_filter)); + chunk->check_or_die(); } DCHECK(!!filter); // The filter has not been assigned if no rows matched @@ -652,8 +663,11 @@ Status NLJoinProbeOperator::_permute_right_join(size_t chunk_size) { } } permute_rows += chunk->num_rows(); - - RETURN_IF_ERROR(eval_conjuncts(_conjunct_ctxs, chunk.get(), nullptr)); + { + CommonExprEvalScopeGuard guard(chunk, _common_expr_ctxs); + RETURN_IF_ERROR(guard.evaluate()); + RETURN_IF_ERROR(eval_conjuncts(_conjunct_ctxs, chunk.get(), nullptr)); + } RETURN_IF_ERROR(_output_accumulator.push(std::move(chunk))); match_flag_index += cur_chunk_size; } @@ -703,7 +717,11 @@ StatusOr NLJoinProbeOperator::_pull_chunk_for_other_join(size_t chunk_ ASSIGN_OR_RETURN(ChunkPtr chunk, _permute_chunk_for_other_join(chunk_size)); DCHECK(chunk); RETURN_IF_ERROR(_probe_for_other_join(chunk)); - RETURN_IF_ERROR(eval_conjuncts(_conjunct_ctxs, chunk.get(), nullptr)); + { + CommonExprEvalScopeGuard guard(chunk, _common_expr_ctxs); + RETURN_IF_ERROR(guard.evaluate()); + RETURN_IF_ERROR(eval_conjuncts(_conjunct_ctxs, chunk.get(), nullptr)); + } RETURN_IF_ERROR(_output_accumulator.push(std::move(chunk))); if (ChunkPtr res = _output_accumulator.pull()) { @@ -800,9 +818,9 @@ void NLJoinProbeOperatorFactory::_init_row_desc() { } OperatorPtr NLJoinProbeOperatorFactory::create(int32_t degree_of_parallelism, int32_t driver_sequence) { - return std::make_shared(this, _id, _plan_node_id, driver_sequence, _join_op, - _sql_join_conjuncts, _join_conjuncts, _conjunct_ctxs, _col_types, - _probe_column_count, _cross_join_context); + return std::make_shared( + this, _id, _plan_node_id, driver_sequence, _join_op, _sql_join_conjuncts, _join_conjuncts, _conjunct_ctxs, + _common_expr_ctxs, _col_types, _probe_column_count, _cross_join_context); } Status NLJoinProbeOperatorFactory::prepare(RuntimeState* state) { @@ -812,6 +830,9 @@ Status NLJoinProbeOperatorFactory::prepare(RuntimeState* state) { _cross_join_context->ref(); _init_row_desc(); + + RETURN_IF_ERROR(Expr::prepare(_common_expr_ctxs, state)); + RETURN_IF_ERROR(Expr::open(_common_expr_ctxs, state)); RETURN_IF_ERROR(Expr::prepare(_join_conjuncts, state)); RETURN_IF_ERROR(Expr::open(_join_conjuncts, state)); RETURN_IF_ERROR(Expr::prepare(_conjunct_ctxs, state)); @@ -821,6 +842,7 @@ Status NLJoinProbeOperatorFactory::prepare(RuntimeState* state) { } void NLJoinProbeOperatorFactory::close(RuntimeState* state) { + Expr::close(_common_expr_ctxs, state); Expr::close(_join_conjuncts, state); Expr::close(_conjunct_ctxs, state); diff --git a/be/src/exec/pipeline/nljoin/nljoin_probe_operator.h b/be/src/exec/pipeline/nljoin/nljoin_probe_operator.h index fbe66c657ffb45..fdbb52345c8980 100644 --- a/be/src/exec/pipeline/nljoin/nljoin_probe_operator.h +++ b/be/src/exec/pipeline/nljoin/nljoin_probe_operator.h @@ -37,6 +37,7 @@ class NLJoinProbeOperator final : public OperatorWithDependency { NLJoinProbeOperator(OperatorFactory* factory, int32_t id, int32_t plan_node_id, int32_t driver_sequence, TJoinOp::type join_op, const std::string& sql_join_conjuncts, const std::vector& join_conjuncts, const std::vector& conjunct_ctxs, + const std::map& common_expr_ctxs, const std::vector& col_types, size_t probe_column_count, const std::shared_ptr& cross_join_context); @@ -115,6 +116,7 @@ class NLJoinProbeOperator final : public OperatorWithDependency { const std::vector& _join_conjuncts; const std::vector& _conjunct_ctxs; + const std::map& _common_expr_ctxs; const std::shared_ptr& _cross_join_context; bool _input_finished = false; @@ -147,6 +149,7 @@ class NLJoinProbeOperatorFactory final : public OperatorWithDependencyFactory { const RowDescriptor& left_row_desc, const RowDescriptor& right_row_desc, std::string sql_join_conjuncts, std::vector&& join_conjuncts, std::vector&& conjunct_ctxs, + std::map&& common_expr_ctxs, std::shared_ptr&& cross_join_context, TJoinOp::type join_op) : OperatorWithDependencyFactory(id, "cross_join_left", plan_node_id), _join_op(join_op), @@ -155,6 +158,7 @@ class NLJoinProbeOperatorFactory final : public OperatorWithDependencyFactory { _sql_join_conjuncts(std::move(sql_join_conjuncts)), _join_conjuncts(std::move(join_conjuncts)), _conjunct_ctxs(std::move(conjunct_ctxs)), + _common_expr_ctxs(std::move(common_expr_ctxs)), _cross_join_context(std::move(cross_join_context)) {} ~NLJoinProbeOperatorFactory() override = default; @@ -178,6 +182,7 @@ class NLJoinProbeOperatorFactory final : public OperatorWithDependencyFactory { std::string _sql_join_conjuncts; std::vector _join_conjuncts; std::vector _conjunct_ctxs; + std::map _common_expr_ctxs; std::shared_ptr _cross_join_context; }; diff --git a/be/src/exec/pipeline/nljoin/spillable_nljoin_probe_operator.cpp b/be/src/exec/pipeline/nljoin/spillable_nljoin_probe_operator.cpp index c7a2a92c7091b0..3ec1501fd64f35 100644 --- a/be/src/exec/pipeline/nljoin/spillable_nljoin_probe_operator.cpp +++ b/be/src/exec/pipeline/nljoin/spillable_nljoin_probe_operator.cpp @@ -27,12 +27,14 @@ namespace starrocks::pipeline { NLJoinProber::NLJoinProber(TJoinOp::type join_op, const std::vector& join_conjuncts, const std::vector& conjunct_ctxs, + const std::map& common_expr_ctxs, const std::vector& col_types, size_t probe_column_count) : _join_op(join_op), _col_types(col_types), _probe_column_count(probe_column_count), _join_conjuncts(join_conjuncts), - _conjunct_ctxs(conjunct_ctxs) {} + _conjunct_ctxs(conjunct_ctxs), + _common_expr_ctxs(common_expr_ctxs) {} Status NLJoinProber::prepare(RuntimeState* state, RuntimeProfile* profile) { _permute_rows_counter = ADD_COUNTER(profile, "PermuteRows", TUnit::UNIT); @@ -115,10 +117,11 @@ void NLJoinProber::_permute_probe_row(Chunk* dst, const ChunkPtr& build_chunk) { SpillableNLJoinProbeOperator::SpillableNLJoinProbeOperator( OperatorFactory* factory, int32_t id, int32_t plan_node_id, int32_t driver_sequence, TJoinOp::type join_op, const std::string& sql_join_conjuncts, const std::vector& join_conjuncts, - const std::vector& conjunct_ctxs, const std::vector& col_types, - size_t probe_column_count, const std::shared_ptr& cross_join_context) + const std::vector& conjunct_ctxs, const std::map& common_expr_ctxs, + const std::vector& col_types, size_t probe_column_count, + const std::shared_ptr& cross_join_context) : OperatorWithDependency(factory, id, "spillable_nestloop_join_probe", plan_node_id, false, driver_sequence), - _prober(join_op, join_conjuncts, conjunct_ctxs, col_types, probe_column_count), + _prober(join_op, join_conjuncts, conjunct_ctxs, common_expr_ctxs, col_types, probe_column_count), _cross_join_context(cross_join_context) {} Status SpillableNLJoinProbeOperator::prepare(RuntimeState* state) { @@ -244,9 +247,9 @@ void SpillableNLJoinProbeOperatorFactory::_init_row_desc() { } OperatorPtr SpillableNLJoinProbeOperatorFactory::create(int32_t degree_of_parallelism, int32_t driver_sequence) { - return std::make_shared(this, _id, _plan_node_id, driver_sequence, _join_op, - _sql_join_conjuncts, _join_conjuncts, _conjunct_ctxs, - _col_types, _probe_column_count, _cross_join_context); + return std::make_shared( + this, _id, _plan_node_id, driver_sequence, _join_op, _sql_join_conjuncts, _join_conjuncts, _conjunct_ctxs, + _common_expr_ctxs, _col_types, _probe_column_count, _cross_join_context); } Status SpillableNLJoinProbeOperatorFactory::prepare(RuntimeState* state) { diff --git a/be/src/exec/pipeline/nljoin/spillable_nljoin_probe_operator.h b/be/src/exec/pipeline/nljoin/spillable_nljoin_probe_operator.h index a9849e75abf23b..3d4b819c7e5f86 100644 --- a/be/src/exec/pipeline/nljoin/spillable_nljoin_probe_operator.h +++ b/be/src/exec/pipeline/nljoin/spillable_nljoin_probe_operator.h @@ -30,8 +30,8 @@ namespace starrocks::pipeline { class NLJoinProber { public: NLJoinProber(TJoinOp::type join_op, const std::vector& join_conjuncts, - const std::vector& conjunct_ctxs, const std::vector& col_types, - size_t probe_column_count); + const std::vector& conjunct_ctxs, const std::map& common_expr_ctxs, + const std::vector& col_types, size_t probe_column_count); ~NLJoinProber() = default; @@ -80,6 +80,7 @@ class NLJoinProber { const std::vector& _join_conjuncts; const std::vector& _conjunct_ctxs; + const std::map& _common_expr_ctxs; // ChunkPtr _probe_chunk = nullptr; @@ -98,6 +99,7 @@ class SpillableNLJoinProbeOperator final : public OperatorWithDependency { TJoinOp::type join_op, const std::string& sql_join_conjuncts, const std::vector& join_conjuncts, const std::vector& conjunct_ctxs, + const std::map& common_expr_ctxs, const std::vector& col_types, size_t probe_column_count, const std::shared_ptr& cross_join_context); @@ -153,6 +155,7 @@ class SpillableNLJoinProbeOperatorFactory final : public OperatorWithDependencyF const RowDescriptor& left_row_desc, const RowDescriptor& right_row_desc, std::string sql_join_conjuncts, std::vector&& join_conjuncts, std::vector&& conjunct_ctxs, + std::map&& common_expr_ctxs, std::shared_ptr&& cross_join_context, TJoinOp::type join_op) : OperatorWithDependencyFactory(id, "spillable_nl_join_left", plan_node_id), _join_op(join_op), @@ -161,6 +164,7 @@ class SpillableNLJoinProbeOperatorFactory final : public OperatorWithDependencyF _sql_join_conjuncts(std::move(sql_join_conjuncts)), _join_conjuncts(std::move(join_conjuncts)), _conjunct_ctxs(std::move(conjunct_ctxs)), + _common_expr_ctxs(std::move(common_expr_ctxs)), _cross_join_context(std::move(cross_join_context)) {} ~SpillableNLJoinProbeOperatorFactory() override = default; @@ -184,6 +188,7 @@ class SpillableNLJoinProbeOperatorFactory final : public OperatorWithDependencyF std::string _sql_join_conjuncts; std::vector _join_conjuncts; std::vector _conjunct_ctxs; + std::map _common_expr_ctxs; std::shared_ptr _cross_join_context; }; diff --git a/be/src/exprs/expr.cpp b/be/src/exprs/expr.cpp index 71ae1db541a854..ad82cb0f869063 100644 --- a/be/src/exprs/expr.cpp +++ b/be/src/exprs/expr.cpp @@ -513,6 +513,13 @@ Status Expr::prepare(const std::vector& ctxs, RuntimeState* state) return Status::OK(); } +Status Expr::prepare(const std::map& ctxs, RuntimeState* state) { + for (const auto& [_, ctx] : ctxs) { + RETURN_IF_ERROR(ctx->prepare(state)); + } + return Status::OK(); +} + Status Expr::prepare(RuntimeState* state, ExprContext* context) { FAIL_POINT_TRIGGER_RETURN_ERROR(randome_error); DCHECK(_type.type != TYPE_UNKNOWN); @@ -529,6 +536,13 @@ Status Expr::open(const std::vector& ctxs, RuntimeState* state) { return Status::OK(); } +Status Expr::open(const std::map& ctxs, RuntimeState* state) { + for (const auto& [_, ctx] : ctxs) { + RETURN_IF_ERROR(ctx->open(state)); + } + return Status::OK(); +} + Status Expr::open(RuntimeState* state, ExprContext* context, FunctionContext::FunctionStateScope scope) { FAIL_POINT_TRIGGER_RETURN_ERROR(random_error); DCHECK(_type.type != TYPE_UNKNOWN); @@ -546,6 +560,14 @@ void Expr::close(const std::vector& ctxs, RuntimeState* state) { } } +void Expr::close(const std::map& ctxs, RuntimeState* state) { + for (const auto& [_, ctx] : ctxs) { + if (ctx != nullptr) { + ctx->close(state); + } + } +} + void Expr::close(RuntimeState* state, ExprContext* context, FunctionContext::FunctionStateScope scope) { for (auto& i : _children) { i->close(state, context, scope); diff --git a/be/src/exprs/expr.h b/be/src/exprs/expr.h index 7ebe955a38dad8..c04f8eb856bea4 100644 --- a/be/src/exprs/expr.h +++ b/be/src/exprs/expr.h @@ -179,9 +179,11 @@ class Expr { /// Convenience function for preparing multiple expr trees. static Status prepare(const std::vector& ctxs, RuntimeState* state); + static Status prepare(const std::map& ctxs, RuntimeState* state); /// Convenience function for opening multiple expr trees. static Status open(const std::vector& ctxs, RuntimeState* state); + static Status open(const std::map& ctxs, RuntimeState* state); /// Clones each ExprContext for multiple expr trees. 'new_ctxs' must be non-NULL. /// Idempotent: if '*new_ctxs' is empty, a clone of each context in 'ctxs' will be added @@ -192,6 +194,7 @@ class Expr { /// Convenience function for closing multiple expr trees. static void close(const std::vector& ctxs, RuntimeState* state); + static void close(const std::map& ctxs, RuntimeState* state); /// Convenience functions for closing a list of ScalarExpr. static void close(const std::vector& exprs); diff --git a/be/src/storage/chunk_helper.cpp b/be/src/storage/chunk_helper.cpp index 99924df458b28d..e9e33a4c7fd4b5 100644 --- a/be/src/storage/chunk_helper.cpp +++ b/be/src/storage/chunk_helper.cpp @@ -26,6 +26,7 @@ #include "column/struct_column.h" #include "column/type_traits.h" #include "column/vectorized_fwd.h" +#include "exprs/expr_context.h" #include "gutil/strings/fastmem.h" #include "runtime/current_thread.h" #include "runtime/descriptors.h" @@ -1037,4 +1038,22 @@ void SegmentedChunk::check_or_die() { } } +CommonExprEvalScopeGuard::CommonExprEvalScopeGuard(const ChunkPtr& chunk, + const std::map& common_expr_ctxs) + : _chunk(chunk), _common_expr_ctxs(common_expr_ctxs) {} + +CommonExprEvalScopeGuard::~CommonExprEvalScopeGuard() { + for (const auto& [slot_id, _] : _common_expr_ctxs) { + _chunk->remove_column_by_slot_id(slot_id); + } +} + +Status CommonExprEvalScopeGuard::evaluate() { + for (const auto& [slot_id, ctx] : _common_expr_ctxs) { + ASSIGN_OR_RETURN(auto column, ctx->evaluate(_chunk.get())); + _chunk->append_column(std::move(column), slot_id); + } + return Status::OK(); +} + } // namespace starrocks diff --git a/be/src/storage/chunk_helper.h b/be/src/storage/chunk_helper.h index 7499d30e7a8e2d..ad19edc9cb9d7e 100644 --- a/be/src/storage/chunk_helper.h +++ b/be/src/storage/chunk_helper.h @@ -221,4 +221,29 @@ class SegmentedChunk final : public std::enable_shared_from_this const size_t _segment_size; }; +class ExprContext; +/** + * RAII guard for evaluating common expressions on a chunk. + * + * This class provides automatic scope management for evaluating common expressions + * that are temporarily used during expression computation. Common expressions are + * computed once and reused across multiple expressions to avoid redundant computation, + * but they are only needed during the computation phase and should be cleaned up + * from the chunk after computation completes. + * + * The destructor automatically removes the common expressions from the chunk + * to prevent memory leaks and ensure proper cleanup. + */ +class CommonExprEvalScopeGuard { +public: + CommonExprEvalScopeGuard(const ChunkPtr& chunk, const std::map& common_expr_ctxs); + ~CommonExprEvalScopeGuard(); + + Status evaluate(); + +private: + const ChunkPtr& _chunk; + const std::map& _common_expr_ctxs; +}; + } // namespace starrocks diff --git a/fe/fe-core/src/main/java/com/starrocks/planner/HashJoinNode.java b/fe/fe-core/src/main/java/com/starrocks/planner/HashJoinNode.java index 7589baf53b64cb..3bb02c53aad7bd 100644 --- a/fe/fe-core/src/main/java/com/starrocks/planner/HashJoinNode.java +++ b/fe/fe-core/src/main/java/com/starrocks/planner/HashJoinNode.java @@ -192,6 +192,9 @@ protected void toThrift(TPlanNode msg) { if (isSkewJoin) { msg.hash_join_node.setIs_skew_join(isSkewJoin); } + if (commonSlotMap != null) { + commonSlotMap.forEach((key, value) -> msg.hash_join_node.putToCommon_slot_map(key.asInt(), value.treeToThrift())); + } } @Override diff --git a/fe/fe-core/src/main/java/com/starrocks/planner/JoinNode.java b/fe/fe-core/src/main/java/com/starrocks/planner/JoinNode.java index 79df8d571bacc8..c5d262cf338069 100644 --- a/fe/fe-core/src/main/java/com/starrocks/planner/JoinNode.java +++ b/fe/fe-core/src/main/java/com/starrocks/planner/JoinNode.java @@ -58,7 +58,9 @@ import org.apache.logging.log4j.Logger; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.stream.Collectors; @@ -95,6 +97,7 @@ public abstract class JoinNode extends PlanNode implements RuntimeFilterBuildNod // The partitionByExprs which need to check the probe side for partition join. protected List probePartitionByExprs; protected boolean canLocalShuffle = false; + protected Map commonSlotMap; public List getBuildRuntimeFilters() { return buildRuntimeFilters; @@ -492,6 +495,10 @@ public void setUkfkProperty(UKFKConstraints.JoinProperty ukfkProperty) { this.ukfkProperty = ukfkProperty; } + public void setCommonSlotMap(Map commonSlotMap) { + this.commonSlotMap = commonSlotMap; + } + @Override protected String getNodeExplainString(String detailPrefix, TExplainLevel detailLevel) { String distrModeStr = @@ -527,6 +534,14 @@ protected String getNodeExplainString(String detailPrefix, TExplainLevel detailL .append("\n"); } + if (commonSlotMap != null && !commonSlotMap.isEmpty()) { + output.append(detailPrefix + " common sub expr:" + "\n"); + for (Map.Entry entry : commonSlotMap.entrySet()) { + output.append(detailPrefix + " : " + + getExplainString(Arrays.asList(entry.getValue())) + "\n"); + } + } + if (detailLevel == TExplainLevel.VERBOSE) { if (!buildRuntimeFilters.isEmpty()) { diff --git a/fe/fe-core/src/main/java/com/starrocks/planner/NestLoopJoinNode.java b/fe/fe-core/src/main/java/com/starrocks/planner/NestLoopJoinNode.java index 27d10f593a5e90..08a63f69dc8855 100644 --- a/fe/fe-core/src/main/java/com/starrocks/planner/NestLoopJoinNode.java +++ b/fe/fe-core/src/main/java/com/starrocks/planner/NestLoopJoinNode.java @@ -136,6 +136,10 @@ protected void toThrift(TPlanNode msg) { msg.nestloop_join_node.setBuild_runtime_filters( RuntimeFilterDescription.toThriftRuntimeFilterDescriptions(buildRuntimeFilters)); } + if (commonSlotMap != null) { + commonSlotMap.forEach((key, value) -> + msg.nestloop_join_node.putToCommon_slot_map(key.asInt(), value.treeToThrift())); + } } @Override diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/scalar/DictMappingOperator.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/scalar/DictMappingOperator.java index a021961206d492..9bf6638187f443 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/scalar/DictMappingOperator.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/scalar/DictMappingOperator.java @@ -78,6 +78,11 @@ public ScalarOperator getChild(int index) { public void setChild(int index, ScalarOperator child) { } + @Override + public boolean isConstant() { + return false; + } + @Override public String toString() { String stringOperator = stringProvideOperator == null ? "" : ", " + stringProvideOperator; diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/exprreuse/ScalarOperatorsReuseRule.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/exprreuse/ScalarOperatorsReuseRule.java index 297288f3eac385..e0c7a825b35081 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/exprreuse/ScalarOperatorsReuseRule.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/exprreuse/ScalarOperatorsReuseRule.java @@ -20,7 +20,7 @@ import com.starrocks.sql.optimizer.base.ColumnRefFactory; import com.starrocks.sql.optimizer.operator.OperatorType; import com.starrocks.sql.optimizer.operator.Projection; -import com.starrocks.sql.optimizer.operator.physical.PhysicalFilterOperator; +import com.starrocks.sql.optimizer.operator.physical.PhysicalOperator; import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator; import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator; import com.starrocks.sql.optimizer.rule.tree.TreeRewriteRule; @@ -49,10 +49,10 @@ public Void visit(OptExpression opt, TaskContext context) { if (shouldRewritePredicate(opt, context)) { Projection result = rewritePredicate(opt, context); if (!result.getCommonSubOperatorMap().isEmpty()) { - PhysicalFilterOperator filter = (PhysicalFilterOperator) opt.getOp(); + PhysicalOperator op = (PhysicalOperator) opt.getOp(); ScalarOperator newPredicate = result.getColumnRefMap().values().iterator().next(); - filter.setPredicate(newPredicate); - filter.setPredicateCommonOperators(result.getCommonSubOperatorMap()); + op.setPredicate(newPredicate); + op.setPredicateCommonOperators(result.getCommonSubOperatorMap()); } } @@ -89,8 +89,9 @@ private boolean shouldRewritePredicate(OptExpression input, TaskContext context) || input.getOp().getPredicate() == null) { return false; } - // for now, only support rewrite predicates in PhysicalFilterOperator - if (input.getOp().getOpType() == OperatorType.PHYSICAL_FILTER) { + if (input.getOp().getOpType() == OperatorType.PHYSICAL_FILTER || + input.getOp().getOpType() == OperatorType.PHYSICAL_HASH_JOIN || + input.getOp().getOpType() == OperatorType.PHYSICAL_NESTLOOP_JOIN) { return true; } return false; diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/validate/InputDependenciesChecker.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/validate/InputDependenciesChecker.java index 95cc6a4fe4e37e..0ccec597ce4b29 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/validate/InputDependenciesChecker.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/validate/InputDependenciesChecker.java @@ -162,6 +162,9 @@ private void checkJoinOpt(OptExpression optExpression) { } if (joinOperator.getPredicate() != null) { usedCols.union(joinOperator.getPredicate().getUsedColumns()); + if (joinOperator.getPredicateCommonOperators() != null) { + inputCols.union(joinOperator.getPredicateCommonOperators().keySet()); + } } } checkInputCols(inputCols, usedCols, optExpression); diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/plan/PlanFragmentBuilder.java b/fe/fe-core/src/main/java/com/starrocks/sql/plan/PlanFragmentBuilder.java index d20e3f3d946c2a..12411e9286094e 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/plan/PlanFragmentBuilder.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/plan/PlanFragmentBuilder.java @@ -2732,6 +2732,7 @@ public PlanFragment visitPhysicalNestLoopJoin(OptExpression optExpr, ExecPlan co rightFragment.getPlanRoot().forceCollectExecStats(); this.currentExecGroup = leftExecGroup; + Map commonSubExprMap = buildCommonSubExprMap(node.getPredicateCommonOperators(), context); List conjuncts = extractConjuncts(node.getPredicate(), context); List joinOnConjuncts = extractConjuncts(node.getOnPredicate(), context); List probePartitionByExprs = Lists.newArrayList(); @@ -2749,6 +2750,7 @@ public PlanFragment visitPhysicalNestLoopJoin(OptExpression optExpr, ExecPlan co NestLoopJoinNode joinNode = new NestLoopJoinNode(context.getNextNodeId(), leftFragment.getPlanRoot(), rightFragment.getPlanRoot(), null, node.getJoinType(), Lists.newArrayList(), joinOnConjuncts); + joinNode.setCommonSlotMap(commonSubExprMap); joinNode.setLimit(node.getLimit()); joinNode.computeStatistics(optExpr.getStatistics()); @@ -2863,6 +2865,7 @@ private PlanFragment visitPhysicalJoin(PlanFragment leftFragment, PlanFragment r List eqJoinConjuncts = joinExpr.eqJoinConjuncts; List otherJoinConjuncts = joinExpr.otherJoin; List conjuncts = joinExpr.conjuncts; + Map commonSlotMap = joinExpr.commonSubOperatorMap; setNullableForJoin(joinOperator, leftFragment, rightFragment, context); @@ -2880,6 +2883,7 @@ private PlanFragment visitPhysicalJoin(PlanFragment leftFragment, PlanFragment r joinNode.setUkfkProperty(joinProperty); } } + joinNode.setCommonSlotMap(commonSlotMap); // set skew join, this is used by runtime filter PhysicalHashJoinOperator physicalHashJoinOperator = (PhysicalHashJoinOperator) node; boolean isSkewJoin = physicalHashJoinOperator.getSkewColumn() != null; @@ -3414,23 +3418,8 @@ public PlanFragment visitPhysicalFilter(OptExpression optExpr, ExecPlan context) TupleDescriptor tupleDescriptor = context.getDescTbl().createTupleDescriptor(); - Map commonSubOperatorMap = Maps.newHashMap(); - if (filter.getPredicateCommonOperators() != null) { - for (Map.Entry entry : filter.getPredicateCommonOperators().entrySet()) { - Expr expr = ScalarOperatorToExpr.buildExecExpression(entry.getValue(), - new ScalarOperatorToExpr.FormatterContext(context.getColRefToExpr(), - filter.getPredicateCommonOperators())); + Map commonSubOperatorMap = buildCommonSubExprMap(filter.getPredicateCommonOperators(), context); - commonSubOperatorMap.put(new SlotId(entry.getKey().getId()), expr); - - SlotDescriptor slotDescriptor = - context.getDescTbl().addSlotDescriptor(tupleDescriptor, new SlotId(entry.getKey().getId())); - slotDescriptor.setIsNullable(expr.isNullable()); - slotDescriptor.setIsMaterialized(false); - slotDescriptor.setType(expr.getType()); - context.getColRefToExpr().put(entry.getKey(), new SlotRef(entry.getKey().toString(), slotDescriptor)); - } - } List predicates = Utils.extractConjuncts(filter.getPredicate()).stream() .map(d -> ScalarOperatorToExpr.buildExecExpression(d, @@ -3648,12 +3637,38 @@ static class JoinExprInfo { public final List eqJoinConjuncts; public final List otherJoin; public final List conjuncts; + public final Map commonSubOperatorMap; - public JoinExprInfo(List eqJoinConjuncts, List otherJoin, List conjuncts) { + public JoinExprInfo(List eqJoinConjuncts, List otherJoin, List conjuncts, + Map commonSubOperatorMap) { this.eqJoinConjuncts = eqJoinConjuncts; this.otherJoin = otherJoin; this.conjuncts = conjuncts; + this.commonSubOperatorMap = commonSubOperatorMap; } + + } + + private Map buildCommonSubExprMap( + Map commonSubOperators, ExecPlan context) { + Map commonSubExprMap = Maps.newHashMap(); + if (commonSubOperators != null && !commonSubOperators.isEmpty()) { + TupleDescriptor tupleDescriptor = context.getDescTbl().createTupleDescriptor(); + for (Map.Entry entry : commonSubOperators.entrySet()) { + Expr expr = ScalarOperatorToExpr.buildExecExpression(entry.getValue(), + new ScalarOperatorToExpr.FormatterContext(context.getColRefToExpr(), commonSubOperators)); + + commonSubExprMap.put(new SlotId(entry.getKey().getId()), expr); + + SlotDescriptor slotDescriptor = + context.getDescTbl().addSlotDescriptor(tupleDescriptor, new SlotId(entry.getKey().getId())); + slotDescriptor.setIsNullable(expr.isNullable()); + slotDescriptor.setIsMaterialized(false); + slotDescriptor.setType(expr.getType()); + context.getColRefToExpr().put(entry.getKey(), new SlotRef(entry.getKey().toString(), slotDescriptor)); + } + } + return commonSubExprMap; } private JoinExprInfo buildJoinExpr(OptExpression optExpr, ExecPlan context) { @@ -3699,13 +3714,18 @@ private JoinExprInfo buildJoinExpr(OptExpression optExpr, ExecPlan context) { List otherJoinConjuncts = otherJoin.stream().map(e -> ScalarOperatorToExpr.buildExecExpression(e, new ScalarOperatorToExpr.FormatterContext(context.getColRefToExpr()))) .collect(Collectors.toList()); + Map commonSubExprMap = Maps.newHashMap(); + if (optExpr.getOp() instanceof PhysicalJoinOperator) { + PhysicalJoinOperator joinOperator = (PhysicalJoinOperator) optExpr.getOp(); + commonSubExprMap = buildCommonSubExprMap(joinOperator.getPredicateCommonOperators(), context); + } List predicates = Utils.extractConjuncts(predicate); List conjuncts = predicates.stream().map(e -> ScalarOperatorToExpr.buildExecExpression(e, new ScalarOperatorToExpr.FormatterContext(context.getColRefToExpr()))) .collect(Collectors.toList()); - return new JoinExprInfo(eqJoinConjuncts, otherJoinConjuncts, conjuncts); + return new JoinExprInfo(eqJoinConjuncts, otherJoinConjuncts, conjuncts, commonSubExprMap); } // TODO(murphy) consider state distribution diff --git a/fe/fe-core/src/test/java/com/starrocks/planner/PushDownSubfieldHashJoinTest.java b/fe/fe-core/src/test/java/com/starrocks/planner/PushDownSubfieldHashJoinTest.java index 64e42a508e0140..2089f7a52acbd1 100644 --- a/fe/fe-core/src/test/java/com/starrocks/planner/PushDownSubfieldHashJoinTest.java +++ b/fe/fe-core/src/test/java/com/starrocks/planner/PushDownSubfieldHashJoinTest.java @@ -16,6 +16,8 @@ import com.starrocks.common.FeConstants; import com.starrocks.qe.ConnectContext; +import com.starrocks.sql.plan.PlanTestBase; +import com.starrocks.sql.plan.PlanTestNoneDBBase; import com.starrocks.statistic.StatsConstants; import com.starrocks.utframe.StarRocksAssert; import com.starrocks.utframe.UtFrameUtils; @@ -111,9 +113,15 @@ public void test() throws Exception { " | colocate: false, reason: \n" + " | equal join conjunct: 3: fk = 1: fk\n" + " | other predicates: CAST(array_sum(array_map( -> != 'A', " + - "if(array_length(array_filter(['A','B'], CAST([0,CAST((2: col_int = 1) AND " + - "(4: id IS NOT NULL) AS TINYINT)] AS ARRAY))) = 0, ['C'], " + - "array_filter(['A','B'], CAST([0,CAST((2: col_int = 1) AND (4: id IS NOT NULL) AS TINYINT)] " + - "AS ARRAY))))) AS BOOLEAN)")); + "if(array_length(26: array_filter) = 0, ['C'], 26: array_filter))) AS BOOLEAN)\n" + + " | common sub expr:\n" + + " | : 2: col_int = 1\n" + + " | : 4: id IS NOT NULL\n" + + " | : (20: expr) AND (21: expr)\n" + + " | : CAST(22: expr AS TINYINT)\n" + + " | : [0,CAST((2: col_int = 1) AND (4: id IS NOT NULL) AS TINYINT)]\n" + + " | : CAST([0,CAST((2: col_int = 1) AND (4: id IS NOT NULL) AS TINYINT)] AS ARRAY)\n" + + " | : array_filter(['A','B'], 25: cast)")); + } } \ No newline at end of file diff --git a/fe/fe-core/src/test/java/com/starrocks/sql/optimizer/JoinPredicateExprReuseTest.java b/fe/fe-core/src/test/java/com/starrocks/sql/optimizer/JoinPredicateExprReuseTest.java new file mode 100644 index 00000000000000..0c3c33df416032 --- /dev/null +++ b/fe/fe-core/src/test/java/com/starrocks/sql/optimizer/JoinPredicateExprReuseTest.java @@ -0,0 +1,130 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.starrocks.sql.optimizer; + +import com.starrocks.sql.plan.PlanTestBase; +import org.junit.jupiter.api.Test; + +public class JoinPredicateExprReuseTest extends PlanTestBase { + @Test + public void testHashJoin() throws Exception { + { + String sql = "select * from t0 left join t1 on t0.v1 = t1.v4 where " + + "abs(t0.v1 + t1.v4) = abs(t0.v2 + t1.v5) and abs(t0.v1 + t1.v4) > 5"; + String plan = getFragmentPlan(sql); + assertContains(plan, " | equal join conjunct: 1: v1 = 4: v4\n" + + " | other predicates: 8: abs = abs(2: v2 + 5: v5), 8: abs > 5\n" + + " | common sub expr:\n" + + " | : 1: v1 + 4: v4\n" + + " | : abs(7: add)"); + } + + { + String sql = "select * from t0 left join t1 on t0.v1 = t1.v4 where " + + "bit_shift_left(t0.v1 + t1.v4, 1) = 10 or bit_shift_left(t0.v1 + t1.v4, 1) = 20"; + String plan = getFragmentPlan(sql); + assertContains(plan, " | equal join conjunct: 1: v1 = 4: v4\n" + + " | other predicates: (8: bit_shift_left = 10) OR (8: bit_shift_left = 20), " + + "8: bit_shift_left IN (10, 20)\n" + + " | common sub expr:\n" + + " | : 1: v1 + 4: v4\n" + + " | : 7: add BITSHIFTLEFT 1"); + } + { + String sql = "select * from t0 right join t1 on t0.v1 = t1.v4 where " + + "abs(t0.v1 + t1.v4) = abs(t0.v2 + t1.v5) and abs(t0.v1 + t1.v4) > 5"; + String plan = getFragmentPlan(sql); + assertContains(plan, " | equal join conjunct: 1: v1 = 4: v4\n" + + " | other predicates: 8: abs = abs(2: v2 + 5: v5), 8: abs > 5\n" + + " | common sub expr:\n" + + " | : 1: v1 + 4: v4\n" + + " | : abs(7: add)"); + } + + { + String sql = "select * from t0 right join t1 on t0.v1 = t1.v4 where " + + "bit_shift_left(t0.v1 + t1.v4, 1) = 10 or bit_shift_left(t0.v1 + t1.v4, 1) = 20"; + String plan = getFragmentPlan(sql); + assertContains(plan, " | equal join conjunct: 1: v1 = 4: v4\n" + + " | other predicates: (8: bit_shift_left = 10) OR (8: bit_shift_left = 20), " + + "8: bit_shift_left IN (10, 20)\n" + + " | common sub expr:\n" + + " | : 1: v1 + 4: v4\n" + + " | : 7: add BITSHIFTLEFT 1"); + } + } + + @Test + public void testNestLoopJoin() throws Exception { + { + String sql = "select * from t0 left join t1 on t0.v1 > t1.v4 where " + + "abs(t0.v1 + t1.v4) = abs(t0.v2 + t1.v5) and abs(t0.v1 + t1.v4) > 5"; + String plan = getFragmentPlan(sql); + assertContains(plan, " | other join predicates: 1: v1 > 4: v4\n" + + " | other predicates: 8: abs = abs(2: v2 + 5: v5), 8: abs > 5\n" + + " | common sub expr:\n" + + " | : 1: v1 + 4: v4\n" + + " | : abs(7: add)"); + } + + { + String sql = "select * from t0 left join t1 on t0.v1 > t1.v4 where " + + "bit_shift_left(t0.v1 + t1.v4, 1) = 10 or bit_shift_left(t0.v1 + t1.v4, 1) = 20"; + String plan = getFragmentPlan(sql); + assertContains(plan, " | other join predicates: 1: v1 > 4: v4\n" + + " | other predicates: (8: bit_shift_left = 10) OR (8: bit_shift_left = 20), " + + "8: bit_shift_left IN (10, 20)\n" + + " | common sub expr:\n" + + " | : 1: v1 + 4: v4\n" + + " | : 7: add BITSHIFTLEFT 1"); + } + + { + String sql = "select * from t0 right join t1 on t0.v1 > t1.v4 and t0.v2 = t1.v5 where " + + "abs(t0.v1 + t1.v4) = abs(t0.v2 + t1.v5) and abs(t0.v1 + t1.v4) > 5"; + String plan = getFragmentPlan(sql); + assertContains(plan, " | equal join conjunct: 2: v2 = 5: v5\n" + + " | other join predicates: 1: v1 > 4: v4\n" + + " | other predicates: 8: abs = abs(2: v2 + 5: v5), 8: abs > 5\n" + + " | common sub expr:\n" + + " | : 1: v1 + 4: v4\n" + + " | : abs(7: add)"); + } + + { + String sql = "select * from t0 right join t1 on t0.v1 > t1.v4 where " + + "bit_shift_left(t0.v1 + t1.v4, 1) = 10 or bit_shift_left(t0.v1 + t1.v4, 1) = 20"; + String plan = getFragmentPlan(sql); + assertContains(plan, " | other join predicates: 1: v1 > 4: v4\n" + + " | other predicates: (8: bit_shift_left = 10) OR (8: bit_shift_left = 20), " + + "8: bit_shift_left IN (10, 20)\n" + + " | common sub expr:\n" + + " | : 1: v1 + 4: v4\n" + + " | : 7: add BITSHIFTLEFT 1"); + } + + { + String sql = "select * from t0 left join t1 on t0.v1 = t1.v4 where " + + "abs(t0.v1 + t1.v4) > 5 and abs(t0.v1 + t1.v4) < 10"; + String plan = getFragmentPlan(sql); + assertContains(plan, " | equal join conjunct: 1: v1 = 4: v4\n" + + " | other predicates: 8: abs > 5, 8: abs < 10\n" + + " | common sub expr:\n" + + " | : 1: v1 + 4: v4\n" + + " | : abs(7: add)"); + } + } + +} diff --git a/fe/fe-core/src/test/java/com/starrocks/sql/plan/JoinTest.java b/fe/fe-core/src/test/java/com/starrocks/sql/plan/JoinTest.java index b39f80fb2e0163..8d4b08d7f765ff 100644 --- a/fe/fe-core/src/test/java/com/starrocks/sql/plan/JoinTest.java +++ b/fe/fe-core/src/test/java/com/starrocks/sql/plan/JoinTest.java @@ -3326,12 +3326,11 @@ public void testPredicatePushDown() throws Exception { " | colocate: false, reason: \n" + " | equal join conjunct: 1: v1 = 5: v4\n" + " | equal join conjunct: 4: count = 8: count\n" + - " | other predicates: ((1: v1 != 1) AND (if(8: count = 1, 'a', 'b') = 'b')) OR ((1: v1 = 1) AND " + - "(if(8: count = 1, 'a', 'b') = 'b')), if(8: count = 1, 'a', 'b') = 'b'\n" + - " | \n" + - " |----5:EXCHANGE\n" + - " | \n" + - " 2:EXCHANGE"); + " | other predicates: ((1: v1 != 1) AND (31: expr)) OR ((1: v1 = 1) AND (31: expr)), 31: expr\n" + + " | common sub expr:\n" + + " | : 8: count = 1\n" + + " | : if(29: expr, 'a', 'b')\n" + + " | : 30: if = 'b'"); } } diff --git a/fe/fe-core/src/test/java/com/starrocks/sql/plan/LowCardinalityTest.java b/fe/fe-core/src/test/java/com/starrocks/sql/plan/LowCardinalityTest.java index 182f14bcbee73c..3f0dd0685e34ca 100644 --- a/fe/fe-core/src/test/java/com/starrocks/sql/plan/LowCardinalityTest.java +++ b/fe/fe-core/src/test/java/com/starrocks/sql/plan/LowCardinalityTest.java @@ -811,10 +811,11 @@ public void testProject() throws Exception { // test input two string column sql = "select if(S_ADDRESS='kks', S_COMMENT, S_COMMENT) from supplier"; plan = getVerboseExplain(sql); - Assertions.assertTrue(plan.contains( - "9 <-> if[(DictDecode(10: S_ADDRESS, [ = 'kks']), DictDecode(11: S_COMMENT, []), " + - "DictDecode(11: S_COMMENT, [])); args: BOOLEAN,VARCHAR,VARCHAR; " + - "result: VARCHAR; args nullable: true; result nullable: true]")); + assertContains(plan, " | 9 <-> if[(DictDecode(10: S_ADDRESS, [ = 'kks']), " + + "[12: expr, VARCHAR(101), true], [12: expr, VARCHAR(101), true]); " + + "args: BOOLEAN,VARCHAR,VARCHAR; result: VARCHAR; args nullable: true; result nullable: true]\n" + + " | common expressions:\n" + + " | 12 <-> DictDecode(11: S_COMMENT, [])"); assertNotContains(plan, "DecodeNode"); // common expression reuse 3 @@ -825,7 +826,10 @@ public void testProject() throws Exception { // support(support(unsupport(Column), unsupport(Column))) sql = "select REVERSE(SUBSTR(LEFT(REVERSE(S_ADDRESS),INSTR(REVERSE(S_ADDRESS),'/')-1),5)) FROM supplier"; plan = getFragmentPlan(sql); - assertContains(plan, " : reverse(substr(left(DictDecode(10: S_ADDRESS, [reverse()])"); + assertContains(plan, " | : " + + "reverse(substr(left(11: expr, CAST(CAST(instr(11: expr, '/') AS BIGINT) - 1 AS INT)), 5))\n" + + " | common expressions:\n" + + " | : DictDecode(10: S_ADDRESS, [reverse()])"); } @Test diff --git a/fe/fe-core/src/test/java/com/starrocks/sql/plan/LowCardinalityTest2.java b/fe/fe-core/src/test/java/com/starrocks/sql/plan/LowCardinalityTest2.java index 41fa1453f200d7..2a340bf691744d 100644 --- a/fe/fe-core/src/test/java/com/starrocks/sql/plan/LowCardinalityTest2.java +++ b/fe/fe-core/src/test/java/com/starrocks/sql/plan/LowCardinalityTest2.java @@ -923,7 +923,10 @@ public void testProject() throws Exception { // support(support(unsupport(Column), unsupport(Column))) sql = "select REVERSE(SUBSTR(LEFT(REVERSE(S_ADDRESS),INSTR(REVERSE(S_ADDRESS),'/')-1),5)) FROM supplier"; plan = getFragmentPlan(sql); - assertContains(plan, " : reverse(substr(left(DictDecode(10: S_ADDRESS, [reverse()])"); + assertContains(plan, " | : reverse(substr(left(11: expr, " + + "CAST(CAST(instr(11: expr, '/') AS BIGINT) - 1 AS INT)), 5))\n" + + " | common expressions:\n" + + " | : DictDecode(10: S_ADDRESS, [reverse()])"); } @Test diff --git a/gensrc/thrift/PlanNodes.thrift b/gensrc/thrift/PlanNodes.thrift index 6a9b3021051425..a2c46c0690bfc2 100644 --- a/gensrc/thrift/PlanNodes.thrift +++ b/gensrc/thrift/PlanNodes.thrift @@ -766,6 +766,8 @@ struct THashJoinNode { 56: optional bool late_materialization = false 57: optional bool enable_partition_hash_join = false 58: optional bool is_skew_join = false + + 59: optional map common_slot_map } struct TMergeJoinNode { @@ -805,6 +807,7 @@ struct TNestLoopJoinNode { 3: optional list join_conjuncts 4: optional string sql_join_conjuncts 5: optional bool interpolate_passthrough = false + 6: optional map common_slot_map } enum TAggregationOp { diff --git a/test/sql/test_join/R/test_predicate_expr_reuse b/test/sql/test_join/R/test_predicate_expr_reuse new file mode 100644 index 00000000000000..462b736b15e073 --- /dev/null +++ b/test/sql/test_join/R/test_predicate_expr_reuse @@ -0,0 +1,135 @@ +-- name: test_outer_join_predicate_expr_reuse +CREATE TABLE t0 ( + v1 INT, + v2 INT, + v3 VARCHAR(20) +) DUPLICATE KEY(v1) +DISTRIBUTED BY HASH(v1) BUCKETS 3 +PROPERTIES ( + "replication_num" = "1" +); +-- result: +-- !result +CREATE TABLE t1 ( + v4 INT, + v5 INT, + v6 VARCHAR(20) +) DUPLICATE KEY(v4) +DISTRIBUTED BY HASH(v4) BUCKETS 3 +PROPERTIES ( + "replication_num" = "1" +); +-- result: +-- !result +INSERT INTO t0 VALUES +(1, 10, 'a'), (2, 20, 'b'), (3, 30, 'c'), (4, 40, 'd'), (5, 50, 'e'), +(6, 60, 'f'), (7, 70, 'g'), (8, 80, 'h'), (9, 90, 'i'), (10, 100, 'j'), +(11, 110, 'a'), (12, 120, 'b'), (13, 130, 'c'), (14, 140, 'd'), (15, 150, 'e'), +(16, 160, 'f'), (17, 170, 'g'), (18, 180, 'h'), (19, 190, 'i'), (20, 200, 'j'), +(21, 210, 'a'), (22, 220, 'b'), (23, 230, 'c'), (24, 240, 'd'), (25, 250, 'e'), +(26, 260, 'f'), (27, 270, 'g'), (28, 280, 'h'), (29, 290, 'i'), (30, 300, 'j'), +(31, 310, 'a'), (32, 320, 'b'), (33, 330, 'c'), (34, 340, 'd'), (35, 350, 'e'), +(36, 360, 'f'), (37, 370, 'g'), (38, 380, 'h'), (39, 390, 'i'), (40, 400, 'j'), +(41, 410, 'a'), (42, 420, 'b'), (43, 430, 'c'), (44, 440, 'd'), (45, 450, 'e'), +(46, 460, 'f'), (47, 470, 'g'), (48, 480, 'h'), (49, 490, 'i'), (50, 500, 'j'), +(51, 510, 'a'), (52, 520, 'b'), (53, 530, 'c'), (54, 540, 'd'), (55, 550, 'e'), +(56, 560, 'f'), (57, 570, 'g'), (58, 580, 'h'), (59, 590, 'i'), (60, 600, 'j'), +(61, 610, 'a'), (62, 620, 'b'), (63, 630, 'c'), (64, 640, 'd'), (65, 650, 'e'), +(66, 660, 'f'), (67, 670, 'g'), (68, 680, 'h'), (69, 690, 'i'), (70, 700, 'j'), +(71, 710, 'a'), (72, 720, 'b'), (73, 730, 'c'), (74, 740, 'd'), (75, 750, 'e'), +(76, 760, 'f'), (77, 770, 'g'), (78, 780, 'h'), (79, 790, 'i'), (80, 800, 'j'), +(81, 810, 'a'), (82, 820, 'b'), (83, 830, 'c'), (84, 840, 'd'), (85, 850, 'e'), +(86, 860, 'f'), (87, 870, 'g'), (88, 880, 'h'), (89, 890, 'i'), (90, 900, 'j'), +(91, 910, 'a'), (92, 920, 'b'), (93, 930, 'c'), (94, 940, 'd'), (95, 950, 'e'), +(96, 960, 'f'), (97, 970, 'g'), (98, 980, 'h'), (99, 990, 'i'), (100, 1000, 'j'); +-- result: +-- !result +INSERT INTO t1 VALUES +(1, 15, 'x'), (2, 25, 'y'), (3, 35, 'z'), (4, 45, 'w'), (5, 55, 'v'), +(6, 65, 'u'), (7, 75, 't'), (8, 85, 's'), (9, 95, 'r'), (10, 105, 'q'), +(11, 115, 'x'), (12, 125, 'y'), (13, 135, 'z'), (14, 145, 'w'), (15, 155, 'v'), +(16, 165, 'u'), (17, 175, 't'), (18, 185, 's'), (19, 195, 'r'), (20, 205, 'q'), +(21, 215, 'x'), (22, 225, 'y'), (23, 235, 'z'), (24, 245, 'w'), (25, 255, 'v'), +(26, 265, 'u'), (27, 275, 't'), (28, 285, 's'), (29, 295, 'r'), (30, 305, 'q'), +(31, 315, 'x'), (32, 325, 'y'), (33, 335, 'z'), (34, 345, 'w'), (35, 355, 'v'), +(36, 365, 'u'), (37, 375, 't'), (38, 385, 's'), (39, 395, 'r'), (40, 405, 'q'), +(41, 415, 'x'), (42, 425, 'y'), (43, 435, 'z'), (44, 445, 'w'), (45, 455, 'v'), +(46, 465, 'u'), (47, 475, 't'), (48, 485, 's'), (49, 495, 'r'), (50, 505, 'q'), +(51, 515, 'x'), (52, 525, 'y'), (53, 535, 'z'), (54, 545, 'w'), (55, 555, 'v'), +(56, 565, 'u'), (57, 575, 't'), (58, 585, 's'), (59, 595, 'r'), (60, 605, 'q'), +(61, 615, 'x'), (62, 625, 'y'), (63, 635, 'z'), (64, 645, 'w'), (65, 655, 'v'), +(66, 665, 'u'), (67, 675, 't'), (68, 685, 's'), (69, 695, 'r'), (70, 705, 'q'), +(71, 715, 'x'), (72, 725, 'y'), (73, 735, 'z'), (74, 745, 'w'), (75, 755, 'v'), +(76, 765, 'u'), (77, 775, 't'), (78, 785, 's'), (79, 795, 'r'), (80, 805, 'q'), +(81, 815, 'x'), (82, 825, 'y'), (83, 835, 'z'), (84, 845, 'w'), (85, 855, 'v'), +(86, 865, 'u'), (87, 875, 't'), (88, 885, 's'), (89, 895, 'r'), (90, 905, 'q'), +(91, 915, 'x'), (92, 925, 'y'), (93, 935, 'z'), (94, 945, 'w'), (95, 955, 'v'), +(96, 965, 'u'), (97, 975, 't'), (98, 985, 's'), (99, 995, 'r'), (100, 1005, 'q'); +-- result: +-- !result +SELECT COUNT(*) FROM t0 LEFT JOIN t1 ON t0.v1 = t1.v4 +WHERE abs(t0.v1 + t1.v4) = abs(t0.v2 + t1.v5) AND abs(t0.v1 + t1.v4) > 5; +-- result: +0 +-- !result +SELECT COUNT(*) FROM t0 LEFT JOIN t1 ON t0.v1 = t1.v4 +WHERE bit_shift_left(t0.v1 + t1.v4, 1) = 10 OR bit_shift_left(t0.v1 + t1.v4, 1) = 20; +-- result: +1 +-- !result +SELECT COUNT(*) FROM t0 RIGHT JOIN t1 ON t0.v1 = t1.v4 +WHERE abs(t0.v1 + t1.v4) = abs(t0.v2 + t1.v5) AND abs(t0.v1 + t1.v4) > 5; +-- result: +0 +-- !result +SELECT COUNT(*) FROM t0 RIGHT JOIN t1 ON t0.v1 = t1.v4 +WHERE bit_shift_left(t0.v1 + t1.v4, 1) = 10 OR bit_shift_left(t0.v1 + t1.v4, 1) = 20; +-- result: +1 +-- !result +SELECT COUNT(*) FROM t0 LEFT JOIN t1 ON t0.v1 > t1.v4 +WHERE abs(t0.v1 + t1.v4) = abs(t0.v2 + t1.v5) AND abs(t0.v1 + t1.v4) > 5; +-- result: +0 +-- !result +SELECT COUNT(*) FROM t0 LEFT JOIN t1 ON t0.v1 > t1.v4 +WHERE bit_shift_left(t0.v1 + t1.v4, 1) = 10 OR bit_shift_left(t0.v1 + t1.v4, 1) = 20; +-- result: +6 +-- !result +SELECT COUNT(*) FROM t0 RIGHT JOIN t1 ON t0.v1 > t1.v4 AND t0.v2 = t1.v5 +WHERE abs(t0.v1 + t1.v4) = abs(t0.v2 + t1.v5) AND abs(t0.v1 + t1.v4) > 5; +-- result: +0 +-- !result +SELECT COUNT(*) FROM t0 LEFT JOIN t1 ON t0.v1 = t1.v4 +WHERE abs(t0.v1 + t1.v4) > 5 AND abs(t0.v1 + t1.v4) < 10; +-- result: +2 +-- !result +SELECT COUNT(*) FROM t0 LEFT JOIN t1 ON t0.v1 = t1.v4 +WHERE abs(t0.v1 + t1.v4) = abs(t0.v2 + t1.v5) + AND bit_shift_left(t0.v1 + t1.v4, 1) > 10 + AND bit_shift_left(t0.v1 + t1.v4, 1) < 20; +-- result: +0 +-- !result +SELECT COUNT(*) FROM t0 RIGHT JOIN t1 ON t0.v1 > t1.v4 +WHERE (abs(t0.v1 + t1.v4) > 5 AND abs(t0.v1 + t1.v4) < 10) + OR (abs(t0.v1 + t1.v4) > 15 AND abs(t0.v1 + t1.v4) < 20); +-- result: +44 +-- !result +SELECT COUNT(*) FROM t0 LEFT JOIN t1 ON t0.v1 = t1.v4 +WHERE abs(bit_shift_left(t0.v1 + t1.v4, 1)) = abs(bit_shift_left(t0.v2 + t1.v5, 1)) + AND abs(bit_shift_left(t0.v1 + t1.v4, 1)) > 10; +-- result: +0 +-- !result +SELECT COUNT(*) FROM t0 RIGHT JOIN t1 ON t0.v1 = t1.v4 +WHERE (t0.v1 + t1.v4) * 2 = (t0.v2 + t1.v5) * 2 + AND (t0.v1 + t1.v4) * 2 > 10 + AND (t0.v1 + t1.v4) * 2 < 100; +-- result: +0 +-- !result \ No newline at end of file diff --git a/test/sql/test_join/T/test_predicate_expr_reuse b/test/sql/test_join/T/test_predicate_expr_reuse new file mode 100644 index 00000000000000..40cf710449ccbd --- /dev/null +++ b/test/sql/test_join/T/test_predicate_expr_reuse @@ -0,0 +1,106 @@ +-- name: test_outer_join_predicate_expr_reuse +CREATE TABLE t0 ( + v1 INT, + v2 INT, + v3 VARCHAR(20) +) DUPLICATE KEY(v1) +DISTRIBUTED BY HASH(v1) BUCKETS 3 +PROPERTIES ( + "replication_num" = "1" +); + +CREATE TABLE t1 ( + v4 INT, + v5 INT, + v6 VARCHAR(20) +) DUPLICATE KEY(v4) +DISTRIBUTED BY HASH(v4) BUCKETS 3 +PROPERTIES ( + "replication_num" = "1" +); + +INSERT INTO t0 VALUES +(1, 10, 'a'), (2, 20, 'b'), (3, 30, 'c'), (4, 40, 'd'), (5, 50, 'e'), +(6, 60, 'f'), (7, 70, 'g'), (8, 80, 'h'), (9, 90, 'i'), (10, 100, 'j'), +(11, 110, 'a'), (12, 120, 'b'), (13, 130, 'c'), (14, 140, 'd'), (15, 150, 'e'), +(16, 160, 'f'), (17, 170, 'g'), (18, 180, 'h'), (19, 190, 'i'), (20, 200, 'j'), +(21, 210, 'a'), (22, 220, 'b'), (23, 230, 'c'), (24, 240, 'd'), (25, 250, 'e'), +(26, 260, 'f'), (27, 270, 'g'), (28, 280, 'h'), (29, 290, 'i'), (30, 300, 'j'), +(31, 310, 'a'), (32, 320, 'b'), (33, 330, 'c'), (34, 340, 'd'), (35, 350, 'e'), +(36, 360, 'f'), (37, 370, 'g'), (38, 380, 'h'), (39, 390, 'i'), (40, 400, 'j'), +(41, 410, 'a'), (42, 420, 'b'), (43, 430, 'c'), (44, 440, 'd'), (45, 450, 'e'), +(46, 460, 'f'), (47, 470, 'g'), (48, 480, 'h'), (49, 490, 'i'), (50, 500, 'j'), +(51, 510, 'a'), (52, 520, 'b'), (53, 530, 'c'), (54, 540, 'd'), (55, 550, 'e'), +(56, 560, 'f'), (57, 570, 'g'), (58, 580, 'h'), (59, 590, 'i'), (60, 600, 'j'), +(61, 610, 'a'), (62, 620, 'b'), (63, 630, 'c'), (64, 640, 'd'), (65, 650, 'e'), +(66, 660, 'f'), (67, 670, 'g'), (68, 680, 'h'), (69, 690, 'i'), (70, 700, 'j'), +(71, 710, 'a'), (72, 720, 'b'), (73, 730, 'c'), (74, 740, 'd'), (75, 750, 'e'), +(76, 760, 'f'), (77, 770, 'g'), (78, 780, 'h'), (79, 790, 'i'), (80, 800, 'j'), +(81, 810, 'a'), (82, 820, 'b'), (83, 830, 'c'), (84, 840, 'd'), (85, 850, 'e'), +(86, 860, 'f'), (87, 870, 'g'), (88, 880, 'h'), (89, 890, 'i'), (90, 900, 'j'), +(91, 910, 'a'), (92, 920, 'b'), (93, 930, 'c'), (94, 940, 'd'), (95, 950, 'e'), +(96, 960, 'f'), (97, 970, 'g'), (98, 980, 'h'), (99, 990, 'i'), (100, 1000, 'j'); + +INSERT INTO t1 VALUES +(1, 15, 'x'), (2, 25, 'y'), (3, 35, 'z'), (4, 45, 'w'), (5, 55, 'v'), +(6, 65, 'u'), (7, 75, 't'), (8, 85, 's'), (9, 95, 'r'), (10, 105, 'q'), +(11, 115, 'x'), (12, 125, 'y'), (13, 135, 'z'), (14, 145, 'w'), (15, 155, 'v'), +(16, 165, 'u'), (17, 175, 't'), (18, 185, 's'), (19, 195, 'r'), (20, 205, 'q'), +(21, 215, 'x'), (22, 225, 'y'), (23, 235, 'z'), (24, 245, 'w'), (25, 255, 'v'), +(26, 265, 'u'), (27, 275, 't'), (28, 285, 's'), (29, 295, 'r'), (30, 305, 'q'), +(31, 315, 'x'), (32, 325, 'y'), (33, 335, 'z'), (34, 345, 'w'), (35, 355, 'v'), +(36, 365, 'u'), (37, 375, 't'), (38, 385, 's'), (39, 395, 'r'), (40, 405, 'q'), +(41, 415, 'x'), (42, 425, 'y'), (43, 435, 'z'), (44, 445, 'w'), (45, 455, 'v'), +(46, 465, 'u'), (47, 475, 't'), (48, 485, 's'), (49, 495, 'r'), (50, 505, 'q'), +(51, 515, 'x'), (52, 525, 'y'), (53, 535, 'z'), (54, 545, 'w'), (55, 555, 'v'), +(56, 565, 'u'), (57, 575, 't'), (58, 585, 's'), (59, 595, 'r'), (60, 605, 'q'), +(61, 615, 'x'), (62, 625, 'y'), (63, 635, 'z'), (64, 645, 'w'), (65, 655, 'v'), +(66, 665, 'u'), (67, 675, 't'), (68, 685, 's'), (69, 695, 'r'), (70, 705, 'q'), +(71, 715, 'x'), (72, 725, 'y'), (73, 735, 'z'), (74, 745, 'w'), (75, 755, 'v'), +(76, 765, 'u'), (77, 775, 't'), (78, 785, 's'), (79, 795, 'r'), (80, 805, 'q'), +(81, 815, 'x'), (82, 825, 'y'), (83, 835, 'z'), (84, 845, 'w'), (85, 855, 'v'), +(86, 865, 'u'), (87, 875, 't'), (88, 885, 's'), (89, 895, 'r'), (90, 905, 'q'), +(91, 915, 'x'), (92, 925, 'y'), (93, 935, 'z'), (94, 945, 'w'), (95, 955, 'v'), +(96, 965, 'u'), (97, 975, 't'), (98, 985, 's'), (99, 995, 'r'), (100, 1005, 'q'); + +SELECT COUNT(*) FROM t0 LEFT JOIN t1 ON t0.v1 = t1.v4 +WHERE abs(t0.v1 + t1.v4) = abs(t0.v2 + t1.v5) AND abs(t0.v1 + t1.v4) > 5; + +SELECT COUNT(*) FROM t0 LEFT JOIN t1 ON t0.v1 = t1.v4 +WHERE bit_shift_left(t0.v1 + t1.v4, 1) = 10 OR bit_shift_left(t0.v1 + t1.v4, 1) = 20; + +SELECT COUNT(*) FROM t0 RIGHT JOIN t1 ON t0.v1 = t1.v4 +WHERE abs(t0.v1 + t1.v4) = abs(t0.v2 + t1.v5) AND abs(t0.v1 + t1.v4) > 5; + +SELECT COUNT(*) FROM t0 RIGHT JOIN t1 ON t0.v1 = t1.v4 +WHERE bit_shift_left(t0.v1 + t1.v4, 1) = 10 OR bit_shift_left(t0.v1 + t1.v4, 1) = 20; + +SELECT COUNT(*) FROM t0 LEFT JOIN t1 ON t0.v1 > t1.v4 +WHERE abs(t0.v1 + t1.v4) = abs(t0.v2 + t1.v5) AND abs(t0.v1 + t1.v4) > 5; + +SELECT COUNT(*) FROM t0 LEFT JOIN t1 ON t0.v1 > t1.v4 +WHERE bit_shift_left(t0.v1 + t1.v4, 1) = 10 OR bit_shift_left(t0.v1 + t1.v4, 1) = 20; + +SELECT COUNT(*) FROM t0 RIGHT JOIN t1 ON t0.v1 > t1.v4 AND t0.v2 = t1.v5 +WHERE abs(t0.v1 + t1.v4) = abs(t0.v2 + t1.v5) AND abs(t0.v1 + t1.v4) > 5; + +SELECT COUNT(*) FROM t0 LEFT JOIN t1 ON t0.v1 = t1.v4 +WHERE abs(t0.v1 + t1.v4) > 5 AND abs(t0.v1 + t1.v4) < 10; + +SELECT COUNT(*) FROM t0 LEFT JOIN t1 ON t0.v1 = t1.v4 +WHERE abs(t0.v1 + t1.v4) = abs(t0.v2 + t1.v5) + AND bit_shift_left(t0.v1 + t1.v4, 1) > 10 + AND bit_shift_left(t0.v1 + t1.v4, 1) < 20; + +SELECT COUNT(*) FROM t0 RIGHT JOIN t1 ON t0.v1 > t1.v4 +WHERE (abs(t0.v1 + t1.v4) > 5 AND abs(t0.v1 + t1.v4) < 10) + OR (abs(t0.v1 + t1.v4) > 15 AND abs(t0.v1 + t1.v4) < 20); + +SELECT COUNT(*) FROM t0 LEFT JOIN t1 ON t0.v1 = t1.v4 +WHERE abs(bit_shift_left(t0.v1 + t1.v4, 1)) = abs(bit_shift_left(t0.v2 + t1.v5, 1)) + AND abs(bit_shift_left(t0.v1 + t1.v4, 1)) > 10; + +SELECT COUNT(*) FROM t0 RIGHT JOIN t1 ON t0.v1 = t1.v4 +WHERE (t0.v1 + t1.v4) * 2 = (t0.v2 + t1.v5) * 2 + AND (t0.v1 + t1.v4) * 2 > 10 + AND (t0.v1 + t1.v4) * 2 < 100;