Browse Source

直接返回mp4格式文件路径-实时视频

xujunwei 5 tháng trước cách đây
mục cha
commit
8f85883444
1 tập tin đã thay đổi với 83 bổ sung27 xóa
  1. 83 27
      app.py

+ 83 - 27
app.py

@@ -6,7 +6,6 @@
 import logging
 import tempfile
 import threading
-
 import cv2
 import gradio as gr
 import uvicorn
@@ -317,32 +316,29 @@ def yolov12_predict(params: PredictParams):
         # 获取最终生成的文件名
         final_filename = None
         if save_dir:
-            import os
-            import glob
-            if os.path.exists(save_dir):
-                # 获取输入文件名(不含扩展名)
-                source = params.source
-                base_name = None
-                if source:
-                    base_name = os.path.splitext(os.path.basename(source))[0]
-                # 支持的扩展名
-                exts = ['*.jpg', '*.jpeg', '*.png', '*.mp4', '*.webm', '*.avi', '*.mov']
-                matched_files = []
-                for ext in exts:
-                    matched_files.extend(glob.glob(os.path.join(save_dir, ext)))
-                # 按时间排序,查找与输入文件同名的第一个文件
-                if base_name and matched_files:
-                    matched_files = sorted(matched_files, key=os.path.getmtime)
-                    for f in matched_files:
-                        if os.path.splitext(os.path.basename(f))[0] == base_name:
-                            final_filename = os.path.basename(f)
-                            logging.info(f"按输入文件名查找,返回文件: {final_filename}")
-                            break
-                # 如果没找到同名文件,返回最新文件
-                if not final_filename and matched_files:
-                    latest_file = max(matched_files, key=os.path.getmtime)
-                    final_filename = os.path.basename(latest_file)
-                    logging.info(f"未找到同名,返回最新文件: {final_filename}")
+            # 获取输入文件名(不含扩展名)
+            source = params.source
+            base_name = None
+            if source:
+                base_name = os.path.splitext(os.path.basename(source))[0]
+            # 支持的扩展名
+            exts = ['*.jpg', '*.jpeg', '*.png', '*.mp4', '*.webm', '*.avi', '*.mov']
+            matched_files = []
+            for ext in exts:
+                matched_files.extend(glob.glob(os.path.join(save_dir, ext)))
+            # 按时间排序,查找与输入文件同名的第一个文件
+            if base_name and matched_files:
+                matched_files = sorted(matched_files, key=os.path.getmtime)
+                for f in matched_files:
+                    if os.path.splitext(os.path.basename(f))[0] == base_name:
+                        final_filename = os.path.basename(f)
+                        logging.info(f"按输入文件名查找,返回文件: {final_filename}")
+                        break
+            # 如果没找到同名文件,返回最新文件
+            if not final_filename and matched_files:
+                latest_file = max(matched_files, key=os.path.getmtime)
+                final_filename = os.path.basename(latest_file)
+                logging.info(f"未找到同名,返回最新文件: {final_filename}")
         
         return {
             "code": 0,
@@ -357,6 +353,66 @@ def yolov12_predict(params: PredictParams):
             "result": None
         }
 
+class StreamParams(BaseModel):
+    """
+    用于接收 /yolov12/stream 接口的参数。
+    model: 推理模型路径
+    source: 拉流地址(如rtsp/http视频流)
+    stream_url: 推流地址(如rtmp推流地址)
+    其他参数同 predict
+    """
+    model: str = "yolov12m.pt"
+    source: str = None
+    stream_url: str = None
+    conf: float = 0.25
+    iou: float = 0.7
+    imgsz: int = 640
+    device: str = ""
+    # 可根据需要补充更多参数
+
+@app_fastapi.post("/yolov12/stream")
+def yolov12_stream(params: StreamParams):
+    """
+    RESTful POST接口:/yolov12/stream
+    接收视频拉流地址和推流地址,调用YOLO模型推理,将推理后的视频推送到推流地址。
+    返回格式:{"code": 0/1, "msg": "success/错误原因", "result": None}
+    """
+    import cv2
+    import logging
+    logging.info("收到/yolov12/stream请求")
+    logging.info(f"请求参数: {params}")
+    try:
+        model = YOLO(params.model)
+        cap = cv2.VideoCapture(params.source)
+        if not cap.isOpened():
+            return {"code": 1, "msg": f"无法打开视频流: {params.source}", "result": None}
+        fps = cap.get(cv2.CAP_PROP_FPS)
+        width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
+        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
+        # 推流地址通常为rtmp/rtsp等
+        fourcc = cv2.VideoWriter_fourcc(*'flv1') if params.stream_url.startswith('rtmp') else cv2.VideoWriter_fourcc(*'mp4v')
+        out = cv2.VideoWriter(params.stream_url, fourcc, fps if fps > 0 else 25, (width, height))
+        if not out.isOpened():
+            cap.release()
+            return {"code": 1, "msg": f"无法打开推流地址: {params.stream_url}", "result": None}
+        frame_count = 0
+        while cap.isOpened():
+            ret, frame = cap.read()
+            if not ret:
+                break
+            # 推理
+            results = model.predict(source=frame, imgsz=params.imgsz, conf=params.conf, iou=params.iou, device=params.device)
+            annotated_frame = results[0].plot()
+            out.write(annotated_frame)
+            frame_count += 1
+        cap.release()
+        out.release()
+        logging.info(f"推理并推流完成,共处理帧数: {frame_count}")
+        return {"code": 0, "msg": "success", "result": None}
+    except Exception as e:
+        logging.error(f"/yolov12/stream 发生异常: {e}")
+        return {"code": 1, "msg": str(e), "result": None}
+
 # 全局异常处理器:参数校验失败时统一返回格式
 @app_fastapi.exception_handler(RequestValidationError)
 async def validation_exception_handler(request, exc):