Skip to content

Commit 6c220d0

Browse files
[Enhancement] support expr reuse in outer join where predicates (#62139)
Signed-off-by: silverbullet233 <3675229+silverbullet233@users.noreply.github.com>
1 parent b7380d3 commit 6c220d0

30 files changed

+643
-62
lines changed

be/src/exec/cross_join_node.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,13 @@ Status CrossJoinNode::init(const TPlanNode& tnode, RuntimeState* state) {
8888
_build_runtime_filters.emplace_back(rf_desc);
8989
}
9090
}
91+
if (tnode.nestloop_join_node.__isset.common_slot_map) {
92+
for (const auto& [key, val] : tnode.nestloop_join_node.common_slot_map) {
93+
ExprContext* context;
94+
RETURN_IF_ERROR(Expr::create_expr_tree(_pool, val, &context, state, true));
95+
_common_expr_ctxs.insert({key, context});
96+
}
97+
}
9198
return Status::OK();
9299
}
93100

@@ -608,10 +615,10 @@ std::vector<std::shared_ptr<pipeline::OperatorFactory>> CrossJoinNode::_decompos
608615

609616
OpFactories left_ops = _children[0]->decompose_to_pipeline(context);
610617
// communication with CrossJoinRight through shared_data.
611-
auto left_factory =
612-
std::make_shared<ProbeFactory>(context->next_operator_id(), id(), _row_descriptor, child(0)->row_desc(),
613-
child(1)->row_desc(), _sql_join_conjuncts, std::move(_join_conjuncts),
614-
std::move(_conjunct_ctxs), std::move(cross_join_context), _join_op);
618+
auto left_factory = std::make_shared<ProbeFactory>(
619+
context->next_operator_id(), id(), _row_descriptor, child(0)->row_desc(), child(1)->row_desc(),
620+
_sql_join_conjuncts, std::move(_join_conjuncts), std::move(_conjunct_ctxs), std::move(_common_expr_ctxs),
621+
std::move(cross_join_context), _join_op);
615622
// Initialize OperatorFactory's fields involving runtime filters.
616623
this->init_runtime_filter_for_operator(left_factory.get(), context, rc_rf_probe_collector);
617624
if (!context->is_colocate_group()) {

be/src/exec/cross_join_node.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ class CrossJoinNode final : public ExecNode {
128128

129129
std::vector<RuntimeFilterBuildDescriptor*> _build_runtime_filters;
130130
bool _interpolate_passthrough = false;
131+
132+
std::map<SlotId, ExprContext*> _common_expr_ctxs;
131133
};
132134

133135
} // namespace starrocks

be/src/exec/hash_join_node.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,14 @@ Status HashJoinNode::init(const TPlanNode& tnode, RuntimeState* state) {
126126
_build_equivalence_partition_expr_ctxs = _build_expr_ctxs;
127127
}
128128

129+
if (tnode.__isset.hash_join_node && tnode.hash_join_node.__isset.common_slot_map) {
130+
for (const auto& [key, val] : tnode.hash_join_node.common_slot_map) {
131+
ExprContext* context;
132+
RETURN_IF_ERROR(Expr::create_expr_tree(_pool, val, &context, state, true));
133+
_common_expr_ctxs.insert({key, context});
134+
}
135+
}
136+
129137
RETURN_IF_ERROR(Expr::create_expr_trees(_pool, tnode.hash_join_node.other_join_conjuncts,
130138
&_other_join_conjunct_ctxs, state));
131139

@@ -484,7 +492,7 @@ pipeline::OpFactories HashJoinNode::_decompose_to_pipeline(pipeline::PipelineBui
484492
_other_join_conjunct_ctxs, _conjunct_ctxs, child(1)->row_desc(), child(0)->row_desc(),
485493
child(1)->type(), child(0)->type(), child(1)->conjunct_ctxs().empty(), _build_runtime_filters,
486494
_output_slots, _output_slots, context->degree_of_parallelism(), _distribution_mode,
487-
_enable_late_materialization, _enable_partition_hash_join, _is_skew_join);
495+
_enable_late_materialization, _enable_partition_hash_join, _is_skew_join, _common_expr_ctxs);
488496
auto hash_joiner_factory = std::make_shared<starrocks::pipeline::HashJoinerFactory>(param);
489497

490498
// Create a shared RefCountedRuntimeFilterCollector

