1
0
Fork 0
litecli/litecli/sqlexecute.py
Daniel Baumann 0912fc1528
Adding upstream version 1.5.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
2025-02-09 17:37:21 +01:00

212 lines
7.1 KiB
Python

import logging
import sqlite3
import uuid
from contextlib import closing
from sqlite3 import OperationalError
import sqlparse
import os.path
from .packages import special
_logger = logging.getLogger(__name__)
# FIELD_TYPES = decoders.copy()
# FIELD_TYPES.update({
# FIELD_TYPE.NULL: type(None)
# })
class SQLExecute(object):
databases_query = """
PRAGMA database_list
"""
tables_query = """
SELECT name
FROM sqlite_master
WHERE type IN ('table','view') AND name NOT LIKE 'sqlite_%'
ORDER BY 1
"""
table_columns_query = """
SELECT m.name as tableName, p.name as columnName
FROM sqlite_master m
LEFT OUTER JOIN pragma_table_info((m.name)) p ON m.name <> p.name
WHERE m.type IN ('table','view') AND m.name NOT LIKE 'sqlite_%'
ORDER BY tableName, columnName
"""
functions_query = '''SELECT ROUTINE_NAME FROM INFORMATION_SCHEMA.ROUTINES
WHERE ROUTINE_TYPE="FUNCTION" AND ROUTINE_SCHEMA = "%s"'''
def __init__(self, database):
self.dbname = database
self._server_type = None
self.connection_id = None
self.conn = None
if not database:
_logger.debug("Database is not specified. Skip connection.")
return
self.connect()
def connect(self, database=None):
db = database or self.dbname
_logger.debug("Connection DB Params: \n" "\tdatabase: %r", database)
db_name = os.path.expanduser(db)
db_dir_name = os.path.dirname(os.path.abspath(db_name))
if not os.path.exists(db_dir_name):
raise Exception("Path does not exist: {}".format(db_dir_name))
conn = sqlite3.connect(database=db_name, isolation_level=None)
if self.conn:
self.conn.close()
self.conn = conn
# Update them after the connection is made to ensure that it was a
# successful connection.
self.dbname = db
# retrieve connection id
self.reset_connection_id()
def run(self, statement):
"""Execute the sql in the database and return the results. The results
are a list of tuples. Each tuple has 4 values
(title, rows, headers, status).
"""
# Remove spaces and EOL
statement = statement.strip()
if not statement: # Empty string
yield (None, None, None, None)
# Split the sql into separate queries and run each one.
# Unless it's saving a favorite query, in which case we
# want to save them all together.
if statement.startswith("\\fs"):
components = [statement]
else:
components = sqlparse.split(statement)
for sql in components:
# Remove spaces, eol and semi-colons.
sql = sql.rstrip(";")
# \G is treated specially since we have to set the expanded output.
if sql.endswith("\\G"):
special.set_expanded_output(True)
sql = sql[:-2].strip()
if not self.conn and not (
sql.startswith(".open")
or sql.lower().startswith("use")
or sql.startswith("\\u")
or sql.startswith("\\?")
or sql.startswith("\\q")
or sql.startswith("help")
or sql.startswith("exit")
or sql.startswith("quit")
):
_logger.debug(
"Not connected to database. Will not run statement: %s.", sql
)
raise OperationalError("Not connected to database.")
# yield ('Not connected to database', None, None, None)
# return
cur = self.conn.cursor() if self.conn else None
try: # Special command
_logger.debug("Trying a dbspecial command. sql: %r", sql)
for result in special.execute(cur, sql):
yield result
except special.CommandNotFound: # Regular SQL
_logger.debug("Regular sql statement. sql: %r", sql)
cur.execute(sql)
yield self.get_result(cur)
def get_result(self, cursor):
"""Get the current result's data from the cursor."""
title = headers = None
# cursor.description is not None for queries that return result sets,
# e.g. SELECT.
if cursor.description is not None:
headers = [x[0] for x in cursor.description]
status = "{0} row{1} in set"
cursor = list(cursor)
rowcount = len(cursor)
else:
_logger.debug("No rows in result.")
status = "Query OK, {0} row{1} affected"
rowcount = 0 if cursor.rowcount == -1 else cursor.rowcount
cursor = None
status = status.format(rowcount, "" if rowcount == 1 else "s")
return (title, cursor, headers, status)
def tables(self):
"""Yields table names"""
with closing(self.conn.cursor()) as cur:
_logger.debug("Tables Query. sql: %r", self.tables_query)
cur.execute(self.tables_query)
for row in cur:
yield row
def table_columns(self):
"""Yields column names"""
with closing(self.conn.cursor()) as cur:
_logger.debug("Columns Query. sql: %r", self.table_columns_query)
cur.execute(self.table_columns_query)
for row in cur:
yield row
def databases(self):
if not self.conn:
return
with closing(self.conn.cursor()) as cur:
_logger.debug("Databases Query. sql: %r", self.databases_query)
for row in cur.execute(self.databases_query):
yield row[1]
def functions(self):
"""Yields tuples of (schema_name, function_name)"""
with closing(self.conn.cursor()) as cur:
_logger.debug("Functions Query. sql: %r", self.functions_query)
cur.execute(self.functions_query % self.dbname)
for row in cur:
yield row
def show_candidates(self):
with closing(self.conn.cursor()) as cur:
_logger.debug("Show Query. sql: %r", self.show_candidates_query)
try:
cur.execute(self.show_candidates_query)
except sqlite3.DatabaseError as e:
_logger.error("No show completions due to %r", e)
yield ""
else:
for row in cur:
yield (row[0].split(None, 1)[-1],)
def server_type(self):
self._server_type = ("sqlite3", "3")
return self._server_type
def get_connection_id(self):
if not self.connection_id:
self.reset_connection_id()
return self.connection_id
def reset_connection_id(self):
# Remember current connection id
_logger.debug("Get current connection id")
# res = self.run('select connection_id()')
self.connection_id = uuid.uuid4()
# for title, cur, headers, status in res:
# self.connection_id = cur.fetchone()[0]
_logger.debug("Current connection id: %s", self.connection_id)