xujunwei преди 6 месеца
родител
ревизия
5fa046220b
променени са 1 файла, в които са добавени 72 реда и са изтрити 0 реда
  1. 72 0
      app.py

+ 72 - 0
app.py

@@ -237,6 +237,78 @@ def yolov12_train(params: TrainParams):
             "result": None
         }
 
+class PredictParams(BaseModel):
+    """
+    用于接收/yolov12/predict接口的预测参数,与YOLO predict方法保持一致。
+    """
+    model: str = "yolov12m.pt"  # 模型路径
+    source: str = None  # 输入源(图片/视频路径、URL等)
+    stream: bool = False  # 是否流式处理
+    conf: float = 0.25  # 置信度阈值
+    iou: float = 0.7  # IoU阈值
+    max_det: int = 300  # 最大检测数量
+    imgsz: int = 640  # 输入图片尺寸
+    batch: int = 1  # 批次大小
+    device: str = ""  # 设备
+    show: bool = False  # 是否显示结果
+    save: bool = False  # 是否保存结果
+    save_txt: bool = False  # 是否保存txt文件
+    save_conf: bool = False  # 是否保存置信度
+    save_crop: bool = False  # 是否保存裁剪图片
+    show_labels: bool = True  # 是否显示标签
+    show_conf: bool = True  # 是否显示置信度
+    show_boxes: bool = True  # 是否显示边界框
+    line_width: int = None  # 线条宽度
+    vid_stride: int = 1  # 视频帧步长
+    stream_buffer: bool = False  # 流缓冲区
+    visualize: bool = False  # 可视化特征
+    augment: bool = False  # 数据增强
+    agnostic_nms: bool = False  # 类别无关NMS
+    classes: list = None  # 指定类别
+    retina_masks: bool = False  # 高分辨率分割掩码
+    embed: list = None  # 特征向量层
+    half: bool = False  # 半精度
+    dnn: bool = False  # OpenCV DNN
+    project: str = ""  # 项目名
+    name: str = ""  # 实验名
+    exist_ok: bool = False  # 是否覆盖现有目录
+    verbose: bool = True  # 详细输出
+
+@app_fastapi.post("/yolov12/predict")
+def yolov12_predict(params: PredictParams):
+    """
+    RESTful POST接口:/yolov12/predict
+    接收预测参数,调用YOLO模型进行预测,并返回预测结果。
+    返回格式:{"code": 0/1, "msg": "success/错误原因", "result": 预测结果或None}
+    """
+    logging.info("收到/yolov12/predict预测请求")
+    logging.info(f"请求参数: {params}")
+    try:
+        model = YOLO(params.model)
+        logging.info("开始模型预测...")
+        
+        # 构建预测参数
+        predict_kwargs = {}
+        for field, value in params.dict().items():
+            if field not in ['model'] and value is not None:
+                predict_kwargs[field] = value
+        
+        results = model.predict(**predict_kwargs)
+        logging.info("模型预测完成")
+        logging.info(f"预测结果: {str(results)}")
+        return {
+            "code": 0,
+            "msg": "success",
+            "result": results[0].save_dir
+        }
+    except Exception as e:
+        logging.error(f"预测过程发生异常: {e}")
+        return {
+            "code": 1,
+            "msg": str(e),
+            "result": None
+        }
+
 # 全局异常处理器:参数校验失败时统一返回格式
 @app_fastapi.exception_handler(RequestValidationError)
 async def validation_exception_handler(request, exc):