added training
This commit is contained in:
parent
528453a321
commit
cc50368d3b
7 changed files with 901 additions and 1 deletions
104
tools/training/continue_training.sh
Normal file
104
tools/training/continue_training.sh
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
#!/usr/bin/env bash
|
||||
# Continue LSTM fine-tuning in batches until improvement plateaus.
|
||||
# Stops when BCER improvement drops below 0.3% per batch.
|
||||
#
|
||||
# Run from PowerShell:
|
||||
# MSYS_NO_PATHCONV=1 wsl -d Ubuntu-22.04 -u root bash /mnt/c/Users/boki/repos/poe2trade/tools/training/continue_training.sh
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
WORK_DIR="$HOME/poe2-tesseract-training"
|
||||
TESSDATA_DIR="/mnt/c/Users/boki/repos/poe2trade/tools/OcrDaemon/tessdata"
|
||||
BATCH_SIZE=400
|
||||
MIN_IMPROVEMENT=0.3 # stop if BCER improves less than this per batch
|
||||
MAX_TOTAL=5000 # absolute safety cap
|
||||
|
||||
cd "$WORK_DIR"
|
||||
|
||||
# Parse BCER from checkpoint filename like poe2_4.230_102_800.checkpoint
|
||||
get_best_bcer() {
|
||||
ls -1 output/poe2_*.checkpoint 2>/dev/null \
|
||||
| grep -v 'poe2_checkpoint$' \
|
||||
| sed 's/.*poe2_\([0-9.]*\)_.*/\1/' \
|
||||
| sort -n \
|
||||
| head -1
|
||||
}
|
||||
|
||||
# Get max iterations from checkpoint filename
|
||||
get_max_iter() {
|
||||
ls -1 output/poe2_*.checkpoint 2>/dev/null \
|
||||
| grep -v 'poe2_checkpoint$' \
|
||||
| sed 's/.*_\([0-9]*\)\.checkpoint/\1/' \
|
||||
| sort -n \
|
||||
| tail -1
|
||||
}
|
||||
|
||||
prev_bcer=$(get_best_bcer)
|
||||
current_max=$(get_max_iter)
|
||||
|
||||
echo "=== Continuing LSTM Training ==="
|
||||
echo "Starting BCER: ${prev_bcer}%"
|
||||
echo "Starting iterations: $current_max"
|
||||
echo "Batch size: $BATCH_SIZE"
|
||||
echo "Min improvement threshold: ${MIN_IMPROVEMENT}%"
|
||||
echo ""
|
||||
|
||||
batch=0
|
||||
while true; do
|
||||
batch=$((batch + 1))
|
||||
new_max=$((current_max + BATCH_SIZE))
|
||||
|
||||
if [ "$new_max" -gt "$MAX_TOTAL" ]; then
|
||||
echo "Reached safety cap of $MAX_TOTAL iterations. Stopping."
|
||||
break
|
||||
fi
|
||||
|
||||
echo "── Batch $batch: iterations $current_max → $new_max ──"
|
||||
|
||||
lstmtraining \
|
||||
--continue_from output/poe2_checkpoint \
|
||||
--traineddata eng.traineddata \
|
||||
--train_listfile training_files.txt \
|
||||
--model_output output/poe2 \
|
||||
--max_iterations "$new_max" \
|
||||
--target_error_rate 0.005 \
|
||||
--debug_interval -1 2>&1 | tail -5
|
||||
|
||||
new_bcer=$(get_best_bcer)
|
||||
echo ""
|
||||
echo "Batch $batch result: BCER ${prev_bcer}% → ${new_bcer}%"
|
||||
|
||||
# Calculate improvement using awk (bash can't do float math)
|
||||
improvement=$(awk "BEGIN {printf \"%.3f\", $prev_bcer - $new_bcer}")
|
||||
echo "Improvement: ${improvement}%"
|
||||
|
||||
# Check if improvement is below threshold
|
||||
stop=$(awk "BEGIN {print ($improvement < $MIN_IMPROVEMENT) ? 1 : 0}")
|
||||
|
||||
if [ "$stop" -eq 1 ]; then
|
||||
echo ""
|
||||
echo "Improvement (${improvement}%) < threshold (${MIN_IMPROVEMENT}%). Stopping."
|
||||
break
|
||||
fi
|
||||
|
||||
prev_bcer="$new_bcer"
|
||||
current_max="$new_max"
|
||||
echo ""
|
||||
done
|
||||
|
||||
# Package final model
|
||||
echo ""
|
||||
echo "=== Packaging final model ==="
|
||||
final_bcer=$(get_best_bcer)
|
||||
echo "Final BCER: ${final_bcer}%"
|
||||
|
||||
lstmtraining --stop_training \
|
||||
--continue_from output/poe2_checkpoint \
|
||||
--traineddata eng.traineddata \
|
||||
--model_output output/poe2.traineddata
|
||||
|
||||
cp output/poe2.traineddata "$TESSDATA_DIR/poe2.traineddata"
|
||||
echo "Model saved to: $TESSDATA_DIR/poe2.traineddata"
|
||||
ls -lh "$TESSDATA_DIR/poe2.traineddata"
|
||||
echo ""
|
||||
echo "=== Done ==="
|
||||
Loading…
Add table
Add a link
Reference in a new issue