104 lines
3 KiB
Bash
104 lines
3 KiB
Bash
#!/usr/bin/env bash
|
|
# Continue LSTM fine-tuning in batches until improvement plateaus.
|
|
# Stops when BCER improvement drops below 0.3% per batch.
|
|
#
|
|
# Run from PowerShell:
|
|
# MSYS_NO_PATHCONV=1 wsl -d Ubuntu-22.04 -u root bash /mnt/c/Users/boki/repos/poe2trade/tools/training/continue_training.sh
|
|
|
|
set -euo pipefail
|
|
|
|
WORK_DIR="$HOME/poe2-tesseract-training"
|
|
TESSDATA_DIR="/mnt/c/Users/boki/repos/poe2trade/tools/OcrDaemon/tessdata"
|
|
BATCH_SIZE=400
|
|
MIN_IMPROVEMENT=0.3 # stop if BCER improves less than this per batch
|
|
MAX_TOTAL=5000 # absolute safety cap
|
|
|
|
cd "$WORK_DIR"
|
|
|
|
# Parse BCER from checkpoint filename like poe2_4.230_102_800.checkpoint
|
|
get_best_bcer() {
|
|
ls -1 output/poe2_*.checkpoint 2>/dev/null \
|
|
| grep -v 'poe2_checkpoint$' \
|
|
| sed 's/.*poe2_\([0-9.]*\)_.*/\1/' \
|
|
| sort -n \
|
|
| head -1
|
|
}
|
|
|
|
# Get max iterations from checkpoint filename
|
|
get_max_iter() {
|
|
ls -1 output/poe2_*.checkpoint 2>/dev/null \
|
|
| grep -v 'poe2_checkpoint$' \
|
|
| sed 's/.*_\([0-9]*\)\.checkpoint/\1/' \
|
|
| sort -n \
|
|
| tail -1
|
|
}
|
|
|
|
prev_bcer=$(get_best_bcer)
|
|
current_max=$(get_max_iter)
|
|
|
|
echo "=== Continuing LSTM Training ==="
|
|
echo "Starting BCER: ${prev_bcer}%"
|
|
echo "Starting iterations: $current_max"
|
|
echo "Batch size: $BATCH_SIZE"
|
|
echo "Min improvement threshold: ${MIN_IMPROVEMENT}%"
|
|
echo ""
|
|
|
|
batch=0
|
|
while true; do
|
|
batch=$((batch + 1))
|
|
new_max=$((current_max + BATCH_SIZE))
|
|
|
|
if [ "$new_max" -gt "$MAX_TOTAL" ]; then
|
|
echo "Reached safety cap of $MAX_TOTAL iterations. Stopping."
|
|
break
|
|
fi
|
|
|
|
echo "── Batch $batch: iterations $current_max → $new_max ──"
|
|
|
|
lstmtraining \
|
|
--continue_from output/poe2_checkpoint \
|
|
--traineddata eng.traineddata \
|
|
--train_listfile training_files.txt \
|
|
--model_output output/poe2 \
|
|
--max_iterations "$new_max" \
|
|
--target_error_rate 0.005 \
|
|
--debug_interval -1 2>&1 | tail -5
|
|
|
|
new_bcer=$(get_best_bcer)
|
|
echo ""
|
|
echo "Batch $batch result: BCER ${prev_bcer}% → ${new_bcer}%"
|
|
|
|
# Calculate improvement using awk (bash can't do float math)
|
|
improvement=$(awk "BEGIN {printf \"%.3f\", $prev_bcer - $new_bcer}")
|
|
echo "Improvement: ${improvement}%"
|
|
|
|
# Check if improvement is below threshold
|
|
stop=$(awk "BEGIN {print ($improvement < $MIN_IMPROVEMENT) ? 1 : 0}")
|
|
|
|
if [ "$stop" -eq 1 ]; then
|
|
echo ""
|
|
echo "Improvement (${improvement}%) < threshold (${MIN_IMPROVEMENT}%). Stopping."
|
|
break
|
|
fi
|
|
|
|
prev_bcer="$new_bcer"
|
|
current_max="$new_max"
|
|
echo ""
|
|
done
|
|
|
|
# Package final model
|
|
echo ""
|
|
echo "=== Packaging final model ==="
|
|
final_bcer=$(get_best_bcer)
|
|
echo "Final BCER: ${final_bcer}%"
|
|
|
|
lstmtraining --stop_training \
|
|
--continue_from output/poe2_checkpoint \
|
|
--traineddata eng.traineddata \
|
|
--model_output output/poe2.traineddata
|
|
|
|
cp output/poe2.traineddata "$TESSDATA_DIR/poe2.traineddata"
|
|
echo "Model saved to: $TESSDATA_DIR/poe2.traineddata"
|
|
ls -lh "$TESSDATA_DIR/poe2.traineddata"
|
|
echo ""
|
|
echo "=== Done ==="
|