Explorar el Código

直接返回mp4格式文件路径-还原

xujunwei hace 5 meses
padre
commit
96bb875ab8
Se han modificado 1 ficheros con 21 adiciones y 96 borrados
  1. 21 96
      app.py

+ 21 - 96
app.py

@@ -307,20 +307,6 @@ def yolov12_predict(params: PredictParams):
         # 确保保存结果,并强制使用MP4格式
         predict_kwargs['save'] = True
         
-        # 如果输入是视频,强制设置输出格式为MP4
-        source = params.source
-        if source:
-            import os
-            source_ext = os.path.splitext(source)[1].lower()
-            video_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.webm', '.flv', '.wmv']
-            
-            if source_ext in video_extensions:
-                # 对于视频输入,设置项目名和实验名以确保输出路径
-                if not predict_kwargs.get('project'):
-                    predict_kwargs['project'] = 'runs/detect'
-                if not predict_kwargs.get('name'):
-                    predict_kwargs['name'] = 'predict'
-        
         results = model.predict(**predict_kwargs)
         logging.info("模型预测完成")
         
@@ -334,90 +320,29 @@ def yolov12_predict(params: PredictParams):
             import os
             import glob
             if os.path.exists(save_dir):
-                # 检查输入源类型
+                # 获取输入文件名(不含扩展名)
                 source = params.source
+                base_name = None
                 if source:
-                    source_ext = os.path.splitext(source)[1].lower()
-                    video_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.webm', '.flv', '.wmv']
-                    
-                    # 如果输入是图片,返回图片文件
-                    if source_ext not in video_extensions:
-                        image_files = []
-                        for ext in ['*.jpg', '*.jpeg', '*.png']:
-                            image_files.extend(glob.glob(os.path.join(save_dir, ext)))
-                        
-                        if image_files:
-                            latest_image = max(image_files, key=os.path.getmtime)
-                            final_filename = os.path.basename(latest_image)
-                            logging.info(f"输入为图片,返回图片文件: {final_filename}")
-                    
-                    # 如果输入是视频,检查并转换为MP4
-                    else:
-                        # 查找所有视频文件
-                        video_files = []
-                        for ext in ['*.avi', '*.webm', '*.mov']:
-                            video_files.extend(glob.glob(os.path.join(save_dir, ext)))
-                        
-                        # 如果找到非MP4视频文件,转换为MP4
-                        for video_file in video_files:
-                            output_mp4 = video_file.rsplit('.', 1)[0] + '.mp4'
-                            try:
-                                # 直接用 OpenCV 尝试多种编码器
-                                import cv2
-                                cap = cv2.VideoCapture(video_file)
-                                fps = cap.get(cv2.CAP_PROP_FPS)
-                                width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
-                                height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
-                                try_codecs = ['mp4v', 'MJPG', 'H264']  # 优先mp4v
-                                out = None
-                                used_codec = None
-                                for codec in try_codecs:
-                                    fourcc = cv2.VideoWriter_fourcc(*codec)
-                                    out = cv2.VideoWriter(output_mp4, fourcc, fps, (width, height))
-                                    if out.isOpened():
-                                        used_codec = codec
-                                        logging.info(f"使用编码器 {codec} 成功写入MP4")
-                                        break
-                                    else:
-                                        out.release()
-                                        out = None
-                                if out is None:
-                                    raise RuntimeError("没有可用的 MP4 编码器,请安装 ffmpeg 或检查 OpenCV 编码支持。")
-                                while cap.isOpened():
-                                    ret, frame = cap.read()
-                                    if not ret:
-                                        break
-                                    out.write(frame)
-                                cap.release()
-                                out.release()
-                                os.remove(video_file)
-                                logging.info(f"使用OpenCV生成MP4: {output_mp4},编码器: {used_codec}")
-                            except Exception as cv_error:
-                                logging.error(f"OpenCV处理失败: {cv_error}")
-                                
-                            except Exception as e:
-                                logging.error(f"转换视频格式时出错: {e}")
-                        
-                        # 获取MP4或WebM文件
-                        video_output_files = []
-                        for ext in ['*.mp4', '*.webm']:
-                            video_output_files.extend(glob.glob(os.path.join(save_dir, ext)))
-                        
-                        if video_output_files:
-                            latest_video = max(video_output_files, key=os.path.getmtime)
-                            final_filename = os.path.basename(latest_video)
-                            logging.info(f"输入为视频,返回文件: {final_filename}")
-                
-                # 如果无法确定输入类型或未找到文件,返回最新文件
-                if not final_filename:
-                    all_files = []
-                    for ext in ['*.jpg', '*.jpeg', '*.png', '*.mp4']:
-                        all_files.extend(glob.glob(os.path.join(save_dir, ext)))
-                    
-                    if all_files:
-                        latest_file = max(all_files, key=os.path.getmtime)
-                        final_filename = os.path.basename(latest_file)
-                        logging.info(f"返回最新文件: {final_filename}")
+                    base_name = os.path.splitext(os.path.basename(source))[0]
+                # 支持的扩展名
+                exts = ['*.jpg', '*.jpeg', '*.png', '*.mp4', '*.webm', '*.avi', '*.mov']
+                matched_files = []
+                for ext in exts:
+                    matched_files.extend(glob.glob(os.path.join(save_dir, ext)))
+                # 按时间排序,查找与输入文件同名的第一个文件
+                if base_name and matched_files:
+                    matched_files = sorted(matched_files, key=os.path.getmtime)
+                    for f in matched_files:
+                        if os.path.splitext(os.path.basename(f))[0] == base_name:
+                            final_filename = os.path.basename(f)
+                            logging.info(f"按输入文件名查找,返回文件: {final_filename}")
+                            break
+                # 如果没找到同名文件,返回最新文件
+                if not final_filename and matched_files:
+                    latest_file = max(matched_files, key=os.path.getmtime)
+                    final_filename = os.path.basename(latest_file)
+                    logging.info(f"未找到同名,返回最新文件: {final_filename}")
         
         return {
             "code": 0,