app.py 38 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004
  1. # --------------------------------------------------------
  2. # Based on yolov10
  3. # https://github.com/THU-MIG/yolov10/app.py
  4. # --------------------------------------------------------'
  5. import logging
  6. import tempfile
  7. import threading
  8. import cv2
  9. import gradio as gr
  10. import uvicorn
  11. import asyncio
  12. from fastapi import FastAPI
  13. from fastapi import status
  14. from fastapi.exceptions import RequestValidationError
  15. from fastapi.responses import JSONResponse
  16. from pydantic import BaseModel
  17. from ultralytics import YOLO
  18. import os
  19. import glob
  20. import subprocess
  21. import signal
  22. import time
  23. from typing import Optional, Dict
  24. from concurrent.futures import ProcessPoolExecutor, Future
  25. from functools import partial
  26. import uuid
  27. # 设置日志格式和级别
  28. logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s - %(message)s')
  29. # 创建进程池执行器和任务管理
  30. stream_executor = ProcessPoolExecutor(max_workers=4)
  31. stream_tasks: Dict[str, Dict] = {} # 存储任务信息:{task_id: {"future": Future, "start_time": long, "source": str, "stream_url": str}}
  32. stream_pids: Dict[str, int] = {} # 记录每个任务的worker进程pid
  33. def cleanup_completed_tasks():
  34. """清理已完成的任务"""
  35. completed_tasks = []
  36. for task_id, task_info in stream_tasks.items():
  37. if task_info["future"].done():
  38. completed_tasks.append(task_id)
  39. for task_id in completed_tasks:
  40. del stream_tasks[task_id]
  41. if task_id in stream_pids:
  42. del stream_pids[task_id]
  43. print(f"已清理完成的任务: {task_id}")
  44. def yolov12_inference(image, video, model_id, image_size, conf_threshold):
  45. model = YOLO(model_id)
  46. if image:
  47. results = model.predict(source=image, imgsz=image_size, conf=conf_threshold)
  48. annotated_image = results[0].plot()
  49. return annotated_image[:, :, ::-1], None
  50. else:
  51. video_path = tempfile.mktemp(suffix=".webm")
  52. with open(video_path, "wb") as f:
  53. with open(video, "rb") as g:
  54. f.write(g.read())
  55. cap = cv2.VideoCapture(video_path)
  56. fps = cap.get(cv2.CAP_PROP_FPS)
  57. frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  58. frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  59. output_video_path = tempfile.mktemp(suffix=".webm")
  60. out = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'vp80'), fps, (frame_width, frame_height))
  61. while cap.isOpened():
  62. ret, frame = cap.read()
  63. if not ret:
  64. break
  65. results = model.predict(source=frame, imgsz=image_size, conf=conf_threshold)
  66. annotated_frame = results[0].plot()
  67. out.write(annotated_frame)
  68. cap.release()
  69. out.release()
  70. return None, output_video_path
  71. def yolov12_inference_for_examples(image, model_path, image_size, conf_threshold):
  72. annotated_image, _ = yolov12_inference(image, None, model_path, image_size, conf_threshold)
  73. return annotated_image
  74. def app():
  75. with gr.Blocks():
  76. with gr.Row():
  77. with gr.Column():
  78. image = gr.Image(type="pil", label="Image", visible=True)
  79. video = gr.Video(label="Video", visible=False)
  80. input_type = gr.Radio(
  81. choices=["Image", "Video"],
  82. value="Image",
  83. label="Input Type",
  84. )
  85. model_id = gr.Dropdown(
  86. label="Model",
  87. choices=[
  88. "yolov12n.pt",
  89. "yolov12s.pt",
  90. "yolov12m.pt",
  91. "yolov12l.pt",
  92. "yolov12x.pt",
  93. ],
  94. value="yolov12m.pt",
  95. )
  96. image_size = gr.Slider(
  97. label="Image Size",
  98. minimum=320,
  99. maximum=1280,
  100. step=32,
  101. value=640,
  102. )
  103. conf_threshold = gr.Slider(
  104. label="Confidence Threshold",
  105. minimum=0.0,
  106. maximum=1.0,
  107. step=0.05,
  108. value=0.25,
  109. )
  110. yolov12_infer = gr.Button(value="Detect Objects")
  111. with gr.Column():
  112. output_image = gr.Image(type="numpy", label="Annotated Image", visible=True)
  113. output_video = gr.Video(label="Annotated Video", visible=False)
  114. def update_visibility(input_type):
  115. image = gr.update(visible=True) if input_type == "Image" else gr.update(visible=False)
  116. video = gr.update(visible=False) if input_type == "Image" else gr.update(visible=True)
  117. output_image = gr.update(visible=True) if input_type == "Image" else gr.update(visible=False)
  118. output_video = gr.update(visible=False) if input_type == "Image" else gr.update(visible=True)
  119. return image, video, output_image, output_video
  120. input_type.change(
  121. fn=update_visibility,
  122. inputs=[input_type],
  123. outputs=[image, video, output_image, output_video],
  124. )
  125. def run_inference(image, video, model_id, image_size, conf_threshold, input_type):
  126. if input_type == "Image":
  127. return yolov12_inference(image, None, model_id, image_size, conf_threshold)
  128. else:
  129. return yolov12_inference(None, video, model_id, image_size, conf_threshold)
  130. yolov12_infer.click(
  131. fn=run_inference,
  132. inputs=[image, video, model_id, image_size, conf_threshold, input_type],
  133. outputs=[output_image, output_video],
  134. )
  135. gr.Examples(
  136. examples=[
  137. [
  138. "ultralytics/assets/bus.jpg",
  139. "yolov12s.pt",
  140. 640,
  141. 0.25,
  142. ],
  143. [
  144. "ultralytics/assets/zidane.jpg",
  145. "yolov12x.pt",
  146. 640,
  147. 0.25,
  148. ],
  149. ],
  150. fn=yolov12_inference_for_examples,
  151. inputs=[
  152. image,
  153. model_id,
  154. image_size,
  155. conf_threshold,
  156. ],
  157. outputs=[output_image],
  158. cache_examples='lazy',
  159. )
  160. gradio_app = gr.Blocks()
  161. with gradio_app:
  162. gr.HTML(
  163. """
  164. <h1 style='text-align: center'>
  165. YOLOv12: Attention-Centric Real-Time Object Detectors
  166. </h1>
  167. """)
  168. gr.HTML(
  169. """
  170. <h3 style='text-align: center'>
  171. <a href='https://arxiv.org/abs/2502.12524' target='_blank'>arXiv</a> | <a href='https://github.com/sunsmarterjie/yolov12' target='_blank'>github</a>
  172. </h3>
  173. """)
  174. with gr.Row():
  175. with gr.Column():
  176. app()
  177. def start_gradio():
  178. gradio_app.launch(server_name="0.0.0.0", server_port=7860)
  179. # FastAPI部分
  180. app_fastapi = FastAPI()
  181. class TrainParams(BaseModel):
  182. """
  183. 用于接收/yolov12/train接口的训练参数,所有参数均需前端传入。
  184. """
  185. model: str # 训练底模
  186. data: str # 数据集配置文件路径
  187. epochs: int # 训练轮数
  188. batch: int # 批次大小
  189. imgsz: int # 输入图片尺寸
  190. scale: float # 随机缩放增强比例
  191. mosaic: float # mosaic数据增强概率
  192. mixup: float # mixup数据增强概率
  193. copy_paste: float # copy-paste数据增强概率
  194. device: str # 训练设备
  195. project: str # 工程名
  196. name: str # 实验名
  197. exist_ok: bool # 是否允许覆盖同名目录
  198. @app_fastapi.post("/yolov12/train")
  199. def yolov12_train(params: TrainParams):
  200. """
  201. RESTful POST接口:/yolov12/train
  202. 接收训练参数,调用YOLO模型训练,并返回训练结果。
  203. 返回格式:{"code": 0/1, "msg": "success/错误原因", "result": 训练结果或None}
  204. """
  205. logging.info("收到/yolov12/train训练请求")
  206. logging.info(f"请求参数: {params}")
  207. try:
  208. # 根据params.model动态确定配置文件
  209. if params.model.endswith('.pt'):
  210. # 如果是.pt文件,将后缀替换为.yaml
  211. config_file = params.model.replace('.pt', '.yaml')
  212. else:
  213. # 如果不是.pt文件,使用默认配置
  214. config_file = "yolov12.yaml"
  215. model = YOLO(config_file)
  216. model.load(params.model)
  217. logging.info("开始模型训练...")
  218. results = model.train(
  219. data=params.data,
  220. epochs=params.epochs,
  221. batch=params.batch,
  222. imgsz=params.imgsz,
  223. scale=params.scale,
  224. mosaic=params.mosaic,
  225. mixup=params.mixup,
  226. copy_paste=params.copy_paste,
  227. device=params.device,
  228. project=params.project,
  229. name=params.name,
  230. exist_ok=params.exist_ok,
  231. )
  232. logging.info("模型训练完成")
  233. # logging.info(f"训练结果: {str(results)}")
  234. return {
  235. "code": 0,
  236. "msg": "success",
  237. "result": str(results.save_dir)
  238. }
  239. except Exception as e:
  240. logging.error(f"训练过程发生异常: {e}")
  241. return {
  242. "code": 1,
  243. "msg": str(e),
  244. "result": None
  245. }
  246. class PredictParams(BaseModel):
  247. """
  248. 用于接收/yolov12/predict接口的预测参数,与YOLO predict方法保持一致。
  249. """
  250. model: str = "yolov12m.pt" # 模型路径
  251. source: str = None # 输入源(图片/视频路径、URL等)
  252. stream: bool = False # 是否流式处理
  253. conf: float = 0.25 # 置信度阈值
  254. iou: float = 0.7 # IoU阈值
  255. max_det: int = 300 # 最大检测数量
  256. imgsz: int = 640 # 输入图片尺寸
  257. batch: int = 1 # 批次大小
  258. device: str = "" # 设备
  259. show: bool = False # 是否显示结果
  260. save: bool = False # 是否保存结果
  261. save_txt: bool = False # 是否保存txt文件
  262. save_conf: bool = False # 是否保存置信度
  263. save_crop: bool = False # 是否保存裁剪图片
  264. show_labels: bool = True # 是否显示标签
  265. show_conf: bool = True # 是否显示置信度
  266. show_boxes: bool = True # 是否显示边界框
  267. line_width: int = None # 线条宽度
  268. vid_stride: int = 1 # 视频帧步长
  269. stream_buffer: bool = False # 流缓冲区
  270. visualize: bool = False # 可视化特征
  271. augment: bool = False # 数据增强
  272. agnostic_nms: bool = False # 类别无关NMS
  273. classes: list = None # 指定类别
  274. retina_masks: bool = False # 高分辨率分割掩码
  275. embed: list = None # 特征向量层
  276. half: bool = False # 半精度
  277. dnn: bool = False # OpenCV DNN
  278. project: str = "" # 项目名
  279. name: str = "" # 实验名
  280. exist_ok: bool = False # 是否覆盖现有目录
  281. verbose: bool = True # 详细输出
  282. @app_fastapi.post("/yolov12/predict")
  283. def yolov12_predict(params: PredictParams):
  284. """
  285. RESTful POST接口:/yolov12/predict
  286. 接收预测参数,调用YOLO模型进行预测,并返回预测结果。
  287. 返回格式:{"code": 0/1, "msg": "success/错误原因", "result": {"save_dir": "保存目录", "filename": "文件名"}}
  288. """
  289. logging.info("收到/yolov12/predict预测请求")
  290. logging.info(f"请求参数: {params}")
  291. try:
  292. model = YOLO(params.model)
  293. logging.info("开始模型预测...")
  294. # 构建预测参数
  295. predict_kwargs = {}
  296. for field, value in params.dict().items():
  297. if field not in ['model'] and value is not None:
  298. predict_kwargs[field] = value
  299. # 确保保存结果,并强制使用MP4格式
  300. predict_kwargs['save'] = True
  301. results = model.predict(**predict_kwargs)
  302. logging.info("模型预测完成")
  303. # 获取保存目录和最终文件名
  304. result = results[0]
  305. save_dir = result.save_dir if hasattr(result, 'save_dir') else None
  306. # 获取最终生成的文件名
  307. final_filename = None
  308. if save_dir:
  309. # 获取输入文件名(不含扩展名)
  310. source = params.source
  311. base_name = None
  312. if source:
  313. base_name = os.path.splitext(os.path.basename(source))[0]
  314. # 支持的扩展名
  315. exts = ['*.jpg', '*.jpeg', '*.png', '*.mp4', '*.webm', '*.avi', '*.mov']
  316. matched_files = []
  317. for ext in exts:
  318. matched_files.extend(glob.glob(os.path.join(save_dir, ext)))
  319. # 按时间排序,查找与输入文件同名的第一个文件
  320. if base_name and matched_files:
  321. matched_files = sorted(matched_files, key=os.path.getmtime)
  322. for f in matched_files:
  323. if os.path.splitext(os.path.basename(f))[0] == base_name:
  324. final_filename = os.path.basename(f)
  325. logging.info(f"按输入文件名查找,返回文件: {final_filename}")
  326. break
  327. # 如果没找到同名文件,返回最新文件
  328. if not final_filename and matched_files:
  329. latest_file = max(matched_files, key=os.path.getmtime)
  330. final_filename = os.path.basename(latest_file)
  331. logging.info(f"未找到同名,返回最新文件: {final_filename}")
  332. return {
  333. "code": 0,
  334. "msg": "success",
  335. "result": save_dir+"/"+final_filename
  336. }
  337. except Exception as e:
  338. logging.error(f"预测过程发生异常: {e}")
  339. return {
  340. "code": 1,
  341. "msg": str(e),
  342. "result": None
  343. }
  344. class StreamParams(BaseModel):
  345. """
  346. 用于接收 /yolov12/stream 接口的参数。
  347. model: 推理模型路径
  348. source: 拉流地址(如rtsp/http视频流)
  349. stream_url: 推流地址(如rtmp推流地址)
  350. fps: 输出帧率
  351. bitrate: 输出码率
  352. save_local: 是否保存到本地文件(用于调试)
  353. 其他参数同 predict
  354. """
  355. model: str = "yolov12m.pt"
  356. source: str = None
  357. stream_url: str = None
  358. fps: int = 25
  359. bitrate: str = "6000k" # 提高默认码率
  360. conf: float = 0.25
  361. iou: Optional[float] = 0.7
  362. imgsz: int = 640
  363. device: str = ""
  364. save_local: bool = False # 是否保存到本地文件
  365. skip_connectivity_test: bool = True # 是否跳过连通性测试
  366. # 可根据需要补充更多参数
  367. def yolov12_stream_worker(params_dict, task_id):
  368. """
  369. 同步推理函数,在进程池中执行
  370. 支持信号终止
  371. """
  372. import os
  373. import cv2
  374. import time
  375. import subprocess
  376. import signal
  377. from ultralytics import YOLO
  378. # 注册SIGTERM信号处理器
  379. def handle_sigterm(signum, frame):
  380. print(f"任务 {task_id} 收到终止信号,准备退出")
  381. exit(0)
  382. signal.signal(signal.SIGTERM, handle_sigterm)
  383. model_path = params_dict['model']
  384. source = params_dict['source']
  385. stream_url = params_dict['stream_url']
  386. fps = params_dict.get('fps', 25)
  387. bitrate = params_dict.get('bitrate', '2000k')
  388. conf = params_dict.get('conf', 0.25)
  389. iou = params_dict.get('iou', 0.7)
  390. imgsz = params_dict.get('imgsz', 640)
  391. device = params_dict.get('device', '')
  392. # 全局变量用于存储进程引用
  393. ffmpeg_process = None
  394. max_retries = 3
  395. retry_count = 0
  396. def cleanup_process():
  397. """清理ffmpeg进程"""
  398. nonlocal ffmpeg_process
  399. if ffmpeg_process:
  400. try:
  401. ffmpeg_process.terminate()
  402. ffmpeg_process.wait(timeout=5)
  403. except subprocess.TimeoutExpired:
  404. ffmpeg_process.kill()
  405. except Exception as e:
  406. print(f"清理ffmpeg进程时出错: {e}")
  407. def start_ffmpeg():
  408. """启动ffmpeg进程"""
  409. nonlocal ffmpeg_process, retry_count, ffmpeg_cmd
  410. try:
  411. print(f"任务 {task_id} 启动ffmpeg命令: {' '.join(ffmpeg_cmd)}")
  412. ffmpeg_process = subprocess.Popen(
  413. ffmpeg_cmd,
  414. stdin=subprocess.PIPE,
  415. stdout=subprocess.PIPE,
  416. stderr=subprocess.PIPE,
  417. bufsize=0
  418. )
  419. # 等待ffmpeg启动
  420. time.sleep(1)
  421. if ffmpeg_process.poll() is not None:
  422. # ffmpeg进程异常退出
  423. stderr_output = ffmpeg_process.stderr.read().decode() if ffmpeg_process.stderr else "未知错误"
  424. raise Exception(f"ffmpeg启动失败: {stderr_output}")
  425. retry_count = 0 # 重置重试计数
  426. return True
  427. except Exception as e:
  428. print(f"任务 {task_id} 启动ffmpeg失败: {e}")
  429. retry_count += 1
  430. # 如果是第一次失败且当前使用的是简单配置,尝试切换到高级配置
  431. if retry_count == 1 and ffmpeg_cmd == simple_ffmpeg_cmd:
  432. print(f"任务 {task_id} 简单配置失败,尝试使用高级配置")
  433. ffmpeg_cmd = advanced_ffmpeg_cmd
  434. return start_ffmpeg()
  435. if retry_count < max_retries:
  436. print(f"任务 {task_id} 尝试重试 ({retry_count}/{max_retries})")
  437. time.sleep(2) # 等待2秒后重试
  438. return start_ffmpeg()
  439. else:
  440. print(f"任务 {task_id} ffmpeg启动失败,已达到最大重试次数")
  441. return False
  442. try:
  443. # 测试推流地址连通性
  444. def test_stream_url():
  445. """测试推流地址是否可达"""
  446. try:
  447. print(f"任务 {task_id} 开始测试推流地址: {stream_url}")
  448. # 解析RTMP URL
  449. if stream_url.startswith('rtmp://'):
  450. # 提取服务器地址和端口
  451. url_parts = stream_url.replace('rtmp://', '').split('/')
  452. server_part = url_parts[0]
  453. server_host = server_part.split(':')[0] if ':' in server_part else server_part
  454. server_port = server_part.split(':')[1] if ':' in server_part else '1935'
  455. print(f"任务 {task_id} 解析的服务器信息: {server_host}:{server_port}")
  456. # 先测试网络连通性
  457. import socket
  458. try:
  459. sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  460. sock.settimeout(5)
  461. result = sock.connect_ex((server_host, int(server_port)))
  462. sock.close()
  463. if result != 0:
  464. print(f"任务 {task_id} 网络连通性测试失败: {server_host}:{server_port}")
  465. return False
  466. else:
  467. print(f"任务 {task_id} 网络连通性测试成功: {server_host}:{server_port}")
  468. except Exception as net_error:
  469. print(f"任务 {task_id} 网络测试异常: {net_error}")
  470. return False
  471. # 使用ffprobe测试推流地址
  472. test_cmd = [
  473. 'ffprobe',
  474. '-v', 'error',
  475. '-print_format', 'json',
  476. '-show_format',
  477. '-timeout', '5000000', # 5秒超时
  478. stream_url
  479. ]
  480. print(f"任务 {task_id} 执行ffprobe命令: {' '.join(test_cmd)}")
  481. result = subprocess.run(test_cmd, capture_output=True, text=True, timeout=10)
  482. if result.returncode == 0:
  483. print(f"任务 {task_id} ffprobe测试成功")
  484. return True
  485. else:
  486. print(f"任务 {task_id} ffprobe测试失败,返回码: {result.returncode}")
  487. print(f"任务 {task_id} ffprobe错误输出: {result.stderr}")
  488. return False
  489. except subprocess.TimeoutExpired:
  490. print(f"任务 {task_id} ffprobe测试超时")
  491. return False
  492. except Exception as e:
  493. print(f"任务 {task_id} 推流地址测试异常: {e}")
  494. return False
  495. # 如果是RTMP推流且不是保存到本地,先测试连通性
  496. if (stream_url.startswith('rtmp://') and
  497. not params_dict.get('save_local', False) and
  498. not params_dict.get('skip_connectivity_test', False)):
  499. print(f"任务 {task_id} 测试推流地址连通性...")
  500. if not test_stream_url():
  501. print(f"任务 {task_id} 推流地址不可达: {stream_url}")
  502. return {
  503. "code": 1,
  504. "msg": f"推流地址不可达: {stream_url}。可能的原因:1. 推流服务器未启动或不可访问 2. 网络连接问题 3. 防火墙阻止连接 4. 推流地址格式错误。建议解决方案:1. 检查推流服务器状态 2. 验证网络连接 3. 使用skip_connectivity_test=true跳过连通性测试 4. 使用save_local=true先保存到本地文件测试",
  505. "result": None
  506. }
  507. elif params_dict.get('skip_connectivity_test', False):
  508. print(f"任务 {task_id} 跳过连通性测试,直接开始推流")
  509. model = YOLO(model_path)
  510. cap = cv2.VideoCapture(source)
  511. if not cap.isOpened():
  512. return {"code": 1, "msg": f"无法打开视频流: {source}", "result": None}
  513. # 获取视频流信息
  514. fps_cap = cap.get(cv2.CAP_PROP_FPS)
  515. width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  516. height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  517. # 使用实际帧率,如果获取不到则使用参数中的fps
  518. output_fps = fps_cap if fps_cap > 0 else fps
  519. # 构建ffmpeg命令 - 高质量编码,增加错误处理
  520. # 首先尝试简单配置,如果失败再使用复杂配置
  521. # 确定输出格式和文件
  522. if params_dict.get('save_local', False):
  523. # 保存到本地文件
  524. output_file = f"output_{task_id}.mp4"
  525. output_format = 'mp4'
  526. else:
  527. # 推流
  528. output_file = stream_url
  529. output_format = 'flv' if stream_url.startswith('rtmp') else 'mpegts'
  530. simple_ffmpeg_cmd = [
  531. 'ffmpeg',
  532. '-f', 'rawvideo',
  533. '-pix_fmt', 'bgr24',
  534. '-s', f'{width}x{height}',
  535. '-r', str(output_fps),
  536. '-i', '-', # 从stdin读取
  537. '-c:v', 'libx264',
  538. '-preset', 'ultrafast', # 使用最快的预设
  539. '-tune', 'zerolatency',
  540. '-b:v', bitrate,
  541. '-loglevel', 'error', # 只显示错误信息
  542. '-f', output_format,
  543. '-y', # 覆盖输出文件
  544. output_file
  545. ]
  546. advanced_ffmpeg_cmd = [
  547. 'ffmpeg',
  548. '-f', 'rawvideo',
  549. '-pix_fmt', 'bgr24',
  550. '-s', f'{width}x{height}',
  551. '-r', str(output_fps),
  552. '-i', '-', # 从stdin读取
  553. '-c:v', 'libx264',
  554. '-preset', 'medium', # 改为medium,平衡质量和速度
  555. '-tune', 'zerolatency',
  556. '-profile:v', 'high', # 使用high profile
  557. '-level', '4.1', # 设置编码级别
  558. '-b:v', bitrate,
  559. '-maxrate', bitrate,
  560. '-bufsize', '8000k', # 增加缓冲区大小
  561. '-g', str(int(output_fps) * 2), # GOP大小
  562. '-keyint_min', str(int(output_fps)), # 最小关键帧间隔
  563. '-sc_threshold', '0', # 禁用场景切换检测
  564. '-bf', '3', # B帧数量
  565. '-refs', '6', # 参考帧数量
  566. '-x264opts', 'no-scenecut=1:nal-hrd=cbr:force-cfr=1', # x264特定选项
  567. '-color_primaries', 'bt709', # 色彩空间
  568. '-color_trc', 'bt709', # 色彩传输特性
  569. '-colorspace', 'bt709', # 色彩空间
  570. '-loglevel', 'error', # 只显示错误信息
  571. '-f', output_format,
  572. '-y', # 覆盖输出文件
  573. output_file
  574. ]
  575. # 默认使用简单配置
  576. ffmpeg_cmd = simple_ffmpeg_cmd
  577. # 启动ffmpeg进程
  578. if not start_ffmpeg():
  579. return {"code": 1, "msg": "ffmpeg启动失败,已达到最大重试次数", "result": None}
  580. frame_count = 0
  581. start_time = time.time()
  582. try:
  583. while cap.isOpened():
  584. ret, frame = cap.read()
  585. if not ret:
  586. break
  587. # YOLO推理
  588. try:
  589. predict_kwargs = {
  590. 'source': frame,
  591. 'imgsz': imgsz,
  592. 'conf': conf,
  593. 'device': device
  594. }
  595. # 只有当iou不为None时才添加到参数中
  596. if iou is not None:
  597. predict_kwargs['iou'] = iou
  598. results = model.predict(**predict_kwargs)
  599. annotated_frame = results[0].plot()
  600. except Exception as predict_error:
  601. print(f"任务 {task_id} YOLO推理出错: {predict_error}")
  602. # 如果推理失败,使用原始帧
  603. annotated_frame = frame
  604. # 将处理后的帧写入ffmpeg
  605. try:
  606. # 检查ffmpeg进程是否还在运行
  607. if ffmpeg_process.poll() is not None:
  608. stderr_output = ffmpeg_process.stderr.read().decode() if ffmpeg_process.stderr else "未知错误"
  609. print(f"任务 {task_id} ffmpeg进程已退出: {stderr_output}")
  610. # 尝试重启ffmpeg进程
  611. print(f"任务 {task_id} 尝试重启ffmpeg进程...")
  612. cleanup_process()
  613. if start_ffmpeg():
  614. print(f"任务 {task_id} ffmpeg进程重启成功,继续处理")
  615. continue
  616. else:
  617. print(f"任务 {task_id} ffmpeg进程重启失败,停止处理")
  618. break
  619. ffmpeg_process.stdin.write(annotated_frame.tobytes())
  620. ffmpeg_process.stdin.flush()
  621. frame_count += 1
  622. # 每100帧输出一次进度
  623. if frame_count % 100 == 0:
  624. elapsed_time = time.time() - start_time
  625. current_fps = frame_count / elapsed_time
  626. print(f"任务 {task_id} 已处理 {frame_count} 帧,当前FPS: {current_fps:.2f}")
  627. except IOError as e:
  628. print(f"任务 {task_id} 写入ffmpeg时出错: {e}")
  629. # 尝试获取ffmpeg的错误输出
  630. try:
  631. if ffmpeg_process.stderr:
  632. stderr_output = ffmpeg_process.stderr.read().decode()
  633. print(f"任务 {task_id} ffmpeg错误输出: {stderr_output}")
  634. except:
  635. pass
  636. # 尝试重启ffmpeg进程
  637. print(f"任务 {task_id} 尝试重启ffmpeg进程...")
  638. cleanup_process()
  639. if start_ffmpeg():
  640. print(f"任务 {task_id} ffmpeg进程重启成功,继续处理")
  641. continue
  642. else:
  643. print(f"任务 {task_id} ffmpeg进程重启失败,停止处理")
  644. break
  645. except KeyboardInterrupt:
  646. print(f"任务 {task_id} 收到中断信号,停止处理")
  647. finally:
  648. # 清理资源
  649. cap.release()
  650. cleanup_process()
  651. elapsed_time = time.time() - start_time
  652. avg_fps = frame_count / elapsed_time if elapsed_time > 0 else 0
  653. print(f"任务 {task_id} 推理并推流完成,共处理帧数: {frame_count},平均FPS: {avg_fps:.2f}")
  654. return {"code": 0, "msg": "success", "result": {"frames_processed": frame_count, "avg_fps": avg_fps}}
  655. except Exception as e:
  656. print(f"任务 {task_id} 发生异常: {e}")
  657. cleanup_process()
  658. # 提供更详细的错误信息
  659. error_msg = str(e)
  660. if "Broken pipe" in error_msg:
  661. error_msg += " - 这通常表示推流地址不可达或网络连接问题,建议:1) 检查推流服务器是否运行 2) 检查网络连接 3) 使用save_local=true先保存到本地文件测试"
  662. elif "Connection refused" in error_msg:
  663. error_msg += " - 推流服务器拒绝连接,请检查推流地址是否正确"
  664. elif "Permission denied" in error_msg:
  665. error_msg += " - 权限不足,请检查推流地址的访问权限"
  666. # 记录详细的调试信息
  667. print(f"任务 {task_id} 调试信息:")
  668. print(f" - 输入源: {source}")
  669. print(f" - 推流地址: {stream_url}")
  670. print(f" - 保存到本地: {params_dict.get('save_local', False)}")
  671. print(f" - 视频尺寸: {width}x{height}")
  672. print(f" - 帧率: {output_fps}")
  673. print(f" - 码率: {bitrate}")
  674. return {"code": 1, "msg": error_msg, "result": None}
  675. @app_fastapi.post("/yolov12/stream")
  676. async def yolov12_stream_async(params: StreamParams):
  677. """
  678. RESTful POST接口:/yolov12/stream
  679. 接收视频拉流地址和推流地址,调用YOLO模型推理,使用ffmpeg将推理后的视频推送到推流地址。
  680. 支持并发推理和任务取消。
  681. 返回格式:{"code": 0/1, "msg": "success/错误原因", "result": {"task_id": "任务ID"}}
  682. """
  683. logging.info("收到/yolov12/stream请求")
  684. logging.info(f"请求参数: {params}")
  685. # 生成唯一任务ID
  686. task_id = str(uuid.uuid4())
  687. try:
  688. # 异步执行推理任务
  689. loop = asyncio.get_event_loop()
  690. future = loop.run_in_executor(
  691. stream_executor, partial(yolov12_stream_worker, params.dict(), task_id)
  692. )
  693. # 添加任务完成后的清理回调
  694. def cleanup_task(fut):
  695. try:
  696. # 获取任务结果(如果有异常会抛出)
  697. result = fut.result()
  698. print(f"任务 {task_id} 已完成,结果: {result}")
  699. except Exception as e:
  700. print(f"任务 {task_id} 执行异常: {e}")
  701. finally:
  702. # 清理任务
  703. if task_id in stream_tasks:
  704. del stream_tasks[task_id]
  705. if task_id in stream_pids:
  706. del stream_pids[task_id]
  707. print(f"已清理任务: {task_id}")
  708. future.add_done_callback(cleanup_task)
  709. stream_tasks[task_id] = {
  710. "future": future,
  711. "start_time": int(time.time() * 1000), # 毫秒级时间戳,long类型
  712. "source": params.source,
  713. "stream_url": params.stream_url
  714. }
  715. logging.info(f"任务 {task_id} 已提交到进程池")
  716. return {
  717. "code": 0,
  718. "msg": "任务已提交",
  719. "result": task_id
  720. }
  721. except Exception as e:
  722. logging.error(f"提交任务时发生异常: {e}")
  723. return {
  724. "code": 1,
  725. "msg": str(e),
  726. "result": None
  727. }
  728. @app_fastapi.post("/yolov12/stream/cancel")
  729. async def cancel_stream_task(task_id: str):
  730. """
  731. RESTful POST接口:/yolov12/stream/cancel
  732. 取消指定的推理任务
  733. 返回格式:{"code": 0/1, "msg": "success/错误原因", "result": None}
  734. """
  735. logging.info(f"收到取消任务请求: {task_id}")
  736. # 先清理已完成的任务
  737. cleanup_completed_tasks()
  738. task_info = stream_tasks.get(task_id)
  739. if not task_info:
  740. return {"code": 1, "msg": "任务不存在", "result": None}
  741. if task_info["future"].done():
  742. return {"code": 1, "msg": "任务已完成,无法取消", "result": None}
  743. try:
  744. # 尝试取消任务
  745. cancelled = task_info["future"].cancel()
  746. if cancelled:
  747. logging.info(f"任务 {task_id} 已取消")
  748. # 取消后立即清理
  749. if task_id in stream_tasks:
  750. del stream_tasks[task_id]
  751. if task_id in stream_pids:
  752. del stream_pids[task_id]
  753. return {"code": 0, "msg": "任务已取消", "result": None}
  754. else:
  755. return {"code": 1, "msg": "任务无法取消(可能正在运行)", "result": None}
  756. except Exception as e:
  757. logging.error(f"取消任务时发生异常: {e}")
  758. return {"code": 1, "msg": str(e), "result": None}
  759. @app_fastapi.get("/yolov12/stream/status")
  760. async def get_stream_status(task_id: str):
  761. """
  762. RESTful GET接口:/yolov12/stream/status
  763. 查询指定任务的状态
  764. 返回格式:{"code": 0/1, "msg": "success/错误原因", "result": {"status": "状态", "result": "结果"}}
  765. """
  766. logging.info(f"收到查询任务状态请求: {task_id}")
  767. # 先清理已完成的任务
  768. cleanup_completed_tasks()
  769. task_info = stream_tasks.get(task_id)
  770. if not task_info:
  771. return {"code": 1, "msg": "任务不存在", "result": None}
  772. try:
  773. if task_info["future"].done():
  774. try:
  775. result = task_info["future"].result()
  776. run_time = (time.time() * 1000 - task_info["start_time"]) / 1000 # 转换为秒
  777. return {
  778. "code": 0,
  779. "msg": "已完成",
  780. "result": {
  781. "status": "completed",
  782. "result": result,
  783. "start_time": task_info["start_time"],
  784. "run_time": round(run_time, 2),
  785. "source": task_info["source"],
  786. "stream_url": task_info["stream_url"]
  787. }
  788. }
  789. except Exception as e:
  790. run_time = (time.time() * 1000 - task_info["start_time"]) / 1000 # 转换为秒
  791. return {
  792. "code": 1,
  793. "msg": f"任务异常: {e}",
  794. "result": {
  795. "status": "failed",
  796. "error": str(e),
  797. "start_time": task_info["start_time"],
  798. "run_time": round(run_time, 2),
  799. "source": task_info["source"],
  800. "stream_url": task_info["stream_url"]
  801. }
  802. }
  803. elif task_info["future"].cancelled():
  804. run_time = (time.time() * 1000 - task_info["start_time"]) / 1000 # 转换为秒
  805. return {
  806. "code": 0,
  807. "msg": "已取消",
  808. "result": {
  809. "status": "cancelled",
  810. "start_time": task_info["start_time"],
  811. "run_time": round(run_time, 2),
  812. "source": task_info["source"],
  813. "stream_url": task_info["stream_url"]
  814. }
  815. }
  816. else:
  817. run_time = (time.time() * 1000 - task_info["start_time"]) / 1000 # 转换为秒
  818. return {
  819. "code": 0,
  820. "msg": "运行中",
  821. "result": {
  822. "status": "running",
  823. "start_time": task_info["start_time"],
  824. "run_time": round(run_time, 2),
  825. "source": task_info["source"],
  826. "stream_url": task_info["stream_url"]
  827. }
  828. }
  829. except Exception as e:
  830. logging.error(f"查询任务状态时发生异常: {e}")
  831. return {"code": 1, "msg": str(e), "result": None}
  832. @app_fastapi.get("/yolov12/stream/list")
  833. async def list_stream_tasks():
  834. """
  835. RESTful GET接口:/yolov12/stream/list
  836. 列出所有任务的状态
  837. 返回格式:{"code": 0, "msg": "success", "result": {"tasks": [{"task_id": "ID", "status": "状态"}]}}
  838. """
  839. logging.info("收到查询所有任务请求")
  840. # 先清理已完成的任务
  841. cleanup_completed_tasks()
  842. try:
  843. tasks_info = []
  844. for task_id, task_info in stream_tasks.items():
  845. if task_info["future"].done():
  846. if task_info["future"].cancelled():
  847. status = "cancelled"
  848. else:
  849. try:
  850. task_info["future"].result() # 检查是否有异常
  851. status = "completed"
  852. except:
  853. status = "failed"
  854. else:
  855. status = "running"
  856. # 计算运行时长
  857. run_time = (time.time() * 1000 - task_info["start_time"]) / 1000 # 转换为秒
  858. tasks_info.append({
  859. "task_id": task_id,
  860. "status": status,
  861. "start_time": task_info["start_time"],
  862. "run_time": round(run_time, 2), # 运行时长(秒)
  863. "source": task_info["source"], # 播流地址
  864. "stream_url": task_info["stream_url"] # 推流地址
  865. })
  866. return {
  867. "code": 0,
  868. "msg": "success",
  869. "result": tasks_info
  870. }
  871. except Exception as e:
  872. logging.error(f"查询所有任务时发生异常: {e}")
  873. return {"code": 1, "msg": str(e), "result": None}
  874. # 全局异常处理器:参数校验失败时统一返回格式
  875. @app_fastapi.exception_handler(RequestValidationError)
  876. async def validation_exception_handler(request, exc):
  877. err_msg = f"参数校验失败: 路径={request.url.path}, 错误={exc.errors()}"
  878. logging.error(err_msg)
  879. return JSONResponse(
  880. status_code=status.HTTP_200_OK,
  881. content={
  882. "code": 422,
  883. "msg": err_msg,
  884. "result": None
  885. }
  886. )
  887. if __name__ == "__main__":
  888. threading.Thread(target=start_gradio, daemon=True).start()
  889. uvicorn.run(app_fastapi, host="0.0.0.0", port=8000)