Skip to content

Commit 3408de0

Browse files
authored
Fix: parsing quoted built-in data types (#5810)
* Do not parse quoted built-in types into UDTs * PR feedback
1 parent e3cb076 commit 3408de0

File tree

3 files changed

+31
-16
lines changed

3 files changed

+31
-16
lines changed

sqlglot/parser.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from collections import defaultdict
88

99
from sqlglot import exp
10-
from sqlglot.errors import ErrorLevel, ParseError, concat_messages, merge_errors
10+
from sqlglot.errors import ErrorLevel, ParseError, TokenError, concat_messages, merge_errors
1111
from sqlglot.helper import apply_index_offset, ensure_list, seq_get
1212
from sqlglot.time import format_time
1313
from sqlglot.tokens import Token, Tokenizer, TokenType
@@ -5298,28 +5298,30 @@ def _parse_types(
52985298
this: t.Optional[exp.Expression] = None
52995299
prefix = self._match_text_seq("SYSUDTLIB", ".")
53005300

5301-
if not self._match_set(self.TYPE_TOKENS):
5301+
if self._match_set(self.TYPE_TOKENS):
5302+
type_token = self._prev.token_type
5303+
else:
5304+
type_token = None
53025305
identifier = allow_identifiers and self._parse_id_var(
53035306
any_token=False, tokens=(TokenType.VAR,)
53045307
)
53055308
if isinstance(identifier, exp.Identifier):
5306-
tokens = self.dialect.tokenize(identifier.sql(dialect=self.dialect))
5307-
5308-
if len(tokens) != 1:
5309-
self.raise_error("Unexpected identifier", self._prev)
5309+
try:
5310+
tokens = self.dialect.tokenize(identifier.name)
5311+
except TokenError:
5312+
tokens = None
53105313

5311-
if tokens[0].token_type in self.TYPE_TOKENS:
5312-
self._prev = tokens[0]
5313-
elif self.dialect.SUPPORTS_USER_DEFINED_TYPES:
5314-
this = self._parse_user_defined_type(identifier)
5314+
if tokens and len(tokens) == 1 and tokens[0].token_type in self.TYPE_TOKENS:
5315+
type_token = tokens[0].token_type
53155316
else:
5316-
self._retreat(self._index - 1)
5317-
return None
5317+
if self.dialect.SUPPORTS_USER_DEFINED_TYPES:
5318+
this = self._parse_user_defined_type(identifier)
5319+
else:
5320+
self._retreat(self._index - 1)
5321+
return None
53185322
else:
53195323
return None
53205324

5321-
type_token = self._prev.token_type
5322-
53235325
if type_token == TokenType.PSEUDO_TYPE:
53245326
return self.expression(exp.PseudoType, this=self._prev.text.upper())
53255327

tests/dialects/test_postgres.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from sqlglot import ParseError, UnsupportedError, exp, transpile
1+
from sqlglot import ParseError, UnsupportedError, exp, transpile, parse_one
22
from sqlglot.helper import logger as helper_logger
33
from tests.dialects.test_dialect import Validator
44

@@ -805,6 +805,12 @@ def test_postgres(self):
805805
)
806806
self.assertIsInstance(self.parse_one("id::UUID"), exp.Cast)
807807

808+
self.validate_identity('1::"int"', "CAST(1 AS INT)")
809+
assert parse_one('1::"int"', read="postgres").to.is_type(exp.DataType.Type.INT)
810+
811+
self.validate_identity('1::"udt"', 'CAST(1 AS "udt")')
812+
assert parse_one('1::"udt"', read="postgres").to.this == exp.DataType.Type.USERDEFINED
813+
808814
self.validate_identity(
809815
"COPY tbl (col1, col2) FROM 'file' WITH (FORMAT format, HEADER MATCH, FREEZE TRUE)"
810816
)

tests/dialects/test_redshift.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from sqlglot import exp, parse_one, transpile
1+
from sqlglot import exp, ParseError, parse_one, transpile
22
from tests.dialects.test_dialect import Validator
33

44

@@ -698,3 +698,10 @@ def test_analyze(self):
698698
self.validate_identity("ANALYZE VERBOSE TBL")
699699
self.validate_identity("ANALYZE TBL PREDICATE COLUMNS")
700700
self.validate_identity("ANALYZE TBL ALL COLUMNS")
701+
702+
def test_cast(self):
703+
self.validate_identity('1::"int"', "CAST(1 AS INTEGER)")
704+
assert parse_one('1::"int"', read="redshift").to.is_type(exp.DataType.Type.INT)
705+
706+
with self.assertRaises(ParseError):
707+
parse_one('1::"udt"', read="redshift")

0 commit comments

Comments
 (0)