|
|
@@ -6,7 +6,6 @@
|
|
|
import logging
|
|
|
import tempfile
|
|
|
import threading
|
|
|
-
|
|
|
import cv2
|
|
|
import gradio as gr
|
|
|
import uvicorn
|
|
|
@@ -317,32 +316,29 @@ def yolov12_predict(params: PredictParams):
|
|
|
# 获取最终生成的文件名
|
|
|
final_filename = None
|
|
|
if save_dir:
|
|
|
- import os
|
|
|
- import glob
|
|
|
- if os.path.exists(save_dir):
|
|
|
- # 获取输入文件名(不含扩展名)
|
|
|
- source = params.source
|
|
|
- base_name = None
|
|
|
- if source:
|
|
|
- base_name = os.path.splitext(os.path.basename(source))[0]
|
|
|
- # 支持的扩展名
|
|
|
- exts = ['*.jpg', '*.jpeg', '*.png', '*.mp4', '*.webm', '*.avi', '*.mov']
|
|
|
- matched_files = []
|
|
|
- for ext in exts:
|
|
|
- matched_files.extend(glob.glob(os.path.join(save_dir, ext)))
|
|
|
- # 按时间排序,查找与输入文件同名的第一个文件
|
|
|
- if base_name and matched_files:
|
|
|
- matched_files = sorted(matched_files, key=os.path.getmtime)
|
|
|
- for f in matched_files:
|
|
|
- if os.path.splitext(os.path.basename(f))[0] == base_name:
|
|
|
- final_filename = os.path.basename(f)
|
|
|
- logging.info(f"按输入文件名查找,返回文件: {final_filename}")
|
|
|
- break
|
|
|
- # 如果没找到同名文件,返回最新文件
|
|
|
- if not final_filename and matched_files:
|
|
|
- latest_file = max(matched_files, key=os.path.getmtime)
|
|
|
- final_filename = os.path.basename(latest_file)
|
|
|
- logging.info(f"未找到同名,返回最新文件: {final_filename}")
|
|
|
+ # 获取输入文件名(不含扩展名)
|
|
|
+ source = params.source
|
|
|
+ base_name = None
|
|
|
+ if source:
|
|
|
+ base_name = os.path.splitext(os.path.basename(source))[0]
|
|
|
+ # 支持的扩展名
|
|
|
+ exts = ['*.jpg', '*.jpeg', '*.png', '*.mp4', '*.webm', '*.avi', '*.mov']
|
|
|
+ matched_files = []
|
|
|
+ for ext in exts:
|
|
|
+ matched_files.extend(glob.glob(os.path.join(save_dir, ext)))
|
|
|
+ # 按时间排序,查找与输入文件同名的第一个文件
|
|
|
+ if base_name and matched_files:
|
|
|
+ matched_files = sorted(matched_files, key=os.path.getmtime)
|
|
|
+ for f in matched_files:
|
|
|
+ if os.path.splitext(os.path.basename(f))[0] == base_name:
|
|
|
+ final_filename = os.path.basename(f)
|
|
|
+ logging.info(f"按输入文件名查找,返回文件: {final_filename}")
|
|
|
+ break
|
|
|
+ # 如果没找到同名文件,返回最新文件
|
|
|
+ if not final_filename and matched_files:
|
|
|
+ latest_file = max(matched_files, key=os.path.getmtime)
|
|
|
+ final_filename = os.path.basename(latest_file)
|
|
|
+ logging.info(f"未找到同名,返回最新文件: {final_filename}")
|
|
|
|
|
|
return {
|
|
|
"code": 0,
|
|
|
@@ -357,6 +353,66 @@ def yolov12_predict(params: PredictParams):
|
|
|
"result": None
|
|
|
}
|
|
|
|
|
|
+class StreamParams(BaseModel):
|
|
|
+ """
|
|
|
+ 用于接收 /yolov12/stream 接口的参数。
|
|
|
+ model: 推理模型路径
|
|
|
+ source: 拉流地址(如rtsp/http视频流)
|
|
|
+ stream_url: 推流地址(如rtmp推流地址)
|
|
|
+ 其他参数同 predict
|
|
|
+ """
|
|
|
+ model: str = "yolov12m.pt"
|
|
|
+ source: str = None
|
|
|
+ stream_url: str = None
|
|
|
+ conf: float = 0.25
|
|
|
+ iou: float = 0.7
|
|
|
+ imgsz: int = 640
|
|
|
+ device: str = ""
|
|
|
+ # 可根据需要补充更多参数
|
|
|
+
|
|
|
+@app_fastapi.post("/yolov12/stream")
|
|
|
+def yolov12_stream(params: StreamParams):
|
|
|
+ """
|
|
|
+ RESTful POST接口:/yolov12/stream
|
|
|
+ 接收视频拉流地址和推流地址,调用YOLO模型推理,将推理后的视频推送到推流地址。
|
|
|
+ 返回格式:{"code": 0/1, "msg": "success/错误原因", "result": None}
|
|
|
+ """
|
|
|
+ import cv2
|
|
|
+ import logging
|
|
|
+ logging.info("收到/yolov12/stream请求")
|
|
|
+ logging.info(f"请求参数: {params}")
|
|
|
+ try:
|
|
|
+ model = YOLO(params.model)
|
|
|
+ cap = cv2.VideoCapture(params.source)
|
|
|
+ if not cap.isOpened():
|
|
|
+ return {"code": 1, "msg": f"无法打开视频流: {params.source}", "result": None}
|
|
|
+ fps = cap.get(cv2.CAP_PROP_FPS)
|
|
|
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
|
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
|
+ # 推流地址通常为rtmp/rtsp等
|
|
|
+ fourcc = cv2.VideoWriter_fourcc(*'flv1') if params.stream_url.startswith('rtmp') else cv2.VideoWriter_fourcc(*'mp4v')
|
|
|
+ out = cv2.VideoWriter(params.stream_url, fourcc, fps if fps > 0 else 25, (width, height))
|
|
|
+ if not out.isOpened():
|
|
|
+ cap.release()
|
|
|
+ return {"code": 1, "msg": f"无法打开推流地址: {params.stream_url}", "result": None}
|
|
|
+ frame_count = 0
|
|
|
+ while cap.isOpened():
|
|
|
+ ret, frame = cap.read()
|
|
|
+ if not ret:
|
|
|
+ break
|
|
|
+ # 推理
|
|
|
+ results = model.predict(source=frame, imgsz=params.imgsz, conf=params.conf, iou=params.iou, device=params.device)
|
|
|
+ annotated_frame = results[0].plot()
|
|
|
+ out.write(annotated_frame)
|
|
|
+ frame_count += 1
|
|
|
+ cap.release()
|
|
|
+ out.release()
|
|
|
+ logging.info(f"推理并推流完成,共处理帧数: {frame_count}")
|
|
|
+ return {"code": 0, "msg": "success", "result": None}
|
|
|
+ except Exception as e:
|
|
|
+ logging.error(f"/yolov12/stream 发生异常: {e}")
|
|
|
+ return {"code": 1, "msg": str(e), "result": None}
|
|
|
+
|
|
|
# 全局异常处理器:参数校验失败时统一返回格式
|
|
|
@app_fastapi.exception_handler(RequestValidationError)
|
|
|
async def validation_exception_handler(request, exc):
|