added trade-tracking and example rust strats
This commit is contained in:
parent
0a4702d12a
commit
3a7557c8f4
15 changed files with 2108 additions and 29 deletions
Binary file not shown.
|
|
@ -101,6 +101,58 @@ impl BacktestEngine {
|
|||
slow_period,
|
||||
)));
|
||||
}
|
||||
"mean_reversion" => {
|
||||
let lookback_period: usize = parameters.get_named_property::<f64>("lookbackPeriod")
|
||||
.unwrap_or(20.0) as usize;
|
||||
let entry_threshold: f64 = parameters.get_named_property::<f64>("entryThreshold")
|
||||
.unwrap_or(2.0);
|
||||
let position_size: f64 = parameters.get_named_property::<f64>("positionSize")
|
||||
.unwrap_or(100.0);
|
||||
|
||||
engine.add_strategy(Box::new(crate::strategies::MeanReversionFixedStrategy::new(
|
||||
name.clone(),
|
||||
id,
|
||||
lookback_period,
|
||||
entry_threshold,
|
||||
position_size,
|
||||
)));
|
||||
}
|
||||
"momentum" => {
|
||||
let lookback_period: usize = parameters.get_named_property::<f64>("lookbackPeriod")
|
||||
.unwrap_or(14.0) as usize;
|
||||
let momentum_threshold: f64 = parameters.get_named_property::<f64>("momentumThreshold")
|
||||
.unwrap_or(5.0);
|
||||
let position_size: f64 = parameters.get_named_property::<f64>("positionSize")
|
||||
.unwrap_or(100.0);
|
||||
|
||||
engine.add_strategy(Box::new(crate::strategies::MomentumStrategy::new(
|
||||
name.clone(),
|
||||
id,
|
||||
lookback_period,
|
||||
momentum_threshold,
|
||||
position_size,
|
||||
)));
|
||||
}
|
||||
"pairs_trading" => {
|
||||
let pair_a: String = parameters.get_named_property::<String>("pairA")?;
|
||||
let pair_b: String = parameters.get_named_property::<String>("pairB")?;
|
||||
let lookback_period: usize = parameters.get_named_property::<f64>("lookbackPeriod")
|
||||
.unwrap_or(20.0) as usize;
|
||||
let entry_threshold: f64 = parameters.get_named_property::<f64>("entryThreshold")
|
||||
.unwrap_or(2.0);
|
||||
let position_size: f64 = parameters.get_named_property::<f64>("positionSize")
|
||||
.unwrap_or(100.0);
|
||||
|
||||
engine.add_strategy(Box::new(crate::strategies::PairsTradingStrategy::new(
|
||||
name.clone(),
|
||||
id,
|
||||
pair_a,
|
||||
pair_b,
|
||||
lookback_period,
|
||||
entry_threshold,
|
||||
position_size,
|
||||
)));
|
||||
}
|
||||
_ => {
|
||||
return Err(Error::from_reason(format!("Unknown strategy type: {}", strategy_type)));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ use crate::{
|
|||
|
||||
use super::{
|
||||
BacktestConfig, BacktestState, EventQueue, BacktestEvent, EventType,
|
||||
Strategy, Signal, SignalType, BacktestResult,
|
||||
Strategy, Signal, SignalType, BacktestResult, TradeTracker,
|
||||
};
|
||||
|
||||
pub struct BacktestEngine {
|
||||
|
|
@ -38,6 +38,9 @@ pub struct BacktestEngine {
|
|||
|
||||
// Price tracking
|
||||
last_prices: HashMap<String, f64>,
|
||||
|
||||
// Trade tracking
|
||||
trade_tracker: TradeTracker,
|
||||
}
|
||||
|
||||
impl BacktestEngine {
|
||||
|
|
@ -67,6 +70,7 @@ impl BacktestEngine {
|
|||
profitable_trades: 0,
|
||||
total_pnl: 0.0,
|
||||
last_prices: HashMap::new(),
|
||||
trade_tracker: TradeTracker::new(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -394,6 +398,9 @@ impl BacktestEngine {
|
|||
// Record the fill with symbol and side information
|
||||
self.state.write().record_fill(order.symbol.clone(), order.side, fill.clone());
|
||||
|
||||
// Track trades
|
||||
self.trade_tracker.process_fill(&order.symbol, order.side, &fill);
|
||||
|
||||
// Update cash
|
||||
let cash_change = match order.side {
|
||||
crate::Side::Buy => -(fill.quantity * fill.price + fill.commission),
|
||||
|
|
@ -410,6 +417,13 @@ impl BacktestEngine {
|
|||
}
|
||||
}
|
||||
|
||||
eprintln!("Fill processed: {} {} @ {} (side: {:?})",
|
||||
fill.quantity, order.symbol, fill.price, order.side);
|
||||
eprintln!("Current position after fill: {}",
|
||||
self.position_tracker.get_position(&order.symbol)
|
||||
.map(|p| p.quantity)
|
||||
.unwrap_or(0.0));
|
||||
|
||||
// Update metrics
|
||||
self.total_trades += 1;
|
||||
if update.resulting_position.realized_pnl > 0.0 {
|
||||
|
|
@ -470,24 +484,43 @@ impl BacktestEngine {
|
|||
let total_pnl = realized_pnl + unrealized_pnl;
|
||||
let total_return = (total_pnl / self.config.initial_capital) * 100.0;
|
||||
|
||||
// Get completed trades from trade tracker for metrics
|
||||
let completed_trades = self.trade_tracker.get_completed_trades();
|
||||
|
||||
// Calculate metrics from completed trades
|
||||
let completed_trade_count = completed_trades.len();
|
||||
let profitable_trades = completed_trades.iter().filter(|t| t.pnl > 0.0).count();
|
||||
let total_wins: f64 = completed_trades.iter().filter(|t| t.pnl > 0.0).map(|t| t.pnl).sum();
|
||||
let total_losses: f64 = completed_trades.iter().filter(|t| t.pnl < 0.0).map(|t| t.pnl.abs()).sum();
|
||||
let avg_win = if profitable_trades > 0 { total_wins / profitable_trades as f64 } else { 0.0 };
|
||||
let avg_loss = if completed_trade_count > profitable_trades {
|
||||
total_losses / (completed_trade_count - profitable_trades) as f64
|
||||
} else { 0.0 };
|
||||
let profit_factor = if total_losses > 0.0 { total_wins / total_losses } else { 0.0 };
|
||||
|
||||
// For the API, return all fills (not just completed trades)
|
||||
// This shows all trading activity
|
||||
let all_fills = state.completed_trades.clone();
|
||||
let total_trades = all_fills.len();
|
||||
|
||||
BacktestResult {
|
||||
config: self.config.clone(),
|
||||
metrics: super::BacktestMetrics {
|
||||
total_return,
|
||||
total_trades: self.total_trades,
|
||||
profitable_trades: self.profitable_trades,
|
||||
win_rate: if self.total_trades > 0 {
|
||||
(self.profitable_trades as f64 / self.total_trades as f64) * 100.0
|
||||
total_trades,
|
||||
profitable_trades,
|
||||
win_rate: if completed_trade_count > 0 {
|
||||
(profitable_trades as f64 / completed_trade_count as f64) * 100.0
|
||||
} else { 0.0 },
|
||||
profit_factor: 0.0, // TODO: Calculate properly
|
||||
profit_factor,
|
||||
sharpe_ratio: 0.0, // TODO: Calculate properly
|
||||
max_drawdown: 0.0, // TODO: Calculate properly
|
||||
total_pnl,
|
||||
avg_win: 0.0, // TODO: Calculate properly
|
||||
avg_loss: 0.0, // TODO: Calculate properly
|
||||
avg_win,
|
||||
avg_loss,
|
||||
},
|
||||
equity_curve: state.equity_curve.clone(),
|
||||
trades: state.completed_trades.clone(),
|
||||
trades: all_fills,
|
||||
final_positions: self.position_tracker.get_all_positions()
|
||||
.into_iter()
|
||||
.map(|p| (p.symbol.clone(), p))
|
||||
|
|
|
|||
|
|
@ -12,11 +12,13 @@ pub mod engine;
|
|||
pub mod event;
|
||||
pub mod strategy;
|
||||
pub mod results;
|
||||
pub mod trade_tracker;
|
||||
|
||||
pub use engine::BacktestEngine;
|
||||
pub use event::{BacktestEvent, EventType};
|
||||
pub use strategy::{Strategy, Signal, SignalType};
|
||||
pub use results::{BacktestResult, BacktestMetrics};
|
||||
pub use trade_tracker::{TradeTracker, CompletedTrade as TrackedTrade};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CompletedTrade {
|
||||
|
|
|
|||
258
apps/stock/core/src/backtest/trade_tracker.rs
Normal file
258
apps/stock/core/src/backtest/trade_tracker.rs
Normal file
|
|
@ -0,0 +1,258 @@
|
|||
use std::collections::{HashMap, VecDeque};
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Serialize, Deserialize};
|
||||
use crate::{Fill, Side};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CompletedTrade {
|
||||
pub id: String,
|
||||
pub symbol: String,
|
||||
pub entry_time: DateTime<Utc>,
|
||||
pub exit_time: DateTime<Utc>,
|
||||
pub entry_price: f64,
|
||||
pub exit_price: f64,
|
||||
pub quantity: f64,
|
||||
pub side: Side, // Side of the opening trade
|
||||
pub pnl: f64,
|
||||
pub pnl_percent: f64,
|
||||
pub commission: f64,
|
||||
pub duration_seconds: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct OpenPosition {
|
||||
symbol: String,
|
||||
side: Side,
|
||||
quantity: f64,
|
||||
entry_price: f64,
|
||||
entry_time: DateTime<Utc>,
|
||||
commission: f64,
|
||||
}
|
||||
|
||||
/// Tracks fills and matches them into completed trades
|
||||
pub struct TradeTracker {
|
||||
open_positions: HashMap<String, VecDeque<OpenPosition>>,
|
||||
completed_trades: Vec<CompletedTrade>,
|
||||
trade_counter: u64,
|
||||
}
|
||||
|
||||
impl TradeTracker {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
open_positions: HashMap::new(),
|
||||
completed_trades: Vec::new(),
|
||||
trade_counter: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn process_fill(&mut self, symbol: &str, side: Side, fill: &Fill) {
|
||||
let positions = self.open_positions.entry(symbol.to_string()).or_insert_with(VecDeque::new);
|
||||
|
||||
// Check if this fill closes existing positions
|
||||
let mut remaining_quantity = fill.quantity;
|
||||
let mut fills_to_remove = Vec::new();
|
||||
|
||||
for (idx, open_pos) in positions.iter_mut().enumerate() {
|
||||
// Only match against opposite side positions
|
||||
if open_pos.side == side {
|
||||
continue;
|
||||
}
|
||||
|
||||
if remaining_quantity <= 0.0 {
|
||||
break;
|
||||
}
|
||||
|
||||
let matched_quantity = remaining_quantity.min(open_pos.quantity);
|
||||
|
||||
// Calculate PnL
|
||||
let (pnl, pnl_percent) = Self::calculate_pnl(
|
||||
&open_pos,
|
||||
fill.price,
|
||||
matched_quantity,
|
||||
fill.commission,
|
||||
);
|
||||
|
||||
// Create completed trade
|
||||
self.trade_counter += 1;
|
||||
let completed_trade = CompletedTrade {
|
||||
id: format!("trade-{}", self.trade_counter),
|
||||
symbol: symbol.to_string(),
|
||||
entry_time: open_pos.entry_time,
|
||||
exit_time: fill.timestamp,
|
||||
entry_price: open_pos.entry_price,
|
||||
exit_price: fill.price,
|
||||
quantity: matched_quantity,
|
||||
side: open_pos.side.clone(),
|
||||
pnl,
|
||||
pnl_percent,
|
||||
commission: open_pos.commission + (fill.commission * matched_quantity / fill.quantity),
|
||||
duration_seconds: (fill.timestamp - open_pos.entry_time).num_seconds(),
|
||||
};
|
||||
|
||||
self.completed_trades.push(completed_trade);
|
||||
|
||||
// Update open position
|
||||
open_pos.quantity -= matched_quantity;
|
||||
remaining_quantity -= matched_quantity;
|
||||
|
||||
if open_pos.quantity <= 0.0 {
|
||||
fills_to_remove.push(idx);
|
||||
}
|
||||
}
|
||||
|
||||
// Remove fully closed positions
|
||||
for idx in fills_to_remove.iter().rev() {
|
||||
positions.remove(*idx);
|
||||
}
|
||||
|
||||
// If there's remaining quantity, it opens a new position
|
||||
if remaining_quantity > 0.0 {
|
||||
let new_position = OpenPosition {
|
||||
symbol: symbol.to_string(),
|
||||
side,
|
||||
quantity: remaining_quantity,
|
||||
entry_price: fill.price,
|
||||
entry_time: fill.timestamp,
|
||||
commission: fill.commission * remaining_quantity / fill.quantity,
|
||||
};
|
||||
positions.push_back(new_position);
|
||||
}
|
||||
}
|
||||
|
||||
fn calculate_pnl(open_pos: &OpenPosition, exit_price: f64, quantity: f64, exit_commission: f64) -> (f64, f64) {
|
||||
let entry_value = open_pos.entry_price * quantity;
|
||||
let exit_value = exit_price * quantity;
|
||||
|
||||
let gross_pnl = match open_pos.side {
|
||||
Side::Buy => exit_value - entry_value,
|
||||
Side::Sell => entry_value - exit_value,
|
||||
};
|
||||
|
||||
let commission = open_pos.commission * (quantity / open_pos.quantity) + exit_commission * (quantity / open_pos.quantity);
|
||||
let net_pnl = gross_pnl - commission;
|
||||
let pnl_percent = (net_pnl / entry_value) * 100.0;
|
||||
|
||||
(net_pnl, pnl_percent)
|
||||
}
|
||||
|
||||
pub fn get_completed_trades(&self) -> &[CompletedTrade] {
|
||||
&self.completed_trades
|
||||
}
|
||||
|
||||
pub fn get_open_positions(&self) -> HashMap<String, Vec<(Side, f64, f64)>> {
|
||||
let mut result = HashMap::new();
|
||||
|
||||
for (symbol, positions) in &self.open_positions {
|
||||
let pos_info: Vec<(Side, f64, f64)> = positions
|
||||
.iter()
|
||||
.map(|p| (p.side.clone(), p.quantity, p.entry_price))
|
||||
.collect();
|
||||
|
||||
if !pos_info.is_empty() {
|
||||
result.insert(symbol.clone(), pos_info);
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
pub fn get_net_position(&self, symbol: &str) -> f64 {
|
||||
let positions = match self.open_positions.get(symbol) {
|
||||
Some(pos) => pos,
|
||||
None => return 0.0,
|
||||
};
|
||||
|
||||
let mut net = 0.0;
|
||||
for pos in positions {
|
||||
match pos.side {
|
||||
Side::Buy => net += pos.quantity,
|
||||
Side::Sell => net -= pos.quantity,
|
||||
}
|
||||
}
|
||||
|
||||
net
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_simple_round_trip() {
|
||||
let mut tracker = TradeTracker::new();
|
||||
|
||||
// Buy 100 shares at $50
|
||||
let buy_fill = Fill {
|
||||
timestamp: Utc::now(),
|
||||
price: 50.0,
|
||||
quantity: 100.0,
|
||||
commission: 1.0,
|
||||
};
|
||||
tracker.process_fill("AAPL", Side::Buy, &buy_fill);
|
||||
|
||||
// Sell 100 shares at $55
|
||||
let sell_fill = Fill {
|
||||
timestamp: Utc::now(),
|
||||
price: 55.0,
|
||||
quantity: 100.0,
|
||||
commission: 1.0,
|
||||
};
|
||||
tracker.process_fill("AAPL", Side::Sell, &sell_fill);
|
||||
|
||||
let trades = tracker.get_completed_trades();
|
||||
assert_eq!(trades.len(), 1);
|
||||
|
||||
let trade = &trades[0];
|
||||
assert_eq!(trade.symbol, "AAPL");
|
||||
assert_eq!(trade.quantity, 100.0);
|
||||
assert_eq!(trade.entry_price, 50.0);
|
||||
assert_eq!(trade.exit_price, 55.0);
|
||||
assert_eq!(trade.pnl, 498.0); // (55-50)*100 - 2 commission
|
||||
assert_eq!(trade.side, Side::Buy);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partial_fills() {
|
||||
let mut tracker = TradeTracker::new();
|
||||
|
||||
// Buy 100 shares at $50
|
||||
let buy_fill = Fill {
|
||||
timestamp: Utc::now(),
|
||||
price: 50.0,
|
||||
quantity: 100.0,
|
||||
commission: 1.0,
|
||||
};
|
||||
tracker.process_fill("AAPL", Side::Buy, &buy_fill);
|
||||
|
||||
// Sell 60 shares at $55
|
||||
let sell_fill1 = Fill {
|
||||
timestamp: Utc::now(),
|
||||
price: 55.0,
|
||||
quantity: 60.0,
|
||||
commission: 0.6,
|
||||
};
|
||||
tracker.process_fill("AAPL", Side::Sell, &sell_fill1);
|
||||
|
||||
// Check we have one completed trade and remaining position
|
||||
let trades = tracker.get_completed_trades();
|
||||
assert_eq!(trades.len(), 1);
|
||||
assert_eq!(trades[0].quantity, 60.0);
|
||||
|
||||
assert_eq!(tracker.get_net_position("AAPL"), 40.0);
|
||||
|
||||
// Sell remaining 40 shares at $52
|
||||
let sell_fill2 = Fill {
|
||||
timestamp: Utc::now(),
|
||||
price: 52.0,
|
||||
quantity: 40.0,
|
||||
commission: 0.4,
|
||||
};
|
||||
tracker.process_fill("AAPL", Side::Sell, &sell_fill2);
|
||||
|
||||
// Now we should have 2 completed trades and no position
|
||||
let trades = tracker.get_completed_trades();
|
||||
assert_eq!(trades.len(), 2);
|
||||
assert_eq!(tracker.get_net_position("AAPL"), 0.0);
|
||||
}
|
||||
}
|
||||
|
|
@ -8,6 +8,7 @@ pub mod api;
|
|||
pub mod analytics;
|
||||
pub mod indicators;
|
||||
pub mod backtest;
|
||||
pub mod strategies;
|
||||
|
||||
// Re-export commonly used types
|
||||
pub use positions::{Position, PositionUpdate, TradeRecord, ClosedTrade};
|
||||
|
|
|
|||
210
apps/stock/core/src/strategies/mean_reversion.rs
Normal file
210
apps/stock/core/src/strategies/mean_reversion.rs
Normal file
|
|
@ -0,0 +1,210 @@
|
|||
use std::collections::HashMap;
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde_json::json;
|
||||
|
||||
use crate::{
|
||||
MarketUpdate, MarketDataType,
|
||||
backtest::{Strategy, Signal, SignalType},
|
||||
};
|
||||
|
||||
/// Mean Reversion Strategy
|
||||
///
|
||||
/// This strategy identifies when a security's price deviates significantly from its
|
||||
/// moving average and trades on the assumption that it will revert back to the mean.
|
||||
///
|
||||
/// Entry Signals:
|
||||
/// - BUY when price falls below (MA - threshold * std_dev)
|
||||
/// - SELL when price rises above (MA + threshold * std_dev)
|
||||
///
|
||||
/// Exit Signals:
|
||||
/// - Exit long when price reaches MA
|
||||
/// - Exit short when price reaches MA
|
||||
pub struct MeanReversionStrategy {
|
||||
name: String,
|
||||
id: String,
|
||||
lookback_period: usize,
|
||||
entry_threshold: f64, // Number of standard deviations
|
||||
exit_threshold: f64, // Typically 0 (at the mean)
|
||||
position_size: f64,
|
||||
|
||||
// State
|
||||
price_history: HashMap<String, Vec<f64>>,
|
||||
positions: HashMap<String, f64>,
|
||||
}
|
||||
|
||||
impl MeanReversionStrategy {
|
||||
pub fn new(
|
||||
name: String,
|
||||
id: String,
|
||||
lookback_period: usize,
|
||||
entry_threshold: f64,
|
||||
position_size: f64,
|
||||
) -> Self {
|
||||
Self {
|
||||
name,
|
||||
id,
|
||||
lookback_period,
|
||||
entry_threshold,
|
||||
exit_threshold: 0.0, // Exit at the mean
|
||||
position_size,
|
||||
price_history: HashMap::new(),
|
||||
positions: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn calculate_mean(prices: &[f64]) -> f64 {
|
||||
prices.iter().sum::<f64>() / prices.len() as f64
|
||||
}
|
||||
|
||||
fn calculate_std_dev(prices: &[f64], mean: f64) -> f64 {
|
||||
let variance = prices.iter()
|
||||
.map(|p| (p - mean).powi(2))
|
||||
.sum::<f64>() / prices.len() as f64;
|
||||
variance.sqrt()
|
||||
}
|
||||
}
|
||||
|
||||
impl Strategy for MeanReversionStrategy {
|
||||
fn on_market_data(&mut self, data: &MarketUpdate) -> Vec<Signal> {
|
||||
let mut signals = Vec::new();
|
||||
|
||||
// Only process bar data
|
||||
if let MarketDataType::Bar(bar) = &data.data {
|
||||
let symbol = &data.symbol;
|
||||
let price = bar.close;
|
||||
|
||||
// Update price history
|
||||
let history = self.price_history.entry(symbol.clone()).or_insert_with(Vec::new);
|
||||
history.push(price);
|
||||
|
||||
// Keep only necessary history
|
||||
if history.len() > self.lookback_period {
|
||||
history.remove(0);
|
||||
}
|
||||
|
||||
// Need enough data
|
||||
if history.len() >= self.lookback_period {
|
||||
// Calculate statistics
|
||||
let mean = Self::calculate_mean(history);
|
||||
let std_dev = Self::calculate_std_dev(history, mean);
|
||||
|
||||
// Calculate bands
|
||||
let upper_band = mean + self.entry_threshold * std_dev;
|
||||
let lower_band = mean - self.entry_threshold * std_dev;
|
||||
|
||||
let current_position = self.positions.get(symbol).copied().unwrap_or(0.0);
|
||||
|
||||
// Check for entry signals
|
||||
if current_position == 0.0 {
|
||||
if price < lower_band {
|
||||
// Price is oversold, buy
|
||||
signals.push(Signal {
|
||||
symbol: symbol.clone(),
|
||||
signal_type: SignalType::Buy,
|
||||
strength: 1.0,
|
||||
quantity: Some(self.position_size),
|
||||
reason: Some(format!(
|
||||
"Mean reversion buy: price ${:.2} < lower band ${:.2} (mean: ${:.2}, std: ${:.2})",
|
||||
price, lower_band, mean, std_dev
|
||||
)),
|
||||
metadata: Some(json!({
|
||||
"mean": mean,
|
||||
"std_dev": std_dev,
|
||||
"upper_band": upper_band,
|
||||
"lower_band": lower_band,
|
||||
"price": price,
|
||||
})),
|
||||
});
|
||||
self.positions.insert(symbol.clone(), self.position_size);
|
||||
} else if price > upper_band {
|
||||
// Price is overbought, sell short
|
||||
signals.push(Signal {
|
||||
symbol: symbol.clone(),
|
||||
signal_type: SignalType::Sell,
|
||||
strength: 1.0,
|
||||
quantity: Some(self.position_size),
|
||||
reason: Some(format!(
|
||||
"Mean reversion sell: price ${:.2} > upper band ${:.2} (mean: ${:.2}, std: ${:.2})",
|
||||
price, upper_band, mean, std_dev
|
||||
)),
|
||||
metadata: Some(json!({
|
||||
"mean": mean,
|
||||
"std_dev": std_dev,
|
||||
"upper_band": upper_band,
|
||||
"lower_band": lower_band,
|
||||
"price": price,
|
||||
})),
|
||||
});
|
||||
self.positions.insert(symbol.clone(), -self.position_size);
|
||||
}
|
||||
}
|
||||
// Check for exit signals
|
||||
else if current_position > 0.0 {
|
||||
// We're long, exit when price crosses above mean (not just touches)
|
||||
// or when we hit stop loss
|
||||
let stop_loss = lower_band - std_dev; // Stop loss below lower band
|
||||
if price >= mean * 1.01 || price <= stop_loss {
|
||||
signals.push(Signal {
|
||||
symbol: symbol.clone(),
|
||||
signal_type: SignalType::Sell,
|
||||
strength: 1.0,
|
||||
quantity: Some(current_position),
|
||||
reason: Some(format!(
|
||||
"Mean reversion exit long: price ${:.2} {} mean ${:.2}",
|
||||
price, if price >= mean * 1.01 { "crossed above" } else { "hit stop loss below" }, mean
|
||||
)),
|
||||
metadata: Some(json!({
|
||||
"mean": mean,
|
||||
"price": price,
|
||||
"exit_type": "mean_reversion",
|
||||
})),
|
||||
});
|
||||
self.positions.remove(symbol);
|
||||
}
|
||||
} else if current_position < 0.0 {
|
||||
// We're short, exit when price crosses below mean
|
||||
// or when we hit stop loss
|
||||
let stop_loss = upper_band + std_dev; // Stop loss above upper band
|
||||
if price <= mean * 0.99 || price >= stop_loss {
|
||||
signals.push(Signal {
|
||||
symbol: symbol.clone(),
|
||||
signal_type: SignalType::Buy,
|
||||
strength: 1.0,
|
||||
quantity: Some(current_position.abs()),
|
||||
reason: Some(format!(
|
||||
"Mean reversion exit short: price ${:.2} {} mean ${:.2}",
|
||||
price, if price <= mean * 0.99 { "crossed below" } else { "hit stop loss above" }, mean
|
||||
)),
|
||||
metadata: Some(json!({
|
||||
"mean": mean,
|
||||
"price": price,
|
||||
"exit_type": "mean_reversion",
|
||||
})),
|
||||
});
|
||||
self.positions.remove(symbol);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
signals
|
||||
}
|
||||
|
||||
fn on_fill(&mut self, symbol: &str, quantity: f64, price: f64, side: &str) {
|
||||
// Position tracking is handled in on_market_data for simplicity
|
||||
eprintln!("Mean reversion fill: {} {} @ {} - {}", quantity, symbol, price, side);
|
||||
}
|
||||
|
||||
fn get_name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
fn get_parameters(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"lookback_period": self.lookback_period,
|
||||
"entry_threshold": self.entry_threshold,
|
||||
"exit_threshold": self.exit_threshold,
|
||||
"position_size": self.position_size,
|
||||
})
|
||||
}
|
||||
}
|
||||
279
apps/stock/core/src/strategies/mean_reversion_fixed.rs
Normal file
279
apps/stock/core/src/strategies/mean_reversion_fixed.rs
Normal file
|
|
@ -0,0 +1,279 @@
|
|||
use std::collections::HashMap;
|
||||
use serde_json::json;
|
||||
|
||||
use crate::{
|
||||
MarketUpdate, MarketDataType,
|
||||
backtest::{Strategy, Signal, SignalType},
|
||||
};
|
||||
|
||||
/// Fixed Mean Reversion Strategy that properly tracks positions
|
||||
///
|
||||
/// This version doesn't maintain its own position tracking but relies
|
||||
/// on the position information passed through on_fill callbacks
|
||||
pub struct MeanReversionFixedStrategy {
|
||||
name: String,
|
||||
id: String,
|
||||
lookback_period: usize,
|
||||
entry_threshold: f64, // Number of standard deviations
|
||||
exit_threshold: f64, // Exit when price moves back this fraction toward mean
|
||||
position_size: f64,
|
||||
|
||||
// State
|
||||
price_history: HashMap<String, Vec<f64>>,
|
||||
current_positions: HashMap<String, f64>, // Track actual positions from fills
|
||||
entry_prices: HashMap<String, f64>, // Track entry prices for exit decisions
|
||||
}
|
||||
|
||||
impl MeanReversionFixedStrategy {
|
||||
pub fn new(
|
||||
name: String,
|
||||
id: String,
|
||||
lookback_period: usize,
|
||||
entry_threshold: f64,
|
||||
position_size: f64,
|
||||
) -> Self {
|
||||
Self {
|
||||
name,
|
||||
id,
|
||||
lookback_period,
|
||||
entry_threshold,
|
||||
exit_threshold: 0.3, // Exit when price moves 30% back to mean
|
||||
position_size,
|
||||
price_history: HashMap::new(),
|
||||
current_positions: HashMap::new(),
|
||||
entry_prices: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn calculate_mean(prices: &[f64]) -> f64 {
|
||||
prices.iter().sum::<f64>() / prices.len() as f64
|
||||
}
|
||||
|
||||
fn calculate_std_dev(prices: &[f64], mean: f64) -> f64 {
|
||||
let variance = prices.iter()
|
||||
.map(|p| (p - mean).powi(2))
|
||||
.sum::<f64>() / prices.len() as f64;
|
||||
variance.sqrt()
|
||||
}
|
||||
}
|
||||
|
||||
impl Strategy for MeanReversionFixedStrategy {
|
||||
fn on_market_data(&mut self, data: &MarketUpdate) -> Vec<Signal> {
|
||||
let mut signals = Vec::new();
|
||||
|
||||
// Only process bar data
|
||||
if let MarketDataType::Bar(bar) = &data.data {
|
||||
let symbol = &data.symbol;
|
||||
let price = bar.close;
|
||||
|
||||
// Update price history
|
||||
let history = self.price_history.entry(symbol.clone()).or_insert_with(Vec::new);
|
||||
history.push(price);
|
||||
|
||||
// Keep only necessary history
|
||||
if history.len() > self.lookback_period {
|
||||
history.remove(0);
|
||||
}
|
||||
|
||||
// Need enough data
|
||||
if history.len() >= self.lookback_period {
|
||||
// Calculate statistics
|
||||
let mean = Self::calculate_mean(history);
|
||||
let std_dev = Self::calculate_std_dev(history, mean);
|
||||
|
||||
// Calculate bands
|
||||
let upper_band = mean + self.entry_threshold * std_dev;
|
||||
let lower_band = mean - self.entry_threshold * std_dev;
|
||||
|
||||
// Get actual position from our tracking
|
||||
let current_position = self.current_positions.get(symbol).copied().unwrap_or(0.0);
|
||||
|
||||
// Entry signals - only when flat
|
||||
if current_position.abs() < 0.001 {
|
||||
if price < lower_band {
|
||||
// Price is oversold, buy
|
||||
eprintln!("Mean reversion: {} oversold at ${:.2}, buying (lower band: ${:.2}, mean: ${:.2})",
|
||||
symbol, price, lower_band, mean);
|
||||
|
||||
signals.push(Signal {
|
||||
symbol: symbol.clone(),
|
||||
signal_type: SignalType::Buy,
|
||||
strength: 1.0,
|
||||
quantity: Some(self.position_size),
|
||||
reason: Some(format!(
|
||||
"Mean reversion buy: price ${:.2} < lower band ${:.2} (mean: ${:.2}, std: ${:.2})",
|
||||
price, lower_band, mean, std_dev
|
||||
)),
|
||||
metadata: Some(json!({
|
||||
"mean": mean,
|
||||
"std_dev": std_dev,
|
||||
"upper_band": upper_band,
|
||||
"lower_band": lower_band,
|
||||
"price": price,
|
||||
})),
|
||||
});
|
||||
} else if price > upper_band {
|
||||
// Price is overbought, sell short
|
||||
eprintln!("Mean reversion: {} overbought at ${:.2}, selling short (upper band: ${:.2}, mean: ${:.2})",
|
||||
symbol, price, upper_band, mean);
|
||||
|
||||
signals.push(Signal {
|
||||
symbol: symbol.clone(),
|
||||
signal_type: SignalType::Sell,
|
||||
strength: 1.0,
|
||||
quantity: Some(self.position_size),
|
||||
reason: Some(format!(
|
||||
"Mean reversion sell: price ${:.2} > upper band ${:.2} (mean: ${:.2}, std: ${:.2})",
|
||||
price, upper_band, mean, std_dev
|
||||
)),
|
||||
metadata: Some(json!({
|
||||
"mean": mean,
|
||||
"std_dev": std_dev,
|
||||
"upper_band": upper_band,
|
||||
"lower_band": lower_band,
|
||||
"price": price,
|
||||
})),
|
||||
});
|
||||
}
|
||||
}
|
||||
// Exit signals - only when we have a position
|
||||
else if current_position > 0.0 {
|
||||
// We're long - check exit conditions
|
||||
let entry_price = self.entry_prices.get(symbol).copied().unwrap_or(price);
|
||||
let target_price = entry_price + (mean - entry_price) * self.exit_threshold;
|
||||
let stop_loss = lower_band - std_dev; // Stop loss below lower band
|
||||
|
||||
if price >= target_price {
|
||||
eprintln!("Mean reversion: {} reached target ${:.2} (entry: ${:.2}, mean: ${:.2}), closing long",
|
||||
symbol, target_price, entry_price, mean);
|
||||
|
||||
signals.push(Signal {
|
||||
symbol: symbol.clone(),
|
||||
signal_type: SignalType::Sell,
|
||||
strength: 1.0,
|
||||
quantity: Some(current_position),
|
||||
reason: Some(format!(
|
||||
"Mean reversion exit long: price ${:.2} reached target ${:.2} (entry: ${:.2})",
|
||||
price, target_price, entry_price
|
||||
)),
|
||||
metadata: Some(json!({
|
||||
"mean": mean,
|
||||
"price": price,
|
||||
"entry_price": entry_price,
|
||||
"target_price": target_price,
|
||||
"exit_type": "target",
|
||||
})),
|
||||
});
|
||||
} else if price <= stop_loss {
|
||||
eprintln!("Mean reversion: {} hit stop loss ${:.2}, closing long",
|
||||
symbol, stop_loss);
|
||||
|
||||
signals.push(Signal {
|
||||
symbol: symbol.clone(),
|
||||
signal_type: SignalType::Sell,
|
||||
strength: 1.0,
|
||||
quantity: Some(current_position),
|
||||
reason: Some(format!(
|
||||
"Mean reversion stop loss: price ${:.2} <= stop ${:.2}",
|
||||
price, stop_loss
|
||||
)),
|
||||
metadata: Some(json!({
|
||||
"stop_loss": stop_loss,
|
||||
"price": price,
|
||||
"exit_type": "stop_loss",
|
||||
})),
|
||||
});
|
||||
}
|
||||
} else if current_position < 0.0 {
|
||||
// We're short - check exit conditions
|
||||
let entry_price = self.entry_prices.get(symbol).copied().unwrap_or(price);
|
||||
let target_price = entry_price - (entry_price - mean) * self.exit_threshold;
|
||||
let stop_loss = upper_band + std_dev; // Stop loss above upper band
|
||||
|
||||
if price <= target_price {
|
||||
eprintln!("Mean reversion: {} reached target ${:.2} (entry: ${:.2}, mean: ${:.2}), closing short",
|
||||
symbol, target_price, entry_price, mean);
|
||||
|
||||
signals.push(Signal {
|
||||
symbol: symbol.clone(),
|
||||
signal_type: SignalType::Buy,
|
||||
strength: 1.0,
|
||||
quantity: Some(current_position.abs()),
|
||||
reason: Some(format!(
|
||||
"Mean reversion exit short: price ${:.2} reached target ${:.2} (entry: ${:.2})",
|
||||
price, target_price, entry_price
|
||||
)),
|
||||
metadata: Some(json!({
|
||||
"mean": mean,
|
||||
"price": price,
|
||||
"entry_price": entry_price,
|
||||
"target_price": target_price,
|
||||
"exit_type": "target",
|
||||
})),
|
||||
});
|
||||
} else if price >= stop_loss {
|
||||
eprintln!("Mean reversion: {} hit stop loss ${:.2}, closing short",
|
||||
symbol, stop_loss);
|
||||
|
||||
signals.push(Signal {
|
||||
symbol: symbol.clone(),
|
||||
signal_type: SignalType::Buy,
|
||||
strength: 1.0,
|
||||
quantity: Some(current_position.abs()),
|
||||
reason: Some(format!(
|
||||
"Mean reversion stop loss: price ${:.2} >= stop ${:.2}",
|
||||
price, stop_loss
|
||||
)),
|
||||
metadata: Some(json!({
|
||||
"stop_loss": stop_loss,
|
||||
"price": price,
|
||||
"exit_type": "stop_loss",
|
||||
})),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
signals
|
||||
}
|
||||
|
||||
fn on_fill(&mut self, symbol: &str, quantity: f64, price: f64, side: &str) {
|
||||
// Update our position tracking based on actual fills
|
||||
let current = self.current_positions.get(symbol).copied().unwrap_or(0.0);
|
||||
let new_position = if side.contains("Buy") {
|
||||
current + quantity
|
||||
} else {
|
||||
current - quantity
|
||||
};
|
||||
|
||||
eprintln!("Mean reversion fill: {} {} @ {} - {}, position: {} -> {}",
|
||||
quantity, symbol, price, side, current, new_position);
|
||||
|
||||
// Update position
|
||||
if new_position.abs() < 0.001 {
|
||||
// Position closed
|
||||
self.current_positions.remove(symbol);
|
||||
self.entry_prices.remove(symbol);
|
||||
} else {
|
||||
self.current_positions.insert(symbol.to_string(), new_position);
|
||||
// Track entry price for new positions
|
||||
if current.abs() < 0.001 {
|
||||
self.entry_prices.insert(symbol.to_string(), price);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
fn get_parameters(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"lookback_period": self.lookback_period,
|
||||
"entry_threshold": self.entry_threshold,
|
||||
"exit_threshold": self.exit_threshold,
|
||||
"position_size": self.position_size,
|
||||
})
|
||||
}
|
||||
}
|
||||
9
apps/stock/core/src/strategies/mod.rs
Normal file
9
apps/stock/core/src/strategies/mod.rs
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
pub mod mean_reversion;
|
||||
pub mod mean_reversion_fixed;
|
||||
pub mod momentum;
|
||||
pub mod pairs_trading;
|
||||
|
||||
pub use mean_reversion::MeanReversionStrategy;
|
||||
pub use mean_reversion_fixed::MeanReversionFixedStrategy;
|
||||
pub use momentum::MomentumStrategy;
|
||||
pub use pairs_trading::PairsTradingStrategy;
|
||||
228
apps/stock/core/src/strategies/momentum.rs
Normal file
228
apps/stock/core/src/strategies/momentum.rs
Normal file
|
|
@ -0,0 +1,228 @@
|
|||
use std::collections::HashMap;
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde_json::json;
|
||||
|
||||
use crate::{
|
||||
MarketUpdate, MarketDataType,
|
||||
backtest::{Strategy, Signal, SignalType},
|
||||
};
|
||||
|
||||
/// Momentum Strategy
|
||||
///
|
||||
/// This strategy trades based on momentum indicators like rate of change (ROC) and
|
||||
/// relative strength. It aims to capture trends by buying securities showing
|
||||
/// upward momentum and selling those showing downward momentum.
|
||||
///
|
||||
/// Entry Signals:
|
||||
/// - BUY when momentum crosses above threshold and accelerating
|
||||
/// - SELL when momentum crosses below -threshold and decelerating
|
||||
///
|
||||
/// Exit Signals:
|
||||
/// - Exit long when momentum turns negative
|
||||
/// - Exit short when momentum turns positive
|
||||
pub struct MomentumStrategy {
|
||||
name: String,
|
||||
id: String,
|
||||
lookback_period: usize,
|
||||
momentum_threshold: f64,
|
||||
position_size: f64,
|
||||
use_acceleration: bool,
|
||||
|
||||
// State
|
||||
price_history: HashMap<String, Vec<f64>>,
|
||||
momentum_history: HashMap<String, Vec<f64>>,
|
||||
positions: HashMap<String, f64>,
|
||||
}
|
||||
|
||||
impl MomentumStrategy {
|
||||
pub fn new(
|
||||
name: String,
|
||||
id: String,
|
||||
lookback_period: usize,
|
||||
momentum_threshold: f64,
|
||||
position_size: f64,
|
||||
) -> Self {
|
||||
Self {
|
||||
name,
|
||||
id,
|
||||
lookback_period,
|
||||
momentum_threshold,
|
||||
position_size,
|
||||
use_acceleration: true,
|
||||
price_history: HashMap::new(),
|
||||
momentum_history: HashMap::new(),
|
||||
positions: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn calculate_momentum(prices: &[f64], lookback_period: usize) -> f64 {
|
||||
if prices.len() < 2 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let current = prices.last().unwrap();
|
||||
let past = prices[prices.len() - lookback_period.min(prices.len())];
|
||||
|
||||
((current - past) / past) * 100.0
|
||||
}
|
||||
|
||||
fn calculate_acceleration(momentum_values: &[f64]) -> f64 {
|
||||
if momentum_values.len() < 2 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let current = momentum_values.last().unwrap();
|
||||
let previous = momentum_values[momentum_values.len() - 2];
|
||||
|
||||
current - previous
|
||||
}
|
||||
}
|
||||
|
||||
impl Strategy for MomentumStrategy {
|
||||
fn on_market_data(&mut self, data: &MarketUpdate) -> Vec<Signal> {
|
||||
let mut signals = Vec::new();
|
||||
|
||||
// Only process bar data
|
||||
if let MarketDataType::Bar(bar) = &data.data {
|
||||
let symbol = &data.symbol;
|
||||
let price = bar.close;
|
||||
|
||||
// Update price history
|
||||
let history = self.price_history.entry(symbol.clone()).or_insert_with(Vec::new);
|
||||
history.push(price);
|
||||
|
||||
// Keep reasonable history
|
||||
if history.len() > self.lookback_period * 2 {
|
||||
history.remove(0);
|
||||
}
|
||||
|
||||
// Need enough data
|
||||
if history.len() >= self.lookback_period {
|
||||
// Calculate momentum
|
||||
let momentum = Self::calculate_momentum(history, self.lookback_period);
|
||||
|
||||
// Update momentum history
|
||||
let mom_history = self.momentum_history.entry(symbol.clone()).or_insert_with(Vec::new);
|
||||
mom_history.push(momentum);
|
||||
|
||||
if mom_history.len() > 5 {
|
||||
mom_history.remove(0);
|
||||
}
|
||||
|
||||
// Calculate acceleration if enabled
|
||||
let acceleration = if self.use_acceleration && mom_history.len() >= 2 {
|
||||
Self::calculate_acceleration(mom_history)
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
let current_position = self.positions.get(symbol).copied().unwrap_or(0.0);
|
||||
|
||||
// Check for entry signals
|
||||
if current_position == 0.0 {
|
||||
// Long entry: strong positive momentum and accelerating
|
||||
if momentum > self.momentum_threshold &&
|
||||
(!self.use_acceleration || acceleration > 0.0) {
|
||||
signals.push(Signal {
|
||||
symbol: symbol.clone(),
|
||||
signal_type: SignalType::Buy,
|
||||
strength: momentum / 100.0, // Normalize strength
|
||||
quantity: Some(self.position_size),
|
||||
reason: Some(format!(
|
||||
"Momentum buy: momentum {:.2}% > threshold {:.2}%, accel: {:.2}",
|
||||
momentum, self.momentum_threshold, acceleration
|
||||
)),
|
||||
metadata: Some(json!({
|
||||
"momentum": momentum,
|
||||
"acceleration": acceleration,
|
||||
"price": price,
|
||||
})),
|
||||
});
|
||||
self.positions.insert(symbol.clone(), self.position_size);
|
||||
}
|
||||
// Short entry: strong negative momentum and decelerating
|
||||
else if momentum < -self.momentum_threshold &&
|
||||
(!self.use_acceleration || acceleration < 0.0) {
|
||||
signals.push(Signal {
|
||||
symbol: symbol.clone(),
|
||||
signal_type: SignalType::Sell,
|
||||
strength: momentum.abs() / 100.0,
|
||||
quantity: Some(self.position_size),
|
||||
reason: Some(format!(
|
||||
"Momentum sell: momentum {:.2}% < threshold -{:.2}%, accel: {:.2}",
|
||||
momentum, self.momentum_threshold, acceleration
|
||||
)),
|
||||
metadata: Some(json!({
|
||||
"momentum": momentum,
|
||||
"acceleration": acceleration,
|
||||
"price": price,
|
||||
})),
|
||||
});
|
||||
self.positions.insert(symbol.clone(), -self.position_size);
|
||||
}
|
||||
}
|
||||
// Check for exit signals
|
||||
else if current_position > 0.0 {
|
||||
// Exit long when momentum turns negative
|
||||
if momentum < 0.0 {
|
||||
signals.push(Signal {
|
||||
symbol: symbol.clone(),
|
||||
signal_type: SignalType::Sell,
|
||||
strength: 1.0,
|
||||
quantity: Some(current_position),
|
||||
reason: Some(format!(
|
||||
"Momentum exit long: momentum turned negative {:.2}%",
|
||||
momentum
|
||||
)),
|
||||
metadata: Some(json!({
|
||||
"momentum": momentum,
|
||||
"price": price,
|
||||
"exit_type": "momentum_reversal",
|
||||
})),
|
||||
});
|
||||
self.positions.remove(symbol);
|
||||
}
|
||||
} else if current_position < 0.0 {
|
||||
// Exit short when momentum turns positive
|
||||
if momentum > 0.0 {
|
||||
signals.push(Signal {
|
||||
symbol: symbol.clone(),
|
||||
signal_type: SignalType::Buy,
|
||||
strength: 1.0,
|
||||
quantity: Some(current_position.abs()),
|
||||
reason: Some(format!(
|
||||
"Momentum exit short: momentum turned positive {:.2}%",
|
||||
momentum
|
||||
)),
|
||||
metadata: Some(json!({
|
||||
"momentum": momentum,
|
||||
"price": price,
|
||||
"exit_type": "momentum_reversal",
|
||||
})),
|
||||
});
|
||||
self.positions.remove(symbol);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
signals
|
||||
}
|
||||
|
||||
fn on_fill(&mut self, symbol: &str, quantity: f64, price: f64, side: &str) {
|
||||
eprintln!("Momentum fill: {} {} @ {} - {}", quantity, symbol, price, side);
|
||||
}
|
||||
|
||||
fn get_name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
fn get_parameters(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"lookback_period": self.lookback_period,
|
||||
"momentum_threshold": self.momentum_threshold,
|
||||
"position_size": self.position_size,
|
||||
"use_acceleration": self.use_acceleration,
|
||||
})
|
||||
}
|
||||
}
|
||||
295
apps/stock/core/src/strategies/pairs_trading.rs
Normal file
295
apps/stock/core/src/strategies/pairs_trading.rs
Normal file
|
|
@ -0,0 +1,295 @@
|
|||
use std::collections::HashMap;
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde_json::json;
|
||||
|
||||
use crate::{
|
||||
MarketUpdate, MarketDataType,
|
||||
backtest::{Strategy, Signal, SignalType},
|
||||
};
|
||||
|
||||
/// Pairs Trading Strategy
|
||||
///
|
||||
/// This strategy trades the spread between two correlated securities. When the spread
|
||||
/// deviates from its historical mean, we trade expecting it to revert.
|
||||
///
|
||||
/// Entry Signals:
|
||||
/// - Long pair A, Short pair B when spread < (mean - threshold * std)
|
||||
/// - Short pair A, Long pair B when spread > (mean + threshold * std)
|
||||
///
|
||||
/// Exit Signals:
|
||||
/// - Exit when spread returns to mean
|
||||
pub struct PairsTradingStrategy {
|
||||
name: String,
|
||||
id: String,
|
||||
pair_a: String,
|
||||
pair_b: String,
|
||||
lookback_period: usize,
|
||||
entry_threshold: f64, // Number of standard deviations
|
||||
position_size: f64,
|
||||
hedge_ratio: f64, // How many shares of B per share of A
|
||||
|
||||
// State
|
||||
price_history_a: Vec<f64>,
|
||||
price_history_b: Vec<f64>,
|
||||
spread_history: Vec<f64>,
|
||||
positions: HashMap<String, f64>,
|
||||
last_prices: HashMap<String, f64>,
|
||||
}
|
||||
|
||||
impl PairsTradingStrategy {
|
||||
pub fn new(
|
||||
name: String,
|
||||
id: String,
|
||||
pair_a: String,
|
||||
pair_b: String,
|
||||
lookback_period: usize,
|
||||
entry_threshold: f64,
|
||||
position_size: f64,
|
||||
) -> Self {
|
||||
Self {
|
||||
name,
|
||||
id,
|
||||
pair_a,
|
||||
pair_b,
|
||||
lookback_period,
|
||||
entry_threshold,
|
||||
position_size,
|
||||
hedge_ratio: 1.0, // Default 1:1, could be calculated dynamically
|
||||
price_history_a: Vec::new(),
|
||||
price_history_b: Vec::new(),
|
||||
spread_history: Vec::new(),
|
||||
positions: HashMap::new(),
|
||||
last_prices: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn calculate_spread(&self, price_a: f64, price_b: f64) -> f64 {
|
||||
price_a - self.hedge_ratio * price_b
|
||||
}
|
||||
|
||||
fn calculate_mean(values: &[f64]) -> f64 {
|
||||
values.iter().sum::<f64>() / values.len() as f64
|
||||
}
|
||||
|
||||
fn calculate_std_dev(values: &[f64], mean: f64) -> f64 {
|
||||
let variance = values.iter()
|
||||
.map(|v| (v - mean).powi(2))
|
||||
.sum::<f64>() / values.len() as f64;
|
||||
variance.sqrt()
|
||||
}
|
||||
}
|
||||
|
||||
impl Strategy for PairsTradingStrategy {
|
||||
fn on_market_data(&mut self, data: &MarketUpdate) -> Vec<Signal> {
|
||||
let mut signals = Vec::new();
|
||||
|
||||
// Only process bar data
|
||||
if let MarketDataType::Bar(bar) = &data.data {
|
||||
let symbol = &data.symbol;
|
||||
let price = bar.close;
|
||||
|
||||
// Update last prices
|
||||
self.last_prices.insert(symbol.clone(), price);
|
||||
|
||||
// Update price histories
|
||||
if symbol == &self.pair_a {
|
||||
self.price_history_a.push(price);
|
||||
if self.price_history_a.len() > self.lookback_period {
|
||||
self.price_history_a.remove(0);
|
||||
}
|
||||
} else if symbol == &self.pair_b {
|
||||
self.price_history_b.push(price);
|
||||
if self.price_history_b.len() > self.lookback_period {
|
||||
self.price_history_b.remove(0);
|
||||
}
|
||||
}
|
||||
|
||||
// Only generate signals when we have prices for both pairs
|
||||
if let (Some(&price_a), Some(&price_b)) =
|
||||
(self.last_prices.get(&self.pair_a), self.last_prices.get(&self.pair_b)) {
|
||||
|
||||
// Calculate current spread
|
||||
let spread = self.calculate_spread(price_a, price_b);
|
||||
|
||||
// Update spread history
|
||||
self.spread_history.push(spread);
|
||||
if self.spread_history.len() > self.lookback_period {
|
||||
self.spread_history.remove(0);
|
||||
}
|
||||
|
||||
// Need enough data
|
||||
if self.spread_history.len() >= self.lookback_period {
|
||||
// Calculate statistics
|
||||
let mean = Self::calculate_mean(&self.spread_history);
|
||||
let std_dev = Self::calculate_std_dev(&self.spread_history, mean);
|
||||
|
||||
// Calculate bands
|
||||
let upper_band = mean + self.entry_threshold * std_dev;
|
||||
let lower_band = mean - self.entry_threshold * std_dev;
|
||||
|
||||
let position_a = self.positions.get(&self.pair_a).copied().unwrap_or(0.0);
|
||||
let position_b = self.positions.get(&self.pair_b).copied().unwrap_or(0.0);
|
||||
|
||||
// Check for entry signals
|
||||
if position_a == 0.0 && position_b == 0.0 {
|
||||
if spread < lower_band {
|
||||
// Spread too low: Buy A, Sell B
|
||||
signals.push(Signal {
|
||||
symbol: self.pair_a.clone(),
|
||||
signal_type: SignalType::Buy,
|
||||
strength: 1.0,
|
||||
quantity: Some(self.position_size),
|
||||
reason: Some(format!(
|
||||
"Pairs trade: spread ${:.2} < lower band ${:.2}",
|
||||
spread, lower_band
|
||||
)),
|
||||
metadata: Some(json!({
|
||||
"spread": spread,
|
||||
"mean": mean,
|
||||
"std_dev": std_dev,
|
||||
"pair": "A",
|
||||
})),
|
||||
});
|
||||
|
||||
signals.push(Signal {
|
||||
symbol: self.pair_b.clone(),
|
||||
signal_type: SignalType::Sell,
|
||||
strength: 1.0,
|
||||
quantity: Some(self.position_size * self.hedge_ratio),
|
||||
reason: Some(format!(
|
||||
"Pairs trade hedge: spread ${:.2} < lower band ${:.2}",
|
||||
spread, lower_band
|
||||
)),
|
||||
metadata: Some(json!({
|
||||
"spread": spread,
|
||||
"mean": mean,
|
||||
"std_dev": std_dev,
|
||||
"pair": "B",
|
||||
})),
|
||||
});
|
||||
|
||||
self.positions.insert(self.pair_a.clone(), self.position_size);
|
||||
self.positions.insert(self.pair_b.clone(), -self.position_size * self.hedge_ratio);
|
||||
} else if spread > upper_band {
|
||||
// Spread too high: Sell A, Buy B
|
||||
signals.push(Signal {
|
||||
symbol: self.pair_a.clone(),
|
||||
signal_type: SignalType::Sell,
|
||||
strength: 1.0,
|
||||
quantity: Some(self.position_size),
|
||||
reason: Some(format!(
|
||||
"Pairs trade: spread ${:.2} > upper band ${:.2}",
|
||||
spread, upper_band
|
||||
)),
|
||||
metadata: Some(json!({
|
||||
"spread": spread,
|
||||
"mean": mean,
|
||||
"std_dev": std_dev,
|
||||
"pair": "A",
|
||||
})),
|
||||
});
|
||||
|
||||
signals.push(Signal {
|
||||
symbol: self.pair_b.clone(),
|
||||
signal_type: SignalType::Buy,
|
||||
strength: 1.0,
|
||||
quantity: Some(self.position_size * self.hedge_ratio),
|
||||
reason: Some(format!(
|
||||
"Pairs trade hedge: spread ${:.2} > upper band ${:.2}",
|
||||
spread, upper_band
|
||||
)),
|
||||
metadata: Some(json!({
|
||||
"spread": spread,
|
||||
"mean": mean,
|
||||
"std_dev": std_dev,
|
||||
"pair": "B",
|
||||
})),
|
||||
});
|
||||
|
||||
self.positions.insert(self.pair_a.clone(), -self.position_size);
|
||||
self.positions.insert(self.pair_b.clone(), self.position_size * self.hedge_ratio);
|
||||
}
|
||||
}
|
||||
// Check for exit signals
|
||||
else if position_a != 0.0 && position_b != 0.0 {
|
||||
// Exit when spread returns to mean
|
||||
let spread_distance = (spread - mean).abs();
|
||||
let exit_threshold = std_dev * 0.1; // Exit near mean
|
||||
|
||||
if spread_distance < exit_threshold {
|
||||
// Close positions
|
||||
if position_a > 0.0 {
|
||||
// We're long A, short B
|
||||
signals.push(Signal {
|
||||
symbol: self.pair_a.clone(),
|
||||
signal_type: SignalType::Sell,
|
||||
strength: 1.0,
|
||||
quantity: Some(position_a),
|
||||
reason: Some(format!(
|
||||
"Pairs trade exit: spread ${:.2} returned to mean ${:.2}",
|
||||
spread, mean
|
||||
)),
|
||||
metadata: None,
|
||||
});
|
||||
|
||||
signals.push(Signal {
|
||||
symbol: self.pair_b.clone(),
|
||||
signal_type: SignalType::Buy,
|
||||
strength: 1.0,
|
||||
quantity: Some(position_b.abs()),
|
||||
reason: Some("Pairs trade exit: closing hedge".to_string()),
|
||||
metadata: None,
|
||||
});
|
||||
} else {
|
||||
// We're short A, long B
|
||||
signals.push(Signal {
|
||||
symbol: self.pair_a.clone(),
|
||||
signal_type: SignalType::Buy,
|
||||
strength: 1.0,
|
||||
quantity: Some(position_a.abs()),
|
||||
reason: Some(format!(
|
||||
"Pairs trade exit: spread ${:.2} returned to mean ${:.2}",
|
||||
spread, mean
|
||||
)),
|
||||
metadata: None,
|
||||
});
|
||||
|
||||
signals.push(Signal {
|
||||
symbol: self.pair_b.clone(),
|
||||
signal_type: SignalType::Sell,
|
||||
strength: 1.0,
|
||||
quantity: Some(position_b),
|
||||
reason: Some("Pairs trade exit: closing hedge".to_string()),
|
||||
metadata: None,
|
||||
});
|
||||
}
|
||||
|
||||
self.positions.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
signals
|
||||
}
|
||||
|
||||
fn on_fill(&mut self, symbol: &str, quantity: f64, price: f64, side: &str) {
|
||||
eprintln!("Pairs trading fill: {} {} @ {} - {}", quantity, symbol, price, side);
|
||||
}
|
||||
|
||||
fn get_name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
fn get_parameters(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"pair_a": self.pair_a,
|
||||
"pair_b": self.pair_b,
|
||||
"lookback_period": self.lookback_period,
|
||||
"entry_threshold": self.entry_threshold,
|
||||
"position_size": self.position_size,
|
||||
"hedge_ratio": self.hedge_ratio,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -140,7 +140,7 @@ export class RustBacktestAdapter extends EventEmitter {
|
|||
timestamp: new Date(point[0]).getTime(),
|
||||
value: point[1],
|
||||
})),
|
||||
trades: this.transformFillsToTrades(rustResult.trades || []),
|
||||
trades: this.transformCompletedTradesToFills(rustResult.trades || []),
|
||||
dailyReturns: this.calculateDailyReturns(rustResult.equity_curve),
|
||||
finalPositions: rustResult.final_positions || {},
|
||||
executionTime: Date.now() - startTime,
|
||||
|
|
@ -290,9 +290,24 @@ export class RustBacktestAdapter extends EventEmitter {
|
|||
// Use native Rust strategy for maximum performance
|
||||
this.container.logger.info('Using native Rust strategy implementation');
|
||||
|
||||
// Map strategy names to their Rust strategy types
|
||||
let strategyType = 'sma_crossover'; // default
|
||||
|
||||
if (strategyName.toLowerCase().includes('mean') || strategyName.toLowerCase().includes('reversion')) {
|
||||
strategyType = 'mean_reversion';
|
||||
} else if (strategyName.toLowerCase().includes('momentum')) {
|
||||
strategyType = 'momentum';
|
||||
} else if (strategyName.toLowerCase().includes('pairs')) {
|
||||
strategyType = 'pairs_trading';
|
||||
} else if (strategyName.toLowerCase().includes('sma') || strategyName.toLowerCase().includes('crossover')) {
|
||||
strategyType = 'sma_crossover';
|
||||
}
|
||||
|
||||
this.container.logger.info(`Mapped strategy '${strategyName}' to type '${strategyType}'`);
|
||||
|
||||
// Use the addNativeStrategy method instead
|
||||
this.currentEngine.addNativeStrategy(
|
||||
'sma_crossover', // strategy type
|
||||
strategyType,
|
||||
strategyName,
|
||||
`strategy-${Date.now()}`,
|
||||
parameters
|
||||
|
|
@ -340,26 +355,168 @@ export class RustBacktestAdapter extends EventEmitter {
|
|||
};
|
||||
}
|
||||
|
||||
private transformFillsToTrades(completedTrades: any[]): any[] {
|
||||
// Now we have CompletedTrade objects with symbol and side information
|
||||
return completedTrades.map((trade, index) => {
|
||||
const timestamp = new Date(trade.timestamp);
|
||||
const side = trade.side === 'Buy' ? 'buy' : 'sell';
|
||||
|
||||
return {
|
||||
id: `trade-${index}`,
|
||||
private transformCompletedTradesToFills(completedTrades: any[]): any[] {
|
||||
// Convert completed trades (paired entry/exit) back to individual fills for the UI
|
||||
const fills: any[] = [];
|
||||
let fillId = 0;
|
||||
|
||||
completedTrades.forEach(trade => {
|
||||
// Create entry fill
|
||||
fills.push({
|
||||
id: `fill-${fillId++}`,
|
||||
timestamp: trade.entry_time || trade.entryDate,
|
||||
symbol: trade.symbol,
|
||||
entryDate: timestamp.toISOString(),
|
||||
exitDate: timestamp.toISOString(), // Same as entry for individual fills
|
||||
entryPrice: trade.price,
|
||||
exitPrice: trade.price,
|
||||
side: trade.side === 'Buy' || trade.side === 'long' ? 'buy' : 'sell',
|
||||
quantity: trade.quantity,
|
||||
side,
|
||||
pnl: 0, // Would need to calculate from paired trades
|
||||
pnlPercent: 0,
|
||||
commission: trade.commission,
|
||||
duration: 0, // Would need to calculate from paired trades
|
||||
};
|
||||
price: trade.entry_price || trade.entryPrice,
|
||||
commission: trade.commission / 2, // Split commission between entry and exit
|
||||
});
|
||||
|
||||
// Create exit fill (opposite side)
|
||||
const exitSide = (trade.side === 'Buy' || trade.side === 'long') ? 'sell' : 'buy';
|
||||
fills.push({
|
||||
id: `fill-${fillId++}`,
|
||||
timestamp: trade.exit_time || trade.exitDate,
|
||||
symbol: trade.symbol,
|
||||
side: exitSide,
|
||||
quantity: trade.quantity,
|
||||
price: trade.exit_price || trade.exitPrice,
|
||||
commission: trade.commission / 2,
|
||||
pnl: trade.pnl,
|
||||
});
|
||||
});
|
||||
|
||||
// Sort by timestamp
|
||||
fills.sort((a, b) => new Date(a.timestamp).getTime() - new Date(b.timestamp).getTime());
|
||||
|
||||
return fills;
|
||||
}
|
||||
|
||||
private transformFillsToTrades(completedTrades: any[]): any[] {
|
||||
// Group fills by symbol to match entries with exits
|
||||
const fillsBySymbol: { [symbol: string]: any[] } = {};
|
||||
|
||||
completedTrades.forEach(trade => {
|
||||
if (!fillsBySymbol[trade.symbol]) {
|
||||
fillsBySymbol[trade.symbol] = [];
|
||||
}
|
||||
fillsBySymbol[trade.symbol].push(trade);
|
||||
});
|
||||
|
||||
const pairedTrades: any[] = [];
|
||||
const openPositions: { [symbol: string]: any[] } = {};
|
||||
|
||||
// Process each symbol's fills chronologically
|
||||
Object.entries(fillsBySymbol).forEach(([symbol, fills]) => {
|
||||
// Sort by timestamp
|
||||
fills.sort((a, b) => new Date(a.timestamp).getTime() - new Date(b.timestamp).getTime());
|
||||
|
||||
fills.forEach((fill, idx) => {
|
||||
const isBuy = fill.side === 'Buy';
|
||||
const timestamp = new Date(fill.timestamp);
|
||||
|
||||
if (!openPositions[symbol]) {
|
||||
openPositions[symbol] = [];
|
||||
}
|
||||
|
||||
const openPos = openPositions[symbol];
|
||||
|
||||
// For buy fills, add to open positions
|
||||
if (isBuy) {
|
||||
openPos.push(fill);
|
||||
} else {
|
||||
// For sell fills, match with open buy positions (FIFO)
|
||||
if (openPos.length > 0 && openPos[0].side === 'Buy') {
|
||||
const entry = openPos.shift();
|
||||
const entryDate = new Date(entry.timestamp);
|
||||
const duration = (timestamp.getTime() - entryDate.getTime()) / 1000; // seconds
|
||||
const pnl = (fill.price - entry.price) * fill.quantity - entry.commission - fill.commission;
|
||||
const pnlPercent = ((fill.price - entry.price) / entry.price) * 100;
|
||||
|
||||
pairedTrades.push({
|
||||
id: `trade-${pairedTrades.length}`,
|
||||
symbol,
|
||||
entryDate: entryDate.toISOString(),
|
||||
exitDate: timestamp.toISOString(),
|
||||
entryPrice: entry.price,
|
||||
exitPrice: fill.price,
|
||||
quantity: fill.quantity,
|
||||
side: 'long',
|
||||
pnl,
|
||||
pnlPercent,
|
||||
commission: entry.commission + fill.commission,
|
||||
duration,
|
||||
});
|
||||
} else {
|
||||
// This is a short entry
|
||||
openPos.push(fill);
|
||||
}
|
||||
}
|
||||
|
||||
// For short positions (sell first, then buy to cover)
|
||||
if (!isBuy && openPos.length > 0) {
|
||||
const lastPos = openPos[openPos.length - 1];
|
||||
if (lastPos.side === 'Sell' && idx < fills.length - 1) {
|
||||
const nextFill = fills[idx + 1];
|
||||
if (nextFill && nextFill.side === 'Buy') {
|
||||
// We'll handle this when we process the buy fill
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle buy fills that close short positions
|
||||
if (isBuy && openPos.length > 1) {
|
||||
const shortPos = openPos.find(p => p.side === 'Sell');
|
||||
if (shortPos) {
|
||||
const shortIdx = openPos.indexOf(shortPos);
|
||||
openPos.splice(shortIdx, 1);
|
||||
|
||||
const entryDate = new Date(shortPos.timestamp);
|
||||
const duration = (timestamp.getTime() - entryDate.getTime()) / 1000;
|
||||
const pnl = (shortPos.price - fill.price) * fill.quantity - shortPos.commission - fill.commission;
|
||||
const pnlPercent = ((shortPos.price - fill.price) / shortPos.price) * 100;
|
||||
|
||||
pairedTrades.push({
|
||||
id: `trade-${pairedTrades.length}`,
|
||||
symbol,
|
||||
entryDate: entryDate.toISOString(),
|
||||
exitDate: timestamp.toISOString(),
|
||||
entryPrice: shortPos.price,
|
||||
exitPrice: fill.price,
|
||||
quantity: fill.quantity,
|
||||
side: 'short',
|
||||
pnl,
|
||||
pnlPercent,
|
||||
commission: shortPos.commission + fill.commission,
|
||||
duration,
|
||||
});
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Add any remaining open positions as incomplete trades
|
||||
const remainingOpenPositions = openPositions[symbol] || [];
|
||||
remainingOpenPositions.forEach(pos => {
|
||||
const timestamp = new Date(pos.timestamp);
|
||||
const side = pos.side === 'Buy' ? 'buy' : 'sell';
|
||||
|
||||
pairedTrades.push({
|
||||
id: `trade-${pairedTrades.length}`,
|
||||
symbol,
|
||||
entryDate: timestamp.toISOString(),
|
||||
exitDate: timestamp.toISOString(), // Same as entry for open positions
|
||||
entryPrice: pos.price,
|
||||
exitPrice: pos.price,
|
||||
quantity: pos.quantity,
|
||||
side,
|
||||
pnl: 0, // No PnL for open positions
|
||||
pnlPercent: 0,
|
||||
commission: pos.commission,
|
||||
duration: 0,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
return pairedTrades;
|
||||
}
|
||||
}
|
||||
180
apps/stock/orchestrator/test-mean-reversion.ts
Normal file
180
apps/stock/orchestrator/test-mean-reversion.ts
Normal file
|
|
@ -0,0 +1,180 @@
|
|||
import { RustBacktestAdapter } from './src/backtest/RustBacktestAdapter';
|
||||
import { IServiceContainer } from '@stock-bot/di';
|
||||
import { BacktestConfig } from './src/types';
|
||||
|
||||
// Mock container
|
||||
const mockContainer: IServiceContainer = {
|
||||
logger: {
|
||||
info: console.log,
|
||||
error: console.error,
|
||||
warn: console.warn,
|
||||
debug: console.log,
|
||||
},
|
||||
mongodb: {} as any,
|
||||
postgres: {} as any,
|
||||
redis: {} as any,
|
||||
custom: {},
|
||||
} as IServiceContainer;
|
||||
|
||||
// Mock storage service that returns test data
|
||||
class MockStorageService {
|
||||
async getHistoricalBars(symbol: string, startDate: Date, endDate: Date, frequency: string) {
|
||||
console.log(`MockStorageService: Getting bars for ${symbol} from ${startDate} to ${endDate}`);
|
||||
|
||||
// Generate test data with mean reverting behavior
|
||||
const bars = [];
|
||||
const startTime = startDate.getTime();
|
||||
const endTime = endDate.getTime();
|
||||
const dayMs = 24 * 60 * 60 * 1000;
|
||||
|
||||
let time = startTime;
|
||||
let dayIndex = 0;
|
||||
|
||||
// Base prices for different symbols
|
||||
const basePrices = {
|
||||
'AAPL': 150,
|
||||
'GOOGL': 2800,
|
||||
'MSFT': 400,
|
||||
};
|
||||
|
||||
const basePrice = basePrices[symbol as keyof typeof basePrices] || 100;
|
||||
|
||||
while (time <= endTime) {
|
||||
// Create mean reverting price movement
|
||||
// Price oscillates around the base price with increasing then decreasing deviations
|
||||
const cycleLength = 40; // 40 day cycle
|
||||
const positionInCycle = dayIndex % cycleLength;
|
||||
const halfCycle = cycleLength / 2;
|
||||
|
||||
let deviation;
|
||||
if (positionInCycle < halfCycle) {
|
||||
// First half: price moves away from mean
|
||||
deviation = (positionInCycle / halfCycle) * 0.1; // Up to 10% deviation
|
||||
} else {
|
||||
// Second half: price reverts to mean
|
||||
deviation = ((cycleLength - positionInCycle) / halfCycle) * 0.1;
|
||||
}
|
||||
|
||||
// Alternate between above and below mean
|
||||
const cycleNumber = Math.floor(dayIndex / cycleLength);
|
||||
const multiplier = cycleNumber % 2 === 0 ? 1 : -1;
|
||||
|
||||
const price = basePrice * (1 + multiplier * deviation);
|
||||
|
||||
// Add some noise
|
||||
const noise = (Math.random() - 0.5) * 0.02 * basePrice;
|
||||
const finalPrice = price + noise;
|
||||
|
||||
bars.push({
|
||||
timestamp: new Date(time),
|
||||
open: finalPrice * 0.99,
|
||||
high: finalPrice * 1.01,
|
||||
low: finalPrice * 0.98,
|
||||
close: finalPrice,
|
||||
volume: 1000000,
|
||||
vwap: finalPrice,
|
||||
});
|
||||
|
||||
time += dayMs;
|
||||
dayIndex++;
|
||||
}
|
||||
|
||||
console.log(`Generated ${bars.length} bars for ${symbol}, first close: ${bars[0].close.toFixed(2)}, last close: ${bars[bars.length - 1].close.toFixed(2)}`);
|
||||
return bars;
|
||||
}
|
||||
}
|
||||
|
||||
// Test the backtest
|
||||
async function testMeanReversionBacktest() {
|
||||
console.log('=== Testing Mean Reversion Backtest ===\n');
|
||||
|
||||
// Create adapter with mock storage
|
||||
const adapter = new RustBacktestAdapter(mockContainer);
|
||||
(adapter as any).storageService = new MockStorageService();
|
||||
|
||||
const config: BacktestConfig = {
|
||||
name: 'Mean Reversion Test',
|
||||
strategy: 'mean_reversion',
|
||||
symbols: ['AAPL', 'GOOGL', 'MSFT'],
|
||||
startDate: '2024-01-01T00:00:00Z',
|
||||
endDate: '2024-06-01T00:00:00Z',
|
||||
initialCapital: 100000,
|
||||
commission: 0.001,
|
||||
slippage: 0.0001,
|
||||
dataFrequency: '1d',
|
||||
config: {
|
||||
lookbackPeriod: 20,
|
||||
entryThreshold: 2.0,
|
||||
positionSize: 100,
|
||||
},
|
||||
};
|
||||
|
||||
try {
|
||||
console.log('Starting backtest...\n');
|
||||
const result = await adapter.runBacktest(config);
|
||||
|
||||
console.log('\n=== Backtest Results ===');
|
||||
console.log(`Status: ${result.status}`);
|
||||
console.log(`Total Trades: ${result.metrics.totalTrades}`);
|
||||
console.log(`Profitable Trades: ${result.metrics.profitableTrades}`);
|
||||
console.log(`Win Rate: ${result.metrics.winRate.toFixed(2)}%`);
|
||||
console.log(`Total Return: ${result.metrics.totalReturn.toFixed(2)}%`);
|
||||
console.log(`Sharpe Ratio: ${result.metrics.sharpeRatio.toFixed(2)}`);
|
||||
console.log(`Max Drawdown: ${result.metrics.maxDrawdown.toFixed(2)}%`);
|
||||
|
||||
console.log('\n=== Trade Analysis ===');
|
||||
console.log(`Number of completed trades: ${result.trades.length}`);
|
||||
|
||||
// Analyze trades by symbol
|
||||
const tradesBySymbol: Record<string, any[]> = {};
|
||||
result.trades.forEach(trade => {
|
||||
if (!tradesBySymbol[trade.symbol]) {
|
||||
tradesBySymbol[trade.symbol] = [];
|
||||
}
|
||||
tradesBySymbol[trade.symbol].push(trade);
|
||||
});
|
||||
|
||||
Object.entries(tradesBySymbol).forEach(([symbol, trades]) => {
|
||||
console.log(`\n${symbol}: ${trades.length} trades`);
|
||||
const longTrades = trades.filter(t => t.side === 'long');
|
||||
const shortTrades = trades.filter(t => t.side === 'short');
|
||||
console.log(` - Long trades: ${longTrades.length}`);
|
||||
console.log(` - Short trades: ${shortTrades.length}`);
|
||||
|
||||
// Count buy/sell pairs
|
||||
const buyTrades = trades.filter(t => t.side === 'buy');
|
||||
const sellTrades = trades.filter(t => t.side === 'sell');
|
||||
console.log(` - Buy trades: ${buyTrades.length}`);
|
||||
console.log(` - Sell trades: ${sellTrades.length}`);
|
||||
|
||||
// Show first few trades
|
||||
console.log(` - First 3 trades:`);
|
||||
trades.slice(0, 3).forEach((trade, idx) => {
|
||||
console.log(` ${idx + 1}. ${trade.side} - Price: $${trade.price.toFixed(2)}, Quantity: ${trade.quantity}${trade.pnl ? `, PnL: $${trade.pnl.toFixed(2)}` : ''}`);
|
||||
});
|
||||
});
|
||||
|
||||
// Check position distribution
|
||||
const allDurations = result.trades.map(t => t.duration / 86400); // Convert to days
|
||||
const avgDuration = allDurations.reduce((a, b) => a + b, 0) / allDurations.length;
|
||||
const minDuration = Math.min(...allDurations);
|
||||
const maxDuration = Math.max(...allDurations);
|
||||
|
||||
console.log('\n=== Duration Analysis ===');
|
||||
console.log(`Average trade duration: ${avgDuration.toFixed(1)} days`);
|
||||
console.log(`Min duration: ${minDuration.toFixed(1)} days`);
|
||||
console.log(`Max duration: ${maxDuration.toFixed(1)} days`);
|
||||
|
||||
// Final positions
|
||||
console.log('\n=== Final Positions ===');
|
||||
Object.entries(result.finalPositions).forEach(([symbol, position]) => {
|
||||
console.log(`${symbol}: ${position}`);
|
||||
});
|
||||
|
||||
} catch (error) {
|
||||
console.error('Backtest failed:', error);
|
||||
}
|
||||
}
|
||||
|
||||
// Run the test
|
||||
testMeanReversionBacktest().catch(console.error);
|
||||
101
apps/stock/orchestrator/test-trade-format.ts
Normal file
101
apps/stock/orchestrator/test-trade-format.ts
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
import { RustBacktestAdapter } from './src/backtest/RustBacktestAdapter';
|
||||
import { IServiceContainer } from '@stock-bot/di';
|
||||
import { BacktestConfig } from './src/types';
|
||||
|
||||
// Mock container
|
||||
const mockContainer: IServiceContainer = {
|
||||
logger: {
|
||||
info: console.log,
|
||||
error: console.error,
|
||||
warn: console.warn,
|
||||
debug: console.log,
|
||||
},
|
||||
mongodb: {} as any,
|
||||
postgres: {} as any,
|
||||
redis: {} as any,
|
||||
custom: {},
|
||||
} as IServiceContainer;
|
||||
|
||||
// Mock storage service
|
||||
class MockStorageService {
|
||||
async getHistoricalBars(symbol: string, startDate: Date, endDate: Date, frequency: string) {
|
||||
const bars = [];
|
||||
const startTime = startDate.getTime();
|
||||
const endTime = endDate.getTime();
|
||||
const dayMs = 24 * 60 * 60 * 1000;
|
||||
|
||||
let time = startTime;
|
||||
let dayIndex = 0;
|
||||
|
||||
// Simple oscillating price for testing
|
||||
while (time <= endTime) {
|
||||
const price = 100 + 10 * Math.sin(dayIndex * 0.2);
|
||||
|
||||
bars.push({
|
||||
timestamp: new Date(time),
|
||||
open: price * 0.99,
|
||||
high: price * 1.01,
|
||||
low: price * 0.98,
|
||||
close: price,
|
||||
volume: 1000000,
|
||||
vwap: price,
|
||||
});
|
||||
|
||||
time += dayMs;
|
||||
dayIndex++;
|
||||
}
|
||||
|
||||
return bars;
|
||||
}
|
||||
}
|
||||
|
||||
// Test the backtest
|
||||
async function testTradeFormat() {
|
||||
console.log('=== Testing Trade Format ===\n');
|
||||
|
||||
const adapter = new RustBacktestAdapter(mockContainer);
|
||||
(adapter as any).storageService = new MockStorageService();
|
||||
|
||||
const config: BacktestConfig = {
|
||||
name: 'Trade Format Test',
|
||||
strategy: 'Simple Moving Average Crossover',
|
||||
symbols: ['TEST'],
|
||||
startDate: '2024-01-01T00:00:00Z',
|
||||
endDate: '2024-03-01T00:00:00Z',
|
||||
initialCapital: 100000,
|
||||
commission: 0.001,
|
||||
slippage: 0.0001,
|
||||
dataFrequency: '1d',
|
||||
config: {
|
||||
fastPeriod: 5,
|
||||
slowPeriod: 15,
|
||||
},
|
||||
};
|
||||
|
||||
try {
|
||||
const result = await adapter.runBacktest(config);
|
||||
|
||||
console.log('\n=== Trade Format ===');
|
||||
console.log('Number of trades:', result.trades.length);
|
||||
console.log('\nFirst 3 trades:');
|
||||
result.trades.slice(0, 3).forEach((trade, idx) => {
|
||||
console.log(`\nTrade ${idx + 1}:`, JSON.stringify(trade, null, 2));
|
||||
});
|
||||
|
||||
// Check what format the trades are in
|
||||
if (result.trades.length > 0) {
|
||||
const firstTrade = result.trades[0];
|
||||
console.log('\n=== Trade Structure Analysis ===');
|
||||
console.log('Keys:', Object.keys(firstTrade));
|
||||
console.log('Has entryDate/exitDate?', 'entryDate' in firstTrade && 'exitDate' in firstTrade);
|
||||
console.log('Has timestamp?', 'timestamp' in firstTrade);
|
||||
console.log('Has side field?', 'side' in firstTrade);
|
||||
console.log('Side value:', firstTrade.side);
|
||||
}
|
||||
|
||||
} catch (error) {
|
||||
console.error('Test failed:', error);
|
||||
}
|
||||
}
|
||||
|
||||
testTradeFormat().catch(console.error);
|
||||
274
test-rust-strategies.js
Normal file
274
test-rust-strategies.js
Normal file
|
|
@ -0,0 +1,274 @@
|
|||
#!/usr/bin/env bun
|
||||
|
||||
import { BacktestEngine } from './apps/stock/core/index.js';
|
||||
|
||||
// Test configuration
|
||||
const config = {
|
||||
name: 'Native Rust Strategies Test',
|
||||
symbols: ['AA', 'AAS'],
|
||||
startDate: '2024-01-01T00:00:00Z',
|
||||
endDate: '2024-01-31T00:00:00Z',
|
||||
initialCapital: 100000,
|
||||
commission: 0.001,
|
||||
slippage: 0.0001,
|
||||
dataFrequency: '1d',
|
||||
};
|
||||
|
||||
// Create the Rust engine
|
||||
const engine = new BacktestEngine(config);
|
||||
|
||||
// Test different native Rust strategies
|
||||
console.log('Testing native Rust strategies...\n');
|
||||
|
||||
// 1. Test Mean Reversion Strategy
|
||||
console.log('1. Mean Reversion Strategy');
|
||||
engine.addNativeStrategy(
|
||||
'mean_reversion',
|
||||
'Mean Reversion AA/AAS',
|
||||
'mean-rev-1',
|
||||
{
|
||||
lookbackPeriod: 20,
|
||||
entryThreshold: 2.0,
|
||||
positionSize: 1000,
|
||||
}
|
||||
);
|
||||
|
||||
// Generate synthetic market data with mean-reverting characteristics
|
||||
const testData = [];
|
||||
const startDate = new Date('2024-01-01');
|
||||
|
||||
for (let i = 0; i < 30; i++) {
|
||||
const date = new Date(startDate);
|
||||
date.setDate(date.getDate() + i);
|
||||
|
||||
// AA: Mean-reverting around 100
|
||||
const aaMean = 100;
|
||||
const aaPrice = aaMean + Math.sin(i / 5) * 10 + (Math.random() - 0.5) * 5;
|
||||
|
||||
// AAS: Mean-reverting around 50
|
||||
const aasMean = 50;
|
||||
const aasPrice = aasMean + Math.sin(i / 5) * 5 + (Math.random() - 0.5) * 2.5;
|
||||
|
||||
testData.push({
|
||||
symbol: 'AA',
|
||||
timestamp: date.getTime(),
|
||||
type: 'bar',
|
||||
open: aaPrice - 0.5,
|
||||
high: aaPrice + 0.5,
|
||||
low: aaPrice - 1,
|
||||
close: aaPrice,
|
||||
volume: 1000000,
|
||||
vwap: aaPrice,
|
||||
});
|
||||
|
||||
testData.push({
|
||||
symbol: 'AAS',
|
||||
timestamp: date.getTime(),
|
||||
type: 'bar',
|
||||
open: aasPrice - 0.25,
|
||||
high: aasPrice + 0.25,
|
||||
low: aasPrice - 0.5,
|
||||
close: aasPrice,
|
||||
volume: 500000,
|
||||
vwap: aasPrice,
|
||||
});
|
||||
}
|
||||
|
||||
console.log(`Loading ${testData.length} market data points...`);
|
||||
engine.loadMarketData(testData);
|
||||
|
||||
// Run the backtest
|
||||
console.log('Running mean reversion backtest...');
|
||||
try {
|
||||
const resultJson = engine.run();
|
||||
const result = JSON.parse(resultJson);
|
||||
|
||||
console.log('\nResults:');
|
||||
console.log('Total trades:', result.trades?.length || 0);
|
||||
console.log('Win rate:', result.metrics.win_rate?.toFixed(2) + '%');
|
||||
console.log('Profit factor:', result.metrics.profit_factor?.toFixed(2));
|
||||
console.log('Total PnL:', '$' + result.metrics.total_pnl?.toFixed(2));
|
||||
console.log('Final equity:', '$' + result.equity_curve[result.equity_curve.length - 1]?.[1].toFixed(2));
|
||||
|
||||
// Show some trades
|
||||
if (result.trades && result.trades.length > 0) {
|
||||
console.log('\nFirst few trades:');
|
||||
result.trades.slice(0, 5).forEach((trade, i) => {
|
||||
console.log(` ${i + 1}. ${trade.symbol} ${trade.side} @ $${trade.price.toFixed(2)}`);
|
||||
});
|
||||
}
|
||||
|
||||
} catch (error) {
|
||||
console.error('Backtest failed:', error);
|
||||
}
|
||||
|
||||
console.log('\n' + '='.repeat(50) + '\n');
|
||||
|
||||
// 2. Test Momentum Strategy
|
||||
console.log('2. Momentum Strategy');
|
||||
|
||||
// Create a new engine for momentum strategy
|
||||
const engine2 = new BacktestEngine(config);
|
||||
|
||||
engine2.addNativeStrategy(
|
||||
'momentum',
|
||||
'Momentum Trading',
|
||||
'momentum-1',
|
||||
{
|
||||
lookbackPeriod: 10,
|
||||
momentumThreshold: 5.0,
|
||||
positionSize: 1000,
|
||||
}
|
||||
);
|
||||
|
||||
// Generate trending market data
|
||||
const trendData = [];
|
||||
for (let i = 0; i < 30; i++) {
|
||||
const date = new Date(startDate);
|
||||
date.setDate(date.getDate() + i);
|
||||
|
||||
// AA: Uptrend
|
||||
const aaPrice = 100 + i * 2 + (Math.random() - 0.5) * 2;
|
||||
|
||||
// AAS: Downtrend then uptrend
|
||||
const aasPrice = i < 15
|
||||
? 50 - i * 1 + (Math.random() - 0.5) * 1
|
||||
: 35 + (i - 15) * 1.5 + (Math.random() - 0.5) * 1;
|
||||
|
||||
trendData.push({
|
||||
symbol: 'AA',
|
||||
timestamp: date.getTime(),
|
||||
type: 'bar',
|
||||
open: aaPrice - 0.5,
|
||||
high: aaPrice + 0.5,
|
||||
low: aaPrice - 1,
|
||||
close: aaPrice,
|
||||
volume: 1000000,
|
||||
vwap: aaPrice,
|
||||
});
|
||||
|
||||
trendData.push({
|
||||
symbol: 'AAS',
|
||||
timestamp: date.getTime(),
|
||||
type: 'bar',
|
||||
open: aasPrice - 0.25,
|
||||
high: aasPrice + 0.25,
|
||||
low: aasPrice - 0.5,
|
||||
close: aasPrice,
|
||||
volume: 500000,
|
||||
vwap: aasPrice,
|
||||
});
|
||||
}
|
||||
|
||||
engine2.loadMarketData(trendData);
|
||||
|
||||
console.log('Running momentum backtest...');
|
||||
try {
|
||||
const resultJson = engine2.run();
|
||||
const result = JSON.parse(resultJson);
|
||||
|
||||
console.log('\nResults:');
|
||||
console.log('Total trades:', result.trades?.length || 0);
|
||||
console.log('Win rate:', result.metrics.win_rate?.toFixed(2) + '%');
|
||||
console.log('Profit factor:', result.metrics.profit_factor?.toFixed(2));
|
||||
console.log('Total PnL:', '$' + result.metrics.total_pnl?.toFixed(2));
|
||||
console.log('Final equity:', '$' + result.equity_curve[result.equity_curve.length - 1]?.[1].toFixed(2));
|
||||
|
||||
} catch (error) {
|
||||
console.error('Backtest failed:', error);
|
||||
}
|
||||
|
||||
console.log('\n' + '='.repeat(50) + '\n');
|
||||
|
||||
// 3. Test Pairs Trading Strategy
|
||||
console.log('3. Pairs Trading Strategy');
|
||||
|
||||
const engine3 = new BacktestEngine(config);
|
||||
|
||||
engine3.addNativeStrategy(
|
||||
'pairs_trading',
|
||||
'Pairs Trading AA/AAS',
|
||||
'pairs-1',
|
||||
{
|
||||
pairA: 'AA',
|
||||
pairB: 'AAS',
|
||||
lookbackPeriod: 20,
|
||||
entryThreshold: 2.0,
|
||||
positionSize: 1000,
|
||||
}
|
||||
);
|
||||
|
||||
// Generate correlated market data with spread deviations
|
||||
const pairsData = [];
|
||||
for (let i = 0; i < 30; i++) {
|
||||
const date = new Date(startDate);
|
||||
date.setDate(date.getDate() + i);
|
||||
|
||||
// Base prices
|
||||
const basePrice = 100 + Math.sin(i / 10) * 5;
|
||||
|
||||
// Spread oscillates around 50
|
||||
const spread = 50 + Math.sin(i / 3) * 10 + (Math.random() - 0.5) * 2;
|
||||
|
||||
const aaPrice = basePrice;
|
||||
const aasPrice = basePrice - spread;
|
||||
|
||||
pairsData.push({
|
||||
symbol: 'AA',
|
||||
timestamp: date.getTime(),
|
||||
type: 'bar',
|
||||
open: aaPrice - 0.5,
|
||||
high: aaPrice + 0.5,
|
||||
low: aaPrice - 1,
|
||||
close: aaPrice,
|
||||
volume: 1000000,
|
||||
vwap: aaPrice,
|
||||
});
|
||||
|
||||
pairsData.push({
|
||||
symbol: 'AAS',
|
||||
timestamp: date.getTime(),
|
||||
type: 'bar',
|
||||
open: aasPrice - 0.25,
|
||||
high: aasPrice + 0.25,
|
||||
low: aasPrice - 0.5,
|
||||
close: aasPrice,
|
||||
volume: 500000,
|
||||
vwap: aasPrice,
|
||||
});
|
||||
}
|
||||
|
||||
engine3.loadMarketData(pairsData);
|
||||
|
||||
console.log('Running pairs trading backtest...');
|
||||
try {
|
||||
const resultJson = engine3.run();
|
||||
const result = JSON.parse(resultJson);
|
||||
|
||||
console.log('\nResults:');
|
||||
console.log('Total trades:', result.trades?.length || 0);
|
||||
console.log('Win rate:', result.metrics.win_rate?.toFixed(2) + '%');
|
||||
console.log('Profit factor:', result.metrics.profit_factor?.toFixed(2));
|
||||
console.log('Total PnL:', '$' + result.metrics.total_pnl?.toFixed(2));
|
||||
console.log('Final equity:', '$' + result.equity_curve[result.equity_curve.length - 1]?.[1].toFixed(2));
|
||||
|
||||
// Show paired trades
|
||||
if (result.trades && result.trades.length > 0) {
|
||||
console.log('\nPairs trades (showing pairs):');
|
||||
for (let i = 0; i < result.trades.length && i < 6; i += 2) {
|
||||
const trade1 = result.trades[i];
|
||||
const trade2 = result.trades[i + 1];
|
||||
if (trade2) {
|
||||
console.log(` Pair ${Math.floor(i/2) + 1}: ${trade1.symbol} ${trade1.side} @ $${trade1.price.toFixed(2)}, ${trade2.symbol} ${trade2.side} @ $${trade2.price.toFixed(2)}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} catch (error) {
|
||||
console.error('Backtest failed:', error);
|
||||
}
|
||||
|
||||
console.log('\n' + '='.repeat(50));
|
||||
console.log('\nNative Rust strategies test complete!');
|
||||
console.log('\nThese strategies run at microsecond speeds in Rust,');
|
||||
console.log('perfect for high-frequency trading and production use.');
|
||||
Loading…
Add table
Add a link
Reference in a new issue