|
|
@@ -3,7 +3,10 @@ from __future__ import annotations
|
|
|
import base64
|
|
|
import json
|
|
|
import mimetypes
|
|
|
+import re
|
|
|
+import sqlite3
|
|
|
import time
|
|
|
+import uuid
|
|
|
from pathlib import Path
|
|
|
from typing import Any
|
|
|
|
|
|
@@ -11,6 +14,8 @@ import psutil
|
|
|
from fastapi import HTTPException
|
|
|
|
|
|
from . import ai_service, settings_service, windows_automation
|
|
|
+from .automation import get_node_definitions, get_node_executor
|
|
|
+from .automation.context import WorkflowContext, WorkflowPaused
|
|
|
from .database import DATA_DIR, get_db
|
|
|
from .scanner import now_iso
|
|
|
from .schemas import (
|
|
|
@@ -23,6 +28,8 @@ from .schemas import (
|
|
|
AutomationVisionAnalyzeRequest,
|
|
|
AutomationWorkflowRunRequest,
|
|
|
AutomationWorkflowSaveRequest,
|
|
|
+ AutomationWorkflowPlanRequest,
|
|
|
+ AutomationWorkflowPlanContinueRequest,
|
|
|
)
|
|
|
|
|
|
|
|
|
@@ -626,69 +633,69 @@ def close_opened_programs(pids: list[int] | None = None) -> dict[str, Any]:
|
|
|
|
|
|
|
|
|
def save_workflow(payload: AutomationWorkflowSaveRequest) -> dict[str, Any]:
|
|
|
- """保存前端记录或手动编辑的自动化工作流和节点。"""
|
|
|
+ """保存 workflow/v1 工作流图。"""
|
|
|
now = now_iso()
|
|
|
- raw_json = payload.model_dump()
|
|
|
- with get_db() as conn:
|
|
|
- cursor = conn.execute(
|
|
|
- """
|
|
|
- INSERT INTO automation_workflows (name, description, raw_json, created_at, updated_at)
|
|
|
- VALUES (?, ?, ?, ?, ?)
|
|
|
- """,
|
|
|
- (payload.name.strip(), payload.description, json.dumps(raw_json, ensure_ascii=False), now, now),
|
|
|
- )
|
|
|
- workflow_id = cursor.lastrowid
|
|
|
- insert_workflow_nodes(conn, workflow_id, payload.nodes, now)
|
|
|
+ workflow_json = normalize_workflow_payload(payload)
|
|
|
+ workflow_key = normalize_workflow_key(payload.workflow_key)
|
|
|
+ try:
|
|
|
+ with get_db() as conn:
|
|
|
+ cursor = conn.execute(
|
|
|
+ """
|
|
|
+ INSERT INTO automation_workflows (workflow_key, name, description, raw_json, created_at, updated_at)
|
|
|
+ VALUES (?, ?, ?, ?, ?, ?)
|
|
|
+ """,
|
|
|
+ (workflow_key, payload.name.strip(), payload.description, json.dumps(workflow_json, ensure_ascii=False), now, now),
|
|
|
+ )
|
|
|
+ workflow_id = cursor.lastrowid
|
|
|
+ conn.execute("DELETE FROM automation_workflow_nodes WHERE workflow_id = ?", (workflow_id,))
|
|
|
+ except sqlite3.IntegrityError as exc:
|
|
|
+ raise HTTPException(status_code=409, detail="Workflow key already exists") from exc
|
|
|
return get_workflow(workflow_id)
|
|
|
|
|
|
|
|
|
def update_workflow(workflow_id: int, payload: AutomationWorkflowSaveRequest) -> dict[str, Any]:
|
|
|
- """更新工作流基础信息和节点列表。"""
|
|
|
+ """更新 workflow/v1 工作流图。"""
|
|
|
now = now_iso()
|
|
|
- raw_json = payload.model_dump()
|
|
|
- with get_db() as conn:
|
|
|
- existing = conn.execute("SELECT id FROM automation_workflows WHERE id = ?", (workflow_id,)).fetchone()
|
|
|
- if not existing:
|
|
|
- raise HTTPException(status_code=404, detail="Automation workflow not found")
|
|
|
- conn.execute(
|
|
|
- """
|
|
|
- UPDATE automation_workflows
|
|
|
- SET name = ?, description = ?, raw_json = ?, updated_at = ?
|
|
|
- WHERE id = ?
|
|
|
- """,
|
|
|
- (payload.name.strip(), payload.description, json.dumps(raw_json, ensure_ascii=False), now, workflow_id),
|
|
|
- )
|
|
|
- conn.execute("DELETE FROM automation_workflow_nodes WHERE workflow_id = ?", (workflow_id,))
|
|
|
- insert_workflow_nodes(conn, workflow_id, payload.nodes, now)
|
|
|
+ workflow_json = normalize_workflow_payload(payload)
|
|
|
+ workflow_key = normalize_workflow_key(payload.workflow_key)
|
|
|
+ try:
|
|
|
+ with get_db() as conn:
|
|
|
+ existing = conn.execute("SELECT id FROM automation_workflows WHERE id = ?", (workflow_id,)).fetchone()
|
|
|
+ if not existing:
|
|
|
+ raise HTTPException(status_code=404, detail="Automation workflow not found")
|
|
|
+ conn.execute(
|
|
|
+ """
|
|
|
+ UPDATE automation_workflows
|
|
|
+ SET workflow_key = ?, name = ?, description = ?, raw_json = ?, updated_at = ?
|
|
|
+ WHERE id = ?
|
|
|
+ """,
|
|
|
+ (workflow_key, payload.name.strip(), payload.description, json.dumps(workflow_json, ensure_ascii=False), now, workflow_id),
|
|
|
+ )
|
|
|
+ conn.execute("DELETE FROM automation_workflow_nodes WHERE workflow_id = ?", (workflow_id,))
|
|
|
+ except sqlite3.IntegrityError as exc:
|
|
|
+ raise HTTPException(status_code=409, detail="Workflow key already exists") from exc
|
|
|
return get_workflow(workflow_id)
|
|
|
|
|
|
|
|
|
-def insert_workflow_nodes(conn, workflow_id: int, nodes: list[Any], now: str) -> None:
|
|
|
- """批量写入工作流节点。"""
|
|
|
- for index, node in enumerate(nodes, start=1):
|
|
|
- conn.execute(
|
|
|
- """
|
|
|
- INSERT INTO automation_workflow_nodes (
|
|
|
- workflow_id, node_index, node_key, node_type, screen_id, title,
|
|
|
- position_x, position_y, next_node_keys, config_json, created_at, updated_at
|
|
|
- )
|
|
|
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
|
- """,
|
|
|
- (
|
|
|
- workflow_id,
|
|
|
- index,
|
|
|
- node.node_key or f"node_{index}",
|
|
|
- node.node_type,
|
|
|
- node.screen_id,
|
|
|
- node.title,
|
|
|
- node.position_x,
|
|
|
- node.position_y,
|
|
|
- json.dumps(node.next_node_keys, ensure_ascii=False),
|
|
|
- json.dumps(node.config, ensure_ascii=False),
|
|
|
- now,
|
|
|
- now,
|
|
|
- ),
|
|
|
- )
|
|
|
+def normalize_workflow_payload(payload: AutomationWorkflowSaveRequest) -> dict[str, Any]:
|
|
|
+ """把请求模型转换为持久化的 workflow/v1 JSON。"""
|
|
|
+ workflow_json = payload.model_dump()
|
|
|
+ workflow_json["schema_version"] = "workflow/v1"
|
|
|
+ workflow_json["workflow_key"] = normalize_workflow_key(payload.workflow_key)
|
|
|
+ workflow_json["name"] = payload.name.strip()
|
|
|
+ workflow_json.setdefault("variables", {})
|
|
|
+ workflow_json.setdefault("settings", {})
|
|
|
+ workflow_json.setdefault("edges", [])
|
|
|
+ return workflow_json
|
|
|
+
|
|
|
+
|
|
|
+def normalize_workflow_key(value: str | None) -> str | None:
|
|
|
+ key = (value or "").strip()
|
|
|
+ if not key:
|
|
|
+ return None
|
|
|
+ if not re.fullmatch(r"[A-Za-z0-9][A-Za-z0-9_-]*", key):
|
|
|
+ raise HTTPException(status_code=400, detail="Workflow key can only contain letters, numbers, underscores, and hyphens")
|
|
|
+ return key
|
|
|
|
|
|
|
|
|
def list_workflows(page: int, page_size: int) -> dict[str, Any]:
|
|
|
@@ -698,33 +705,38 @@ def list_workflows(page: int, page_size: int) -> dict[str, Any]:
|
|
|
total = conn.execute("SELECT COUNT(*) AS total FROM automation_workflows").fetchone()["total"]
|
|
|
rows = conn.execute(
|
|
|
"""
|
|
|
- SELECT w.*, COUNT(n.id) AS node_count
|
|
|
- FROM automation_workflows w
|
|
|
- LEFT JOIN automation_workflow_nodes n ON n.workflow_id = w.id
|
|
|
- GROUP BY w.id
|
|
|
- ORDER BY w.updated_at DESC
|
|
|
+ SELECT *
|
|
|
+ FROM automation_workflows
|
|
|
+ ORDER BY updated_at DESC
|
|
|
LIMIT ? OFFSET ?
|
|
|
""",
|
|
|
(page_size, offset),
|
|
|
).fetchall()
|
|
|
- return {"items": rows, "total": total, "page": page, "page_size": page_size}
|
|
|
+ return {"items": [workflow_summary(row) for row in rows], "total": total, "page": page, "page_size": page_size}
|
|
|
|
|
|
|
|
|
def get_workflow(workflow_id: int) -> dict[str, Any]:
|
|
|
- """读取工作流详情和节点列表。"""
|
|
|
+ """读取 workflow/v1 工作流详情。"""
|
|
|
with get_db() as conn:
|
|
|
workflow = conn.execute("SELECT * FROM automation_workflows WHERE id = ?", (workflow_id,)).fetchone()
|
|
|
if not workflow:
|
|
|
raise HTTPException(status_code=404, detail="Automation workflow not found")
|
|
|
- nodes = conn.execute(
|
|
|
- "SELECT * FROM automation_workflow_nodes WHERE workflow_id = ? ORDER BY node_index ASC",
|
|
|
- (workflow_id,),
|
|
|
- ).fetchall()
|
|
|
- item = dict(workflow)
|
|
|
- item["nodes"] = [public_node(row) for row in nodes]
|
|
|
+ item = workflow_to_public(workflow)
|
|
|
return item
|
|
|
|
|
|
|
|
|
+def get_workflow_by_key(workflow_key: str) -> dict[str, Any]:
|
|
|
+ """按稳定 key 读取 workflow/v1 工作流详情。"""
|
|
|
+ key = normalize_workflow_key(workflow_key)
|
|
|
+ if not key:
|
|
|
+ raise HTTPException(status_code=400, detail="Workflow key is required")
|
|
|
+ with get_db() as conn:
|
|
|
+ workflow = conn.execute("SELECT * FROM automation_workflows WHERE workflow_key = ?", (key,)).fetchone()
|
|
|
+ if not workflow:
|
|
|
+ raise HTTPException(status_code=404, detail="Automation workflow not found")
|
|
|
+ return workflow_to_public(workflow)
|
|
|
+
|
|
|
+
|
|
|
def delete_workflow(workflow_id: int) -> dict[str, Any]:
|
|
|
"""删除工作流及其节点。"""
|
|
|
with get_db() as conn:
|
|
|
@@ -735,124 +747,269 @@ def delete_workflow(workflow_id: int) -> dict[str, Any]:
|
|
|
|
|
|
|
|
|
def run_workflow(workflow_id: int, payload: AutomationWorkflowRunRequest) -> dict[str, Any]:
|
|
|
- """按数据库中保存的工作流节点和连线顺序在后端执行整个工作流。"""
|
|
|
+ """执行 workflow/v1 工作流图。"""
|
|
|
workflow = get_workflow(workflow_id)
|
|
|
defaults = settings_service.default_ai_params()
|
|
|
provider_id = payload.provider_id or defaults.get("provider_id")
|
|
|
model_id = payload.model_id or defaults.get("model_id")
|
|
|
temperature = payload.temperature if payload.temperature is not None else defaults.get("temperature", 0.1)
|
|
|
- nodes = ordered_workflow_nodes(workflow.get("nodes") or [])
|
|
|
- results: list[dict[str, Any]] = []
|
|
|
- opened_pids: list[int] = []
|
|
|
+ context = WorkflowContext(
|
|
|
+ workflow_id=workflow_id,
|
|
|
+ provider_id=provider_id,
|
|
|
+ model_id=model_id,
|
|
|
+ temperature=float(temperature),
|
|
|
+ variables=workflow_variables(workflow, payload.variables),
|
|
|
+ )
|
|
|
+ nodes = workflow.get("nodes") or []
|
|
|
+ edges = workflow.get("edges") or []
|
|
|
+ node_map = {node["id"]: node for node in nodes}
|
|
|
+ start_id = first_workflow_node_id(nodes, edges)
|
|
|
+ if not start_id:
|
|
|
+ return {"workflow_id": workflow_id, "status": "SUCCESS", "results": []}
|
|
|
|
|
|
- for node in nodes:
|
|
|
+ results: list[dict[str, Any]] = []
|
|
|
+ current_id: str | None = start_id
|
|
|
+ visited_steps = 0
|
|
|
+ max_steps = int(workflow.get("settings", {}).get("max_steps") or 100)
|
|
|
+
|
|
|
+ while current_id and visited_steps < max_steps:
|
|
|
+ visited_steps += 1
|
|
|
+ node = node_map.get(current_id)
|
|
|
+ if not node:
|
|
|
+ return {"workflow_id": workflow_id, "status": "FAILED", "detail": f"Missing node: {current_id}", "results": results}
|
|
|
try:
|
|
|
- result = execute_workflow_node(workflow_id, node, provider_id, model_id, temperature, opened_pids)
|
|
|
- opened_pids.extend(
|
|
|
- item["pid"]
|
|
|
- for item in result.get("new_processes", [])
|
|
|
- if item.get("pid") and item["pid"] not in opened_pids
|
|
|
- )
|
|
|
- results.append({"node": node, "status": "SUCCESS", "result": result})
|
|
|
+ resolved_inputs = resolve_node_inputs(node, edges, context)
|
|
|
+ outputs = execute_workflow_node(node, resolved_inputs, context)
|
|
|
+ context.outputs[node["id"]] = outputs
|
|
|
+ results.append({"node_id": node["id"], "node": node, "status": "SUCCESS", "inputs": resolved_inputs, "outputs": outputs})
|
|
|
+ if node.get("type") == "flow.end":
|
|
|
+ return {"workflow_id": workflow_id, "status": "SUCCESS", "results": results, "outputs": context.outputs}
|
|
|
+ next_port = str(outputs.get("next_port") or "success")
|
|
|
+ current_id = next_control_node_id(node["id"], next_port, edges) or next_control_node_id(node["id"], "next", edges)
|
|
|
except HTTPException as exc:
|
|
|
- if not (isinstance(exc.detail, dict) and exc.detail.get("error")):
|
|
|
- record_error(
|
|
|
- action_type=node.get("node_type") or "workflow",
|
|
|
- message=str(exc.detail),
|
|
|
- screen_id=node.get("screen_id"),
|
|
|
- workflow_id=workflow_id,
|
|
|
- node_id=node.get("id"),
|
|
|
- )
|
|
|
- failure = {"node": node, "status": "FAILED", "detail": exc.detail}
|
|
|
+ failure = {
|
|
|
+ "node_id": node.get("id"),
|
|
|
+ "node": node,
|
|
|
+ "status": "FAILED",
|
|
|
+ "detail": exc.detail,
|
|
|
+ "artifacts": capture_failure_artifacts(context),
|
|
|
+ }
|
|
|
results.append(failure)
|
|
|
return {"workflow_id": workflow_id, "status": "FAILED", "failed": failure, "results": results}
|
|
|
- return {"workflow_id": workflow_id, "status": "SUCCESS", "results": results}
|
|
|
+ except WorkflowPaused as exc:
|
|
|
+ paused = {"node_id": node.get("id"), "node": node, "status": "PAUSED", "detail": exc.payload}
|
|
|
+ results.append(paused)
|
|
|
+ return {"workflow_id": workflow_id, "status": "PAUSED", "paused": paused, "results": results}
|
|
|
+ except Exception as exc:
|
|
|
+ failure = {
|
|
|
+ "node_id": node.get("id"),
|
|
|
+ "node": node,
|
|
|
+ "status": "FAILED",
|
|
|
+ "detail": str(exc),
|
|
|
+ "artifacts": capture_failure_artifacts(context),
|
|
|
+ }
|
|
|
+ results.append(failure)
|
|
|
+ return {"workflow_id": workflow_id, "status": "FAILED", "failed": failure, "results": results}
|
|
|
+ if visited_steps >= max_steps:
|
|
|
+ return {"workflow_id": workflow_id, "status": "FAILED", "detail": f"Workflow exceeded max_steps={max_steps}", "results": results}
|
|
|
+ return {"workflow_id": workflow_id, "status": "SUCCESS", "results": results, "outputs": context.outputs}
|
|
|
|
|
|
|
|
|
-def ordered_workflow_nodes(nodes: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
|
- """根据节点连线得到执行顺序;没有连线时沿用节点序号。"""
|
|
|
- if not nodes:
|
|
|
- return []
|
|
|
- by_key = {node.get("node_key") or f"node_{node.get('node_index')}": node for node in nodes}
|
|
|
- targeted = {key for node in nodes for key in node.get("next_node_keys", [])}
|
|
|
- start_keys = [key for key in by_key if key not in targeted] or [next(iter(by_key))]
|
|
|
- ordered: list[dict[str, Any]] = []
|
|
|
- visited: set[str] = set()
|
|
|
-
|
|
|
- def visit(key: str) -> None:
|
|
|
- if key in visited or key not in by_key:
|
|
|
- return
|
|
|
- visited.add(key)
|
|
|
- node = by_key[key]
|
|
|
- ordered.append(node)
|
|
|
- for next_key in node.get("next_node_keys", []):
|
|
|
- visit(next_key)
|
|
|
-
|
|
|
- for key in start_keys:
|
|
|
- visit(key)
|
|
|
- for key in by_key:
|
|
|
- visit(key)
|
|
|
- return ordered
|
|
|
+def run_workflow_by_key(workflow_key: str, payload: AutomationWorkflowRunRequest) -> dict[str, Any]:
|
|
|
+ workflow = get_workflow_by_key(workflow_key)
|
|
|
+ return run_workflow(int(workflow["id"]), payload)
|
|
|
|
|
|
|
|
|
def execute_workflow_node(
|
|
|
- workflow_id: int,
|
|
|
node: dict[str, Any],
|
|
|
- provider_id: int | None,
|
|
|
- model_id: int | None,
|
|
|
- temperature: float,
|
|
|
- opened_pids: list[int],
|
|
|
+ inputs: dict[str, Any],
|
|
|
+ context: WorkflowContext,
|
|
|
) -> dict[str, Any]:
|
|
|
- """把工作流节点配置转换成已有高层动作并执行。"""
|
|
|
- node_type = node.get("node_type")
|
|
|
- config = node.get("config") or {}
|
|
|
- base = {
|
|
|
- "screen_id": node.get("screen_id"),
|
|
|
- "provider_id": provider_id,
|
|
|
- "model_id": model_id,
|
|
|
- "temperature": temperature,
|
|
|
- "workflow_id": workflow_id,
|
|
|
- "node_id": node.get("id"),
|
|
|
+ """通过节点注册表执行 workflow/v1 节点。"""
|
|
|
+ try:
|
|
|
+ executor = get_node_executor(str(node.get("type") or ""))
|
|
|
+ except KeyError as exc:
|
|
|
+ raise HTTPException(status_code=400, detail=str(exc)) from exc
|
|
|
+ return executor(node, inputs, context)
|
|
|
+
|
|
|
+
|
|
|
+def capture_failure_artifacts(context: WorkflowContext) -> dict[str, Any]:
|
|
|
+ """工作流失败时尽量保存一张当前屏幕截图,供前端询问用户。"""
|
|
|
+ artifacts: dict[str, Any] = {}
|
|
|
+ try:
|
|
|
+ screenshot = take_screenshot_file(error_dir(), "workflow_failure")
|
|
|
+ except Exception as exc:
|
|
|
+ artifacts["screenshot_error"] = str(exc)
|
|
|
+ return artifacts
|
|
|
+ artifacts["screenshot_path"] = screenshot.get("db_path") or screenshot.get("path")
|
|
|
+ artifacts["width"] = screenshot.get("width")
|
|
|
+ artifacts["height"] = screenshot.get("height")
|
|
|
+ context.runtime["current_screenshot_path"] = artifacts["screenshot_path"]
|
|
|
+ return artifacts
|
|
|
+
|
|
|
+
|
|
|
+def workflow_to_public(row: dict[str, Any]) -> dict[str, Any]:
|
|
|
+ item = workflow_summary(row)
|
|
|
+ workflow_json = parse_workflow_json(row.get("raw_json"))
|
|
|
+ item.update(workflow_json)
|
|
|
+ item["id"] = row["id"]
|
|
|
+ item["workflow_key"] = row.get("workflow_key") or workflow_json.get("workflow_key")
|
|
|
+ item["created_at"] = row["created_at"]
|
|
|
+ item["updated_at"] = row["updated_at"]
|
|
|
+ item["node_count"] = len(item.get("nodes") or [])
|
|
|
+ item["edge_count"] = len(item.get("edges") or [])
|
|
|
+ return item
|
|
|
+
|
|
|
+
|
|
|
+def workflow_summary(row: dict[str, Any]) -> dict[str, Any]:
|
|
|
+ workflow_json = parse_workflow_json(row.get("raw_json"))
|
|
|
+ return {
|
|
|
+ "id": row["id"],
|
|
|
+ "workflow_key": row.get("workflow_key") or workflow_json.get("workflow_key"),
|
|
|
+ "name": row["name"],
|
|
|
+ "description": row.get("description"),
|
|
|
+ "schema_version": workflow_json.get("schema_version") or "workflow/v1",
|
|
|
+ "node_count": len(workflow_json.get("nodes") or []),
|
|
|
+ "edge_count": len(workflow_json.get("edges") or []),
|
|
|
+ "created_at": row.get("created_at"),
|
|
|
+ "updated_at": row.get("updated_at"),
|
|
|
}
|
|
|
- if node_type == "mouse":
|
|
|
- return execute_mouse_action(
|
|
|
- AutomationMouseActionRequest(
|
|
|
- **base,
|
|
|
- x=int(config.get("x", 0)),
|
|
|
- y=int(config.get("y", 0)),
|
|
|
- mouse_action=config.get("mouse_action") or "click",
|
|
|
- )
|
|
|
- )
|
|
|
- if node_type == "keyboard":
|
|
|
- return execute_keyboard_action(AutomationKeyboardActionRequest(**base, keys=config.get("keys") or []))
|
|
|
- if node_type == "text_input":
|
|
|
- return execute_text_input(AutomationTextInputRequest(**base, text=str(config.get("text") or "")))
|
|
|
- if node_type == "start_program":
|
|
|
- return execute_start_program(
|
|
|
- AutomationStartProgramRequest(
|
|
|
- **base,
|
|
|
- command=str(config.get("command") or ""),
|
|
|
- cwd=config.get("cwd"),
|
|
|
- shell=bool(config.get("shell", True)),
|
|
|
- )
|
|
|
- )
|
|
|
- if node_type == "close_programs":
|
|
|
- return close_opened_programs(config.get("pids") or opened_pids)
|
|
|
- raise HTTPException(status_code=400, detail=f"Unsupported workflow node type: {node_type}")
|
|
|
|
|
|
|
|
|
-def public_node(row: dict[str, Any]) -> dict[str, Any]:
|
|
|
- """把工作流节点行转换为接口返回格式。"""
|
|
|
- item = dict(row)
|
|
|
+def parse_workflow_json(raw_json: str | None) -> dict[str, Any]:
|
|
|
+ if not raw_json:
|
|
|
+ return empty_workflow_json()
|
|
|
try:
|
|
|
- item["config"] = json.loads(item.pop("config_json") or "{}")
|
|
|
+ parsed = json.loads(raw_json)
|
|
|
except json.JSONDecodeError:
|
|
|
- item["config"] = {}
|
|
|
+ return empty_workflow_json()
|
|
|
+ if not isinstance(parsed, dict):
|
|
|
+ return empty_workflow_json()
|
|
|
+ parsed.setdefault("schema_version", "workflow/v1")
|
|
|
+ parsed.setdefault("variables", {})
|
|
|
+ parsed.setdefault("settings", {})
|
|
|
+ parsed.setdefault("nodes", [])
|
|
|
+ parsed.setdefault("edges", [])
|
|
|
+ return parsed
|
|
|
+
|
|
|
+
|
|
|
+def empty_workflow_json() -> dict[str, Any]:
|
|
|
+ return {"schema_version": "workflow/v1", "variables": {}, "settings": {}, "nodes": [], "edges": []}
|
|
|
+
|
|
|
+
|
|
|
+def workflow_variables(workflow: dict[str, Any], overrides: dict[str, Any]) -> dict[str, Any]:
|
|
|
+ variables: dict[str, Any] = {}
|
|
|
+ for name, definition in (workflow.get("variables") or {}).items():
|
|
|
+ if isinstance(definition, dict):
|
|
|
+ variables[name] = definition.get("default")
|
|
|
+ else:
|
|
|
+ variables[name] = definition
|
|
|
+ variables.update(overrides or {})
|
|
|
+ return variables
|
|
|
+
|
|
|
+
|
|
|
+def first_workflow_node_id(nodes: list[dict[str, Any]], edges: list[dict[str, Any]]) -> str | None:
|
|
|
+ if not nodes:
|
|
|
+ return None
|
|
|
+ for node in nodes:
|
|
|
+ if node.get("type") == "flow.start":
|
|
|
+ return node.get("id")
|
|
|
+ targeted = {edge.get("target") for edge in edges if edge.get("kind") == "control"}
|
|
|
+ for node in nodes:
|
|
|
+ if node.get("id") not in targeted:
|
|
|
+ return node.get("id")
|
|
|
+ return nodes[0].get("id")
|
|
|
+
|
|
|
+
|
|
|
+def next_control_node_id(source_id: str, source_port: str, edges: list[dict[str, Any]]) -> str | None:
|
|
|
+ fallback = None
|
|
|
+ for edge in edges:
|
|
|
+ if edge.get("kind") != "control" or edge.get("source") != source_id:
|
|
|
+ continue
|
|
|
+ if edge.get("source_port") == source_port:
|
|
|
+ return edge.get("target")
|
|
|
+ if edge.get("source_port") in (None, "", "success", "next") and fallback is None:
|
|
|
+ fallback = edge.get("target")
|
|
|
+ return fallback
|
|
|
+
|
|
|
+
|
|
|
+def resolve_node_inputs(node: dict[str, Any], edges: list[dict[str, Any]], context: WorkflowContext) -> dict[str, Any]:
|
|
|
+ resolved: dict[str, Any] = {}
|
|
|
+ for key, value in (node.get("inputs") or {}).items():
|
|
|
+ resolved[key] = resolve_value_ref(value, context)
|
|
|
+ for edge in edges:
|
|
|
+ if edge.get("kind") != "data" or edge.get("target") != node.get("id"):
|
|
|
+ continue
|
|
|
+ source_outputs = context.outputs.get(edge.get("source") or "", {})
|
|
|
+ resolved[edge.get("target_port") or "value"] = source_outputs.get(edge.get("source_port") or "value")
|
|
|
+ return resolved
|
|
|
+
|
|
|
+
|
|
|
+def resolve_value_ref(value: Any, context: WorkflowContext) -> Any:
|
|
|
+ if not isinstance(value, dict) or "source" not in value:
|
|
|
+ return value
|
|
|
+ source = value.get("source")
|
|
|
+ if source == "literal":
|
|
|
+ return value.get("value")
|
|
|
+ if source == "variable":
|
|
|
+ return context.variables.get(value.get("name") or "")
|
|
|
+ if source == "node_output":
|
|
|
+ return context.outputs.get(value.get("node_id") or "", {}).get(value.get("output") or "")
|
|
|
+ if source == "runtime":
|
|
|
+ return context.runtime.get(value.get("name") or "")
|
|
|
+ return None
|
|
|
+
|
|
|
+
|
|
|
+def list_workflow_node_definitions() -> dict[str, Any]:
|
|
|
+ """返回前端可用于生成节点库和属性表单的节点定义。"""
|
|
|
+ return {"schema_version": "workflow/v1", "items": get_node_definitions()}
|
|
|
+
|
|
|
+
|
|
|
+def plan_workflow(payload: AutomationWorkflowPlanRequest) -> dict[str, Any]:
|
|
|
+ """让 AI 根据用户需求和节点定义生成 workflow/v1 草稿。"""
|
|
|
+ provider_id, model_id, temperature = resolve_ai_params(payload.provider_id, payload.model_id, payload.temperature)
|
|
|
+ prompt = build_workflow_plan_prompt(payload.requirement)
|
|
|
+ ai_result = ai_service.chat(provider_id, model_id, prompt, temperature)
|
|
|
try:
|
|
|
- item["next_node_keys"] = json.loads(item.get("next_node_keys") or "[]")
|
|
|
- except json.JSONDecodeError:
|
|
|
- item["next_node_keys"] = []
|
|
|
- return item
|
|
|
+ parsed = json_from_ai(ai_result["content"])
|
|
|
+ except (json.JSONDecodeError, ValueError) as exc:
|
|
|
+ raise HTTPException(status_code=502, detail=f"AI workflow plan output is not valid JSON: {exc}") from exc
|
|
|
+ session_id = str(uuid.uuid4())
|
|
|
+ return {"session_id": session_id, "plan": parsed, "ai_raw_content": ai_result["content"]}
|
|
|
+
|
|
|
+
|
|
|
+def continue_workflow_plan(payload: AutomationWorkflowPlanContinueRequest) -> dict[str, Any]:
|
|
|
+ """继续一次 AI 工作流规划对话,返回新的草稿建议。"""
|
|
|
+ provider_id, model_id, temperature = resolve_ai_params(payload.provider_id, payload.model_id, payload.temperature)
|
|
|
+ prompt = build_workflow_plan_prompt(payload.user_message, session_id=payload.session_id)
|
|
|
+ ai_result = ai_service.chat(provider_id, model_id, prompt, temperature)
|
|
|
+ try:
|
|
|
+ parsed = json_from_ai(ai_result["content"])
|
|
|
+ except (json.JSONDecodeError, ValueError) as exc:
|
|
|
+ raise HTTPException(status_code=502, detail=f"AI workflow plan output is not valid JSON: {exc}") from exc
|
|
|
+ return {"session_id": payload.session_id, "plan": parsed, "ai_raw_content": ai_result["content"]}
|
|
|
+
|
|
|
+
|
|
|
+def build_workflow_plan_prompt(requirement: str, session_id: str | None = None) -> str:
|
|
|
+ node_defs = json.dumps(get_node_definitions(), ensure_ascii=False, indent=2)
|
|
|
+ return f"""请作为 Windows 自动化工作流规划器,根据用户需求生成 workflow/v1 JSON 草稿。
|
|
|
+
|
|
|
+要求:
|
|
|
+1. 只能使用节点定义列表中的 type。
|
|
|
+2. 输出严格 JSON 对象,不要 Markdown。
|
|
|
+3. JSON 字段必须包含 summary、questions、workflow。
|
|
|
+4. workflow 必须包含 schema_version、name、description、variables、settings、nodes、edges。
|
|
|
+5. 不确定的坐标或界面状态,优先添加 human.ask_user 节点或 screen.screenshot 节点。
|
|
|
+6. 控制流连线 kind 使用 control,数据连线 kind 使用 data。
|
|
|
+
|
|
|
+会话 ID:{session_id or "new"}
|
|
|
+
|
|
|
+用户需求:
|
|
|
+{requirement}
|
|
|
+
|
|
|
+可用节点定义:
|
|
|
+{node_defs}
|
|
|
+"""
|
|
|
|
|
|
|
|
|
def list_errors(page: int, page_size: int) -> dict[str, Any]:
|