app.py 45 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174
  1. # --------------------------------------------------------
  2. # Based on yolov10
  3. # https://github.com/THU-MIG/yolov10/app.py
  4. # --------------------------------------------------------'
  5. import logging
  6. import tempfile
  7. import threading
  8. import cv2
  9. import gradio as gr
  10. import uvicorn
  11. import asyncio
  12. from fastapi import FastAPI
  13. from fastapi import status
  14. from fastapi.exceptions import RequestValidationError
  15. from fastapi.responses import JSONResponse
  16. from pydantic import BaseModel
  17. from ultralytics import YOLO
  18. import os
  19. import glob
  20. import subprocess
  21. import signal
  22. import time
  23. from typing import Optional, Dict
  24. from concurrent.futures import ProcessPoolExecutor, Future
  25. from functools import partial
  26. import uuid
  27. # 设置日志格式和级别
  28. logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s - %(message)s')
  29. # 创建进程池执行器和任务管理
  30. stream_executor = ProcessPoolExecutor(max_workers=4)
  31. stream_tasks: Dict[str, Dict] = {} # 存储任务信息:{task_id: {"future": Future, "start_time": long, "source": str, "stream_url": str}}
  32. stream_pids: Dict[str, int] = {} # 记录每个任务的worker进程pid
  33. def create_new_executor():
  34. """创建新的进程池执行器"""
  35. global stream_executor
  36. try:
  37. if stream_executor:
  38. stream_executor.shutdown(wait=False)
  39. except:
  40. pass
  41. stream_executor = ProcessPoolExecutor(max_workers=4)
  42. print("已重新创建进程池执行器")
  43. def cleanup_completed_tasks():
  44. """清理已完成的任务"""
  45. completed_tasks = []
  46. for task_id, task_info in stream_tasks.items():
  47. if task_info["future"].done():
  48. completed_tasks.append(task_id)
  49. for task_id in completed_tasks:
  50. del stream_tasks[task_id]
  51. if task_id in stream_pids:
  52. del stream_pids[task_id]
  53. # 清理PID文件
  54. pid_file = f"/tmp/yolov12_task_{task_id}.pid"
  55. try:
  56. if os.path.exists(pid_file):
  57. os.remove(pid_file)
  58. except:
  59. pass
  60. print(f"已清理完成的任务: {task_id}")
  61. def yolov12_inference(image, video, model_id, image_size, conf_threshold):
  62. model = YOLO(model_id)
  63. if image:
  64. results = model.predict(source=image, imgsz=image_size, conf=conf_threshold)
  65. annotated_image = results[0].plot()
  66. return annotated_image[:, :, ::-1], None
  67. else:
  68. video_path = tempfile.mktemp(suffix=".webm")
  69. with open(video_path, "wb") as f:
  70. with open(video, "rb") as g:
  71. f.write(g.read())
  72. cap = cv2.VideoCapture(video_path)
  73. fps = cap.get(cv2.CAP_PROP_FPS)
  74. frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  75. frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  76. output_video_path = tempfile.mktemp(suffix=".webm")
  77. out = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'vp80'), fps, (frame_width, frame_height))
  78. while cap.isOpened():
  79. ret, frame = cap.read()
  80. if not ret:
  81. break
  82. results = model.predict(source=frame, imgsz=image_size, conf=conf_threshold)
  83. annotated_frame = results[0].plot()
  84. out.write(annotated_frame)
  85. cap.release()
  86. out.release()
  87. return None, output_video_path
  88. def yolov12_inference_for_examples(image, model_path, image_size, conf_threshold):
  89. annotated_image, _ = yolov12_inference(image, None, model_path, image_size, conf_threshold)
  90. return annotated_image
  91. def app():
  92. with gr.Blocks():
  93. with gr.Row():
  94. with gr.Column():
  95. image = gr.Image(type="pil", label="Image", visible=True)
  96. video = gr.Video(label="Video", visible=False)
  97. input_type = gr.Radio(
  98. choices=["Image", "Video"],
  99. value="Image",
  100. label="Input Type",
  101. )
  102. model_id = gr.Dropdown(
  103. label="Model",
  104. choices=[
  105. "yolov12n.pt",
  106. "yolov12s.pt",
  107. "yolov12m.pt",
  108. "yolov12l.pt",
  109. "yolov12x.pt",
  110. ],
  111. value="yolov12m.pt",
  112. )
  113. image_size = gr.Slider(
  114. label="Image Size",
  115. minimum=320,
  116. maximum=1280,
  117. step=32,
  118. value=640,
  119. )
  120. conf_threshold = gr.Slider(
  121. label="Confidence Threshold",
  122. minimum=0.0,
  123. maximum=1.0,
  124. step=0.05,
  125. value=0.25,
  126. )
  127. yolov12_infer = gr.Button(value="Detect Objects")
  128. with gr.Column():
  129. output_image = gr.Image(type="numpy", label="Annotated Image", visible=True)
  130. output_video = gr.Video(label="Annotated Video", visible=False)
  131. def update_visibility(input_type):
  132. image = gr.update(visible=True) if input_type == "Image" else gr.update(visible=False)
  133. video = gr.update(visible=False) if input_type == "Image" else gr.update(visible=True)
  134. output_image = gr.update(visible=True) if input_type == "Image" else gr.update(visible=False)
  135. output_video = gr.update(visible=False) if input_type == "Image" else gr.update(visible=True)
  136. return image, video, output_image, output_video
  137. input_type.change(
  138. fn=update_visibility,
  139. inputs=[input_type],
  140. outputs=[image, video, output_image, output_video],
  141. )
  142. def run_inference(image, video, model_id, image_size, conf_threshold, input_type):
  143. if input_type == "Image":
  144. return yolov12_inference(image, None, model_id, image_size, conf_threshold)
  145. else:
  146. return yolov12_inference(None, video, model_id, image_size, conf_threshold)
  147. yolov12_infer.click(
  148. fn=run_inference,
  149. inputs=[image, video, model_id, image_size, conf_threshold, input_type],
  150. outputs=[output_image, output_video],
  151. )
  152. gr.Examples(
  153. examples=[
  154. [
  155. "ultralytics/assets/bus.jpg",
  156. "yolov12s.pt",
  157. 640,
  158. 0.25,
  159. ],
  160. [
  161. "ultralytics/assets/zidane.jpg",
  162. "yolov12x.pt",
  163. 640,
  164. 0.25,
  165. ],
  166. ],
  167. fn=yolov12_inference_for_examples,
  168. inputs=[
  169. image,
  170. model_id,
  171. image_size,
  172. conf_threshold,
  173. ],
  174. outputs=[output_image],
  175. cache_examples='lazy',
  176. )
  177. gradio_app = gr.Blocks()
  178. with gradio_app:
  179. gr.HTML(
  180. """
  181. <h1 style='text-align: center'>
  182. YOLOv12: Attention-Centric Real-Time Object Detectors
  183. </h1>
  184. """)
  185. gr.HTML(
  186. """
  187. <h3 style='text-align: center'>
  188. <a href='https://arxiv.org/abs/2502.12524' target='_blank'>arXiv</a> | <a href='https://github.com/sunsmarterjie/yolov12' target='_blank'>github</a>
  189. </h3>
  190. """)
  191. with gr.Row():
  192. with gr.Column():
  193. app()
  194. def start_gradio():
  195. gradio_app.launch(server_name="0.0.0.0", server_port=7860)
  196. # FastAPI部分
  197. app_fastapi = FastAPI()
  198. class TrainParams(BaseModel):
  199. """
  200. 用于接收/yolov12/train接口的训练参数,所有参数均需前端传入。
  201. """
  202. model: str # 训练底模
  203. data: str # 数据集配置文件路径
  204. epochs: int # 训练轮数
  205. batch: int # 批次大小
  206. imgsz: int # 输入图片尺寸
  207. scale: float # 随机缩放增强比例
  208. mosaic: float # mosaic数据增强概率
  209. mixup: float # mixup数据增强概率
  210. copy_paste: float # copy-paste数据增强概率
  211. device: str # 训练设备
  212. project: str # 工程名
  213. name: str # 实验名
  214. exist_ok: bool # 是否允许覆盖同名目录
  215. @app_fastapi.post("/yolov12/train")
  216. def yolov12_train(params: TrainParams):
  217. """
  218. RESTful POST接口:/yolov12/train
  219. 接收训练参数,调用YOLO模型训练,并返回训练结果。
  220. 返回格式:{"code": 0/1, "msg": "success/错误原因", "result": 训练结果或None}
  221. """
  222. logging.info("收到/yolov12/train训练请求")
  223. logging.info(f"请求参数: {params}")
  224. try:
  225. # 根据params.model动态确定配置文件
  226. if params.model.endswith('.pt'):
  227. # 如果是.pt文件,将后缀替换为.yaml
  228. config_file = params.model.replace('.pt', '.yaml')
  229. else:
  230. # 如果不是.pt文件,使用默认配置
  231. config_file = "yolov12.yaml"
  232. model = YOLO(config_file)
  233. model.load(params.model)
  234. logging.info("开始模型训练...")
  235. results = model.train(
  236. data=params.data,
  237. epochs=params.epochs,
  238. batch=params.batch,
  239. imgsz=params.imgsz,
  240. scale=params.scale,
  241. mosaic=params.mosaic,
  242. mixup=params.mixup,
  243. copy_paste=params.copy_paste,
  244. device=params.device,
  245. project=params.project,
  246. name=params.name,
  247. exist_ok=params.exist_ok,
  248. )
  249. logging.info("模型训练完成")
  250. # logging.info(f"训练结果: {str(results)}")
  251. return {
  252. "code": 0,
  253. "msg": "success",
  254. "result": str(results.save_dir)
  255. }
  256. except Exception as e:
  257. logging.error(f"训练过程发生异常: {e}")
  258. return {
  259. "code": 1,
  260. "msg": str(e),
  261. "result": None
  262. }
  263. class PredictParams(BaseModel):
  264. """
  265. 用于接收/yolov12/predict接口的预测参数,与YOLO predict方法保持一致。
  266. """
  267. model: str = "yolov12m.pt" # 模型路径
  268. source: str = None # 输入源(图片/视频路径、URL等)
  269. stream: bool = False # 是否流式处理
  270. conf: float = 0.25 # 置信度阈值
  271. iou: float = 0.7 # IoU阈值
  272. max_det: int = 300 # 最大检测数量
  273. imgsz: int = 640 # 输入图片尺寸
  274. batch: int = 1 # 批次大小
  275. device: str = "" # 设备
  276. show: bool = False # 是否显示结果
  277. save: bool = False # 是否保存结果
  278. save_txt: bool = False # 是否保存txt文件
  279. save_conf: bool = False # 是否保存置信度
  280. save_crop: bool = False # 是否保存裁剪图片
  281. show_labels: bool = True # 是否显示标签
  282. show_conf: bool = True # 是否显示置信度
  283. show_boxes: bool = True # 是否显示边界框
  284. line_width: int = None # 线条宽度
  285. vid_stride: int = 1 # 视频帧步长
  286. stream_buffer: bool = False # 流缓冲区
  287. visualize: bool = False # 可视化特征
  288. augment: bool = False # 数据增强
  289. agnostic_nms: bool = False # 类别无关NMS
  290. classes: list = None # 指定类别
  291. retina_masks: bool = False # 高分辨率分割掩码
  292. embed: list = None # 特征向量层
  293. half: bool = False # 半精度
  294. dnn: bool = False # OpenCV DNN
  295. project: str = "" # 项目名
  296. name: str = "" # 实验名
  297. exist_ok: bool = False # 是否覆盖现有目录
  298. verbose: bool = True # 详细输出
  299. @app_fastapi.post("/yolov12/predict")
  300. def yolov12_predict(params: PredictParams):
  301. """
  302. RESTful POST接口:/yolov12/predict
  303. 接收预测参数,调用YOLO模型进行预测,并返回预测结果。
  304. 返回格式:{"code": 0/1, "msg": "success/错误原因", "result": {"save_dir": "保存目录", "filename": "文件名"}}
  305. """
  306. logging.info("收到/yolov12/predict预测请求")
  307. logging.info(f"请求参数: {params}")
  308. try:
  309. model = YOLO(params.model)
  310. logging.info("开始模型预测...")
  311. # 构建预测参数
  312. predict_kwargs = {}
  313. for field, value in params.dict().items():
  314. if field not in ['model'] and value is not None:
  315. predict_kwargs[field] = value
  316. # 确保保存结果,并强制使用MP4格式
  317. predict_kwargs['save'] = True
  318. results = model.predict(**predict_kwargs)
  319. logging.info("模型预测完成")
  320. # 获取保存目录和最终文件名
  321. result = results[0]
  322. save_dir = result.save_dir if hasattr(result, 'save_dir') else None
  323. # 获取最终生成的文件名
  324. final_filename = None
  325. if save_dir:
  326. # 获取输入文件名(不含扩展名)
  327. source = params.source
  328. base_name = None
  329. if source:
  330. base_name = os.path.splitext(os.path.basename(source))[0]
  331. # 支持的扩展名
  332. exts = ['*.jpg', '*.jpeg', '*.png', '*.mp4', '*.webm', '*.avi', '*.mov']
  333. matched_files = []
  334. for ext in exts:
  335. matched_files.extend(glob.glob(os.path.join(save_dir, ext)))
  336. # 按时间排序,查找与输入文件同名的第一个文件
  337. if base_name and matched_files:
  338. matched_files = sorted(matched_files, key=os.path.getmtime)
  339. for f in matched_files:
  340. if os.path.splitext(os.path.basename(f))[0] == base_name:
  341. final_filename = os.path.basename(f)
  342. logging.info(f"按输入文件名查找,返回文件: {final_filename}")
  343. break
  344. # 如果没找到同名文件,返回最新文件
  345. if not final_filename and matched_files:
  346. latest_file = max(matched_files, key=os.path.getmtime)
  347. final_filename = os.path.basename(latest_file)
  348. logging.info(f"未找到同名,返回最新文件: {final_filename}")
  349. return {
  350. "code": 0,
  351. "msg": "success",
  352. "result": save_dir+"/"+final_filename
  353. }
  354. except Exception as e:
  355. logging.error(f"预测过程发生异常: {e}")
  356. return {
  357. "code": 1,
  358. "msg": str(e),
  359. "result": None
  360. }
  361. class StreamParams(BaseModel):
  362. """
  363. 用于接收 /yolov12/stream 接口的参数。
  364. model: 推理模型路径
  365. source: 拉流地址(如rtsp/http视频流)
  366. stream_url: 推流地址(如rtmp推流地址)
  367. fps: 输出帧率
  368. bitrate: 输出码率
  369. save_local: 是否保存到本地文件(用于调试)
  370. 其他参数同 predict
  371. """
  372. model: str = "yolov12m.pt"
  373. source: str = None
  374. stream_url: str = None
  375. fps: int = 25
  376. bitrate: str = "6000k" # 提高默认码率
  377. conf: float = 0.25
  378. iou: Optional[float] = 0.7
  379. imgsz: int = 640
  380. device: str = ""
  381. save_local: bool = False # 是否保存到本地文件
  382. skip_connectivity_test: bool = True # 是否跳过连通性测试
  383. # 可根据需要补充更多参数
  384. def yolov12_stream_worker(params_dict, task_id):
  385. """
  386. 同步推理函数,在进程池中执行
  387. 支持信号终止
  388. """
  389. import os
  390. import cv2
  391. import time
  392. import subprocess
  393. import signal
  394. from ultralytics import YOLO
  395. # 记录当前进程PID
  396. current_pid = os.getpid()
  397. print(f"任务 {task_id} worker进程启动,PID: {current_pid}")
  398. # 将PID写入临时文件,供主进程读取
  399. pid_file = f"/tmp/yolov12_task_{task_id}.pid"
  400. try:
  401. with open(pid_file, 'w') as f:
  402. f.write(str(current_pid))
  403. except Exception as e:
  404. print(f"任务 {task_id} 写入PID文件失败: {e}")
  405. # 注册SIGTERM信号处理器
  406. def handle_sigterm(signum, frame):
  407. print(f"任务 {task_id} 收到终止信号,准备退出")
  408. # 清理PID文件
  409. try:
  410. if os.path.exists(pid_file):
  411. os.remove(pid_file)
  412. except:
  413. pass
  414. # 清理ffmpeg进程
  415. if 'ffmpeg_process' in locals() and ffmpeg_process:
  416. try:
  417. ffmpeg_process.terminate()
  418. ffmpeg_process.wait(timeout=2)
  419. except:
  420. pass
  421. # 清理视频捕获
  422. if 'cap' in locals() and cap:
  423. try:
  424. cap.release()
  425. except:
  426. pass
  427. exit(0)
  428. signal.signal(signal.SIGTERM, handle_sigterm)
  429. model_path = params_dict['model']
  430. source = params_dict['source']
  431. stream_url = params_dict['stream_url']
  432. fps = params_dict.get('fps', 25)
  433. bitrate = params_dict.get('bitrate', '2000k')
  434. conf = params_dict.get('conf', 0.25)
  435. iou = params_dict.get('iou', 0.7)
  436. imgsz = params_dict.get('imgsz', 640)
  437. device = params_dict.get('device', '')
  438. # 全局变量用于存储进程引用
  439. ffmpeg_process = None
  440. max_retries = 3
  441. retry_count = 0
  442. def cleanup_process():
  443. """清理ffmpeg进程"""
  444. nonlocal ffmpeg_process
  445. if ffmpeg_process:
  446. try:
  447. ffmpeg_process.terminate()
  448. ffmpeg_process.wait(timeout=5)
  449. except subprocess.TimeoutExpired:
  450. ffmpeg_process.kill()
  451. except Exception as e:
  452. print(f"清理ffmpeg进程时出错: {e}")
  453. def start_ffmpeg():
  454. """启动ffmpeg进程"""
  455. nonlocal ffmpeg_process, retry_count, ffmpeg_cmd
  456. try:
  457. print(f"任务 {task_id} 启动ffmpeg命令: {' '.join(ffmpeg_cmd)}")
  458. ffmpeg_process = subprocess.Popen(
  459. ffmpeg_cmd,
  460. stdin=subprocess.PIPE,
  461. stdout=subprocess.PIPE,
  462. stderr=subprocess.PIPE,
  463. bufsize=0
  464. )
  465. # 等待ffmpeg启动
  466. time.sleep(1)
  467. if ffmpeg_process.poll() is not None:
  468. # ffmpeg进程异常退出
  469. stderr_output = ffmpeg_process.stderr.read().decode() if ffmpeg_process.stderr else "未知错误"
  470. raise Exception(f"ffmpeg启动失败: {stderr_output}")
  471. retry_count = 0 # 重置重试计数
  472. return True
  473. except Exception as e:
  474. print(f"任务 {task_id} 启动ffmpeg失败: {e}")
  475. retry_count += 1
  476. # 如果是第一次失败且当前使用的是简单配置,尝试切换到高级配置
  477. if retry_count == 1 and ffmpeg_cmd == simple_ffmpeg_cmd:
  478. print(f"任务 {task_id} 简单配置失败,尝试使用高级配置")
  479. ffmpeg_cmd = advanced_ffmpeg_cmd
  480. return start_ffmpeg()
  481. if retry_count < max_retries:
  482. print(f"任务 {task_id} 尝试重试 ({retry_count}/{max_retries})")
  483. time.sleep(2) # 等待2秒后重试
  484. return start_ffmpeg()
  485. else:
  486. print(f"任务 {task_id} ffmpeg启动失败,已达到最大重试次数")
  487. return False
  488. try:
  489. # 测试推流地址连通性
  490. def test_stream_url():
  491. """测试推流地址是否可达"""
  492. try:
  493. print(f"任务 {task_id} 开始测试推流地址: {stream_url}")
  494. # 解析RTMP URL
  495. if stream_url.startswith('rtmp://'):
  496. # 提取服务器地址和端口
  497. url_parts = stream_url.replace('rtmp://', '').split('/')
  498. server_part = url_parts[0]
  499. server_host = server_part.split(':')[0] if ':' in server_part else server_part
  500. server_port = server_part.split(':')[1] if ':' in server_part else '1935'
  501. print(f"任务 {task_id} 解析的服务器信息: {server_host}:{server_port}")
  502. # 先测试网络连通性
  503. import socket
  504. try:
  505. sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  506. sock.settimeout(5)
  507. result = sock.connect_ex((server_host, int(server_port)))
  508. sock.close()
  509. if result != 0:
  510. print(f"任务 {task_id} 网络连通性测试失败: {server_host}:{server_port}")
  511. return False
  512. else:
  513. print(f"任务 {task_id} 网络连通性测试成功: {server_host}:{server_port}")
  514. except Exception as net_error:
  515. print(f"任务 {task_id} 网络测试异常: {net_error}")
  516. return False
  517. # 使用ffprobe测试推流地址
  518. test_cmd = [
  519. 'ffprobe',
  520. '-v', 'error',
  521. '-print_format', 'json',
  522. '-show_format',
  523. '-timeout', '5000000', # 5秒超时
  524. stream_url
  525. ]
  526. print(f"任务 {task_id} 执行ffprobe命令: {' '.join(test_cmd)}")
  527. result = subprocess.run(test_cmd, capture_output=True, text=True, timeout=10)
  528. if result.returncode == 0:
  529. print(f"任务 {task_id} ffprobe测试成功")
  530. return True
  531. else:
  532. print(f"任务 {task_id} ffprobe测试失败,返回码: {result.returncode}")
  533. print(f"任务 {task_id} ffprobe错误输出: {result.stderr}")
  534. return False
  535. except subprocess.TimeoutExpired:
  536. print(f"任务 {task_id} ffprobe测试超时")
  537. return False
  538. except Exception as e:
  539. print(f"任务 {task_id} 推流地址测试异常: {e}")
  540. return False
  541. # 如果是RTMP推流且不是保存到本地,先测试连通性
  542. if (stream_url.startswith('rtmp://') and
  543. not params_dict.get('save_local', False) and
  544. not params_dict.get('skip_connectivity_test', False)):
  545. print(f"任务 {task_id} 测试推流地址连通性...")
  546. if not test_stream_url():
  547. print(f"任务 {task_id} 推流地址不可达: {stream_url}")
  548. return {
  549. "code": 1,
  550. "msg": f"推流地址不可达: {stream_url}。可能的原因:1. 推流服务器未启动或不可访问 2. 网络连接问题 3. 防火墙阻止连接 4. 推流地址格式错误。建议解决方案:1. 检查推流服务器状态 2. 验证网络连接 3. 使用skip_connectivity_test=true跳过连通性测试 4. 使用save_local=true先保存到本地文件测试",
  551. "result": None
  552. }
  553. elif params_dict.get('skip_connectivity_test', False):
  554. print(f"任务 {task_id} 跳过连通性测试,直接开始推流")
  555. model = YOLO(model_path)
  556. cap = cv2.VideoCapture(source)
  557. if not cap.isOpened():
  558. return {"code": 1, "msg": f"无法打开视频流: {source}", "result": None}
  559. # 获取视频流信息
  560. fps_cap = cap.get(cv2.CAP_PROP_FPS)
  561. width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  562. height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  563. # 使用实际帧率,如果获取不到则使用参数中的fps
  564. output_fps = fps_cap if fps_cap > 0 else fps
  565. # 构建ffmpeg命令 - 高质量编码,增加错误处理
  566. # 首先尝试简单配置,如果失败再使用复杂配置
  567. # 确定输出格式和文件
  568. if params_dict.get('save_local', False):
  569. # 保存到本地文件
  570. output_file = f"output_{task_id}.mp4"
  571. output_format = 'mp4'
  572. else:
  573. # 推流
  574. output_file = stream_url
  575. output_format = 'flv' if stream_url.startswith('rtmp') else 'mpegts'
  576. simple_ffmpeg_cmd = [
  577. 'ffmpeg',
  578. '-f', 'rawvideo',
  579. '-pix_fmt', 'bgr24',
  580. '-s', f'{width}x{height}',
  581. '-r', str(output_fps),
  582. '-i', '-', # 从stdin读取
  583. '-c:v', 'libx264',
  584. '-preset', 'ultrafast', # 使用最快的预设
  585. '-tune', 'zerolatency',
  586. '-b:v', bitrate,
  587. '-loglevel', 'error', # 只显示错误信息
  588. '-f', output_format,
  589. '-y', # 覆盖输出文件
  590. output_file
  591. ]
  592. advanced_ffmpeg_cmd = [
  593. 'ffmpeg',
  594. '-f', 'rawvideo',
  595. '-pix_fmt', 'bgr24',
  596. '-s', f'{width}x{height}',
  597. '-r', str(output_fps),
  598. '-i', '-', # 从stdin读取
  599. '-c:v', 'libx264',
  600. '-preset', 'medium', # 改为medium,平衡质量和速度
  601. '-tune', 'zerolatency',
  602. '-profile:v', 'high', # 使用high profile
  603. '-level', '4.1', # 设置编码级别
  604. '-b:v', bitrate,
  605. '-maxrate', bitrate,
  606. '-bufsize', '8000k', # 增加缓冲区大小
  607. '-g', str(int(output_fps) * 2), # GOP大小
  608. '-keyint_min', str(int(output_fps)), # 最小关键帧间隔
  609. '-sc_threshold', '0', # 禁用场景切换检测
  610. '-bf', '3', # B帧数量
  611. '-refs', '6', # 参考帧数量
  612. '-x264opts', 'no-scenecut=1:nal-hrd=cbr:force-cfr=1', # x264特定选项
  613. '-color_primaries', 'bt709', # 色彩空间
  614. '-color_trc', 'bt709', # 色彩传输特性
  615. '-colorspace', 'bt709', # 色彩空间
  616. '-loglevel', 'error', # 只显示错误信息
  617. '-f', output_format,
  618. '-y', # 覆盖输出文件
  619. output_file
  620. ]
  621. # 默认使用简单配置
  622. ffmpeg_cmd = simple_ffmpeg_cmd
  623. # 启动ffmpeg进程
  624. if not start_ffmpeg():
  625. return {"code": 1, "msg": "ffmpeg启动失败,已达到最大重试次数", "result": None}
  626. frame_count = 0
  627. start_time = time.time()
  628. try:
  629. while cap.isOpened():
  630. ret, frame = cap.read()
  631. if not ret:
  632. break
  633. # YOLO推理
  634. try:
  635. predict_kwargs = {
  636. 'source': frame,
  637. 'imgsz': imgsz,
  638. 'conf': conf,
  639. 'device': device
  640. }
  641. # 只有当iou不为None时才添加到参数中
  642. if iou is not None:
  643. predict_kwargs['iou'] = iou
  644. results = model.predict(**predict_kwargs)
  645. annotated_frame = results[0].plot()
  646. except Exception as predict_error:
  647. print(f"任务 {task_id} YOLO推理出错: {predict_error}")
  648. # 如果推理失败,使用原始帧
  649. annotated_frame = frame
  650. # 将处理后的帧写入ffmpeg
  651. try:
  652. # 检查ffmpeg进程是否还在运行
  653. if ffmpeg_process.poll() is not None:
  654. stderr_output = ffmpeg_process.stderr.read().decode() if ffmpeg_process.stderr else "未知错误"
  655. print(f"任务 {task_id} ffmpeg进程已退出: {stderr_output}")
  656. # 尝试重启ffmpeg进程
  657. print(f"任务 {task_id} 尝试重启ffmpeg进程...")
  658. cleanup_process()
  659. if start_ffmpeg():
  660. print(f"任务 {task_id} ffmpeg进程重启成功,继续处理")
  661. continue
  662. else:
  663. print(f"任务 {task_id} ffmpeg进程重启失败,停止处理")
  664. break
  665. ffmpeg_process.stdin.write(annotated_frame.tobytes())
  666. ffmpeg_process.stdin.flush()
  667. frame_count += 1
  668. # 每100帧输出一次进度
  669. if frame_count % 100 == 0:
  670. elapsed_time = time.time() - start_time
  671. current_fps = frame_count / elapsed_time
  672. print(f"任务 {task_id} 已处理 {frame_count} 帧,当前FPS: {current_fps:.2f}")
  673. except IOError as e:
  674. print(f"任务 {task_id} 写入ffmpeg时出错: {e}")
  675. # 尝试获取ffmpeg的错误输出
  676. try:
  677. if ffmpeg_process.stderr:
  678. stderr_output = ffmpeg_process.stderr.read().decode()
  679. print(f"任务 {task_id} ffmpeg错误输出: {stderr_output}")
  680. except:
  681. pass
  682. # 尝试重启ffmpeg进程
  683. print(f"任务 {task_id} 尝试重启ffmpeg进程...")
  684. cleanup_process()
  685. if start_ffmpeg():
  686. print(f"任务 {task_id} ffmpeg进程重启成功,继续处理")
  687. continue
  688. else:
  689. print(f"任务 {task_id} ffmpeg进程重启失败,停止处理")
  690. break
  691. except KeyboardInterrupt:
  692. print(f"任务 {task_id} 收到中断信号,停止处理")
  693. finally:
  694. # 清理资源
  695. cap.release()
  696. cleanup_process()
  697. elapsed_time = time.time() - start_time
  698. avg_fps = frame_count / elapsed_time if elapsed_time > 0 else 0
  699. print(f"任务 {task_id} 推理并推流完成,共处理帧数: {frame_count},平均FPS: {avg_fps:.2f}")
  700. return {"code": 0, "msg": "success", "result": {"frames_processed": frame_count, "avg_fps": avg_fps}}
  701. except Exception as e:
  702. print(f"任务 {task_id} 发生异常: {e}")
  703. cleanup_process()
  704. # 提供更详细的错误信息
  705. error_msg = str(e)
  706. if "Broken pipe" in error_msg:
  707. error_msg += " - 这通常表示推流地址不可达或网络连接问题,建议:1) 检查推流服务器是否运行 2) 检查网络连接 3) 使用save_local=true先保存到本地文件测试"
  708. elif "Connection refused" in error_msg:
  709. error_msg += " - 推流服务器拒绝连接,请检查推流地址是否正确"
  710. elif "Permission denied" in error_msg:
  711. error_msg += " - 权限不足,请检查推流地址的访问权限"
  712. # 记录详细的调试信息
  713. print(f"任务 {task_id} 调试信息:")
  714. print(f" - 输入源: {source}")
  715. print(f" - 推流地址: {stream_url}")
  716. print(f" - 保存到本地: {params_dict.get('save_local', False)}")
  717. print(f" - 视频尺寸: {width}x{height}")
  718. print(f" - 帧率: {output_fps}")
  719. print(f" - 码率: {bitrate}")
  720. return {"code": 1, "msg": error_msg, "result": None}
  721. @app_fastapi.post("/yolov12/stream")
  722. async def yolov12_stream_async(params: StreamParams):
  723. """
  724. RESTful POST接口:/yolov12/stream
  725. 接收视频拉流地址和推流地址,调用YOLO模型推理,使用ffmpeg将推理后的视频推送到推流地址。
  726. 支持并发推理和任务取消。
  727. 返回格式:{"code": 0/1, "msg": "success/错误原因", "result": {"task_id": "任务ID"}}
  728. """
  729. logging.info("收到/yolov12/stream请求")
  730. logging.info(f"请求参数: {params}")
  731. # 生成唯一任务ID
  732. task_id = str(uuid.uuid4())
  733. try:
  734. # 异步执行推理任务
  735. loop = asyncio.get_event_loop()
  736. # 确保进程池健康
  737. ensure_executor_healthy()
  738. # 检查进程池是否可用
  739. try:
  740. future = loop.run_in_executor(
  741. stream_executor, partial(yolov12_stream_worker, params.dict(), task_id)
  742. )
  743. except Exception as pool_error:
  744. logging.warning(f"进程池异常,尝试重新创建: {pool_error}")
  745. create_new_executor()
  746. future = loop.run_in_executor(
  747. stream_executor, partial(yolov12_stream_worker, params.dict(), task_id)
  748. )
  749. # 添加任务完成后的清理回调
  750. def cleanup_task(fut):
  751. try:
  752. # 检查任务是否被取消
  753. if fut.cancelled():
  754. print(f"任务 {task_id} 已被取消")
  755. else:
  756. # 获取任务结果(如果有异常会抛出)
  757. result = fut.result()
  758. print(f"任务 {task_id} 已完成,结果: {result}")
  759. except asyncio.CancelledError:
  760. print(f"任务 {task_id} 被取消")
  761. except Exception as e:
  762. print(f"任务 {task_id} 执行异常: {e}")
  763. finally:
  764. # 清理任务
  765. if task_id in stream_tasks:
  766. del stream_tasks[task_id]
  767. if task_id in stream_pids:
  768. del stream_pids[task_id]
  769. # 清理PID文件
  770. pid_file = f"/tmp/yolov12_task_{task_id}.pid"
  771. try:
  772. if os.path.exists(pid_file):
  773. os.remove(pid_file)
  774. except:
  775. pass
  776. print(f"已清理任务: {task_id}")
  777. future.add_done_callback(cleanup_task)
  778. stream_tasks[task_id] = {
  779. "future": future,
  780. "start_time": int(time.time() * 1000), # 毫秒级时间戳,long类型
  781. "source": params.source,
  782. "stream_url": params.stream_url
  783. }
  784. # 等待worker进程启动并读取PID
  785. pid_file = f"/tmp/yolov12_task_{task_id}.pid"
  786. max_wait = 50 # 最多等待5秒
  787. for _ in range(max_wait):
  788. if os.path.exists(pid_file):
  789. try:
  790. with open(pid_file, 'r') as f:
  791. worker_pid = int(f.read().strip())
  792. stream_pids[task_id] = worker_pid
  793. logging.info(f"任务 {task_id} worker进程PID: {worker_pid}")
  794. break
  795. except Exception as e:
  796. logging.warning(f"读取PID文件失败: {e}")
  797. await asyncio.sleep(0.1)
  798. else:
  799. logging.warning(f"任务 {task_id} 无法获取worker进程PID")
  800. logging.info(f"任务 {task_id} 已提交到进程池")
  801. return {
  802. "code": 0,
  803. "msg": "任务已提交",
  804. "result": task_id
  805. }
  806. except Exception as e:
  807. logging.error(f"提交任务时发生异常: {e}")
  808. return {
  809. "code": 1,
  810. "msg": str(e),
  811. "result": None
  812. }
  813. @app_fastapi.post("/yolov12/stream/cancel")
  814. async def cancel_stream_task(task_id: str):
  815. """
  816. RESTful POST接口:/yolov12/stream/cancel
  817. 取消指定的推理任务
  818. 返回格式:{"code": 0/1, "msg": "success/错误原因", "result": None}
  819. """
  820. logging.info(f"收到取消任务请求: {task_id}")
  821. # 先清理已完成的任务
  822. cleanup_completed_tasks()
  823. task_info = stream_tasks.get(task_id)
  824. if not task_info:
  825. return {"code": 1, "msg": "任务不存在", "result": None}
  826. if task_info["future"].done():
  827. return {"code": 1, "msg": "任务已完成,无法取消", "result": None}
  828. try:
  829. # 尝试取消Future
  830. cancelled = task_info["future"].cancel()
  831. # 获取worker进程的PID并终止进程
  832. worker_pid = stream_pids.get(task_id)
  833. if worker_pid:
  834. try:
  835. logging.info(f"终止worker进程: {worker_pid}")
  836. # 使用SIGTERM优雅终止,而不是强制杀死
  837. os.kill(worker_pid, signal.SIGTERM)
  838. # 等待进程结束,最多等待3秒
  839. import time
  840. for _ in range(30): # 30次 * 0.1秒 = 3秒
  841. try:
  842. os.kill(worker_pid, 0) # 检查进程是否还存在
  843. time.sleep(0.1)
  844. except OSError:
  845. # 进程已经结束
  846. break
  847. else:
  848. # 如果进程还在运行,记录警告但不强制杀死
  849. logging.warning(f"进程 {worker_pid} 未能在3秒内优雅退出,但已发送取消信号")
  850. except OSError as e:
  851. logging.warning(f"终止进程时出错: {e}")
  852. if cancelled:
  853. logging.info(f"任务 {task_id} 已取消")
  854. # 取消后不立即清理,让回调函数处理
  855. # 只清理PID文件,因为进程可能还在运行
  856. pid_file = f"/tmp/yolov12_task_{task_id}.pid"
  857. try:
  858. if os.path.exists(pid_file):
  859. os.remove(pid_file)
  860. except:
  861. pass
  862. return {"code": 0, "msg": "任务已取消", "result": None}
  863. else:
  864. return {"code": 1, "msg": "任务无法取消(可能正在运行)", "result": None}
  865. except Exception as e:
  866. logging.error(f"取消任务时发生异常: {e}")
  867. return {"code": 1, "msg": str(e), "result": None}
  868. @app_fastapi.get("/yolov12/stream/status")
  869. async def get_stream_status(task_id: str):
  870. """
  871. RESTful GET接口:/yolov12/stream/status
  872. 查询指定任务的状态
  873. 返回格式:{"code": 0/1, "msg": "success/错误原因", "result": {"status": "状态", "result": "结果"}}
  874. """
  875. logging.info(f"收到查询任务状态请求: {task_id}")
  876. # 先清理已完成的任务
  877. cleanup_completed_tasks()
  878. task_info = stream_tasks.get(task_id)
  879. if not task_info:
  880. return {"code": 1, "msg": "任务不存在", "result": None}
  881. try:
  882. if task_info["future"].done():
  883. try:
  884. # 检查是否被取消
  885. if task_info["future"].cancelled():
  886. run_time = (time.time() * 1000 - task_info["start_time"]) / 1000 # 转换为秒
  887. return {
  888. "code": 0,
  889. "msg": "已取消",
  890. "result": {
  891. "status": "cancelled",
  892. "start_time": task_info["start_time"],
  893. "run_time": round(run_time, 2),
  894. "source": task_info["source"],
  895. "stream_url": task_info["stream_url"]
  896. }
  897. }
  898. else:
  899. result = task_info["future"].result()
  900. run_time = (time.time() * 1000 - task_info["start_time"]) / 1000 # 转换为秒
  901. return {
  902. "code": 0,
  903. "msg": "已完成",
  904. "result": {
  905. "status": "completed",
  906. "result": result,
  907. "start_time": task_info["start_time"],
  908. "run_time": round(run_time, 2),
  909. "source": task_info["source"],
  910. "stream_url": task_info["stream_url"]
  911. }
  912. }
  913. except asyncio.CancelledError:
  914. run_time = (time.time() * 1000 - task_info["start_time"]) / 1000 # 转换为秒
  915. return {
  916. "code": 0,
  917. "msg": "已取消",
  918. "result": {
  919. "status": "cancelled",
  920. "start_time": task_info["start_time"],
  921. "run_time": round(run_time, 2),
  922. "source": task_info["source"],
  923. "stream_url": task_info["stream_url"]
  924. }
  925. }
  926. except Exception as e:
  927. run_time = (time.time() * 1000 - task_info["start_time"]) / 1000 # 转换为秒
  928. return {
  929. "code": 1,
  930. "msg": f"任务异常: {e}",
  931. "result": {
  932. "status": "failed",
  933. "error": str(e),
  934. "start_time": task_info["start_time"],
  935. "run_time": round(run_time, 2),
  936. "source": task_info["source"],
  937. "stream_url": task_info["stream_url"]
  938. }
  939. }
  940. elif task_info["future"].cancelled():
  941. run_time = (time.time() * 1000 - task_info["start_time"]) / 1000 # 转换为秒
  942. return {
  943. "code": 0,
  944. "msg": "已取消",
  945. "result": {
  946. "status": "cancelled",
  947. "start_time": task_info["start_time"],
  948. "run_time": round(run_time, 2),
  949. "source": task_info["source"],
  950. "stream_url": task_info["stream_url"]
  951. }
  952. }
  953. else:
  954. run_time = (time.time() * 1000 - task_info["start_time"]) / 1000 # 转换为秒
  955. return {
  956. "code": 0,
  957. "msg": "运行中",
  958. "result": {
  959. "status": "running",
  960. "start_time": task_info["start_time"],
  961. "run_time": round(run_time, 2),
  962. "source": task_info["source"],
  963. "stream_url": task_info["stream_url"]
  964. }
  965. }
  966. except Exception as e:
  967. logging.error(f"查询任务状态时发生异常: {e}")
  968. return {"code": 1, "msg": str(e), "result": None}
  969. @app_fastapi.get("/yolov12/stream/list")
  970. async def list_stream_tasks():
  971. """
  972. RESTful GET接口:/yolov12/stream/list
  973. 列出所有任务的状态
  974. 返回格式:{"code": 0, "msg": "success", "result": {"tasks": [{"task_id": "ID", "status": "状态"}]}}
  975. """
  976. logging.info("收到查询所有任务请求")
  977. # 先清理已完成的任务
  978. cleanup_completed_tasks()
  979. try:
  980. tasks_info = []
  981. for task_id, task_info in stream_tasks.items():
  982. if task_info["future"].done():
  983. try:
  984. if task_info["future"].cancelled():
  985. status = "cancelled"
  986. else:
  987. task_info["future"].result() # 检查是否有异常
  988. status = "completed"
  989. except asyncio.CancelledError:
  990. status = "cancelled"
  991. except:
  992. status = "failed"
  993. else:
  994. status = "running"
  995. # 计算运行时长
  996. run_time = (time.time() * 1000 - task_info["start_time"]) / 1000 # 转换为秒
  997. tasks_info.append({
  998. "task_id": task_id,
  999. "status": status,
  1000. "start_time": task_info["start_time"],
  1001. "run_time": round(run_time, 2), # 运行时长(秒)
  1002. "source": task_info["source"], # 播流地址
  1003. "stream_url": task_info["stream_url"] # 推流地址
  1004. })
  1005. return {
  1006. "code": 0,
  1007. "msg": "success",
  1008. "result": tasks_info
  1009. }
  1010. except Exception as e:
  1011. logging.error(f"查询所有任务时发生异常: {e}")
  1012. return {"code": 1, "msg": str(e), "result": None}
  1013. # 全局异常处理器:参数校验失败时统一返回格式
  1014. @app_fastapi.exception_handler(RequestValidationError)
  1015. async def validation_exception_handler(request, exc):
  1016. err_msg = f"参数校验失败: 路径={request.url.path}, 错误={exc.errors()}"
  1017. logging.error(err_msg)
  1018. return JSONResponse(
  1019. status_code=status.HTTP_200_OK,
  1020. content={
  1021. "code": 422,
  1022. "msg": err_msg,
  1023. "result": None
  1024. }
  1025. )
  1026. def check_executor_health():
  1027. """检查进程池执行器状态"""
  1028. global stream_executor
  1029. try:
  1030. # 尝试提交一个简单的测试任务
  1031. test_future = stream_executor.submit(lambda: 1)
  1032. test_future.result(timeout=1)
  1033. test_future.cancel()
  1034. return True
  1035. except Exception as e:
  1036. logging.warning(f"进程池健康检查失败: {e}")
  1037. return False
  1038. def ensure_executor_healthy():
  1039. """确保进程池执行器健康"""
  1040. if not check_executor_health():
  1041. logging.warning("进程池不健康,重新创建")
  1042. create_new_executor()
  1043. return False
  1044. return True
  1045. if __name__ == "__main__":
  1046. threading.Thread(target=start_gradio, daemon=True).start()
  1047. uvicorn.run(app_fastapi, host="0.0.0.0", port=8000)