xujunwei il y a 6 mois
Parent
commit
e47dafeaf6
2 fichiers modifiés avec 85 ajouts et 4 suppressions
  1. 80 2
      app.py
  2. 5 2
      train.py

+ 80 - 2
app.py

@@ -7,7 +7,14 @@ import gradio as gr
 import cv2
 import tempfile
 from ultralytics import YOLO
+import threading
+from fastapi import FastAPI
+from pydantic import BaseModel
+import uvicorn
+import logging
 
+# 设置日志格式和级别
+logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s - %(message)s')
 
 def yolov12_inference(image, video, model_id, image_size, conf_threshold):
     model = YOLO(model_id)
@@ -161,5 +168,76 @@ with gradio_app:
     with gr.Row():
         with gr.Column():
             app()
-if __name__ == '__main__':
-    gradio_app.launch()
+
+def start_gradio():
+    gradio_app.launch(server_name="0.0.0.0", server_port=7860)
+
+# FastAPI部分
+app_fastapi = FastAPI()
+
+class TrainParams(BaseModel):
+    """
+    用于接收/yolov12/train接口的训练参数,所有参数均需前端传入。
+    """
+    data: str           # 数据集配置文件路径
+    epochs: int         # 训练轮数
+    batch: int          # 批次大小
+    imgsz: int          # 输入图片尺寸
+    scale: float        # 随机缩放增强比例
+    mosaic: float       # mosaic数据增强概率
+    mixup: float        # mixup数据增强概率
+    copy_paste: float   # copy-paste数据增强概率
+    device: str         # 训练设备
+    project: str        # 工程名
+    name: str           # 实验名
+    exist_ok: bool      # 是否允许覆盖同名目录
+
+@app_fastapi.post("/yolov12/train")
+def yolov12_train(params: TrainParams):
+    """
+    RESTful POST接口:/yolov12/train
+    接收训练参数,调用YOLO模型训练,并返回训练结果。
+    返回格式:{"code": 0/1, "msg": "success/错误原因", "result": 训练结果或None}
+    """
+    logging.info("收到/yolov12/train训练请求")
+    logging.info(f"请求参数: {params}")
+    try:
+        model = YOLO("yolov12.yaml")  # 如有yolov12n.yaml可替换
+        logging.info("开始模型训练...")
+        results = model.train(
+            data=params.data,
+            epochs=params.epochs,
+            batch=params.batch,
+            imgsz=params.imgsz,
+            scale=params.scale,
+            mosaic=params.mosaic,
+            mixup=params.mixup,
+            copy_paste=params.copy_paste,
+            device=params.device,
+            project=params.project,
+            name=params.name,
+            exist_ok=params.exist_ok,
+        )
+        logging.info("模型训练完成")
+        logging.info(f"训练结果: save_dir={results.save_dir}, metrics={results.metrics}, epoch={results.epoch}, best_fitness={getattr(results, 'best_fitness', None)}")
+        return {
+            "code": 0,
+            "msg": "success",
+            "result": {
+                "save_dir": str(results.save_dir),
+                "metrics": str(results.metrics),
+                "epoch": results.epoch,
+                "best_fitness": getattr(results, "best_fitness", None)
+            }
+        }
+    except Exception as e:
+        logging.error(f"训练过程发生异常: {e}")
+        return {
+            "code": 1,
+            "msg": str(e),
+            "result": None
+        }
+
+if __name__ == "__main__":
+    threading.Thread(target=start_gradio, daemon=True).start()
+    uvicorn.run(app_fastapi, host="0.0.0.0", port=8000)

+ 5 - 2
train.py

@@ -7,8 +7,11 @@ if __name__ == "__main__":
 
     # Train the model
     results = model.train(
-    data='/Users/jsonxu/Desktop/yolov12/dataset/data.yaml',
-    epochs=100, 
+    data='/Users/jsonxu/code/ai/train/6017/data.yaml',
+    project='/Users/jsonxu/code/ai/train/6017', # 训练结果路径
+    name='result',# 训练结果名称
+    exist_ok=True, #如果为 True,则允许覆盖现有的项目/名称目录。这对迭代实验非常有用,无需手动清除之前的输出。
+    epochs=1,
     batch=16, 
     imgsz=640,
     scale=0.5,  # S:0.9; M:0.9; L:0.9; X:0.9