Skip to content

Commit 976c1d0

Browse files
authored
Merge pull request #20 from nick-galluzzo/refactor/custom-embeddings-cli
refactor: Restructure embedding CLI
2 parents 97cdd28 + c4e8977 commit 976c1d0

21 files changed

+650
-476
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "notebookllama"
3-
version = "0.2.2.post1"
3+
version = "0.2.3"
44
description = "An OSS and LlamaCloud-backed alternative to NotebookLM"
55
readme = "README.md"
66
requires-python = ">=3.13"

tools/cli/config/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .models import EmbeddingConfig
2+
3+
__all__ = ["EmbeddingConfig"]

tools/cli/config/models.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from dataclasses import dataclass
2+
from typing import Optional
3+
4+
5+
@dataclass
6+
class EmbeddingConfig:
7+
provider: str
8+
api_key: Optional[str] = None
9+
model: Optional[str] = None
10+
region: Optional[str] = None
11+
key_id: Optional[str] = None
12+
embedding_config: Optional[object] = None

tools/cli/embedding_app.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import os
2+
from textual.app import App
3+
4+
from .config import EmbeddingConfig
5+
from .screens import InitialScreen
6+
7+
8+
class EmbeddingSetupApp(App):
9+
"""Main application for embedding configuration setup."""
10+
11+
CSS_PATH = "stylesheets/base.tcss"
12+
13+
def __init__(self):
14+
super().__init__()
15+
self.config = EmbeddingConfig(provider="")
16+
17+
def on_mount(self) -> None:
18+
self.push_screen(InitialScreen())
19+
20+
def handle_completion(self, config: EmbeddingConfig) -> None:
21+
self.exit(config)
22+
23+
def handle_default_setup(self) -> None:
24+
from llama_index.embeddings.openai import OpenAIEmbedding
25+
from llama_cloud import PipelineCreateEmbeddingConfig_OpenaiEmbedding
26+
27+
self.config.provider = "OpenAI"
28+
self.config.api_key = os.getenv("OPENAI_API_KEY")
29+
self.config.model = "text-embedding-3-small"
30+
31+
embed_model = OpenAIEmbedding(
32+
model=self.config.model, api_key=self.config.api_key
33+
)
34+
embedding_config = PipelineCreateEmbeddingConfig_OpenaiEmbedding(
35+
type="OPENAI_EMBEDDING",
36+
component=embed_model,
37+
)
38+
self.config = embedding_config
39+
40+
self.handle_completion(self.config)

tools/cli/screens/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .base import BaseScreen, ConfigurationScreen
2+
from .initial import InitialScreen
3+
from .embedding_provider import ProviderSelectScreen
4+
5+
__all__ = ["BaseScreen", "ConfigurationScreen", "InitialScreen", "ProviderSelectScreen"]

