| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715 |
- # --------------------------------------------------------
- # Based on yolov10
- # https://github.com/THU-MIG/yolov10/app.py
- # --------------------------------------------------------'
- import logging
- import tempfile
- import threading
- import cv2
- import gradio as gr
- import uvicorn
- import asyncio
- from fastapi import FastAPI
- from fastapi import status
- from fastapi.exceptions import RequestValidationError
- from fastapi.responses import JSONResponse
- from pydantic import BaseModel
- from ultralytics import YOLO
- import os
- import glob
- import subprocess
- import signal
- import time
- from typing import Optional, Dict
- from concurrent.futures import ProcessPoolExecutor, Future
- from functools import partial
- import uuid
- # 设置日志格式和级别
- logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s - %(message)s')
- # 创建进程池执行器和任务管理
- stream_executor = ProcessPoolExecutor(max_workers=4)
- stream_tasks: Dict[str, Future] = {}
- stream_pids: Dict[str, int] = {} # 记录每个任务的worker进程pid
- def yolov12_inference(image, video, model_id, image_size, conf_threshold):
- model = YOLO(model_id)
- if image:
- results = model.predict(source=image, imgsz=image_size, conf=conf_threshold)
- annotated_image = results[0].plot()
- return annotated_image[:, :, ::-1], None
- else:
- video_path = tempfile.mktemp(suffix=".webm")
- with open(video_path, "wb") as f:
- with open(video, "rb") as g:
- f.write(g.read())
- cap = cv2.VideoCapture(video_path)
- fps = cap.get(cv2.CAP_PROP_FPS)
- frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
- frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
- output_video_path = tempfile.mktemp(suffix=".webm")
- out = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'vp80'), fps, (frame_width, frame_height))
- while cap.isOpened():
- ret, frame = cap.read()
- if not ret:
- break
- results = model.predict(source=frame, imgsz=image_size, conf=conf_threshold)
- annotated_frame = results[0].plot()
- out.write(annotated_frame)
- cap.release()
- out.release()
- return None, output_video_path
- def yolov12_inference_for_examples(image, model_path, image_size, conf_threshold):
- annotated_image, _ = yolov12_inference(image, None, model_path, image_size, conf_threshold)
- return annotated_image
- def app():
- with gr.Blocks():
- with gr.Row():
- with gr.Column():
- image = gr.Image(type="pil", label="Image", visible=True)
- video = gr.Video(label="Video", visible=False)
- input_type = gr.Radio(
- choices=["Image", "Video"],
- value="Image",
- label="Input Type",
- )
- model_id = gr.Dropdown(
- label="Model",
- choices=[
- "yolov12n.pt",
- "yolov12s.pt",
- "yolov12m.pt",
- "yolov12l.pt",
- "yolov12x.pt",
- ],
- value="yolov12m.pt",
- )
- image_size = gr.Slider(
- label="Image Size",
- minimum=320,
- maximum=1280,
- step=32,
- value=640,
- )
- conf_threshold = gr.Slider(
- label="Confidence Threshold",
- minimum=0.0,
- maximum=1.0,
- step=0.05,
- value=0.25,
- )
- yolov12_infer = gr.Button(value="Detect Objects")
- with gr.Column():
- output_image = gr.Image(type="numpy", label="Annotated Image", visible=True)
- output_video = gr.Video(label="Annotated Video", visible=False)
- def update_visibility(input_type):
- image = gr.update(visible=True) if input_type == "Image" else gr.update(visible=False)
- video = gr.update(visible=False) if input_type == "Image" else gr.update(visible=True)
- output_image = gr.update(visible=True) if input_type == "Image" else gr.update(visible=False)
- output_video = gr.update(visible=False) if input_type == "Image" else gr.update(visible=True)
- return image, video, output_image, output_video
- input_type.change(
- fn=update_visibility,
- inputs=[input_type],
- outputs=[image, video, output_image, output_video],
- )
- def run_inference(image, video, model_id, image_size, conf_threshold, input_type):
- if input_type == "Image":
- return yolov12_inference(image, None, model_id, image_size, conf_threshold)
- else:
- return yolov12_inference(None, video, model_id, image_size, conf_threshold)
- yolov12_infer.click(
- fn=run_inference,
- inputs=[image, video, model_id, image_size, conf_threshold, input_type],
- outputs=[output_image, output_video],
- )
- gr.Examples(
- examples=[
- [
- "ultralytics/assets/bus.jpg",
- "yolov12s.pt",
- 640,
- 0.25,
- ],
- [
- "ultralytics/assets/zidane.jpg",
- "yolov12x.pt",
- 640,
- 0.25,
- ],
- ],
- fn=yolov12_inference_for_examples,
- inputs=[
- image,
- model_id,
- image_size,
- conf_threshold,
- ],
- outputs=[output_image],
- cache_examples='lazy',
- )
- gradio_app = gr.Blocks()
- with gradio_app:
- gr.HTML(
- """
- <h1 style='text-align: center'>
- YOLOv12: Attention-Centric Real-Time Object Detectors
- </h1>
- """)
- gr.HTML(
- """
- <h3 style='text-align: center'>
- <a href='https://arxiv.org/abs/2502.12524' target='_blank'>arXiv</a> | <a href='https://github.com/sunsmarterjie/yolov12' target='_blank'>github</a>
- </h3>
- """)
- with gr.Row():
- with gr.Column():
- app()
- def start_gradio():
- gradio_app.launch(server_name="0.0.0.0", server_port=7860)
- # FastAPI部分
- app_fastapi = FastAPI()
- class TrainParams(BaseModel):
- """
- 用于接收/yolov12/train接口的训练参数,所有参数均需前端传入。
- """
- model: str # 训练底模
- data: str # 数据集配置文件路径
- epochs: int # 训练轮数
- batch: int # 批次大小
- imgsz: int # 输入图片尺寸
- scale: float # 随机缩放增强比例
- mosaic: float # mosaic数据增强概率
- mixup: float # mixup数据增强概率
- copy_paste: float # copy-paste数据增强概率
- device: str # 训练设备
- project: str # 工程名
- name: str # 实验名
- exist_ok: bool # 是否允许覆盖同名目录
- @app_fastapi.post("/yolov12/train")
- def yolov12_train(params: TrainParams):
- """
- RESTful POST接口:/yolov12/train
- 接收训练参数,调用YOLO模型训练,并返回训练结果。
- 返回格式:{"code": 0/1, "msg": "success/错误原因", "result": 训练结果或None}
- """
- logging.info("收到/yolov12/train训练请求")
- logging.info(f"请求参数: {params}")
- try:
- # 根据params.model动态确定配置文件
- if params.model.endswith('.pt'):
- # 如果是.pt文件,将后缀替换为.yaml
- config_file = params.model.replace('.pt', '.yaml')
- else:
- # 如果不是.pt文件,使用默认配置
- config_file = "yolov12.yaml"
-
- model = YOLO(config_file)
- model.load(params.model)
- logging.info("开始模型训练...")
- results = model.train(
- data=params.data,
- epochs=params.epochs,
- batch=params.batch,
- imgsz=params.imgsz,
- scale=params.scale,
- mosaic=params.mosaic,
- mixup=params.mixup,
- copy_paste=params.copy_paste,
- device=params.device,
- project=params.project,
- name=params.name,
- exist_ok=params.exist_ok,
- )
- logging.info("模型训练完成")
- # logging.info(f"训练结果: {str(results)}")
- return {
- "code": 0,
- "msg": "success",
- "result": str(results.save_dir)
- }
- except Exception as e:
- logging.error(f"训练过程发生异常: {e}")
- return {
- "code": 1,
- "msg": str(e),
- "result": None
- }
- class PredictParams(BaseModel):
- """
- 用于接收/yolov12/predict接口的预测参数,与YOLO predict方法保持一致。
- """
- model: str = "yolov12m.pt" # 模型路径
- source: str = None # 输入源(图片/视频路径、URL等)
- stream: bool = False # 是否流式处理
- conf: float = 0.25 # 置信度阈值
- iou: float = 0.7 # IoU阈值
- max_det: int = 300 # 最大检测数量
- imgsz: int = 640 # 输入图片尺寸
- batch: int = 1 # 批次大小
- device: str = "" # 设备
- show: bool = False # 是否显示结果
- save: bool = False # 是否保存结果
- save_txt: bool = False # 是否保存txt文件
- save_conf: bool = False # 是否保存置信度
- save_crop: bool = False # 是否保存裁剪图片
- show_labels: bool = True # 是否显示标签
- show_conf: bool = True # 是否显示置信度
- show_boxes: bool = True # 是否显示边界框
- line_width: int = None # 线条宽度
- vid_stride: int = 1 # 视频帧步长
- stream_buffer: bool = False # 流缓冲区
- visualize: bool = False # 可视化特征
- augment: bool = False # 数据增强
- agnostic_nms: bool = False # 类别无关NMS
- classes: list = None # 指定类别
- retina_masks: bool = False # 高分辨率分割掩码
- embed: list = None # 特征向量层
- half: bool = False # 半精度
- dnn: bool = False # OpenCV DNN
- project: str = "" # 项目名
- name: str = "" # 实验名
- exist_ok: bool = False # 是否覆盖现有目录
- verbose: bool = True # 详细输出
- @app_fastapi.post("/yolov12/predict")
- def yolov12_predict(params: PredictParams):
- """
- RESTful POST接口:/yolov12/predict
- 接收预测参数,调用YOLO模型进行预测,并返回预测结果。
- 返回格式:{"code": 0/1, "msg": "success/错误原因", "result": {"save_dir": "保存目录", "filename": "文件名"}}
- """
- logging.info("收到/yolov12/predict预测请求")
- logging.info(f"请求参数: {params}")
- try:
- model = YOLO(params.model)
- logging.info("开始模型预测...")
-
- # 构建预测参数
- predict_kwargs = {}
- for field, value in params.dict().items():
- if field not in ['model'] and value is not None:
- predict_kwargs[field] = value
-
- # 确保保存结果,并强制使用MP4格式
- predict_kwargs['save'] = True
-
- results = model.predict(**predict_kwargs)
- logging.info("模型预测完成")
-
- # 获取保存目录和最终文件名
- result = results[0]
- save_dir = result.save_dir if hasattr(result, 'save_dir') else None
-
- # 获取最终生成的文件名
- final_filename = None
- if save_dir:
- # 获取输入文件名(不含扩展名)
- source = params.source
- base_name = None
- if source:
- base_name = os.path.splitext(os.path.basename(source))[0]
- # 支持的扩展名
- exts = ['*.jpg', '*.jpeg', '*.png', '*.mp4', '*.webm', '*.avi', '*.mov']
- matched_files = []
- for ext in exts:
- matched_files.extend(glob.glob(os.path.join(save_dir, ext)))
- # 按时间排序,查找与输入文件同名的第一个文件
- if base_name and matched_files:
- matched_files = sorted(matched_files, key=os.path.getmtime)
- for f in matched_files:
- if os.path.splitext(os.path.basename(f))[0] == base_name:
- final_filename = os.path.basename(f)
- logging.info(f"按输入文件名查找,返回文件: {final_filename}")
- break
- # 如果没找到同名文件,返回最新文件
- if not final_filename and matched_files:
- latest_file = max(matched_files, key=os.path.getmtime)
- final_filename = os.path.basename(latest_file)
- logging.info(f"未找到同名,返回最新文件: {final_filename}")
-
- return {
- "code": 0,
- "msg": "success",
- "result": save_dir+"/"+final_filename
- }
- except Exception as e:
- logging.error(f"预测过程发生异常: {e}")
- return {
- "code": 1,
- "msg": str(e),
- "result": None
- }
- class StreamParams(BaseModel):
- """
- 用于接收 /yolov12/stream 接口的参数。
- model: 推理模型路径
- source: 拉流地址(如rtsp/http视频流)
- stream_url: 推流地址(如rtmp推流地址)
- fps: 输出帧率
- bitrate: 输出码率
- 其他参数同 predict
- """
- model: str = "yolov12m.pt"
- source: str = None
- stream_url: str = None
- fps: int = 25
- bitrate: str = "2000k"
- conf: float = 0.25
- iou: Optional[float] = 0.7
- imgsz: int = 640
- device: str = ""
- # 可根据需要补充更多参数
- def yolov12_stream_worker(params_dict, task_id):
- """
- 同步推理函数,在进程池中执行
- 支持信号终止
- """
- import os
- import cv2
- import time
- import subprocess
- import signal
- from ultralytics import YOLO
- # 注册SIGTERM信号处理器
- def handle_sigterm(signum, frame):
- print(f"任务 {task_id} 收到终止信号,准备退出")
- exit(0)
- signal.signal(signal.SIGTERM, handle_sigterm)
- model_path = params_dict['model']
- source = params_dict['source']
- stream_url = params_dict['stream_url']
- fps = params_dict.get('fps', 25)
- bitrate = params_dict.get('bitrate', '2000k')
- conf = params_dict.get('conf', 0.25)
- iou = params_dict.get('iou', 0.7)
- imgsz = params_dict.get('imgsz', 640)
- device = params_dict.get('device', '')
- # 全局变量用于存储进程引用
- ffmpeg_process = None
-
- def cleanup_process():
- """清理ffmpeg进程"""
- nonlocal ffmpeg_process
- if ffmpeg_process:
- try:
- ffmpeg_process.terminate()
- ffmpeg_process.wait(timeout=5)
- except subprocess.TimeoutExpired:
- ffmpeg_process.kill()
- except Exception as e:
- print(f"清理ffmpeg进程时出错: {e}")
- try:
- model = YOLO(model_path)
- cap = cv2.VideoCapture(source)
- if not cap.isOpened():
- return {"code": 1, "msg": f"无法打开视频流: {source}", "result": None}
-
- # 获取视频流信息
- fps_cap = cap.get(cv2.CAP_PROP_FPS)
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
-
- # 使用实际帧率,如果获取不到则使用参数中的fps
- output_fps = fps_cap if fps_cap > 0 else fps
-
- # 构建ffmpeg命令
- ffmpeg_cmd = [
- 'ffmpeg',
- '-f', 'rawvideo',
- '-pix_fmt', 'bgr24',
- '-s', f'{width}x{height}',
- '-r', str(output_fps),
- '-i', '-', # 从stdin读取
- '-c:v', 'libx264',
- '-preset', 'ultrafast',
- '-tune', 'zerolatency',
- '-b:v', bitrate,
- '-maxrate', bitrate,
- '-bufsize', '4000k',
- '-g', str(int(output_fps) * 2), # GOP大小
- '-f', 'flv' if stream_url.startswith('rtmp') else 'mpegts',
- '-y', # 覆盖输出文件
- stream_url
- ]
-
- print(f"任务 {task_id} 启动ffmpeg命令: {' '.join(ffmpeg_cmd)}")
-
- # 启动ffmpeg进程
- ffmpeg_process = subprocess.Popen(
- ffmpeg_cmd,
- stdin=subprocess.PIPE,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
- bufsize=0
- )
-
- # 等待ffmpeg启动
- time.sleep(1)
-
- if ffmpeg_process.poll() is not None:
- # ffmpeg进程异常退出
- stderr_output = ffmpeg_process.stderr.read().decode() if ffmpeg_process.stderr else "未知错误"
- return {"code": 1, "msg": f"ffmpeg启动失败: {stderr_output}", "result": None}
-
- frame_count = 0
- start_time = time.time()
-
- try:
- while cap.isOpened():
- ret, frame = cap.read()
- if not ret:
- break
-
- # YOLO推理
- try:
- predict_kwargs = {
- 'source': frame,
- 'imgsz': imgsz,
- 'conf': conf,
- 'device': device
- }
-
- # 只有当iou不为None时才添加到参数中
- if iou is not None:
- predict_kwargs['iou'] = iou
-
- results = model.predict(**predict_kwargs)
- annotated_frame = results[0].plot()
- except Exception as predict_error:
- print(f"任务 {task_id} YOLO推理出错: {predict_error}")
- # 如果推理失败,使用原始帧
- annotated_frame = frame
-
- # 将处理后的帧写入ffmpeg
- try:
- ffmpeg_process.stdin.write(annotated_frame.tobytes())
- ffmpeg_process.stdin.flush()
- frame_count += 1
-
- # 每100帧输出一次进度
- if frame_count % 100 == 0:
- elapsed_time = time.time() - start_time
- current_fps = frame_count / elapsed_time
- print(f"任务 {task_id} 已处理 {frame_count} 帧,当前FPS: {current_fps:.2f}")
-
- except IOError as e:
- print(f"任务 {task_id} 写入ffmpeg时出错: {e}")
- break
-
- except KeyboardInterrupt:
- print(f"任务 {task_id} 收到中断信号,停止处理")
- finally:
- # 清理资源
- cap.release()
- cleanup_process()
-
- elapsed_time = time.time() - start_time
- avg_fps = frame_count / elapsed_time if elapsed_time > 0 else 0
-
- print(f"任务 {task_id} 推理并推流完成,共处理帧数: {frame_count},平均FPS: {avg_fps:.2f}")
- return {"code": 0, "msg": "success", "result": {"frames_processed": frame_count, "avg_fps": avg_fps}}
-
- except Exception as e:
- print(f"任务 {task_id} 发生异常: {e}")
- cleanup_process()
- return {"code": 1, "msg": str(e), "result": None}
- @app_fastapi.post("/yolov12/stream")
- async def yolov12_stream_async(params: StreamParams):
- """
- RESTful POST接口:/yolov12/stream
- 接收视频拉流地址和推流地址,调用YOLO模型推理,使用ffmpeg将推理后的视频推送到推流地址。
- 支持并发推理和任务取消。
- 返回格式:{"code": 0/1, "msg": "success/错误原因", "result": {"task_id": "任务ID"}}
- """
- logging.info("收到/yolov12/stream请求")
- logging.info(f"请求参数: {params}")
-
- # 生成唯一任务ID
- task_id = str(uuid.uuid4())
-
- try:
- # 异步执行推理任务
- loop = asyncio.get_event_loop()
- future = loop.run_in_executor(
- stream_executor, partial(yolov12_stream_worker, params.dict(), task_id)
- )
- stream_tasks[task_id] = future
-
- logging.info(f"任务 {task_id} 已提交到进程池")
-
- return {
- "code": 0,
- "msg": "任务已提交",
- "result": task_id
- }
-
- except Exception as e:
- logging.error(f"提交任务时发生异常: {e}")
- return {
- "code": 1,
- "msg": str(e),
- "result": None
- }
- @app_fastapi.post("/yolov12/stream/cancel")
- async def cancel_stream_task(task_id: str):
- """
- RESTful POST接口:/yolov12/stream/cancel
- 取消指定的推理任务
- 返回格式:{"code": 0/1, "msg": "success/错误原因", "result": None}
- """
- logging.info(f"收到取消任务请求: {task_id}")
-
- future = stream_tasks.get(task_id)
- if not future:
- return {"code": 1, "msg": "任务不存在", "result": None}
-
- if future.done():
- return {"code": 1, "msg": "任务已完成,无法取消", "result": None}
-
- try:
- # 尝试取消任务
- cancelled = future.cancel()
- if cancelled:
- logging.info(f"任务 {task_id} 已取消")
- return {"code": 0, "msg": "任务已取消", "result": None}
- else:
- return {"code": 1, "msg": "任务无法取消(可能正在运行)", "result": None}
- except Exception as e:
- logging.error(f"取消任务时发生异常: {e}")
- return {"code": 1, "msg": str(e), "result": None}
- @app_fastapi.get("/yolov12/stream/status")
- async def get_stream_status(task_id: str):
- """
- RESTful GET接口:/yolov12/stream/status
- 查询指定任务的状态
- 返回格式:{"code": 0/1, "msg": "success/错误原因", "result": {"status": "状态", "result": "结果"}}
- """
- logging.info(f"收到查询任务状态请求: {task_id}")
-
- future = stream_tasks.get(task_id)
- if not future:
- return {"code": 1, "msg": "任务不存在", "result": None}
-
- try:
- if future.done():
- try:
- result = future.result()
- return {
- "code": 0,
- "msg": "已完成",
- "result": {"status": "completed", "result": result}
- }
- except Exception as e:
- return {
- "code": 1,
- "msg": f"任务异常: {e}",
- "result": {"status": "failed", "error": str(e)}
- }
- elif future.cancelled():
- return {
- "code": 0,
- "msg": "已取消",
- "result": {"status": "cancelled"}
- }
- else:
- return {
- "code": 0,
- "msg": "运行中",
- "result": {"status": "running"}
- }
- except Exception as e:
- logging.error(f"查询任务状态时发生异常: {e}")
- return {"code": 1, "msg": str(e), "result": None}
- @app_fastapi.get("/yolov12/stream/list")
- async def list_stream_tasks():
- """
- RESTful GET接口:/yolov12/stream/list
- 列出所有任务的状态
- 返回格式:{"code": 0, "msg": "success", "result": {"tasks": [{"task_id": "ID", "status": "状态"}]}}
- """
- logging.info("收到查询所有任务请求")
-
- try:
- tasks_info = []
- for task_id, future in stream_tasks.items():
- if future.done():
- if future.cancelled():
- status = "cancelled"
- else:
- try:
- future.result() # 检查是否有异常
- status = "completed"
- except:
- status = "failed"
- else:
- status = "running"
-
- tasks_info.append({
- "task_id": task_id,
- "status": status
- })
-
- return {
- "code": 0,
- "msg": "success",
- "result": tasks_info
- }
- except Exception as e:
- logging.error(f"查询所有任务时发生异常: {e}")
- return {"code": 1, "msg": str(e), "result": None}
- # 全局异常处理器:参数校验失败时统一返回格式
- @app_fastapi.exception_handler(RequestValidationError)
- async def validation_exception_handler(request, exc):
- err_msg = f"参数校验失败: 路径={request.url.path}, 错误={exc.errors()}"
- logging.error(err_msg)
- return JSONResponse(
- status_code=status.HTTP_200_OK,
- content={
- "code": 422,
- "msg": err_msg,
- "result": None
- }
- )
- if __name__ == "__main__":
- threading.Thread(target=start_gradio, daemon=True).start()
- uvicorn.run(app_fastapi, host="0.0.0.0", port=8000)
|