from __future__ import annotations import base64 import json import mimetypes import re import sqlite3 import time import uuid from pathlib import Path from typing import Any import psutil from fastapi import HTTPException from . import ai_service, settings_service, windows_automation from .automation import get_node_definitions, get_node_executor from .automation.context import WorkflowContext, WorkflowPaused from .database import DATA_DIR, get_db from .scanner import now_iso from .schemas import ( AutomationKeyboardActionRequest, AutomationMouseActionRequest, AutomationElementLocateRequest, AutomationScreenshotCaptureRequest, AutomationStartProgramRequest, AutomationTextInputRequest, AutomationVisionAnalyzeRequest, AutomationWorkflowRunRequest, AutomationWorkflowSaveRequest, AutomationWorkflowPlanRequest, AutomationWorkflowPlanContinueRequest, ) AUTOMATION_DIR = DATA_DIR / "automation" SCREEN_DIR = AUTOMATION_DIR / "screens" ERROR_DIR = AUTOMATION_DIR / "errors" RUNTIME_DIR = AUTOMATION_DIR / "runtime" OPENED_PROCESS_IDS: set[int] = set() SCREEN_ANALYZE_PROMPT = """请作为 AI 视觉自动化助手分析这张 Windows 屏幕截图,并严格只输出 JSON 对象。 输出字段: - interface_name:界面名称,简洁中文。 - description:界面描述,说明当前主要窗口或桌面内容。 - is_windows_desktop:boolean,截图是否处于 Windows 桌面。 - is_browser_webpage:boolean,截图是否为浏览器中的网页。 - elements:可操作元素数组。 元素字段: - name:元素名称。 - approximate_location:元素在界面中的大致位置文字描述,例如“窗口右上角”“左侧导航栏中部”“底部任务栏靠左”。不要输出具体坐标或百分比。 判断规则: 1. 如果截图位于 Windows 桌面,请识别桌面图标、开始菜单入口、任务栏应用、托盘区域等可操作元素。 2. 如果不是 Windows 桌面,也就是存在打开的前台窗口或全屏界面,只识别该前台窗口内的可操作元素,不要识别被遮挡的桌面元素。 3. 不要输出 Markdown,不要解释,只输出 JSON。 """ ELEMENT_LOCATE_PROMPT = """请作为 AI 视觉定位助手,在这张 Windows 屏幕截图中查找一个具体的可操作元素。 目标元素名称: {name} 目标元素大致位置描述: {approximate_location} 所在界面描述: {screen_description} 请严格只输出 JSON 对象,字段为: - has_element:boolean,图片中是否能找到该目标元素。 - x_percent:元素中心点 X 相对整张截图宽度的百分比,范围 0-100,可以保留 2 位小数。找不到时为 null。 - y_percent:元素中心点 Y 相对整张截图高度的百分比,范围 0-100,可以保留 2 位小数。找不到时为 null。 - reason:简短中文原因。 只定位这个目标元素,不要列出其他元素。不要输出 Markdown,不要解释,只输出 JSON。 """ SCREEN_COMPARE_PROMPT = """请作为 AI 视觉自动化校验器判断两张截图是否处于同一个目标界面。 图片1是当前实际屏幕截图。图片2是数据库中保存的目标界面截图。 目标界面描述如下: {description} 请严格只输出 JSON 对象,字段为: - is_match:boolean,图片1是否仍然处于目标界面。 - similarity:0 到 1 的数值,表示相似度。 - reason:简短中文原因。 判断时可以允许小的光标位置、时间、列表内容滚动或轻微刷新差异,但如果前台窗口、网页、弹窗、主要页面或应用已经不同,应返回 false。 """ def ensure_dirs() -> None: """确保自动化截图、错误截图和运行时目录存在。""" for path in [screen_dir(), error_dir(), runtime_dir()]: path.mkdir(parents=True, exist_ok=True) def screen_dir() -> Path: """根据系统设置获取已识别界面截图目录。""" return settings_service.resolve_data_path("automation_screen_path", "automation/screens") def error_dir() -> Path: """根据系统设置获取错误截图目录。""" return settings_service.resolve_data_path("automation_error_path", "automation/errors") def runtime_dir() -> Path: """根据系统设置获取临时截图目录。""" return settings_service.resolve_data_path("automation_runtime_path", "automation/runtime") def image_to_base64(path: str | Path) -> dict[str, str]: """读取图片文件并转为 AI 服务可接收的 base64 结构。""" file_path = stored_path(path) mime_type = mimetypes.guess_type(file_path.name)[0] or "image/png" return { "base64": base64.b64encode(file_path.read_bytes()).decode("ascii"), "mime_type": mime_type, } def json_from_ai(content: str) -> dict[str, Any]: """从 AI 输出中提取 JSON 对象,兼容模型误加代码块的情况。""" parsed = json.loads(ai_service.extract_json_text(content)) if not isinstance(parsed, dict): raise ValueError("AI output must be a JSON object") return parsed def take_screenshot_file(folder: Path, prefix: str) -> dict[str, Any]: """截取当前屏幕并保存为 PNG 文件,同时返回 base64 和分辨率信息。""" ensure_dirs() filename = f"{prefix}_{int(time.time() * 1000)}.png" path = folder / filename result = windows_automation.take_screenshot(str(path), include_base64=True) result["path"] = str(path) result["db_path"] = data_relative_path(path) return result def data_relative_path(path: str | Path) -> str: """把 data 目录下的文件路径转换为数据库保存用的相对路径。""" file_path = Path(path).resolve() try: return file_path.relative_to(DATA_DIR.resolve()).as_posix() except ValueError: return str(file_path) def stored_path(path: str | Path) -> Path: """把数据库中的相对路径还原成真实文件路径,同时兼容旧的绝对路径。""" file_path = Path(path) if file_path.is_absolute(): return file_path return (DATA_DIR / file_path).resolve() def resolve_ai_params( provider_id: int | None, model_id: int | None, temperature: float | None, ) -> tuple[int, int, float]: """合并请求参数和系统默认 AI 参数。""" defaults = settings_service.default_ai_params() resolved_provider = provider_id or defaults.get("provider_id") resolved_model = model_id or defaults.get("model_id") resolved_temperature = temperature if temperature is not None else defaults.get("temperature", 0.1) if not resolved_provider or not resolved_model: raise HTTPException(status_code=400, detail="AI provider and model are required. Configure system defaults or pass them explicitly.") return int(resolved_provider), int(resolved_model), float(resolved_temperature) def capture_screenshot(payload: AutomationScreenshotCaptureRequest) -> dict[str, Any]: """截取当前屏幕并返回给前端显示,不进行 AI 分析。""" if payload.save: screenshot = take_screenshot_file(runtime_dir(), "manual_screenshot") else: screenshot = windows_automation.take_screenshot(None, include_base64=True) screenshot["path"] = None screenshot["db_path"] = None return { "width": screenshot["width"], "height": screenshot["height"], "image_base64": screenshot["image_base64"], "mime_type": screenshot["mime_type"], "path": screenshot.get("db_path"), } def analyze_screen(payload: AutomationVisionAnalyzeRequest) -> dict[str, Any]: """截图当前屏幕,调用 AI 识别界面和可操作元素,并保存识别结果。""" provider_id, model_id, temperature = resolve_ai_params(payload.provider_id, payload.model_id, payload.temperature) screenshot = take_screenshot_file(screen_dir(), "screen") image = image_to_base64(screenshot["path"]) ai_result = ai_service.chat_with_images( provider_id, model_id, SCREEN_ANALYZE_PROMPT, [image], temperature, ) try: parsed = json_from_ai(ai_result["content"]) except (json.JSONDecodeError, ValueError) as exc: raise HTTPException(status_code=502, detail=f"AI vision output is not valid JSON: {exc}") from exc width = int(screenshot["width"]) height = int(screenshot["height"]) elements = normalize_elements(parsed.get("elements"), width, height) now = now_iso() with get_db() as conn: cursor = conn.execute( """ INSERT INTO automation_screens ( interface_name, description, image_path, width, height, is_windows_desktop, is_browser_webpage, raw_ai_json, created_at, updated_at ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( str(parsed.get("interface_name") or "未命名界面")[:160], parsed.get("description"), screenshot["db_path"], width, height, 1 if bool(parsed.get("is_windows_desktop")) else 0, 1 if bool(parsed.get("is_browser_webpage")) else 0, json.dumps(parsed, ensure_ascii=False), now, now, ), ) screen_id = cursor.lastrowid for index, element in enumerate(elements, start=1): conn.execute( """ INSERT INTO automation_screen_elements ( screen_id, element_index, name, x_percent, y_percent, x, y, approximate_location, is_located, raw_json, created_at ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( screen_id, index, element["name"], element["x_percent"], element["y_percent"], element["x"], element["y"], element["approximate_location"], 1 if element["is_located"] else 0, json.dumps(element.get("raw") or element, ensure_ascii=False), now, ), ) detail = get_screen(screen_id) detail["image_base64"] = screenshot["image_base64"] detail["mime_type"] = screenshot["mime_type"] detail["ai_raw_content"] = ai_result["content"] return detail def normalize_elements(raw_elements: Any, width: int, height: int) -> list[dict[str, Any]]: """规范化 AI 返回的可操作元素清单;初始分析阶段不要求坐标。""" if not isinstance(raw_elements, list): return [] result = [] for item in raw_elements: if not isinstance(item, dict): continue name = str(item.get("name") or f"元素 {len(result) + 1}")[:160] approximate_location = str(item.get("approximate_location") or item.get("location") or "未定位")[:300] x_percent = normalize_percent(item.get("x_percent")) if item.get("x_percent") is not None else 0.0 y_percent = normalize_percent(item.get("y_percent")) if item.get("y_percent") is not None else 0.0 is_located = item.get("x_percent") is not None and item.get("y_percent") is not None x = round(width * x_percent / 100) y = round(height * y_percent / 100) result.append( { "name": name, "x_percent": x_percent, "y_percent": y_percent, "x": max(0, min(width - 1, x)), "y": max(0, min(height - 1, y)), "approximate_location": approximate_location, "is_located": is_located, "raw": item, } ) return result def locate_element(screen_id: int, element_id: int, payload: AutomationElementLocateRequest) -> dict[str, Any]: """针对单个可操作元素调用 AI 精确定位,并更新该元素的像素坐标。""" provider_id, model_id, temperature = resolve_ai_params(payload.provider_id, payload.model_id, payload.temperature) screen = get_screen(screen_id) element = next((item for item in screen.get("elements", []) if item["id"] == element_id), None) if not element: raise HTTPException(status_code=404, detail="Automation screen element not found") prompt = ( ELEMENT_LOCATE_PROMPT .replace("{name}", element.get("name") or "") .replace("{approximate_location}", element.get("approximate_location") or "") .replace("{screen_description}", screen.get("description") or screen.get("interface_name") or "") ) ai_result = ai_service.chat_with_images( provider_id, model_id, prompt, [image_to_base64(screen["image_path"])], temperature, ) try: parsed = json_from_ai(ai_result["content"]) except (json.JSONDecodeError, ValueError) as exc: raise HTTPException(status_code=502, detail=f"AI locate output is not valid JSON: {exc}") from exc if not bool(parsed.get("has_element")) or parsed.get("x_percent") is None or parsed.get("y_percent") is None: return {"located": False, "element": element, "ai_result": parsed, "ai_raw_content": ai_result["content"]} x_percent = normalize_percent(parsed.get("x_percent")) y_percent = normalize_percent(parsed.get("y_percent")) x = max(0, min(int(screen["width"]) - 1, round(int(screen["width"]) * x_percent / 100))) y = max(0, min(int(screen["height"]) - 1, round(int(screen["height"]) * y_percent / 100))) raw = {**parsed, "previous": element.get("raw_json")} with get_db() as conn: conn.execute( """ UPDATE automation_screen_elements SET x_percent = ?, y_percent = ?, x = ?, y = ?, is_located = 1, raw_json = ? WHERE id = ? AND screen_id = ? """, (x_percent, y_percent, x, y, json.dumps(raw, ensure_ascii=False), element_id, screen_id), ) updated = get_screen(screen_id, include_image=True) updated_element = next(item for item in updated["elements"] if item["id"] == element_id) return { "located": True, "element": updated_element, "screen": updated, "ai_result": parsed, "ai_raw_content": ai_result["content"], } def normalize_percent(value: Any) -> float: """规范化百分比数值,兼容模型偶尔输出 0-1 小数的情况。""" try: number = float(value) except (TypeError, ValueError): return 0.0 if 0 <= number <= 1: number *= 100 return max(0.0, min(100.0, round(number, 2))) def list_screens(page: int, page_size: int) -> dict[str, Any]: """分页查询已识别界面列表。""" offset = (page - 1) * page_size with get_db() as conn: total = conn.execute("SELECT COUNT(*) AS total FROM automation_screens").fetchone()["total"] rows = conn.execute( """ SELECT s.*, COUNT(e.id) AS element_count FROM automation_screens s LEFT JOIN automation_screen_elements e ON e.screen_id = s.id GROUP BY s.id ORDER BY s.created_at DESC LIMIT ? OFFSET ? """, (page_size, offset), ).fetchall() return {"items": [public_screen(row) for row in rows], "total": total, "page": page, "page_size": page_size} def get_screen(screen_id: int, include_image: bool = False) -> dict[str, Any]: """读取单个已识别界面的详情和可操作元素。""" with get_db() as conn: screen = conn.execute("SELECT * FROM automation_screens WHERE id = ?", (screen_id,)).fetchone() if not screen: raise HTTPException(status_code=404, detail="Automation screen not found") elements = conn.execute( "SELECT * FROM automation_screen_elements WHERE screen_id = ? ORDER BY element_index ASC", (screen_id,), ).fetchall() item = public_screen(screen) item["elements"] = [public_element(row) for row in elements] if include_image and stored_path(item["image_path"]).exists(): image = image_to_base64(item["image_path"]) item["image_base64"] = image["base64"] item["mime_type"] = image["mime_type"] return item def delete_screen(screen_id: int) -> dict[str, Any]: """删除已识别界面记录,图片文件保留用于审计。""" with get_db() as conn: cursor = conn.execute("DELETE FROM automation_screens WHERE id = ?", (screen_id,)) if cursor.rowcount == 0: raise HTTPException(status_code=404, detail="Automation screen not found") return {"deleted": cursor.rowcount} def public_screen(row: dict[str, Any]) -> dict[str, Any]: """把数据库中的界面行转换为接口返回格式。""" item = dict(row) item["is_windows_desktop"] = bool(item.get("is_windows_desktop")) item["is_browser_webpage"] = bool(item.get("is_browser_webpage")) return item def public_element(row: dict[str, Any]) -> dict[str, Any]: """把数据库中的元素行转换为接口返回格式。""" item = dict(row) item["is_located"] = bool(item.get("is_located")) return item def process_snapshot() -> dict[int, dict[str, Any]]: """获取当前进程快照,只用于自动化动作前后对比,不写入进程扫描表。""" snapshot: dict[int, dict[str, Any]] = {} for proc in psutil.process_iter(["pid", "name", "exe"]): try: snapshot[int(proc.info["pid"])] = { "pid": int(proc.info["pid"]), "name": proc.info.get("name"), "exe": proc.info.get("exe"), } except (psutil.Error, OSError, TypeError, ValueError): continue return snapshot def diff_new_processes(before: dict[int, dict[str, Any]], after: dict[int, dict[str, Any]]) -> list[dict[str, Any]]: """比较动作前后的进程快照,找出本次自动化动作新增的进程。""" new_items = [after[pid] for pid in sorted(set(after) - set(before))] OPENED_PROCESS_IDS.update(item["pid"] for item in new_items) return new_items def validate_screen_before_action( screen_id: int | None, provider_id: int | None, model_id: int | None, temperature: float, action_type: str, workflow_id: int | None = None, node_id: int | None = None, ) -> dict[str, Any] | None: """如果动作绑定了界面 ID,则先用 AI 判断当前屏幕是否仍处于目标界面。""" if screen_id is None: return None provider_id, model_id, temperature = resolve_ai_params(provider_id, model_id, temperature) target = get_screen(screen_id) current = take_screenshot_file(error_dir(), "compare_current") prompt = SCREEN_COMPARE_PROMPT.replace("{description}", target.get("description") or target.get("interface_name") or "") ai_result = ai_service.chat_with_images( provider_id, model_id, prompt, [image_to_base64(current["path"]), image_to_base64(target["image_path"])], temperature, ) try: parsed = json_from_ai(ai_result["content"]) except (json.JSONDecodeError, ValueError) as exc: raise HTTPException(status_code=502, detail=f"AI compare output is not valid JSON: {exc}") from exc is_match = bool(parsed.get("is_match")) similarity = safe_float(parsed.get("similarity")) if not is_match: error = record_error( action_type=action_type, message=str(parsed.get("reason") or "界面对比失败,当前屏幕不是目标界面"), screen_id=screen_id, workflow_id=workflow_id, node_id=node_id, similarity=similarity, expected_image_path=target["image_path"], actual_image_path=current["db_path"], compare_result=parsed, ) raise HTTPException(status_code=409, detail={"message": error["message"], "error": error}) return parsed def safe_float(value: Any) -> float | None: """安全转换浮点数。""" try: return float(value) except (TypeError, ValueError): return None def record_error( action_type: str, message: str, screen_id: int | None = None, workflow_id: int | None = None, node_id: int | None = None, similarity: float | None = None, expected_image_path: str | None = None, actual_image_path: str | None = None, compare_result: dict[str, Any] | None = None, ) -> dict[str, Any]: """保存自动化错误记录,便于在错误记录菜单中回看。""" now = now_iso() with get_db() as conn: cursor = conn.execute( """ INSERT INTO automation_errors ( workflow_id, node_id, screen_id, action_type, message, similarity, expected_image_path, actual_image_path, compare_result_json, created_at ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( workflow_id, node_id, screen_id, action_type, message, similarity, expected_image_path, actual_image_path, json.dumps(compare_result or {}, ensure_ascii=False), now, ), ) row = conn.execute("SELECT * FROM automation_errors WHERE id = ?", (cursor.lastrowid,)).fetchone() return public_error(row) def execute_mouse_action(payload: AutomationMouseActionRequest) -> dict[str, Any]: """执行鼠标点击类动作,并记录动作前后新增进程。""" before = process_snapshot() compare = validate_screen_before_action( payload.screen_id, payload.provider_id, payload.model_id, payload.temperature, f"mouse_{payload.mouse_action}", payload.workflow_id, payload.node_id, ) action_map = {"click": "click", "double_click": "double_click", "right_click": "right_click"} result = windows_automation.mouse_action(action_map[payload.mouse_action], x=payload.x, y=payload.y) time.sleep(0.5) new_processes = diff_new_processes(before, process_snapshot()) return {"result": result, "compare": compare, "new_processes": new_processes} def execute_keyboard_action(payload: AutomationKeyboardActionRequest) -> dict[str, Any]: """执行键盘组合键动作,并记录动作前后新增进程。""" before = process_snapshot() compare = validate_screen_before_action( payload.screen_id, payload.provider_id, payload.model_id, payload.temperature, "keyboard", payload.workflow_id, payload.node_id, ) result = windows_automation.keyboard_action("hotkey" if len(payload.keys) > 1 else "press", key=payload.keys[0], keys=payload.keys) time.sleep(0.5) new_processes = diff_new_processes(before, process_snapshot()) return {"result": result, "compare": compare, "new_processes": new_processes} def execute_text_input(payload: AutomationTextInputRequest) -> dict[str, Any]: """通过剪贴板粘贴文本,避免直接模拟按键时中文输入不稳定。""" before = process_snapshot() compare = validate_screen_before_action( payload.screen_id, payload.provider_id, payload.model_id, payload.temperature, "text_input", payload.workflow_id, payload.node_id, ) try: import pyperclip except ImportError as exc: raise HTTPException(status_code=500, detail="pyperclip is not installed") from exc pyperclip.copy(payload.text) result = windows_automation.keyboard_action("hotkey", keys=["ctrl", "v"]) time.sleep(0.5) new_processes = diff_new_processes(before, process_snapshot()) return {"result": result, "compare": compare, "new_processes": new_processes} def execute_start_program(payload: AutomationStartProgramRequest) -> dict[str, Any]: """启动程序,并把动作后新增的进程记录为本次自动化打开的程序。""" before = process_snapshot() compare = validate_screen_before_action( payload.screen_id, payload.provider_id, payload.model_id, payload.temperature, "start_program", payload.workflow_id, payload.node_id, ) result = windows_automation.start_program(payload.command, payload.cwd, payload.shell) time.sleep(1) new_processes = diff_new_processes(before, process_snapshot()) if result.get("pid"): OPENED_PROCESS_IDS.add(int(result["pid"])) return {"result": result, "compare": compare, "new_processes": new_processes} def close_opened_programs(pids: list[int] | None = None) -> dict[str, Any]: """关闭本次自动化过程中记录的新进程。""" targets = sorted(set(pids or list(OPENED_PROCESS_IDS))) closed = [] for pid in targets: try: closed.append(windows_automation.stop_program(pid=pid)) OPENED_PROCESS_IDS.discard(pid) except HTTPException as exc: closed.append({"pid": pid, "error": exc.detail}) return {"action": "close_opened_programs", "items": closed} def save_workflow(payload: AutomationWorkflowSaveRequest) -> dict[str, Any]: """保存 workflow/v1 工作流图。""" now = now_iso() workflow_json = normalize_workflow_payload(payload) workflow_key = normalize_workflow_key(payload.workflow_key) try: with get_db() as conn: cursor = conn.execute( """ INSERT INTO automation_workflows (workflow_key, name, description, raw_json, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?) """, (workflow_key, payload.name.strip(), payload.description, json.dumps(workflow_json, ensure_ascii=False), now, now), ) workflow_id = cursor.lastrowid conn.execute("DELETE FROM automation_workflow_nodes WHERE workflow_id = ?", (workflow_id,)) except sqlite3.IntegrityError as exc: raise HTTPException(status_code=409, detail="Workflow key already exists") from exc return get_workflow(workflow_id) def update_workflow(workflow_id: int, payload: AutomationWorkflowSaveRequest) -> dict[str, Any]: """更新 workflow/v1 工作流图。""" now = now_iso() workflow_json = normalize_workflow_payload(payload) workflow_key = normalize_workflow_key(payload.workflow_key) try: with get_db() as conn: existing = conn.execute("SELECT id FROM automation_workflows WHERE id = ?", (workflow_id,)).fetchone() if not existing: raise HTTPException(status_code=404, detail="Automation workflow not found") conn.execute( """ UPDATE automation_workflows SET workflow_key = ?, name = ?, description = ?, raw_json = ?, updated_at = ? WHERE id = ? """, (workflow_key, payload.name.strip(), payload.description, json.dumps(workflow_json, ensure_ascii=False), now, workflow_id), ) conn.execute("DELETE FROM automation_workflow_nodes WHERE workflow_id = ?", (workflow_id,)) except sqlite3.IntegrityError as exc: raise HTTPException(status_code=409, detail="Workflow key already exists") from exc return get_workflow(workflow_id) def normalize_workflow_payload(payload: AutomationWorkflowSaveRequest) -> dict[str, Any]: """把请求模型转换为持久化的 workflow/v1 JSON。""" workflow_json = payload.model_dump() workflow_json["schema_version"] = "workflow/v1" workflow_json["workflow_key"] = normalize_workflow_key(payload.workflow_key) workflow_json["name"] = payload.name.strip() workflow_json.setdefault("variables", {}) workflow_json.setdefault("settings", {}) workflow_json.setdefault("edges", []) return workflow_json def normalize_workflow_key(value: str | None) -> str | None: key = (value or "").strip() if not key: return None if not re.fullmatch(r"[A-Za-z0-9][A-Za-z0-9_-]*", key): raise HTTPException(status_code=400, detail="Workflow key can only contain letters, numbers, underscores, and hyphens") return key def list_workflows(page: int, page_size: int) -> dict[str, Any]: """分页查询自动化工作流列表。""" offset = (page - 1) * page_size with get_db() as conn: total = conn.execute("SELECT COUNT(*) AS total FROM automation_workflows").fetchone()["total"] rows = conn.execute( """ SELECT * FROM automation_workflows ORDER BY updated_at DESC LIMIT ? OFFSET ? """, (page_size, offset), ).fetchall() return {"items": [workflow_summary(row) for row in rows], "total": total, "page": page, "page_size": page_size} def get_workflow(workflow_id: int) -> dict[str, Any]: """读取 workflow/v1 工作流详情。""" with get_db() as conn: workflow = conn.execute("SELECT * FROM automation_workflows WHERE id = ?", (workflow_id,)).fetchone() if not workflow: raise HTTPException(status_code=404, detail="Automation workflow not found") item = workflow_to_public(workflow) return item def get_workflow_by_key(workflow_key: str) -> dict[str, Any]: """按稳定 key 读取 workflow/v1 工作流详情。""" key = normalize_workflow_key(workflow_key) if not key: raise HTTPException(status_code=400, detail="Workflow key is required") with get_db() as conn: workflow = conn.execute("SELECT * FROM automation_workflows WHERE workflow_key = ?", (key,)).fetchone() if not workflow: raise HTTPException(status_code=404, detail="Automation workflow not found") return workflow_to_public(workflow) def delete_workflow(workflow_id: int) -> dict[str, Any]: """删除工作流及其节点。""" with get_db() as conn: cursor = conn.execute("DELETE FROM automation_workflows WHERE id = ?", (workflow_id,)) if cursor.rowcount == 0: raise HTTPException(status_code=404, detail="Automation workflow not found") return {"deleted": cursor.rowcount} def run_workflow(workflow_id: int, payload: AutomationWorkflowRunRequest) -> dict[str, Any]: """执行 workflow/v1 工作流图。""" workflow = get_workflow(workflow_id) defaults = settings_service.default_ai_params() provider_id = payload.provider_id or defaults.get("provider_id") model_id = payload.model_id or defaults.get("model_id") temperature = payload.temperature if payload.temperature is not None else defaults.get("temperature", 0.1) context = WorkflowContext( workflow_id=workflow_id, provider_id=provider_id, model_id=model_id, temperature=float(temperature), variables=workflow_variables(workflow, payload.variables), ) nodes = workflow.get("nodes") or [] edges = workflow.get("edges") or [] node_map = {node["id"]: node for node in nodes} start_id = first_workflow_node_id(nodes, edges) if not start_id: return {"workflow_id": workflow_id, "status": "SUCCESS", "results": []} results: list[dict[str, Any]] = [] current_id: str | None = start_id visited_steps = 0 max_steps = int(workflow.get("settings", {}).get("max_steps") or 100) while current_id and visited_steps < max_steps: visited_steps += 1 node = node_map.get(current_id) if not node: return {"workflow_id": workflow_id, "status": "FAILED", "detail": f"Missing node: {current_id}", "results": results} try: resolved_inputs = resolve_node_inputs(node, edges, context) outputs = execute_workflow_node(node, resolved_inputs, context) context.outputs[node["id"]] = outputs results.append({"node_id": node["id"], "node": node, "status": "SUCCESS", "inputs": resolved_inputs, "outputs": outputs}) if node.get("type") == "flow.end": return {"workflow_id": workflow_id, "status": "SUCCESS", "results": results, "outputs": context.outputs} next_port = str(outputs.get("next_port") or "success") current_id = next_control_node_id(node["id"], next_port, edges) or next_control_node_id(node["id"], "next", edges) except HTTPException as exc: failure = { "node_id": node.get("id"), "node": node, "status": "FAILED", "detail": exc.detail, "artifacts": capture_failure_artifacts(context), } results.append(failure) return {"workflow_id": workflow_id, "status": "FAILED", "failed": failure, "results": results} except WorkflowPaused as exc: paused = {"node_id": node.get("id"), "node": node, "status": "PAUSED", "detail": exc.payload} results.append(paused) return {"workflow_id": workflow_id, "status": "PAUSED", "paused": paused, "results": results} except Exception as exc: failure = { "node_id": node.get("id"), "node": node, "status": "FAILED", "detail": str(exc), "artifacts": capture_failure_artifacts(context), } results.append(failure) return {"workflow_id": workflow_id, "status": "FAILED", "failed": failure, "results": results} if visited_steps >= max_steps: return {"workflow_id": workflow_id, "status": "FAILED", "detail": f"Workflow exceeded max_steps={max_steps}", "results": results} return {"workflow_id": workflow_id, "status": "SUCCESS", "results": results, "outputs": context.outputs} def run_workflow_by_key(workflow_key: str, payload: AutomationWorkflowRunRequest) -> dict[str, Any]: workflow = get_workflow_by_key(workflow_key) return run_workflow(int(workflow["id"]), payload) def execute_workflow_node( node: dict[str, Any], inputs: dict[str, Any], context: WorkflowContext, ) -> dict[str, Any]: """通过节点注册表执行 workflow/v1 节点。""" try: executor = get_node_executor(str(node.get("type") or "")) except KeyError as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc return executor(node, inputs, context) def capture_failure_artifacts(context: WorkflowContext) -> dict[str, Any]: """工作流失败时尽量保存一张当前屏幕截图,供前端询问用户。""" artifacts: dict[str, Any] = {} try: screenshot = take_screenshot_file(error_dir(), "workflow_failure") except Exception as exc: artifacts["screenshot_error"] = str(exc) return artifacts artifacts["screenshot_path"] = screenshot.get("db_path") or screenshot.get("path") artifacts["width"] = screenshot.get("width") artifacts["height"] = screenshot.get("height") context.runtime["current_screenshot_path"] = artifacts["screenshot_path"] return artifacts def workflow_to_public(row: dict[str, Any]) -> dict[str, Any]: item = workflow_summary(row) workflow_json = parse_workflow_json(row.get("raw_json")) item.update(workflow_json) item["id"] = row["id"] item["workflow_key"] = row.get("workflow_key") or workflow_json.get("workflow_key") item["created_at"] = row["created_at"] item["updated_at"] = row["updated_at"] item["node_count"] = len(item.get("nodes") or []) item["edge_count"] = len(item.get("edges") or []) return item def workflow_summary(row: dict[str, Any]) -> dict[str, Any]: workflow_json = parse_workflow_json(row.get("raw_json")) return { "id": row["id"], "workflow_key": row.get("workflow_key") or workflow_json.get("workflow_key"), "name": row["name"], "description": row.get("description"), "schema_version": workflow_json.get("schema_version") or "workflow/v1", "node_count": len(workflow_json.get("nodes") or []), "edge_count": len(workflow_json.get("edges") or []), "created_at": row.get("created_at"), "updated_at": row.get("updated_at"), } def parse_workflow_json(raw_json: str | None) -> dict[str, Any]: if not raw_json: return empty_workflow_json() try: parsed = json.loads(raw_json) except json.JSONDecodeError: return empty_workflow_json() if not isinstance(parsed, dict): return empty_workflow_json() parsed.setdefault("schema_version", "workflow/v1") parsed.setdefault("variables", {}) parsed.setdefault("settings", {}) parsed.setdefault("nodes", []) parsed.setdefault("edges", []) return parsed def empty_workflow_json() -> dict[str, Any]: return {"schema_version": "workflow/v1", "variables": {}, "settings": {}, "nodes": [], "edges": []} def workflow_variables(workflow: dict[str, Any], overrides: dict[str, Any]) -> dict[str, Any]: variables: dict[str, Any] = {} for name, definition in (workflow.get("variables") or {}).items(): if isinstance(definition, dict): variables[name] = definition.get("default") else: variables[name] = definition variables.update(overrides or {}) return variables def first_workflow_node_id(nodes: list[dict[str, Any]], edges: list[dict[str, Any]]) -> str | None: if not nodes: return None for node in nodes: if node.get("type") == "flow.start": return node.get("id") targeted = {edge.get("target") for edge in edges if edge.get("kind") == "control"} for node in nodes: if node.get("id") not in targeted: return node.get("id") return nodes[0].get("id") def next_control_node_id(source_id: str, source_port: str, edges: list[dict[str, Any]]) -> str | None: fallback = None for edge in edges: if edge.get("kind") != "control" or edge.get("source") != source_id: continue if edge.get("source_port") == source_port: return edge.get("target") if edge.get("source_port") in (None, "", "success", "next") and fallback is None: fallback = edge.get("target") return fallback def resolve_node_inputs(node: dict[str, Any], edges: list[dict[str, Any]], context: WorkflowContext) -> dict[str, Any]: resolved: dict[str, Any] = {} for key, value in (node.get("inputs") or {}).items(): resolved[key] = resolve_value_ref(value, context) for edge in edges: if edge.get("kind") != "data" or edge.get("target") != node.get("id"): continue source_outputs = context.outputs.get(edge.get("source") or "", {}) resolved[edge.get("target_port") or "value"] = source_outputs.get(edge.get("source_port") or "value") return resolved def resolve_value_ref(value: Any, context: WorkflowContext) -> Any: if not isinstance(value, dict) or "source" not in value: return value source = value.get("source") if source == "literal": return value.get("value") if source == "variable": return context.variables.get(value.get("name") or "") if source == "node_output": return context.outputs.get(value.get("node_id") or "", {}).get(value.get("output") or "") if source == "runtime": return context.runtime.get(value.get("name") or "") return None def list_workflow_node_definitions() -> dict[str, Any]: """返回前端可用于生成节点库和属性表单的节点定义。""" return {"schema_version": "workflow/v1", "items": get_node_definitions()} def plan_workflow(payload: AutomationWorkflowPlanRequest) -> dict[str, Any]: """让 AI 根据用户需求和节点定义生成 workflow/v1 草稿。""" provider_id, model_id, temperature = resolve_ai_params(payload.provider_id, payload.model_id, payload.temperature) prompt = build_workflow_plan_prompt(payload.requirement) ai_result = ai_service.chat(provider_id, model_id, prompt, temperature) try: parsed = json_from_ai(ai_result["content"]) except (json.JSONDecodeError, ValueError) as exc: raise HTTPException(status_code=502, detail=f"AI workflow plan output is not valid JSON: {exc}") from exc session_id = str(uuid.uuid4()) return {"session_id": session_id, "plan": parsed, "ai_raw_content": ai_result["content"]} def continue_workflow_plan(payload: AutomationWorkflowPlanContinueRequest) -> dict[str, Any]: """继续一次 AI 工作流规划对话,返回新的草稿建议。""" provider_id, model_id, temperature = resolve_ai_params(payload.provider_id, payload.model_id, payload.temperature) prompt = build_workflow_plan_prompt(payload.user_message, session_id=payload.session_id) ai_result = ai_service.chat(provider_id, model_id, prompt, temperature) try: parsed = json_from_ai(ai_result["content"]) except (json.JSONDecodeError, ValueError) as exc: raise HTTPException(status_code=502, detail=f"AI workflow plan output is not valid JSON: {exc}") from exc return {"session_id": payload.session_id, "plan": parsed, "ai_raw_content": ai_result["content"]} def build_workflow_plan_prompt(requirement: str, session_id: str | None = None) -> str: node_defs = json.dumps(get_node_definitions(), ensure_ascii=False, indent=2) return f"""请作为 Windows 自动化工作流规划器,根据用户需求生成 workflow/v1 JSON 草稿。 要求: 1. 只能使用节点定义列表中的 type。 2. 输出严格 JSON 对象,不要 Markdown。 3. JSON 字段必须包含 summary、questions、workflow。 4. workflow 必须包含 schema_version、name、description、variables、settings、nodes、edges。 5. 不确定的坐标或界面状态,优先添加 human.ask_user 节点或 screen.screenshot 节点。 6. 控制流连线 kind 使用 control,数据连线 kind 使用 data。 会话 ID:{session_id or "new"} 用户需求: {requirement} 可用节点定义: {node_defs} """ def list_errors(page: int, page_size: int) -> dict[str, Any]: """分页查询自动化错误记录。""" offset = (page - 1) * page_size with get_db() as conn: total = conn.execute("SELECT COUNT(*) AS total FROM automation_errors").fetchone()["total"] rows = conn.execute( """ SELECT e.*, s.interface_name FROM automation_errors e LEFT JOIN automation_screens s ON s.id = e.screen_id ORDER BY e.created_at DESC LIMIT ? OFFSET ? """, (page_size, offset), ).fetchall() return {"items": [public_error(row) for row in rows], "total": total, "page": page, "page_size": page_size} def get_error(error_id: int, include_images: bool = False) -> dict[str, Any]: """读取单条自动化错误详情,可附带目标截图和实际截图。""" with get_db() as conn: row = conn.execute("SELECT * FROM automation_errors WHERE id = ?", (error_id,)).fetchone() if not row: raise HTTPException(status_code=404, detail="Automation error not found") item = public_error(row) if include_images: for key in ["expected_image_path", "actual_image_path"]: path = item.get(key) if path and stored_path(path).exists(): image = image_to_base64(path) item[key.replace("_path", "_base64")] = image["base64"] item[key.replace("_path", "_mime_type")] = image["mime_type"] return item def public_error(row: dict[str, Any]) -> dict[str, Any]: """把错误记录行转换为接口返回格式。""" item = dict(row) try: item["compare_result"] = json.loads(item.pop("compare_result_json") or "{}") except json.JSONDecodeError: item["compare_result"] = {} return item