app.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603
  1. # --------------------------------------------------------
  2. # Based on yolov10
  3. # https://github.com/THU-MIG/yolov10/app.py
  4. # --------------------------------------------------------'
  5. import logging
  6. import tempfile
  7. import threading
  8. import cv2
  9. import gradio as gr
  10. import uvicorn
  11. from fastapi import FastAPI
  12. from fastapi import status
  13. from fastapi.exceptions import RequestValidationError
  14. from fastapi.responses import JSONResponse
  15. from pydantic import BaseModel
  16. from ultralytics import YOLO
  17. # 设置日志格式和级别
  18. logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s - %(message)s')
  19. def yolov12_inference(image, video, model_id, image_size, conf_threshold):
  20. model = YOLO(model_id)
  21. if image:
  22. results = model.predict(source=image, imgsz=image_size, conf=conf_threshold)
  23. annotated_image = results[0].plot()
  24. return annotated_image[:, :, ::-1], None
  25. else:
  26. video_path = tempfile.mktemp(suffix=".webm")
  27. with open(video_path, "wb") as f:
  28. with open(video, "rb") as g:
  29. f.write(g.read())
  30. cap = cv2.VideoCapture(video_path)
  31. fps = cap.get(cv2.CAP_PROP_FPS)
  32. frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  33. frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  34. output_video_path = tempfile.mktemp(suffix=".webm")
  35. out = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'vp80'), fps, (frame_width, frame_height))
  36. while cap.isOpened():
  37. ret, frame = cap.read()
  38. if not ret:
  39. break
  40. results = model.predict(source=frame, imgsz=image_size, conf=conf_threshold)
  41. annotated_frame = results[0].plot()
  42. out.write(annotated_frame)
  43. cap.release()
  44. out.release()
  45. return None, output_video_path
  46. def yolov12_inference_for_examples(image, model_path, image_size, conf_threshold):
  47. annotated_image, _ = yolov12_inference(image, None, model_path, image_size, conf_threshold)
  48. return annotated_image
  49. def app():
  50. with gr.Blocks():
  51. with gr.Row():
  52. with gr.Column():
  53. image = gr.Image(type="pil", label="Image", visible=True)
  54. video = gr.Video(label="Video", visible=False)
  55. input_type = gr.Radio(
  56. choices=["Image", "Video"],
  57. value="Image",
  58. label="Input Type",
  59. )
  60. model_id = gr.Dropdown(
  61. label="Model",
  62. choices=[
  63. "yolov12n.pt",
  64. "yolov12s.pt",
  65. "yolov12m.pt",
  66. "yolov12l.pt",
  67. "yolov12x.pt",
  68. ],
  69. value="yolov12m.pt",
  70. )
  71. image_size = gr.Slider(
  72. label="Image Size",
  73. minimum=320,
  74. maximum=1280,
  75. step=32,
  76. value=640,
  77. )
  78. conf_threshold = gr.Slider(
  79. label="Confidence Threshold",
  80. minimum=0.0,
  81. maximum=1.0,
  82. step=0.05,
  83. value=0.25,
  84. )
  85. yolov12_infer = gr.Button(value="Detect Objects")
  86. with gr.Column():
  87. output_image = gr.Image(type="numpy", label="Annotated Image", visible=True)
  88. output_video = gr.Video(label="Annotated Video", visible=False)
  89. def update_visibility(input_type):
  90. image = gr.update(visible=True) if input_type == "Image" else gr.update(visible=False)
  91. video = gr.update(visible=False) if input_type == "Image" else gr.update(visible=True)
  92. output_image = gr.update(visible=True) if input_type == "Image" else gr.update(visible=False)
  93. output_video = gr.update(visible=False) if input_type == "Image" else gr.update(visible=True)
  94. return image, video, output_image, output_video
  95. input_type.change(
  96. fn=update_visibility,
  97. inputs=[input_type],
  98. outputs=[image, video, output_image, output_video],
  99. )
  100. def run_inference(image, video, model_id, image_size, conf_threshold, input_type):
  101. if input_type == "Image":
  102. return yolov12_inference(image, None, model_id, image_size, conf_threshold)
  103. else:
  104. return yolov12_inference(None, video, model_id, image_size, conf_threshold)
  105. yolov12_infer.click(
  106. fn=run_inference,
  107. inputs=[image, video, model_id, image_size, conf_threshold, input_type],
  108. outputs=[output_image, output_video],
  109. )
  110. gr.Examples(
  111. examples=[
  112. [
  113. "ultralytics/assets/bus.jpg",
  114. "yolov12s.pt",
  115. 640,
  116. 0.25,
  117. ],
  118. [
  119. "ultralytics/assets/zidane.jpg",
  120. "yolov12x.pt",
  121. 640,
  122. 0.25,
  123. ],
  124. ],
  125. fn=yolov12_inference_for_examples,
  126. inputs=[
  127. image,
  128. model_id,
  129. image_size,
  130. conf_threshold,
  131. ],
  132. outputs=[output_image],
  133. cache_examples='lazy',
  134. )
  135. gradio_app = gr.Blocks()
  136. with gradio_app:
  137. gr.HTML(
  138. """
  139. <h1 style='text-align: center'>
  140. YOLOv12: Attention-Centric Real-Time Object Detectors
  141. </h1>
  142. """)
  143. gr.HTML(
  144. """
  145. <h3 style='text-align: center'>
  146. <a href='https://arxiv.org/abs/2502.12524' target='_blank'>arXiv</a> | <a href='https://github.com/sunsmarterjie/yolov12' target='_blank'>github</a>
  147. </h3>
  148. """)
  149. with gr.Row():
  150. with gr.Column():
  151. app()
  152. def start_gradio():
  153. gradio_app.launch(server_name="0.0.0.0", server_port=7860)
  154. # FastAPI部分
  155. app_fastapi = FastAPI()
  156. @app_fastapi.get("/")
  157. def read_root():
  158. """根路径,返回摄像头检测页面"""
  159. return FileResponse("camera_detect.html")
  160. @app_fastapi.get("/camera")
  161. def camera_page():
  162. """摄像头检测页面"""
  163. return FileResponse("camera_detect.html")
  164. @app_fastapi.get("/test-mp4")
  165. def test_mp4_page():
  166. """MP4播放测试页面"""
  167. return FileResponse("test_browser_mp4.html")
  168. class TrainParams(BaseModel):
  169. """
  170. 用于接收/yolov12/train接口的训练参数,所有参数均需前端传入。
  171. """
  172. model: str # 训练底模
  173. data: str # 数据集配置文件路径
  174. epochs: int # 训练轮数
  175. batch: int # 批次大小
  176. imgsz: int # 输入图片尺寸
  177. scale: float # 随机缩放增强比例
  178. mosaic: float # mosaic数据增强概率
  179. mixup: float # mixup数据增强概率
  180. copy_paste: float # copy-paste数据增强概率
  181. device: str # 训练设备
  182. project: str # 工程名
  183. name: str # 实验名
  184. exist_ok: bool # 是否允许覆盖同名目录
  185. @app_fastapi.post("/yolov12/train")
  186. def yolov12_train(params: TrainParams):
  187. """
  188. RESTful POST接口:/yolov12/train
  189. 接收训练参数,调用YOLO模型训练,并返回训练结果。
  190. 返回格式:{"code": 0/1, "msg": "success/错误原因", "result": 训练结果或None}
  191. """
  192. logging.info("收到/yolov12/train训练请求")
  193. logging.info(f"请求参数: {params}")
  194. try:
  195. # 根据params.model动态确定配置文件
  196. if params.model.endswith('.pt'):
  197. # 如果是.pt文件,将后缀替换为.yaml
  198. config_file = params.model.replace('.pt', '.yaml')
  199. else:
  200. # 如果不是.pt文件,使用默认配置
  201. config_file = "yolov12.yaml"
  202. model = YOLO(config_file)
  203. model.load(params.model)
  204. logging.info("开始模型训练...")
  205. results = model.train(
  206. data=params.data,
  207. epochs=params.epochs,
  208. batch=params.batch,
  209. imgsz=params.imgsz,
  210. scale=params.scale,
  211. mosaic=params.mosaic,
  212. mixup=params.mixup,
  213. copy_paste=params.copy_paste,
  214. device=params.device,
  215. project=params.project,
  216. name=params.name,
  217. exist_ok=params.exist_ok,
  218. )
  219. logging.info("模型训练完成")
  220. # logging.info(f"训练结果: {str(results)}")
  221. return {
  222. "code": 0,
  223. "msg": "success",
  224. "result": str(results.save_dir)
  225. }
  226. except Exception as e:
  227. logging.error(f"训练过程发生异常: {e}")
  228. return {
  229. "code": 1,
  230. "msg": str(e),
  231. "result": None
  232. }
  233. class PredictParams(BaseModel):
  234. """
  235. 用于接收/yolov12/predict接口的预测参数,与YOLO predict方法保持一致。
  236. """
  237. model: str = "yolov12m.pt" # 模型路径
  238. source: str = None # 输入源(图片/视频路径、URL等)
  239. stream: bool = False # 是否流式处理
  240. conf: float = 0.25 # 置信度阈值
  241. iou: float = 0.7 # IoU阈值
  242. max_det: int = 300 # 最大检测数量
  243. imgsz: int = 640 # 输入图片尺寸
  244. batch: int = 1 # 批次大小
  245. device: str = "" # 设备
  246. show: bool = False # 是否显示结果
  247. save: bool = False # 是否保存结果
  248. save_txt: bool = False # 是否保存txt文件
  249. save_conf: bool = False # 是否保存置信度
  250. save_crop: bool = False # 是否保存裁剪图片
  251. show_labels: bool = True # 是否显示标签
  252. show_conf: bool = True # 是否显示置信度
  253. show_boxes: bool = True # 是否显示边界框
  254. line_width: int = None # 线条宽度
  255. vid_stride: int = 1 # 视频帧步长
  256. stream_buffer: bool = False # 流缓冲区
  257. visualize: bool = False # 可视化特征
  258. augment: bool = False # 数据增强
  259. agnostic_nms: bool = False # 类别无关NMS
  260. classes: list = None # 指定类别
  261. retina_masks: bool = False # 高分辨率分割掩码
  262. embed: list = None # 特征向量层
  263. half: bool = False # 半精度
  264. dnn: bool = False # OpenCV DNN
  265. project: str = "" # 项目名
  266. name: str = "" # 实验名
  267. exist_ok: bool = False # 是否覆盖现有目录
  268. verbose: bool = True # 详细输出
  269. @app_fastapi.post("/yolov12/predict")
  270. def yolov12_predict(params: PredictParams):
  271. """
  272. RESTful POST接口:/yolov12/predict
  273. 接收预测参数,调用YOLO模型进行预测,并返回预测结果。
  274. 返回格式:{"code": 0/1, "msg": "success/错误原因", "result": {"save_dir": "保存目录", "filename": "文件名"}}
  275. """
  276. logging.info("收到/yolov12/predict预测请求")
  277. logging.info(f"请求参数: {params}")
  278. try:
  279. model = YOLO(params.model)
  280. logging.info("开始模型预测...")
  281. # 构建预测参数
  282. predict_kwargs = {}
  283. for field, value in params.dict().items():
  284. if field not in ['model'] and value is not None:
  285. predict_kwargs[field] = value
  286. # 确保保存结果,并强制使用MP4格式
  287. predict_kwargs['save'] = True
  288. # 如果输入是视频,强制设置输出格式为MP4
  289. source = params.source
  290. if source:
  291. import os
  292. source_ext = os.path.splitext(source)[1].lower()
  293. video_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.webm', '.flv', '.wmv']
  294. if source_ext in video_extensions:
  295. # 对于视频输入,设置项目名和实验名以确保输出路径
  296. if not predict_kwargs.get('project'):
  297. predict_kwargs['project'] = 'runs/detect'
  298. if not predict_kwargs.get('name'):
  299. predict_kwargs['name'] = 'predict'
  300. results = model.predict(**predict_kwargs)
  301. logging.info("模型预测完成")
  302. # 获取保存目录和最终文件名
  303. result = results[0]
  304. save_dir = result.save_dir if hasattr(result, 'save_dir') else None
  305. # 获取最终生成的文件名
  306. final_filename = None
  307. if save_dir:
  308. import os
  309. import glob
  310. if os.path.exists(save_dir):
  311. # 检查输入源类型
  312. source = params.source
  313. if source:
  314. source_ext = os.path.splitext(source)[1].lower()
  315. video_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.webm', '.flv', '.wmv']
  316. # 如果输入是图片,返回图片文件
  317. if source_ext not in video_extensions:
  318. image_files = []
  319. for ext in ['*.jpg', '*.jpeg', '*.png']:
  320. image_files.extend(glob.glob(os.path.join(save_dir, ext)))
  321. if image_files:
  322. latest_image = max(image_files, key=os.path.getmtime)
  323. final_filename = os.path.basename(latest_image)
  324. logging.info(f"输入为图片,返回图片文件: {final_filename}")
  325. # 如果输入是视频,检查并转换为MP4
  326. else:
  327. # 查找所有视频文件
  328. video_files = []
  329. for ext in ['*.avi', '*.webm', '*.mov']:
  330. video_files.extend(glob.glob(os.path.join(save_dir, ext)))
  331. # 如果找到非MP4视频文件,尝试转换为MP4,失败则转换为WebM
  332. for video_file in video_files:
  333. output_mp4 = video_file.rsplit('.', 1)[0] + '.mp4'
  334. output_webm = video_file.rsplit('.', 1)[0] + '.webm'
  335. try:
  336. import subprocess
  337. # 使用ffmpeg转换为浏览器兼容的MP4格式(简化参数)
  338. cmd = [
  339. 'ffmpeg', '-i', video_file,
  340. '-c:v', 'libx264', # H.264编码器
  341. '-preset', 'ultrafast', # 最快编码
  342. '-crf', '28', # 稍低质量但更兼容
  343. '-pix_fmt', 'yuv420p', # 兼容性像素格式
  344. '-y', output_mp4
  345. ]
  346. logging.info(f"使用ffmpeg转换视频: {video_file} -> {output_mp4}")
  347. result = subprocess.run(cmd, capture_output=True, text=True)
  348. if result.returncode == 0:
  349. # 验证生成的MP4文件
  350. verify_cmd = [
  351. 'ffprobe', '-v', 'quiet',
  352. '-select_streams', 'v:0',
  353. '-show_entries', 'stream=codec_name',
  354. '-of', 'csv=p=0', output_mp4
  355. ]
  356. verify_result = subprocess.run(verify_cmd, capture_output=True, text=True)
  357. if verify_result.returncode == 0 and 'h264' in verify_result.stdout.lower():
  358. os.remove(video_file)
  359. logging.info(f"✓ 成功转换为H.264 MP4: {output_mp4}")
  360. else:
  361. logging.warning(f"生成的MP4可能不是H.264编码,尝试转换为WebM")
  362. # 尝试转换为WebM
  363. webm_cmd = [
  364. 'ffmpeg', '-i', video_file,
  365. '-c:v', 'libvpx',
  366. '-crf', '30',
  367. '-b:v', '0',
  368. '-y', output_webm
  369. ]
  370. webm_result = subprocess.run(webm_cmd, capture_output=True, text=True)
  371. if webm_result.returncode == 0:
  372. os.remove(video_file)
  373. logging.info(f"✓ 成功转换为WebM: {output_webm}")
  374. else:
  375. logging.error(f"WebM转换也失败: {webm_result.stderr}")
  376. else:
  377. logging.error(f"MP4转换失败,尝试转换为WebM: {result.stderr}")
  378. # 尝试转换为WebM
  379. webm_cmd = [
  380. 'ffmpeg', '-i', video_file,
  381. '-c:v', 'libvpx',
  382. '-crf', '30',
  383. '-b:v', '0',
  384. '-y', output_webm
  385. ]
  386. webm_result = subprocess.run(webm_cmd, capture_output=True, text=True)
  387. if webm_result.returncode == 0:
  388. os.remove(video_file)
  389. logging.info(f"✓ 成功转换为WebM: {output_webm}")
  390. else:
  391. logging.error(f"WebM转换也失败: {webm_result.stderr}")
  392. # 如果ffmpeg失败,尝试使用OpenCV作为备选方案
  393. logging.info("尝试使用OpenCV作为备选方案")
  394. try:
  395. import cv2
  396. cap = cv2.VideoCapture(video_file)
  397. fps = cap.get(cv2.CAP_PROP_FPS)
  398. width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  399. height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  400. # 尝试使用H264编码器
  401. fourcc = cv2.VideoWriter_fourcc(*'H264')
  402. out = cv2.VideoWriter(output_mp4, fourcc, fps, (width, height))
  403. if out.isOpened():
  404. while cap.isOpened():
  405. ret, frame = cap.read()
  406. if not ret:
  407. break
  408. out.write(frame)
  409. cap.release()
  410. out.release()
  411. os.remove(video_file)
  412. logging.info(f"使用OpenCV H264编码器生成MP4: {output_mp4}")
  413. else:
  414. logging.error("OpenCV H264编码器也无法使用")
  415. except Exception as cv_error:
  416. logging.error(f"OpenCV备选方案也失败: {cv_error}")
  417. except (FileNotFoundError, subprocess.SubprocessError) as e:
  418. logging.error(f"ffmpeg不可用: {e}")
  419. # 如果ffmpeg不可用,尝试使用OpenCV
  420. try:
  421. import cv2
  422. cap = cv2.VideoCapture(video_file)
  423. fps = cap.get(cv2.CAP_PROP_FPS)
  424. width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  425. height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  426. # 强制使用H264编码器,如果失败则使用XVID
  427. fourcc_options = ['H264', 'avc1']
  428. out = None
  429. for fourcc in fourcc_options:
  430. try:
  431. codec = cv2.VideoWriter_fourcc(*fourcc)
  432. out = cv2.VideoWriter(output_mp4, codec, fps, (width, height))
  433. if out.isOpened():
  434. logging.info(f"使用OpenCV编码器 {fourcc}")
  435. break
  436. except:
  437. continue
  438. # 如果H264和avc1都失败,使用XVID生成AVI,然后用ffmpeg转换
  439. if not out or not out.isOpened():
  440. logging.warning("H264编码器不可用,使用XVID生成AVI")
  441. temp_avi = video_file.rsplit('.', 1)[0] + '_temp.avi'
  442. xvid_codec = cv2.VideoWriter_fourcc(*'XVID')
  443. out = cv2.VideoWriter(temp_avi, xvid_codec, fps, (width, height))
  444. if out.isOpened():
  445. while cap.isOpened():
  446. ret, frame = cap.read()
  447. if not ret:
  448. break
  449. out.write(frame)
  450. cap.release()
  451. out.release()
  452. # 尝试用ffmpeg将AVI转换为MP4
  453. try:
  454. cmd = [
  455. 'ffmpeg', '-i', temp_avi,
  456. '-c:v', 'libx264',
  457. '-preset', 'ultrafast',
  458. '-crf', '28',
  459. '-pix_fmt', 'yuv420p',
  460. '-y', output_mp4
  461. ]
  462. result = subprocess.run(cmd, capture_output=True, text=True)
  463. if result.returncode == 0:
  464. os.remove(temp_avi)
  465. logging.info(f"通过AVI中转成功生成MP4: {output_mp4}")
  466. else:
  467. logging.error(f"AVI转MP4失败: {result.stderr}")
  468. except Exception as e:
  469. logging.error(f"ffmpeg不可用,保持AVI格式: {e}")
  470. if out and out.isOpened():
  471. while cap.isOpened():
  472. ret, frame = cap.read()
  473. if not ret:
  474. break
  475. out.write(frame)
  476. cap.release()
  477. out.release()
  478. os.remove(video_file)
  479. logging.info(f"使用OpenCV生成MP4: {output_mp4}")
  480. else:
  481. logging.error("所有编码器都无法使用")
  482. except Exception as cv_error:
  483. logging.error(f"OpenCV处理失败: {cv_error}")
  484. except Exception as e:
  485. logging.error(f"转换视频格式时出错: {e}")
  486. # 获取MP4或WebM文件
  487. video_output_files = []
  488. for ext in ['*.mp4', '*.webm']:
  489. video_output_files.extend(glob.glob(os.path.join(save_dir, ext)))
  490. if video_output_files:
  491. latest_video = max(video_output_files, key=os.path.getmtime)
  492. final_filename = os.path.basename(latest_video)
  493. logging.info(f"输入为视频,返回文件: {final_filename}")
  494. # 如果无法确定输入类型或未找到文件,返回最新文件
  495. if not final_filename:
  496. all_files = []
  497. for ext in ['*.jpg', '*.jpeg', '*.png', '*.mp4']:
  498. all_files.extend(glob.glob(os.path.join(save_dir, ext)))
  499. if all_files:
  500. latest_file = max(all_files, key=os.path.getmtime)
  501. final_filename = os.path.basename(latest_file)
  502. logging.info(f"返回最新文件: {final_filename}")
  503. return {
  504. "code": 0,
  505. "msg": "success",
  506. "result": save_dir+"/"+final_filename
  507. }
  508. except Exception as e:
  509. logging.error(f"预测过程发生异常: {e}")
  510. return {
  511. "code": 1,
  512. "msg": str(e),
  513. "result": None
  514. }
  515. # 全局异常处理器:参数校验失败时统一返回格式
  516. @app_fastapi.exception_handler(RequestValidationError)
  517. async def validation_exception_handler(request, exc):
  518. err_msg = f"参数校验失败: 路径={request.url.path}, 错误={exc.errors()}"
  519. logging.error(err_msg)
  520. return JSONResponse(
  521. status_code=status.HTTP_200_OK,
  522. content={
  523. "code": 422,
  524. "msg": err_msg,
  525. "result": None
  526. }
  527. )
  528. if __name__ == "__main__":
  529. threading.Thread(target=start_gradio, daemon=True).start()
  530. uvicorn.run(app_fastapi, host="0.0.0.0", port=8000)