Adding upstream version 1.23.2.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
f253096a15
commit
94e3fc38e7
93 changed files with 10761 additions and 0 deletions
89
mycli/AUTHORS
Normal file
89
mycli/AUTHORS
Normal file
|
@ -0,0 +1,89 @@
|
|||
Project Lead:
|
||||
-------------
|
||||
* Thomas Roten
|
||||
|
||||
|
||||
Core Developers:
|
||||
----------------
|
||||
|
||||
* Irina Truong
|
||||
* Matheus Rosa
|
||||
* Darik Gamble
|
||||
* Dick Marinus
|
||||
* Amjith Ramanujam
|
||||
|
||||
Contributors:
|
||||
-------------
|
||||
|
||||
* Steve Robbins
|
||||
* Shoma Suzuki
|
||||
* Daniel West
|
||||
* Scrappy Soft
|
||||
* Daniel Black
|
||||
* Jonathan Bruno
|
||||
* Casper Langemeijer
|
||||
* Jonathan Slenders
|
||||
* Artem Bezsmertnyi
|
||||
* Mikhail Borisov
|
||||
* Heath Naylor
|
||||
* Phil Cohen
|
||||
* spacewander
|
||||
* Adam Chainz
|
||||
* Johannes Hoff
|
||||
* Kacper Kwapisz
|
||||
* Lennart Weller
|
||||
* Martijn Engler
|
||||
* Terseus
|
||||
* Tyler Kuipers
|
||||
* William GARCIA
|
||||
* Yasuhiro Matsumoto
|
||||
* bjarnagin
|
||||
* jbruno
|
||||
* mrdeathless
|
||||
* Abirami P
|
||||
* John Sterling
|
||||
* Jialong Liu
|
||||
* Zhidong
|
||||
* Daniël van Eeden
|
||||
* zer09
|
||||
* cxbig
|
||||
* chainkite
|
||||
* Michał Górny
|
||||
* Terje Røsten
|
||||
* Ryan Smith
|
||||
* Klaus Wünschel
|
||||
* François Pietka
|
||||
* Colin Caine
|
||||
* Frederic Aoustin
|
||||
* caitinggui
|
||||
* ushuz
|
||||
* Zhaolong Zhu
|
||||
* Zhongyang Guan
|
||||
* Huachao Mao
|
||||
* QiaoHou Peng
|
||||
* Yang Zou
|
||||
* Angelo Lupo
|
||||
* Aljosha Papsch
|
||||
* Zane C. Bowers-Hadley
|
||||
* Mike Palandra
|
||||
* Georgy Frolov
|
||||
* Jonathan Lloyd
|
||||
* Nathan Huang
|
||||
* Jakub Boukal
|
||||
* Takeshi D. Itoh
|
||||
* laixintao
|
||||
* Zach DeCook
|
||||
* kevinhwang91
|
||||
* KITAGAWA Yasutaka
|
||||
* bitkeen
|
||||
* Morgan Mitchell
|
||||
* Massimiliano Torromeo
|
||||
* Roland Walker
|
||||
* xeron
|
||||
* 0xflotus
|
||||
* Seamile
|
||||
|
||||
Creator:
|
||||
--------
|
||||
|
||||
Amjith Ramanujam
|
31
mycli/SPONSORS
Normal file
31
mycli/SPONSORS
Normal file
|
@ -0,0 +1,31 @@
|
|||
Many thanks to the following Kickstarter backers.
|
||||
|
||||
* Tech Blue Software
|
||||
* jweiland.net
|
||||
|
||||
# Silver Sponsors
|
||||
|
||||
* Whitane Tech
|
||||
* Open Query Pty Ltd
|
||||
* Prathap Ramamurthy
|
||||
* Lincoln Loop
|
||||
|
||||
# Sponsors
|
||||
|
||||
* Nathan Taggart
|
||||
* Iryna Cherniavska
|
||||
* Sudaraka Wijesinghe
|
||||
* www.mysqlfanboy.com
|
||||
* Steve Robbins
|
||||
* Norbert Spichtig
|
||||
* orpharion bestheneme
|
||||
* Daniel Black
|
||||
* Anonymous
|
||||
* Magnus udd
|
||||
* Anonymous
|
||||
* Lewis Peckover
|
||||
* Cyrille Tabary
|
||||
* Heath Naylor
|
||||
* Ted Pennings
|
||||
* Chris Anderton
|
||||
* Jonathan Slenders
|
1
mycli/__init__.py
Normal file
1
mycli/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
__version__ = '1.23.2'
|
56
mycli/clibuffer.py
Normal file
56
mycli/clibuffer.py
Normal file
|
@ -0,0 +1,56 @@
|
|||
from prompt_toolkit.enums import DEFAULT_BUFFER
|
||||
from prompt_toolkit.filters import Condition
|
||||
from prompt_toolkit.application import get_app
|
||||
from .packages.parseutils import is_open_quote
|
||||
from .packages import special
|
||||
|
||||
|
||||
def cli_is_multiline(mycli):
|
||||
@Condition
|
||||
def cond():
|
||||
doc = get_app().layout.get_buffer_by_name(DEFAULT_BUFFER).document
|
||||
|
||||
if not mycli.multi_line:
|
||||
return False
|
||||
else:
|
||||
return not _multiline_exception(doc.text)
|
||||
return cond
|
||||
|
||||
|
||||
def _multiline_exception(text):
|
||||
orig = text
|
||||
text = text.strip()
|
||||
|
||||
# Multi-statement favorite query is a special case. Because there will
|
||||
# be a semicolon separating statements, we can't consider semicolon an
|
||||
# EOL. Let's consider an empty line an EOL instead.
|
||||
if text.startswith('\\fs'):
|
||||
return orig.endswith('\n')
|
||||
|
||||
return (
|
||||
# Special Command
|
||||
text.startswith('\\') or
|
||||
|
||||
# Delimiter declaration
|
||||
text.lower().startswith('delimiter') or
|
||||
|
||||
# Ended with the current delimiter (usually a semi-column)
|
||||
text.endswith(special.get_current_delimiter()) or
|
||||
|
||||
text.endswith('\\g') or
|
||||
text.endswith('\\G') or
|
||||
text.endswith(r'\e') or
|
||||
text.endswith(r'\clip') or
|
||||
|
||||
# Exit doesn't need semi-column`
|
||||
(text == 'exit') or
|
||||
|
||||
# Quit doesn't need semi-column
|
||||
(text == 'quit') or
|
||||
|
||||
# To all teh vim fans out there
|
||||
(text == ':q') or
|
||||
|
||||
# just a plain enter without any text
|
||||
(text == '')
|
||||
)
|
152
mycli/clistyle.py
Normal file
152
mycli/clistyle.py
Normal file
|
@ -0,0 +1,152 @@
|
|||
import logging
|
||||
|
||||
import pygments.styles
|
||||
from pygments.token import string_to_tokentype, Token
|
||||
from pygments.style import Style as PygmentsStyle
|
||||
from pygments.util import ClassNotFound
|
||||
from prompt_toolkit.styles.pygments import style_from_pygments_cls
|
||||
from prompt_toolkit.styles import merge_styles, Style
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# map Pygments tokens (ptk 1.0) to class names (ptk 2.0).
|
||||
TOKEN_TO_PROMPT_STYLE = {
|
||||
Token.Menu.Completions.Completion.Current: 'completion-menu.completion.current',
|
||||
Token.Menu.Completions.Completion: 'completion-menu.completion',
|
||||
Token.Menu.Completions.Meta.Current: 'completion-menu.meta.completion.current',
|
||||
Token.Menu.Completions.Meta: 'completion-menu.meta.completion',
|
||||
Token.Menu.Completions.MultiColumnMeta: 'completion-menu.multi-column-meta',
|
||||
Token.Menu.Completions.ProgressButton: 'scrollbar.arrow', # best guess
|
||||
Token.Menu.Completions.ProgressBar: 'scrollbar', # best guess
|
||||
Token.SelectedText: 'selected',
|
||||
Token.SearchMatch: 'search',
|
||||
Token.SearchMatch.Current: 'search.current',
|
||||
Token.Toolbar: 'bottom-toolbar',
|
||||
Token.Toolbar.Off: 'bottom-toolbar.off',
|
||||
Token.Toolbar.On: 'bottom-toolbar.on',
|
||||
Token.Toolbar.Search: 'search-toolbar',
|
||||
Token.Toolbar.Search.Text: 'search-toolbar.text',
|
||||
Token.Toolbar.System: 'system-toolbar',
|
||||
Token.Toolbar.Arg: 'arg-toolbar',
|
||||
Token.Toolbar.Arg.Text: 'arg-toolbar.text',
|
||||
Token.Toolbar.Transaction.Valid: 'bottom-toolbar.transaction.valid',
|
||||
Token.Toolbar.Transaction.Failed: 'bottom-toolbar.transaction.failed',
|
||||
Token.Output.Header: 'output.header',
|
||||
Token.Output.OddRow: 'output.odd-row',
|
||||
Token.Output.EvenRow: 'output.even-row',
|
||||
Token.Output.Null: 'output.null',
|
||||
Token.Prompt: 'prompt',
|
||||
Token.Continuation: 'continuation',
|
||||
}
|
||||
|
||||
# reverse dict for cli_helpers, because they still expect Pygments tokens.
|
||||
PROMPT_STYLE_TO_TOKEN = {
|
||||
v: k for k, v in TOKEN_TO_PROMPT_STYLE.items()
|
||||
}
|
||||
|
||||
# all tokens that the Pygments MySQL lexer can produce
|
||||
OVERRIDE_STYLE_TO_TOKEN = {
|
||||
'sql.comment': Token.Comment,
|
||||
'sql.comment.multi-line': Token.Comment.Multiline,
|
||||
'sql.comment.single-line': Token.Comment.Single,
|
||||
'sql.comment.optimizer-hint': Token.Comment.Special,
|
||||
'sql.escape': Token.Error,
|
||||
'sql.keyword': Token.Keyword,
|
||||
'sql.datatype': Token.Keyword.Type,
|
||||
'sql.literal': Token.Literal,
|
||||
'sql.literal.date': Token.Literal.Date,
|
||||
'sql.symbol': Token.Name,
|
||||
'sql.quoted-schema-object': Token.Name.Quoted,
|
||||
'sql.quoted-schema-object.escape': Token.Name.Quoted.Escape,
|
||||
'sql.constant': Token.Name.Constant,
|
||||
'sql.function': Token.Name.Function,
|
||||
'sql.variable': Token.Name.Variable,
|
||||
'sql.number': Token.Number,
|
||||
'sql.number.binary': Token.Number.Bin,
|
||||
'sql.number.float': Token.Number.Float,
|
||||
'sql.number.hex': Token.Number.Hex,
|
||||
'sql.number.integer': Token.Number.Integer,
|
||||
'sql.operator': Token.Operator,
|
||||
'sql.punctuation': Token.Punctuation,
|
||||
'sql.string': Token.String,
|
||||
'sql.string.double-quouted': Token.String.Double,
|
||||
'sql.string.escape': Token.String.Escape,
|
||||
'sql.string.single-quoted': Token.String.Single,
|
||||
'sql.whitespace': Token.Text,
|
||||
}
|
||||
|
||||
def parse_pygments_style(token_name, style_object, style_dict):
|
||||
"""Parse token type and style string.
|
||||
|
||||
:param token_name: str name of Pygments token. Example: "Token.String"
|
||||
:param style_object: pygments.style.Style instance to use as base
|
||||
:param style_dict: dict of token names and their styles, customized to this cli
|
||||
|
||||
"""
|
||||
token_type = string_to_tokentype(token_name)
|
||||
try:
|
||||
other_token_type = string_to_tokentype(style_dict[token_name])
|
||||
return token_type, style_object.styles[other_token_type]
|
||||
except AttributeError as err:
|
||||
return token_type, style_dict[token_name]
|
||||
|
||||
|
||||
def style_factory(name, cli_style):
|
||||
try:
|
||||
style = pygments.styles.get_style_by_name(name)
|
||||
except ClassNotFound:
|
||||
style = pygments.styles.get_style_by_name('native')
|
||||
|
||||
prompt_styles = []
|
||||
# prompt-toolkit used pygments tokens for styling before, switched to style
|
||||
# names in 2.0. Convert old token types to new style names, for backwards compatibility.
|
||||
for token in cli_style:
|
||||
if token.startswith('Token.'):
|
||||
# treat as pygments token (1.0)
|
||||
token_type, style_value = parse_pygments_style(
|
||||
token, style, cli_style)
|
||||
if token_type in TOKEN_TO_PROMPT_STYLE:
|
||||
prompt_style = TOKEN_TO_PROMPT_STYLE[token_type]
|
||||
prompt_styles.append((prompt_style, style_value))
|
||||
else:
|
||||
# we don't want to support tokens anymore
|
||||
logger.error('Unhandled style / class name: %s', token)
|
||||
else:
|
||||
# treat as prompt style name (2.0). See default style names here:
|
||||
# https://github.com/jonathanslenders/python-prompt-toolkit/blob/master/prompt_toolkit/styles/defaults.py
|
||||
prompt_styles.append((token, cli_style[token]))
|
||||
|
||||
override_style = Style([('bottom-toolbar', 'noreverse')])
|
||||
return merge_styles([
|
||||
style_from_pygments_cls(style),
|
||||
override_style,
|
||||
Style(prompt_styles)
|
||||
])
|
||||
|
||||
|
||||
def style_factory_output(name, cli_style):
|
||||
try:
|
||||
style = pygments.styles.get_style_by_name(name).styles
|
||||
except ClassNotFound:
|
||||
style = pygments.styles.get_style_by_name('native').styles
|
||||
|
||||
for token in cli_style:
|
||||
if token.startswith('Token.'):
|
||||
token_type, style_value = parse_pygments_style(
|
||||
token, style, cli_style)
|
||||
style.update({token_type: style_value})
|
||||
elif token in PROMPT_STYLE_TO_TOKEN:
|
||||
token_type = PROMPT_STYLE_TO_TOKEN[token]
|
||||
style.update({token_type: cli_style[token]})
|
||||
elif token in OVERRIDE_STYLE_TO_TOKEN:
|
||||
token_type = OVERRIDE_STYLE_TO_TOKEN[token]
|
||||
style.update({token_type: cli_style[token]})
|
||||
else:
|
||||
# TODO: cli helpers will have to switch to ptk.Style
|
||||
logger.error('Unhandled style / class name: %s', token)
|
||||
|
||||
class OutputStyle(PygmentsStyle):
|
||||
default_style = ""
|
||||
styles = style
|
||||
|
||||
return OutputStyle
|
53
mycli/clitoolbar.py
Normal file
53
mycli/clitoolbar.py
Normal file
|
@ -0,0 +1,53 @@
|
|||
from prompt_toolkit.key_binding.vi_state import InputMode
|
||||
from prompt_toolkit.application import get_app
|
||||
from prompt_toolkit.enums import EditingMode
|
||||
from .packages import special
|
||||
|
||||
|
||||
def create_toolbar_tokens_func(mycli, show_fish_help):
|
||||
"""Return a function that generates the toolbar tokens."""
|
||||
def get_toolbar_tokens():
|
||||
result = []
|
||||
result.append(('class:bottom-toolbar', ' '))
|
||||
|
||||
if mycli.multi_line:
|
||||
delimiter = special.get_current_delimiter()
|
||||
result.append(
|
||||
(
|
||||
'class:bottom-toolbar',
|
||||
' ({} [{}] will end the line) '.format(
|
||||
'Semi-colon' if delimiter == ';' else 'Delimiter', delimiter)
|
||||
))
|
||||
|
||||
if mycli.multi_line:
|
||||
result.append(('class:bottom-toolbar.on', '[F3] Multiline: ON '))
|
||||
else:
|
||||
result.append(('class:bottom-toolbar.off',
|
||||
'[F3] Multiline: OFF '))
|
||||
if mycli.prompt_app.editing_mode == EditingMode.VI:
|
||||
result.append((
|
||||
'class:botton-toolbar.on',
|
||||
'Vi-mode ({})'.format(_get_vi_mode())
|
||||
))
|
||||
|
||||
if show_fish_help():
|
||||
result.append(
|
||||
('class:bottom-toolbar', ' Right-arrow to complete suggestion'))
|
||||
|
||||
if mycli.completion_refresher.is_refreshing():
|
||||
result.append(
|
||||
('class:bottom-toolbar', ' Refreshing completions...'))
|
||||
|
||||
return result
|
||||
return get_toolbar_tokens
|
||||
|
||||
|
||||
def _get_vi_mode():
|
||||
"""Get the current vi mode for display."""
|
||||
return {
|
||||
InputMode.INSERT: 'I',
|
||||
InputMode.NAVIGATION: 'N',
|
||||
InputMode.REPLACE: 'R',
|
||||
InputMode.REPLACE_SINGLE: 'R',
|
||||
InputMode.INSERT_MULTIPLE: 'M',
|
||||
}[get_app().vi_state.input_mode]
|
6
mycli/compat.py
Normal file
6
mycli/compat.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
"""Platform and Python version compatibility support."""
|
||||
|
||||
import sys
|
||||
|
||||
|
||||
WIN = sys.platform in ('win32', 'cygwin')
|
123
mycli/completion_refresher.py
Normal file
123
mycli/completion_refresher.py
Normal file
|
@ -0,0 +1,123 @@
|
|||
import threading
|
||||
from .packages.special.main import COMMANDS
|
||||
from collections import OrderedDict
|
||||
|
||||
from .sqlcompleter import SQLCompleter
|
||||
from .sqlexecute import SQLExecute
|
||||
|
||||
class CompletionRefresher(object):
|
||||
|
||||
refreshers = OrderedDict()
|
||||
|
||||
def __init__(self):
|
||||
self._completer_thread = None
|
||||
self._restart_refresh = threading.Event()
|
||||
|
||||
def refresh(self, executor, callbacks, completer_options=None):
|
||||
"""Creates a SQLCompleter object and populates it with the relevant
|
||||
completion suggestions in a background thread.
|
||||
|
||||
executor - SQLExecute object, used to extract the credentials to connect
|
||||
to the database.
|
||||
callbacks - A function or a list of functions to call after the thread
|
||||
has completed the refresh. The newly created completion
|
||||
object will be passed in as an argument to each callback.
|
||||
completer_options - dict of options to pass to SQLCompleter.
|
||||
|
||||
"""
|
||||
if completer_options is None:
|
||||
completer_options = {}
|
||||
|
||||
if self.is_refreshing():
|
||||
self._restart_refresh.set()
|
||||
return [(None, None, None, 'Auto-completion refresh restarted.')]
|
||||
else:
|
||||
self._completer_thread = threading.Thread(
|
||||
target=self._bg_refresh,
|
||||
args=(executor, callbacks, completer_options),
|
||||
name='completion_refresh')
|
||||
self._completer_thread.setDaemon(True)
|
||||
self._completer_thread.start()
|
||||
return [(None, None, None,
|
||||
'Auto-completion refresh started in the background.')]
|
||||
|
||||
def is_refreshing(self):
|
||||
return self._completer_thread and self._completer_thread.is_alive()
|
||||
|
||||
def _bg_refresh(self, sqlexecute, callbacks, completer_options):
|
||||
completer = SQLCompleter(**completer_options)
|
||||
|
||||
# Create a new pgexecute method to popoulate the completions.
|
||||
e = sqlexecute
|
||||
executor = SQLExecute(e.dbname, e.user, e.password, e.host, e.port,
|
||||
e.socket, e.charset, e.local_infile, e.ssl,
|
||||
e.ssh_user, e.ssh_host, e.ssh_port,
|
||||
e.ssh_password, e.ssh_key_filename)
|
||||
|
||||
# If callbacks is a single function then push it into a list.
|
||||
if callable(callbacks):
|
||||
callbacks = [callbacks]
|
||||
|
||||
while 1:
|
||||
for refresher in self.refreshers.values():
|
||||
refresher(completer, executor)
|
||||
if self._restart_refresh.is_set():
|
||||
self._restart_refresh.clear()
|
||||
break
|
||||
else:
|
||||
# Break out of while loop if the for loop finishes natually
|
||||
# without hitting the break statement.
|
||||
break
|
||||
|
||||
# Start over the refresh from the beginning if the for loop hit the
|
||||
# break statement.
|
||||
continue
|
||||
|
||||
for callback in callbacks:
|
||||
callback(completer)
|
||||
|
||||
def refresher(name, refreshers=CompletionRefresher.refreshers):
|
||||
"""Decorator to add the decorated function to the dictionary of
|
||||
refreshers. Any function decorated with a @refresher will be executed as
|
||||
part of the completion refresh routine."""
|
||||
def wrapper(wrapped):
|
||||
refreshers[name] = wrapped
|
||||
return wrapped
|
||||
return wrapper
|
||||
|
||||
@refresher('databases')
|
||||
def refresh_databases(completer, executor):
|
||||
completer.extend_database_names(executor.databases())
|
||||
|
||||
@refresher('schemata')
|
||||
def refresh_schemata(completer, executor):
|
||||
# schemata - In MySQL Schema is the same as database. But for mycli
|
||||
# schemata will be the name of the current database.
|
||||
completer.extend_schemata(executor.dbname)
|
||||
completer.set_dbname(executor.dbname)
|
||||
|
||||
@refresher('tables')
|
||||
def refresh_tables(completer, executor):
|
||||
completer.extend_relations(executor.tables(), kind='tables')
|
||||
completer.extend_columns(executor.table_columns(), kind='tables')
|
||||
|
||||
@refresher('users')
|
||||
def refresh_users(completer, executor):
|
||||
completer.extend_users(executor.users())
|
||||
|
||||
# @refresher('views')
|
||||
# def refresh_views(completer, executor):
|
||||
# completer.extend_relations(executor.views(), kind='views')
|
||||
# completer.extend_columns(executor.view_columns(), kind='views')
|
||||
|
||||
@refresher('functions')
|
||||
def refresh_functions(completer, executor):
|
||||
completer.extend_functions(executor.functions())
|
||||
|
||||
@refresher('special_commands')
|
||||
def refresh_special(completer, executor):
|
||||
completer.extend_special_commands(COMMANDS.keys())
|
||||
|
||||
@refresher('show_commands')
|
||||
def refresh_show_commands(completer, executor):
|
||||
completer.extend_show_items(executor.show_candidates())
|
286
mycli/config.py
Normal file
286
mycli/config.py
Normal file
|
@ -0,0 +1,286 @@
|
|||
import io
|
||||
import shutil
|
||||
from copy import copy
|
||||
from io import BytesIO, TextIOWrapper
|
||||
import logging
|
||||
import os
|
||||
from os.path import exists
|
||||
import struct
|
||||
import sys
|
||||
from typing import Union
|
||||
|
||||
from configobj import ConfigObj, ConfigObjError
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
|
||||
try:
|
||||
basestring
|
||||
except NameError:
|
||||
basestring = str
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def log(logger, level, message):
|
||||
"""Logs message to stderr if logging isn't initialized."""
|
||||
|
||||
if logger.parent.name != 'root':
|
||||
logger.log(level, message)
|
||||
else:
|
||||
print(message, file=sys.stderr)
|
||||
|
||||
|
||||
def read_config_file(f, list_values=True):
|
||||
"""Read a config file.
|
||||
|
||||
*list_values* set to `True` is the default behavior of ConfigObj.
|
||||
Disabling it causes values to not be parsed for lists,
|
||||
(e.g. 'a,b,c' -> ['a', 'b', 'c']. Additionally, the config values are
|
||||
not unquoted. We are disabling list_values when reading MySQL config files
|
||||
so we can correctly interpret commas in passwords.
|
||||
|
||||
"""
|
||||
|
||||
if isinstance(f, basestring):
|
||||
f = os.path.expanduser(f)
|
||||
|
||||
try:
|
||||
config = ConfigObj(f, interpolation=False, encoding='utf8',
|
||||
list_values=list_values)
|
||||
except ConfigObjError as e:
|
||||
log(logger, logging.ERROR, "Unable to parse line {0} of config file "
|
||||
"'{1}'.".format(e.line_number, f))
|
||||
log(logger, logging.ERROR, "Using successfully parsed config values.")
|
||||
return e.config
|
||||
except (IOError, OSError) as e:
|
||||
log(logger, logging.WARNING, "You don't have permission to read "
|
||||
"config file '{0}'.".format(e.filename))
|
||||
return None
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def get_included_configs(config_file: Union[str, io.TextIOWrapper]) -> list:
|
||||
"""Get a list of configuration files that are included into config_path
|
||||
with !includedir directive.
|
||||
|
||||
"Normal" configs should be passed as file paths. The only exception
|
||||
is .mylogin which is decoded into a stream. However, it never
|
||||
contains include directives and so will be ignored by this
|
||||
function.
|
||||
|
||||
"""
|
||||
if not isinstance(config_file, str) or not os.path.isfile(config_file):
|
||||
return []
|
||||
included_configs = []
|
||||
|
||||
try:
|
||||
with open(config_file) as f:
|
||||
include_directives = filter(
|
||||
lambda s: s.startswith('!includedir'),
|
||||
f
|
||||
)
|
||||
dirs = map(lambda s: s.strip().split()[-1], include_directives)
|
||||
dirs = filter(os.path.isdir, dirs)
|
||||
for dir in dirs:
|
||||
for filename in os.listdir(dir):
|
||||
if filename.endswith('.cnf'):
|
||||
included_configs.append(os.path.join(dir, filename))
|
||||
except (PermissionError, UnicodeDecodeError):
|
||||
pass
|
||||
return included_configs
|
||||
|
||||
|
||||
def read_config_files(files, list_values=True):
|
||||
"""Read and merge a list of config files."""
|
||||
|
||||
config = ConfigObj(list_values=list_values)
|
||||
_files = copy(files)
|
||||
while _files:
|
||||
_file = _files.pop(0)
|
||||
_config = read_config_file(_file, list_values=list_values)
|
||||
|
||||
# expand includes only if we were able to parse config
|
||||
# (otherwise we'll just encounter the same errors again)
|
||||
if config is not None:
|
||||
_files = get_included_configs(_file) + _files
|
||||
if bool(_config) is True:
|
||||
config.merge(_config)
|
||||
config.filename = _config.filename
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def write_default_config(source, destination, overwrite=False):
|
||||
destination = os.path.expanduser(destination)
|
||||
if not overwrite and exists(destination):
|
||||
return
|
||||
|
||||
shutil.copyfile(source, destination)
|
||||
|
||||
|
||||
def get_mylogin_cnf_path():
|
||||
"""Return the path to the login path file or None if it doesn't exist."""
|
||||
mylogin_cnf_path = os.getenv('MYSQL_TEST_LOGIN_FILE')
|
||||
|
||||
if mylogin_cnf_path is None:
|
||||
app_data = os.getenv('APPDATA')
|
||||
default_dir = os.path.join(app_data, 'MySQL') if app_data else '~'
|
||||
mylogin_cnf_path = os.path.join(default_dir, '.mylogin.cnf')
|
||||
|
||||
mylogin_cnf_path = os.path.expanduser(mylogin_cnf_path)
|
||||
|
||||
if exists(mylogin_cnf_path):
|
||||
logger.debug("Found login path file at '{0}'".format(mylogin_cnf_path))
|
||||
return mylogin_cnf_path
|
||||
return None
|
||||
|
||||
|
||||
def open_mylogin_cnf(name):
|
||||
"""Open a readable version of .mylogin.cnf.
|
||||
|
||||
Returns the file contents as a TextIOWrapper object.
|
||||
|
||||
:param str name: The pathname of the file to be opened.
|
||||
:return: the login path file or None
|
||||
"""
|
||||
|
||||
try:
|
||||
with open(name, 'rb') as f:
|
||||
plaintext = read_and_decrypt_mylogin_cnf(f)
|
||||
except (OSError, IOError, ValueError):
|
||||
logger.error('Unable to open login path file.')
|
||||
return None
|
||||
|
||||
if not isinstance(plaintext, BytesIO):
|
||||
logger.error('Unable to read login path file.')
|
||||
return None
|
||||
|
||||
return TextIOWrapper(plaintext)
|
||||
|
||||
|
||||
def read_and_decrypt_mylogin_cnf(f):
|
||||
"""Read and decrypt the contents of .mylogin.cnf.
|
||||
|
||||
This decryption algorithm mimics the code in MySQL's
|
||||
mysql_config_editor.cc.
|
||||
|
||||
The login key is 20-bytes of random non-printable ASCII.
|
||||
It is written to the actual login path file. It is used
|
||||
to generate the real key used in the AES cipher.
|
||||
|
||||
:param f: an I/O object opened in binary mode
|
||||
:return: the decrypted login path file
|
||||
:rtype: io.BytesIO or None
|
||||
"""
|
||||
|
||||
# Number of bytes used to store the length of ciphertext.
|
||||
MAX_CIPHER_STORE_LEN = 4
|
||||
|
||||
LOGIN_KEY_LEN = 20
|
||||
|
||||
# Move past the unused buffer.
|
||||
buf = f.read(4)
|
||||
|
||||
if not buf or len(buf) != 4:
|
||||
logger.error('Login path file is blank or incomplete.')
|
||||
return None
|
||||
|
||||
# Read the login key.
|
||||
key = f.read(LOGIN_KEY_LEN)
|
||||
|
||||
# Generate the real key.
|
||||
rkey = [0] * 16
|
||||
for i in range(LOGIN_KEY_LEN):
|
||||
try:
|
||||
rkey[i % 16] ^= ord(key[i:i+1])
|
||||
except TypeError:
|
||||
# ord() was unable to get the value of the byte.
|
||||
logger.error('Unable to generate login path AES key.')
|
||||
return None
|
||||
rkey = struct.pack('16B', *rkey)
|
||||
|
||||
# Create a decryptor object using the key.
|
||||
decryptor = _get_decryptor(rkey)
|
||||
|
||||
# Create a bytes buffer to hold the plaintext.
|
||||
plaintext = BytesIO()
|
||||
|
||||
while True:
|
||||
# Read the length of the ciphertext.
|
||||
len_buf = f.read(MAX_CIPHER_STORE_LEN)
|
||||
if len(len_buf) < MAX_CIPHER_STORE_LEN:
|
||||
break
|
||||
cipher_len, = struct.unpack("<i", len_buf)
|
||||
|
||||
# Read cipher_len bytes from the file and decrypt.
|
||||
cipher = f.read(cipher_len)
|
||||
plain = _remove_pad(decryptor.update(cipher))
|
||||
if plain is False:
|
||||
continue
|
||||
plaintext.write(plain)
|
||||
|
||||
if plaintext.tell() == 0:
|
||||
logger.error('No data successfully decrypted from login path file.')
|
||||
return None
|
||||
|
||||
plaintext.seek(0)
|
||||
return plaintext
|
||||
|
||||
|
||||
def str_to_bool(s):
|
||||
"""Convert a string value to its corresponding boolean value."""
|
||||
if isinstance(s, bool):
|
||||
return s
|
||||
elif not isinstance(s, basestring):
|
||||
raise TypeError('argument must be a string')
|
||||
|
||||
true_values = ('true', 'on', '1')
|
||||
false_values = ('false', 'off', '0')
|
||||
|
||||
if s.lower() in true_values:
|
||||
return True
|
||||
elif s.lower() in false_values:
|
||||
return False
|
||||
else:
|
||||
raise ValueError('not a recognized boolean value: %s'.format(s))
|
||||
|
||||
|
||||
def strip_matching_quotes(s):
|
||||
"""Remove matching, surrounding quotes from a string.
|
||||
|
||||
This is the same logic that ConfigObj uses when parsing config
|
||||
values.
|
||||
|
||||
"""
|
||||
if (isinstance(s, basestring) and len(s) >= 2 and
|
||||
s[0] == s[-1] and s[0] in ('"', "'")):
|
||||
s = s[1:-1]
|
||||
return s
|
||||
|
||||
|
||||
def _get_decryptor(key):
|
||||
"""Get the AES decryptor."""
|
||||
c = Cipher(algorithms.AES(key), modes.ECB(), backend=default_backend())
|
||||
return c.decryptor()
|
||||
|
||||
|
||||
def _remove_pad(line):
|
||||
"""Remove the pad from the *line*."""
|
||||
pad_length = ord(line[-1:])
|
||||
try:
|
||||
# Determine pad length.
|
||||
pad_length = ord(line[-1:])
|
||||
except TypeError:
|
||||
# ord() was unable to get the value of the byte.
|
||||
logger.warning('Unable to remove pad.')
|
||||
return False
|
||||
|
||||
if pad_length > len(line) or len(set(line[-pad_length:])) != 1:
|
||||
# Pad length should be less than or equal to the length of the
|
||||
# plaintext. The pad should have a single unique byte.
|
||||
logger.warning('Invalid pad found in login path file.')
|
||||
return False
|
||||
|
||||
return line[:-pad_length]
|
85
mycli/key_bindings.py
Normal file
85
mycli/key_bindings.py
Normal file
|
@ -0,0 +1,85 @@
|
|||
import logging
|
||||
from prompt_toolkit.enums import EditingMode
|
||||
from prompt_toolkit.filters import completion_is_selected
|
||||
from prompt_toolkit.key_binding import KeyBindings
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def mycli_bindings(mycli):
|
||||
"""Custom key bindings for mycli."""
|
||||
kb = KeyBindings()
|
||||
|
||||
@kb.add('f2')
|
||||
def _(event):
|
||||
"""Enable/Disable SmartCompletion Mode."""
|
||||
_logger.debug('Detected F2 key.')
|
||||
mycli.completer.smart_completion = not mycli.completer.smart_completion
|
||||
|
||||
@kb.add('f3')
|
||||
def _(event):
|
||||
"""Enable/Disable Multiline Mode."""
|
||||
_logger.debug('Detected F3 key.')
|
||||
mycli.multi_line = not mycli.multi_line
|
||||
|
||||
@kb.add('f4')
|
||||
def _(event):
|
||||
"""Toggle between Vi and Emacs mode."""
|
||||
_logger.debug('Detected F4 key.')
|
||||
if mycli.key_bindings == "vi":
|
||||
event.app.editing_mode = EditingMode.EMACS
|
||||
mycli.key_bindings = "emacs"
|
||||
else:
|
||||
event.app.editing_mode = EditingMode.VI
|
||||
mycli.key_bindings = "vi"
|
||||
|
||||
@kb.add('tab')
|
||||
def _(event):
|
||||
"""Force autocompletion at cursor."""
|
||||
_logger.debug('Detected <Tab> key.')
|
||||
b = event.app.current_buffer
|
||||
if b.complete_state:
|
||||
b.complete_next()
|
||||
else:
|
||||
b.start_completion(select_first=True)
|
||||
|
||||
@kb.add('c-space')
|
||||
def _(event):
|
||||
"""
|
||||
Initialize autocompletion at cursor.
|
||||
|
||||
If the autocompletion menu is not showing, display it with the
|
||||
appropriate completions for the context.
|
||||
|
||||
If the menu is showing, select the next completion.
|
||||
"""
|
||||
_logger.debug('Detected <C-Space> key.')
|
||||
|
||||
b = event.app.current_buffer
|
||||
if b.complete_state:
|
||||
b.complete_next()
|
||||
else:
|
||||
b.start_completion(select_first=False)
|
||||
|
||||
@kb.add('enter', filter=completion_is_selected)
|
||||
def _(event):
|
||||
"""Makes the enter key work as the tab key only when showing the menu.
|
||||
|
||||
In other words, don't execute query when enter is pressed in
|
||||
the completion dropdown menu, instead close the dropdown menu
|
||||
(accept current selection).
|
||||
|
||||
"""
|
||||
_logger.debug('Detected enter key.')
|
||||
|
||||
event.current_buffer.complete_state = None
|
||||
b = event.app.current_buffer
|
||||
b.complete_state = None
|
||||
|
||||
@kb.add('escape', 'enter')
|
||||
def _(event):
|
||||
"""Introduces a line break regardless of multi-line mode or not."""
|
||||
_logger.debug('Detected alt-enter key.')
|
||||
event.app.current_buffer.insert_text('\n')
|
||||
|
||||
return kb
|
12
mycli/lexer.py
Normal file
12
mycli/lexer.py
Normal file
|
@ -0,0 +1,12 @@
|
|||
from pygments.lexer import inherit
|
||||
from pygments.lexers.sql import MySqlLexer
|
||||
from pygments.token import Keyword
|
||||
|
||||
|
||||
class MyCliLexer(MySqlLexer):
|
||||
"""Extends MySQL lexer to add keywords."""
|
||||
|
||||
tokens = {
|
||||
'root': [(r'\brepair\b', Keyword),
|
||||
(r'\boffset\b', Keyword), inherit],
|
||||
}
|
54
mycli/magic.py
Normal file
54
mycli/magic.py
Normal file
|
@ -0,0 +1,54 @@
|
|||
from .main import MyCli
|
||||
import sql.parse
|
||||
import sql.connection
|
||||
import logging
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
def load_ipython_extension(ipython):
|
||||
|
||||
# This is called via the ipython command '%load_ext mycli.magic'.
|
||||
|
||||
# First, load the sql magic if it isn't already loaded.
|
||||
if not ipython.find_line_magic('sql'):
|
||||
ipython.run_line_magic('load_ext', 'sql')
|
||||
|
||||
# Register our own magic.
|
||||
ipython.register_magic_function(mycli_line_magic, 'line', 'mycli')
|
||||
|
||||
def mycli_line_magic(line):
|
||||
_logger.debug('mycli magic called: %r', line)
|
||||
parsed = sql.parse.parse(line, {})
|
||||
conn = sql.connection.Connection(parsed['connection'])
|
||||
|
||||
try:
|
||||
# A corresponding mycli object already exists
|
||||
mycli = conn._mycli
|
||||
_logger.debug('Reusing existing mycli')
|
||||
except AttributeError:
|
||||
mycli = MyCli()
|
||||
u = conn.session.engine.url
|
||||
_logger.debug('New mycli: %r', str(u))
|
||||
|
||||
mycli.connect(u.database, u.host, u.username, u.port, u.password)
|
||||
conn._mycli = mycli
|
||||
|
||||
# For convenience, print the connection alias
|
||||
print('Connected: {}'.format(conn.name))
|
||||
|
||||
try:
|
||||
mycli.run_cli()
|
||||
except SystemExit:
|
||||
pass
|
||||
|
||||
if not mycli.query_history:
|
||||
return
|
||||
|
||||
q = mycli.query_history[-1]
|
||||
if q.mutating:
|
||||
_logger.debug('Mutating query detected -- ignoring')
|
||||
return
|
||||
|
||||
if q.successful:
|
||||
ipython = get_ipython()
|
||||
return ipython.run_cell_magic('sql', line, q.query)
|
1370
mycli/main.py
Executable file
1370
mycli/main.py
Executable file
File diff suppressed because it is too large
Load diff
153
mycli/myclirc
Normal file
153
mycli/myclirc
Normal file
|
@ -0,0 +1,153 @@
|
|||
# vi: ft=dosini
|
||||
[main]
|
||||
|
||||
# Enables context sensitive auto-completion. If this is disabled the all
|
||||
# possible completions will be listed.
|
||||
smart_completion = True
|
||||
|
||||
# Multi-line mode allows breaking up the sql statements into multiple lines. If
|
||||
# this is set to True, then the end of the statements must have a semi-colon.
|
||||
# If this is set to False then sql statements can't be split into multiple
|
||||
# lines. End of line (return) is considered as the end of the statement.
|
||||
multi_line = False
|
||||
|
||||
# Destructive warning mode will alert you before executing a sql statement
|
||||
# that may cause harm to the database such as "drop table", "drop database"
|
||||
# or "shutdown".
|
||||
destructive_warning = True
|
||||
|
||||
# log_file location.
|
||||
log_file = ~/.mycli.log
|
||||
|
||||
# Default log level. Possible values: "CRITICAL", "ERROR", "WARNING", "INFO"
|
||||
# and "DEBUG". "NONE" disables logging.
|
||||
log_level = INFO
|
||||
|
||||
# Log every query and its results to a file. Enable this by uncommenting the
|
||||
# line below.
|
||||
# audit_log = ~/.mycli-audit.log
|
||||
|
||||
# Timing of sql statments and table rendering.
|
||||
timing = True
|
||||
|
||||
# Table format. Possible values: ascii, double, github,
|
||||
# psql, plain, simple, grid, fancy_grid, pipe, orgtbl, rst, mediawiki, html,
|
||||
# latex, latex_booktabs, textile, moinmoin, jira, vertical, tsv, csv.
|
||||
# Recommended: ascii
|
||||
table_format = ascii
|
||||
|
||||
# Syntax coloring style. Possible values (many support the "-dark" suffix):
|
||||
# manni, igor, xcode, vim, autumn, vs, rrt, native, perldoc, borland, tango, emacs,
|
||||
# friendly, monokai, paraiso, colorful, murphy, bw, pastie, paraiso, trac, default,
|
||||
# fruity.
|
||||
# Screenshots at http://mycli.net/syntax
|
||||
# Can be further modified in [colors]
|
||||
syntax_style = default
|
||||
|
||||
# Keybindings: Possible values: emacs, vi.
|
||||
# Emacs mode: Ctrl-A is home, Ctrl-E is end. All emacs keybindings are available in the REPL.
|
||||
# When Vi mode is enabled you can use modal editing features offered by Vi in the REPL.
|
||||
key_bindings = emacs
|
||||
|
||||
# Enabling this option will show the suggestions in a wider menu. Thus more items are suggested.
|
||||
wider_completion_menu = False
|
||||
|
||||
# MySQL prompt
|
||||
# \D - The full current date
|
||||
# \d - Database name
|
||||
# \h - Hostname of the server
|
||||
# \m - Minutes of the current time
|
||||
# \n - Newline
|
||||
# \P - AM/PM
|
||||
# \p - Port
|
||||
# \R - The current time, in 24-hour military time (0–23)
|
||||
# \r - The current time, standard 12-hour time (1–12)
|
||||
# \s - Seconds of the current time
|
||||
# \t - Product type (Percona, MySQL, MariaDB)
|
||||
# \A - DSN alias name (from the [alias_dsn] section)
|
||||
# \u - Username
|
||||
# \x1b[...m - insert ANSI escape sequence
|
||||
prompt = '\t \u@\h:\d> '
|
||||
prompt_continuation = '->'
|
||||
|
||||
# Skip intro info on startup and outro info on exit
|
||||
less_chatty = False
|
||||
|
||||
# Use alias from --login-path instead of host name in prompt
|
||||
login_path_as_host = False
|
||||
|
||||
# Cause result sets to be displayed vertically if they are too wide for the current window,
|
||||
# and using normal tabular format otherwise. (This applies to statements terminated by ; or \G.)
|
||||
auto_vertical_output = False
|
||||
|
||||
# keyword casing preference. Possible values "lower", "upper", "auto"
|
||||
keyword_casing = auto
|
||||
|
||||
# disabled pager on startup
|
||||
enable_pager = True
|
||||
|
||||
# Custom colors for the completion menu, toolbar, etc.
|
||||
[colors]
|
||||
completion-menu.completion.current = 'bg:#ffffff #000000'
|
||||
completion-menu.completion = 'bg:#008888 #ffffff'
|
||||
completion-menu.meta.completion.current = 'bg:#44aaaa #000000'
|
||||
completion-menu.meta.completion = 'bg:#448888 #ffffff'
|
||||
completion-menu.multi-column-meta = 'bg:#aaffff #000000'
|
||||
scrollbar.arrow = 'bg:#003333'
|
||||
scrollbar = 'bg:#00aaaa'
|
||||
selected = '#ffffff bg:#6666aa'
|
||||
search = '#ffffff bg:#4444aa'
|
||||
search.current = '#ffffff bg:#44aa44'
|
||||
bottom-toolbar = 'bg:#222222 #aaaaaa'
|
||||
bottom-toolbar.off = 'bg:#222222 #888888'
|
||||
bottom-toolbar.on = 'bg:#222222 #ffffff'
|
||||
search-toolbar = 'noinherit bold'
|
||||
search-toolbar.text = 'nobold'
|
||||
system-toolbar = 'noinherit bold'
|
||||
arg-toolbar = 'noinherit bold'
|
||||
arg-toolbar.text = 'nobold'
|
||||
bottom-toolbar.transaction.valid = 'bg:#222222 #00ff5f bold'
|
||||
bottom-toolbar.transaction.failed = 'bg:#222222 #ff005f bold'
|
||||
|
||||
# style classes for colored table output
|
||||
output.header = "#00ff5f bold"
|
||||
output.odd-row = ""
|
||||
output.even-row = ""
|
||||
output.null = "#808080"
|
||||
|
||||
# SQL syntax highlighting overrides
|
||||
# sql.comment = 'italic #408080'
|
||||
# sql.comment.multi-line = ''
|
||||
# sql.comment.single-line = ''
|
||||
# sql.comment.optimizer-hint = ''
|
||||
# sql.escape = 'border:#FF0000'
|
||||
# sql.keyword = 'bold #008000'
|
||||
# sql.datatype = 'nobold #B00040'
|
||||
# sql.literal = ''
|
||||
# sql.literal.date = ''
|
||||
# sql.symbol = ''
|
||||
# sql.quoted-schema-object = ''
|
||||
# sql.quoted-schema-object.escape = ''
|
||||
# sql.constant = '#880000'
|
||||
# sql.function = '#0000FF'
|
||||
# sql.variable = '#19177C'
|
||||
# sql.number = '#666666'
|
||||
# sql.number.binary = ''
|
||||
# sql.number.float = ''
|
||||
# sql.number.hex = ''
|
||||
# sql.number.integer = ''
|
||||
# sql.operator = '#666666'
|
||||
# sql.punctuation = ''
|
||||
# sql.string = '#BA2121'
|
||||
# sql.string.double-quouted = ''
|
||||
# sql.string.escape = 'bold #BB6622'
|
||||
# sql.string.single-quoted = ''
|
||||
# sql.whitespace = ''
|
||||
|
||||
# Favorite queries.
|
||||
[favorite_queries]
|
||||
|
||||
# Use the -d option to reference a DSN.
|
||||
# Special characters in passwords and other strings can be escaped with URL encoding.
|
||||
[alias_dsn]
|
||||
# example_dsn = mysql://[user[:password]@][host][:port][/dbname]
|
0
mycli/packages/__init__.py
Normal file
0
mycli/packages/__init__.py
Normal file
294
mycli/packages/completion_engine.py
Normal file
294
mycli/packages/completion_engine.py
Normal file
|
@ -0,0 +1,294 @@
|
|||
import os
|
||||
import sys
|
||||
import sqlparse
|
||||
from sqlparse.sql import Comparison, Identifier, Where
|
||||
from .parseutils import last_word, extract_tables, find_prev_keyword
|
||||
from .special import parse_special_command
|
||||
|
||||
|
||||
def suggest_type(full_text, text_before_cursor):
|
||||
"""Takes the full_text that is typed so far and also the text before the
|
||||
cursor to suggest completion type and scope.
|
||||
|
||||
Returns a tuple with a type of entity ('table', 'column' etc) and a scope.
|
||||
A scope for a column category will be a list of tables.
|
||||
"""
|
||||
|
||||
word_before_cursor = last_word(text_before_cursor,
|
||||
include='many_punctuations')
|
||||
|
||||
identifier = None
|
||||
|
||||
# here should be removed once sqlparse has been fixed
|
||||
try:
|
||||
# If we've partially typed a word then word_before_cursor won't be an empty
|
||||
# string. In that case we want to remove the partially typed string before
|
||||
# sending it to the sqlparser. Otherwise the last token will always be the
|
||||
# partially typed string which renders the smart completion useless because
|
||||
# it will always return the list of keywords as completion.
|
||||
if word_before_cursor:
|
||||
if word_before_cursor.endswith(
|
||||
'(') or word_before_cursor.startswith('\\'):
|
||||
parsed = sqlparse.parse(text_before_cursor)
|
||||
else:
|
||||
parsed = sqlparse.parse(
|
||||
text_before_cursor[:-len(word_before_cursor)])
|
||||
|
||||
# word_before_cursor may include a schema qualification, like
|
||||
# "schema_name.partial_name" or "schema_name.", so parse it
|
||||
# separately
|
||||
p = sqlparse.parse(word_before_cursor)[0]
|
||||
|
||||
if p.tokens and isinstance(p.tokens[0], Identifier):
|
||||
identifier = p.tokens[0]
|
||||
else:
|
||||
parsed = sqlparse.parse(text_before_cursor)
|
||||
except (TypeError, AttributeError):
|
||||
return [{'type': 'keyword'}]
|
||||
|
||||
if len(parsed) > 1:
|
||||
# Multiple statements being edited -- isolate the current one by
|
||||
# cumulatively summing statement lengths to find the one that bounds the
|
||||
# current position
|
||||
current_pos = len(text_before_cursor)
|
||||
stmt_start, stmt_end = 0, 0
|
||||
|
||||
for statement in parsed:
|
||||
stmt_len = len(str(statement))
|
||||
stmt_start, stmt_end = stmt_end, stmt_end + stmt_len
|
||||
|
||||
if stmt_end >= current_pos:
|
||||
text_before_cursor = full_text[stmt_start:current_pos]
|
||||
full_text = full_text[stmt_start:]
|
||||
break
|
||||
|
||||
elif parsed:
|
||||
# A single statement
|
||||
statement = parsed[0]
|
||||
else:
|
||||
# The empty string
|
||||
statement = None
|
||||
|
||||
# Check for special commands and handle those separately
|
||||
if statement:
|
||||
# Be careful here because trivial whitespace is parsed as a statement,
|
||||
# but the statement won't have a first token
|
||||
tok1 = statement.token_first()
|
||||
if tok1 and (tok1.value == 'source' or tok1.value.startswith('\\')):
|
||||
return suggest_special(text_before_cursor)
|
||||
|
||||
last_token = statement and statement.token_prev(len(statement.tokens))[1] or ''
|
||||
|
||||
return suggest_based_on_last_token(last_token, text_before_cursor,
|
||||
full_text, identifier)
|
||||
|
||||
|
||||
def suggest_special(text):
|
||||
text = text.lstrip()
|
||||
cmd, _, arg = parse_special_command(text)
|
||||
|
||||
if cmd == text:
|
||||
# Trying to complete the special command itself
|
||||
return [{'type': 'special'}]
|
||||
|
||||
if cmd in ('\\u', '\\r'):
|
||||
return [{'type': 'database'}]
|
||||
|
||||
if cmd in ('\\T'):
|
||||
return [{'type': 'table_format'}]
|
||||
|
||||
if cmd in ['\\f', '\\fs', '\\fd']:
|
||||
return [{'type': 'favoritequery'}]
|
||||
|
||||
if cmd in ['\\dt', '\\dt+']:
|
||||
return [
|
||||
{'type': 'table', 'schema': []},
|
||||
{'type': 'view', 'schema': []},
|
||||
{'type': 'schema'},
|
||||
]
|
||||
elif cmd in ['\\.', 'source']:
|
||||
return[{'type': 'file_name'}]
|
||||
|
||||
return [{'type': 'keyword'}, {'type': 'special'}]
|
||||
|
||||
|
||||
def suggest_based_on_last_token(token, text_before_cursor, full_text, identifier):
|
||||
if isinstance(token, str):
|
||||
token_v = token.lower()
|
||||
elif isinstance(token, Comparison):
|
||||
# If 'token' is a Comparison type such as
|
||||
# 'select * FROM abc a JOIN def d ON a.id = d.'. Then calling
|
||||
# token.value on the comparison type will only return the lhs of the
|
||||
# comparison. In this case a.id. So we need to do token.tokens to get
|
||||
# both sides of the comparison and pick the last token out of that
|
||||
# list.
|
||||
token_v = token.tokens[-1].value.lower()
|
||||
elif isinstance(token, Where):
|
||||
# sqlparse groups all tokens from the where clause into a single token
|
||||
# list. This means that token.value may be something like
|
||||
# 'where foo > 5 and '. We need to look "inside" token.tokens to handle
|
||||
# suggestions in complicated where clauses correctly
|
||||
prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor)
|
||||
return suggest_based_on_last_token(prev_keyword, text_before_cursor,
|
||||
full_text, identifier)
|
||||
else:
|
||||
token_v = token.value.lower()
|
||||
|
||||
is_operand = lambda x: x and any([x.endswith(op) for op in ['+', '-', '*', '/']])
|
||||
|
||||
if not token:
|
||||
return [{'type': 'keyword'}, {'type': 'special'}]
|
||||
elif token_v.endswith('('):
|
||||
p = sqlparse.parse(text_before_cursor)[0]
|
||||
|
||||
if p.tokens and isinstance(p.tokens[-1], Where):
|
||||
# Four possibilities:
|
||||
# 1 - Parenthesized clause like "WHERE foo AND ("
|
||||
# Suggest columns/functions
|
||||
# 2 - Function call like "WHERE foo("
|
||||
# Suggest columns/functions
|
||||
# 3 - Subquery expression like "WHERE EXISTS ("
|
||||
# Suggest keywords, in order to do a subquery
|
||||
# 4 - Subquery OR array comparison like "WHERE foo = ANY("
|
||||
# Suggest columns/functions AND keywords. (If we wanted to be
|
||||
# really fancy, we could suggest only array-typed columns)
|
||||
|
||||
column_suggestions = suggest_based_on_last_token('where',
|
||||
text_before_cursor, full_text, identifier)
|
||||
|
||||
# Check for a subquery expression (cases 3 & 4)
|
||||
where = p.tokens[-1]
|
||||
idx, prev_tok = where.token_prev(len(where.tokens) - 1)
|
||||
|
||||
if isinstance(prev_tok, Comparison):
|
||||
# e.g. "SELECT foo FROM bar WHERE foo = ANY("
|
||||
prev_tok = prev_tok.tokens[-1]
|
||||
|
||||
prev_tok = prev_tok.value.lower()
|
||||
if prev_tok == 'exists':
|
||||
return [{'type': 'keyword'}]
|
||||
else:
|
||||
return column_suggestions
|
||||
|
||||
# Get the token before the parens
|
||||
idx, prev_tok = p.token_prev(len(p.tokens) - 1)
|
||||
if prev_tok and prev_tok.value and prev_tok.value.lower() == 'using':
|
||||
# tbl1 INNER JOIN tbl2 USING (col1, col2)
|
||||
tables = extract_tables(full_text)
|
||||
|
||||
# suggest columns that are present in more than one table
|
||||
return [{'type': 'column', 'tables': tables, 'drop_unique': True}]
|
||||
elif p.token_first().value.lower() == 'select':
|
||||
# If the lparen is preceeded by a space chances are we're about to
|
||||
# do a sub-select.
|
||||
if last_word(text_before_cursor,
|
||||
'all_punctuations').startswith('('):
|
||||
return [{'type': 'keyword'}]
|
||||
elif p.token_first().value.lower() == 'show':
|
||||
return [{'type': 'show'}]
|
||||
|
||||
# We're probably in a function argument list
|
||||
return [{'type': 'column', 'tables': extract_tables(full_text)}]
|
||||
elif token_v in ('set', 'order by', 'distinct'):
|
||||
return [{'type': 'column', 'tables': extract_tables(full_text)}]
|
||||
elif token_v == 'as':
|
||||
# Don't suggest anything for an alias
|
||||
return []
|
||||
elif token_v in ('show'):
|
||||
return [{'type': 'show'}]
|
||||
elif token_v in ('to',):
|
||||
p = sqlparse.parse(text_before_cursor)[0]
|
||||
if p.token_first().value.lower() == 'change':
|
||||
return [{'type': 'change'}]
|
||||
else:
|
||||
return [{'type': 'user'}]
|
||||
elif token_v in ('user', 'for'):
|
||||
return [{'type': 'user'}]
|
||||
elif token_v in ('select', 'where', 'having'):
|
||||
# Check for a table alias or schema qualification
|
||||
parent = (identifier and identifier.get_parent_name()) or []
|
||||
|
||||
tables = extract_tables(full_text)
|
||||
if parent:
|
||||
tables = [t for t in tables if identifies(parent, *t)]
|
||||
return [{'type': 'column', 'tables': tables},
|
||||
{'type': 'table', 'schema': parent},
|
||||
{'type': 'view', 'schema': parent},
|
||||
{'type': 'function', 'schema': parent}]
|
||||
else:
|
||||
aliases = [alias or table for (schema, table, alias) in tables]
|
||||
return [{'type': 'column', 'tables': tables},
|
||||
{'type': 'function', 'schema': []},
|
||||
{'type': 'alias', 'aliases': aliases},
|
||||
{'type': 'keyword'}]
|
||||
elif (token_v.endswith('join') and token.is_keyword) or (token_v in
|
||||
('copy', 'from', 'update', 'into', 'describe', 'truncate',
|
||||
'desc', 'explain')):
|
||||
schema = (identifier and identifier.get_parent_name()) or []
|
||||
|
||||
# Suggest tables from either the currently-selected schema or the
|
||||
# public schema if no schema has been specified
|
||||
suggest = [{'type': 'table', 'schema': schema}]
|
||||
|
||||
if not schema:
|
||||
# Suggest schemas
|
||||
suggest.insert(0, {'type': 'schema'})
|
||||
|
||||
# Only tables can be TRUNCATED, otherwise suggest views
|
||||
if token_v != 'truncate':
|
||||
suggest.append({'type': 'view', 'schema': schema})
|
||||
|
||||
return suggest
|
||||
|
||||
elif token_v in ('table', 'view', 'function'):
|
||||
# E.g. 'DROP FUNCTION <funcname>', 'ALTER TABLE <tablname>'
|
||||
rel_type = token_v
|
||||
schema = (identifier and identifier.get_parent_name()) or []
|
||||
if schema:
|
||||
return [{'type': rel_type, 'schema': schema}]
|
||||
else:
|
||||
return [{'type': 'schema'}, {'type': rel_type, 'schema': []}]
|
||||
elif token_v == 'on':
|
||||
tables = extract_tables(full_text) # [(schema, table, alias), ...]
|
||||
parent = (identifier and identifier.get_parent_name()) or []
|
||||
if parent:
|
||||
# "ON parent.<suggestion>"
|
||||
# parent can be either a schema name or table alias
|
||||
tables = [t for t in tables if identifies(parent, *t)]
|
||||
return [{'type': 'column', 'tables': tables},
|
||||
{'type': 'table', 'schema': parent},
|
||||
{'type': 'view', 'schema': parent},
|
||||
{'type': 'function', 'schema': parent}]
|
||||
else:
|
||||
# ON <suggestion>
|
||||
# Use table alias if there is one, otherwise the table name
|
||||
aliases = [alias or table for (schema, table, alias) in tables]
|
||||
suggest = [{'type': 'alias', 'aliases': aliases}]
|
||||
|
||||
# The lists of 'aliases' could be empty if we're trying to complete
|
||||
# a GRANT query. eg: GRANT SELECT, INSERT ON <tab>
|
||||
# In that case we just suggest all tables.
|
||||
if not aliases:
|
||||
suggest.append({'type': 'table', 'schema': parent})
|
||||
return suggest
|
||||
|
||||
elif token_v in ('use', 'database', 'template', 'connect'):
|
||||
# "\c <db", "use <db>", "DROP DATABASE <db>",
|
||||
# "CREATE DATABASE <newdb> WITH TEMPLATE <db>"
|
||||
return [{'type': 'database'}]
|
||||
elif token_v == 'tableformat':
|
||||
return [{'type': 'table_format'}]
|
||||
elif token_v.endswith(',') or is_operand(token_v) or token_v in ['=', 'and', 'or']:
|
||||
prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor)
|
||||
if prev_keyword:
|
||||
return suggest_based_on_last_token(
|
||||
prev_keyword, text_before_cursor, full_text, identifier)
|
||||
else:
|
||||
return []
|
||||
else:
|
||||
return [{'type': 'keyword'}]
|
||||
|
||||
|
||||
def identifies(id, schema, table, alias):
|
||||
return id == alias or id == table or (
|
||||
schema and (id == schema + '.' + table))
|
106
mycli/packages/filepaths.py
Normal file
106
mycli/packages/filepaths.py
Normal file
|
@ -0,0 +1,106 @@
|
|||
import os
|
||||
import platform
|
||||
|
||||
|
||||
if os.name == "posix":
|
||||
if platform.system() == "Darwin":
|
||||
DEFAULT_SOCKET_DIRS = ("/tmp",)
|
||||
else:
|
||||
DEFAULT_SOCKET_DIRS = ("/var/run", "/var/lib")
|
||||
else:
|
||||
DEFAULT_SOCKET_DIRS = ()
|
||||
|
||||
|
||||
def list_path(root_dir):
|
||||
"""List directory if exists.
|
||||
|
||||
:param root_dir: str
|
||||
:return: list
|
||||
|
||||
"""
|
||||
res = []
|
||||
if os.path.isdir(root_dir):
|
||||
for name in os.listdir(root_dir):
|
||||
res.append(name)
|
||||
return res
|
||||
|
||||
|
||||
def complete_path(curr_dir, last_dir):
|
||||
"""Return the path to complete that matches the last entered component.
|
||||
|
||||
If the last entered component is ~, expanded path would not
|
||||
match, so return all of the available paths.
|
||||
|
||||
:param curr_dir: str
|
||||
:param last_dir: str
|
||||
:return: str
|
||||
|
||||
"""
|
||||
if not last_dir or curr_dir.startswith(last_dir):
|
||||
return curr_dir
|
||||
elif last_dir == '~':
|
||||
return os.path.join(last_dir, curr_dir)
|
||||
|
||||
|
||||
def parse_path(root_dir):
|
||||
"""Split path into head and last component for the completer.
|
||||
|
||||
Also return position where last component starts.
|
||||
|
||||
:param root_dir: str path
|
||||
:return: tuple of (string, string, int)
|
||||
|
||||
"""
|
||||
base_dir, last_dir, position = '', '', 0
|
||||
if root_dir:
|
||||
base_dir, last_dir = os.path.split(root_dir)
|
||||
position = -len(last_dir) if last_dir else 0
|
||||
return base_dir, last_dir, position
|
||||
|
||||
|
||||
def suggest_path(root_dir):
|
||||
"""List all files and subdirectories in a directory.
|
||||
|
||||
If the directory is not specified, suggest root directory,
|
||||
user directory, current and parent directory.
|
||||
|
||||
:param root_dir: string: directory to list
|
||||
:return: list
|
||||
|
||||
"""
|
||||
if not root_dir:
|
||||
return [os.path.abspath(os.sep), '~', os.curdir, os.pardir]
|
||||
|
||||
if '~' in root_dir:
|
||||
root_dir = os.path.expanduser(root_dir)
|
||||
|
||||
if not os.path.exists(root_dir):
|
||||
root_dir, _ = os.path.split(root_dir)
|
||||
|
||||
return list_path(root_dir)
|
||||
|
||||
|
||||
def dir_path_exists(path):
|
||||
"""Check if the directory path exists for a given file.
|
||||
|
||||
For example, for a file /home/user/.cache/mycli/log, check if
|
||||
/home/user/.cache/mycli exists.
|
||||
|
||||
:param str path: The file path.
|
||||
:return: Whether or not the directory path exists.
|
||||
|
||||
"""
|
||||
return os.path.exists(os.path.dirname(path))
|
||||
|
||||
|
||||
def guess_socket_location():
|
||||
"""Try to guess the location of the default mysql socket file."""
|
||||
socket_dirs = filter(os.path.exists, DEFAULT_SOCKET_DIRS)
|
||||
for directory in socket_dirs:
|
||||
for r, dirs, files in os.walk(directory, topdown=True):
|
||||
for filename in files:
|
||||
name, ext = os.path.splitext(filename)
|
||||
if name.startswith("mysql") and ext in ('.socket', '.sock'):
|
||||
return os.path.join(r, filename)
|
||||
dirs[:] = [d for d in dirs if d.startswith("mysql")]
|
||||
return None
|
28
mycli/packages/paramiko_stub/__init__.py
Normal file
28
mycli/packages/paramiko_stub/__init__.py
Normal file
|
@ -0,0 +1,28 @@
|
|||
"""A module to import instead of paramiko when it is not available (to avoid
|
||||
checking for paramiko all over the place).
|
||||
|
||||
When paramiko is first envoked, it simply shuts down mycli, telling
|
||||
user they either have to install paramiko or should not use SSH
|
||||
features.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class Paramiko:
|
||||
def __getattr__(self, name):
|
||||
import sys
|
||||
from textwrap import dedent
|
||||
print(dedent("""
|
||||
To enable certain SSH features you need to install paramiko:
|
||||
|
||||
pip install paramiko
|
||||
|
||||
It is required for the following configuration options:
|
||||
--list-ssh-config
|
||||
--ssh-config-host
|
||||
--ssh-host
|
||||
"""))
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
paramiko = Paramiko()
|
267
mycli/packages/parseutils.py
Normal file
267
mycli/packages/parseutils.py
Normal file
|
@ -0,0 +1,267 @@
|
|||
import re
|
||||
import sqlparse
|
||||
from sqlparse.sql import IdentifierList, Identifier, Function
|
||||
from sqlparse.tokens import Keyword, DML, Punctuation
|
||||
|
||||
cleanup_regex = {
|
||||
# This matches only alphanumerics and underscores.
|
||||
'alphanum_underscore': re.compile(r'(\w+)$'),
|
||||
# This matches everything except spaces, parens, colon, and comma
|
||||
'many_punctuations': re.compile(r'([^():,\s]+)$'),
|
||||
# This matches everything except spaces, parens, colon, comma, and period
|
||||
'most_punctuations': re.compile(r'([^\.():,\s]+)$'),
|
||||
# This matches everything except a space.
|
||||
'all_punctuations': re.compile(r'([^\s]+)$'),
|
||||
}
|
||||
|
||||
def last_word(text, include='alphanum_underscore'):
|
||||
r"""
|
||||
Find the last word in a sentence.
|
||||
|
||||
>>> last_word('abc')
|
||||
'abc'
|
||||
>>> last_word(' abc')
|
||||
'abc'
|
||||
>>> last_word('')
|
||||
''
|
||||
>>> last_word(' ')
|
||||
''
|
||||
>>> last_word('abc ')
|
||||
''
|
||||
>>> last_word('abc def')
|
||||
'def'
|
||||
>>> last_word('abc def ')
|
||||
''
|
||||
>>> last_word('abc def;')
|
||||
''
|
||||
>>> last_word('bac $def')
|
||||
'def'
|
||||
>>> last_word('bac $def', include='most_punctuations')
|
||||
'$def'
|
||||
>>> last_word('bac \def', include='most_punctuations')
|
||||
'\\\\def'
|
||||
>>> last_word('bac \def;', include='most_punctuations')
|
||||
'\\\\def;'
|
||||
>>> last_word('bac::def', include='most_punctuations')
|
||||
'def'
|
||||
"""
|
||||
|
||||
if not text: # Empty string
|
||||
return ''
|
||||
|
||||
if text[-1].isspace():
|
||||
return ''
|
||||
else:
|
||||
regex = cleanup_regex[include]
|
||||
matches = regex.search(text)
|
||||
if matches:
|
||||
return matches.group(0)
|
||||
else:
|
||||
return ''
|
||||
|
||||
|
||||
# This code is borrowed from sqlparse example script.
|
||||
# <url>
|
||||
def is_subselect(parsed):
|
||||
if not parsed.is_group:
|
||||
return False
|
||||
for item in parsed.tokens:
|
||||
if item.ttype is DML and item.value.upper() in ('SELECT', 'INSERT',
|
||||
'UPDATE', 'CREATE', 'DELETE'):
|
||||
return True
|
||||
return False
|
||||
|
||||
def extract_from_part(parsed, stop_at_punctuation=True):
|
||||
tbl_prefix_seen = False
|
||||
for item in parsed.tokens:
|
||||
if tbl_prefix_seen:
|
||||
if is_subselect(item):
|
||||
for x in extract_from_part(item, stop_at_punctuation):
|
||||
yield x
|
||||
elif stop_at_punctuation and item.ttype is Punctuation:
|
||||
return
|
||||
# An incomplete nested select won't be recognized correctly as a
|
||||
# sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This causes
|
||||
# the second FROM to trigger this elif condition resulting in a
|
||||
# StopIteration. So we need to ignore the keyword if the keyword
|
||||
# FROM.
|
||||
# Also 'SELECT * FROM abc JOIN def' will trigger this elif
|
||||
# condition. So we need to ignore the keyword JOIN and its variants
|
||||
# INNER JOIN, FULL OUTER JOIN, etc.
|
||||
elif item.ttype is Keyword and (
|
||||
not item.value.upper() == 'FROM') and (
|
||||
not item.value.upper().endswith('JOIN')):
|
||||
return
|
||||
else:
|
||||
yield item
|
||||
elif ((item.ttype is Keyword or item.ttype is Keyword.DML) and
|
||||
item.value.upper() in ('COPY', 'FROM', 'INTO', 'UPDATE', 'TABLE', 'JOIN',)):
|
||||
tbl_prefix_seen = True
|
||||
# 'SELECT a, FROM abc' will detect FROM as part of the column list.
|
||||
# So this check here is necessary.
|
||||
elif isinstance(item, IdentifierList):
|
||||
for identifier in item.get_identifiers():
|
||||
if (identifier.ttype is Keyword and
|
||||
identifier.value.upper() == 'FROM'):
|
||||
tbl_prefix_seen = True
|
||||
break
|
||||
|
||||
def extract_table_identifiers(token_stream):
|
||||
"""yields tuples of (schema_name, table_name, table_alias)"""
|
||||
|
||||
for item in token_stream:
|
||||
if isinstance(item, IdentifierList):
|
||||
for identifier in item.get_identifiers():
|
||||
# Sometimes Keywords (such as FROM ) are classified as
|
||||
# identifiers which don't have the get_real_name() method.
|
||||
try:
|
||||
schema_name = identifier.get_parent_name()
|
||||
real_name = identifier.get_real_name()
|
||||
except AttributeError:
|
||||
continue
|
||||
if real_name:
|
||||
yield (schema_name, real_name, identifier.get_alias())
|
||||
elif isinstance(item, Identifier):
|
||||
real_name = item.get_real_name()
|
||||
schema_name = item.get_parent_name()
|
||||
|
||||
if real_name:
|
||||
yield (schema_name, real_name, item.get_alias())
|
||||
else:
|
||||
name = item.get_name()
|
||||
yield (None, name, item.get_alias() or name)
|
||||
elif isinstance(item, Function):
|
||||
yield (None, item.get_name(), item.get_name())
|
||||
|
||||
# extract_tables is inspired from examples in the sqlparse lib.
|
||||
def extract_tables(sql):
|
||||
"""Extract the table names from an SQL statment.
|
||||
|
||||
Returns a list of (schema, table, alias) tuples
|
||||
|
||||
"""
|
||||
parsed = sqlparse.parse(sql)
|
||||
if not parsed:
|
||||
return []
|
||||
|
||||
# INSERT statements must stop looking for tables at the sign of first
|
||||
# Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2)
|
||||
# abc is the table name, but if we don't stop at the first lparen, then
|
||||
# we'll identify abc, col1 and col2 as table names.
|
||||
insert_stmt = parsed[0].token_first().value.lower() == 'insert'
|
||||
stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt)
|
||||
return list(extract_table_identifiers(stream))
|
||||
|
||||
def find_prev_keyword(sql):
|
||||
""" Find the last sql keyword in an SQL statement
|
||||
|
||||
Returns the value of the last keyword, and the text of the query with
|
||||
everything after the last keyword stripped
|
||||
"""
|
||||
if not sql.strip():
|
||||
return None, ''
|
||||
|
||||
parsed = sqlparse.parse(sql)[0]
|
||||
flattened = list(parsed.flatten())
|
||||
|
||||
logical_operators = ('AND', 'OR', 'NOT', 'BETWEEN')
|
||||
|
||||
for t in reversed(flattened):
|
||||
if t.value == '(' or (t.is_keyword and (
|
||||
t.value.upper() not in logical_operators)):
|
||||
# Find the location of token t in the original parsed statement
|
||||
# We can't use parsed.token_index(t) because t may be a child token
|
||||
# inside a TokenList, in which case token_index thows an error
|
||||
# Minimal example:
|
||||
# p = sqlparse.parse('select * from foo where bar')
|
||||
# t = list(p.flatten())[-3] # The "Where" token
|
||||
# p.token_index(t) # Throws ValueError: not in list
|
||||
idx = flattened.index(t)
|
||||
|
||||
# Combine the string values of all tokens in the original list
|
||||
# up to and including the target keyword token t, to produce a
|
||||
# query string with everything after the keyword token removed
|
||||
text = ''.join(tok.value for tok in flattened[:idx+1])
|
||||
return t, text
|
||||
|
||||
return None, ''
|
||||
|
||||
|
||||
def query_starts_with(query, prefixes):
|
||||
"""Check if the query starts with any item from *prefixes*."""
|
||||
prefixes = [prefix.lower() for prefix in prefixes]
|
||||
formatted_sql = sqlparse.format(query.lower(), strip_comments=True)
|
||||
return bool(formatted_sql) and formatted_sql.split()[0] in prefixes
|
||||
|
||||
|
||||
def queries_start_with(queries, prefixes):
|
||||
"""Check if any queries start with any item from *prefixes*."""
|
||||
for query in sqlparse.split(queries):
|
||||
if query and query_starts_with(query, prefixes) is True:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def query_has_where_clause(query):
|
||||
"""Check if the query contains a where-clause."""
|
||||
return any(
|
||||
isinstance(token, sqlparse.sql.Where)
|
||||
for token_list in sqlparse.parse(query)
|
||||
for token in token_list
|
||||
)
|
||||
|
||||
|
||||
def is_destructive(queries):
|
||||
"""Returns if any of the queries in *queries* is destructive."""
|
||||
keywords = ('drop', 'shutdown', 'delete', 'truncate', 'alter')
|
||||
for query in sqlparse.split(queries):
|
||||
if query:
|
||||
if query_starts_with(query, keywords) is True:
|
||||
return True
|
||||
elif query_starts_with(
|
||||
query, ['update']
|
||||
) is True and not query_has_where_clause(query):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def is_open_quote(sql):
|
||||
"""Returns true if the query contains an unclosed quote."""
|
||||
|
||||
# parsed can contain one or more semi-colon separated commands
|
||||
parsed = sqlparse.parse(sql)
|
||||
return any(_parsed_is_open_quote(p) for p in parsed)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sql = 'select * from (select t. from tabl t'
|
||||
print (extract_tables(sql))
|
||||
|
||||
|
||||
def is_dropping_database(queries, dbname):
|
||||
"""Determine if the query is dropping a specific database."""
|
||||
result = False
|
||||
if dbname is None:
|
||||
return False
|
||||
|
||||
def normalize_db_name(db):
|
||||
return db.lower().strip('`"')
|
||||
|
||||
dbname = normalize_db_name(dbname)
|
||||
|
||||
for query in sqlparse.parse(queries):
|
||||
keywords = [t for t in query.tokens if t.is_keyword]
|
||||
if len(keywords) < 2:
|
||||
continue
|
||||
if keywords[0].normalized in ("DROP", "CREATE") and keywords[1].value.lower() in (
|
||||
"database",
|
||||
"schema",
|
||||
):
|
||||
database_token = next(
|
||||
(t for t in query.tokens if isinstance(t, Identifier)), None
|
||||
)
|
||||
if database_token is not None and normalize_db_name(database_token.get_name()) == dbname:
|
||||
result = keywords[0].normalized == "DROP"
|
||||
else:
|
||||
return result
|
54
mycli/packages/prompt_utils.py
Normal file
54
mycli/packages/prompt_utils.py
Normal file
|
@ -0,0 +1,54 @@
|
|||
import sys
|
||||
import click
|
||||
from .parseutils import is_destructive
|
||||
|
||||
|
||||
class ConfirmBoolParamType(click.ParamType):
|
||||
name = 'confirmation'
|
||||
|
||||
def convert(self, value, param, ctx):
|
||||
if isinstance(value, bool):
|
||||
return bool(value)
|
||||
value = value.lower()
|
||||
if value in ('yes', 'y'):
|
||||
return True
|
||||
elif value in ('no', 'n'):
|
||||
return False
|
||||
self.fail('%s is not a valid boolean' % value, param, ctx)
|
||||
|
||||
def __repr__(self):
|
||||
return 'BOOL'
|
||||
|
||||
|
||||
BOOLEAN_TYPE = ConfirmBoolParamType()
|
||||
|
||||
|
||||
def confirm_destructive_query(queries):
|
||||
"""Check if the query is destructive and prompts the user to confirm.
|
||||
|
||||
Returns:
|
||||
* None if the query is non-destructive or we can't prompt the user.
|
||||
* True if the query is destructive and the user wants to proceed.
|
||||
* False if the query is destructive and the user doesn't want to proceed.
|
||||
|
||||
"""
|
||||
prompt_text = ("You're about to run a destructive command.\n"
|
||||
"Do you want to proceed? (y/n)")
|
||||
if is_destructive(queries) and sys.stdin.isatty():
|
||||
return prompt(prompt_text, type=BOOLEAN_TYPE)
|
||||
|
||||
|
||||
def confirm(*args, **kwargs):
|
||||
"""Prompt for confirmation (yes/no) and handle any abort exceptions."""
|
||||
try:
|
||||
return click.confirm(*args, **kwargs)
|
||||
except click.Abort:
|
||||
return False
|
||||
|
||||
|
||||
def prompt(*args, **kwargs):
|
||||
"""Prompt the user for input and handle any abort exceptions."""
|
||||
try:
|
||||
return click.prompt(*args, **kwargs)
|
||||
except click.Abort:
|
||||
return False
|
10
mycli/packages/special/__init__.py
Normal file
10
mycli/packages/special/__init__.py
Normal file
|
@ -0,0 +1,10 @@
|
|||
__all__ = []
|
||||
|
||||
def export(defn):
|
||||
"""Decorator to explicitly mark functions that are exposed in a lib."""
|
||||
globals()[defn.__name__] = defn
|
||||
__all__.append(defn.__name__)
|
||||
return defn
|
||||
|
||||
from . import dbcommands
|
||||
from . import iocommands
|
157
mycli/packages/special/dbcommands.py
Normal file
157
mycli/packages/special/dbcommands.py
Normal file
|
@ -0,0 +1,157 @@
|
|||
import logging
|
||||
import os
|
||||
import platform
|
||||
from mycli import __version__
|
||||
from mycli.packages.special import iocommands
|
||||
from mycli.packages.special.utils import format_uptime
|
||||
from .main import special_command, RAW_QUERY, PARSED_QUERY
|
||||
from pymysql import ProgrammingError
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@special_command('\\dt', '\\dt[+] [table]', 'List or describe tables.',
|
||||
arg_type=PARSED_QUERY, case_sensitive=True)
|
||||
def list_tables(cur, arg=None, arg_type=PARSED_QUERY, verbose=False):
|
||||
if arg:
|
||||
query = 'SHOW FIELDS FROM {0}'.format(arg)
|
||||
else:
|
||||
query = 'SHOW TABLES'
|
||||
log.debug(query)
|
||||
cur.execute(query)
|
||||
tables = cur.fetchall()
|
||||
status = ''
|
||||
if cur.description:
|
||||
headers = [x[0] for x in cur.description]
|
||||
else:
|
||||
return [(None, None, None, '')]
|
||||
|
||||
if verbose and arg:
|
||||
query = 'SHOW CREATE TABLE {0}'.format(arg)
|
||||
log.debug(query)
|
||||
cur.execute(query)
|
||||
status = cur.fetchone()[1]
|
||||
|
||||
return [(None, tables, headers, status)]
|
||||
|
||||
@special_command('\\l', '\\l', 'List databases.', arg_type=RAW_QUERY, case_sensitive=True)
|
||||
def list_databases(cur, **_):
|
||||
query = 'SHOW DATABASES'
|
||||
log.debug(query)
|
||||
cur.execute(query)
|
||||
if cur.description:
|
||||
headers = [x[0] for x in cur.description]
|
||||
return [(None, cur, headers, '')]
|
||||
else:
|
||||
return [(None, None, None, '')]
|
||||
|
||||
@special_command('status', '\\s', 'Get status information from the server.',
|
||||
arg_type=RAW_QUERY, aliases=('\\s', ), case_sensitive=True)
|
||||
def status(cur, **_):
|
||||
query = 'SHOW GLOBAL STATUS;'
|
||||
log.debug(query)
|
||||
try:
|
||||
cur.execute(query)
|
||||
except ProgrammingError:
|
||||
# Fallback in case query fail, as it does with Mysql 4
|
||||
query = 'SHOW STATUS;'
|
||||
log.debug(query)
|
||||
cur.execute(query)
|
||||
status = dict(cur.fetchall())
|
||||
|
||||
query = 'SHOW GLOBAL VARIABLES;'
|
||||
log.debug(query)
|
||||
cur.execute(query)
|
||||
variables = dict(cur.fetchall())
|
||||
|
||||
# prepare in case keys are bytes, as with Python 3 and Mysql 4
|
||||
if (isinstance(list(variables)[0], bytes) and
|
||||
isinstance(list(status)[0], bytes)):
|
||||
variables = {k.decode('utf-8'): v.decode('utf-8') for k, v
|
||||
in variables.items()}
|
||||
status = {k.decode('utf-8'): v.decode('utf-8') for k, v
|
||||
in status.items()}
|
||||
|
||||
# Create output buffers.
|
||||
title = []
|
||||
output = []
|
||||
footer = []
|
||||
|
||||
title.append('--------------')
|
||||
|
||||
# Output the mycli client information.
|
||||
implementation = platform.python_implementation()
|
||||
version = platform.python_version()
|
||||
client_info = []
|
||||
client_info.append('mycli {0},'.format(__version__))
|
||||
client_info.append('running on {0} {1}'.format(implementation, version))
|
||||
title.append(' '.join(client_info) + '\n')
|
||||
|
||||
# Build the output that will be displayed as a table.
|
||||
output.append(('Connection id:', cur.connection.thread_id()))
|
||||
|
||||
query = 'SELECT DATABASE(), USER();'
|
||||
log.debug(query)
|
||||
cur.execute(query)
|
||||
db, user = cur.fetchone()
|
||||
if db is None:
|
||||
db = ''
|
||||
|
||||
output.append(('Current database:', db))
|
||||
output.append(('Current user:', user))
|
||||
|
||||
if iocommands.is_pager_enabled():
|
||||
if 'PAGER' in os.environ:
|
||||
pager = os.environ['PAGER']
|
||||
else:
|
||||
pager = 'System default'
|
||||
else:
|
||||
pager = 'stdout'
|
||||
output.append(('Current pager:', pager))
|
||||
|
||||
output.append(('Server version:', '{0} {1}'.format(
|
||||
variables['version'], variables['version_comment'])))
|
||||
output.append(('Protocol version:', variables['protocol_version']))
|
||||
|
||||
if 'unix' in cur.connection.host_info.lower():
|
||||
host_info = cur.connection.host_info
|
||||
else:
|
||||
host_info = '{0} via TCP/IP'.format(cur.connection.host)
|
||||
|
||||
output.append(('Connection:', host_info))
|
||||
|
||||
query = ('SELECT @@character_set_server, @@character_set_database, '
|
||||
'@@character_set_client, @@character_set_connection LIMIT 1;')
|
||||
log.debug(query)
|
||||
cur.execute(query)
|
||||
charset = cur.fetchone()
|
||||
output.append(('Server characterset:', charset[0]))
|
||||
output.append(('Db characterset:', charset[1]))
|
||||
output.append(('Client characterset:', charset[2]))
|
||||
output.append(('Conn. characterset:', charset[3]))
|
||||
|
||||
if 'TCP/IP' in host_info:
|
||||
output.append(('TCP port:', cur.connection.port))
|
||||
else:
|
||||
output.append(('UNIX socket:', variables['socket']))
|
||||
|
||||
output.append(('Uptime:', format_uptime(status['Uptime'])))
|
||||
|
||||
# Print the current server statistics.
|
||||
stats = []
|
||||
stats.append('Connections: {0}'.format(status['Threads_connected']))
|
||||
if 'Queries' in status:
|
||||
stats.append('Queries: {0}'.format(status['Queries']))
|
||||
stats.append('Slow queries: {0}'.format(status['Slow_queries']))
|
||||
stats.append('Opens: {0}'.format(status['Opened_tables']))
|
||||
stats.append('Flush tables: {0}'.format(status['Flush_commands']))
|
||||
stats.append('Open tables: {0}'.format(status['Open_tables']))
|
||||
if 'Queries' in status:
|
||||
queries_per_second = int(status['Queries']) / int(status['Uptime'])
|
||||
stats.append('Queries per second avg: {:.3f}'.format(
|
||||
queries_per_second))
|
||||
stats = ' '.join(stats)
|
||||
footer.append('\n' + stats)
|
||||
|
||||
footer.append('--------------')
|
||||
return [('\n'.join(title), output, '', '\n'.join(footer))]
|
80
mycli/packages/special/delimitercommand.py
Normal file
80
mycli/packages/special/delimitercommand.py
Normal file
|
@ -0,0 +1,80 @@
|
|||
import re
|
||||
import sqlparse
|
||||
|
||||
|
||||
class DelimiterCommand(object):
|
||||
def __init__(self):
|
||||
self._delimiter = ';'
|
||||
|
||||
def _split(self, sql):
|
||||
"""Temporary workaround until sqlparse.split() learns about custom
|
||||
delimiters."""
|
||||
|
||||
placeholder = "\ufffc" # unicode object replacement character
|
||||
|
||||
if self._delimiter == ';':
|
||||
return sqlparse.split(sql)
|
||||
|
||||
# We must find a string that original sql does not contain.
|
||||
# Most likely, our placeholder is enough, but if not, keep looking
|
||||
while placeholder in sql:
|
||||
placeholder += placeholder[0]
|
||||
sql = sql.replace(';', placeholder)
|
||||
sql = sql.replace(self._delimiter, ';')
|
||||
|
||||
split = sqlparse.split(sql)
|
||||
|
||||
return [
|
||||
stmt.replace(';', self._delimiter).replace(placeholder, ';')
|
||||
for stmt in split
|
||||
]
|
||||
|
||||
def queries_iter(self, input):
|
||||
"""Iterate over queries in the input string."""
|
||||
|
||||
queries = self._split(input)
|
||||
while queries:
|
||||
for sql in queries:
|
||||
delimiter = self._delimiter
|
||||
sql = queries.pop(0)
|
||||
if sql.endswith(delimiter):
|
||||
trailing_delimiter = True
|
||||
sql = sql.strip(delimiter)
|
||||
else:
|
||||
trailing_delimiter = False
|
||||
|
||||
yield sql
|
||||
|
||||
# if the delimiter was changed by the last command,
|
||||
# re-split everything, and if we previously stripped
|
||||
# the delimiter, append it to the end
|
||||
if self._delimiter != delimiter:
|
||||
combined_statement = ' '.join([sql] + queries)
|
||||
if trailing_delimiter:
|
||||
combined_statement += delimiter
|
||||
queries = self._split(combined_statement)[1:]
|
||||
|
||||
def set(self, arg, **_):
|
||||
"""Change delimiter.
|
||||
|
||||
Since `arg` is everything that follows the DELIMITER token
|
||||
after sqlparse (it may include other statements separated by
|
||||
the new delimiter), we want to set the delimiter to the first
|
||||
word of it.
|
||||
|
||||
"""
|
||||
match = arg and re.search(r'[^\s]+', arg)
|
||||
if not match:
|
||||
message = 'Missing required argument, delimiter'
|
||||
return [(None, None, None, message)]
|
||||
|
||||
delimiter = match.group()
|
||||
if delimiter.lower() == 'delimiter':
|
||||
return [(None, None, None, 'Invalid delimiter "delimiter"')]
|
||||
|
||||
self._delimiter = delimiter
|
||||
return [(None, None, None, "Changed delimiter to {}".format(delimiter))]
|
||||
|
||||
@property
|
||||
def current(self):
|
||||
return self._delimiter
|
63
mycli/packages/special/favoritequeries.py
Normal file
63
mycli/packages/special/favoritequeries.py
Normal file
|
@ -0,0 +1,63 @@
|
|||
class FavoriteQueries(object):
|
||||
|
||||
section_name = 'favorite_queries'
|
||||
|
||||
usage = '''
|
||||
Favorite Queries are a way to save frequently used queries
|
||||
with a short name.
|
||||
Examples:
|
||||
|
||||
# Save a new favorite query.
|
||||
> \\fs simple select * from abc where a is not Null;
|
||||
|
||||
# List all favorite queries.
|
||||
> \\f
|
||||
╒════════╤═══════════════════════════════════════╕
|
||||
│ Name │ Query │
|
||||
╞════════╪═══════════════════════════════════════╡
|
||||
│ simple │ SELECT * FROM abc where a is not NULL │
|
||||
╘════════╧═══════════════════════════════════════╛
|
||||
|
||||
# Run a favorite query.
|
||||
> \\f simple
|
||||
╒════════╤════════╕
|
||||
│ a │ b │
|
||||
╞════════╪════════╡
|
||||
│ 日本語 │ 日本語 │
|
||||
╘════════╧════════╛
|
||||
|
||||
# Delete a favorite query.
|
||||
> \\fd simple
|
||||
simple: Deleted
|
||||
'''
|
||||
|
||||
# Class-level variable, for convenience to use as a singleton.
|
||||
instance = None
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
return FavoriteQueries(config)
|
||||
|
||||
def list(self):
|
||||
return self.config.get(self.section_name, [])
|
||||
|
||||
def get(self, name):
|
||||
return self.config.get(self.section_name, {}).get(name, None)
|
||||
|
||||
def save(self, name, query):
|
||||
self.config.encoding = 'utf-8'
|
||||
if self.section_name not in self.config:
|
||||
self.config[self.section_name] = {}
|
||||
self.config[self.section_name][name] = query
|
||||
self.config.write()
|
||||
|
||||
def delete(self, name):
|
||||
try:
|
||||
del self.config[self.section_name][name]
|
||||
except KeyError:
|
||||
return '%s: Not Found.' % name
|
||||
self.config.write()
|
||||
return '%s: Deleted' % name
|
543
mycli/packages/special/iocommands.py
Normal file
543
mycli/packages/special/iocommands.py
Normal file
|
@ -0,0 +1,543 @@
|
|||
import os
|
||||
import re
|
||||
import locale
|
||||
import logging
|
||||
import subprocess
|
||||
import shlex
|
||||
from io import open
|
||||
from time import sleep
|
||||
|
||||
import click
|
||||
import pyperclip
|
||||
import sqlparse
|
||||
|
||||
from . import export
|
||||
from .main import special_command, NO_QUERY, PARSED_QUERY
|
||||
from .favoritequeries import FavoriteQueries
|
||||
from .delimitercommand import DelimiterCommand
|
||||
from .utils import handle_cd_command
|
||||
from mycli.packages.prompt_utils import confirm_destructive_query
|
||||
|
||||
TIMING_ENABLED = False
|
||||
use_expanded_output = False
|
||||
PAGER_ENABLED = True
|
||||
tee_file = None
|
||||
once_file = None
|
||||
written_to_once_file = False
|
||||
pipe_once_process = None
|
||||
written_to_pipe_once_process = False
|
||||
delimiter_command = DelimiterCommand()
|
||||
|
||||
|
||||
@export
|
||||
def set_timing_enabled(val):
|
||||
global TIMING_ENABLED
|
||||
TIMING_ENABLED = val
|
||||
|
||||
@export
|
||||
def set_pager_enabled(val):
|
||||
global PAGER_ENABLED
|
||||
PAGER_ENABLED = val
|
||||
|
||||
|
||||
@export
|
||||
def is_pager_enabled():
|
||||
return PAGER_ENABLED
|
||||
|
||||
@export
|
||||
@special_command('pager', '\\P [command]',
|
||||
'Set PAGER. Print the query results via PAGER.',
|
||||
arg_type=PARSED_QUERY, aliases=('\\P', ), case_sensitive=True)
|
||||
def set_pager(arg, **_):
|
||||
if arg:
|
||||
os.environ['PAGER'] = arg
|
||||
msg = 'PAGER set to %s.' % arg
|
||||
set_pager_enabled(True)
|
||||
else:
|
||||
if 'PAGER' in os.environ:
|
||||
msg = 'PAGER set to %s.' % os.environ['PAGER']
|
||||
else:
|
||||
# This uses click's default per echo_via_pager.
|
||||
msg = 'Pager enabled.'
|
||||
set_pager_enabled(True)
|
||||
|
||||
return [(None, None, None, msg)]
|
||||
|
||||
@export
|
||||
@special_command('nopager', '\\n', 'Disable pager, print to stdout.',
|
||||
arg_type=NO_QUERY, aliases=('\\n', ), case_sensitive=True)
|
||||
def disable_pager():
|
||||
set_pager_enabled(False)
|
||||
return [(None, None, None, 'Pager disabled.')]
|
||||
|
||||
@special_command('\\timing', '\\t', 'Toggle timing of commands.', arg_type=NO_QUERY, aliases=('\\t', ), case_sensitive=True)
|
||||
def toggle_timing():
|
||||
global TIMING_ENABLED
|
||||
TIMING_ENABLED = not TIMING_ENABLED
|
||||
message = "Timing is "
|
||||
message += "on." if TIMING_ENABLED else "off."
|
||||
return [(None, None, None, message)]
|
||||
|
||||
@export
|
||||
def is_timing_enabled():
|
||||
return TIMING_ENABLED
|
||||
|
||||
@export
|
||||
def set_expanded_output(val):
|
||||
global use_expanded_output
|
||||
use_expanded_output = val
|
||||
|
||||
@export
|
||||
def is_expanded_output():
|
||||
return use_expanded_output
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
@export
|
||||
def editor_command(command):
|
||||
"""
|
||||
Is this an external editor command?
|
||||
:param command: string
|
||||
"""
|
||||
# It is possible to have `\e filename` or `SELECT * FROM \e`. So we check
|
||||
# for both conditions.
|
||||
return command.strip().endswith('\\e') or command.strip().startswith('\\e')
|
||||
|
||||
@export
|
||||
def get_filename(sql):
|
||||
if sql.strip().startswith('\\e'):
|
||||
command, _, filename = sql.partition(' ')
|
||||
return filename.strip() or None
|
||||
|
||||
|
||||
@export
|
||||
def get_editor_query(sql):
|
||||
"""Get the query part of an editor command."""
|
||||
sql = sql.strip()
|
||||
|
||||
# The reason we can't simply do .strip('\e') is that it strips characters,
|
||||
# not a substring. So it'll strip "e" in the end of the sql also!
|
||||
# Ex: "select * from style\e" -> "select * from styl".
|
||||
pattern = re.compile(r'(^\\e|\\e$)')
|
||||
while pattern.search(sql):
|
||||
sql = pattern.sub('', sql)
|
||||
|
||||
return sql
|
||||
|
||||
|
||||
@export
|
||||
def open_external_editor(filename=None, sql=None):
|
||||
"""Open external editor, wait for the user to type in their query, return
|
||||
the query.
|
||||
|
||||
:return: list with one tuple, query as first element.
|
||||
|
||||
"""
|
||||
|
||||
message = None
|
||||
filename = filename.strip().split(' ', 1)[0] if filename else None
|
||||
|
||||
sql = sql or ''
|
||||
MARKER = '# Type your query above this line.\n'
|
||||
|
||||
# Populate the editor buffer with the partial sql (if available) and a
|
||||
# placeholder comment.
|
||||
query = click.edit(u'{sql}\n\n{marker}'.format(sql=sql, marker=MARKER),
|
||||
filename=filename, extension='.sql')
|
||||
|
||||
if filename:
|
||||
try:
|
||||
with open(filename) as f:
|
||||
query = f.read()
|
||||
except IOError:
|
||||
message = 'Error reading file: %s.' % filename
|
||||
|
||||
if query is not None:
|
||||
query = query.split(MARKER, 1)[0].rstrip('\n')
|
||||
else:
|
||||
# Don't return None for the caller to deal with.
|
||||
# Empty string is ok.
|
||||
query = sql
|
||||
|
||||
return (query, message)
|
||||
|
||||
|
||||
@export
|
||||
def clip_command(command):
|
||||
"""Is this a clip command?
|
||||
|
||||
:param command: string
|
||||
|
||||
"""
|
||||
# It is possible to have `\clip` or `SELECT * FROM \clip`. So we check
|
||||
# for both conditions.
|
||||
return command.strip().endswith('\\clip') or command.strip().startswith('\\clip')
|
||||
|
||||
|
||||
@export
|
||||
def get_clip_query(sql):
|
||||
"""Get the query part of a clip command."""
|
||||
sql = sql.strip()
|
||||
|
||||
# The reason we can't simply do .strip('\clip') is that it strips characters,
|
||||
# not a substring. So it'll strip "c" in the end of the sql also!
|
||||
pattern = re.compile(r'(^\\clip|\\clip$)')
|
||||
while pattern.search(sql):
|
||||
sql = pattern.sub('', sql)
|
||||
|
||||
return sql
|
||||
|
||||
|
||||
@export
|
||||
def copy_query_to_clipboard(sql=None):
|
||||
"""Send query to the clipboard."""
|
||||
|
||||
sql = sql or ''
|
||||
message = None
|
||||
|
||||
try:
|
||||
pyperclip.copy(u'{sql}'.format(sql=sql))
|
||||
except RuntimeError as e:
|
||||
message = 'Error clipping query: %s.' % e.strerror
|
||||
|
||||
return message
|
||||
|
||||
|
||||
@special_command('\\f', '\\f [name [args..]]', 'List or execute favorite queries.', arg_type=PARSED_QUERY, case_sensitive=True)
|
||||
def execute_favorite_query(cur, arg, **_):
|
||||
"""Returns (title, rows, headers, status)"""
|
||||
if arg == '':
|
||||
for result in list_favorite_queries():
|
||||
yield result
|
||||
|
||||
"""Parse out favorite name and optional substitution parameters"""
|
||||
name, _, arg_str = arg.partition(' ')
|
||||
args = shlex.split(arg_str)
|
||||
|
||||
query = FavoriteQueries.instance.get(name)
|
||||
if query is None:
|
||||
message = "No favorite query: %s" % (name)
|
||||
yield (None, None, None, message)
|
||||
else:
|
||||
query, arg_error = subst_favorite_query_args(query, args)
|
||||
if arg_error:
|
||||
yield (None, None, None, arg_error)
|
||||
else:
|
||||
for sql in sqlparse.split(query):
|
||||
sql = sql.rstrip(';')
|
||||
title = '> %s' % (sql)
|
||||
cur.execute(sql)
|
||||
if cur.description:
|
||||
headers = [x[0] for x in cur.description]
|
||||
yield (title, cur, headers, None)
|
||||
else:
|
||||
yield (title, None, None, None)
|
||||
|
||||
def list_favorite_queries():
|
||||
"""List of all favorite queries.
|
||||
Returns (title, rows, headers, status)"""
|
||||
|
||||
headers = ["Name", "Query"]
|
||||
rows = [(r, FavoriteQueries.instance.get(r))
|
||||
for r in FavoriteQueries.instance.list()]
|
||||
|
||||
if not rows:
|
||||
status = '\nNo favorite queries found.' + FavoriteQueries.instance.usage
|
||||
else:
|
||||
status = ''
|
||||
return [('', rows, headers, status)]
|
||||
|
||||
|
||||
def subst_favorite_query_args(query, args):
|
||||
"""replace positional parameters ($1...$N) in query."""
|
||||
for idx, val in enumerate(args):
|
||||
subst_var = '$' + str(idx + 1)
|
||||
if subst_var not in query:
|
||||
return [None, 'query does not have substitution parameter ' + subst_var + ':\n ' + query]
|
||||
|
||||
query = query.replace(subst_var, val)
|
||||
|
||||
match = re.search(r'\$\d+', query)
|
||||
if match:
|
||||
return[None, 'missing substitution for ' + match.group(0) + ' in query:\n ' + query]
|
||||
|
||||
return [query, None]
|
||||
|
||||
@special_command('\\fs', '\\fs name query', 'Save a favorite query.')
|
||||
def save_favorite_query(arg, **_):
|
||||
"""Save a new favorite query.
|
||||
Returns (title, rows, headers, status)"""
|
||||
|
||||
usage = 'Syntax: \\fs name query.\n\n' + FavoriteQueries.instance.usage
|
||||
if not arg:
|
||||
return [(None, None, None, usage)]
|
||||
|
||||
name, _, query = arg.partition(' ')
|
||||
|
||||
# If either name or query is missing then print the usage and complain.
|
||||
if (not name) or (not query):
|
||||
return [(None, None, None,
|
||||
usage + 'Err: Both name and query are required.')]
|
||||
|
||||
FavoriteQueries.instance.save(name, query)
|
||||
return [(None, None, None, "Saved.")]
|
||||
|
||||
|
||||
@special_command('\\fd', '\\fd [name]', 'Delete a favorite query.')
|
||||
def delete_favorite_query(arg, **_):
|
||||
"""Delete an existing favorite query."""
|
||||
usage = 'Syntax: \\fd name.\n\n' + FavoriteQueries.instance.usage
|
||||
if not arg:
|
||||
return [(None, None, None, usage)]
|
||||
|
||||
status = FavoriteQueries.instance.delete(arg)
|
||||
|
||||
return [(None, None, None, status)]
|
||||
|
||||
|
||||
@special_command('system', 'system [command]',
|
||||
'Execute a system shell commmand.')
|
||||
def execute_system_command(arg, **_):
|
||||
"""Execute a system shell command."""
|
||||
usage = "Syntax: system [command].\n"
|
||||
|
||||
if not arg:
|
||||
return [(None, None, None, usage)]
|
||||
|
||||
try:
|
||||
command = arg.strip()
|
||||
if command.startswith('cd'):
|
||||
ok, error_message = handle_cd_command(arg)
|
||||
if not ok:
|
||||
return [(None, None, None, error_message)]
|
||||
return [(None, None, None, '')]
|
||||
|
||||
args = arg.split(' ')
|
||||
process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
output, error = process.communicate()
|
||||
response = output if not error else error
|
||||
|
||||
# Python 3 returns bytes. This needs to be decoded to a string.
|
||||
if isinstance(response, bytes):
|
||||
encoding = locale.getpreferredencoding(False)
|
||||
response = response.decode(encoding)
|
||||
|
||||
return [(None, None, None, response)]
|
||||
except OSError as e:
|
||||
return [(None, None, None, 'OSError: %s' % e.strerror)]
|
||||
|
||||
|
||||
def parseargfile(arg):
|
||||
if arg.startswith('-o '):
|
||||
mode = "w"
|
||||
filename = arg[3:]
|
||||
else:
|
||||
mode = 'a'
|
||||
filename = arg
|
||||
|
||||
if not filename:
|
||||
raise TypeError('You must provide a filename.')
|
||||
|
||||
return {'file': os.path.expanduser(filename), 'mode': mode}
|
||||
|
||||
|
||||
@special_command('tee', 'tee [-o] filename',
|
||||
'Append all results to an output file (overwrite using -o).')
|
||||
def set_tee(arg, **_):
|
||||
global tee_file
|
||||
|
||||
try:
|
||||
tee_file = open(**parseargfile(arg))
|
||||
except (IOError, OSError) as e:
|
||||
raise OSError("Cannot write to file '{}': {}".format(e.filename, e.strerror))
|
||||
|
||||
return [(None, None, None, "")]
|
||||
|
||||
@export
|
||||
def close_tee():
|
||||
global tee_file
|
||||
if tee_file:
|
||||
tee_file.close()
|
||||
tee_file = None
|
||||
|
||||
|
||||
@special_command('notee', 'notee', 'Stop writing results to an output file.')
|
||||
def no_tee(arg, **_):
|
||||
close_tee()
|
||||
return [(None, None, None, "")]
|
||||
|
||||
@export
|
||||
def write_tee(output):
|
||||
global tee_file
|
||||
if tee_file:
|
||||
click.echo(output, file=tee_file, nl=False)
|
||||
click.echo(u'\n', file=tee_file, nl=False)
|
||||
tee_file.flush()
|
||||
|
||||
|
||||
@special_command('\\once', '\\o [-o] filename',
|
||||
'Append next result to an output file (overwrite using -o).',
|
||||
aliases=('\\o', ))
|
||||
def set_once(arg, **_):
|
||||
global once_file, written_to_once_file
|
||||
|
||||
try:
|
||||
once_file = open(**parseargfile(arg))
|
||||
except (IOError, OSError) as e:
|
||||
raise OSError("Cannot write to file '{}': {}".format(
|
||||
e.filename, e.strerror))
|
||||
written_to_once_file = False
|
||||
|
||||
return [(None, None, None, "")]
|
||||
|
||||
|
||||
@export
|
||||
def write_once(output):
|
||||
global once_file, written_to_once_file
|
||||
if output and once_file:
|
||||
click.echo(output, file=once_file, nl=False)
|
||||
click.echo(u"\n", file=once_file, nl=False)
|
||||
once_file.flush()
|
||||
written_to_once_file = True
|
||||
|
||||
|
||||
@export
|
||||
def unset_once_if_written():
|
||||
"""Unset the once file, if it has been written to."""
|
||||
global once_file, written_to_once_file
|
||||
if written_to_once_file and once_file:
|
||||
once_file.close()
|
||||
once_file = None
|
||||
|
||||
|
||||
@special_command('\\pipe_once', '\\| command',
|
||||
'Send next result to a subprocess.',
|
||||
aliases=('\\|', ))
|
||||
def set_pipe_once(arg, **_):
|
||||
global pipe_once_process, written_to_pipe_once_process
|
||||
pipe_once_cmd = shlex.split(arg)
|
||||
if len(pipe_once_cmd) == 0:
|
||||
raise OSError("pipe_once requires a command")
|
||||
written_to_pipe_once_process = False
|
||||
pipe_once_process = subprocess.Popen(pipe_once_cmd,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
bufsize=1,
|
||||
encoding='UTF-8',
|
||||
universal_newlines=True)
|
||||
return [(None, None, None, "")]
|
||||
|
||||
|
||||
@export
|
||||
def write_pipe_once(output):
|
||||
global pipe_once_process, written_to_pipe_once_process
|
||||
if output and pipe_once_process:
|
||||
try:
|
||||
click.echo(output, file=pipe_once_process.stdin, nl=False)
|
||||
click.echo(u"\n", file=pipe_once_process.stdin, nl=False)
|
||||
except (IOError, OSError) as e:
|
||||
pipe_once_process.terminate()
|
||||
raise OSError(
|
||||
"Failed writing to pipe_once subprocess: {}".format(e.strerror))
|
||||
written_to_pipe_once_process = True
|
||||
|
||||
|
||||
@export
|
||||
def unset_pipe_once_if_written():
|
||||
"""Unset the pipe_once cmd, if it has been written to."""
|
||||
global pipe_once_process, written_to_pipe_once_process
|
||||
if written_to_pipe_once_process:
|
||||
(stdout_data, stderr_data) = pipe_once_process.communicate()
|
||||
if len(stdout_data) > 0:
|
||||
print(stdout_data.rstrip(u"\n"))
|
||||
if len(stderr_data) > 0:
|
||||
print(stderr_data.rstrip(u"\n"))
|
||||
pipe_once_process = None
|
||||
written_to_pipe_once_process = False
|
||||
|
||||
|
||||
@special_command(
|
||||
'watch',
|
||||
'watch [seconds] [-c] query',
|
||||
'Executes the query every [seconds] seconds (by default 5).'
|
||||
)
|
||||
def watch_query(arg, **kwargs):
|
||||
usage = """Syntax: watch [seconds] [-c] query.
|
||||
* seconds: The interval at the query will be repeated, in seconds.
|
||||
By default 5.
|
||||
* -c: Clears the screen between every iteration.
|
||||
"""
|
||||
if not arg:
|
||||
yield (None, None, None, usage)
|
||||
return
|
||||
seconds = 5
|
||||
clear_screen = False
|
||||
statement = None
|
||||
while statement is None:
|
||||
arg = arg.strip()
|
||||
if not arg:
|
||||
# Oops, we parsed all the arguments without finding a statement
|
||||
yield (None, None, None, usage)
|
||||
return
|
||||
(current_arg, _, arg) = arg.partition(' ')
|
||||
try:
|
||||
seconds = float(current_arg)
|
||||
continue
|
||||
except ValueError:
|
||||
pass
|
||||
if current_arg == '-c':
|
||||
clear_screen = True
|
||||
continue
|
||||
statement = '{0!s} {1!s}'.format(current_arg, arg)
|
||||
destructive_prompt = confirm_destructive_query(statement)
|
||||
if destructive_prompt is False:
|
||||
click.secho("Wise choice!")
|
||||
return
|
||||
elif destructive_prompt is True:
|
||||
click.secho("Your call!")
|
||||
cur = kwargs['cur']
|
||||
sql_list = [
|
||||
(sql.rstrip(';'), "> {0!s}".format(sql))
|
||||
for sql in sqlparse.split(statement)
|
||||
]
|
||||
old_pager_enabled = is_pager_enabled()
|
||||
while True:
|
||||
if clear_screen:
|
||||
click.clear()
|
||||
try:
|
||||
# Somewhere in the code the pager its activated after every yield,
|
||||
# so we disable it in every iteration
|
||||
set_pager_enabled(False)
|
||||
for (sql, title) in sql_list:
|
||||
cur.execute(sql)
|
||||
if cur.description:
|
||||
headers = [x[0] for x in cur.description]
|
||||
yield (title, cur, headers, None)
|
||||
else:
|
||||
yield (title, None, None, None)
|
||||
sleep(seconds)
|
||||
except KeyboardInterrupt:
|
||||
# This prints the Ctrl-C character in its own line, which prevents
|
||||
# to print a line with the cursor positioned behind the prompt
|
||||
click.secho("", nl=True)
|
||||
return
|
||||
finally:
|
||||
set_pager_enabled(old_pager_enabled)
|
||||
|
||||
|
||||
@export
|
||||
@special_command('delimiter', None, 'Change SQL delimiter.')
|
||||
def set_delimiter(arg, **_):
|
||||
return delimiter_command.set(arg)
|
||||
|
||||
|
||||
@export
|
||||
def get_current_delimiter():
|
||||
return delimiter_command.current
|
||||
|
||||
|
||||
@export
|
||||
def split_queries(input):
|
||||
for query in delimiter_command.queries_iter(input):
|
||||
yield query
|
120
mycli/packages/special/main.py
Normal file
120
mycli/packages/special/main.py
Normal file
|
@ -0,0 +1,120 @@
|
|||
import logging
|
||||
from collections import namedtuple
|
||||
|
||||
from . import export
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
NO_QUERY = 0
|
||||
PARSED_QUERY = 1
|
||||
RAW_QUERY = 2
|
||||
|
||||
SpecialCommand = namedtuple('SpecialCommand',
|
||||
['handler', 'command', 'shortcut', 'description', 'arg_type', 'hidden',
|
||||
'case_sensitive'])
|
||||
|
||||
COMMANDS = {}
|
||||
|
||||
@export
|
||||
class CommandNotFound(Exception):
|
||||
pass
|
||||
|
||||
@export
|
||||
def parse_special_command(sql):
|
||||
command, _, arg = sql.partition(' ')
|
||||
verbose = '+' in command
|
||||
command = command.strip().replace('+', '')
|
||||
return (command, verbose, arg.strip())
|
||||
|
||||
@export
|
||||
def special_command(command, shortcut, description, arg_type=PARSED_QUERY,
|
||||
hidden=False, case_sensitive=False, aliases=()):
|
||||
def wrapper(wrapped):
|
||||
register_special_command(wrapped, command, shortcut, description,
|
||||
arg_type, hidden, case_sensitive, aliases)
|
||||
return wrapped
|
||||
return wrapper
|
||||
|
||||
@export
|
||||
def register_special_command(handler, command, shortcut, description,
|
||||
arg_type=PARSED_QUERY, hidden=False, case_sensitive=False, aliases=()):
|
||||
cmd = command.lower() if not case_sensitive else command
|
||||
COMMANDS[cmd] = SpecialCommand(handler, command, shortcut, description,
|
||||
arg_type, hidden, case_sensitive)
|
||||
for alias in aliases:
|
||||
cmd = alias.lower() if not case_sensitive else alias
|
||||
COMMANDS[cmd] = SpecialCommand(handler, command, shortcut, description,
|
||||
arg_type, case_sensitive=case_sensitive,
|
||||
hidden=True)
|
||||
|
||||
@export
|
||||
def execute(cur, sql):
|
||||
"""Execute a special command and return the results. If the special command
|
||||
is not supported a KeyError will be raised.
|
||||
"""
|
||||
command, verbose, arg = parse_special_command(sql)
|
||||
|
||||
if (command not in COMMANDS) and (command.lower() not in COMMANDS):
|
||||
raise CommandNotFound
|
||||
|
||||
try:
|
||||
special_cmd = COMMANDS[command]
|
||||
except KeyError:
|
||||
special_cmd = COMMANDS[command.lower()]
|
||||
if special_cmd.case_sensitive:
|
||||
raise CommandNotFound('Command not found: %s' % command)
|
||||
|
||||
# "help <SQL KEYWORD> is a special case. We want built-in help, not
|
||||
# mycli help here.
|
||||
if command == 'help' and arg:
|
||||
return show_keyword_help(cur=cur, arg=arg)
|
||||
|
||||
if special_cmd.arg_type == NO_QUERY:
|
||||
return special_cmd.handler()
|
||||
elif special_cmd.arg_type == PARSED_QUERY:
|
||||
return special_cmd.handler(cur=cur, arg=arg, verbose=verbose)
|
||||
elif special_cmd.arg_type == RAW_QUERY:
|
||||
return special_cmd.handler(cur=cur, query=sql)
|
||||
|
||||
@special_command('help', '\\?', 'Show this help.', arg_type=NO_QUERY, aliases=('\\?', '?'))
|
||||
def show_help(): # All the parameters are ignored.
|
||||
headers = ['Command', 'Shortcut', 'Description']
|
||||
result = []
|
||||
|
||||
for _, value in sorted(COMMANDS.items()):
|
||||
if not value.hidden:
|
||||
result.append((value.command, value.shortcut, value.description))
|
||||
return [(None, result, headers, None)]
|
||||
|
||||
def show_keyword_help(cur, arg):
|
||||
"""
|
||||
Call the built-in "show <command>", to display help for an SQL keyword.
|
||||
:param cur: cursor
|
||||
:param arg: string
|
||||
:return: list
|
||||
"""
|
||||
keyword = arg.strip('"').strip("'")
|
||||
query = "help '{0}'".format(keyword)
|
||||
log.debug(query)
|
||||
cur.execute(query)
|
||||
if cur.description and cur.rowcount > 0:
|
||||
headers = [x[0] for x in cur.description]
|
||||
return [(None, cur, headers, '')]
|
||||
else:
|
||||
return [(None, None, None, 'No help found for {0}.'.format(keyword))]
|
||||
|
||||
|
||||
@special_command('exit', '\\q', 'Exit.', arg_type=NO_QUERY, aliases=('\\q', ))
|
||||
@special_command('quit', '\\q', 'Quit.', arg_type=NO_QUERY)
|
||||
def quit(*_args):
|
||||
raise EOFError
|
||||
|
||||
|
||||
@special_command('\\e', '\\e', 'Edit command with editor (uses $EDITOR).',
|
||||
arg_type=NO_QUERY, case_sensitive=True)
|
||||
@special_command('\\clip', '\\clip', 'Copy query to the system clipboard.',
|
||||
arg_type=NO_QUERY, case_sensitive=True)
|
||||
@special_command('\\G', '\\G', 'Display current query results vertically.',
|
||||
arg_type=NO_QUERY, case_sensitive=True)
|
||||
def stub():
|
||||
raise NotImplementedError
|
46
mycli/packages/special/utils.py
Normal file
46
mycli/packages/special/utils.py
Normal file
|
@ -0,0 +1,46 @@
|
|||
import os
|
||||
import subprocess
|
||||
|
||||
def handle_cd_command(arg):
|
||||
"""Handles a `cd` shell command by calling python's os.chdir."""
|
||||
CD_CMD = 'cd'
|
||||
tokens = arg.split(CD_CMD + ' ')
|
||||
directory = tokens[-1] if len(tokens) > 1 else None
|
||||
if not directory:
|
||||
return False, "No folder name was provided."
|
||||
try:
|
||||
os.chdir(directory)
|
||||
subprocess.call(['pwd'])
|
||||
return True, None
|
||||
except OSError as e:
|
||||
return False, e.strerror
|
||||
|
||||
def format_uptime(uptime_in_seconds):
|
||||
"""Format number of seconds into human-readable string.
|
||||
|
||||
:param uptime_in_seconds: The server uptime in seconds.
|
||||
:returns: A human-readable string representing the uptime.
|
||||
|
||||
>>> uptime = format_uptime('56892')
|
||||
>>> print(uptime)
|
||||
15 hours 48 min 12 sec
|
||||
"""
|
||||
|
||||
m, s = divmod(int(uptime_in_seconds), 60)
|
||||
h, m = divmod(m, 60)
|
||||
d, h = divmod(h, 24)
|
||||
|
||||
uptime_values = []
|
||||
|
||||
for value, unit in ((d, 'days'), (h, 'hours'), (m, 'min'), (s, 'sec')):
|
||||
if value == 0 and not uptime_values:
|
||||
# Don't include a value/unit if the unit isn't applicable to
|
||||
# the uptime. E.g. don't do 0 days 0 hours 1 min 30 sec.
|
||||
continue
|
||||
elif value == 1 and unit.endswith('s'):
|
||||
# Remove the "s" if the unit is singular.
|
||||
unit = unit[:-1]
|
||||
uptime_values.append('{0} {1}'.format(value, unit))
|
||||
|
||||
uptime = ' '.join(uptime_values)
|
||||
return uptime
|
0
mycli/packages/tabular_output/__init__.py
Normal file
0
mycli/packages/tabular_output/__init__.py
Normal file
63
mycli/packages/tabular_output/sql_format.py
Normal file
63
mycli/packages/tabular_output/sql_format.py
Normal file
|
@ -0,0 +1,63 @@
|
|||
"""Format adapter for sql."""
|
||||
|
||||
from cli_helpers.utils import filter_dict_by_key
|
||||
from mycli.packages.parseutils import extract_tables
|
||||
|
||||
supported_formats = ('sql-insert', 'sql-update', 'sql-update-1',
|
||||
'sql-update-2', )
|
||||
|
||||
preprocessors = ()
|
||||
|
||||
|
||||
def escape_for_sql_statement(value):
|
||||
if isinstance(value, bytes):
|
||||
return f"X'{value.hex()}'"
|
||||
else:
|
||||
return formatter.mycli.sqlexecute.conn.escape(value)
|
||||
|
||||
|
||||
def adapter(data, headers, table_format=None, **kwargs):
|
||||
tables = extract_tables(formatter.query)
|
||||
if len(tables) > 0:
|
||||
table = tables[0]
|
||||
if table[0]:
|
||||
table_name = "{}.{}".format(*table[:2])
|
||||
else:
|
||||
table_name = table[1]
|
||||
else:
|
||||
table_name = "`DUAL`"
|
||||
if table_format == 'sql-insert':
|
||||
h = "`, `".join(headers)
|
||||
yield "INSERT INTO {} (`{}`) VALUES".format(table_name, h)
|
||||
prefix = " "
|
||||
for d in data:
|
||||
values = ", ".join(escape_for_sql_statement(v)
|
||||
for i, v in enumerate(d))
|
||||
yield "{}({})".format(prefix, values)
|
||||
if prefix == " ":
|
||||
prefix = ", "
|
||||
yield ";"
|
||||
if table_format.startswith('sql-update'):
|
||||
s = table_format.split('-')
|
||||
keys = 1
|
||||
if len(s) > 2:
|
||||
keys = int(s[-1])
|
||||
for d in data:
|
||||
yield "UPDATE {} SET".format(table_name)
|
||||
prefix = " "
|
||||
for i, v in enumerate(d[keys:], keys):
|
||||
yield "{}`{}` = {}".format(prefix, headers[i], escape_for_sql_statement(v))
|
||||
if prefix == " ":
|
||||
prefix = ", "
|
||||
f = "`{}` = {}"
|
||||
where = (f.format(headers[i], escape_for_sql_statement(
|
||||
d[i])) for i in range(keys))
|
||||
yield "WHERE {};".format(" AND ".join(where))
|
||||
|
||||
|
||||
def register_new_formatter(TabularOutputFormatter):
|
||||
global formatter
|
||||
formatter = TabularOutputFormatter
|
||||
for sql_format in supported_formats:
|
||||
TabularOutputFormatter.register_new_formatter(
|
||||
sql_format, adapter, preprocessors, {'table_format': sql_format})
|
435
mycli/sqlcompleter.py
Normal file
435
mycli/sqlcompleter.py
Normal file
|
@ -0,0 +1,435 @@
|
|||
import logging
|
||||
from re import compile, escape
|
||||
from collections import Counter
|
||||
|
||||
from prompt_toolkit.completion import Completer, Completion
|
||||
|
||||
from .packages.completion_engine import suggest_type
|
||||
from .packages.parseutils import last_word
|
||||
from .packages.filepaths import parse_path, complete_path, suggest_path
|
||||
from .packages.special.favoritequeries import FavoriteQueries
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SQLCompleter(Completer):
|
||||
keywords = ['ACCESS', 'ADD', 'ALL', 'ALTER TABLE', 'AND', 'ANY', 'AS',
|
||||
'ASC', 'AUTO_INCREMENT', 'BEFORE', 'BEGIN', 'BETWEEN',
|
||||
'BIGINT', 'BINARY', 'BY', 'CASE', 'CHANGE MASTER TO', 'CHAR',
|
||||
'CHARACTER SET', 'CHECK', 'COLLATE', 'COLUMN', 'COMMENT',
|
||||
'COMMIT', 'CONSTRAINT', 'CREATE', 'CURRENT',
|
||||
'CURRENT_TIMESTAMP', 'DATABASE', 'DATE', 'DECIMAL', 'DEFAULT',
|
||||
'DELETE FROM', 'DESC', 'DESCRIBE', 'DROP',
|
||||
'ELSE', 'END', 'ENGINE', 'ESCAPE', 'EXISTS', 'FILE', 'FLOAT',
|
||||
'FOR', 'FOREIGN KEY', 'FORMAT', 'FROM', 'FULL', 'FUNCTION',
|
||||
'GRANT', 'GROUP BY', 'HAVING', 'HOST', 'IDENTIFIED', 'IN',
|
||||
'INCREMENT', 'INDEX', 'INSERT INTO', 'INT', 'INTEGER',
|
||||
'INTERVAL', 'INTO', 'IS', 'JOIN', 'KEY', 'LEFT', 'LEVEL',
|
||||
'LIKE', 'LIMIT', 'LOCK', 'LOGS', 'LONG', 'MASTER',
|
||||
'MEDIUMINT', 'MODE', 'MODIFY', 'NOT', 'NULL', 'NUMBER',
|
||||
'OFFSET', 'ON', 'OPTION', 'OR', 'ORDER BY', 'OUTER', 'OWNER',
|
||||
'PASSWORD', 'PORT', 'PRIMARY', 'PRIVILEGES', 'PROCESSLIST',
|
||||
'PURGE', 'REFERENCES', 'REGEXP', 'RENAME', 'REPAIR', 'RESET',
|
||||
'REVOKE', 'RIGHT', 'ROLLBACK', 'ROW', 'ROWS', 'ROW_FORMAT',
|
||||
'SAVEPOINT', 'SELECT', 'SESSION', 'SET', 'SHARE', 'SHOW',
|
||||
'SLAVE', 'SMALLINT', 'SMALLINT', 'START', 'STOP', 'TABLE',
|
||||
'THEN', 'TINYINT', 'TO', 'TRANSACTION', 'TRIGGER', 'TRUNCATE',
|
||||
'UNION', 'UNIQUE', 'UNSIGNED', 'UPDATE', 'USE', 'USER',
|
||||
'USING', 'VALUES', 'VARCHAR', 'VIEW', 'WHEN', 'WHERE', 'WITH']
|
||||
|
||||
functions = ['AVG', 'CONCAT', 'COUNT', 'DISTINCT', 'FIRST', 'FORMAT',
|
||||
'FROM_UNIXTIME', 'LAST', 'LCASE', 'LEN', 'MAX', 'MID',
|
||||
'MIN', 'NOW', 'ROUND', 'SUM', 'TOP', 'UCASE', 'UNIX_TIMESTAMP']
|
||||
|
||||
show_items = []
|
||||
|
||||
change_items = ['MASTER_BIND', 'MASTER_HOST', 'MASTER_USER',
|
||||
'MASTER_PASSWORD', 'MASTER_PORT', 'MASTER_CONNECT_RETRY',
|
||||
'MASTER_HEARTBEAT_PERIOD', 'MASTER_LOG_FILE',
|
||||
'MASTER_LOG_POS', 'RELAY_LOG_FILE', 'RELAY_LOG_POS',
|
||||
'MASTER_SSL', 'MASTER_SSL_CA', 'MASTER_SSL_CAPATH',
|
||||
'MASTER_SSL_CERT', 'MASTER_SSL_KEY', 'MASTER_SSL_CIPHER',
|
||||
'MASTER_SSL_VERIFY_SERVER_CERT', 'IGNORE_SERVER_IDS']
|
||||
|
||||
users = []
|
||||
|
||||
def __init__(self, smart_completion=True, supported_formats=(), keyword_casing='auto'):
|
||||
super(self.__class__, self).__init__()
|
||||
self.smart_completion = smart_completion
|
||||
self.reserved_words = set()
|
||||
for x in self.keywords:
|
||||
self.reserved_words.update(x.split())
|
||||
self.name_pattern = compile(r"^[_a-z][_a-z0-9\$]*$")
|
||||
|
||||
self.special_commands = []
|
||||
self.table_formats = supported_formats
|
||||
if keyword_casing not in ('upper', 'lower', 'auto'):
|
||||
keyword_casing = 'auto'
|
||||
self.keyword_casing = keyword_casing
|
||||
self.reset_completions()
|
||||
|
||||
def escape_name(self, name):
|
||||
if name and ((not self.name_pattern.match(name))
|
||||
or (name.upper() in self.reserved_words)
|
||||
or (name.upper() in self.functions)):
|
||||
name = '`%s`' % name
|
||||
|
||||
return name
|
||||
|
||||
def unescape_name(self, name):
|
||||
"""Unquote a string."""
|
||||
if name and name[0] == '"' and name[-1] == '"':
|
||||
name = name[1:-1]
|
||||
|
||||
return name
|
||||
|
||||
def escaped_names(self, names):
|
||||
return [self.escape_name(name) for name in names]
|
||||
|
||||
def extend_special_commands(self, special_commands):
|
||||
# Special commands are not part of all_completions since they can only
|
||||
# be at the beginning of a line.
|
||||
self.special_commands.extend(special_commands)
|
||||
|
||||
def extend_database_names(self, databases):
|
||||
self.databases.extend(databases)
|
||||
|
||||
def extend_keywords(self, additional_keywords):
|
||||
self.keywords.extend(additional_keywords)
|
||||
self.all_completions.update(additional_keywords)
|
||||
|
||||
def extend_show_items(self, show_items):
|
||||
for show_item in show_items:
|
||||
self.show_items.extend(show_item)
|
||||
self.all_completions.update(show_item)
|
||||
|
||||
def extend_change_items(self, change_items):
|
||||
for change_item in change_items:
|
||||
self.change_items.extend(change_item)
|
||||
self.all_completions.update(change_item)
|
||||
|
||||
def extend_users(self, users):
|
||||
for user in users:
|
||||
self.users.extend(user)
|
||||
self.all_completions.update(user)
|
||||
|
||||
def extend_schemata(self, schema):
|
||||
if schema is None:
|
||||
return
|
||||
metadata = self.dbmetadata['tables']
|
||||
metadata[schema] = {}
|
||||
|
||||
# dbmetadata.values() are the 'tables' and 'functions' dicts
|
||||
for metadata in self.dbmetadata.values():
|
||||
metadata[schema] = {}
|
||||
self.all_completions.update(schema)
|
||||
|
||||
def extend_relations(self, data, kind):
|
||||
"""Extend metadata for tables or views
|
||||
|
||||
:param data: list of (rel_name, ) tuples
|
||||
:param kind: either 'tables' or 'views'
|
||||
:return:
|
||||
"""
|
||||
# 'data' is a generator object. It can throw an exception while being
|
||||
# consumed. This could happen if the user has launched the app without
|
||||
# specifying a database name. This exception must be handled to prevent
|
||||
# crashing.
|
||||
try:
|
||||
data = [self.escaped_names(d) for d in data]
|
||||
except Exception:
|
||||
data = []
|
||||
|
||||
# dbmetadata['tables'][$schema_name][$table_name] should be a list of
|
||||
# column names. Default to an asterisk
|
||||
metadata = self.dbmetadata[kind]
|
||||
for relname in data:
|
||||
try:
|
||||
metadata[self.dbname][relname[0]] = ['*']
|
||||
except KeyError:
|
||||
_logger.error('%r %r listed in unrecognized schema %r',
|
||||
kind, relname[0], self.dbname)
|
||||
self.all_completions.add(relname[0])
|
||||
|
||||
def extend_columns(self, column_data, kind):
|
||||
"""Extend column metadata
|
||||
|
||||
:param column_data: list of (rel_name, column_name) tuples
|
||||
:param kind: either 'tables' or 'views'
|
||||
:return:
|
||||
"""
|
||||
# 'column_data' is a generator object. It can throw an exception while
|
||||
# being consumed. This could happen if the user has launched the app
|
||||
# without specifying a database name. This exception must be handled to
|
||||
# prevent crashing.
|
||||
try:
|
||||
column_data = [self.escaped_names(d) for d in column_data]
|
||||
except Exception:
|
||||
column_data = []
|
||||
|
||||
metadata = self.dbmetadata[kind]
|
||||
for relname, column in column_data:
|
||||
metadata[self.dbname][relname].append(column)
|
||||
self.all_completions.add(column)
|
||||
|
||||
def extend_functions(self, func_data):
|
||||
# 'func_data' is a generator object. It can throw an exception while
|
||||
# being consumed. This could happen if the user has launched the app
|
||||
# without specifying a database name. This exception must be handled to
|
||||
# prevent crashing.
|
||||
try:
|
||||
func_data = [self.escaped_names(d) for d in func_data]
|
||||
except Exception:
|
||||
func_data = []
|
||||
|
||||
# dbmetadata['functions'][$schema_name][$function_name] should return
|
||||
# function metadata.
|
||||
metadata = self.dbmetadata['functions']
|
||||
|
||||
for func in func_data:
|
||||
metadata[self.dbname][func[0]] = None
|
||||
self.all_completions.add(func[0])
|
||||
|
||||
def set_dbname(self, dbname):
|
||||
self.dbname = dbname
|
||||
|
||||
def reset_completions(self):
|
||||
self.databases = []
|
||||
self.users = []
|
||||
self.show_items = []
|
||||
self.dbname = ''
|
||||
self.dbmetadata = {'tables': {}, 'views': {}, 'functions': {}}
|
||||
self.all_completions = set(self.keywords + self.functions)
|
||||
|
||||
@staticmethod
|
||||
def find_matches(text, collection, start_only=False, fuzzy=True, casing=None):
|
||||
"""Find completion matches for the given text.
|
||||
|
||||
Given the user's input text and a collection of available
|
||||
completions, find completions matching the last word of the
|
||||
text.
|
||||
|
||||
If `start_only` is True, the text will match an available
|
||||
completion only at the beginning. Otherwise, a completion is
|
||||
considered a match if the text appears anywhere within it.
|
||||
|
||||
yields prompt_toolkit Completion instances for any matches found
|
||||
in the collection of available completions.
|
||||
"""
|
||||
last = last_word(text, include='most_punctuations')
|
||||
text = last.lower()
|
||||
|
||||
completions = []
|
||||
|
||||
if fuzzy:
|
||||
regex = '.*?'.join(map(escape, text))
|
||||
pat = compile('(%s)' % regex)
|
||||
for item in sorted(collection):
|
||||
r = pat.search(item.lower())
|
||||
if r:
|
||||
completions.append((len(r.group()), r.start(), item))
|
||||
else:
|
||||
match_end_limit = len(text) if start_only else None
|
||||
for item in sorted(collection):
|
||||
match_point = item.lower().find(text, 0, match_end_limit)
|
||||
if match_point >= 0:
|
||||
completions.append((len(text), match_point, item))
|
||||
|
||||
if casing == 'auto':
|
||||
casing = 'lower' if last and last[-1].islower() else 'upper'
|
||||
|
||||
def apply_case(kw):
|
||||
if casing == 'upper':
|
||||
return kw.upper()
|
||||
return kw.lower()
|
||||
|
||||
return (Completion(z if casing is None else apply_case(z), -len(text))
|
||||
for x, y, z in sorted(completions))
|
||||
|
||||
def get_completions(self, document, complete_event, smart_completion=None):
|
||||
word_before_cursor = document.get_word_before_cursor(WORD=True)
|
||||
if smart_completion is None:
|
||||
smart_completion = self.smart_completion
|
||||
|
||||
# If smart_completion is off then match any word that starts with
|
||||
# 'word_before_cursor'.
|
||||
if not smart_completion:
|
||||
return self.find_matches(word_before_cursor, self.all_completions,
|
||||
start_only=True, fuzzy=False)
|
||||
|
||||
completions = []
|
||||
suggestions = suggest_type(document.text, document.text_before_cursor)
|
||||
|
||||
for suggestion in suggestions:
|
||||
|
||||
_logger.debug('Suggestion type: %r', suggestion['type'])
|
||||
|
||||
if suggestion['type'] == 'column':
|
||||
tables = suggestion['tables']
|
||||
_logger.debug("Completion column scope: %r", tables)
|
||||
scoped_cols = self.populate_scoped_cols(tables)
|
||||
if suggestion.get('drop_unique'):
|
||||
# drop_unique is used for 'tb11 JOIN tbl2 USING (...'
|
||||
# which should suggest only columns that appear in more than
|
||||
# one table
|
||||
scoped_cols = [
|
||||
col for (col, count) in Counter(scoped_cols).items()
|
||||
if count > 1 and col != '*'
|
||||
]
|
||||
|
||||
cols = self.find_matches(word_before_cursor, scoped_cols)
|
||||
completions.extend(cols)
|
||||
|
||||
elif suggestion['type'] == 'function':
|
||||
# suggest user-defined functions using substring matching
|
||||
funcs = self.populate_schema_objects(suggestion['schema'],
|
||||
'functions')
|
||||
user_funcs = self.find_matches(word_before_cursor, funcs)
|
||||
completions.extend(user_funcs)
|
||||
|
||||
# suggest hardcoded functions using startswith matching only if
|
||||
# there is no schema qualifier. If a schema qualifier is
|
||||
# present it probably denotes a table.
|
||||
# eg: SELECT * FROM users u WHERE u.
|
||||
if not suggestion['schema']:
|
||||
predefined_funcs = self.find_matches(word_before_cursor,
|
||||
self.functions,
|
||||
start_only=True,
|
||||
fuzzy=False,
|
||||
casing=self.keyword_casing)
|
||||
completions.extend(predefined_funcs)
|
||||
|
||||
elif suggestion['type'] == 'table':
|
||||
tables = self.populate_schema_objects(suggestion['schema'],
|
||||
'tables')
|
||||
tables = self.find_matches(word_before_cursor, tables)
|
||||
completions.extend(tables)
|
||||
|
||||
elif suggestion['type'] == 'view':
|
||||
views = self.populate_schema_objects(suggestion['schema'],
|
||||
'views')
|
||||
views = self.find_matches(word_before_cursor, views)
|
||||
completions.extend(views)
|
||||
|
||||
elif suggestion['type'] == 'alias':
|
||||
aliases = suggestion['aliases']
|
||||
aliases = self.find_matches(word_before_cursor, aliases)
|
||||
completions.extend(aliases)
|
||||
|
||||
elif suggestion['type'] == 'database':
|
||||
dbs = self.find_matches(word_before_cursor, self.databases)
|
||||
completions.extend(dbs)
|
||||
|
||||
elif suggestion['type'] == 'keyword':
|
||||
keywords = self.find_matches(word_before_cursor, self.keywords,
|
||||
start_only=True,
|
||||
fuzzy=False,
|
||||
casing=self.keyword_casing)
|
||||
completions.extend(keywords)
|
||||
|
||||
elif suggestion['type'] == 'show':
|
||||
show_items = self.find_matches(word_before_cursor,
|
||||
self.show_items,
|
||||
start_only=False,
|
||||
fuzzy=True,
|
||||
casing=self.keyword_casing)
|
||||
completions.extend(show_items)
|
||||
|
||||
elif suggestion['type'] == 'change':
|
||||
change_items = self.find_matches(word_before_cursor,
|
||||
self.change_items,
|
||||
start_only=False,
|
||||
fuzzy=True)
|
||||
completions.extend(change_items)
|
||||
elif suggestion['type'] == 'user':
|
||||
users = self.find_matches(word_before_cursor, self.users,
|
||||
start_only=False,
|
||||
fuzzy=True)
|
||||
completions.extend(users)
|
||||
|
||||
elif suggestion['type'] == 'special':
|
||||
special = self.find_matches(word_before_cursor,
|
||||
self.special_commands,
|
||||
start_only=True,
|
||||
fuzzy=False)
|
||||
completions.extend(special)
|
||||
elif suggestion['type'] == 'favoritequery':
|
||||
queries = self.find_matches(word_before_cursor,
|
||||
FavoriteQueries.instance.list(),
|
||||
start_only=False, fuzzy=True)
|
||||
completions.extend(queries)
|
||||
elif suggestion['type'] == 'table_format':
|
||||
formats = self.find_matches(word_before_cursor,
|
||||
self.table_formats,
|
||||
start_only=True, fuzzy=False)
|
||||
completions.extend(formats)
|
||||
elif suggestion['type'] == 'file_name':
|
||||
file_names = self.find_files(word_before_cursor)
|
||||
completions.extend(file_names)
|
||||
|
||||
return completions
|
||||
|
||||
def find_files(self, word):
|
||||
"""Yield matching directory or file names.
|
||||
|
||||
:param word:
|
||||
:return: iterable
|
||||
|
||||
"""
|
||||
base_path, last_path, position = parse_path(word)
|
||||
paths = suggest_path(word)
|
||||
for name in sorted(paths):
|
||||
suggestion = complete_path(name, last_path)
|
||||
if suggestion:
|
||||
yield Completion(suggestion, position)
|
||||
|
||||
def populate_scoped_cols(self, scoped_tbls):
|
||||
"""Find all columns in a set of scoped_tables
|
||||
:param scoped_tbls: list of (schema, table, alias) tuples
|
||||
:return: list of column names
|
||||
"""
|
||||
columns = []
|
||||
meta = self.dbmetadata
|
||||
|
||||
for tbl in scoped_tbls:
|
||||
# A fully qualified schema.relname reference or default_schema
|
||||
# DO NOT escape schema names.
|
||||
schema = tbl[0] or self.dbname
|
||||
relname = tbl[1]
|
||||
escaped_relname = self.escape_name(tbl[1])
|
||||
|
||||
# We don't know if schema.relname is a table or view. Since
|
||||
# tables and views cannot share the same name, we can check one
|
||||
# at a time
|
||||
try:
|
||||
columns.extend(meta['tables'][schema][relname])
|
||||
|
||||
# Table exists, so don't bother checking for a view
|
||||
continue
|
||||
except KeyError:
|
||||
try:
|
||||
columns.extend(meta['tables'][schema][escaped_relname])
|
||||
# Table exists, so don't bother checking for a view
|
||||
continue
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
try:
|
||||
columns.extend(meta['views'][schema][relname])
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
return columns
|
||||
|
||||
def populate_schema_objects(self, schema, obj_type):
|
||||
"""Returns list of tables or functions for a (optional) schema"""
|
||||
metadata = self.dbmetadata[obj_type]
|
||||
schema = schema or self.dbname
|
||||
|
||||
try:
|
||||
objects = metadata[schema].keys()
|
||||
except KeyError:
|
||||
# schema doesn't exist
|
||||
objects = []
|
||||
|
||||
return objects
|
322
mycli/sqlexecute.py
Normal file
322
mycli/sqlexecute.py
Normal file
|
@ -0,0 +1,322 @@
|
|||
import logging
|
||||
import pymysql
|
||||
import sqlparse
|
||||
from .packages import special
|
||||
from pymysql.constants import FIELD_TYPE
|
||||
from pymysql.converters import (convert_datetime,
|
||||
convert_timedelta, convert_date, conversions,
|
||||
decoders)
|
||||
try:
|
||||
import paramiko
|
||||
except ImportError:
|
||||
from mycli.packages.paramiko_stub import paramiko
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
FIELD_TYPES = decoders.copy()
|
||||
FIELD_TYPES.update({
|
||||
FIELD_TYPE.NULL: type(None)
|
||||
})
|
||||
|
||||
class SQLExecute(object):
|
||||
|
||||
databases_query = '''SHOW DATABASES'''
|
||||
|
||||
tables_query = '''SHOW TABLES'''
|
||||
|
||||
version_query = '''SELECT @@VERSION'''
|
||||
|
||||
version_comment_query = '''SELECT @@VERSION_COMMENT'''
|
||||
version_comment_query_mysql4 = '''SHOW VARIABLES LIKE "version_comment"'''
|
||||
|
||||
show_candidates_query = '''SELECT name from mysql.help_topic WHERE name like "SHOW %"'''
|
||||
|
||||
users_query = '''SELECT CONCAT("'", user, "'@'",host,"'") FROM mysql.user'''
|
||||
|
||||
functions_query = '''SELECT ROUTINE_NAME FROM INFORMATION_SCHEMA.ROUTINES
|
||||
WHERE ROUTINE_TYPE="FUNCTION" AND ROUTINE_SCHEMA = "%s"'''
|
||||
|
||||
table_columns_query = '''select TABLE_NAME, COLUMN_NAME from information_schema.columns
|
||||
where table_schema = '%s'
|
||||
order by table_name,ordinal_position'''
|
||||
|
||||
def __init__(self, database, user, password, host, port, socket, charset,
|
||||
local_infile, ssl, ssh_user, ssh_host, ssh_port, ssh_password,
|
||||
ssh_key_filename, init_command=None):
|
||||
self.dbname = database
|
||||
self.user = user
|
||||
self.password = password
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.socket = socket
|
||||
self.charset = charset
|
||||
self.local_infile = local_infile
|
||||
self.ssl = ssl
|
||||
self._server_type = None
|
||||
self.connection_id = None
|
||||
self.ssh_user = ssh_user
|
||||
self.ssh_host = ssh_host
|
||||
self.ssh_port = ssh_port
|
||||
self.ssh_password = ssh_password
|
||||
self.ssh_key_filename = ssh_key_filename
|
||||
self.init_command = init_command
|
||||
self.connect()
|
||||
|
||||
def connect(self, database=None, user=None, password=None, host=None,
|
||||
port=None, socket=None, charset=None, local_infile=None,
|
||||
ssl=None, ssh_host=None, ssh_port=None, ssh_user=None,
|
||||
ssh_password=None, ssh_key_filename=None, init_command=None):
|
||||
db = (database or self.dbname)
|
||||
user = (user or self.user)
|
||||
password = (password or self.password)
|
||||
host = (host or self.host)
|
||||
port = (port or self.port)
|
||||
socket = (socket or self.socket)
|
||||
charset = (charset or self.charset)
|
||||
local_infile = (local_infile or self.local_infile)
|
||||
ssl = (ssl or self.ssl)
|
||||
ssh_user = (ssh_user or self.ssh_user)
|
||||
ssh_host = (ssh_host or self.ssh_host)
|
||||
ssh_port = (ssh_port or self.ssh_port)
|
||||
ssh_password = (ssh_password or self.ssh_password)
|
||||
ssh_key_filename = (ssh_key_filename or self.ssh_key_filename)
|
||||
init_command = (init_command or self.init_command)
|
||||
_logger.debug(
|
||||
'Connection DB Params: \n'
|
||||
'\tdatabase: %r'
|
||||
'\tuser: %r'
|
||||
'\thost: %r'
|
||||
'\tport: %r'
|
||||
'\tsocket: %r'
|
||||
'\tcharset: %r'
|
||||
'\tlocal_infile: %r'
|
||||
'\tssl: %r'
|
||||
'\tssh_user: %r'
|
||||
'\tssh_host: %r'
|
||||
'\tssh_port: %r'
|
||||
'\tssh_password: %r'
|
||||
'\tssh_key_filename: %r'
|
||||
'\tinit_command: %r',
|
||||
db, user, host, port, socket, charset, local_infile, ssl,
|
||||
ssh_user, ssh_host, ssh_port, ssh_password, ssh_key_filename,
|
||||
init_command
|
||||
)
|
||||
conv = conversions.copy()
|
||||
conv.update({
|
||||
FIELD_TYPE.TIMESTAMP: lambda obj: (convert_datetime(obj) or obj),
|
||||
FIELD_TYPE.DATETIME: lambda obj: (convert_datetime(obj) or obj),
|
||||
FIELD_TYPE.TIME: lambda obj: (convert_timedelta(obj) or obj),
|
||||
FIELD_TYPE.DATE: lambda obj: (convert_date(obj) or obj),
|
||||
})
|
||||
|
||||
defer_connect = False
|
||||
|
||||
if ssh_host:
|
||||
defer_connect = True
|
||||
|
||||
client_flag = pymysql.constants.CLIENT.INTERACTIVE
|
||||
if init_command and len(list(special.split_queries(init_command))) > 1:
|
||||
client_flag |= pymysql.constants.CLIENT.MULTI_STATEMENTS
|
||||
|
||||
conn = pymysql.connect(
|
||||
database=db, user=user, password=password, host=host, port=port,
|
||||
unix_socket=socket, use_unicode=True, charset=charset,
|
||||
autocommit=True, client_flag=client_flag,
|
||||
local_infile=local_infile, conv=conv, ssl=ssl, program_name="mycli",
|
||||
defer_connect=defer_connect, init_command=init_command
|
||||
)
|
||||
|
||||
if ssh_host:
|
||||
client = paramiko.SSHClient()
|
||||
client.load_system_host_keys()
|
||||
client.set_missing_host_key_policy(paramiko.WarningPolicy())
|
||||
client.connect(
|
||||
ssh_host, ssh_port, ssh_user, ssh_password,
|
||||
key_filename=ssh_key_filename
|
||||
)
|
||||
chan = client.get_transport().open_channel(
|
||||
'direct-tcpip',
|
||||
(host, port),
|
||||
('0.0.0.0', 0),
|
||||
)
|
||||
conn.connect(chan)
|
||||
|
||||
if hasattr(self, 'conn'):
|
||||
self.conn.close()
|
||||
self.conn = conn
|
||||
# Update them after the connection is made to ensure that it was a
|
||||
# successful connection.
|
||||
self.dbname = db
|
||||
self.user = user
|
||||
self.password = password
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.socket = socket
|
||||
self.charset = charset
|
||||
self.ssl = ssl
|
||||
self.init_command = init_command
|
||||
# retrieve connection id
|
||||
self.reset_connection_id()
|
||||
|
||||
def run(self, statement):
|
||||
"""Execute the sql in the database and return the results. The results
|
||||
are a list of tuples. Each tuple has 4 values
|
||||
(title, rows, headers, status).
|
||||
"""
|
||||
|
||||
# Remove spaces and EOL
|
||||
statement = statement.strip()
|
||||
if not statement: # Empty string
|
||||
yield (None, None, None, None)
|
||||
|
||||
# Split the sql into separate queries and run each one.
|
||||
# Unless it's saving a favorite query, in which case we
|
||||
# want to save them all together.
|
||||
if statement.startswith('\\fs'):
|
||||
components = [statement]
|
||||
else:
|
||||
components = special.split_queries(statement)
|
||||
|
||||
for sql in components:
|
||||
# \G is treated specially since we have to set the expanded output.
|
||||
if sql.endswith('\\G'):
|
||||
special.set_expanded_output(True)
|
||||
sql = sql[:-2].strip()
|
||||
|
||||
cur = self.conn.cursor()
|
||||
try: # Special command
|
||||
_logger.debug('Trying a dbspecial command. sql: %r', sql)
|
||||
for result in special.execute(cur, sql):
|
||||
yield result
|
||||
except special.CommandNotFound: # Regular SQL
|
||||
_logger.debug('Regular sql statement. sql: %r', sql)
|
||||
cur.execute(sql)
|
||||
while True:
|
||||
yield self.get_result(cur)
|
||||
|
||||
# PyMySQL returns an extra, empty result set with stored
|
||||
# procedures. We skip it (rowcount is zero and no
|
||||
# description).
|
||||
if not cur.nextset() or (not cur.rowcount and cur.description is None):
|
||||
break
|
||||
|
||||
def get_result(self, cursor):
|
||||
"""Get the current result's data from the cursor."""
|
||||
title = headers = None
|
||||
|
||||
# cursor.description is not None for queries that return result sets,
|
||||
# e.g. SELECT or SHOW.
|
||||
if cursor.description is not None:
|
||||
headers = [x[0] for x in cursor.description]
|
||||
status = '{0} row{1} in set'
|
||||
else:
|
||||
_logger.debug('No rows in result.')
|
||||
status = 'Query OK, {0} row{1} affected'
|
||||
status = status.format(cursor.rowcount,
|
||||
'' if cursor.rowcount == 1 else 's')
|
||||
|
||||
return (title, cursor if cursor.description else None, headers, status)
|
||||
|
||||
def tables(self):
|
||||
"""Yields table names"""
|
||||
|
||||
with self.conn.cursor() as cur:
|
||||
_logger.debug('Tables Query. sql: %r', self.tables_query)
|
||||
cur.execute(self.tables_query)
|
||||
for row in cur:
|
||||
yield row
|
||||
|
||||
def table_columns(self):
|
||||
"""Yields (table name, column name) pairs"""
|
||||
with self.conn.cursor() as cur:
|
||||
_logger.debug('Columns Query. sql: %r', self.table_columns_query)
|
||||
cur.execute(self.table_columns_query % self.dbname)
|
||||
for row in cur:
|
||||
yield row
|
||||
|
||||
def databases(self):
|
||||
with self.conn.cursor() as cur:
|
||||
_logger.debug('Databases Query. sql: %r', self.databases_query)
|
||||
cur.execute(self.databases_query)
|
||||
return [x[0] for x in cur.fetchall()]
|
||||
|
||||
def functions(self):
|
||||
"""Yields tuples of (schema_name, function_name)"""
|
||||
|
||||
with self.conn.cursor() as cur:
|
||||
_logger.debug('Functions Query. sql: %r', self.functions_query)
|
||||
cur.execute(self.functions_query % self.dbname)
|
||||
for row in cur:
|
||||
yield row
|
||||
|
||||
def show_candidates(self):
|
||||
with self.conn.cursor() as cur:
|
||||
_logger.debug('Show Query. sql: %r', self.show_candidates_query)
|
||||
try:
|
||||
cur.execute(self.show_candidates_query)
|
||||
except pymysql.DatabaseError as e:
|
||||
_logger.error('No show completions due to %r', e)
|
||||
yield ''
|
||||
else:
|
||||
for row in cur:
|
||||
yield (row[0].split(None, 1)[-1], )
|
||||
|
||||
def users(self):
|
||||
with self.conn.cursor() as cur:
|
||||
_logger.debug('Users Query. sql: %r', self.users_query)
|
||||
try:
|
||||
cur.execute(self.users_query)
|
||||
except pymysql.DatabaseError as e:
|
||||
_logger.error('No user completions due to %r', e)
|
||||
yield ''
|
||||
else:
|
||||
for row in cur:
|
||||
yield row
|
||||
|
||||
def server_type(self):
|
||||
if self._server_type:
|
||||
return self._server_type
|
||||
with self.conn.cursor() as cur:
|
||||
_logger.debug('Version Query. sql: %r', self.version_query)
|
||||
cur.execute(self.version_query)
|
||||
version = cur.fetchone()[0]
|
||||
if version[0] == '4':
|
||||
_logger.debug('Version Comment. sql: %r',
|
||||
self.version_comment_query_mysql4)
|
||||
cur.execute(self.version_comment_query_mysql4)
|
||||
version_comment = cur.fetchone()[1].lower()
|
||||
if isinstance(version_comment, bytes):
|
||||
# with python3 this query returns bytes
|
||||
version_comment = version_comment.decode('utf-8')
|
||||
else:
|
||||
_logger.debug('Version Comment. sql: %r',
|
||||
self.version_comment_query)
|
||||
cur.execute(self.version_comment_query)
|
||||
version_comment = cur.fetchone()[0].lower()
|
||||
|
||||
if 'mariadb' in version_comment:
|
||||
product_type = 'mariadb'
|
||||
elif 'percona' in version_comment:
|
||||
product_type = 'percona'
|
||||
else:
|
||||
product_type = 'mysql'
|
||||
|
||||
self._server_type = (product_type, version)
|
||||
return self._server_type
|
||||
|
||||
def get_connection_id(self):
|
||||
if not self.connection_id:
|
||||
self.reset_connection_id()
|
||||
return self.connection_id
|
||||
|
||||
def reset_connection_id(self):
|
||||
# Remember current connection id
|
||||
_logger.debug('Get current connection id')
|
||||
res = self.run('select connection_id()')
|
||||
for title, cur, headers, status in res:
|
||||
self.connection_id = cur.fetchone()[0]
|
||||
_logger.debug('Current connection id: %s', self.connection_id)
|
||||
|
||||
def change_db(self, db):
|
||||
self.conn.select_db(db)
|
||||
self.dbname = db
|
Loading…
Add table
Add a link
Reference in a new issue