| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004 |
- # --------------------------------------------------------
- # Based on yolov10
- # https://github.com/THU-MIG/yolov10/app.py
- # --------------------------------------------------------'
- import logging
- import tempfile
- import threading
- import cv2
- import gradio as gr
- import uvicorn
- import asyncio
- from fastapi import FastAPI
- from fastapi import status
- from fastapi.exceptions import RequestValidationError
- from fastapi.responses import JSONResponse
- from pydantic import BaseModel
- from ultralytics import YOLO
- import os
- import glob
- import subprocess
- import signal
- import time
- from typing import Optional, Dict
- from concurrent.futures import ProcessPoolExecutor, Future
- from functools import partial
- import uuid
- # 设置日志格式和级别
- logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s - %(message)s')
- # 创建进程池执行器和任务管理
- stream_executor = ProcessPoolExecutor(max_workers=4)
- stream_tasks: Dict[str, Dict] = {} # 存储任务信息:{task_id: {"future": Future, "start_time": long, "source": str, "stream_url": str}}
- stream_pids: Dict[str, int] = {} # 记录每个任务的worker进程pid
- def cleanup_completed_tasks():
- """清理已完成的任务"""
- completed_tasks = []
- for task_id, task_info in stream_tasks.items():
- if task_info["future"].done():
- completed_tasks.append(task_id)
-
- for task_id in completed_tasks:
- del stream_tasks[task_id]
- if task_id in stream_pids:
- del stream_pids[task_id]
- print(f"已清理完成的任务: {task_id}")
- def yolov12_inference(image, video, model_id, image_size, conf_threshold):
- model = YOLO(model_id)
- if image:
- results = model.predict(source=image, imgsz=image_size, conf=conf_threshold)
- annotated_image = results[0].plot()
- return annotated_image[:, :, ::-1], None
- else:
- video_path = tempfile.mktemp(suffix=".webm")
- with open(video_path, "wb") as f:
- with open(video, "rb") as g:
- f.write(g.read())
- cap = cv2.VideoCapture(video_path)
- fps = cap.get(cv2.CAP_PROP_FPS)
- frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
- frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
- output_video_path = tempfile.mktemp(suffix=".webm")
- out = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'vp80'), fps, (frame_width, frame_height))
- while cap.isOpened():
- ret, frame = cap.read()
- if not ret:
- break
- results = model.predict(source=frame, imgsz=image_size, conf=conf_threshold)
- annotated_frame = results[0].plot()
- out.write(annotated_frame)
- cap.release()
- out.release()
- return None, output_video_path
- def yolov12_inference_for_examples(image, model_path, image_size, conf_threshold):
- annotated_image, _ = yolov12_inference(image, None, model_path, image_size, conf_threshold)
- return annotated_image
- def app():
- with gr.Blocks():
- with gr.Row():
- with gr.Column():
- image = gr.Image(type="pil", label="Image", visible=True)
- video = gr.Video(label="Video", visible=False)
- input_type = gr.Radio(
- choices=["Image", "Video"],
- value="Image",
- label="Input Type",
- )
- model_id = gr.Dropdown(
- label="Model",
- choices=[
- "yolov12n.pt",
- "yolov12s.pt",
- "yolov12m.pt",
- "yolov12l.pt",
- "yolov12x.pt",
- ],
- value="yolov12m.pt",
- )
- image_size = gr.Slider(
- label="Image Size",
- minimum=320,
- maximum=1280,
- step=32,
- value=640,
- )
- conf_threshold = gr.Slider(
- label="Confidence Threshold",
- minimum=0.0,
- maximum=1.0,
- step=0.05,
- value=0.25,
- )
- yolov12_infer = gr.Button(value="Detect Objects")
- with gr.Column():
- output_image = gr.Image(type="numpy", label="Annotated Image", visible=True)
- output_video = gr.Video(label="Annotated Video", visible=False)
- def update_visibility(input_type):
- image = gr.update(visible=True) if input_type == "Image" else gr.update(visible=False)
- video = gr.update(visible=False) if input_type == "Image" else gr.update(visible=True)
- output_image = gr.update(visible=True) if input_type == "Image" else gr.update(visible=False)
- output_video = gr.update(visible=False) if input_type == "Image" else gr.update(visible=True)
- return image, video, output_image, output_video
- input_type.change(
- fn=update_visibility,
- inputs=[input_type],
- outputs=[image, video, output_image, output_video],
- )
- def run_inference(image, video, model_id, image_size, conf_threshold, input_type):
- if input_type == "Image":
- return yolov12_inference(image, None, model_id, image_size, conf_threshold)
- else:
- return yolov12_inference(None, video, model_id, image_size, conf_threshold)
- yolov12_infer.click(
- fn=run_inference,
- inputs=[image, video, model_id, image_size, conf_threshold, input_type],
- outputs=[output_image, output_video],
- )
- gr.Examples(
- examples=[
- [
- "ultralytics/assets/bus.jpg",
- "yolov12s.pt",
- 640,
- 0.25,
- ],
- [
- "ultralytics/assets/zidane.jpg",
- "yolov12x.pt",
- 640,
- 0.25,
- ],
- ],
- fn=yolov12_inference_for_examples,
- inputs=[
- image,
- model_id,
- image_size,
- conf_threshold,
- ],
- outputs=[output_image],
- cache_examples='lazy',
- )
- gradio_app = gr.Blocks()
- with gradio_app:
- gr.HTML(
- """
- <h1 style='text-align: center'>
- YOLOv12: Attention-Centric Real-Time Object Detectors
- </h1>
- """)
- gr.HTML(
- """
- <h3 style='text-align: center'>
- <a href='https://arxiv.org/abs/2502.12524' target='_blank'>arXiv</a> | <a href='https://github.com/sunsmarterjie/yolov12' target='_blank'>github</a>
- </h3>
- """)
- with gr.Row():
- with gr.Column():
- app()
- def start_gradio():
- gradio_app.launch(server_name="0.0.0.0", server_port=7860)
- # FastAPI部分
- app_fastapi = FastAPI()
- class TrainParams(BaseModel):
- """
- 用于接收/yolov12/train接口的训练参数,所有参数均需前端传入。
- """
- model: str # 训练底模
- data: str # 数据集配置文件路径
- epochs: int # 训练轮数
- batch: int # 批次大小
- imgsz: int # 输入图片尺寸
- scale: float # 随机缩放增强比例
- mosaic: float # mosaic数据增强概率
- mixup: float # mixup数据增强概率
- copy_paste: float # copy-paste数据增强概率
- device: str # 训练设备
- project: str # 工程名
- name: str # 实验名
- exist_ok: bool # 是否允许覆盖同名目录
- @app_fastapi.post("/yolov12/train")
- def yolov12_train(params: TrainParams):
- """
- RESTful POST接口:/yolov12/train
- 接收训练参数,调用YOLO模型训练,并返回训练结果。
- 返回格式:{"code": 0/1, "msg": "success/错误原因", "result": 训练结果或None}
- """
- logging.info("收到/yolov12/train训练请求")
- logging.info(f"请求参数: {params}")
- try:
- # 根据params.model动态确定配置文件
- if params.model.endswith('.pt'):
- # 如果是.pt文件,将后缀替换为.yaml
- config_file = params.model.replace('.pt', '.yaml')
- else:
- # 如果不是.pt文件,使用默认配置
- config_file = "yolov12.yaml"
-
- model = YOLO(config_file)
- model.load(params.model)
- logging.info("开始模型训练...")
- results = model.train(
- data=params.data,
- epochs=params.epochs,
- batch=params.batch,
- imgsz=params.imgsz,
- scale=params.scale,
- mosaic=params.mosaic,
- mixup=params.mixup,
- copy_paste=params.copy_paste,
- device=params.device,
- project=params.project,
- name=params.name,
- exist_ok=params.exist_ok,
- )
- logging.info("模型训练完成")
- # logging.info(f"训练结果: {str(results)}")
- return {
- "code": 0,
- "msg": "success",
- "result": str(results.save_dir)
- }
- except Exception as e:
- logging.error(f"训练过程发生异常: {e}")
- return {
- "code": 1,
- "msg": str(e),
- "result": None
- }
- class PredictParams(BaseModel):
- """
- 用于接收/yolov12/predict接口的预测参数,与YOLO predict方法保持一致。
- """
- model: str = "yolov12m.pt" # 模型路径
- source: str = None # 输入源(图片/视频路径、URL等)
- stream: bool = False # 是否流式处理
- conf: float = 0.25 # 置信度阈值
- iou: float = 0.7 # IoU阈值
- max_det: int = 300 # 最大检测数量
- imgsz: int = 640 # 输入图片尺寸
- batch: int = 1 # 批次大小
- device: str = "" # 设备
- show: bool = False # 是否显示结果
- save: bool = False # 是否保存结果
- save_txt: bool = False # 是否保存txt文件
- save_conf: bool = False # 是否保存置信度
- save_crop: bool = False # 是否保存裁剪图片
- show_labels: bool = True # 是否显示标签
- show_conf: bool = True # 是否显示置信度
- show_boxes: bool = True # 是否显示边界框
- line_width: int = None # 线条宽度
- vid_stride: int = 1 # 视频帧步长
- stream_buffer: bool = False # 流缓冲区
- visualize: bool = False # 可视化特征
- augment: bool = False # 数据增强
- agnostic_nms: bool = False # 类别无关NMS
- classes: list = None # 指定类别
- retina_masks: bool = False # 高分辨率分割掩码
- embed: list = None # 特征向量层
- half: bool = False # 半精度
- dnn: bool = False # OpenCV DNN
- project: str = "" # 项目名
- name: str = "" # 实验名
- exist_ok: bool = False # 是否覆盖现有目录
- verbose: bool = True # 详细输出
- @app_fastapi.post("/yolov12/predict")
- def yolov12_predict(params: PredictParams):
- """
- RESTful POST接口:/yolov12/predict
- 接收预测参数,调用YOLO模型进行预测,并返回预测结果。
- 返回格式:{"code": 0/1, "msg": "success/错误原因", "result": {"save_dir": "保存目录", "filename": "文件名"}}
- """
- logging.info("收到/yolov12/predict预测请求")
- logging.info(f"请求参数: {params}")
- try:
- model = YOLO(params.model)
- logging.info("开始模型预测...")
-
- # 构建预测参数
- predict_kwargs = {}
- for field, value in params.dict().items():
- if field not in ['model'] and value is not None:
- predict_kwargs[field] = value
-
- # 确保保存结果,并强制使用MP4格式
- predict_kwargs['save'] = True
-
- results = model.predict(**predict_kwargs)
- logging.info("模型预测完成")
-
- # 获取保存目录和最终文件名
- result = results[0]
- save_dir = result.save_dir if hasattr(result, 'save_dir') else None
-
- # 获取最终生成的文件名
- final_filename = None
- if save_dir:
- # 获取输入文件名(不含扩展名)
- source = params.source
- base_name = None
- if source:
- base_name = os.path.splitext(os.path.basename(source))[0]
- # 支持的扩展名
- exts = ['*.jpg', '*.jpeg', '*.png', '*.mp4', '*.webm', '*.avi', '*.mov']
- matched_files = []
- for ext in exts:
- matched_files.extend(glob.glob(os.path.join(save_dir, ext)))
- # 按时间排序,查找与输入文件同名的第一个文件
- if base_name and matched_files:
- matched_files = sorted(matched_files, key=os.path.getmtime)
- for f in matched_files:
- if os.path.splitext(os.path.basename(f))[0] == base_name:
- final_filename = os.path.basename(f)
- logging.info(f"按输入文件名查找,返回文件: {final_filename}")
- break
- # 如果没找到同名文件,返回最新文件
- if not final_filename and matched_files:
- latest_file = max(matched_files, key=os.path.getmtime)
- final_filename = os.path.basename(latest_file)
- logging.info(f"未找到同名,返回最新文件: {final_filename}")
-
- return {
- "code": 0,
- "msg": "success",
- "result": save_dir+"/"+final_filename
- }
- except Exception as e:
- logging.error(f"预测过程发生异常: {e}")
- return {
- "code": 1,
- "msg": str(e),
- "result": None
- }
- class StreamParams(BaseModel):
- """
- 用于接收 /yolov12/stream 接口的参数。
- model: 推理模型路径
- source: 拉流地址(如rtsp/http视频流)
- stream_url: 推流地址(如rtmp推流地址)
- fps: 输出帧率
- bitrate: 输出码率
- save_local: 是否保存到本地文件(用于调试)
- 其他参数同 predict
- """
- model: str = "yolov12m.pt"
- source: str = None
- stream_url: str = None
- fps: int = 25
- bitrate: str = "6000k" # 提高默认码率
- conf: float = 0.25
- iou: Optional[float] = 0.7
- imgsz: int = 640
- device: str = ""
- save_local: bool = False # 是否保存到本地文件
- skip_connectivity_test: bool = True # 是否跳过连通性测试
- # 可根据需要补充更多参数
- def yolov12_stream_worker(params_dict, task_id):
- """
- 同步推理函数,在进程池中执行
- 支持信号终止
- """
- import os
- import cv2
- import time
- import subprocess
- import signal
- from ultralytics import YOLO
- # 注册SIGTERM信号处理器
- def handle_sigterm(signum, frame):
- print(f"任务 {task_id} 收到终止信号,准备退出")
- exit(0)
- signal.signal(signal.SIGTERM, handle_sigterm)
- model_path = params_dict['model']
- source = params_dict['source']
- stream_url = params_dict['stream_url']
- fps = params_dict.get('fps', 25)
- bitrate = params_dict.get('bitrate', '2000k')
- conf = params_dict.get('conf', 0.25)
- iou = params_dict.get('iou', 0.7)
- imgsz = params_dict.get('imgsz', 640)
- device = params_dict.get('device', '')
- # 全局变量用于存储进程引用
- ffmpeg_process = None
- max_retries = 3
- retry_count = 0
- def cleanup_process():
- """清理ffmpeg进程"""
- nonlocal ffmpeg_process
- if ffmpeg_process:
- try:
- ffmpeg_process.terminate()
- ffmpeg_process.wait(timeout=5)
- except subprocess.TimeoutExpired:
- ffmpeg_process.kill()
- except Exception as e:
- print(f"清理ffmpeg进程时出错: {e}")
- def start_ffmpeg():
- """启动ffmpeg进程"""
- nonlocal ffmpeg_process, retry_count, ffmpeg_cmd
- try:
- print(f"任务 {task_id} 启动ffmpeg命令: {' '.join(ffmpeg_cmd)}")
- ffmpeg_process = subprocess.Popen(
- ffmpeg_cmd,
- stdin=subprocess.PIPE,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
- bufsize=0
- )
- # 等待ffmpeg启动
- time.sleep(1)
- if ffmpeg_process.poll() is not None:
- # ffmpeg进程异常退出
- stderr_output = ffmpeg_process.stderr.read().decode() if ffmpeg_process.stderr else "未知错误"
- raise Exception(f"ffmpeg启动失败: {stderr_output}")
- retry_count = 0 # 重置重试计数
- return True
- except Exception as e:
- print(f"任务 {task_id} 启动ffmpeg失败: {e}")
- retry_count += 1
- # 如果是第一次失败且当前使用的是简单配置,尝试切换到高级配置
- if retry_count == 1 and ffmpeg_cmd == simple_ffmpeg_cmd:
- print(f"任务 {task_id} 简单配置失败,尝试使用高级配置")
- ffmpeg_cmd = advanced_ffmpeg_cmd
- return start_ffmpeg()
- if retry_count < max_retries:
- print(f"任务 {task_id} 尝试重试 ({retry_count}/{max_retries})")
- time.sleep(2) # 等待2秒后重试
- return start_ffmpeg()
- else:
- print(f"任务 {task_id} ffmpeg启动失败,已达到最大重试次数")
- return False
- try:
- # 测试推流地址连通性
- def test_stream_url():
- """测试推流地址是否可达"""
- try:
- print(f"任务 {task_id} 开始测试推流地址: {stream_url}")
- # 解析RTMP URL
- if stream_url.startswith('rtmp://'):
- # 提取服务器地址和端口
- url_parts = stream_url.replace('rtmp://', '').split('/')
- server_part = url_parts[0]
- server_host = server_part.split(':')[0] if ':' in server_part else server_part
- server_port = server_part.split(':')[1] if ':' in server_part else '1935'
- print(f"任务 {task_id} 解析的服务器信息: {server_host}:{server_port}")
- # 先测试网络连通性
- import socket
- try:
- sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- sock.settimeout(5)
- result = sock.connect_ex((server_host, int(server_port)))
- sock.close()
- if result != 0:
- print(f"任务 {task_id} 网络连通性测试失败: {server_host}:{server_port}")
- return False
- else:
- print(f"任务 {task_id} 网络连通性测试成功: {server_host}:{server_port}")
- except Exception as net_error:
- print(f"任务 {task_id} 网络测试异常: {net_error}")
- return False
- # 使用ffprobe测试推流地址
- test_cmd = [
- 'ffprobe',
- '-v', 'error',
- '-print_format', 'json',
- '-show_format',
- '-timeout', '5000000', # 5秒超时
- stream_url
- ]
- print(f"任务 {task_id} 执行ffprobe命令: {' '.join(test_cmd)}")
- result = subprocess.run(test_cmd, capture_output=True, text=True, timeout=10)
- if result.returncode == 0:
- print(f"任务 {task_id} ffprobe测试成功")
- return True
- else:
- print(f"任务 {task_id} ffprobe测试失败,返回码: {result.returncode}")
- print(f"任务 {task_id} ffprobe错误输出: {result.stderr}")
- return False
- except subprocess.TimeoutExpired:
- print(f"任务 {task_id} ffprobe测试超时")
- return False
- except Exception as e:
- print(f"任务 {task_id} 推流地址测试异常: {e}")
- return False
- # 如果是RTMP推流且不是保存到本地,先测试连通性
- if (stream_url.startswith('rtmp://') and
- not params_dict.get('save_local', False) and
- not params_dict.get('skip_connectivity_test', False)):
- print(f"任务 {task_id} 测试推流地址连通性...")
- if not test_stream_url():
- print(f"任务 {task_id} 推流地址不可达: {stream_url}")
- return {
- "code": 1,
- "msg": f"推流地址不可达: {stream_url}。可能的原因:1. 推流服务器未启动或不可访问 2. 网络连接问题 3. 防火墙阻止连接 4. 推流地址格式错误。建议解决方案:1. 检查推流服务器状态 2. 验证网络连接 3. 使用skip_connectivity_test=true跳过连通性测试 4. 使用save_local=true先保存到本地文件测试",
- "result": None
- }
- elif params_dict.get('skip_connectivity_test', False):
- print(f"任务 {task_id} 跳过连通性测试,直接开始推流")
-
- model = YOLO(model_path)
- cap = cv2.VideoCapture(source)
- if not cap.isOpened():
- return {"code": 1, "msg": f"无法打开视频流: {source}", "result": None}
-
- # 获取视频流信息
- fps_cap = cap.get(cv2.CAP_PROP_FPS)
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
-
- # 使用实际帧率,如果获取不到则使用参数中的fps
- output_fps = fps_cap if fps_cap > 0 else fps
-
- # 构建ffmpeg命令 - 高质量编码,增加错误处理
- # 首先尝试简单配置,如果失败再使用复杂配置
-
- # 确定输出格式和文件
- if params_dict.get('save_local', False):
- # 保存到本地文件
- output_file = f"output_{task_id}.mp4"
- output_format = 'mp4'
- else:
- # 推流
- output_file = stream_url
- output_format = 'flv' if stream_url.startswith('rtmp') else 'mpegts'
-
- simple_ffmpeg_cmd = [
- 'ffmpeg',
- '-f', 'rawvideo',
- '-pix_fmt', 'bgr24',
- '-s', f'{width}x{height}',
- '-r', str(output_fps),
- '-i', '-', # 从stdin读取
- '-c:v', 'libx264',
- '-preset', 'ultrafast', # 使用最快的预设
- '-tune', 'zerolatency',
- '-b:v', bitrate,
- '-loglevel', 'error', # 只显示错误信息
- '-f', output_format,
- '-y', # 覆盖输出文件
- output_file
- ]
-
- advanced_ffmpeg_cmd = [
- 'ffmpeg',
- '-f', 'rawvideo',
- '-pix_fmt', 'bgr24',
- '-s', f'{width}x{height}',
- '-r', str(output_fps),
- '-i', '-', # 从stdin读取
- '-c:v', 'libx264',
- '-preset', 'medium', # 改为medium,平衡质量和速度
- '-tune', 'zerolatency',
- '-profile:v', 'high', # 使用high profile
- '-level', '4.1', # 设置编码级别
- '-b:v', bitrate,
- '-maxrate', bitrate,
- '-bufsize', '8000k', # 增加缓冲区大小
- '-g', str(int(output_fps) * 2), # GOP大小
- '-keyint_min', str(int(output_fps)), # 最小关键帧间隔
- '-sc_threshold', '0', # 禁用场景切换检测
- '-bf', '3', # B帧数量
- '-refs', '6', # 参考帧数量
- '-x264opts', 'no-scenecut=1:nal-hrd=cbr:force-cfr=1', # x264特定选项
- '-color_primaries', 'bt709', # 色彩空间
- '-color_trc', 'bt709', # 色彩传输特性
- '-colorspace', 'bt709', # 色彩空间
- '-loglevel', 'error', # 只显示错误信息
- '-f', output_format,
- '-y', # 覆盖输出文件
- output_file
- ]
-
- # 默认使用简单配置
- ffmpeg_cmd = simple_ffmpeg_cmd
-
- # 启动ffmpeg进程
- if not start_ffmpeg():
- return {"code": 1, "msg": "ffmpeg启动失败,已达到最大重试次数", "result": None}
-
- frame_count = 0
- start_time = time.time()
-
- try:
- while cap.isOpened():
- ret, frame = cap.read()
- if not ret:
- break
-
- # YOLO推理
- try:
- predict_kwargs = {
- 'source': frame,
- 'imgsz': imgsz,
- 'conf': conf,
- 'device': device
- }
-
- # 只有当iou不为None时才添加到参数中
- if iou is not None:
- predict_kwargs['iou'] = iou
-
- results = model.predict(**predict_kwargs)
- annotated_frame = results[0].plot()
- except Exception as predict_error:
- print(f"任务 {task_id} YOLO推理出错: {predict_error}")
- # 如果推理失败,使用原始帧
- annotated_frame = frame
-
- # 将处理后的帧写入ffmpeg
- try:
- # 检查ffmpeg进程是否还在运行
- if ffmpeg_process.poll() is not None:
- stderr_output = ffmpeg_process.stderr.read().decode() if ffmpeg_process.stderr else "未知错误"
- print(f"任务 {task_id} ffmpeg进程已退出: {stderr_output}")
-
- # 尝试重启ffmpeg进程
- print(f"任务 {task_id} 尝试重启ffmpeg进程...")
- cleanup_process()
- if start_ffmpeg():
- print(f"任务 {task_id} ffmpeg进程重启成功,继续处理")
- continue
- else:
- print(f"任务 {task_id} ffmpeg进程重启失败,停止处理")
- break
-
- ffmpeg_process.stdin.write(annotated_frame.tobytes())
- ffmpeg_process.stdin.flush()
- frame_count += 1
-
- # 每100帧输出一次进度
- if frame_count % 100 == 0:
- elapsed_time = time.time() - start_time
- current_fps = frame_count / elapsed_time
- print(f"任务 {task_id} 已处理 {frame_count} 帧,当前FPS: {current_fps:.2f}")
-
- except IOError as e:
- print(f"任务 {task_id} 写入ffmpeg时出错: {e}")
- # 尝试获取ffmpeg的错误输出
- try:
- if ffmpeg_process.stderr:
- stderr_output = ffmpeg_process.stderr.read().decode()
- print(f"任务 {task_id} ffmpeg错误输出: {stderr_output}")
- except:
- pass
-
- # 尝试重启ffmpeg进程
- print(f"任务 {task_id} 尝试重启ffmpeg进程...")
- cleanup_process()
- if start_ffmpeg():
- print(f"任务 {task_id} ffmpeg进程重启成功,继续处理")
- continue
- else:
- print(f"任务 {task_id} ffmpeg进程重启失败,停止处理")
- break
-
- except KeyboardInterrupt:
- print(f"任务 {task_id} 收到中断信号,停止处理")
- finally:
- # 清理资源
- cap.release()
- cleanup_process()
-
- elapsed_time = time.time() - start_time
- avg_fps = frame_count / elapsed_time if elapsed_time > 0 else 0
-
- print(f"任务 {task_id} 推理并推流完成,共处理帧数: {frame_count},平均FPS: {avg_fps:.2f}")
- return {"code": 0, "msg": "success", "result": {"frames_processed": frame_count, "avg_fps": avg_fps}}
-
- except Exception as e:
- print(f"任务 {task_id} 发生异常: {e}")
- cleanup_process()
-
- # 提供更详细的错误信息
- error_msg = str(e)
- if "Broken pipe" in error_msg:
- error_msg += " - 这通常表示推流地址不可达或网络连接问题,建议:1) 检查推流服务器是否运行 2) 检查网络连接 3) 使用save_local=true先保存到本地文件测试"
- elif "Connection refused" in error_msg:
- error_msg += " - 推流服务器拒绝连接,请检查推流地址是否正确"
- elif "Permission denied" in error_msg:
- error_msg += " - 权限不足,请检查推流地址的访问权限"
-
- # 记录详细的调试信息
- print(f"任务 {task_id} 调试信息:")
- print(f" - 输入源: {source}")
- print(f" - 推流地址: {stream_url}")
- print(f" - 保存到本地: {params_dict.get('save_local', False)}")
- print(f" - 视频尺寸: {width}x{height}")
- print(f" - 帧率: {output_fps}")
- print(f" - 码率: {bitrate}")
-
- return {"code": 1, "msg": error_msg, "result": None}
- @app_fastapi.post("/yolov12/stream")
- async def yolov12_stream_async(params: StreamParams):
- """
- RESTful POST接口:/yolov12/stream
- 接收视频拉流地址和推流地址,调用YOLO模型推理,使用ffmpeg将推理后的视频推送到推流地址。
- 支持并发推理和任务取消。
- 返回格式:{"code": 0/1, "msg": "success/错误原因", "result": {"task_id": "任务ID"}}
- """
- logging.info("收到/yolov12/stream请求")
- logging.info(f"请求参数: {params}")
-
- # 生成唯一任务ID
- task_id = str(uuid.uuid4())
-
- try:
- # 异步执行推理任务
- loop = asyncio.get_event_loop()
- future = loop.run_in_executor(
- stream_executor, partial(yolov12_stream_worker, params.dict(), task_id)
- )
-
- # 添加任务完成后的清理回调
- def cleanup_task(fut):
- try:
- # 获取任务结果(如果有异常会抛出)
- result = fut.result()
- print(f"任务 {task_id} 已完成,结果: {result}")
- except Exception as e:
- print(f"任务 {task_id} 执行异常: {e}")
- finally:
- # 清理任务
- if task_id in stream_tasks:
- del stream_tasks[task_id]
- if task_id in stream_pids:
- del stream_pids[task_id]
- print(f"已清理任务: {task_id}")
-
- future.add_done_callback(cleanup_task)
- stream_tasks[task_id] = {
- "future": future,
- "start_time": int(time.time() * 1000), # 毫秒级时间戳,long类型
- "source": params.source,
- "stream_url": params.stream_url
- }
-
- logging.info(f"任务 {task_id} 已提交到进程池")
-
- return {
- "code": 0,
- "msg": "任务已提交",
- "result": task_id
- }
-
- except Exception as e:
- logging.error(f"提交任务时发生异常: {e}")
- return {
- "code": 1,
- "msg": str(e),
- "result": None
- }
- @app_fastapi.post("/yolov12/stream/cancel")
- async def cancel_stream_task(task_id: str):
- """
- RESTful POST接口:/yolov12/stream/cancel
- 取消指定的推理任务
- 返回格式:{"code": 0/1, "msg": "success/错误原因", "result": None}
- """
- logging.info(f"收到取消任务请求: {task_id}")
-
- # 先清理已完成的任务
- cleanup_completed_tasks()
-
- task_info = stream_tasks.get(task_id)
- if not task_info:
- return {"code": 1, "msg": "任务不存在", "result": None}
-
- if task_info["future"].done():
- return {"code": 1, "msg": "任务已完成,无法取消", "result": None}
-
- try:
- # 尝试取消任务
- cancelled = task_info["future"].cancel()
- if cancelled:
- logging.info(f"任务 {task_id} 已取消")
- # 取消后立即清理
- if task_id in stream_tasks:
- del stream_tasks[task_id]
- if task_id in stream_pids:
- del stream_pids[task_id]
- return {"code": 0, "msg": "任务已取消", "result": None}
- else:
- return {"code": 1, "msg": "任务无法取消(可能正在运行)", "result": None}
- except Exception as e:
- logging.error(f"取消任务时发生异常: {e}")
- return {"code": 1, "msg": str(e), "result": None}
- @app_fastapi.get("/yolov12/stream/status")
- async def get_stream_status(task_id: str):
- """
- RESTful GET接口:/yolov12/stream/status
- 查询指定任务的状态
- 返回格式:{"code": 0/1, "msg": "success/错误原因", "result": {"status": "状态", "result": "结果"}}
- """
- logging.info(f"收到查询任务状态请求: {task_id}")
-
- # 先清理已完成的任务
- cleanup_completed_tasks()
-
- task_info = stream_tasks.get(task_id)
- if not task_info:
- return {"code": 1, "msg": "任务不存在", "result": None}
-
- try:
- if task_info["future"].done():
- try:
- result = task_info["future"].result()
- run_time = (time.time() * 1000 - task_info["start_time"]) / 1000 # 转换为秒
- return {
- "code": 0,
- "msg": "已完成",
- "result": {
- "status": "completed",
- "result": result,
- "start_time": task_info["start_time"],
- "run_time": round(run_time, 2),
- "source": task_info["source"],
- "stream_url": task_info["stream_url"]
- }
- }
- except Exception as e:
- run_time = (time.time() * 1000 - task_info["start_time"]) / 1000 # 转换为秒
- return {
- "code": 1,
- "msg": f"任务异常: {e}",
- "result": {
- "status": "failed",
- "error": str(e),
- "start_time": task_info["start_time"],
- "run_time": round(run_time, 2),
- "source": task_info["source"],
- "stream_url": task_info["stream_url"]
- }
- }
- elif task_info["future"].cancelled():
- run_time = (time.time() * 1000 - task_info["start_time"]) / 1000 # 转换为秒
- return {
- "code": 0,
- "msg": "已取消",
- "result": {
- "status": "cancelled",
- "start_time": task_info["start_time"],
- "run_time": round(run_time, 2),
- "source": task_info["source"],
- "stream_url": task_info["stream_url"]
- }
- }
- else:
- run_time = (time.time() * 1000 - task_info["start_time"]) / 1000 # 转换为秒
- return {
- "code": 0,
- "msg": "运行中",
- "result": {
- "status": "running",
- "start_time": task_info["start_time"],
- "run_time": round(run_time, 2),
- "source": task_info["source"],
- "stream_url": task_info["stream_url"]
- }
- }
- except Exception as e:
- logging.error(f"查询任务状态时发生异常: {e}")
- return {"code": 1, "msg": str(e), "result": None}
- @app_fastapi.get("/yolov12/stream/list")
- async def list_stream_tasks():
- """
- RESTful GET接口:/yolov12/stream/list
- 列出所有任务的状态
- 返回格式:{"code": 0, "msg": "success", "result": {"tasks": [{"task_id": "ID", "status": "状态"}]}}
- """
- logging.info("收到查询所有任务请求")
-
- # 先清理已完成的任务
- cleanup_completed_tasks()
- try:
- tasks_info = []
- for task_id, task_info in stream_tasks.items():
- if task_info["future"].done():
- if task_info["future"].cancelled():
- status = "cancelled"
- else:
- try:
- task_info["future"].result() # 检查是否有异常
- status = "completed"
- except:
- status = "failed"
- else:
- status = "running"
-
- # 计算运行时长
- run_time = (time.time() * 1000 - task_info["start_time"]) / 1000 # 转换为秒
-
- tasks_info.append({
- "task_id": task_id,
- "status": status,
- "start_time": task_info["start_time"],
- "run_time": round(run_time, 2), # 运行时长(秒)
- "source": task_info["source"], # 播流地址
- "stream_url": task_info["stream_url"] # 推流地址
- })
-
- return {
- "code": 0,
- "msg": "success",
- "result": tasks_info
- }
- except Exception as e:
- logging.error(f"查询所有任务时发生异常: {e}")
- return {"code": 1, "msg": str(e), "result": None}
- # 全局异常处理器:参数校验失败时统一返回格式
- @app_fastapi.exception_handler(RequestValidationError)
- async def validation_exception_handler(request, exc):
- err_msg = f"参数校验失败: 路径={request.url.path}, 错误={exc.errors()}"
- logging.error(err_msg)
- return JSONResponse(
- status_code=status.HTTP_200_OK,
- content={
- "code": 422,
- "msg": err_msg,
- "result": None
- }
- )
- if __name__ == "__main__":
- threading.Thread(target=start_gradio, daemon=True).start()
- uvicorn.run(app_fastapi, host="0.0.0.0", port=8000)
|