xujunwei 6 месяцев назад
Родитель
Сommit
66e0ece643
1 измененных файлов с 9 добавлено и 1 удалено
  1. 9 1
      app.py

+ 9 - 1
app.py

@@ -207,7 +207,15 @@ def yolov12_train(params: TrainParams):
     logging.info("收到/yolov12/train训练请求")
     logging.info(f"请求参数: {params}")
     try:
-        model = YOLO("yolov12.yaml")  # 如有yolov12n.yaml可替换
+        # 根据params.model动态确定配置文件
+        if params.model.endswith('.pt'):
+            # 如果是.pt文件,将后缀替换为.yaml
+            config_file = params.model.replace('.pt', '.yaml')
+        else:
+            # 如果不是.pt文件,使用默认配置
+            config_file = "yolov12.yaml"
+        
+        model = YOLO(config_file)
         model.load(params.model)
         logging.info("开始模型训练...")
         results = model.train(