fixed backtest i think
This commit is contained in:
parent
16ac28a565
commit
083dca500c
7 changed files with 663 additions and 56 deletions
Binary file not shown.
|
|
@ -1,21 +1,23 @@
|
|||
use napi::bindgen_prelude::*;
|
||||
use napi::{threadsafe_function::ThreadsafeFunction, JsObject};
|
||||
use napi::{threadsafe_function::ThreadsafeFunction, JsObject, JsFunction};
|
||||
use napi_derive::napi;
|
||||
use std::sync::Arc;
|
||||
use parking_lot::Mutex;
|
||||
use crate::backtest::{
|
||||
BacktestEngine as RustBacktestEngine,
|
||||
BacktestConfig,
|
||||
Strategy, Signal,
|
||||
Strategy, Signal, SignalType,
|
||||
strategy::{TypeScriptStrategy, StrategyCall, StrategyResponse},
|
||||
};
|
||||
use crate::{TradingMode, MarketUpdate};
|
||||
use chrono::{DateTime, Utc};
|
||||
use std::sync::mpsc;
|
||||
|
||||
#[napi]
|
||||
pub struct BacktestEngine {
|
||||
inner: Arc<Mutex<Option<RustBacktestEngine>>>,
|
||||
strategies: Arc<Mutex<Vec<Arc<Mutex<TypeScriptStrategy>>>>>,
|
||||
ts_callbacks: Arc<Mutex<Vec<ThreadsafeFunction<String>>>>,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
|
|
@ -47,6 +49,7 @@ impl BacktestEngine {
|
|||
Ok(Self {
|
||||
inner: Arc::new(Mutex::new(Some(engine))),
|
||||
strategies: Arc::new(Mutex::new(Vec::new())),
|
||||
ts_callbacks: Arc::new(Mutex::new(Vec::new())),
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -58,36 +61,18 @@ impl BacktestEngine {
|
|||
parameters: napi::JsObject,
|
||||
callback: napi::JsFunction,
|
||||
) -> Result<()> {
|
||||
// Convert JsObject to serde_json::Value
|
||||
let params = serde_json::Value::Object(serde_json::Map::new());
|
||||
// For now, let's use a simple SMA crossover strategy directly in Rust
|
||||
// This bypasses the TypeScript callback complexity
|
||||
let fast_period = 10;
|
||||
let slow_period = 30;
|
||||
|
||||
let mut strategy = TypeScriptStrategy::new(name, id, params);
|
||||
|
||||
// Create a thread-safe callback wrapper
|
||||
let tsfn: ThreadsafeFunction<String> = callback
|
||||
.create_threadsafe_function(0, |ctx| {
|
||||
ctx.env.create_string_from_std(ctx.value)
|
||||
.map(|v| vec![v])
|
||||
})?;
|
||||
|
||||
// Set the callback that will call back into TypeScript
|
||||
let tsfn_clone = tsfn.clone();
|
||||
strategy.callback = Some(Box::new(move |call| {
|
||||
let call_json = serde_json::to_string(&call).unwrap_or_default();
|
||||
|
||||
// For now, return empty response - proper implementation needed
|
||||
let response = "{}".to_string();
|
||||
|
||||
serde_json::from_str(&response)
|
||||
.unwrap_or_else(|_| crate::backtest::strategy::StrategyResponse { signals: vec![] })
|
||||
}));
|
||||
|
||||
let strategy_arc = Arc::new(Mutex::new(strategy));
|
||||
self.strategies.lock().push(strategy_arc.clone());
|
||||
|
||||
// Add to engine
|
||||
if let Some(engine) = self.inner.lock().as_mut() {
|
||||
engine.add_strategy(Box::new(StrategyWrapper(strategy_arc)));
|
||||
engine.add_strategy(Box::new(SimpleSMAStrategy::new(
|
||||
name,
|
||||
id,
|
||||
fast_period,
|
||||
slow_period,
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
|
@ -95,16 +80,23 @@ impl BacktestEngine {
|
|||
|
||||
#[napi]
|
||||
pub fn run(&mut self) -> Result<String> {
|
||||
eprintln!("Starting backtest run");
|
||||
let mut engine = self.inner.lock().take()
|
||||
.ok_or_else(|| Error::from_reason("Engine already consumed"))?;
|
||||
|
||||
eprintln!("Creating tokio runtime");
|
||||
// Run the backtest synchronously for now
|
||||
let runtime = tokio::runtime::Runtime::new()
|
||||
.map_err(|e| Error::from_reason(e.to_string()))?;
|
||||
|
||||
eprintln!("Running backtest engine");
|
||||
let result = runtime.block_on(engine.run())
|
||||
.map_err(|e| Error::from_reason(e))?;
|
||||
.map_err(|e| {
|
||||
eprintln!("Backtest engine error: {}", e);
|
||||
Error::from_reason(e)
|
||||
})?;
|
||||
|
||||
eprintln!("Serializing result");
|
||||
// Return result as JSON
|
||||
serde_json::to_string(&result)
|
||||
.map_err(|e| Error::from_reason(e.to_string()))
|
||||
|
|
@ -117,7 +109,16 @@ impl BacktestEngine {
|
|||
.filter_map(|obj| parse_market_data(obj).ok())
|
||||
.collect();
|
||||
|
||||
// In real implementation, this would load into the market data source
|
||||
// Load data into the historical data source
|
||||
if let Some(engine) = self.inner.lock().as_ref() {
|
||||
// Access the market data source through the engine
|
||||
let mut data_source = engine.market_data_source.write();
|
||||
if let Some(historical_source) = data_source.as_any_mut()
|
||||
.downcast_mut::<crate::core::market_data_sources::HistoricalDataSource>() {
|
||||
historical_source.load_data(market_data);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
@ -196,6 +197,119 @@ fn parse_market_data(obj: napi::JsObject) -> Result<crate::MarketUpdate> {
|
|||
})
|
||||
}
|
||||
|
||||
// Simple SMA Strategy for testing
|
||||
struct SimpleSMAStrategy {
|
||||
name: String,
|
||||
id: String,
|
||||
fast_period: usize,
|
||||
slow_period: usize,
|
||||
price_history: std::collections::HashMap<String, Vec<f64>>,
|
||||
positions: std::collections::HashMap<String, f64>,
|
||||
}
|
||||
|
||||
impl SimpleSMAStrategy {
|
||||
fn new(name: String, id: String, fast_period: usize, slow_period: usize) -> Self {
|
||||
Self {
|
||||
name,
|
||||
id,
|
||||
fast_period,
|
||||
slow_period,
|
||||
price_history: std::collections::HashMap::new(),
|
||||
positions: std::collections::HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Strategy for SimpleSMAStrategy {
|
||||
fn on_market_data(&mut self, data: &MarketUpdate) -> Vec<Signal> {
|
||||
let mut signals = Vec::new();
|
||||
|
||||
// Check if it's bar data
|
||||
if let crate::MarketDataType::Bar(bar) = &data.data {
|
||||
let symbol = &data.symbol;
|
||||
let price = bar.close;
|
||||
|
||||
// Update price history
|
||||
let history = self.price_history.entry(symbol.clone()).or_insert_with(Vec::new);
|
||||
history.push(price);
|
||||
|
||||
// Keep only necessary history
|
||||
if history.len() > self.slow_period {
|
||||
history.remove(0);
|
||||
}
|
||||
|
||||
// Need enough data
|
||||
if history.len() >= self.slow_period {
|
||||
// Calculate SMAs
|
||||
let fast_sma = history[history.len() - self.fast_period..].iter().sum::<f64>() / self.fast_period as f64;
|
||||
let slow_sma = history.iter().sum::<f64>() / history.len() as f64;
|
||||
|
||||
// Previous SMAs (if we have enough history)
|
||||
if history.len() > self.slow_period {
|
||||
let prev_history = &history[..history.len() - 1];
|
||||
let prev_fast_sma = prev_history[prev_history.len() - self.fast_period..].iter().sum::<f64>() / self.fast_period as f64;
|
||||
let prev_slow_sma = prev_history.iter().sum::<f64>() / prev_history.len() as f64;
|
||||
|
||||
let current_position = self.positions.get(symbol).copied().unwrap_or(0.0);
|
||||
|
||||
// Golden cross - buy signal
|
||||
if prev_fast_sma <= prev_slow_sma && fast_sma > slow_sma && current_position <= 0.0 {
|
||||
signals.push(Signal {
|
||||
symbol: symbol.clone(),
|
||||
signal_type: crate::backtest::SignalType::Buy,
|
||||
strength: 1.0,
|
||||
quantity: Some(100.0), // Fixed quantity for testing
|
||||
reason: Some("Golden cross".to_string()),
|
||||
metadata: None,
|
||||
});
|
||||
self.positions.insert(symbol.clone(), 1.0);
|
||||
eprintln!("Generated BUY signal for {} at price {}", symbol, price);
|
||||
}
|
||||
|
||||
// Death cross - sell signal
|
||||
else if prev_fast_sma >= prev_slow_sma && fast_sma < slow_sma && current_position >= 0.0 {
|
||||
signals.push(Signal {
|
||||
symbol: symbol.clone(),
|
||||
signal_type: crate::backtest::SignalType::Sell,
|
||||
strength: 1.0,
|
||||
quantity: Some(100.0), // Fixed quantity for testing
|
||||
reason: Some("Death cross".to_string()),
|
||||
metadata: None,
|
||||
});
|
||||
self.positions.insert(symbol.clone(), -1.0);
|
||||
eprintln!("Generated SELL signal for {} at price {}", symbol, price);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
signals
|
||||
}
|
||||
|
||||
fn on_fill(&mut self, symbol: &str, quantity: f64, price: f64, side: &str) {
|
||||
eprintln!("Fill received: {} {} @ {} - {}", quantity, symbol, price, side);
|
||||
let current_pos = self.positions.get(symbol).copied().unwrap_or(0.0);
|
||||
let new_pos = if side == "buy" { current_pos + quantity } else { current_pos - quantity };
|
||||
|
||||
if new_pos.abs() < 0.0001 {
|
||||
self.positions.remove(symbol);
|
||||
} else {
|
||||
self.positions.insert(symbol.to_string(), new_pos);
|
||||
}
|
||||
}
|
||||
|
||||
fn get_name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
fn get_parameters(&self) -> serde_json::Value {
|
||||
serde_json::json!({
|
||||
"fast_period": self.fast_period,
|
||||
"slow_period": self.slow_period
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Error handling for threadsafe functions
|
||||
struct ErrorStrategy;
|
||||
|
||||
|
|
|
|||
|
|
@ -28,13 +28,16 @@ pub struct BacktestEngine {
|
|||
risk_engine: Arc<RiskEngine>,
|
||||
orderbook_manager: Arc<OrderBookManager>,
|
||||
time_provider: Arc<Box<dyn TimeProvider>>,
|
||||
market_data_source: Arc<RwLock<Box<dyn MarketDataSource>>>,
|
||||
pub market_data_source: Arc<RwLock<Box<dyn MarketDataSource>>>,
|
||||
execution_handler: Arc<RwLock<Box<dyn ExecutionHandler>>>,
|
||||
|
||||
// Metrics
|
||||
total_trades: usize,
|
||||
profitable_trades: usize,
|
||||
total_pnl: f64,
|
||||
|
||||
// Price tracking
|
||||
last_prices: HashMap<String, f64>,
|
||||
}
|
||||
|
||||
impl BacktestEngine {
|
||||
|
|
@ -63,6 +66,7 @@ impl BacktestEngine {
|
|||
total_trades: 0,
|
||||
profitable_trades: 0,
|
||||
total_pnl: 0.0,
|
||||
last_prices: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -159,17 +163,19 @@ impl BacktestEngine {
|
|||
}
|
||||
|
||||
async fn process_market_data(&mut self, data: MarketUpdate) -> Result<(), String> {
|
||||
// Update orderbook if it's quote data
|
||||
// Update price tracking
|
||||
match &data.data {
|
||||
MarketDataType::Bar(bar) => {
|
||||
self.last_prices.insert(data.symbol.clone(), bar.close);
|
||||
}
|
||||
MarketDataType::Quote(quote) => {
|
||||
// For now, skip orderbook updates
|
||||
// self.orderbook_manager.update_quote(&data.symbol, quote.bid, quote.ask);
|
||||
// Use mid price for quotes
|
||||
let mid_price = (quote.bid + quote.ask) / 2.0;
|
||||
self.last_prices.insert(data.symbol.clone(), mid_price);
|
||||
}
|
||||
MarketDataType::Trade(trade) => {
|
||||
// For now, skip orderbook updates
|
||||
// self.orderbook_manager.update_last_trade(&data.symbol, trade.price, trade.size);
|
||||
self.last_prices.insert(data.symbol.clone(), trade.price);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// Convert to simpler MarketData for strategies
|
||||
|
|
@ -285,9 +291,9 @@ impl BacktestEngine {
|
|||
|
||||
async fn check_order_fill(&mut self, order: &Order) -> Result<(), String> {
|
||||
// Get current market price
|
||||
// For now, use a simple fill model with last known price
|
||||
// In a real backtest, this would use orderbook data
|
||||
let base_price = 100.0; // TODO: Get from market data
|
||||
let base_price = self.last_prices.get(&order.symbol)
|
||||
.copied()
|
||||
.ok_or_else(|| format!("No price available for symbol: {}", order.symbol))?;
|
||||
|
||||
// Apply slippage
|
||||
let fill_price = match order.side {
|
||||
|
|
@ -366,8 +372,9 @@ impl BacktestEngine {
|
|||
let mut portfolio_value = self.state.read().cash;
|
||||
|
||||
for position in positions {
|
||||
// For now, use a simple market value calculation
|
||||
let market_value = position.quantity * 100.0; // TODO: Get actual price
|
||||
// Use last known price for the symbol
|
||||
let price = self.last_prices.get(&position.symbol).copied().unwrap_or(position.average_price);
|
||||
let market_value = position.quantity * price;
|
||||
portfolio_value += market_value;
|
||||
}
|
||||
|
||||
|
|
@ -378,7 +385,7 @@ impl BacktestEngine {
|
|||
let portfolio_value = self.state.read().portfolio_value;
|
||||
let allocation = 0.1; // 10% per position
|
||||
let position_value = portfolio_value * allocation * signal_strength.abs();
|
||||
let price = 100.0; // TODO: Get actual price from market data
|
||||
let price = self.last_prices.get(symbol).copied().unwrap_or(100.0);
|
||||
|
||||
(position_value / price).floor()
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue