|
1 | 1 | import z from "zod" |
2 | 2 | import fuzzysort from "fuzzysort" |
3 | 3 | import { Config } from "../config/config" |
4 | | -import { mapValues, mergeDeep, sortBy } from "remeda" |
| 4 | +import { mapValues, mergeDeep, omit, pickBy, sortBy } from "remeda" |
5 | 5 | import { NoSuchModelError, type Provider as SDK } from "ai" |
6 | 6 | import { Log } from "../util/log" |
7 | 7 | import { BunProc } from "../bun" |
@@ -405,16 +405,6 @@ export namespace Provider { |
405 | 405 | }, |
406 | 406 | } |
407 | 407 |
|
408 | | - export const Variant = z |
409 | | - .object({ |
410 | | - disabled: z.boolean(), |
411 | | - }) |
412 | | - .catchall(z.any()) |
413 | | - .meta({ |
414 | | - ref: "Variant", |
415 | | - }) |
416 | | - export type Variant = z.infer<typeof Variant> |
417 | | - |
418 | 408 | export const Model = z |
419 | 409 | .object({ |
420 | 410 | id: z.string(), |
@@ -478,7 +468,7 @@ export namespace Provider { |
478 | 468 | options: z.record(z.string(), z.any()), |
479 | 469 | headers: z.record(z.string(), z.string()), |
480 | 470 | release_date: z.string(), |
481 | | - variants: z.record(z.string(), Variant).optional(), |
| 471 | + variants: z.record(z.string(), z.record(z.string(), z.any())).optional(), |
482 | 472 | }) |
483 | 473 | .meta({ |
484 | 474 | ref: "Model", |
@@ -561,7 +551,7 @@ export namespace Provider { |
561 | 551 | variants: {}, |
562 | 552 | } |
563 | 553 |
|
564 | | - m.variants = mapValues(ProviderTransform.variants(m), (v) => ({ disabled: false, ...v })) |
| 554 | + m.variants = mapValues(ProviderTransform.variants(m), (v) => v) |
565 | 555 |
|
566 | 556 | return m |
567 | 557 | } |
@@ -697,7 +687,13 @@ export namespace Provider { |
697 | 687 | headers: mergeDeep(existingModel?.headers ?? {}, model.headers ?? {}), |
698 | 688 | family: model.family ?? existingModel?.family ?? "", |
699 | 689 | release_date: model.release_date ?? existingModel?.release_date ?? "", |
| 690 | + variants: {}, |
700 | 691 | } |
| 692 | + const merged = mergeDeep(ProviderTransform.variants(parsedModel), model.variants ?? {}) |
| 693 | + parsedModel.variants = mapValues( |
| 694 | + pickBy(merged, (v) => !v.disabled), |
| 695 | + (v) => omit(v, ["disabled"]), |
| 696 | + ) |
701 | 697 | parsed.models[modelID] = parsedModel |
702 | 698 | } |
703 | 699 | database[providerID] = parsed |
@@ -822,6 +818,16 @@ export namespace Provider { |
822 | 818 | (configProvider?.whitelist && !configProvider.whitelist.includes(modelID)) |
823 | 819 | ) |
824 | 820 | delete provider.models[modelID] |
| 821 | + |
| 822 | + // Filter out disabled variants from config |
| 823 | + const configVariants = configProvider?.models?.[modelID]?.variants |
| 824 | + if (configVariants && model.variants) { |
| 825 | + const merged = mergeDeep(model.variants, configVariants) |
| 826 | + model.variants = mapValues( |
| 827 | + pickBy(merged, (v) => !v.disabled), |
| 828 | + (v) => omit(v, ["disabled"]), |
| 829 | + ) |
| 830 | + } |
825 | 831 | } |
826 | 832 |
|
827 | 833 | if (Object.keys(provider.models).length === 0) { |
|
0 commit comments