Skip to content

Commit ae4f84e

Browse files
[Enhancement] Support complex expressions in FILTER clause and add boolean type validation for aggregate functions
Signed-off-by: stephen <stephen5217@163.com>
1 parent 4f55da0 commit ae4f84e

File tree

4 files changed

+88
-2
lines changed

4 files changed

+88
-2
lines changed

fe/fe-core/src/main/java/com/starrocks/sql/analyzer/FunctionAnalyzer.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,17 @@ public static void analyze(FunctionCallExpr functionCallExpr) {
214214
FunctionName argFuncNameWithoutIf =
215215
new FunctionName(AggStateUtils.getAggFuncNameOfCombinator(fnName.getFunction()));
216216
FunctionParams params = functionCallExpr.getParams();
217+
218+
// Validate that the condition parameter (last parameter) is boolean type
219+
if (!params.exprs().isEmpty()) {
220+
Expr conditionExpr = params.exprs().get(params.exprs().size() - 1);
221+
if (!conditionExpr.getType().isBoolean()) {
222+
throw new SemanticException(String.format(
223+
"The condition expression in %s function must be boolean type, but got %s",
224+
fnName.getFunction(), conditionExpr.getType().toSql()), functionCallExpr.getPos());
225+
}
226+
}
227+
217228
FunctionParams functionParamsWithOutIf =
218229
new FunctionParams(params.isStar(), params.exprs().subList(0, params.exprs().size() - 1),
219230
params.getExprsNames() == null ? null :

fe/fe-core/src/main/java/com/starrocks/sql/parser/AstBuilder.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7672,7 +7672,7 @@ public ParseNode visitAggregationFunctionCall(StarRocksParser.AggregationFunctio
76727672
if (isCountFunc && isDistinct) {
76737673
throw new ParsingException("Aggregation filter does not support COUNT DISTINCT");
76747674
}
7675-
Expr booleanExpr = (Expr) visit(context.filter().booleanExpression());
7675+
Expr booleanExpr = (Expr) visit(context.filter().expression());
76767676
functionName = functionName + FunctionSet.AGG_STATE_IF_SUFFIX;
76777677
exprs.add(booleanExpr);
76787678

fe/fe-core/src/test/java/com/starrocks/sql/plan/AggregateTest.java

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3058,4 +3058,79 @@ public void testGroupByCompressedKey() throws Exception {
30583058
plan = getThriftPlan(sql);
30593059
assertContains(plan, "group_by_min_max:[TExpr(");
30603060
}
3061+
3062+
@Test
3063+
public void testAggregateFilterSyntax() throws Exception {
3064+
// Test basic FILTER syntax with boolean expression
3065+
String sql = "select count(*) filter (where v1 > 5) from t0";
3066+
String plan = getFragmentPlan(sql);
3067+
assertContains(plan, "count_if");
3068+
3069+
// Test FILTER with complex boolean expression
3070+
sql = "select sum(v2) filter (where v1 > 5 and v2 < 10) from t0";
3071+
plan = getFragmentPlan(sql);
3072+
assertContains(plan, "sum_if");
3073+
3074+
// Test FILTER with logical operators
3075+
sql = "select avg(v3) filter (where v1 = 1 or v2 = 2) from t0";
3076+
plan = getFragmentPlan(sql);
3077+
assertContains(plan, "avg_if");
3078+
3079+
// Test FILTER with NOT operator
3080+
sql = "select max(v1) filter (where not (v2 > 10)) from t0";
3081+
plan = getFragmentPlan(sql);
3082+
assertContains(plan, "max_if");
3083+
}
3084+
3085+
@Test
3086+
public void testAggregateFilterBooleanTypeValidation() throws Exception {
3087+
// Test that non-boolean expressions in FILTER throw semantic exceptions
3088+
String sql = "select count(*) filter (where v1) from t0";
3089+
try {
3090+
getFragmentPlan(sql);
3091+
Assertions.fail("Expected semantic exception for non-boolean filter condition");
3092+
} catch (Exception e) {
3093+
assertContains(e.getMessage(), "The condition expression in count_if function must be boolean type, but got bigint");
3094+
}
3095+
3096+
// Test that string expressions in FILTER throw semantic exceptions
3097+
sql = "select sum(v2) filter (where 'true') from t0";
3098+
try {
3099+
getFragmentPlan(sql);
3100+
Assertions.fail("Expected semantic exception for string filter condition");
3101+
} catch (Exception e) {
3102+
assertContains(e.getMessage(), "The condition expression in sum_if function must be boolean type, but got varchar");
3103+
}
3104+
}
3105+
3106+
@Test
3107+
public void testAggIfFunctionBooleanTypeValidation() throws Exception {
3108+
// Test sum_if with correct boolean condition
3109+
String sql = "select sum_if(v2, v1 > 5) from t0";
3110+
String plan = getFragmentPlan(sql);
3111+
assertContains(plan, "sum_if");
3112+
3113+
// Test count_if with correct boolean condition
3114+
sql = "select count_if(v1 > 0 and v2 < 100) from t0";
3115+
plan = getFragmentPlan(sql);
3116+
assertContains(plan, "count_if");
3117+
3118+
// Test that non-boolean condition in sum_if throws exception
3119+
sql = "select sum_if(v2, v1) from t0";
3120+
try {
3121+
getFragmentPlan(sql);
3122+
Assertions.fail("Expected semantic exception for non-boolean condition in sum_if");
3123+
} catch (Exception e) {
3124+
assertContains(e.getMessage(), "The condition expression in sum_if function must be boolean type, but got bigint");
3125+
}
3126+
3127+
// Test that numeric condition in count_if throws exception
3128+
sql = "select count_if(v2) from t0";
3129+
try {
3130+
getFragmentPlan(sql);
3131+
Assertions.fail("Expected semantic exception for non-boolean condition in count_if");
3132+
} catch (Exception e) {
3133+
assertContains(e.getMessage(), "The condition expression in count_if function must be boolean type, but got bigint");
3134+
}
3135+
}
30613136
}

fe/fe-grammar/src/main/antlr/com/starrocks/grammar/StarRocks.g4

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2750,7 +2750,7 @@ whenClause
27502750
;
27512751

27522752
filter
2753-
: FILTER '(' WHERE booleanExpression ')'
2753+
: FILTER '(' WHERE expression ')'
27542754
;
27552755

27562756
over

0 commit comments

Comments
 (0)