|
@@ -10,15 +10,17 @@ from typing import Any
|
|
|
import psutil
|
|
import psutil
|
|
|
from fastapi import HTTPException
|
|
from fastapi import HTTPException
|
|
|
|
|
|
|
|
-from . import ai_service, windows_automation
|
|
|
|
|
|
|
+from . import ai_service, settings_service, windows_automation
|
|
|
from .database import DATA_DIR, get_db
|
|
from .database import DATA_DIR, get_db
|
|
|
from .scanner import now_iso
|
|
from .scanner import now_iso
|
|
|
from .schemas import (
|
|
from .schemas import (
|
|
|
AutomationKeyboardActionRequest,
|
|
AutomationKeyboardActionRequest,
|
|
|
AutomationMouseActionRequest,
|
|
AutomationMouseActionRequest,
|
|
|
|
|
+ AutomationScreenshotCaptureRequest,
|
|
|
AutomationStartProgramRequest,
|
|
AutomationStartProgramRequest,
|
|
|
AutomationTextInputRequest,
|
|
AutomationTextInputRequest,
|
|
|
AutomationVisionAnalyzeRequest,
|
|
AutomationVisionAnalyzeRequest,
|
|
|
|
|
+ AutomationWorkflowRunRequest,
|
|
|
AutomationWorkflowSaveRequest,
|
|
AutomationWorkflowSaveRequest,
|
|
|
)
|
|
)
|
|
|
|
|
|
|
@@ -66,13 +68,28 @@ SCREEN_COMPARE_PROMPT = """请作为 AI 视觉自动化校验器判断两张截
|
|
|
|
|
|
|
|
def ensure_dirs() -> None:
|
|
def ensure_dirs() -> None:
|
|
|
"""确保自动化截图、错误截图和运行时目录存在。"""
|
|
"""确保自动化截图、错误截图和运行时目录存在。"""
|
|
|
- for path in [SCREEN_DIR, ERROR_DIR, RUNTIME_DIR]:
|
|
|
|
|
|
|
+ for path in [screen_dir(), error_dir(), runtime_dir()]:
|
|
|
path.mkdir(parents=True, exist_ok=True)
|
|
path.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+def screen_dir() -> Path:
|
|
|
|
|
+ """根据系统设置获取已识别界面截图目录。"""
|
|
|
|
|
+ return settings_service.resolve_data_path("automation_screen_path", "automation/screens")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def error_dir() -> Path:
|
|
|
|
|
+ """根据系统设置获取错误截图目录。"""
|
|
|
|
|
+ return settings_service.resolve_data_path("automation_error_path", "automation/errors")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def runtime_dir() -> Path:
|
|
|
|
|
+ """根据系统设置获取临时截图目录。"""
|
|
|
|
|
+ return settings_service.resolve_data_path("automation_runtime_path", "automation/runtime")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
def image_to_base64(path: str | Path) -> dict[str, str]:
|
|
def image_to_base64(path: str | Path) -> dict[str, str]:
|
|
|
"""读取图片文件并转为 AI 服务可接收的 base64 结构。"""
|
|
"""读取图片文件并转为 AI 服务可接收的 base64 结构。"""
|
|
|
- file_path = Path(path)
|
|
|
|
|
|
|
+ file_path = stored_path(path)
|
|
|
mime_type = mimetypes.guess_type(file_path.name)[0] or "image/png"
|
|
mime_type = mimetypes.guess_type(file_path.name)[0] or "image/png"
|
|
|
return {
|
|
return {
|
|
|
"base64": base64.b64encode(file_path.read_bytes()).decode("ascii"),
|
|
"base64": base64.b64encode(file_path.read_bytes()).decode("ascii"),
|
|
@@ -95,19 +112,70 @@ def take_screenshot_file(folder: Path, prefix: str) -> dict[str, Any]:
|
|
|
path = folder / filename
|
|
path = folder / filename
|
|
|
result = windows_automation.take_screenshot(str(path), include_base64=True)
|
|
result = windows_automation.take_screenshot(str(path), include_base64=True)
|
|
|
result["path"] = str(path)
|
|
result["path"] = str(path)
|
|
|
|
|
+ result["db_path"] = data_relative_path(path)
|
|
|
return result
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+def data_relative_path(path: str | Path) -> str:
|
|
|
|
|
+ """把 data 目录下的文件路径转换为数据库保存用的相对路径。"""
|
|
|
|
|
+ file_path = Path(path).resolve()
|
|
|
|
|
+ try:
|
|
|
|
|
+ return file_path.relative_to(DATA_DIR.resolve()).as_posix()
|
|
|
|
|
+ except ValueError:
|
|
|
|
|
+ return str(file_path)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def stored_path(path: str | Path) -> Path:
|
|
|
|
|
+ """把数据库中的相对路径还原成真实文件路径,同时兼容旧的绝对路径。"""
|
|
|
|
|
+ file_path = Path(path)
|
|
|
|
|
+ if file_path.is_absolute():
|
|
|
|
|
+ return file_path
|
|
|
|
|
+ return (DATA_DIR / file_path).resolve()
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def resolve_ai_params(
|
|
|
|
|
+ provider_id: int | None,
|
|
|
|
|
+ model_id: int | None,
|
|
|
|
|
+ temperature: float | None,
|
|
|
|
|
+) -> tuple[int, int, float]:
|
|
|
|
|
+ """合并请求参数和系统默认 AI 参数。"""
|
|
|
|
|
+ defaults = settings_service.default_ai_params()
|
|
|
|
|
+ resolved_provider = provider_id or defaults.get("provider_id")
|
|
|
|
|
+ resolved_model = model_id or defaults.get("model_id")
|
|
|
|
|
+ resolved_temperature = temperature if temperature is not None else defaults.get("temperature", 0.1)
|
|
|
|
|
+ if not resolved_provider or not resolved_model:
|
|
|
|
|
+ raise HTTPException(status_code=400, detail="AI provider and model are required. Configure system defaults or pass them explicitly.")
|
|
|
|
|
+ return int(resolved_provider), int(resolved_model), float(resolved_temperature)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def capture_screenshot(payload: AutomationScreenshotCaptureRequest) -> dict[str, Any]:
|
|
|
|
|
+ """截取当前屏幕并返回给前端显示,不进行 AI 分析。"""
|
|
|
|
|
+ if payload.save:
|
|
|
|
|
+ screenshot = take_screenshot_file(runtime_dir(), "manual_screenshot")
|
|
|
|
|
+ else:
|
|
|
|
|
+ screenshot = windows_automation.take_screenshot(None, include_base64=True)
|
|
|
|
|
+ screenshot["path"] = None
|
|
|
|
|
+ screenshot["db_path"] = None
|
|
|
|
|
+ return {
|
|
|
|
|
+ "width": screenshot["width"],
|
|
|
|
|
+ "height": screenshot["height"],
|
|
|
|
|
+ "image_base64": screenshot["image_base64"],
|
|
|
|
|
+ "mime_type": screenshot["mime_type"],
|
|
|
|
|
+ "path": screenshot.get("db_path"),
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
def analyze_screen(payload: AutomationVisionAnalyzeRequest) -> dict[str, Any]:
|
|
def analyze_screen(payload: AutomationVisionAnalyzeRequest) -> dict[str, Any]:
|
|
|
"""截图当前屏幕,调用 AI 识别界面和可操作元素,并保存识别结果。"""
|
|
"""截图当前屏幕,调用 AI 识别界面和可操作元素,并保存识别结果。"""
|
|
|
- screenshot = take_screenshot_file(SCREEN_DIR, "screen")
|
|
|
|
|
|
|
+ provider_id, model_id, temperature = resolve_ai_params(payload.provider_id, payload.model_id, payload.temperature)
|
|
|
|
|
+ screenshot = take_screenshot_file(screen_dir(), "screen")
|
|
|
image = image_to_base64(screenshot["path"])
|
|
image = image_to_base64(screenshot["path"])
|
|
|
ai_result = ai_service.chat_with_images(
|
|
ai_result = ai_service.chat_with_images(
|
|
|
- payload.provider_id,
|
|
|
|
|
- payload.model_id,
|
|
|
|
|
|
|
+ provider_id,
|
|
|
|
|
+ model_id,
|
|
|
SCREEN_ANALYZE_PROMPT,
|
|
SCREEN_ANALYZE_PROMPT,
|
|
|
[image],
|
|
[image],
|
|
|
- payload.temperature,
|
|
|
|
|
|
|
+ temperature,
|
|
|
)
|
|
)
|
|
|
try:
|
|
try:
|
|
|
parsed = json_from_ai(ai_result["content"])
|
|
parsed = json_from_ai(ai_result["content"])
|
|
@@ -130,7 +198,7 @@ def analyze_screen(payload: AutomationVisionAnalyzeRequest) -> dict[str, Any]:
|
|
|
(
|
|
(
|
|
|
str(parsed.get("interface_name") or "未命名界面")[:160],
|
|
str(parsed.get("interface_name") or "未命名界面")[:160],
|
|
|
parsed.get("description"),
|
|
parsed.get("description"),
|
|
|
- screenshot["path"],
|
|
|
|
|
|
|
+ screenshot["db_path"],
|
|
|
width,
|
|
width,
|
|
|
height,
|
|
height,
|
|
|
1 if bool(parsed.get("is_windows_desktop")) else 0,
|
|
1 if bool(parsed.get("is_windows_desktop")) else 0,
|
|
@@ -236,7 +304,7 @@ def get_screen(screen_id: int, include_image: bool = False) -> dict[str, Any]:
|
|
|
).fetchall()
|
|
).fetchall()
|
|
|
item = public_screen(screen)
|
|
item = public_screen(screen)
|
|
|
item["elements"] = [public_element(row) for row in elements]
|
|
item["elements"] = [public_element(row) for row in elements]
|
|
|
- if include_image and Path(item["image_path"]).exists():
|
|
|
|
|
|
|
+ if include_image and stored_path(item["image_path"]).exists():
|
|
|
image = image_to_base64(item["image_path"])
|
|
image = image_to_base64(item["image_path"])
|
|
|
item["image_base64"] = image["base64"]
|
|
item["image_base64"] = image["base64"]
|
|
|
item["mime_type"] = image["mime_type"]
|
|
item["mime_type"] = image["mime_type"]
|
|
@@ -300,11 +368,10 @@ def validate_screen_before_action(
|
|
|
"""如果动作绑定了界面 ID,则先用 AI 判断当前屏幕是否仍处于目标界面。"""
|
|
"""如果动作绑定了界面 ID,则先用 AI 判断当前屏幕是否仍处于目标界面。"""
|
|
|
if screen_id is None:
|
|
if screen_id is None:
|
|
|
return None
|
|
return None
|
|
|
- if provider_id is None or model_id is None:
|
|
|
|
|
- raise HTTPException(status_code=400, detail="provider_id and model_id are required when screen_id is provided")
|
|
|
|
|
|
|
+ provider_id, model_id, temperature = resolve_ai_params(provider_id, model_id, temperature)
|
|
|
|
|
|
|
|
target = get_screen(screen_id)
|
|
target = get_screen(screen_id)
|
|
|
- current = take_screenshot_file(RUNTIME_DIR, "compare_current")
|
|
|
|
|
|
|
+ current = take_screenshot_file(error_dir(), "compare_current")
|
|
|
prompt = SCREEN_COMPARE_PROMPT.replace("{description}", target.get("description") or target.get("interface_name") or "")
|
|
prompt = SCREEN_COMPARE_PROMPT.replace("{description}", target.get("description") or target.get("interface_name") or "")
|
|
|
ai_result = ai_service.chat_with_images(
|
|
ai_result = ai_service.chat_with_images(
|
|
|
provider_id,
|
|
provider_id,
|
|
@@ -329,7 +396,7 @@ def validate_screen_before_action(
|
|
|
node_id=node_id,
|
|
node_id=node_id,
|
|
|
similarity=similarity,
|
|
similarity=similarity,
|
|
|
expected_image_path=target["image_path"],
|
|
expected_image_path=target["image_path"],
|
|
|
- actual_image_path=current["path"],
|
|
|
|
|
|
|
+ actual_image_path=current["db_path"],
|
|
|
compare_result=parsed,
|
|
compare_result=parsed,
|
|
|
)
|
|
)
|
|
|
raise HTTPException(status_code=409, detail={"message": error["message"], "error": error})
|
|
raise HTTPException(status_code=409, detail={"message": error["message"], "error": error})
|
|
@@ -520,16 +587,21 @@ def insert_workflow_nodes(conn, workflow_id: int, nodes: list[Any], now: str) ->
|
|
|
conn.execute(
|
|
conn.execute(
|
|
|
"""
|
|
"""
|
|
|
INSERT INTO automation_workflow_nodes (
|
|
INSERT INTO automation_workflow_nodes (
|
|
|
- workflow_id, node_index, node_type, screen_id, title, config_json, created_at, updated_at
|
|
|
|
|
|
|
+ workflow_id, node_index, node_key, node_type, screen_id, title,
|
|
|
|
|
+ position_x, position_y, next_node_keys, config_json, created_at, updated_at
|
|
|
)
|
|
)
|
|
|
- VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
|
|
|
|
|
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
|
""",
|
|
""",
|
|
|
(
|
|
(
|
|
|
workflow_id,
|
|
workflow_id,
|
|
|
index,
|
|
index,
|
|
|
|
|
+ node.node_key or f"node_{index}",
|
|
|
node.node_type,
|
|
node.node_type,
|
|
|
node.screen_id,
|
|
node.screen_id,
|
|
|
node.title,
|
|
node.title,
|
|
|
|
|
+ node.position_x,
|
|
|
|
|
+ node.position_y,
|
|
|
|
|
+ json.dumps(node.next_node_keys, ensure_ascii=False),
|
|
|
json.dumps(node.config, ensure_ascii=False),
|
|
json.dumps(node.config, ensure_ascii=False),
|
|
|
now,
|
|
now,
|
|
|
now,
|
|
now,
|
|
@@ -580,6 +652,113 @@ def delete_workflow(workflow_id: int) -> dict[str, Any]:
|
|
|
return {"deleted": cursor.rowcount}
|
|
return {"deleted": cursor.rowcount}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+def run_workflow(workflow_id: int, payload: AutomationWorkflowRunRequest) -> dict[str, Any]:
|
|
|
|
|
+ """按数据库中保存的工作流节点和连线顺序在后端执行整个工作流。"""
|
|
|
|
|
+ workflow = get_workflow(workflow_id)
|
|
|
|
|
+ defaults = settings_service.default_ai_params()
|
|
|
|
|
+ provider_id = payload.provider_id or defaults.get("provider_id")
|
|
|
|
|
+ model_id = payload.model_id or defaults.get("model_id")
|
|
|
|
|
+ temperature = payload.temperature if payload.temperature is not None else defaults.get("temperature", 0.1)
|
|
|
|
|
+ nodes = ordered_workflow_nodes(workflow.get("nodes") or [])
|
|
|
|
|
+ results: list[dict[str, Any]] = []
|
|
|
|
|
+ opened_pids: list[int] = []
|
|
|
|
|
+
|
|
|
|
|
+ for node in nodes:
|
|
|
|
|
+ try:
|
|
|
|
|
+ result = execute_workflow_node(workflow_id, node, provider_id, model_id, temperature, opened_pids)
|
|
|
|
|
+ opened_pids.extend(
|
|
|
|
|
+ item["pid"]
|
|
|
|
|
+ for item in result.get("new_processes", [])
|
|
|
|
|
+ if item.get("pid") and item["pid"] not in opened_pids
|
|
|
|
|
+ )
|
|
|
|
|
+ results.append({"node": node, "status": "SUCCESS", "result": result})
|
|
|
|
|
+ except HTTPException as exc:
|
|
|
|
|
+ if not (isinstance(exc.detail, dict) and exc.detail.get("error")):
|
|
|
|
|
+ record_error(
|
|
|
|
|
+ action_type=node.get("node_type") or "workflow",
|
|
|
|
|
+ message=str(exc.detail),
|
|
|
|
|
+ screen_id=node.get("screen_id"),
|
|
|
|
|
+ workflow_id=workflow_id,
|
|
|
|
|
+ node_id=node.get("id"),
|
|
|
|
|
+ )
|
|
|
|
|
+ failure = {"node": node, "status": "FAILED", "detail": exc.detail}
|
|
|
|
|
+ results.append(failure)
|
|
|
|
|
+ return {"workflow_id": workflow_id, "status": "FAILED", "failed": failure, "results": results}
|
|
|
|
|
+ return {"workflow_id": workflow_id, "status": "SUCCESS", "results": results}
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def ordered_workflow_nodes(nodes: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
|
|
|
+ """根据节点连线得到执行顺序;没有连线时沿用节点序号。"""
|
|
|
|
|
+ if not nodes:
|
|
|
|
|
+ return []
|
|
|
|
|
+ by_key = {node.get("node_key") or f"node_{node.get('node_index')}": node for node in nodes}
|
|
|
|
|
+ targeted = {key for node in nodes for key in node.get("next_node_keys", [])}
|
|
|
|
|
+ start_keys = [key for key in by_key if key not in targeted] or [next(iter(by_key))]
|
|
|
|
|
+ ordered: list[dict[str, Any]] = []
|
|
|
|
|
+ visited: set[str] = set()
|
|
|
|
|
+
|
|
|
|
|
+ def visit(key: str) -> None:
|
|
|
|
|
+ if key in visited or key not in by_key:
|
|
|
|
|
+ return
|
|
|
|
|
+ visited.add(key)
|
|
|
|
|
+ node = by_key[key]
|
|
|
|
|
+ ordered.append(node)
|
|
|
|
|
+ for next_key in node.get("next_node_keys", []):
|
|
|
|
|
+ visit(next_key)
|
|
|
|
|
+
|
|
|
|
|
+ for key in start_keys:
|
|
|
|
|
+ visit(key)
|
|
|
|
|
+ for key in by_key:
|
|
|
|
|
+ visit(key)
|
|
|
|
|
+ return ordered
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def execute_workflow_node(
|
|
|
|
|
+ workflow_id: int,
|
|
|
|
|
+ node: dict[str, Any],
|
|
|
|
|
+ provider_id: int | None,
|
|
|
|
|
+ model_id: int | None,
|
|
|
|
|
+ temperature: float,
|
|
|
|
|
+ opened_pids: list[int],
|
|
|
|
|
+) -> dict[str, Any]:
|
|
|
|
|
+ """把工作流节点配置转换成已有高层动作并执行。"""
|
|
|
|
|
+ node_type = node.get("node_type")
|
|
|
|
|
+ config = node.get("config") or {}
|
|
|
|
|
+ base = {
|
|
|
|
|
+ "screen_id": node.get("screen_id"),
|
|
|
|
|
+ "provider_id": provider_id,
|
|
|
|
|
+ "model_id": model_id,
|
|
|
|
|
+ "temperature": temperature,
|
|
|
|
|
+ "workflow_id": workflow_id,
|
|
|
|
|
+ "node_id": node.get("id"),
|
|
|
|
|
+ }
|
|
|
|
|
+ if node_type == "mouse":
|
|
|
|
|
+ return execute_mouse_action(
|
|
|
|
|
+ AutomationMouseActionRequest(
|
|
|
|
|
+ **base,
|
|
|
|
|
+ x=int(config.get("x", 0)),
|
|
|
|
|
+ y=int(config.get("y", 0)),
|
|
|
|
|
+ mouse_action=config.get("mouse_action") or "click",
|
|
|
|
|
+ )
|
|
|
|
|
+ )
|
|
|
|
|
+ if node_type == "keyboard":
|
|
|
|
|
+ return execute_keyboard_action(AutomationKeyboardActionRequest(**base, keys=config.get("keys") or []))
|
|
|
|
|
+ if node_type == "text_input":
|
|
|
|
|
+ return execute_text_input(AutomationTextInputRequest(**base, text=str(config.get("text") or "")))
|
|
|
|
|
+ if node_type == "start_program":
|
|
|
|
|
+ return execute_start_program(
|
|
|
|
|
+ AutomationStartProgramRequest(
|
|
|
|
|
+ **base,
|
|
|
|
|
+ command=str(config.get("command") or ""),
|
|
|
|
|
+ cwd=config.get("cwd"),
|
|
|
|
|
+ shell=bool(config.get("shell", True)),
|
|
|
|
|
+ )
|
|
|
|
|
+ )
|
|
|
|
|
+ if node_type == "close_programs":
|
|
|
|
|
+ return close_opened_programs(config.get("pids") or opened_pids)
|
|
|
|
|
+ raise HTTPException(status_code=400, detail=f"Unsupported workflow node type: {node_type}")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
def public_node(row: dict[str, Any]) -> dict[str, Any]:
|
|
def public_node(row: dict[str, Any]) -> dict[str, Any]:
|
|
|
"""把工作流节点行转换为接口返回格式。"""
|
|
"""把工作流节点行转换为接口返回格式。"""
|
|
|
item = dict(row)
|
|
item = dict(row)
|
|
@@ -587,6 +766,10 @@ def public_node(row: dict[str, Any]) -> dict[str, Any]:
|
|
|
item["config"] = json.loads(item.pop("config_json") or "{}")
|
|
item["config"] = json.loads(item.pop("config_json") or "{}")
|
|
|
except json.JSONDecodeError:
|
|
except json.JSONDecodeError:
|
|
|
item["config"] = {}
|
|
item["config"] = {}
|
|
|
|
|
+ try:
|
|
|
|
|
+ item["next_node_keys"] = json.loads(item.get("next_node_keys") or "[]")
|
|
|
|
|
+ except json.JSONDecodeError:
|
|
|
|
|
+ item["next_node_keys"] = []
|
|
|
return item
|
|
return item
|
|
|
|
|
|
|
|
|
|
|
|
@@ -618,7 +801,7 @@ def get_error(error_id: int, include_images: bool = False) -> dict[str, Any]:
|
|
|
if include_images:
|
|
if include_images:
|
|
|
for key in ["expected_image_path", "actual_image_path"]:
|
|
for key in ["expected_image_path", "actual_image_path"]:
|
|
|
path = item.get(key)
|
|
path = item.get(key)
|
|
|
- if path and Path(path).exists():
|
|
|
|
|
|
|
+ if path and stored_path(path).exists():
|
|
|
image = image_to_base64(path)
|
|
image = image_to_base64(path)
|
|
|
item[key.replace("_path", "_base64")] = image["base64"]
|
|
item[key.replace("_path", "_base64")] = image["base64"]
|
|
|
item[key.replace("_path", "_mime_type")] = image["mime_type"]
|
|
item[key.replace("_path", "_mime_type")] = image["mime_type"]
|