| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485 |
- # --------------------------------------------------------
- # Based on yolov10
- # https://github.com/THU-MIG/yolov10/app.py
- # --------------------------------------------------------'
- import logging
- import tempfile
- import threading
- import cv2
- import gradio as gr
- import uvicorn
- from fastapi import FastAPI
- from fastapi import status
- from fastapi.exceptions import RequestValidationError
- from fastapi.responses import JSONResponse
- from pydantic import BaseModel
- from ultralytics import YOLO
- # 设置日志格式和级别
- logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s - %(message)s')
- def yolov12_inference(image, video, model_id, image_size, conf_threshold):
- model = YOLO(model_id)
- if image:
- results = model.predict(source=image, imgsz=image_size, conf=conf_threshold)
- annotated_image = results[0].plot()
- return annotated_image[:, :, ::-1], None
- else:
- video_path = tempfile.mktemp(suffix=".webm")
- with open(video_path, "wb") as f:
- with open(video, "rb") as g:
- f.write(g.read())
- cap = cv2.VideoCapture(video_path)
- fps = cap.get(cv2.CAP_PROP_FPS)
- frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
- frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
- output_video_path = tempfile.mktemp(suffix=".webm")
- out = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'vp80'), fps, (frame_width, frame_height))
- while cap.isOpened():
- ret, frame = cap.read()
- if not ret:
- break
- results = model.predict(source=frame, imgsz=image_size, conf=conf_threshold)
- annotated_frame = results[0].plot()
- out.write(annotated_frame)
- cap.release()
- out.release()
- return None, output_video_path
- def yolov12_inference_for_examples(image, model_path, image_size, conf_threshold):
- annotated_image, _ = yolov12_inference(image, None, model_path, image_size, conf_threshold)
- return annotated_image
- def app():
- with gr.Blocks():
- with gr.Row():
- with gr.Column():
- image = gr.Image(type="pil", label="Image", visible=True)
- video = gr.Video(label="Video", visible=False)
- input_type = gr.Radio(
- choices=["Image", "Video"],
- value="Image",
- label="Input Type",
- )
- model_id = gr.Dropdown(
- label="Model",
- choices=[
- "yolov12n.pt",
- "yolov12s.pt",
- "yolov12m.pt",
- "yolov12l.pt",
- "yolov12x.pt",
- ],
- value="yolov12m.pt",
- )
- image_size = gr.Slider(
- label="Image Size",
- minimum=320,
- maximum=1280,
- step=32,
- value=640,
- )
- conf_threshold = gr.Slider(
- label="Confidence Threshold",
- minimum=0.0,
- maximum=1.0,
- step=0.05,
- value=0.25,
- )
- yolov12_infer = gr.Button(value="Detect Objects")
- with gr.Column():
- output_image = gr.Image(type="numpy", label="Annotated Image", visible=True)
- output_video = gr.Video(label="Annotated Video", visible=False)
- def update_visibility(input_type):
- image = gr.update(visible=True) if input_type == "Image" else gr.update(visible=False)
- video = gr.update(visible=False) if input_type == "Image" else gr.update(visible=True)
- output_image = gr.update(visible=True) if input_type == "Image" else gr.update(visible=False)
- output_video = gr.update(visible=False) if input_type == "Image" else gr.update(visible=True)
- return image, video, output_image, output_video
- input_type.change(
- fn=update_visibility,
- inputs=[input_type],
- outputs=[image, video, output_image, output_video],
- )
- def run_inference(image, video, model_id, image_size, conf_threshold, input_type):
- if input_type == "Image":
- return yolov12_inference(image, None, model_id, image_size, conf_threshold)
- else:
- return yolov12_inference(None, video, model_id, image_size, conf_threshold)
- yolov12_infer.click(
- fn=run_inference,
- inputs=[image, video, model_id, image_size, conf_threshold, input_type],
- outputs=[output_image, output_video],
- )
- gr.Examples(
- examples=[
- [
- "ultralytics/assets/bus.jpg",
- "yolov12s.pt",
- 640,
- 0.25,
- ],
- [
- "ultralytics/assets/zidane.jpg",
- "yolov12x.pt",
- 640,
- 0.25,
- ],
- ],
- fn=yolov12_inference_for_examples,
- inputs=[
- image,
- model_id,
- image_size,
- conf_threshold,
- ],
- outputs=[output_image],
- cache_examples='lazy',
- )
- gradio_app = gr.Blocks()
- with gradio_app:
- gr.HTML(
- """
- <h1 style='text-align: center'>
- YOLOv12: Attention-Centric Real-Time Object Detectors
- </h1>
- """)
- gr.HTML(
- """
- <h3 style='text-align: center'>
- <a href='https://arxiv.org/abs/2502.12524' target='_blank'>arXiv</a> | <a href='https://github.com/sunsmarterjie/yolov12' target='_blank'>github</a>
- </h3>
- """)
- with gr.Row():
- with gr.Column():
- app()
- def start_gradio():
- gradio_app.launch(server_name="0.0.0.0", server_port=7860)
- # FastAPI部分
- app_fastapi = FastAPI()
- class TrainParams(BaseModel):
- """
- 用于接收/yolov12/train接口的训练参数,所有参数均需前端传入。
- """
- model: str # 训练底模
- data: str # 数据集配置文件路径
- epochs: int # 训练轮数
- batch: int # 批次大小
- imgsz: int # 输入图片尺寸
- scale: float # 随机缩放增强比例
- mosaic: float # mosaic数据增强概率
- mixup: float # mixup数据增强概率
- copy_paste: float # copy-paste数据增强概率
- device: str # 训练设备
- project: str # 工程名
- name: str # 实验名
- exist_ok: bool # 是否允许覆盖同名目录
- @app_fastapi.post("/yolov12/train")
- def yolov12_train(params: TrainParams):
- """
- RESTful POST接口:/yolov12/train
- 接收训练参数,调用YOLO模型训练,并返回训练结果。
- 返回格式:{"code": 0/1, "msg": "success/错误原因", "result": 训练结果或None}
- """
- logging.info("收到/yolov12/train训练请求")
- logging.info(f"请求参数: {params}")
- try:
- # 根据params.model动态确定配置文件
- if params.model.endswith('.pt'):
- # 如果是.pt文件,将后缀替换为.yaml
- config_file = params.model.replace('.pt', '.yaml')
- else:
- # 如果不是.pt文件,使用默认配置
- config_file = "yolov12.yaml"
-
- model = YOLO(config_file)
- model.load(params.model)
- logging.info("开始模型训练...")
- results = model.train(
- data=params.data,
- epochs=params.epochs,
- batch=params.batch,
- imgsz=params.imgsz,
- scale=params.scale,
- mosaic=params.mosaic,
- mixup=params.mixup,
- copy_paste=params.copy_paste,
- device=params.device,
- project=params.project,
- name=params.name,
- exist_ok=params.exist_ok,
- )
- logging.info("模型训练完成")
- # logging.info(f"训练结果: {str(results)}")
- return {
- "code": 0,
- "msg": "success",
- "result": str(results.save_dir)
- }
- except Exception as e:
- logging.error(f"训练过程发生异常: {e}")
- return {
- "code": 1,
- "msg": str(e),
- "result": None
- }
- class PredictParams(BaseModel):
- """
- 用于接收/yolov12/predict接口的预测参数,与YOLO predict方法保持一致。
- """
- model: str = "yolov12m.pt" # 模型路径
- source: str = None # 输入源(图片/视频路径、URL等)
- stream: bool = False # 是否流式处理
- conf: float = 0.25 # 置信度阈值
- iou: float = 0.7 # IoU阈值
- max_det: int = 300 # 最大检测数量
- imgsz: int = 640 # 输入图片尺寸
- batch: int = 1 # 批次大小
- device: str = "" # 设备
- show: bool = False # 是否显示结果
- save: bool = False # 是否保存结果
- save_txt: bool = False # 是否保存txt文件
- save_conf: bool = False # 是否保存置信度
- save_crop: bool = False # 是否保存裁剪图片
- show_labels: bool = True # 是否显示标签
- show_conf: bool = True # 是否显示置信度
- show_boxes: bool = True # 是否显示边界框
- line_width: int = None # 线条宽度
- vid_stride: int = 1 # 视频帧步长
- stream_buffer: bool = False # 流缓冲区
- visualize: bool = False # 可视化特征
- augment: bool = False # 数据增强
- agnostic_nms: bool = False # 类别无关NMS
- classes: list = None # 指定类别
- retina_masks: bool = False # 高分辨率分割掩码
- embed: list = None # 特征向量层
- half: bool = False # 半精度
- dnn: bool = False # OpenCV DNN
- project: str = "" # 项目名
- name: str = "" # 实验名
- exist_ok: bool = False # 是否覆盖现有目录
- verbose: bool = True # 详细输出
- @app_fastapi.post("/yolov12/predict")
- def yolov12_predict(params: PredictParams):
- """
- RESTful POST接口:/yolov12/predict
- 接收预测参数,调用YOLO模型进行预测,并返回预测结果。
- 返回格式:{"code": 0/1, "msg": "success/错误原因", "result": {"save_dir": "保存目录", "filename": "文件名"}}
- """
- logging.info("收到/yolov12/predict预测请求")
- logging.info(f"请求参数: {params}")
- try:
- model = YOLO(params.model)
- logging.info("开始模型预测...")
-
- # 构建预测参数
- predict_kwargs = {}
- for field, value in params.dict().items():
- if field not in ['model'] and value is not None:
- predict_kwargs[field] = value
-
- # 确保保存结果
- predict_kwargs['save'] = True
-
- results = model.predict(**predict_kwargs)
- logging.info("模型预测完成")
-
- # 获取保存目录和最终文件名
- result = results[0]
- save_dir = result.save_dir if hasattr(result, 'save_dir') else None
-
- # 获取最终生成的文件名
- final_filename = None
- if save_dir:
- import os
- import glob
- if os.path.exists(save_dir):
- # 检查输入源类型
- source = params.source
- if source:
- source_ext = os.path.splitext(source)[1].lower()
- video_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.webm', '.flv', '.wmv']
-
- # 如果输入是图片,返回图片文件
- if source_ext not in video_extensions:
- image_files = []
- for ext in ['*.jpg', '*.jpeg', '*.png']:
- image_files.extend(glob.glob(os.path.join(save_dir, ext)))
-
- if image_files:
- latest_image = max(image_files, key=os.path.getmtime)
- final_filename = os.path.basename(latest_image)
- logging.info(f"输入为图片,返回图片文件: {final_filename}")
-
- # 如果输入是视频,检查并转换为MP4
- else:
- # 查找所有视频文件
- video_files = []
- for ext in ['*.avi', '*.webm', '*.mov']:
- video_files.extend(glob.glob(os.path.join(save_dir, ext)))
-
- # 如果找到非MP4视频文件,转换为MP4
- for video_file in video_files:
- output_mp4 = video_file.rsplit('.', 1)[0] + '.mp4'
- try:
- import cv2
- cap = cv2.VideoCapture(video_file)
- fps = cap.get(cv2.CAP_PROP_FPS)
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
-
- # 尝试不同的MP4编码器(按兼容性排序)
- fourcc_options = ['avc1', 'H264', 'mp4v']
- out = None
-
- for fourcc in fourcc_options:
- try:
- fourcc_code = cv2.VideoWriter_fourcc(*fourcc)
- out = cv2.VideoWriter(output_mp4, fourcc_code, fps, (width, height))
- if out.isOpened():
- logging.info(f"使用编码器 {fourcc} 创建MP4文件")
- break
- except:
- continue
-
- if out and out.isOpened():
- while cap.isOpened():
- ret, frame = cap.read()
- if not ret:
- break
- out.write(frame)
-
- cap.release()
- out.release()
-
- # 使用ffmpeg进一步优化MP4文件(如果可用)
- try:
- import subprocess
- temp_mp4 = output_mp4 + '.temp.mp4'
- os.rename(output_mp4, temp_mp4)
-
- # 使用ffmpeg重新编码为H.264格式
- cmd = [
- 'ffmpeg', '-i', temp_mp4,
- '-c:v', 'libx264',
- '-preset', 'fast',
- '-crf', '23',
- '-y', output_mp4
- ]
-
- result = subprocess.run(cmd, capture_output=True, text=True)
- if result.returncode == 0:
- os.remove(temp_mp4)
- logging.info(f"使用ffmpeg优化MP4文件: {output_mp4}")
- else:
- # ffmpeg失败,恢复原文件
- os.rename(temp_mp4, output_mp4)
- logging.warning(f"ffmpeg优化失败,使用OpenCV生成的MP4: {output_mp4}")
-
- except (FileNotFoundError, subprocess.SubprocessError) as e:
- # ffmpeg不可用,使用OpenCV生成的MP4
- logging.warning(f"ffmpeg不可用,使用OpenCV生成的MP4: {output_mp4}")
-
- # 删除原文件
- os.remove(video_file)
- logging.info(f"视频已转换为MP4格式: {output_mp4}")
- else:
- # OpenCV编码器失败,尝试使用ffmpeg直接转换
- logging.warning(f"OpenCV编码器失败,尝试使用ffmpeg转换")
- try:
- import subprocess
- cmd = [
- 'ffmpeg', '-i', video_file,
- '-c:v', 'libx264',
- '-preset', 'fast',
- '-crf', '23',
- '-y', output_mp4
- ]
-
- result = subprocess.run(cmd, capture_output=True, text=True)
- if result.returncode == 0:
- os.remove(video_file)
- logging.info(f"使用ffmpeg直接转换MP4文件: {output_mp4}")
- else:
- logging.error(f"ffmpeg转换失败: {result.stderr}")
-
- except (FileNotFoundError, subprocess.SubprocessError) as e:
- logging.error(f"ffmpeg不可用,保持原格式: {e}")
-
- except Exception as e:
- logging.error(f"转换视频格式时出错: {e}")
-
- # 获取MP4文件
- mp4_files = glob.glob(os.path.join(save_dir, "*.mp4"))
- if mp4_files:
- latest_mp4 = max(mp4_files, key=os.path.getmtime)
- final_filename = os.path.basename(latest_mp4)
- logging.info(f"输入为视频,返回MP4文件: {final_filename}")
-
- # 如果无法确定输入类型或未找到文件,返回最新文件
- if not final_filename:
- all_files = []
- for ext in ['*.jpg', '*.jpeg', '*.png', '*.mp4']:
- all_files.extend(glob.glob(os.path.join(save_dir, ext)))
-
- if all_files:
- latest_file = max(all_files, key=os.path.getmtime)
- final_filename = os.path.basename(latest_file)
- logging.info(f"返回最新文件: {final_filename}")
-
- return {
- "code": 0,
- "msg": "success",
- "result": save_dir+"/"+final_filename
- }
- except Exception as e:
- logging.error(f"预测过程发生异常: {e}")
- return {
- "code": 1,
- "msg": str(e),
- "result": None
- }
- # 全局异常处理器:参数校验失败时统一返回格式
- @app_fastapi.exception_handler(RequestValidationError)
- async def validation_exception_handler(request, exc):
- err_msg = f"参数校验失败: 路径={request.url.path}, 错误={exc.errors()}"
- logging.error(err_msg)
- return JSONResponse(
- status_code=status.HTTP_200_OK,
- content={
- "code": 422,
- "msg": err_msg,
- "result": None
- }
- )
- if __name__ == "__main__":
- threading.Thread(target=start_gradio, daemon=True).start()
- uvicorn.run(app_fastapi, host="0.0.0.0", port=8000)
|