154 lines
4.5 KiB
Python
154 lines
4.5 KiB
Python
"""
|
|
Persistent Python YOLO detection daemon (stdin/stdout JSON-per-line protocol).
|
|
|
|
Loads a YOLOv11 model and serves inference requests over stdin/stdout.
|
|
Managed as a subprocess by PythonDetectBridge in Poe2Trade.Screen.
|
|
|
|
Request: {"cmd": "detect", "imageBase64": "...", "conf": 0.3, "iou": 0.45, "imgsz": 640}
|
|
Response: {"ok": true, "count": 3, "inferenceMs": 12.5, "detections": [...]}
|
|
"""
|
|
|
|
import sys
|
|
import json
|
|
import time
|
|
|
|
_models = {}
|
|
|
|
|
|
def _redirect_stdout_to_stderr():
|
|
"""Redirect stdout to stderr so library print() calls don't corrupt the JSON protocol."""
|
|
real_stdout = sys.stdout
|
|
sys.stdout = sys.stderr
|
|
return real_stdout
|
|
|
|
|
|
def _restore_stdout(real_stdout):
|
|
sys.stdout = real_stdout
|
|
|
|
|
|
def load_model(name="enemy-v1"):
|
|
if name in _models:
|
|
return _models[name]
|
|
|
|
import os
|
|
from ultralytics import YOLO
|
|
|
|
model_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models")
|
|
|
|
# Prefer TensorRT engine if available, fall back to .pt
|
|
engine_path = os.path.join(model_dir, f"{name}.engine")
|
|
pt_path = os.path.join(model_dir, f"{name}.pt")
|
|
model_path = engine_path if os.path.exists(engine_path) else pt_path
|
|
|
|
if not os.path.exists(model_path):
|
|
raise FileNotFoundError(f"Model not found: {pt_path} (also checked {engine_path})")
|
|
|
|
sys.stderr.write(f"Loading YOLO model '{name}' from {model_path}...\n")
|
|
sys.stderr.flush()
|
|
|
|
real_stdout = _redirect_stdout_to_stderr()
|
|
try:
|
|
model = YOLO(model_path)
|
|
# Warmup with dummy inference (triggers CUDA init)
|
|
import numpy as np
|
|
dummy = np.zeros((640, 640, 3), dtype=np.uint8)
|
|
model.predict(dummy, verbose=False)
|
|
finally:
|
|
_restore_stdout(real_stdout)
|
|
|
|
_models[name] = model
|
|
sys.stderr.write(f"YOLO model '{name}' loaded and warmed up.\n")
|
|
sys.stderr.flush()
|
|
return model
|
|
|
|
|
|
def handle_detect(req):
|
|
import base64
|
|
import io
|
|
import numpy as np
|
|
from PIL import Image
|
|
|
|
image_base64 = req.get("imageBase64")
|
|
if not image_base64:
|
|
return {"ok": False, "error": "Missing imageBase64"}
|
|
|
|
img_bytes = base64.b64decode(image_base64)
|
|
img = np.array(Image.open(io.BytesIO(img_bytes)))
|
|
# PIL gives RGB, but ultralytics model.predict() assumes numpy arrays are BGR
|
|
img = img[:, :, ::-1]
|
|
|
|
conf = req.get("conf", 0.3)
|
|
iou = req.get("iou", 0.45)
|
|
imgsz = req.get("imgsz", 640)
|
|
model_name = req.get("model", "enemy-v1")
|
|
|
|
model = load_model(model_name)
|
|
|
|
real_stdout = _redirect_stdout_to_stderr()
|
|
try:
|
|
start = time.perf_counter()
|
|
results = model.predict(img, conf=conf, iou=iou, imgsz=imgsz, verbose=False)
|
|
inference_ms = (time.perf_counter() - start) * 1000
|
|
finally:
|
|
_restore_stdout(real_stdout)
|
|
|
|
detections = []
|
|
for result in results:
|
|
boxes = result.boxes
|
|
if boxes is None:
|
|
continue
|
|
for i in range(len(boxes)):
|
|
box = boxes[i]
|
|
x1, y1, x2, y2 = box.xyxy[0].tolist()
|
|
x, y = int(x1), int(y1)
|
|
w, h = int(x2 - x1), int(y2 - y1)
|
|
cx, cy = x + w // 2, y + h // 2
|
|
class_id = int(box.cls[0].item())
|
|
class_name = result.names[class_id] if result.names else str(class_id)
|
|
confidence = float(box.conf[0].item())
|
|
|
|
detections.append({
|
|
"class": class_name,
|
|
"classId": class_id,
|
|
"confidence": round(confidence, 4),
|
|
"x": x, "y": y, "width": w, "height": h,
|
|
"cx": cx, "cy": cy,
|
|
})
|
|
|
|
return {
|
|
"ok": True,
|
|
"count": len(detections),
|
|
"inferenceMs": round(inference_ms, 2),
|
|
"detections": detections,
|
|
}
|
|
|
|
|
|
def handle_request(req):
|
|
cmd = req.get("cmd")
|
|
if cmd == "detect":
|
|
return handle_detect(req)
|
|
if cmd == "ping":
|
|
return {"ok": True, "pong": True}
|
|
return {"ok": False, "error": f"Unknown command: {cmd}"}
|
|
|
|
|
|
def main():
|
|
# Signal ready immediately — model loads lazily on first detect request
|
|
sys.stdout.write(json.dumps({"ok": True, "ready": True}) + "\n")
|
|
sys.stdout.flush()
|
|
|
|
for line in sys.stdin:
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
try:
|
|
req = json.loads(line)
|
|
resp = handle_request(req)
|
|
except Exception as e:
|
|
resp = {"ok": False, "error": str(e)}
|
|
sys.stdout.write(json.dumps(resp) + "\n")
|
|
sys.stdout.flush()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|