1
0
Fork 0
litecli/litecli/main.py
Daniel Baumann d4dff17dce
Merging upstream version 1.15.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
2025-03-17 07:31:48 +01:00

1037 lines
36 KiB
Python

from __future__ import print_function, unicode_literals
import itertools
import logging
import os
import re
import shutil
import sys
import threading
import traceback
from collections import namedtuple
from datetime import datetime
from io import open
from sqlite3 import OperationalError, sqlite_version
from time import time
import click
import sqlparse
from cli_helpers.tabular_output import TabularOutputFormatter, preprocessors
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
from prompt_toolkit.completion import DynamicCompleter
from prompt_toolkit.document import Document
from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode
from prompt_toolkit.filters import HasFocus, IsDone
from prompt_toolkit.formatted_text import ANSI
from prompt_toolkit.history import FileHistory
from prompt_toolkit.layout.processors import (
ConditionalProcessor,
HighlightMatchingBracketProcessor,
)
from prompt_toolkit.lexers import PygmentsLexer
from prompt_toolkit.shortcuts import CompleteStyle, PromptSession
from .__init__ import __version__
from .clibuffer import cli_is_multiline
from .clistyle import style_factory, style_factory_output
from .clitoolbar import create_toolbar_tokens_func
from .completion_refresher import CompletionRefresher
from .config import config_location, ensure_dir_exists, get_config
from .key_bindings import cli_bindings
from .lexer import LiteCliLexer
from .packages import special
from .packages.filepaths import dir_path_exists
from .packages.prompt_utils import confirm, confirm_destructive_query
from .packages.special.main import NO_QUERY
from .sqlcompleter import SQLCompleter
from .sqlexecute import SQLExecute
click.disable_unicode_literals_warning = True
# Query tuples are used for maintaining history
Query = namedtuple("Query", ["query", "successful", "mutating"])
PACKAGE_ROOT = os.path.abspath(os.path.dirname(__file__))
class LiteCli(object):
default_prompt = "\\d> "
max_len_prompt = 45
def __init__(
self,
sqlexecute=None,
prompt=None,
logfile=None,
auto_vertical_output=False,
warn=None,
liteclirc=None,
):
self.sqlexecute = sqlexecute
self.logfile = logfile
# Load config.
c = self.config = get_config(liteclirc)
self.multi_line = c["main"].as_bool("multi_line")
self.key_bindings = c["main"]["key_bindings"]
special.set_favorite_queries(self.config)
self.formatter = TabularOutputFormatter(format_name=c["main"]["table_format"])
self.formatter.litecli = self
self.syntax_style = c["main"]["syntax_style"]
self.less_chatty = c["main"].as_bool("less_chatty")
self.show_bottom_toolbar = c["main"].as_bool("show_bottom_toolbar")
self.cli_style = c["colors"]
self.output_style = style_factory_output(self.syntax_style, self.cli_style)
self.wider_completion_menu = c["main"].as_bool("wider_completion_menu")
self.autocompletion = c["main"].as_bool("autocompletion")
c_dest_warning = c["main"].as_bool("destructive_warning")
self.destructive_warning = c_dest_warning if warn is None else warn
self.login_path_as_host = c["main"].as_bool("login_path_as_host")
# read from cli argument or user config file
self.auto_vertical_output = auto_vertical_output or c["main"].as_bool("auto_vertical_output")
# audit log
if self.logfile is None and "audit_log" in c["main"]:
try:
self.logfile = open(os.path.expanduser(c["main"]["audit_log"]), "a")
except (IOError, OSError):
self.echo(
"Error: Unable to open the audit log file. Your queries will not be logged.",
err=True,
fg="red",
)
self.logfile = False
# Load startup commands.
try:
self.startup_commands = c["startup_commands"]
except KeyError: # Redundant given the load_config() function that merges in the standard config, but put here to avoid fail if user do not have updated config file.
self.startup_commands = None
self.completion_refresher = CompletionRefresher()
self.logger = logging.getLogger(__name__)
self.initialize_logging()
prompt_cnf = self.read_my_cnf_files(["prompt"])["prompt"]
self.prompt_format = prompt or prompt_cnf or c["main"]["prompt"] or self.default_prompt
self.prompt_continuation_format = c["main"]["prompt_continuation"]
keyword_casing = c["main"].get("keyword_casing", "auto")
self.query_history = []
# Initialize completer.
self.completer = SQLCompleter(
supported_formats=self.formatter.supported_formats,
keyword_casing=keyword_casing,
)
self._completer_lock = threading.Lock()
# Register custom special commands.
self.register_special_commands()
self.prompt_app = None
def register_special_commands(self):
special.register_special_command(
self.change_db,
".open",
".open",
"Change to a new database.",
aliases=("use", "\\u"),
)
special.register_special_command(
self.refresh_completions,
"rehash",
"\\#",
"Refresh auto-completions.",
arg_type=NO_QUERY,
aliases=("\\#",),
)
special.register_special_command(
self.change_table_format,
".mode",
"\\T",
"Change the table format used to output results.",
aliases=("tableformat", "\\T"),
case_sensitive=True,
)
special.register_special_command(
self.execute_from_file,
".read",
"\\. filename",
"Execute commands from file.",
case_sensitive=True,
aliases=("\\.", "source"),
)
special.register_special_command(
self.change_prompt_format,
"prompt",
"\\R",
"Change prompt format.",
aliases=("\\R",),
case_sensitive=True,
)
def change_table_format(self, arg, **_):
try:
self.formatter.format_name = arg
yield (None, None, None, "Changed table format to {}".format(arg))
except ValueError:
msg = "Table format {} not recognized. Allowed formats:".format(arg)
for table_type in self.formatter.supported_formats:
msg += "\n\t{}".format(table_type)
yield (None, None, None, msg)
def change_db(self, arg, **_):
if arg is None:
self.sqlexecute.connect()
else:
self.sqlexecute.connect(database=arg)
self.refresh_completions()
yield (
None,
None,
None,
'You are now connected to database "%s"' % (self.sqlexecute.dbname),
)
def execute_from_file(self, arg, **_):
if not arg:
message = "Missing required argument, filename."
return [(None, None, None, message)]
try:
with open(os.path.expanduser(arg), encoding="utf-8") as f:
query = f.read()
except IOError as e:
return [(None, None, None, str(e))]
if self.destructive_warning and confirm_destructive_query(query) is False:
message = "Wise choice. Command execution stopped."
return [(None, None, None, message)]
return self.sqlexecute.run(query)
def change_prompt_format(self, arg, **_):
"""
Change the prompt format.
"""
if not arg:
message = "Missing required argument, format."
return [(None, None, None, message)]
self.prompt_format = self.get_prompt(arg)
return [(None, None, None, "Changed prompt format to %s" % arg)]
def initialize_logging(self):
log_file = self.config["main"]["log_file"]
if log_file == "default":
log_file = config_location() + "log"
try:
ensure_dir_exists(log_file)
except OSError:
# Unable to create log file, log to temp directory instead.
log_file = "/tmp/litecli.log"
log_level = self.config["main"]["log_level"]
level_map = {
"CRITICAL": logging.CRITICAL,
"ERROR": logging.ERROR,
"WARNING": logging.WARNING,
"INFO": logging.INFO,
"DEBUG": logging.DEBUG,
}
# Disable logging if value is NONE by switching to a no-op handler
# Set log level to a high value so it doesn't even waste cycles getting called.
if log_level.upper() == "NONE":
handler = logging.NullHandler()
log_level = "CRITICAL"
elif dir_path_exists(log_file):
handler = logging.FileHandler(log_file)
else:
self.echo(
'Error: Unable to open the log file "{}".'.format(log_file),
err=True,
fg="red",
)
return
formatter = logging.Formatter("%(asctime)s (%(process)d/%(threadName)s) " "%(name)s %(levelname)s - %(message)s")
handler.setFormatter(formatter)
root_logger = logging.getLogger("litecli")
root_logger.addHandler(handler)
root_logger.setLevel(level_map[log_level.upper()])
logging.captureWarnings(True)
root_logger.debug("Initializing litecli logging.")
root_logger.debug("Log file %r.", log_file)
def read_my_cnf_files(self, keys):
"""
Reads a list of config files and merges them. The last one will win.
:param files: list of files to read
:param keys: list of keys to retrieve
:returns: tuple, with None for missing keys.
"""
cnf = self.config
sections = ["main"]
def get(key):
result = None
for sect in cnf:
if sect in sections and key in cnf[sect]:
result = cnf[sect][key]
return result
return {x: get(x) for x in keys}
def connect(self, database=""):
cnf = {"database": None}
cnf = self.read_my_cnf_files(cnf.keys())
# Fall back to config values only if user did not specify a value.
database = database or cnf["database"]
# Connect to the database.
def _connect():
self.sqlexecute = SQLExecute(database)
try:
_connect()
except Exception as e: # Connecting to a database could fail.
self.logger.debug("Database connection failed: %r.", e)
self.logger.error("traceback: %r", traceback.format_exc())
self.echo(str(e), err=True, fg="red")
exit(1)
def handle_editor_command(self, text):
R"""Editor command is any query that is prefixed or suffixed by a '\e'.
The reason for a while loop is because a user might edit a query
multiple times. For eg:
"select * from \e"<enter> to edit it in vim, then come
back to the prompt with the edited query "select * from
blah where q = 'abc'\e" to edit it again.
:param text: Document
:return: Document
"""
while special.editor_command(text):
filename = special.get_filename(text)
query = special.get_editor_query(text) or self.get_last_query()
sql, message = special.open_external_editor(filename, sql=query)
if message:
# Something went wrong. Raise an exception and bail.
raise RuntimeError(message)
while True:
try:
text = self.prompt_app.prompt(default=sql)
break
except KeyboardInterrupt:
sql = ""
continue
return text
def run_cli(self):
iterations = 0
sqlexecute = self.sqlexecute
logger = self.logger
self.configure_pager()
self.refresh_completions()
history_file = self.config["main"]["history_file"]
if history_file == "default":
history_file = config_location() + "history"
history_file = os.path.expanduser(history_file)
if dir_path_exists(history_file):
history = FileHistory(history_file)
else:
history = None
self.echo(
'Error: Unable to open the history file "{}". ' "Your query history will not be saved.".format(history_file),
err=True,
fg="red",
)
key_bindings = cli_bindings(self)
if not self.less_chatty:
print(f"LiteCli: {__version__} (SQLite: {sqlite_version})")
print("GitHub: https://github.com/dbcli/litecli")
def get_message():
prompt = self.get_prompt(self.prompt_format)
if self.prompt_format == self.default_prompt and len(prompt) > self.max_len_prompt:
prompt = self.get_prompt("\\d> ")
prompt = prompt.replace("\\x1b", "\x1b")
return ANSI(prompt)
def get_continuation(width, line_number, is_soft_wrap):
continuation = " " * (width - 1) + " "
return [("class:continuation", continuation)]
def show_suggestion_tip():
return iterations < 2
def output_res(res, start):
result_count = 0
mutating = False
for title, cur, headers, status in res:
logger.debug("headers: %r", headers)
logger.debug("rows: %r", cur)
logger.debug("status: %r", status)
threshold = 1000
if is_select(status) and cur and cur.rowcount > threshold:
self.echo(
"The result set has more than {} rows.".format(threshold),
fg="red",
)
if not confirm("Do you want to continue?"):
self.echo("Aborted!", err=True, fg="red")
break
if self.auto_vertical_output:
max_width = self.prompt_app.output.get_size().columns
else:
max_width = None
formatted = self.format_output(title, cur, headers, special.is_expanded_output(), max_width)
t = time() - start
try:
if result_count > 0:
self.echo("")
try:
self.output(formatted, status)
except KeyboardInterrupt:
pass
self.echo("Time: %0.03fs" % t)
except KeyboardInterrupt:
pass
start = time()
result_count += 1
mutating = mutating or is_mutating(status)
return mutating
def one_iteration(text=None):
if text is None:
try:
text = self.prompt_app.prompt()
except KeyboardInterrupt:
return
special.set_expanded_output(False)
try:
text = self.handle_editor_command(text)
except RuntimeError as e:
logger.error("sql: %r, error: %r", text, e)
logger.error("traceback: %r", traceback.format_exc())
self.echo(str(e), err=True, fg="red")
return
while special.is_llm_command(text):
try:
start = time()
cur = self.sqlexecute.conn and self.sqlexecute.conn.cursor()
context, sql, duration = special.handle_llm(text, cur)
if context:
click.echo("LLM Reponse:")
click.echo(context)
click.echo('---')
click.echo(f"Time: {duration:.2f} seconds")
text = self.prompt_app.prompt(default=sql)
except KeyboardInterrupt:
return
except special.FinishIteration as e:
return output_res(e.results, start) if e.results else None
except RuntimeError as e:
logger.error("sql: %r, error: %r", text, e)
logger.error("traceback: %r", traceback.format_exc())
self.echo(str(e), err=True, fg="red")
return
if not text.strip():
return
if self.destructive_warning:
destroy = confirm_destructive_query(text)
if destroy is None:
pass # Query was not destructive. Nothing to do here.
elif destroy is True:
self.echo("Your call!")
else:
self.echo("Wise choice!")
return
mutating = False
try:
logger.debug("sql: %r", text)
special.write_tee(self.get_prompt(self.prompt_format) + text)
if self.logfile:
self.logfile.write("\n# %s\n" % datetime.now())
self.logfile.write(text)
self.logfile.write("\n")
successful = False
start = time()
res = sqlexecute.run(text)
self.formatter.query = text
successful = True
special.unset_once_if_written()
# Keep track of whether or not the query is mutating. In case
# of a multi-statement query, the overall query is considered
# mutating if any one of the component statements is mutating
mutating = output_res(res, start)
special.unset_pipe_once_if_written()
except EOFError as e:
raise e
except KeyboardInterrupt:
try:
sqlexecute.conn.interrupt()
except Exception as e:
self.echo(
"Encountered error while cancelling query: {}".format(e),
err=True,
fg="red",
)
else:
logger.debug("cancelled query")
self.echo("cancelled query", err=True, fg="red")
except NotImplementedError:
self.echo("Not Yet Implemented.", fg="yellow")
except OperationalError as e:
logger.debug("Exception: %r", e)
if e.args[0] in (2003, 2006, 2013):
logger.debug("Attempting to reconnect.")
self.echo("Reconnecting...", fg="yellow")
try:
sqlexecute.connect()
logger.debug("Reconnected successfully.")
one_iteration(text)
return # OK to just return, cuz the recursion call runs to the end.
except OperationalError as e:
logger.debug("Reconnect failed. e: %r", e)
self.echo(str(e), err=True, fg="red")
# If reconnection failed, don't proceed further.
return
else:
logger.error("sql: %r, error: %r", text, e)
logger.error("traceback: %r", traceback.format_exc())
self.echo(str(e), err=True, fg="red")
except Exception as e:
logger.error("sql: %r, error: %r", text, e)
logger.error("traceback: %r", traceback.format_exc())
self.echo(str(e), err=True, fg="red")
else:
# Refresh the table names and column names if necessary.
if need_completion_refresh(text):
self.refresh_completions(reset=need_completion_reset(text))
finally:
if self.logfile is False:
self.echo("Warning: This query was not logged.", err=True, fg="red")
query = Query(text, successful, mutating)
self.query_history.append(query)
get_toolbar_tokens = create_toolbar_tokens_func(self, show_suggestion_tip)
if self.wider_completion_menu:
complete_style = CompleteStyle.MULTI_COLUMN
else:
complete_style = CompleteStyle.COLUMN
if not self.autocompletion:
complete_style = CompleteStyle.READLINE_LIKE
with self._completer_lock:
if self.key_bindings == "vi":
editing_mode = EditingMode.VI
else:
editing_mode = EditingMode.EMACS
self.prompt_app = PromptSession(
lexer=PygmentsLexer(LiteCliLexer),
reserve_space_for_menu=self.get_reserved_space(),
message=get_message,
prompt_continuation=get_continuation,
bottom_toolbar=get_toolbar_tokens if self.show_bottom_toolbar else None,
complete_style=complete_style,
input_processors=[
ConditionalProcessor(
processor=HighlightMatchingBracketProcessor(chars="[](){}"),
filter=HasFocus(DEFAULT_BUFFER) & ~IsDone(),
)
],
tempfile_suffix=".sql",
completer=DynamicCompleter(lambda: self.completer),
history=history,
auto_suggest=AutoSuggestFromHistory(),
complete_while_typing=True,
multiline=cli_is_multiline(self),
style=style_factory(self.syntax_style, self.cli_style),
include_default_pygments_style=False,
key_bindings=key_bindings,
enable_open_in_editor=True,
enable_system_prompt=True,
enable_suspend=True,
editing_mode=editing_mode,
search_ignore_case=True,
)
def startup_commands():
if self.startup_commands:
if "commands" in self.startup_commands:
if isinstance(self.startup_commands["commands"], str):
commands = [self.startup_commands["commands"]]
else:
commands = self.startup_commands["commands"]
for command in commands:
try:
res = sqlexecute.run(command)
except Exception as e:
click.echo(command)
self.echo(str(e), err=True, fg="red")
else:
click.echo(command)
for title, cur, headers, status in res:
if title == "dot command not implemented":
self.echo(
"The SQLite dot command '" + command.split(" ", 1)[0] + "' is not yet implemented.",
fg="yellow",
)
else:
output = self.format_output(title, cur, headers)
for line in output:
self.echo(line)
else:
self.echo(
"Could not read commands. The startup commands needs to be formatted as: \n commands = 'command1', 'command2', ...",
fg="yellow",
)
try:
startup_commands()
except Exception as e:
self.echo("Could not execute all startup commands: \n" + str(e), fg="yellow")
try:
while True:
one_iteration()
iterations += 1
except EOFError:
special.close_tee()
if not self.less_chatty:
self.echo("Goodbye!")
def log_output(self, output):
"""Log the output in the audit log, if it's enabled."""
if self.logfile:
click.echo(output, file=self.logfile)
def echo(self, s, **kwargs):
"""Print a message to stdout.
The message will be logged in the audit log, if enabled.
All keyword arguments are passed to click.echo().
"""
self.log_output(s)
click.secho(s, **kwargs)
def get_output_margin(self, status=None):
"""Get the output margin (number of rows for the prompt, footer and
timing message."""
margin = self.get_reserved_space() + self.get_prompt(self.prompt_format).count("\n") + 2
if status:
margin += 1 + status.count("\n")
return margin
def output(self, output, status=None):
"""Output text to stdout or a pager command.
The status text is not outputted to pager or files.
The message will be logged in the audit log, if enabled. The
message will be written to the tee file, if enabled. The
message will be written to the output file, if enabled.
"""
if output:
size = self.prompt_app.output.get_size()
margin = self.get_output_margin(status)
fits = True
buf = []
output_via_pager = self.explicit_pager and special.is_pager_enabled()
for i, line in enumerate(output, 1):
self.log_output(line)
special.write_tee(line)
special.write_once(line)
special.write_pipe_once(line)
if fits or output_via_pager:
# buffering
buf.append(line)
if len(line) > size.columns or i > (size.rows - margin):
fits = False
if not self.explicit_pager and special.is_pager_enabled():
# doesn't fit, use pager
output_via_pager = True
if not output_via_pager:
# doesn't fit, flush buffer
for line in buf:
click.secho(line)
buf = []
else:
click.secho(line)
if buf:
if output_via_pager:
# sadly click.echo_via_pager doesn't accept generators
click.echo_via_pager("\n".join(buf))
else:
for line in buf:
click.secho(line)
if status:
self.log_output(status)
click.secho(status)
def configure_pager(self):
# Provide sane defaults for less if they are empty.
if not os.environ.get("LESS"):
os.environ["LESS"] = "-RXF"
cnf = self.read_my_cnf_files(["pager", "skip-pager"])
if cnf["pager"]:
special.set_pager(cnf["pager"])
self.explicit_pager = True
else:
self.explicit_pager = False
if cnf["skip-pager"] or not self.config["main"].as_bool("enable_pager"):
special.disable_pager()
def refresh_completions(self, reset=False):
if reset:
with self._completer_lock:
self.completer.reset_completions()
self.completion_refresher.refresh(
self.sqlexecute,
self._on_completions_refreshed,
{
"supported_formats": self.formatter.supported_formats,
"keyword_casing": self.completer.keyword_casing,
},
)
return [(None, None, None, "Auto-completion refresh started in the background.")]
def _on_completions_refreshed(self, new_completer):
"""Swap the completer object in cli with the newly created completer."""
with self._completer_lock:
self.completer = new_completer
if self.prompt_app:
# After refreshing, redraw the CLI to clear the statusbar
# "Refreshing completions..." indicator
self.prompt_app.app.invalidate()
def get_completions(self, text, cursor_positition):
with self._completer_lock:
return self.completer.get_completions(Document(text=text, cursor_position=cursor_positition), None)
def get_prompt(self, string):
self.logger.debug("Getting prompt %r", string)
sqlexecute = self.sqlexecute
now = datetime.now()
# Prepare the replacements dictionary
replacements = {
r"\d": sqlexecute.dbname or "(none)",
r"\f": os.path.basename(sqlexecute.dbname or "(none)"),
r"\n": "\n",
r"\D": now.strftime("%a %b %d %H:%M:%S %Y"),
r"\m": now.strftime("%M"),
r"\P": now.strftime("%p"),
r"\R": now.strftime("%H"),
r"\r": now.strftime("%I"),
r"\s": now.strftime("%S"),
r"\_": " ",
}
# Compile a regex pattern that matches any of the keys in replacements
pattern = re.compile("|".join(re.escape(key) for key in replacements.keys()))
# Define the replacement function
def replacer(match):
return replacements[match.group(0)]
# Perform the substitution
return pattern.sub(replacer, string)
def run_query(self, query, new_line=True):
"""Runs *query*."""
results = self.sqlexecute.run(query)
for result in results:
title, cur, headers, status = result
self.formatter.query = query
output = self.format_output(title, cur, headers)
for line in output:
click.echo(line, nl=new_line)
def format_output(self, title, cur, headers, expanded=False, max_width=None):
expanded = expanded or self.formatter.format_name == "vertical"
output = []
output_kwargs = {
"dialect": "unix",
"disable_numparse": True,
"preserve_whitespace": True,
"preprocessors": (preprocessors.align_decimals,),
"style": self.output_style,
}
if title: # Only print the title if it's not None.
output = itertools.chain(output, [title])
if cur:
column_types = None
if hasattr(cur, "description"):
column_types = [str(col) for col in cur.description]
if max_width is not None:
cur = list(cur)
formatted = self.formatter.format_output(
cur,
headers,
format_name="vertical" if expanded else None,
column_types=column_types,
**output_kwargs,
)
if isinstance(formatted, str):
formatted = formatted.splitlines()
formatted = iter(formatted)
first_line = next(formatted)
formatted = itertools.chain([first_line], formatted)
if not expanded and max_width and headers and cur and len(first_line) > max_width:
formatted = self.formatter.format_output(
cur,
headers,
format_name="vertical",
column_types=column_types,
**output_kwargs,
)
if isinstance(formatted, str):
formatted = iter(formatted.splitlines())
output = itertools.chain(output, formatted)
return output
def get_reserved_space(self):
"""Get the number of lines to reserve for the completion menu."""
reserved_space_ratio = 0.45
max_reserved_space = 8
_, height = shutil.get_terminal_size()
return min(int(round(height * reserved_space_ratio)), max_reserved_space)
def get_last_query(self):
"""Get the last query executed or None."""
return self.query_history[-1][0] if self.query_history else None
@click.command()
@click.version_option(__version__, "-V", "--version")
@click.option("-D", "--database", "dbname", help="Database to use.")
@click.option(
"-R",
"--prompt",
"prompt",
help='Prompt format (Default: "{0}").'.format(LiteCli.default_prompt),
)
@click.option(
"-l",
"--logfile",
type=click.File(mode="a", encoding="utf-8"),
help="Log every query and its results to a file.",
)
@click.option(
"--liteclirc",
default=config_location() + "config",
help="Location of liteclirc file.",
type=click.Path(dir_okay=False),
)
@click.option(
"--auto-vertical-output",
is_flag=True,
help="Automatically switch to vertical output mode if the result is wider than the terminal width.",
)
@click.option("-t", "--table", is_flag=True, help="Display batch output in table format.")
@click.option("--csv", is_flag=True, help="Display batch output in CSV format.")
@click.option("--warn/--no-warn", default=None, help="Warn before running a destructive query.")
@click.option("-e", "--execute", type=str, help="Execute command and quit.")
@click.argument("database", default="", nargs=1)
def cli(
database,
dbname,
prompt,
logfile,
auto_vertical_output,
table,
csv,
warn,
execute,
liteclirc,
):
"""A SQLite terminal client with auto-completion and syntax highlighting.
\b
Examples:
- litecli lite_database
"""
litecli = LiteCli(
prompt=prompt,
logfile=logfile,
auto_vertical_output=auto_vertical_output,
warn=warn,
liteclirc=liteclirc,
)
# Choose which ever one has a valid value.
database = database or dbname
litecli.connect(database)
litecli.logger.debug("Launch Params: \n" "\tdatabase: %r", database)
# --execute argument
if execute:
try:
if csv:
litecli.formatter.format_name = "csv"
elif not table:
litecli.formatter.format_name = "tsv"
litecli.run_query(execute)
exit(0)
except Exception as e:
click.secho(str(e), err=True, fg="red")
exit(1)
if sys.stdin.isatty():
litecli.run_cli()
else:
stdin = click.get_text_stream("stdin")
stdin_text = stdin.read()
try:
sys.stdin = open("/dev/tty")
except (FileNotFoundError, OSError):
litecli.logger.warning("Unable to open TTY as stdin.")
if litecli.destructive_warning and confirm_destructive_query(stdin_text) is False:
exit(0)
try:
new_line = True
if csv:
litecli.formatter.format_name = "csv"
elif not table:
litecli.formatter.format_name = "tsv"
litecli.run_query(stdin_text, new_line=new_line)
exit(0)
except Exception as e:
click.secho(str(e), err=True, fg="red")
exit(1)
def need_completion_refresh(queries):
"""Determines if the completion needs a refresh by checking if the sql
statement is an alter, create, drop or change db."""
for query in sqlparse.split(queries):
try:
first_token = query.split()[0]
if first_token.lower() in (
"alter",
"create",
"use",
"\\r",
"\\u",
"connect",
"drop",
):
return True
except Exception:
return False
def need_completion_reset(queries):
"""Determines if the statement is a database switch such as 'use' or '\\u'.
When a database is changed the existing completions must be reset before we
start the completion refresh for the new database.
"""
for query in sqlparse.split(queries):
try:
first_token = query.split()[0]
if first_token.lower() in ("use", "\\u"):
return True
except Exception:
return False
def is_mutating(status):
"""Determines if the statement is mutating based on the status."""
if not status:
return False
mutating = set(
[
"insert",
"update",
"delete",
"alter",
"create",
"drop",
"replace",
"truncate",
"load",
]
)
return status.split(None, 1)[0].lower() in mutating
def is_select(status):
"""Returns true if the first word in status is 'select'."""
if not status:
return False
return status.split(None, 1)[0].lower() == "select"
if __name__ == "__main__":
cli()