|
|
@@ -237,6 +237,78 @@ def yolov12_train(params: TrainParams):
|
|
|
"result": None
|
|
|
}
|
|
|
|
|
|
+class PredictParams(BaseModel):
|
|
|
+ """
|
|
|
+ 用于接收/yolov12/predict接口的预测参数,与YOLO predict方法保持一致。
|
|
|
+ """
|
|
|
+ model: str = "yolov12m.pt" # 模型路径
|
|
|
+ source: str = None # 输入源(图片/视频路径、URL等)
|
|
|
+ stream: bool = False # 是否流式处理
|
|
|
+ conf: float = 0.25 # 置信度阈值
|
|
|
+ iou: float = 0.7 # IoU阈值
|
|
|
+ max_det: int = 300 # 最大检测数量
|
|
|
+ imgsz: int = 640 # 输入图片尺寸
|
|
|
+ batch: int = 1 # 批次大小
|
|
|
+ device: str = "" # 设备
|
|
|
+ show: bool = False # 是否显示结果
|
|
|
+ save: bool = False # 是否保存结果
|
|
|
+ save_txt: bool = False # 是否保存txt文件
|
|
|
+ save_conf: bool = False # 是否保存置信度
|
|
|
+ save_crop: bool = False # 是否保存裁剪图片
|
|
|
+ show_labels: bool = True # 是否显示标签
|
|
|
+ show_conf: bool = True # 是否显示置信度
|
|
|
+ show_boxes: bool = True # 是否显示边界框
|
|
|
+ line_width: int = None # 线条宽度
|
|
|
+ vid_stride: int = 1 # 视频帧步长
|
|
|
+ stream_buffer: bool = False # 流缓冲区
|
|
|
+ visualize: bool = False # 可视化特征
|
|
|
+ augment: bool = False # 数据增强
|
|
|
+ agnostic_nms: bool = False # 类别无关NMS
|
|
|
+ classes: list = None # 指定类别
|
|
|
+ retina_masks: bool = False # 高分辨率分割掩码
|
|
|
+ embed: list = None # 特征向量层
|
|
|
+ half: bool = False # 半精度
|
|
|
+ dnn: bool = False # OpenCV DNN
|
|
|
+ project: str = "" # 项目名
|
|
|
+ name: str = "" # 实验名
|
|
|
+ exist_ok: bool = False # 是否覆盖现有目录
|
|
|
+ verbose: bool = True # 详细输出
|
|
|
+
|
|
|
+@app_fastapi.post("/yolov12/predict")
|
|
|
+def yolov12_predict(params: PredictParams):
|
|
|
+ """
|
|
|
+ RESTful POST接口:/yolov12/predict
|
|
|
+ 接收预测参数,调用YOLO模型进行预测,并返回预测结果。
|
|
|
+ 返回格式:{"code": 0/1, "msg": "success/错误原因", "result": 预测结果或None}
|
|
|
+ """
|
|
|
+ logging.info("收到/yolov12/predict预测请求")
|
|
|
+ logging.info(f"请求参数: {params}")
|
|
|
+ try:
|
|
|
+ model = YOLO(params.model)
|
|
|
+ logging.info("开始模型预测...")
|
|
|
+
|
|
|
+ # 构建预测参数
|
|
|
+ predict_kwargs = {}
|
|
|
+ for field, value in params.dict().items():
|
|
|
+ if field not in ['model'] and value is not None:
|
|
|
+ predict_kwargs[field] = value
|
|
|
+
|
|
|
+ results = model.predict(**predict_kwargs)
|
|
|
+ logging.info("模型预测完成")
|
|
|
+ logging.info(f"预测结果: {str(results)}")
|
|
|
+ return {
|
|
|
+ "code": 0,
|
|
|
+ "msg": "success",
|
|
|
+ "result": results[0].save_dir
|
|
|
+ }
|
|
|
+ except Exception as e:
|
|
|
+ logging.error(f"预测过程发生异常: {e}")
|
|
|
+ return {
|
|
|
+ "code": 1,
|
|
|
+ "msg": str(e),
|
|
|
+ "result": None
|
|
|
+ }
|
|
|
+
|
|
|
# 全局异常处理器:参数校验失败时统一返回格式
|
|
|
@app_fastapi.exception_handler(RequestValidationError)
|
|
|
async def validation_exception_handler(request, exc):
|