Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: updated logger and model caching minor bugfix #release #895

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 48 additions & 42 deletions app/components/chat/BaseChat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -168,30 +168,32 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
}, []);

useEffect(() => {
const providerSettings = getProviderSettings();
let parsedApiKeys: Record<string, string> | undefined = {};
if (typeof window !== 'undefined') {
const providerSettings = getProviderSettings();
let parsedApiKeys: Record<string, string> | undefined = {};

try {
parsedApiKeys = getApiKeysFromCookies();
setApiKeys(parsedApiKeys);
} catch (error) {
console.error('Error loading API keys from cookies:', error);
try {
parsedApiKeys = getApiKeysFromCookies();
setApiKeys(parsedApiKeys);
} catch (error) {
console.error('Error loading API keys from cookies:', error);

// Clear invalid cookie data
Cookies.remove('apiKeys');
// Clear invalid cookie data
Cookies.remove('apiKeys');
}
setIsModelLoading('all');
initializeModelList({ apiKeys: parsedApiKeys, providerSettings })
.then((modelList) => {
// console.log('Model List: ', modelList);
setModelList(modelList);
})
.catch((error) => {
console.error('Error initializing model list:', error);
})
.finally(() => {
setIsModelLoading(undefined);
});
}
setIsModelLoading('all');
initializeModelList({ apiKeys: parsedApiKeys, providerSettings })
.then((modelList) => {
console.log('Model List: ', modelList);
setModelList(modelList);
})
.catch((error) => {
console.error('Error initializing model list:', error);
})
.finally(() => {
setIsModelLoading(undefined);
});
}, [providerList]);

const onApiKeysChange = async (providerName: string, apiKey: string) => {
Expand Down Expand Up @@ -401,28 +403,32 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
<rect className={classNames(styles.PromptShine)} x="48" y="24" width="70" height="1"></rect>
</svg>
<div>
<div className={isModelSettingsCollapsed ? 'hidden' : ''}>
<ModelSelector
key={provider?.name + ':' + modelList.length}
model={model}
setModel={setModel}
modelList={modelList}
provider={provider}
setProvider={setProvider}
providerList={providerList || (PROVIDER_LIST as ProviderInfo[])}
apiKeys={apiKeys}
modelLoading={isModelLoading}
/>
{(providerList || []).length > 0 && provider && (
<APIKeyManager
provider={provider}
apiKey={apiKeys[provider.name] || ''}
setApiKey={(key) => {
onApiKeysChange(provider.name, key);
}}
/>
<ClientOnly>
{() => (
<div className={isModelSettingsCollapsed ? 'hidden' : ''}>
<ModelSelector
key={provider?.name + ':' + modelList.length}
model={model}
setModel={setModel}
modelList={modelList}
provider={provider}
setProvider={setProvider}
providerList={providerList || (PROVIDER_LIST as ProviderInfo[])}
apiKeys={apiKeys}
modelLoading={isModelLoading}
/>
{(providerList || []).length > 0 && provider && (
<APIKeyManager
provider={provider}
apiKey={apiKeys[provider.name] || ''}
setApiKey={(key) => {
onApiKeysChange(provider.name, key);
}}
/>
)}
</div>
)}
</div>
</ClientOnly>
</div>
<FilePreview
files={uploadedFiles}
Expand Down
7 changes: 4 additions & 3 deletions app/components/chat/Chat.client.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ export const ChatImpl = memo(
});
useEffect(() => {
const prompt = searchParams.get('prompt');
console.log(prompt, searchParams, model, provider);

// console.log(prompt, searchParams, model, provider);

if (prompt) {
setSearchParams({});
Expand Down Expand Up @@ -289,14 +290,14 @@ export const ChatImpl = memo(

// reload();

const template = await selectStarterTemplate({
const { template, title } = await selectStarterTemplate({
message: messageInput,
model,
provider,
});

if (template !== 'blank') {
const temResp = await getTemplates(template);
const temResp = await getTemplates(template, title);

if (temResp) {
const { assistantMessage, userMessage } = temResp;
Expand Down
3 changes: 2 additions & 1 deletion app/components/settings/providers/ProvidersTab.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ import type { IProviderConfig } from '~/types/model';
import { logStore } from '~/lib/stores/logs';

// Import a default fallback icon
import DefaultIcon from '/icons/Default.svg'; // Adjust the path as necessary
import { providerBaseUrlEnvKeys } from '~/utils/constants';

const DefaultIcon = '/icons/Default.svg'; // Adjust the path as necessary

export default function ProvidersTab() {
const { providers, updateProviderSettings, isLocalModel } = useSettings();
const [filteredProviders, setFilteredProviders] = useState<IProviderConfig[]>([]);
Expand Down
3 changes: 1 addition & 2 deletions app/entry.server.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import { renderToReadableStream } from 'react-dom/server';
import { renderHeadToString } from 'remix-island';
import { Head } from './root';
import { themeStore } from '~/lib/stores/theme';
import { initializeModelList } from '~/utils/constants';

export default async function handleRequest(
request: Request,
Expand All @@ -14,7 +13,7 @@ export default async function handleRequest(
remixContext: EntryContext,
_loadContext: AppLoadContext,
) {
await initializeModelList({});
// await initializeModelList({});

const readable = await renderToReadableStream(<RemixServer context={remixContext} url={request.url} />, {
signal: request.signal,
Expand Down
45 changes: 35 additions & 10 deletions app/lib/.server/llm/stream-text.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import { getSystemPrompt } from '~/lib/common/prompts/prompts';
import {
DEFAULT_MODEL,
DEFAULT_PROVIDER,
getModelList,
MODEL_REGEX,
MODIFICATIONS_TAG_NAME,
PROVIDER_LIST,
Expand All @@ -15,6 +14,8 @@ import ignore from 'ignore';
import type { IProviderSetting } from '~/types/model';
import { PromptLibrary } from '~/lib/common/prompt-library';
import { allowedHTMLElements } from '~/utils/markdown';
import { LLMManager } from '~/lib/modules/llm/manager';
import { createScopedLogger } from '~/utils/logger';

interface ToolResult<Name extends string, Args, Result> {
toolCallId: string;
Expand Down Expand Up @@ -142,6 +143,8 @@ function extractPropertiesFromMessage(message: Message): { model: string; provid
return { model, provider, content: cleanedContent };
}

const logger = createScopedLogger('stream-text');

export async function streamText(props: {
messages: Messages;
env: Env;
Expand All @@ -158,15 +161,10 @@ export async function streamText(props: {

let currentModel = DEFAULT_MODEL;
let currentProvider = DEFAULT_PROVIDER.name;
const MODEL_LIST = await getModelList({ apiKeys, providerSettings, serverEnv: serverEnv as any });
const processedMessages = messages.map((message) => {
if (message.role === 'user') {
const { model, provider, content } = extractPropertiesFromMessage(message);

if (MODEL_LIST.find((m) => m.name === model)) {
currentModel = model;
}

currentModel = model;
currentProvider = provider;

return { ...message, content };
Expand All @@ -183,11 +181,36 @@ export async function streamText(props: {
return message;
});

const modelDetails = MODEL_LIST.find((m) => m.name === currentModel);
const provider = PROVIDER_LIST.find((p) => p.name === currentProvider) || DEFAULT_PROVIDER;
const staticModels = LLMManager.getInstance().getStaticModelListFromProvider(provider);
let modelDetails = staticModels.find((m) => m.name === currentModel);

if (!modelDetails) {
const modelsList = [
...(provider.staticModels || []),
...(await LLMManager.getInstance().getModelListFromProvider(provider, {
apiKeys,
providerSettings,
serverEnv: serverEnv as any,
})),
];

if (!modelsList.length) {
throw new Error(`No models found for provider ${provider.name}`);
}

const dynamicMaxTokens = modelDetails && modelDetails.maxTokenAllowed ? modelDetails.maxTokenAllowed : MAX_TOKENS;
modelDetails = modelsList.find((m) => m.name === currentModel);

const provider = PROVIDER_LIST.find((p) => p.name === currentProvider) || DEFAULT_PROVIDER;
if (!modelDetails) {
// Fallback to first model
logger.warn(
`MODEL [${currentModel}] not found in provider [${provider.name}]. Falling back to first model. ${modelsList[0].name}`,
);
modelDetails = modelsList[0];
}
}

const dynamicMaxTokens = modelDetails && modelDetails.maxTokenAllowed ? modelDetails.maxTokenAllowed : MAX_TOKENS;

let systemPrompt =
PromptLibrary.getPropmtFromLibrary(promptId || 'default', {
Expand All @@ -201,6 +224,8 @@ export async function streamText(props: {
systemPrompt = `${systemPrompt}\n\n ${codeContext}`;
}

logger.info(`Sending llm call to ${provider.name} with model ${modelDetails.name}`);

return _streamText({
model: provider.getModelInstance({
model: currentModel,
Expand Down
52 changes: 52 additions & 0 deletions app/lib/modules/llm/base-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ export abstract class BaseProvider implements ProviderInfo {
abstract name: string;
abstract staticModels: ModelInfo[];
abstract config: ProviderConfig;
cachedDynamicModels?: {
cacheId: string;
models: ModelInfo[];
};

getApiKeyLink?: string;
labelForGetApiKey?: string;
Expand Down Expand Up @@ -49,6 +53,54 @@ export abstract class BaseProvider implements ProviderInfo {
apiKey,
};
}
getModelsFromCache(options: {
apiKeys?: Record<string, string>;
providerSettings?: Record<string, IProviderSetting>;
serverEnv?: Record<string, string>;
}): ModelInfo[] | null {
if (!this.cachedDynamicModels) {
// console.log('no dynamic models',this.name);
return null;
}

const cacheKey = this.cachedDynamicModels.cacheId;
const generatedCacheKey = this.getDynamicModelsCacheKey(options);

if (cacheKey !== generatedCacheKey) {
// console.log('cache key mismatch',this.name,cacheKey,generatedCacheKey);
this.cachedDynamicModels = undefined;
return null;
}

return this.cachedDynamicModels.models;
}
getDynamicModelsCacheKey(options: {
apiKeys?: Record<string, string>;
providerSettings?: Record<string, IProviderSetting>;
serverEnv?: Record<string, string>;
}) {
return JSON.stringify({
apiKeys: options.apiKeys?.[this.name],
providerSettings: options.providerSettings?.[this.name],
serverEnv: options.serverEnv,
});
}
storeDynamicModels(
options: {
apiKeys?: Record<string, string>;
providerSettings?: Record<string, IProviderSetting>;
serverEnv?: Record<string, string>;
},
models: ModelInfo[],
) {
const cacheId = this.getDynamicModelsCacheKey(options);

// console.log('caching dynamic models',this.name,cacheId);
this.cachedDynamicModels = {
cacheId,
models,
};
}

// Declare the optional getDynamicModels method
getDynamicModels?(
Expand Down
Loading
Loading