app.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715
  1. # --------------------------------------------------------
  2. # Based on yolov10
  3. # https://github.com/THU-MIG/yolov10/app.py
  4. # --------------------------------------------------------'
  5. import logging
  6. import tempfile
  7. import threading
  8. import cv2
  9. import gradio as gr
  10. import uvicorn
  11. import asyncio
  12. from fastapi import FastAPI
  13. from fastapi import status
  14. from fastapi.exceptions import RequestValidationError
  15. from fastapi.responses import JSONResponse
  16. from pydantic import BaseModel
  17. from ultralytics import YOLO
  18. import os
  19. import glob
  20. import subprocess
  21. import signal
  22. import time
  23. from typing import Optional, Dict
  24. from concurrent.futures import ProcessPoolExecutor, Future
  25. from functools import partial
  26. import uuid
  27. # 设置日志格式和级别
  28. logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s - %(message)s')
  29. # 创建进程池执行器和任务管理
  30. stream_executor = ProcessPoolExecutor(max_workers=4)
  31. stream_tasks: Dict[str, Future] = {}
  32. stream_pids: Dict[str, int] = {} # 记录每个任务的worker进程pid
  33. def yolov12_inference(image, video, model_id, image_size, conf_threshold):
  34. model = YOLO(model_id)
  35. if image:
  36. results = model.predict(source=image, imgsz=image_size, conf=conf_threshold)
  37. annotated_image = results[0].plot()
  38. return annotated_image[:, :, ::-1], None
  39. else:
  40. video_path = tempfile.mktemp(suffix=".webm")
  41. with open(video_path, "wb") as f:
  42. with open(video, "rb") as g:
  43. f.write(g.read())
  44. cap = cv2.VideoCapture(video_path)
  45. fps = cap.get(cv2.CAP_PROP_FPS)
  46. frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  47. frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  48. output_video_path = tempfile.mktemp(suffix=".webm")
  49. out = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'vp80'), fps, (frame_width, frame_height))
  50. while cap.isOpened():
  51. ret, frame = cap.read()
  52. if not ret:
  53. break
  54. results = model.predict(source=frame, imgsz=image_size, conf=conf_threshold)
  55. annotated_frame = results[0].plot()
  56. out.write(annotated_frame)
  57. cap.release()
  58. out.release()
  59. return None, output_video_path
  60. def yolov12_inference_for_examples(image, model_path, image_size, conf_threshold):
  61. annotated_image, _ = yolov12_inference(image, None, model_path, image_size, conf_threshold)
  62. return annotated_image
  63. def app():
  64. with gr.Blocks():
  65. with gr.Row():
  66. with gr.Column():
  67. image = gr.Image(type="pil", label="Image", visible=True)
  68. video = gr.Video(label="Video", visible=False)
  69. input_type = gr.Radio(
  70. choices=["Image", "Video"],
  71. value="Image",
  72. label="Input Type",
  73. )
  74. model_id = gr.Dropdown(
  75. label="Model",
  76. choices=[
  77. "yolov12n.pt",
  78. "yolov12s.pt",
  79. "yolov12m.pt",
  80. "yolov12l.pt",
  81. "yolov12x.pt",
  82. ],
  83. value="yolov12m.pt",
  84. )
  85. image_size = gr.Slider(
  86. label="Image Size",
  87. minimum=320,
  88. maximum=1280,
  89. step=32,
  90. value=640,
  91. )
  92. conf_threshold = gr.Slider(
  93. label="Confidence Threshold",
  94. minimum=0.0,
  95. maximum=1.0,
  96. step=0.05,
  97. value=0.25,
  98. )
  99. yolov12_infer = gr.Button(value="Detect Objects")
  100. with gr.Column():
  101. output_image = gr.Image(type="numpy", label="Annotated Image", visible=True)
  102. output_video = gr.Video(label="Annotated Video", visible=False)
  103. def update_visibility(input_type):
  104. image = gr.update(visible=True) if input_type == "Image" else gr.update(visible=False)
  105. video = gr.update(visible=False) if input_type == "Image" else gr.update(visible=True)
  106. output_image = gr.update(visible=True) if input_type == "Image" else gr.update(visible=False)
  107. output_video = gr.update(visible=False) if input_type == "Image" else gr.update(visible=True)
  108. return image, video, output_image, output_video
  109. input_type.change(
  110. fn=update_visibility,
  111. inputs=[input_type],
  112. outputs=[image, video, output_image, output_video],
  113. )
  114. def run_inference(image, video, model_id, image_size, conf_threshold, input_type):
  115. if input_type == "Image":
  116. return yolov12_inference(image, None, model_id, image_size, conf_threshold)
  117. else:
  118. return yolov12_inference(None, video, model_id, image_size, conf_threshold)
  119. yolov12_infer.click(
  120. fn=run_inference,
  121. inputs=[image, video, model_id, image_size, conf_threshold, input_type],
  122. outputs=[output_image, output_video],
  123. )
  124. gr.Examples(
  125. examples=[
  126. [
  127. "ultralytics/assets/bus.jpg",
  128. "yolov12s.pt",
  129. 640,
  130. 0.25,
  131. ],
  132. [
  133. "ultralytics/assets/zidane.jpg",
  134. "yolov12x.pt",
  135. 640,
  136. 0.25,
  137. ],
  138. ],
  139. fn=yolov12_inference_for_examples,
  140. inputs=[
  141. image,
  142. model_id,
  143. image_size,
  144. conf_threshold,
  145. ],
  146. outputs=[output_image],
  147. cache_examples='lazy',
  148. )
  149. gradio_app = gr.Blocks()
  150. with gradio_app:
  151. gr.HTML(
  152. """
  153. <h1 style='text-align: center'>
  154. YOLOv12: Attention-Centric Real-Time Object Detectors
  155. </h1>
  156. """)
  157. gr.HTML(
  158. """
  159. <h3 style='text-align: center'>
  160. <a href='https://arxiv.org/abs/2502.12524' target='_blank'>arXiv</a> | <a href='https://github.com/sunsmarterjie/yolov12' target='_blank'>github</a>
  161. </h3>
  162. """)
  163. with gr.Row():
  164. with gr.Column():
  165. app()
  166. def start_gradio():
  167. gradio_app.launch(server_name="0.0.0.0", server_port=7860)
  168. # FastAPI部分
  169. app_fastapi = FastAPI()
  170. class TrainParams(BaseModel):
  171. """
  172. 用于接收/yolov12/train接口的训练参数,所有参数均需前端传入。
  173. """
  174. model: str # 训练底模
  175. data: str # 数据集配置文件路径
  176. epochs: int # 训练轮数
  177. batch: int # 批次大小
  178. imgsz: int # 输入图片尺寸
  179. scale: float # 随机缩放增强比例
  180. mosaic: float # mosaic数据增强概率
  181. mixup: float # mixup数据增强概率
  182. copy_paste: float # copy-paste数据增强概率
  183. device: str # 训练设备
  184. project: str # 工程名
  185. name: str # 实验名
  186. exist_ok: bool # 是否允许覆盖同名目录
  187. @app_fastapi.post("/yolov12/train")
  188. def yolov12_train(params: TrainParams):
  189. """
  190. RESTful POST接口:/yolov12/train
  191. 接收训练参数,调用YOLO模型训练,并返回训练结果。
  192. 返回格式:{"code": 0/1, "msg": "success/错误原因", "result": 训练结果或None}
  193. """
  194. logging.info("收到/yolov12/train训练请求")
  195. logging.info(f"请求参数: {params}")
  196. try:
  197. # 根据params.model动态确定配置文件
  198. if params.model.endswith('.pt'):
  199. # 如果是.pt文件,将后缀替换为.yaml
  200. config_file = params.model.replace('.pt', '.yaml')
  201. else:
  202. # 如果不是.pt文件,使用默认配置
  203. config_file = "yolov12.yaml"
  204. model = YOLO(config_file)
  205. model.load(params.model)
  206. logging.info("开始模型训练...")
  207. results = model.train(
  208. data=params.data,
  209. epochs=params.epochs,
  210. batch=params.batch,
  211. imgsz=params.imgsz,
  212. scale=params.scale,
  213. mosaic=params.mosaic,
  214. mixup=params.mixup,
  215. copy_paste=params.copy_paste,
  216. device=params.device,
  217. project=params.project,
  218. name=params.name,
  219. exist_ok=params.exist_ok,
  220. )
  221. logging.info("模型训练完成")
  222. # logging.info(f"训练结果: {str(results)}")
  223. return {
  224. "code": 0,
  225. "msg": "success",
  226. "result": str(results.save_dir)
  227. }
  228. except Exception as e:
  229. logging.error(f"训练过程发生异常: {e}")
  230. return {
  231. "code": 1,
  232. "msg": str(e),
  233. "result": None
  234. }
  235. class PredictParams(BaseModel):
  236. """
  237. 用于接收/yolov12/predict接口的预测参数,与YOLO predict方法保持一致。
  238. """
  239. model: str = "yolov12m.pt" # 模型路径
  240. source: str = None # 输入源(图片/视频路径、URL等)
  241. stream: bool = False # 是否流式处理
  242. conf: float = 0.25 # 置信度阈值
  243. iou: float = 0.7 # IoU阈值
  244. max_det: int = 300 # 最大检测数量
  245. imgsz: int = 640 # 输入图片尺寸
  246. batch: int = 1 # 批次大小
  247. device: str = "" # 设备
  248. show: bool = False # 是否显示结果
  249. save: bool = False # 是否保存结果
  250. save_txt: bool = False # 是否保存txt文件
  251. save_conf: bool = False # 是否保存置信度
  252. save_crop: bool = False # 是否保存裁剪图片
  253. show_labels: bool = True # 是否显示标签
  254. show_conf: bool = True # 是否显示置信度
  255. show_boxes: bool = True # 是否显示边界框
  256. line_width: int = None # 线条宽度
  257. vid_stride: int = 1 # 视频帧步长
  258. stream_buffer: bool = False # 流缓冲区
  259. visualize: bool = False # 可视化特征
  260. augment: bool = False # 数据增强
  261. agnostic_nms: bool = False # 类别无关NMS
  262. classes: list = None # 指定类别
  263. retina_masks: bool = False # 高分辨率分割掩码
  264. embed: list = None # 特征向量层
  265. half: bool = False # 半精度
  266. dnn: bool = False # OpenCV DNN
  267. project: str = "" # 项目名
  268. name: str = "" # 实验名
  269. exist_ok: bool = False # 是否覆盖现有目录
  270. verbose: bool = True # 详细输出
  271. @app_fastapi.post("/yolov12/predict")
  272. def yolov12_predict(params: PredictParams):
  273. """
  274. RESTful POST接口:/yolov12/predict
  275. 接收预测参数,调用YOLO模型进行预测,并返回预测结果。
  276. 返回格式:{"code": 0/1, "msg": "success/错误原因", "result": {"save_dir": "保存目录", "filename": "文件名"}}
  277. """
  278. logging.info("收到/yolov12/predict预测请求")
  279. logging.info(f"请求参数: {params}")
  280. try:
  281. model = YOLO(params.model)
  282. logging.info("开始模型预测...")
  283. # 构建预测参数
  284. predict_kwargs = {}
  285. for field, value in params.dict().items():
  286. if field not in ['model'] and value is not None:
  287. predict_kwargs[field] = value
  288. # 确保保存结果,并强制使用MP4格式
  289. predict_kwargs['save'] = True
  290. results = model.predict(**predict_kwargs)
  291. logging.info("模型预测完成")
  292. # 获取保存目录和最终文件名
  293. result = results[0]
  294. save_dir = result.save_dir if hasattr(result, 'save_dir') else None
  295. # 获取最终生成的文件名
  296. final_filename = None
  297. if save_dir:
  298. # 获取输入文件名(不含扩展名)
  299. source = params.source
  300. base_name = None
  301. if source:
  302. base_name = os.path.splitext(os.path.basename(source))[0]
  303. # 支持的扩展名
  304. exts = ['*.jpg', '*.jpeg', '*.png', '*.mp4', '*.webm', '*.avi', '*.mov']
  305. matched_files = []
  306. for ext in exts:
  307. matched_files.extend(glob.glob(os.path.join(save_dir, ext)))
  308. # 按时间排序,查找与输入文件同名的第一个文件
  309. if base_name and matched_files:
  310. matched_files = sorted(matched_files, key=os.path.getmtime)
  311. for f in matched_files:
  312. if os.path.splitext(os.path.basename(f))[0] == base_name:
  313. final_filename = os.path.basename(f)
  314. logging.info(f"按输入文件名查找,返回文件: {final_filename}")
  315. break
  316. # 如果没找到同名文件,返回最新文件
  317. if not final_filename and matched_files:
  318. latest_file = max(matched_files, key=os.path.getmtime)
  319. final_filename = os.path.basename(latest_file)
  320. logging.info(f"未找到同名,返回最新文件: {final_filename}")
  321. return {
  322. "code": 0,
  323. "msg": "success",
  324. "result": save_dir+"/"+final_filename
  325. }
  326. except Exception as e:
  327. logging.error(f"预测过程发生异常: {e}")
  328. return {
  329. "code": 1,
  330. "msg": str(e),
  331. "result": None
  332. }
  333. class StreamParams(BaseModel):
  334. """
  335. 用于接收 /yolov12/stream 接口的参数。
  336. model: 推理模型路径
  337. source: 拉流地址(如rtsp/http视频流)
  338. stream_url: 推流地址(如rtmp推流地址)
  339. fps: 输出帧率
  340. bitrate: 输出码率
  341. 其他参数同 predict
  342. """
  343. model: str = "yolov12m.pt"
  344. source: str = None
  345. stream_url: str = None
  346. fps: int = 25
  347. bitrate: str = "2000k"
  348. conf: float = 0.25
  349. iou: Optional[float] = 0.7
  350. imgsz: int = 640
  351. device: str = ""
  352. # 可根据需要补充更多参数
  353. def yolov12_stream_worker(params_dict, task_id):
  354. """
  355. 同步推理函数,在进程池中执行
  356. 支持信号终止
  357. """
  358. import os
  359. import cv2
  360. import time
  361. import subprocess
  362. import signal
  363. from ultralytics import YOLO
  364. # 注册SIGTERM信号处理器
  365. def handle_sigterm(signum, frame):
  366. print(f"任务 {task_id} 收到终止信号,准备退出")
  367. exit(0)
  368. signal.signal(signal.SIGTERM, handle_sigterm)
  369. model_path = params_dict['model']
  370. source = params_dict['source']
  371. stream_url = params_dict['stream_url']
  372. fps = params_dict.get('fps', 25)
  373. bitrate = params_dict.get('bitrate', '2000k')
  374. conf = params_dict.get('conf', 0.25)
  375. iou = params_dict.get('iou', 0.7)
  376. imgsz = params_dict.get('imgsz', 640)
  377. device = params_dict.get('device', '')
  378. # 全局变量用于存储进程引用
  379. ffmpeg_process = None
  380. def cleanup_process():
  381. """清理ffmpeg进程"""
  382. nonlocal ffmpeg_process
  383. if ffmpeg_process:
  384. try:
  385. ffmpeg_process.terminate()
  386. ffmpeg_process.wait(timeout=5)
  387. except subprocess.TimeoutExpired:
  388. ffmpeg_process.kill()
  389. except Exception as e:
  390. print(f"清理ffmpeg进程时出错: {e}")
  391. try:
  392. model = YOLO(model_path)
  393. cap = cv2.VideoCapture(source)
  394. if not cap.isOpened():
  395. return {"code": 1, "msg": f"无法打开视频流: {source}", "result": None}
  396. # 获取视频流信息
  397. fps_cap = cap.get(cv2.CAP_PROP_FPS)
  398. width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  399. height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  400. # 使用实际帧率,如果获取不到则使用参数中的fps
  401. output_fps = fps_cap if fps_cap > 0 else fps
  402. # 构建ffmpeg命令
  403. ffmpeg_cmd = [
  404. 'ffmpeg',
  405. '-f', 'rawvideo',
  406. '-pix_fmt', 'bgr24',
  407. '-s', f'{width}x{height}',
  408. '-r', str(output_fps),
  409. '-i', '-', # 从stdin读取
  410. '-c:v', 'libx264',
  411. '-preset', 'ultrafast',
  412. '-tune', 'zerolatency',
  413. '-b:v', bitrate,
  414. '-maxrate', bitrate,
  415. '-bufsize', '4000k',
  416. '-g', str(int(output_fps) * 2), # GOP大小
  417. '-f', 'flv' if stream_url.startswith('rtmp') else 'mpegts',
  418. '-y', # 覆盖输出文件
  419. stream_url
  420. ]
  421. print(f"任务 {task_id} 启动ffmpeg命令: {' '.join(ffmpeg_cmd)}")
  422. # 启动ffmpeg进程
  423. ffmpeg_process = subprocess.Popen(
  424. ffmpeg_cmd,
  425. stdin=subprocess.PIPE,
  426. stdout=subprocess.PIPE,
  427. stderr=subprocess.PIPE,
  428. bufsize=0
  429. )
  430. # 等待ffmpeg启动
  431. time.sleep(1)
  432. if ffmpeg_process.poll() is not None:
  433. # ffmpeg进程异常退出
  434. stderr_output = ffmpeg_process.stderr.read().decode() if ffmpeg_process.stderr else "未知错误"
  435. return {"code": 1, "msg": f"ffmpeg启动失败: {stderr_output}", "result": None}
  436. frame_count = 0
  437. start_time = time.time()
  438. try:
  439. while cap.isOpened():
  440. ret, frame = cap.read()
  441. if not ret:
  442. break
  443. # YOLO推理
  444. try:
  445. predict_kwargs = {
  446. 'source': frame,
  447. 'imgsz': imgsz,
  448. 'conf': conf,
  449. 'device': device
  450. }
  451. # 只有当iou不为None时才添加到参数中
  452. if iou is not None:
  453. predict_kwargs['iou'] = iou
  454. results = model.predict(**predict_kwargs)
  455. annotated_frame = results[0].plot()
  456. except Exception as predict_error:
  457. print(f"任务 {task_id} YOLO推理出错: {predict_error}")
  458. # 如果推理失败,使用原始帧
  459. annotated_frame = frame
  460. # 将处理后的帧写入ffmpeg
  461. try:
  462. ffmpeg_process.stdin.write(annotated_frame.tobytes())
  463. ffmpeg_process.stdin.flush()
  464. frame_count += 1
  465. # 每100帧输出一次进度
  466. if frame_count % 100 == 0:
  467. elapsed_time = time.time() - start_time
  468. current_fps = frame_count / elapsed_time
  469. print(f"任务 {task_id} 已处理 {frame_count} 帧,当前FPS: {current_fps:.2f}")
  470. except IOError as e:
  471. print(f"任务 {task_id} 写入ffmpeg时出错: {e}")
  472. break
  473. except KeyboardInterrupt:
  474. print(f"任务 {task_id} 收到中断信号,停止处理")
  475. finally:
  476. # 清理资源
  477. cap.release()
  478. cleanup_process()
  479. elapsed_time = time.time() - start_time
  480. avg_fps = frame_count / elapsed_time if elapsed_time > 0 else 0
  481. print(f"任务 {task_id} 推理并推流完成,共处理帧数: {frame_count},平均FPS: {avg_fps:.2f}")
  482. return {"code": 0, "msg": "success", "result": {"frames_processed": frame_count, "avg_fps": avg_fps}}
  483. except Exception as e:
  484. print(f"任务 {task_id} 发生异常: {e}")
  485. cleanup_process()
  486. return {"code": 1, "msg": str(e), "result": None}
  487. @app_fastapi.post("/yolov12/stream")
  488. async def yolov12_stream_async(params: StreamParams):
  489. """
  490. RESTful POST接口:/yolov12/stream
  491. 接收视频拉流地址和推流地址,调用YOLO模型推理,使用ffmpeg将推理后的视频推送到推流地址。
  492. 支持并发推理和任务取消。
  493. 返回格式:{"code": 0/1, "msg": "success/错误原因", "result": {"task_id": "任务ID"}}
  494. """
  495. logging.info("收到/yolov12/stream请求")
  496. logging.info(f"请求参数: {params}")
  497. # 生成唯一任务ID
  498. task_id = str(uuid.uuid4())
  499. try:
  500. # 异步执行推理任务
  501. loop = asyncio.get_event_loop()
  502. future = loop.run_in_executor(
  503. stream_executor, partial(yolov12_stream_worker, params.dict(), task_id)
  504. )
  505. stream_tasks[task_id] = future
  506. logging.info(f"任务 {task_id} 已提交到进程池")
  507. return {
  508. "code": 0,
  509. "msg": "任务已提交",
  510. "result": task_id
  511. }
  512. except Exception as e:
  513. logging.error(f"提交任务时发生异常: {e}")
  514. return {
  515. "code": 1,
  516. "msg": str(e),
  517. "result": None
  518. }
  519. @app_fastapi.post("/yolov12/stream/cancel")
  520. async def cancel_stream_task(task_id: str):
  521. """
  522. RESTful POST接口:/yolov12/stream/cancel
  523. 取消指定的推理任务
  524. 返回格式:{"code": 0/1, "msg": "success/错误原因", "result": None}
  525. """
  526. logging.info(f"收到取消任务请求: {task_id}")
  527. future = stream_tasks.get(task_id)
  528. if not future:
  529. return {"code": 1, "msg": "任务不存在", "result": None}
  530. if future.done():
  531. return {"code": 1, "msg": "任务已完成,无法取消", "result": None}
  532. try:
  533. # 尝试取消任务
  534. cancelled = future.cancel()
  535. if cancelled:
  536. logging.info(f"任务 {task_id} 已取消")
  537. return {"code": 0, "msg": "任务已取消", "result": None}
  538. else:
  539. return {"code": 1, "msg": "任务无法取消(可能正在运行)", "result": None}
  540. except Exception as e:
  541. logging.error(f"取消任务时发生异常: {e}")
  542. return {"code": 1, "msg": str(e), "result": None}
  543. @app_fastapi.get("/yolov12/stream/status")
  544. async def get_stream_status(task_id: str):
  545. """
  546. RESTful GET接口:/yolov12/stream/status
  547. 查询指定任务的状态
  548. 返回格式:{"code": 0/1, "msg": "success/错误原因", "result": {"status": "状态", "result": "结果"}}
  549. """
  550. logging.info(f"收到查询任务状态请求: {task_id}")
  551. future = stream_tasks.get(task_id)
  552. if not future:
  553. return {"code": 1, "msg": "任务不存在", "result": None}
  554. try:
  555. if future.done():
  556. try:
  557. result = future.result()
  558. return {
  559. "code": 0,
  560. "msg": "已完成",
  561. "result": {"status": "completed", "result": result}
  562. }
  563. except Exception as e:
  564. return {
  565. "code": 1,
  566. "msg": f"任务异常: {e}",
  567. "result": {"status": "failed", "error": str(e)}
  568. }
  569. elif future.cancelled():
  570. return {
  571. "code": 0,
  572. "msg": "已取消",
  573. "result": {"status": "cancelled"}
  574. }
  575. else:
  576. return {
  577. "code": 0,
  578. "msg": "运行中",
  579. "result": {"status": "running"}
  580. }
  581. except Exception as e:
  582. logging.error(f"查询任务状态时发生异常: {e}")
  583. return {"code": 1, "msg": str(e), "result": None}
  584. @app_fastapi.get("/yolov12/stream/list")
  585. async def list_stream_tasks():
  586. """
  587. RESTful GET接口:/yolov12/stream/list
  588. 列出所有任务的状态
  589. 返回格式:{"code": 0, "msg": "success", "result": {"tasks": [{"task_id": "ID", "status": "状态"}]}}
  590. """
  591. logging.info("收到查询所有任务请求")
  592. try:
  593. tasks_info = []
  594. for task_id, future in stream_tasks.items():
  595. if future.done():
  596. if future.cancelled():
  597. status = "cancelled"
  598. else:
  599. try:
  600. future.result() # 检查是否有异常
  601. status = "completed"
  602. except:
  603. status = "failed"
  604. else:
  605. status = "running"
  606. tasks_info.append({
  607. "task_id": task_id,
  608. "status": status
  609. })
  610. return {
  611. "code": 0,
  612. "msg": "success",
  613. "result": tasks_info
  614. }
  615. except Exception as e:
  616. logging.error(f"查询所有任务时发生异常: {e}")
  617. return {"code": 1, "msg": str(e), "result": None}
  618. # 全局异常处理器:参数校验失败时统一返回格式
  619. @app_fastapi.exception_handler(RequestValidationError)
  620. async def validation_exception_handler(request, exc):
  621. err_msg = f"参数校验失败: 路径={request.url.path}, 错误={exc.errors()}"
  622. logging.error(err_msg)
  623. return JSONResponse(
  624. status_code=status.HTTP_200_OK,
  625. content={
  626. "code": 422,
  627. "msg": err_msg,
  628. "result": None
  629. }
  630. )
  631. if __name__ == "__main__":
  632. threading.Thread(target=start_gradio, daemon=True).start()
  633. uvicorn.run(app_fastapi, host="0.0.0.0", port=8000)