1
+ // Copyright 2021-present StarRocks, Inc. All rights reserved.
2
+ //
3
+ // Licensed under the Apache License, Version 2.0 (the "License");
4
+ // you may not use this file except in compliance with the License.
5
+ // You may obtain a copy of the License at
6
+ //
7
+ // https://www.apache.org/licenses/LICENSE-2.0
8
+ //
9
+ // Unless required by applicable law or agreed to in writing, software
10
+ // distributed under the License is distributed on an "AS IS" BASIS,
11
+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ // See the License for the specific language governing permissions and
13
+ // limitations under the License.
14
+
15
+ #include " exprs/agg/combinator/agg_state_utils.h"
16
+
17
+ #include < fmt/format.h>
18
+
19
+ #include " exprs/agg/combinator/agg_state_combinator.h"
20
+ #include " exprs/agg/combinator/agg_state_combine.h"
21
+ #include " exprs/agg/combinator/agg_state_if.h"
22
+ #include " exprs/agg/combinator/agg_state_merge.h"
23
+ #include " exprs/agg/combinator/agg_state_union.h"
24
+ #include " exprs/agg/combinator/state_function.h"
25
+ #include " exprs/agg/combinator/state_merge_function.h"
26
+ #include " exprs/agg/combinator/state_union_function.h"
27
+
28
+ namespace starrocks {
29
+
30
+ bool AggStateUtils::is_count_function (const std::string& func_name) noexcept {
31
+ return func_name == FUNCTION_COUNT || func_name == (FUNCTION_COUNT + AGG_STATE_IF_SUFFIX) ||
32
+ func_name == (FUNCTION_COUNT + AGG_STATE_UNION_SUFFIX) ||
33
+ func_name == (FUNCTION_COUNT + AGG_STATE_MERGE_SUFFIX) ||
34
+ func_name == (FUNCTION_COUNT + AGG_STATE_COMBINE_SUFFIX);
35
+ }
36
+
37
+ // Get the aggregate state descriptor from the aggregate function.
38
+ const AggStateDesc* AggStateUtils::get_agg_state_desc (const AggregateFunction* agg_func) {
39
+ if (dynamic_cast <const AggStateUnion*>(agg_func)) {
40
+ auto * agg_state_union = down_cast<const AggStateUnion*>(agg_func);
41
+ return agg_state_union->get_agg_state_desc ();
42
+ } else if (dynamic_cast <const AggStateMerge*>(agg_func)) {
43
+ auto * agg_state_merge = down_cast<const AggStateMerge*>(agg_func);
44
+ return agg_state_merge->get_agg_state_desc ();
45
+ } else if (dynamic_cast <const AggStateCombine*>(agg_func)) {
46
+ auto * agg_state_merge = down_cast<const AggStateCombine*>(agg_func);
47
+ return agg_state_merge->get_agg_state_desc ();
48
+ } else if (dynamic_cast <const AggStateIf*>(agg_func)) {
49
+ auto * agg_state_if = down_cast<const AggStateIf*>(agg_func);
50
+ return agg_state_if->get_agg_state_desc ();
51
+ }
52
+ return nullptr ;
53
+ }
54
+
55
+ // Get the aggregate state function according to the agg_state_desc and function name.
56
+ // If the function is not an aggregate state function, return nullptr.
57
+ StatusOr<AggregateFunctionPtr> AggStateUtils::get_agg_state_function (const AggStateDesc& agg_state_desc,
58
+ const std::string& func_name,
59
+ const std::vector<TypeDescriptor>& arg_types) {
60
+ auto nested_func_name = agg_state_desc.get_func_name ();
61
+ bool is_merge_or_union = AggStateUtils::is_agg_state_merge (nested_func_name, func_name) ||
62
+ AggStateUtils::is_agg_state_union (nested_func_name, func_name);
63
+ if (is_merge_or_union && arg_types.size () != 1 ) {
64
+ return Status::InternalError (
65
+ fmt::format (" Invalid agg function plan: {} with (arg type {})" , func_name, arg_types.size ()));
66
+ }
67
+
68
+ if (AggStateUtils::is_agg_state_merge (nested_func_name, func_name)) {
69
+ // aggregate _merge combinator
70
+ auto * nested_func = AggStateDesc::get_agg_state_func (&agg_state_desc);
71
+ if (nested_func == nullptr ) {
72
+ return Status::InternalError (fmt::format (
73
+ " Merge combinator function {} fails to get the nested agg func: {}" , func_name, nested_func_name));
74
+ }
75
+ return std::make_shared<AggStateMerge>(std::move (agg_state_desc), nested_func);
76
+ } else if (AggStateUtils::is_agg_state_union (nested_func_name, func_name)) {
77
+ // aggregate _union combinator
78
+ auto * nested_func = AggStateDesc::get_agg_state_func (&agg_state_desc);
79
+ if (nested_func == nullptr ) {
80
+ return Status::InternalError (fmt::format (
81
+ " Union combinator function {} fails to get the nested agg func: {}" , func_name, nested_func_name));
82
+ }
83
+ return std::make_shared<AggStateUnion>(std::move (agg_state_desc), nested_func);
84
+ } else if (AggStateUtils::is_agg_state_combine (nested_func_name, func_name)) {
85
+ // aggregate _combine combinator
86
+ auto * nested_func = AggStateDesc::get_agg_state_func (&agg_state_desc);
87
+ if (nested_func == nullptr ) {
88
+ return Status::InternalError (
89
+ fmt::format (" Combine combinator function {} fails to get the nested agg func: {}" , func_name,
90
+ nested_func_name));
91
+ }
92
+ return std::make_shared<AggStateCombine>(std::move (agg_state_desc), nested_func);
93
+ } else if (AggStateUtils::is_agg_state_if (nested_func_name, func_name)) {
94
+ // aggregate _if combinator
95
+ auto * nested_func = AggStateDesc::get_agg_state_func (&agg_state_desc);
96
+ if (nested_func == nullptr ) {
97
+ return Status::InternalError (fmt::format (" if combinator function {} fails to get the nested agg func: {}" ,
98
+ func_name, nested_func_name));
99
+ }
100
+ return std::make_shared<AggStateIf>(std::move (agg_state_desc), nested_func);
101
+ } else {
102
+ return Status::InternalError (fmt::format (" Agg function combinator is not implemented: {}" , func_name));
103
+ }
104
+ }
105
+
106
+ // Get the aggregate state function according to the TAggStateDesc and function name.
107
+ // If the function is not an aggregate state function, return nullptr.
108
+ StateCombinatorPtr AggStateUtils::get_agg_state_function (const TAggStateDesc& desc, const std::string& func_name,
109
+ const TypeDescriptor& return_type,
110
+ std::vector<bool > arg_nullables) {
111
+ if (is_agg_state_function (func_name)) {
112
+ auto agg_state_desc = AggStateDesc::from_thrift (desc);
113
+ // For _state combinator function, it's created according to the agg_state_desc rather than fid.
114
+ return std::make_shared<StateFunction>(agg_state_desc, return_type, std::move (arg_nullables));
115
+ } else if (is_agg_state_union_function (func_name)) {
116
+ auto agg_state_desc = AggStateDesc::from_thrift (desc);
117
+ // For _state combinator function, it's created according to the agg_state_desc rather than fid.
118
+ return std::make_shared<StateUnionFunction>(agg_state_desc, return_type, std::move (arg_nullables));
119
+ } else if (is_agg_state_merge_function (func_name)) {
120
+ auto agg_state_desc = AggStateDesc::from_thrift (desc);
121
+ // For _state combinator function, it's created according to the agg_state_desc rather than fid.
122
+ return std::make_shared<StateMergeFunction>(agg_state_desc, return_type, std::move (arg_nullables));
123
+ } else {
124
+ return nullptr ;
125
+ }
126
+ }
127
+
128
+ } // namespace starrocks
0 commit comments