田运杰 10 meses atrás
pai
commit
213b890b2f

+ 1 - 0
ultralytics/engine/__init__.py

@@ -0,0 +1 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license

Diferenças do arquivo suprimidas por serem muito extensas
+ 1476 - 0
ultralytics/engine/exporter.py


Diferenças do arquivo suprimidas por serem muito extensas
+ 1173 - 0
ultralytics/engine/model.py


+ 408 - 0
ultralytics/engine/predictor.py

@@ -0,0 +1,408 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+"""
+Run prediction on images, videos, directories, globs, YouTube, webcam, streams, etc.
+
+Usage - sources:
+    $ yolo mode=predict model=yolov8n.pt source=0                               # webcam
+                                                img.jpg                         # image
+                                                vid.mp4                         # video
+                                                screen                          # screenshot
+                                                path/                           # directory
+                                                list.txt                        # list of images
+                                                list.streams                    # list of streams
+                                                'path/*.jpg'                    # glob
+                                                'https://youtu.be/LNwODJXcvt4'  # YouTube
+                                                'rtsp://example.com/media.mp4'  # RTSP, RTMP, HTTP, TCP stream
+
+Usage - formats:
+    $ yolo mode=predict model=yolov8n.pt                 # PyTorch
+                              yolov8n.torchscript        # TorchScript
+                              yolov8n.onnx               # ONNX Runtime or OpenCV DNN with dnn=True
+                              yolov8n_openvino_model     # OpenVINO
+                              yolov8n.engine             # TensorRT
+                              yolov8n.mlpackage          # CoreML (macOS-only)
+                              yolov8n_saved_model        # TensorFlow SavedModel
+                              yolov8n.pb                 # TensorFlow GraphDef
+                              yolov8n.tflite             # TensorFlow Lite
+                              yolov8n_edgetpu.tflite     # TensorFlow Edge TPU
+                              yolov8n_paddle_model       # PaddlePaddle
+                              yolov8n.mnn                # MNN
+                              yolov8n_ncnn_model         # NCNN
+"""
+
+import platform
+import re
+import threading
+from pathlib import Path
+
+import cv2
+import numpy as np
+import torch
+
+from ultralytics.cfg import get_cfg, get_save_dir
+from ultralytics.data import load_inference_source
+from ultralytics.data.augment import LetterBox, classify_transforms
+from ultralytics.nn.autobackend import AutoBackend
+from ultralytics.utils import DEFAULT_CFG, LOGGER, MACOS, WINDOWS, callbacks, colorstr, ops
+from ultralytics.utils.checks import check_imgsz, check_imshow
+from ultralytics.utils.files import increment_path
+from ultralytics.utils.torch_utils import select_device, smart_inference_mode
+
+STREAM_WARNING = """
+WARNING ⚠️ inference results will accumulate in RAM unless `stream=True` is passed, causing potential out-of-memory
+errors for large sources or long-running streams and videos. See https://docs.ultralytics.com/modes/predict/ for help.
+
+Example:
+    results = model(source=..., stream=True)  # generator of Results objects
+    for r in results:
+        boxes = r.boxes  # Boxes object for bbox outputs
+        masks = r.masks  # Masks object for segment masks outputs
+        probs = r.probs  # Class probabilities for classification outputs
+"""
+
+
+class BasePredictor:
+    """
+    BasePredictor.
+
+    A base class for creating predictors.
+
+    Attributes:
+        args (SimpleNamespace): Configuration for the predictor.
+        save_dir (Path): Directory to save results.
+        done_warmup (bool): Whether the predictor has finished setup.
+        model (nn.Module): Model used for prediction.
+        data (dict): Data configuration.
+        device (torch.device): Device used for prediction.
+        dataset (Dataset): Dataset used for prediction.
+        vid_writer (dict): Dictionary of {save_path: video_writer, ...} writer for saving video output.
+    """
+
+    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
+        """
+        Initializes the BasePredictor class.
+
+        Args:
+            cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
+            overrides (dict, optional): Configuration overrides. Defaults to None.
+        """
+        self.args = get_cfg(cfg, overrides)
+        self.save_dir = get_save_dir(self.args)
+        if self.args.conf is None:
+            self.args.conf = 0.25  # default conf=0.25
+        self.done_warmup = False
+        if self.args.show:
+            self.args.show = check_imshow(warn=True)
+
+        # Usable if setup is done
+        self.model = None
+        self.data = self.args.data  # data_dict
+        self.imgsz = None
+        self.device = None
+        self.dataset = None
+        self.vid_writer = {}  # dict of {save_path: video_writer, ...}
+        self.plotted_img = None
+        self.source_type = None
+        self.seen = 0
+        self.windows = []
+        self.batch = None
+        self.results = None
+        self.transforms = None
+        self.callbacks = _callbacks or callbacks.get_default_callbacks()
+        self.txt_path = None
+        self._lock = threading.Lock()  # for automatic thread-safe inference
+        callbacks.add_integration_callbacks(self)
+
+    def preprocess(self, im):
+        """
+        Prepares input image before inference.
+
+        Args:
+            im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list.
+        """
+        not_tensor = not isinstance(im, torch.Tensor)
+        if not_tensor:
+            im = np.stack(self.pre_transform(im))
+            im = im[..., ::-1].transpose((0, 3, 1, 2))  # BGR to RGB, BHWC to BCHW, (n, 3, h, w)
+            im = np.ascontiguousarray(im)  # contiguous
+            im = torch.from_numpy(im)
+
+        im = im.to(self.device)
+        im = im.half() if self.model.fp16 else im.float()  # uint8 to fp16/32
+        if not_tensor:
+            im /= 255  # 0 - 255 to 0.0 - 1.0
+        return im
+
+    def inference(self, im, *args, **kwargs):
+        """Runs inference on a given image using the specified model and arguments."""
+        visualize = (
+            increment_path(self.save_dir / Path(self.batch[0][0]).stem, mkdir=True)
+            if self.args.visualize and (not self.source_type.tensor)
+            else False
+        )
+        return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs)
+
+    def pre_transform(self, im):
+        """
+        Pre-transform input image before inference.
+
+        Args:
+            im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
+
+        Returns:
+            (list): A list of transformed images.
+        """
+        same_shapes = len({x.shape for x in im}) == 1
+        letterbox = LetterBox(
+            self.imgsz,
+            auto=same_shapes and (self.model.pt or (getattr(self.model, "dynamic", False) and not self.model.imx)),
+            stride=self.model.stride,
+        )
+        return [letterbox(image=x) for x in im]
+
+    def postprocess(self, preds, img, orig_imgs):
+        """Post-processes predictions for an image and returns them."""
+        return preds
+
+    def __call__(self, source=None, model=None, stream=False, *args, **kwargs):
+        """Performs inference on an image or stream."""
+        self.stream = stream
+        if stream:
+            return self.stream_inference(source, model, *args, **kwargs)
+        else:
+            return list(self.stream_inference(source, model, *args, **kwargs))  # merge list of Result into one
+
+    def predict_cli(self, source=None, model=None):
+        """
+        Method used for Command Line Interface (CLI) prediction.
+
+        This function is designed to run predictions using the CLI. It sets up the source and model, then processes
+        the inputs in a streaming manner. This method ensures that no outputs accumulate in memory by consuming the
+        generator without storing results.
+
+        Note:
+            Do not modify this function or remove the generator. The generator ensures that no outputs are
+            accumulated in memory, which is critical for preventing memory issues during long-running predictions.
+        """
+        gen = self.stream_inference(source, model)
+        for _ in gen:  # sourcery skip: remove-empty-nested-block, noqa
+            pass
+
+    def setup_source(self, source):
+        """Sets up source and inference mode."""
+        self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2)  # check image size
+        self.transforms = (
+            getattr(
+                self.model.model,
+                "transforms",
+                classify_transforms(self.imgsz[0], crop_fraction=self.args.crop_fraction),
+            )
+            if self.args.task == "classify"
+            else None
+        )
+        self.dataset = load_inference_source(
+            source=source,
+            batch=self.args.batch,
+            vid_stride=self.args.vid_stride,
+            buffer=self.args.stream_buffer,
+        )
+        self.source_type = self.dataset.source_type
+        if not getattr(self, "stream", True) and (
+            self.source_type.stream
+            or self.source_type.screenshot
+            or len(self.dataset) > 1000  # many images
+            or any(getattr(self.dataset, "video_flag", [False]))
+        ):  # videos
+            LOGGER.warning(STREAM_WARNING)
+        self.vid_writer = {}
+
+    @smart_inference_mode()
+    def stream_inference(self, source=None, model=None, *args, **kwargs):
+        """Streams real-time inference on camera feed and saves results to file."""
+        if self.args.verbose:
+            LOGGER.info("")
+
+        # Setup model
+        if not self.model:
+            self.setup_model(model)
+
+        with self._lock:  # for thread-safe inference
+            # Setup source every time predict is called
+            self.setup_source(source if source is not None else self.args.source)
+
+            # Check if save_dir/ label file exists
+            if self.args.save or self.args.save_txt:
+                (self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
+
+            # Warmup model
+            if not self.done_warmup:
+                self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, 3, *self.imgsz))
+                self.done_warmup = True
+
+            self.seen, self.windows, self.batch = 0, [], None
+            profilers = (
+                ops.Profile(device=self.device),
+                ops.Profile(device=self.device),
+                ops.Profile(device=self.device),
+            )
+            self.run_callbacks("on_predict_start")
+            for self.batch in self.dataset:
+                self.run_callbacks("on_predict_batch_start")
+                paths, im0s, s = self.batch
+
+                # Preprocess
+                with profilers[0]:
+                    im = self.preprocess(im0s)
+
+                # Inference
+                with profilers[1]:
+                    preds = self.inference(im, *args, **kwargs)
+                    if self.args.embed:
+                        yield from [preds] if isinstance(preds, torch.Tensor) else preds  # yield embedding tensors
+                        continue
+
+                # Postprocess
+                with profilers[2]:
+                    self.results = self.postprocess(preds, im, im0s)
+                self.run_callbacks("on_predict_postprocess_end")
+
+                # Visualize, save, write results
+                n = len(im0s)
+                for i in range(n):
+                    self.seen += 1
+                    self.results[i].speed = {
+                        "preprocess": profilers[0].dt * 1e3 / n,
+                        "inference": profilers[1].dt * 1e3 / n,
+                        "postprocess": profilers[2].dt * 1e3 / n,
+                    }
+                    if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
+                        s[i] += self.write_results(i, Path(paths[i]), im, s)
+
+                # Print batch results
+                if self.args.verbose:
+                    LOGGER.info("\n".join(s))
+
+                self.run_callbacks("on_predict_batch_end")
+                yield from self.results
+
+        # Release assets
+        for v in self.vid_writer.values():
+            if isinstance(v, cv2.VideoWriter):
+                v.release()
+
+        # Print final results
+        if self.args.verbose and self.seen:
+            t = tuple(x.t / self.seen * 1e3 for x in profilers)  # speeds per image
+            LOGGER.info(
+                f"Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape "
+                f"{(min(self.args.batch, self.seen), 3, *im.shape[2:])}" % t
+            )
+        if self.args.save or self.args.save_txt or self.args.save_crop:
+            nl = len(list(self.save_dir.glob("labels/*.txt")))  # number of labels
+            s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else ""
+            LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
+        self.run_callbacks("on_predict_end")
+
+    def setup_model(self, model, verbose=True):
+        """Initialize YOLO model with given parameters and set it to evaluation mode."""
+        self.model = AutoBackend(
+            weights=model or self.args.model,
+            device=select_device(self.args.device, verbose=verbose),
+            dnn=self.args.dnn,
+            data=self.args.data,
+            fp16=self.args.half,
+            batch=self.args.batch,
+            fuse=True,
+            verbose=verbose,
+        )
+
+        self.device = self.model.device  # update device
+        self.args.half = self.model.fp16  # update half
+        self.model.eval()
+
+    def write_results(self, i, p, im, s):
+        """Write inference results to a file or directory."""
+        string = ""  # print string
+        if len(im.shape) == 3:
+            im = im[None]  # expand for batch dim
+        if self.source_type.stream or self.source_type.from_img or self.source_type.tensor:  # batch_size >= 1
+            string += f"{i}: "
+            frame = self.dataset.count
+        else:
+            match = re.search(r"frame (\d+)/", s[i])
+            frame = int(match[1]) if match else None  # 0 if frame undetermined
+
+        self.txt_path = self.save_dir / "labels" / (p.stem + ("" if self.dataset.mode == "image" else f"_{frame}"))
+        string += "{:g}x{:g} ".format(*im.shape[2:])
+        result = self.results[i]
+        result.save_dir = self.save_dir.__str__()  # used in other locations
+        string += f"{result.verbose()}{result.speed['inference']:.1f}ms"
+
+        # Add predictions to image
+        if self.args.save or self.args.show:
+            self.plotted_img = result.plot(
+                line_width=self.args.line_width,
+                boxes=self.args.show_boxes,
+                conf=self.args.show_conf,
+                labels=self.args.show_labels,
+                im_gpu=None if self.args.retina_masks else im[i],
+            )
+
+        # Save results
+        if self.args.save_txt:
+            result.save_txt(f"{self.txt_path}.txt", save_conf=self.args.save_conf)
+        if self.args.save_crop:
+            result.save_crop(save_dir=self.save_dir / "crops", file_name=self.txt_path.stem)
+        if self.args.show:
+            self.show(str(p))
+        if self.args.save:
+            self.save_predicted_images(str(self.save_dir / p.name), frame)
+
+        return string
+
+    def save_predicted_images(self, save_path="", frame=0):
+        """Save video predictions as mp4 at specified path."""
+        im = self.plotted_img
+
+        # Save videos and streams
+        if self.dataset.mode in {"stream", "video"}:
+            fps = self.dataset.fps if self.dataset.mode == "video" else 30
+            frames_path = f"{save_path.split('.', 1)[0]}_frames/"
+            if save_path not in self.vid_writer:  # new video
+                if self.args.save_frames:
+                    Path(frames_path).mkdir(parents=True, exist_ok=True)
+                suffix, fourcc = (".mp4", "avc1") if MACOS else (".avi", "WMV2") if WINDOWS else (".avi", "MJPG")
+                self.vid_writer[save_path] = cv2.VideoWriter(
+                    filename=str(Path(save_path).with_suffix(suffix)),
+                    fourcc=cv2.VideoWriter_fourcc(*fourcc),
+                    fps=fps,  # integer required, floats produce error in MP4 codec
+                    frameSize=(im.shape[1], im.shape[0]),  # (width, height)
+                )
+
+            # Save video
+            self.vid_writer[save_path].write(im)
+            if self.args.save_frames:
+                cv2.imwrite(f"{frames_path}{frame}.jpg", im)
+
+        # Save images
+        else:
+            cv2.imwrite(str(Path(save_path).with_suffix(".jpg")), im)  # save to JPG for best support
+
+    def show(self, p=""):
+        """Display an image in a window using the OpenCV imshow function."""
+        im = self.plotted_img
+        if platform.system() == "Linux" and p not in self.windows:
+            self.windows.append(p)
+            cv2.namedWindow(p, cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)  # allow window resize (Linux)
+            cv2.resizeWindow(p, im.shape[1], im.shape[0])  # (width, height)
+        cv2.imshow(p, im)
+        cv2.waitKey(300 if self.dataset.mode == "image" else 1)  # 1 millisecond
+
+    def run_callbacks(self, event: str):
+        """Runs all registered callbacks for a specific event."""
+        for callback in self.callbacks.get(event, []):
+            callback(self)
+
+    def add_callback(self, event: str, func):
+        """Add callback."""
+        self.callbacks[event].append(func)

