added trade-tracking and example rust strats
This commit is contained in:
parent
0a4702d12a
commit
3a7557c8f4
15 changed files with 2108 additions and 29 deletions
Binary file not shown.
|
|
@ -101,6 +101,58 @@ impl BacktestEngine {
|
|||
slow_period,
|
||||
)));
|
||||
}
|
||||
"mean_reversion" => {
|
||||
let lookback_period: usize = parameters.get_named_property::<f64>("lookbackPeriod")
|
||||
.unwrap_or(20.0) as usize;
|
||||
let entry_threshold: f64 = parameters.get_named_property::<f64>("entryThreshold")
|
||||
.unwrap_or(2.0);
|
||||
let position_size: f64 = parameters.get_named_property::<f64>("positionSize")
|
||||
.unwrap_or(100.0);
|
||||
|
||||
engine.add_strategy(Box::new(crate::strategies::MeanReversionFixedStrategy::new(
|
||||
name.clone(),
|
||||
id,
|
||||
lookback_period,
|
||||
entry_threshold,
|
||||
position_size,
|
||||
)));
|
||||
}
|
||||
"momentum" => {
|
||||
let lookback_period: usize = parameters.get_named_property::<f64>("lookbackPeriod")
|
||||
.unwrap_or(14.0) as usize;
|
||||
let momentum_threshold: f64 = parameters.get_named_property::<f64>("momentumThreshold")
|
||||
.unwrap_or(5.0);
|
||||
let position_size: f64 = parameters.get_named_property::<f64>("positionSize")
|
||||
.unwrap_or(100.0);
|
||||
|
||||
engine.add_strategy(Box::new(crate::strategies::MomentumStrategy::new(
|
||||
name.clone(),
|
||||
id,
|
||||
lookback_period,
|
||||
momentum_threshold,
|
||||
position_size,
|
||||
)));
|
||||
}
|
||||
"pairs_trading" => {
|
||||
let pair_a: String = parameters.get_named_property::<String>("pairA")?;
|
||||
let pair_b: String = parameters.get_named_property::<String>("pairB")?;
|
||||
let lookback_period: usize = parameters.get_named_property::<f64>("lookbackPeriod")
|
||||
.unwrap_or(20.0) as usize;
|
||||
let entry_threshold: f64 = parameters.get_named_property::<f64>("entryThreshold")
|
||||
.unwrap_or(2.0);
|
||||
let position_size: f64 = parameters.get_named_property::<f64>("positionSize")
|
||||
.unwrap_or(100.0);
|
||||
|
||||
engine.add_strategy(Box::new(crate::strategies::PairsTradingStrategy::new(
|
||||
name.clone(),
|
||||
id,
|
||||
pair_a,
|
||||
pair_b,
|
||||
lookback_period,
|
||||
entry_threshold,
|
||||
position_size,
|
||||
)));
|
||||
}
|
||||
_ => {
|
||||
return Err(Error::from_reason(format!("Unknown strategy type: {}", strategy_type)));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ use crate::{
|
|||
|
||||
use super::{
|
||||
BacktestConfig, BacktestState, EventQueue, BacktestEvent, EventType,
|
||||
Strategy, Signal, SignalType, BacktestResult,
|
||||
Strategy, Signal, SignalType, BacktestResult, TradeTracker,
|
||||
};
|
||||
|
||||
pub struct BacktestEngine {
|
||||
|
|
@ -38,6 +38,9 @@ pub struct BacktestEngine {
|
|||
|
||||
// Price tracking
|
||||
last_prices: HashMap<String, f64>,
|
||||
|
||||
// Trade tracking
|
||||
trade_tracker: TradeTracker,
|
||||
}
|
||||
|
||||
impl BacktestEngine {
|
||||
|
|
@ -67,6 +70,7 @@ impl BacktestEngine {
|
|||
profitable_trades: 0,
|
||||
total_pnl: 0.0,
|
||||
last_prices: HashMap::new(),
|
||||
trade_tracker: TradeTracker::new(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -394,6 +398,9 @@ impl BacktestEngine {
|
|||
// Record the fill with symbol and side information
|
||||
self.state.write().record_fill(order.symbol.clone(), order.side, fill.clone());
|
||||
|
||||
// Track trades
|
||||
self.trade_tracker.process_fill(&order.symbol, order.side, &fill);
|
||||
|
||||
// Update cash
|
||||
let cash_change = match order.side {
|
||||
crate::Side::Buy => -(fill.quantity * fill.price + fill.commission),
|
||||
|
|
@ -410,6 +417,13 @@ impl BacktestEngine {
|
|||
}
|
||||
}
|
||||
|
||||
eprintln!("Fill processed: {} {} @ {} (side: {:?})",
|
||||
fill.quantity, order.symbol, fill.price, order.side);
|
||||
eprintln!("Current position after fill: {}",
|
||||
self.position_tracker.get_position(&order.symbol)
|
||||
.map(|p| p.quantity)
|
||||
.unwrap_or(0.0));
|
||||
|
||||
// Update metrics
|
||||
self.total_trades += 1;
|
||||
if update.resulting_position.realized_pnl > 0.0 {
|
||||
|
|
@ -470,24 +484,43 @@ impl BacktestEngine {
|
|||
let total_pnl = realized_pnl + unrealized_pnl;
|
||||
let total_return = (total_pnl / self.config.initial_capital) * 100.0;
|
||||
|
||||
// Get completed trades from trade tracker for metrics
|
||||
let completed_trades = self.trade_tracker.get_completed_trades();
|
||||
|
||||
// Calculate metrics from completed trades
|
||||
let completed_trade_count = completed_trades.len();
|
||||
let profitable_trades = completed_trades.iter().filter(|t| t.pnl > 0.0).count();
|
||||
let total_wins: f64 = completed_trades.iter().filter(|t| t.pnl > 0.0).map(|t| t.pnl).sum();
|
||||
let total_losses: f64 = completed_trades.iter().filter(|t| t.pnl < 0.0).map(|t| t.pnl.abs()).sum();
|
||||
let avg_win = if profitable_trades > 0 { total_wins / profitable_trades as f64 } else { 0.0 };
|
||||
let avg_loss = if completed_trade_count > profitable_trades {
|
||||
total_losses / (completed_trade_count - profitable_trades) as f64
|
||||
} else { 0.0 };
|
||||
let profit_factor = if total_losses > 0.0 { total_wins / total_losses } else { 0.0 };
|
||||
|
||||
// For the API, return all fills (not just completed trades)
|
||||
// This shows all trading activity
|
||||
let all_fills = state.completed_trades.clone();
|
||||
let total_trades = all_fills.len();
|
||||
|
||||
BacktestResult {
|
||||
config: self.config.clone(),
|
||||
metrics: super::BacktestMetrics {
|
||||
total_return,
|
||||
total_trades: self.total_trades,
|
||||
profitable_trades: self.profitable_trades,
|
||||
win_rate: if self.total_trades > 0 {
|
||||
(self.profitable_trades as f64 / self.total_trades as f64) * 100.0
|
||||
total_trades,
|
||||
profitable_trades,
|
||||
win_rate: if completed_trade_count > 0 {
|
||||
(profitable_trades as f64 / completed_trade_count as f64) * 100.0
|
||||
} else { 0.0 },
|
||||
profit_factor: 0.0, // TODO: Calculate properly
|
||||
profit_factor,
|
||||
sharpe_ratio: 0.0, // TODO: Calculate properly
|
||||
max_drawdown: 0.0, // TODO: Calculate properly
|
||||
total_pnl,
|
||||
avg_win: 0.0, // TODO: Calculate properly
|
||||
avg_loss: 0.0, // TODO: Calculate properly
|
||||
avg_win,
|
||||
avg_loss,
|
||||
},
|
||||
equity_curve: state.equity_curve.clone(),
|
||||
trades: state.completed_trades.clone(),
|
||||
trades: all_fills,
|
||||
final_positions: self.position_tracker.get_all_positions()
|
||||
.into_iter()
|
||||
.map(|p| (p.symbol.clone(), p))
|
||||
|
|
|
|||
|
|
@ -12,11 +12,13 @@ pub mod engine;
|
|||
pub mod event;
|
||||
pub mod strategy;
|
||||
pub mod results;
|
||||
pub mod trade_tracker;
|
||||
|
||||
pub use engine::BacktestEngine;
|
||||
pub use event::{BacktestEvent, EventType};
|
||||
pub use strategy::{Strategy, Signal, SignalType};
|
||||
pub use results::{BacktestResult, BacktestMetrics};
|
||||
pub use trade_tracker::{TradeTracker, CompletedTrade as TrackedTrade};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CompletedTrade {
|
||||
|
|
|
|||
258
apps/stock/core/src/backtest/trade_tracker.rs
Normal file
258
apps/stock/core/src/backtest/trade_tracker.rs
Normal file
|
|
@ -0,0 +1,258 @@
|
|||
use std::collections::{HashMap, VecDeque};
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Serialize, Deserialize};
|
||||
use crate::{Fill, Side};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CompletedTrade {
|
||||
pub id: String,
|
||||
pub symbol: String,
|
||||
pub entry_time: DateTime<Utc>,
|
||||
pub exit_time: DateTime<Utc>,
|
||||
pub entry_price: f64,
|
||||
pub exit_price: f64,
|
||||
pub quantity: f64,
|
||||
pub side: Side, // Side of the opening trade
|
||||
pub pnl: f64,
|
||||
pub pnl_percent: f64,
|
||||
pub commission: f64,
|
||||
pub duration_seconds: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct OpenPosition {
|
||||
symbol: String,
|
||||
side: Side,
|
||||
quantity: f64,
|
||||
entry_price: f64,
|
||||
entry_time: DateTime<Utc>,
|
||||
commission: f64,
|
||||
}
|
||||
|
||||
/// Tracks fills and matches them into completed trades
|
||||
pub struct TradeTracker {
|
||||
open_positions: HashMap<String, VecDeque<OpenPosition>>,
|
||||
completed_trades: Vec<CompletedTrade>,
|
||||
trade_counter: u64,
|
||||
}
|
||||
|
||||
impl TradeTracker {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
open_positions: HashMap::new(),
|
||||
completed_trades: Vec::new(),
|
||||
trade_counter: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn process_fill(&mut self, symbol: &str, side: Side, fill: &Fill) {
|
||||
let positions = self.open_positions.entry(symbol.to_string()).or_insert_with(VecDeque::new);
|
||||
|
||||
// Check if this fill closes existing positions
|
||||
let mut remaining_quantity = fill.quantity;
|
||||
let mut fills_to_remove = Vec::new();
|
||||
|
||||
for (idx, open_pos) in positions.iter_mut().enumerate() {
|
||||
// Only match against opposite side positions
|
||||
if open_pos.side == side {
|
||||
continue;
|
||||
}
|
||||
|
||||
if remaining_quantity <= 0.0 {
|
||||
break;
|
||||
}
|
||||
|
||||
let matched_quantity = remaining_quantity.min(open_pos.quantity);
|
||||
|
||||
// Calculate PnL
|
||||
let (pnl, pnl_percent) = Self::calculate_pnl(
|
||||
&open_pos,
|
||||
fill.price,
|
||||
matched_quantity,
|
||||
fill.commission,
|
||||
);
|
||||
|
||||
// Create completed trade
|
||||
self.trade_counter += 1;
|
||||
let completed_trade = CompletedTrade {
|
||||
id: format!("trade-{}", self.trade_counter),
|
||||
symbol: symbol.to_string(),
|
||||
entry_time: open_pos.entry_time,
|
||||
exit_time: fill.timestamp,
|
||||
entry_price: open_pos.entry_price,
|
||||
exit_price: fill.price,
|
||||
quantity: matched_quantity,
|
||||
side: open_pos.side.clone(),
|
||||
pnl,
|
||||
pnl_percent,
|
||||
commission: open_pos.commission + (fill.commission * matched_quantity / fill.quantity),
|
||||
duration_seconds: (fill.timestamp - open_pos.entry_time).num_seconds(),
|
||||
};
|
||||
|
||||
self.completed_trades.push(completed_trade);
|
||||
|
||||
// Update open position
|
||||
open_pos.quantity -= matched_quantity;
|
||||
remaining_quantity -= matched_quantity;
|
||||
|
||||
if open_pos.quantity <= 0.0 {
|
||||
fills_to_remove.push(idx);
|
||||
}
|
||||
}
|
||||
|
||||
// Remove fully closed positions
|
||||
for idx in fills_to_remove.iter().rev() {
|
||||
positions.remove(*idx);
|
||||
}
|
||||
|
||||
// If there's remaining quantity, it opens a new position
|
||||
if remaining_quantity > 0.0 {
|
||||
let new_position = OpenPosition {
|
||||
symbol: symbol.to_string(),
|
||||
side,
|
||||
quantity: remaining_quantity,
|
||||
entry_price: fill.price,
|
||||
entry_time: fill.timestamp,
|
||||
commission: fill.commission * remaining_quantity / fill.quantity,
|
||||
};
|
||||
positions.push_back(new_position);
|
||||
}
|
||||
}
|
||||
|
||||
fn calculate_pnl(open_pos: &OpenPosition, exit_price: f64, quantity: f64, exit_commission: f64) -> (f64, f64) {
|
||||
let entry_value = open_pos.entry_price * quantity;
|
||||
let exit_value = exit_price * quantity;
|
||||
|
||||
let gross_pnl = match open_pos.side {
|
||||
Side::Buy => exit_value - entry_value,
|
||||
Side::Sell => entry_value - exit_value,
|
||||
};
|
||||
|
||||
let commission = open_pos.commission * (quantity / open_pos.quantity) + exit_commission * (quantity / open_pos.quantity);
|
||||
let net_pnl = gross_pnl - commission;
|
||||
let pnl_percent = (net_pnl / entry_value) * 100.0;
|
||||
|
||||
(net_pnl, pnl_percent)
|
||||
}
|
||||
|
||||
pub fn get_completed_trades(&self) -> &[CompletedTrade] {
|
||||
&self.completed_trades
|
||||
}
|
||||
|
||||
pub fn get_open_positions(&self) -> HashMap<String, Vec<(Side, f64, f64)>> {
|
||||
let mut result = HashMap::new();
|
||||
|
||||
for (symbol, positions) in &self.open_positions {
|
||||
let pos_info: Vec<(Side, f64, f64)> = positions
|
||||
.iter()
|
||||
.map(|p| (p.side.clone(), p.quantity, p.entry_price))
|
||||
.collect();
|
||||
|
||||
if !pos_info.is_empty() {
|
||||
result.insert(symbol.clone(), pos_info);
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
pub fn get_net_position(&self, symbol: &str) -> f64 {
|
||||
let positions = match self.open_positions.get(symbol) {
|
||||
Some(pos) => pos,
|
||||
None => return 0.0,
|
||||
};
|
||||
|
||||
let mut net = 0.0;
|
||||
for pos in positions {
|
||||
match pos.side {
|
||||
Side::Buy => net += pos.quantity,
|
||||
Side::Sell => net -= pos.quantity,
|
||||
}
|
||||
}
|
||||
|
||||
net
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_simple_round_trip() {
|
||||
let mut tracker = TradeTracker::new();
|
||||
|
||||
// Buy 100 shares at $50
|
||||
let buy_fill = Fill {
|
||||
timestamp: Utc::now(),
|
||||
price: 50.0,
|
||||
quantity: 100.0,
|
||||
commission: 1.0,
|
||||
};
|
||||
tracker.process_fill("AAPL", Side::Buy, &buy_fill);
|
||||
|
||||
// Sell 100 shares at $55
|
||||
let sell_fill = Fill {
|
||||
timestamp: Utc::now(),
|
||||
price: 55.0,
|
||||
quantity: 100.0,
|
||||
commission: 1.0,
|
||||
};
|
||||
tracker.process_fill("AAPL", Side::Sell, &sell_fill);
|
||||
|
||||
let trades = tracker.get_completed_trades();
|
||||
assert_eq!(trades.len(), 1);
|
||||
|
||||
let trade = &trades[0];
|
||||
assert_eq!(trade.symbol, "AAPL");
|
||||
assert_eq!(trade.quantity, 100.0);
|
||||
assert_eq!(trade.entry_price, 50.0);
|
||||
assert_eq!(trade.exit_price, 55.0);
|
||||
assert_eq!(trade.pnl, 498.0); // (55-50)*100 - 2 commission
|
||||
assert_eq!(trade.side, Side::Buy);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partial_fills() {
|
||||
let mut tracker = TradeTracker::new();
|
||||
|
||||
// Buy 100 shares at $50
|
||||
let buy_fill = Fill {
|
||||
timestamp: Utc::now(),
|
||||
price: 50.0,
|
||||
quantity: 100.0,
|
||||
commission: 1.0,
|
||||
};
|
||||
tracker.process_fill("AAPL", Side::Buy, &buy_fill);
|
||||
|
||||
// Sell 60 shares at $55
|
||||
let sell_fill1 = Fill {
|
||||
timestamp: Utc::now(),
|
||||
price: 55.0,
|
||||
quantity: 60.0,
|
||||
commission: 0.6,
|
||||
};
|
||||
tracker.process_fill("AAPL", Side::Sell, &sell_fill1);
|
||||
|
||||
// Check we have one completed trade and remaining position
|
||||
let trades = tracker.get_completed_trades();
|
||||
assert_eq!(trades.len(), 1);
|
||||
assert_eq!(trades[0].quantity, 60.0);
|
||||
|
||||
assert_eq!(tracker.get_net_position("AAPL"), 40.0);
|
||||
|
||||
// Sell remaining 40 shares at $52
|
||||
let sell_fill2 = Fill {
|
||||
timestamp: Utc::now(),
|
||||
price: 52.0,
|
||||
quantity: 40.0,
|
||||
commission: 0.4,
|
||||
};
|
||||
tracker.process_fill("AAPL", Side::Sell, &sell_fill2);
|
||||
|
||||
// Now we should have 2 completed trades and no position
|
||||
let trades = tracker.get_completed_trades();
|
||||
assert_eq!(trades.len(), 2);
|
||||
assert_eq!(tracker.get_net_position("AAPL"), 0.0);
|
||||
}
|
||||
}
|
||||
|
|
@ -8,6 +8,7 @@ pub mod api;
|
|||
pub mod analytics;
|
||||
pub mod indicators;
|
||||
pub mod backtest;
|
||||
pub mod strategies;
|
||||
|
||||
// Re-export commonly used types
|
||||
pub use positions::{Position, PositionUpdate, TradeRecord, ClosedTrade};
|
||||
|
|
|
|||
210
apps/stock/core/src/strategies/mean_reversion.rs
Normal file
210
apps/stock/core/src/strategies/mean_reversion.rs
Normal file
|
|
@ -0,0 +1,210 @@
|
|||
use std::collections::HashMap;
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde_json::json;
|
||||
|
||||
use crate::{
|
||||
MarketUpdate, MarketDataType,
|
||||
backtest::{Strategy, Signal, SignalType},
|
||||
};
|
||||
|
||||
/// Mean Reversion Strategy
|
||||
///
|
||||
/// This strategy identifies when a security's price deviates significantly from its
|
||||
/// moving average and trades on the assumption that it will revert back to the mean.
|
||||
///
|
||||
/// Entry Signals:
|
||||
/// - BUY when price falls below (MA - threshold * std_dev)
|
||||
/// - SELL when price rises above (MA + threshold * std_dev)
|
||||
///
|
||||
/// Exit Signals:
|
||||
/// - Exit long when price reaches MA
|
||||
/// - Exit short when price reaches MA
|
||||
pub struct MeanReversionStrategy {
|
||||
name: String,
|
||||
id: String,
|
||||
lookback_period: usize,
|
||||
entry_threshold: f64, // Number of standard deviations
|
||||
exit_threshold: f64, // Typically 0 (at the mean)
|
||||
position_size: f64,
|
||||
|
||||
// State
|
||||
price_history: HashMap<String, Vec<f64>>,
|
||||
positions: HashMap<String, f64>,
|
||||
}
|
||||
|
||||
impl MeanReversionStrategy {
|
||||
pub fn new(
|
||||
name: String,
|
||||
id: String,
|
||||
lookback_period: usize,
|
||||
entry_threshold: f64,
|
||||
position_size: f64,
|
||||
) -> Self {
|
||||
Self {
|
||||
name,
|
||||
id,
|
||||
lookback_period,
|
||||
entry_threshold,
|
||||
exit_threshold: 0.0, // Exit at the mean
|
||||
position_size,
|
||||
price_history: HashMap::new(),
|
||||
positions: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn calculate_mean(prices: &[f64]) -> f64 {
|
||||
prices.iter().sum::<f64>() / prices.len() as f64
|
||||
}
|
||||
|
||||
fn calculate_std_dev(prices: &[f64], mean: f64) -> f64 {
|
||||
let variance = prices.iter()
|
||||
.map(|p| (p - mean).powi(2))
|
||||
.sum::<f64>() / prices.len() as f64;
|
||||
variance.sqrt()
|
||||
}
|
||||
}
|
||||
|
||||
impl Strategy for MeanReversionStrategy {
|
||||
fn on_market_data(&mut self, data: &MarketUpdate) -> Vec<Signal> {
|
||||
let mut signals = Vec::new();
|
||||
|
||||
// Only process bar data
|
||||
if let MarketDataType::Bar(bar) = &data.data {
|
||||
let symbol = &data.symbol;
|
||||
let price = bar.close;
|
||||
|
||||
// Update price history
|
||||
let history = self.price_history.entry(symbol.clone()).or_insert_with(Vec::new);
|
||||
history.push(price);
|
||||
|
||||
// Keep only necessary history
|
||||
if history.len() > self.lookback_period {
|
||||
history.remove(0);
|
||||
}
|
||||
|
||||
// Need enough data
|
||||
if history.len() >= self.lookback_period {
|
||||
// Calculate statistics
|
||||
let mean = Self::calculate_mean(history);
|
||||
let std_dev = Self::calculate_std_dev(history, mean);
|
||||
|
||||
// Calculate bands
|
||||
let upper_band = mean + self.entry_threshold * std_dev;
|
||||
let lower_band = mean - self.entry_threshold * std_dev;
|
||||
|
||||
let current_position = self.positions.get(symbol).copied().unwrap_or(0.0);
|
||||
|
||||
// Check for entry signals
|
||||
if current_position == 0.0 {
|
||||
if price < lower_band {
|
||||
// Price is oversold, buy
|
||||
signals.push(Signal {
|
||||
symbol: symbol.clone(),
|
||||
signal_type: SignalType::Buy,
|
||||
strength: 1.0,
|
||||
quantity: Some(self.position_size),
|
||||
reason: Some(format!(
|
||||
"Mean reversion buy: price ${:.2} < lower band ${:.2} (mean: ${:.2}, std: ${:.2})",
|
||||
price, lower_band, mean, std_dev
|
||||
)),
|
||||
metadata: Some(json!({
|
||||
"mean": mean,
|
||||
"std_dev": std_dev,
|
||||
"upper_band": upper_band,
|
||||
"lower_band": lower_band,
|
||||
"price": price,
|
||||
})),
|
||||
});
|
||||
self.positions.insert(symbol.clone(), self.position_size);
|
||||
} else if price > upper_band {
|
||||
// Price is overbought, sell short
|
||||
signals.push(Signal {
|
||||
symbol: symbol.clone(),
|
||||
signal_type: SignalType::Sell,
|
||||
strength: 1.0,
|
||||
quantity: Some(self.position_size),
|
||||
reason: Some(format!(
|
||||
"Mean reversion sell: price ${:.2} > upper band ${:.2} (mean: ${:.2}, std: ${:.2})",
|
||||
price, upper_band, mean, std_dev
|
||||
)),
|
||||
metadata: Some(json!({
|
||||
"mean": mean,
|
||||
"std_dev": std_dev,
|
||||
"upper_band": upper_band,
|
||||
"lower_band": lower_band,
|
||||
"price": price,
|
||||
})),
|
||||
});
|
||||
self.positions.insert(symbol.clone(), -self.position_size);
|
||||
}
|
||||
}
|
||||
// Check for exit signals
|
||||
else if current_position > 0.0 {
|
||||
// We're long, exit when price crosses above mean (not just touches)
|
||||
// or when we hit stop loss
|
||||
let stop_loss = lower_band - std_dev; // Stop loss below lower band
|
||||
if price >= mean * 1.01 || price <= stop_loss {
|
||||
signals.push(Signal {
|
||||
symbol: symbol.clone(),
|
||||
signal_type: SignalType::Sell,
|
||||
strength: 1.0,
|
||||
quantity: Some(current_position),
|
||||
reason: Some(format!(
|
||||
"Mean reversion exit long: price ${:.2} {} mean ${:.2}",
|
||||
price, if price >= mean * 1.01 { "crossed above" } else { "hit stop loss below" }, mean
|
||||
)),
|
||||
metadata: Some(json!({
|
||||
"mean": mean,
|
||||
"price": price,
|
||||
"exit_type": "mean_reversion",
|
||||
})),
|
||||
});
|
||||
self.positions.remove(symbol);
|
||||
}
|
||||
} else if current_position < 0.0 {
|
||||
// We're short, exit when price crosses below mean
|
||||
// or when we hit stop loss
|
||||
let stop_loss = upper_band + std_dev; // Stop loss above upper band
|
||||
if price <= mean * 0.99 || price >= stop_loss {
|
||||
signals.push(Signal {
|
||||
symbol: symbol.clone(),
|
||||
signal_type: SignalType::Buy,
|
||||
strength: 1.0,
|
||||
quantity: Some(current_position.abs()),
|
||||
reason: Some(format!(
|
||||
"Mean reversion exit short: price ${:.2} {} mean ${:.2}",
|
||||
price, if price <= mean * 0.99 { "crossed below" } else { "hit stop loss above" }, mean
|
||||
)),
|
||||
metadata: Some(json!({
|
||||
"mean": mean,
|
||||
"price": price,
|
||||
"exit_type": "mean_reversion",
|
||||
})),
|
||||
});
|
||||
self.positions.remove(symbol);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
signals
|
||||
}
|
||||
|
||||
fn on_fill(&mut self, symbol: &str, quantity: f64, price: f64, side: &str) {
|
||||
// Position tracking is handled in on_market_data for simplicity
|
||||
eprintln!("Mean reversion fill: {} {} @ {} - {}", quantity, symbol, price, side);
|
||||
}
|
||||
|
||||
fn get_name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
fn get_parameters(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"lookback_period": self.lookback_period,
|
||||
"entry_threshold": self.entry_threshold,
|
||||
"exit_threshold": self.exit_threshold,
|
||||
"position_size": self.position_size,
|
||||
})
|
||||
}
|
||||
}
|
||||
279
apps/stock/core/src/strategies/mean_reversion_fixed.rs
Normal file
279
apps/stock/core/src/strategies/mean_reversion_fixed.rs
Normal file
|
|
@ -0,0 +1,279 @@
|
|||
use std::collections::HashMap;
|
||||
use serde_json::json;
|
||||
|
||||
use crate::{
|
||||
MarketUpdate, MarketDataType,
|
||||
backtest::{Strategy, Signal, SignalType},
|
||||
};
|
||||
|
||||
/// Fixed Mean Reversion Strategy that properly tracks positions
|
||||
///
|
||||
/// This version doesn't maintain its own position tracking but relies
|
||||
/// on the position information passed through on_fill callbacks
|
||||
pub struct MeanReversionFixedStrategy {
|
||||
name: String,
|
||||
id: String,
|
||||
lookback_period: usize,
|
||||
entry_threshold: f64, // Number of standard deviations
|
||||
exit_threshold: f64, // Exit when price moves back this fraction toward mean
|
||||
position_size: f64,
|
||||
|
||||
// State
|
||||
price_history: HashMap<String, Vec<f64>>,
|
||||
current_positions: HashMap<String, f64>, // Track actual positions from fills
|
||||
entry_prices: HashMap<String, f64>, // Track entry prices for exit decisions
|
||||
}
|
||||
|
||||
impl MeanReversionFixedStrategy {
|
||||
pub fn new(
|
||||
name: String,
|
||||
id: String,
|
||||
lookback_period: usize,
|
||||
entry_threshold: f64,
|
||||
position_size: f64,
|
||||
) -> Self {
|
||||
Self {
|
||||
name,
|
||||
id,
|
||||
lookback_period,
|
||||
entry_threshold,
|
||||
exit_threshold: 0.3, // Exit when price moves 30% back to mean
|
||||
position_size,
|
||||
price_history: HashMap::new(),
|
||||
current_positions: HashMap::new(),
|
||||
entry_prices: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn calculate_mean(prices: &[f64]) -> f64 {
|
||||
prices.iter().sum::<f64>() / prices.len() as f64
|
||||
}
|
||||
|
||||
fn calculate_std_dev(prices: &[f64], mean: f64) -> f64 {
|
||||
let variance = prices.iter()
|
||||
.map(|p| (p - mean).powi(2))
|
||||
.sum::<f64>() / prices.len() as f64;
|
||||
variance.sqrt()
|
||||
}
|
||||
}
|
||||
|
||||
impl Strategy for MeanReversionFixedStrategy {
|
||||
fn on_market_data(&mut self, data: &MarketUpdate) -> Vec<Signal> {
|
||||
let mut signals = Vec::new();
|
||||
|
||||
// Only process bar data
|
||||
if let MarketDataType::Bar(bar) = &data.data {
|
||||
let symbol = &data.symbol;
|
||||
let price = bar.close;
|
||||
|
||||
// Update price history
|
||||
let history = self.price_history.entry(symbol.clone()).or_insert_with(Vec::new);
|
||||
history.push(price);
|
||||
|
||||
// Keep only necessary history
|
||||
if history.len() > self.lookback_period {
|
||||
history.remove(0);
|
||||
}
|
||||
|
||||
// Need enough data
|
||||
if history.len() >= self.lookback_period {
|
||||
// Calculate statistics
|
||||
let mean = Self::calculate_mean(history);
|
||||
let std_dev = Self::calculate_std_dev(history, mean);
|
||||
|
||||
// Calculate bands
|
||||
let upper_band = mean + self.entry_threshold * std_dev;
|
||||
let lower_band = mean - self.entry_threshold * std_dev;
|
||||
|
||||
// Get actual position from our tracking
|
||||
let current_position = self.current_positions.get(symbol).copied().unwrap_or(0.0);
|
||||
|
||||
// Entry signals - only when flat
|
||||
if current_position.abs() < 0.001 {
|
||||
if price < lower_band {
|
||||
// Price is oversold, buy
|
||||
eprintln!("Mean reversion: {} oversold at ${:.2}, buying (lower band: ${:.2}, mean: ${:.2})",
|
||||
symbol, price, lower_band, mean);
|
||||
|
||||
signals.push(Signal {
|
||||
symbol: symbol.clone(),
|
||||
signal_type: SignalType::Buy,
|
||||
strength: 1.0,
|
||||
quantity: Some(self.position_size),
|
||||
reason: Some(format!(
|
||||
"Mean reversion buy: price ${:.2} < lower band ${:.2} (mean: ${:.2}, std: ${:.2})",
|
||||
price, lower_band, mean, std_dev
|
||||
)),
|
||||
metadata: Some(json!({
|
||||
"mean": mean,
|
||||
"std_dev": std_dev,
|
||||
"upper_band": upper_band,
|
||||
"lower_band": lower_band,
|
||||
"price": price,
|
||||
})),
|
||||
});
|
||||
} else if price > upper_band {
|
||||
// Price is overbought, sell short
|
||||
eprintln!("Mean reversion: {} overbought at ${:.2}, selling short (upper band: ${:.2}, mean: ${:.2})",
|
||||
symbol, price, upper_band, mean);
|
||||
|
||||
signals.push(Signal {
|
||||
symbol: symbol.clone(),
|
||||
signal_type: SignalType::Sell,
|
||||
strength: 1.0,
|
||||
quantity: Some(self.position_size),
|
||||
reason: Some(format!(
|
||||
"Mean reversion sell: price ${:.2} > upper band ${:.2} (mean: ${:.2}, std: ${:.2})",
|
||||
price, upper_band, mean, std_dev
|
||||
)),
|
||||
metadata: Some(json!({
|
||||
"mean": mean,
|
||||
"std_dev": std_dev,
|
||||
"upper_band": upper_band,
|
||||
"lower_band": lower_band,
|
||||
"price": price,
|
||||
})),
|
||||
});
|
||||
}
|
||||
}
|
||||
// Exit signals - only when we have a position
|
||||
else if current_position > 0.0 {
|
||||
// We're long - check exit conditions
|
||||
let entry_price = self.entry_prices.get(symbol).copied().unwrap_or(price);
|
||||
let target_price = entry_price + (mean - entry_price) * self.exit_threshold;
|
||||
let stop_loss = lower_band - std_dev; // Stop loss below lower band
|
||||
|
||||
if price >= target_price {
|
||||
eprintln!("Mean reversion: {} reached target ${:.2} (entry: ${:.2}, mean: ${:.2}), closing long",
|
||||
symbol, target_price, entry_price, mean);
|
||||
|
||||
signals.push(Signal {
|
||||
symbol: symbol.clone(),
|
||||
signal_type: SignalType::Sell,
|
||||
strength: 1.0,
|
||||
quantity: Some(current_position),
|
||||
reason: Some(format!(
|
||||
"Mean reversion exit long: price ${:.2} reached target ${:.2} (entry: ${:.2})",
|
||||
price, target_price, entry_price
|
||||
)),
|
||||
metadata: Some(json!({
|
||||
"mean": mean,
|
||||
"price": price,
|
||||
"entry_price": entry_price,
|
||||
"target_price": target_price,
|
||||
"exit_type": "target",
|
||||
})),
|
||||
});
|
||||
} else if price <= stop_loss {
|
||||
eprintln!("Mean reversion: {} hit stop loss ${:.2}, closing long",
|
||||
symbol, stop_loss);
|
||||
|
||||
signals.push(Signal {
|
||||
symbol: symbol.clone(),
|
||||
signal_type: SignalType::Sell,
|
||||
strength: 1.0,
|
||||
quantity: Some(current_position),
|
||||
reason: Some(format!(
|
||||
"Mean reversion stop loss: price ${:.2} <= stop ${:.2}",
|
||||
price, stop_loss
|
||||
)),
|
||||
metadata: Some(json!({
|
||||
"stop_loss": stop_loss,
|
||||
"price": price,
|
||||
"exit_type": "stop_loss",
|
||||
})),
|
||||
});
|
||||
}
|
||||
} else if current_position < 0.0 {
|
||||
// We're short - check exit conditions
|
||||
let entry_price = self.entry_prices.get(symbol).copied().unwrap_or(price);
|
||||
let target_price = entry_price - (entry_price - mean) * self.exit_threshold;
|
||||
let stop_loss = upper_band + std_dev; // Stop loss above upper band
|
||||
|
||||
if price <= target_price {
|
||||
eprintln!("Mean reversion: {} reached target ${:.2} (entry: ${:.2}, mean: ${:.2}), closing short",
|
||||
symbol, target_price, entry_price, mean);
|
||||
|
||||
signals.push(Signal {
|
||||
symbol: symbol.clone(),
|
||||
signal_type: SignalType::Buy,
|
||||
strength: 1.0,
|
||||
quantity: Some(current_position.abs()),
|
||||
reason: Some(format!(
|
||||
"Mean reversion exit short: price ${:.2} reached target ${:.2} (entry: ${:.2})",
|
||||
price, target_price, entry_price
|
||||
)),
|
||||
metadata: Some(json!({
|
||||
"mean": mean,
|
||||
"price": price,
|
||||
"entry_price": entry_price,
|
||||
"target_price": target_price,
|
||||
"exit_type": "target",
|
||||
})),
|
||||
});
|
||||
} else if price >= stop_loss {
|
||||
eprintln!("Mean reversion: {} hit stop loss ${:.2}, closing short",
|
||||
symbol, stop_loss);
|
||||
|
||||
signals.push(Signal {
|
||||
symbol: symbol.clone(),
|
||||
signal_type: SignalType::Buy,
|
||||
strength: 1.0,
|
||||
quantity: Some(current_position.abs()),
|
||||
reason: Some(format!(
|
||||
"Mean reversion stop loss: price ${:.2} >= stop ${:.2}",
|
||||
price, stop_loss
|
||||
)),
|
||||
metadata: Some(json!({
|
||||
"stop_loss": stop_loss,
|
||||
"price": price,
|
||||
"exit_type": "stop_loss",
|
||||
})),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
signals
|
||||
}
|
||||
|
||||
fn on_fill(&mut self, symbol: &str, quantity: f64, price: f64, side: &str) {
|
||||
// Update our position tracking based on actual fills
|
||||
let current = self.current_positions.get(symbol).copied().unwrap_or(0.0);
|
||||
let new_position = if side.contains("Buy") {
|
||||
current + quantity
|
||||
} else {
|
||||
current - quantity
|
||||
};
|
||||
|
||||
eprintln!("Mean reversion fill: {} {} @ {} - {}, position: {} -> {}",
|
||||
quantity, symbol, price, side, current, new_position);
|
||||
|
||||
// Update position
|
||||
if new_position.abs() < 0.001 {
|
||||
// Position closed
|
||||
self.current_positions.remove(symbol);
|
||||
self.entry_prices.remove(symbol);
|
||||
} else {
|
||||
self.current_positions.insert(symbol.to_string(), new_position);
|
||||
// Track entry price for new positions
|
||||
if current.abs() < 0.001 {
|
||||
self.entry_prices.insert(symbol.to_string(), price);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
fn get_parameters(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"lookback_period": self.lookback_period,
|
||||
"entry_threshold": self.entry_threshold,
|
||||
"exit_threshold": self.exit_threshold,
|
||||
"position_size": self.position_size,
|
||||
})
|
||||
}
|
||||
}
|
||||
9
apps/stock/core/src/strategies/mod.rs
Normal file
9
apps/stock/core/src/strategies/mod.rs
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
pub mod mean_reversion;
|
||||
pub mod mean_reversion_fixed;
|
||||
pub mod momentum;
|
||||
pub mod pairs_trading;
|
||||
|
||||
pub use mean_reversion::MeanReversionStrategy;
|
||||
pub use mean_reversion_fixed::MeanReversionFixedStrategy;
|
||||
pub use momentum::MomentumStrategy;
|
||||
pub use pairs_trading::PairsTradingStrategy;
|
||||
228
apps/stock/core/src/strategies/momentum.rs
Normal file
228
apps/stock/core/src/strategies/momentum.rs
Normal file
|
|
@ -0,0 +1,228 @@
|
|||
use std::collections::HashMap;
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde_json::json;
|
||||
|
||||
use crate::{
|
||||
MarketUpdate, MarketDataType,
|
||||
backtest::{Strategy, Signal, SignalType},
|
||||
};
|
||||
|
||||
/// Momentum Strategy
|
||||
///
|
||||
/// This strategy trades based on momentum indicators like rate of change (ROC) and
|
||||
/// relative strength. It aims to capture trends by buying securities showing
|
||||
/// upward momentum and selling those showing downward momentum.
|
||||
///
|
||||
/// Entry Signals:
|
||||
/// - BUY when momentum crosses above threshold and accelerating
|
||||
/// - SELL when momentum crosses below -threshold and decelerating
|
||||
///
|
||||
/// Exit Signals:
|
||||
/// - Exit long when momentum turns negative
|
||||
/// - Exit short when momentum turns positive
|
||||
pub struct MomentumStrategy {
|
||||
name: String,
|
||||
id: String,
|
||||
lookback_period: usize,
|
||||
momentum_threshold: f64,
|
||||
position_size: f64,
|
||||
use_acceleration: bool,
|
||||
|
||||
// State
|
||||
price_history: HashMap<String, Vec<f64>>,
|
||||
momentum_history: HashMap<String, Vec<f64>>,
|
||||
positions: HashMap<String, f64>,
|
||||
}
|
||||
|
||||
impl MomentumStrategy {
|
||||
pub fn new(
|
||||
name: String,
|
||||
id: String,
|
||||
lookback_period: usize,
|
||||
momentum_threshold: f64,
|
||||
position_size: f64,
|
||||
) -> Self {
|
||||
Self {
|
||||
name,
|
||||
id,
|
||||
lookback_period,
|
||||
momentum_threshold,
|
||||
position_size,
|
||||
use_acceleration: true,
|
||||
price_history: HashMap::new(),
|
||||
momentum_history: HashMap::new(),
|
||||
positions: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn calculate_momentum(prices: &[f64], lookback_period: usize) -> f64 {
|
||||
if prices.len() < 2 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let current = prices.last().unwrap();
|
||||
let past = prices[prices.len() - lookback_period.min(prices.len())];
|
||||
|
||||
((current - past) / past) * 100.0
|
||||
}
|
||||
|
||||
fn calculate_acceleration(momentum_values: &[f64]) -> f64 {
|
||||
if momentum_values.len() < 2 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let current = momentum_values.last().unwrap();
|
||||
let previous = momentum_values[momentum_values.len() - 2];
|
||||
|
||||
current - previous
|
||||
}
|
||||
}
|
||||
|
||||
impl Strategy for MomentumStrategy {
|
||||
fn on_market_data(&mut self, data: &MarketUpdate) -> Vec<Signal> {
|
||||
let mut signals = Vec::new();
|
||||
|
||||
// Only process bar data
|
||||
if let MarketDataType::Bar(bar) = &data.data {
|
||||
let symbol = &data.symbol;
|
||||
let price = bar.close;
|
||||
|
||||
// Update price history
|
||||
let history = self.price_history.entry(symbol.clone()).or_insert_with(Vec::new);
|
||||
history.push(price);
|
||||
|
||||
// Keep reasonable history
|
||||
if history.len() > self.lookback_period * 2 {
|
||||
history.remove(0);
|
||||
}
|
||||
|
||||
// Need enough data
|
||||
if history.len() >= self.lookback_period {
|
||||
// Calculate momentum
|
||||
let momentum = Self::calculate_momentum(history, self.lookback_period);
|
||||
|
||||
// Update momentum history
|
||||
let mom_history = self.momentum_history.entry(symbol.clone()).or_insert_with(Vec::new);
|
||||
mom_history.push(momentum);
|
||||
|
||||
if mom_history.len() > 5 {
|
||||
mom_history.remove(0);
|
||||
}
|
||||
|
||||
// Calculate acceleration if enabled
|
||||
let acceleration = if self.use_acceleration && mom_history.len() >= 2 {
|
||||
Self::calculate_acceleration(mom_history)
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
let current_position = self.positions.get(symbol).copied().unwrap_or(0.0);
|
||||
|
||||
// Check for entry signals
|
||||
if current_position == 0.0 {
|
||||
// Long entry: strong positive momentum and accelerating
|
||||
if momentum > self.momentum_threshold &&
|
||||
(!self.use_acceleration || acceleration > 0.0) {
|
||||
signals.push(Signal {
|
||||
symbol: symbol.clone(),
|
||||
signal_type: SignalType::Buy,
|
||||
strength: momentum / 100.0, // Normalize strength
|
||||
quantity: Some(self.position_size),
|
||||
reason: Some(format!(
|
||||
"Momentum buy: momentum {:.2}% > threshold {:.2}%, accel: {:.2}",
|
||||
momentum, self.momentum_threshold, acceleration
|
||||
)),
|
||||
metadata: Some(json!({
|
||||
"momentum": momentum,
|
||||
"acceleration": acceleration,
|
||||
"price": price,
|
||||
})),
|
||||
});
|
||||
self.positions.insert(symbol.clone(), self.position_size);
|
||||
}
|
||||
// Short entry: strong negative momentum and decelerating
|
||||
else if momentum < -self.momentum_threshold &&
|
||||
(!self.use_acceleration || acceleration < 0.0) {
|
||||
signals.push(Signal {
|
||||
symbol: symbol.clone(),
|
||||
signal_type: SignalType::Sell,
|
||||
strength: momentum.abs() / 100.0,
|
||||
quantity: Some(self.position_size),
|
||||
reason: Some(format!(
|
||||
"Momentum sell: momentum {:.2}% < threshold -{:.2}%, accel: {:.2}",
|
||||
momentum, self.momentum_threshold, acceleration
|
||||
)),
|
||||
metadata: Some(json!({
|
||||
"momentum": momentum,
|
||||
"acceleration": acceleration,
|
||||
"price": price,
|
||||
})),
|
||||
});
|
||||
self.positions.insert(symbol.clone(), -self.position_size);
|
||||
}
|
||||
}
|
||||
// Check for exit signals
|
||||
else if current_position > 0.0 {
|
||||
// Exit long when momentum turns negative
|
||||
if momentum < 0.0 {
|
||||
signals.push(Signal {
|
||||
symbol: symbol.clone(),
|
||||
signal_type: SignalType::Sell,
|
||||
strength: 1.0,
|
||||
quantity: Some(current_position),
|
||||
reason: Some(format!(
|
||||
"Momentum exit long: momentum turned negative {:.2}%",
|
||||
momentum
|
||||
)),
|
||||
metadata: Some(json!({
|
||||
"momentum": momentum,
|
||||
"price": price,
|
||||
"exit_type": "momentum_reversal",
|
||||
})),
|
||||
});
|
||||
self.positions.remove(symbol);
|
||||
}
|
||||
} else if current_position < 0.0 {
|
||||
// Exit short when momentum turns positive
|
||||
if momentum > 0.0 {
|
||||
signals.push(Signal {
|
||||
symbol: symbol.clone(),
|
||||
signal_type: SignalType::Buy,
|
||||
strength: 1.0,
|
||||
quantity: Some(current_position.abs()),
|
||||
reason: Some(format!(
|
||||
"Momentum exit short: momentum turned positive {:.2}%",
|
||||
momentum
|
||||
)),
|
||||
metadata: Some(json!({
|
||||
"momentum": momentum,
|
||||
"price": price,
|
||||
"exit_type": "momentum_reversal",
|
||||
})),
|
||||
});
|
||||
self.positions.remove(symbol);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
signals
|
||||
}
|
||||
|
||||
fn on_fill(&mut self, symbol: &str, quantity: f64, price: f64, side: &str) {
|
||||
eprintln!("Momentum fill: {} {} @ {} - {}", quantity, symbol, price, side);
|
||||
}
|
||||
|
||||
fn get_name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
fn get_parameters(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"lookback_period": self.lookback_period,
|
||||
"momentum_threshold": self.momentum_threshold,
|
||||
"position_size": self.position_size,
|
||||
"use_acceleration": self.use_acceleration,
|
||||
})
|
||||
}
|
||||
}
|
||||
295
apps/stock/core/src/strategies/pairs_trading.rs
Normal file
295
apps/stock/core/src/strategies/pairs_trading.rs
Normal file
|
|
@ -0,0 +1,295 @@
|
|||
use std::collections::HashMap;
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde_json::json;
|
||||
|
||||
use crate::{
|
||||
MarketUpdate, MarketDataType,
|
||||
backtest::{Strategy, Signal, SignalType},
|
||||
};
|
||||
|
||||
/// Pairs Trading Strategy
|
||||
///
|
||||
/// This strategy trades the spread between two correlated securities. When the spread
|
||||
/// deviates from its historical mean, we trade expecting it to revert.
|
||||
///
|
||||
/// Entry Signals:
|
||||
/// - Long pair A, Short pair B when spread < (mean - threshold * std)
|
||||
/// - Short pair A, Long pair B when spread > (mean + threshold * std)
|
||||
///
|
||||
/// Exit Signals:
|
||||
/// - Exit when spread returns to mean
|
||||
pub struct PairsTradingStrategy {
|
||||
name: String,
|
||||
id: String,
|
||||
pair_a: String,
|
||||
pair_b: String,
|
||||
lookback_period: usize,
|
||||
entry_threshold: f64, // Number of standard deviations
|
||||
position_size: f64,
|
||||
hedge_ratio: f64, // How many shares of B per share of A
|
||||
|
||||
// State
|
||||
price_history_a: Vec<f64>,
|
||||
price_history_b: Vec<f64>,
|
||||
spread_history: Vec<f64>,
|
||||
positions: HashMap<String, f64>,
|
||||
last_prices: HashMap<String, f64>,
|
||||
}
|
||||
|
||||
impl PairsTradingStrategy {
|
||||
pub fn new(
|
||||
name: String,
|
||||
id: String,
|
||||
pair_a: String,
|
||||
pair_b: String,
|
||||
lookback_period: usize,
|
||||
entry_threshold: f64,
|
||||
position_size: f64,
|
||||
) -> Self {
|
||||
Self {
|
||||
name,
|
||||
id,
|
||||
pair_a,
|
||||
pair_b,
|
||||
lookback_period,
|
||||
entry_threshold,
|
||||
position_size,
|
||||
hedge_ratio: 1.0, // Default 1:1, could be calculated dynamically
|
||||
price_history_a: Vec::new(),
|
||||
price_history_b: Vec::new(),
|
||||
spread_history: Vec::new(),
|
||||
positions: HashMap::new(),
|
||||
last_prices: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn calculate_spread(&self, price_a: f64, price_b: f64) -> f64 {
|
||||
price_a - self.hedge_ratio * price_b
|
||||
}
|
||||
|
||||
fn calculate_mean(values: &[f64]) -> f64 {
|
||||
values.iter().sum::<f64>() / values.len() as f64
|
||||
}
|
||||
|
||||
fn calculate_std_dev(values: &[f64], mean: f64) -> f64 {
|
||||
let variance = values.iter()
|
||||
.map(|v| (v - mean).powi(2))
|
||||
.sum::<f64>() / values.len() as f64;
|
||||
variance.sqrt()
|
||||
}
|
||||
}
|
||||
|
||||
impl Strategy for PairsTradingStrategy {
|
||||
fn on_market_data(&mut self, data: &MarketUpdate) -> Vec<Signal> {
|
||||
let mut signals = Vec::new();
|
||||
|
||||
// Only process bar data
|
||||
if let MarketDataType::Bar(bar) = &data.data {
|
||||
let symbol = &data.symbol;
|
||||
let price = bar.close;
|
||||
|
||||
// Update last prices
|
||||
self.last_prices.insert(symbol.clone(), price);
|
||||
|
||||
// Update price histories
|
||||
if symbol == &self.pair_a {
|
||||
self.price_history_a.push(price);
|
||||
if self.price_history_a.len() > self.lookback_period {
|
||||
self.price_history_a.remove(0);
|
||||
}
|
||||
} else if symbol == &self.pair_b {
|
||||
self.price_history_b.push(price);
|
||||
if self.price_history_b.len() > self.lookback_period {
|
||||
self.price_history_b.remove(0);
|
||||
}
|
||||
}
|
||||
|
||||
// Only generate signals when we have prices for both pairs
|
||||
if let (Some(&price_a), Some(&price_b)) =
|
||||
(self.last_prices.get(&self.pair_a), self.last_prices.get(&self.pair_b)) {
|
||||
|
||||
// Calculate current spread
|
||||
let spread = self.calculate_spread(price_a, price_b);
|
||||
|
||||
// Update spread history
|
||||
self.spread_history.push(spread);
|
||||
if self.spread_history.len() > self.lookback_period {
|
||||
self.spread_history.remove(0);
|
||||
}
|
||||
|
||||
// Need enough data
|
||||
if self.spread_history.len() >= self.lookback_period {
|
||||
// Calculate statistics
|
||||
let mean = Self::calculate_mean(&self.spread_history);
|
||||
let std_dev = Self::calculate_std_dev(&self.spread_history, mean);
|
||||
|
||||
// Calculate bands
|
||||
let upper_band = mean + self.entry_threshold * std_dev;
|
||||
let lower_band = mean - self.entry_threshold * std_dev;
|
||||
|
||||
let position_a = self.positions.get(&self.pair_a).copied().unwrap_or(0.0);
|
||||
let position_b = self.positions.get(&self.pair_b).copied().unwrap_or(0.0);
|
||||
|
||||
// Check for entry signals
|
||||
if position_a == 0.0 && position_b == 0.0 {
|
||||
if spread < lower_band {
|
||||
// Spread too low: Buy A, Sell B
|
||||
signals.push(Signal {
|
||||
symbol: self.pair_a.clone(),
|
||||
signal_type: SignalType::Buy,
|
||||
strength: 1.0,
|
||||
quantity: Some(self.position_size),
|
||||
reason: Some(format!(
|
||||
"Pairs trade: spread ${:.2} < lower band ${:.2}",
|
||||
spread, lower_band
|
||||
)),
|
||||
metadata: Some(json!({
|
||||
"spread": spread,
|
||||
"mean": mean,
|
||||
"std_dev": std_dev,
|
||||
"pair": "A",
|
||||
})),
|
||||
});
|
||||
|
||||
signals.push(Signal {
|
||||
symbol: self.pair_b.clone(),
|
||||
signal_type: SignalType::Sell,
|
||||
strength: 1.0,
|
||||
quantity: Some(self.position_size * self.hedge_ratio),
|
||||
reason: Some(format!(
|
||||
"Pairs trade hedge: spread ${:.2} < lower band ${:.2}",
|
||||
spread, lower_band
|
||||
)),
|
||||
metadata: Some(json!({
|
||||
"spread": spread,
|
||||
"mean": mean,
|
||||
"std_dev": std_dev,
|
||||
"pair": "B",
|
||||
})),
|
||||
});
|
||||
|
||||
self.positions.insert(self.pair_a.clone(), self.position_size);
|
||||
self.positions.insert(self.pair_b.clone(), -self.position_size * self.hedge_ratio);
|
||||
} else if spread > upper_band {
|
||||
// Spread too high: Sell A, Buy B
|
||||
signals.push(Signal {
|
||||
symbol: self.pair_a.clone(),
|
||||
signal_type: SignalType::Sell,
|
||||
strength: 1.0,
|
||||
quantity: Some(self.position_size),
|
||||
reason: Some(format!(
|
||||
"Pairs trade: spread ${:.2} > upper band ${:.2}",
|
||||
spread, upper_band
|
||||
)),
|
||||
metadata: Some(json!({
|
||||
"spread": spread,
|
||||
"mean": mean,
|
||||
"std_dev": std_dev,
|
||||
"pair": "A",
|
||||
})),
|
||||
});
|
||||
|
||||
signals.push(Signal {
|
||||
symbol: self.pair_b.clone(),
|
||||
signal_type: SignalType::Buy,
|
||||
strength: 1.0,
|
||||
quantity: Some(self.position_size * self.hedge_ratio),
|
||||
reason: Some(format!(
|
||||
"Pairs trade hedge: spread ${:.2} > upper band ${:.2}",
|
||||
spread, upper_band
|
||||
)),
|
||||
metadata: Some(json!({
|
||||
"spread": spread,
|
||||
"mean": mean,
|
||||
"std_dev": std_dev,
|
||||
"pair": "B",
|
||||
})),
|
||||
});
|
||||
|
||||
self.positions.insert(self.pair_a.clone(), -self.position_size);
|
||||
self.positions.insert(self.pair_b.clone(), self.position_size * self.hedge_ratio);
|
||||
}
|
||||
}
|
||||
// Check for exit signals
|
||||
else if position_a != 0.0 && position_b != 0.0 {
|
||||
// Exit when spread returns to mean
|
||||
let spread_distance = (spread - mean).abs();
|
||||
let exit_threshold = std_dev * 0.1; // Exit near mean
|
||||
|
||||
if spread_distance < exit_threshold {
|
||||
// Close positions
|
||||
if position_a > 0.0 {
|
||||
// We're long A, short B
|
||||
signals.push(Signal {
|
||||
symbol: self.pair_a.clone(),
|
||||
signal_type: SignalType::Sell,
|
||||
strength: 1.0,
|
||||
quantity: Some(position_a),
|
||||
reason: Some(format!(
|
||||
"Pairs trade exit: spread ${:.2} returned to mean ${:.2}",
|
||||
spread, mean
|
||||
)),
|
||||
metadata: None,
|
||||
});
|
||||
|
||||
signals.push(Signal {
|
||||
symbol: self.pair_b.clone(),
|
||||
signal_type: SignalType::Buy,
|
||||
strength: 1.0,
|
||||
quantity: Some(position_b.abs()),
|
||||
reason: Some("Pairs trade exit: closing hedge".to_string()),
|
||||
metadata: None,
|
||||
});
|
||||
} else {
|
||||
// We're short A, long B
|
||||
signals.push(Signal {
|
||||
symbol: self.pair_a.clone(),
|
||||
signal_type: SignalType::Buy,
|
||||
strength: 1.0,
|
||||
quantity: Some(position_a.abs()),
|
||||
reason: Some(format!(
|
||||
"Pairs trade exit: spread ${:.2} returned to mean ${:.2}",
|
||||
spread, mean
|
||||
)),
|
||||
metadata: None,
|
||||
});
|
||||
|
||||
signals.push(Signal {
|
||||
symbol: self.pair_b.clone(),
|
||||
signal_type: SignalType::Sell,
|
||||
strength: 1.0,
|
||||
quantity: Some(position_b),
|
||||
reason: Some("Pairs trade exit: closing hedge".to_string()),
|
||||
metadata: None,
|
||||
});
|
||||
}
|
||||
|
||||
self.positions.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
signals
|
||||
}
|
||||
|
||||
fn on_fill(&mut self, symbol: &str, quantity: f64, price: f64, side: &str) {
|
||||
eprintln!("Pairs trading fill: {} {} @ {} - {}", quantity, symbol, price, side);
|
||||
}
|
||||
|
||||
fn get_name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
fn get_parameters(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"pair_a": self.pair_a,
|
||||
"pair_b": self.pair_b,
|
||||
"lookback_period": self.lookback_period,
|
||||
"entry_threshold": self.entry_threshold,
|
||||
"position_size": self.position_size,
|
||||
"hedge_ratio": self.hedge_ratio,
|
||||
})
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue