Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: supports fetching models during configuration initialization #1161

Merged
merged 1 commit into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
feat: supports fetching models during configuration initialization
  • Loading branch information
sigoden committed Feb 9, 2025
commit d316127a99606b011edb4a0fc1ddba4505479969
123 changes: 92 additions & 31 deletions src/client/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ use crate::{
use anyhow::{bail, Context, Result};
use fancy_regex::Regex;
use indexmap::IndexMap;
use inquire::{required, Select, Text};
use inquire::{
list_option::ListOption, required, validator::Validation, MultiSelect, Select, Text,
};
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
use serde_json::{json, Value};
Expand All @@ -23,6 +25,7 @@ lazy_static::lazy_static! {
pub static ref ALL_PROVIDER_MODELS: Vec<ProviderModels> = {
Config::loal_models_override().ok().unwrap_or_else(|| serde_yaml::from_str(MODELS_YAML).unwrap())
};
static ref EMBEDDING_MODEL_RE: Regex = Regex::new(r"(^(bge-|e5-|uae-|gte-|text-)|embed|multilingual|minilm)").unwrap();
static ref ESCAPE_SLASH_RE: Regex = Regex::new(r"(?<!\\)/").unwrap();
}

Expand Down Expand Up @@ -331,16 +334,29 @@ pub struct RerankResult {

pub type PromptAction<'a> = (&'a str, &'a str, Option<&'a str>);

pub fn create_config(prompts: &[PromptAction], client: &str) -> Result<(String, Value)> {
pub async fn create_config(
prompts: &[PromptAction<'static>],
client: &str,
) -> Result<(String, Value)> {
let mut config = json!({
"type": client,
});
let model = set_client_config(prompts, &mut config, client)?;
for (key, desc, help_message) in prompts {
let env_name = format!("{client}_{key}").to_ascii_uppercase();
let required = std::env::var(&env_name).is_err();
let value = prompt_input_string(desc, required, *help_message)?;
if !value.is_empty() {
config[key] = value.into();
}
}
let model = set_client_models_config(&mut config, client).await?;
let clients = json!(vec![config]);
Ok((model, clients))
}

pub fn create_openai_compatible_client_config(client: &str) -> Result<Option<(String, Value)>> {
pub async fn create_openai_compatible_client_config(
client: &str,
) -> Result<Option<(String, Value)>> {
let api_base = super::OPENAI_COMPATIBLE_PROVIDERS
.into_iter()
.find(|(name, _)| client == *name)
Expand Down Expand Up @@ -371,7 +387,7 @@ pub fn create_openai_compatible_client_config(client: &str) -> Result<Option<(St
config["api_key"] = api_key.into();
}

let model = set_client_models_config(&mut config, &name)?;
let model = set_client_models_config(&mut config, &name).await?;
let clients = json!(vec![config]);
Ok(Some((model, clients)))
}
Expand Down Expand Up @@ -512,23 +528,7 @@ pub fn json_str_from_map<'a>(
map.get(field_name).and_then(|v| v.as_str())
}

fn set_client_config(
list: &[PromptAction],
client_config: &mut Value,
client: &str,
) -> Result<String> {
for (key, desc, help_message) in list {
let env_name = format!("{client}_{key}").to_ascii_uppercase();
let required = std::env::var(&env_name).is_err();
let value = prompt_input_string(desc, required, *help_message)?;
if !value.is_empty() {
client_config[key] = value.into();
}
}
set_client_models_config(client_config, client)
}

fn set_client_models_config(client_config: &mut Value, client: &str) -> Result<String> {
async fn set_client_models_config(client_config: &mut Value, client: &str) -> Result<String> {
if let Some(provider) = ALL_PROVIDER_MODELS.iter().find(|v| v.provider == client) {
let models: Vec<String> = provider
.models
Expand All @@ -539,13 +539,46 @@ fn set_client_models_config(client_config: &mut Value, client: &str) -> Result<S
let model_name = select_model(models)?;
return Ok(format!("{client}:{model_name}"));
}

let model_names = prompt_input_string(
"LLM models",
true,
Some("Separated by commas, e.g. llama3.3,qwen2.5"),
)?;
let model_names = model_names
let mut model_names = vec![];
if let (Some(true), Some(api_base), api_key) = (
client_config["type"]
.as_str()
.map(|v| v == OpenAICompatibleClient::NAME),
client_config["api_base"].as_str(),
client_config["api_key"]
.as_str()
.map(|v| v.to_string())
.or_else(|| {
let env_name = format!("{client}_api_key").to_ascii_uppercase();
std::env::var(&env_name).ok()
}),
) {
if let Ok(fetched_models) = abortable_run_with_spinner(
fetch_models(api_base, api_key.as_deref()),
"Fetching models",
create_abort_signal(),
)
.await
{
model_names = MultiSelect::new("LLM models (required):", fetched_models)
.with_validator(|list: &[ListOption<&String>]| {
if list.is_empty() {
Ok(Validation::Invalid(
"At least one item must be selected".into(),
))
} else {
Ok(Validation::Valid)
}
})
.prompt()?;
}
}
if model_names.is_empty() {
model_names = prompt_input_string(
"LLM models",
true,
Some("Separated by commas, e.g. llama3.3,qwen2.5"),
)?
.split(',')
.filter_map(|v| {
let v = v.trim();
Expand All @@ -556,10 +589,38 @@ fn set_client_models_config(client_config: &mut Value, client: &str) -> Result<S
}
})
.collect::<Vec<_>>();
}
if model_names.is_empty() {
bail!("No models");
}
let models: Vec<Value> = model_names.iter().map(|v| json!({"name": v})).collect();
let models: Vec<Value> = model_names
.iter()
.map(|v| {
let l = v.to_lowercase();
if l.contains("rank") {
json!({
"name": v,
"type": "reranker",
})
} else if let Ok(true) = EMBEDDING_MODEL_RE.is_match(&l) {
json!({
"name": v,
"type": "embedding",
"default_chunk_size": 1000,
"max_batch_size": 16
})
} else if v.contains("vision") {
json!({
"name": v,
"supports_vision": true
})
} else {
json!({
"name": v,
})
}
})
.collect();
client_config["models"] = models.into();
let model_name = select_model(model_names)?;
Ok(format!("{client}:{model_name}"))
Expand All @@ -572,7 +633,7 @@ fn select_model(model_names: Vec<String>) -> Result<String> {
let model = if model_names.len() == 1 {
model_names[0].clone()
} else {
Select::new("Select model:", model_names).prompt()?
Select::new("Default Model (required):", model_names).prompt()?
};
Ok(model)
}
Expand Down
6 changes: 3 additions & 3 deletions src/client/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,13 @@ macro_rules! register_client {
client_types
}

pub fn create_client_config(client: &str) -> anyhow::Result<(String, serde_json::Value)> {
pub async fn create_client_config(client: &str) -> anyhow::Result<(String, serde_json::Value)> {
$(
if client == $client::NAME && client != $crate::client::OpenAICompatibleClient::NAME {
return create_config(&$client::PROMPTS, $client::NAME)
return create_config(&$client::PROMPTS, $client::NAME).await
}
)+
if let Some(ret) = create_openai_compatible_client_config(client)? {
if let Some(ret) = create_openai_compatible_client_config(client).await? {
return Ok(ret);
}
anyhow::bail!("Unknown client '{}'", client)
Expand Down
8 changes: 4 additions & 4 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ impl Default for Config {
pub type GlobalConfig = Arc<RwLock<Config>>;

impl Config {
pub fn init(working_mode: WorkingMode, info_flag: bool) -> Result<Self> {
pub async fn init(working_mode: WorkingMode, info_flag: bool) -> Result<Self> {
let config_path = Self::config_file();
let mut config = if !config_path.exists() {
match env::var(get_env_name("provider"))
Expand All @@ -252,7 +252,7 @@ impl Config {
Some(v) => Self::load_dynamic(&v)?,
None => {
if *IS_STDOUT_TERMINAL {
create_config_file(&config_path)?;
create_config_file(&config_path).await?;
}
Self::load_from_file(&config_path)?
}
Expand Down Expand Up @@ -2604,7 +2604,7 @@ impl AssertState {
}
}

fn create_config_file(config_path: &Path) -> Result<()> {
async fn create_config_file(config_path: &Path) -> Result<()> {
let ans = Confirm::new("No config file, create a new one?")
.with_default(true)
.prompt()?;
Expand All @@ -2615,7 +2615,7 @@ fn create_config_file(config_path: &Path) -> Result<()> {
let client = Select::new("API Provider (required):", list_client_types()).prompt()?;

let mut config = serde_json::json!({});
let (model, clients_config) = create_client_config(client)?;
let (model, clients_config) = create_client_config(client).await?;
config["model"] = model.into();
config[CLIENTS_FIELD] = clients_config;

Expand Down
2 changes: 1 addition & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ async fn main() -> Result<()> {
|| cli.list_macros
|| cli.list_sessions;
setup_logger(working_mode.is_serve())?;
let config = Arc::new(RwLock::new(Config::init(working_mode, info_flag)?));
let config = Arc::new(RwLock::new(Config::init(working_mode, info_flag).await?));
if let Err(err) = run(config, cli, text).await {
render_error(err);
std::process::exit(1);
Expand Down
27 changes: 26 additions & 1 deletion src/utils/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ const USER_AGENT: &str = "curl/8.6.0";

lazy_static::lazy_static! {
static ref CLIENT: Result<reqwest::Client> = {
let builder = reqwest::ClientBuilder::new().timeout(Duration::from_secs(30));
let builder = reqwest::ClientBuilder::new().timeout(Duration::from_secs(16));
let client = builder.build()?;
Ok(client)
};
Expand Down Expand Up @@ -158,6 +158,31 @@ pub async fn fetch_with_loaders(
Ok(result)
}

pub async fn fetch_models(api_base: &str, api_key: Option<&str>) -> Result<Vec<String>> {
let client = match *CLIENT {
Ok(ref client) => client,
Err(ref err) => bail!("{err}"),
};
let mut request_builder = client.get(format!("{}/models", api_base.trim_end_matches('/')));
if let Some(api_key) = api_key {
request_builder = request_builder.bearer_auth(api_key);
}
let res_body: Value = request_builder.send().await?.json().await?;
let result: Vec<String> = res_body
.get("data")
.and_then(|v| v.as_array())
.map(|v| {
v.iter()
.filter_map(|v| v.get("id").and_then(|v| v.as_str().map(|v| v.to_string())))
.collect()
})
.unwrap_or_default();
if result.is_empty() {
bail!("No models")
}
Ok(result)
}

#[derive(Debug, Clone, Default)]
pub struct CrawlOptions {
extract: Option<String>,
Expand Down