Skip to content

Commit 4fee270

Browse files
committed
fix: resolve ty optional import errors
1 parent 6b11a4c commit 4fee270

File tree

7 files changed

+55
-33
lines changed

7 files changed

+55
-33
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88

99
- Avoid completion refresh crashes when no database is connected.
1010

11+
### Internal
12+
13+
- Clean up ty type-checking for optional sqlean/llm imports.
14+
1115
## 1.18.0
1216

1317
### Internal

litecli/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# type: ignore
21
from __future__ import annotations
32

43
import importlib.metadata

litecli/main.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,6 @@
1111
from collections import namedtuple
1212
from datetime import datetime
1313
from io import open
14-
15-
try:
16-
from sqlean import OperationalError, sqlite_version
17-
except ImportError:
18-
from sqlite3 import OperationalError, sqlite_version
1914
from time import time
2015
from typing import Any, Generator, Iterable, cast
2116

@@ -51,6 +46,15 @@
5146
from .sqlcompleter import SQLCompleter
5247
from .sqlexecute import SQLExecute
5348

49+
try:
50+
import sqlean as _sqlite3
51+
except ImportError:
52+
import sqlite3 as _sqlite3
53+
54+
_sqlite3 = cast(Any, _sqlite3)
55+
OperationalError = _sqlite3.OperationalError
56+
sqlite_version = _sqlite3.sqlite_version
57+
5458
# Query tuples are used for maintaining history
5559
Query = namedtuple("Query", ["query", "successful", "mutating"])
5660

litecli/packages/special/llm.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import contextlib
4+
import importlib
45
import io
56
import logging
67
import os
@@ -15,32 +16,40 @@
1516
import click
1617

1718
try:
18-
import llm
19-
20-
LLM_IMPORTED = True
19+
import llm as llm_module
2120
except ImportError:
22-
llm = None
23-
LLM_IMPORTED = False
21+
llm_module = None
2422

2523
try:
26-
from llm.cli import cli
27-
28-
LLM_CLI_IMPORTED = True
24+
llm_cli_module = importlib.import_module("llm.cli")
2925
except ImportError:
30-
cli = None
31-
LLM_CLI_IMPORTED = False
26+
llm_cli_module = None
3227

3328
from . import export
3429
from .main import Verbosity, parse_special_command
3530
from .types import DBCursor
3631

32+
LLM_IMPORTED = llm_module is not None
33+
34+
cli: click.Command | None
35+
if llm_cli_module is not None:
36+
llm_cli = getattr(llm_cli_module, "cli", None)
37+
cli = llm_cli if isinstance(llm_cli, click.Command) else None
38+
else:
39+
cli = None
40+
41+
LLM_CLI_IMPORTED = cli is not None
42+
3743
log = logging.getLogger(__name__)
3844

3945
LLM_TEMPLATE_NAME = "litecli-llm-template"
40-
LLM_CLI_COMMANDS: list[str] = list(cli.commands.keys()) if LLM_CLI_IMPORTED else []
46+
LLM_CLI_COMMANDS: list[str] = list(cli.commands.keys()) if isinstance(cli, click.Group) else []
4147
# Mapping of model_id to None used for completion tree leaves.
42-
# the file name is llm.py and module name is llm, hence ty is complaining that get_models is missing.
43-
MODELS: dict[str, None] = {x.model_id: None for x in llm.get_models()} if LLM_IMPORTED else {} # type: ignore[attr-defined]
48+
if llm_module is not None:
49+
get_models = getattr(llm_module, "get_models", None)
50+
MODELS: dict[str, None] = {x.model_id: None for x in get_models()} if callable(get_models) else {}
51+
else:
52+
MODELS = {}
4453

4554

4655
def run_external_cmd(
@@ -124,7 +133,7 @@ def build_command_tree(cmd: click.Command) -> dict[str, Any] | None:
124133

125134

126135
# Generate the tree
127-
COMMAND_TREE: dict[str, Any] | None = build_command_tree(cli) if LLM_CLI_IMPORTED else {}
136+
COMMAND_TREE: dict[str, Any] | None = build_command_tree(cli) if cli is not None else {}
128137

129138

130139
def get_completions(tokens: list[str], tree: dict[str, Any] | None = COMMAND_TREE) -> list[str]:

litecli/sqlexecute.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,26 @@
11
from __future__ import annotations
22

33
import logging
4+
import os.path
45
from contextlib import closing
5-
from typing import Any, Generator, Iterable
6+
from typing import Any, Generator, Iterable, cast
7+
from urllib.parse import urlparse
8+
9+
import sqlparse
610

711
try:
8-
import sqlean as sqlite3
9-
from sqlean import OperationalError
12+
import sqlean as _sqlite3
1013

11-
sqlite3.extensions.enable_all()
14+
_sqlite3.extensions.enable_all()
1215
except ImportError:
13-
import sqlite3
14-
from sqlite3 import OperationalError
15-
import os.path
16-
from urllib.parse import urlparse
17-
18-
import sqlparse
16+
import sqlite3 as _sqlite3
1917

2018
from litecli.packages import special
2119
from litecli.packages.special.utils import check_if_sqlitedotcommand
2220

21+
sqlite3 = cast(Any, _sqlite3)
22+
OperationalError = sqlite3.OperationalError
23+
2324
_logger = logging.getLogger(__name__)
2425

2526
# FIELD_TYPES = decoders.copy()

tests/test_main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def stub_terminal_size():
243243
shutil.get_terminal_size = stub_terminal_size # type: ignore[assignment]
244244
lc = LiteCli()
245245
assert isinstance(lc.get_reserved_space(), int)
246-
shutil.get_terminal_size = old_func # type: ignore[assignment]
246+
shutil.get_terminal_size = old_func
247247

248248

249249
@dbtest

tests/test_sqlexecute.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
11
# coding=UTF-8
22

33
import os
4+
from typing import Any, cast
45

56
import pytest
67

78
from .utils import assert_result_equal, dbtest, is_expanded_output, run, set_expanded_output
89

910
try:
10-
from sqlean import OperationalError, ProgrammingError
11+
import sqlean as _sqlite3
1112
except ImportError:
12-
from sqlite3 import OperationalError, ProgrammingError
13+
import sqlite3 as _sqlite3
14+
15+
_sqlite3 = cast(Any, _sqlite3)
16+
OperationalError = _sqlite3.OperationalError
17+
ProgrammingError = _sqlite3.ProgrammingError
1318

1419

1520
@dbtest

0 commit comments

Comments
 (0)