got detection somewhat working

This commit is contained in:
Boki 2026-02-20 21:57:51 -05:00
parent 40d30115bf
commit c75b2b27f0
10 changed files with 500 additions and 56 deletions

View file

@ -91,11 +91,36 @@ public class BotOrchestrator : IAsyncDisposable
BossRunExecutor = new BossRunExecutor(game, screen, inventory, logWatcher, store.Settings, BossDetector); BossRunExecutor = new BossRunExecutor(game, screen, inventory, logWatcher, store.Settings, BossDetector);
logWatcher.AreaEntered += _ => Navigation.Reset(); logWatcher.AreaEntered += area =>
{
Navigation.Reset();
OnAreaEntered(area);
};
logWatcher.Start(); // start early so area events fire even before Bot.Start() logWatcher.Start(); // start early so area events fire even before Bot.Start()
_paused = store.Settings.Paused; _paused = store.Settings.Paused;
} }
// Boss zones → boss name mapping
private static readonly Dictionary<string, string> BossZones = new(StringComparer.OrdinalIgnoreCase)
{
["The Black Cathedral"] = "kulemak",
};
private void OnAreaEntered(string area)
{
if (BossZones.TryGetValue(area, out var boss))
{
BossDetector.SetBoss(boss);
BossDetector.Enabled = true;
Log.Information("Boss zone detected: {Area} → enabling {Boss} detector", area, boss);
}
else if (BossDetector.Enabled)
{
BossDetector.Enabled = false;
Log.Information("Left boss zone → disabling boss detector");
}
}
public bool IsReady => _started; public bool IsReady => _started;
public bool IsPaused => _paused; public bool IsPaused => _paused;

View file

