poe2-bot/tools/python-detect/fix_labels.py

77 lines
2.4 KiB
Python

"""Fix labels generated by buggy annotate.py that multiplied by scale."""
import glob
import os
import cv2
import sys
def compute_scale(img_path):
img = cv2.imread(img_path)
if img is None:
return None
h, w = img.shape[:2]
max_h, max_w = 900, 1600
if h > max_h or w > max_w:
return min(max_w / w, max_h / h)
return 1.0
def fix_label(label_path, scale):
if scale == 1.0:
return False
with open(label_path) as f:
lines = f.readlines()
fixed = []
for line in lines:
parts = line.strip().split()
if len(parts) != 5:
fixed.append(line)
continue
cls = parts[0]
cx = float(parts[1]) / scale
cy = float(parts[2]) / scale
w = float(parts[3]) / scale
h = float(parts[4]) / scale
cx = min(cx, 1.0)
cy = min(cy, 1.0)
w = min(w, 1.0)
h = min(h, 1.0)
fixed.append(f"{cls} {cx:.6f} {cy:.6f} {w:.6f} {h:.6f}\n")
with open(label_path, "w") as f:
f.writelines(fixed)
return True
def main():
dataset_dir = sys.argv[1] if len(sys.argv) > 1 else "../../training-data/boss-dataset"
dataset_dir = os.path.abspath(dataset_dir)
count = 0
for split in ["train", "valid"]:
label_dir = os.path.join(dataset_dir, split, "labels")
img_dir = os.path.join(dataset_dir, split, "images")
if not os.path.isdir(label_dir):
continue
for label_path in glob.glob(os.path.join(label_dir, "*.txt")):
base = os.path.splitext(os.path.basename(label_path))[0]
img_path = None
for ext in (".jpg", ".jpeg", ".png"):
candidate = os.path.join(img_dir, base + ext)
if os.path.exists(candidate):
img_path = candidate
break
if img_path is None:
print(f" WARNING: no image for {label_path}")
continue
scale = compute_scale(img_path)
if scale is None:
print(f" WARNING: can't read {img_path}")
continue
if fix_label(label_path, scale):
count += 1
# Show first few for verification
if count <= 3:
with open(label_path) as f:
print(f" Fixed {os.path.basename(label_path)}: {f.read().strip()}")
print(f"\nFixed {count} label files")
if __name__ == "__main__":
main()