poe2-bot/tools/training/continue_training.sh
2026-02-10 22:56:51 -05:00

104 lines
3 KiB
Bash

#!/usr/bin/env bash
# Continue LSTM fine-tuning in batches until improvement plateaus.
# Stops when BCER improvement drops below 0.3% per batch.
#
# Run from PowerShell:
# MSYS_NO_PATHCONV=1 wsl -d Ubuntu-22.04 -u root bash /mnt/c/Users/boki/repos/poe2trade/tools/training/continue_training.sh
set -euo pipefail
WORK_DIR="$HOME/poe2-tesseract-training"
TESSDATA_DIR="/mnt/c/Users/boki/repos/poe2trade/tools/OcrDaemon/tessdata"
BATCH_SIZE=400
MIN_IMPROVEMENT=0.3 # stop if BCER improves less than this per batch
MAX_TOTAL=5000 # absolute safety cap
cd "$WORK_DIR"
# Parse BCER from checkpoint filename like poe2_4.230_102_800.checkpoint
get_best_bcer() {
ls -1 output/poe2_*.checkpoint 2>/dev/null \
| grep -v 'poe2_checkpoint$' \
| sed 's/.*poe2_\([0-9.]*\)_.*/\1/' \
| sort -n \
| head -1
}
# Get max iterations from checkpoint filename
get_max_iter() {
ls -1 output/poe2_*.checkpoint 2>/dev/null \
| grep -v 'poe2_checkpoint$' \
| sed 's/.*_\([0-9]*\)\.checkpoint/\1/' \
| sort -n \
| tail -1
}
prev_bcer=$(get_best_bcer)
current_max=$(get_max_iter)
echo "=== Continuing LSTM Training ==="
echo "Starting BCER: ${prev_bcer}%"
echo "Starting iterations: $current_max"
echo "Batch size: $BATCH_SIZE"
echo "Min improvement threshold: ${MIN_IMPROVEMENT}%"
echo ""
batch=0
while true; do
batch=$((batch + 1))
new_max=$((current_max + BATCH_SIZE))
if [ "$new_max" -gt "$MAX_TOTAL" ]; then
echo "Reached safety cap of $MAX_TOTAL iterations. Stopping."
break
fi
echo "── Batch $batch: iterations $current_max$new_max ──"
lstmtraining \
--continue_from output/poe2_checkpoint \
--traineddata eng.traineddata \
--train_listfile training_files.txt \
--model_output output/poe2 \
--max_iterations "$new_max" \
--target_error_rate 0.005 \
--debug_interval -1 2>&1 | tail -5
new_bcer=$(get_best_bcer)
echo ""
echo "Batch $batch result: BCER ${prev_bcer}% → ${new_bcer}%"
# Calculate improvement using awk (bash can't do float math)
improvement=$(awk "BEGIN {printf \"%.3f\", $prev_bcer - $new_bcer}")
echo "Improvement: ${improvement}%"
# Check if improvement is below threshold
stop=$(awk "BEGIN {print ($improvement < $MIN_IMPROVEMENT) ? 1 : 0}")
if [ "$stop" -eq 1 ]; then
echo ""
echo "Improvement (${improvement}%) < threshold (${MIN_IMPROVEMENT}%). Stopping."
break
fi
prev_bcer="$new_bcer"
current_max="$new_max"
echo ""
done
# Package final model
echo ""
echo "=== Packaging final model ==="
final_bcer=$(get_best_bcer)
echo "Final BCER: ${final_bcer}%"
lstmtraining --stop_training \
--continue_from output/poe2_checkpoint \
--traineddata eng.traineddata \
--model_output output/poe2.traineddata
cp output/poe2.traineddata "$TESSDATA_DIR/poe2.traineddata"
echo "Model saved to: $TESSDATA_DIR/poe2.traineddata"
ls -lh "$TESSDATA_DIR/poe2.traineddata"
echo ""
echo "=== Done ==="