|
@@ -7,7 +7,14 @@ import gradio as gr
|
|
|
import cv2
|
|
import cv2
|
|
|
import tempfile
|
|
import tempfile
|
|
|
from ultralytics import YOLO
|
|
from ultralytics import YOLO
|
|
|
|
|
+import threading
|
|
|
|
|
+from fastapi import FastAPI
|
|
|
|
|
+from pydantic import BaseModel
|
|
|
|
|
+import uvicorn
|
|
|
|
|
+import logging
|
|
|
|
|
|
|
|
|
|
+# 设置日志格式和级别
|
|
|
|
|
+logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s - %(message)s')
|
|
|
|
|
|
|
|
def yolov12_inference(image, video, model_id, image_size, conf_threshold):
|
|
def yolov12_inference(image, video, model_id, image_size, conf_threshold):
|
|
|
model = YOLO(model_id)
|
|
model = YOLO(model_id)
|
|
@@ -161,5 +168,76 @@ with gradio_app:
|
|
|
with gr.Row():
|
|
with gr.Row():
|
|
|
with gr.Column():
|
|
with gr.Column():
|
|
|
app()
|
|
app()
|
|
|
-if __name__ == '__main__':
|
|
|
|
|
- gradio_app.launch()
|
|
|
|
|
|
|
+
|
|
|
|
|
+def start_gradio():
|
|
|
|
|
+ gradio_app.launch(server_name="0.0.0.0", server_port=7860)
|
|
|
|
|
+
|
|
|
|
|
+# FastAPI部分
|
|
|
|
|
+app_fastapi = FastAPI()
|
|
|
|
|
+
|
|
|
|
|
+class TrainParams(BaseModel):
|
|
|
|
|
+ """
|
|
|
|
|
+ 用于接收/yolov12/train接口的训练参数,所有参数均需前端传入。
|
|
|
|
|
+ """
|
|
|
|
|
+ data: str # 数据集配置文件路径
|
|
|
|
|
+ epochs: int # 训练轮数
|
|
|
|
|
+ batch: int # 批次大小
|
|
|
|
|
+ imgsz: int # 输入图片尺寸
|
|
|
|
|
+ scale: float # 随机缩放增强比例
|
|
|
|
|
+ mosaic: float # mosaic数据增强概率
|
|
|
|
|
+ mixup: float # mixup数据增强概率
|
|
|
|
|
+ copy_paste: float # copy-paste数据增强概率
|
|
|
|
|
+ device: str # 训练设备
|
|
|
|
|
+ project: str # 工程名
|
|
|
|
|
+ name: str # 实验名
|
|
|
|
|
+ exist_ok: bool # 是否允许覆盖同名目录
|
|
|
|
|
+
|
|
|
|
|
+@app_fastapi.post("/yolov12/train")
|
|
|
|
|
+def yolov12_train(params: TrainParams):
|
|
|
|
|
+ """
|
|
|
|
|
+ RESTful POST接口:/yolov12/train
|
|
|
|
|
+ 接收训练参数,调用YOLO模型训练,并返回训练结果。
|
|
|
|
|
+ 返回格式:{"code": 0/1, "msg": "success/错误原因", "result": 训练结果或None}
|
|
|
|
|
+ """
|
|
|
|
|
+ logging.info("收到/yolov12/train训练请求")
|
|
|
|
|
+ logging.info(f"请求参数: {params}")
|
|
|
|
|
+ try:
|
|
|
|
|
+ model = YOLO("yolov12.yaml") # 如有yolov12n.yaml可替换
|
|
|
|
|
+ logging.info("开始模型训练...")
|
|
|
|
|
+ results = model.train(
|
|
|
|
|
+ data=params.data,
|
|
|
|
|
+ epochs=params.epochs,
|
|
|
|
|
+ batch=params.batch,
|
|
|
|
|
+ imgsz=params.imgsz,
|
|
|
|
|
+ scale=params.scale,
|
|
|
|
|
+ mosaic=params.mosaic,
|
|
|
|
|
+ mixup=params.mixup,
|
|
|
|
|
+ copy_paste=params.copy_paste,
|
|
|
|
|
+ device=params.device,
|
|
|
|
|
+ project=params.project,
|
|
|
|
|
+ name=params.name,
|
|
|
|
|
+ exist_ok=params.exist_ok,
|
|
|
|
|
+ )
|
|
|
|
|
+ logging.info("模型训练完成")
|
|
|
|
|
+ logging.info(f"训练结果: save_dir={results.save_dir}, metrics={results.metrics}, epoch={results.epoch}, best_fitness={getattr(results, 'best_fitness', None)}")
|
|
|
|
|
+ return {
|
|
|
|
|
+ "code": 0,
|
|
|
|
|
+ "msg": "success",
|
|
|
|
|
+ "result": {
|
|
|
|
|
+ "save_dir": str(results.save_dir),
|
|
|
|
|
+ "metrics": str(results.metrics),
|
|
|
|
|
+ "epoch": results.epoch,
|
|
|
|
|
+ "best_fitness": getattr(results, "best_fitness", None)
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logging.error(f"训练过程发生异常: {e}")
|
|
|
|
|
+ return {
|
|
|
|
|
+ "code": 1,
|
|
|
|
|
+ "msg": str(e),
|
|
|
|
|
+ "result": None
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+if __name__ == "__main__":
|
|
|
|
|
+ threading.Thread(target=start_gradio, daemon=True).start()
|
|
|
|
|
+ uvicorn.run(app_fastapi, host="0.0.0.0", port=8000)
|