浏览代码

直接返回mp4格式文件路径

xujunwei 5 月之前
父节点
当前提交
e7388f866c
共有 1 个文件被更改,包括 75 次插入3 次删除
  1. 75 3
      app.py

+ 75 - 3
app.py

@@ -290,7 +290,7 @@ def yolov12_predict(params: PredictParams):
     """
     """
     RESTful POST接口:/yolov12/predict
     RESTful POST接口:/yolov12/predict
     接收预测参数,调用YOLO模型进行预测,并返回预测结果。
     接收预测参数,调用YOLO模型进行预测,并返回预测结果。
-    返回格式:{"code": 0/1, "msg": "success/错误原因", "result": 预测结果或None}
+    返回格式:{"code": 0/1, "msg": "success/错误原因", "result": {"save_dir": "保存目录", "filename": "文件名"}}
     """
     """
     logging.info("收到/yolov12/predict预测请求")
     logging.info("收到/yolov12/predict预测请求")
     logging.info(f"请求参数: {params}")
     logging.info(f"请求参数: {params}")
@@ -304,13 +304,85 @@ def yolov12_predict(params: PredictParams):
             if field not in ['model'] and value is not None:
             if field not in ['model'] and value is not None:
                 predict_kwargs[field] = value
                 predict_kwargs[field] = value
         
         
+        # 确保保存结果
+        predict_kwargs['save'] = True
+        
         results = model.predict(**predict_kwargs)
         results = model.predict(**predict_kwargs)
         logging.info("模型预测完成")
         logging.info("模型预测完成")
-        logging.info(f"预测结果: {str(results)}")
+        
+        # 获取保存目录和最终文件名
+        result = results[0]
+        save_dir = result.save_dir if hasattr(result, 'save_dir') else None
+        
+        # 获取最终生成的文件名并转换为MP4
+        final_filename = None
+        if save_dir:
+            import os
+            import glob
+            if os.path.exists(save_dir):
+                # 查找所有视频文件
+                video_files = []
+                for ext in ['*.avi', '*.webm', '*.mov']:
+                    video_files.extend(glob.glob(os.path.join(save_dir, ext)))
+                
+                # 如果找到非MP4视频文件,转换为MP4
+                for video_file in video_files:
+                    output_mp4 = video_file.rsplit('.', 1)[0] + '.mp4'
+                    try:
+                        import cv2
+                        cap = cv2.VideoCapture(video_file)
+                        fps = cap.get(cv2.CAP_PROP_FPS)
+                        width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
+                        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
+                        
+                        # 尝试不同的MP4编码器
+                        fourcc_options = ['mp4v', 'avc1', 'H264']
+                        out = None
+                        
+                        for fourcc in fourcc_options:
+                            try:
+                                fourcc_code = cv2.VideoWriter_fourcc(*fourcc)
+                                out = cv2.VideoWriter(output_mp4, fourcc_code, fps, (width, height))
+                                if out.isOpened():
+                                    logging.info(f"使用编码器 {fourcc} 创建MP4文件")
+                                    break
+                            except:
+                                continue
+                        
+                        if out and out.isOpened():
+                            while cap.isOpened():
+                                ret, frame = cap.read()
+                                if not ret:
+                                    break
+                                out.write(frame)
+                            
+                            cap.release()
+                            out.release()
+                            
+                            # 删除原文件
+                            os.remove(video_file)
+                            logging.info(f"视频已转换为MP4格式: {output_mp4}")
+                        else:
+                            logging.warning(f"无法创建MP4编码器,保持原格式")
+                            
+                    except Exception as e:
+                        logging.error(f"转换视频格式时出错: {e}")
+                
+                # 获取所有文件(包括转换后的MP4)
+                all_files = []
+                for ext in ['*.jpg', '*.jpeg', '*.png', '*.mp4']:
+                    all_files.extend(glob.glob(os.path.join(save_dir, ext)))
+                
+                if all_files:
+                    # 获取最新的文件(按修改时间排序)
+                    latest_file = max(all_files, key=os.path.getmtime)
+                    final_filename = os.path.basename(latest_file)
+                    logging.info(f"最终生成的文件: {final_filename}")
+        
         return {
         return {
             "code": 0,
             "code": 0,
             "msg": "success",
             "msg": "success",
-            "result": results[0].save_dir
+            "result": save_dir+"/"+final_filename
         }
         }
     except Exception as e:
     except Exception as e:
         logging.error(f"预测过程发生异常: {e}")
         logging.error(f"预测过程发生异常: {e}")