be/src/exec/hash_join_node.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,8 @@ class HashJoinNode final : public ExecNode {
140140
bool _probe_eos = false; // probe table scan finished;
141141
size_t _runtime_join_filter_pushdown_limit = 1024000;
142142

143+
std::map<SlotId, ExprContext*> _common_expr_ctxs;
144+
143145
RuntimeProfile::Counter* _build_timer = nullptr;
144146
RuntimeProfile::Counter* _build_ht_timer = nullptr;
145147
RuntimeProfile::Counter* _copy_right_table_chunk_timer = nullptr;

be/src/exec/hash_joiner.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "pipeline/hashjoin/hash_joiner_fwd.h"
3434
#include "runtime/current_thread.h"
3535
#include "simd/simd.h"
36+
#include "storage/chunk_helper.h"
3637
#include "util/runtime_profile.h"
3738

3839
namespace starrocks {
@@ -73,6 +74,7 @@ HashJoiner::HashJoiner(const HashJoinerParam& param)
7374
_probe_expr_ctxs(param._probe_expr_ctxs),
7475
_other_join_conjunct_ctxs(param._other_join_conjunct_ctxs),
7576
_conjunct_ctxs(param._conjunct_ctxs),
77+
_common_expr_ctxs(param._common_expr_ctxs),
7678
_build_row_descriptor(param._build_row_descriptor),
7779
_probe_row_descriptor(param._probe_row_descriptor),
7880
_build_node_type(param._build_node_type),
@@ -158,6 +160,11 @@ void HashJoiner::_init_hash_table_param(HashTableParam* param, RuntimeState* sta
158160
param->column_view_concat_rows_limit = state->column_view_concat_rows_limit();
159161
param->column_view_concat_bytes_limit = state->column_view_concat_bytes_limit();
160162
std::set<SlotId> predicate_slots;
163+
for (const auto& [slot_id, ctx] : _common_expr_ctxs) {
164+
std::vector<SlotId> expr_slots;
165+
ctx->root()->get_slot_ids(&expr_slots);
166+
predicate_slots.insert(expr_slots.begin(), expr_slots.end());
167+
}
161168
for (ExprContext* expr_context : _conjunct_ctxs) {
162169
std::vector<SlotId> expr_slots;
163170
expr_context->root()->get_slot_ids(&expr_slots);
@@ -388,6 +395,9 @@ Status HashJoiner::_calc_filter_for_other_conjunct(ChunkPtr* chunk, Filter& filt
388395
hit_all = false;
389396
filter.assign((*chunk)->num_rows(), 1);
390397

398+
CommonExprEvalScopeGuard guard(*chunk, _common_expr_ctxs);
399+
RETURN_IF_ERROR(guard.evaluate());
400+
391401
for (auto* ctx : _other_join_conjunct_ctxs) {
392402
ASSIGN_OR_RETURN(ColumnPtr column, ctx->evaluate((*chunk).get()))
393403
size_t true_count = ColumnHelper::count_true_with_notnull(column);
@@ -516,6 +526,8 @@ Status HashJoiner::_process_other_conjunct(ChunkPtr* chunk, JoinHashTable& hash_
516526

517527
Status HashJoiner::_process_where_conjunct(ChunkPtr* chunk) {
518528
SCOPED_TIMER(probe_metrics().where_conjunct_evaluate_timer);
529+
CommonExprEvalScopeGuard guard(*chunk, _common_expr_ctxs);
530+
RETURN_IF_ERROR(guard.evaluate());
519531
return ExecNode::eval_conjuncts(_conjunct_ctxs, (*chunk).get());
520532
}
521533

be/src/exec/hash_joiner.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ struct HashJoinerParam {
7272
bool build_conjunct_ctxs_is_empty, std::list<RuntimeFilterBuildDescriptor*> build_runtime_filters,
7373
std::set<SlotId> build_output_slots, std::set<SlotId> probe_output_slots, size_t max_dop,
7474
const TJoinDistributionMode::type distribution_mode, bool enable_late_materialization,
75-
bool enable_partition_hash_join, bool is_skew_join)
75+
bool enable_partition_hash_join, bool is_skew_join,
76+
const std::map<SlotId, ExprContext*>& common_expr_ctxs)
7677
: _pool(pool),
7778
_hash_join_node(hash_join_node),
7879
_is_null_safes(std::move(is_null_safes)),
@@ -92,7 +93,8 @@ struct HashJoinerParam {
9293
_distribution_mode(distribution_mode),
9394
_enable_late_materialization(enable_late_materialization),
9495
_enable_partition_hash_join(enable_partition_hash_join),
95-
_is_skew_join(is_skew_join) {}
96+
_is_skew_join(is_skew_join),
97+
_common_expr_ctxs(common_expr_ctxs) {}
9698

9799
HashJoinerParam(HashJoinerParam&&) = default;
98100
HashJoinerParam(HashJoinerParam&) = default;
@@ -120,6 +122,7 @@ struct HashJoinerParam {
120122
const bool _enable_late_materialization;
121123
const bool _enable_partition_hash_join;
122124
const bool _is_skew_join;
125+
const std::map<SlotId, ExprContext*> _common_expr_ctxs;
123126
};
124127

125128
inline bool could_short_circuit(TJoinOp::type join_type) {
@@ -439,6 +442,7 @@ class HashJoiner final : public pipeline::ContextWithDependency {
439442
const std::vector<ExprContext*>& _other_join_conjunct_ctxs;
440443
// Conjuncts in Join followed by a filter predicate, usually in Where and Having.
441444
const std::vector<ExprContext*>& _conjunct_ctxs;
445+
const std::map<SlotId, ExprContext*>& _common_expr_ctxs;
442446
const RowDescriptor& _build_row_descriptor;
443447
const RowDescriptor& _probe_row_descriptor;
444448
const TPlanNodeType::type _build_node_type;

be/src/exec/pipeline/hashjoin/hash_joiner_factory.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,19 @@ namespace starrocks::pipeline {
1919
Status HashJoinerFactory::prepare(RuntimeState* state) {
2020
RETURN_IF_ERROR(Expr::prepare(_param._build_expr_ctxs, state));
2121
RETURN_IF_ERROR(Expr::prepare(_param._probe_expr_ctxs, state));
22+
RETURN_IF_ERROR(Expr::prepare(_param._common_expr_ctxs, state));
2223
RETURN_IF_ERROR(Expr::prepare(_param._other_join_conjunct_ctxs, state));
2324
RETURN_IF_ERROR(Expr::prepare(_param._conjunct_ctxs, state));
2425
RETURN_IF_ERROR(Expr::open(_param._build_expr_ctxs, state));
2526
RETURN_IF_ERROR(Expr::open(_param._probe_expr_ctxs, state));
27+
RETURN_IF_ERROR(Expr::open(_param._common_expr_ctxs, state));
2628
RETURN_IF_ERROR(Expr::open(_param._other_join_conjunct_ctxs, state));
2729
RETURN_IF_ERROR(Expr::open(_param._conjunct_ctxs, state));
2830
return Status::OK();
2931
}
3032

3133
void HashJoinerFactory::close(RuntimeState* state) {
34+
Expr::close(_param._common_expr_ctxs, state);
3235
Expr::close(_param._conjunct_ctxs, state);
3336
Expr::close(_param._other_join_conjunct_ctxs, state);
3437
Expr::close(_param._probe_expr_ctxs, state);

be/src/exec/pipeline/nljoin/nljoin_probe_operator.cpp

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "runtime/current_thread.h"
2323
#include "runtime/descriptors.h"
2424
#include "simd/simd.h"
25+
#include "storage/chunk_helper.h"
2526

2627
namespace starrocks::pipeline {
2728

@@ -30,6 +31,7 @@ NLJoinProbeOperator::NLJoinProbeOperator(OperatorFactory* factory, int32_t id, i
3031
const std::string& sql_join_conjuncts,
3132
const std::vector<ExprContext*>& join_conjuncts,
3233
const std::vector<ExprContext*>& conjunct_ctxs,
34+
const std::map<SlotId, ExprContext*>& common_expr_ctxs,
3335
const std::vector<SlotDescriptor*>& col_types, size_t probe_column_count,
3436
const std::shared_ptr<NLJoinContext>& cross_join_context)
3537
: OperatorWithDependency(factory, id, "nestloop_join_probe", plan_node_id, false, driver_sequence),
@@ -39,6 +41,7 @@ NLJoinProbeOperator::NLJoinProbeOperator(OperatorFactory* factory, int32_t id, i
3941
_sql_join_conjuncts(sql_join_conjuncts),
4042
_join_conjuncts(join_conjuncts),
4143
_conjunct_ctxs(conjunct_ctxs),
44+
_common_expr_ctxs(common_expr_ctxs),
4245
_cross_join_context(cross_join_context) {}
4346

4447
Status NLJoinProbeOperator::prepare(RuntimeState* state) {
@@ -309,6 +312,9 @@ Status NLJoinProbeOperator::_eval_nullaware_anti_conjuncts(const ChunkPtr& chunk
309312
// for null-aware left anti join, join_conjunct[0] is on-predicate
310313
// others are other-conjuncts
311314
// process on conjuncts
315+
CommonExprEvalScopeGuard guard(chunk, _common_expr_ctxs);
316+
RETURN_IF_ERROR(guard.evaluate());
317+
312318
{
313319
ASSIGN_OR_RETURN(ColumnPtr column, _join_conjuncts[0]->evaluate(chunk.get()));
314320
size_t num_false = ColumnHelper::count_false_with_notnull(column);
@@ -354,6 +360,8 @@ Status NLJoinProbeOperator::_eval_nullaware_anti_conjuncts(const ChunkPtr& chunk
354360

355361
Status NLJoinProbeOperator::_probe_for_inner_join(const ChunkPtr& chunk) {
356362
if (!_join_conjuncts.empty() && chunk && !chunk->is_empty()) {
363+
CommonExprEvalScopeGuard guard(chunk, _common_expr_ctxs);
364+
RETURN_IF_ERROR(guard.evaluate());
357365
RETURN_IF_ERROR(eval_conjuncts_and_in_filters(_join_conjuncts, chunk.get(), nullptr, true));
358366
}
359367
return Status::OK();
@@ -374,7 +382,10 @@ Status NLJoinProbeOperator::_probe_for_other_join(const ChunkPtr& chunk) {
374382
if (_is_null_aware_left_anti_join()) {
375383
RETURN_IF_ERROR(_eval_nullaware_anti_conjuncts(chunk, &filter));
376384
} else {
385+
CommonExprEvalScopeGuard guard(chunk, _common_expr_ctxs);
386+
RETURN_IF_ERROR(guard.evaluate());
377387
RETURN_IF_ERROR(eval_conjuncts_and_in_filters(_join_conjuncts, chunk.get(), &filter, apply_filter));
388+
chunk->check_or_die();
378389
}
379390
DCHECK(!!filter);
380391
// The filter has not been assigned if no rows matched
@@ -652,8 +663,11 @@ Status NLJoinProbeOperator::_permute_right_join(size_t chunk_size) {
652663
}
653664
}
654665
permute_rows += chunk->num_rows();
655-
656-
RETURN_IF_ERROR(eval_conjuncts(_conjunct_ctxs, chunk.get(), nullptr));
666+
{
667+
CommonExprEvalScopeGuard guard(chunk, _common_expr_ctxs);
668+
RETURN_IF_ERROR(guard.evaluate());
669+
RETURN_IF_ERROR(eval_conjuncts(_conjunct_ctxs, chunk.get(), nullptr));
670+
}
657671
RETURN_IF_ERROR(_output_accumulator.push(std::move(chunk)));
658672
match_flag_index += cur_chunk_size;
659673
}
@@ -703,7 +717,11 @@ StatusOr<ChunkPtr> NLJoinProbeOperator::_pull_chunk_for_other_join(size_t chunk_
703717
ASSIGN_OR_RETURN(ChunkPtr chunk, _permute_chunk_for_other_join(chunk_size));
704718
DCHECK(chunk);
705719
RETURN_IF_ERROR(_probe_for_other_join(chunk));
706-
RETURN_IF_ERROR(eval_conjuncts(_conjunct_ctxs, chunk.get(), nullptr));
720+
{
721+
CommonExprEvalScopeGuard guard(chunk, _common_expr_ctxs);
722+
RETURN_IF_ERROR(guard.evaluate());
723+
RETURN_IF_ERROR(eval_conjuncts(_conjunct_ctxs, chunk.get(), nullptr));
724+
}
707725

708726
RETURN_IF_ERROR(_output_accumulator.push(std::move(chunk)));
709727
if (ChunkPtr res = _output_accumulator.pull()) {
@@ -800,9 +818,9 @@ void NLJoinProbeOperatorFactory::_init_row_desc() {
800818
}
801819

802820
OperatorPtr NLJoinProbeOperatorFactory::create(int32_t degree_of_parallelism, int32_t driver_sequence) {
803-
return std::make_shared<NLJoinProbeOperator>(this, _id, _plan_node_id, driver_sequence, _join_op,
804-
_sql_join_conjuncts, _join_conjuncts, _conjunct_ctxs, _col_types,
805-
_probe_column_count, _cross_join_context);
821+
return std::make_shared<NLJoinProbeOperator>(
822+
this, _id, _plan_node_id, driver_sequence, _join_op, _sql_join_conjuncts, _join_conjuncts, _conjunct_ctxs,
823+
_common_expr_ctxs, _col_types, _probe_column_count, _cross_join_context);
806824
}
807825

808826
Status NLJoinProbeOperatorFactory::prepare(RuntimeState* state) {
@@ -812,6 +830,9 @@ Status NLJoinProbeOperatorFactory::prepare(RuntimeState* state) {
812830
_cross_join_context->ref();
813831

814832
_init_row_desc();
833+
834+
RETURN_IF_ERROR(Expr::prepare(_common_expr_ctxs, state));
835+
RETURN_IF_ERROR(Expr::open(_common_expr_ctxs, state));
815836
RETURN_IF_ERROR(Expr::prepare(_join_conjuncts, state));
816837
RETURN_IF_ERROR(Expr::open(_join_conjuncts, state));
817838
RETURN_IF_ERROR(Expr::prepare(_conjunct_ctxs, state));
@@ -821,6 +842,7 @@ Status NLJoinProbeOperatorFactory::prepare(RuntimeState* state) {
821842
}
822843

823844
void NLJoinProbeOperatorFactory::close(RuntimeState* state) {
845+
Expr::close(_common_expr_ctxs, state);
824846
Expr::close(_join_conjuncts, state);
825847
Expr::close(_conjunct_ctxs, state);
826848

be/src/exec/pipeline/nljoin/nljoin_probe_operator.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class NLJoinProbeOperator final : public OperatorWithDependency {
3737
NLJoinProbeOperator(OperatorFactory* factory, int32_t id, int32_t plan_node_id, int32_t driver_sequence,
3838
TJoinOp::type join_op, const std::string& sql_join_conjuncts,
3939
const std::vector<ExprContext*>& join_conjuncts, const std::vector<ExprContext*>& conjunct_ctxs,
40+
const std::map<SlotId, ExprContext*>& common_expr_ctxs,
4041
const std::vector<SlotDescriptor*>& col_types, size_t probe_column_count,
4142
const std::shared_ptr<NLJoinContext>& cross_join_context);
4243

@@ -115,6 +116,7 @@ class NLJoinProbeOperator final : public OperatorWithDependency {
115116
const std::vector<ExprContext*>& _join_conjuncts;
116117

117118
const std::vector<ExprContext*>& _conjunct_ctxs;
119+
const std::map<SlotId, ExprContext*>& _common_expr_ctxs;
118120
const std::shared_ptr<NLJoinContext>& _cross_join_context;
119121

120122
bool _input_finished = false;
@@ -147,6 +149,7 @@ class NLJoinProbeOperatorFactory final : public OperatorWithDependencyFactory {
147149
const RowDescriptor& left_row_desc, const RowDescriptor& right_row_desc,
148150
std::string sql_join_conjuncts, std::vector<ExprContext*>&& join_conjuncts,
149151
std::vector<ExprContext*>&& conjunct_ctxs,
152+
std::map<SlotId, ExprContext*>&& common_expr_ctxs,
150153
std::shared_ptr<NLJoinContext>&& cross_join_context, TJoinOp::type join_op)
151154
: OperatorWithDependencyFactory(id, "cross_join_left", plan_node_id),
152155
_join_op(join_op),
@@ -155,6 +158,7 @@ class NLJoinProbeOperatorFactory final : public OperatorWithDependencyFactory {
155158
_sql_join_conjuncts(std::move(sql_join_conjuncts)),
156159
_join_conjuncts(std::move(join_conjuncts)),
157160
_conjunct_ctxs(std::move(conjunct_ctxs)),
161+
_common_expr_ctxs(std::move(common_expr_ctxs)),
158162
_cross_join_context(std::move(cross_join_context)) {}
159163

160164
~NLJoinProbeOperatorFactory() override = default;
@@ -178,6 +182,7 @@ class NLJoinProbeOperatorFactory final : public OperatorWithDependencyFactory {
178182
std::string _sql_join_conjuncts;
179183
std::vector<ExprContext*> _join_conjuncts;
180184
std::vector<ExprContext*> _conjunct_ctxs;
185+
std::map<SlotId, ExprContext*> _common_expr_ctxs;
181186

182187
std::shared_ptr<NLJoinContext> _cross_join_context;
183188
};

0 commit comments

Comments
 (0)