ai.ts 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. import { generateObject, NoObjectGeneratedError } from "ai";
  2. import { createOpenAI } from "@ai-sdk/openai";
  3. import { createGoogleGenerativeAI } from "@ai-sdk/google";
  4. import { z } from "zod";
  5. import { config, AIProvider } from "./config";
  6. // 延迟函数
  7. const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms));
  8. // 生成配置接口
  9. export interface GenerationConfig<T, I = string> {
  10. systemPrompt: string;
  11. temperature: number;
  12. promptBuilder: (input: I) => string;
  13. schema: z.ZodSchema<T>;
  14. taskName: string;
  15. maxTokens: number;
  16. }
  17. const createAIClient = (provider: AIProvider) => {
  18. if (provider.type === 'google') {
  19. return createGoogleGenerativeAI({
  20. apiKey: provider.apiKey,
  21. baseURL: provider.baseUrl,
  22. });
  23. } else {
  24. return createOpenAI({
  25. apiKey: provider.apiKey,
  26. baseURL: provider.baseUrl,
  27. compatibility: "compatible",
  28. });
  29. }
  30. };
  31. // 通用 AI 生成函数
  32. export async function generateWithAI<T, I = string>(
  33. input: I,
  34. generationConfig: GenerationConfig<T, I>
  35. ): Promise<T> {
  36. const providers = config.PROVIDERS;
  37. if (providers.length === 0) {
  38. console.error("没有配置 API Key");
  39. throw new Error("没有配置 API Key");
  40. }
  41. let lastError: unknown = null;
  42. // 遍历所有提供商
  43. for (let providerIndex = 0; providerIndex < providers.length; providerIndex++) {
  44. const provider = providers[providerIndex];
  45. // 检查是否跳过此提供商(第一个提供商不跳过)
  46. if (providerIndex > 0 && Math.random() < (provider.skipProbability ?? 0)) {
  47. console.log(`跳过提供商: ${provider.name} (跳过概率: ${provider.skipProbability})`);
  48. continue;
  49. }
  50. const retryCount = provider.retryCount ?? 1;
  51. console.log(`使用提供商: ${provider.name},模型: ${provider.model},重试次数: ${retryCount}`);
  52. // 对当前提供商进行重试
  53. for (let attempt = 0; attempt < retryCount; attempt++) {
  54. try {
  55. console.log(`提供商 ${provider.name} 第 ${attempt + 1}/${retryCount} 次尝试`);
  56. const llm = createAIClient(provider);
  57. const generateOptions = {
  58. model: llm(provider.model),
  59. system: generationConfig.systemPrompt,
  60. prompt: generationConfig.promptBuilder(input),
  61. schema: generationConfig.schema,
  62. temperature: generationConfig.temperature,
  63. maxTokens: generationConfig.maxTokens,
  64. retryCount: 1,
  65. mode: provider.mode || 'auto',
  66. // eslint-disable-next-line @typescript-eslint/no-explicit-any
  67. experimental_repairText: provider.mode === 'json' ? async (options: any) => {
  68. options.text = options.text.replace('```json\n', '').replace('\n```', '');
  69. return options.text;
  70. } : undefined,
  71. };
  72. const { object } = await generateObject(generateOptions);
  73. console.log(`提供商 ${provider.name} 第 ${attempt + 1} 次尝试成功`);
  74. return object as T;
  75. } catch (error) {
  76. lastError = error;
  77. console.error(`提供商 ${provider.name} 第 ${attempt + 1} 次尝试失败:`, error);
  78. if (NoObjectGeneratedError.isInstance(error)) {
  79. console.log("NoObjectGeneratedError 详情:");
  80. console.log("Cause:", error.cause);
  81. console.log("Text:", error.text);
  82. console.log("Response:", error.response);
  83. console.log("Usage:", error.usage);
  84. console.log("Finish Reason:", error.finishReason);
  85. }
  86. // 如果不是最后一次尝试,等待后再重试
  87. if (attempt < retryCount - 1) {
  88. const waitTime = (attempt + 1) * 200; // 递增等待时间
  89. console.log(`等待 ${waitTime} 毫秒后重试...`);
  90. await sleep(waitTime);
  91. }
  92. }
  93. }
  94. console.log(`提供商 ${provider.name} 所有尝试都失败了`);
  95. }
  96. console.error("所有提供商都失败了:", lastError);
  97. throw new Error(`${generationConfig.taskName}失败: ${lastError}`);
  98. }