| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268 |
- from __future__ import annotations
- import json
- import threading
- import time
- import uuid
- from typing import Any
- from fastapi import HTTPException
- from .database import get_db
- from .scanner import now_iso
- from .schemas import AutomationWorkflowRunRequest
- TERMINAL_STATUSES = {"SUCCESS", "FAILED", "PAUSED"}
- _worker_thread: threading.Thread | None = None
- _worker_guard = threading.Lock()
- _stop_event = threading.Event()
- _wake_event = threading.Event()
- def start_worker() -> None:
- """启动唯一的后台消费线程,并恢复服务异常退出时遗留的任务。"""
- global _worker_thread
- with _worker_guard:
- if _worker_thread and _worker_thread.is_alive():
- return
- with get_db() as conn:
- conn.execute(
- """
- UPDATE automation_workflow_tasks
- SET status = 'QUEUED', started_at = NULL,
- error_message = CASE
- WHEN error_message IS NULL OR error_message = '' THEN '服务重启后重新排队'
- ELSE error_message
- END
- WHERE status = 'RUNNING'
- """
- )
- conn.execute(
- "UPDATE automation_workflow_runtime SET active_task_id = NULL, updated_at = ? WHERE id = 1",
- (now_iso(),),
- )
- _stop_event.clear()
- _worker_thread = threading.Thread(target=_worker_loop, name="workflow-task-worker", daemon=True)
- _worker_thread.start()
- def stop_worker() -> None:
- """通知后台线程停止;正在执行的节点会在当前进程退出时结束。"""
- _stop_event.set()
- _wake_event.set()
- def enqueue_workflow_task(workflow: dict[str, Any], payload: AutomationWorkflowRunRequest) -> dict[str, Any]:
- """保存任务快照并加入全局串行队列。"""
- workflow_key = str(workflow.get("workflow_key") or "").strip()
- if not workflow_key:
- raise HTTPException(status_code=400, detail="Workflow key is required before execution")
- task_id = str(uuid.uuid4())
- created_at = now_iso()
- request_json = json.dumps(payload.model_dump(), ensure_ascii=False)
- snapshot_json = json.dumps(workflow, ensure_ascii=False)
- with get_db() as conn:
- conn.execute(
- """
- INSERT INTO automation_workflow_tasks (
- id, workflow_id, workflow_key, workflow_name, status,
- request_json, workflow_snapshot_json, created_at
- ) VALUES (?, ?, ?, ?, 'QUEUED', ?, ?, ?)
- """,
- (
- task_id,
- workflow.get("id"),
- workflow_key,
- str(workflow.get("name") or workflow_key),
- request_json,
- snapshot_json,
- created_at,
- ),
- )
- _wake_event.set()
- return get_workflow_task(task_id)
- def list_workflow_tasks(page: int, page_size: int, status: str | None = None) -> dict[str, Any]:
- """分页读取 workflow 异步任务历史。"""
- clauses: list[str] = []
- params: list[Any] = []
- if status:
- clauses.append("status = ?")
- params.append(status.upper())
- where_sql = "WHERE " + " AND ".join(clauses) if clauses else ""
- offset = (page - 1) * page_size
- with get_db() as conn:
- total = conn.execute(
- f"SELECT COUNT(*) AS total FROM automation_workflow_tasks {where_sql}",
- params,
- ).fetchone()["total"]
- rows = conn.execute(
- f"""
- SELECT * FROM automation_workflow_tasks
- {where_sql}
- ORDER BY created_at DESC
- LIMIT ? OFFSET ?
- """,
- [*params, page_size, offset],
- ).fetchall()
- return {
- "items": [task_to_public(row, include_payload=False) for row in rows],
- "total": total,
- "page": page,
- "page_size": page_size,
- }
- def get_workflow_task(task_id: str) -> dict[str, Any]:
- with get_db() as conn:
- row = conn.execute("SELECT * FROM automation_workflow_tasks WHERE id = ?", (task_id,)).fetchone()
- if not row:
- raise HTTPException(status_code=404, detail="Workflow task not found")
- return task_to_public(row, include_payload=True)
- def task_to_public(row: dict[str, Any], include_payload: bool) -> dict[str, Any]:
- item = {
- "id": row["id"],
- "workflow_id": row.get("workflow_id"),
- "workflow_key": row["workflow_key"],
- "workflow_name": row["workflow_name"],
- "status": row["status"],
- "error_message": row.get("error_message"),
- "created_at": row["created_at"],
- "started_at": row.get("started_at"),
- "finished_at": row.get("finished_at"),
- }
- if row["status"] == "QUEUED":
- item["queue_position"] = queue_position(row["id"], row["created_at"])
- else:
- item["queue_position"] = None
- if include_payload:
- item["request"] = parse_json_object(row.get("request_json"))
- item["result"] = parse_json_value(row.get("result_json"))
- item["return_data"] = parse_json_value(row.get("return_data_json"))
- return item
- def queue_position(task_id: str, created_at: str) -> int:
- with get_db() as conn:
- count = conn.execute(
- """
- SELECT COUNT(*) AS total
- FROM automation_workflow_tasks
- WHERE status = 'QUEUED'
- AND (created_at < ? OR (created_at = ? AND id <= ?))
- """,
- (created_at, created_at, task_id),
- ).fetchone()["total"]
- return int(count)
- def parse_json_object(value: str | None) -> dict[str, Any]:
- parsed = parse_json_value(value)
- return parsed if isinstance(parsed, dict) else {}
- def parse_json_value(value: str | None) -> Any:
- if not value:
- return None
- try:
- return json.loads(value)
- except json.JSONDecodeError:
- return None
- def _worker_loop() -> None:
- while not _stop_event.is_set():
- task = _claim_next_task()
- if not task:
- _wake_event.wait(1)
- _wake_event.clear()
- continue
- _execute_task(task)
- def _claim_next_task() -> dict[str, Any] | None:
- """通过 SQLite 写锁领取任务,确保任何时刻只有一个全局活动任务。"""
- with get_db() as conn:
- conn.execute("BEGIN IMMEDIATE")
- runtime = conn.execute(
- "SELECT active_task_id FROM automation_workflow_runtime WHERE id = 1"
- ).fetchone()
- if runtime and runtime.get("active_task_id"):
- return None
- task = conn.execute(
- """
- SELECT * FROM automation_workflow_tasks
- WHERE status = 'QUEUED'
- ORDER BY created_at ASC, id ASC
- LIMIT 1
- """
- ).fetchone()
- if not task:
- return None
- started_at = now_iso()
- conn.execute(
- "UPDATE automation_workflow_tasks SET status = 'RUNNING', started_at = ?, error_message = NULL WHERE id = ?",
- (started_at, task["id"]),
- )
- conn.execute(
- "UPDATE automation_workflow_runtime SET active_task_id = ?, updated_at = ? WHERE id = 1",
- (task["id"], started_at),
- )
- task["status"] = "RUNNING"
- task["started_at"] = started_at
- return task
- def _execute_task(task: dict[str, Any]) -> None:
- from . import automation_service
- result: dict[str, Any] | None = None
- error_message: str | None = None
- status = "FAILED"
- try:
- workflow = parse_json_object(task.get("workflow_snapshot_json"))
- payload = AutomationWorkflowRunRequest.model_validate(parse_json_object(task.get("request_json")))
- result = automation_service.execute_workflow(workflow, payload)
- raw_status = str(result.get("status") or "FAILED").upper()
- status = raw_status if raw_status in TERMINAL_STATUSES else "FAILED"
- if status == "FAILED":
- error_message = workflow_failure_message(result)
- return_data = automation_service.workflow_return_data(workflow, result)
- except Exception as exc:
- error_message = str(exc)
- return_data = None
- result = {"status": "FAILED", "detail": error_message}
- finished_at = now_iso()
- with get_db() as conn:
- conn.execute("BEGIN IMMEDIATE")
- conn.execute(
- """
- UPDATE automation_workflow_tasks
- SET status = ?, result_json = ?, return_data_json = ?, error_message = ?, finished_at = ?
- WHERE id = ?
- """,
- (
- status,
- json.dumps(result, ensure_ascii=False),
- json.dumps(return_data, ensure_ascii=False) if return_data is not None else None,
- error_message,
- finished_at,
- task["id"],
- ),
- )
- conn.execute(
- "UPDATE automation_workflow_runtime SET active_task_id = NULL, updated_at = ? WHERE id = 1",
- (finished_at,),
- )
- _wake_event.set()
- def workflow_failure_message(result: dict[str, Any]) -> str | None:
- failed = result.get("failed")
- if isinstance(failed, dict) and failed.get("detail"):
- return str(failed["detail"])
- return str(result.get("detail") or "Workflow execution failed")
|