|
|
@@ -207,7 +207,15 @@ def yolov12_train(params: TrainParams):
|
|
|
logging.info("收到/yolov12/train训练请求")
|
|
|
logging.info(f"请求参数: {params}")
|
|
|
try:
|
|
|
- model = YOLO("yolov12.yaml") # 如有yolov12n.yaml可替换
|
|
|
+ # 根据params.model动态确定配置文件
|
|
|
+ if params.model.endswith('.pt'):
|
|
|
+ # 如果是.pt文件,将后缀替换为.yaml
|
|
|
+ config_file = params.model.replace('.pt', '.yaml')
|
|
|
+ else:
|
|
|
+ # 如果不是.pt文件,使用默认配置
|
|
|
+ config_file = "yolov12.yaml"
|
|
|
+
|
|
|
+ model = YOLO(config_file)
|
|
|
model.load(params.model)
|
|
|
logging.info("开始模型训练...")
|
|
|
results = model.train(
|