|
|
@@ -307,20 +307,6 @@ def yolov12_predict(params: PredictParams):
|
|
|
# 确保保存结果,并强制使用MP4格式
|
|
|
predict_kwargs['save'] = True
|
|
|
|
|
|
- # 如果输入是视频,强制设置输出格式为MP4
|
|
|
- source = params.source
|
|
|
- if source:
|
|
|
- import os
|
|
|
- source_ext = os.path.splitext(source)[1].lower()
|
|
|
- video_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.webm', '.flv', '.wmv']
|
|
|
-
|
|
|
- if source_ext in video_extensions:
|
|
|
- # 对于视频输入,设置项目名和实验名以确保输出路径
|
|
|
- if not predict_kwargs.get('project'):
|
|
|
- predict_kwargs['project'] = 'runs/detect'
|
|
|
- if not predict_kwargs.get('name'):
|
|
|
- predict_kwargs['name'] = 'predict'
|
|
|
-
|
|
|
results = model.predict(**predict_kwargs)
|
|
|
logging.info("模型预测完成")
|
|
|
|
|
|
@@ -334,90 +320,29 @@ def yolov12_predict(params: PredictParams):
|
|
|
import os
|
|
|
import glob
|
|
|
if os.path.exists(save_dir):
|
|
|
- # 检查输入源类型
|
|
|
+ # 获取输入文件名(不含扩展名)
|
|
|
source = params.source
|
|
|
+ base_name = None
|
|
|
if source:
|
|
|
- source_ext = os.path.splitext(source)[1].lower()
|
|
|
- video_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.webm', '.flv', '.wmv']
|
|
|
-
|
|
|
- # 如果输入是图片,返回图片文件
|
|
|
- if source_ext not in video_extensions:
|
|
|
- image_files = []
|
|
|
- for ext in ['*.jpg', '*.jpeg', '*.png']:
|
|
|
- image_files.extend(glob.glob(os.path.join(save_dir, ext)))
|
|
|
-
|
|
|
- if image_files:
|
|
|
- latest_image = max(image_files, key=os.path.getmtime)
|
|
|
- final_filename = os.path.basename(latest_image)
|
|
|
- logging.info(f"输入为图片,返回图片文件: {final_filename}")
|
|
|
-
|
|
|
- # 如果输入是视频,检查并转换为MP4
|
|
|
- else:
|
|
|
- # 查找所有视频文件
|
|
|
- video_files = []
|
|
|
- for ext in ['*.avi', '*.webm', '*.mov']:
|
|
|
- video_files.extend(glob.glob(os.path.join(save_dir, ext)))
|
|
|
-
|
|
|
- # 如果找到非MP4视频文件,转换为MP4
|
|
|
- for video_file in video_files:
|
|
|
- output_mp4 = video_file.rsplit('.', 1)[0] + '.mp4'
|
|
|
- try:
|
|
|
- # 直接用 OpenCV 尝试多种编码器
|
|
|
- import cv2
|
|
|
- cap = cv2.VideoCapture(video_file)
|
|
|
- fps = cap.get(cv2.CAP_PROP_FPS)
|
|
|
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
|
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
|
- try_codecs = ['mp4v', 'MJPG', 'H264'] # 优先mp4v
|
|
|
- out = None
|
|
|
- used_codec = None
|
|
|
- for codec in try_codecs:
|
|
|
- fourcc = cv2.VideoWriter_fourcc(*codec)
|
|
|
- out = cv2.VideoWriter(output_mp4, fourcc, fps, (width, height))
|
|
|
- if out.isOpened():
|
|
|
- used_codec = codec
|
|
|
- logging.info(f"使用编码器 {codec} 成功写入MP4")
|
|
|
- break
|
|
|
- else:
|
|
|
- out.release()
|
|
|
- out = None
|
|
|
- if out is None:
|
|
|
- raise RuntimeError("没有可用的 MP4 编码器,请安装 ffmpeg 或检查 OpenCV 编码支持。")
|
|
|
- while cap.isOpened():
|
|
|
- ret, frame = cap.read()
|
|
|
- if not ret:
|
|
|
- break
|
|
|
- out.write(frame)
|
|
|
- cap.release()
|
|
|
- out.release()
|
|
|
- os.remove(video_file)
|
|
|
- logging.info(f"使用OpenCV生成MP4: {output_mp4},编码器: {used_codec}")
|
|
|
- except Exception as cv_error:
|
|
|
- logging.error(f"OpenCV处理失败: {cv_error}")
|
|
|
-
|
|
|
- except Exception as e:
|
|
|
- logging.error(f"转换视频格式时出错: {e}")
|
|
|
-
|
|
|
- # 获取MP4或WebM文件
|
|
|
- video_output_files = []
|
|
|
- for ext in ['*.mp4', '*.webm']:
|
|
|
- video_output_files.extend(glob.glob(os.path.join(save_dir, ext)))
|
|
|
-
|
|
|
- if video_output_files:
|
|
|
- latest_video = max(video_output_files, key=os.path.getmtime)
|
|
|
- final_filename = os.path.basename(latest_video)
|
|
|
- logging.info(f"输入为视频,返回文件: {final_filename}")
|
|
|
-
|
|
|
- # 如果无法确定输入类型或未找到文件,返回最新文件
|
|
|
- if not final_filename:
|
|
|
- all_files = []
|
|
|
- for ext in ['*.jpg', '*.jpeg', '*.png', '*.mp4']:
|
|
|
- all_files.extend(glob.glob(os.path.join(save_dir, ext)))
|
|
|
-
|
|
|
- if all_files:
|
|
|
- latest_file = max(all_files, key=os.path.getmtime)
|
|
|
- final_filename = os.path.basename(latest_file)
|
|
|
- logging.info(f"返回最新文件: {final_filename}")
|
|
|
+ base_name = os.path.splitext(os.path.basename(source))[0]
|
|
|
+ # 支持的扩展名
|
|
|
+ exts = ['*.jpg', '*.jpeg', '*.png', '*.mp4', '*.webm', '*.avi', '*.mov']
|
|
|
+ matched_files = []
|
|
|
+ for ext in exts:
|
|
|
+ matched_files.extend(glob.glob(os.path.join(save_dir, ext)))
|
|
|
+ # 按时间排序,查找与输入文件同名的第一个文件
|
|
|
+ if base_name and matched_files:
|
|
|
+ matched_files = sorted(matched_files, key=os.path.getmtime)
|
|
|
+ for f in matched_files:
|
|
|
+ if os.path.splitext(os.path.basename(f))[0] == base_name:
|
|
|
+ final_filename = os.path.basename(f)
|
|
|
+ logging.info(f"按输入文件名查找,返回文件: {final_filename}")
|
|
|
+ break
|
|
|
+ # 如果没找到同名文件,返回最新文件
|
|
|
+ if not final_filename and matched_files:
|
|
|
+ latest_file = max(matched_files, key=os.path.getmtime)
|
|
|
+ final_filename = os.path.basename(latest_file)
|
|
|
+ logging.info(f"未找到同名,返回最新文件: {final_filename}")
|
|
|
|
|
|
return {
|
|
|
"code": 0,
|