1072 lines
42 KiB
Python
1072 lines
42 KiB
Python
import json
|
|
import logging
|
|
import re
|
|
from itertools import count, repeat, chain
|
|
import operator
|
|
from collections import namedtuple, defaultdict, OrderedDict
|
|
from cli_helpers.tabular_output import TabularOutputFormatter
|
|
from pgspecial.namedqueries import NamedQueries
|
|
from prompt_toolkit.completion import Completer, Completion, PathCompleter
|
|
from prompt_toolkit.document import Document
|
|
from .packages.sqlcompletion import (
|
|
FromClauseItem,
|
|
suggest_type,
|
|
Special,
|
|
Database,
|
|
Schema,
|
|
Table,
|
|
TableFormat,
|
|
Function,
|
|
Column,
|
|
View,
|
|
Keyword,
|
|
NamedQuery,
|
|
Datatype,
|
|
Alias,
|
|
Path,
|
|
JoinCondition,
|
|
Join,
|
|
)
|
|
from .packages.parseutils.meta import ColumnMetadata, ForeignKey
|
|
from .packages.parseutils.utils import last_word
|
|
from .packages.parseutils.tables import TableReference
|
|
from .packages.pgliterals.main import get_literals
|
|
from .packages.prioritization import PrevalenceCounter
|
|
from .config import load_config, config_location
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
Match = namedtuple("Match", ["completion", "priority"])
|
|
|
|
_SchemaObject = namedtuple("SchemaObject", "name schema meta")
|
|
|
|
|
|
def SchemaObject(name, schema=None, meta=None):
|
|
return _SchemaObject(name, schema, meta)
|
|
|
|
|
|
_Candidate = namedtuple("Candidate", "completion prio meta synonyms prio2 display")
|
|
|
|
|
|
def Candidate(
|
|
completion, prio=None, meta=None, synonyms=None, prio2=None, display=None
|
|
):
|
|
return _Candidate(
|
|
completion, prio, meta, synonyms or [completion], prio2, display or completion
|
|
)
|
|
|
|
|
|
# Used to strip trailing '::some_type' from default-value expressions
|
|
arg_default_type_strip_regex = re.compile(r"::[\w\.]+(\[\])?$")
|
|
|
|
normalize_ref = lambda ref: ref if ref[0] == '"' else '"' + ref.lower() + '"'
|
|
|
|
|
|
def generate_alias(tbl, alias_map=None):
|
|
"""Generate a table alias, consisting of all upper-case letters in
|
|
the table name, or, if there are no upper-case letters, the first letter +
|
|
all letters preceded by _
|
|
param tbl - unescaped name of the table to alias
|
|
"""
|
|
if alias_map and tbl in alias_map:
|
|
return alias_map[tbl]
|
|
return "".join(
|
|
[l for l in tbl if l.isupper()]
|
|
or [l for l, prev in zip(tbl, "_" + tbl) if prev == "_" and l != "_"]
|
|
)
|
|
|
|
|
|
class InvalidMapFile(ValueError):
|
|
pass
|
|
|
|
|
|
def load_alias_map_file(path):
|
|
try:
|
|
with open(path) as fo:
|
|
alias_map = json.load(fo)
|
|
except FileNotFoundError as err:
|
|
raise InvalidMapFile(
|
|
f"Cannot read alias_map_file - {err.filename} does not exist"
|
|
)
|
|
except json.JSONDecodeError:
|
|
raise InvalidMapFile(f"Cannot read alias_map_file - {path} is not valid json")
|
|
else:
|
|
return alias_map
|
|
|
|
|
|
class PGCompleter(Completer):
|
|
# keywords_tree: A dict mapping keywords to well known following keywords.
|
|
# e.g. 'CREATE': ['TABLE', 'USER', ...],
|
|
keywords_tree = get_literals("keywords", type_=dict)
|
|
keywords = tuple(set(chain(keywords_tree.keys(), *keywords_tree.values())))
|
|
functions = get_literals("functions")
|
|
datatypes = get_literals("datatypes")
|
|
reserved_words = set(get_literals("reserved"))
|
|
|
|
def __init__(self, smart_completion=True, pgspecial=None, settings=None):
|
|
super().__init__()
|
|
self.smart_completion = smart_completion
|
|
self.pgspecial = pgspecial
|
|
self.prioritizer = PrevalenceCounter()
|
|
settings = settings or {}
|
|
self.signature_arg_style = settings.get(
|
|
"signature_arg_style", "{arg_name} {arg_type}"
|
|
)
|
|
self.call_arg_style = settings.get(
|
|
"call_arg_style", "{arg_name: <{max_arg_len}} := {arg_default}"
|
|
)
|
|
self.call_arg_display_style = settings.get(
|
|
"call_arg_display_style", "{arg_name}"
|
|
)
|
|
self.call_arg_oneliner_max = settings.get("call_arg_oneliner_max", 2)
|
|
self.search_path_filter = settings.get("search_path_filter")
|
|
self.generate_aliases = settings.get("generate_aliases")
|
|
alias_map_file = settings.get("alias_map_file")
|
|
if alias_map_file is not None:
|
|
self.alias_map = load_alias_map_file(alias_map_file)
|
|
else:
|
|
self.alias_map = None
|
|
self.casing_file = settings.get("casing_file")
|
|
self.insert_col_skip_patterns = [
|
|
re.compile(pattern)
|
|
for pattern in settings.get(
|
|
"insert_col_skip_patterns", [r"^now\(\)$", r"^nextval\("]
|
|
)
|
|
]
|
|
self.generate_casing_file = settings.get("generate_casing_file")
|
|
self.qualify_columns = settings.get("qualify_columns", "if_more_than_one_table")
|
|
self.asterisk_column_order = settings.get(
|
|
"asterisk_column_order", "table_order"
|
|
)
|
|
|
|
keyword_casing = settings.get("keyword_casing", "upper").lower()
|
|
if keyword_casing not in ("upper", "lower", "auto"):
|
|
keyword_casing = "upper"
|
|
self.keyword_casing = keyword_casing
|
|
self.name_pattern = re.compile(r"^[_a-z][_a-z0-9\$]*$")
|
|
|
|
self.databases = []
|
|
self.dbmetadata = {"tables": {}, "views": {}, "functions": {}, "datatypes": {}}
|
|
self.search_path = []
|
|
self.casing = {}
|
|
|
|
self.all_completions = set(self.keywords + self.functions)
|
|
|
|
def escape_name(self, name):
|
|
if name and (
|
|
(not self.name_pattern.match(name))
|
|
or (name.upper() in self.reserved_words)
|
|
or (name.upper() in self.functions)
|
|
):
|
|
name = '"%s"' % name
|
|
|
|
return name
|
|
|
|
def escape_schema(self, name):
|
|
return "'{}'".format(self.unescape_name(name))
|
|
|
|
def unescape_name(self, name):
|
|
"""Unquote a string."""
|
|
if name and name[0] == '"' and name[-1] == '"':
|
|
name = name[1:-1]
|
|
|
|
return name
|
|
|
|
def escaped_names(self, names):
|
|
return [self.escape_name(name) for name in names]
|
|
|
|
def extend_database_names(self, databases):
|
|
self.databases.extend(databases)
|
|
|
|
def extend_keywords(self, additional_keywords):
|
|
self.keywords.extend(additional_keywords)
|
|
self.all_completions.update(additional_keywords)
|
|
|
|
def extend_schemata(self, schemata):
|
|
# schemata is a list of schema names
|
|
schemata = self.escaped_names(schemata)
|
|
metadata = self.dbmetadata["tables"]
|
|
for schema in schemata:
|
|
metadata[schema] = {}
|
|
|
|
# dbmetadata.values() are the 'tables' and 'functions' dicts
|
|
for metadata in self.dbmetadata.values():
|
|
for schema in schemata:
|
|
metadata[schema] = {}
|
|
|
|
self.all_completions.update(schemata)
|
|
|
|
def extend_casing(self, words):
|
|
"""extend casing data
|
|
|
|
:return:
|
|
"""
|
|
# casing should be a dict {lowercasename:PreferredCasingName}
|
|
self.casing = {word.lower(): word for word in words}
|
|
|
|
def extend_relations(self, data, kind):
|
|
"""extend metadata for tables or views.
|
|
|
|
:param data: list of (schema_name, rel_name) tuples
|
|
:param kind: either 'tables' or 'views'
|
|
|
|
:return:
|
|
|
|
"""
|
|
|
|
data = [self.escaped_names(d) for d in data]
|
|
|
|
# dbmetadata['tables']['schema_name']['table_name'] should be an
|
|
# OrderedDict {column_name:ColumnMetaData}.
|
|
metadata = self.dbmetadata[kind]
|
|
for schema, relname in data:
|
|
try:
|
|
metadata[schema][relname] = OrderedDict()
|
|
except KeyError:
|
|
_logger.error(
|
|
"%r %r listed in unrecognized schema %r", kind, relname, schema
|
|
)
|
|
self.all_completions.add(relname)
|
|
|
|
def extend_columns(self, column_data, kind):
|
|
"""extend column metadata.
|
|
|
|
:param column_data: list of (schema_name, rel_name, column_name,
|
|
column_type, has_default, default) tuples
|
|
:param kind: either 'tables' or 'views'
|
|
|
|
:return:
|
|
|
|
"""
|
|
metadata = self.dbmetadata[kind]
|
|
for schema, relname, colname, datatype, has_default, default in column_data:
|
|
(schema, relname, colname) = self.escaped_names([schema, relname, colname])
|
|
column = ColumnMetadata(
|
|
name=colname,
|
|
datatype=datatype,
|
|
has_default=has_default,
|
|
default=default,
|
|
)
|
|
metadata[schema][relname][colname] = column
|
|
self.all_completions.add(colname)
|
|
|
|
def extend_functions(self, func_data):
|
|
# func_data is a list of function metadata namedtuples
|
|
|
|
# dbmetadata['schema_name']['functions']['function_name'] should return
|
|
# the function metadata namedtuple for the corresponding function
|
|
metadata = self.dbmetadata["functions"]
|
|
|
|
for f in func_data:
|
|
schema, func = self.escaped_names([f.schema_name, f.func_name])
|
|
|
|
if func in metadata[schema]:
|
|
metadata[schema][func].append(f)
|
|
else:
|
|
metadata[schema][func] = [f]
|
|
|
|
self.all_completions.add(func)
|
|
|
|
self._refresh_arg_list_cache()
|
|
|
|
def _refresh_arg_list_cache(self):
|
|
# We keep a cache of {function_usage:{function_metadata: function_arg_list_string}}
|
|
# This is used when suggesting functions, to avoid the latency that would result
|
|
# if we'd recalculate the arg lists each time we suggest functions (in large DBs)
|
|
self._arg_list_cache = {
|
|
usage: {
|
|
meta: self._arg_list(meta, usage)
|
|
for sch, funcs in self.dbmetadata["functions"].items()
|
|
for func, metas in funcs.items()
|
|
for meta in metas
|
|
}
|
|
for usage in ("call", "call_display", "signature")
|
|
}
|
|
|
|
def extend_foreignkeys(self, fk_data):
|
|
# fk_data is a list of ForeignKey namedtuples, with fields
|
|
# parentschema, childschema, parenttable, childtable,
|
|
# parentcolumns, childcolumns
|
|
|
|
# These are added as a list of ForeignKey namedtuples to the
|
|
# ColumnMetadata namedtuple for both the child and parent
|
|
meta = self.dbmetadata["tables"]
|
|
|
|
for fk in fk_data:
|
|
e = self.escaped_names
|
|
parentschema, childschema = e([fk.parentschema, fk.childschema])
|
|
parenttable, childtable = e([fk.parenttable, fk.childtable])
|
|
childcol, parcol = e([fk.childcolumn, fk.parentcolumn])
|
|
childcolmeta = meta[childschema][childtable][childcol]
|
|
parcolmeta = meta[parentschema][parenttable][parcol]
|
|
fk = ForeignKey(
|
|
parentschema, parenttable, parcol, childschema, childtable, childcol
|
|
)
|
|
childcolmeta.foreignkeys.append(fk)
|
|
parcolmeta.foreignkeys.append(fk)
|
|
|
|
def extend_datatypes(self, type_data):
|
|
# dbmetadata['datatypes'][schema_name][type_name] should store type
|
|
# metadata, such as composite type field names. Currently, we're not
|
|
# storing any metadata beyond typename, so just store None
|
|
meta = self.dbmetadata["datatypes"]
|
|
|
|
for t in type_data:
|
|
schema, type_name = self.escaped_names(t)
|
|
meta[schema][type_name] = None
|
|
self.all_completions.add(type_name)
|
|
|
|
def extend_query_history(self, text, is_init=False):
|
|
if is_init:
|
|
# During completer initialization, only load keyword preferences,
|
|
# not names
|
|
self.prioritizer.update_keywords(text)
|
|
else:
|
|
self.prioritizer.update(text)
|
|
|
|
def set_search_path(self, search_path):
|
|
self.search_path = self.escaped_names(search_path)
|
|
|
|
def reset_completions(self):
|
|
self.databases = []
|
|
self.special_commands = []
|
|
self.search_path = []
|
|
self.dbmetadata = {"tables": {}, "views": {}, "functions": {}, "datatypes": {}}
|
|
self.all_completions = set(self.keywords + self.functions)
|
|
|
|
def find_matches(self, text, collection, mode="fuzzy", meta=None):
|
|
"""Find completion matches for the given text.
|
|
|
|
Given the user's input text and a collection of available
|
|
completions, find completions matching the last word of the
|
|
text.
|
|
|
|
`collection` can be either a list of strings or a list of Candidate
|
|
namedtuples.
|
|
`mode` can be either 'fuzzy', or 'strict'
|
|
'fuzzy': fuzzy matching, ties broken by name prevalance
|
|
`keyword`: start only matching, ties broken by keyword prevalance
|
|
|
|
yields prompt_toolkit Completion instances for any matches found
|
|
in the collection of available completions.
|
|
|
|
"""
|
|
if not collection:
|
|
return []
|
|
prio_order = [
|
|
"keyword",
|
|
"function",
|
|
"view",
|
|
"table",
|
|
"datatype",
|
|
"database",
|
|
"schema",
|
|
"column",
|
|
"table alias",
|
|
"join",
|
|
"name join",
|
|
"fk join",
|
|
"table format",
|
|
]
|
|
type_priority = prio_order.index(meta) if meta in prio_order else -1
|
|
text = last_word(text, include="most_punctuations").lower()
|
|
text_len = len(text)
|
|
|
|
if text and text[0] == '"':
|
|
# text starts with double quote; user is manually escaping a name
|
|
# Match on everything that follows the double-quote. Note that
|
|
# text_len is calculated before removing the quote, so the
|
|
# Completion.position value is correct
|
|
text = text[1:]
|
|
|
|
if mode == "fuzzy":
|
|
fuzzy = True
|
|
priority_func = self.prioritizer.name_count
|
|
else:
|
|
fuzzy = False
|
|
priority_func = self.prioritizer.keyword_count
|
|
|
|
# Construct a `_match` function for either fuzzy or non-fuzzy matching
|
|
# The match function returns a 2-tuple used for sorting the matches,
|
|
# or None if the item doesn't match
|
|
# Note: higher priority values mean more important, so use negative
|
|
# signs to flip the direction of the tuple
|
|
if fuzzy:
|
|
regex = ".*?".join(map(re.escape, text))
|
|
pat = re.compile("(%s)" % regex)
|
|
|
|
def _match(item):
|
|
if item.lower()[: len(text) + 1] in (text, text + " "):
|
|
# Exact match of first word in suggestion
|
|
# This is to get exact alias matches to the top
|
|
# E.g. for input `e`, 'Entries E' should be on top
|
|
# (before e.g. `EndUsers EU`)
|
|
return float("Infinity"), -1
|
|
r = pat.search(self.unescape_name(item.lower()))
|
|
if r:
|
|
return -len(r.group()), -r.start()
|
|
|
|
else:
|
|
match_end_limit = len(text)
|
|
|
|
def _match(item):
|
|
match_point = item.lower().find(text, 0, match_end_limit)
|
|
if match_point >= 0:
|
|
# Use negative infinity to force keywords to sort after all
|
|
# fuzzy matches
|
|
return -float("Infinity"), -match_point
|
|
|
|
matches = []
|
|
for cand in collection:
|
|
if isinstance(cand, _Candidate):
|
|
item, prio, display_meta, synonyms, prio2, display = cand
|
|
if display_meta is None:
|
|
display_meta = meta
|
|
syn_matches = (_match(x) for x in synonyms)
|
|
# Nones need to be removed to avoid max() crashing in Python 3
|
|
syn_matches = [m for m in syn_matches if m]
|
|
sort_key = max(syn_matches) if syn_matches else None
|
|
else:
|
|
item, display_meta, prio, prio2, display = cand, meta, 0, 0, cand
|
|
sort_key = _match(cand)
|
|
|
|
if sort_key:
|
|
if display_meta and len(display_meta) > 50:
|
|
# Truncate meta-text to 50 characters, if necessary
|
|
display_meta = display_meta[:47] + "..."
|
|
|
|
# Lexical order of items in the collection, used for
|
|
# tiebreaking items with the same match group length and start
|
|
# position. Since we use *higher* priority to mean "more
|
|
# important," we use -ord(c) to prioritize "aa" > "ab" and end
|
|
# with 1 to prioritize shorter strings (ie "user" > "users").
|
|
# We first do a case-insensitive sort and then a
|
|
# case-sensitive one as a tie breaker.
|
|
# We also use the unescape_name to make sure quoted names have
|
|
# the same priority as unquoted names.
|
|
lexical_priority = (
|
|
tuple(
|
|
0 if c in " _" else -ord(c)
|
|
for c in self.unescape_name(item.lower())
|
|
)
|
|
+ (1,)
|
|
+ tuple(c for c in item)
|
|
)
|
|
|
|
item = self.case(item)
|
|
display = self.case(display)
|
|
priority = (
|
|
sort_key,
|
|
type_priority,
|
|
prio,
|
|
priority_func(item),
|
|
prio2,
|
|
lexical_priority,
|
|
)
|
|
matches.append(
|
|
Match(
|
|
completion=Completion(
|
|
text=item,
|
|
start_position=-text_len,
|
|
display_meta=display_meta,
|
|
display=display,
|
|
),
|
|
priority=priority,
|
|
)
|
|
)
|
|
return matches
|
|
|
|
def case(self, word):
|
|
return self.casing.get(word, word)
|
|
|
|
def get_completions(self, document, complete_event, smart_completion=None):
|
|
word_before_cursor = document.get_word_before_cursor(WORD=True)
|
|
|
|
if smart_completion is None:
|
|
smart_completion = self.smart_completion
|
|
|
|
# If smart_completion is off then match any word that starts with
|
|
# 'word_before_cursor'.
|
|
if not smart_completion:
|
|
matches = self.find_matches(
|
|
word_before_cursor, self.all_completions, mode="strict"
|
|
)
|
|
completions = [m.completion for m in matches]
|
|
return sorted(completions, key=operator.attrgetter("text"))
|
|
|
|
matches = []
|
|
suggestions = suggest_type(document.text, document.text_before_cursor)
|
|
|
|
for suggestion in suggestions:
|
|
suggestion_type = type(suggestion)
|
|
_logger.debug("Suggestion type: %r", suggestion_type)
|
|
|
|
# Map suggestion type to method
|
|
# e.g. 'table' -> self.get_table_matches
|
|
matcher = self.suggestion_matchers[suggestion_type]
|
|
matches.extend(matcher(self, suggestion, word_before_cursor))
|
|
|
|
# Sort matches so highest priorities are first
|
|
matches = sorted(matches, key=operator.attrgetter("priority"), reverse=True)
|
|
|
|
return [m.completion for m in matches]
|
|
|
|
def get_column_matches(self, suggestion, word_before_cursor):
|
|
tables = suggestion.table_refs
|
|
do_qualify = (
|
|
suggestion.qualifiable
|
|
and {
|
|
"always": True,
|
|
"never": False,
|
|
"if_more_than_one_table": len(tables) > 1,
|
|
}[self.qualify_columns]
|
|
)
|
|
qualify = lambda col, tbl: (
|
|
(tbl + "." + self.case(col)) if do_qualify else self.case(col)
|
|
)
|
|
_logger.debug("Completion column scope: %r", tables)
|
|
scoped_cols = self.populate_scoped_cols(tables, suggestion.local_tables)
|
|
|
|
def make_cand(name, ref):
|
|
synonyms = (name, generate_alias(self.case(name)))
|
|
return Candidate(qualify(name, ref), 0, "column", synonyms)
|
|
|
|
def flat_cols():
|
|
return [
|
|
make_cand(c.name, t.ref)
|
|
for t, cols in scoped_cols.items()
|
|
for c in cols
|
|
]
|
|
|
|
if suggestion.require_last_table:
|
|
# require_last_table is used for 'tb11 JOIN tbl2 USING (...' which should
|
|
# suggest only columns that appear in the last table and one more
|
|
ltbl = tables[-1].ref
|
|
other_tbl_cols = {
|
|
c.name for t, cs in scoped_cols.items() if t.ref != ltbl for c in cs
|
|
}
|
|
scoped_cols = {
|
|
t: [col for col in cols if col.name in other_tbl_cols]
|
|
for t, cols in scoped_cols.items()
|
|
if t.ref == ltbl
|
|
}
|
|
lastword = last_word(word_before_cursor, include="most_punctuations")
|
|
if lastword == "*":
|
|
if suggestion.context == "insert":
|
|
|
|
def filter(col):
|
|
if not col.has_default:
|
|
return True
|
|
return not any(
|
|
p.match(col.default) for p in self.insert_col_skip_patterns
|
|
)
|
|
|
|
scoped_cols = {
|
|
t: [col for col in cols if filter(col)]
|
|
for t, cols in scoped_cols.items()
|
|
}
|
|
if self.asterisk_column_order == "alphabetic":
|
|
for cols in scoped_cols.values():
|
|
cols.sort(key=operator.attrgetter("name"))
|
|
if (
|
|
lastword != word_before_cursor
|
|
and len(tables) == 1
|
|
and word_before_cursor[-len(lastword) - 1] == "."
|
|
):
|
|
# User typed x.*; replicate "x." for all columns except the
|
|
# first, which gets the original (as we only replace the "*"")
|
|
sep = ", " + word_before_cursor[:-1]
|
|
collist = sep.join(self.case(c.completion) for c in flat_cols())
|
|
else:
|
|
collist = ", ".join(
|
|
qualify(c.name, t.ref) for t, cs in scoped_cols.items() for c in cs
|
|
)
|
|
|
|
return [
|
|
Match(
|
|
completion=Completion(
|
|
collist, -1, display_meta="columns", display="*"
|
|
),
|
|
priority=(1, 1, 1),
|
|
)
|
|
]
|
|
|
|
return self.find_matches(word_before_cursor, flat_cols(), meta="column")
|
|
|
|
def alias(self, tbl, tbls):
|
|
"""Generate a unique table alias
|
|
tbl - name of the table to alias, quoted if it needs to be
|
|
tbls - TableReference iterable of tables already in query
|
|
"""
|
|
tbl = self.case(tbl)
|
|
tbls = {normalize_ref(t.ref) for t in tbls}
|
|
if self.generate_aliases:
|
|
tbl = generate_alias(self.unescape_name(tbl))
|
|
if normalize_ref(tbl) not in tbls:
|
|
return tbl
|
|
elif tbl[0] == '"':
|
|
aliases = ('"' + tbl[1:-1] + str(i) + '"' for i in count(2))
|
|
else:
|
|
aliases = (tbl + str(i) for i in count(2))
|
|
return next(a for a in aliases if normalize_ref(a) not in tbls)
|
|
|
|
def get_join_matches(self, suggestion, word_before_cursor):
|
|
tbls = suggestion.table_refs
|
|
cols = self.populate_scoped_cols(tbls)
|
|
# Set up some data structures for efficient access
|
|
qualified = {normalize_ref(t.ref): t.schema for t in tbls}
|
|
ref_prio = {normalize_ref(t.ref): n for n, t in enumerate(tbls)}
|
|
refs = {normalize_ref(t.ref) for t in tbls}
|
|
other_tbls = {(t.schema, t.name) for t in list(cols)[:-1]}
|
|
joins = []
|
|
# Iterate over FKs in existing tables to find potential joins
|
|
fks = (
|
|
(fk, rtbl, rcol)
|
|
for rtbl, rcols in cols.items()
|
|
for rcol in rcols
|
|
for fk in rcol.foreignkeys
|
|
)
|
|
col = namedtuple("col", "schema tbl col")
|
|
for fk, rtbl, rcol in fks:
|
|
right = col(rtbl.schema, rtbl.name, rcol.name)
|
|
child = col(fk.childschema, fk.childtable, fk.childcolumn)
|
|
parent = col(fk.parentschema, fk.parenttable, fk.parentcolumn)
|
|
left = child if parent == right else parent
|
|
if suggestion.schema and left.schema != suggestion.schema:
|
|
continue
|
|
c = self.case
|
|
if self.generate_aliases or normalize_ref(left.tbl) in refs:
|
|
lref = self.alias(left.tbl, suggestion.table_refs)
|
|
join = "{0} {4} ON {4}.{1} = {2}.{3}".format(
|
|
c(left.tbl), c(left.col), rtbl.ref, c(right.col), lref
|
|
)
|
|
else:
|
|
join = "{0} ON {0}.{1} = {2}.{3}".format(
|
|
c(left.tbl), c(left.col), rtbl.ref, c(right.col)
|
|
)
|
|
alias = generate_alias(self.case(left.tbl))
|
|
synonyms = [
|
|
join,
|
|
"{0} ON {0}.{1} = {2}.{3}".format(
|
|
alias, c(left.col), rtbl.ref, c(right.col)
|
|
),
|
|
]
|
|
# Schema-qualify if (1) new table in same schema as old, and old
|
|
# is schema-qualified, or (2) new in other schema, except public
|
|
if not suggestion.schema and (
|
|
qualified[normalize_ref(rtbl.ref)]
|
|
and left.schema == right.schema
|
|
or left.schema not in (right.schema, "public")
|
|
):
|
|
join = left.schema + "." + join
|
|
prio = ref_prio[normalize_ref(rtbl.ref)] * 2 + (
|
|
0 if (left.schema, left.tbl) in other_tbls else 1
|
|
)
|
|
joins.append(Candidate(join, prio, "join", synonyms=synonyms))
|
|
|
|
return self.find_matches(word_before_cursor, joins, meta="join")
|
|
|
|
def get_join_condition_matches(self, suggestion, word_before_cursor):
|
|
col = namedtuple("col", "schema tbl col")
|
|
tbls = self.populate_scoped_cols(suggestion.table_refs).items
|
|
cols = [(t, c) for t, cs in tbls() for c in cs]
|
|
try:
|
|
lref = (suggestion.parent or suggestion.table_refs[-1]).ref
|
|
ltbl, lcols = [(t, cs) for (t, cs) in tbls() if t.ref == lref][-1]
|
|
except IndexError: # The user typed an incorrect table qualifier
|
|
return []
|
|
conds, found_conds = [], set()
|
|
|
|
def add_cond(lcol, rcol, rref, prio, meta):
|
|
prefix = "" if suggestion.parent else ltbl.ref + "."
|
|
case = self.case
|
|
cond = prefix + case(lcol) + " = " + rref + "." + case(rcol)
|
|
if cond not in found_conds:
|
|
found_conds.add(cond)
|
|
conds.append(Candidate(cond, prio + ref_prio[rref], meta))
|
|
|
|
def list_dict(pairs): # Turns [(a, b), (a, c)] into {a: [b, c]}
|
|
d = defaultdict(list)
|
|
for pair in pairs:
|
|
d[pair[0]].append(pair[1])
|
|
return d
|
|
|
|
# Tables that are closer to the cursor get higher prio
|
|
ref_prio = {tbl.ref: num for num, tbl in enumerate(suggestion.table_refs)}
|
|
# Map (schema, table, col) to tables
|
|
coldict = list_dict(
|
|
((t.schema, t.name, c.name), t) for t, c in cols if t.ref != lref
|
|
)
|
|
# For each fk from the left table, generate a join condition if
|
|
# the other table is also in the scope
|
|
fks = ((fk, lcol.name) for lcol in lcols for fk in lcol.foreignkeys)
|
|
for fk, lcol in fks:
|
|
left = col(ltbl.schema, ltbl.name, lcol)
|
|
child = col(fk.childschema, fk.childtable, fk.childcolumn)
|
|
par = col(fk.parentschema, fk.parenttable, fk.parentcolumn)
|
|
left, right = (child, par) if left == child else (par, child)
|
|
for rtbl in coldict[right]:
|
|
add_cond(left.col, right.col, rtbl.ref, 2000, "fk join")
|
|
# For name matching, use a {(colname, coltype): TableReference} dict
|
|
coltyp = namedtuple("coltyp", "name datatype")
|
|
col_table = list_dict((coltyp(c.name, c.datatype), t) for t, c in cols)
|
|
# Find all name-match join conditions
|
|
for c in (coltyp(c.name, c.datatype) for c in lcols):
|
|
for rtbl in (t for t in col_table[c] if t.ref != ltbl.ref):
|
|
prio = 1000 if c.datatype in ("integer", "bigint", "smallint") else 0
|
|
add_cond(c.name, c.name, rtbl.ref, prio, "name join")
|
|
|
|
return self.find_matches(word_before_cursor, conds, meta="join")
|
|
|
|
def get_function_matches(self, suggestion, word_before_cursor, alias=False):
|
|
if suggestion.usage == "from":
|
|
# Only suggest functions allowed in FROM clause
|
|
|
|
def filt(f):
|
|
return (
|
|
not f.is_aggregate
|
|
and not f.is_window
|
|
and not f.is_extension
|
|
and (
|
|
f.is_public
|
|
or f.schema_name in self.search_path
|
|
or f.schema_name == suggestion.schema
|
|
)
|
|
)
|
|
|
|
else:
|
|
alias = False
|
|
|
|
def filt(f):
|
|
return not f.is_extension and (
|
|
f.is_public or f.schema_name == suggestion.schema
|
|
)
|
|
|
|
arg_mode = {"signature": "signature", "special": None}.get(
|
|
suggestion.usage, "call"
|
|
)
|
|
|
|
# Function overloading means we way have multiple functions of the same
|
|
# name at this point, so keep unique names only
|
|
all_functions = self.populate_functions(suggestion.schema, filt)
|
|
funcs = {self._make_cand(f, alias, suggestion, arg_mode) for f in all_functions}
|
|
|
|
matches = self.find_matches(word_before_cursor, funcs, meta="function")
|
|
|
|
if not suggestion.schema and not suggestion.usage:
|
|
# also suggest hardcoded functions using startswith matching
|
|
predefined_funcs = self.find_matches(
|
|
word_before_cursor, self.functions, mode="strict", meta="function"
|
|
)
|
|
matches.extend(predefined_funcs)
|
|
|
|
return matches
|
|
|
|
def get_schema_matches(self, suggestion, word_before_cursor):
|
|
schema_names = self.dbmetadata["tables"].keys()
|
|
|
|
# Unless we're sure the user really wants them, hide schema names
|
|
# starting with pg_, which are mostly temporary schemas
|
|
if not word_before_cursor.startswith("pg_"):
|
|
schema_names = [s for s in schema_names if not s.startswith("pg_")]
|
|
|
|
if suggestion.quoted:
|
|
schema_names = [self.escape_schema(s) for s in schema_names]
|
|
|
|
return self.find_matches(word_before_cursor, schema_names, meta="schema")
|
|
|
|
def get_from_clause_item_matches(self, suggestion, word_before_cursor):
|
|
alias = self.generate_aliases
|
|
s = suggestion
|
|
t_sug = Table(s.schema, s.table_refs, s.local_tables)
|
|
v_sug = View(s.schema, s.table_refs)
|
|
f_sug = Function(s.schema, s.table_refs, usage="from")
|
|
return (
|
|
self.get_table_matches(t_sug, word_before_cursor, alias)
|
|
+ self.get_view_matches(v_sug, word_before_cursor, alias)
|
|
+ self.get_function_matches(f_sug, word_before_cursor, alias)
|
|
)
|
|
|
|
def _arg_list(self, func, usage):
|
|
"""Returns a an arg list string, e.g. `(_foo:=23)` for a func.
|
|
|
|
:param func is a FunctionMetadata object
|
|
:param usage is 'call', 'call_display' or 'signature'
|
|
|
|
"""
|
|
template = {
|
|
"call": self.call_arg_style,
|
|
"call_display": self.call_arg_display_style,
|
|
"signature": self.signature_arg_style,
|
|
}[usage]
|
|
args = func.args()
|
|
if not template:
|
|
return "()"
|
|
elif usage == "call" and len(args) < 2:
|
|
return "()"
|
|
elif usage == "call" and func.has_variadic():
|
|
return "()"
|
|
multiline = usage == "call" and len(args) > self.call_arg_oneliner_max
|
|
max_arg_len = max(len(a.name) for a in args) if multiline else 0
|
|
args = (
|
|
self._format_arg(template, arg, arg_num + 1, max_arg_len)
|
|
for arg_num, arg in enumerate(args)
|
|
)
|
|
if multiline:
|
|
return "(" + ",".join("\n " + a for a in args if a) + "\n)"
|
|
else:
|
|
return "(" + ", ".join(a for a in args if a) + ")"
|
|
|
|
def _format_arg(self, template, arg, arg_num, max_arg_len):
|
|
if not template:
|
|
return None
|
|
if arg.has_default:
|
|
arg_default = "NULL" if arg.default is None else arg.default
|
|
# Remove trailing ::(schema.)type
|
|
arg_default = arg_default_type_strip_regex.sub("", arg_default)
|
|
else:
|
|
arg_default = ""
|
|
return template.format(
|
|
max_arg_len=max_arg_len,
|
|
arg_name=self.case(arg.name),
|
|
arg_num=arg_num,
|
|
arg_type=arg.datatype,
|
|
arg_default=arg_default,
|
|
)
|
|
|
|
def _make_cand(self, tbl, do_alias, suggestion, arg_mode=None):
|
|
"""Returns a Candidate namedtuple.
|
|
|
|
:param tbl is a SchemaObject
|
|
:param arg_mode determines what type of arg list to suffix for functions.
|
|
Possible values: call, signature
|
|
|
|
"""
|
|
cased_tbl = self.case(tbl.name)
|
|
if do_alias:
|
|
alias = self.alias(cased_tbl, suggestion.table_refs)
|
|
synonyms = (cased_tbl, generate_alias(cased_tbl))
|
|
maybe_alias = (" " + alias) if do_alias else ""
|
|
maybe_schema = (self.case(tbl.schema) + ".") if tbl.schema else ""
|
|
suffix = self._arg_list_cache[arg_mode][tbl.meta] if arg_mode else ""
|
|
if arg_mode == "call":
|
|
display_suffix = self._arg_list_cache["call_display"][tbl.meta]
|
|
elif arg_mode == "signature":
|
|
display_suffix = self._arg_list_cache["signature"][tbl.meta]
|
|
else:
|
|
display_suffix = ""
|
|
item = maybe_schema + cased_tbl + suffix + maybe_alias
|
|
display = maybe_schema + cased_tbl + display_suffix + maybe_alias
|
|
prio2 = 0 if tbl.schema else 1
|
|
return Candidate(item, synonyms=synonyms, prio2=prio2, display=display)
|
|
|
|
def get_table_matches(self, suggestion, word_before_cursor, alias=False):
|
|
tables = self.populate_schema_objects(suggestion.schema, "tables")
|
|
tables.extend(SchemaObject(tbl.name) for tbl in suggestion.local_tables)
|
|
|
|
# Unless we're sure the user really wants them, don't suggest the
|
|
# pg_catalog tables that are implicitly on the search path
|
|
if not suggestion.schema and (not word_before_cursor.startswith("pg_")):
|
|
tables = [t for t in tables if not t.name.startswith("pg_")]
|
|
tables = [self._make_cand(t, alias, suggestion) for t in tables]
|
|
return self.find_matches(word_before_cursor, tables, meta="table")
|
|
|
|
def get_table_formats(self, _, word_before_cursor):
|
|
formats = TabularOutputFormatter().supported_formats
|
|
return self.find_matches(word_before_cursor, formats, meta="table format")
|
|
|
|
def get_view_matches(self, suggestion, word_before_cursor, alias=False):
|
|
views = self.populate_schema_objects(suggestion.schema, "views")
|
|
|
|
if not suggestion.schema and (not word_before_cursor.startswith("pg_")):
|
|
views = [v for v in views if not v.name.startswith("pg_")]
|
|
views = [self._make_cand(v, alias, suggestion) for v in views]
|
|
return self.find_matches(word_before_cursor, views, meta="view")
|
|
|
|
def get_alias_matches(self, suggestion, word_before_cursor):
|
|
aliases = suggestion.aliases
|
|
return self.find_matches(word_before_cursor, aliases, meta="table alias")
|
|
|
|
def get_database_matches(self, _, word_before_cursor):
|
|
return self.find_matches(word_before_cursor, self.databases, meta="database")
|
|
|
|
def get_keyword_matches(self, suggestion, word_before_cursor):
|
|
keywords = self.keywords_tree.keys()
|
|
# Get well known following keywords for the last token. If any, narrow
|
|
# candidates to this list.
|
|
next_keywords = self.keywords_tree.get(suggestion.last_token, [])
|
|
if next_keywords:
|
|
keywords = next_keywords
|
|
|
|
casing = self.keyword_casing
|
|
if casing == "auto":
|
|
if word_before_cursor and word_before_cursor[-1].islower():
|
|
casing = "lower"
|
|
else:
|
|
casing = "upper"
|
|
|
|
if casing == "upper":
|
|
keywords = [k.upper() for k in keywords]
|
|
else:
|
|
keywords = [k.lower() for k in keywords]
|
|
|
|
return self.find_matches(
|
|
word_before_cursor, keywords, mode="strict", meta="keyword"
|
|
)
|
|
|
|
def get_path_matches(self, _, word_before_cursor):
|
|
completer = PathCompleter(expanduser=True)
|
|
document = Document(
|
|
text=word_before_cursor, cursor_position=len(word_before_cursor)
|
|
)
|
|
for c in completer.get_completions(document, None):
|
|
yield Match(completion=c, priority=(0,))
|
|
|
|
def get_special_matches(self, _, word_before_cursor):
|
|
if not self.pgspecial:
|
|
return []
|
|
|
|
commands = self.pgspecial.commands
|
|
cmds = commands.keys()
|
|
cmds = [Candidate(cmd, 0, commands[cmd].description) for cmd in cmds]
|
|
return self.find_matches(word_before_cursor, cmds, mode="strict")
|
|
|
|
def get_datatype_matches(self, suggestion, word_before_cursor):
|
|
# suggest custom datatypes
|
|
types = self.populate_schema_objects(suggestion.schema, "datatypes")
|
|
types = [self._make_cand(t, False, suggestion) for t in types]
|
|
matches = self.find_matches(word_before_cursor, types, meta="datatype")
|
|
|
|
if not suggestion.schema:
|
|
# Also suggest hardcoded types
|
|
matches.extend(
|
|
self.find_matches(
|
|
word_before_cursor, self.datatypes, mode="strict", meta="datatype"
|
|
)
|
|
)
|
|
|
|
return matches
|
|
|
|
def get_namedquery_matches(self, _, word_before_cursor):
|
|
return self.find_matches(
|
|
word_before_cursor, NamedQueries.instance.list(), meta="named query"
|
|
)
|
|
|
|
suggestion_matchers = {
|
|
FromClauseItem: get_from_clause_item_matches,
|
|
JoinCondition: get_join_condition_matches,
|
|
Join: get_join_matches,
|
|
Column: get_column_matches,
|
|
Function: get_function_matches,
|
|
Schema: get_schema_matches,
|
|
Table: get_table_matches,
|
|
TableFormat: get_table_formats,
|
|
View: get_view_matches,
|
|
Alias: get_alias_matches,
|
|
Database: get_database_matches,
|
|
Keyword: get_keyword_matches,
|
|
Special: get_special_matches,
|
|
Datatype: get_datatype_matches,
|
|
NamedQuery: get_namedquery_matches,
|
|
Path: get_path_matches,
|
|
}
|
|
|
|
def populate_scoped_cols(self, scoped_tbls, local_tbls=()):
|
|
"""Find all columns in a set of scoped_tables.
|
|
|
|
:param scoped_tbls: list of TableReference namedtuples
|
|
:param local_tbls: tuple(TableMetadata)
|
|
:return: {TableReference:{colname:ColumnMetaData}}
|
|
|
|
"""
|
|
ctes = {normalize_ref(t.name): t.columns for t in local_tbls}
|
|
columns = OrderedDict()
|
|
meta = self.dbmetadata
|
|
|
|
def addcols(schema, rel, alias, reltype, cols):
|
|
tbl = TableReference(schema, rel, alias, reltype == "functions")
|
|
if tbl not in columns:
|
|
columns[tbl] = []
|
|
columns[tbl].extend(cols)
|
|
|
|
for tbl in scoped_tbls:
|
|
# Local tables should shadow database tables
|
|
if tbl.schema is None and normalize_ref(tbl.name) in ctes:
|
|
cols = ctes[normalize_ref(tbl.name)]
|
|
addcols(None, tbl.name, "CTE", tbl.alias, cols)
|
|
continue
|
|
schemas = [tbl.schema] if tbl.schema else self.search_path
|
|
for schema in schemas:
|
|
relname = self.escape_name(tbl.name)
|
|
schema = self.escape_name(schema)
|
|
if tbl.is_function:
|
|
# Return column names from a set-returning function
|
|
# Get an array of FunctionMetadata objects
|
|
functions = meta["functions"].get(schema, {}).get(relname)
|
|
for func in functions or []:
|
|
# func is a FunctionMetadata object
|
|
cols = func.fields()
|
|
addcols(schema, relname, tbl.alias, "functions", cols)
|
|
else:
|
|
for reltype in ("tables", "views"):
|
|
cols = meta[reltype].get(schema, {}).get(relname)
|
|
if cols:
|
|
cols = cols.values()
|
|
addcols(schema, relname, tbl.alias, reltype, cols)
|
|
break
|
|
|
|
return columns
|
|
|
|
def _get_schemas(self, obj_typ, schema):
|
|
"""Returns a list of schemas from which to suggest objects.
|
|
|
|
:param schema is the schema qualification input by the user (if any)
|
|
|
|
"""
|
|
metadata = self.dbmetadata[obj_typ]
|
|
if schema:
|
|
schema = self.escape_name(schema)
|
|
return [schema] if schema in metadata else []
|
|
return self.search_path if self.search_path_filter else metadata.keys()
|
|
|
|
def _maybe_schema(self, schema, parent):
|
|
return None if parent or schema in self.search_path else schema
|
|
|
|
def populate_schema_objects(self, schema, obj_type):
|
|
"""Returns a list of SchemaObjects representing tables or views.
|
|
|
|
:param schema is the schema qualification input by the user (if any)
|
|
|
|
"""
|
|
|
|
return [
|
|
SchemaObject(
|
|
name=obj, schema=(self._maybe_schema(schema=sch, parent=schema))
|
|
)
|
|
for sch in self._get_schemas(obj_type, schema)
|
|
for obj in self.dbmetadata[obj_type][sch].keys()
|
|
]
|
|
|
|
def populate_functions(self, schema, filter_func):
|
|
"""Returns a list of function SchemaObjects.
|
|
|
|
:param filter_func is a function that accepts a FunctionMetadata
|
|
namedtuple and returns a boolean indicating whether that
|
|
function should be kept or discarded
|
|
|
|
"""
|
|
|
|
# Because of multiple dispatch, we can have multiple functions
|
|
# with the same name, which is why `for meta in metas` is necessary
|
|
# in the comprehensions below
|
|
return [
|
|
SchemaObject(
|
|
name=func,
|
|
schema=(self._maybe_schema(schema=sch, parent=schema)),
|
|
meta=meta,
|
|
)
|
|
for sch in self._get_schemas("functions", schema)
|
|
for (func, metas) in self.dbmetadata["functions"][sch].items()
|
|
for meta in metas
|
|
if filter_func(meta)
|
|
]
|