Skip to content

Commit 3b4e463

Browse files
feat(optimizer)!: annotate types for Snowflake TRIM function
1 parent b128339 commit 3b4e463

File tree

4 files changed

+92
-67
lines changed

4 files changed

+92
-67
lines changed

sqlglot/dialects/snowflake.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1362,9 +1362,10 @@ class Generator(generator.Generator):
13621362

13631363
TYPE_MAPPING = {
13641364
**generator.Generator.TYPE_MAPPING,
1365+
exp.DataType.Type.BIGDECIMAL: "DOUBLE",
13651366
exp.DataType.Type.NESTED: "OBJECT",
13661367
exp.DataType.Type.STRUCT: "OBJECT",
1367-
exp.DataType.Type.BIGDECIMAL: "DOUBLE",
1368+
exp.DataType.Type.TEXT: "VARCHAR",
13681369
}
13691370

13701371
TOKEN_MAPPING = {

tests/dialects/test_dialect.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def test_cast(self):
209209
"postgres": "CAST(a AS TEXT)",
210210
"presto": "CAST(a AS VARCHAR)",
211211
"redshift": "CAST(a AS VARCHAR(MAX))",
212-
"snowflake": "CAST(a AS TEXT)",
212+
"snowflake": "CAST(a AS VARCHAR)",
213213
"spark": "CAST(a AS STRING)",
214214
"starrocks": "CAST(a AS STRING)",
215215
"tsql": "CAST(a AS VARCHAR(MAX))",
@@ -293,7 +293,7 @@ def test_cast(self):
293293
"postgres": "CAST(a AS TEXT)",
294294
"presto": "CAST(a AS VARCHAR)",
295295
"redshift": "CAST(a AS VARCHAR(MAX))",
296-
"snowflake": "CAST(a AS TEXT)",
296+
"snowflake": "CAST(a AS VARCHAR)",
297297
"spark": "CAST(a AS STRING)",
298298
"starrocks": "CAST(a AS STRING)",
299299
"tsql": "CAST(a AS VARCHAR(MAX))",

tests/dialects/test_snowflake.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,10 @@ def test_snowflake(self):
8585
self.validate_identity("SELECT CONVERT_TIMEZONE('UTC', 'America/Los_Angeles', col)")
8686
self.validate_identity("ALTER TABLE a SWAP WITH b")
8787
self.validate_identity("SELECT MATCH_CONDITION")
88-
self.validate_identity("SELECT * REPLACE (CAST(col AS TEXT) AS scol) FROM t")
88+
self.validate_identity(
89+
"SELECT * REPLACE (CAST(col AS TEXT) AS scol) FROM t",
90+
"SELECT * REPLACE (CAST(col AS VARCHAR) AS scol) FROM t",
91+
)
8992
self.validate_identity("1 /* /* */")
9093
self.validate_identity("TO_TIMESTAMP(col, fmt)")
9194
self.validate_identity("SELECT TO_CHAR(CAST('12:05:05' AS TIME))")
@@ -178,7 +181,7 @@ def test_snowflake(self):
178181
)
179182
self.validate_identity(
180183
"SELECT a:from::STRING, a:from || ' test' ",
181-
"SELECT CAST(GET_PATH(a, 'from') AS TEXT), GET_PATH(a, 'from') || ' test'",
184+
"SELECT CAST(GET_PATH(a, 'from') AS VARCHAR), GET_PATH(a, 'from') || ' test'",
182185
)
183186
self.validate_identity(
184187
"SELECT a:select",
@@ -187,7 +190,7 @@ def test_snowflake(self):
187190
self.validate_identity("x:from", "GET_PATH(x, 'from')")
188191
self.validate_identity(
189192
"value:values::string::int",
190-
"CAST(CAST(GET_PATH(value, 'values') AS TEXT) AS INT)",
193+
"CAST(CAST(GET_PATH(value, 'values') AS VARCHAR) AS INT)",
191194
)
192195
self.validate_identity(
193196
"""SELECT GET_PATH(PARSE_JSON('{"y": [{"z": 1}]}'), 'y[0]:z')""",
@@ -2638,25 +2641,26 @@ def test_swap(self):
26382641
assert isinstance(ast.args["actions"][0], exp.SwapTable)
26392642

26402643
def test_try_cast(self):
2641-
self.validate_identity("SELECT TRY_CAST(x AS DOUBLE)")
2642-
self.validate_identity("SELECT TRY_CAST(FOO() AS TEXT)")
2643-
2644-
self.validate_all("TRY_CAST('foo' AS TEXT)", read={"hive": "CAST('foo' AS STRING)"})
2645-
self.validate_all("CAST(5 + 5 AS TEXT)", read={"hive": "CAST(5 + 5 AS STRING)"})
2644+
self.validate_all("TRY_CAST('foo' AS VARCHAR)", read={"hive": "CAST('foo' AS STRING)"})
2645+
self.validate_all("CAST(5 + 5 AS VARCHAR)", read={"hive": "CAST(5 + 5 AS STRING)"})
26462646
self.validate_all(
2647-
"CAST(TRY_CAST('2020-01-01' AS DATE) AS TEXT)",
2647+
"CAST(TRY_CAST('2020-01-01' AS DATE) AS VARCHAR)",
26482648
read={
26492649
"hive": "CAST(CAST('2020-01-01' AS DATE) AS STRING)",
2650-
"snowflake": "CAST(TRY_CAST('2020-01-01' AS DATE) AS TEXT)",
2650+
"snowflake": "CAST(TRY_CAST('2020-01-01' AS DATE) AS VARCHAR)",
26512651
},
26522652
)
26532653
self.validate_all(
2654-
"TRY_CAST('val' AS TEXT)",
2654+
"TRY_CAST('val' AS VARCHAR)",
26552655
read={
26562656
"hive": "CAST('val' AS STRING)",
2657-
"snowflake": "TRY_CAST('val' AS TEXT)",
2657+
"snowflake": "TRY_CAST('val' AS VARCHAR)",
26582658
},
26592659
)
2660+
self.validate_identity("SELECT TRY_CAST(x AS DOUBLE)")
2661+
self.validate_identity(
2662+
"SELECT TRY_CAST(FOO() AS TEXT)", "SELECT TRY_CAST(FOO() AS VARCHAR)"
2663+
)
26602664

26612665
from sqlglot.optimizer.annotate_types import annotate_types
26622666

@@ -2670,7 +2674,7 @@ def test_try_cast(self):
26702674

26712675
expression = annotate_types(expression, schema={"t": {"x": value_type}})
26722676
self.assertEqual(
2673-
expression.sql(dialect="snowflake"), f"SELECT {func}(t.x AS TEXT) FROM t"
2677+
expression.sql(dialect="snowflake"), f"SELECT {func}(t.x AS VARCHAR) FROM t"
26742678
)
26752679

26762680
def test_copy(self):

tests/fixtures/optimizer/annotate_functions.sql

Lines changed: 71 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1531,10 +1531,6 @@ STRING;
15311531
-- Snowflake
15321532
--------------------------------------
15331533

1534-
# dialect: snowflake
1535-
LEAST(x::DECIMAL(18, 2));
1536-
DECIMAL(18, 2);
1537-
15381534
# dialect: snowflake
15391535
CHARINDEX('world', 'hello world');
15401536
INT;
@@ -1543,18 +1539,6 @@ INT;
15431539
CHARINDEX('world', 'hello world', 1);
15441540
INT;
15451541

1546-
# dialect: snowflake
1547-
REVERSE('Hello, world!');
1548-
VARCHAR;
1549-
1550-
# dialect: snowflake
1551-
REVERSE(tbl.str_col);
1552-
TEXT;
1553-
1554-
# dialect: snowflake
1555-
REVERSE(tbl.bin_col);
1556-
BINARY;
1557-
15581542
# dialect: snowflake
15591543
CONCAT('Hello', 'World!');
15601544
VARCHAR;
@@ -1591,10 +1575,6 @@ BOOLEAN;
15911575
CONTAINS(tbl.bin_col, NULL);
15921576
BOOLEAN;
15931577

1594-
# dialect: snowflake
1595-
REVERSE(NULL);
1596-
VARCHAR;
1597-
15981578
# dialect: snowflake
15991579
ENDSWITH('hello world', 'world');
16001580
BOOLEAN;
@@ -1612,53 +1592,61 @@ ENDSWITH(tbl.bin_col, NULL);
16121592
BOOLEAN;
16131593

16141594
# dialect: snowflake
1615-
STARTSWITH('hello world', 'hello');
1616-
BOOLEAN;
1595+
LEAST(x::DECIMAL(18, 2));
1596+
DECIMAL(18, 2);
16171597

16181598
# dialect: snowflake
1619-
STARTSWITH(tbl.str_col, 'test');
1620-
BOOLEAN;
1599+
LEFT('hello world', 5);
1600+
VARCHAR;
16211601

16221602
# dialect: snowflake
1623-
STARTSWITH(tbl.bin_col, tbl.bin_col);
1624-
BOOLEAN;
1603+
LEFT(tbl.str_col, 3);
1604+
STRING;
16251605

16261606
# dialect: snowflake
1627-
STARTSWITH(tbl.bin_col, NULL);
1628-
BOOLEAN;
1607+
LEFT(tbl.bin_col, 3);
1608+
BINARY;
16291609

16301610
# dialect: snowflake
1631-
SUBSTR('hello world', 1, 5);
1632-
VARCHAR;
1611+
LEFT(tbl.bin_col, NULL);
1612+
BINARY;
16331613

16341614
# dialect: snowflake
1635-
SUBSTR(tbl.str_col, 1, 3);
1636-
STRING;
1615+
LEN(tbl.str_col);
1616+
INT;
16371617

16381618
# dialect: snowflake
1639-
SUBSTR(tbl.bin_col, 1, 3);
1640-
BINARY;
1619+
LEN(tbl.bin_col);
1620+
INT;
16411621

16421622
# dialect: snowflake
1643-
SUBSTR(tbl.str_col, NULL);
1644-
STRING;
1623+
LENGTH(tbl.str_col);
1624+
INT;
16451625

16461626
# dialect: snowflake
1647-
LEFT('hello world', 5);
1627+
LENGTH(tbl.bin_col);
1628+
INT;
1629+
1630+
# dialect: snowflake
1631+
LOWER(tbl.str_col);
16481632
VARCHAR;
16491633

16501634
# dialect: snowflake
1651-
LEFT(tbl.str_col, 3);
1652-
STRING;
1635+
REVERSE('Hello, world!');
1636+
VARCHAR;
16531637

16541638
# dialect: snowflake
1655-
LEFT(tbl.bin_col, 3);
1656-
BINARY;
1639+
REVERSE(tbl.str_col);
1640+
VARCHAR;
16571641

16581642
# dialect: snowflake
1659-
LEFT(tbl.bin_col, NULL);
1643+
REVERSE(tbl.bin_col);
16601644
BINARY;
16611645

1646+
# dialect: snowflake
1647+
REVERSE(NULL);
1648+
VARCHAR;
1649+
16621650
# dialect: snowflake
16631651
RIGHT('hello world', 5);
16641652
VARCHAR;
@@ -1676,23 +1664,55 @@ RIGHT(tbl.str_col, NULL);
16761664
STRING;
16771665

16781666
# dialect: snowflake
1679-
LENGTH(tbl.str_col);
1680-
INT;
1667+
STARTSWITH('hello world', 'hello');
1668+
BOOLEAN;
16811669

16821670
# dialect: snowflake
1683-
LENGTH(tbl.bin_col);
1684-
INT;
1671+
STARTSWITH(tbl.str_col, 'test');
1672+
BOOLEAN;
16851673

16861674
# dialect: snowflake
1687-
LEN(tbl.str_col);
1688-
INT;
1675+
STARTSWITH(tbl.bin_col, tbl.bin_col);
1676+
BOOLEAN;
16891677

16901678
# dialect: snowflake
1691-
LEN(tbl.bin_col);
1692-
INT;
1679+
STARTSWITH(tbl.bin_col, NULL);
1680+
BOOLEAN;
16931681

16941682
# dialect: snowflake
1695-
LOWER(tbl.str_col);
1683+
SUBSTR('hello world', 1, 5);
1684+
VARCHAR;
1685+
1686+
# dialect: snowflake
1687+
SUBSTR(tbl.str_col, 1, 3);
1688+
STRING;
1689+
1690+
# dialect: snowflake
1691+
SUBSTR(tbl.bin_col, 1, 3);
1692+
BINARY;
1693+
1694+
# dialect: snowflake
1695+
SUBSTR(tbl.str_col, NULL);
1696+
STRING;
1697+
1698+
# dialect: snowflake
1699+
TRIM('hello world');
1700+
VARCHAR;
1701+
1702+
# dialect: snowflake
1703+
TRIM('hello world', 'hello');
1704+
VARCHAR;
1705+
1706+
# dialect: snowflake
1707+
TRIM(tbl.str_col);
1708+
VARCHAR;
1709+
1710+
# dialect: snowflake
1711+
TRIM(tbl.str_col, tbl.str_col);
1712+
VARCHAR;
1713+
1714+
# dialect: snowflake
1715+
TRIM(NULL);
16961716
VARCHAR;
16971717

16981718
--------------------------------------

0 commit comments

Comments
 (0)