| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311 |
- from __future__ import annotations
- import json
- import re
- import sqlite3
- from typing import Any
- import httpx
- from fastapi import HTTPException
- from pydantic import ValidationError
- from . import ai_gemini, ai_openai
- from .database import get_db
- from .scanner import now_iso
- from .schemas import (
- AiImportItem,
- AiModelCreate,
- AiModelUpdate,
- AiProviderCreate,
- AiProviderUpdate,
- )
- def public_provider(row: dict[str, Any]) -> dict[str, Any]:
- item = dict(row)
- item["enabled"] = bool(item["enabled"])
- item["api_key_set"] = bool(item.get("api_key"))
- item.pop("api_key", None)
- return item
- def public_model(row: dict[str, Any]) -> dict[str, Any]:
- item = dict(row)
- item["is_default"] = bool(item["is_default"])
- return item
- def list_providers() -> list[dict[str, Any]]:
- with get_db() as conn:
- rows = conn.execute("SELECT * FROM ai_providers ORDER BY name ASC").fetchall()
- return [public_provider(row) for row in rows]
- def create_provider(payload: AiProviderCreate) -> dict[str, Any]:
- now = now_iso()
- try:
- with get_db() as conn:
- cursor = conn.execute(
- """
- INSERT INTO ai_providers (name, provider_type, base_url, api_key, enabled, created_at, updated_at)
- VALUES (?, ?, ?, ?, ?, ?, ?)
- """,
- (
- payload.name.strip(),
- payload.provider_type,
- clean_optional(payload.base_url),
- clean_optional(payload.api_key),
- 1 if payload.enabled else 0,
- now,
- now,
- ),
- )
- row = conn.execute("SELECT * FROM ai_providers WHERE id = ?", (cursor.lastrowid,)).fetchone()
- except sqlite3.IntegrityError as exc:
- raise HTTPException(status_code=409, detail="AI provider name already exists") from exc
- return public_provider(row)
- def update_provider(provider_id: int, payload: AiProviderUpdate) -> dict[str, Any]:
- now = now_iso()
- with get_db() as conn:
- existing = conn.execute("SELECT * FROM ai_providers WHERE id = ?", (provider_id,)).fetchone()
- if not existing:
- raise HTTPException(status_code=404, detail="AI provider not found")
- api_key = existing.get("api_key")
- if payload.clear_api_key:
- api_key = None
- elif payload.api_key is not None and payload.api_key != "":
- api_key = payload.api_key
- try:
- conn.execute(
- """
- UPDATE ai_providers
- SET name = ?, provider_type = ?, base_url = ?, api_key = ?, enabled = ?, updated_at = ?
- WHERE id = ?
- """,
- (
- payload.name.strip(),
- payload.provider_type,
- clean_optional(payload.base_url),
- clean_optional(api_key),
- 1 if payload.enabled else 0,
- now,
- provider_id,
- ),
- )
- except sqlite3.IntegrityError as exc:
- raise HTTPException(status_code=409, detail="AI provider name already exists") from exc
- row = conn.execute("SELECT * FROM ai_providers WHERE id = ?", (provider_id,)).fetchone()
- return public_provider(row)
- def delete_provider(provider_id: int) -> dict[str, Any]:
- with get_db() as conn:
- cursor = conn.execute("DELETE FROM ai_providers WHERE id = ?", (provider_id,))
- if cursor.rowcount == 0:
- raise HTTPException(status_code=404, detail="AI provider not found")
- return {"deleted": cursor.rowcount}
- def list_models(provider_id: int | None = None) -> list[dict[str, Any]]:
- where = "WHERE m.provider_id = ?" if provider_id else ""
- params = [provider_id] if provider_id else []
- with get_db() as conn:
- rows = conn.execute(
- f"""
- SELECT m.*, p.name AS provider_name, p.provider_type
- FROM ai_models m
- JOIN ai_providers p ON p.id = m.provider_id
- {where}
- ORDER BY p.name ASC, m.is_default DESC, m.name ASC
- """,
- params,
- ).fetchall()
- return [public_model(row) for row in rows]
- def create_model(payload: AiModelCreate) -> dict[str, Any]:
- now = now_iso()
- with get_db() as conn:
- ensure_provider_exists(conn, payload.provider_id)
- try:
- cursor = conn.execute(
- """
- INSERT INTO ai_models (provider_id, name, display_name, is_default, created_at, updated_at)
- VALUES (?, ?, ?, ?, ?, ?)
- """,
- (
- payload.provider_id,
- payload.name.strip(),
- clean_optional(payload.display_name),
- 1 if payload.is_default else 0,
- now,
- now,
- ),
- )
- except sqlite3.IntegrityError as exc:
- raise HTTPException(status_code=409, detail="AI model already exists for this provider") from exc
- if payload.is_default:
- clear_other_default_models(conn, payload.provider_id, cursor.lastrowid)
- row = get_model_row(conn, cursor.lastrowid)
- return public_model(row)
- def update_model(model_id: int, payload: AiModelUpdate) -> dict[str, Any]:
- now = now_iso()
- with get_db() as conn:
- ensure_provider_exists(conn, payload.provider_id)
- if not conn.execute("SELECT id FROM ai_models WHERE id = ?", (model_id,)).fetchone():
- raise HTTPException(status_code=404, detail="AI model not found")
- try:
- conn.execute(
- """
- UPDATE ai_models
- SET provider_id = ?, name = ?, display_name = ?, is_default = ?, updated_at = ?
- WHERE id = ?
- """,
- (
- payload.provider_id,
- payload.name.strip(),
- clean_optional(payload.display_name),
- 1 if payload.is_default else 0,
- now,
- model_id,
- ),
- )
- except sqlite3.IntegrityError as exc:
- raise HTTPException(status_code=409, detail="AI model already exists for this provider") from exc
- if payload.is_default:
- clear_other_default_models(conn, payload.provider_id, model_id)
- row = get_model_row(conn, model_id)
- return public_model(row)
- def delete_model(model_id: int) -> dict[str, Any]:
- with get_db() as conn:
- cursor = conn.execute("DELETE FROM ai_models WHERE id = ?", (model_id,))
- if cursor.rowcount == 0:
- raise HTTPException(status_code=404, detail="AI model not found")
- return {"deleted": cursor.rowcount}
- def chat(provider_id: int, model_id: int, prompt: str, temperature: float) -> dict[str, Any]:
- provider, model = get_provider_and_model(provider_id, model_id)
- try:
- if provider["provider_type"] in {"OPENAI", "OPENAI_COMPATIBLE"}:
- result = ai_openai.chat(provider, model, prompt, temperature)
- elif provider["provider_type"] == "GOOGLE_GEMINI":
- result = ai_gemini.chat(provider, model, prompt, temperature)
- else:
- raise HTTPException(status_code=400, detail="Unsupported AI provider type")
- except httpx.HTTPStatusError as exc:
- detail = exc.response.text[:1000] if exc.response is not None else str(exc)
- raise HTTPException(status_code=502, detail=f"AI provider returned an error: {detail}") from exc
- except httpx.HTTPError as exc:
- raise HTTPException(status_code=502, detail=f"AI provider request failed: {exc}") from exc
- return {
- "provider": public_provider(provider),
- "model": public_model(model),
- **result,
- }
- def chat_with_images(
- provider_id: int,
- model_id: int,
- prompt: str,
- images: list[dict[str, str]],
- temperature: float,
- ) -> dict[str, Any]:
- provider, model = get_provider_and_model(provider_id, model_id)
- try:
- if provider["provider_type"] in {"OPENAI", "OPENAI_COMPATIBLE"}:
- result = ai_openai.chat_with_images(provider, model, prompt, images, temperature)
- elif provider["provider_type"] == "GOOGLE_GEMINI":
- result = ai_gemini.chat_with_images(provider, model, prompt, images, temperature)
- else:
- raise HTTPException(status_code=400, detail="Unsupported AI provider type")
- except httpx.HTTPStatusError as exc:
- detail = exc.response.text[:1000] if exc.response is not None else str(exc)
- raise HTTPException(status_code=502, detail=f"AI provider returned an error: {detail}") from exc
- except httpx.HTTPError as exc:
- raise HTTPException(status_code=502, detail=f"AI provider request failed: {exc}") from exc
- return {
- "provider": public_provider(provider),
- "model": public_model(model),
- **result,
- }
- def parse_ai_items(content: str) -> list[dict[str, Any]]:
- parsed = json.loads(extract_json_text(content))
- items = parsed.get("items") if isinstance(parsed, dict) else parsed
- if not isinstance(items, list):
- raise ValueError("AI output must be a JSON array or an object containing items")
- validated = []
- for item in items:
- validated.append(AiImportItem.model_validate(item).model_dump())
- return validated
- def extract_json_text(content: str) -> str:
- text = content.strip()
- fenced = re.search(r"```(?:json)?\s*(.*?)```", text, re.DOTALL | re.IGNORECASE)
- if fenced:
- text = fenced.group(1).strip()
- if text.startswith("[") or text.startswith("{"):
- return text
- start_candidates = [index for index in [text.find("["), text.find("{")] if index >= 0]
- if not start_candidates:
- return text
- start = min(start_candidates)
- end = max(text.rfind("]"), text.rfind("}"))
- return text[start : end + 1] if end > start else text[start:]
- def get_provider_and_model(provider_id: int, model_id: int) -> tuple[dict[str, Any], dict[str, Any]]:
- with get_db() as conn:
- provider = conn.execute("SELECT * FROM ai_providers WHERE id = ?", (provider_id,)).fetchone()
- model = conn.execute("SELECT * FROM ai_models WHERE id = ?", (model_id,)).fetchone()
- if not provider:
- raise HTTPException(status_code=404, detail="AI provider not found")
- if not provider["enabled"]:
- raise HTTPException(status_code=400, detail="AI provider is disabled")
- if not model or model["provider_id"] != provider_id:
- raise HTTPException(status_code=400, detail="AI model does not belong to this provider")
- return provider, model
- def ensure_provider_exists(conn, provider_id: int) -> None:
- if not conn.execute("SELECT id FROM ai_providers WHERE id = ?", (provider_id,)).fetchone():
- raise HTTPException(status_code=404, detail="AI provider not found")
- def get_model_row(conn, model_id: int) -> dict[str, Any]:
- row = conn.execute(
- """
- SELECT m.*, p.name AS provider_name, p.provider_type
- FROM ai_models m
- JOIN ai_providers p ON p.id = m.provider_id
- WHERE m.id = ?
- """,
- (model_id,),
- ).fetchone()
- if not row:
- raise HTTPException(status_code=404, detail="AI model not found")
- return row
- def clear_other_default_models(conn, provider_id: int, model_id: int) -> None:
- conn.execute(
- "UPDATE ai_models SET is_default = 0, updated_at = ? WHERE provider_id = ? AND id <> ?",
- (now_iso(), provider_id, model_id),
- )
- def clean_optional(value: str | None) -> str | None:
- if value is None:
- return None
- stripped = value.strip()
- return stripped or None
|