app.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. # --------------------------------------------------------
  2. # Based on yolov10
  3. # https://github.com/THU-MIG/yolov10/app.py
  4. # --------------------------------------------------------'
  5. import gradio as gr
  6. import cv2
  7. import tempfile
  8. from ultralytics import YOLO
  9. import threading
  10. from fastapi import FastAPI
  11. from pydantic import BaseModel
  12. import uvicorn
  13. import logging
  14. from fastapi.responses import JSONResponse
  15. from fastapi.exception_handlers import RequestValidationError
  16. from fastapi.exceptions import RequestValidationError
  17. from fastapi import status
  18. # 设置日志格式和级别
  19. logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s - %(message)s')
  20. def yolov12_inference(image, video, model_id, image_size, conf_threshold):
  21. model = YOLO(model_id)
  22. if image:
  23. results = model.predict(source=image, imgsz=image_size, conf=conf_threshold)
  24. annotated_image = results[0].plot()
  25. return annotated_image[:, :, ::-1], None
  26. else:
  27. video_path = tempfile.mktemp(suffix=".webm")
  28. with open(video_path, "wb") as f:
  29. with open(video, "rb") as g:
  30. f.write(g.read())
  31. cap = cv2.VideoCapture(video_path)
  32. fps = cap.get(cv2.CAP_PROP_FPS)
  33. frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  34. frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  35. output_video_path = tempfile.mktemp(suffix=".webm")
  36. out = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'vp80'), fps, (frame_width, frame_height))
  37. while cap.isOpened():
  38. ret, frame = cap.read()
  39. if not ret:
  40. break
  41. results = model.predict(source=frame, imgsz=image_size, conf=conf_threshold)
  42. annotated_frame = results[0].plot()
  43. out.write(annotated_frame)
  44. cap.release()
  45. out.release()
  46. return None, output_video_path
  47. def yolov12_inference_for_examples(image, model_path, image_size, conf_threshold):
  48. annotated_image, _ = yolov12_inference(image, None, model_path, image_size, conf_threshold)
  49. return annotated_image
  50. def app():
  51. with gr.Blocks():
  52. with gr.Row():
  53. with gr.Column():
  54. image = gr.Image(type="pil", label="Image", visible=True)
  55. video = gr.Video(label="Video", visible=False)
  56. input_type = gr.Radio(
  57. choices=["Image", "Video"],
  58. value="Image",
  59. label="Input Type",
  60. )
  61. model_id = gr.Dropdown(
  62. label="Model",
  63. choices=[
  64. "yolov12n.pt",
  65. "yolov12s.pt",
  66. "yolov12m.pt",
  67. "yolov12l.pt",
  68. "yolov12x.pt",
  69. ],
  70. value="yolov12m.pt",
  71. )
  72. image_size = gr.Slider(
  73. label="Image Size",
  74. minimum=320,
  75. maximum=1280,
  76. step=32,
  77. value=640,
  78. )
  79. conf_threshold = gr.Slider(
  80. label="Confidence Threshold",
  81. minimum=0.0,
  82. maximum=1.0,
  83. step=0.05,
  84. value=0.25,
  85. )
  86. yolov12_infer = gr.Button(value="Detect Objects")
  87. with gr.Column():
  88. output_image = gr.Image(type="numpy", label="Annotated Image", visible=True)
  89. output_video = gr.Video(label="Annotated Video", visible=False)
  90. def update_visibility(input_type):
  91. image = gr.update(visible=True) if input_type == "Image" else gr.update(visible=False)
  92. video = gr.update(visible=False) if input_type == "Image" else gr.update(visible=True)
  93. output_image = gr.update(visible=True) if input_type == "Image" else gr.update(visible=False)
  94. output_video = gr.update(visible=False) if input_type == "Image" else gr.update(visible=True)
  95. return image, video, output_image, output_video
  96. input_type.change(
  97. fn=update_visibility,
  98. inputs=[input_type],
  99. outputs=[image, video, output_image, output_video],
  100. )
  101. def run_inference(image, video, model_id, image_size, conf_threshold, input_type):
  102. if input_type == "Image":
  103. return yolov12_inference(image, None, model_id, image_size, conf_threshold)
  104. else:
  105. return yolov12_inference(None, video, model_id, image_size, conf_threshold)
  106. yolov12_infer.click(
  107. fn=run_inference,
  108. inputs=[image, video, model_id, image_size, conf_threshold, input_type],
  109. outputs=[output_image, output_video],
  110. )
  111. gr.Examples(
  112. examples=[
  113. [
  114. "ultralytics/assets/bus.jpg",
  115. "yolov12s.pt",
  116. 640,
  117. 0.25,
  118. ],
  119. [
  120. "ultralytics/assets/zidane.jpg",
  121. "yolov12x.pt",
  122. 640,
  123. 0.25,
  124. ],
  125. ],
  126. fn=yolov12_inference_for_examples,
  127. inputs=[
  128. image,
  129. model_id,
  130. image_size,
  131. conf_threshold,
  132. ],
  133. outputs=[output_image],
  134. cache_examples='lazy',
  135. )
  136. gradio_app = gr.Blocks()
  137. with gradio_app:
  138. gr.HTML(
  139. """
  140. <h1 style='text-align: center'>
  141. YOLOv12: Attention-Centric Real-Time Object Detectors
  142. </h1>
  143. """)
  144. gr.HTML(
  145. """
  146. <h3 style='text-align: center'>
  147. <a href='https://arxiv.org/abs/2502.12524' target='_blank'>arXiv</a> | <a href='https://github.com/sunsmarterjie/yolov12' target='_blank'>github</a>
  148. </h3>
  149. """)
  150. with gr.Row():
  151. with gr.Column():
  152. app()
  153. def start_gradio():
  154. gradio_app.launch(server_name="0.0.0.0", server_port=7860)
  155. # FastAPI部分
  156. app_fastapi = FastAPI()
  157. class TrainParams(BaseModel):
  158. """
  159. 用于接收/yolov12/train接口的训练参数,所有参数均需前端传入。
  160. """
  161. data: str # 数据集配置文件路径
  162. epochs: int # 训练轮数
  163. batch: int # 批次大小
  164. imgsz: int # 输入图片尺寸
  165. scale: float # 随机缩放增强比例
  166. mosaic: float # mosaic数据增强概率
  167. mixup: float # mixup数据增强概率
  168. copy_paste: float # copy-paste数据增强概率
  169. device: str # 训练设备
  170. project: str # 工程名
  171. name: str # 实验名
  172. exist_ok: bool # 是否允许覆盖同名目录
  173. @app_fastapi.post("/yolov12/train")
  174. def yolov12_train(params: TrainParams):
  175. """
  176. RESTful POST接口:/yolov12/train
  177. 接收训练参数,调用YOLO模型训练,并返回训练结果。
  178. 返回格式:{"code": 0/1, "msg": "success/错误原因", "result": 训练结果或None}
  179. """
  180. logging.info("收到/yolov12/train训练请求")
  181. logging.info(f"请求参数: {params}")
  182. try:
  183. model = YOLO("yolov12.yaml") # 如有yolov12n.yaml可替换
  184. logging.info("开始模型训练...")
  185. results = model.train(
  186. data=params.data,
  187. epochs=params.epochs,
  188. batch=params.batch,
  189. imgsz=params.imgsz,
  190. scale=params.scale,
  191. mosaic=params.mosaic,
  192. mixup=params.mixup,
  193. copy_paste=params.copy_paste,
  194. device=params.device,
  195. project=params.project,
  196. name=params.name,
  197. exist_ok=params.exist_ok,
  198. )
  199. logging.info("模型训练完成")
  200. # logging.info(f"训练结果: {str(results)}")
  201. return {
  202. "code": 0,
  203. "msg": "success",
  204. "result": str(results.save_dir)
  205. }
  206. except Exception as e:
  207. logging.error(f"训练过程发生异常: {e}")
  208. return {
  209. "code": 1,
  210. "msg": str(e),
  211. "result": None
  212. }
  213. # 全局异常处理器:参数校验失败时统一返回格式
  214. @app_fastapi.exception_handler(RequestValidationError)
  215. async def validation_exception_handler(request, exc):
  216. err_msg = f"参数校验失败: 路径={request.url.path}, 错误={exc.errors()}"
  217. logging.error(err_msg)
  218. return JSONResponse(
  219. status_code=status.HTTP_200_OK,
  220. content={
  221. "code": 422,
  222. "msg": err_msg,
  223. "result": None
  224. }
  225. )
  226. if __name__ == "__main__":
  227. threading.Thread(target=start_gradio, daemon=True).start()
  228. uvicorn.run(app_fastapi, host="0.0.0.0", port=8000)