611 lines
18 KiB
Python
611 lines
18 KiB
Python
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
import logging
|
|
from re import compile, escape
|
|
from collections import Counter
|
|
|
|
from prompt_toolkit.completion import Completer, Completion
|
|
|
|
from .packages.completion_engine import suggest_type
|
|
from .packages.parseutils import last_word
|
|
from .packages.special.iocommands import favoritequeries
|
|
from .packages.filepaths import parse_path, complete_path, suggest_path
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SQLCompleter(Completer):
|
|
keywords = [
|
|
"ABORT",
|
|
"ACTION",
|
|
"ADD",
|
|
"AFTER",
|
|
"ALL",
|
|
"ALTER",
|
|
"ANALYZE",
|
|
"AND",
|
|
"AS",
|
|
"ASC",
|
|
"ATTACH",
|
|
"AUTOINCREMENT",
|
|
"BEFORE",
|
|
"BEGIN",
|
|
"BETWEEN",
|
|
"BIGINT",
|
|
"BLOB",
|
|
"BOOLEAN",
|
|
"BY",
|
|
"CASCADE",
|
|
"CASE",
|
|
"CAST",
|
|
"CHARACTER",
|
|
"CHECK",
|
|
"CLOB",
|
|
"COLLATE",
|
|
"COLUMN",
|
|
"COMMIT",
|
|
"CONFLICT",
|
|
"CONSTRAINT",
|
|
"CREATE",
|
|
"CROSS",
|
|
"CURRENT",
|
|
"CURRENT_DATE",
|
|
"CURRENT_TIME",
|
|
"CURRENT_TIMESTAMP",
|
|
"DATABASE",
|
|
"DATE",
|
|
"DATETIME",
|
|
"DECIMAL",
|
|
"DEFAULT",
|
|
"DEFERRABLE",
|
|
"DEFERRED",
|
|
"DELETE",
|
|
"DETACH",
|
|
"DISTINCT",
|
|
"DO",
|
|
"DOUBLE PRECISION",
|
|
"DOUBLE",
|
|
"DROP",
|
|
"EACH",
|
|
"ELSE",
|
|
"END",
|
|
"ESCAPE",
|
|
"EXCEPT",
|
|
"EXCLUSIVE",
|
|
"EXISTS",
|
|
"EXPLAIN",
|
|
"FAIL",
|
|
"FILTER",
|
|
"FLOAT",
|
|
"FOLLOWING",
|
|
"FOR",
|
|
"FOREIGN",
|
|
"FROM",
|
|
"FULL",
|
|
"GLOB",
|
|
"GROUP",
|
|
"HAVING",
|
|
"IF",
|
|
"IGNORE",
|
|
"IMMEDIATE",
|
|
"IN",
|
|
"INDEX",
|
|
"INDEXED",
|
|
"INITIALLY",
|
|
"INNER",
|
|
"INSERT",
|
|
"INSTEAD",
|
|
"INT",
|
|
"INT2",
|
|
"INT8",
|
|
"INTEGER",
|
|
"INTERSECT",
|
|
"INTO",
|
|
"IS",
|
|
"ISNULL",
|
|
"JOIN",
|
|
"KEY",
|
|
"LEFT",
|
|
"LIKE",
|
|
"LIMIT",
|
|
"MATCH",
|
|
"MEDIUMINT",
|
|
"NATIVE CHARACTER",
|
|
"NATURAL",
|
|
"NCHAR",
|
|
"NO",
|
|
"NOT",
|
|
"NOTHING",
|
|
"NULL",
|
|
"NULLS FIRST",
|
|
"NULLS LAST",
|
|
"NUMERIC",
|
|
"NVARCHAR",
|
|
"OF",
|
|
"OFFSET",
|
|
"ON",
|
|
"OR",
|
|
"ORDER BY",
|
|
"OUTER",
|
|
"OVER",
|
|
"PARTITION",
|
|
"PLAN",
|
|
"PRAGMA",
|
|
"PRECEDING",
|
|
"PRIMARY",
|
|
"QUERY",
|
|
"RAISE",
|
|
"RANGE",
|
|
"REAL",
|
|
"RECURSIVE",
|
|
"REFERENCES",
|
|
"REGEXP",
|
|
"REINDEX",
|
|
"RELEASE",
|
|
"RENAME",
|
|
"REPLACE",
|
|
"RESTRICT",
|
|
"RIGHT",
|
|
"ROLLBACK",
|
|
"ROW",
|
|
"ROWS",
|
|
"SAVEPOINT",
|
|
"SELECT",
|
|
"SET",
|
|
"SMALLINT",
|
|
"TABLE",
|
|
"TEMP",
|
|
"TEMPORARY",
|
|
"TEXT",
|
|
"THEN",
|
|
"TINYINT",
|
|
"TO",
|
|
"TRANSACTION",
|
|
"TRIGGER",
|
|
"UNBOUNDED",
|
|
"UNION",
|
|
"UNIQUE",
|
|
"UNSIGNED BIG INT",
|
|
"UPDATE",
|
|
"USING",
|
|
"VACUUM",
|
|
"VALUES",
|
|
"VARCHAR",
|
|
"VARYING CHARACTER",
|
|
"VIEW",
|
|
"VIRTUAL",
|
|
"WHEN",
|
|
"WHERE",
|
|
"WINDOW",
|
|
"WITH",
|
|
"WITHOUT",
|
|
]
|
|
|
|
functions = [
|
|
"ABS",
|
|
"AVG",
|
|
"CHANGES",
|
|
"CHAR",
|
|
"COALESCE",
|
|
"COUNT",
|
|
"CUME_DIST",
|
|
"DATE",
|
|
"DATETIME",
|
|
"DENSE_RANK",
|
|
"GLOB",
|
|
"GROUP_CONCAT",
|
|
"HEX",
|
|
"IFNULL",
|
|
"INSTR",
|
|
"JSON",
|
|
"JSON_ARRAY",
|
|
"JSON_ARRAY_LENGTH",
|
|
"JSON_EACH",
|
|
"JSON_EXTRACT",
|
|
"JSON_GROUP_ARRAY",
|
|
"JSON_GROUP_OBJECT",
|
|
"JSON_INSERT",
|
|
"JSON_OBJECT",
|
|
"JSON_PATCH",
|
|
"JSON_QUOTE",
|
|
"JSON_REMOVE",
|
|
"JSON_REPLACE",
|
|
"JSON_SET",
|
|
"JSON_TREE",
|
|
"JSON_TYPE",
|
|
"JSON_VALID",
|
|
"JULIANDAY",
|
|
"LAG",
|
|
"LAST_INSERT_ROWID",
|
|
"LENGTH",
|
|
"LIKELIHOOD",
|
|
"LIKELY",
|
|
"LOAD_EXTENSION",
|
|
"LOWER",
|
|
"LTRIM",
|
|
"MAX",
|
|
"MIN",
|
|
"NTILE",
|
|
"NULLIF",
|
|
"PERCENT_RANK",
|
|
"PRINTF",
|
|
"QUOTE",
|
|
"RANDOM",
|
|
"RANDOMBLOB",
|
|
"RANK",
|
|
"REPLACE",
|
|
"ROUND",
|
|
"ROW_NUMBER",
|
|
"RTRIM",
|
|
"SOUNDEX",
|
|
"SQLITE_COMPILEOPTION_GET",
|
|
"SQLITE_COMPILEOPTION_USED",
|
|
"SQLITE_OFFSET",
|
|
"SQLITE_SOURCE_ID",
|
|
"SQLITE_VERSION",
|
|
"STRFTIME",
|
|
"SUBSTR",
|
|
"SUM",
|
|
"TIME",
|
|
"TOTAL",
|
|
"TOTAL_CHANGES",
|
|
"TRIM",
|
|
]
|
|
|
|
def __init__(self, supported_formats=(), keyword_casing="auto"):
|
|
super(self.__class__, self).__init__()
|
|
self.reserved_words = set()
|
|
for x in self.keywords:
|
|
self.reserved_words.update(x.split())
|
|
self.name_pattern = compile(r"^[_a-z][_a-z0-9\$]*$")
|
|
|
|
self.special_commands = []
|
|
self.table_formats = supported_formats
|
|
if keyword_casing not in ("upper", "lower", "auto"):
|
|
keyword_casing = "auto"
|
|
self.keyword_casing = keyword_casing
|
|
self.reset_completions()
|
|
|
|
def escape_name(self, name):
|
|
if name and (
|
|
(not self.name_pattern.match(name))
|
|
or (name.upper() in self.reserved_words)
|
|
or (name.upper() in self.functions)
|
|
):
|
|
name = "`%s`" % name
|
|
|
|
return name
|
|
|
|
def unescape_name(self, name):
|
|
"""Unquote a string."""
|
|
if name and name[0] == '"' and name[-1] == '"':
|
|
name = name[1:-1]
|
|
|
|
return name
|
|
|
|
def escaped_names(self, names):
|
|
return [self.escape_name(name) for name in names]
|
|
|
|
def extend_special_commands(self, special_commands):
|
|
# Special commands are not part of all_completions since they can only
|
|
# be at the beginning of a line.
|
|
self.special_commands.extend(special_commands)
|
|
|
|
def extend_database_names(self, databases):
|
|
self.databases.extend(databases)
|
|
|
|
def extend_keywords(self, additional_keywords):
|
|
self.keywords.extend(additional_keywords)
|
|
self.all_completions.update(additional_keywords)
|
|
|
|
def extend_schemata(self, schema):
|
|
if schema is None:
|
|
return
|
|
metadata = self.dbmetadata["tables"]
|
|
metadata[schema] = {}
|
|
|
|
# dbmetadata.values() are the 'tables' and 'functions' dicts
|
|
for metadata in self.dbmetadata.values():
|
|
metadata[schema] = {}
|
|
self.all_completions.update(schema)
|
|
|
|
def extend_relations(self, data, kind):
|
|
"""Extend metadata for tables or views
|
|
|
|
:param data: list of (rel_name, ) tuples
|
|
:param kind: either 'tables' or 'views'
|
|
:return:
|
|
"""
|
|
# 'data' is a generator object. It can throw an exception while being
|
|
# consumed. This could happen if the user has launched the app without
|
|
# specifying a database name. This exception must be handled to prevent
|
|
# crashing.
|
|
try:
|
|
data = [self.escaped_names(d) for d in data]
|
|
except Exception:
|
|
data = []
|
|
|
|
# dbmetadata['tables'][$schema_name][$table_name] should be a list of
|
|
# column names. Default to an asterisk
|
|
metadata = self.dbmetadata[kind]
|
|
for relname in data:
|
|
try:
|
|
metadata[self.dbname][relname[0]] = ["*"]
|
|
except KeyError:
|
|
_logger.error(
|
|
"%r %r listed in unrecognized schema %r",
|
|
kind,
|
|
relname[0],
|
|
self.dbname,
|
|
)
|
|
self.all_completions.add(relname[0])
|
|
|
|
def extend_columns(self, column_data, kind):
|
|
"""Extend column metadata
|
|
|
|
:param column_data: list of (rel_name, column_name) tuples
|
|
:param kind: either 'tables' or 'views'
|
|
:return:
|
|
"""
|
|
# 'column_data' is a generator object. It can throw an exception while
|
|
# being consumed. This could happen if the user has launched the app
|
|
# without specifying a database name. This exception must be handled to
|
|
# prevent crashing.
|
|
try:
|
|
column_data = [self.escaped_names(d) for d in column_data]
|
|
except Exception:
|
|
column_data = []
|
|
|
|
metadata = self.dbmetadata[kind]
|
|
for relname, column in column_data:
|
|
metadata[self.dbname][relname].append(column)
|
|
self.all_completions.add(column)
|
|
|
|
def extend_functions(self, func_data):
|
|
# 'func_data' is a generator object. It can throw an exception while
|
|
# being consumed. This could happen if the user has launched the app
|
|
# without specifying a database name. This exception must be handled to
|
|
# prevent crashing.
|
|
try:
|
|
func_data = [self.escaped_names(d) for d in func_data]
|
|
except Exception:
|
|
func_data = []
|
|
|
|
# dbmetadata['functions'][$schema_name][$function_name] should return
|
|
# function metadata.
|
|
metadata = self.dbmetadata["functions"]
|
|
|
|
for func in func_data:
|
|
metadata[self.dbname][func[0]] = None
|
|
self.all_completions.add(func[0])
|
|
|
|
def set_dbname(self, dbname):
|
|
self.dbname = dbname
|
|
|
|
def reset_completions(self):
|
|
self.databases = []
|
|
self.dbname = ""
|
|
self.dbmetadata = {"tables": {}, "views": {}, "functions": {}}
|
|
self.all_completions = set(self.keywords + self.functions)
|
|
|
|
@staticmethod
|
|
def find_matches(
|
|
text,
|
|
collection,
|
|
start_only=False,
|
|
fuzzy=True,
|
|
casing=None,
|
|
punctuations="most_punctuations",
|
|
):
|
|
"""Find completion matches for the given text.
|
|
|
|
Given the user's input text and a collection of available
|
|
completions, find completions matching the last word of the
|
|
text.
|
|
|
|
If `start_only` is True, the text will match an available
|
|
completion only at the beginning. Otherwise, a completion is
|
|
considered a match if the text appears anywhere within it.
|
|
|
|
yields prompt_toolkit Completion instances for any matches found
|
|
in the collection of available completions.
|
|
"""
|
|
last = last_word(text, include=punctuations)
|
|
text = last.lower()
|
|
|
|
completions = []
|
|
|
|
if fuzzy:
|
|
regex = ".*?".join(map(escape, text))
|
|
pat = compile("(%s)" % regex)
|
|
for item in sorted(collection):
|
|
r = pat.search(item.lower())
|
|
if r:
|
|
completions.append((len(r.group()), r.start(), item))
|
|
else:
|
|
match_end_limit = len(text) if start_only else None
|
|
for item in sorted(collection):
|
|
match_point = item.lower().find(text, 0, match_end_limit)
|
|
if match_point >= 0:
|
|
completions.append((len(text), match_point, item))
|
|
|
|
if casing == "auto":
|
|
casing = "lower" if last and last[-1].islower() else "upper"
|
|
|
|
def apply_case(kw):
|
|
if casing == "upper":
|
|
return kw.upper()
|
|
return kw.lower()
|
|
|
|
return (
|
|
Completion(z if casing is None else apply_case(z), -len(text))
|
|
for x, y, z in sorted(completions)
|
|
)
|
|
|
|
def get_completions(self, document, complete_event):
|
|
word_before_cursor = document.get_word_before_cursor(WORD=True)
|
|
completions = []
|
|
suggestions = suggest_type(document.text, document.text_before_cursor)
|
|
|
|
for suggestion in suggestions:
|
|
_logger.debug("Suggestion type: %r", suggestion["type"])
|
|
|
|
if suggestion["type"] == "column":
|
|
tables = suggestion["tables"]
|
|
_logger.debug("Completion column scope: %r", tables)
|
|
scoped_cols = self.populate_scoped_cols(tables)
|
|
if suggestion.get("drop_unique"):
|
|
# drop_unique is used for 'tb11 JOIN tbl2 USING (...'
|
|
# which should suggest only columns that appear in more than
|
|
# one table
|
|
scoped_cols = [
|
|
col
|
|
for (col, count) in Counter(scoped_cols).items()
|
|
if count > 1 and col != "*"
|
|
]
|
|
|
|
cols = self.find_matches(word_before_cursor, scoped_cols)
|
|
completions.extend(cols)
|
|
|
|
elif suggestion["type"] == "function":
|
|
# suggest user-defined functions using substring matching
|
|
funcs = self.populate_schema_objects(suggestion["schema"], "functions")
|
|
user_funcs = self.find_matches(word_before_cursor, funcs)
|
|
completions.extend(user_funcs)
|
|
|
|
# suggest hardcoded functions using startswith matching only if
|
|
# there is no schema qualifier. If a schema qualifier is
|
|
# present it probably denotes a table.
|
|
# eg: SELECT * FROM users u WHERE u.
|
|
if not suggestion["schema"]:
|
|
predefined_funcs = self.find_matches(
|
|
word_before_cursor,
|
|
self.functions,
|
|
start_only=True,
|
|
fuzzy=False,
|
|
casing=self.keyword_casing,
|
|
)
|
|
completions.extend(predefined_funcs)
|
|
|
|
elif suggestion["type"] == "table":
|
|
tables = self.populate_schema_objects(suggestion["schema"], "tables")
|
|
tables = self.find_matches(word_before_cursor, tables)
|
|
completions.extend(tables)
|
|
|
|
elif suggestion["type"] == "view":
|
|
views = self.populate_schema_objects(suggestion["schema"], "views")
|
|
views = self.find_matches(word_before_cursor, views)
|
|
completions.extend(views)
|
|
|
|
elif suggestion["type"] == "alias":
|
|
aliases = suggestion["aliases"]
|
|
aliases = self.find_matches(word_before_cursor, aliases)
|
|
completions.extend(aliases)
|
|
|
|
elif suggestion["type"] == "database":
|
|
dbs = self.find_matches(word_before_cursor, self.databases)
|
|
completions.extend(dbs)
|
|
|
|
elif suggestion["type"] == "keyword":
|
|
keywords = self.find_matches(
|
|
word_before_cursor,
|
|
self.keywords,
|
|
start_only=True,
|
|
fuzzy=False,
|
|
casing=self.keyword_casing,
|
|
punctuations="many_punctuations",
|
|
)
|
|
completions.extend(keywords)
|
|
|
|
elif suggestion["type"] == "special":
|
|
special = self.find_matches(
|
|
word_before_cursor,
|
|
self.special_commands,
|
|
start_only=True,
|
|
fuzzy=False,
|
|
punctuations="many_punctuations",
|
|
)
|
|
completions.extend(special)
|
|
elif suggestion["type"] == "favoritequery":
|
|
queries = self.find_matches(
|
|
word_before_cursor,
|
|
favoritequeries.list(),
|
|
start_only=False,
|
|
fuzzy=True,
|
|
)
|
|
completions.extend(queries)
|
|
elif suggestion["type"] == "table_format":
|
|
formats = self.find_matches(
|
|
word_before_cursor, self.table_formats, start_only=True, fuzzy=False
|
|
)
|
|
completions.extend(formats)
|
|
elif suggestion["type"] == "file_name":
|
|
file_names = self.find_files(word_before_cursor)
|
|
completions.extend(file_names)
|
|
|
|
return completions
|
|
|
|
def find_files(self, word):
|
|
"""Yield matching directory or file names.
|
|
|
|
:param word:
|
|
:return: iterable
|
|
|
|
"""
|
|
base_path, last_path, position = parse_path(word)
|
|
paths = suggest_path(word)
|
|
for name in sorted(paths):
|
|
suggestion = complete_path(name, last_path)
|
|
if suggestion:
|
|
yield Completion(suggestion, position)
|
|
|
|
def populate_scoped_cols(self, scoped_tbls):
|
|
"""Find all columns in a set of scoped_tables
|
|
:param scoped_tbls: list of (schema, table, alias) tuples
|
|
:return: list of column names
|
|
"""
|
|
columns = []
|
|
meta = self.dbmetadata
|
|
|
|
for tbl in scoped_tbls:
|
|
# A fully qualified schema.relname reference or default_schema
|
|
# DO NOT escape schema names.
|
|
schema = tbl[0] or self.dbname
|
|
relname = tbl[1]
|
|
escaped_relname = self.escape_name(tbl[1])
|
|
|
|
# We don't know if schema.relname is a table or view. Since
|
|
# tables and views cannot share the same name, we can check one
|
|
# at a time
|
|
try:
|
|
columns.extend(meta["tables"][schema][relname])
|
|
|
|
# Table exists, so don't bother checking for a view
|
|
continue
|
|
except KeyError:
|
|
try:
|
|
columns.extend(meta["tables"][schema][escaped_relname])
|
|
# Table exists, so don't bother checking for a view
|
|
continue
|
|
except KeyError:
|
|
pass
|
|
|
|
try:
|
|
columns.extend(meta["views"][schema][relname])
|
|
except KeyError:
|
|
pass
|
|
|
|
return columns
|
|
|
|
def populate_schema_objects(self, schema, obj_type):
|
|
"""Returns list of tables or functions for a (optional) schema"""
|
|
metadata = self.dbmetadata[obj_type]
|
|
schema = schema or self.dbname
|
|
|
|
try:
|
|
objects = metadata[schema].keys()
|
|
except KeyError:
|
|
# schema doesn't exist
|
|
objects = []
|
|
|
|
return objects
|