diff --git a/src/Poe2Trade.Bot/BotOrchestrator.cs b/src/Poe2Trade.Bot/BotOrchestrator.cs index 7c60b2c..8def5e6 100644 --- a/src/Poe2Trade.Bot/BotOrchestrator.cs +++ b/src/Poe2Trade.Bot/BotOrchestrator.cs @@ -91,11 +91,36 @@ public class BotOrchestrator : IAsyncDisposable BossRunExecutor = new BossRunExecutor(game, screen, inventory, logWatcher, store.Settings, BossDetector); - logWatcher.AreaEntered += _ => Navigation.Reset(); + logWatcher.AreaEntered += area => + { + Navigation.Reset(); + OnAreaEntered(area); + }; logWatcher.Start(); // start early so area events fire even before Bot.Start() _paused = store.Settings.Paused; } + // Boss zones → boss name mapping + private static readonly Dictionary BossZones = new(StringComparer.OrdinalIgnoreCase) + { + ["The Black Cathedral"] = "kulemak", + }; + + private void OnAreaEntered(string area) + { + if (BossZones.TryGetValue(area, out var boss)) + { + BossDetector.SetBoss(boss); + BossDetector.Enabled = true; + Log.Information("Boss zone detected: {Area} → enabling {Boss} detector", area, boss); + } + else if (BossDetector.Enabled) + { + BossDetector.Enabled = false; + Log.Information("Left boss zone → disabling boss detector"); + } + } + public bool IsReady => _started; public bool IsPaused => _paused; diff --git a/src/Poe2Trade.Screen/BossDetector.cs b/src/Poe2Trade.Screen/BossDetector.cs index 6ea8fd4..22f353b 100644 --- a/src/Poe2Trade.Screen/BossDetector.cs +++ b/src/Poe2Trade.Screen/BossDetector.cs @@ -1,78 +1,228 @@ -using Poe2Trade.Core; +using OpenCvSharp; using Serilog; using Region = Poe2Trade.Core.Region; namespace Poe2Trade.Screen; +/// +/// Detects bosses using YOLO running on a background thread. +/// Process() feeds frames, YOLO updates Latest with authoritative positions. +/// At ~26ms inference, YOLO runs fast enough to not need template tracking. +/// public class BossDetector : IFrameConsumer, IDisposable { - private const int DetectEveryNFrames = 6; private const int MinConsecutiveFrames = 2; + private const string ModelsDir = "tools/python-detect/models"; - private readonly PythonDetectBridge _bridge = new(); - private volatile BossSnapshot _latest = new([], 0, 0); - private int _frameCounter; - private int _consecutiveDetections; + private OnnxYoloDetector? _detector; private string _modelName = "boss-kulemak"; + private string _bossName = "kulemak"; + private volatile BossSnapshot _latest = new([], 0, 0); + private BossSnapshot _previous = new([], 0, 0); + private int _consecutiveDetections; + private int _inferenceCount; + + // Async frame-slot: Process() drops frame here, background loop runs YOLO + private volatile Mat? _pendingFrame; + private readonly ManualResetEventSlim _frameReady = new(false); + private Task? _inferenceLoop; + private CancellationTokenSource? _cts; + private bool _enabled; + + public bool Enabled + { + get => _enabled; + set + { + if (_enabled == value) return; + _enabled = value; + if (value) + StartLoop(); + else + StopLoop(); + } + } - public bool Enabled { get; set; } public BossSnapshot Latest => _latest; public event Action? BossDetected; public void SetBoss(string bossName) { + _bossName = bossName; _modelName = $"boss-{bossName}"; _consecutiveDetections = 0; + _inferenceCount = 0; + + if (_enabled) + { + StopLoop(); + StartLoop(); + } } + /// + /// Called by FramePipeline every frame. + /// Feeds the latest frame to YOLO (drops older pending frames). + /// public void Process(ScreenFrame frame) { - if (!Enabled) return; - if (++_frameCounter % DetectEveryNFrames != 0) return; + if (!_enabled) return; + + var fullRegion = new Region(0, 0, frame.Width, frame.Height); + using var bgr = frame.CropBgr(fullRegion); + var clone = bgr.Clone(); + var old = Interlocked.Exchange(ref _pendingFrame, clone); + old?.Dispose(); + _frameReady.Set(); + } + + private static Rect ClampRect(Rect r, int maxW, int maxH) + { + int x = Math.Max(0, r.X); + int y = Math.Max(0, r.Y); + int w = Math.Min(r.Width, maxW - x); + int h = Math.Min(r.Height, maxH - y); + return new Rect(x, y, Math.Max(0, w), Math.Max(0, h)); + } + + // ── YOLO inference (background thread) ────────────────────── + + private void StartLoop() + { + StopLoop(); + + var modelPath = Path.GetFullPath(Path.Combine(ModelsDir, $"{_modelName}.onnx")); + if (!File.Exists(modelPath)) + { + Log.Error("BossDetector: ONNX model not found at {Path}", modelPath); + return; + } try { - // Use full frame — model was trained on full 2560x1440 screenshots - var fullRegion = new Region(0, 0, frame.Width, frame.Height); - using var bgr = frame.CropBgr(fullRegion); - var result = _bridge.Detect(bgr, conf: 0.60f, imgsz: 1280, model: _modelName); - - var bosses = new List(result.Count); - foreach (var det in result.Detections) - { - bosses.Add(new DetectedBoss( - det.ClassName, - det.Confidence, - det.X, - det.Y, - det.Width, - det.Height, - det.Cx, - det.Cy)); - } - - var snapshot = new BossSnapshot( - bosses.AsReadOnly(), - DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(), - result.InferenceMs); - - _latest = snapshot; - if (bosses.Count > 0) - { - _consecutiveDetections++; - if (_consecutiveDetections >= MinConsecutiveFrames) - BossDetected?.Invoke(snapshot); - } - else - { - _consecutiveDetections = 0; - } + _detector = new OnnxYoloDetector(modelPath, [_bossName], confThreshold: 0.40f); } catch (Exception ex) { - Log.Debug(ex, "BossDetector YOLO failed"); + Log.Error(ex, "BossDetector: failed to load ONNX model {Path}", modelPath); + return; } + + _cts = new CancellationTokenSource(); + _inferenceLoop = Task.Factory.StartNew( + () => InferenceLoop(_cts.Token), + _cts.Token, + TaskCreationOptions.LongRunning, + TaskScheduler.Default).Unwrap(); + + Log.Information("BossDetector: started inference loop with {Model}", _modelName); } - public void Dispose() => _bridge.Dispose(); + private void StopLoop() + { + if (_cts != null) + { + _cts.Cancel(); + _frameReady.Set(); + try { _inferenceLoop?.Wait(TimeSpan.FromSeconds(3)); } catch { /* timeout ok */ } + _cts.Dispose(); + _cts = null; + _inferenceLoop = null; + } + + _detector?.Dispose(); + _detector = null; + + var old = Interlocked.Exchange(ref _pendingFrame, null); + old?.Dispose(); + _frameReady.Reset(); + _consecutiveDetections = 0; + } + + private async Task InferenceLoop(CancellationToken ct) + { + Log.Information("BossDetector: inference loop started"); + + while (!ct.IsCancellationRequested) + { + try + { + _frameReady.Wait(ct); + _frameReady.Reset(); + + var frame = Interlocked.Exchange(ref _pendingFrame, null); + if (frame == null) continue; + + try + { + var (detections, totalMs, preMs, infMs) = _detector!.Detect(frame); + _inferenceCount++; + + if (_inferenceCount % 15 == 0) + Log.Information("BossDetect: {Count} hits, total={Total:F0}ms (pre={Pre:F0} inf={Inf:F0}), consecutive={Consec}", + detections.Count, totalMs, preMs, infMs, _consecutiveDetections); + + if (detections.Count > 0) + { + _consecutiveDetections++; + + var timestamp = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(); + var deltaMs = (float)(timestamp - _previous.Timestamp); + + var enriched = new List(detections.Count); + foreach (var det in detections) + { + float vx = 0, vy = 0; + if (deltaMs > 0 && _previous.Bosses.Count > 0) + { + var prev = _previous.Bosses.FirstOrDefault(b => b.ClassName == det.ClassName); + if (prev != null) + { + vx = (det.Cx - prev.Cx) / deltaMs; + vy = (det.Cy - prev.Cy) / deltaMs; + } + } + enriched.Add(det with { VxPerMs = vx, VyPerMs = vy }); + } + + var snapshot = new BossSnapshot( + enriched.AsReadOnly(), + timestamp, + totalMs); + _latest = snapshot; + _previous = snapshot; + + if (_consecutiveDetections >= MinConsecutiveFrames) + BossDetected?.Invoke(snapshot); + } + else + { + _consecutiveDetections = 0; + _latest = new BossSnapshot([], DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(), 0); + } + } + finally + { + frame.Dispose(); + } + } + catch (OperationCanceledException) + { + break; + } + catch (Exception ex) + { + Log.Warning(ex, "BossDetector inference failed"); + await Task.Delay(100, ct); + } + } + + Log.Information("BossDetector: inference loop stopped"); + } + + public void Dispose() + { + StopLoop(); + _frameReady.Dispose(); + } } diff --git a/src/Poe2Trade.Screen/DetectionTypes.cs b/src/Poe2Trade.Screen/DetectionTypes.cs index 6bb9b9f..3d05050 100644 --- a/src/Poe2Trade.Screen/DetectionTypes.cs +++ b/src/Poe2Trade.Screen/DetectionTypes.cs @@ -15,7 +15,8 @@ public record DetectedBoss( string ClassName, float Confidence, int X, int Y, int Width, int Height, - int Cx, int Cy); + int Cx, int Cy, + float VxPerMs = 0, float VyPerMs = 0); public record BossSnapshot( IReadOnlyList Bosses, diff --git a/src/Poe2Trade.Screen/OnnxYoloDetector.cs b/src/Poe2Trade.Screen/OnnxYoloDetector.cs new file mode 100644 index 0000000..006ce6f --- /dev/null +++ b/src/Poe2Trade.Screen/OnnxYoloDetector.cs @@ -0,0 +1,227 @@ +using System.Runtime.InteropServices; +using Microsoft.ML.OnnxRuntime; +using Microsoft.ML.OnnxRuntime.Tensors; +using OpenCvSharp; +using OpenCvSharp.Dnn; +using Serilog; + +namespace Poe2Trade.Screen; + +/// +/// YOLO11 object detection via ONNX Runtime with CUDA GPU acceleration. +/// Handles letterbox preprocessing, inference, and NMS postprocessing. +/// Buffers are pooled to avoid LOH allocations that trigger Gen2 GC pauses. +/// +public class OnnxYoloDetector : IDisposable +{ + private readonly InferenceSession _session; + private readonly string[] _classNames; + private readonly int _imgSize; + private readonly float _confThreshold; + private readonly float _iouThreshold; + private readonly string _inputName; + private bool _warmedUp; + + // Pooled buffers — allocated once, reused every inference (avoids LOH/GC pressure) + private readonly float[] _tensorBuffer; // 3 * imgSize * imgSize (~1.2MB for 640) + private float[]? _outputBuffer; // rowSize * numDetections, sized on first use + + // Pre-allocated Mats for preprocessing (reused every inference — avoids alloc/GC per frame) + private readonly Mat _resized = new(); + private readonly Mat _padded; + private readonly Mat _rgb = new(); + private readonly Mat _floatMat = new(); + + public OnnxYoloDetector(string modelPath, string[] classNames, + float confThreshold = 0.40f, float iouThreshold = 0.45f) + { + _classNames = classNames; + _confThreshold = confThreshold; + _iouThreshold = iouThreshold; + + var opts = new SessionOptions(); + opts.GraphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL; + opts.InterOpNumThreads = 1; // single model, no inter-op parallelism needed + opts.IntraOpNumThreads = Environment.ProcessorCount / 2; // use half the cores (leave room for game + pipeline) + opts.ExecutionMode = ExecutionMode.ORT_SEQUENTIAL; // sequential is faster for single inference + // CPU EP — avoids GPU contention with DXGI screen capture + Log.Information("OnnxYolo: using CPU EP, intra threads={Threads}", opts.IntraOpNumThreads); + + _session = new InferenceSession(modelPath, opts); + _inputName = _session.InputNames[0]; + + // Read imgSize from the model's input shape (NCHW: [1, 3, H, W]) + var inputMeta = _session.InputMetadata[_inputName]; + _imgSize = inputMeta.Dimensions[2]; // H == W for square YOLO input + + _tensorBuffer = new float[3 * _imgSize * _imgSize]; + _padded = new Mat(_imgSize, _imgSize, MatType.CV_8UC3, new Scalar(114, 114, 114)); + + Log.Information("OnnxYolo: loaded {Path} (input: {Input}, imgSize: {ImgSize})", + modelPath, _inputName, _imgSize); + } + + /// + /// Run detection on a BGR Mat. Returns detected bosses in original image coordinates. + /// + public (List Detections, float TotalMs, float PreMs, float InfMs) Detect(Mat bgrMat) + { + var swTotal = System.Diagnostics.Stopwatch.StartNew(); + + // 1. Letterbox preprocess (reuses _tensorBuffer) + var swPre = System.Diagnostics.Stopwatch.StartNew(); + var (tensor, scale, padX, padY) = Preprocess(bgrMat); + swPre.Stop(); + + // 2. Run inference + var swInf = System.Diagnostics.Stopwatch.StartNew(); + var inputs = new List + { + NamedOnnxValue.CreateFromTensor(_inputName, tensor) + }; + using var results = _session.Run(inputs); + swInf.Stop(); + + // 3. Parse output (reuses _outputBuffer) + var outputTensor = results.First().AsTensor(); + var detections = Postprocess(outputTensor, scale, padX, padY, bgrMat.Width, bgrMat.Height); + + swTotal.Stop(); + var totalMs = (float)swTotal.Elapsed.TotalMilliseconds; + + if (!_warmedUp) + { + _warmedUp = true; + Log.Information("OnnxYolo warmup: pre={Pre:F0}ms inf={Inf:F0}ms total={Total:F0}ms", + swPre.Elapsed.TotalMilliseconds, swInf.Elapsed.TotalMilliseconds, totalMs); + } + + return (detections, totalMs, (float)swPre.Elapsed.TotalMilliseconds, (float)swInf.Elapsed.TotalMilliseconds); + } + + private (DenseTensor tensor, float scale, int padX, int padY) Preprocess(Mat bgrMat) + { + int origW = bgrMat.Width, origH = bgrMat.Height; + + float scale = Math.Min((float)_imgSize / origW, (float)_imgSize / origH); + int newW = (int)Math.Round(origW * scale); + int newH = (int)Math.Round(origH * scale); + + int padX = (_imgSize - newW) / 2; + int padY = (_imgSize - newH) / 2; + + Cv2.Resize(bgrMat, _resized, new Size(newW, newH), interpolation: InterpolationFlags.Linear); + + _padded.SetTo(new Scalar(114, 114, 114)); + _resized.CopyTo(_padded[new Rect(padX, padY, newW, newH)]); + + Cv2.CvtColor(_padded, _rgb, ColorConversionCodes.BGR2RGB); + + _rgb.ConvertTo(_floatMat, MatType.CV_32FC3, 1.0 / 255.0); + + // HWC → NCHW via channel split + Marshal.Copy into pooled buffer + int pixels = _imgSize * _imgSize; + Cv2.Split(_floatMat, out Mat[] channels); + try + { + for (int c = 0; c < 3; c++) + Marshal.Copy(channels[c].Data, _tensorBuffer, c * pixels, pixels); + } + finally + { + foreach (var ch in channels) ch.Dispose(); + } + + // Wrap pooled buffer in tensor (no copy — DenseTensor references the array) + var tensor = new DenseTensor(_tensorBuffer, [1, 3, _imgSize, _imgSize]); + return (tensor, scale, padX, padY); + } + + private List Postprocess(Tensor output, float scale, + int padX, int padY, int origW, int origH) + { + int numClasses = _classNames.Length; + int numDetections = output.Dimensions[2]; + int rowSize = output.Dimensions[1]; // 4 + nc + int flatSize = rowSize * numDetections; + + // Reuse output buffer (resize only if model output shape changed) + if (_outputBuffer == null || _outputBuffer.Length < flatSize) + _outputBuffer = new float[flatSize]; + + if (output is DenseTensor dense) + dense.Buffer.Span.CopyTo(_outputBuffer); + else + for (int r = 0; r < rowSize; r++) + for (int i = 0; i < numDetections; i++) + _outputBuffer[r * numDetections + i] = output[0, r, i]; + + var boxes = new List(); + var confidences = new List(); + var classIds = new List(); + + for (int i = 0; i < numDetections; i++) + { + float bestConf = 0; + int bestClass = 0; + for (int c = 0; c < numClasses; c++) + { + float conf = _outputBuffer[(4 + c) * numDetections + i]; + if (conf > bestConf) + { + bestConf = conf; + bestClass = c; + } + } + + if (bestConf < _confThreshold) continue; + + float cx = _outputBuffer[0 * numDetections + i]; + float cy = _outputBuffer[1 * numDetections + i]; + float w = _outputBuffer[2 * numDetections + i]; + float h = _outputBuffer[3 * numDetections + i]; + + float x1 = (cx - w / 2 - padX) / scale; + float y1 = (cy - h / 2 - padY) / scale; + float bw = w / scale; + float bh = h / scale; + + x1 = Math.Max(0, x1); + y1 = Math.Max(0, y1); + bw = Math.Min(bw, origW - x1); + bh = Math.Min(bh, origH - y1); + + boxes.Add(new Rect((int)x1, (int)y1, (int)bw, (int)bh)); + confidences.Add(bestConf); + classIds.Add(bestClass); + } + + if (boxes.Count == 0) + return []; + + CvDnn.NMSBoxes(boxes, confidences, _confThreshold, _iouThreshold, out int[] indices); + + var detections = new List(indices.Length); + foreach (var idx in indices) + { + var box = boxes[idx]; + detections.Add(new DetectedBoss( + _classNames[classIds[idx]], + confidences[idx], + box.X, box.Y, box.Width, box.Height, + box.X + box.Width / 2, + box.Y + box.Height / 2)); + } + + return detections; + } + + public void Dispose() + { + _session.Dispose(); + _resized.Dispose(); + _padded.Dispose(); + _rgb.Dispose(); + _floatMat.Dispose(); + } +} diff --git a/src/Poe2Trade.Screen/Poe2Trade.Screen.csproj b/src/Poe2Trade.Screen/Poe2Trade.Screen.csproj index 65b90fb..bcfbcf4 100644 --- a/src/Poe2Trade.Screen/Poe2Trade.Screen.csproj +++ b/src/Poe2Trade.Screen/Poe2Trade.Screen.csproj @@ -11,6 +11,7 @@ + diff --git a/src/Poe2Trade.Screen/PythonDetectBridge.cs b/src/Poe2Trade.Screen/PythonDetectBridge.cs index c262cf1..c97c016 100644 --- a/src/Poe2Trade.Screen/PythonDetectBridge.cs +++ b/src/Poe2Trade.Screen/PythonDetectBridge.cs @@ -39,7 +39,7 @@ class PythonDetectBridge : IDisposable { EnsureRunning(); - var imageBytes = bgrMat.ToBytes(".png"); + var imageBytes = bgrMat.ToBytes(".jpg", [(int)ImwriteFlags.JpegQuality, 95]); var imageBase64 = Convert.ToBase64String(imageBytes); var req = new Dictionary diff --git a/src/Poe2Trade.Ui/Overlay/D2dOverlay.cs b/src/Poe2Trade.Ui/Overlay/D2dOverlay.cs index dd11c49..ee7bab2 100644 --- a/src/Poe2Trade.Ui/Overlay/D2dOverlay.cs +++ b/src/Poe2Trade.Ui/Overlay/D2dOverlay.cs @@ -185,6 +185,7 @@ public sealed class D2dOverlay return new OverlayState( Enemies: detection.Enemies, Bosses: bossDetection.Bosses, + BossTimestampMs: bossDetection.Timestamp, InferenceMs: detection.InferenceMs, Hud: _bot.HudReader.Current, NavState: _bot.Navigation.State, diff --git a/src/Poe2Trade.Ui/Overlay/IOverlayLayer.cs b/src/Poe2Trade.Ui/Overlay/IOverlayLayer.cs index 402742a..62b579a 100644 --- a/src/Poe2Trade.Ui/Overlay/IOverlayLayer.cs +++ b/src/Poe2Trade.Ui/Overlay/IOverlayLayer.cs @@ -6,6 +6,7 @@ namespace Poe2Trade.Ui.Overlay; public record OverlayState( IReadOnlyList Enemies, IReadOnlyList Bosses, + long BossTimestampMs, float InferenceMs, HudSnapshot? Hud, NavigationState NavState, diff --git a/src/Poe2Trade.Ui/Overlay/Layers/D2dEnemyBoxLayer.cs b/src/Poe2Trade.Ui/Overlay/Layers/D2dEnemyBoxLayer.cs index 48ea2c0..7a5513d 100644 --- a/src/Poe2Trade.Ui/Overlay/Layers/D2dEnemyBoxLayer.cs +++ b/src/Poe2Trade.Ui/Overlay/Layers/D2dEnemyBoxLayer.cs @@ -53,10 +53,15 @@ internal sealed class D2dEnemyBoxLayer : ID2dOverlayLayer, IDisposable rt.DrawTextLayout(new System.Numerics.Vector2(labelX, labelY), layout, textBrush); } - // Boss bounding boxes (cyan) + // Boss bounding boxes (cyan) — extrapolate position to compensate for inference delay + var now = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(); + var ageMs = (float)Math.Clamp(now - state.BossTimestampMs, 0, 60); + foreach (var boss in state.Bosses) { - var rect = new RectangleF(boss.X, boss.Y, boss.Width, boss.Height); + float dx = boss.VxPerMs * ageMs; + float dy = boss.VyPerMs * ageMs; + var rect = new RectangleF(boss.X + dx, boss.Y + dy, boss.Width, boss.Height); rt.DrawRectangle(rect, ctx.Cyan, 3f); var pct = Math.Clamp((int)(boss.Confidence * 100), 0, 100); @@ -68,8 +73,8 @@ internal sealed class D2dEnemyBoxLayer : ID2dOverlayLayer, IDisposable } var m = layout.Metrics; - var labelX = boss.X; - var labelY = boss.Y - m.Height - 2; + var labelX = boss.X + dx; + var labelY = boss.Y + dy - m.Height - 2; rt.FillRectangle( new RectangleF(labelX - 1, labelY - 1, m.Width + 2, m.Height + 2), diff --git a/tools/python-detect/manage.py b/tools/python-detect/manage.py index 37112ed..a8e190f 100644 --- a/tools/python-detect/manage.py +++ b/tools/python-detect/manage.py @@ -7,6 +7,7 @@ Subcommands (all take a positional boss name): runs kulemak List training runs + metrics table annotate kulemak [dir] Launch annotation GUI prelabel kulemak [dir] [--model boss-kulemak] Auto-label unlabeled images + export kulemak [--imgsz 640] Export .pt to ONNX format """ import argparse @@ -291,6 +292,32 @@ def cmd_prelabel(args): _PRELABEL_MODEL_DEFAULT = "__auto__" +# ── export ────────────────────────────────────────────────────── +def cmd_export(args): + """Export .pt model to .onnx format for ONNX Runtime inference.""" + boss = args.boss + model_name = f"boss-{boss}" + pt_path = os.path.join(MODELS_DIR, f"{model_name}.pt") + + if not os.path.exists(pt_path): + print(f"Model not found: {pt_path}") + return + + from ultralytics import YOLO + model = YOLO(pt_path) + + print(f"Exporting {pt_path} -> ONNX (imgsz={args.imgsz})...") + model.export(format="onnx", imgsz=args.imgsz, opset=17, simplify=True, dynamic=False) + + # ultralytics writes the .onnx next to the .pt file + onnx_src = os.path.join(MODELS_DIR, f"{model_name}.onnx") + if os.path.exists(onnx_src): + size_mb = os.path.getsize(onnx_src) / (1024 * 1024) + print(f"\nExported: {onnx_src} ({size_mb:.1f} MB)") + else: + print(f"\nWarning: expected output not found at {onnx_src}") + + # ── CLI ────────────────────────────────────────────────────────── def main(): parser = argparse.ArgumentParser( @@ -315,10 +342,10 @@ def main(): p = sub.add_parser("train", help="Train YOLO model") p.add_argument("boss", help="Boss name (e.g. kulemak)") p.add_argument("--data", default=None, help="Path to data.yaml") - p.add_argument("--model", default="yolo11s", help="YOLO model variant") + p.add_argument("--model", default="yolo11n", help="YOLO model variant") p.add_argument("--epochs", type=int, default=200, help="Training epochs") - p.add_argument("--imgsz", type=int, default=1280, help="Image size") - p.add_argument("--batch", type=int, default=8, help="Batch size") + p.add_argument("--imgsz", type=int, default=640, help="Image size") + p.add_argument("--batch", type=int, default=16, help="Batch size") p.add_argument("--device", default="0", help="CUDA device") p.add_argument("--name", default=None, help="Run name (auto-increments if omitted)") @@ -333,6 +360,11 @@ def main(): p.add_argument("--model", default=_PRELABEL_MODEL_DEFAULT, help="Model name in models/ (default: boss-{boss})") p.add_argument("--conf", type=float, default=0.20, help="Confidence threshold") + # export + p = sub.add_parser("export", help="Export .pt model to ONNX format") + p.add_argument("boss", help="Boss name (e.g. kulemak)") + p.add_argument("--imgsz", type=int, default=640, help="Image size for export") + args = parser.parse_args() if args.command is None: @@ -345,6 +377,7 @@ def main(): "train": cmd_train, "runs": cmd_runs, "prelabel": cmd_prelabel, + "export": cmd_export, } commands[args.command](args)