"""Fix labels generated by buggy annotate.py that multiplied by scale.""" import glob import os import cv2 import sys def compute_scale(img_path): img = cv2.imread(img_path) if img is None: return None h, w = img.shape[:2] max_h, max_w = 900, 1600 if h > max_h or w > max_w: return min(max_w / w, max_h / h) return 1.0 def fix_label(label_path, scale): if scale == 1.0: return False with open(label_path) as f: lines = f.readlines() fixed = [] for line in lines: parts = line.strip().split() if len(parts) != 5: fixed.append(line) continue cls = parts[0] cx = float(parts[1]) / scale cy = float(parts[2]) / scale w = float(parts[3]) / scale h = float(parts[4]) / scale cx = min(cx, 1.0) cy = min(cy, 1.0) w = min(w, 1.0) h = min(h, 1.0) fixed.append(f"{cls} {cx:.6f} {cy:.6f} {w:.6f} {h:.6f}\n") with open(label_path, "w") as f: f.writelines(fixed) return True def main(): dataset_dir = sys.argv[1] if len(sys.argv) > 1 else "../../training-data/boss-dataset" dataset_dir = os.path.abspath(dataset_dir) count = 0 for split in ["train", "valid"]: label_dir = os.path.join(dataset_dir, split, "labels") img_dir = os.path.join(dataset_dir, split, "images") if not os.path.isdir(label_dir): continue for label_path in glob.glob(os.path.join(label_dir, "*.txt")): base = os.path.splitext(os.path.basename(label_path))[0] img_path = None for ext in (".jpg", ".jpeg", ".png"): candidate = os.path.join(img_dir, base + ext) if os.path.exists(candidate): img_path = candidate break if img_path is None: print(f" WARNING: no image for {label_path}") continue scale = compute_scale(img_path) if scale is None: print(f" WARNING: can't read {img_path}") continue if fix_label(label_path, scale): count += 1 # Show first few for verification if count <= 3: with open(label_path) as f: print(f" Fixed {os.path.basename(label_path)}: {f.read().strip()}") print(f"\nFixed {count} label files") if __name__ == "__main__": main()