use std::collections::HashMap; use std::sync::Arc; use parking_lot::RwLock; use chrono::{DateTime, Utc}; use serde::{Serialize, Deserialize}; use crate::{ TradingMode, MarketDataSource, ExecutionHandler, TimeProvider, MarketUpdate, MarketDataType, Order, Fill, Side, positions::PositionTracker, risk::RiskEngine, orderbook::OrderBookManager, }; use super::{ BacktestConfig, BacktestState, EventQueue, BacktestEvent, EventType, Strategy, Signal, SignalType, BacktestResult, TradeTracker, }; pub struct BacktestEngine { config: BacktestConfig, state: Arc>, event_queue: Arc>, strategies: Arc>>>, // Core components position_tracker: Arc, risk_engine: Arc, orderbook_manager: Arc, time_provider: Arc>, pub market_data_source: Arc>>, execution_handler: Arc>>, // Metrics total_trades: usize, profitable_trades: usize, total_pnl: f64, // Price tracking last_prices: HashMap, // Trade tracking trade_tracker: TradeTracker, } impl BacktestEngine { pub fn new( config: BacktestConfig, mode: TradingMode, time_provider: Box, market_data_source: Box, execution_handler: Box, ) -> Self { let state = Arc::new(RwLock::new( BacktestState::new(config.initial_capital, config.start_time) )); Self { config, state, event_queue: Arc::new(RwLock::new(EventQueue::new())), strategies: Arc::new(RwLock::new(Vec::new())), position_tracker: Arc::new(PositionTracker::new()), risk_engine: Arc::new(RiskEngine::new()), orderbook_manager: Arc::new(OrderBookManager::new()), time_provider: Arc::new(time_provider), market_data_source: Arc::new(RwLock::new(market_data_source)), execution_handler: Arc::new(RwLock::new(execution_handler)), total_trades: 0, profitable_trades: 0, total_pnl: 0.0, last_prices: HashMap::new(), trade_tracker: TradeTracker::new(), } } pub fn add_strategy(&mut self, strategy: Box) { self.strategies.write().push(strategy); } pub async fn run(&mut self) -> Result { eprintln!("=== BacktestEngine::run() START ==="); eprintln!("Config: start={}, end={}, symbols={:?}", self.config.start_time, self.config.end_time, self.config.symbols); eprintln!("Number of strategies loaded: {}", self.strategies.read().len()); // Initialize start time if let Some(simulated_time) = self.time_provider.as_any() .downcast_ref::() { simulated_time.advance_to(self.config.start_time); eprintln!("Time initialized to: {}", self.config.start_time); } // Load market data eprintln!("Loading market data from data source..."); self.load_market_data().await?; let queue_len = self.event_queue.read().len(); eprintln!("Event queue length after loading: {}", queue_len); if queue_len == 0 { eprintln!("WARNING: No events loaded! Check data source."); } // Main event loop let mut iteration = 0; while !self.event_queue.read().is_empty() { iteration += 1; if iteration <= 5 || iteration % 100 == 0 { eprintln!("Processing iteration {} at time {}", iteration, self.time_provider.now()); } // Get the next event's timestamp let next_event_time = self.event_queue.read() .peek_next() .map(|e| e.timestamp); if let Some(event_time) = next_event_time { // Advance time to the next event self.advance_time(event_time); // Get all events at this timestamp let current_time = self.time_provider.now(); let events = self.event_queue.write().pop_until(current_time); for event in events { self.process_event(event).await?; } // Update portfolio value self.update_portfolio_value(); } else { // No more events break; } } eprintln!("Backtest complete. Total trades: {}", self.total_trades); // Generate results Ok(self.generate_results()) } async fn load_market_data(&mut self) -> Result<(), String> { eprintln!("=== load_market_data START ==="); let mut data_source = self.market_data_source.write(); // Check if it's a HistoricalDataSource if let Some(historical) = data_source.as_any() .downcast_ref::() { eprintln!("Data source is HistoricalDataSource"); eprintln!("Historical data points available: {}", historical.data_len()); } else { eprintln!("WARNING: Data source is NOT HistoricalDataSource!"); } eprintln!("Seeking to start time: {}", self.config.start_time); data_source.seek_to_time(self.config.start_time)?; let mut count = 0; let mut first_few = 0; // Load all data into event queue while let Some(update) = data_source.get_next_update().await { if update.timestamp > self.config.end_time { eprintln!("Reached end time at {} data points", count); break; } count += 1; // Log first few data points if first_few < 3 { eprintln!("Data point {}: symbol={}, time={}, type={:?}", count, update.symbol, update.timestamp, match &update.data { MarketDataType::Bar(b) => format!("Bar(close={})", b.close), MarketDataType::Quote(q) => format!("Quote(bid={}, ask={})", q.bid, q.ask), MarketDataType::Trade(t) => format!("Trade(price={})", t.price), } ); first_few += 1; } if count % 100 == 0 { eprintln!("Loaded {} data points so far...", count); } let event = BacktestEvent::market_data(update.timestamp, update); self.event_queue.write().push(event); } eprintln!("=== load_market_data COMPLETE ==="); eprintln!("Total data points loaded: {}", count); Ok(()) } async fn process_event(&mut self, event: BacktestEvent) -> Result<(), String> { match event.event_type { EventType::MarketData(data) => { self.process_market_data(data).await?; } EventType::OrderSubmitted(order) => { self.process_order_submission(order).await?; } EventType::OrderFilled(_fill) => { // Fills are already processed when orders are executed // This event is just for recording // Note: We now record fills in process_fill with symbol info } EventType::OrderCancelled(order_id) => { self.process_order_cancellation(&order_id)?; } EventType::TimeUpdate(time) => { self.advance_time(time); } } Ok(()) } async fn process_market_data(&mut self, data: MarketUpdate) -> Result<(), String> { static mut MARKET_DATA_COUNT: usize = 0; unsafe { MARKET_DATA_COUNT += 1; if MARKET_DATA_COUNT <= 3 || MARKET_DATA_COUNT % 100 == 0 { eprintln!("process_market_data #{}: symbol={}, time={}", MARKET_DATA_COUNT, data.symbol, data.timestamp); } } // Update price tracking match &data.data { MarketDataType::Bar(bar) => { self.last_prices.insert(data.symbol.clone(), bar.close); } MarketDataType::Quote(quote) => { // Use mid price for quotes let mid_price = (quote.bid + quote.ask) / 2.0; self.last_prices.insert(data.symbol.clone(), mid_price); } MarketDataType::Trade(trade) => { self.last_prices.insert(data.symbol.clone(), trade.price); } } // Convert to simpler MarketData for strategies let market_data = self.convert_to_market_data(&data); // Send to strategies let mut all_signals = Vec::new(); { let mut strategies = self.strategies.write(); for (i, strategy) in strategies.iter_mut().enumerate() { let signals = strategy.on_market_data(&market_data); if !signals.is_empty() { eprintln!("Strategy {} generated {} signals!", i, signals.len()); } all_signals.extend(signals); } } // Process signals for signal in all_signals { eprintln!("Processing signal: {:?}", signal); self.process_signal(signal).await?; } // Check pending orders for fills self.check_pending_orders(&data).await?; Ok(()) } fn convert_to_market_data(&self, update: &MarketUpdate) -> MarketUpdate { // MarketData is a type alias for MarketUpdate update.clone() } async fn process_signal(&mut self, signal: Signal) -> Result<(), String> { // Only process strong signals if signal.strength.abs() < 0.7 { return Ok(()); } // Convert signal to order let order = self.signal_to_order(signal)?; // Submit order self.process_order_submission(order).await } fn signal_to_order(&self, signal: Signal) -> Result { let quantity = signal.quantity.unwrap_or_else(|| { // Calculate position size based on portfolio self.calculate_position_size(&signal.symbol, signal.strength) }); let side = match signal.signal_type { SignalType::Buy => Side::Buy, SignalType::Sell => Side::Sell, SignalType::Close => { // Determine side based on current position let position = self.position_tracker.get_position(&signal.symbol); if position.as_ref().map(|p| p.quantity > 0.0).unwrap_or(false) { Side::Sell } else { Side::Buy } } }; Ok(crate::Order { id: format!("order_{}", uuid::Uuid::new_v4()), symbol: signal.symbol, side, quantity, order_type: crate::OrderType::Market, time_in_force: crate::TimeInForce::Day, }) } async fn process_order_submission(&mut self, order: Order) -> Result<(), String> { // Risk checks // Get current position for the symbol let current_position = self.position_tracker .get_position(&order.symbol) .map(|p| p.quantity); let risk_check = self.risk_engine.check_order(&order, current_position); if !risk_check.passed { return Err(format!("Risk check failed: {:?}", risk_check.violations)); } // Add to pending orders self.state.write().add_pending_order(order.clone()); // For market orders in backtesting, fill immediately if matches!(order.order_type, crate::OrderType::Market) { self.check_order_fill(&order).await?; } Ok(()) } async fn check_pending_orders(&mut self, market_data: &MarketUpdate) -> Result<(), String> { let orders_to_check: Vec = { let state = self.state.read(); state.pending_orders.values() .filter(|o| o.symbol == market_data.symbol) .cloned() .collect() }; for order in orders_to_check { self.check_order_fill(&order).await?; } Ok(()) } async fn check_order_fill(&mut self, order: &Order) -> Result<(), String> { // Get current market price let base_price = self.last_prices.get(&order.symbol) .copied() .ok_or_else(|| format!("No price available for symbol: {}", order.symbol))?; // Apply slippage let fill_price = match order.side { crate::Side::Buy => base_price * (1.0 + self.config.slippage), crate::Side::Sell => base_price * (1.0 - self.config.slippage), }; // Create fill let fill = crate::Fill { timestamp: self.time_provider.now(), price: fill_price, quantity: order.quantity, commission: order.quantity * fill_price * self.config.commission, }; // Process the fill self.process_fill(&order, fill).await } async fn process_fill(&mut self, order: &crate::Order, fill: crate::Fill) -> Result<(), String> { // Remove from pending orders self.state.write().remove_pending_order(&order.id); // Update positions let update = self.position_tracker.process_fill( &order.symbol, &fill, order.side, ); // Record the fill with symbol and side information self.state.write().record_fill(order.symbol.clone(), order.side, fill.clone()); // Track trades self.trade_tracker.process_fill(&order.symbol, order.side, &fill); // Update cash let cash_change = match order.side { crate::Side::Buy => -(fill.quantity * fill.price + fill.commission), crate::Side::Sell => fill.quantity * fill.price - fill.commission, }; self.state.write().cash += cash_change; // Notify strategies { let mut strategies = self.strategies.write(); for strategy in strategies.iter_mut() { strategy.on_fill(&order.symbol, fill.quantity, fill.price, &format!("{:?}", order.side)); } } eprintln!("Fill processed: {} {} @ {} (side: {:?})", fill.quantity, order.symbol, fill.price, order.side); eprintln!("Current position after fill: {}", self.position_tracker.get_position(&order.symbol) .map(|p| p.quantity) .unwrap_or(0.0)); // Update metrics self.total_trades += 1; if update.resulting_position.realized_pnl > 0.0 { self.profitable_trades += 1; } self.total_pnl = update.resulting_position.realized_pnl; Ok(()) } fn process_order_cancellation(&mut self, order_id: &str) -> Result<(), String> { self.state.write().remove_pending_order(order_id); Ok(()) } fn advance_time(&mut self, time: DateTime) { if let Some(simulated_time) = self.time_provider.as_any() .downcast_ref::() { simulated_time.advance_to(time); } self.state.write().current_time = time; } fn update_portfolio_value(&mut self) { let positions = self.position_tracker.get_all_positions(); let mut portfolio_value = self.state.read().cash; for position in positions { // Use last known price for the symbol let price = self.last_prices.get(&position.symbol).copied().unwrap_or(position.average_price); let market_value = position.quantity * price; portfolio_value += market_value; } self.state.write().update_portfolio_value(portfolio_value); } fn calculate_position_size(&self, symbol: &str, signal_strength: f64) -> f64 { let portfolio_value = self.state.read().portfolio_value; let allocation = 0.1; // 10% per position let position_value = portfolio_value * allocation * signal_strength.abs(); let price = self.last_prices.get(symbol).copied().unwrap_or(100.0); (position_value / price).floor() } fn get_next_event_time(&self) -> Option> { // Get the timestamp of the next event in the queue self.event_queue.read() .peek_next() .map(|event| event.timestamp) } fn generate_results(&self) -> BacktestResult { let state = self.state.read(); let (realized_pnl, unrealized_pnl) = self.position_tracker.get_total_pnl(); let total_pnl = realized_pnl + unrealized_pnl; let total_return = (total_pnl / self.config.initial_capital) * 100.0; // Get completed trades from trade tracker for metrics let completed_trades = self.trade_tracker.get_completed_trades(); // Calculate metrics from completed trades let completed_trade_count = completed_trades.len(); let profitable_trades = completed_trades.iter().filter(|t| t.pnl > 0.0).count(); let total_wins: f64 = completed_trades.iter().filter(|t| t.pnl > 0.0).map(|t| t.pnl).sum(); let total_losses: f64 = completed_trades.iter().filter(|t| t.pnl < 0.0).map(|t| t.pnl.abs()).sum(); let avg_win = if profitable_trades > 0 { total_wins / profitable_trades as f64 } else { 0.0 }; let avg_loss = if completed_trade_count > profitable_trades { total_losses / (completed_trade_count - profitable_trades) as f64 } else { 0.0 }; let profit_factor = if total_losses > 0.0 { total_wins / total_losses } else { 0.0 }; // For the API, return all fills (not just completed trades) // This shows all trading activity let all_fills = state.completed_trades.clone(); let total_trades = all_fills.len(); BacktestResult { config: self.config.clone(), metrics: super::BacktestMetrics { total_return, total_trades, profitable_trades, win_rate: if completed_trade_count > 0 { (profitable_trades as f64 / completed_trade_count as f64) * 100.0 } else { 0.0 }, profit_factor, sharpe_ratio: 0.0, // TODO: Calculate properly max_drawdown: 0.0, // TODO: Calculate properly total_pnl, avg_win, avg_loss, }, equity_curve: state.equity_curve.clone(), trades: all_fills, final_positions: self.position_tracker.get_all_positions() .into_iter() .map(|p| (p.symbol.clone(), p)) .collect(), } } } // Add uuid dependency use uuid::Uuid;