import { DEFAULT_MODELS, ServiceProvider } from "../constant"; import { LLMModel } from "../client/api"; const CustomSeq = { val: -1000, //To ensure the custom model located at front, start from -1000, refer to constant.ts cache: new Map<string, number>(), next: (id: string) => { if (CustomSeq.cache.has(id)) { return CustomSeq.cache.get(id) as number; } else { let seq = CustomSeq.val++; CustomSeq.cache.set(id, seq); return seq; } }, }; const customProvider = (providerName: string) => ({ id: providerName.toLowerCase(), providerName: providerName, providerType: "custom", sorted: CustomSeq.next(providerName), }); /** * Sorts an array of models based on specified rules. * * First, sorted by provider; if the same, sorted by model */ const sortModelTable = (models: ReturnType<typeof collectModels>) => models.sort((a, b) => { if (a.provider && b.provider) { let cmp = a.provider.sorted - b.provider.sorted; return cmp === 0 ? a.sorted - b.sorted : cmp; } else { return a.sorted - b.sorted; } }); /** * get model name and provider from a formatted string, * e.g. `gpt-4@OpenAi` or `claude-3-5-sonnet@20240620@Google` * @param modelWithProvider model name with provider separated by last `@` char, * @returns [model, provider] tuple, if no `@` char found, provider is undefined */ export function getModelProvider(modelWithProvider: string): [string, string?] { const [model, provider] = modelWithProvider.split(/@(?!.*@)/); return [model, provider]; } export function collectModelTable( models: readonly LLMModel[], customModels: string, ) { const modelTable: Record< string, { available: boolean; name: string; displayName: string; sorted: number; provider?: LLMModel["provider"]; // Marked as optional isDefault?: boolean; } > = {}; // default models models.forEach((m) => { // using <modelName>@<providerId> as fullName modelTable[`${m.name}@${m?.provider?.id}`] = { ...m, displayName: m.name, // 'provider' is copied over if it exists }; }); // server custom models customModels .split(",") .filter((v) => !!v && v.length > 0) .forEach((m) => { const available = !m.startsWith("-"); const nameConfig = m.startsWith("+") || m.startsWith("-") ? m.slice(1) : m; let [name, displayName] = nameConfig.split("="); // enable or disable all models if (name === "all") { Object.values(modelTable).forEach( (model) => (model.available = available), ); } else { // 1. find model by name, and set available value const [customModelName, customProviderName] = getModelProvider(name); let count = 0; for (const fullName in modelTable) { const [modelName, providerName] = getModelProvider(fullName); if ( customModelName == modelName && (customProviderName === undefined || customProviderName === providerName) ) { count += 1; modelTable[fullName]["available"] = available; // swap name and displayName for bytedance if (providerName === "bytedance") { [name, displayName] = [displayName, modelName]; modelTable[fullName]["name"] = name; } if (displayName) { modelTable[fullName]["displayName"] = displayName; } } } // 2. if model not exists, create new model with available value if (count === 0) { let [customModelName, customProviderName] = getModelProvider(name); const provider = customProvider( customProviderName || customModelName, ); // swap name and displayName for bytedance if (displayName && provider.providerName == "ByteDance") { [customModelName, displayName] = [displayName, customModelName]; } modelTable[`${customModelName}@${provider?.id}`] = { name: customModelName, displayName: displayName || customModelName, available, provider, // Use optional chaining sorted: CustomSeq.next(`${customModelName}@${provider?.id}`), }; } } }); return modelTable; } export function collectModelTableWithDefaultModel( models: readonly LLMModel[], customModels: string, defaultModel: string, ) { let modelTable = collectModelTable(models, customModels); if (defaultModel && defaultModel !== "") { if (defaultModel.includes("@")) { if (defaultModel in modelTable) { modelTable[defaultModel].isDefault = true; } } else { for (const key of Object.keys(modelTable)) { if ( modelTable[key].available && getModelProvider(key)[0] == defaultModel ) { modelTable[key].isDefault = true; break; } } } } return modelTable; } /** * Generate full model table. */ export function collectModels( models: readonly LLMModel[], customModels: string, ) { const modelTable = collectModelTable(models, customModels); let allModels = Object.values(modelTable); allModels = sortModelTable(allModels); return allModels; } export function collectModelsWithDefaultModel( models: readonly LLMModel[], customModels: string, defaultModel: string, ) { const modelTable = collectModelTableWithDefaultModel( models, customModels, defaultModel, ); let allModels = Object.values(modelTable); allModels = sortModelTable(allModels); return allModels; } export function isModelAvailableInServer( customModels: string, modelName: string, providerName: string, ) { const fullName = `${modelName}@${providerName}`; const modelTable = collectModelTable(DEFAULT_MODELS, customModels); return modelTable[fullName]?.available === false; } /** * Check if the model name is a GPT-4 related model * * @param modelName The name of the model to check * @returns True if the model is a GPT-4 related model (excluding gpt-4o-mini) */ export function isGPT4Model(modelName: string): boolean { return ( (modelName.startsWith("gpt-4") || modelName.startsWith("chatgpt-4o") || modelName.startsWith("o1")) && !modelName.startsWith("gpt-4o-mini") ); } /** * Checks if a model is not available on any of the specified providers in the server. * * @param {string} customModels - A string of custom models, comma-separated. * @param {string} modelName - The name of the model to check. * @param {string|string[]} providerNames - A string or array of provider names to check against. * * @returns {boolean} True if the model is not available on any of the specified providers, false otherwise. */ export function isModelNotavailableInServer( customModels: string, modelName: string, providerNames: string | string[], ): boolean { // Check DISABLE_GPT4 environment variable if ( process.env.DISABLE_GPT4 === "1" && isGPT4Model(modelName.toLowerCase()) ) { return true; } const modelTable = collectModelTable(DEFAULT_MODELS, customModels); const providerNamesArray = Array.isArray(providerNames) ? providerNames : [providerNames]; for (const providerName of providerNamesArray) { // if model provider is bytedance, use model config name to check if not avaliable if (providerName === ServiceProvider.ByteDance) { return !Object.values(modelTable).filter((v) => v.name === modelName)?.[0] ?.available; } const fullName = `${modelName}@${providerName.toLowerCase()}`; if (modelTable?.[fullName]?.available === true) return false; } return true; }