ai_service.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. from __future__ import annotations
  2. import json
  3. import re
  4. import sqlite3
  5. from typing import Any
  6. import httpx
  7. from fastapi import HTTPException
  8. from pydantic import ValidationError
  9. from . import ai_gemini, ai_openai
  10. from .database import get_db
  11. from .scanner import now_iso
  12. from .schemas import (
  13. AiImportItem,
  14. AiModelCreate,
  15. AiModelUpdate,
  16. AiProviderCreate,
  17. AiProviderUpdate,
  18. )
  19. def public_provider(row: dict[str, Any]) -> dict[str, Any]:
  20. item = dict(row)
  21. item["enabled"] = bool(item["enabled"])
  22. item["api_key_set"] = bool(item.get("api_key"))
  23. item.pop("api_key", None)
  24. return item
  25. def public_model(row: dict[str, Any]) -> dict[str, Any]:
  26. item = dict(row)
  27. item["is_default"] = bool(item["is_default"])
  28. return item
  29. def list_providers() -> list[dict[str, Any]]:
  30. with get_db() as conn:
  31. rows = conn.execute("SELECT * FROM ai_providers ORDER BY name ASC").fetchall()
  32. return [public_provider(row) for row in rows]
  33. def create_provider(payload: AiProviderCreate) -> dict[str, Any]:
  34. now = now_iso()
  35. try:
  36. with get_db() as conn:
  37. cursor = conn.execute(
  38. """
  39. INSERT INTO ai_providers (name, provider_type, base_url, api_key, enabled, created_at, updated_at)
  40. VALUES (?, ?, ?, ?, ?, ?, ?)
  41. """,
  42. (
  43. payload.name.strip(),
  44. payload.provider_type,
  45. clean_optional(payload.base_url),
  46. clean_optional(payload.api_key),
  47. 1 if payload.enabled else 0,
  48. now,
  49. now,
  50. ),
  51. )
  52. row = conn.execute("SELECT * FROM ai_providers WHERE id = ?", (cursor.lastrowid,)).fetchone()
  53. except sqlite3.IntegrityError as exc:
  54. raise HTTPException(status_code=409, detail="AI provider name already exists") from exc
  55. return public_provider(row)
  56. def update_provider(provider_id: int, payload: AiProviderUpdate) -> dict[str, Any]:
  57. now = now_iso()
  58. with get_db() as conn:
  59. existing = conn.execute("SELECT * FROM ai_providers WHERE id = ?", (provider_id,)).fetchone()
  60. if not existing:
  61. raise HTTPException(status_code=404, detail="AI provider not found")
  62. api_key = existing.get("api_key")
  63. if payload.clear_api_key:
  64. api_key = None
  65. elif payload.api_key is not None and payload.api_key != "":
  66. api_key = payload.api_key
  67. try:
  68. conn.execute(
  69. """
  70. UPDATE ai_providers
  71. SET name = ?, provider_type = ?, base_url = ?, api_key = ?, enabled = ?, updated_at = ?
  72. WHERE id = ?
  73. """,
  74. (
  75. payload.name.strip(),
  76. payload.provider_type,
  77. clean_optional(payload.base_url),
  78. clean_optional(api_key),
  79. 1 if payload.enabled else 0,
  80. now,
  81. provider_id,
  82. ),
  83. )
  84. except sqlite3.IntegrityError as exc:
  85. raise HTTPException(status_code=409, detail="AI provider name already exists") from exc
  86. row = conn.execute("SELECT * FROM ai_providers WHERE id = ?", (provider_id,)).fetchone()
  87. return public_provider(row)
  88. def delete_provider(provider_id: int) -> dict[str, Any]:
  89. with get_db() as conn:
  90. cursor = conn.execute("DELETE FROM ai_providers WHERE id = ?", (provider_id,))
  91. if cursor.rowcount == 0:
  92. raise HTTPException(status_code=404, detail="AI provider not found")
  93. return {"deleted": cursor.rowcount}
  94. def list_models(provider_id: int | None = None) -> list[dict[str, Any]]:
  95. where = "WHERE m.provider_id = ?" if provider_id else ""
  96. params = [provider_id] if provider_id else []
  97. with get_db() as conn:
  98. rows = conn.execute(
  99. f"""
  100. SELECT m.*, p.name AS provider_name, p.provider_type
  101. FROM ai_models m
  102. JOIN ai_providers p ON p.id = m.provider_id
  103. {where}
  104. ORDER BY p.name ASC, m.is_default DESC, m.name ASC
  105. """,
  106. params,
  107. ).fetchall()
  108. return [public_model(row) for row in rows]
  109. def create_model(payload: AiModelCreate) -> dict[str, Any]:
  110. now = now_iso()
  111. with get_db() as conn:
  112. ensure_provider_exists(conn, payload.provider_id)
  113. try:
  114. cursor = conn.execute(
  115. """
  116. INSERT INTO ai_models (provider_id, name, display_name, is_default, created_at, updated_at)
  117. VALUES (?, ?, ?, ?, ?, ?)
  118. """,
  119. (
  120. payload.provider_id,
  121. payload.name.strip(),
  122. clean_optional(payload.display_name),
  123. 1 if payload.is_default else 0,
  124. now,
  125. now,
  126. ),
  127. )
  128. except sqlite3.IntegrityError as exc:
  129. raise HTTPException(status_code=409, detail="AI model already exists for this provider") from exc
  130. if payload.is_default:
  131. clear_other_default_models(conn, payload.provider_id, cursor.lastrowid)
  132. row = get_model_row(conn, cursor.lastrowid)
  133. return public_model(row)
  134. def update_model(model_id: int, payload: AiModelUpdate) -> dict[str, Any]:
  135. now = now_iso()
  136. with get_db() as conn:
  137. ensure_provider_exists(conn, payload.provider_id)
  138. if not conn.execute("SELECT id FROM ai_models WHERE id = ?", (model_id,)).fetchone():
  139. raise HTTPException(status_code=404, detail="AI model not found")
  140. try:
  141. conn.execute(
  142. """
  143. UPDATE ai_models
  144. SET provider_id = ?, name = ?, display_name = ?, is_default = ?, updated_at = ?
  145. WHERE id = ?
  146. """,
  147. (
  148. payload.provider_id,
  149. payload.name.strip(),
  150. clean_optional(payload.display_name),
  151. 1 if payload.is_default else 0,
  152. now,
  153. model_id,
  154. ),
  155. )
  156. except sqlite3.IntegrityError as exc:
  157. raise HTTPException(status_code=409, detail="AI model already exists for this provider") from exc
  158. if payload.is_default:
  159. clear_other_default_models(conn, payload.provider_id, model_id)
  160. row = get_model_row(conn, model_id)
  161. return public_model(row)
  162. def delete_model(model_id: int) -> dict[str, Any]:
  163. with get_db() as conn:
  164. cursor = conn.execute("DELETE FROM ai_models WHERE id = ?", (model_id,))
  165. if cursor.rowcount == 0:
  166. raise HTTPException(status_code=404, detail="AI model not found")
  167. return {"deleted": cursor.rowcount}
  168. def chat(provider_id: int, model_id: int, prompt: str, temperature: float) -> dict[str, Any]:
  169. provider, model = get_provider_and_model(provider_id, model_id)
  170. try:
  171. if provider["provider_type"] in {"OPENAI", "OPENAI_COMPATIBLE"}:
  172. result = ai_openai.chat(provider, model, prompt, temperature)
  173. elif provider["provider_type"] == "GOOGLE_GEMINI":
  174. result = ai_gemini.chat(provider, model, prompt, temperature)
  175. else:
  176. raise HTTPException(status_code=400, detail="Unsupported AI provider type")
  177. except httpx.HTTPStatusError as exc:
  178. detail = exc.response.text[:1000] if exc.response is not None else str(exc)
  179. raise HTTPException(status_code=502, detail=f"AI provider returned an error: {detail}") from exc
  180. except httpx.HTTPError as exc:
  181. raise HTTPException(status_code=502, detail=f"AI provider request failed: {exc}") from exc
  182. return {
  183. "provider": public_provider(provider),
  184. "model": public_model(model),
  185. **result,
  186. }
  187. def chat_with_images(
  188. provider_id: int,
  189. model_id: int,
  190. prompt: str,
  191. images: list[dict[str, str]],
  192. temperature: float,
  193. ) -> dict[str, Any]:
  194. provider, model = get_provider_and_model(provider_id, model_id)
  195. try:
  196. if provider["provider_type"] in {"OPENAI", "OPENAI_COMPATIBLE"}:
  197. result = ai_openai.chat_with_images(provider, model, prompt, images, temperature)
  198. elif provider["provider_type"] == "GOOGLE_GEMINI":
  199. result = ai_gemini.chat_with_images(provider, model, prompt, images, temperature)
  200. else:
  201. raise HTTPException(status_code=400, detail="Unsupported AI provider type")
  202. except httpx.HTTPStatusError as exc:
  203. detail = exc.response.text[:1000] if exc.response is not None else str(exc)
  204. raise HTTPException(status_code=502, detail=f"AI provider returned an error: {detail}") from exc
  205. except httpx.HTTPError as exc:
  206. raise HTTPException(status_code=502, detail=f"AI provider request failed: {exc}") from exc
  207. return {
  208. "provider": public_provider(provider),
  209. "model": public_model(model),
  210. **result,
  211. }
  212. def parse_ai_items(content: str) -> list[dict[str, Any]]:
  213. parsed = json.loads(extract_json_text(content))
  214. items = parsed.get("items") if isinstance(parsed, dict) else parsed
  215. if not isinstance(items, list):
  216. raise ValueError("AI output must be a JSON array or an object containing items")
  217. validated = []
  218. for item in items:
  219. validated.append(AiImportItem.model_validate(item).model_dump())
  220. return validated
  221. def extract_json_text(content: str) -> str:
  222. text = content.strip()
  223. fenced = re.search(r"```(?:json)?\s*(.*?)```", text, re.DOTALL | re.IGNORECASE)
  224. if fenced:
  225. text = fenced.group(1).strip()
  226. if text.startswith("[") or text.startswith("{"):
  227. return text
  228. start_candidates = [index for index in [text.find("["), text.find("{")] if index >= 0]
  229. if not start_candidates:
  230. return text
  231. start = min(start_candidates)
  232. end = max(text.rfind("]"), text.rfind("}"))
  233. return text[start : end + 1] if end > start else text[start:]
  234. def get_provider_and_model(provider_id: int, model_id: int) -> tuple[dict[str, Any], dict[str, Any]]:
  235. with get_db() as conn:
  236. provider = conn.execute("SELECT * FROM ai_providers WHERE id = ?", (provider_id,)).fetchone()
  237. model = conn.execute("SELECT * FROM ai_models WHERE id = ?", (model_id,)).fetchone()
  238. if not provider:
  239. raise HTTPException(status_code=404, detail="AI provider not found")
  240. if not provider["enabled"]:
  241. raise HTTPException(status_code=400, detail="AI provider is disabled")
  242. if not model or model["provider_id"] != provider_id:
  243. raise HTTPException(status_code=400, detail="AI model does not belong to this provider")
  244. return provider, model
  245. def ensure_provider_exists(conn, provider_id: int) -> None:
  246. if not conn.execute("SELECT id FROM ai_providers WHERE id = ?", (provider_id,)).fetchone():
  247. raise HTTPException(status_code=404, detail="AI provider not found")
  248. def get_model_row(conn, model_id: int) -> dict[str, Any]:
  249. row = conn.execute(
  250. """
  251. SELECT m.*, p.name AS provider_name, p.provider_type
  252. FROM ai_models m
  253. JOIN ai_providers p ON p.id = m.provider_id
  254. WHERE m.id = ?
  255. """,
  256. (model_id,),
  257. ).fetchone()
  258. if not row:
  259. raise HTTPException(status_code=404, detail="AI model not found")
  260. return row
  261. def clear_other_default_models(conn, provider_id: int, model_id: int) -> None:
  262. conn.execute(
  263. "UPDATE ai_models SET is_default = 0, updated_at = ? WHERE provider_id = ? AND id <> ?",
  264. (now_iso(), provider_id, model_id),
  265. )
  266. def clean_optional(value: str | None) -> str | None:
  267. if value is None:
  268. return None
  269. stripped = value.strip()
  270. return stripped or None