77 lines
2.4 KiB
Python
77 lines
2.4 KiB
Python
"""Fix labels generated by buggy annotate.py that multiplied by scale."""
|
|
import glob
|
|
import os
|
|
import cv2
|
|
import sys
|
|
|
|
def compute_scale(img_path):
|
|
img = cv2.imread(img_path)
|
|
if img is None:
|
|
return None
|
|
h, w = img.shape[:2]
|
|
max_h, max_w = 900, 1600
|
|
if h > max_h or w > max_w:
|
|
return min(max_w / w, max_h / h)
|
|
return 1.0
|
|
|
|
def fix_label(label_path, scale):
|
|
if scale == 1.0:
|
|
return False
|
|
with open(label_path) as f:
|
|
lines = f.readlines()
|
|
fixed = []
|
|
for line in lines:
|
|
parts = line.strip().split()
|
|
if len(parts) != 5:
|
|
fixed.append(line)
|
|
continue
|
|
cls = parts[0]
|
|
cx = float(parts[1]) / scale
|
|
cy = float(parts[2]) / scale
|
|
w = float(parts[3]) / scale
|
|
h = float(parts[4]) / scale
|
|
cx = min(cx, 1.0)
|
|
cy = min(cy, 1.0)
|
|
w = min(w, 1.0)
|
|
h = min(h, 1.0)
|
|
fixed.append(f"{cls} {cx:.6f} {cy:.6f} {w:.6f} {h:.6f}\n")
|
|
with open(label_path, "w") as f:
|
|
f.writelines(fixed)
|
|
return True
|
|
|
|
def main():
|
|
dataset_dir = sys.argv[1] if len(sys.argv) > 1 else "../../training-data/boss-dataset"
|
|
dataset_dir = os.path.abspath(dataset_dir)
|
|
|
|
count = 0
|
|
for split in ["train", "valid"]:
|
|
label_dir = os.path.join(dataset_dir, split, "labels")
|
|
img_dir = os.path.join(dataset_dir, split, "images")
|
|
if not os.path.isdir(label_dir):
|
|
continue
|
|
for label_path in glob.glob(os.path.join(label_dir, "*.txt")):
|
|
base = os.path.splitext(os.path.basename(label_path))[0]
|
|
img_path = None
|
|
for ext in (".jpg", ".jpeg", ".png"):
|
|
candidate = os.path.join(img_dir, base + ext)
|
|
if os.path.exists(candidate):
|
|
img_path = candidate
|
|
break
|
|
if img_path is None:
|
|
print(f" WARNING: no image for {label_path}")
|
|
continue
|
|
scale = compute_scale(img_path)
|
|
if scale is None:
|
|
print(f" WARNING: can't read {img_path}")
|
|
continue
|
|
if fix_label(label_path, scale):
|
|
count += 1
|
|
# Show first few for verification
|
|
if count <= 3:
|
|
with open(label_path) as f:
|
|
print(f" Fixed {os.path.basename(label_path)}: {f.read().strip()}")
|
|
|
|
print(f"\nFixed {count} label files")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|