import unittest

from sqlglot import exp, parse_one, to_table
from sqlglot.errors import SchemaError
from sqlglot.schema import MappingSchema, ensure_schema


class TestSchema(unittest.TestCase):
    def assert_column_names(self, schema, *table_results):
        for table, result in table_results:
            with self.subTest(f"{table} -> {result}"):
                self.assertEqual(schema.column_names(to_table(table)), result)

    def assert_column_names_raises(self, schema, *tables):
        for table in tables:
            with self.subTest(table):
                with self.assertRaises(SchemaError):
                    schema.column_names(to_table(table))

    def assert_column_names_empty(self, schema, *tables):
        for table in tables:
            with self.subTest(table):
                self.assertEqual(schema.column_names(to_table(table)), [])

    def test_schema(self):
        schema = ensure_schema(
            {
                "x": {
                    "a": "uint64",
                },
                "y": {
                    "b": "uint64",
                    "c": "uint64",
                },
            },
        )

        self.assert_column_names(
            schema,
            ("x", ["a"]),
            ("y", ["b", "c"]),
            ("z.x", ["a"]),
            ("z.x.y", ["b", "c"]),
        )

        self.assert_column_names_empty(
            schema,
            "z",
            "z.z",
            "z.z.z",
        )

    def test_schema_db(self):
        schema = ensure_schema(
            {
                "d1": {
                    "x": {
                        "a": "uint64",
                    },
                    "y": {
                        "b": "uint64",
                    },
                },
                "d2": {
                    "x": {
                        "c": "uint64",
                    },
                },
            },
        )

        self.assert_column_names(
            schema,
            ("d1.x", ["a"]),
            ("d2.x", ["c"]),
            ("y", ["b"]),
            ("d1.y", ["b"]),
            ("z.d1.y", ["b"]),
        )

        self.assert_column_names_raises(
            schema,
            "x",
        )

        self.assert_column_names_empty(
            schema,
            "z.x",
            "z.y",
        )

    def test_schema_catalog(self):
        schema = ensure_schema(
            {
                "c1": {
                    "d1": {
                        "x": {
                            "a": "uint64",
                        },
                        "y": {
                            "b": "uint64",
                        },
                        "z": {
                            "c": "uint64",
                        },
                    },
                },
                "c2": {
                    "d1": {
                        "y": {
                            "d": "uint64",
                        },
                        "z": {
                            "e": "uint64",
                        },
                    },
                    "d2": {
                        "z": {
                            "f": "uint64",
                        },
                    },
                },
            }
        )

        self.assert_column_names(
            schema,
            ("x", ["a"]),
            ("d1.x", ["a"]),
            ("c1.d1.x", ["a"]),
            ("c1.d1.y", ["b"]),
            ("c1.d1.z", ["c"]),
            ("c2.d1.y", ["d"]),
            ("c2.d1.z", ["e"]),
            ("d2.z", ["f"]),
            ("c2.d2.z", ["f"]),
        )

        self.assert_column_names_raises(
            schema,
            "y",
            "z",
            "d1.y",
            "d1.z",
        )

        self.assert_column_names_empty(
            schema,
            "q",
            "d2.x",
            "a.b.c",
        )

    def test_schema_add_table_with_and_without_mapping(self):
        schema = MappingSchema()
        schema.add_table("test")
        self.assertEqual(schema.column_names("test"), [])
        schema.add_table("test", {"x": "string"})
        self.assertEqual(schema.column_names("test"), ["x"])
        schema.add_table("test", {"x": "string", "y": "int"})
        self.assertEqual(schema.column_names("test"), ["x", "y"])
        schema.add_table("test")
        self.assertEqual(schema.column_names("test"), ["x", "y"])

    def test_schema_get_column_type(self):
        schema = MappingSchema({"A": {"b": "varchar"}})
        self.assertEqual(schema.get_column_type("a", "B").this, exp.DataType.Type.VARCHAR)
        self.assertEqual(
            schema.get_column_type(exp.table_("a"), exp.column("b")).this,
            exp.DataType.Type.VARCHAR,
        )
        self.assertEqual(
            schema.get_column_type("a", exp.column("b")).this, exp.DataType.Type.VARCHAR
        )
        self.assertEqual(
            schema.get_column_type(exp.table_("a"), "b").this, exp.DataType.Type.VARCHAR
        )
        schema = MappingSchema({"a": {"b": {"c": "varchar"}}})
        self.assertEqual(
            schema.get_column_type(exp.table_("b", db="a"), exp.column("c")).this,
            exp.DataType.Type.VARCHAR,
        )
        self.assertEqual(
            schema.get_column_type(exp.table_("b", db="a"), "c").this, exp.DataType.Type.VARCHAR
        )
        schema = MappingSchema({"a": {"b": {"c": {"d": "varchar"}}}})
        self.assertEqual(
            schema.get_column_type(exp.table_("c", db="b", catalog="a"), exp.column("d")).this,
            exp.DataType.Type.VARCHAR,
        )
        self.assertEqual(
            schema.get_column_type(exp.table_("c", db="b", catalog="a"), "d").this,
            exp.DataType.Type.VARCHAR,
        )

        schema = MappingSchema({"foo": {"bar": parse_one("INT", into=exp.DataType)}})
        self.assertEqual(schema.get_column_type("foo", "bar").this, exp.DataType.Type.INT)

    def test_schema_normalization(self):
        schema = MappingSchema(
            schema={"x": {"`y`": {"Z": {"a": "INT", "`B`": "VARCHAR"}, "w": {"C": "INT"}}}},
            dialect="clickhouse",
        )

        table_z = exp.table_("z", db="y", catalog="x")
        table_w = exp.table_("w", db="y", catalog="x")

        self.assertEqual(schema.column_names(table_z), ["a", "B"])
        self.assertEqual(schema.column_names(table_w), ["c"])

        schema = MappingSchema(schema={"x": {"`y`": "INT"}}, dialect="clickhouse")
        self.assertEqual(schema.column_names(exp.table_("x")), ["y"])

        # Check that add_table normalizes both the table and the column names to be added / updated
        schema = MappingSchema()
        schema.add_table("Foo", {"SomeColumn": "INT", '"SomeColumn"': "DOUBLE"})
        self.assertEqual(schema.column_names(exp.table_("fOO")), ["somecolumn", "SomeColumn"])

        # Check that names are normalized to uppercase for Snowflake
        schema = MappingSchema(schema={"x": {"foo": "int", '"bLa"': "int"}}, dialect="snowflake")
        self.assertEqual(schema.column_names(exp.table_("x")), ["FOO", "bLa"])

        # Check that switching off the normalization logic works as expected
        schema = MappingSchema(schema={"x": {"foo": "int"}}, normalize=False, dialect="snowflake")
        self.assertEqual(schema.column_names(exp.table_("x")), ["foo"])

        # Check that the correct dialect is used when calling schema methods
        # Note: T-SQL is case-insensitive by default, so `fo` in clickhouse will match the normalized table name
        schema = MappingSchema(schema={"[Fo]": {"x": "int"}}, dialect="tsql")
        self.assertEqual(
            schema.column_names("[Fo]"), schema.column_names("`fo`", dialect="clickhouse")
        )

        # Check that all column identifiers are normalized to lowercase for BigQuery, even quoted
        # ones. Also, ensure that tables aren't normalized, since they're case-sensitive by default.
        schema = MappingSchema(schema={"Foo": {"`BaR`": "int"}}, dialect="bigquery")
        self.assertEqual(schema.column_names("Foo"), ["bar"])
        self.assertEqual(schema.column_names("foo"), [])

        # Check that the schema's normalization setting can be overridden
        schema = MappingSchema(schema={"X": {"y": "int"}}, normalize=False, dialect="snowflake")
        self.assertEqual(schema.column_names("x", normalize=True), ["y"])

    def test_same_number_of_qualifiers(self):
        schema = MappingSchema({"x": {"y": {"c1": "int"}}})

        with self.assertRaises(SchemaError) as ctx:
            schema.add_table("z", {"c2": "int"})

        self.assertEqual(
            str(ctx.exception),
            "Table z must match the schema's nesting level: 2.",
        )

        schema = MappingSchema()
        schema.add_table("x.y", {"c1": "int"})

        with self.assertRaises(SchemaError) as ctx:
            schema.add_table("z", {"c2": "int"})

        self.assertEqual(
            str(ctx.exception),
            "Table z must match the schema's nesting level: 2.",
        )

        with self.assertRaises(SchemaError) as ctx:
            MappingSchema({"x": {"y": {"c1": "int"}}, "z": {"c2": "int"}})

        self.assertEqual(
            str(ctx.exception),
            "Table z must match the schema's nesting level: 2.",
        )

        with self.assertRaises(SchemaError) as ctx:
            MappingSchema(
                {
                    "catalog": {
                        "db": {"tbl": {"col": "a"}},
                    },
                    "tbl2": {"col": "b"},
                },
            )
        self.assertEqual(
            str(ctx.exception),
            "Table tbl2 must match the schema's nesting level: 3.",
        )

        with self.assertRaises(SchemaError) as ctx:
            MappingSchema(
                {
                    "tbl2": {"col": "b"},
                    "catalog": {
                        "db": {"tbl": {"col": "a"}},
                    },
                },
            )
        self.assertEqual(
            str(ctx.exception),
            "Table catalog.db.tbl must match the schema's nesting level: 1.",
        )

    def test_has_column(self):
        schema = MappingSchema({"x": {"c": "int"}})
        self.assertTrue(schema.has_column("x", exp.column("c")))
        self.assertFalse(schema.has_column("x", exp.column("k")))

    def test_find(self):
        schema = MappingSchema({"x": {"c": "int"}})
        found = schema.find(exp.to_table("x"))
        self.assertEqual(found, {"c": "int"})
        found = schema.find(exp.to_table("x"), ensure_data_types=True)
        self.assertEqual(found, {"c": exp.DataType.build("int")})