Diferenças do arquivo suprimidas por serem muito extensas
+ 1740 - 0
ultralytics/engine/results.py


+ 820 - 0
ultralytics/engine/trainer.py

@@ -0,0 +1,820 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+"""
+Train a model on a dataset.
+
+Usage:
+    $ yolo mode=train model=yolov8n.pt data=coco8.yaml imgsz=640 epochs=100 batch=16
+"""
+
+import gc
+import math
+import os
+import subprocess
+import time
+import warnings
+from copy import copy, deepcopy
+from datetime import datetime, timedelta
+from pathlib import Path
+
+import numpy as np
+import torch
+from torch import distributed as dist
+from torch import nn, optim
+
+from ultralytics.cfg import get_cfg, get_save_dir
+from ultralytics.data.utils import check_cls_dataset, check_det_dataset
+from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
+from ultralytics.utils import (
+    DEFAULT_CFG,
+    LOCAL_RANK,
+    LOGGER,
+    RANK,
+    TQDM,
+    __version__,
+    callbacks,
+    clean_url,
+    colorstr,
+    emojis,
+    yaml_save,
+)
+from ultralytics.utils.autobatch import check_train_batch_size
+from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_model_file_from_stem, print_args
+from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
+from ultralytics.utils.files import get_latest_run
+from ultralytics.utils.torch_utils import (
+    TORCH_2_4,
+    EarlyStopping,
+    ModelEMA,
+    autocast,
+    convert_optimizer_state_dict_to_fp16,
+    init_seeds,
+    one_cycle,
+    select_device,
+    strip_optimizer,
+    torch_distributed_zero_first,
+)
+
+
+class BaseTrainer:
+    """
+    A base class for creating trainers.
+
+    Attributes:
+        args (SimpleNamespace): Configuration for the trainer.
+        validator (BaseValidator): Validator instance.
+        model (nn.Module): Model instance.
+        callbacks (defaultdict): Dictionary of callbacks.
+        save_dir (Path): Directory to save results.
+        wdir (Path): Directory to save weights.
+        last (Path): Path to the last checkpoint.
+        best (Path): Path to the best checkpoint.
+        save_period (int): Save checkpoint every x epochs (disabled if < 1).
+        batch_size (int): Batch size for training.
+        epochs (int): Number of epochs to train for.
+        start_epoch (int): Starting epoch for training.
+        device (torch.device): Device to use for training.
+        amp (bool): Flag to enable AMP (Automatic Mixed Precision).
+        scaler (amp.GradScaler): Gradient scaler for AMP.
+        data (str): Path to data.
+        trainset (torch.utils.data.Dataset): Training dataset.
+        testset (torch.utils.data.Dataset): Testing dataset.
+        ema (nn.Module): EMA (Exponential Moving Average) of the model.
+        resume (bool): Resume training from a checkpoint.
+        lf (nn.Module): Loss function.
+        scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler.
+        best_fitness (float): The best fitness value achieved.
+        fitness (float): Current fitness value.
+        loss (float): Current loss value.
+        tloss (float): Total loss value.
+        loss_names (list): List of loss names.
+        csv (Path): Path to results CSV file.
+    """
+
+    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
+        """
+        Initializes the BaseTrainer class.
+
+        Args:
+            cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
+            overrides (dict, optional): Configuration overrides. Defaults to None.
+        """
+        self.args = get_cfg(cfg, overrides)
+        self.check_resume(overrides)
+        self.device = select_device(self.args.device, self.args.batch)
+        self.validator = None
+        self.metrics = None
+        self.plots = {}
+        init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
+
+        # Dirs
+        self.save_dir = get_save_dir(self.args)
+        self.args.name = self.save_dir.name  # update name for loggers
+        self.wdir = self.save_dir / "weights"  # weights dir
+        if RANK in {-1, 0}:
+            self.wdir.mkdir(parents=True, exist_ok=True)  # make dir
+            self.args.save_dir = str(self.save_dir)
+            yaml_save(self.save_dir / "args.yaml", vars(self.args))  # save run args
+        self.last, self.best = self.wdir / "last.pt", self.wdir / "best.pt"  # checkpoint paths
+        self.save_period = self.args.save_period
+
+        self.batch_size = self.args.batch
+        self.epochs = self.args.epochs or 100  # in case users accidentally pass epochs=None with timed training
+        self.start_epoch = 0
+        if RANK == -1:
+            print_args(vars(self.args))
+
+        # Device
+        if self.device.type in {"cpu", "mps"}:
+            self.args.workers = 0  # faster CPU training as time dominated by inference, not dataloading
+
+        # Model and Dataset
+        self.model = check_model_file_from_stem(self.args.model)  # add suffix, i.e. yolov8n -> yolov8n.pt
+        with torch_distributed_zero_first(LOCAL_RANK):  # avoid auto-downloading dataset multiple times
+            self.trainset, self.testset = self.get_dataset()
+        self.ema = None
+
+        # Optimization utils init
+        self.lf = None
+        self.scheduler = None
+
+        # Epoch level metrics
+        self.best_fitness = None
+        self.fitness = None
+        self.loss = None
+        self.tloss = None
+        self.loss_names = ["Loss"]
+        self.csv = self.save_dir / "results.csv"
+        self.plot_idx = [0, 1, 2]
+
+        # HUB
+        self.hub_session = None
+
+        # Callbacks
+        self.callbacks = _callbacks or callbacks.get_default_callbacks()
+        if RANK in {-1, 0}:
+            callbacks.add_integration_callbacks(self)
+
+    def add_callback(self, event: str, callback):
+        """Appends the given callback."""
+        self.callbacks[event].append(callback)
+
+    def set_callback(self, event: str, callback):
+        """Overrides the existing callbacks with the given callback."""
+        self.callbacks[event] = [callback]
+
+    def run_callbacks(self, event: str):
+        """Run all existing callbacks associated with a particular event."""
+        for callback in self.callbacks.get(event, []):
+            callback(self)
+
+    def train(self):
+        """Allow device='', device=None on Multi-GPU systems to default to device=0."""
+        if isinstance(self.args.device, str) and len(self.args.device):  # i.e. device='0' or device='0,1,2,3'
+            world_size = len(self.args.device.split(","))
+        elif isinstance(self.args.device, (tuple, list)):  # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list)
+            world_size = len(self.args.device)
+        elif self.args.device in {"cpu", "mps"}:  # i.e. device='cpu' or 'mps'
+            world_size = 0
+        elif torch.cuda.is_available():  # i.e. device=None or device='' or device=number
+            world_size = 1  # default to device 0
+        else:  # i.e. device=None or device=''
+            world_size = 0
+
+        # Run subprocess if DDP training, else train normally
+        if world_size > 1 and "LOCAL_RANK" not in os.environ:
+            # Argument checks
+            if self.args.rect:
+                LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'")
+                self.args.rect = False
+            if self.args.batch < 1.0:
+                LOGGER.warning(
+                    "WARNING ⚠️ 'batch<1' for AutoBatch is incompatible with Multi-GPU training, setting "
+                    "default 'batch=16'"
+                )
+                self.args.batch = 16
+
+            # Command
+            cmd, file = generate_ddp_command(world_size, self)
+            try:
+                LOGGER.info(f"{colorstr('DDP:')} debug command {' '.join(cmd)}")
+                subprocess.run(cmd, check=True)
+            except Exception as e:
+                raise e
+            finally:
+                ddp_cleanup(self, str(file))
+
+        else:
+            self._do_train(world_size)
+
+    def _setup_scheduler(self):
+        """Initialize training learning rate scheduler."""
+        if self.args.cos_lr:
+            self.lf = one_cycle(1, self.args.lrf, self.epochs)  # cosine 1->hyp['lrf']
+        else:
+            self.lf = lambda x: max(1 - x / self.epochs, 0) * (1.0 - self.args.lrf) + self.args.lrf  # linear
+        self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
+
+    def _setup_ddp(self, world_size):
+        """Initializes and sets the DistributedDataParallel parameters for training."""
+        torch.cuda.set_device(RANK)
+        self.device = torch.device("cuda", RANK)
+        # LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
+        os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1"  # set to enforce timeout
+        dist.init_process_group(
+            backend="nccl" if dist.is_nccl_available() else "gloo",
+            timeout=timedelta(seconds=10800),  # 3 hours
+            rank=RANK,
+            world_size=world_size,
+        )
+
+    def _setup_train(self, world_size):
+        """Builds dataloaders and optimizer on correct rank process."""
+        # Model
+        self.run_callbacks("on_pretrain_routine_start")
+        ckpt = self.setup_model()
+        self.model = self.model.to(self.device)
+        self.set_model_attributes()
+
+        # Freeze layers
+        freeze_list = (
+            self.args.freeze
+            if isinstance(self.args.freeze, list)
+            else range(self.args.freeze)
+            if isinstance(self.args.freeze, int)
+            else []
+        )
+        always_freeze_names = [".dfl"]  # always freeze these layers
+        freeze_layer_names = [f"model.{x}." for x in freeze_list] + always_freeze_names
+        for k, v in self.model.named_parameters():
+            # v.register_hook(lambda x: torch.nan_to_num(x))  # NaN to 0 (commented for erratic training results)
+            if any(x in k for x in freeze_layer_names):
+                LOGGER.info(f"Freezing layer '{k}'")
+                v.requires_grad = False
+            elif not v.requires_grad and v.dtype.is_floating_point:  # only floating point Tensor can require gradients
+                LOGGER.info(
+                    f"WARNING ⚠️ setting 'requires_grad=True' for frozen layer '{k}'. "
+                    "See ultralytics.engine.trainer for customization of frozen layers."
+                )
+                v.requires_grad = True
+
+        # Check AMP
+        self.amp = torch.tensor(self.args.amp).to(self.device)  # True or False
+        if self.amp and RANK in {-1, 0}:  # Single-GPU and DDP
+            callbacks_backup = callbacks.default_callbacks.copy()  # backup callbacks as check_amp() resets them
+            self.amp = torch.tensor(check_amp(self.model), device=self.device)
+            callbacks.default_callbacks = callbacks_backup  # restore callbacks
+        if RANK > -1 and world_size > 1:  # DDP
+            dist.broadcast(self.amp, src=0)  # broadcast the tensor from rank 0 to all other ranks (returns None)
+        self.amp = bool(self.amp)  # as boolean
+        self.scaler = (
+            torch.amp.GradScaler("cuda", enabled=self.amp) if TORCH_2_4 else torch.cuda.amp.GradScaler(enabled=self.amp)
+        )
+        if world_size > 1:
+            self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK], find_unused_parameters=True)
+            self.set_model_attributes()  # set again after DDP wrapper
+
+        # Check imgsz
+        gs = max(int(self.model.stride.max() if hasattr(self.model, "stride") else 32), 32)  # grid size (max stride)
+        self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1)
+        self.stride = gs  # for multiscale training
+
+        # Batch size
+        if self.batch_size < 1 and RANK == -1:  # single-GPU only, estimate best batch size
+            self.args.batch = self.batch_size = self.auto_batch()
+
+        # Dataloaders
+        batch_size = self.batch_size // max(world_size, 1)
+        self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=LOCAL_RANK, mode="train")
+        if RANK in {-1, 0}:
+            # Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects.
+            self.test_loader = self.get_dataloader(
+                self.testset, batch_size=batch_size if self.args.task == "obb" else batch_size * 2, rank=-1, mode="val"
+            )
+            self.validator = self.get_validator()
+            metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix="val")
+            self.metrics = dict(zip(metric_keys, [0] * len(metric_keys)))
+            self.ema = ModelEMA(self.model)
+            if self.args.plots:
+                self.plot_training_labels()
+
+        # Optimizer
+        self.accumulate = max(round(self.args.nbs / self.batch_size), 1)  # accumulate loss before optimizing
+        weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs  # scale weight_decay
+        iterations = math.ceil(len(self.train_loader.dataset) / max(self.batch_size, self.args.nbs)) * self.epochs
+        self.optimizer = self.build_optimizer(
+            model=self.model,
+            name=self.args.optimizer,
+            lr=self.args.lr0,
+            momentum=self.args.momentum,
+            decay=weight_decay,
+            iterations=iterations,
+        )
+        # Scheduler
+        self._setup_scheduler()
+        self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
+        self.resume_training(ckpt)
+        self.scheduler.last_epoch = self.start_epoch - 1  # do not move
+        self.run_callbacks("on_pretrain_routine_end")
+
+    def _do_train(self, world_size=1):
+        """Train completed, evaluate and plot if specified by arguments."""
+        if world_size > 1:
+            self._setup_ddp(world_size)
+        self._setup_train(world_size)
+
+        nb = len(self.train_loader)  # number of batches
+        nw = max(round(self.args.warmup_epochs * nb), 100) if self.args.warmup_epochs > 0 else -1  # warmup iterations
+        last_opt_step = -1
+        self.epoch_time = None
+        self.epoch_time_start = time.time()
+        self.train_time_start = time.time()
+        self.run_callbacks("on_train_start")
+        LOGGER.info(
+            f"Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n"
+            f"Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n"
+            f"Logging results to {colorstr('bold', self.save_dir)}\n"
+            f"Starting training for " + (f"{self.args.time} hours..." if self.args.time else f"{self.epochs} epochs...")
+        )
+        if self.args.close_mosaic:
+            base_idx = (self.epochs - self.args.close_mosaic) * nb
+            self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
+        epoch = self.start_epoch
+        self.optimizer.zero_grad()  # zero any resumed gradients to ensure stability on train start
+        while True:
+            self.epoch = epoch
+            self.run_callbacks("on_train_epoch_start")
+            with warnings.catch_warnings():
+                warnings.simplefilter("ignore")  # suppress 'Detected lr_scheduler.step() before optimizer.step()'
+                self.scheduler.step()
+
+            self.model.train()
+            if RANK != -1:
+                self.train_loader.sampler.set_epoch(epoch)
+            pbar = enumerate(self.train_loader)
+            # Update dataloader attributes (optional)
+            if epoch == (self.epochs - self.args.close_mosaic):
+                self._close_dataloader_mosaic()
+                self.train_loader.reset()
+
+            if RANK in {-1, 0}:
+                LOGGER.info(self.progress_string())
+                pbar = TQDM(enumerate(self.train_loader), total=nb)
+            self.tloss = None
+            for i, batch in pbar:
+                self.run_callbacks("on_train_batch_start")
+                # Warmup
+                ni = i + nb * epoch
+                if ni <= nw:
+                    xi = [0, nw]  # x interp
+                    self.accumulate = max(1, int(np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round()))
+                    for j, x in enumerate(self.optimizer.param_groups):
+                        # Bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
+                        x["lr"] = np.interp(
+                            ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x["initial_lr"] * self.lf(epoch)]
+                        )
+                        if "momentum" in x:
+                            x["momentum"] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
+
+                # Forward
+                with autocast(self.amp):
+                    batch = self.preprocess_batch(batch)
+                    self.loss, self.loss_items = self.model(batch)
+                    if RANK != -1:
+                        self.loss *= world_size
+                    self.tloss = (
+                        (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None else self.loss_items
+                    )
+
+                # Backward
+                self.scaler.scale(self.loss).backward()
+
+                # Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
+                if ni - last_opt_step >= self.accumulate:
+                    self.optimizer_step()
+                    last_opt_step = ni
+
+                    # Timed stopping
+                    if self.args.time:
+                        self.stop = (time.time() - self.train_time_start) > (self.args.time * 3600)
+                        if RANK != -1:  # if DDP training
+                            broadcast_list = [self.stop if RANK == 0 else None]
+                            dist.broadcast_object_list(broadcast_list, 0)  # broadcast 'stop' to all ranks
+                            self.stop = broadcast_list[0]
+                        if self.stop:  # training time exceeded
+                            break
+
+                # Log
+                if RANK in {-1, 0}:
+                    loss_length = self.tloss.shape[0] if len(self.tloss.shape) else 1
+                    pbar.set_description(
+                        ("%11s" * 2 + "%11.4g" * (2 + loss_length))
+                        % (
+                            f"{epoch + 1}/{self.epochs}",
+                            f"{self._get_memory():.3g}G",  # (GB) GPU memory util
+                            *(self.tloss if loss_length > 1 else torch.unsqueeze(self.tloss, 0)),  # losses
+                            batch["cls"].shape[0],  # batch size, i.e. 8
+                            batch["img"].shape[-1],  # imgsz, i.e 640
+                        )
+                    )
+                    self.run_callbacks("on_batch_end")
+                    if self.args.plots and ni in self.plot_idx:
+                        self.plot_training_samples(batch, ni)
+
+                self.run_callbacks("on_train_batch_end")
+
+            self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)}  # for loggers
+            self.run_callbacks("on_train_epoch_end")
+            if RANK in {-1, 0}:
+                final_epoch = epoch + 1 >= self.epochs
+                self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"])
+
+                # Validation
+                if self.args.val or final_epoch or self.stopper.possible_stop or self.stop:
+                    self.metrics, self.fitness = self.validate()
+                self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr})
+                self.stop |= self.stopper(epoch + 1, self.fitness) or final_epoch
+                if self.args.time:
+                    self.stop |= (time.time() - self.train_time_start) > (self.args.time * 3600)
+
+                # Save model
+                if self.args.save or final_epoch:
+                    self.save_model()
+                    self.run_callbacks("on_model_save")
+
+            # Scheduler
+            t = time.time()
+            self.epoch_time = t - self.epoch_time_start
+            self.epoch_time_start = t
+            if self.args.time:
+                mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1)
+                self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time)
+                self._setup_scheduler()
+                self.scheduler.last_epoch = self.epoch  # do not move
+                self.stop |= epoch >= self.epochs  # stop if exceeded epochs
+            self.run_callbacks("on_fit_epoch_end")
+            self._clear_memory()
+
+            # Early Stopping
+            if RANK != -1:  # if DDP training
+                broadcast_list = [self.stop if RANK == 0 else None]
+                dist.broadcast_object_list(broadcast_list, 0)  # broadcast 'stop' to all ranks
+                self.stop = broadcast_list[0]
+            if self.stop:
+                break  # must break all DDP ranks
+            epoch += 1
+
+        if RANK in {-1, 0}:
+            # Do final val with best.pt
+            seconds = time.time() - self.train_time_start
+            LOGGER.info(f"\n{epoch - self.start_epoch + 1} epochs completed in {seconds / 3600:.3f} hours.")
+            self.final_eval()
+            if self.args.plots:
+                self.plot_metrics()
+            self.run_callbacks("on_train_end")
+        self._clear_memory()
+        self.run_callbacks("teardown")
+
+    def auto_batch(self, max_num_obj=0):
+        """Get batch size by calculating memory occupation of model."""
+        return check_train_batch_size(
+            model=self.model,
+            imgsz=self.args.imgsz,
+            amp=self.amp,
+            batch=self.batch_size,
+            max_num_obj=max_num_obj,
+        )  # returns batch size
+
+    def _get_memory(self):
+        """Get accelerator memory utilization in GB."""
+        if self.device.type == "mps":
+            memory = torch.mps.driver_allocated_memory()
+        elif self.device.type == "cpu":
+            memory = 0
+        else:
+            memory = torch.cuda.memory_reserved()
+        return memory / 1e9
+
+    def _clear_memory(self):
+        """Clear accelerator memory on different platforms."""
+        gc.collect()
+        if self.device.type == "mps":
+            torch.mps.empty_cache()
+        elif self.device.type == "cpu":
+            return
+        else:
+            torch.cuda.empty_cache()
+
+    def read_results_csv(self):
+        """Read results.csv into a dict using pandas."""
+        import pandas as pd  # scope for faster 'import ultralytics'
+
+        return pd.read_csv(self.csv).to_dict(orient="list")
+
+    def save_model(self):
+        """Save model training checkpoints with additional metadata."""
+        import io
+
+        # Serialize ckpt to a byte buffer once (faster than repeated torch.save() calls)
+        buffer = io.BytesIO()
+        torch.save(
+            {
+                "epoch": self.epoch,
+                "best_fitness": self.best_fitness,
+                "model": None,  # resume and final checkpoints derive from EMA
+                "ema": deepcopy(self.ema.ema).half(),
+                "updates": self.ema.updates,
+                "optimizer": convert_optimizer_state_dict_to_fp16(deepcopy(self.optimizer.state_dict())),
+                "train_args": vars(self.args),  # save as dict
+                "train_metrics": {**self.metrics, **{"fitness": self.fitness}},
+                "train_results": self.read_results_csv(),
+                "date": datetime.now().isoformat(),
+                "version": __version__,
+                "license": "AGPL-3.0 (https://ultralytics.com/license)",
+                "docs": "https://docs.ultralytics.com",
+            },
+            buffer,
+        )
+        serialized_ckpt = buffer.getvalue()  # get the serialized content to save
+
+        # Save checkpoints
+        self.last.write_bytes(serialized_ckpt)  # save last.pt
+        if self.best_fitness == self.fitness:
+            self.best.write_bytes(serialized_ckpt)  # save best.pt
+        if (self.save_period > 0) and (self.epoch % self.save_period == 0):
+            (self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt)  # save epoch, i.e. 'epoch3.pt'
+        # if self.args.close_mosaic and self.epoch == (self.epochs - self.args.close_mosaic - 1):
+        #    (self.wdir / "last_mosaic.pt").write_bytes(serialized_ckpt)  # save mosaic checkpoint
+
+    def get_dataset(self):
+        """
+        Get train, val path from data dict if it exists.
+
+        Returns None if data format is not recognized.
+        """
+        try:
+            if self.args.task == "classify":
+                data = check_cls_dataset(self.args.data)
+            elif self.args.data.split(".")[-1] in {"yaml", "yml"} or self.args.task in {
+                "detect",
+                "segment",
+                "pose",
+                "obb",
+            }:
+                data = check_det_dataset(self.args.data)
+                if "yaml_file" in data:
+                    self.args.data = data["yaml_file"]  # for validating 'yolo train data=url.zip' usage
+        except Exception as e:
+            raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
+        self.data = data
+        return data["train"], data.get("val") or data.get("test")
+
+    def setup_model(self):
+        """Load/create/download model for any task."""
+        if isinstance(self.model, torch.nn.Module):  # if model is loaded beforehand. No setup needed
+            return
+
+        cfg, weights = self.model, None
+        ckpt = None
+        if str(self.model).endswith(".pt"):
+            weights, ckpt = attempt_load_one_weight(self.model)
+            cfg = weights.yaml
+        elif isinstance(self.args.pretrained, (str, Path)):
+            weights, _ = attempt_load_one_weight(self.args.pretrained)
+        self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1)  # calls Model(cfg, weights)
+        return ckpt
+
+    def optimizer_step(self):
+        """Perform a single step of the training optimizer with gradient clipping and EMA update."""
+        self.scaler.unscale_(self.optimizer)  # unscale gradients
+        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0)  # clip gradients
+        self.scaler.step(self.optimizer)
+        self.scaler.update()
+        self.optimizer.zero_grad()
+        if self.ema:
+            self.ema.update(self.model)
+
+    def preprocess_batch(self, batch):
+        """Allows custom preprocessing model inputs and ground truths depending on task type."""
+        return batch
+
+    def validate(self):
+        """
+        Runs validation on test set using self.validator.
+
+        The returned dict is expected to contain "fitness" key.
+        """
+        metrics = self.validator(self)
+        fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy())  # use loss as fitness measure if not found
+        if not self.best_fitness or self.best_fitness < fitness:
+            self.best_fitness = fitness
+        return metrics, fitness
+
+    def get_model(self, cfg=None, weights=None, verbose=True):
+        """Get model and raise NotImplementedError for loading cfg files."""
+        raise NotImplementedError("This task trainer doesn't support loading cfg files")
+
+    def get_validator(self):
+        """Returns a NotImplementedError when the get_validator function is called."""
+        raise NotImplementedError("get_validator function not implemented in trainer")
+
+    def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
+        """Returns dataloader derived from torch.data.Dataloader."""
+        raise NotImplementedError("get_dataloader function not implemented in trainer")
+
+    def build_dataset(self, img_path, mode="train", batch=None):
+        """Build dataset."""
+        raise NotImplementedError("build_dataset function not implemented in trainer")
+
+    def label_loss_items(self, loss_items=None, prefix="train"):
+        """
+        Returns a loss dict with labelled training loss items tensor.
+
+        Note:
+            This is not needed for classification but necessary for segmentation & detection
+        """
+        return {"loss": loss_items} if loss_items is not None else ["loss"]
+
+    def set_model_attributes(self):
+        """To set or update model parameters before training."""
+        self.model.names = self.data["names"]
+
+    def build_targets(self, preds, targets):
+        """Builds target tensors for training YOLO model."""
+        pass
+
+    def progress_string(self):
+        """Returns a string describing training progress."""
+        return ""
+
+    # TODO: may need to put these following functions into callback
+    def plot_training_samples(self, batch, ni):
+        """Plots training samples during YOLO training."""
+        pass
+
+    def plot_training_labels(self):
+        """Plots training labels for YOLO model."""
+        pass
+
+    def save_metrics(self, metrics):
+        """Saves training metrics to a CSV file."""
+        keys, vals = list(metrics.keys()), list(metrics.values())
+        n = len(metrics) + 2  # number of cols
+        s = "" if self.csv.exists() else (("%s," * n % tuple(["epoch", "time"] + keys)).rstrip(",") + "\n")  # header
+        t = time.time() - self.train_time_start
+        with open(self.csv, "a") as f:
+            f.write(s + ("%.6g," * n % tuple([self.epoch + 1, t] + vals)).rstrip(",") + "\n")
+
+    def plot_metrics(self):
+        """Plot and display metrics visually."""
+        pass
+
+    def on_plot(self, name, data=None):
+        """Registers plots (e.g. to be consumed in callbacks)."""
+        path = Path(name)
+        self.plots[path] = {"data": data, "timestamp": time.time()}
+
+    def final_eval(self):
+        """Performs final evaluation and validation for object detection YOLO model."""
+        ckpt = {}
+        for f in self.last, self.best:
+            if f.exists():
+                if f is self.last:
+                    ckpt = strip_optimizer(f)
+                elif f is self.best:
+                    k = "train_results"  # update best.pt train_metrics from last.pt
+                    strip_optimizer(f, updates={k: ckpt[k]} if k in ckpt else None)
+                    LOGGER.info(f"\nValidating {f}...")
+                    self.validator.args.plots = self.args.plots
+                    self.metrics = self.validator(model=f)
+                    self.metrics.pop("fitness", None)
+                    self.run_callbacks("on_fit_epoch_end")
+
+    def check_resume(self, overrides):
+        """Check if resume checkpoint exists and update arguments accordingly."""
+        resume = self.args.resume
+        if resume:
+            try:
+                exists = isinstance(resume, (str, Path)) and Path(resume).exists()
+                last = Path(check_file(resume) if exists else get_latest_run())
+
+                # Check that resume data YAML exists, otherwise strip to force re-download of dataset
+                ckpt_args = attempt_load_weights(last).args
+                if not Path(ckpt_args["data"]).exists():
+                    ckpt_args["data"] = self.args.data
+
+                resume = True
+                self.args = get_cfg(ckpt_args)
+                self.args.model = self.args.resume = str(last)  # reinstate model
+                for k in (
+                    "imgsz",
+                    "batch",
+                    "device",
+                    "close_mosaic",
+                ):  # allow arg updates to reduce memory or update device on resume
+                    if k in overrides:
+                        setattr(self.args, k, overrides[k])
+
+            except Exception as e:
+                raise FileNotFoundError(
+                    "Resume checkpoint not found. Please pass a valid checkpoint to resume from, "
+                    "i.e. 'yolo train resume model=path/to/last.pt'"
+                ) from e
+        self.resume = resume
+
+    def resume_training(self, ckpt):
+        """Resume YOLO training from given epoch and best fitness."""
+        if ckpt is None or not self.resume:
+            return
+        best_fitness = 0.0
+        start_epoch = ckpt.get("epoch", -1) + 1
+        if ckpt.get("optimizer", None) is not None:
+            self.optimizer.load_state_dict(ckpt["optimizer"])  # optimizer
+            best_fitness = ckpt["best_fitness"]
+        if self.ema and ckpt.get("ema"):
+            self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict())  # EMA
+            self.ema.updates = ckpt["updates"]
+        assert start_epoch > 0, (
+            f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n"
+            f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
+        )
+        LOGGER.info(f"Resuming training {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs")
+        if self.epochs < start_epoch:
+            LOGGER.info(
+                f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs."
+            )
+            self.epochs += ckpt["epoch"]  # finetune additional epochs
+        self.best_fitness = best_fitness
+        self.start_epoch = start_epoch
+        if start_epoch > (self.epochs - self.args.close_mosaic):
+            self._close_dataloader_mosaic()
+
+    def _close_dataloader_mosaic(self):
+        """Update dataloaders to stop using mosaic augmentation."""
+        if hasattr(self.train_loader.dataset, "mosaic"):
+            self.train_loader.dataset.mosaic = False
+        if hasattr(self.train_loader.dataset, "close_mosaic"):
+            LOGGER.info("Closing dataloader mosaic")
+            self.train_loader.dataset.close_mosaic(hyp=copy(self.args))
+
+    def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
+        """
+        Constructs an optimizer for the given model, based on the specified optimizer name, learning rate, momentum,
+        weight decay, and number of iterations.
+
+        Args:
+            model (torch.nn.Module): The model for which to build an optimizer.
+            name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected
+                based on the number of iterations. Default: 'auto'.
+            lr (float, optional): The learning rate for the optimizer. Default: 0.001.
+            momentum (float, optional): The momentum factor for the optimizer. Default: 0.9.
+            decay (float, optional): The weight decay for the optimizer. Default: 1e-5.
+            iterations (float, optional): The number of iterations, which determines the optimizer if
+                name is 'auto'. Default: 1e5.
+
+        Returns:
+            (torch.optim.Optimizer): The constructed optimizer.
+        """
+        g = [], [], []  # optimizer parameter groups
+        bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k)  # normalization layers, i.e. BatchNorm2d()
+        if name == "auto":
+            LOGGER.info(
+                f"{colorstr('optimizer:')} 'optimizer=auto' found, "
+                f"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and "
+                f"determining best 'optimizer', 'lr0' and 'momentum' automatically... "
+            )
+            nc = getattr(model, "nc", 10)  # number of classes
+            lr_fit = round(0.002 * 5 / (4 + nc), 6)  # lr0 fit equation to 6 decimal places
+            name, lr, momentum = ("SGD", 0.01, 0.9) if iterations > 10000 else ("AdamW", lr_fit, 0.9)
+            self.args.warmup_bias_lr = 0.0  # no higher than 0.01 for Adam
+
+        for module_name, module in model.named_modules():
+            for param_name, param in module.named_parameters(recurse=False):
+                fullname = f"{module_name}.{param_name}" if module_name else param_name
+                if "bias" in fullname:  # bias (no decay)
+                    g[2].append(param)
+                elif isinstance(module, bn):  # weight (no decay)
+                    g[1].append(param)
+                else:  # weight (with decay)
+                    g[0].append(param)
+
+        optimizers = {"Adam", "Adamax", "AdamW", "NAdam", "RAdam", "RMSProp", "SGD", "auto"}
+        name = {x.lower(): x for x in optimizers}.get(name.lower())
+        if name in {"Adam", "Adamax", "AdamW", "NAdam", "RAdam"}:
+            optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
+        elif name == "RMSProp":
+            optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
+        elif name == "SGD":
+            optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
+        else:
+            raise NotImplementedError(
+                f"Optimizer '{name}' not found in list of available optimizers {optimizers}. "
+                "Request support for addition optimizers at https://github.com/ultralytics/ultralytics."
+            )
+
+        optimizer.add_param_group({"params": g[0], "weight_decay": decay})  # add g0 with weight_decay
+        optimizer.add_param_group({"params": g[1], "weight_decay": 0.0})  # add g1 (BatchNorm2d weights)
+        LOGGER.info(
+            f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups "
+            f"{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)"
+        )
+        return optimizer

