Browse Source

直接返回mp4格式文件路径

xujunwei 5 months atrás
parent
commit
fe2bccb136
1 changed files with 80 additions and 54 deletions
  1. 80 54
      app.py

+ 80 - 54
app.py

@@ -314,70 +314,96 @@ def yolov12_predict(params: PredictParams):
         result = results[0]
         save_dir = result.save_dir if hasattr(result, 'save_dir') else None
         
-        # 获取最终生成的文件名并转换为MP4
+        # 获取最终生成的文件名
         final_filename = None
         if save_dir:
             import os
             import glob
             if os.path.exists(save_dir):
-                # 查找所有视频文件
-                video_files = []
-                for ext in ['*.avi', '*.webm', '*.mov']:
-                    video_files.extend(glob.glob(os.path.join(save_dir, ext)))
-                
-                # 如果找到非MP4视频文件,转换为MP4
-                for video_file in video_files:
-                    output_mp4 = video_file.rsplit('.', 1)[0] + '.mp4'
-                    try:
-                        import cv2
-                        cap = cv2.VideoCapture(video_file)
-                        fps = cap.get(cv2.CAP_PROP_FPS)
-                        width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
-                        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
+                # 检查输入源类型
+                source = params.source
+                if source:
+                    source_ext = os.path.splitext(source)[1].lower()
+                    video_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.webm', '.flv', '.wmv']
+                    
+                    # 如果输入是图片,返回图片文件
+                    if source_ext not in video_extensions:
+                        image_files = []
+                        for ext in ['*.jpg', '*.jpeg', '*.png']:
+                            image_files.extend(glob.glob(os.path.join(save_dir, ext)))
                         
-                        # 尝试不同的MP4编码器
-                        fourcc_options = ['mp4v', 'avc1', 'H264']
-                        out = None
+                        if image_files:
+                            latest_image = max(image_files, key=os.path.getmtime)
+                            final_filename = os.path.basename(latest_image)
+                            logging.info(f"输入为图片,返回图片文件: {final_filename}")
+                    
+                    # 如果输入是视频,检查并转换为MP4
+                    else:
+                        # 查找所有视频文件
+                        video_files = []
+                        for ext in ['*.avi', '*.webm', '*.mov']:
+                            video_files.extend(glob.glob(os.path.join(save_dir, ext)))
                         
-                        for fourcc in fourcc_options:
+                        # 如果找到非MP4视频文件,转换为MP4
+                        for video_file in video_files:
+                            output_mp4 = video_file.rsplit('.', 1)[0] + '.mp4'
                             try:
-                                fourcc_code = cv2.VideoWriter_fourcc(*fourcc)
-                                out = cv2.VideoWriter(output_mp4, fourcc_code, fps, (width, height))
-                                if out.isOpened():
-                                    logging.info(f"使用编码器 {fourcc} 创建MP4文件")
-                                    break
-                            except:
-                                continue
+                                import cv2
+                                cap = cv2.VideoCapture(video_file)
+                                fps = cap.get(cv2.CAP_PROP_FPS)
+                                width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
+                                height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
+                                
+                                # 尝试不同的MP4编码器
+                                fourcc_options = ['mp4v', 'avc1', 'H264']
+                                out = None
+                                
+                                for fourcc in fourcc_options:
+                                    try:
+                                        fourcc_code = cv2.VideoWriter_fourcc(*fourcc)
+                                        out = cv2.VideoWriter(output_mp4, fourcc_code, fps, (width, height))
+                                        if out.isOpened():
+                                            logging.info(f"使用编码器 {fourcc} 创建MP4文件")
+                                            break
+                                    except:
+                                        continue
+                                
+                                if out and out.isOpened():
+                                    while cap.isOpened():
+                                        ret, frame = cap.read()
+                                        if not ret:
+                                            break
+                                        out.write(frame)
+                                    
+                                    cap.release()
+                                    out.release()
+                                    
+                                    # 删除原文件
+                                    os.remove(video_file)
+                                    logging.info(f"视频已转换为MP4格式: {output_mp4}")
+                                else:
+                                    logging.warning(f"无法创建MP4编码器,保持原格式")
+                                    
+                            except Exception as e:
+                                logging.error(f"转换视频格式时出错: {e}")
                         
-                        if out and out.isOpened():
-                            while cap.isOpened():
-                                ret, frame = cap.read()
-                                if not ret:
-                                    break
-                                out.write(frame)
-                            
-                            cap.release()
-                            out.release()
-                            
-                            # 删除原文件
-                            os.remove(video_file)
-                            logging.info(f"视频已转换为MP4格式: {output_mp4}")
-                        else:
-                            logging.warning(f"无法创建MP4编码器,保持原格式")
-                            
-                    except Exception as e:
-                        logging.error(f"转换视频格式时出错: {e}")
-                
-                # 获取所有文件(包括转换后的MP4)
-                all_files = []
-                for ext in ['*.jpg', '*.jpeg', '*.png', '*.mp4']:
-                    all_files.extend(glob.glob(os.path.join(save_dir, ext)))
+                        # 获取MP4文件
+                        mp4_files = glob.glob(os.path.join(save_dir, "*.mp4"))
+                        if mp4_files:
+                            latest_mp4 = max(mp4_files, key=os.path.getmtime)
+                            final_filename = os.path.basename(latest_mp4)
+                            logging.info(f"输入为视频,返回MP4文件: {final_filename}")
                 
-                if all_files:
-                    # 获取最新的文件(按修改时间排序)
-                    latest_file = max(all_files, key=os.path.getmtime)
-                    final_filename = os.path.basename(latest_file)
-                    logging.info(f"最终生成的文件: {final_filename}")
+                # 如果无法确定输入类型或未找到文件,返回最新文件
+                if not final_filename:
+                    all_files = []
+                    for ext in ['*.jpg', '*.jpeg', '*.png', '*.mp4']:
+                        all_files.extend(glob.glob(os.path.join(save_dir, ext)))
+                    
+                    if all_files:
+                        latest_file = max(all_files, key=os.path.getmtime)
+                        final_filename = os.path.basename(latest_file)
+                        logging.info(f"返回最新文件: {final_filename}")
         
         return {
             "code": 0,