workflow_task_service.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. from __future__ import annotations
  2. import json
  3. import threading
  4. import time
  5. import uuid
  6. from typing import Any
  7. from fastapi import HTTPException
  8. from .database import get_db
  9. from .scanner import now_iso
  10. from .schemas import AutomationWorkflowRunRequest
  11. TERMINAL_STATUSES = {"SUCCESS", "FAILED", "PAUSED"}
  12. _worker_thread: threading.Thread | None = None
  13. _worker_guard = threading.Lock()
  14. _stop_event = threading.Event()
  15. _wake_event = threading.Event()
  16. def start_worker() -> None:
  17. """启动唯一的后台消费线程,并恢复服务异常退出时遗留的任务。"""
  18. global _worker_thread
  19. with _worker_guard:
  20. if _worker_thread and _worker_thread.is_alive():
  21. return
  22. with get_db() as conn:
  23. conn.execute(
  24. """
  25. UPDATE automation_workflow_tasks
  26. SET status = 'QUEUED', started_at = NULL,
  27. error_message = CASE
  28. WHEN error_message IS NULL OR error_message = '' THEN '服务重启后重新排队'
  29. ELSE error_message
  30. END
  31. WHERE status = 'RUNNING'
  32. """
  33. )
  34. conn.execute(
  35. "UPDATE automation_workflow_runtime SET active_task_id = NULL, updated_at = ? WHERE id = 1",
  36. (now_iso(),),
  37. )
  38. _stop_event.clear()
  39. _worker_thread = threading.Thread(target=_worker_loop, name="workflow-task-worker", daemon=True)
  40. _worker_thread.start()
  41. def stop_worker() -> None:
  42. """通知后台线程停止;正在执行的节点会在当前进程退出时结束。"""
  43. _stop_event.set()
  44. _wake_event.set()
  45. def enqueue_workflow_task(workflow: dict[str, Any], payload: AutomationWorkflowRunRequest) -> dict[str, Any]:
  46. """保存任务快照并加入全局串行队列。"""
  47. workflow_key = str(workflow.get("workflow_key") or "").strip()
  48. if not workflow_key:
  49. raise HTTPException(status_code=400, detail="Workflow key is required before execution")
  50. task_id = str(uuid.uuid4())
  51. created_at = now_iso()
  52. request_json = json.dumps(payload.model_dump(), ensure_ascii=False)
  53. snapshot_json = json.dumps(workflow, ensure_ascii=False)
  54. with get_db() as conn:
  55. conn.execute(
  56. """
  57. INSERT INTO automation_workflow_tasks (
  58. id, workflow_id, workflow_key, workflow_name, status,
  59. request_json, workflow_snapshot_json, created_at
  60. ) VALUES (?, ?, ?, ?, 'QUEUED', ?, ?, ?)
  61. """,
  62. (
  63. task_id,
  64. workflow.get("id"),
  65. workflow_key,
  66. str(workflow.get("name") or workflow_key),
  67. request_json,
  68. snapshot_json,
  69. created_at,
  70. ),
  71. )
  72. _wake_event.set()
  73. return get_workflow_task(task_id)
  74. def list_workflow_tasks(page: int, page_size: int, status: str | None = None) -> dict[str, Any]:
  75. """分页读取 workflow 异步任务历史。"""
  76. clauses: list[str] = []
  77. params: list[Any] = []
  78. if status:
  79. clauses.append("status = ?")
  80. params.append(status.upper())
  81. where_sql = "WHERE " + " AND ".join(clauses) if clauses else ""
  82. offset = (page - 1) * page_size
  83. with get_db() as conn:
  84. total = conn.execute(
  85. f"SELECT COUNT(*) AS total FROM automation_workflow_tasks {where_sql}",
  86. params,
  87. ).fetchone()["total"]
  88. rows = conn.execute(
  89. f"""
  90. SELECT * FROM automation_workflow_tasks
  91. {where_sql}
  92. ORDER BY created_at DESC
  93. LIMIT ? OFFSET ?
  94. """,
  95. [*params, page_size, offset],
  96. ).fetchall()
  97. return {
  98. "items": [task_to_public(row, include_payload=False) for row in rows],
  99. "total": total,
  100. "page": page,
  101. "page_size": page_size,
  102. }
  103. def get_workflow_task(task_id: str) -> dict[str, Any]:
  104. with get_db() as conn:
  105. row = conn.execute("SELECT * FROM automation_workflow_tasks WHERE id = ?", (task_id,)).fetchone()
  106. if not row:
  107. raise HTTPException(status_code=404, detail="Workflow task not found")
  108. return task_to_public(row, include_payload=True)
  109. def task_to_public(row: dict[str, Any], include_payload: bool) -> dict[str, Any]:
  110. item = {
  111. "id": row["id"],
  112. "workflow_id": row.get("workflow_id"),
  113. "workflow_key": row["workflow_key"],
  114. "workflow_name": row["workflow_name"],
  115. "status": row["status"],
  116. "error_message": row.get("error_message"),
  117. "created_at": row["created_at"],
  118. "started_at": row.get("started_at"),
  119. "finished_at": row.get("finished_at"),
  120. }
  121. if row["status"] == "QUEUED":
  122. item["queue_position"] = queue_position(row["id"], row["created_at"])
  123. else:
  124. item["queue_position"] = None
  125. if include_payload:
  126. item["request"] = parse_json_object(row.get("request_json"))
  127. item["result"] = parse_json_value(row.get("result_json"))
  128. item["return_data"] = parse_json_value(row.get("return_data_json"))
  129. return item
  130. def queue_position(task_id: str, created_at: str) -> int:
  131. with get_db() as conn:
  132. count = conn.execute(
  133. """
  134. SELECT COUNT(*) AS total
  135. FROM automation_workflow_tasks
  136. WHERE status = 'QUEUED'
  137. AND (created_at < ? OR (created_at = ? AND id <= ?))
  138. """,
  139. (created_at, created_at, task_id),
  140. ).fetchone()["total"]
  141. return int(count)
  142. def parse_json_object(value: str | None) -> dict[str, Any]:
  143. parsed = parse_json_value(value)
  144. return parsed if isinstance(parsed, dict) else {}
  145. def parse_json_value(value: str | None) -> Any:
  146. if not value:
  147. return None
  148. try:
  149. return json.loads(value)
  150. except json.JSONDecodeError:
  151. return None
  152. def _worker_loop() -> None:
  153. while not _stop_event.is_set():
  154. task = _claim_next_task()
  155. if not task:
  156. _wake_event.wait(1)
  157. _wake_event.clear()
  158. continue
  159. _execute_task(task)
  160. def _claim_next_task() -> dict[str, Any] | None:
  161. """通过 SQLite 写锁领取任务,确保任何时刻只有一个全局活动任务。"""
  162. with get_db() as conn:
  163. conn.execute("BEGIN IMMEDIATE")
  164. runtime = conn.execute(
  165. "SELECT active_task_id FROM automation_workflow_runtime WHERE id = 1"
  166. ).fetchone()
  167. if runtime and runtime.get("active_task_id"):
  168. return None
  169. task = conn.execute(
  170. """
  171. SELECT * FROM automation_workflow_tasks
  172. WHERE status = 'QUEUED'
  173. ORDER BY created_at ASC, id ASC
  174. LIMIT 1
  175. """
  176. ).fetchone()
  177. if not task:
  178. return None
  179. started_at = now_iso()
  180. conn.execute(
  181. "UPDATE automation_workflow_tasks SET status = 'RUNNING', started_at = ?, error_message = NULL WHERE id = ?",
  182. (started_at, task["id"]),
  183. )
  184. conn.execute(
  185. "UPDATE automation_workflow_runtime SET active_task_id = ?, updated_at = ? WHERE id = 1",
  186. (task["id"], started_at),
  187. )
  188. task["status"] = "RUNNING"
  189. task["started_at"] = started_at
  190. return task
  191. def _execute_task(task: dict[str, Any]) -> None:
  192. from . import automation_service
  193. result: dict[str, Any] | None = None
  194. error_message: str | None = None
  195. status = "FAILED"
  196. try:
  197. workflow = parse_json_object(task.get("workflow_snapshot_json"))
  198. payload = AutomationWorkflowRunRequest.model_validate(parse_json_object(task.get("request_json")))
  199. result = automation_service.execute_workflow(workflow, payload)
  200. raw_status = str(result.get("status") or "FAILED").upper()
  201. status = raw_status if raw_status in TERMINAL_STATUSES else "FAILED"
  202. if status == "FAILED":
  203. error_message = workflow_failure_message(result)
  204. return_data = automation_service.workflow_return_data(workflow, result)
  205. except Exception as exc:
  206. error_message = str(exc)
  207. return_data = None
  208. result = {"status": "FAILED", "detail": error_message}
  209. finished_at = now_iso()
  210. with get_db() as conn:
  211. conn.execute("BEGIN IMMEDIATE")
  212. conn.execute(
  213. """
  214. UPDATE automation_workflow_tasks
  215. SET status = ?, result_json = ?, return_data_json = ?, error_message = ?, finished_at = ?
  216. WHERE id = ?
  217. """,
  218. (
  219. status,
  220. json.dumps(result, ensure_ascii=False),
  221. json.dumps(return_data, ensure_ascii=False) if return_data is not None else None,
  222. error_message,
  223. finished_at,
  224. task["id"],
  225. ),
  226. )
  227. conn.execute(
  228. "UPDATE automation_workflow_runtime SET active_task_id = NULL, updated_at = ? WHERE id = 1",
  229. (finished_at,),
  230. )
  231. _wake_event.set()
  232. def workflow_failure_message(result: dict[str, Any]) -> str | None:
  233. failed = result.get("failed")
  234. if isinstance(failed, dict) and failed.get("detail"):
  235. return str(failed["detail"])
  236. return str(result.get("detail") or "Workflow execution failed")