Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions be/src/exec/cross_join_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,13 @@ Status CrossJoinNode::init(const TPlanNode& tnode, RuntimeState* state) {
_build_runtime_filters.emplace_back(rf_desc);
}
}
if (tnode.nestloop_join_node.__isset.common_slot_map) {
for (const auto& [key, val] : tnode.nestloop_join_node.common_slot_map) {
ExprContext* context;
RETURN_IF_ERROR(Expr::create_expr_tree(_pool, val, &context, state, true));
_common_expr_ctxs.insert({key, context});
}
}
return Status::OK();
}

Expand Down Expand Up @@ -608,10 +615,10 @@ std::vector<std::shared_ptr<pipeline::OperatorFactory>> CrossJoinNode::_decompos

OpFactories left_ops = _children[0]->decompose_to_pipeline(context);
// communication with CrossJoinRight through shared_data.
auto left_factory =
std::make_shared<ProbeFactory>(context->next_operator_id(), id(), _row_descriptor, child(0)->row_desc(),
child(1)->row_desc(), _sql_join_conjuncts, std::move(_join_conjuncts),
std::move(_conjunct_ctxs), std::move(cross_join_context), _join_op);
auto left_factory = std::make_shared<ProbeFactory>(
context->next_operator_id(), id(), _row_descriptor, child(0)->row_desc(), child(1)->row_desc(),
_sql_join_conjuncts, std::move(_join_conjuncts), std::move(_conjunct_ctxs), std::move(_common_expr_ctxs),
std::move(cross_join_context), _join_op);
// Initialize OperatorFactory's fields involving runtime filters.
this->init_runtime_filter_for_operator(left_factory.get(), context, rc_rf_probe_collector);
if (!context->is_colocate_group()) {
Expand Down
2 changes: 2 additions & 0 deletions be/src/exec/cross_join_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ class CrossJoinNode final : public ExecNode {

std::vector<RuntimeFilterBuildDescriptor*> _build_runtime_filters;
bool _interpolate_passthrough = false;

std::map<SlotId, ExprContext*> _common_expr_ctxs;
};

} // namespace starrocks
10 changes: 9 additions & 1 deletion be/src/exec/hash_join_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,14 @@ Status HashJoinNode::init(const TPlanNode& tnode, RuntimeState* state) {
_build_equivalence_partition_expr_ctxs = _build_expr_ctxs;
}

if (tnode.__isset.hash_join_node && tnode.hash_join_node.__isset.common_slot_map) {
for (const auto& [key, val] : tnode.hash_join_node.common_slot_map) {
ExprContext* context;
RETURN_IF_ERROR(Expr::create_expr_tree(_pool, val, &context, state, true));
_common_expr_ctxs.insert({key, context});
}
}

RETURN_IF_ERROR(Expr::create_expr_trees(_pool, tnode.hash_join_node.other_join_conjuncts,
&_other_join_conjunct_ctxs, state));

Expand Down Expand Up @@ -484,7 +492,7 @@ pipeline::OpFactories HashJoinNode::_decompose_to_pipeline(pipeline::PipelineBui
_other_join_conjunct_ctxs, _conjunct_ctxs, child(1)->row_desc(), child(0)->row_desc(),
child(1)->type(), child(0)->type(), child(1)->conjunct_ctxs().empty(), _build_runtime_filters,
_output_slots, _output_slots, context->degree_of_parallelism(), _distribution_mode,
_enable_late_materialization, _enable_partition_hash_join, _is_skew_join);
_enable_late_materialization, _enable_partition_hash_join, _is_skew_join, _common_expr_ctxs);
auto hash_joiner_factory = std::make_shared<starrocks::pipeline::HashJoinerFactory>(param);