tools/cli/screens/base.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from textual.app import ComposeResult
2+
from textual.containers import Container
3+
from textual.screen import Screen
4+
from textual.widgets import Label, Footer
5+
from textual.binding import Binding
6+
from textual.widgets import Input
7+
8+
9+
class BaseScreen(Screen):
10+
"""Base screen with common functionality for all screens."""
11+
12+
BINDINGS = [
13+
Binding("ctrl+q", "quit", "Exit", key_display="ctrl+q"),
14+
Binding("ctrl+d", "toggle_dark", "Toggle Dark Theme", key_display="ctrl+d"),
15+
]
16+
17+
def action_toggle_dark(self) -> None:
18+
self.app.theme = (
19+
"textual-dark" if self.app.theme == "textual-light" else "textual-light"
20+
)
21+
22+
def action_quit(self) -> None:
23+
self.app.exit()
24+
25+
def compose(self) -> ComposeResult:
26+
yield Container(
27+
Label(self.get_title(), classes="form-title"),
28+
*self.get_form_elements(),
29+
classes="form-container",
30+
)
31+
yield Footer()
32+
33+
def get_title(self) -> str:
34+
return "Base Screen"
35+
36+
def get_form_elements(self) -> list[ComposeResult]:
37+
return []
38+
39+
40+
class ConfigurationScreen(BaseScreen):
41+
"""Base screen provider configuration with submit functionality."""
42+
43+
BINDINGS = BaseScreen.BINDINGS + [
44+
Binding("shift+enter", "submit", "Submit"),
45+
]
46+
47+
def on_input_submitted(self, event: Input.Submitted) -> None:
48+
"""Catches the Enter key press and delegates the work."""
49+
self.process_submission()
50+
51+
def process_submission(self) -> None:
52+
"""
53+
To be implemented by each specific provider screen.
54+
This method contains the unique logic for validating and creating
55+
the embedding configuration.
56+
"""
57+
58+
raise NotImplementedError(
59+
"Each configuration screen must implement 'process_submission'"
60+
)
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from textual import on
2+
from textual.widgets import Select
3+
4+
from .base import BaseScreen
5+
from .embedding_providers import (
6+
OpenAIEmbeddingScreen,
7+
BedrockEmbeddingScreen,
8+
AzureEmbeddingScreen,
9+
GeminiEmbeddingScreen,
10+
CohereEmbeddingScreen,
11+
HuggingFaceEmbeddingScreen,
12+
)
13+
14+
15+
class ProviderSelectScreen(BaseScreen):
16+
"""Screen for selecting embedding provider."""
17+
18+
def get_title(self) -> str:
19+
return "Select an embedding provider"
20+
21+
def get_form_elements(self) -> list:
22+
return [
23+
Select(
24+
options=[
25+
("OpenAI", "OpenAI"),
26+
("Cohere", "Cohere"),
27+
("Bedrock", "Bedrock"),
28+
("HuggingFace", "HuggingFace"),
29+
("Azure", "Azure"),
30+
("Gemini", "Gemini"),
31+
],
32+
prompt="Please select an embedding provider",
33+
classes="form-control",
34+
id="provider_select",
35+
)
36+
]
37+
38+
@on(Select.Changed, "#provider_select")
39+
def handle_selection(self, event: Select.Changed) -> None:
40+
from ..embedding_app import EmbeddingSetupApp
41+
42+
app = self.app
43+
if isinstance(app, EmbeddingSetupApp):
44+
app.config.provider = event.value
45+
self.handle_next()
46+
47+
def handle_next(self) -> None:
48+
from ..embedding_app import EmbeddingSetupApp
49+
50+
app = self.app
51+
if isinstance(app, EmbeddingSetupApp):
52+
provider_screens = {
53+
"OpenAI": OpenAIEmbeddingScreen,
54+
"Bedrock": BedrockEmbeddingScreen,
55+
"Azure": AzureEmbeddingScreen,
56+
"Gemini": GeminiEmbeddingScreen,
57+
"Cohere": CohereEmbeddingScreen,
58+
"HuggingFace": HuggingFaceEmbeddingScreen,
59+
}
60+
screen_class = provider_screens.get(app.config.provider)
61+
if screen_class:
62+
app.push_screen(screen_class())
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from .openai import OpenAIEmbeddingScreen
2+
from .bedrock import BedrockEmbeddingScreen
3+
from .azure import AzureEmbeddingScreen
4+
from .gemini import GeminiEmbeddingScreen
5+
from .cohere import CohereEmbeddingScreen
6+
from .huggingface import HuggingFaceEmbeddingScreen
7+
8+
__all__ = [
9+
"OpenAIEmbeddingScreen",
10+
"BedrockEmbeddingScreen",
11+
"AzureEmbeddingScreen",
12+
"GeminiEmbeddingScreen",
13+
"CohereEmbeddingScreen",
14+
"HuggingFaceEmbeddingScreen",
15+
]
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from textual.app import ComposeResult
2+
from textual.widgets import Input
3+
4+
from llama_index.embeddings.azure_inference import AzureAIEmbeddingsModel
5+
from llama_cloud import PipelineCreateEmbeddingConfig_AzureEmbedding
6+
7+
from ..base import ConfigurationScreen
8+
9+
10+
class AzureEmbeddingScreen(ConfigurationScreen):
11+
"""Configuration screen for Azure embeddings."""
12+
13+
def get_title(self) -> str:
14+
return "Azure Embedding Configuration"
15+
16+
def get_form_elements(self) -> list[ComposeResult]:
17+
return [
18+
Input(
19+
placeholder="API Key",
20+
password=True,
21+
id="api_key",
22+
classes="form-control",
23+
),
24+
Input(placeholder="Endpoint URL", id="endpoint", classes="form-control"),
25+
]
26+
27+
def process_submission(self) -> None:
28+
api_key = self.query_one("#api_key", Input).value
29+
endpoint = self.query_one("#endpoint", Input).value
30+
31+
if not all([api_key, endpoint]):
32+
self.notify("All fields are required.", severity="error")
33+
return
34+
35+
try:
36+
embed_model = AzureAIEmbeddingsModel(credential=api_key, endpoint=endpoint)
37+
embedding_config = PipelineCreateEmbeddingConfig_AzureEmbedding(
38+
type="AZURE_EMBEDDING", component=embed_model
39+
)
40+
41+
self.app.config = embedding_config
42+
self.app.handle_completion(self.app.config)
43+
except Exception as e:
44+
self.notify(f"Error: {e}", severity="error")
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from textual.app import ComposeResult
2+
from textual.widgets import Input, Select
3+
4+
from llama_index.embeddings.bedrock import BedrockEmbedding
5+
from llama_cloud import PipelineCreateEmbeddingConfig_BedrockEmbedding
6+
7+
from ..base import ConfigurationScreen
8+
9+
10+
class BedrockEmbeddingScreen(ConfigurationScreen):
11+
"""Configuration screen for Bedrock embeddings."""
12+
13+
def get_title(self) -> str:
14+
return "Bedrock Embedding Configuration"
15+
16+
def get_form_elements(self) -> list[ComposeResult]:
17+
model_options = []
18+
try:
19+
supported_models = BedrockEmbedding.list_supported_models()
20+
model_options = [
21+
(f"{provider.title()}: {model_id.split('.')[-1]}", model_id)
22+
for provider, models in supported_models.items()
23+
for model_id in models
24+
]
25+
except Exception as e:
26+
self.notify(
27+
f"Could not fetch Bedrock models: {e}", severity="error", timeout=10
28+
)
29+
30+
return [
31+
Select(
32+
options=model_options,
33+
prompt="Select Bedrock Model",
34+
id="model",
35+
classes="form-control",
36+
),
37+
Input(
38+
placeholder="Region (e.g., us-east-1)",
39+
id="region",
40+
classes="form-control",
41+
),
42+
Input(
43+
placeholder="Access Key ID (Optional)",
44+
id="access_key_id",
45+
classes="form-control",
46+
),
47+
Input(
48+
placeholder="Secret Access Key (Optional)",
49+
password=True,
50+
id="secret_access_key",
51+
classes="form-control",
52+
),
53+
]
54+
55+
def process_submission(self) -> None:
56+
model = self.query_one("#model", Select).value
57+
region = self.query_one("#region", Input).value
58+
access_key_id = self.query_one("#access_key_id", Input).value
59+
secret_access_key = self.query_one("#secret_access_key", Input).value
60+
61+
if not all([model, region]):
62+
self.notify("All fields are required.", severity="error")
63+
return
64+
65+
try:
66+
embed_model = BedrockEmbedding(
67+
model_name=model,
68+
region_name=region,
69+
aws_access_key_id=access_key_id,
70+
aws_secret_access_key=secret_access_key,
71+
)
72+
embedding_config = PipelineCreateEmbeddingConfig_BedrockEmbedding(
73+
type="BEDROCK_EMBEDDING", component=embed_model
74+
)
75+
76+
self.app.config = embedding_config
77+
self.app.handle_completion(self.app.config)
78+
except Exception as e:
79+
self.notify(f"Error: {e}", severity="error")

0 commit comments

Comments
 (0)