@ -1,78 +1,228 @@
using Poe2Trade.Core; using OpenCvSharp;
using Serilog; using Serilog;
using Region = Poe2Trade.Core.Region; using Region = Poe2Trade.Core.Region;
namespace Poe2Trade.Screen; namespace Poe2Trade.Screen;
/// <summary>
/// Detects bosses using YOLO running on a background thread.
/// Process() feeds frames, YOLO updates Latest with authoritative positions.
/// At ~26ms inference, YOLO runs fast enough to not need template tracking.
/// </summary>
public class BossDetector : IFrameConsumer, IDisposable public class BossDetector : IFrameConsumer, IDisposable
{ {
private const int DetectEveryNFrames = 6;
private const int MinConsecutiveFrames = 2; private const int MinConsecutiveFrames = 2;
private const string ModelsDir = "tools/python-detect/models";
private readonly PythonDetectBridge _bridge = new(); private OnnxYoloDetector? _detector;
private volatile BossSnapshot _latest = new([], 0, 0);
private int _frameCounter;
private int _consecutiveDetections;
private string _modelName = "boss-kulemak"; private string _modelName = "boss-kulemak";
private string _bossName = "kulemak";
private volatile BossSnapshot _latest = new([], 0, 0);
private BossSnapshot _previous = new([], 0, 0);
private int _consecutiveDetections;
private int _inferenceCount;
// Async frame-slot: Process() drops frame here, background loop runs YOLO
private volatile Mat? _pendingFrame;
private readonly ManualResetEventSlim _frameReady = new(false);
private Task? _inferenceLoop;
private CancellationTokenSource? _cts;
private bool _enabled;
public bool Enabled
{
get => _enabled;
set
{
if (_enabled == value) return;
_enabled = value;
if (value)
StartLoop();
else
StopLoop();
}
}
public bool Enabled { get; set; }
public BossSnapshot Latest => _latest; public BossSnapshot Latest => _latest;
public event Action<BossSnapshot>? BossDetected; public event Action<BossSnapshot>? BossDetected;
public void SetBoss(string bossName) public void SetBoss(string bossName)
{ {
_bossName = bossName;
_modelName = $"boss-{bossName}"; _modelName = $"boss-{bossName}";
_consecutiveDetections = 0; _consecutiveDetections = 0;
_inferenceCount = 0;
if (_enabled)
{
StopLoop();
StartLoop();
}
} }
/// <summary>
/// Called by FramePipeline every frame.
/// Feeds the latest frame to YOLO (drops older pending frames).
/// </summary>
public void Process(ScreenFrame frame) public void Process(ScreenFrame frame)
{ {
if (!Enabled) return; if (!_enabled) return;
if (++_frameCounter % DetectEveryNFrames != 0) return;
var fullRegion = new Region(0, 0, frame.Width, frame.Height);
using var bgr = frame.CropBgr(fullRegion);
var clone = bgr.Clone();
var old = Interlocked.Exchange(ref _pendingFrame, clone);
old?.Dispose();
_frameReady.Set();
}
private static Rect ClampRect(Rect r, int maxW, int maxH)
{
int x = Math.Max(0, r.X);
int y = Math.Max(0, r.Y);
int w = Math.Min(r.Width, maxW - x);
int h = Math.Min(r.Height, maxH - y);
return new Rect(x, y, Math.Max(0, w), Math.Max(0, h));
}
// ── YOLO inference (background thread) ──────────────────────
private void StartLoop()
{
StopLoop();
var modelPath = Path.GetFullPath(Path.Combine(ModelsDir, $"{_modelName}.onnx"));
if (!File.Exists(modelPath))
{
Log.Error("BossDetector: ONNX model not found at {Path}", modelPath);
return;
}
try try
{ {
// Use full frame — model was trained on full 2560x1440 screenshots _detector = new OnnxYoloDetector(modelPath, [_bossName], confThreshold: 0.40f);
var fullRegion = new Region(0, 0, frame.Width, frame.Height); }
using var bgr = frame.CropBgr(fullRegion); catch (Exception ex)
var result = _bridge.Detect(bgr, conf: 0.60f, imgsz: 1280, model: _modelName);
var bosses = new List<DetectedBoss>(result.Count);
foreach (var det in result.Detections)
{ {
bosses.Add(new DetectedBoss( Log.Error(ex, "BossDetector: failed to load ONNX model {Path}", modelPath);
det.ClassName, return;
det.Confidence, }
det.X,
det.Y, _cts = new CancellationTokenSource();
det.Width, _inferenceLoop = Task.Factory.StartNew(
det.Height, () => InferenceLoop(_cts.Token),
det.Cx, _cts.Token,
det.Cy)); TaskCreationOptions.LongRunning,
TaskScheduler.Default).Unwrap();
Log.Information("BossDetector: started inference loop with {Model}", _modelName);
}
private void StopLoop()
{
if (_cts != null)
{
_cts.Cancel();
_frameReady.Set();
try { _inferenceLoop?.Wait(TimeSpan.FromSeconds(3)); } catch { /* timeout ok */ }
_cts.Dispose();
_cts = null;
_inferenceLoop = null;
}
_detector?.Dispose();
_detector = null;
var old = Interlocked.Exchange(ref _pendingFrame, null);
old?.Dispose();
_frameReady.Reset();
_consecutiveDetections = 0;
}
private async Task InferenceLoop(CancellationToken ct)
{
Log.Information("BossDetector: inference loop started");
while (!ct.IsCancellationRequested)
{
try
{
_frameReady.Wait(ct);
_frameReady.Reset();
var frame = Interlocked.Exchange(ref _pendingFrame, null);
if (frame == null) continue;
try
{
var (detections, totalMs, preMs, infMs) = _detector!.Detect(frame);
_inferenceCount++;
if (_inferenceCount % 15 == 0)
Log.Information("BossDetect: {Count} hits, total={Total:F0}ms (pre={Pre:F0} inf={Inf:F0}), consecutive={Consec}",
detections.Count, totalMs, preMs, infMs, _consecutiveDetections);
if (detections.Count > 0)
{
_consecutiveDetections++;
var timestamp = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds();
var deltaMs = (float)(timestamp - _previous.Timestamp);
var enriched = new List<DetectedBoss>(detections.Count);
foreach (var det in detections)
{
float vx = 0, vy = 0;
if (deltaMs > 0 && _previous.Bosses.Count > 0)
{
var prev = _previous.Bosses.FirstOrDefault(b => b.ClassName == det.ClassName);
if (prev != null)
{
vx = (det.Cx - prev.Cx) / deltaMs;
vy = (det.Cy - prev.Cy) / deltaMs;
}
}
enriched.Add(det with { VxPerMs = vx, VyPerMs = vy });
} }
var snapshot = new BossSnapshot( var snapshot = new BossSnapshot(
bosses.AsReadOnly(), enriched.AsReadOnly(),
DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(), timestamp,
result.InferenceMs); totalMs);
_latest = snapshot; _latest = snapshot;
if (bosses.Count > 0) _previous = snapshot;
{
_consecutiveDetections++;
if (_consecutiveDetections >= MinConsecutiveFrames) if (_consecutiveDetections >= MinConsecutiveFrames)
BossDetected?.Invoke(snapshot); BossDetected?.Invoke(snapshot);
} }
else else
{ {
_consecutiveDetections = 0; _consecutiveDetections = 0;
_latest = new BossSnapshot([], DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(), 0);
} }
} }
finally
{
frame.Dispose();
}
}
catch (OperationCanceledException)
{
break;
}
catch (Exception ex) catch (Exception ex)
{ {
Log.Debug(ex, "BossDetector YOLO failed"); Log.Warning(ex, "BossDetector inference failed");
await Task.Delay(100, ct);
} }
} }
public void Dispose() => _bridge.Dispose(); Log.Information("BossDetector: inference loop stopped");
}
public void Dispose()
{
StopLoop();
_frameReady.Dispose();
}
} }

