533 lines
No EOL
20 KiB
Rust
533 lines
No EOL
20 KiB
Rust
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
use parking_lot::RwLock;
|
|
use chrono::{DateTime, Utc};
|
|
use serde::{Serialize, Deserialize};
|
|
|
|
use crate::{
|
|
TradingMode, MarketDataSource, ExecutionHandler, TimeProvider,
|
|
MarketUpdate, MarketDataType, Order, Fill, Side,
|
|
positions::PositionTracker,
|
|
risk::RiskEngine,
|
|
orderbook::OrderBookManager,
|
|
};
|
|
|
|
use super::{
|
|
BacktestConfig, BacktestState, EventQueue, BacktestEvent, EventType,
|
|
Strategy, Signal, SignalType, BacktestResult, TradeTracker,
|
|
};
|
|
|
|
pub struct BacktestEngine {
|
|
config: BacktestConfig,
|
|
state: Arc<RwLock<BacktestState>>,
|
|
event_queue: Arc<RwLock<EventQueue>>,
|
|
strategies: Arc<RwLock<Vec<Box<dyn Strategy>>>>,
|
|
|
|
// Core components
|
|
position_tracker: Arc<PositionTracker>,
|
|
risk_engine: Arc<RiskEngine>,
|
|
orderbook_manager: Arc<OrderBookManager>,
|
|
time_provider: Arc<Box<dyn TimeProvider>>,
|
|
pub market_data_source: Arc<RwLock<Box<dyn MarketDataSource>>>,
|
|
execution_handler: Arc<RwLock<Box<dyn ExecutionHandler>>>,
|
|
|
|
// Metrics
|
|
total_trades: usize,
|
|
profitable_trades: usize,
|
|
total_pnl: f64,
|
|
|
|
// Price tracking
|
|
last_prices: HashMap<String, f64>,
|
|
|
|
// Trade tracking
|
|
trade_tracker: TradeTracker,
|
|
}
|
|
|
|
impl BacktestEngine {
|
|
pub fn new(
|
|
config: BacktestConfig,
|
|
mode: TradingMode,
|
|
time_provider: Box<dyn TimeProvider>,
|
|
market_data_source: Box<dyn MarketDataSource>,
|
|
execution_handler: Box<dyn ExecutionHandler>,
|
|
) -> Self {
|
|
let state = Arc::new(RwLock::new(
|
|
BacktestState::new(config.initial_capital, config.start_time)
|
|
));
|
|
|
|
Self {
|
|
config,
|
|
state,
|
|
event_queue: Arc::new(RwLock::new(EventQueue::new())),
|
|
strategies: Arc::new(RwLock::new(Vec::new())),
|
|
position_tracker: Arc::new(PositionTracker::new()),
|
|
risk_engine: Arc::new(RiskEngine::new()),
|
|
orderbook_manager: Arc::new(OrderBookManager::new()),
|
|
time_provider: Arc::new(time_provider),
|
|
market_data_source: Arc::new(RwLock::new(market_data_source)),
|
|
execution_handler: Arc::new(RwLock::new(execution_handler)),
|
|
total_trades: 0,
|
|
profitable_trades: 0,
|
|
total_pnl: 0.0,
|
|
last_prices: HashMap::new(),
|
|
trade_tracker: TradeTracker::new(),
|
|
}
|
|
}
|
|
|
|
pub fn add_strategy(&mut self, strategy: Box<dyn Strategy>) {
|
|
self.strategies.write().push(strategy);
|
|
}
|
|
|
|
pub async fn run(&mut self) -> Result<BacktestResult, String> {
|
|
eprintln!("=== BacktestEngine::run() START ===");
|
|
eprintln!("Config: start={}, end={}, symbols={:?}",
|
|
self.config.start_time, self.config.end_time, self.config.symbols);
|
|
eprintln!("Number of strategies loaded: {}", self.strategies.read().len());
|
|
|
|
// Initialize start time
|
|
if let Some(simulated_time) = self.time_provider.as_any()
|
|
.downcast_ref::<crate::core::time_providers::SimulatedTime>()
|
|
{
|
|
simulated_time.advance_to(self.config.start_time);
|
|
eprintln!("Time initialized to: {}", self.config.start_time);
|
|
}
|
|
|
|
// Load market data
|
|
eprintln!("Loading market data from data source...");
|
|
self.load_market_data().await?;
|
|
|
|
let queue_len = self.event_queue.read().len();
|
|
eprintln!("Event queue length after loading: {}", queue_len);
|
|
|
|
if queue_len == 0 {
|
|
eprintln!("WARNING: No events loaded! Check data source.");
|
|
}
|
|
|
|
// Main event loop
|
|
let mut iteration = 0;
|
|
while !self.event_queue.read().is_empty() {
|
|
iteration += 1;
|
|
if iteration <= 5 || iteration % 100 == 0 {
|
|
eprintln!("Processing iteration {} at time {}", iteration, self.time_provider.now());
|
|
}
|
|
|
|
// Get the next event's timestamp
|
|
let next_event_time = self.event_queue.read()
|
|
.peek_next()
|
|
.map(|e| e.timestamp);
|
|
|
|
if let Some(event_time) = next_event_time {
|
|
// Advance time to the next event
|
|
self.advance_time(event_time);
|
|
|
|
// Get all events at this timestamp
|
|
let current_time = self.time_provider.now();
|
|
let events = self.event_queue.write().pop_until(current_time);
|
|
|
|
for event in events {
|
|
self.process_event(event).await?;
|
|
}
|
|
|
|
// Update portfolio value
|
|
self.update_portfolio_value();
|
|
} else {
|
|
// No more events
|
|
break;
|
|
}
|
|
}
|
|
|
|
eprintln!("Backtest complete. Total trades: {}", self.total_trades);
|
|
|
|
// Generate results
|
|
Ok(self.generate_results())
|
|
}
|
|
|
|
async fn load_market_data(&mut self) -> Result<(), String> {
|
|
eprintln!("=== load_market_data START ===");
|
|
let mut data_source = self.market_data_source.write();
|
|
|
|
// Check if it's a HistoricalDataSource
|
|
if let Some(historical) = data_source.as_any()
|
|
.downcast_ref::<crate::core::market_data_sources::HistoricalDataSource>() {
|
|
eprintln!("Data source is HistoricalDataSource");
|
|
eprintln!("Historical data points available: {}", historical.data_len());
|
|
} else {
|
|
eprintln!("WARNING: Data source is NOT HistoricalDataSource!");
|
|
}
|
|
|
|
eprintln!("Seeking to start time: {}", self.config.start_time);
|
|
data_source.seek_to_time(self.config.start_time)?;
|
|
|
|
let mut count = 0;
|
|
let mut first_few = 0;
|
|
|
|
// Load all data into event queue
|
|
while let Some(update) = data_source.get_next_update().await {
|
|
if update.timestamp > self.config.end_time {
|
|
eprintln!("Reached end time at {} data points", count);
|
|
break;
|
|
}
|
|
|
|
count += 1;
|
|
|
|
// Log first few data points
|
|
if first_few < 3 {
|
|
eprintln!("Data point {}: symbol={}, time={}, type={:?}",
|
|
count, update.symbol, update.timestamp,
|
|
match &update.data {
|
|
MarketDataType::Bar(b) => format!("Bar(close={})", b.close),
|
|
MarketDataType::Quote(q) => format!("Quote(bid={}, ask={})", q.bid, q.ask),
|
|
MarketDataType::Trade(t) => format!("Trade(price={})", t.price),
|
|
}
|
|
);
|
|
first_few += 1;
|
|
}
|
|
|
|
if count % 100 == 0 {
|
|
eprintln!("Loaded {} data points so far...", count);
|
|
}
|
|
|
|
let event = BacktestEvent::market_data(update.timestamp, update);
|
|
self.event_queue.write().push(event);
|
|
}
|
|
|
|
eprintln!("=== load_market_data COMPLETE ===");
|
|
eprintln!("Total data points loaded: {}", count);
|
|
Ok(())
|
|
}
|
|
|
|
async fn process_event(&mut self, event: BacktestEvent) -> Result<(), String> {
|
|
match event.event_type {
|
|
EventType::MarketData(data) => {
|
|
self.process_market_data(data).await?;
|
|
}
|
|
EventType::OrderSubmitted(order) => {
|
|
self.process_order_submission(order).await?;
|
|
}
|
|
EventType::OrderFilled(_fill) => {
|
|
// Fills are already processed when orders are executed
|
|
// This event is just for recording
|
|
// Note: We now record fills in process_fill with symbol info
|
|
}
|
|
EventType::OrderCancelled(order_id) => {
|
|
self.process_order_cancellation(&order_id)?;
|
|
}
|
|
EventType::TimeUpdate(time) => {
|
|
self.advance_time(time);
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn process_market_data(&mut self, data: MarketUpdate) -> Result<(), String> {
|
|
static mut MARKET_DATA_COUNT: usize = 0;
|
|
unsafe {
|
|
MARKET_DATA_COUNT += 1;
|
|
if MARKET_DATA_COUNT <= 3 || MARKET_DATA_COUNT % 100 == 0 {
|
|
eprintln!("process_market_data #{}: symbol={}, time={}",
|
|
MARKET_DATA_COUNT, data.symbol, data.timestamp);
|
|
}
|
|
}
|
|
|
|
// Update price tracking
|
|
match &data.data {
|
|
MarketDataType::Bar(bar) => {
|
|
self.last_prices.insert(data.symbol.clone(), bar.close);
|
|
}
|
|
MarketDataType::Quote(quote) => {
|
|
// Use mid price for quotes
|
|
let mid_price = (quote.bid + quote.ask) / 2.0;
|
|
self.last_prices.insert(data.symbol.clone(), mid_price);
|
|
}
|
|
MarketDataType::Trade(trade) => {
|
|
self.last_prices.insert(data.symbol.clone(), trade.price);
|
|
}
|
|
}
|
|
|
|
// Convert to simpler MarketData for strategies
|
|
let market_data = self.convert_to_market_data(&data);
|
|
|
|
// Send to strategies
|
|
let mut all_signals = Vec::new();
|
|
{
|
|
let mut strategies = self.strategies.write();
|
|
for (i, strategy) in strategies.iter_mut().enumerate() {
|
|
let signals = strategy.on_market_data(&market_data);
|
|
if !signals.is_empty() {
|
|
eprintln!("Strategy {} generated {} signals!", i, signals.len());
|
|
}
|
|
all_signals.extend(signals);
|
|
}
|
|
}
|
|
|
|
// Process signals
|
|
for signal in all_signals {
|
|
eprintln!("Processing signal: {:?}", signal);
|
|
self.process_signal(signal).await?;
|
|
}
|
|
|
|
// Check pending orders for fills
|
|
self.check_pending_orders(&data).await?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn convert_to_market_data(&self, update: &MarketUpdate) -> MarketUpdate {
|
|
// MarketData is a type alias for MarketUpdate
|
|
update.clone()
|
|
}
|
|
|
|
async fn process_signal(&mut self, signal: Signal) -> Result<(), String> {
|
|
// Only process strong signals
|
|
if signal.strength.abs() < 0.7 {
|
|
return Ok(());
|
|
}
|
|
|
|
// Convert signal to order
|
|
let order = self.signal_to_order(signal)?;
|
|
|
|
// Submit order
|
|
self.process_order_submission(order).await
|
|
}
|
|
|
|
fn signal_to_order(&self, signal: Signal) -> Result<Order, String> {
|
|
let quantity = signal.quantity.unwrap_or_else(|| {
|
|
// Calculate position size based on portfolio
|
|
self.calculate_position_size(&signal.symbol, signal.strength)
|
|
});
|
|
|
|
let side = match signal.signal_type {
|
|
SignalType::Buy => Side::Buy,
|
|
SignalType::Sell => Side::Sell,
|
|
SignalType::Close => {
|
|
// Determine side based on current position
|
|
let position = self.position_tracker.get_position(&signal.symbol);
|
|
if position.as_ref().map(|p| p.quantity > 0.0).unwrap_or(false) {
|
|
Side::Sell
|
|
} else {
|
|
Side::Buy
|
|
}
|
|
}
|
|
};
|
|
|
|
Ok(crate::Order {
|
|
id: format!("order_{}", uuid::Uuid::new_v4()),
|
|
symbol: signal.symbol,
|
|
side,
|
|
quantity,
|
|
order_type: crate::OrderType::Market,
|
|
time_in_force: crate::TimeInForce::Day,
|
|
})
|
|
}
|
|
|
|
async fn process_order_submission(&mut self, order: Order) -> Result<(), String> {
|
|
// Risk checks
|
|
// Get current position for the symbol
|
|
let current_position = self.position_tracker
|
|
.get_position(&order.symbol)
|
|
.map(|p| p.quantity);
|
|
|
|
let risk_check = self.risk_engine.check_order(&order, current_position);
|
|
if !risk_check.passed {
|
|
return Err(format!("Risk check failed: {:?}", risk_check.violations));
|
|
}
|
|
|
|
// Add to pending orders
|
|
self.state.write().add_pending_order(order.clone());
|
|
|
|
// For market orders in backtesting, fill immediately
|
|
if matches!(order.order_type, crate::OrderType::Market) {
|
|
self.check_order_fill(&order).await?;
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn check_pending_orders(&mut self, market_data: &MarketUpdate) -> Result<(), String> {
|
|
let orders_to_check: Vec<Order> = {
|
|
let state = self.state.read();
|
|
state.pending_orders.values()
|
|
.filter(|o| o.symbol == market_data.symbol)
|
|
.cloned()
|
|
.collect()
|
|
};
|
|
|
|
for order in orders_to_check {
|
|
self.check_order_fill(&order).await?;
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn check_order_fill(&mut self, order: &Order) -> Result<(), String> {
|
|
// Get current market price
|
|
let base_price = self.last_prices.get(&order.symbol)
|
|
.copied()
|
|
.ok_or_else(|| format!("No price available for symbol: {}", order.symbol))?;
|
|
|
|
// Apply slippage
|
|
let fill_price = match order.side {
|
|
crate::Side::Buy => base_price * (1.0 + self.config.slippage),
|
|
crate::Side::Sell => base_price * (1.0 - self.config.slippage),
|
|
};
|
|
|
|
// Create fill
|
|
let fill = crate::Fill {
|
|
timestamp: self.time_provider.now(),
|
|
price: fill_price,
|
|
quantity: order.quantity,
|
|
commission: order.quantity * fill_price * self.config.commission,
|
|
};
|
|
|
|
// Process the fill
|
|
self.process_fill(&order, fill).await
|
|
}
|
|
|
|
async fn process_fill(&mut self, order: &crate::Order, fill: crate::Fill) -> Result<(), String> {
|
|
// Remove from pending orders
|
|
self.state.write().remove_pending_order(&order.id);
|
|
|
|
// Update positions
|
|
let update = self.position_tracker.process_fill(
|
|
&order.symbol,
|
|
&fill,
|
|
order.side,
|
|
);
|
|
|
|
// Record the fill with symbol and side information
|
|
self.state.write().record_fill(order.symbol.clone(), order.side, fill.clone());
|
|
|
|
// Track trades
|
|
self.trade_tracker.process_fill(&order.symbol, order.side, &fill);
|
|
|
|
// Update cash
|
|
let cash_change = match order.side {
|
|
crate::Side::Buy => -(fill.quantity * fill.price + fill.commission),
|
|
crate::Side::Sell => fill.quantity * fill.price - fill.commission,
|
|
};
|
|
self.state.write().cash += cash_change;
|
|
|
|
// Notify strategies
|
|
{
|
|
let mut strategies = self.strategies.write();
|
|
for strategy in strategies.iter_mut() {
|
|
strategy.on_fill(&order.symbol, fill.quantity, fill.price,
|
|
&format!("{:?}", order.side));
|
|
}
|
|
}
|
|
|
|
eprintln!("Fill processed: {} {} @ {} (side: {:?})",
|
|
fill.quantity, order.symbol, fill.price, order.side);
|
|
eprintln!("Current position after fill: {}",
|
|
self.position_tracker.get_position(&order.symbol)
|
|
.map(|p| p.quantity)
|
|
.unwrap_or(0.0));
|
|
|
|
// Update metrics
|
|
self.total_trades += 1;
|
|
if update.resulting_position.realized_pnl > 0.0 {
|
|
self.profitable_trades += 1;
|
|
}
|
|
self.total_pnl = update.resulting_position.realized_pnl;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn process_order_cancellation(&mut self, order_id: &str) -> Result<(), String> {
|
|
self.state.write().remove_pending_order(order_id);
|
|
Ok(())
|
|
}
|
|
|
|
fn advance_time(&mut self, time: DateTime<Utc>) {
|
|
if let Some(simulated_time) = self.time_provider.as_any()
|
|
.downcast_ref::<crate::core::time_providers::SimulatedTime>()
|
|
{
|
|
simulated_time.advance_to(time);
|
|
}
|
|
self.state.write().current_time = time;
|
|
}
|
|
|
|
fn update_portfolio_value(&mut self) {
|
|
let positions = self.position_tracker.get_all_positions();
|
|
let mut portfolio_value = self.state.read().cash;
|
|
|
|
for position in positions {
|
|
// Use last known price for the symbol
|
|
let price = self.last_prices.get(&position.symbol).copied().unwrap_or(position.average_price);
|
|
let market_value = position.quantity * price;
|
|
portfolio_value += market_value;
|
|
}
|
|
|
|
self.state.write().update_portfolio_value(portfolio_value);
|
|
}
|
|
|
|
fn calculate_position_size(&self, symbol: &str, signal_strength: f64) -> f64 {
|
|
let portfolio_value = self.state.read().portfolio_value;
|
|
let allocation = 0.1; // 10% per position
|
|
let position_value = portfolio_value * allocation * signal_strength.abs();
|
|
let price = self.last_prices.get(symbol).copied().unwrap_or(100.0);
|
|
|
|
(position_value / price).floor()
|
|
}
|
|
|
|
fn get_next_event_time(&self) -> Option<DateTime<Utc>> {
|
|
// Get the timestamp of the next event in the queue
|
|
self.event_queue.read()
|
|
.peek_next()
|
|
.map(|event| event.timestamp)
|
|
}
|
|
|
|
fn generate_results(&self) -> BacktestResult {
|
|
let state = self.state.read();
|
|
let (realized_pnl, unrealized_pnl) = self.position_tracker.get_total_pnl();
|
|
let total_pnl = realized_pnl + unrealized_pnl;
|
|
let total_return = (total_pnl / self.config.initial_capital) * 100.0;
|
|
|
|
// Get completed trades from trade tracker for metrics
|
|
let completed_trades = self.trade_tracker.get_completed_trades();
|
|
|
|
// Calculate metrics from completed trades
|
|
let completed_trade_count = completed_trades.len();
|
|
let profitable_trades = completed_trades.iter().filter(|t| t.pnl > 0.0).count();
|
|
let total_wins: f64 = completed_trades.iter().filter(|t| t.pnl > 0.0).map(|t| t.pnl).sum();
|
|
let total_losses: f64 = completed_trades.iter().filter(|t| t.pnl < 0.0).map(|t| t.pnl.abs()).sum();
|
|
let avg_win = if profitable_trades > 0 { total_wins / profitable_trades as f64 } else { 0.0 };
|
|
let avg_loss = if completed_trade_count > profitable_trades {
|
|
total_losses / (completed_trade_count - profitable_trades) as f64
|
|
} else { 0.0 };
|
|
let profit_factor = if total_losses > 0.0 { total_wins / total_losses } else { 0.0 };
|
|
|
|
// For the API, return all fills (not just completed trades)
|
|
// This shows all trading activity
|
|
let all_fills = state.completed_trades.clone();
|
|
let total_trades = all_fills.len();
|
|
|
|
BacktestResult {
|
|
config: self.config.clone(),
|
|
metrics: super::BacktestMetrics {
|
|
total_return,
|
|
total_trades,
|
|
profitable_trades,
|
|
win_rate: if completed_trade_count > 0 {
|
|
(profitable_trades as f64 / completed_trade_count as f64) * 100.0
|
|
} else { 0.0 },
|
|
profit_factor,
|
|
sharpe_ratio: 0.0, // TODO: Calculate properly
|
|
max_drawdown: 0.0, // TODO: Calculate properly
|
|
total_pnl,
|
|
avg_win,
|
|
avg_loss,
|
|
},
|
|
equity_curve: state.equity_curve.clone(),
|
|
trades: all_fills,
|
|
final_positions: self.position_tracker.get_all_positions()
|
|
.into_iter()
|
|
.map(|p| (p.symbol.clone(), p))
|
|
.collect(),
|
|
}
|
|
}
|
|
}
|
|
|
|
// Add uuid dependency
|
|
use uuid::Uuid; |