added trade-tracking and example rust strats

This commit is contained in:
Boki 2025-07-03 22:55:23 -04:00
parent 0a4702d12a
commit 3a7557c8f4
15 changed files with 2108 additions and 29 deletions

Binary file not shown.

View file

@ -101,6 +101,58 @@ impl BacktestEngine {
slow_period,
)));
}
"mean_reversion" => {
let lookback_period: usize = parameters.get_named_property::<f64>("lookbackPeriod")
.unwrap_or(20.0) as usize;
let entry_threshold: f64 = parameters.get_named_property::<f64>("entryThreshold")
.unwrap_or(2.0);
let position_size: f64 = parameters.get_named_property::<f64>("positionSize")
.unwrap_or(100.0);
engine.add_strategy(Box::new(crate::strategies::MeanReversionFixedStrategy::new(
name.clone(),
id,
lookback_period,
entry_threshold,
position_size,
)));
}
"momentum" => {
let lookback_period: usize = parameters.get_named_property::<f64>("lookbackPeriod")
.unwrap_or(14.0) as usize;
let momentum_threshold: f64 = parameters.get_named_property::<f64>("momentumThreshold")
.unwrap_or(5.0);
let position_size: f64 = parameters.get_named_property::<f64>("positionSize")
.unwrap_or(100.0);
engine.add_strategy(Box::new(crate::strategies::MomentumStrategy::new(
name.clone(),
id,
lookback_period,
momentum_threshold,
position_size,
)));
}
"pairs_trading" => {
let pair_a: String = parameters.get_named_property::<String>("pairA")?;
let pair_b: String = parameters.get_named_property::<String>("pairB")?;
let lookback_period: usize = parameters.get_named_property::<f64>("lookbackPeriod")
.unwrap_or(20.0) as usize;
let entry_threshold: f64 = parameters.get_named_property::<f64>("entryThreshold")
.unwrap_or(2.0);
let position_size: f64 = parameters.get_named_property::<f64>("positionSize")
.unwrap_or(100.0);
engine.add_strategy(Box::new(crate::strategies::PairsTradingStrategy::new(
name.clone(),
id,
pair_a,
pair_b,
lookback_period,
entry_threshold,
position_size,
)));
}
_ => {
return Err(Error::from_reason(format!("Unknown strategy type: {}", strategy_type)));
}

View file

