Merging upstream version 10.2.9.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
3439d8569e
commit
2468c1121f
13 changed files with 91 additions and 26 deletions
|
@ -30,7 +30,7 @@ from sqlglot.parser import Parser
|
||||||
from sqlglot.schema import MappingSchema
|
from sqlglot.schema import MappingSchema
|
||||||
from sqlglot.tokens import Tokenizer, TokenType
|
from sqlglot.tokens import Tokenizer, TokenType
|
||||||
|
|
||||||
__version__ = "10.2.6"
|
__version__ = "10.2.9"
|
||||||
|
|
||||||
pretty = False
|
pretty = False
|
||||||
|
|
||||||
|
|
|
@ -250,6 +250,7 @@ class Hive(Dialect):
|
||||||
TYPE_MAPPING = {
|
TYPE_MAPPING = {
|
||||||
**generator.Generator.TYPE_MAPPING,
|
**generator.Generator.TYPE_MAPPING,
|
||||||
exp.DataType.Type.TEXT: "STRING",
|
exp.DataType.Type.TEXT: "STRING",
|
||||||
|
exp.DataType.Type.DATETIME: "TIMESTAMP",
|
||||||
exp.DataType.Type.VARBINARY: "BINARY",
|
exp.DataType.Type.VARBINARY: "BINARY",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -244,6 +244,7 @@ class Postgres(Dialect):
|
||||||
|
|
||||||
class Parser(parser.Parser):
|
class Parser(parser.Parser):
|
||||||
STRICT_CAST = False
|
STRICT_CAST = False
|
||||||
|
LATERAL_FUNCTION_AS_VIEW = True
|
||||||
|
|
||||||
FUNCTIONS = {
|
FUNCTIONS = {
|
||||||
**parser.Parser.FUNCTIONS,
|
**parser.Parser.FUNCTIONS,
|
||||||
|
|
|
@ -224,6 +224,12 @@ class TSQL(Dialect):
|
||||||
class Tokenizer(tokens.Tokenizer):
|
class Tokenizer(tokens.Tokenizer):
|
||||||
IDENTIFIERS = ['"', ("[", "]")]
|
IDENTIFIERS = ['"', ("[", "]")]
|
||||||
|
|
||||||
|
QUOTES = [
|
||||||
|
(prefix + quote, quote) if prefix else quote
|
||||||
|
for quote in ["'", '"']
|
||||||
|
for prefix in ["", "n", "N"]
|
||||||
|
]
|
||||||
|
|
||||||
KEYWORDS = {
|
KEYWORDS = {
|
||||||
**tokens.Tokenizer.KEYWORDS,
|
**tokens.Tokenizer.KEYWORDS,
|
||||||
"BIT": TokenType.BOOLEAN,
|
"BIT": TokenType.BOOLEAN,
|
||||||
|
|
|
@ -3673,7 +3673,11 @@ def table_(table, db=None, catalog=None, quoted=None, alias=None) -> Table:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def values(values, alias=None) -> Values:
|
def values(
|
||||||
|
values: t.Iterable[t.Tuple[t.Any, ...]],
|
||||||
|
alias: t.Optional[str] = None,
|
||||||
|
columns: t.Optional[t.Iterable[str]] = None,
|
||||||
|
) -> Values:
|
||||||
"""Build VALUES statement.
|
"""Build VALUES statement.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
@ -3681,17 +3685,23 @@ def values(values, alias=None) -> Values:
|
||||||
"VALUES (1, '2')"
|
"VALUES (1, '2')"
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
values (list[tuple[str | Expression]]): values statements that will be converted to SQL
|
values: values statements that will be converted to SQL
|
||||||
alias (str): optional alias
|
alias: optional alias
|
||||||
dialect (str): the dialect used to parse the input expression.
|
columns: Optional list of ordered column names. An alias is required when providing column names.
|
||||||
**opts: other options to use to parse the input expressions.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Values: the Values expression object
|
Values: the Values expression object
|
||||||
"""
|
"""
|
||||||
|
if columns and not alias:
|
||||||
|
raise ValueError("Alias is required when providing columns")
|
||||||
|
table_alias = (
|
||||||
|
TableAlias(this=to_identifier(alias), columns=[to_identifier(x) for x in columns])
|
||||||
|
if columns
|
||||||
|
else TableAlias(this=to_identifier(alias) if alias else None)
|
||||||
|
)
|
||||||
return Values(
|
return Values(
|
||||||
expressions=[convert(tup) for tup in values],
|
expressions=[convert(tup) for tup in values],
|
||||||
alias=to_identifier(alias) if alias else None,
|
alias=table_alias,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -795,14 +795,16 @@ class Generator:
|
||||||
|
|
||||||
alias = expression.args["alias"]
|
alias = expression.args["alias"]
|
||||||
table = alias.name
|
table = alias.name
|
||||||
table = f" {table}" if table else table
|
|
||||||
columns = self.expressions(alias, key="columns", flat=True)
|
columns = self.expressions(alias, key="columns", flat=True)
|
||||||
columns = f" AS {columns}" if columns else ""
|
|
||||||
|
|
||||||
if expression.args.get("view"):
|
if expression.args.get("view"):
|
||||||
|
table = f" {table}" if table else table
|
||||||
|
columns = f" AS {columns}" if columns else ""
|
||||||
op_sql = self.seg(f"LATERAL VIEW{' OUTER' if expression.args.get('outer') else ''}")
|
op_sql = self.seg(f"LATERAL VIEW{' OUTER' if expression.args.get('outer') else ''}")
|
||||||
return f"{op_sql}{self.sep()}{this}{table}{columns}"
|
return f"{op_sql}{self.sep()}{this}{table}{columns}"
|
||||||
|
|
||||||
|
table = f" AS {table}" if table else table
|
||||||
|
columns = f"({columns})" if columns else ""
|
||||||
return f"LATERAL {this}{table}{columns}"
|
return f"LATERAL {this}{table}{columns}"
|
||||||
|
|
||||||
def limit_sql(self, expression: exp.Limit) -> str:
|
def limit_sql(self, expression: exp.Limit) -> str:
|
||||||
|
@ -889,8 +891,8 @@ class Generator:
|
||||||
def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str:
|
def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str:
|
||||||
return csv(
|
return csv(
|
||||||
*sqls,
|
*sqls,
|
||||||
*[self.sql(sql) for sql in expression.args.get("laterals", [])],
|
|
||||||
*[self.sql(sql) for sql in expression.args.get("joins", [])],
|
*[self.sql(sql) for sql in expression.args.get("joins", [])],
|
||||||
|
*[self.sql(sql) for sql in expression.args.get("laterals", [])],
|
||||||
self.sql(expression, "where"),
|
self.sql(expression, "where"),
|
||||||
self.sql(expression, "group"),
|
self.sql(expression, "group"),
|
||||||
self.sql(expression, "having"),
|
self.sql(expression, "having"),
|
||||||
|
|
|
@ -562,6 +562,7 @@ class Parser(metaclass=_Parser):
|
||||||
TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"}
|
TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"}
|
||||||
|
|
||||||
STRICT_CAST = True
|
STRICT_CAST = True
|
||||||
|
LATERAL_FUNCTION_AS_VIEW = False
|
||||||
|
|
||||||
__slots__ = (
|
__slots__ = (
|
||||||
"error_level",
|
"error_level",
|
||||||
|
@ -1287,14 +1288,24 @@ class Parser(metaclass=_Parser):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if not this:
|
if not this:
|
||||||
this = self._parse_function()
|
this = self._parse_function() or self._parse_id_var(any_token=False)
|
||||||
|
while self._match(TokenType.DOT):
|
||||||
table_alias = self._parse_id_var(any_token=False)
|
this = exp.Dot(
|
||||||
|
this=this,
|
||||||
|
expression=self._parse_function() or self._parse_id_var(any_token=False),
|
||||||
|
)
|
||||||
|
|
||||||
columns = None
|
columns = None
|
||||||
|
table_alias = None
|
||||||
|
if view or self.LATERAL_FUNCTION_AS_VIEW:
|
||||||
|
table_alias = self._parse_id_var(any_token=False)
|
||||||
if self._match(TokenType.ALIAS):
|
if self._match(TokenType.ALIAS):
|
||||||
columns = self._parse_csv(self._parse_id_var)
|
columns = self._parse_csv(self._parse_id_var)
|
||||||
elif self._match(TokenType.L_PAREN):
|
else:
|
||||||
|
self._match(TokenType.ALIAS)
|
||||||
|
table_alias = self._parse_id_var(any_token=False)
|
||||||
|
|
||||||
|
if self._match(TokenType.L_PAREN):
|
||||||
columns = self._parse_csv(self._parse_id_var)
|
columns = self._parse_csv(self._parse_id_var)
|
||||||
self._match_r_paren()
|
self._match_r_paren()
|
||||||
|
|
||||||
|
|
|
@ -237,12 +237,17 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
|
||||||
if table_:
|
if table_:
|
||||||
table_schema = self.find(table_, raise_on_missing=False)
|
table_schema = self.find(table_, raise_on_missing=False)
|
||||||
if table_schema:
|
if table_schema:
|
||||||
schema_type = table_schema.get(column_name).upper() # type: ignore
|
column_type = table_schema.get(column_name)
|
||||||
return self._convert_type(schema_type)
|
|
||||||
|
if isinstance(column_type, exp.DataType):
|
||||||
|
return column_type
|
||||||
|
elif isinstance(column_type, str):
|
||||||
|
return self._to_data_type(column_type.upper())
|
||||||
|
raise SchemaError(f"Unknown column type '{column_type}'")
|
||||||
return exp.DataType(this=exp.DataType.Type.UNKNOWN)
|
return exp.DataType(this=exp.DataType.Type.UNKNOWN)
|
||||||
raise SchemaError(f"Could not convert table '{table}'")
|
raise SchemaError(f"Could not convert table '{table}'")
|
||||||
|
|
||||||
def _convert_type(self, schema_type: str) -> exp.DataType:
|
def _to_data_type(self, schema_type: str) -> exp.DataType:
|
||||||
"""
|
"""
|
||||||
Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object.
|
Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object.
|
||||||
|
|
||||||
|
|
|
@ -496,7 +496,7 @@ FROM cs.telescope.dag_report, TABLE(FLATTEN(input => SPLIT(operators, ','))) AS
|
||||||
f.value AS "Contact",
|
f.value AS "Contact",
|
||||||
f1.value['type'] AS "Type",
|
f1.value['type'] AS "Type",
|
||||||
f1.value['content'] AS "Details"
|
f1.value['content'] AS "Details"
|
||||||
FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') f, LATERAL FLATTEN(input => f.value['business']) f1""",
|
FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') AS f, LATERAL FLATTEN(input => f.value['business']) AS f1""",
|
||||||
},
|
},
|
||||||
pretty=True,
|
pretty=True,
|
||||||
)
|
)
|
||||||
|
|
|
@ -371,13 +371,19 @@ class TestTSQL(Validator):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"SELECT t.x, y.z FROM x CROSS APPLY tvfTest(t.x)y(z)",
|
"SELECT t.x, y.z FROM x CROSS APPLY tvfTest(t.x)y(z)",
|
||||||
write={
|
write={
|
||||||
"spark": "SELECT t.x, y.z FROM x JOIN LATERAL TVFTEST(t.x) y AS z",
|
"spark": "SELECT t.x, y.z FROM x JOIN LATERAL TVFTEST(t.x) AS y(z)",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"SELECT t.x, y.z FROM x OUTER APPLY tvfTest(t.x)y(z)",
|
"SELECT t.x, y.z FROM x OUTER APPLY tvfTest(t.x)y(z)",
|
||||||
write={
|
write={
|
||||||
"spark": "SELECT t.x, y.z FROM x LEFT JOIN LATERAL TVFTEST(t.x) y AS z",
|
"spark": "SELECT t.x, y.z FROM x LEFT JOIN LATERAL TVFTEST(t.x) AS y(z)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT t.x, y.z FROM x OUTER APPLY a.b.tvfTest(t.x)y(z)",
|
||||||
|
write={
|
||||||
|
"spark": "SELECT t.x, y.z FROM x LEFT JOIN LATERAL a.b.TVFTEST(t.x) AS y(z)",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -421,3 +427,17 @@ class TestTSQL(Validator):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"SELECT FORMAT(num_col, 'c')", write={"spark": "SELECT FORMAT_NUMBER(num_col, 'c')"}
|
"SELECT FORMAT(num_col, 'c')", write={"spark": "SELECT FORMAT_NUMBER(num_col, 'c')"}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_string(self):
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT N'test'",
|
||||||
|
write={"spark": "SELECT 'test'"},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT n'test'",
|
||||||
|
write={"spark": "SELECT 'test'"},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT '''test'''",
|
||||||
|
write={"spark": r"SELECT '\'test\''"},
|
||||||
|
)
|
||||||
|
|
|
@ -473,6 +473,12 @@ class TestBuild(unittest.TestCase):
|
||||||
(lambda: exp.values([("1", 2)]), "VALUES ('1', 2)"),
|
(lambda: exp.values([("1", 2)]), "VALUES ('1', 2)"),
|
||||||
(lambda: exp.values([("1", 2)], "alias"), "(VALUES ('1', 2)) AS alias"),
|
(lambda: exp.values([("1", 2)], "alias"), "(VALUES ('1', 2)) AS alias"),
|
||||||
(lambda: exp.values([("1", 2), ("2", 3)]), "VALUES ('1', 2), ('2', 3)"),
|
(lambda: exp.values([("1", 2), ("2", 3)]), "VALUES ('1', 2), ('2', 3)"),
|
||||||
|
(
|
||||||
|
lambda: exp.values(
|
||||||
|
[("1", 2, None), ("2", 3, None)], "alias", ["col1", "col2", "col3"]
|
||||||
|
),
|
||||||
|
"(VALUES ('1', 2, NULL), ('2', 3, NULL)) AS alias(col1, col2, col3)",
|
||||||
|
),
|
||||||
(lambda: exp.delete("y", where="x > 1"), "DELETE FROM y WHERE x > 1"),
|
(lambda: exp.delete("y", where="x > 1"), "DELETE FROM y WHERE x > 1"),
|
||||||
(lambda: exp.delete("y", where=exp.and_("x > 1")), "DELETE FROM y WHERE x > 1"),
|
(lambda: exp.delete("y", where=exp.and_("x > 1")), "DELETE FROM y WHERE x > 1"),
|
||||||
]:
|
]:
|
||||||
|
|
|
@ -85,7 +85,7 @@ class TestParser(unittest.TestCase):
|
||||||
self.assertEqual(len(parse_one("select * from (select 1) x cross join y").args["joins"]), 1)
|
self.assertEqual(len(parse_one("select * from (select 1) x cross join y").args["joins"]), 1)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
parse_one("""SELECT * FROM x CROSS JOIN y, z LATERAL VIEW EXPLODE(y)""").sql(),
|
parse_one("""SELECT * FROM x CROSS JOIN y, z LATERAL VIEW EXPLODE(y)""").sql(),
|
||||||
"""SELECT * FROM x, z LATERAL VIEW EXPLODE(y) CROSS JOIN y""",
|
"""SELECT * FROM x, z CROSS JOIN y LATERAL VIEW EXPLODE(y)""",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_command(self):
|
def test_command(self):
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from sqlglot import exp, to_table
|
from sqlglot import exp, parse_one, to_table
|
||||||
from sqlglot.errors import SchemaError
|
from sqlglot.errors import SchemaError
|
||||||
from sqlglot.schema import MappingSchema, ensure_schema
|
from sqlglot.schema import MappingSchema, ensure_schema
|
||||||
|
|
||||||
|
@ -181,3 +181,6 @@ class TestSchema(unittest.TestCase):
|
||||||
schema.get_column_type(exp.Table(this="c", db="b", catalog="a"), "d").this,
|
schema.get_column_type(exp.Table(this="c", db="b", catalog="a"), "d").this,
|
||||||
exp.DataType.Type.VARCHAR,
|
exp.DataType.Type.VARCHAR,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
schema = MappingSchema({"foo": {"bar": parse_one("INT", into=exp.DataType)}})
|
||||||
|
self.assertEqual(schema.get_column_type("foo", "bar").this, exp.DataType.Type.INT)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue