from __future__ import annotations import json import threading import time import uuid from typing import Any from fastapi import HTTPException from .database import get_db from .scanner import now_iso from .schemas import AutomationWorkflowRunRequest TERMINAL_STATUSES = {"SUCCESS", "FAILED", "PAUSED"} _worker_thread: threading.Thread | None = None _worker_guard = threading.Lock() _stop_event = threading.Event() _wake_event = threading.Event() def start_worker() -> None: """启动唯一的后台消费线程,并恢复服务异常退出时遗留的任务。""" global _worker_thread with _worker_guard: if _worker_thread and _worker_thread.is_alive(): return with get_db() as conn: conn.execute( """ UPDATE automation_workflow_tasks SET status = 'QUEUED', started_at = NULL, error_message = CASE WHEN error_message IS NULL OR error_message = '' THEN '服务重启后重新排队' ELSE error_message END WHERE status = 'RUNNING' """ ) conn.execute( "UPDATE automation_workflow_runtime SET active_task_id = NULL, updated_at = ? WHERE id = 1", (now_iso(),), ) _stop_event.clear() _worker_thread = threading.Thread(target=_worker_loop, name="workflow-task-worker", daemon=True) _worker_thread.start() def stop_worker() -> None: """通知后台线程停止;正在执行的节点会在当前进程退出时结束。""" _stop_event.set() _wake_event.set() def enqueue_workflow_task(workflow: dict[str, Any], payload: AutomationWorkflowRunRequest) -> dict[str, Any]: """保存任务快照并加入全局串行队列。""" workflow_key = str(workflow.get("workflow_key") or "").strip() if not workflow_key: raise HTTPException(status_code=400, detail="Workflow key is required before execution") task_id = str(uuid.uuid4()) created_at = now_iso() request_json = json.dumps(payload.model_dump(), ensure_ascii=False) snapshot_json = json.dumps(workflow, ensure_ascii=False) with get_db() as conn: conn.execute( """ INSERT INTO automation_workflow_tasks ( id, workflow_id, workflow_key, workflow_name, status, request_json, workflow_snapshot_json, created_at ) VALUES (?, ?, ?, ?, 'QUEUED', ?, ?, ?) """, ( task_id, workflow.get("id"), workflow_key, str(workflow.get("name") or workflow_key), request_json, snapshot_json, created_at, ), ) _wake_event.set() return get_workflow_task(task_id) def list_workflow_tasks(page: int, page_size: int, status: str | None = None) -> dict[str, Any]: """分页读取 workflow 异步任务历史。""" clauses: list[str] = [] params: list[Any] = [] if status: clauses.append("status = ?") params.append(status.upper()) where_sql = "WHERE " + " AND ".join(clauses) if clauses else "" offset = (page - 1) * page_size with get_db() as conn: total = conn.execute( f"SELECT COUNT(*) AS total FROM automation_workflow_tasks {where_sql}", params, ).fetchone()["total"] rows = conn.execute( f""" SELECT * FROM automation_workflow_tasks {where_sql} ORDER BY created_at DESC LIMIT ? OFFSET ? """, [*params, page_size, offset], ).fetchall() return { "items": [task_to_public(row, include_payload=False) for row in rows], "total": total, "page": page, "page_size": page_size, } def get_workflow_task(task_id: str) -> dict[str, Any]: with get_db() as conn: row = conn.execute("SELECT * FROM automation_workflow_tasks WHERE id = ?", (task_id,)).fetchone() if not row: raise HTTPException(status_code=404, detail="Workflow task not found") return task_to_public(row, include_payload=True) def task_to_public(row: dict[str, Any], include_payload: bool) -> dict[str, Any]: item = { "id": row["id"], "workflow_id": row.get("workflow_id"), "workflow_key": row["workflow_key"], "workflow_name": row["workflow_name"], "status": row["status"], "error_message": row.get("error_message"), "created_at": row["created_at"], "started_at": row.get("started_at"), "finished_at": row.get("finished_at"), } if row["status"] == "QUEUED": item["queue_position"] = queue_position(row["id"], row["created_at"]) else: item["queue_position"] = None if include_payload: item["request"] = parse_json_object(row.get("request_json")) item["result"] = parse_json_value(row.get("result_json")) item["return_data"] = parse_json_value(row.get("return_data_json")) return item def queue_position(task_id: str, created_at: str) -> int: with get_db() as conn: count = conn.execute( """ SELECT COUNT(*) AS total FROM automation_workflow_tasks WHERE status = 'QUEUED' AND (created_at < ? OR (created_at = ? AND id <= ?)) """, (created_at, created_at, task_id), ).fetchone()["total"] return int(count) def parse_json_object(value: str | None) -> dict[str, Any]: parsed = parse_json_value(value) return parsed if isinstance(parsed, dict) else {} def parse_json_value(value: str | None) -> Any: if not value: return None try: return json.loads(value) except json.JSONDecodeError: return None def _worker_loop() -> None: while not _stop_event.is_set(): task = _claim_next_task() if not task: _wake_event.wait(1) _wake_event.clear() continue _execute_task(task) def _claim_next_task() -> dict[str, Any] | None: """通过 SQLite 写锁领取任务,确保任何时刻只有一个全局活动任务。""" with get_db() as conn: conn.execute("BEGIN IMMEDIATE") runtime = conn.execute( "SELECT active_task_id FROM automation_workflow_runtime WHERE id = 1" ).fetchone() if runtime and runtime.get("active_task_id"): return None task = conn.execute( """ SELECT * FROM automation_workflow_tasks WHERE status = 'QUEUED' ORDER BY created_at ASC, id ASC LIMIT 1 """ ).fetchone() if not task: return None started_at = now_iso() conn.execute( "UPDATE automation_workflow_tasks SET status = 'RUNNING', started_at = ?, error_message = NULL WHERE id = ?", (started_at, task["id"]), ) conn.execute( "UPDATE automation_workflow_runtime SET active_task_id = ?, updated_at = ? WHERE id = 1", (task["id"], started_at), ) task["status"] = "RUNNING" task["started_at"] = started_at return task def _execute_task(task: dict[str, Any]) -> None: from . import automation_service result: dict[str, Any] | None = None error_message: str | None = None status = "FAILED" try: workflow = parse_json_object(task.get("workflow_snapshot_json")) payload = AutomationWorkflowRunRequest.model_validate(parse_json_object(task.get("request_json"))) result = automation_service.execute_workflow(workflow, payload) raw_status = str(result.get("status") or "FAILED").upper() status = raw_status if raw_status in TERMINAL_STATUSES else "FAILED" if status == "FAILED": error_message = workflow_failure_message(result) return_data = automation_service.workflow_return_data(workflow, result) except Exception as exc: error_message = str(exc) return_data = None result = {"status": "FAILED", "detail": error_message} finished_at = now_iso() with get_db() as conn: conn.execute("BEGIN IMMEDIATE") conn.execute( """ UPDATE automation_workflow_tasks SET status = ?, result_json = ?, return_data_json = ?, error_message = ?, finished_at = ? WHERE id = ? """, ( status, json.dumps(result, ensure_ascii=False), json.dumps(return_data, ensure_ascii=False) if return_data is not None else None, error_message, finished_at, task["id"], ), ) conn.execute( "UPDATE automation_workflow_runtime SET active_task_id = NULL, updated_at = ? WHERE id = 1", (finished_at,), ) _wake_event.set() def workflow_failure_message(result: dict[str, Any]) -> str | None: failed = result.get("failed") if isinstance(failed, dict) and failed.get("detail"): return str(failed["detail"]) return str(result.get("detail") or "Workflow execution failed")