View file

@ -15,7 +15,8 @@ public record DetectedBoss(
string ClassName, string ClassName,
float Confidence, float Confidence,
int X, int Y, int Width, int Height, int X, int Y, int Width, int Height,
int Cx, int Cy); int Cx, int Cy,
float VxPerMs = 0, float VyPerMs = 0);
public record BossSnapshot( public record BossSnapshot(
IReadOnlyList<DetectedBoss> Bosses, IReadOnlyList<DetectedBoss> Bosses,

View file

@ -0,0 +1,227 @@
using System.Runtime.InteropServices;
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;
using OpenCvSharp;
using OpenCvSharp.Dnn;
using Serilog;
namespace Poe2Trade.Screen;
/// <summary>
/// YOLO11 object detection via ONNX Runtime with CUDA GPU acceleration.
/// Handles letterbox preprocessing, inference, and NMS postprocessing.
/// Buffers are pooled to avoid LOH allocations that trigger Gen2 GC pauses.
/// </summary>
public class OnnxYoloDetector : IDisposable
{
private readonly InferenceSession _session;
private readonly string[] _classNames;
private readonly int _imgSize;
private readonly float _confThreshold;
private readonly float _iouThreshold;
private readonly string _inputName;
private bool _warmedUp;
// Pooled buffers — allocated once, reused every inference (avoids LOH/GC pressure)
private readonly float[] _tensorBuffer; // 3 * imgSize * imgSize (~1.2MB for 640)
private float[]? _outputBuffer; // rowSize * numDetections, sized on first use
// Pre-allocated Mats for preprocessing (reused every inference — avoids alloc/GC per frame)
private readonly Mat _resized = new();
private readonly Mat _padded;
private readonly Mat _rgb = new();
private readonly Mat _floatMat = new();
public OnnxYoloDetector(string modelPath, string[] classNames,
float confThreshold = 0.40f, float iouThreshold = 0.45f)
{
_classNames = classNames;
_confThreshold = confThreshold;
_iouThreshold = iouThreshold;
var opts = new SessionOptions();
opts.GraphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL;
opts.InterOpNumThreads = 1; // single model, no inter-op parallelism needed
opts.IntraOpNumThreads = Environment.ProcessorCount / 2; // use half the cores (leave room for game + pipeline)
opts.ExecutionMode = ExecutionMode.ORT_SEQUENTIAL; // sequential is faster for single inference
// CPU EP — avoids GPU contention with DXGI screen capture
Log.Information("OnnxYolo: using CPU EP, intra threads={Threads}", opts.IntraOpNumThreads);
_session = new InferenceSession(modelPath, opts);
_inputName = _session.InputNames[0];
// Read imgSize from the model's input shape (NCHW: [1, 3, H, W])
var inputMeta = _session.InputMetadata[_inputName];
_imgSize = inputMeta.Dimensions[2]; // H == W for square YOLO input
_tensorBuffer = new float[3 * _imgSize * _imgSize];
_padded = new Mat(_imgSize, _imgSize, MatType.CV_8UC3, new Scalar(114, 114, 114));
Log.Information("OnnxYolo: loaded {Path} (input: {Input}, imgSize: {ImgSize})",
modelPath, _inputName, _imgSize);
}
/// <summary>
/// Run detection on a BGR Mat. Returns detected bosses in original image coordinates.
/// </summary>
public (List<DetectedBoss> Detections, float TotalMs, float PreMs, float InfMs) Detect(Mat bgrMat)
{
var swTotal = System.Diagnostics.Stopwatch.StartNew();
// 1. Letterbox preprocess (reuses _tensorBuffer)
var swPre = System.Diagnostics.Stopwatch.StartNew();
var (tensor, scale, padX, padY) = Preprocess(bgrMat);
swPre.Stop();
// 2. Run inference
var swInf = System.Diagnostics.Stopwatch.StartNew();
var inputs = new List<NamedOnnxValue>
{
NamedOnnxValue.CreateFromTensor(_inputName, tensor)
};
using var results = _session.Run(inputs);
swInf.Stop();
// 3. Parse output (reuses _outputBuffer)
var outputTensor = results.First().AsTensor<float>();
var detections = Postprocess(outputTensor, scale, padX, padY, bgrMat.Width, bgrMat.Height);
swTotal.Stop();
var totalMs = (float)swTotal.Elapsed.TotalMilliseconds;
if (!_warmedUp)
{
_warmedUp = true;
Log.Information("OnnxYolo warmup: pre={Pre:F0}ms inf={Inf:F0}ms total={Total:F0}ms",
swPre.Elapsed.TotalMilliseconds, swInf.Elapsed.TotalMilliseconds, totalMs);
}
return (detections, totalMs, (float)swPre.Elapsed.TotalMilliseconds, (float)swInf.Elapsed.TotalMilliseconds);
}
private (DenseTensor<float> tensor, float scale, int padX, int padY) Preprocess(Mat bgrMat)
{
int origW = bgrMat.Width, origH = bgrMat.Height;
float scale = Math.Min((float)_imgSize / origW, (float)_imgSize / origH);
int newW = (int)Math.Round(origW * scale);
int newH = (int)Math.Round(origH * scale);
int padX = (_imgSize - newW) / 2;
int padY = (_imgSize - newH) / 2;
Cv2.Resize(bgrMat, _resized, new Size(newW, newH), interpolation: InterpolationFlags.Linear);
_padded.SetTo(new Scalar(114, 114, 114));
_resized.CopyTo(_padded[new Rect(padX, padY, newW, newH)]);
Cv2.CvtColor(_padded, _rgb, ColorConversionCodes.BGR2RGB);
_rgb.ConvertTo(_floatMat, MatType.CV_32FC3, 1.0 / 255.0);
// HWC → NCHW via channel split + Marshal.Copy into pooled buffer
int pixels = _imgSize * _imgSize;
Cv2.Split(_floatMat, out Mat[] channels);
try
{
for (int c = 0; c < 3; c++)
Marshal.Copy(channels[c].Data, _tensorBuffer, c * pixels, pixels);
}
finally
{
foreach (var ch in channels) ch.Dispose();
}
// Wrap pooled buffer in tensor (no copy — DenseTensor references the array)
var tensor = new DenseTensor<float>(_tensorBuffer, [1, 3, _imgSize, _imgSize]);
return (tensor, scale, padX, padY);
}
private List<DetectedBoss> Postprocess(Tensor<float> output, float scale,
int padX, int padY, int origW, int origH)
{
int numClasses = _classNames.Length;
int numDetections = output.Dimensions[2];
int rowSize = output.Dimensions[1]; // 4 + nc
int flatSize = rowSize * numDetections;
// Reuse output buffer (resize only if model output shape changed)
if (_outputBuffer == null || _outputBuffer.Length < flatSize)
_outputBuffer = new float[flatSize];
if (output is DenseTensor<float> dense)
dense.Buffer.Span.CopyTo(_outputBuffer);
else
for (int r = 0; r < rowSize; r++)
for (int i = 0; i < numDetections; i++)
_outputBuffer[r * numDetections + i] = output[0, r, i];
var boxes = new List<Rect>();
var confidences = new List<float>();
var classIds = new List<int>();
for (int i = 0; i < numDetections; i++)
{
float bestConf = 0;
int bestClass = 0;
for (int c = 0; c < numClasses; c++)
{
float conf = _outputBuffer[(4 + c) * numDetections + i];
if (conf > bestConf)
{
bestConf = conf;
bestClass = c;
}
}
if (bestConf < _confThreshold) continue;
float cx = _outputBuffer[0 * numDetections + i];
float cy = _outputBuffer[1 * numDetections + i];
float w = _outputBuffer[2 * numDetections + i];
float h = _outputBuffer[3 * numDetections + i];
float x1 = (cx - w / 2 - padX) / scale;
float y1 = (cy - h / 2 - padY) / scale;
float bw = w / scale;
float bh = h / scale;
x1 = Math.Max(0, x1);
y1 = Math.Max(0, y1);
bw = Math.Min(bw, origW - x1);
bh = Math.Min(bh, origH - y1);
boxes.Add(new Rect((int)x1, (int)y1, (int)bw, (int)bh));
confidences.Add(bestConf);
classIds.Add(bestClass);
}
if (boxes.Count == 0)
return [];
CvDnn.NMSBoxes(boxes, confidences, _confThreshold, _iouThreshold, out int[] indices);
var detections = new List<DetectedBoss>(indices.Length);
foreach (var idx in indices)
{
var box = boxes[idx];
detections.Add(new DetectedBoss(
_classNames[classIds[idx]],
confidences[idx],
box.X, box.Y, box.Width, box.Height,
box.X + box.Width / 2,
box.Y + box.Height / 2));
}
return detections;
}
public void Dispose()
{
_session.Dispose();
_resized.Dispose();
_padded.Dispose();
_rgb.Dispose();
_floatMat.Dispose();
}
}

View file

@ -11,6 +11,7 @@
<PackageReference Include="OpenCvSharp4.runtime.win" Version="4.11.0.*" /> <PackageReference Include="OpenCvSharp4.runtime.win" Version="4.11.0.*" />
<PackageReference Include="System.Drawing.Common" Version="8.0.12" /> <PackageReference Include="System.Drawing.Common" Version="8.0.12" />
<PackageReference Include="Vortice.Direct3D11" Version="3.8.2" /> <PackageReference Include="Vortice.Direct3D11" Version="3.8.2" />
<PackageReference Include="Microsoft.ML.OnnxRuntime" Version="1.24.*" />
<PackageReference Include="Vortice.DXGI" Version="3.8.2" /> <PackageReference Include="Vortice.DXGI" Version="3.8.2" />
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>

View file

@ -39,7 +39,7 @@ class PythonDetectBridge : IDisposable
{ {
EnsureRunning(); EnsureRunning();
var imageBytes = bgrMat.ToBytes(".png"); var imageBytes = bgrMat.ToBytes(".jpg", [(int)ImwriteFlags.JpegQuality, 95]);
var imageBase64 = Convert.ToBase64String(imageBytes); var imageBase64 = Convert.ToBase64String(imageBytes);
var req = new Dictionary<string, object?> var req = new Dictionary<string, object?>

View file

@ -185,6 +185,7 @@ public sealed class D2dOverlay
return new OverlayState( return new OverlayState(
Enemies: detection.Enemies, Enemies: detection.Enemies,
Bosses: bossDetection.Bosses, Bosses: bossDetection.Bosses,
BossTimestampMs: bossDetection.Timestamp,
InferenceMs: detection.InferenceMs, InferenceMs: detection.InferenceMs,
Hud: _bot.HudReader.Current, Hud: _bot.HudReader.Current,
NavState: _bot.Navigation.State, NavState: _bot.Navigation.State,

View file

@ -6,6 +6,7 @@ namespace Poe2Trade.Ui.Overlay;
public record OverlayState( public record OverlayState(
IReadOnlyList<DetectedEnemy> Enemies, IReadOnlyList<DetectedEnemy> Enemies,
IReadOnlyList<DetectedBoss> Bosses, IReadOnlyList<DetectedBoss> Bosses,
long BossTimestampMs,
float InferenceMs, float InferenceMs,
HudSnapshot? Hud, HudSnapshot? Hud,
NavigationState NavState, NavigationState NavState,

View file

@ -53,10 +53,15 @@ internal sealed class D2dEnemyBoxLayer : ID2dOverlayLayer, IDisposable
rt.DrawTextLayout(new System.Numerics.Vector2(labelX, labelY), layout, textBrush); rt.DrawTextLayout(new System.Numerics.Vector2(labelX, labelY), layout, textBrush);
} }
// Boss bounding boxes (cyan) // Boss bounding boxes (cyan) — extrapolate position to compensate for inference delay
var now = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds();
var ageMs = (float)Math.Clamp(now - state.BossTimestampMs, 0, 60);
foreach (var boss in state.Bosses) foreach (var boss in state.Bosses)
{ {
var rect = new RectangleF(boss.X, boss.Y, boss.Width, boss.Height); float dx = boss.VxPerMs * ageMs;
float dy = boss.VyPerMs * ageMs;
var rect = new RectangleF(boss.X + dx, boss.Y + dy, boss.Width, boss.Height);
rt.DrawRectangle(rect, ctx.Cyan, 3f); rt.DrawRectangle(rect, ctx.Cyan, 3f);
var pct = Math.Clamp((int)(boss.Confidence * 100), 0, 100); var pct = Math.Clamp((int)(boss.Confidence * 100), 0, 100);
@ -68,8 +73,8 @@ internal sealed class D2dEnemyBoxLayer : ID2dOverlayLayer, IDisposable
} }
var m = layout.Metrics; var m = layout.Metrics;
var labelX = boss.X; var labelX = boss.X + dx;
var labelY = boss.Y - m.Height - 2; var labelY = boss.Y + dy - m.Height - 2;
rt.FillRectangle( rt.FillRectangle(
new RectangleF(labelX - 1, labelY - 1, m.Width + 2, m.Height + 2), new RectangleF(labelX - 1, labelY - 1, m.Width + 2, m.Height + 2),

View file

@ -7,6 +7,7 @@ Subcommands (all take a positional boss name):
runs kulemak List training runs + metrics table runs kulemak List training runs + metrics table
annotate kulemak [dir] Launch annotation GUI annotate kulemak [dir] Launch annotation GUI
prelabel kulemak [dir] [--model boss-kulemak] Auto-label unlabeled images prelabel kulemak [dir] [--model boss-kulemak] Auto-label unlabeled images
export kulemak [--imgsz 640] Export .pt to ONNX format
""" """
import argparse import argparse
@ -291,6 +292,32 @@ def cmd_prelabel(args):
_PRELABEL_MODEL_DEFAULT = "__auto__" _PRELABEL_MODEL_DEFAULT = "__auto__"
# ── export ──────────────────────────────────────────────────────
def cmd_export(args):
"""Export .pt model to .onnx format for ONNX Runtime inference."""
boss = args.boss
model_name = f"boss-{boss}"
pt_path = os.path.join(MODELS_DIR, f"{model_name}.pt")
if not os.path.exists(pt_path):
print(f"Model not found: {pt_path}")
return
from ultralytics import YOLO
model = YOLO(pt_path)
print(f"Exporting {pt_path} -> ONNX (imgsz={args.imgsz})...")
model.export(format="onnx", imgsz=args.imgsz, opset=17, simplify=True, dynamic=False)
# ultralytics writes the .onnx next to the .pt file
onnx_src = os.path.join(MODELS_DIR, f"{model_name}.onnx")
if os.path.exists(onnx_src):
size_mb = os.path.getsize(onnx_src) / (1024 * 1024)
print(f"\nExported: {onnx_src} ({size_mb:.1f} MB)")
else:
print(f"\nWarning: expected output not found at {onnx_src}")
# ── CLI ────────────────────────────────────────────────────────── # ── CLI ──────────────────────────────────────────────────────────
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -315,10 +342,10 @@ def main():
p = sub.add_parser("train", help="Train YOLO model") p = sub.add_parser("train", help="Train YOLO model")
p.add_argument("boss", help="Boss name (e.g. kulemak)") p.add_argument("boss", help="Boss name (e.g. kulemak)")
p.add_argument("--data", default=None, help="Path to data.yaml") p.add_argument("--data", default=None, help="Path to data.yaml")
p.add_argument("--model", default="yolo11s", help="YOLO model variant") p.add_argument("--model", default="yolo11n", help="YOLO model variant")
p.add_argument("--epochs", type=int, default=200, help="Training epochs") p.add_argument("--epochs", type=int, default=200, help="Training epochs")
p.add_argument("--imgsz", type=int, default=1280, help="Image size") p.add_argument("--imgsz", type=int, default=640, help="Image size")
p.add_argument("--batch", type=int, default=8, help="Batch size") p.add_argument("--batch", type=int, default=16, help="Batch size")
p.add_argument("--device", default="0", help="CUDA device") p.add_argument("--device", default="0", help="CUDA device")
p.add_argument("--name", default=None, help="Run name (auto-increments if omitted)") p.add_argument("--name", default=None, help="Run name (auto-increments if omitted)")
@ -333,6 +360,11 @@ def main():
p.add_argument("--model", default=_PRELABEL_MODEL_DEFAULT, help="Model name in models/ (default: boss-{boss})") p.add_argument("--model", default=_PRELABEL_MODEL_DEFAULT, help="Model name in models/ (default: boss-{boss})")
p.add_argument("--conf", type=float, default=0.20, help="Confidence threshold") p.add_argument("--conf", type=float, default=0.20, help="Confidence threshold")
# export
p = sub.add_parser("export", help="Export .pt model to ONNX format")
p.add_argument("boss", help="Boss name (e.g. kulemak)")
p.add_argument("--imgsz", type=int, default=640, help="Image size for export")
args = parser.parse_args() args = parser.parse_args()
if args.command is None: if args.command is None:
@ -345,6 +377,7 @@ def main():
"train": cmd_train, "train": cmd_train,
"runs": cmd_runs, "runs": cmd_runs,
"prelabel": cmd_prelabel, "prelabel": cmd_prelabel,
"export": cmd_export,
} }
commands[args.command](args) commands[args.command](args)