+ 242 - 0
ultralytics/engine/tuner.py

@@ -0,0 +1,242 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+"""
+Module provides functionalities for hyperparameter tuning of the Ultralytics YOLO models for object detection, instance
+segmentation, image classification, pose estimation, and multi-object tracking.
+
+Hyperparameter tuning is the process of systematically searching for the optimal set of hyperparameters
+that yield the best model performance. This is particularly crucial in deep learning models like YOLO,
+where small changes in hyperparameters can lead to significant differences in model accuracy and efficiency.
+
+Example:
+    Tune hyperparameters for YOLOv8n on COCO8 at imgsz=640 and epochs=30 for 300 tuning iterations.
+    ```python
+    from ultralytics import YOLO
+
+    model = YOLO("yolo11n.pt")
+    model.tune(data="coco8.yaml", epochs=10, iterations=300, optimizer="AdamW", plots=False, save=False, val=False)
+    ```
+"""
+
+import random
+import shutil
+import subprocess
+import time
+
+import numpy as np
+import torch
+
+from ultralytics.cfg import get_cfg, get_save_dir
+from ultralytics.utils import DEFAULT_CFG, LOGGER, callbacks, colorstr, remove_colorstr, yaml_print, yaml_save
+from ultralytics.utils.plotting import plot_tune_results
+
+
+class Tuner:
+    """
+    Class responsible for hyperparameter tuning of YOLO models.
+
+    The class evolves YOLO model hyperparameters over a given number of iterations
+    by mutating them according to the search space and retraining the model to evaluate their performance.
+
+    Attributes:
+        space (dict): Hyperparameter search space containing bounds and scaling factors for mutation.
+        tune_dir (Path): Directory where evolution logs and results will be saved.
+        tune_csv (Path): Path to the CSV file where evolution logs are saved.
+
+    Methods:
+        _mutate(hyp: dict) -> dict:
+            Mutates the given hyperparameters within the bounds specified in `self.space`.
+
+        __call__():
+            Executes the hyperparameter evolution across multiple iterations.
+
+    Example:
+        Tune hyperparameters for YOLOv8n on COCO8 at imgsz=640 and epochs=30 for 300 tuning iterations.
+        ```python
+        from ultralytics import YOLO
+
+        model = YOLO("yolo11n.pt")
+        model.tune(data="coco8.yaml", epochs=10, iterations=300, optimizer="AdamW", plots=False, save=False, val=False)
+        ```
+
+        Tune with custom search space.
+        ```python
+        from ultralytics import YOLO
+
+        model = YOLO("yolo11n.pt")
+        model.tune(space={key1: val1, key2: val2})  # custom search space dictionary
+        ```
+    """
+
+    def __init__(self, args=DEFAULT_CFG, _callbacks=None):
+        """
+        Initialize the Tuner with configurations.
+
+        Args:
+            args (dict, optional): Configuration for hyperparameter evolution.
+        """
+        self.space = args.pop("space", None) or {  # key: (min, max, gain(optional))
+            # 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
+            "lr0": (1e-5, 1e-1),  # initial learning rate (i.e. SGD=1E-2, Adam=1E-3)
+            "lrf": (0.0001, 0.1),  # final OneCycleLR learning rate (lr0 * lrf)
+            "momentum": (0.7, 0.98, 0.3),  # SGD momentum/Adam beta1
+            "weight_decay": (0.0, 0.001),  # optimizer weight decay 5e-4
+            "warmup_epochs": (0.0, 5.0),  # warmup epochs (fractions ok)
+            "warmup_momentum": (0.0, 0.95),  # warmup initial momentum
+            "box": (1.0, 20.0),  # box loss gain
+            "cls": (0.2, 4.0),  # cls loss gain (scale with pixels)
+            "dfl": (0.4, 6.0),  # dfl loss gain
+            "hsv_h": (0.0, 0.1),  # image HSV-Hue augmentation (fraction)
+            "hsv_s": (0.0, 0.9),  # image HSV-Saturation augmentation (fraction)
+            "hsv_v": (0.0, 0.9),  # image HSV-Value augmentation (fraction)
+            "degrees": (0.0, 45.0),  # image rotation (+/- deg)
+            "translate": (0.0, 0.9),  # image translation (+/- fraction)
+            "scale": (0.0, 0.95),  # image scale (+/- gain)
+            "shear": (0.0, 10.0),  # image shear (+/- deg)
+            "perspective": (0.0, 0.001),  # image perspective (+/- fraction), range 0-0.001
+            "flipud": (0.0, 1.0),  # image flip up-down (probability)
+            "fliplr": (0.0, 1.0),  # image flip left-right (probability)
+            "bgr": (0.0, 1.0),  # image channel bgr (probability)
+            "mosaic": (0.0, 1.0),  # image mixup (probability)
+            "mixup": (0.0, 1.0),  # image mixup (probability)
+            "copy_paste": (0.0, 1.0),  # segment copy-paste (probability)
+        }
+        self.args = get_cfg(overrides=args)
+        self.tune_dir = get_save_dir(self.args, name=self.args.name or "tune")
+        self.args.name = None  # reset to not affect training directory
+        self.tune_csv = self.tune_dir / "tune_results.csv"
+        self.callbacks = _callbacks or callbacks.get_default_callbacks()
+        self.prefix = colorstr("Tuner: ")
+        callbacks.add_integration_callbacks(self)
+        LOGGER.info(
+            f"{self.prefix}Initialized Tuner instance with 'tune_dir={self.tune_dir}'\n"
+            f"{self.prefix}💡 Learn about tuning at https://docs.ultralytics.com/guides/hyperparameter-tuning"
+        )
+
+    def _mutate(self, parent="single", n=5, mutation=0.8, sigma=0.2):
+        """
+        Mutates the hyperparameters based on bounds and scaling factors specified in `self.space`.
+
+        Args:
+            parent (str): Parent selection method: 'single' or 'weighted'.
+            n (int): Number of parents to consider.
+            mutation (float): Probability of a parameter mutation in any given iteration.
+            sigma (float): Standard deviation for Gaussian random number generator.
+
+        Returns:
+            (dict): A dictionary containing mutated hyperparameters.
+        """
+        if self.tune_csv.exists():  # if CSV file exists: select best hyps and mutate
+            # Select parent(s)
+            x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1)
+            fitness = x[:, 0]  # first column
+            n = min(n, len(x))  # number of previous results to consider
+            x = x[np.argsort(-fitness)][:n]  # top n mutations
+            w = x[:, 0] - x[:, 0].min() + 1e-6  # weights (sum > 0)
+            if parent == "single" or len(x) == 1:
+                # x = x[random.randint(0, n - 1)]  # random selection
+                x = x[random.choices(range(n), weights=w)[0]]  # weighted selection
+            elif parent == "weighted":
+                x = (x * w.reshape(n, 1)).sum(0) / w.sum()  # weighted combination
+
+            # Mutate
+            r = np.random  # method
+            r.seed(int(time.time()))
+            g = np.array([v[2] if len(v) == 3 else 1.0 for v in self.space.values()])  # gains 0-1
+            ng = len(self.space)
+            v = np.ones(ng)
+            while all(v == 1):  # mutate until a change occurs (prevent duplicates)
+                v = (g * (r.random(ng) < mutation) * r.randn(ng) * r.random() * sigma + 1).clip(0.3, 3.0)
+            hyp = {k: float(x[i + 1] * v[i]) for i, k in enumerate(self.space.keys())}
+        else:
+            hyp = {k: getattr(self.args, k) for k in self.space.keys()}
+
+        # Constrain to limits
+        for k, v in self.space.items():
+            hyp[k] = max(hyp[k], v[0])  # lower limit
+            hyp[k] = min(hyp[k], v[1])  # upper limit
+            hyp[k] = round(hyp[k], 5)  # significant digits
+
+        return hyp
+
+    def __call__(self, model=None, iterations=10, cleanup=True):
+        """
+        Executes the hyperparameter evolution process when the Tuner instance is called.
+
+        This method iterates through the number of iterations, performing the following steps in each iteration:
+        1. Load the existing hyperparameters or initialize new ones.
+        2. Mutate the hyperparameters using the `mutate` method.
+        3. Train a YOLO model with the mutated hyperparameters.
+        4. Log the fitness score and mutated hyperparameters to a CSV file.
+
+        Args:
+           model (Model): A pre-initialized YOLO model to be used for training.
+           iterations (int): The number of generations to run the evolution for.
+           cleanup (bool): Whether to delete iteration weights to reduce storage space used during tuning.
+
+        Note:
+           The method utilizes the `self.tune_csv` Path object to read and log hyperparameters and fitness scores.
+           Ensure this path is set correctly in the Tuner instance.
+        """
+        t0 = time.time()
+        best_save_dir, best_metrics = None, None
+        (self.tune_dir / "weights").mkdir(parents=True, exist_ok=True)
+        for i in range(iterations):
+            # Mutate hyperparameters
+            mutated_hyp = self._mutate()
+            LOGGER.info(f"{self.prefix}Starting iteration {i + 1}/{iterations} with hyperparameters: {mutated_hyp}")
+
+            metrics = {}
+            train_args = {**vars(self.args), **mutated_hyp}
+            save_dir = get_save_dir(get_cfg(train_args))
+            weights_dir = save_dir / "weights"
+            try:
+                # Train YOLO model with mutated hyperparameters (run in subprocess to avoid dataloader hang)
+                cmd = ["yolo", "train", *(f"{k}={v}" for k, v in train_args.items())]
+                return_code = subprocess.run(" ".join(cmd), check=True, shell=True).returncode
+                ckpt_file = weights_dir / ("best.pt" if (weights_dir / "best.pt").exists() else "last.pt")
+                metrics = torch.load(ckpt_file)["train_metrics"]
+                assert return_code == 0, "training failed"
+
+            except Exception as e:
+                LOGGER.warning(f"WARNING ❌️ training failure for hyperparameter tuning iteration {i + 1}\n{e}")
+
+            # Save results and mutated_hyp to CSV
+            fitness = metrics.get("fitness", 0.0)
+            log_row = [round(fitness, 5)] + [mutated_hyp[k] for k in self.space.keys()]
+            headers = "" if self.tune_csv.exists() else (",".join(["fitness"] + list(self.space.keys())) + "\n")
+            with open(self.tune_csv, "a") as f:
+                f.write(headers + ",".join(map(str, log_row)) + "\n")
+
+            # Get best results
+            x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1)
+            fitness = x[:, 0]  # first column
+            best_idx = fitness.argmax()
+            best_is_current = best_idx == i
+            if best_is_current:
+                best_save_dir = save_dir
+                best_metrics = {k: round(v, 5) for k, v in metrics.items()}
+                for ckpt in weights_dir.glob("*.pt"):
+                    shutil.copy2(ckpt, self.tune_dir / "weights")
+            elif cleanup:
+                shutil.rmtree(weights_dir, ignore_errors=True)  # remove iteration weights/ dir to reduce storage space
+
+            # Plot tune results
+            plot_tune_results(self.tune_csv)
+
+            # Save and print tune results
+            header = (
+                f"{self.prefix}{i + 1}/{iterations} iterations complete ✅ ({time.time() - t0:.2f}s)\n"
+                f"{self.prefix}Results saved to {colorstr('bold', self.tune_dir)}\n"
+                f"{self.prefix}Best fitness={fitness[best_idx]} observed at iteration {best_idx + 1}\n"
+                f"{self.prefix}Best fitness metrics are {best_metrics}\n"
+                f"{self.prefix}Best fitness model is {best_save_dir}\n"
+                f"{self.prefix}Best fitness hyperparameters are printed below.\n"
+            )
+            LOGGER.info("\n" + header)
+            data = {k: float(x[best_idx, i + 1]) for i, k in enumerate(self.space.keys())}
+            yaml_save(
+                self.tune_dir / "best_hyperparameters.yaml",
+                data=data,
+                header=remove_colorstr(header.replace(self.prefix, "# ")) + "\n",
+            )
+            yaml_print(self.tune_dir / "best_hyperparameters.yaml")