@ -14,7 +14,7 @@ use crate::{
use super::{
BacktestConfig, BacktestState, EventQueue, BacktestEvent, EventType,
Strategy, Signal, SignalType, BacktestResult,
Strategy, Signal, SignalType, BacktestResult, TradeTracker,
};
pub struct BacktestEngine {
@ -38,6 +38,9 @@ pub struct BacktestEngine {
// Price tracking
last_prices: HashMap<String, f64>,
// Trade tracking
trade_tracker: TradeTracker,
}
impl BacktestEngine {
@ -67,6 +70,7 @@ impl BacktestEngine {
profitable_trades: 0,
total_pnl: 0.0,
last_prices: HashMap::new(),
trade_tracker: TradeTracker::new(),
}
}
@ -394,6 +398,9 @@ impl BacktestEngine {
// Record the fill with symbol and side information
self.state.write().record_fill(order.symbol.clone(), order.side, fill.clone());
// Track trades
self.trade_tracker.process_fill(&order.symbol, order.side, &fill);
// Update cash
let cash_change = match order.side {
crate::Side::Buy => -(fill.quantity * fill.price + fill.commission),
@ -410,6 +417,13 @@ impl BacktestEngine {
}
}
eprintln!("Fill processed: {} {} @ {} (side: {:?})",
fill.quantity, order.symbol, fill.price, order.side);
eprintln!("Current position after fill: {}",
self.position_tracker.get_position(&order.symbol)
.map(|p| p.quantity)
.unwrap_or(0.0));
// Update metrics
self.total_trades += 1;
if update.resulting_position.realized_pnl > 0.0 {
@ -470,24 +484,43 @@ impl BacktestEngine {
let total_pnl = realized_pnl + unrealized_pnl;
let total_return = (total_pnl / self.config.initial_capital) * 100.0;
// Get completed trades from trade tracker for metrics
let completed_trades = self.trade_tracker.get_completed_trades();
// Calculate metrics from completed trades
let completed_trade_count = completed_trades.len();
let profitable_trades = completed_trades.iter().filter(|t| t.pnl > 0.0).count();
let total_wins: f64 = completed_trades.iter().filter(|t| t.pnl > 0.0).map(|t| t.pnl).sum();
let total_losses: f64 = completed_trades.iter().filter(|t| t.pnl < 0.0).map(|t| t.pnl.abs()).sum();
let avg_win = if profitable_trades > 0 { total_wins / profitable_trades as f64 } else { 0.0 };
let avg_loss = if completed_trade_count > profitable_trades {
total_losses / (completed_trade_count - profitable_trades) as f64
} else { 0.0 };
let profit_factor = if total_losses > 0.0 { total_wins / total_losses } else { 0.0 };
// For the API, return all fills (not just completed trades)
// This shows all trading activity
let all_fills = state.completed_trades.clone();
let total_trades = all_fills.len();
BacktestResult {
config: self.config.clone(),
metrics: super::BacktestMetrics {
total_return,
total_trades: self.total_trades,
profitable_trades: self.profitable_trades,
win_rate: if self.total_trades > 0 {
(self.profitable_trades as f64 / self.total_trades as f64) * 100.0
total_trades,
profitable_trades,
win_rate: if completed_trade_count > 0 {
(profitable_trades as f64 / completed_trade_count as f64) * 100.0
} else { 0.0 },
profit_factor: 0.0, // TODO: Calculate properly
profit_factor,
sharpe_ratio: 0.0, // TODO: Calculate properly
max_drawdown: 0.0, // TODO: Calculate properly
total_pnl,
avg_win: 0.0, // TODO: Calculate properly
avg_loss: 0.0, // TODO: Calculate properly
avg_win,
avg_loss,
},
equity_curve: state.equity_curve.clone(),
trades: state.completed_trades.clone(),
trades: all_fills,
final_positions: self.position_tracker.get_all_positions()
.into_iter()
.map(|p| (p.symbol.clone(), p))

View file

@ -12,11 +12,13 @@ pub mod engine;
pub mod event;
pub mod strategy;
pub mod results;
pub mod trade_tracker;
pub use engine::BacktestEngine;
pub use event::{BacktestEvent, EventType};
pub use strategy::{Strategy, Signal, SignalType};
pub use results::{BacktestResult, BacktestMetrics};
pub use trade_tracker::{TradeTracker, CompletedTrade as TrackedTrade};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletedTrade {

View file

@ -0,0 +1,258 @@
use std::collections::{HashMap, VecDeque};
use chrono::{DateTime, Utc};
use serde::{Serialize, Deserialize};
use crate::{Fill, Side};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletedTrade {
pub id: String,
pub symbol: String,
pub entry_time: DateTime<Utc>,
pub exit_time: DateTime<Utc>,
pub entry_price: f64,
pub exit_price: f64,
pub quantity: f64,
pub side: Side, // Side of the opening trade
pub pnl: f64,
pub pnl_percent: f64,
pub commission: f64,
pub duration_seconds: i64,
}
#[derive(Debug, Clone)]
struct OpenPosition {
symbol: String,
side: Side,
quantity: f64,
entry_price: f64,
entry_time: DateTime<Utc>,
commission: f64,
}
/// Tracks fills and matches them into completed trades
pub struct TradeTracker {
open_positions: HashMap<String, VecDeque<OpenPosition>>,
completed_trades: Vec<CompletedTrade>,
trade_counter: u64,
}
impl TradeTracker {
pub fn new() -> Self {
Self {
open_positions: HashMap::new(),
completed_trades: Vec::new(),
trade_counter: 0,
}
}
pub fn process_fill(&mut self, symbol: &str, side: Side, fill: &Fill) {
let positions = self.open_positions.entry(symbol.to_string()).or_insert_with(VecDeque::new);
// Check if this fill closes existing positions
let mut remaining_quantity = fill.quantity;
let mut fills_to_remove = Vec::new();
for (idx, open_pos) in positions.iter_mut().enumerate() {
// Only match against opposite side positions
if open_pos.side == side {
continue;
}
if remaining_quantity <= 0.0 {
break;
}
let matched_quantity = remaining_quantity.min(open_pos.quantity);
// Calculate PnL
let (pnl, pnl_percent) = Self::calculate_pnl(
&open_pos,
fill.price,
matched_quantity,
fill.commission,
);
// Create completed trade
self.trade_counter += 1;
let completed_trade = CompletedTrade {
id: format!("trade-{}", self.trade_counter),
symbol: symbol.to_string(),
entry_time: open_pos.entry_time,
exit_time: fill.timestamp,
entry_price: open_pos.entry_price,
exit_price: fill.price,
quantity: matched_quantity,
side: open_pos.side.clone(),
pnl,
pnl_percent,
commission: open_pos.commission + (fill.commission * matched_quantity / fill.quantity),
duration_seconds: (fill.timestamp - open_pos.entry_time).num_seconds(),
};
self.completed_trades.push(completed_trade);
// Update open position
open_pos.quantity -= matched_quantity;
remaining_quantity -= matched_quantity;
if open_pos.quantity <= 0.0 {
fills_to_remove.push(idx);
}
}
// Remove fully closed positions
for idx in fills_to_remove.iter().rev() {
positions.remove(*idx);
}
// If there's remaining quantity, it opens a new position
if remaining_quantity > 0.0 {
let new_position = OpenPosition {
symbol: symbol.to_string(),
side,
quantity: remaining_quantity,
entry_price: fill.price,
entry_time: fill.timestamp,
commission: fill.commission * remaining_quantity / fill.quantity,
};
positions.push_back(new_position);
}
}
fn calculate_pnl(open_pos: &OpenPosition, exit_price: f64, quantity: f64, exit_commission: f64) -> (f64, f64) {
let entry_value = open_pos.entry_price * quantity;
let exit_value = exit_price * quantity;
let gross_pnl = match open_pos.side {
Side::Buy => exit_value - entry_value,
Side::Sell => entry_value - exit_value,
};
let commission = open_pos.commission * (quantity / open_pos.quantity) + exit_commission * (quantity / open_pos.quantity);
let net_pnl = gross_pnl - commission;
let pnl_percent = (net_pnl / entry_value) * 100.0;
(net_pnl, pnl_percent)
}
pub fn get_completed_trades(&self) -> &[CompletedTrade] {
&self.completed_trades
}
pub fn get_open_positions(&self) -> HashMap<String, Vec<(Side, f64, f64)>> {
let mut result = HashMap::new();
for (symbol, positions) in &self.open_positions {
let pos_info: Vec<(Side, f64, f64)> = positions
.iter()
.map(|p| (p.side.clone(), p.quantity, p.entry_price))
.collect();
if !pos_info.is_empty() {
result.insert(symbol.clone(), pos_info);
}
}
result
}
pub fn get_net_position(&self, symbol: &str) -> f64 {
let positions = match self.open_positions.get(symbol) {
Some(pos) => pos,
None => return 0.0,
};
let mut net = 0.0;
for pos in positions {
match pos.side {
Side::Buy => net += pos.quantity,
Side::Sell => net -= pos.quantity,
}
}
net
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simple_round_trip() {
let mut tracker = TradeTracker::new();
// Buy 100 shares at $50
let buy_fill = Fill {
timestamp: Utc::now(),
price: 50.0,
quantity: 100.0,
commission: 1.0,
};
tracker.process_fill("AAPL", Side::Buy, &buy_fill);
// Sell 100 shares at $55
let sell_fill = Fill {
timestamp: Utc::now(),
price: 55.0,
quantity: 100.0,
commission: 1.0,
};
tracker.process_fill("AAPL", Side::Sell, &sell_fill);
let trades = tracker.get_completed_trades();
assert_eq!(trades.len(), 1);
let trade = &trades[0];
assert_eq!(trade.symbol, "AAPL");
assert_eq!(trade.quantity, 100.0);
assert_eq!(trade.entry_price, 50.0);
assert_eq!(trade.exit_price, 55.0);
assert_eq!(trade.pnl, 498.0); // (55-50)*100 - 2 commission
assert_eq!(trade.side, Side::Buy);
}
#[test]
fn test_partial_fills() {
let mut tracker = TradeTracker::new();
// Buy 100 shares at $50
let buy_fill = Fill {
timestamp: Utc::now(),
price: 50.0,
quantity: 100.0,
commission: 1.0,
};
tracker.process_fill("AAPL", Side::Buy, &buy_fill);
// Sell 60 shares at $55
let sell_fill1 = Fill {
timestamp: Utc::now(),
price: 55.0,
quantity: 60.0,
commission: 0.6,
};
tracker.process_fill("AAPL", Side::Sell, &sell_fill1);
// Check we have one completed trade and remaining position
let trades = tracker.get_completed_trades();
assert_eq!(trades.len(), 1);
assert_eq!(trades[0].quantity, 60.0);
assert_eq!(tracker.get_net_position("AAPL"), 40.0);
// Sell remaining 40 shares at $52
let sell_fill2 = Fill {
timestamp: Utc::now(),
price: 52.0,
quantity: 40.0,
commission: 0.4,
};
tracker.process_fill("AAPL", Side::Sell, &sell_fill2);
// Now we should have 2 completed trades and no position
let trades = tracker.get_completed_trades();
assert_eq!(trades.len(), 2);
assert_eq!(tracker.get_net_position("AAPL"), 0.0);
}
}

View file

@ -8,6 +8,7 @@ pub mod api;
pub mod analytics;
pub mod indicators;
pub mod backtest;
pub mod strategies;
// Re-export commonly used types
pub use positions::{Position, PositionUpdate, TradeRecord, ClosedTrade};

View file

@ -0,0 +1,210 @@
use std::collections::HashMap;
use chrono::{DateTime, Utc};
use serde_json::json;
use crate::{
MarketUpdate, MarketDataType,
backtest::{Strategy, Signal, SignalType},
};
/// Mean Reversion Strategy
///
/// This strategy identifies when a security's price deviates significantly from its
/// moving average and trades on the assumption that it will revert back to the mean.
///
/// Entry Signals:
/// - BUY when price falls below (MA - threshold * std_dev)
/// - SELL when price rises above (MA + threshold * std_dev)
///
/// Exit Signals:
/// - Exit long when price reaches MA
/// - Exit short when price reaches MA
pub struct MeanReversionStrategy {
name: String,
id: String,
lookback_period: usize,
entry_threshold: f64, // Number of standard deviations
exit_threshold: f64, // Typically 0 (at the mean)
position_size: f64,
// State
price_history: HashMap<String, Vec<f64>>,
positions: HashMap<String, f64>,
}
impl MeanReversionStrategy {
pub fn new(
name: String,
id: String,
lookback_period: usize,
entry_threshold: f64,
position_size: f64,
) -> Self {
Self {
name,
id,
lookback_period,
entry_threshold,
exit_threshold: 0.0, // Exit at the mean
position_size,
price_history: HashMap::new(),
positions: HashMap::new(),
}
}
fn calculate_mean(prices: &[f64]) -> f64 {
prices.iter().sum::<f64>() / prices.len() as f64
}
fn calculate_std_dev(prices: &[f64], mean: f64) -> f64 {
let variance = prices.iter()
.map(|p| (p - mean).powi(2))
.sum::<f64>() / prices.len() as f64;
variance.sqrt()
}
}
impl Strategy for MeanReversionStrategy {
fn on_market_data(&mut self, data: &MarketUpdate) -> Vec<Signal> {
let mut signals = Vec::new();
// Only process bar data
if let MarketDataType::Bar(bar) = &data.data {
let symbol = &data.symbol;
let price = bar.close;
// Update price history
let history = self.price_history.entry(symbol.clone()).or_insert_with(Vec::new);
history.push(price);
// Keep only necessary history
if history.len() > self.lookback_period {
history.remove(0);
}
// Need enough data
if history.len() >= self.lookback_period {
// Calculate statistics
let mean = Self::calculate_mean(history);
let std_dev = Self::calculate_std_dev(history, mean);
// Calculate bands
let upper_band = mean + self.entry_threshold * std_dev;
let lower_band = mean - self.entry_threshold * std_dev;
let current_position = self.positions.get(symbol).copied().unwrap_or(0.0);
// Check for entry signals
if current_position == 0.0 {
if price < lower_band {
// Price is oversold, buy
signals.push(Signal {
symbol: symbol.clone(),
signal_type: SignalType::Buy,
strength: 1.0,
quantity: Some(self.position_size),
reason: Some(format!(
"Mean reversion buy: price ${:.2} < lower band ${:.2} (mean: ${:.2}, std: ${:.2})",
price, lower_band, mean, std_dev
)),
metadata: Some(json!({
"mean": mean,
"std_dev": std_dev,
"upper_band": upper_band,
"lower_band": lower_band,
"price": price,
})),
});
self.positions.insert(symbol.clone(), self.position_size);
} else if price > upper_band {
// Price is overbought, sell short
signals.push(Signal {
symbol: symbol.clone(),
signal_type: SignalType::Sell,
strength: 1.0,
quantity: Some(self.position_size),
reason: Some(format!(
"Mean reversion sell: price ${:.2} > upper band ${:.2} (mean: ${:.2}, std: ${:.2})",
price, upper_band, mean, std_dev
)),
metadata: Some(json!({
"mean": mean,
"std_dev": std_dev,
"upper_band": upper_band,
"lower_band": lower_band,
"price": price,
})),
});
self.positions.insert(symbol.clone(), -self.position_size);
}
}
// Check for exit signals
else if current_position > 0.0 {
// We're long, exit when price crosses above mean (not just touches)
// or when we hit stop loss
let stop_loss = lower_band - std_dev; // Stop loss below lower band
if price >= mean * 1.01 || price <= stop_loss {
signals.push(Signal {
symbol: symbol.clone(),
signal_type: SignalType::Sell,
strength: 1.0,
quantity: Some(current_position),
reason: Some(format!(
"Mean reversion exit long: price ${:.2} {} mean ${:.2}",
price, if price >= mean * 1.01 { "crossed above" } else { "hit stop loss below" }, mean
)),
metadata: Some(json!({
"mean": mean,
"price": price,
"exit_type": "mean_reversion",
})),
});
self.positions.remove(symbol);
}
} else if current_position < 0.0 {
// We're short, exit when price crosses below mean
// or when we hit stop loss
let stop_loss = upper_band + std_dev; // Stop loss above upper band
if price <= mean * 0.99 || price >= stop_loss {
signals.push(Signal {
symbol: symbol.clone(),
signal_type: SignalType::Buy,
strength: 1.0,
quantity: Some(current_position.abs()),
reason: Some(format!(
"Mean reversion exit short: price ${:.2} {} mean ${:.2}",
price, if price <= mean * 0.99 { "crossed below" } else { "hit stop loss above" }, mean
)),
metadata: Some(json!({
"mean": mean,
"price": price,
"exit_type": "mean_reversion",
})),
});
self.positions.remove(symbol);
}
}
}
}
signals
}
fn on_fill(&mut self, symbol: &str, quantity: f64, price: f64, side: &str) {
// Position tracking is handled in on_market_data for simplicity
eprintln!("Mean reversion fill: {} {} @ {} - {}", quantity, symbol, price, side);
}
fn get_name(&self) -> &str {
&self.name
}
fn get_parameters(&self) -> serde_json::Value {
json!({
"lookback_period": self.lookback_period,
"entry_threshold": self.entry_threshold,
"exit_threshold": self.exit_threshold,
"position_size": self.position_size,
})
}
}

View file

@ -0,0 +1,279 @@
use std::collections::HashMap;
use serde_json::json;
use crate::{
MarketUpdate, MarketDataType,
backtest::{Strategy, Signal, SignalType},
};
/// Fixed Mean Reversion Strategy that properly tracks positions
///
/// This version doesn't maintain its own position tracking but relies
/// on the position information passed through on_fill callbacks
pub struct MeanReversionFixedStrategy {
name: String,
id: String,
lookback_period: usize,
entry_threshold: f64, // Number of standard deviations
exit_threshold: f64, // Exit when price moves back this fraction toward mean
position_size: f64,
// State
price_history: HashMap<String, Vec<f64>>,
current_positions: HashMap<String, f64>, // Track actual positions from fills
entry_prices: HashMap<String, f64>, // Track entry prices for exit decisions
}
impl MeanReversionFixedStrategy {
pub fn new(
name: String,
id: String,
lookback_period: usize,
entry_threshold: f64,
position_size: f64,
) -> Self {
Self {
name,
id,
lookback_period,
entry_threshold,
exit_threshold: 0.3, // Exit when price moves 30% back to mean
position_size,
price_history: HashMap::new(),
current_positions: HashMap::new(),
entry_prices: HashMap::new(),
}
}
fn calculate_mean(prices: &[f64]) -> f64 {
prices.iter().sum::<f64>() / prices.len() as f64
}
fn calculate_std_dev(prices: &[f64], mean: f64) -> f64 {
let variance = prices.iter()
.map(|p| (p - mean).powi(2))
.sum::<f64>() / prices.len() as f64;
variance.sqrt()
}
}
impl Strategy for MeanReversionFixedStrategy {
fn on_market_data(&mut self, data: &MarketUpdate) -> Vec<Signal> {
let mut signals = Vec::new();
// Only process bar data
if let MarketDataType::Bar(bar) = &data.data {
let symbol = &data.symbol;
let price = bar.close;
// Update price history
let history = self.price_history.entry(symbol.clone()).or_insert_with(Vec::new);
history.push(price);
// Keep only necessary history
if history.len() > self.lookback_period {
history.remove(0);
}
// Need enough data
if history.len() >= self.lookback_period {
// Calculate statistics
let mean = Self::calculate_mean(history);
let std_dev = Self::calculate_std_dev(history, mean);
// Calculate bands
let upper_band = mean + self.entry_threshold * std_dev;
let lower_band = mean - self.entry_threshold * std_dev;
// Get actual position from our tracking
let current_position = self.current_positions.get(symbol).copied().unwrap_or(0.0);
// Entry signals - only when flat
if current_position.abs() < 0.001 {
if price < lower_band {
// Price is oversold, buy
eprintln!("Mean reversion: {} oversold at ${:.2}, buying (lower band: ${:.2}, mean: ${:.2})",
symbol, price, lower_band, mean);
signals.push(Signal {
symbol: symbol.clone(),
signal_type: SignalType::Buy,
strength: 1.0,
quantity: Some(self.position_size),
reason: Some(format!(
"Mean reversion buy: price ${:.2} < lower band ${:.2} (mean: ${:.2}, std: ${:.2})",
price, lower_band, mean, std_dev
)),
metadata: Some(json!({
"mean": mean,
"std_dev": std_dev,
"upper_band": upper_band,
"lower_band": lower_band,
"price": price,
})),
});
} else if price > upper_band {
// Price is overbought, sell short
eprintln!("Mean reversion: {} overbought at ${:.2}, selling short (upper band: ${:.2}, mean: ${:.2})",
symbol, price, upper_band, mean);
signals.push(Signal {
symbol: symbol.clone(),
signal_type: SignalType::Sell,
strength: 1.0,
quantity: Some(self.position_size),
reason: Some(format!(
"Mean reversion sell: price ${:.2} > upper band ${:.2} (mean: ${:.2}, std: ${:.2})",
price, upper_band, mean, std_dev
)),
metadata: Some(json!({
"mean": mean,
"std_dev": std_dev,
"upper_band": upper_band,
"lower_band": lower_band,
"price": price,
})),
});
}
}
// Exit signals - only when we have a position
else if current_position > 0.0 {
// We're long - check exit conditions
let entry_price = self.entry_prices.get(symbol).copied().unwrap_or(price);
let target_price = entry_price + (mean - entry_price) * self.exit_threshold;
let stop_loss = lower_band - std_dev; // Stop loss below lower band
if price >= target_price {
eprintln!("Mean reversion: {} reached target ${:.2} (entry: ${:.2}, mean: ${:.2}), closing long",
symbol, target_price, entry_price, mean);
signals.push(Signal {
symbol: symbol.clone(),
signal_type: SignalType::Sell,
strength: 1.0,
quantity: Some(current_position),
reason: Some(format!(
"Mean reversion exit long: price ${:.2} reached target ${:.2} (entry: ${:.2})",
price, target_price, entry_price
)),
metadata: Some(json!({
"mean": mean,
"price": price,
"entry_price": entry_price,
"target_price": target_price,
"exit_type": "target",
})),
});
} else if price <= stop_loss {
eprintln!("Mean reversion: {} hit stop loss ${:.2}, closing long",
symbol, stop_loss);
signals.push(Signal {
symbol: symbol.clone(),
signal_type: SignalType::Sell,
strength: 1.0,
quantity: Some(current_position),
reason: Some(format!(
"Mean reversion stop loss: price ${:.2} <= stop ${:.2}",
price, stop_loss
)),
metadata: Some(json!({
"stop_loss": stop_loss,
"price": price,
"exit_type": "stop_loss",
})),
});
}
} else if current_position < 0.0 {
// We're short - check exit conditions
let entry_price = self.entry_prices.get(symbol).copied().unwrap_or(price);
let target_price = entry_price - (entry_price - mean) * self.exit_threshold;
let stop_loss = upper_band + std_dev; // Stop loss above upper band
if price <= target_price {
eprintln!("Mean reversion: {} reached target ${:.2} (entry: ${:.2}, mean: ${:.2}), closing short",
symbol, target_price, entry_price, mean);
signals.push(Signal {
symbol: symbol.clone(),
signal_type: SignalType::Buy,
strength: 1.0,
quantity: Some(current_position.abs()),
reason: Some(format!(
"Mean reversion exit short: price ${:.2} reached target ${:.2} (entry: ${:.2})",
price, target_price, entry_price
)),
metadata: Some(json!({
"mean": mean,
"price": price,
"entry_price": entry_price,
"target_price": target_price,
"exit_type": "target",
})),
});
} else if price >= stop_loss {
eprintln!("Mean reversion: {} hit stop loss ${:.2}, closing short",
symbol, stop_loss);
signals.push(Signal {
symbol: symbol.clone(),
signal_type: SignalType::Buy,
strength: 1.0,
quantity: Some(current_position.abs()),
reason: Some(format!(
"Mean reversion stop loss: price ${:.2} >= stop ${:.2}",
price, stop_loss
)),
metadata: Some(json!({
"stop_loss": stop_loss,
"price": price,
"exit_type": "stop_loss",
})),
});
}
}
}
}
signals
}
fn on_fill(&mut self, symbol: &str, quantity: f64, price: f64, side: &str) {
// Update our position tracking based on actual fills
let current = self.current_positions.get(symbol).copied().unwrap_or(0.0);
let new_position = if side.contains("Buy") {
current + quantity
} else {
current - quantity
};
eprintln!("Mean reversion fill: {} {} @ {} - {}, position: {} -> {}",
quantity, symbol, price, side, current, new_position);
// Update position
if new_position.abs() < 0.001 {
// Position closed
self.current_positions.remove(symbol);
self.entry_prices.remove(symbol);
} else {
self.current_positions.insert(symbol.to_string(), new_position);
// Track entry price for new positions
if current.abs() < 0.001 {
self.entry_prices.insert(symbol.to_string(), price);
}
}
}
fn get_name(&self) -> &str {
&self.name
}
fn get_parameters(&self) -> serde_json::Value {
json!({
"lookback_period": self.lookback_period,
"entry_threshold": self.entry_threshold,
"exit_threshold": self.exit_threshold,
"position_size": self.position_size,
})
}
}

View file

@ -0,0 +1,9 @@
pub mod mean_reversion;
pub mod mean_reversion_fixed;
pub mod momentum;
pub mod pairs_trading;
pub use mean_reversion::MeanReversionStrategy;
pub use mean_reversion_fixed::MeanReversionFixedStrategy;
pub use momentum::MomentumStrategy;
pub use pairs_trading::PairsTradingStrategy;

View file

@ -0,0 +1,228 @@
use std::collections::HashMap;
use chrono::{DateTime, Utc};
use serde_json::json;
use crate::{
MarketUpdate, MarketDataType,
backtest::{Strategy, Signal, SignalType},
};
/// Momentum Strategy
///
/// This strategy trades based on momentum indicators like rate of change (ROC) and
/// relative strength. It aims to capture trends by buying securities showing
/// upward momentum and selling those showing downward momentum.
///
/// Entry Signals:
/// - BUY when momentum crosses above threshold and accelerating
/// - SELL when momentum crosses below -threshold and decelerating
///
/// Exit Signals:
/// - Exit long when momentum turns negative
/// - Exit short when momentum turns positive
pub struct MomentumStrategy {
name: String,
id: String,
lookback_period: usize,
momentum_threshold: f64,
position_size: f64,
use_acceleration: bool,
// State
price_history: HashMap<String, Vec<f64>>,
momentum_history: HashMap<String, Vec<f64>>,
positions: HashMap<String, f64>,
}
impl MomentumStrategy {
pub fn new(
name: String,
id: String,
lookback_period: usize,
momentum_threshold: f64,
position_size: f64,
) -> Self {
Self {
name,
id,
lookback_period,
momentum_threshold,
position_size,
use_acceleration: true,
price_history: HashMap::new(),
momentum_history: HashMap::new(),
positions: HashMap::new(),
}
}
fn calculate_momentum(prices: &[f64], lookback_period: usize) -> f64 {
if prices.len() < 2 {
return 0.0;
}
let current = prices.last().unwrap();
let past = prices[prices.len() - lookback_period.min(prices.len())];
((current - past) / past) * 100.0
}
fn calculate_acceleration(momentum_values: &[f64]) -> f64 {
if momentum_values.len() < 2 {
return 0.0;
}
let current = momentum_values.last().unwrap();
let previous = momentum_values[momentum_values.len() - 2];
current - previous
}
}
impl Strategy for MomentumStrategy {
fn on_market_data(&mut self, data: &MarketUpdate) -> Vec<Signal> {
let mut signals = Vec::new();
// Only process bar data
if let MarketDataType::Bar(bar) = &data.data {
let symbol = &data.symbol;
let price = bar.close;
// Update price history
let history = self.price_history.entry(symbol.clone()).or_insert_with(Vec::new);
history.push(price);
// Keep reasonable history
if history.len() > self.lookback_period * 2 {
history.remove(0);
}
// Need enough data
if history.len() >= self.lookback_period {
// Calculate momentum
let momentum = Self::calculate_momentum(history, self.lookback_period);
// Update momentum history
let mom_history = self.momentum_history.entry(symbol.clone()).or_insert_with(Vec::new);
mom_history.push(momentum);
if mom_history.len() > 5 {
mom_history.remove(0);
}
// Calculate acceleration if enabled
let acceleration = if self.use_acceleration && mom_history.len() >= 2 {
Self::calculate_acceleration(mom_history)
} else {
0.0
};
let current_position = self.positions.get(symbol).copied().unwrap_or(0.0);
// Check for entry signals
if current_position == 0.0 {
// Long entry: strong positive momentum and accelerating
if momentum > self.momentum_threshold &&
(!self.use_acceleration || acceleration > 0.0) {
signals.push(Signal {
symbol: symbol.clone(),
signal_type: SignalType::Buy,
strength: momentum / 100.0, // Normalize strength
quantity: Some(self.position_size),
reason: Some(format!(
"Momentum buy: momentum {:.2}% > threshold {:.2}%, accel: {:.2}",
momentum, self.momentum_threshold, acceleration
)),
metadata: Some(json!({
"momentum": momentum,
"acceleration": acceleration,
"price": price,
})),
});
self.positions.insert(symbol.clone(), self.position_size);
}
// Short entry: strong negative momentum and decelerating
else if momentum < -self.momentum_threshold &&
(!self.use_acceleration || acceleration < 0.0) {
signals.push(Signal {
symbol: symbol.clone(),
signal_type: SignalType::Sell,
strength: momentum.abs() / 100.0,
quantity: Some(self.position_size),
reason: Some(format!(
"Momentum sell: momentum {:.2}% < threshold -{:.2}%, accel: {:.2}",
momentum, self.momentum_threshold, acceleration
)),
metadata: Some(json!({
"momentum": momentum,
"acceleration": acceleration,
"price": price,
})),
});
self.positions.insert(symbol.clone(), -self.position_size);
}
}
// Check for exit signals
else if current_position > 0.0 {
// Exit long when momentum turns negative
if momentum < 0.0 {
signals.push(Signal {
symbol: symbol.clone(),
signal_type: SignalType::Sell,
strength: 1.0,
quantity: Some(current_position),
reason: Some(format!(
"Momentum exit long: momentum turned negative {:.2}%",
momentum
)),
metadata: Some(json!({
"momentum": momentum,
"price": price,
"exit_type": "momentum_reversal",
})),
});
self.positions.remove(symbol);
}
} else if current_position < 0.0 {
// Exit short when momentum turns positive
if momentum > 0.0 {
signals.push(Signal {
symbol: symbol.clone(),
signal_type: SignalType::Buy,
strength: 1.0,
quantity: Some(current_position.abs()),
reason: Some(format!(
"Momentum exit short: momentum turned positive {:.2}%",
momentum
)),
metadata: Some(json!({
"momentum": momentum,
"price": price,
"exit_type": "momentum_reversal",
})),
});
self.positions.remove(symbol);
}
}
}
}
signals
}
fn on_fill(&mut self, symbol: &str, quantity: f64, price: f64, side: &str) {
eprintln!("Momentum fill: {} {} @ {} - {}", quantity, symbol, price, side);
}
fn get_name(&self) -> &str {
&self.name
}
fn get_parameters(&self) -> serde_json::Value {
json!({
"lookback_period": self.lookback_period,
"momentum_threshold": self.momentum_threshold,
"position_size": self.position_size,
"use_acceleration": self.use_acceleration,
})
}
}

View file

@ -0,0 +1,295 @@
use std::collections::HashMap;
use chrono::{DateTime, Utc};
use serde_json::json;
use crate::{
MarketUpdate, MarketDataType,
backtest::{Strategy, Signal, SignalType},
};
/// Pairs Trading Strategy
///
/// This strategy trades the spread between two correlated securities. When the spread
/// deviates from its historical mean, we trade expecting it to revert.
///
/// Entry Signals:
/// - Long pair A, Short pair B when spread < (mean - threshold * std)
/// - Short pair A, Long pair B when spread > (mean + threshold * std)
///
/// Exit Signals:
/// - Exit when spread returns to mean
pub struct PairsTradingStrategy {
name: String,
id: String,
pair_a: String,
pair_b: String,
lookback_period: usize,
entry_threshold: f64, // Number of standard deviations
position_size: f64,
hedge_ratio: f64, // How many shares of B per share of A
// State
price_history_a: Vec<f64>,
price_history_b: Vec<f64>,
spread_history: Vec<f64>,
positions: HashMap<String, f64>,
last_prices: HashMap<String, f64>,
}
impl PairsTradingStrategy {
pub fn new(
name: String,
id: String,
pair_a: String,
pair_b: String,
lookback_period: usize,
entry_threshold: f64,
position_size: f64,
) -> Self {
Self {
name,
id,
pair_a,
pair_b,
lookback_period,
entry_threshold,
position_size,
hedge_ratio: 1.0, // Default 1:1, could be calculated dynamically
price_history_a: Vec::new(),
price_history_b: Vec::new(),
spread_history: Vec::new(),
positions: HashMap::new(),
last_prices: HashMap::new(),
}
}
fn calculate_spread(&self, price_a: f64, price_b: f64) -> f64 {
price_a - self.hedge_ratio * price_b
}
fn calculate_mean(values: &[f64]) -> f64 {
values.iter().sum::<f64>() / values.len() as f64
}
fn calculate_std_dev(values: &[f64], mean: f64) -> f64 {
let variance = values.iter()
.map(|v| (v - mean).powi(2))
.sum::<f64>() / values.len() as f64;
variance.sqrt()
}
}
impl Strategy for PairsTradingStrategy {
fn on_market_data(&mut self, data: &MarketUpdate) -> Vec<Signal> {
let mut signals = Vec::new();
// Only process bar data
if let MarketDataType::Bar(bar) = &data.data {
let symbol = &data.symbol;
let price = bar.close;
// Update last prices
self.last_prices.insert(symbol.clone(), price);
// Update price histories
if symbol == &self.pair_a {
self.price_history_a.push(price);
if self.price_history_a.len() > self.lookback_period {
self.price_history_a.remove(0);
}
} else if symbol == &self.pair_b {
self.price_history_b.push(price);
if self.price_history_b.len() > self.lookback_period {
self.price_history_b.remove(0);
}
}
// Only generate signals when we have prices for both pairs
if let (Some(&price_a), Some(&price_b)) =
(self.last_prices.get(&self.pair_a), self.last_prices.get(&self.pair_b)) {
// Calculate current spread
let spread = self.calculate_spread(price_a, price_b);
// Update spread history
self.spread_history.push(spread);
if self.spread_history.len() > self.lookback_period {
self.spread_history.remove(0);
}
// Need enough data
if self.spread_history.len() >= self.lookback_period {
// Calculate statistics
let mean = Self::calculate_mean(&self.spread_history);
let std_dev = Self::calculate_std_dev(&self.spread_history, mean);
// Calculate bands
let upper_band = mean + self.entry_threshold * std_dev;
let lower_band = mean - self.entry_threshold * std_dev;
let position_a = self.positions.get(&self.pair_a).copied().unwrap_or(0.0);
let position_b = self.positions.get(&self.pair_b).copied().unwrap_or(0.0);
// Check for entry signals
if position_a == 0.0 && position_b == 0.0 {
if spread < lower_band {
// Spread too low: Buy A, Sell B
signals.push(Signal {
symbol: self.pair_a.clone(),
signal_type: SignalType::Buy,
strength: 1.0,
quantity: Some(self.position_size),
reason: Some(format!(
"Pairs trade: spread ${:.2} < lower band ${:.2}",
spread, lower_band
)),
metadata: Some(json!({
"spread": spread,
"mean": mean,
"std_dev": std_dev,
"pair": "A",
})),
});
signals.push(Signal {
symbol: self.pair_b.clone(),
signal_type: SignalType::Sell,
strength: 1.0,
quantity: Some(self.position_size * self.hedge_ratio),
reason: Some(format!(
"Pairs trade hedge: spread ${:.2} < lower band ${:.2}",
spread, lower_band
)),
metadata: Some(json!({
"spread": spread,
"mean": mean,
"std_dev": std_dev,
"pair": "B",
})),
});
self.positions.insert(self.pair_a.clone(), self.position_size);
self.positions.insert(self.pair_b.clone(), -self.position_size * self.hedge_ratio);
} else if spread > upper_band {
// Spread too high: Sell A, Buy B
signals.push(Signal {
symbol: self.pair_a.clone(),
signal_type: SignalType::Sell,
strength: 1.0,
quantity: Some(self.position_size),
reason: Some(format!(
"Pairs trade: spread ${:.2} > upper band ${:.2}",
spread, upper_band
)),
metadata: Some(json!({
"spread": spread,
"mean": mean,
"std_dev": std_dev,
"pair": "A",
})),
});
signals.push(Signal {
symbol: self.pair_b.clone(),
signal_type: SignalType::Buy,
strength: 1.0,
quantity: Some(self.position_size * self.hedge_ratio),
reason: Some(format!(
"Pairs trade hedge: spread ${:.2} > upper band ${:.2}",
spread, upper_band
)),
metadata: Some(json!({
"spread": spread,
"mean": mean,
"std_dev": std_dev,
"pair": "B",
})),
});
self.positions.insert(self.pair_a.clone(), -self.position_size);
self.positions.insert(self.pair_b.clone(), self.position_size * self.hedge_ratio);
}
}
// Check for exit signals
else if position_a != 0.0 && position_b != 0.0 {
// Exit when spread returns to mean
let spread_distance = (spread - mean).abs();
let exit_threshold = std_dev * 0.1; // Exit near mean
if spread_distance < exit_threshold {
// Close positions
if position_a > 0.0 {
// We're long A, short B
signals.push(Signal {
symbol: self.pair_a.clone(),
signal_type: SignalType::Sell,
strength: 1.0,
quantity: Some(position_a),
reason: Some(format!(
"Pairs trade exit: spread ${:.2} returned to mean ${:.2}",
spread, mean
)),
metadata: None,
});
signals.push(Signal {
symbol: self.pair_b.clone(),
signal_type: SignalType::Buy,
strength: 1.0,
quantity: Some(position_b.abs()),
reason: Some("Pairs trade exit: closing hedge".to_string()),
metadata: None,
});
} else {
// We're short A, long B
signals.push(Signal {
symbol: self.pair_a.clone(),
signal_type: SignalType::Buy,
strength: 1.0,
quantity: Some(position_a.abs()),
reason: Some(format!(
"Pairs trade exit: spread ${:.2} returned to mean ${:.2}",
spread, mean
)),
metadata: None,
});
signals.push(Signal {
symbol: self.pair_b.clone(),
signal_type: SignalType::Sell,
strength: 1.0,
quantity: Some(position_b),
reason: Some("Pairs trade exit: closing hedge".to_string()),
metadata: None,
});
}
self.positions.clear();
}
}
}
}
}
signals
}
fn on_fill(&mut self, symbol: &str, quantity: f64, price: f64, side: &str) {
eprintln!("Pairs trading fill: {} {} @ {} - {}", quantity, symbol, price, side);
}
fn get_name(&self) -> &str {
&self.name
}
fn get_parameters(&self) -> serde_json::Value {
json!({
"pair_a": self.pair_a,
"pair_b": self.pair_b,
"lookback_period": self.lookback_period,
"entry_threshold": self.entry_threshold,
"position_size": self.position_size,
"hedge_ratio": self.hedge_ratio,
})
}
}

View file

@ -140,7 +140,7 @@ export class RustBacktestAdapter extends EventEmitter {
timestamp: new Date(point[0]).getTime(),
value: point[1],
})),
trades: this.transformFillsToTrades(rustResult.trades || []),
trades: this.transformCompletedTradesToFills(rustResult.trades || []),
dailyReturns: this.calculateDailyReturns(rustResult.equity_curve),
finalPositions: rustResult.final_positions || {},
executionTime: Date.now() - startTime,
@ -290,9 +290,24 @@ export class RustBacktestAdapter extends EventEmitter {
// Use native Rust strategy for maximum performance
this.container.logger.info('Using native Rust strategy implementation');
// Map strategy names to their Rust strategy types
let strategyType = 'sma_crossover'; // default
if (strategyName.toLowerCase().includes('mean') || strategyName.toLowerCase().includes('reversion')) {
strategyType = 'mean_reversion';
} else if (strategyName.toLowerCase().includes('momentum')) {
strategyType = 'momentum';
} else if (strategyName.toLowerCase().includes('pairs')) {
strategyType = 'pairs_trading';
} else if (strategyName.toLowerCase().includes('sma') || strategyName.toLowerCase().includes('crossover')) {
strategyType = 'sma_crossover';
}
this.container.logger.info(`Mapped strategy '${strategyName}' to type '${strategyType}'`);
// Use the addNativeStrategy method instead
this.currentEngine.addNativeStrategy(
'sma_crossover', // strategy type
strategyType,
strategyName,
`strategy-${Date.now()}`,
parameters
@ -340,26 +355,168 @@ export class RustBacktestAdapter extends EventEmitter {
};
}
private transformFillsToTrades(completedTrades: any[]): any[] {
// Now we have CompletedTrade objects with symbol and side information
return completedTrades.map((trade, index) => {
const timestamp = new Date(trade.timestamp);
const side = trade.side === 'Buy' ? 'buy' : 'sell';
return {
id: `trade-${index}`,
private transformCompletedTradesToFills(completedTrades: any[]): any[] {
// Convert completed trades (paired entry/exit) back to individual fills for the UI
const fills: any[] = [];
let fillId = 0;
completedTrades.forEach(trade => {
// Create entry fill
fills.push({
id: `fill-${fillId++}`,
timestamp: trade.entry_time || trade.entryDate,
symbol: trade.symbol,
entryDate: timestamp.toISOString(),
exitDate: timestamp.toISOString(), // Same as entry for individual fills
entryPrice: trade.price,
exitPrice: trade.price,
side: trade.side === 'Buy' || trade.side === 'long' ? 'buy' : 'sell',
quantity: trade.quantity,
side,
pnl: 0, // Would need to calculate from paired trades
pnlPercent: 0,
commission: trade.commission,
duration: 0, // Would need to calculate from paired trades
};
price: trade.entry_price || trade.entryPrice,
commission: trade.commission / 2, // Split commission between entry and exit
});
// Create exit fill (opposite side)
const exitSide = (trade.side === 'Buy' || trade.side === 'long') ? 'sell' : 'buy';
fills.push({
id: `fill-${fillId++}`,
timestamp: trade.exit_time || trade.exitDate,
symbol: trade.symbol,
side: exitSide,
quantity: trade.quantity,
price: trade.exit_price || trade.exitPrice,
commission: trade.commission / 2,
pnl: trade.pnl,
});
});
// Sort by timestamp
fills.sort((a, b) => new Date(a.timestamp).getTime() - new Date(b.timestamp).getTime());
return fills;
}
private transformFillsToTrades(completedTrades: any[]): any[] {
// Group fills by symbol to match entries with exits
const fillsBySymbol: { [symbol: string]: any[] } = {};
completedTrades.forEach(trade => {
if (!fillsBySymbol[trade.symbol]) {
fillsBySymbol[trade.symbol] = [];
}
fillsBySymbol[trade.symbol].push(trade);
});
const pairedTrades: any[] = [];
const openPositions: { [symbol: string]: any[] } = {};
// Process each symbol's fills chronologically
Object.entries(fillsBySymbol).forEach(([symbol, fills]) => {
// Sort by timestamp
fills.sort((a, b) => new Date(a.timestamp).getTime() - new Date(b.timestamp).getTime());
fills.forEach((fill, idx) => {
const isBuy = fill.side === 'Buy';
const timestamp = new Date(fill.timestamp);
if (!openPositions[symbol]) {
openPositions[symbol] = [];
}
const openPos = openPositions[symbol];
// For buy fills, add to open positions
if (isBuy) {
openPos.push(fill);
} else {
// For sell fills, match with open buy positions (FIFO)
if (openPos.length > 0 && openPos[0].side === 'Buy') {
const entry = openPos.shift();
const entryDate = new Date(entry.timestamp);
const duration = (timestamp.getTime() - entryDate.getTime()) / 1000; // seconds
const pnl = (fill.price - entry.price) * fill.quantity - entry.commission - fill.commission;
const pnlPercent = ((fill.price - entry.price) / entry.price) * 100;
pairedTrades.push({
id: `trade-${pairedTrades.length}`,
symbol,
entryDate: entryDate.toISOString(),
exitDate: timestamp.toISOString(),
entryPrice: entry.price,
exitPrice: fill.price,
quantity: fill.quantity,
side: 'long',
pnl,
pnlPercent,
commission: entry.commission + fill.commission,
duration,
});
} else {
// This is a short entry
openPos.push(fill);
}
}
// For short positions (sell first, then buy to cover)
if (!isBuy && openPos.length > 0) {
const lastPos = openPos[openPos.length - 1];
if (lastPos.side === 'Sell' && idx < fills.length - 1) {
const nextFill = fills[idx + 1];
if (nextFill && nextFill.side === 'Buy') {
// We'll handle this when we process the buy fill
}
}
}
// Handle buy fills that close short positions
if (isBuy && openPos.length > 1) {
const shortPos = openPos.find(p => p.side === 'Sell');
if (shortPos) {
const shortIdx = openPos.indexOf(shortPos);
openPos.splice(shortIdx, 1);
const entryDate = new Date(shortPos.timestamp);
const duration = (timestamp.getTime() - entryDate.getTime()) / 1000;
const pnl = (shortPos.price - fill.price) * fill.quantity - shortPos.commission - fill.commission;
const pnlPercent = ((shortPos.price - fill.price) / shortPos.price) * 100;
pairedTrades.push({
id: `trade-${pairedTrades.length}`,
symbol,
entryDate: entryDate.toISOString(),
exitDate: timestamp.toISOString(),
entryPrice: shortPos.price,
exitPrice: fill.price,
quantity: fill.quantity,
side: 'short',
pnl,
pnlPercent,
commission: shortPos.commission + fill.commission,
duration,
});
}
}
});
// Add any remaining open positions as incomplete trades
const remainingOpenPositions = openPositions[symbol] || [];
remainingOpenPositions.forEach(pos => {
const timestamp = new Date(pos.timestamp);
const side = pos.side === 'Buy' ? 'buy' : 'sell';
pairedTrades.push({
id: `trade-${pairedTrades.length}`,
symbol,
entryDate: timestamp.toISOString(),
exitDate: timestamp.toISOString(), // Same as entry for open positions
entryPrice: pos.price,
exitPrice: pos.price,
quantity: pos.quantity,
side,
pnl: 0, // No PnL for open positions
pnlPercent: 0,
commission: pos.commission,
duration: 0,
});
});
});
return pairedTrades;
}
}

View file

@ -0,0 +1,180 @@
import { RustBacktestAdapter } from './src/backtest/RustBacktestAdapter';
import { IServiceContainer } from '@stock-bot/di';
import { BacktestConfig } from './src/types';
// Mock container
const mockContainer: IServiceContainer = {
logger: {
info: console.log,
error: console.error,
warn: console.warn,
debug: console.log,
},
mongodb: {} as any,
postgres: {} as any,
redis: {} as any,
custom: {},
} as IServiceContainer;
// Mock storage service that returns test data
class MockStorageService {
async getHistoricalBars(symbol: string, startDate: Date, endDate: Date, frequency: string) {
console.log(`MockStorageService: Getting bars for ${symbol} from ${startDate} to ${endDate}`);
// Generate test data with mean reverting behavior
const bars = [];
const startTime = startDate.getTime();
const endTime = endDate.getTime();
const dayMs = 24 * 60 * 60 * 1000;
let time = startTime;
let dayIndex = 0;
// Base prices for different symbols
const basePrices = {
'AAPL': 150,
'GOOGL': 2800,
'MSFT': 400,
};
const basePrice = basePrices[symbol as keyof typeof basePrices] || 100;
while (time <= endTime) {
// Create mean reverting price movement
// Price oscillates around the base price with increasing then decreasing deviations
const cycleLength = 40; // 40 day cycle
const positionInCycle = dayIndex % cycleLength;
const halfCycle = cycleLength / 2;
let deviation;
if (positionInCycle < halfCycle) {
// First half: price moves away from mean
deviation = (positionInCycle / halfCycle) * 0.1; // Up to 10% deviation
} else {
// Second half: price reverts to mean
deviation = ((cycleLength - positionInCycle) / halfCycle) * 0.1;
}
// Alternate between above and below mean
const cycleNumber = Math.floor(dayIndex / cycleLength);
const multiplier = cycleNumber % 2 === 0 ? 1 : -1;
const price = basePrice * (1 + multiplier * deviation);
// Add some noise
const noise = (Math.random() - 0.5) * 0.02 * basePrice;
const finalPrice = price + noise;
bars.push({
timestamp: new Date(time),
open: finalPrice * 0.99,
high: finalPrice * 1.01,
low: finalPrice * 0.98,
close: finalPrice,
volume: 1000000,
vwap: finalPrice,
});
time += dayMs;
dayIndex++;
}
console.log(`Generated ${bars.length} bars for ${symbol}, first close: ${bars[0].close.toFixed(2)}, last close: ${bars[bars.length - 1].close.toFixed(2)}`);
return bars;
}
}
// Test the backtest
async function testMeanReversionBacktest() {
console.log('=== Testing Mean Reversion Backtest ===\n');
// Create adapter with mock storage
const adapter = new RustBacktestAdapter(mockContainer);
(adapter as any).storageService = new MockStorageService();
const config: BacktestConfig = {
name: 'Mean Reversion Test',
strategy: 'mean_reversion',
symbols: ['AAPL', 'GOOGL', 'MSFT'],
startDate: '2024-01-01T00:00:00Z',
endDate: '2024-06-01T00:00:00Z',
initialCapital: 100000,
commission: 0.001,
slippage: 0.0001,
dataFrequency: '1d',
config: {
lookbackPeriod: 20,
entryThreshold: 2.0,
positionSize: 100,
},
};
try {
console.log('Starting backtest...\n');
const result = await adapter.runBacktest(config);
console.log('\n=== Backtest Results ===');
console.log(`Status: ${result.status}`);
console.log(`Total Trades: ${result.metrics.totalTrades}`);
console.log(`Profitable Trades: ${result.metrics.profitableTrades}`);
console.log(`Win Rate: ${result.metrics.winRate.toFixed(2)}%`);
console.log(`Total Return: ${result.metrics.totalReturn.toFixed(2)}%`);
console.log(`Sharpe Ratio: ${result.metrics.sharpeRatio.toFixed(2)}`);
console.log(`Max Drawdown: ${result.metrics.maxDrawdown.toFixed(2)}%`);
console.log('\n=== Trade Analysis ===');
console.log(`Number of completed trades: ${result.trades.length}`);
// Analyze trades by symbol
const tradesBySymbol: Record<string, any[]> = {};
result.trades.forEach(trade => {
if (!tradesBySymbol[trade.symbol]) {
tradesBySymbol[trade.symbol] = [];
}
tradesBySymbol[trade.symbol].push(trade);
});
Object.entries(tradesBySymbol).forEach(([symbol, trades]) => {
console.log(`\n${symbol}: ${trades.length} trades`);
const longTrades = trades.filter(t => t.side === 'long');
const shortTrades = trades.filter(t => t.side === 'short');
console.log(` - Long trades: ${longTrades.length}`);
console.log(` - Short trades: ${shortTrades.length}`);
// Count buy/sell pairs
const buyTrades = trades.filter(t => t.side === 'buy');
const sellTrades = trades.filter(t => t.side === 'sell');
console.log(` - Buy trades: ${buyTrades.length}`);
console.log(` - Sell trades: ${sellTrades.length}`);
// Show first few trades
console.log(` - First 3 trades:`);
trades.slice(0, 3).forEach((trade, idx) => {
console.log(` ${idx + 1}. ${trade.side} - Price: $${trade.price.toFixed(2)}, Quantity: ${trade.quantity}${trade.pnl ? `, PnL: $${trade.pnl.toFixed(2)}` : ''}`);
});
});
// Check position distribution
const allDurations = result.trades.map(t => t.duration / 86400); // Convert to days
const avgDuration = allDurations.reduce((a, b) => a + b, 0) / allDurations.length;
const minDuration = Math.min(...allDurations);
const maxDuration = Math.max(...allDurations);
console.log('\n=== Duration Analysis ===');
console.log(`Average trade duration: ${avgDuration.toFixed(1)} days`);
console.log(`Min duration: ${minDuration.toFixed(1)} days`);
console.log(`Max duration: ${maxDuration.toFixed(1)} days`);
// Final positions
console.log('\n=== Final Positions ===');
Object.entries(result.finalPositions).forEach(([symbol, position]) => {
console.log(`${symbol}: ${position}`);
});
} catch (error) {
console.error('Backtest failed:', error);
}
}
// Run the test
testMeanReversionBacktest().catch(console.error);

View file

@ -0,0 +1,101 @@
import { RustBacktestAdapter } from './src/backtest/RustBacktestAdapter';
import { IServiceContainer } from '@stock-bot/di';
import { BacktestConfig } from './src/types';
// Mock container
const mockContainer: IServiceContainer = {
logger: {
info: console.log,
error: console.error,
warn: console.warn,
debug: console.log,
},
mongodb: {} as any,
postgres: {} as any,
redis: {} as any,
custom: {},
} as IServiceContainer;
// Mock storage service
class MockStorageService {
async getHistoricalBars(symbol: string, startDate: Date, endDate: Date, frequency: string) {
const bars = [];
const startTime = startDate.getTime();
const endTime = endDate.getTime();
const dayMs = 24 * 60 * 60 * 1000;
let time = startTime;
let dayIndex = 0;
// Simple oscillating price for testing
while (time <= endTime) {
const price = 100 + 10 * Math.sin(dayIndex * 0.2);
bars.push({
timestamp: new Date(time),
open: price * 0.99,
high: price * 1.01,
low: price * 0.98,
close: price,
volume: 1000000,
vwap: price,
});
time += dayMs;
dayIndex++;
}
return bars;
}
}
// Test the backtest
async function testTradeFormat() {
console.log('=== Testing Trade Format ===\n');
const adapter = new RustBacktestAdapter(mockContainer);
(adapter as any).storageService = new MockStorageService();
const config: BacktestConfig = {
name: 'Trade Format Test',
strategy: 'Simple Moving Average Crossover',
symbols: ['TEST'],
startDate: '2024-01-01T00:00:00Z',
endDate: '2024-03-01T00:00:00Z',
initialCapital: 100000,
commission: 0.001,
slippage: 0.0001,
dataFrequency: '1d',
config: {
fastPeriod: 5,
slowPeriod: 15,
},
};
try {
const result = await adapter.runBacktest(config);
console.log('\n=== Trade Format ===');
console.log('Number of trades:', result.trades.length);
console.log('\nFirst 3 trades:');
result.trades.slice(0, 3).forEach((trade, idx) => {
console.log(`\nTrade ${idx + 1}:`, JSON.stringify(trade, null, 2));
});
// Check what format the trades are in
if (result.trades.length > 0) {
const firstTrade = result.trades[0];
console.log('\n=== Trade Structure Analysis ===');
console.log('Keys:', Object.keys(firstTrade));
console.log('Has entryDate/exitDate?', 'entryDate' in firstTrade && 'exitDate' in firstTrade);
console.log('Has timestamp?', 'timestamp' in firstTrade);
console.log('Has side field?', 'side' in firstTrade);
console.log('Side value:', firstTrade.side);
}
} catch (error) {
console.error('Test failed:', error);
}
}
testTradeFormat().catch(console.error);