app.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. # --------------------------------------------------------
  2. # Based on yolov10
  3. # https://github.com/THU-MIG/yolov10/app.py
  4. # --------------------------------------------------------'
  5. import gradio as gr
  6. import cv2
  7. import tempfile
  8. from ultralytics import YOLO
  9. def yolov12_inference(image, video, model_id, image_size, conf_threshold):
  10. model = YOLO(model_id)
  11. if image:
  12. results = model.predict(source=image, imgsz=image_size, conf=conf_threshold)
  13. annotated_image = results[0].plot()
  14. return annotated_image[:, :, ::-1], None
  15. else:
  16. video_path = tempfile.mktemp(suffix=".webm")
  17. with open(video_path, "wb") as f:
  18. with open(video, "rb") as g:
  19. f.write(g.read())
  20. cap = cv2.VideoCapture(video_path)
  21. fps = cap.get(cv2.CAP_PROP_FPS)
  22. frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  23. frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  24. output_video_path = tempfile.mktemp(suffix=".webm")
  25. out = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'vp80'), fps, (frame_width, frame_height))
  26. while cap.isOpened():
  27. ret, frame = cap.read()
  28. if not ret:
  29. break
  30. results = model.predict(source=frame, imgsz=image_size, conf=conf_threshold)
  31. annotated_frame = results[0].plot()
  32. out.write(annotated_frame)
  33. cap.release()
  34. out.release()
  35. return None, output_video_path
  36. def yolov12_inference_for_examples(image, model_path, image_size, conf_threshold):
  37. annotated_image, _ = yolov12_inference(image, None, model_path, image_size, conf_threshold)
  38. return annotated_image
  39. def app():
  40. with gr.Blocks():
  41. with gr.Row():
  42. with gr.Column():
  43. image = gr.Image(type="pil", label="Image", visible=True)
  44. video = gr.Video(label="Video", visible=False)
  45. input_type = gr.Radio(
  46. choices=["Image", "Video"],
  47. value="Image",
  48. label="Input Type",
  49. )
  50. model_id = gr.Dropdown(
  51. label="Model",
  52. choices=[
  53. "yolov12n.pt",
  54. "yolov12s.pt",
  55. "yolov12m.pt",
  56. "yolov12l.pt",
  57. "yolov12x.pt",
  58. ],
  59. value="yolov12m.pt",
  60. )
  61. image_size = gr.Slider(
  62. label="Image Size",
  63. minimum=320,
  64. maximum=1280,
  65. step=32,
  66. value=640,
  67. )
  68. conf_threshold = gr.Slider(
  69. label="Confidence Threshold",
  70. minimum=0.0,
  71. maximum=1.0,
  72. step=0.05,
  73. value=0.25,
  74. )
  75. yolov12_infer = gr.Button(value="Detect Objects")
  76. with gr.Column():
  77. output_image = gr.Image(type="numpy", label="Annotated Image", visible=True)
  78. output_video = gr.Video(label="Annotated Video", visible=False)
  79. def update_visibility(input_type):
  80. image = gr.update(visible=True) if input_type == "Image" else gr.update(visible=False)
  81. video = gr.update(visible=False) if input_type == "Image" else gr.update(visible=True)
  82. output_image = gr.update(visible=True) if input_type == "Image" else gr.update(visible=False)
  83. output_video = gr.update(visible=False) if input_type == "Image" else gr.update(visible=True)
  84. return image, video, output_image, output_video
  85. input_type.change(
  86. fn=update_visibility,
  87. inputs=[input_type],
  88. outputs=[image, video, output_image, output_video],
  89. )
  90. def run_inference(image, video, model_id, image_size, conf_threshold, input_type):
  91. if input_type == "Image":
  92. return yolov12_inference(image, None, model_id, image_size, conf_threshold)
  93. else:
  94. return yolov12_inference(None, video, model_id, image_size, conf_threshold)
  95. yolov12_infer.click(
  96. fn=run_inference,
  97. inputs=[image, video, model_id, image_size, conf_threshold, input_type],
  98. outputs=[output_image, output_video],
  99. )
  100. gr.Examples(
  101. examples=[
  102. [
  103. "ultralytics/assets/bus.jpg",
  104. "yolov12s.pt",
  105. 640,
  106. 0.25,
  107. ],
  108. [
  109. "ultralytics/assets/zidane.jpg",
  110. "yolov12x.pt",
  111. 640,
  112. 0.25,
  113. ],
  114. ],
  115. fn=yolov12_inference_for_examples,
  116. inputs=[
  117. image,
  118. model_id,
  119. image_size,
  120. conf_threshold,
  121. ],
  122. outputs=[output_image],
  123. cache_examples='lazy',
  124. )
  125. gradio_app = gr.Blocks()
  126. with gradio_app:
  127. gr.HTML(
  128. """
  129. <h1 style='text-align: center'>
  130. YOLOv12: Attention-Centric Real-Time Object Detectors
  131. </h1>
  132. """)
  133. gr.HTML(
  134. """
  135. <h3 style='text-align: center'>
  136. <a href='https://arxiv.org/abs/2502.12524' target='_blank'>arXiv</a> | <a href='https://github.com/sunsmarterjie/yolov12' target='_blank'>github</a>
  137. </h3>
  138. """)
  139. with gr.Row():
  140. with gr.Column():
  141. app()
  142. if __name__ == '__main__':
  143. gradio_app.launch()