# --------------------------------------------------------
# Based on yolov10
# https://github.com/THU-MIG/yolov10/app.py
# --------------------------------------------------------'
import logging
import tempfile
import threading
import cv2
import gradio as gr
import uvicorn
from fastapi import FastAPI
from fastapi import status
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from ultralytics import YOLO
# 设置日志格式和级别
logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s - %(message)s')
def yolov12_inference(image, video, model_id, image_size, conf_threshold):
model = YOLO(model_id)
if image:
results = model.predict(source=image, imgsz=image_size, conf=conf_threshold)
annotated_image = results[0].plot()
return annotated_image[:, :, ::-1], None
else:
video_path = tempfile.mktemp(suffix=".webm")
with open(video_path, "wb") as f:
with open(video, "rb") as g:
f.write(g.read())
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
output_video_path = tempfile.mktemp(suffix=".webm")
out = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'vp80'), fps, (frame_width, frame_height))
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
results = model.predict(source=frame, imgsz=image_size, conf=conf_threshold)
annotated_frame = results[0].plot()
out.write(annotated_frame)
cap.release()
out.release()
return None, output_video_path
def yolov12_inference_for_examples(image, model_path, image_size, conf_threshold):
annotated_image, _ = yolov12_inference(image, None, model_path, image_size, conf_threshold)
return annotated_image
def app():
with gr.Blocks():
with gr.Row():
with gr.Column():
image = gr.Image(type="pil", label="Image", visible=True)
video = gr.Video(label="Video", visible=False)
input_type = gr.Radio(
choices=["Image", "Video"],
value="Image",
label="Input Type",
)
model_id = gr.Dropdown(
label="Model",
choices=[
"yolov12n.pt",
"yolov12s.pt",
"yolov12m.pt",
"yolov12l.pt",
"yolov12x.pt",
],
value="yolov12m.pt",
)
image_size = gr.Slider(
label="Image Size",
minimum=320,
maximum=1280,
step=32,
value=640,
)
conf_threshold = gr.Slider(
label="Confidence Threshold",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.25,
)
yolov12_infer = gr.Button(value="Detect Objects")
with gr.Column():
output_image = gr.Image(type="numpy", label="Annotated Image", visible=True)
output_video = gr.Video(label="Annotated Video", visible=False)
def update_visibility(input_type):
image = gr.update(visible=True) if input_type == "Image" else gr.update(visible=False)
video = gr.update(visible=False) if input_type == "Image" else gr.update(visible=True)
output_image = gr.update(visible=True) if input_type == "Image" else gr.update(visible=False)
output_video = gr.update(visible=False) if input_type == "Image" else gr.update(visible=True)
return image, video, output_image, output_video
input_type.change(
fn=update_visibility,
inputs=[input_type],
outputs=[image, video, output_image, output_video],
)
def run_inference(image, video, model_id, image_size, conf_threshold, input_type):
if input_type == "Image":
return yolov12_inference(image, None, model_id, image_size, conf_threshold)
else:
return yolov12_inference(None, video, model_id, image_size, conf_threshold)
yolov12_infer.click(
fn=run_inference,
inputs=[image, video, model_id, image_size, conf_threshold, input_type],
outputs=[output_image, output_video],
)
gr.Examples(
examples=[
[
"ultralytics/assets/bus.jpg",
"yolov12s.pt",
640,
0.25,
],
[
"ultralytics/assets/zidane.jpg",
"yolov12x.pt",
640,
0.25,
],
],
fn=yolov12_inference_for_examples,
inputs=[
image,
model_id,
image_size,
conf_threshold,
],
outputs=[output_image],
cache_examples='lazy',
)
gradio_app = gr.Blocks()
with gradio_app:
gr.HTML(
"""
YOLOv12: Attention-Centric Real-Time Object Detectors
""")
gr.HTML(
"""
""")
with gr.Row():
with gr.Column():
app()
def start_gradio():
gradio_app.launch(server_name="0.0.0.0", server_port=7860)
# FastAPI部分
app_fastapi = FastAPI()
class TrainParams(BaseModel):
"""
用于接收/yolov12/train接口的训练参数,所有参数均需前端传入。
"""
model: str # 训练底模
data: str # 数据集配置文件路径
epochs: int # 训练轮数
batch: int # 批次大小
imgsz: int # 输入图片尺寸
scale: float # 随机缩放增强比例
mosaic: float # mosaic数据增强概率
mixup: float # mixup数据增强概率
copy_paste: float # copy-paste数据增强概率
device: str # 训练设备
project: str # 工程名
name: str # 实验名
exist_ok: bool # 是否允许覆盖同名目录
@app_fastapi.post("/yolov12/train")
def yolov12_train(params: TrainParams):
"""
RESTful POST接口:/yolov12/train
接收训练参数,调用YOLO模型训练,并返回训练结果。
返回格式:{"code": 0/1, "msg": "success/错误原因", "result": 训练结果或None}
"""
logging.info("收到/yolov12/train训练请求")
logging.info(f"请求参数: {params}")
try:
# 根据params.model动态确定配置文件
if params.model.endswith('.pt'):
# 如果是.pt文件,将后缀替换为.yaml
config_file = params.model.replace('.pt', '.yaml')
else:
# 如果不是.pt文件,使用默认配置
config_file = "yolov12.yaml"
model = YOLO(config_file)
model.load(params.model)
logging.info("开始模型训练...")
results = model.train(
data=params.data,
epochs=params.epochs,
batch=params.batch,
imgsz=params.imgsz,
scale=params.scale,
mosaic=params.mosaic,
mixup=params.mixup,
copy_paste=params.copy_paste,
device=params.device,
project=params.project,
name=params.name,
exist_ok=params.exist_ok,
)
logging.info("模型训练完成")
# logging.info(f"训练结果: {str(results)}")
return {
"code": 0,
"msg": "success",
"result": str(results.save_dir)
}
except Exception as e:
logging.error(f"训练过程发生异常: {e}")
return {
"code": 1,
"msg": str(e),
"result": None
}
class PredictParams(BaseModel):
"""
用于接收/yolov12/predict接口的预测参数,与YOLO predict方法保持一致。
"""
model: str = "yolov12m.pt" # 模型路径
source: str = None # 输入源(图片/视频路径、URL等)
stream: bool = False # 是否流式处理
conf: float = 0.25 # 置信度阈值
iou: float = 0.7 # IoU阈值
max_det: int = 300 # 最大检测数量
imgsz: int = 640 # 输入图片尺寸
batch: int = 1 # 批次大小
device: str = "" # 设备
show: bool = False # 是否显示结果
save: bool = False # 是否保存结果
save_txt: bool = False # 是否保存txt文件
save_conf: bool = False # 是否保存置信度
save_crop: bool = False # 是否保存裁剪图片
show_labels: bool = True # 是否显示标签
show_conf: bool = True # 是否显示置信度
show_boxes: bool = True # 是否显示边界框
line_width: int = None # 线条宽度
vid_stride: int = 1 # 视频帧步长
stream_buffer: bool = False # 流缓冲区
visualize: bool = False # 可视化特征
augment: bool = False # 数据增强
agnostic_nms: bool = False # 类别无关NMS
classes: list = None # 指定类别
retina_masks: bool = False # 高分辨率分割掩码
embed: list = None # 特征向量层
half: bool = False # 半精度
dnn: bool = False # OpenCV DNN
project: str = "" # 项目名
name: str = "" # 实验名
exist_ok: bool = False # 是否覆盖现有目录
verbose: bool = True # 详细输出
@app_fastapi.post("/yolov12/predict")
def yolov12_predict(params: PredictParams):
"""
RESTful POST接口:/yolov12/predict
接收预测参数,调用YOLO模型进行预测,并返回预测结果。
返回格式:{"code": 0/1, "msg": "success/错误原因", "result": 预测结果或None}
"""
logging.info("收到/yolov12/predict预测请求")
logging.info(f"请求参数: {params}")
try:
model = YOLO(params.model)
logging.info("开始模型预测...")
# 构建预测参数
predict_kwargs = {}
for field, value in params.dict().items():
if field not in ['model'] and value is not None:
predict_kwargs[field] = value
results = model.predict(**predict_kwargs)
logging.info("模型预测完成")
logging.info(f"预测结果: {str(results)}")
return {
"code": 0,
"msg": "success",
"result": results[0].save_dir
}
except Exception as e:
logging.error(f"预测过程发生异常: {e}")
return {
"code": 1,
"msg": str(e),
"result": None
}
# 全局异常处理器:参数校验失败时统一返回格式
@app_fastapi.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc):
err_msg = f"参数校验失败: 路径={request.url.path}, 错误={exc.errors()}"
logging.error(err_msg)
return JSONResponse(
status_code=status.HTTP_200_OK,
content={
"code": 422,
"msg": err_msg,
"result": None
}
)
if __name__ == "__main__":
threading.Thread(target=start_gradio, daemon=True).start()
uvicorn.run(app_fastapi, host="0.0.0.0", port=8000)