| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339 |
- # --------------------------------------------------------
- # Based on yolov10
- # https://github.com/THU-MIG/yolov10/app.py
- # --------------------------------------------------------'
- import logging
- import tempfile
- import threading
- import cv2
- import gradio as gr
- import uvicorn
- from fastapi import FastAPI
- from fastapi import status
- from fastapi.exceptions import RequestValidationError
- from fastapi.responses import JSONResponse
- from pydantic import BaseModel
- from ultralytics import YOLO
- # 设置日志格式和级别
- logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s - %(message)s')
- def yolov12_inference(image, video, model_id, image_size, conf_threshold):
- model = YOLO(model_id)
- if image:
- results = model.predict(source=image, imgsz=image_size, conf=conf_threshold)
- annotated_image = results[0].plot()
- return annotated_image[:, :, ::-1], None
- else:
- video_path = tempfile.mktemp(suffix=".webm")
- with open(video_path, "wb") as f:
- with open(video, "rb") as g:
- f.write(g.read())
- cap = cv2.VideoCapture(video_path)
- fps = cap.get(cv2.CAP_PROP_FPS)
- frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
- frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
- output_video_path = tempfile.mktemp(suffix=".webm")
- out = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'vp80'), fps, (frame_width, frame_height))
- while cap.isOpened():
- ret, frame = cap.read()
- if not ret:
- break
- results = model.predict(source=frame, imgsz=image_size, conf=conf_threshold)
- annotated_frame = results[0].plot()
- out.write(annotated_frame)
- cap.release()
- out.release()
- return None, output_video_path
- def yolov12_inference_for_examples(image, model_path, image_size, conf_threshold):
- annotated_image, _ = yolov12_inference(image, None, model_path, image_size, conf_threshold)
- return annotated_image
- def app():
- with gr.Blocks():
- with gr.Row():
- with gr.Column():
- image = gr.Image(type="pil", label="Image", visible=True)
- video = gr.Video(label="Video", visible=False)
- input_type = gr.Radio(
- choices=["Image", "Video"],
- value="Image",
- label="Input Type",
- )
- model_id = gr.Dropdown(
- label="Model",
- choices=[
- "yolov12n.pt",
- "yolov12s.pt",
- "yolov12m.pt",
- "yolov12l.pt",
- "yolov12x.pt",
- ],
- value="yolov12m.pt",
- )
- image_size = gr.Slider(
- label="Image Size",
- minimum=320,
- maximum=1280,
- step=32,
- value=640,
- )
- conf_threshold = gr.Slider(
- label="Confidence Threshold",
- minimum=0.0,
- maximum=1.0,
- step=0.05,
- value=0.25,
- )
- yolov12_infer = gr.Button(value="Detect Objects")
- with gr.Column():
- output_image = gr.Image(type="numpy", label="Annotated Image", visible=True)
- output_video = gr.Video(label="Annotated Video", visible=False)
- def update_visibility(input_type):
- image = gr.update(visible=True) if input_type == "Image" else gr.update(visible=False)
- video = gr.update(visible=False) if input_type == "Image" else gr.update(visible=True)
- output_image = gr.update(visible=True) if input_type == "Image" else gr.update(visible=False)
- output_video = gr.update(visible=False) if input_type == "Image" else gr.update(visible=True)
- return image, video, output_image, output_video
- input_type.change(
- fn=update_visibility,
- inputs=[input_type],
- outputs=[image, video, output_image, output_video],
- )
- def run_inference(image, video, model_id, image_size, conf_threshold, input_type):
- if input_type == "Image":
- return yolov12_inference(image, None, model_id, image_size, conf_threshold)
- else:
- return yolov12_inference(None, video, model_id, image_size, conf_threshold)
- yolov12_infer.click(
- fn=run_inference,
- inputs=[image, video, model_id, image_size, conf_threshold, input_type],
- outputs=[output_image, output_video],
- )
- gr.Examples(
- examples=[
- [
- "ultralytics/assets/bus.jpg",
- "yolov12s.pt",
- 640,
- 0.25,
- ],
- [
- "ultralytics/assets/zidane.jpg",
- "yolov12x.pt",
- 640,
- 0.25,
- ],
- ],
- fn=yolov12_inference_for_examples,
- inputs=[
- image,
- model_id,
- image_size,
- conf_threshold,
- ],
- outputs=[output_image],
- cache_examples='lazy',
- )
- gradio_app = gr.Blocks()
- with gradio_app:
- gr.HTML(
- """
- <h1 style='text-align: center'>
- YOLOv12: Attention-Centric Real-Time Object Detectors
- </h1>
- """)
- gr.HTML(
- """
- <h3 style='text-align: center'>
- <a href='https://arxiv.org/abs/2502.12524' target='_blank'>arXiv</a> | <a href='https://github.com/sunsmarterjie/yolov12' target='_blank'>github</a>
- </h3>
- """)
- with gr.Row():
- with gr.Column():
- app()
- def start_gradio():
- gradio_app.launch(server_name="0.0.0.0", server_port=7860)
- # FastAPI部分
- app_fastapi = FastAPI()
- class TrainParams(BaseModel):
- """
- 用于接收/yolov12/train接口的训练参数,所有参数均需前端传入。
- """
- model: str # 训练底模
- data: str # 数据集配置文件路径
- epochs: int # 训练轮数
- batch: int # 批次大小
- imgsz: int # 输入图片尺寸
- scale: float # 随机缩放增强比例
- mosaic: float # mosaic数据增强概率
- mixup: float # mixup数据增强概率
- copy_paste: float # copy-paste数据增强概率
- device: str # 训练设备
- project: str # 工程名
- name: str # 实验名
- exist_ok: bool # 是否允许覆盖同名目录
- @app_fastapi.post("/yolov12/train")
- def yolov12_train(params: TrainParams):
- """
- RESTful POST接口:/yolov12/train
- 接收训练参数,调用YOLO模型训练,并返回训练结果。
- 返回格式:{"code": 0/1, "msg": "success/错误原因", "result": 训练结果或None}
- """
- logging.info("收到/yolov12/train训练请求")
- logging.info(f"请求参数: {params}")
- try:
- # 根据params.model动态确定配置文件
- if params.model.endswith('.pt'):
- # 如果是.pt文件,将后缀替换为.yaml
- config_file = params.model.replace('.pt', '.yaml')
- else:
- # 如果不是.pt文件,使用默认配置
- config_file = "yolov12.yaml"
-
- model = YOLO(config_file)
- model.load(params.model)
- logging.info("开始模型训练...")
- results = model.train(
- data=params.data,
- epochs=params.epochs,
- batch=params.batch,
- imgsz=params.imgsz,
- scale=params.scale,
- mosaic=params.mosaic,
- mixup=params.mixup,
- copy_paste=params.copy_paste,
- device=params.device,
- project=params.project,
- name=params.name,
- exist_ok=params.exist_ok,
- )
- logging.info("模型训练完成")
- # logging.info(f"训练结果: {str(results)}")
- return {
- "code": 0,
- "msg": "success",
- "result": str(results.save_dir)
- }
- except Exception as e:
- logging.error(f"训练过程发生异常: {e}")
- return {
- "code": 1,
- "msg": str(e),
- "result": None
- }
- class PredictParams(BaseModel):
- """
- 用于接收/yolov12/predict接口的预测参数,与YOLO predict方法保持一致。
- """
- model: str = "yolov12m.pt" # 模型路径
- source: str = None # 输入源(图片/视频路径、URL等)
- stream: bool = False # 是否流式处理
- conf: float = 0.25 # 置信度阈值
- iou: float = 0.7 # IoU阈值
- max_det: int = 300 # 最大检测数量
- imgsz: int = 640 # 输入图片尺寸
- batch: int = 1 # 批次大小
- device: str = "" # 设备
- show: bool = False # 是否显示结果
- save: bool = False # 是否保存结果
- save_txt: bool = False # 是否保存txt文件
- save_conf: bool = False # 是否保存置信度
- save_crop: bool = False # 是否保存裁剪图片
- show_labels: bool = True # 是否显示标签
- show_conf: bool = True # 是否显示置信度
- show_boxes: bool = True # 是否显示边界框
- line_width: int = None # 线条宽度
- vid_stride: int = 1 # 视频帧步长
- stream_buffer: bool = False # 流缓冲区
- visualize: bool = False # 可视化特征
- augment: bool = False # 数据增强
- agnostic_nms: bool = False # 类别无关NMS
- classes: list = None # 指定类别
- retina_masks: bool = False # 高分辨率分割掩码
- embed: list = None # 特征向量层
- half: bool = False # 半精度
- dnn: bool = False # OpenCV DNN
- project: str = "" # 项目名
- name: str = "" # 实验名
- exist_ok: bool = False # 是否覆盖现有目录
- verbose: bool = True # 详细输出
- @app_fastapi.post("/yolov12/predict")
- def yolov12_predict(params: PredictParams):
- """
- RESTful POST接口:/yolov12/predict
- 接收预测参数,调用YOLO模型进行预测,并返回预测结果。
- 返回格式:{"code": 0/1, "msg": "success/错误原因", "result": 预测结果或None}
- """
- logging.info("收到/yolov12/predict预测请求")
- logging.info(f"请求参数: {params}")
- try:
- model = YOLO(params.model)
- logging.info("开始模型预测...")
-
- # 构建预测参数
- predict_kwargs = {}
- for field, value in params.dict().items():
- if field not in ['model'] and value is not None:
- predict_kwargs[field] = value
-
- results = model.predict(**predict_kwargs)
- logging.info("模型预测完成")
- logging.info(f"预测结果: {str(results)}")
- return {
- "code": 0,
- "msg": "success",
- "result": results[0].save_dir
- }
- except Exception as e:
- logging.error(f"预测过程发生异常: {e}")
- return {
- "code": 1,
- "msg": str(e),
- "result": None
- }
- # 全局异常处理器:参数校验失败时统一返回格式
- @app_fastapi.exception_handler(RequestValidationError)
- async def validation_exception_handler(request, exc):
- err_msg = f"参数校验失败: 路径={request.url.path}, 错误={exc.errors()}"
- logging.error(err_msg)
- return JSONResponse(
- status_code=status.HTTP_200_OK,
- content={
- "code": 422,
- "msg": err_msg,
- "result": None
- }
- )
- if __name__ == "__main__":
- threading.Thread(target=start_gradio, daemon=True).start()
- uvicorn.run(app_fastapi, host="0.0.0.0", port=8000)
|