540 lines
15 KiB
Python
540 lines
15 KiB
Python
import os
|
|
import re
|
|
import locale
|
|
import logging
|
|
import subprocess
|
|
import shlex
|
|
from io import open
|
|
from time import sleep
|
|
|
|
import click
|
|
import pyperclip
|
|
import sqlparse
|
|
|
|
from . import export
|
|
from .main import special_command, NO_QUERY, PARSED_QUERY
|
|
from .favoritequeries import FavoriteQueries
|
|
from .delimitercommand import DelimiterCommand
|
|
from .utils import handle_cd_command
|
|
from mycli.packages.prompt_utils import confirm_destructive_query
|
|
|
|
TIMING_ENABLED = False
|
|
use_expanded_output = False
|
|
PAGER_ENABLED = True
|
|
tee_file = None
|
|
once_file = None
|
|
written_to_once_file = False
|
|
pipe_once_process = None
|
|
written_to_pipe_once_process = False
|
|
delimiter_command = DelimiterCommand()
|
|
|
|
|
|
@export
|
|
def set_timing_enabled(val):
|
|
global TIMING_ENABLED
|
|
TIMING_ENABLED = val
|
|
|
|
|
|
@export
|
|
def set_pager_enabled(val):
|
|
global PAGER_ENABLED
|
|
PAGER_ENABLED = val
|
|
|
|
|
|
@export
|
|
def is_pager_enabled():
|
|
return PAGER_ENABLED
|
|
|
|
|
|
@export
|
|
@special_command(
|
|
"pager", "\\P [command]", "Set PAGER. Print the query results via PAGER.", arg_type=PARSED_QUERY, aliases=("\\P",), case_sensitive=True
|
|
)
|
|
def set_pager(arg, **_):
|
|
if arg:
|
|
os.environ["PAGER"] = arg
|
|
msg = "PAGER set to %s." % arg
|
|
set_pager_enabled(True)
|
|
else:
|
|
if "PAGER" in os.environ:
|
|
msg = "PAGER set to %s." % os.environ["PAGER"]
|
|
else:
|
|
# This uses click's default per echo_via_pager.
|
|
msg = "Pager enabled."
|
|
set_pager_enabled(True)
|
|
|
|
return [(None, None, None, msg)]
|
|
|
|
|
|
@export
|
|
@special_command("nopager", "\\n", "Disable pager, print to stdout.", arg_type=NO_QUERY, aliases=("\\n",), case_sensitive=True)
|
|
def disable_pager():
|
|
set_pager_enabled(False)
|
|
return [(None, None, None, "Pager disabled.")]
|
|
|
|
|
|
@special_command("\\timing", "\\t", "Toggle timing of commands.", arg_type=NO_QUERY, aliases=("\\t",), case_sensitive=True)
|
|
def toggle_timing():
|
|
global TIMING_ENABLED
|
|
TIMING_ENABLED = not TIMING_ENABLED
|
|
message = "Timing is "
|
|
message += "on." if TIMING_ENABLED else "off."
|
|
return [(None, None, None, message)]
|
|
|
|
|
|
@export
|
|
def is_timing_enabled():
|
|
return TIMING_ENABLED
|
|
|
|
|
|
@export
|
|
def set_expanded_output(val):
|
|
global use_expanded_output
|
|
use_expanded_output = val
|
|
|
|
|
|
@export
|
|
def is_expanded_output():
|
|
return use_expanded_output
|
|
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
|
@export
|
|
def editor_command(command):
|
|
"""
|
|
Is this an external editor command?
|
|
:param command: string
|
|
"""
|
|
# It is possible to have `\e filename` or `SELECT * FROM \e`. So we check
|
|
# for both conditions.
|
|
return command.strip().endswith("\\e") or command.strip().startswith("\\e")
|
|
|
|
|
|
@export
|
|
def get_filename(sql):
|
|
if sql.strip().startswith("\\e"):
|
|
command, _, filename = sql.partition(" ")
|
|
return filename.strip() or None
|
|
|
|
|
|
@export
|
|
def get_editor_query(sql):
|
|
"""Get the query part of an editor command."""
|
|
sql = sql.strip()
|
|
|
|
# The reason we can't simply do .strip('\e') is that it strips characters,
|
|
# not a substring. So it'll strip "e" in the end of the sql also!
|
|
# Ex: "select * from style\e" -> "select * from styl".
|
|
pattern = re.compile(r"(^\\e|\\e$)")
|
|
while pattern.search(sql):
|
|
sql = pattern.sub("", sql)
|
|
|
|
return sql
|
|
|
|
|
|
@export
|
|
def open_external_editor(filename=None, sql=None):
|
|
"""Open external editor, wait for the user to type in their query, return
|
|
the query.
|
|
|
|
:return: list with one tuple, query as first element.
|
|
|
|
"""
|
|
|
|
message = None
|
|
filename = filename.strip().split(" ", 1)[0] if filename else None
|
|
|
|
sql = sql or ""
|
|
MARKER = "# Type your query above this line.\n"
|
|
|
|
# Populate the editor buffer with the partial sql (if available) and a
|
|
# placeholder comment.
|
|
query = click.edit("{sql}\n\n{marker}".format(sql=sql, marker=MARKER), filename=filename, extension=".sql")
|
|
|
|
if filename:
|
|
try:
|
|
with open(filename) as f:
|
|
query = f.read()
|
|
except IOError:
|
|
message = "Error reading file: %s." % filename
|
|
|
|
if query is not None:
|
|
query = query.split(MARKER, 1)[0].rstrip("\n")
|
|
else:
|
|
# Don't return None for the caller to deal with.
|
|
# Empty string is ok.
|
|
query = sql
|
|
|
|
return (query, message)
|
|
|
|
|
|
@export
|
|
def clip_command(command):
|
|
"""Is this a clip command?
|
|
|
|
:param command: string
|
|
|
|
"""
|
|
# It is possible to have `\clip` or `SELECT * FROM \clip`. So we check
|
|
# for both conditions.
|
|
return command.strip().endswith("\\clip") or command.strip().startswith("\\clip")
|
|
|
|
|
|
@export
|
|
def get_clip_query(sql):
|
|
"""Get the query part of a clip command."""
|
|
sql = sql.strip()
|
|
|
|
# The reason we can't simply do .strip('\clip') is that it strips characters,
|
|
# not a substring. So it'll strip "c" in the end of the sql also!
|
|
pattern = re.compile(r"(^\\clip|\\clip$)")
|
|
while pattern.search(sql):
|
|
sql = pattern.sub("", sql)
|
|
|
|
return sql
|
|
|
|
|
|
@export
|
|
def copy_query_to_clipboard(sql=None):
|
|
"""Send query to the clipboard."""
|
|
|
|
sql = sql or ""
|
|
message = None
|
|
|
|
try:
|
|
pyperclip.copy("{sql}".format(sql=sql))
|
|
except RuntimeError as e:
|
|
message = "Error clipping query: %s." % e.strerror
|
|
|
|
return message
|
|
|
|
|
|
@special_command("\\f", "\\f [name [args..]]", "List or execute favorite queries.", arg_type=PARSED_QUERY, case_sensitive=True)
|
|
def execute_favorite_query(cur, arg, **_):
|
|
"""Returns (title, rows, headers, status)"""
|
|
if arg == "":
|
|
for result in list_favorite_queries():
|
|
yield result
|
|
|
|
"""Parse out favorite name and optional substitution parameters"""
|
|
name, _, arg_str = arg.partition(" ")
|
|
args = shlex.split(arg_str)
|
|
|
|
query = FavoriteQueries.instance.get(name)
|
|
if query is None:
|
|
message = "No favorite query: %s" % (name)
|
|
yield (None, None, None, message)
|
|
else:
|
|
query, arg_error = subst_favorite_query_args(query, args)
|
|
if arg_error:
|
|
yield (None, None, None, arg_error)
|
|
else:
|
|
for sql in sqlparse.split(query):
|
|
sql = sql.rstrip(";")
|
|
title = "> %s" % (sql)
|
|
cur.execute(sql)
|
|
if cur.description:
|
|
headers = [x[0] for x in cur.description]
|
|
yield (title, cur, headers, None)
|
|
else:
|
|
yield (title, None, None, None)
|
|
|
|
|
|
def list_favorite_queries():
|
|
"""List of all favorite queries.
|
|
Returns (title, rows, headers, status)"""
|
|
|
|
headers = ["Name", "Query"]
|
|
rows = [(r, FavoriteQueries.instance.get(r)) for r in FavoriteQueries.instance.list()]
|
|
|
|
if not rows:
|
|
status = "\nNo favorite queries found." + FavoriteQueries.instance.usage
|
|
else:
|
|
status = ""
|
|
return [("", rows, headers, status)]
|
|
|
|
|
|
def subst_favorite_query_args(query, args):
|
|
"""replace positional parameters ($1...$N) in query."""
|
|
for idx, val in enumerate(args):
|
|
subst_var = "$" + str(idx + 1)
|
|
if subst_var not in query:
|
|
return [None, "query does not have substitution parameter " + subst_var + ":\n " + query]
|
|
|
|
query = query.replace(subst_var, val)
|
|
|
|
match = re.search(r"\$\d+", query)
|
|
if match:
|
|
return [None, "missing substitution for " + match.group(0) + " in query:\n " + query]
|
|
|
|
return [query, None]
|
|
|
|
|
|
@special_command("\\fs", "\\fs name query", "Save a favorite query.")
|
|
def save_favorite_query(arg, **_):
|
|
"""Save a new favorite query.
|
|
Returns (title, rows, headers, status)"""
|
|
|
|
usage = "Syntax: \\fs name query.\n\n" + FavoriteQueries.instance.usage
|
|
if not arg:
|
|
return [(None, None, None, usage)]
|
|
|
|
name, _, query = arg.partition(" ")
|
|
|
|
# If either name or query is missing then print the usage and complain.
|
|
if (not name) or (not query):
|
|
return [(None, None, None, usage + "Err: Both name and query are required.")]
|
|
|
|
FavoriteQueries.instance.save(name, query)
|
|
return [(None, None, None, "Saved.")]
|
|
|
|
|
|
@special_command("\\fd", "\\fd [name]", "Delete a favorite query.")
|
|
def delete_favorite_query(arg, **_):
|
|
"""Delete an existing favorite query."""
|
|
usage = "Syntax: \\fd name.\n\n" + FavoriteQueries.instance.usage
|
|
if not arg:
|
|
return [(None, None, None, usage)]
|
|
|
|
status = FavoriteQueries.instance.delete(arg)
|
|
|
|
return [(None, None, None, status)]
|
|
|
|
|
|
@special_command("system", "system [command]", "Execute a system shell commmand.")
|
|
def execute_system_command(arg, **_):
|
|
"""Execute a system shell command."""
|
|
usage = "Syntax: system [command].\n"
|
|
|
|
if not arg:
|
|
return [(None, None, None, usage)]
|
|
|
|
try:
|
|
command = arg.strip()
|
|
if command.startswith("cd"):
|
|
ok, error_message = handle_cd_command(arg)
|
|
if not ok:
|
|
return [(None, None, None, error_message)]
|
|
return [(None, None, None, "")]
|
|
|
|
args = arg.split(" ")
|
|
process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
|
output, error = process.communicate()
|
|
response = output if not error else error
|
|
|
|
# Python 3 returns bytes. This needs to be decoded to a string.
|
|
if isinstance(response, bytes):
|
|
encoding = locale.getpreferredencoding(False)
|
|
response = response.decode(encoding)
|
|
|
|
return [(None, None, None, response)]
|
|
except OSError as e:
|
|
return [(None, None, None, "OSError: %s" % e.strerror)]
|
|
|
|
|
|
def parseargfile(arg):
|
|
if arg.startswith("-o "):
|
|
mode = "w"
|
|
filename = arg[3:]
|
|
else:
|
|
mode = "a"
|
|
filename = arg
|
|
|
|
if not filename:
|
|
raise TypeError("You must provide a filename.")
|
|
|
|
return {"file": os.path.expanduser(filename), "mode": mode}
|
|
|
|
|
|
@special_command("tee", "tee [-o] filename", "Append all results to an output file (overwrite using -o).")
|
|
def set_tee(arg, **_):
|
|
global tee_file
|
|
|
|
try:
|
|
tee_file = open(**parseargfile(arg))
|
|
except (IOError, OSError) as e:
|
|
raise OSError("Cannot write to file '{}': {}".format(e.filename, e.strerror))
|
|
|
|
return [(None, None, None, "")]
|
|
|
|
|
|
@export
|
|
def close_tee():
|
|
global tee_file
|
|
if tee_file:
|
|
tee_file.close()
|
|
tee_file = None
|
|
|
|
|
|
@special_command("notee", "notee", "Stop writing results to an output file.")
|
|
def no_tee(arg, **_):
|
|
close_tee()
|
|
return [(None, None, None, "")]
|
|
|
|
|
|
@export
|
|
def write_tee(output):
|
|
global tee_file
|
|
if tee_file:
|
|
click.echo(output, file=tee_file, nl=False)
|
|
click.echo("\n", file=tee_file, nl=False)
|
|
tee_file.flush()
|
|
|
|
|
|
@special_command("\\once", "\\o [-o] filename", "Append next result to an output file (overwrite using -o).", aliases=("\\o",))
|
|
def set_once(arg, **_):
|
|
global once_file, written_to_once_file
|
|
|
|
try:
|
|
once_file = open(**parseargfile(arg))
|
|
except (IOError, OSError) as e:
|
|
raise OSError("Cannot write to file '{}': {}".format(e.filename, e.strerror))
|
|
written_to_once_file = False
|
|
|
|
return [(None, None, None, "")]
|
|
|
|
|
|
@export
|
|
def write_once(output):
|
|
global once_file, written_to_once_file
|
|
if output and once_file:
|
|
click.echo(output, file=once_file, nl=False)
|
|
click.echo("\n", file=once_file, nl=False)
|
|
once_file.flush()
|
|
written_to_once_file = True
|
|
|
|
|
|
@export
|
|
def unset_once_if_written():
|
|
"""Unset the once file, if it has been written to."""
|
|
global once_file, written_to_once_file
|
|
if written_to_once_file and once_file:
|
|
once_file.close()
|
|
once_file = None
|
|
|
|
|
|
@special_command("\\pipe_once", "\\| command", "Send next result to a subprocess.", aliases=("\\|",))
|
|
def set_pipe_once(arg, **_):
|
|
global pipe_once_process, written_to_pipe_once_process
|
|
pipe_once_cmd = shlex.split(arg)
|
|
if len(pipe_once_cmd) == 0:
|
|
raise OSError("pipe_once requires a command")
|
|
written_to_pipe_once_process = False
|
|
pipe_once_process = subprocess.Popen(
|
|
pipe_once_cmd,
|
|
stdin=subprocess.PIPE,
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.PIPE,
|
|
bufsize=1,
|
|
encoding="UTF-8",
|
|
universal_newlines=True,
|
|
)
|
|
return [(None, None, None, "")]
|
|
|
|
|
|
@export
|
|
def write_pipe_once(output):
|
|
global pipe_once_process, written_to_pipe_once_process
|
|
if output and pipe_once_process:
|
|
try:
|
|
click.echo(output, file=pipe_once_process.stdin, nl=False)
|
|
click.echo("\n", file=pipe_once_process.stdin, nl=False)
|
|
except (IOError, OSError) as e:
|
|
pipe_once_process.terminate()
|
|
raise OSError("Failed writing to pipe_once subprocess: {}".format(e.strerror))
|
|
written_to_pipe_once_process = True
|
|
|
|
|
|
@export
|
|
def unset_pipe_once_if_written():
|
|
"""Unset the pipe_once cmd, if it has been written to."""
|
|
global pipe_once_process, written_to_pipe_once_process
|
|
if written_to_pipe_once_process:
|
|
(stdout_data, stderr_data) = pipe_once_process.communicate()
|
|
if len(stdout_data) > 0:
|
|
print(stdout_data.rstrip("\n"))
|
|
if len(stderr_data) > 0:
|
|
print(stderr_data.rstrip("\n"))
|
|
pipe_once_process = None
|
|
written_to_pipe_once_process = False
|
|
|
|
|
|
@special_command("watch", "watch [seconds] [-c] query", "Executes the query every [seconds] seconds (by default 5).")
|
|
def watch_query(arg, **kwargs):
|
|
usage = """Syntax: watch [seconds] [-c] query.
|
|
* seconds: The interval at the query will be repeated, in seconds.
|
|
By default 5.
|
|
* -c: Clears the screen between every iteration.
|
|
"""
|
|
if not arg:
|
|
yield (None, None, None, usage)
|
|
return
|
|
seconds = 5
|
|
clear_screen = False
|
|
statement = None
|
|
while statement is None:
|
|
arg = arg.strip()
|
|
if not arg:
|
|
# Oops, we parsed all the arguments without finding a statement
|
|
yield (None, None, None, usage)
|
|
return
|
|
(current_arg, _, arg) = arg.partition(" ")
|
|
try:
|
|
seconds = float(current_arg)
|
|
continue
|
|
except ValueError:
|
|
pass
|
|
if current_arg == "-c":
|
|
clear_screen = True
|
|
continue
|
|
statement = "{0!s} {1!s}".format(current_arg, arg)
|
|
destructive_prompt = confirm_destructive_query(statement)
|
|
if destructive_prompt is False:
|
|
click.secho("Wise choice!")
|
|
return
|
|
elif destructive_prompt is True:
|
|
click.secho("Your call!")
|
|
cur = kwargs["cur"]
|
|
sql_list = [(sql.rstrip(";"), "> {0!s}".format(sql)) for sql in sqlparse.split(statement)]
|
|
old_pager_enabled = is_pager_enabled()
|
|
while True:
|
|
if clear_screen:
|
|
click.clear()
|
|
try:
|
|
# Somewhere in the code the pager its activated after every yield,
|
|
# so we disable it in every iteration
|
|
set_pager_enabled(False)
|
|
for sql, title in sql_list:
|
|
cur.execute(sql)
|
|
if cur.description:
|
|
headers = [x[0] for x in cur.description]
|
|
yield (title, cur, headers, None)
|
|
else:
|
|
yield (title, None, None, None)
|
|
sleep(seconds)
|
|
except KeyboardInterrupt:
|
|
# This prints the Ctrl-C character in its own line, which prevents
|
|
# to print a line with the cursor positioned behind the prompt
|
|
click.secho("", nl=True)
|
|
return
|
|
finally:
|
|
set_pager_enabled(old_pager_enabled)
|
|
|
|
|
|
@export
|
|
@special_command("delimiter", None, "Change SQL delimiter.")
|
|
def set_delimiter(arg, **_):
|
|
return delimiter_command.set(arg)
|
|
|
|
|
|
@export
|
|
def get_current_delimiter():
|
|
return delimiter_command.current
|
|
|
|
|
|
@export
|
|
def split_queries(input):
|
|
for query in delimiter_command.queries_iter(input):
|
|
yield query
|