app.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. # --------------------------------------------------------
  2. # By Yunjie Tian
  3. # Based on yolov10
  4. # https://github.com/THU-MIG/yolov10
  5. # --------------------------------------------------------'
  6. import gradio as gr
  7. import cv2
  8. import tempfile
  9. from ultralytics import YOLO
  10. def yolov12_inference(image, video, model_id, image_size, conf_threshold):
  11. model = YOLO(model_id)
  12. if image:
  13. results = model.predict(source=image, imgsz=image_size, conf=conf_threshold)
  14. annotated_image = results[0].plot()
  15. return annotated_image[:, :, ::-1], None
  16. else:
  17. video_path = tempfile.mktemp(suffix=".webm")
  18. with open(video_path, "wb") as f:
  19. with open(video, "rb") as g:
  20. f.write(g.read())
  21. cap = cv2.VideoCapture(video_path)
  22. fps = cap.get(cv2.CAP_PROP_FPS)
  23. frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  24. frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  25. output_video_path = tempfile.mktemp(suffix=".webm")
  26. out = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'vp80'), fps, (frame_width, frame_height))
  27. while cap.isOpened():
  28. ret, frame = cap.read()
  29. if not ret:
  30. break
  31. results = model.predict(source=frame, imgsz=image_size, conf=conf_threshold)
  32. annotated_frame = results[0].plot()
  33. out.write(annotated_frame)
  34. cap.release()
  35. out.release()
  36. return None, output_video_path
  37. def yolov12_inference_for_examples(image, model_path, image_size, conf_threshold):
  38. annotated_image, _ = yolov12_inference(image, None, model_path, image_size, conf_threshold)
  39. return annotated_image
  40. def app():
  41. with gr.Blocks():
  42. with gr.Row():
  43. with gr.Column():
  44. image = gr.Image(type="pil", label="Image", visible=True)
  45. video = gr.Video(label="Video", visible=False)
  46. input_type = gr.Radio(
  47. choices=["Image", "Video"],
  48. value="Image",
  49. label="Input Type",
  50. )
  51. model_id = gr.Dropdown(
  52. label="Model",
  53. choices=[
  54. "yolov12n.pt",
  55. "yolov12s.pt",
  56. "yolov12m.pt",
  57. "yolov12l.pt",
  58. "yolov12x.pt",
  59. ],
  60. value="yolov12m.pt",
  61. )
  62. image_size = gr.Slider(
  63. label="Image Size",
  64. minimum=320,
  65. maximum=1280,
  66. step=32,
  67. value=640,
  68. )
  69. conf_threshold = gr.Slider(
  70. label="Confidence Threshold",
  71. minimum=0.0,
  72. maximum=1.0,
  73. step=0.05,
  74. value=0.25,
  75. )
  76. yolov12_infer = gr.Button(value="Detect Objects")
  77. with gr.Column():
  78. output_image = gr.Image(type="numpy", label="Annotated Image", visible=True)
  79. output_video = gr.Video(label="Annotated Video", visible=False)
  80. def update_visibility(input_type):
  81. image = gr.update(visible=True) if input_type == "Image" else gr.update(visible=False)
  82. video = gr.update(visible=False) if input_type == "Image" else gr.update(visible=True)
  83. output_image = gr.update(visible=True) if input_type == "Image" else gr.update(visible=False)
  84. output_video = gr.update(visible=False) if input_type == "Image" else gr.update(visible=True)
  85. return image, video, output_image, output_video
  86. input_type.change(
  87. fn=update_visibility,
  88. inputs=[input_type],
  89. outputs=[image, video, output_image, output_video],
  90. )
  91. def run_inference(image, video, model_id, image_size, conf_threshold, input_type):
  92. if input_type == "Image":
  93. return yolov12_inference(image, None, model_id, image_size, conf_threshold)
  94. else:
  95. return yolov12_inference(None, video, model_id, image_size, conf_threshold)
  96. yolov12_infer.click(
  97. fn=run_inference,
  98. inputs=[image, video, model_id, image_size, conf_threshold, input_type],
  99. outputs=[output_image, output_video],
  100. )
  101. gr.Examples(
  102. examples=[
  103. [
  104. "ultralytics/assets/bus.jpg",
  105. "yolov12s.pt",
  106. 640,
  107. 0.25,
  108. ],
  109. [
  110. "ultralytics/assets/zidane.jpg",
  111. "yolov12x.pt",
  112. 640,
  113. 0.25,
  114. ],
  115. ],
  116. fn=yolov12_inference_for_examples,
  117. inputs=[
  118. image,
  119. model_id,
  120. image_size,
  121. conf_threshold,
  122. ],
  123. outputs=[output_image],
  124. cache_examples='lazy',
  125. )
  126. gradio_app = gr.Blocks()
  127. with gradio_app:
  128. gr.HTML(
  129. """
  130. <h1 style='text-align: center'>
  131. YOLOv12: Attention-Centric Real-Time Object Detectors
  132. </h1>
  133. """)
  134. gr.HTML(
  135. """
  136. <h3 style='text-align: center'>
  137. <a href='https://arxiv.org/abs/2503.xxxxx' target='_blank'>arXiv</a> | <a href='https://github.com/sunsmarterjie/yolov12' target='_blank'>github</a>
  138. </h3>
  139. """)
  140. with gr.Row():
  141. with gr.Column():
  142. app()
  143. if __name__ == '__main__':
  144. gradio_app.launch()