+ 341 - 0
ultralytics/engine/validator.py

@@ -0,0 +1,341 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+"""
+Check a model's accuracy on a test or val split of a dataset.
+
+Usage:
+    $ yolo mode=val model=yolov8n.pt data=coco8.yaml imgsz=640
+
+Usage - formats:
+    $ yolo mode=val model=yolov8n.pt                 # PyTorch
+                          yolov8n.torchscript        # TorchScript
+                          yolov8n.onnx               # ONNX Runtime or OpenCV DNN with dnn=True
+                          yolov8n_openvino_model     # OpenVINO
+                          yolov8n.engine             # TensorRT
+                          yolov8n.mlpackage          # CoreML (macOS-only)
+                          yolov8n_saved_model        # TensorFlow SavedModel
+                          yolov8n.pb                 # TensorFlow GraphDef
+                          yolov8n.tflite             # TensorFlow Lite
+                          yolov8n_edgetpu.tflite     # TensorFlow Edge TPU
+                          yolov8n_paddle_model       # PaddlePaddle
+                          yolov8n.mnn                # MNN
+                          yolov8n_ncnn_model         # NCNN
+"""
+
+import json
+import time
+from pathlib import Path
+
+import numpy as np
+import torch
+
+from ultralytics.cfg import get_cfg, get_save_dir
+from ultralytics.data.utils import check_cls_dataset, check_det_dataset
+from ultralytics.nn.autobackend import AutoBackend
+from ultralytics.utils import LOGGER, TQDM, callbacks, colorstr, emojis
+from ultralytics.utils.checks import check_imgsz
+from ultralytics.utils.ops import Profile
+from ultralytics.utils.torch_utils import de_parallel, select_device, smart_inference_mode
+
+
+class BaseValidator:
+    """
+    BaseValidator.
+
+    A base class for creating validators.
+
+    Attributes:
+        args (SimpleNamespace): Configuration for the validator.
+        dataloader (DataLoader): Dataloader to use for validation.
+        pbar (tqdm): Progress bar to update during validation.
+        model (nn.Module): Model to validate.
+        data (dict): Data dictionary.
+        device (torch.device): Device to use for validation.
+        batch_i (int): Current batch index.
+        training (bool): Whether the model is in training mode.
+        names (dict): Class names.
+        seen: Records the number of images seen so far during validation.
+        stats: Placeholder for statistics during validation.
+        confusion_matrix: Placeholder for a confusion matrix.
+        nc: Number of classes.
+        iouv: (torch.Tensor): IoU thresholds from 0.50 to 0.95 in spaces of 0.05.
+        jdict (dict): Dictionary to store JSON validation results.
+        speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective
+                      batch processing times in milliseconds.
+        save_dir (Path): Directory to save results.
+        plots (dict): Dictionary to store plots for visualization.
+        callbacks (dict): Dictionary to store various callback functions.
+    """
+
+    def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
+        """
+        Initializes a BaseValidator instance.
+
+        Args:
+            dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation.
+            save_dir (Path, optional): Directory to save results.
+            pbar (tqdm.tqdm): Progress bar for displaying progress.
+            args (SimpleNamespace): Configuration for the validator.
+            _callbacks (dict): Dictionary to store various callback functions.
+        """
+        self.args = get_cfg(overrides=args)
+        self.dataloader = dataloader
+        self.pbar = pbar
+        self.stride = None
+        self.data = None
+        self.device = None
+        self.batch_i = None
+        self.training = True
+        self.names = None
+        self.seen = None
+        self.stats = None
+        self.confusion_matrix = None
+        self.nc = None
+        self.iouv = None
+        self.jdict = None
+        self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
+
+        self.save_dir = save_dir or get_save_dir(self.args)
+        (self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
+        if self.args.conf is None:
+            self.args.conf = 0.001  # default conf=0.001
+        self.args.imgsz = check_imgsz(self.args.imgsz, max_dim=1)
+
+        self.plots = {}
+        self.callbacks = _callbacks or callbacks.get_default_callbacks()
+
+    @smart_inference_mode()
+    def __call__(self, trainer=None, model=None):
+        """Executes validation process, running inference on dataloader and computing performance metrics."""
+        self.training = trainer is not None
+        augment = self.args.augment and (not self.training)
+        if self.training:
+            self.device = trainer.device
+            self.data = trainer.data
+            # force FP16 val during training
+            self.args.half = self.device.type != "cpu" and trainer.amp
+            model = trainer.ema.ema or trainer.model
+            model = model.half() if self.args.half else model.float()
+            # self.model = model
+            self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
+            self.args.plots &= trainer.stopper.possible_stop or (trainer.epoch == trainer.epochs - 1)
+            model.eval()
+        else:
+            if str(self.args.model).endswith(".yaml") and model is None:
+                LOGGER.warning("WARNING ⚠️ validating an untrained model YAML will result in 0 mAP.")
+            callbacks.add_integration_callbacks(self)
+            model = AutoBackend(
+                weights=model or self.args.model,
+                device=select_device(self.args.device, self.args.batch),
+                dnn=self.args.dnn,
+                data=self.args.data,
+                fp16=self.args.half,
+            )
+            # self.model = model
+            self.device = model.device  # update device
+            self.args.half = model.fp16  # update half
+            stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
+            imgsz = check_imgsz(self.args.imgsz, stride=stride)
+            if engine:
+                self.args.batch = model.batch_size
+            elif not pt and not jit:
+                self.args.batch = model.metadata.get("batch", 1)  # export.py models default to batch-size 1
+                LOGGER.info(f"Setting batch={self.args.batch} input of shape ({self.args.batch}, 3, {imgsz}, {imgsz})")
+
+            if str(self.args.data).split(".")[-1] in {"yaml", "yml"}:
+                self.data = check_det_dataset(self.args.data)
+            elif self.args.task == "classify":
+                self.data = check_cls_dataset(self.args.data, split=self.args.split)
+            else:
+                raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))
+
+            if self.device.type in {"cpu", "mps"}:
+                self.args.workers = 0  # faster CPU val as time dominated by inference, not dataloading
+            if not pt:
+                self.args.rect = False
+            self.stride = model.stride  # used in get_dataloader() for padding
+            self.dataloader = self.dataloader or self.get_dataloader(self.data.get(self.args.split), self.args.batch)
+
+            model.eval()
+            model.warmup(imgsz=(1 if pt else self.args.batch, 3, imgsz, imgsz))  # warmup
+
+        self.run_callbacks("on_val_start")
+        dt = (
+            Profile(device=self.device),
+            Profile(device=self.device),
+            Profile(device=self.device),
+            Profile(device=self.device),
+        )
+        bar = TQDM(self.dataloader, desc=self.get_desc(), total=len(self.dataloader))
+        self.init_metrics(de_parallel(model))
+        self.jdict = []  # empty before each val
+        for batch_i, batch in enumerate(bar):
+            self.run_callbacks("on_val_batch_start")
+            self.batch_i = batch_i
+            # Preprocess
+            with dt[0]:
+                batch = self.preprocess(batch)
+
+            # Inference
+            with dt[1]:
+                preds = model(batch["img"], augment=augment)
+
+            # Loss
+            with dt[2]:
+                if self.training:
+                    self.loss += model.loss(batch, preds)[1]
+
+            # Postprocess
+            with dt[3]:
+                preds = self.postprocess(preds)
+
+            self.update_metrics(preds, batch)
+            if self.args.plots and batch_i < 3:
+                self.plot_val_samples(batch, batch_i)
+                self.plot_predictions(batch, preds, batch_i)
+
+            self.run_callbacks("on_val_batch_end")
+        stats = self.get_stats()
+        self.check_stats(stats)
+        self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt)))
+        self.finalize_metrics()
+        self.print_results()
+        self.run_callbacks("on_val_end")
+        if self.training:
+            model.float()
+            results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")}
+            return {k: round(float(v), 5) for k, v in results.items()}  # return results as 5 decimal place floats
+        else:
+            LOGGER.info(
+                "Speed: {:.1f}ms preprocess, {:.1f}ms inference, {:.1f}ms loss, {:.1f}ms postprocess per image".format(
+                    *tuple(self.speed.values())
+                )
+            )
+            if self.args.save_json and self.jdict:
+                with open(str(self.save_dir / "predictions.json"), "w") as f:
+                    LOGGER.info(f"Saving {f.name}...")
+                    json.dump(self.jdict, f)  # flatten and save
+                stats = self.eval_json(stats)  # update stats
+            if self.args.plots or self.args.save_json:
+                LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
+            return stats
+
+    def match_predictions(self, pred_classes, true_classes, iou, use_scipy=False):
+        """
+        Matches predictions to ground truth objects (pred_classes, true_classes) using IoU.
+
+        Args:
+            pred_classes (torch.Tensor): Predicted class indices of shape(N,).
+            true_classes (torch.Tensor): Target class indices of shape(M,).
+            iou (torch.Tensor): An NxM tensor containing the pairwise IoU values for predictions and ground of truth
+            use_scipy (bool): Whether to use scipy for matching (more precise).
+
+        Returns:
+            (torch.Tensor): Correct tensor of shape(N,10) for 10 IoU thresholds.
+        """
+        # Dx10 matrix, where D - detections, 10 - IoU thresholds
+        correct = np.zeros((pred_classes.shape[0], self.iouv.shape[0])).astype(bool)
+        # LxD matrix where L - labels (rows), D - detections (columns)
+        correct_class = true_classes[:, None] == pred_classes
+        iou = iou * correct_class  # zero out the wrong classes
+        iou = iou.cpu().numpy()
+        for i, threshold in enumerate(self.iouv.cpu().tolist()):
+            if use_scipy:
+                # WARNING: known issue that reduces mAP in https://github.com/ultralytics/ultralytics/pull/4708
+                import scipy  # scope import to avoid importing for all commands
+
+                cost_matrix = iou * (iou >= threshold)
+                if cost_matrix.any():
+                    labels_idx, detections_idx = scipy.optimize.linear_sum_assignment(cost_matrix)
+                    valid = cost_matrix[labels_idx, detections_idx] > 0
+                    if valid.any():
+                        correct[detections_idx[valid], i] = True
+            else:
+                matches = np.nonzero(iou >= threshold)  # IoU > threshold and classes match
+                matches = np.array(matches).T
+                if matches.shape[0]:
+                    if matches.shape[0] > 1:
+                        matches = matches[iou[matches[:, 0], matches[:, 1]].argsort()[::-1]]
+                        matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
+                        # matches = matches[matches[:, 2].argsort()[::-1]]
+                        matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
+                    correct[matches[:, 1].astype(int), i] = True
+        return torch.tensor(correct, dtype=torch.bool, device=pred_classes.device)
+
+    def add_callback(self, event: str, callback):
+        """Appends the given callback."""
+        self.callbacks[event].append(callback)
+
+    def run_callbacks(self, event: str):
+        """Runs all callbacks associated with a specified event."""
+        for callback in self.callbacks.get(event, []):
+            callback(self)
+
+    def get_dataloader(self, dataset_path, batch_size):
+        """Get data loader from dataset path and batch size."""
+        raise NotImplementedError("get_dataloader function not implemented for this validator")
+
+    def build_dataset(self, img_path):
+        """Build dataset."""
+        raise NotImplementedError("build_dataset function not implemented in validator")
+
+    def preprocess(self, batch):
+        """Preprocesses an input batch."""
+        return batch
+
+    def postprocess(self, preds):
+        """Preprocesses the predictions."""
+        return preds
+
+    def init_metrics(self, model):
+        """Initialize performance metrics for the YOLO model."""
+        pass
+
+    def update_metrics(self, preds, batch):
+        """Updates metrics based on predictions and batch."""
+        pass
+
+    def finalize_metrics(self, *args, **kwargs):
+        """Finalizes and returns all metrics."""
+        pass
+
+    def get_stats(self):
+        """Returns statistics about the model's performance."""
+        return {}
+
+    def check_stats(self, stats):
+        """Checks statistics."""
+        pass
+
+    def print_results(self):
+        """Prints the results of the model's predictions."""
+        pass
+
+    def get_desc(self):
+        """Get description of the YOLO model."""
+        pass
+
+    @property
+    def metric_keys(self):
+        """Returns the metric keys used in YOLO training/validation."""
+        return []
+
+    def on_plot(self, name, data=None):
+        """Registers plots (e.g. to be consumed in callbacks)."""
+        self.plots[Path(name)] = {"data": data, "timestamp": time.time()}
+
+    # TODO: may need to put these following functions into callback
+    def plot_val_samples(self, batch, ni):
+        """Plots validation samples during training."""
+        pass
+
+    def plot_predictions(self, batch, preds, ni):
+        """Plots YOLO model predictions on batch images."""
+        pass
+
+    def pred_to_json(self, preds, batch):
+        """Convert predictions to JSON format."""
+        pass
+
+    def eval_json(self, stats):
+        """Evaluate and return JSON format of prediction statistics."""
+        pass