// Create a shared RefCountedRuntimeFilterCollector
Expand Down
2 changes: 2 additions & 0 deletions be/src/exec/hash_join_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ class HashJoinNode final : public ExecNode {
bool _probe_eos = false; // probe table scan finished;
size_t _runtime_join_filter_pushdown_limit = 1024000;

std::map<SlotId, ExprContext*> _common_expr_ctxs;

RuntimeProfile::Counter* _build_timer = nullptr;
RuntimeProfile::Counter* _build_ht_timer = nullptr;
RuntimeProfile::Counter* _copy_right_table_chunk_timer = nullptr;
Expand Down
12 changes: 12 additions & 0 deletions be/src/exec/hash_joiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "pipeline/hashjoin/hash_joiner_fwd.h"
#include "runtime/current_thread.h"
#include "simd/simd.h"
#include "storage/chunk_helper.h"
#include "util/runtime_profile.h"

namespace starrocks {
Expand Down Expand Up @@ -73,6 +74,7 @@ HashJoiner::HashJoiner(const HashJoinerParam& param)
_probe_expr_ctxs(param._probe_expr_ctxs),
_other_join_conjunct_ctxs(param._other_join_conjunct_ctxs),
_conjunct_ctxs(param._conjunct_ctxs),
_common_expr_ctxs(param._common_expr_ctxs),
_build_row_descriptor(param._build_row_descriptor),
_probe_row_descriptor(param._probe_row_descriptor),
_build_node_type(param._build_node_type),
Expand Down Expand Up @@ -158,6 +160,11 @@ void HashJoiner::_init_hash_table_param(HashTableParam* param, RuntimeState* sta
param->column_view_concat_rows_limit = state->column_view_concat_rows_limit();
param->column_view_concat_bytes_limit = state->column_view_concat_bytes_limit();
std::set<SlotId> predicate_slots;
for (const auto& [slot_id, ctx] : _common_expr_ctxs) {
std::vector<SlotId> expr_slots;
ctx->root()->get_slot_ids(&expr_slots);
predicate_slots.insert(expr_slots.begin(), expr_slots.end());
}
for (ExprContext* expr_context : _conjunct_ctxs) {
std::vector<SlotId> expr_slots;
expr_context->root()->get_slot_ids(&expr_slots);
Expand Down Expand Up @@ -388,6 +395,9 @@ Status HashJoiner::_calc_filter_for_other_conjunct(ChunkPtr* chunk, Filter& filt
hit_all = false;
filter.assign((*chunk)->num_rows(), 1);

CommonExprEvalScopeGuard guard(*chunk, _common_expr_ctxs);
RETURN_IF_ERROR(guard.evaluate());

for (auto* ctx : _other_join_conjunct_ctxs) {
ASSIGN_OR_RETURN(ColumnPtr column, ctx->evaluate((*chunk).get()))
size_t true_count = ColumnHelper::count_true_with_notnull(column);
Expand Down Expand Up @@ -516,6 +526,8 @@ Status HashJoiner::_process_other_conjunct(ChunkPtr* chunk, JoinHashTable& hash_

Status HashJoiner::_process_where_conjunct(ChunkPtr* chunk) {
SCOPED_TIMER(probe_metrics().where_conjunct_evaluate_timer);
CommonExprEvalScopeGuard guard(*chunk, _common_expr_ctxs);
RETURN_IF_ERROR(guard.evaluate());
return ExecNode::eval_conjuncts(_conjunct_ctxs, (*chunk).get());
}

Expand Down
8 changes: 6 additions & 2 deletions be/src/exec/hash_joiner.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ struct HashJoinerParam {
bool build_conjunct_ctxs_is_empty, std::list<RuntimeFilterBuildDescriptor*> build_runtime_filters,
std::set<SlotId> build_output_slots, std::set<SlotId> probe_output_slots, size_t max_dop,
const TJoinDistributionMode::type distribution_mode, bool enable_late_materialization,
bool enable_partition_hash_join, bool is_skew_join)
bool enable_partition_hash_join, bool is_skew_join,
const std::map<SlotId, ExprContext*>& common_expr_ctxs)
: _pool(pool),
_hash_join_node(hash_join_node),
_is_null_safes(std::move(is_null_safes)),
Expand All @@ -92,7 +93,8 @@ struct HashJoinerParam {
_distribution_mode(distribution_mode),
_enable_late_materialization(enable_late_materialization),
_enable_partition_hash_join(enable_partition_hash_join),
_is_skew_join(is_skew_join) {}
_is_skew_join(is_skew_join),
_common_expr_ctxs(common_expr_ctxs) {}

HashJoinerParam(HashJoinerParam&&) = default;
HashJoinerParam(HashJoinerParam&) = default;
Expand Down Expand Up @@ -120,6 +122,7 @@ struct HashJoinerParam {
const bool _enable_late_materialization;
const bool _enable_partition_hash_join;
const bool _is_skew_join;
const std::map<SlotId, ExprContext*> _common_expr_ctxs;
};

inline bool could_short_circuit(TJoinOp::type join_type) {
Expand Down Expand Up @@ -439,6 +442,7 @@ class HashJoiner final : public pipeline::ContextWithDependency {
const std::vector<ExprContext*>& _other_join_conjunct_ctxs;
// Conjuncts in Join followed by a filter predicate, usually in Where and Having.
const std::vector<ExprContext*>& _conjunct_ctxs;
const std::map<SlotId, ExprContext*>& _common_expr_ctxs;
const RowDescriptor& _build_row_descriptor;
const RowDescriptor& _probe_row_descriptor;
const TPlanNodeType::type _build_node_type;
Expand Down
3 changes: 3 additions & 0 deletions be/src/exec/pipeline/hashjoin/hash_joiner_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,19 @@ namespace starrocks::pipeline {
Status HashJoinerFactory::prepare(RuntimeState* state) {
RETURN_IF_ERROR(Expr::prepare(_param._build_expr_ctxs, state));
RETURN_IF_ERROR(Expr::prepare(_param._probe_expr_ctxs, state));
RETURN_IF_ERROR(Expr::prepare(_param._common_expr_ctxs, state));
RETURN_IF_ERROR(Expr::prepare(_param._other_join_conjunct_ctxs, state));
RETURN_IF_ERROR(Expr::prepare(_param._conjunct_ctxs, state));
RETURN_IF_ERROR(Expr::open(_param._build_expr_ctxs, state));
RETURN_IF_ERROR(Expr::open(_param._probe_expr_ctxs, state));
RETURN_IF_ERROR(Expr::open(_param._common_expr_ctxs, state));
RETURN_IF_ERROR(Expr::open(_param._other_join_conjunct_ctxs, state));
RETURN_IF_ERROR(Expr::open(_param._conjunct_ctxs, state));
return Status::OK();
}

void HashJoinerFactory::close(RuntimeState* state) {
Expr::close(_param._common_expr_ctxs, state);
Expr::close(_param._conjunct_ctxs, state);
Expr::close(_param._other_join_conjunct_ctxs, state);
Expr::close(_param._probe_expr_ctxs, state);
Expand Down
34 changes: 28 additions & 6 deletions be/src/exec/pipeline/nljoin/nljoin_probe_operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "runtime/current_thread.h"
#include "runtime/descriptors.h"
#include "simd/simd.h"
#include "storage/chunk_helper.h"

namespace starrocks::pipeline {

Expand All @@ -30,6 +31,7 @@ NLJoinProbeOperator::NLJoinProbeOperator(OperatorFactory* factory, int32_t id, i
const std::string& sql_join_conjuncts,
const std::vector<ExprContext*>& join_conjuncts,
const std::vector<ExprContext*>& conjunct_ctxs,
const std::map<SlotId, ExprContext*>& common_expr_ctxs,
const std::vector<SlotDescriptor*>& col_types, size_t probe_column_count,
const std::shared_ptr<NLJoinContext>& cross_join_context)
: OperatorWithDependency(factory, id, "nestloop_join_probe", plan_node_id, false, driver_sequence),
Expand All @@ -39,6 +41,7 @@ NLJoinProbeOperator::NLJoinProbeOperator(OperatorFactory* factory, int32_t id, i
_sql_join_conjuncts(sql_join_conjuncts),
_join_conjuncts(join_conjuncts),
_conjunct_ctxs(conjunct_ctxs),
_common_expr_ctxs(common_expr_ctxs),
_cross_join_context(cross_join_context) {}

Status NLJoinProbeOperator::prepare(RuntimeState* state) {
Expand Down Expand Up @@ -309,6 +312,9 @@ Status NLJoinProbeOperator::_eval_nullaware_anti_conjuncts(const ChunkPtr& chunk
// for null-aware left anti join, join_conjunct[0] is on-predicate
// others are other-conjuncts
// process on conjuncts
CommonExprEvalScopeGuard guard(chunk, _common_expr_ctxs);
RETURN_IF_ERROR(guard.evaluate());

{
ASSIGN_OR_RETURN(ColumnPtr column, _join_conjuncts[0]->evaluate(chunk.get()));
size_t num_false = ColumnHelper::count_false_with_notnull(column);
Expand Down Expand Up @@ -354,6 +360,8 @@ Status NLJoinProbeOperator::_eval_nullaware_anti_conjuncts(const ChunkPtr& chunk

Status NLJoinProbeOperator::_probe_for_inner_join(const ChunkPtr& chunk) {
if (!_join_conjuncts.empty() && chunk && !chunk->is_empty()) {
CommonExprEvalScopeGuard guard(chunk, _common_expr_ctxs);
RETURN_IF_ERROR(guard.evaluate());
RETURN_IF_ERROR(eval_conjuncts_and_in_filters(_join_conjuncts, chunk.get(), nullptr, true));
}
return Status::OK();
Expand All @@ -374,7 +382,10 @@ Status NLJoinProbeOperator::_probe_for_other_join(const ChunkPtr& chunk) {
if (_is_null_aware_left_anti_join()) {
RETURN_IF_ERROR(_eval_nullaware_anti_conjuncts(chunk, &filter));
} else {
CommonExprEvalScopeGuard guard(chunk, _common_expr_ctxs);
RETURN_IF_ERROR(guard.evaluate());
RETURN_IF_ERROR(eval_conjuncts_and_in_filters(_join_conjuncts, chunk.get(), &filter, apply_filter));
chunk->check_or_die();
}
DCHECK(!!filter);
// The filter has not been assigned if no rows matched
Expand Down Expand Up @@ -652,8 +663,11 @@ Status NLJoinProbeOperator::_permute_right_join(size_t chunk_size) {
}
}
permute_rows += chunk->num_rows();

RETURN_IF_ERROR(eval_conjuncts(_conjunct_ctxs, chunk.get(), nullptr));
{
CommonExprEvalScopeGuard guard(chunk, _common_expr_ctxs);
RETURN_IF_ERROR(guard.evaluate());
RETURN_IF_ERROR(eval_conjuncts(_conjunct_ctxs, chunk.get(), nullptr));
}
RETURN_IF_ERROR(_output_accumulator.push(std::move(chunk)));
match_flag_index += cur_chunk_size;
}
Expand Down Expand Up @@ -703,7 +717,11 @@ StatusOr<ChunkPtr> NLJoinProbeOperator::_pull_chunk_for_other_join(size_t chunk_
ASSIGN_OR_RETURN(ChunkPtr chunk, _permute_chunk_for_other_join(chunk_size));
DCHECK(chunk);
RETURN_IF_ERROR(_probe_for_other_join(chunk));
RETURN_IF_ERROR(eval_conjuncts(_conjunct_ctxs, chunk.get(), nullptr));
{
CommonExprEvalScopeGuard guard(chunk, _common_expr_ctxs);
RETURN_IF_ERROR(guard.evaluate());
RETURN_IF_ERROR(eval_conjuncts(_conjunct_ctxs, chunk.get(), nullptr));
}

RETURN_IF_ERROR(_output_accumulator.push(std::move(chunk)));
if (ChunkPtr res = _output_accumulator.pull()) {
Expand Down Expand Up @@ -800,9 +818,9 @@ void NLJoinProbeOperatorFactory::_init_row_desc() {
}

OperatorPtr NLJoinProbeOperatorFactory::create(int32_t degree_of_parallelism, int32_t driver_sequence) {
return std::make_shared<NLJoinProbeOperator>(this, _id, _plan_node_id, driver_sequence, _join_op,
_sql_join_conjuncts, _join_conjuncts, _conjunct_ctxs, _col_types,
_probe_column_count, _cross_join_context);
return std::make_shared<NLJoinProbeOperator>(
this, _id, _plan_node_id, driver_sequence, _join_op, _sql_join_conjuncts, _join_conjuncts, _conjunct_ctxs,
_common_expr_ctxs, _col_types, _probe_column_count, _cross_join_context);
}

Status NLJoinProbeOperatorFactory::prepare(RuntimeState* state) {
Expand All @@ -812,6 +830,9 @@ Status NLJoinProbeOperatorFactory::prepare(RuntimeState* state) {
_cross_join_context->ref();

_init_row_desc();

RETURN_IF_ERROR(Expr::prepare(_common_expr_ctxs, state));
RETURN_IF_ERROR(Expr::open(_common_expr_ctxs, state));
RETURN_IF_ERROR(Expr::prepare(_join_conjuncts, state));
RETURN_IF_ERROR(Expr::open(_join_conjuncts, state));
RETURN_IF_ERROR(Expr::prepare(_conjunct_ctxs, state));
Expand All @@ -821,6 +842,7 @@ Status NLJoinProbeOperatorFactory::prepare(RuntimeState* state) {
}

void NLJoinProbeOperatorFactory::close(RuntimeState* state) {
Expr::close(_common_expr_ctxs, state);
Expr::close(_join_conjuncts, state);
Expr::close(_conjunct_ctxs, state);

Expand Down
5 changes: 5 additions & 0 deletions be/src/exec/pipeline/nljoin/nljoin_probe_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class NLJoinProbeOperator final : public OperatorWithDependency {
NLJoinProbeOperator(OperatorFactory* factory, int32_t id, int32_t plan_node_id, int32_t driver_sequence,
TJoinOp::type join_op, const std::string& sql_join_conjuncts,
const std::vector<ExprContext*>& join_conjuncts, const std::vector<ExprContext*>& conjunct_ctxs,
const std::map<SlotId, ExprContext*>& common_expr_ctxs,
const std::vector<SlotDescriptor*>& col_types, size_t probe_column_count,
const std::shared_ptr<NLJoinContext>& cross_join_context);

Expand Down Expand Up @@ -115,6 +116,7 @@ class NLJoinProbeOperator final : public OperatorWithDependency {
const std::vector<ExprContext*>& _join_conjuncts;

const std::vector<ExprContext*>& _conjunct_ctxs;
const std::map<SlotId, ExprContext*>& _common_expr_ctxs;
const std::shared_ptr<NLJoinContext>& _cross_join_context;

bool _input_finished = false;
Expand Down Expand Up @@ -147,6 +149,7 @@ class NLJoinProbeOperatorFactory final : public OperatorWithDependencyFactory {
const RowDescriptor& left_row_desc, const RowDescriptor& right_row_desc,
std::string sql_join_conjuncts, std::vector<ExprContext*>&& join_conjuncts,
std::vector<ExprContext*>&& conjunct_ctxs,
std::map<SlotId, ExprContext*>&& common_expr_ctxs,
std::shared_ptr<NLJoinContext>&& cross_join_context, TJoinOp::type join_op)
: OperatorWithDependencyFactory(id, "cross_join_left", plan_node_id),
_join_op(join_op),
Expand All @@ -155,6 +158,7 @@ class NLJoinProbeOperatorFactory final : public OperatorWithDependencyFactory {
_sql_join_conjuncts(std::move(sql_join_conjuncts)),
_join_conjuncts(std::move(join_conjuncts)),
_conjunct_ctxs(std::move(conjunct_ctxs)),
_common_expr_ctxs(std::move(common_expr_ctxs)),
_cross_join_context(std::move(cross_join_context)) {}

~NLJoinProbeOperatorFactory() override = default;
Expand All @@ -178,6 +182,7 @@ class NLJoinProbeOperatorFactory final : public OperatorWithDependencyFactory {
std::string _sql_join_conjuncts;
std::vector<ExprContext*> _join_conjuncts;
std::vector<ExprContext*> _conjunct_ctxs;
std::map<SlotId, ExprContext*> _common_expr_ctxs;

std::shared_ptr<NLJoinContext> _cross_join_context;
};
Expand Down
Loading
Loading