Skip to content

Commit b7380d3

Browse files
authored
[Feature] (IVM Part4) Support more agg combinator functions (#62122)
Signed-off-by: shuming.li <ming.moriarty@gmail.com>
1 parent ced0083 commit b7380d3

22 files changed

+2063
-136
lines changed

be/src/exprs/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ set(EXPR_FILES
3131
agg/factory/aggregate_resolver_utility.cpp
3232
agg/factory/aggregate_resolver_variance.cpp
3333
agg/factory/aggregate_resolver_window.cpp
34+
agg/combinator/agg_state_utils.cpp
3435
base64.cpp
3536
binary_functions.cpp
3637
expr_context.cpp
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
// Copyright 2021-present StarRocks, Inc. All rights reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "column/nullable_column.h"
18+
#include "column/vectorized_fwd.h"
19+
#include "exprs/agg/aggregate.h"
20+
#include "exprs/agg/combinator/agg_state_combinator.h"
21+
#include "exprs/agg/combinator/agg_state_utils.h"
22+
#include "runtime/agg_state_desc.h"
23+
24+
namespace starrocks {
25+
struct AggStateCombineState {};
26+
27+
// An aggregate combine combinator that combines aggregate inputs to compute intermediate results.
28+
// This combinator is equivalent to calling `{agg_func}_union({agg_func}_state(arg_types))` in SQL,
29+
// but with reduced function call overhead and memory allocation for better performance.
30+
// eg:
31+
// - SQL: sum_union(sum_state(col))
32+
// - This combinator: sum_combine(col)
33+
//
34+
// DESC: intermediate_type {agg_func}_combine(arg types)
35+
// input type : aggregate function's arg types
36+
// intermediate type : aggregate function's intermediate_type
37+
// return type : aggregate function's intermediate_type
38+
class AggStateCombine final : public AggStateCombinator<AggStateCombineState, AggStateCombine> {
39+
public:
40+
AggStateCombine(AggStateDesc agg_state_desc, const AggregateFunction* function)
41+
: AggStateCombinator(agg_state_desc, function) {
42+
DCHECK(_function != nullptr);
43+
}
44+
45+
void update(FunctionContext* ctx, const Column** columns, AggDataPtr __restrict state,
46+
size_t row_num) const override {
47+
_function->update(ctx, columns, state, row_num);
48+
}
49+
50+
void merge(FunctionContext* ctx, const Column* column, AggDataPtr __restrict state, size_t row_num) const override {
51+
_function->merge(ctx, column, state, row_num);
52+
}
53+
54+
void serialize_to_column([[maybe_unused]] FunctionContext* ctx, ConstAggDataPtr __restrict state,
55+
Column* to) const override {
56+
_serialize_to_column_nullable(ctx, state, to);
57+
}
58+
59+
void convert_to_serialize_format([[maybe_unused]] FunctionContext* ctx, const Columns& srcs, size_t chunk_size,
60+
ColumnPtr* dst) const override {
61+
DCHECK_EQ(1, srcs.size());
62+
*dst = srcs[0];
63+
}
64+
65+
void finalize_to_column(FunctionContext* ctx __attribute__((unused)), ConstAggDataPtr __restrict state,
66+
Column* to) const override {
67+
_serialize_to_column_nullable(ctx, state, to);
68+
}
69+
70+
std::string get_name() const override { return "agg_state_combine"; }
71+
72+
private:
73+
inline void _serialize_to_column_nullable(FunctionContext* ctx, ConstAggDataPtr __restrict state,
74+
Column* to) const {
75+
// `count` is a special case because `CountNullableAggregateFunction` is used to handle nullable column
76+
// and its serialize/finalize is meant to not nullable.
77+
if (_function->get_name() == AggStateUtils::FUNCTION_COUNT ||
78+
_function->get_name() == AggStateUtils::FUNCTION_COUNT_NULLABLE) {
79+
if (LIKELY(to->is_nullable())) {
80+
auto* nullable_column = down_cast<NullableColumn*>(to);
81+
_function->serialize_to_column(ctx, state, nullable_column->mutable_data_column());
82+
nullable_column->null_column_data().push_back(0);
83+
} else {
84+
_function->serialize_to_column(ctx, state, to);
85+
}
86+
} else {
87+
_function->serialize_to_column(ctx, state, to);
88+
}
89+
}
90+
};
91+
92+
} // namespace starrocks

be/src/exprs/agg/combinator/agg_state_merge.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ struct AggStateMergeState {};
2323

2424
// An aggregate merge combinator that merges aggregate intermediate states to compute the final result of aggregate function.
2525
//
26-
// DESC: return_type {agg_func}_merge(immediate_type)
27-
// input type : aggregate function's immediate_type
28-
// intermediate type : aggregate function's immediate_type
26+
// DESC: return_type {agg_func}_merge(intermediate_type)
27+
// input type : aggregate function's intermediate_type
28+
// intermediate type : aggregate function's intermediate_type
2929
// return type : aggregate function's return type
3030
class AggStateMerge final : public AggStateCombinator<AggStateMergeState, AggStateMerge> {
3131
public:

be/src/exprs/agg/combinator/agg_state_union.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@
2121
namespace starrocks {
2222
struct AggStateUnionState {};
2323

24-
// An aggregate union combinator that combines intermediate states to compute the immediate result of aggregate function.
24+
// An aggregate union combinator that combines intermediate states to compute the intermediate result of aggregate function.
2525
//
26-
// DESC: immediate_type {agg_func}_union(immediate_type)
27-
// input type : aggregate function's immediate_type
28-
// intermediate type : aggregate function's immediate_type
29-
// return type : aggregate function's immediate_type
26+
// DESC: intermediate_type {agg_func}_union(intermediate_type)
27+
// input type : aggregate function's intermediate_type
28+
// intermediate type : aggregate function's intermediate_type
29+
// return type : aggregate function's intermediate_type
3030
class AggStateUnion final : public AggStateCombinator<AggStateUnionState, AggStateUnion> {
3131
public:
3232
AggStateUnion(AggStateDesc agg_state_desc, const AggregateFunction* function)
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
// Copyright 2021-present StarRocks, Inc. All rights reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "exprs/agg/combinator/agg_state_utils.h"
16+
17+
#include <fmt/format.h>
18+
19+
#include "exprs/agg/combinator/agg_state_combinator.h"
20+
#include "exprs/agg/combinator/agg_state_combine.h"
21+
#include "exprs/agg/combinator/agg_state_if.h"
22+
#include "exprs/agg/combinator/agg_state_merge.h"
23+
#include "exprs/agg/combinator/agg_state_union.h"
24+
#include "exprs/agg/combinator/state_function.h"
25+
#include "exprs/agg/combinator/state_merge_function.h"
26+
#include "exprs/agg/combinator/state_union_function.h"
27+
28+
namespace starrocks {
29+
30+
bool AggStateUtils::is_count_function(const std::string& func_name) noexcept {
31+
return func_name == FUNCTION_COUNT || func_name == (FUNCTION_COUNT + AGG_STATE_IF_SUFFIX) ||
32+
func_name == (FUNCTION_COUNT + AGG_STATE_UNION_SUFFIX) ||
33+
func_name == (FUNCTION_COUNT + AGG_STATE_MERGE_SUFFIX) ||
34+
func_name == (FUNCTION_COUNT + AGG_STATE_COMBINE_SUFFIX);
35+
}
36+
37+
// Get the aggregate state descriptor from the aggregate function.
38+
const AggStateDesc* AggStateUtils::get_agg_state_desc(const AggregateFunction* agg_func) {
39+
if (dynamic_cast<const AggStateUnion*>(agg_func)) {
40+
auto* agg_state_union = down_cast<const AggStateUnion*>(agg_func);
41+
return agg_state_union->get_agg_state_desc();
42+
} else if (dynamic_cast<const AggStateMerge*>(agg_func)) {
43+
auto* agg_state_merge = down_cast<const AggStateMerge*>(agg_func);
44+
return agg_state_merge->get_agg_state_desc();
45+
} else if (dynamic_cast<const AggStateCombine*>(agg_func)) {
46+
auto* agg_state_merge = down_cast<const AggStateCombine*>(agg_func);
47+
return agg_state_merge->get_agg_state_desc();
48+
} else if (dynamic_cast<const AggStateIf*>(agg_func)) {
49+
auto* agg_state_if = down_cast<const AggStateIf*>(agg_func);
50+
return agg_state_if->get_agg_state_desc();
51+
}
52+
return nullptr;
53+
}
54+
55+
// Get the aggregate state function according to the agg_state_desc and function name.
56+
// If the function is not an aggregate state function, return nullptr.
57+
StatusOr<AggregateFunctionPtr> AggStateUtils::get_agg_state_function(const AggStateDesc& agg_state_desc,
58+
const std::string& func_name,
59+
const std::vector<TypeDescriptor>& arg_types) {
60+
auto nested_func_name = agg_state_desc.get_func_name();
61+
bool is_merge_or_union = AggStateUtils::is_agg_state_merge(nested_func_name, func_name) ||
62+
AggStateUtils::is_agg_state_union(nested_func_name, func_name);
63+
if (is_merge_or_union && arg_types.size() != 1) {
64+
return Status::InternalError(
65+
fmt::format("Invalid agg function plan: {} with (arg type {})", func_name, arg_types.size()));
66+
}
67+
68+
if (AggStateUtils::is_agg_state_merge(nested_func_name, func_name)) {
69+
// aggregate _merge combinator
70+
auto* nested_func = AggStateDesc::get_agg_state_func(&agg_state_desc);
71+
if (nested_func == nullptr) {
72+
return Status::InternalError(fmt::format(
73+
"Merge combinator function {} fails to get the nested agg func: {}", func_name, nested_func_name));
74+
}
75+
return std::make_shared<AggStateMerge>(std::move(agg_state_desc), nested_func);
76+
} else if (AggStateUtils::is_agg_state_union(nested_func_name, func_name)) {
77+
// aggregate _union combinator
78+
auto* nested_func = AggStateDesc::get_agg_state_func(&agg_state_desc);
79+
if (nested_func == nullptr) {
80+
return Status::InternalError(fmt::format(
81+
"Union combinator function {} fails to get the nested agg func: {}", func_name, nested_func_name));
82+
}
83+
return std::make_shared<AggStateUnion>(std::move(agg_state_desc), nested_func);
84+
} else if (AggStateUtils::is_agg_state_combine(nested_func_name, func_name)) {
85+
// aggregate _combine combinator
86+
auto* nested_func = AggStateDesc::get_agg_state_func(&agg_state_desc);
87+
if (nested_func == nullptr) {
88+
return Status::InternalError(
89+
fmt::format("Combine combinator function {} fails to get the nested agg func: {}", func_name,
90+
nested_func_name));
91+
}
92+
return std::make_shared<AggStateCombine>(std::move(agg_state_desc), nested_func);
93+
} else if (AggStateUtils::is_agg_state_if(nested_func_name, func_name)) {
94+
// aggregate _if combinator
95+
auto* nested_func = AggStateDesc::get_agg_state_func(&agg_state_desc);
96+
if (nested_func == nullptr) {
97+
return Status::InternalError(fmt::format("if combinator function {} fails to get the nested agg func: {}",
98+
func_name, nested_func_name));
99+
}
100+
return std::make_shared<AggStateIf>(std::move(agg_state_desc), nested_func);
101+
} else {
102+
return Status::InternalError(fmt::format("Agg function combinator is not implemented: {}", func_name));
103+
}
104+
}
105+
106+
// Get the aggregate state function according to the TAggStateDesc and function name.
107+
// If the function is not an aggregate state function, return nullptr.
108+
StateCombinatorPtr AggStateUtils::get_agg_state_function(const TAggStateDesc& desc, const std::string& func_name,
109+
const TypeDescriptor& return_type,
110+
std::vector<bool> arg_nullables) {
111+
if (is_agg_state_function(func_name)) {
112+
auto agg_state_desc = AggStateDesc::from_thrift(desc);
113+
// For _state combinator function, it's created according to the agg_state_desc rather than fid.
114+
return std::make_shared<StateFunction>(agg_state_desc, return_type, std::move(arg_nullables));
115+
} else if (is_agg_state_union_function(func_name)) {
116+
auto agg_state_desc = AggStateDesc::from_thrift(desc);
117+
// For _state combinator function, it's created according to the agg_state_desc rather than fid.
118+
return std::make_shared<StateUnionFunction>(agg_state_desc, return_type, std::move(arg_nullables));
119+
} else if (is_agg_state_merge_function(func_name)) {
120+
auto agg_state_desc = AggStateDesc::from_thrift(desc);
121+
// For _state combinator function, it's created according to the agg_state_desc rather than fid.
122+
return std::make_shared<StateMergeFunction>(agg_state_desc, return_type, std::move(arg_nullables));
123+
} else {
124+
return nullptr;
125+
}
126+
}
127+
128+
} // namespace starrocks

0 commit comments

Comments
 (0)