diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8e327a9..2b71bcb 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -12,7 +12,7 @@ jobs: strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 6073ec5..0491657 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -13,7 +13,7 @@ jobs: strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] steps: - uses: actions/checkout@v4 diff --git a/CHANGELOG.md b/CHANGELOG.md index f1e8c68..328261f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,32 @@ +## 1.14.2 - 2025-01-26 + +### Bug Fixes + +* Catch errors surfaced by `llm` cli and surface them as runtime errors. + +## 1.14.1 - 2025-01-25 + +### Bug Fixes + +* Capture stderr in addition to stdout when capturing output from `llm` cli. + +## 1.14.0 - 2025-01-22 + +### Features + +* Add LLM feature to ask an LLM to create a SQL query. + - This adds a new `\llm` special command + - eg: `\llm "Who is the largest customer based on revenue?"` + +### Bug Fixes + +* Fix the [windows path](https://github.com/dbcli/litecli/issues/187) shown in prompt to remove escaping. +* Fix a bug where if column name was same as table name it was [crashing](https://github.com/dbcli/litecli/issues/155) the autocompletion. + +### Internal + +* Change min required python version to 3.9+ + ## 1.13.2 - 2024-11-24 ### Internal diff --git a/README.md b/README.md index 81f0769..ac4b3b8 100644 --- a/README.md +++ b/README.md @@ -6,8 +6,8 @@ A command-line client for SQLite databases that has auto-completion and syntax highlighting. -![Completion](screenshots/litecli.png) -![CompletionGif](screenshots/litecli.gif) +![Completion](https://raw.githubusercontent.com/dbcli/litecli/refs/heads/main/screenshots/litecli.png) +![CompletionGif](https://raw.githubusercontent.com/dbcli/litecli/refs/heads/main/screenshots/litecli.gif) ## Installation diff --git a/litecli/main.py b/litecli/main.py index a0607ab..7e5a817 100644 --- a/litecli/main.py +++ b/litecli/main.py @@ -1,52 +1,50 @@ -from __future__ import unicode_literals -from __future__ import print_function +from __future__ import print_function, unicode_literals -import os -import sys -import traceback +import itertools import logging +import os +import re +import shutil +import sys import threading -from time import time +import traceback +from collections import namedtuple from datetime import datetime from io import open -from collections import namedtuple from sqlite3 import OperationalError, sqlite_version -import shutil +from time import time -from cli_helpers.tabular_output import TabularOutputFormatter -from cli_helpers.tabular_output import preprocessors import click import sqlparse +from cli_helpers.tabular_output import TabularOutputFormatter, preprocessors +from prompt_toolkit.auto_suggest import AutoSuggestFromHistory from prompt_toolkit.completion import DynamicCompleter -from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode -from prompt_toolkit.shortcuts import PromptSession, CompleteStyle from prompt_toolkit.document import Document +from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode from prompt_toolkit.filters import HasFocus, IsDone from prompt_toolkit.formatted_text import ANSI +from prompt_toolkit.history import FileHistory from prompt_toolkit.layout.processors import ( - HighlightMatchingBracketProcessor, ConditionalProcessor, + HighlightMatchingBracketProcessor, ) from prompt_toolkit.lexers import PygmentsLexer -from prompt_toolkit.history import FileHistory -from prompt_toolkit.auto_suggest import AutoSuggestFromHistory +from prompt_toolkit.shortcuts import CompleteStyle, PromptSession -from .packages.special.main import NO_QUERY -from .packages.prompt_utils import confirm, confirm_destructive_query -from .packages import special -from .sqlcompleter import SQLCompleter -from .clitoolbar import create_toolbar_tokens_func -from .clistyle import style_factory, style_factory_output -from .sqlexecute import SQLExecute +from .__init__ import __version__ from .clibuffer import cli_is_multiline +from .clistyle import style_factory, style_factory_output +from .clitoolbar import create_toolbar_tokens_func from .completion_refresher import CompletionRefresher from .config import config_location, ensure_dir_exists, get_config from .key_bindings import cli_bindings from .lexer import LiteCliLexer -from .__init__ import __version__ +from .packages import special from .packages.filepaths import dir_path_exists - -import itertools +from .packages.prompt_utils import confirm, confirm_destructive_query +from .packages.special.main import NO_QUERY +from .sqlcompleter import SQLCompleter +from .sqlexecute import SQLExecute click.disable_unicode_literals_warning = True @@ -385,6 +383,47 @@ class LiteCli(object): def show_suggestion_tip(): return iterations < 2 + def output_res(res, start): + result_count = 0 + mutating = False + for title, cur, headers, status in res: + logger.debug("headers: %r", headers) + logger.debug("rows: %r", cur) + logger.debug("status: %r", status) + threshold = 1000 + if is_select(status) and cur and cur.rowcount > threshold: + self.echo( + "The result set has more than {} rows.".format(threshold), + fg="red", + ) + if not confirm("Do you want to continue?"): + self.echo("Aborted!", err=True, fg="red") + break + + if self.auto_vertical_output: + max_width = self.prompt_app.output.get_size().columns + else: + max_width = None + + formatted = self.format_output(title, cur, headers, special.is_expanded_output(), max_width) + + t = time() - start + try: + if result_count > 0: + self.echo("") + try: + self.output(formatted, status) + except KeyboardInterrupt: + pass + self.echo("Time: %0.03fs" % t) + except KeyboardInterrupt: + pass + + start = time() + result_count += 1 + mutating = mutating or is_mutating(status) + return mutating + def one_iteration(text=None): if text is None: try: @@ -402,6 +441,24 @@ class LiteCli(object): self.echo(str(e), err=True, fg="red") return + if special.is_llm_command(text): + try: + start = time() + cur = self.sqlexecute.conn.cursor() + context, sql = special.handle_llm(text, cur) + if context: + click.echo(context) + text = self.prompt_app.prompt(default=sql) + except KeyboardInterrupt: + return + except special.FinishIteration as e: + return output_res(e.results, start) if e.results else None + except RuntimeError as e: + logger.error("sql: %r, error: %r", text, e) + logger.error("traceback: %r", traceback.format_exc()) + self.echo(str(e), err=True, fg="red") + return + if not text.strip(): return @@ -415,9 +472,6 @@ class LiteCli(object): self.echo("Wise choice!") return - # Keep track of whether or not the query is mutating. In case - # of a multi-statement query, the overall query is considered - # mutating if any one of the component statements is mutating mutating = False try: @@ -434,44 +488,11 @@ class LiteCli(object): res = sqlexecute.run(text) self.formatter.query = text successful = True - result_count = 0 - for title, cur, headers, status in res: - logger.debug("headers: %r", headers) - logger.debug("rows: %r", cur) - logger.debug("status: %r", status) - threshold = 1000 - if is_select(status) and cur and cur.rowcount > threshold: - self.echo( - "The result set has more than {} rows.".format(threshold), - fg="red", - ) - if not confirm("Do you want to continue?"): - self.echo("Aborted!", err=True, fg="red") - break - - if self.auto_vertical_output: - max_width = self.prompt_app.output.get_size().columns - else: - max_width = None - - formatted = self.format_output(title, cur, headers, special.is_expanded_output(), max_width) - - t = time() - start - try: - if result_count > 0: - self.echo("") - try: - self.output(formatted, status) - except KeyboardInterrupt: - pass - self.echo("Time: %0.03fs" % t) - except KeyboardInterrupt: - pass - - start = time() - result_count += 1 - mutating = mutating or is_mutating(status) special.unset_once_if_written() + # Keep track of whether or not the query is mutating. In case + # of a multi-statement query, the overall query is considered + # mutating if any one of the component statements is mutating + mutating = output_res(res, start) special.unset_pipe_once_if_written() except EOFError as e: raise e @@ -735,20 +756,32 @@ class LiteCli(object): return self.completer.get_completions(Document(text=text, cursor_position=cursor_positition), None) def get_prompt(self, string): - self.logger.debug("Getting prompt") + self.logger.debug("Getting prompt %r", string) sqlexecute = self.sqlexecute now = datetime.now() - string = string.replace("\\d", sqlexecute.dbname or "(none)") - string = string.replace("\\f", os.path.basename(sqlexecute.dbname or "(none)")) - string = string.replace("\\n", "\n") - string = string.replace("\\D", now.strftime("%a %b %d %H:%M:%S %Y")) - string = string.replace("\\m", now.strftime("%M")) - string = string.replace("\\P", now.strftime("%p")) - string = string.replace("\\R", now.strftime("%H")) - string = string.replace("\\r", now.strftime("%I")) - string = string.replace("\\s", now.strftime("%S")) - string = string.replace("\\_", " ") - return string + + # Prepare the replacements dictionary + replacements = { + r"\d": sqlexecute.dbname or "(none)", + r"\f": os.path.basename(sqlexecute.dbname or "(none)"), + r"\n": "\n", + r"\D": now.strftime("%a %b %d %H:%M:%S %Y"), + r"\m": now.strftime("%M"), + r"\P": now.strftime("%p"), + r"\R": now.strftime("%H"), + r"\r": now.strftime("%I"), + r"\s": now.strftime("%S"), + r"\_": " ", + } + # Compile a regex pattern that matches any of the keys in replacements + pattern = re.compile("|".join(re.escape(key) for key in replacements.keys())) + + # Define the replacement function + def replacer(match): + return replacements[match.group(0)] + + # Perform the substitution + return pattern.sub(replacer, string) def run_query(self, query, new_line=True): """Runs *query*.""" diff --git a/litecli/packages/completion_engine.py b/litecli/packages/completion_engine.py index 05b70ac..2d9a033 100644 --- a/litecli/packages/completion_engine.py +++ b/litecli/packages/completion_engine.py @@ -118,6 +118,9 @@ def suggest_special(text): else: return [{"type": "table", "schema": []}] + if cmd in [".llm", ".ai", "\\llm", "\\ai"]: + return [{"type": "llm"}] + return [{"type": "keyword"}, {"type": "special"}] diff --git a/litecli/packages/special/__init__.py b/litecli/packages/special/__init__.py index 5924d09..0338c36 100644 --- a/litecli/packages/special/__init__.py +++ b/litecli/packages/special/__init__.py @@ -12,3 +12,4 @@ def export(defn): from . import dbcommands from . import iocommands +from . import llm diff --git a/litecli/packages/special/dbcommands.py b/litecli/packages/special/dbcommands.py index d779764..315f6c7 100644 --- a/litecli/packages/special/dbcommands.py +++ b/litecli/packages/special/dbcommands.py @@ -6,6 +6,7 @@ import sys import platform import shlex + from litecli import __version__ from litecli.packages.special import iocommands from .main import special_command, RAW_QUERY, PARSED_QUERY diff --git a/litecli/packages/special/iocommands.py b/litecli/packages/special/iocommands.py index eeba814..ec65672 100644 --- a/litecli/packages/special/iocommands.py +++ b/litecli/packages/special/iocommands.py @@ -1,10 +1,11 @@ from __future__ import unicode_literals -import os -import re + import locale import logging -import subprocess +import os +import re import shlex +import subprocess from io import open from time import sleep @@ -12,11 +13,11 @@ import click import sqlparse from configobj import ConfigObj +from ..prompt_utils import confirm_destructive_query from . import export -from .main import special_command, NO_QUERY, PARSED_QUERY from .favoritequeries import FavoriteQueries +from .main import NO_QUERY, PARSED_QUERY, special_command from .utils import handle_cd_command -from litecli.packages.prompt_utils import confirm_destructive_query use_expanded_output = False PAGER_ENABLED = True @@ -27,6 +28,8 @@ pipe_once_process = None written_to_pipe_once_process = False favoritequeries = FavoriteQueries(ConfigObj()) +log = logging.getLogger(__name__) + @export def set_favorite_queries(config): @@ -95,9 +98,6 @@ def is_expanded_output(): return use_expanded_output -_logger = logging.getLogger(__name__) - - @export def editor_command(command): """ diff --git a/litecli/packages/special/llm.py b/litecli/packages/special/llm.py new file mode 100644 index 0000000..69375d7 --- /dev/null +++ b/litecli/packages/special/llm.py @@ -0,0 +1,336 @@ +import contextlib +import io +import logging +import os +import re +import shlex +import sys +from runpy import run_module +from typing import Optional, Tuple + +import click + +try: + import llm + from llm.cli import cli + + LLM_CLI_COMMANDS = list(cli.commands.keys()) + MODELS = {x.model_id: None for x in llm.get_models()} +except ImportError: + llm = None + cli = None + +from . import export +from .main import parse_special_command + +log = logging.getLogger(__name__) + + +def run_external_cmd(cmd, *args, capture_output=False, restart_cli=False, raise_exception=True): + original_exe = sys.executable + original_args = sys.argv + + try: + sys.argv = [cmd] + list(args) + code = 0 + + if capture_output: + buffer = io.StringIO() + redirect = contextlib.ExitStack() + redirect.enter_context(contextlib.redirect_stdout(buffer)) + redirect.enter_context(contextlib.redirect_stderr(buffer)) + else: + redirect = contextlib.nullcontext() + + with redirect: + try: + run_module(cmd, run_name="__main__") + except SystemExit as e: + code = e.code + if code != 0 and raise_exception: + if capture_output: + raise RuntimeError(buffer.getvalue()) + else: + raise RuntimeError(f"Command {cmd} failed with exit code {code}.") + + if restart_cli and code == 0: + os.execv(original_exe, [original_exe] + original_args) + + if capture_output: + return code, buffer.getvalue() + else: + return code, "" + finally: + sys.argv = original_args + + +def build_command_tree(cmd): + """Recursively build a command tree for a Click app. + + Args: + cmd (click.Command or click.Group): The Click command/group to inspect. + + Returns: + dict: A nested dictionary representing the command structure. + """ + tree = {} + if isinstance(cmd, click.Group): + for name, subcmd in cmd.commands.items(): + if cmd.name == "models" and name == "default": + tree[name] = MODELS + else: + # Recursively build the tree for subcommands + tree[name] = build_command_tree(subcmd) + else: + # Leaf command with no subcommands + tree = None + return tree + + +# Generate the tree +COMMAND_TREE = build_command_tree(cli) + + +def get_completions(tokens, tree=COMMAND_TREE): + """Get autocompletions for the current command tokens. + + Args: + tree (dict): The command tree. + tokens (list): List of tokens (command arguments). + + Returns: + list: List of possible completions. + """ + for token in tokens: + if token.startswith("-"): + # Skip options (flags) + continue + if tree and token in tree: + tree = tree[token] + else: + # No completions available + return [] + + # Return possible completions (keys of the current tree level) + return list(tree.keys()) if tree else [] + + +@export +class FinishIteration(Exception): + def __init__(self, results=None): + self.results = results + + +USAGE = """ +Use an LLM to create SQL queries to answer questions from your database. +Examples: + +# Ask a question. +> \\llm 'Most visited urls?' + +# List available models +> \\llm models +gpt-4o +gpt-3.5-turbo +qwq + +# Change default model +> \\llm models default llama3 + +# Set api key (not required for local models) +> \\llm keys set openai sg-1234 +API key set for openai. + +# Install a model plugin +> \\llm install llm-ollama +llm-ollama installed. + +# Models directory +# https://llm.datasette.io/en/stable/plugins/directory.html +""" + +_SQL_CODE_FENCE = r"```sql\n(.*?)\n```" +PROMPT = """A SQLite database has the following schema: + +$db_schema + +Here is a sample row of data from each table: $sample_data + +Use the provided schema and the sample data to construct a SQL query that +can be run in SQLite3 to answer + +$question + +Explain the reason for choosing each table in the SQL query you have +written. Keep the explanation concise. +Finally include a sql query in a code fence such as this one: + +```sql +SELECT count(*) FROM table_name; +``` +""" + + +def initialize_llm(): + # Initialize the LLM library. + if click.confirm("This feature requires additional libraries. Install LLM library?", default=True): + click.echo("Installing LLM library. Please wait...") + run_external_cmd("pip", "install", "--quiet", "llm", restart_cli=True) + + +def ensure_litecli_template(replace=False): + """ + Create a template called litecli with the default prompt. + """ + if not replace: + # Check if it already exists. + code, _ = run_external_cmd("llm", "templates", "show", "litecli", capture_output=True, raise_exception=False) + if code == 0: # Template already exists. No need to create it. + return + + run_external_cmd("llm", PROMPT, "--save", "litecli") + return + + +@export +def handle_llm(text, cur) -> Tuple[str, Optional[str]]: + """This function handles the special command `\\llm`. + + If it deals with a question that results in a SQL query then it will return + the query. + If it deals with a subcommand like `models` or `keys` then it will raise + FinishIteration() which will be caught by the main loop AND print any + output that was supplied (or None). + """ + _, verbose, arg = parse_special_command(text) + + # LLM is not installed. + if llm is None: + initialize_llm() + raise FinishIteration(None) + + if not arg.strip(): # No question provided. Print usage and bail. + output = [(None, None, None, USAGE)] + raise FinishIteration(output) + + parts = shlex.split(arg) + + restart = False + # If the parts has `-c` then capture the output and check for fenced SQL. + # User is continuing a previous question. + # eg: \llm -m ollama -c "Show only the top 5 results" + if "-c" in parts: + capture_output = True + use_context = False + # If the parts has `pormpt` command without `-c` then use context to the prompt. + # \llm -m ollama prompt "Most visited urls?" + elif "prompt" in parts: # User might invoke prompt with an option flag in the first argument. + capture_output = True + use_context = True + elif "install" in parts or "uninstall" in parts: + capture_output = False + use_context = False + restart = True + # If the parts starts with any of the known LLM_CLI_COMMANDS then invoke + # the llm and don't capture output. This is to handle commands like `models` or `keys`. + elif parts[0] in LLM_CLI_COMMANDS: + capture_output = False + use_context = False + # If the parts doesn't have any known LLM_CLI_COMMANDS then the user is + # invoking a question. eg: \llm -m ollama "Most visited urls?" + elif not set(parts).intersection(LLM_CLI_COMMANDS): + capture_output = True + use_context = True + # User invoked llm with a question without `prompt` subcommand. Capture the + # output and check for fenced SQL. eg: \llm "Most visited urls?" + else: + capture_output = True + use_context = True + + if not use_context: + args = parts + if capture_output: + _, result = run_external_cmd("llm", *args, capture_output=capture_output) + match = re.search(_SQL_CODE_FENCE, result, re.DOTALL) + if match: + sql = match.group(1).strip() + else: + output = [(None, None, None, result)] + raise FinishIteration(output) + + return result if verbose else "", sql + else: + run_external_cmd("llm", *args, restart_cli=restart) + raise FinishIteration(None) + + try: + ensure_litecli_template() + context, sql = sql_using_llm(cur=cur, question=arg, verbose=verbose) + if not verbose: + context = "" + return context, sql + except Exception as e: + # Something went wrong. Raise an exception and bail. + raise RuntimeError(e) + + +@export +def is_llm_command(command) -> bool: + """ + Is this an llm/ai command? + """ + cmd, _, _ = parse_special_command(command) + return cmd in ("\\llm", "\\ai", ".llm", ".ai") + + +@export +def sql_using_llm(cur, question=None, verbose=False) -> Tuple[str, Optional[str]]: + schema_query = """ + SELECT sql FROM sqlite_master + WHERE sql IS NOT NULL + ORDER BY tbl_name, type DESC, name + """ + tables_query = """ + SELECT name FROM sqlite_master + WHERE type IN ('table','view') AND name NOT LIKE 'sqlite_%' + ORDER BY 1 + """ + sample_row_query = "SELECT * FROM {table} LIMIT 1" + log.debug(schema_query) + cur.execute(schema_query) + db_schema = "\n".join([x for (x,) in cur.fetchall()]) + + log.debug(tables_query) + cur.execute(tables_query) + sample_data = {} + for (table,) in cur.fetchall(): + sample_row = sample_row_query.format(table=table) + cur.execute(sample_row) + cols = [x[0] for x in cur.description] + row = cur.fetchone() + if row is None: # Skip empty tables + continue + sample_data[table] = list(zip(cols, row)) + + args = [ + "--template", + "litecli", + "--param", + "db_schema", + db_schema, + "--param", + "sample_data", + sample_data, + "--param", + "question", + question, + " ", # Dummy argument to prevent llm from waiting on stdin + ] + _, result = run_external_cmd("llm", *args, capture_output=True) + match = re.search(_SQL_CODE_FENCE, result, re.DOTALL) + if match: + sql = match.group(1).strip() + else: + sql = "" + + return result, sql diff --git a/litecli/packages/special/main.py b/litecli/packages/special/main.py index 49abdf0..9544811 100644 --- a/litecli/packages/special/main.py +++ b/litecli/packages/special/main.py @@ -152,5 +152,13 @@ def quit(*_args): arg_type=NO_QUERY, case_sensitive=True, ) +@special_command( + "\\llm", + "\\ai", + "Use LLM to construct a SQL query.", + arg_type=NO_QUERY, + case_sensitive=False, + aliases=(".ai", ".llm"), +) def stub(): raise NotImplementedError diff --git a/litecli/sqlcompleter.py b/litecli/sqlcompleter.py index b154432..d6d21c7 100644 --- a/litecli/sqlcompleter.py +++ b/litecli/sqlcompleter.py @@ -9,6 +9,7 @@ from prompt_toolkit.completion import Completer, Completion from .packages.completion_engine import suggest_type from .packages.parseutils import last_word from .packages.special.iocommands import favoritequeries +from .packages.special import llm from .packages.filepaths import parse_path, complete_path, suggest_path _logger = logging.getLogger(__name__) @@ -529,6 +530,19 @@ class SQLCompleter(Completer): elif suggestion["type"] == "file_name": file_names = self.find_files(word_before_cursor) completions.extend(file_names) + elif suggestion["type"] == "llm": + if not word_before_cursor: + tokens = document.text.split()[1:] + else: + tokens = document.text.split()[1:-1] + possible_entries = llm.get_completions(tokens) + subcommands = self.find_matches( + word_before_cursor, + possible_entries, + start_only=False, + fuzzy=True, + ) + completions.extend(subcommands) return completions diff --git a/litecli/sqlexecute.py b/litecli/sqlexecute.py index 4277512..4f88764 100644 --- a/litecli/sqlexecute.py +++ b/litecli/sqlexecute.py @@ -32,7 +32,7 @@ class SQLExecute(object): table_columns_query = """ SELECT m.name as tableName, p.name as columnName FROM sqlite_master m - LEFT OUTER JOIN pragma_table_info((m.name)) p ON m.name <> p.name + JOIN pragma_table_info((m.name)) p WHERE m.type IN ('table','view') AND m.name NOT LIKE 'sqlite_%' ORDER BY tableName, columnName """ diff --git a/pyproject.toml b/pyproject.toml index 5caeb84..ba9a9a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ name = "litecli" dynamic = ["version"] description = "CLI for SQLite Databases with auto-completion and syntax highlighting." readme = "README.md" -requires-python = ">=3.7" +requires-python = ">=3.9" license = { text = "BSD" } authors = [{ name = "dbcli", email = "litecli-users@googlegroups.com" }] urls = { "homepage" = "https://github.com/dbcli/litecli" } @@ -14,6 +14,8 @@ dependencies = [ "prompt-toolkit>=3.0.3,<4.0.0", "pygments>=1.6", "sqlparse>=0.4.4", + "setuptools", # Required by llm commands to install models + "pip", ] [build-system] @@ -30,6 +32,8 @@ build-backend = "setuptools.build_meta" litecli = "litecli.main:cli" [project.optional-dependencies] +ai = ["llm"] + dev = [ "behave>=1.2.6", "coverage>=7.2.7", @@ -38,6 +42,7 @@ dev = [ "pytest-cov>=4.1.0", "tox>=4.8.0", "pdbpp>=0.10.3", + "llm>=0.19.0", ] [tool.setuptools.packages.find] diff --git a/tests/test_completion_engine.py b/tests/test_completion_engine.py index b04e184..86053d1 100644 --- a/tests/test_completion_engine.py +++ b/tests/test_completion_engine.py @@ -357,6 +357,18 @@ def test_sub_select_multiple_col_name_completion(): ) +def test_suggested_multiple_column_names(): + suggestions = suggest_type("SELECT id, from users", "SELECT id, ") + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "column", "tables": [(None, "users", None)]}, + {"type": "function", "schema": []}, + {"type": "alias", "aliases": ["users"]}, + {"type": "keyword"}, + ] + ) + + def test_sub_select_dot_col_name_completion(): suggestions = suggest_type("SELECT * FROM (SELECT t. FROM tabl t", "SELECT * FROM (SELECT t.") assert sorted_dicts(suggestions) == sorted_dicts( diff --git a/tests/test_llm_special.py b/tests/test_llm_special.py new file mode 100644 index 0000000..2f3b010 --- /dev/null +++ b/tests/test_llm_special.py @@ -0,0 +1,162 @@ +import pytest +from unittest.mock import patch +from litecli.packages.special.llm import handle_llm, FinishIteration, USAGE + + +@patch("litecli.packages.special.llm.initialize_llm") +@patch("litecli.packages.special.llm.llm", new=None) +def test_llm_command_without_install(mock_initialize_llm, executor): + """ + Test that handle_llm initializes llm when it is None and raises FinishIteration. + """ + test_text = r"\llm" + cur_mock = executor + + with pytest.raises(FinishIteration) as exc_info: + handle_llm(test_text, cur_mock) + + mock_initialize_llm.assert_called_once() + assert exc_info.value.args[0] is None + + +@patch("litecli.packages.special.llm.llm") +def test_llm_command_without_args(mock_llm, executor): + r""" + Invoking \llm without any arguments should print the usage and raise + FinishIteration. + """ + assert mock_llm is not None + test_text = r"\llm" + cur_mock = executor + + with pytest.raises(FinishIteration) as exc_info: + handle_llm(test_text, cur_mock) + + assert exc_info.value.args[0] == [(None, None, None, USAGE)] + + +@patch("litecli.packages.special.llm.llm") +@patch("litecli.packages.special.llm.run_external_cmd") +def test_llm_command_with_c_flag(mock_run_cmd, mock_llm, executor): + # Suppose the LLM returns some text without fenced SQL + mock_run_cmd.return_value = (0, "Hello, I have no SQL for you today.") + + test_text = r"\llm -c 'Something interesting?'" + + with pytest.raises(FinishIteration) as exc_info: + handle_llm(test_text, executor) + + # We expect no code fence => FinishIteration with that output + assert exc_info.value.args[0] == [(None, None, None, "Hello, I have no SQL for you today.")] + + +@patch("litecli.packages.special.llm.llm") +@patch("litecli.packages.special.llm.run_external_cmd") +def test_llm_command_with_c_flag_and_fenced_sql(mock_run_cmd, mock_llm, executor): + # The luscious SQL is inside triple backticks + return_text = "Here is your query:\n" "```sql\nSELECT * FROM table;\n```" + mock_run_cmd.return_value = (0, return_text) + + test_text = r"\llm -c 'Rewrite the SQL without CTE'" + + result, sql = handle_llm(test_text, executor) + + # We expect the function to return (result, sql), but result might be "" if verbose is not set + # By default, `verbose` is false unless text has something like \llm --verbose? + # The function code: return result if verbose else "", sql + # Our test_text doesn't set verbose => we expect "" for the returned context. + assert result == "" + assert sql == "SELECT * FROM table;" + + +@patch("litecli.packages.special.llm.llm") +@patch("litecli.packages.special.llm.run_external_cmd") +def test_llm_command_known_subcommand(mock_run_cmd, mock_llm, executor): + """ + If the parts[0] is in LLM_CLI_COMMANDS, we do NOT capture output, we just call run_external_cmd + and then raise FinishIteration. + """ + # Let's assume 'models' is in LLM_CLI_COMMANDS + test_text = r"\llm models" + + with pytest.raises(FinishIteration) as exc_info: + handle_llm(test_text, executor) + + # We check that run_external_cmd was called with these arguments: + mock_run_cmd.assert_called_once_with("llm", "models", restart_cli=False) + # And the function should raise FinishIteration(None) + assert exc_info.value.args[0] is None + + +@patch("litecli.packages.special.llm.llm") +@patch("litecli.packages.special.llm.run_external_cmd") +def test_llm_command_with_install_flag(mock_run_cmd, mock_llm, executor): + """ + If 'install' or 'uninstall' is in the parts, we do not capture output but restart the CLI. + """ + test_text = r"\llm install openai" + + with pytest.raises(FinishIteration) as exc_info: + handle_llm(test_text, executor) + + # We expect a restart + mock_run_cmd.assert_called_once_with("llm", "install", "openai", restart_cli=True) + assert exc_info.value.args[0] is None + + +@patch("litecli.packages.special.llm.llm") +@patch("litecli.packages.special.llm.ensure_litecli_template") +@patch("litecli.packages.special.llm.sql_using_llm") +def test_llm_command_with_prompt(mock_sql_using_llm, mock_ensure_template, mock_llm, executor): + r""" + \llm prompt "some question" + Should use context, capture output, and call sql_using_llm. + """ + # Mock out the return from sql_using_llm + mock_sql_using_llm.return_value = ("context from LLM", "SELECT 1;") + + test_text = r"\llm prompt 'Magic happening here?'" + context, sql = handle_llm(test_text, executor) + + # ensure_litecli_template should be called + mock_ensure_template.assert_called_once() + # sql_using_llm should be called with question=arg, which is "prompt 'Magic happening here?'" + # Actually, the question is the entire "prompt 'Magic happening here?'" minus the \llm + # But in the function we do parse shlex.split. + mock_sql_using_llm.assert_called() + assert context == "" + assert sql == "SELECT 1;" + + +@patch("litecli.packages.special.llm.llm") +@patch("litecli.packages.special.llm.ensure_litecli_template") +@patch("litecli.packages.special.llm.sql_using_llm") +def test_llm_command_question_with_context(mock_sql_using_llm, mock_ensure_template, mock_llm, executor): + """ + If arg doesn't contain any known command, it's treated as a question => capture output + context. + """ + mock_sql_using_llm.return_value = ("You have context!", "SELECT 2;") + + test_text = r"\llm 'Top 10 downloads by size.'" + context, sql = handle_llm(test_text, executor) + + mock_ensure_template.assert_called_once() + mock_sql_using_llm.assert_called() + assert context == "" + assert sql == "SELECT 2;" + + +@patch("litecli.packages.special.llm.llm") +@patch("litecli.packages.special.llm.ensure_litecli_template") +@patch("litecli.packages.special.llm.sql_using_llm") +def test_llm_command_question_verbose(mock_sql_using_llm, mock_ensure_template, mock_llm, executor): + r""" + Invoking \llm+ returns the context and the SQL query. + """ + mock_sql_using_llm.return_value = ("Verbose context, oh yeah!", "SELECT 42;") + + test_text = r"\llm+ 'Top 10 downloads by size.'" + context, sql = handle_llm(test_text, executor) + + assert context == "Verbose context, oh yeah!" + assert sql == "SELECT 42;" diff --git a/tests/test_main.py b/tests/test_main.py index a8fa4ae..1c24da4 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -2,6 +2,8 @@ import os from collections import namedtuple from textwrap import dedent import shutil +from datetime import datetime +from unittest.mock import patch import click from click.testing import CliRunner @@ -267,3 +269,64 @@ def test_startup_commands(executor): ] # implement tests on executions of the startupcommands + + +@patch("litecli.main.datetime") # Adjust if your module path is different +def test_get_prompt(mock_datetime): + # We'll freeze time at 2025-01-20 13:37:42 for comedic effect. + # Because "leet" times call for 13:37! + frozen_time = datetime(2025, 1, 20, 13, 37, 42) + mock_datetime.now.return_value = frozen_time + # Ensure `datetime` class is still accessible for strftime usage + mock_datetime.datetime = datetime + + # Instantiate and connect + lc = LiteCli() + lc.connect("/tmp/litecli_test.db") + + # 1. Test \d => full path to the DB + assert lc.get_prompt(r"\d") == "/tmp/litecli_test.db" + + # 2. Test \f => basename of the DB + # (because "f" stands for "filename", presumably!) + assert lc.get_prompt(r"\f") == "litecli_test.db" + + # 3. Test \_ => single space + assert lc.get_prompt(r"Hello\_World") == "Hello World" + + # 4. Test \n => newline + # Just to be sure we're only inserting a newline, + # we can check length or assert the presence of "\n". + expected = f"Line1{os.linesep}Line2" + assert lc.get_prompt(r"Line1\nLine2") == expected + + # 5. Test date/time placeholders (with frozen time): + # \D => e.g. 'Mon Jan 20 13:37:42 2025' + expected_date_str = frozen_time.strftime("%a %b %d %H:%M:%S %Y") + assert lc.get_prompt(r"\D") == expected_date_str + + # 6. Test \m => minutes + assert lc.get_prompt(r"\m") == "37" + + # 7. Test \P => AM/PM + # 13:37 is PM + assert lc.get_prompt(r"\P") == "PM" + + # 8. Test \R => 24-hour format hour + assert lc.get_prompt(r"\R") == "13" + + # 9. Test \r => 12-hour format hour + # 13:37 is 01 in 12-hour format + assert lc.get_prompt(r"\r") == "01" + + # 10. Test \s => seconds + assert lc.get_prompt(r"\s") == "42" + + # 11. Test when dbname is None => (none) + lc.connect(None) # Simulate no DB connection + assert lc.get_prompt(r"\d") == "(none)" + assert lc.get_prompt(r"\f") == "(none)" + + # 12. Windows path + lc.connect("C:\\Users\\litecli\\litecli_test.db") + assert lc.get_prompt(r"\d") == "C:\\Users\\litecli\\litecli_test.db" diff --git a/tests/test_sqlexecute.py b/tests/test_sqlexecute.py index 2bdc84c..b1be9ac 100644 --- a/tests/test_sqlexecute.py +++ b/tests/test_sqlexecute.py @@ -38,13 +38,15 @@ def test_binary(executor): ## Failing in Travis for some unknown reason. -# @dbtest -# def test_table_and_columns_query(executor): -# run(executor, "create table a(x text, y text)") -# run(executor, "create table b(z text)") +@dbtest +def test_table_and_columns_query(executor): + run(executor, "create table a(x text, y text)") + run(executor, "create table b(z text)") + run(executor, "create table t(t text)") -# assert set(executor.tables()) == set([("a",), ("b",)]) -# assert set(executor.table_columns()) == set([("a", "x"), ("a", "y"), ("b", "z")]) + assert set(executor.tables()) == set([("a",), ("b",), ("t",)]) + assert set(executor.table_columns()) == set([("a", "x"), ("a", "y"), ("b", "z"), ("t", "t")]) + assert set(executor.table_columns()) == set([("a", "x"), ("a", "y"), ("b", "z"), ("t", "t")]) @dbtest