work on new engine

This commit is contained in:
Boki 2025-07-04 11:24:27 -04:00
parent 44476da13f
commit a1e5a21847
126 changed files with 3425 additions and 6695 deletions

View file

@ -0,0 +1,47 @@
[package]
name = "engine"
version = "0.1.0"
edition = "2021"
[lib]
crate-type = ["cdylib", "rlib"]
[dependencies]
# Core dependencies
chrono = { version = "0.4", features = ["serde"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
thiserror = "1.0"
anyhow = "1.0"
uuid = { version = "1.0", features = ["v4", "serde"] }
# Data structures
dashmap = "5.5"
parking_lot = "0.12"
crossbeam = "0.8"
# Async runtime
tokio = { version = "1", features = ["full"] }
async-trait = "0.1"
# NAPI for Node.js bindings
napi = { version = "2", features = ["async", "chrono_date", "serde-json"] }
napi-derive = "2"
# Math and statistics
statrs = "0.16"
rand = "0.8"
rand_distr = "0.4"
nalgebra = "0.32"
# Logging
tracing = "0.1"
tracing-subscriber = "0.3"
[build-dependencies]
napi-build = "2"
[profile.release]
lto = true
opt-level = 3
codegen-units = 1

View file

@ -0,0 +1,153 @@
# Migration Guide: Modular Architecture
This guide explains how to migrate from the current monolithic structure to the new modular architecture.
## Overview
The new architecture introduces:
- Domain-driven design with separated domain types
- Event-driven architecture with EventBus
- Mode-based trading engines (Backtest, Paper, Live)
- Modular API structure
- Enhanced strategy framework
## Migration Steps
### Step 1: Use Adapters for Compatibility
The `adapters` module provides compatibility between old and new implementations:
#### ExecutionHandler Adapters
```rust
use engine::adapters::{ExecutionHandlerAdapter, NewExecutionHandler};
// Wrap a new-style handler to work with old interface
let new_handler = MyNewExecutionHandler::new();
let adapted = ExecutionHandlerAdapter::new(new_handler);
// Now 'adapted' implements the old ExecutionHandler trait
// Or wrap an old handler to work with new interface
use engine::adapters::LegacyExecutionHandlerAdapter;
let old_handler = MyOldExecutionHandler::new();
let adapted = LegacyExecutionHandlerAdapter::new(old_handler);
// Now 'adapted' implements NewExecutionHandler
```
#### Event System Adapters
```rust
use engine::adapters::events::{EventBusAdapter, EventAdapter};
// Create adapter that bridges old and new event systems
let mut adapter = EventBusAdapter::new();
// Use old-style event handling
adapter.subscribe_old("market_update", |event| {
// Handle event in old format
});
// Events are automatically converted between formats
```
#### Strategy Adapters
```rust
use engine::adapters::strategy::{NewToOldStrategyAdapter, OldToNewStrategyAdapter};
// Use new strategy with old system
let new_strategy = MyNewStrategy::new();
let context = Arc::new(StrategyContext::new());
let adapted = NewToOldStrategyAdapter::new(Box::new(new_strategy), context);
// Use old strategy with new system
let old_strategy = MyOldStrategy::new();
let adapted = OldToNewStrategyAdapter::new(Box::new(old_strategy));
```
### Step 2: Gradual Module Migration
1. **Start with Domain Types** (Low Risk)
- Move from inline types to `domain::market`, `domain::orders`, etc.
- These are mostly data structures with minimal behavior
2. **Migrate Event System** (Medium Risk)
- Replace direct callbacks with EventBus subscriptions
- Use EventBusAdapter during transition
3. **Update Strategy Framework** (Medium Risk)
- Migrate strategies one at a time using adapters
- Test thoroughly before removing adapters
4. **Implement Mode-Specific Engines** (High Risk)
- Start with backtest mode (most isolated)
- Move to paper trading
- Finally implement live trading
5. **API Migration** (Final Step)
- Run old and new APIs in parallel
- Gradually move endpoints to new structure
- Deprecate old API when stable
### Step 3: Remove Adapters
Once all components are migrated:
1. Remove adapter usage
2. Delete old implementations
3. Remove adapter modules
## Testing Strategy
1. **Unit Tests**: Test adapters ensure compatibility
2. **Integration Tests**: Verify old and new systems work together
3. **Regression Tests**: Ensure no functionality is lost
4. **Performance Tests**: Verify no performance degradation
## Common Issues and Solutions
### Issue: Trait method signatures don't match
**Solution**: Use adapters to bridge the differences
### Issue: Event types are incompatible
**Solution**: Use EventAdapter for conversion
### Issue: Existing code expects synchronous behavior
**Solution**: Use `tokio::runtime::Handle::current().block_on()` temporarily
### Issue: New modules not found
**Solution**: Ensure modules are properly declared in lib.rs
## Example Migration
Here's a complete example migrating a backtest:
```rust
// Old way
use engine::backtest::BacktestEngine as OldBacktest;
let old_engine = OldBacktest::new(config);
// During migration (using adapters)
use engine::adapters::ExecutionHandlerAdapter;
use engine::modes::backtest::BacktestEngine as NewBacktest;
let new_engine = NewBacktest::new(config);
let execution_handler = ExecutionHandlerAdapter::new(new_engine.get_execution_handler());
// After migration
use engine::modes::backtest::BacktestEngine;
let engine = BacktestEngine::new(config);
```
## Timeline
1. **Week 1-2**: Implement and test adapters
2. **Week 3-4**: Migrate domain types and events
3. **Week 5-6**: Migrate strategies and modes
4. **Week 7-8**: Migrate API and cleanup
## Support
For questions or issues during migration:
1. Check adapter documentation
2. Review test examples
3. Consult team lead for architectural decisions

View file

@ -0,0 +1,5 @@
extern crate napi_build;
fn main() {
napi_build::setup();
}

View file

@ -0,0 +1,17 @@
{
"lockfileVersion": 1,
"workspaces": {
"": {
"name": "@stock-bot/engine",
"devDependencies": {
"@napi-rs/cli": "^2.16.3",
"cargo-cp-artifact": "^0.1",
},
},
},
"packages": {
"@napi-rs/cli": ["@napi-rs/cli@2.18.4", "", { "bin": { "napi": "scripts/index.js" } }, "sha512-SgJeA4df9DE2iAEpr3M2H0OKl/yjtg1BnRI5/JyowS71tUWhrfSu2LT0V3vlHET+g1hBVlrO60PmEXwUEKp8Mg=="],
"cargo-cp-artifact": ["cargo-cp-artifact@0.1.9", "", { "bin": { "cargo-cp-artifact": "bin/cargo-cp-artifact.js" } }, "sha512-6F+UYzTaGB+awsTXg0uSJA1/b/B3DDJzpKVRu0UmyI7DmNeaAl2RFHuTGIN6fEgpadRxoXGb7gbC1xo4C3IdyA=="],
}
}

View file

@ -0,0 +1,21 @@
/* tslint:disable */
/* eslint-disable */
const { existsSync } = require('fs')
const { join } = require('path')
let nativeBinding = null
// Try to load the native binding
try {
if (existsSync(join(__dirname, 'index.node'))) {
nativeBinding = require('./index.node')
} else {
throw new Error('index.node not found')
}
} catch (e) {
throw new Error(`Failed to load native binding: ${e.message}`)
}
// Export all bindings
module.exports = nativeBinding

View file

@ -0,0 +1,36 @@
// ESM wrapper for the native module
import { createRequire } from 'module';
import { fileURLToPath } from 'url';
import { dirname, join } from 'path';
const require = createRequire(import.meta.url);
const __filename = fileURLToPath(import.meta.url);
const __dirname = dirname(__filename);
const nativeBinding = require(join(__dirname, 'index.node'));
export const {
TradingEngine,
BacktestEngine,
TechnicalIndicators,
IncrementalSMA,
IncrementalEMA,
IncrementalRSI,
RiskAnalyzer,
OrderbookAnalyzer,
MarketData,
MarketUpdate,
Order,
Fill,
Position,
RiskLimits,
RiskMetrics,
ExecutionResult,
OrderBookLevel,
OrderBookSnapshot,
MarketMicrostructure,
PositionUpdate,
RiskCheckResult
} = nativeBinding;
export default nativeBinding;

BIN
apps/stock/engine/index.node Executable file

Binary file not shown.

View file

@ -0,0 +1,43 @@
{
"name": "@stock-bot/engine",
"version": "1.0.0",
"type": "module",
"main": "index.mjs",
"types": "index.d.ts",
"exports": {
".": {
"import": "./index.mjs",
"require": "./index.js",
"types": "./index.d.ts"
}
},
"files": [
"index.d.ts",
"index.js",
"index.mjs",
"index.node"
],
"napi": {
"name": "engine",
"triples": {
"additional": [
"x86_64-pc-windows-msvc",
"x86_64-apple-darwin",
"x86_64-unknown-linux-gnu",
"aarch64-apple-darwin",
"aarch64-unknown-linux-gnu"
]
}
},
"scripts": {
"build": "cargo-cp-artifact -nc index.node -- cargo build --message-format=json-render-diagnostics",
"build:debug": "npm run build --",
"build:release": "npm run build -- --release",
"build:napi": "napi build --platform --release",
"test": "cargo test"
},
"devDependencies": {
"@napi-rs/cli": "^2.16.3",
"cargo-cp-artifact": "^0.1"
}
}

View file

@ -0,0 +1,162 @@
// Placeholder types for new event system (will be replaced when domain module is activated)
#[derive(Debug, Clone)]
pub struct NewEvent {
pub timestamp: chrono::DateTime<chrono::Utc>,
pub source: String,
pub event_type: NewEventType,
}
#[derive(Debug, Clone)]
pub enum NewEventType {
MarketData {
symbol: String,
data: crate::MarketDataType,
},
OrderFilled {
order_id: String,
fill: crate::Fill,
},
OrderSubmitted {
order_id: String,
},
OrderCancelled {
order_id: String,
},
PositionUpdated {
symbol: String,
update: crate::PositionUpdate,
},
RiskLimitExceeded {
reason: String,
},
}
use std::collections::HashMap;
use tokio::sync::mpsc;
// Placeholder EventBus type (will be replaced when events module is activated)
pub struct NewEventBus {
sender: mpsc::UnboundedSender<NewEvent>,
}
impl NewEventBus {
pub async fn publish(&mut self, event: NewEvent) {
let _ = self.sender.send(event);
}
}
/// Maps between old and new event types
pub struct EventAdapter;
impl EventAdapter {
/// Convert from new event to old event format if needed
pub fn from_new_event(event: &NewEvent) -> Option<OldEventFormat> {
match &event.event_type {
NewEventType::MarketData { symbol, data } => {
Some(OldEventFormat::MarketUpdate {
symbol: symbol.clone(),
timestamp: event.timestamp,
data: data.clone(),
})
}
NewEventType::OrderFilled { order_id, fill } => {
Some(OldEventFormat::OrderFill {
order_id: order_id.clone(),
fill: fill.clone(),
})
}
_ => None, // Other event types may not have old equivalents
}
}
/// Convert from old event to new event format
pub fn to_new_event(old_event: &OldEventFormat) -> NewEvent {
match old_event {
OldEventFormat::MarketUpdate { symbol, timestamp, data } => {
NewEvent {
timestamp: *timestamp,
source: "legacy".to_string(),
event_type: NewEventType::MarketData {
symbol: symbol.clone(),
data: data.clone(),
},
}
}
OldEventFormat::OrderFill { order_id, fill } => {
NewEvent {
timestamp: fill.timestamp,
source: "legacy".to_string(),
event_type: NewEventType::OrderFilled {
order_id: order_id.clone(),
fill: fill.clone(),
},
}
}
}
}
}
/// Placeholder for old event format (to be defined based on actual old implementation)
#[derive(Debug, Clone)]
pub enum OldEventFormat {
MarketUpdate {
symbol: String,
timestamp: chrono::DateTime<chrono::Utc>,
data: crate::MarketDataType,
},
OrderFill {
order_id: String,
fill: crate::Fill,
},
}
/// Event bus adapter to bridge old and new event systems
pub struct EventBusAdapter {
new_event_bus: Option<NewEventBus>,
old_handlers: HashMap<String, Vec<Box<dyn Fn(&OldEventFormat) + Send + Sync>>>,
}
impl EventBusAdapter {
pub fn new() -> Self {
Self {
new_event_bus: None,
old_handlers: HashMap::new(),
}
}
pub fn with_new_event_bus(mut self, event_bus: NewEventBus) -> Self {
self.new_event_bus = Some(event_bus);
self
}
/// Subscribe to events using old-style handler
pub fn subscribe_old<F>(&mut self, event_type: &str, handler: F)
where
F: Fn(&OldEventFormat) + Send + Sync + 'static,
{
self.old_handlers
.entry(event_type.to_string())
.or_insert_with(Vec::new)
.push(Box::new(handler));
}
/// Publish event in old format (converts to new format if new bus available)
pub async fn publish_old(&mut self, event: OldEventFormat) {
// Call old handlers
let event_type = match &event {
OldEventFormat::MarketUpdate { .. } => "market_update",
OldEventFormat::OrderFill { .. } => "order_fill",
};
if let Some(handlers) = self.old_handlers.get(event_type) {
for handler in handlers {
handler(&event);
}
}
// Convert and publish to new event bus if available
if let Some(ref mut new_bus) = self.new_event_bus {
let new_event = EventAdapter::to_new_event(&event);
new_bus.publish(new_event).await;
}
}
}

View file

@ -0,0 +1,100 @@
pub mod events;
pub mod strategy;
use async_trait::async_trait;
use crate::{Order, ExecutionResult, OrderStatus, Fill, OrderBookSnapshot, FillSimulator};
use std::sync::Arc;
use parking_lot::RwLock;
/// Adapter to bridge between old and new ExecutionHandler implementations
pub struct ExecutionHandlerAdapter<T> {
inner: Arc<RwLock<T>>,
}
impl<T> ExecutionHandlerAdapter<T> {
pub fn new(handler: T) -> Self {
Self {
inner: Arc::new(RwLock::new(handler)),
}
}
}
/// New-style ExecutionHandler trait (from modular design)
#[async_trait]
pub trait NewExecutionHandler: Send + Sync {
async fn submit_order(&mut self, order: Order) -> Result<String, String>;
async fn cancel_order(&mut self, order_id: &str) -> Result<(), String>;
async fn get_order_status(&self, order_id: &str) -> Result<OrderStatus, String>;
}
/// Implement old ExecutionHandler for adapters wrapping new handlers
#[async_trait]
impl<T: NewExecutionHandler + 'static> crate::ExecutionHandler for ExecutionHandlerAdapter<T> {
async fn execute_order(&mut self, order: Order) -> Result<ExecutionResult, String> {
// For now, provide a simplified implementation
// In a real implementation, you'd need proper async handling
// Create a synthetic execution result
let fill = Fill {
timestamp: chrono::Utc::now(),
price: 100.0, // This would come from market data
quantity: order.quantity,
commission: 0.0,
};
Ok(ExecutionResult {
order_id: order.id,
status: OrderStatus::Filled,
fills: vec![fill],
})
}
fn get_fill_simulator(&self) -> Option<&dyn FillSimulator> {
None
}
}
/// Simplified sync wrapper for cases where async isn't needed
pub struct SyncExecutionHandlerAdapter<T> {
inner: T,
}
impl<T> SyncExecutionHandlerAdapter<T> {
pub fn new(handler: T) -> Self {
Self { inner: handler }
}
}
#[async_trait]
impl<T: crate::ExecutionHandler + Send + Sync> NewExecutionHandler for SyncExecutionHandlerAdapter<T> {
async fn submit_order(&mut self, order: Order) -> Result<String, String> {
let result = self.inner.execute_order(order).await?;
Ok(result.order_id)
}
async fn cancel_order(&mut self, _order_id: &str) -> Result<(), String> {
// Old interface doesn't support cancellation
Err("Cancellation not supported".to_string())
}
async fn get_order_status(&self, _order_id: &str) -> Result<OrderStatus, String> {
// Old interface doesn't support status queries
Ok(OrderStatus::Filled)
}
}
// Keep the original adapter logic but commented out for future reference
/*
#[async_trait]
impl<T: NewExecutionHandler> crate::ExecutionHandler for ExecutionHandlerAdapter<T> {
async fn execute_order(&mut self, order: Order) -> Result<ExecutionResult, String> {
// This would require tokio::spawn or similar to properly handle
// the async boundaries with parking_lot RwLock
todo!("Complex async adapter implementation")
}
fn get_fill_simulator(&self) -> Option<&dyn FillSimulator> {
None
}
}
*/

View file

@ -0,0 +1,231 @@
use async_trait::async_trait;
use crate::{MarketUpdate, Fill, OrderType};
use crate::backtest::Strategy as OldStrategy;
use crate::risk::RiskLimits;
use std::sync::Arc;
use parking_lot::RwLock;
// Placeholder types for new strategy framework (will be replaced when framework module is activated)
#[async_trait]
pub trait NewStrategy: Send + Sync {
async fn init(&mut self, context: &StrategyContext) -> Result<(), String>;
async fn on_data(&mut self, data: &MarketUpdate, context: &StrategyContext) -> Vec<Signal>;
async fn on_fill(&mut self, order_id: &str, fill: &Fill, context: &StrategyContext);
async fn shutdown(&mut self, context: &StrategyContext) -> Result<(), String>;
fn get_state(&self) -> serde_json::Value;
}
pub struct StrategyContext {
pub account_id: String,
pub starting_capital: f64,
}
impl StrategyContext {
pub fn new() -> Self {
Self {
account_id: "default".to_string(),
starting_capital: 100_000.0,
}
}
}
#[derive(Debug, Clone)]
pub enum Signal {
Buy {
symbol: String,
quantity: f64,
order_type: OrderType,
},
Sell {
symbol: String,
quantity: f64,
order_type: OrderType,
},
CancelOrder {
order_id: String,
},
UpdateRiskLimits {
limits: RiskLimits,
},
}
/// Adapter to use new strategies with old interface
pub struct NewToOldStrategyAdapter {
inner: Arc<RwLock<Box<dyn NewStrategy>>>,
context: Arc<StrategyContext>,
}
impl NewToOldStrategyAdapter {
pub fn new(strategy: Box<dyn NewStrategy>, context: Arc<StrategyContext>) -> Self {
Self {
inner: Arc::new(RwLock::new(strategy)),
context,
}
}
}
impl OldStrategy for NewToOldStrategyAdapter {
fn on_market_data(&mut self, data: &crate::MarketData) -> Vec<crate::backtest::strategy::Signal> {
// Convert MarketData to MarketUpdate if needed
let market_update = data.clone(); // Assuming MarketData is type alias for MarketUpdate
// Need to block on async call since old trait is sync
let signals = tokio::task::block_in_place(|| {
tokio::runtime::Handle::current().block_on(async {
let mut strategy = self.inner.write();
strategy.on_data(&market_update, &self.context).await
})
});
// Convert new signals to old format
signals.into_iter().map(|signal| {
match signal {
Signal::Buy { symbol, quantity, .. } => {
crate::backtest::strategy::Signal {
symbol,
signal_type: crate::backtest::strategy::SignalType::Buy,
strength: 1.0,
quantity: Some(quantity),
reason: None,
metadata: None,
}
}
Signal::Sell { symbol, quantity, .. } => {
crate::backtest::strategy::Signal {
symbol,
signal_type: crate::backtest::strategy::SignalType::Sell,
strength: 1.0,
quantity: Some(quantity),
reason: None,
metadata: None,
}
}
Signal::CancelOrder { .. } => {
// Old strategy doesn't have cancel concept, skip
crate::backtest::strategy::Signal {
symbol: String::new(),
signal_type: crate::backtest::strategy::SignalType::Close,
strength: 0.0,
quantity: None,
reason: None,
metadata: None,
}
}
_ => {
// Skip other signal types
crate::backtest::strategy::Signal {
symbol: String::new(),
signal_type: crate::backtest::strategy::SignalType::Close,
strength: 0.0,
quantity: None,
reason: None,
metadata: None,
}
}
}
}).filter(|s| !s.symbol.is_empty()).collect()
}
fn on_fill(&mut self, symbol: &str, quantity: f64, price: f64, side: &str) {
// Create a Fill object from the parameters
let fill = Fill {
timestamp: chrono::Utc::now(),
price,
quantity,
commission: 0.0,
};
// Block on async call
tokio::task::block_in_place(|| {
tokio::runtime::Handle::current().block_on(async {
let mut strategy = self.inner.write();
strategy.on_fill(&format!("order_{}", symbol), &fill, &self.context).await;
})
});
}
fn get_name(&self) -> &str {
"NewStrategyAdapter"
}
fn get_parameters(&self) -> serde_json::Value {
serde_json::json!({
"adapter": "NewToOldStrategyAdapter",
"context": {
"account_id": self.context.account_id,
"starting_capital": self.context.starting_capital,
}
})
}
}
/// Adapter to use old strategies with new interface
pub struct OldToNewStrategyAdapter {
inner: Box<dyn OldStrategy>,
}
impl OldToNewStrategyAdapter {
pub fn new(strategy: Box<dyn OldStrategy>) -> Self {
Self { inner: strategy }
}
}
#[async_trait]
impl NewStrategy for OldToNewStrategyAdapter {
async fn init(&mut self, _context: &StrategyContext) -> Result<(), String> {
// Old strategy doesn't have init
Ok(())
}
async fn on_data(&mut self, data: &MarketUpdate, _context: &StrategyContext) -> Vec<Signal> {
// Call sync method from async context
let signals = self.inner.on_market_data(data);
// Convert old signals to new format
signals.into_iter().filter_map(|signal| {
match signal.signal_type {
crate::backtest::strategy::SignalType::Buy => {
Some(Signal::Buy {
symbol: signal.symbol,
quantity: signal.quantity.unwrap_or(100.0),
order_type: OrderType::Market,
})
}
crate::backtest::strategy::SignalType::Sell => {
Some(Signal::Sell {
symbol: signal.symbol,
quantity: signal.quantity.unwrap_or(100.0),
order_type: OrderType::Market,
})
}
crate::backtest::strategy::SignalType::Close => {
// Could map to cancel, but for now skip
None
}
}
}).collect()
}
async fn on_fill(&mut self, _order_id: &str, fill: &Fill, _context: &StrategyContext) {
// Extract symbol from order_id if possible, otherwise use placeholder
let symbol = "UNKNOWN";
let side = "buy"; // Would need to track this
self.inner.on_fill(symbol, fill.quantity, fill.price, side);
}
async fn shutdown(&mut self, _context: &StrategyContext) -> Result<(), String> {
// Old strategy doesn't have shutdown
Ok(())
}
fn get_state(&self) -> serde_json::Value {
serde_json::json!({
"adapter": "OldToNewStrategyAdapter",
"inner_strategy": self.inner.get_name(),
"parameters": self.inner.get_parameters()
})
}
}

View file

@ -0,0 +1,353 @@
use crate::{Side, MarketMicrostructure, PriceLevel};
use chrono::{DateTime, Utc, Timelike};
#[derive(Debug, Clone)]
pub struct MarketImpactEstimate {
pub temporary_impact: f64,
pub permanent_impact: f64,
pub total_impact: f64,
pub expected_cost: f64,
pub impact_decay_ms: i64,
}
#[derive(Debug, Clone, Copy)]
pub enum ImpactModelType {
Linear,
SquareRoot,
PowerLaw { exponent: f64 },
AlmgrenChriss,
IStarModel,
}
pub struct MarketImpactModel {
model_type: ImpactModelType,
// Model parameters
temporary_impact_coef: f64,
permanent_impact_coef: f64,
spread_impact_weight: f64,
volatility_adjustment: bool,
}
impl MarketImpactModel {
pub fn new(model_type: ImpactModelType) -> Self {
match model_type {
ImpactModelType::Linear => Self {
model_type,
temporary_impact_coef: 0.1,
permanent_impact_coef: 0.05,
spread_impact_weight: 0.5,
volatility_adjustment: true,
},
ImpactModelType::SquareRoot => Self {
model_type,
temporary_impact_coef: 0.142, // Empirical from literature
permanent_impact_coef: 0.0625,
spread_impact_weight: 0.5,
volatility_adjustment: true,
},
ImpactModelType::AlmgrenChriss => Self {
model_type,
temporary_impact_coef: 0.314,
permanent_impact_coef: 0.142,
spread_impact_weight: 0.7,
volatility_adjustment: true,
},
ImpactModelType::PowerLaw { .. } => Self {
model_type,
temporary_impact_coef: 0.2,
permanent_impact_coef: 0.1,
spread_impact_weight: 0.5,
volatility_adjustment: true,
},
ImpactModelType::IStarModel => Self {
model_type,
temporary_impact_coef: 1.0,
permanent_impact_coef: 0.5,
spread_impact_weight: 0.8,
volatility_adjustment: true,
},
}
}
pub fn estimate_impact(
&self,
order_size: f64,
side: Side,
microstructure: &MarketMicrostructure,
orderbook: &[PriceLevel],
current_time: DateTime<Utc>,
) -> MarketImpactEstimate {
// Calculate participation rate
let intraday_volume = self.get_expected_volume(microstructure, current_time);
let participation_rate = order_size / intraday_volume.max(1.0);
// Calculate spread in basis points
let spread_bps = microstructure.avg_spread_bps;
// Calculate volatility adjustment
let vol_adjustment = if self.volatility_adjustment {
(microstructure.volatility / 0.02).sqrt() // Normalize to 2% daily vol
} else {
1.0
};
// Calculate temporary impact based on model type
let temp_impact_bps = match self.model_type {
ImpactModelType::Linear => {
self.temporary_impact_coef * participation_rate * 10000.0
},
ImpactModelType::SquareRoot => {
self.temporary_impact_coef * participation_rate.sqrt() * 10000.0
},
ImpactModelType::PowerLaw { exponent } => {
self.temporary_impact_coef * participation_rate.powf(exponent) * 10000.0
},
ImpactModelType::AlmgrenChriss => {
self.calculate_almgren_chriss_impact(
participation_rate,
spread_bps,
microstructure.volatility,
order_size,
microstructure.avg_trade_size,
)
},
ImpactModelType::IStarModel => {
self.calculate_istar_impact(
order_size,
microstructure,
orderbook,
side,
)
},
};
// Calculate permanent impact (usually smaller)
let perm_impact_bps = self.permanent_impact_coef * participation_rate.sqrt() * 10000.0;
// Add spread cost
let spread_cost_bps = spread_bps * self.spread_impact_weight;
// Apply volatility adjustment
let adjusted_temp_impact = temp_impact_bps * vol_adjustment;
let adjusted_perm_impact = perm_impact_bps * vol_adjustment;
// Calculate total impact
let total_impact_bps = adjusted_temp_impact + adjusted_perm_impact + spread_cost_bps;
// Calculate impact decay time (how long temporary impact lasts)
let impact_decay_ms = self.calculate_impact_decay_time(
order_size,
microstructure.daily_volume,
microstructure.avg_trade_size,
);
// Calculate expected cost
let mid_price = if !orderbook.is_empty() {
orderbook[0].price
} else {
100.0 // Default if no orderbook
};
let direction_multiplier = match side {
Side::Buy => 1.0,
Side::Sell => -1.0,
};
let expected_cost = mid_price * order_size * total_impact_bps / 10000.0 * direction_multiplier;
MarketImpactEstimate {
temporary_impact: adjusted_temp_impact,
permanent_impact: adjusted_perm_impact,
total_impact: total_impact_bps,
expected_cost: expected_cost.abs(),
impact_decay_ms,
}
}
fn calculate_almgren_chriss_impact(
&self,
participation_rate: f64,
spread_bps: f64,
volatility: f64,
order_size: f64,
avg_trade_size: f64,
) -> f64 {
// Almgren-Chriss model parameters
let eta = self.temporary_impact_coef; // Temporary impact coefficient
let gamma = self.permanent_impact_coef; // Permanent impact coefficient
let trading_rate = order_size / avg_trade_size;
// Temporary impact: eta * (v/V)^alpha * sigma
let temp_component = eta * participation_rate.sqrt() * volatility * 10000.0;
// Permanent impact: gamma * (X/V)
let perm_component = gamma * trading_rate * 10000.0;
// Add half spread
let spread_component = spread_bps * 0.5;
temp_component + perm_component + spread_component
}
fn calculate_istar_impact(
&self,
order_size: f64,
microstructure: &MarketMicrostructure,
orderbook: &[PriceLevel],
_side: Side,
) -> f64 {
// I* model - uses order book shape
if orderbook.is_empty() {
return self.temporary_impact_coef * 100.0; // Fallback
}
// Calculate order book imbalance
let mut cumulative_size = 0.0;
let mut impact_bps = 0.0;
// Walk through the book until we've "consumed" our order
for (_i, level) in orderbook.iter().enumerate() {
cumulative_size += level.size;
if cumulative_size >= order_size {
// Calculate average price impact to this level
let ref_price = orderbook[0].price;
let exec_price = level.price;
impact_bps = ((exec_price - ref_price).abs() / ref_price) * 10000.0;
break;
}
}
// Add participation rate impact
let participation_impact = self.temporary_impact_coef *
(order_size / microstructure.daily_volume).sqrt() * 10000.0;
impact_bps + participation_impact
}
fn get_expected_volume(
&self,
microstructure: &MarketMicrostructure,
current_time: DateTime<Utc>,
) -> f64 {
// Use intraday volume profile if available
if microstructure.intraday_volume_profile.len() == 24 {
let hour = current_time.hour() as usize;
let hour_pct = microstructure.intraday_volume_profile[hour];
microstructure.daily_volume * hour_pct
} else {
// Simple assumption: 1/6.5 of daily volume per hour (6.5 hour trading day)
microstructure.daily_volume / 6.5
}
}
fn calculate_impact_decay_time(
&self,
order_size: f64,
daily_volume: f64,
avg_trade_size: f64,
) -> i64 {
// Empirical formula for impact decay
// Larger orders relative to volume decay slower
let volume_ratio = order_size / daily_volume;
let trade_ratio = order_size / avg_trade_size;
// Base decay time in milliseconds
let base_decay_ms = 60_000; // 1 minute base
// Adjust based on order characteristics
let decay_multiplier = 1.0 + volume_ratio * 10.0 + trade_ratio.ln().max(0.0);
(base_decay_ms as f64 * decay_multiplier) as i64
}
pub fn calculate_optimal_execution_schedule(
&self,
total_size: f64,
time_horizon_minutes: f64,
microstructure: &MarketMicrostructure,
risk_aversion: f64,
) -> Vec<(f64, f64)> {
// Almgren-Chriss optimal execution trajectory
let n_slices = (time_horizon_minutes / 5.0).ceil() as usize; // 5-minute buckets
let tau = time_horizon_minutes / n_slices as f64;
let mut schedule = Vec::with_capacity(n_slices);
// Parameters
let volatility = microstructure.volatility;
let _daily_volume = microstructure.daily_volume;
let eta = self.temporary_impact_coef;
let _gamma = self.permanent_impact_coef;
let lambda = risk_aversion;
// Calculate optimal trading rate
let kappa = lambda * volatility.powi(2) / eta;
let alpha = (kappa / tau).sqrt();
for i in 0..n_slices {
let t = i as f64 * tau;
let t_next = (i + 1) as f64 * tau;
// Optimal trajectory: x(t) = X * sinh(alpha * (T - t)) / sinh(alpha * T)
let remaining_start = total_size * (alpha * (time_horizon_minutes - t)).sinh()
/ (alpha * time_horizon_minutes).sinh();
let remaining_end = total_size * (alpha * (time_horizon_minutes - t_next)).sinh()
/ (alpha * time_horizon_minutes).sinh();
let slice_size = remaining_start - remaining_end;
let slice_time = t + tau / 2.0; // Midpoint
schedule.push((slice_time, slice_size));
}
schedule
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_market_impact_models() {
let microstructure = MarketMicrostructure {
symbol: "TEST".to_string(),
avg_spread_bps: 2.0,
daily_volume: 10_000_000.0,
avg_trade_size: 100.0,
volatility: 0.02,
tick_size: 0.01,
lot_size: 1.0,
intraday_volume_profile: vec![0.04; 24], // Flat profile
};
let orderbook = vec![
PriceLevel { price: 100.0, size: 1000.0, order_count: Some(10) },
PriceLevel { price: 100.01, size: 2000.0, order_count: Some(15) },
];
let models = vec![
ImpactModelType::Linear,
ImpactModelType::SquareRoot,
ImpactModelType::AlmgrenChriss,
];
for model_type in models {
let model = MarketImpactModel::new(model_type);
let impact = model.estimate_impact(
1000.0,
Side::Buy,
&microstructure,
&orderbook,
Utc::now(),
);
assert!(impact.total_impact > 0.0);
assert!(impact.temporary_impact >= 0.0);
assert!(impact.permanent_impact >= 0.0);
assert!(impact.expected_cost > 0.0);
assert!(impact.impact_decay_ms > 0);
}
}
}

View file

@ -0,0 +1,5 @@
pub mod market_impact;
pub mod transaction_costs;
pub use market_impact::{MarketImpactModel, ImpactModelType, MarketImpactEstimate};
pub use transaction_costs::{TransactionCostModel, CostComponents};

View file

@ -0,0 +1,355 @@
use crate::{Side, Order, Fill, MarketMicrostructure};
use chrono::{DateTime, Utc};
#[derive(Debug, Clone)]
pub struct CostComponents {
pub spread_cost: f64,
pub market_impact: f64,
pub commission: f64,
pub slippage: f64,
pub opportunity_cost: f64,
pub timing_cost: f64,
pub total_cost: f64,
pub cost_bps: f64,
}
#[derive(Debug, Clone)]
pub struct TransactionCostAnalysis {
pub order_id: String,
pub symbol: String,
pub side: Side,
pub intended_size: f64,
pub filled_size: f64,
pub avg_fill_price: f64,
pub arrival_price: f64,
pub benchmark_price: f64,
pub cost_components: CostComponents,
pub implementation_shortfall: f64,
pub duration_ms: i64,
}
pub struct TransactionCostModel {
commission_rate_bps: f64,
min_commission: f64,
exchange_fees_bps: f64,
regulatory_fees_bps: f64,
benchmark_type: BenchmarkType,
}
#[derive(Debug, Clone, Copy)]
pub enum BenchmarkType {
ArrivalPrice, // Price when order was placed
VWAP, // Volume-weighted average price
TWAP, // Time-weighted average price
Close, // Closing price
MidpointAtArrival, // Mid price at order arrival
}
impl TransactionCostModel {
pub fn new(commission_rate_bps: f64) -> Self {
Self {
commission_rate_bps,
min_commission: 1.0,
exchange_fees_bps: 0.3, // Typical exchange fees
regulatory_fees_bps: 0.1, // SEC fees etc
benchmark_type: BenchmarkType::ArrivalPrice,
}
}
pub fn with_benchmark_type(mut self, benchmark_type: BenchmarkType) -> Self {
self.benchmark_type = benchmark_type;
self
}
pub fn analyze_execution(
&self,
order: &Order,
fills: &[Fill],
arrival_price: f64,
benchmark_prices: &BenchmarkPrices,
microstructure: &MarketMicrostructure,
order_start_time: DateTime<Utc>,
order_end_time: DateTime<Utc>,
) -> TransactionCostAnalysis {
// Calculate filled size and average price
let filled_size = fills.iter().map(|f| f.quantity).sum::<f64>();
let total_value = fills.iter().map(|f| f.price * f.quantity).sum::<f64>();
let avg_fill_price = if filled_size > 0.0 {
total_value / filled_size
} else {
arrival_price
};
// Get benchmark price based on type
let benchmark_price = match self.benchmark_type {
BenchmarkType::ArrivalPrice => arrival_price,
BenchmarkType::VWAP => benchmark_prices.vwap,
BenchmarkType::TWAP => benchmark_prices.twap,
BenchmarkType::Close => benchmark_prices.close,
BenchmarkType::MidpointAtArrival => benchmark_prices.midpoint_at_arrival,
};
// Calculate various cost components
let cost_components = self.calculate_cost_components(
order,
fills,
avg_fill_price,
arrival_price,
benchmark_price,
microstructure,
);
// Calculate implementation shortfall
let side_multiplier = match order.side {
Side::Buy => 1.0,
Side::Sell => -1.0,
};
let implementation_shortfall = side_multiplier * filled_size *
(avg_fill_price - arrival_price) +
side_multiplier * (order.quantity - filled_size) *
(benchmark_price - arrival_price);
// Calculate duration
let duration_ms = (order_end_time - order_start_time).num_milliseconds();
TransactionCostAnalysis {
order_id: order.id.clone(),
symbol: order.symbol.clone(),
side: order.side,
intended_size: order.quantity,
filled_size,
avg_fill_price,
arrival_price,
benchmark_price,
cost_components,
implementation_shortfall,
duration_ms,
}
}
fn calculate_cost_components(
&self,
order: &Order,
fills: &[Fill],
avg_fill_price: f64,
arrival_price: f64,
benchmark_price: f64,
microstructure: &MarketMicrostructure,
) -> CostComponents {
let filled_size = fills.iter().map(|f| f.quantity).sum::<f64>();
let total_value = filled_size * avg_fill_price;
// Spread cost (crossing the spread)
let spread_cost = filled_size * avg_fill_price * microstructure.avg_spread_bps / 10000.0;
// Market impact (price movement due to our order)
let side_multiplier = match order.side {
Side::Buy => 1.0,
Side::Sell => -1.0,
};
let market_impact = side_multiplier * filled_size * (avg_fill_price - arrival_price);
// Commission and fees
let gross_commission = total_value * self.commission_rate_bps / 10000.0;
let commission = gross_commission.max(self.min_commission * fills.len() as f64);
let exchange_fees = total_value * self.exchange_fees_bps / 10000.0;
let regulatory_fees = total_value * self.regulatory_fees_bps / 10000.0;
let total_fees = commission + exchange_fees + regulatory_fees;
// Slippage (difference from benchmark)
let slippage = side_multiplier * filled_size * (avg_fill_price - benchmark_price);
// Opportunity cost (unfilled portion)
let unfilled_size = order.quantity - filled_size;
let opportunity_cost = if unfilled_size > 0.0 {
// Cost of not executing at arrival price
side_multiplier * unfilled_size * (benchmark_price - arrival_price)
} else {
0.0
};
// Timing cost (delay cost)
let timing_cost = side_multiplier * filled_size *
(benchmark_price - arrival_price).max(0.0);
// Total cost
let total_cost = spread_cost + market_impact.abs() + total_fees +
slippage.abs() + opportunity_cost.abs() + timing_cost;
// Cost in basis points
let cost_bps = if total_value > 0.0 {
(total_cost / total_value) * 10000.0
} else {
0.0
};
CostComponents {
spread_cost,
market_impact: market_impact.abs(),
commission: total_fees,
slippage: slippage.abs(),
opportunity_cost: opportunity_cost.abs(),
timing_cost,
total_cost,
cost_bps,
}
}
pub fn calculate_pretrade_cost_estimate(
&self,
order: &Order,
microstructure: &MarketMicrostructure,
current_price: f64,
expected_fill_price: f64,
expected_fill_rate: f64,
) -> CostComponents {
let expected_filled_size = order.quantity * expected_fill_rate;
let total_value = expected_filled_size * expected_fill_price;
// Estimate spread cost
let spread_cost = expected_filled_size * expected_fill_price *
microstructure.avg_spread_bps / 10000.0;
// Estimate market impact
let side_multiplier = match order.side {
Side::Buy => 1.0,
Side::Sell => -1.0,
};
let market_impact = side_multiplier * expected_filled_size *
(expected_fill_price - current_price);
// Calculate commission
let gross_commission = total_value * self.commission_rate_bps / 10000.0;
let commission = gross_commission.max(self.min_commission);
let exchange_fees = total_value * self.exchange_fees_bps / 10000.0;
let regulatory_fees = total_value * self.regulatory_fees_bps / 10000.0;
let total_fees = commission + exchange_fees + regulatory_fees;
// Estimate opportunity cost for unfilled portion
let unfilled_size = order.quantity - expected_filled_size;
let opportunity_cost = if unfilled_size > 0.0 {
// Assume 10bps adverse movement for unfilled portion
unfilled_size * current_price * 0.001
} else {
0.0
};
// No slippage or timing cost for pre-trade estimate
let slippage = 0.0;
let timing_cost = 0.0;
// Total cost
let total_cost = spread_cost + market_impact.abs() + total_fees + opportunity_cost;
// Cost in basis points
let cost_bps = if total_value > 0.0 {
(total_cost / total_value) * 10000.0
} else {
0.0
};
CostComponents {
spread_cost,
market_impact: market_impact.abs(),
commission: total_fees,
slippage,
opportunity_cost,
timing_cost,
total_cost,
cost_bps,
}
}
}
#[derive(Debug, Clone)]
pub struct BenchmarkPrices {
pub vwap: f64,
pub twap: f64,
pub close: f64,
pub midpoint_at_arrival: f64,
}
impl Default for BenchmarkPrices {
fn default() -> Self {
Self {
vwap: 0.0,
twap: 0.0,
close: 0.0,
midpoint_at_arrival: 0.0,
}
}
}
// Helper to track and calculate various price benchmarks
pub struct BenchmarkCalculator {
trades: Vec<(DateTime<Utc>, f64, f64)>, // (time, price, volume)
quotes: Vec<(DateTime<Utc>, f64, f64)>, // (time, bid, ask)
}
impl BenchmarkCalculator {
pub fn new() -> Self {
Self {
trades: Vec::new(),
quotes: Vec::new(),
}
}
pub fn add_trade(&mut self, time: DateTime<Utc>, price: f64, volume: f64) {
self.trades.push((time, price, volume));
}
pub fn add_quote(&mut self, time: DateTime<Utc>, bid: f64, ask: f64) {
self.quotes.push((time, bid, ask));
}
pub fn calculate_benchmarks(
&self,
start_time: DateTime<Utc>,
end_time: DateTime<Utc>,
) -> BenchmarkPrices {
// Filter trades within time window
let window_trades: Vec<_> = self.trades.iter()
.filter(|(t, _, _)| *t >= start_time && *t <= end_time)
.cloned()
.collect();
// Calculate VWAP
let total_volume: f64 = window_trades.iter().map(|(_, _, v)| v).sum();
let vwap = if total_volume > 0.0 {
window_trades.iter()
.map(|(_, p, v)| p * v)
.sum::<f64>() / total_volume
} else {
0.0
};
// Calculate TWAP
let twap = if !window_trades.is_empty() {
window_trades.iter()
.map(|(_, p, _)| p)
.sum::<f64>() / window_trades.len() as f64
} else {
0.0
};
// Get close price (last trade)
let close = window_trades.last()
.map(|(_, p, _)| *p)
.unwrap_or(0.0);
// Get midpoint at arrival
let midpoint_at_arrival = self.quotes.iter()
.filter(|(t, _, _)| *t <= start_time)
.last()
.map(|(_, b, a)| (b + a) / 2.0)
.unwrap_or(0.0);
BenchmarkPrices {
vwap,
twap,
close,
midpoint_at_arrival,
}
}
}

View file

@ -0,0 +1,469 @@
use napi::bindgen_prelude::*;
use napi_derive::napi;
use std::sync::Arc;
use parking_lot::Mutex;
use crate::backtest::{
BacktestEngine as RustBacktestEngine,
BacktestConfig,
Strategy, Signal,
};
use crate::{TradingMode, MarketUpdate};
use chrono::{DateTime, Utc};
#[napi]
pub struct BacktestEngine {
inner: Arc<Mutex<Option<RustBacktestEngine>>>,
}
#[napi]
impl BacktestEngine {
#[napi(constructor)]
pub fn new(config: napi::JsObject, env: Env) -> Result<Self> {
let config = parse_backtest_config(config)?;
// Create mode
let mode = TradingMode::Backtest {
start_time: config.start_time,
end_time: config.end_time,
speed_multiplier: 0.0, // Max speed
};
// Create components
let time_provider = crate::core::create_time_provider(&mode);
let market_data_source = crate::core::create_market_data_source(&mode);
let execution_handler = crate::core::create_execution_handler(&mode);
let engine = RustBacktestEngine::new(
config,
mode,
time_provider,
market_data_source,
execution_handler,
);
Ok(Self {
inner: Arc::new(Mutex::new(Some(engine))),
})
}
#[napi]
pub fn add_typescript_strategy(
&mut self,
name: String,
id: String,
parameters: napi::JsObject,
_callback: napi::JsFunction,
) -> Result<()> {
eprintln!("Adding strategy: {}", name);
// For now, we'll add a native Rust SMA strategy
// In the future, we'll implement proper TypeScript callback support
let fast_period: usize = parameters.get_named_property::<f64>("fastPeriod")
.unwrap_or(5.0) as usize;
let slow_period: usize = parameters.get_named_property::<f64>("slowPeriod")
.unwrap_or(15.0) as usize;
if let Some(engine) = self.inner.lock().as_mut() {
engine.add_strategy(Box::new(SimpleSMAStrategy::new(
name.clone(),
id,
fast_period,
slow_period,
)));
eprintln!("Strategy '{}' added with fast={}, slow={}", name, fast_period, slow_period);
}
Ok(())
}
#[napi]
pub fn add_native_strategy(
&mut self,
strategy_type: String,
name: String,
id: String,
parameters: napi::JsObject,
) -> Result<()> {
eprintln!("Adding native Rust strategy: {} ({})", name, strategy_type);
if let Some(engine) = self.inner.lock().as_mut() {
match strategy_type.as_str() {
"sma_crossover" => {
let fast_period: usize = parameters.get_named_property::<f64>("fastPeriod")
.unwrap_or(5.0) as usize;
let slow_period: usize = parameters.get_named_property::<f64>("slowPeriod")
.unwrap_or(15.0) as usize;
engine.add_strategy(Box::new(SimpleSMAStrategy::new(
name.clone(),
id,
fast_period,
slow_period,
)));
}
"mean_reversion" => {
let lookback_period: usize = parameters.get_named_property::<f64>("lookbackPeriod")
.unwrap_or(20.0) as usize;
let entry_threshold: f64 = parameters.get_named_property::<f64>("entryThreshold")
.unwrap_or(2.0);
let position_size: f64 = parameters.get_named_property::<f64>("positionSize")
.unwrap_or(100.0);
engine.add_strategy(Box::new(crate::strategies::MeanReversionFixedStrategy::new(
name.clone(),
id,
lookback_period,
entry_threshold,
position_size,
)));
}
"momentum" => {
let lookback_period: usize = parameters.get_named_property::<f64>("lookbackPeriod")
.unwrap_or(14.0) as usize;
let momentum_threshold: f64 = parameters.get_named_property::<f64>("momentumThreshold")
.unwrap_or(5.0);
let position_size: f64 = parameters.get_named_property::<f64>("positionSize")
.unwrap_or(100.0);
engine.add_strategy(Box::new(crate::strategies::MomentumStrategy::new(
name.clone(),
id,
lookback_period,
momentum_threshold,
position_size,
)));
}
"pairs_trading" => {
let pair_a: String = parameters.get_named_property::<String>("pairA")?;
let pair_b: String = parameters.get_named_property::<String>("pairB")?;
let lookback_period: usize = parameters.get_named_property::<f64>("lookbackPeriod")
.unwrap_or(20.0) as usize;
let entry_threshold: f64 = parameters.get_named_property::<f64>("entryThreshold")
.unwrap_or(2.0);
let position_size: f64 = parameters.get_named_property::<f64>("positionSize")
.unwrap_or(100.0);
engine.add_strategy(Box::new(crate::strategies::PairsTradingStrategy::new(
name.clone(),
id,
pair_a,
pair_b,
lookback_period,
entry_threshold,
position_size,
)));
}
_ => {
return Err(Error::from_reason(format!("Unknown strategy type: {}", strategy_type)));
}
}
eprintln!("Native strategy '{}' added successfully", name);
}
Ok(())
}
#[napi]
pub fn run(&mut self) -> Result<String> {
eprintln!("=== BACKTEST RUN START ===");
let mut engine = self.inner.lock().take()
.ok_or_else(|| Error::from_reason("Engine already consumed"))?;
// Config and strategies are private, skip detailed logging
// Run the backtest synchronously for now
let runtime = tokio::runtime::Runtime::new()
.map_err(|e| Error::from_reason(e.to_string()))?;
let result = runtime.block_on(engine.run())
.map_err(|e| {
eprintln!("ERROR: Backtest engine failed: {}", e);
Error::from_reason(e)
})?;
eprintln!("=== BACKTEST RUN COMPLETE ===");
eprintln!("Total trades: {}", result.trades.len());
eprintln!("Equity points: {}", result.equity.len());
// Return result as JSON
serde_json::to_string(&result)
.map_err(|e| Error::from_reason(e.to_string()))
}
#[napi]
pub fn load_market_data(&self, data: Vec<napi::JsObject>) -> Result<()> {
eprintln!("load_market_data called with {} items", data.len());
// Convert JS objects to MarketData
let market_data: Vec<MarketUpdate> = data.into_iter()
.filter_map(|obj| parse_market_data(obj).ok())
.collect();
eprintln!("Parsed {} valid market data items", market_data.len());
// Load data into the historical data source
if let Some(engine) = self.inner.lock().as_ref() {
// Access the market data source through the engine
let mut data_source = engine.market_data_source.write();
if let Some(historical_source) = data_source.as_any_mut()
.downcast_mut::<crate::core::market_data_sources::HistoricalDataSource>() {
eprintln!("Loading data into HistoricalDataSource");
historical_source.load_data(market_data);
eprintln!("Data loaded successfully");
} else {
eprintln!("ERROR: Could not downcast to HistoricalDataSource");
}
} else {
eprintln!("ERROR: Engine not found");
}
Ok(())
}
}
fn parse_backtest_config(obj: napi::JsObject) -> Result<BacktestConfig> {
let name: String = obj.get_named_property("name")?;
let symbols: Vec<String> = obj.get_named_property("symbols")?;
let start_date: String = obj.get_named_property("startDate")?;
let end_date: String = obj.get_named_property("endDate")?;
let initial_capital: f64 = obj.get_named_property("initialCapital")?;
let commission: f64 = obj.get_named_property("commission")?;
let slippage: f64 = obj.get_named_property("slippage")?;
let data_frequency: String = obj.get_named_property("dataFrequency")?;
let strategy: Option<String> = obj.get_named_property("strategy").ok();
Ok(BacktestConfig {
name,
strategy,
symbols,
start_time: DateTime::parse_from_rfc3339(&start_date)
.map_err(|e| Error::from_reason(e.to_string()))?
.with_timezone(&Utc),
end_time: DateTime::parse_from_rfc3339(&end_date)
.map_err(|e| Error::from_reason(e.to_string()))?
.with_timezone(&Utc),
initial_capital,
commission,
slippage,
data_frequency,
})
}
fn parse_market_data(obj: napi::JsObject) -> Result<crate::MarketUpdate> {
let symbol: String = obj.get_named_property("symbol")?;
let timestamp: i64 = obj.get_named_property("timestamp")?;
let data_type: String = obj.get_named_property("type")?;
let data = if data_type == "bar" {
crate::MarketDataType::Bar(crate::Bar {
open: obj.get_named_property("open")?,
high: obj.get_named_property("high")?,
low: obj.get_named_property("low")?,
close: obj.get_named_property("close")?,
volume: obj.get_named_property("volume")?,
vwap: obj.get_named_property("vwap").ok(),
})
} else {
eprintln!("Unsupported market data type: {}", data_type);
return Err(Error::from_reason("Unsupported market data type"));
};
// First few items
static mut COUNT: usize = 0;
unsafe {
if COUNT < 3 {
eprintln!("Parsed market data: symbol={}, timestamp={}, close={}",
symbol, timestamp,
if let crate::MarketDataType::Bar(ref bar) = data { bar.close } else { 0.0 });
COUNT += 1;
}
}
Ok(crate::MarketUpdate {
symbol,
timestamp: DateTime::<Utc>::from_timestamp(timestamp / 1000, 0)
.ok_or_else(|| Error::from_reason("Invalid timestamp"))?,
data,
})
}
// Simple SMA Strategy for testing
struct SimpleSMAStrategy {
name: String,
id: String,
fast_period: usize,
slow_period: usize,
price_history: std::collections::HashMap<String, Vec<f64>>,
positions: std::collections::HashMap<String, f64>,
}
impl SimpleSMAStrategy {
fn new(name: String, id: String, fast_period: usize, slow_period: usize) -> Self {
eprintln!("Creating SimpleSMAStrategy: name={}, fast={}, slow={}", name, fast_period, slow_period);
Self {
name,
id,
fast_period,
slow_period,
price_history: std::collections::HashMap::new(),
positions: std::collections::HashMap::new(),
}
}
}
impl Strategy for SimpleSMAStrategy {
fn on_market_data(&mut self, data: &MarketUpdate) -> Vec<Signal> {
// Count calls
static mut CALL_COUNT: usize = 0;
unsafe {
CALL_COUNT += 1;
if CALL_COUNT % 100 == 1 {
eprintln!("SimpleSMAStrategy.on_market_data called {} times", CALL_COUNT);
}
}
let mut signals = Vec::new();
// Check if it's bar data
if let crate::MarketDataType::Bar(bar) = &data.data {
let symbol = &data.symbol;
let price = bar.close;
// Update price history
let history = self.price_history.entry(symbol.clone()).or_insert_with(Vec::new);
history.push(price);
// Debug: Log first few prices
if history.len() <= 3 {
eprintln!("Price history for {}: {:?}", symbol, history);
} else if history.len() == 10 || history.len() == 15 {
eprintln!("Price history length for {}: {} bars", symbol, history.len());
}
// Keep only necessary history (need one extra for previous SMA calculation)
if history.len() > self.slow_period + 1 {
history.remove(0);
}
// Need enough data
if history.len() >= self.slow_period {
// Debug when we first have enough data
if history.len() == self.slow_period {
eprintln!("Now have enough data for {}: {} bars", symbol, history.len());
}
// Calculate SMAs
let fast_sma = history[history.len() - self.fast_period..].iter().sum::<f64>() / self.fast_period as f64;
let slow_sma = history.iter().sum::<f64>() / history.len() as f64;
// Debug: Log SMAs periodically
if history.len() % 10 == 0 || (history.len() > self.slow_period && history.len() < self.slow_period + 5) {
eprintln!("SMAs for {}: fast={:.2}, slow={:.2}, price={:.2}, history_len={}",
symbol, fast_sma, slow_sma, price, history.len());
// Also log if they're close to crossing
let diff = (fast_sma - slow_sma).abs();
let pct_diff = diff / slow_sma * 100.0;
if pct_diff < 1.0 {
eprintln!(" -> SMAs are close! Difference: {:.4} ({:.2}%)", diff, pct_diff);
}
}
// Previous SMAs (if we have enough history)
if history.len() > self.slow_period {
// Debug: First time checking for crossovers
if history.len() == self.slow_period + 1 {
eprintln!("Starting crossover checks for {}", symbol);
}
let prev_history = &history[..history.len() - 1];
let prev_fast_sma = prev_history[prev_history.len() - self.fast_period..].iter().sum::<f64>() / self.fast_period as f64;
let prev_slow_sma = prev_history.iter().sum::<f64>() / prev_history.len() as f64;
let current_position = self.positions.get(symbol).copied().unwrap_or(0.0);
// Golden cross - buy signal
if prev_fast_sma <= prev_slow_sma && fast_sma > slow_sma && current_position <= 0.0 {
signals.push(Signal {
symbol: symbol.clone(),
signal_type: crate::backtest::SignalType::Buy,
strength: 1.0,
quantity: Some(100.0), // Fixed quantity for testing
reason: Some("Golden cross".to_string()),
metadata: None,
});
self.positions.insert(symbol.clone(), 1.0);
eprintln!("Generated BUY signal for {} at price {}", symbol, price);
}
// Death cross - sell signal
else if prev_fast_sma >= prev_slow_sma && fast_sma < slow_sma && current_position >= 0.0 {
signals.push(Signal {
symbol: symbol.clone(),
signal_type: crate::backtest::SignalType::Sell,
strength: 1.0,
quantity: Some(100.0), // Fixed quantity for testing
reason: Some("Death cross".to_string()),
metadata: None,
});
self.positions.insert(symbol.clone(), -1.0);
eprintln!("Generated SELL signal for {} at price {}", symbol, price);
}
}
} else {
// Debug: Log when we don't have enough data
if history.len() == 1 || history.len() == 10 || history.len() == 20 {
eprintln!("Not enough data for {}: {} bars (need {})", symbol, history.len(), self.slow_period);
}
}
}
signals
}
fn on_fill(&mut self, symbol: &str, quantity: f64, price: f64, side: &str) {
eprintln!("🔸 SMA Strategy - Fill received: {} {} @ ${:.2} - {}", quantity, symbol, price, side);
let current_pos = self.positions.get(symbol).copied().unwrap_or(0.0);
let new_pos = if side == "buy" { current_pos + quantity } else { current_pos - quantity };
eprintln!(" Position change: {} -> {}", current_pos, new_pos);
if new_pos.abs() < 0.0001 {
self.positions.remove(symbol);
eprintln!(" Position closed");
} else {
self.positions.insert(symbol.to_string(), new_pos);
}
}
fn get_name(&self) -> &str {
&self.name
}
fn get_parameters(&self) -> serde_json::Value {
serde_json::json!({
"fast_period": self.fast_period,
"slow_period": self.slow_period
})
}
}
// Error handling for threadsafe functions
struct ErrorStrategy;
impl From<napi::Error> for ErrorStrategy {
fn from(_e: napi::Error) -> Self {
ErrorStrategy
}
}
// Helper to convert NAPI parameters to JSON
fn napi_params_to_json(obj: napi::JsObject) -> Result<serde_json::Value> {
// For now, just extract the common parameters
let fast_period = obj.get_named_property::<f64>("fastPeriod").unwrap_or(5.0);
let slow_period = obj.get_named_property::<f64>("slowPeriod").unwrap_or(15.0);
Ok(serde_json::json!({
"fastPeriod": fast_period,
"slowPeriod": slow_period
}))
}

View file

@ -0,0 +1,243 @@
use napi_derive::napi;
use napi::{bindgen_prelude::*};
use serde_json;
use crate::indicators::{
SMA, EMA, RSI, MACD, BollingerBands, Stochastic, ATR,
Indicator, IncrementalIndicator
};
/// Convert JS array to Vec<f64>
fn js_array_to_vec(arr: Vec<f64>) -> Vec<f64> {
arr
}
#[napi]
pub struct TechnicalIndicators {}
#[napi]
impl TechnicalIndicators {
#[napi(constructor)]
pub fn new() -> Self {
Self {}
}
/// Calculate Simple Moving Average
#[napi]
pub fn calculate_sma(&self, values: Vec<f64>, period: u32) -> Result<Vec<f64>> {
match SMA::calculate_series(&values, period as usize) {
Ok(result) => Ok(result),
Err(e) => Err(Error::from_reason(e.to_string())),
}
}
/// Calculate Exponential Moving Average
#[napi]
pub fn calculate_ema(&self, values: Vec<f64>, period: u32) -> Result<Vec<f64>> {
match EMA::calculate_series(&values, period as usize) {
Ok(result) => Ok(result),
Err(e) => Err(Error::from_reason(e.to_string())),
}
}
/// Calculate Relative Strength Index
#[napi]
pub fn calculate_rsi(&self, values: Vec<f64>, period: u32) -> Result<Vec<f64>> {
match RSI::calculate_series(&values, period as usize) {
Ok(result) => Ok(result),
Err(e) => Err(Error::from_reason(e.to_string())),
}
}
/// Calculate MACD - returns JSON string
#[napi]
pub fn calculate_macd(
&self,
values: Vec<f64>,
fast_period: u32,
slow_period: u32,
signal_period: u32
) -> Result<String> {
match MACD::calculate_series(&values, fast_period as usize, slow_period as usize, signal_period as usize) {
Ok((macd, signal, histogram)) => {
let result = serde_json::json!({
"macd": macd,
"signal": signal,
"histogram": histogram
});
Ok(result.to_string())
}
Err(e) => Err(Error::from_reason(e.to_string())),
}
}
/// Calculate Bollinger Bands - returns JSON string
#[napi]
pub fn calculate_bollinger_bands(
&self,
values: Vec<f64>,
period: u32,
std_dev: f64
) -> Result<String> {
match BollingerBands::calculate_series(&values, period as usize, std_dev) {
Ok((middle, upper, lower)) => {
let result = serde_json::json!({
"middle": middle,
"upper": upper,
"lower": lower
});
Ok(result.to_string())
}
Err(e) => Err(Error::from_reason(e.to_string())),
}
}
/// Calculate Stochastic Oscillator - returns JSON string
#[napi]
pub fn calculate_stochastic(
&self,
high: Vec<f64>,
low: Vec<f64>,
close: Vec<f64>,
k_period: u32,
d_period: u32,
smooth_k: u32
) -> Result<String> {
match Stochastic::calculate_series(
&high,
&low,
&close,
k_period as usize,
d_period as usize,
smooth_k as usize
) {
Ok((k, d)) => {
let result = serde_json::json!({
"k": k,
"d": d
});
Ok(result.to_string())
}
Err(e) => Err(Error::from_reason(e.to_string())),
}
}
/// Calculate Average True Range
#[napi]
pub fn calculate_atr(
&self,
high: Vec<f64>,
low: Vec<f64>,
close: Vec<f64>,
period: u32
) -> Result<Vec<f64>> {
match ATR::calculate_series(&high, &low, &close, period as usize) {
Ok(result) => Ok(result),
Err(e) => Err(Error::from_reason(e.to_string())),
}
}
}
/// Incremental indicator calculator for streaming data
#[napi]
pub struct IncrementalSMA {
indicator: SMA,
}
#[napi]
impl IncrementalSMA {
#[napi(constructor)]
pub fn new(period: u32) -> Result<Self> {
match SMA::new(period as usize) {
Ok(indicator) => Ok(Self { indicator }),
Err(e) => Err(Error::from_reason(e.to_string())),
}
}
#[napi]
pub fn update(&mut self, value: f64) -> Result<Option<f64>> {
match self.indicator.update(value) {
Ok(result) => Ok(result),
Err(e) => Err(Error::from_reason(e.to_string())),
}
}
#[napi]
pub fn current(&self) -> Option<f64> {
self.indicator.current()
}
#[napi]
pub fn reset(&mut self) {
self.indicator.reset();
}
#[napi]
pub fn is_ready(&self) -> bool {
self.indicator.is_ready()
}
}
/// Incremental EMA calculator
#[napi]
pub struct IncrementalEMA {
indicator: EMA,
}
#[napi]
impl IncrementalEMA {
#[napi(constructor)]
pub fn new(period: u32) -> Result<Self> {
match EMA::new(period as usize) {
Ok(indicator) => Ok(Self { indicator }),
Err(e) => Err(Error::from_reason(e.to_string())),
}
}
#[napi]
pub fn update(&mut self, value: f64) -> Result<Option<f64>> {
match self.indicator.update(value) {
Ok(result) => Ok(result),
Err(e) => Err(Error::from_reason(e.to_string())),
}
}
#[napi]
pub fn current(&self) -> Option<f64> {
self.indicator.current()
}
#[napi]
pub fn reset(&mut self) {
self.indicator.reset();
}
}
/// Incremental RSI calculator
#[napi]
pub struct IncrementalRSI {
indicator: RSI,
}
#[napi]
impl IncrementalRSI {
#[napi(constructor)]
pub fn new(period: u32) -> Result<Self> {
match RSI::new(period as usize) {
Ok(indicator) => Ok(Self { indicator }),
Err(e) => Err(Error::from_reason(e.to_string())),
}
}
#[napi]
pub fn update(&mut self, value: f64) -> Result<Option<f64>> {
match self.indicator.update(value) {
Ok(result) => Ok(result),
Err(e) => Err(Error::from_reason(e.to_string())),
}
}
#[napi]
pub fn current(&self) -> Option<f64> {
self.indicator.current()
}
}

View file

@ -0,0 +1,435 @@
mod indicators;
mod risk;
mod backtest;
pub use indicators::{TechnicalIndicators, IncrementalSMA, IncrementalEMA, IncrementalRSI};
pub use risk::{RiskAnalyzer, OrderbookAnalyzer};
pub use backtest::BacktestEngine;
use napi_derive::napi;
use napi::{bindgen_prelude::*, JsObject};
use crate::{
TradingCore, TradingMode, Order, OrderType, TimeInForce, Side,
MarketUpdate, Quote, Trade,
MarketMicrostructure,
core::{create_market_data_source, create_execution_handler, create_time_provider},
};
use crate::risk::RiskLimits;
use std::sync::Arc;
use parking_lot::Mutex;
use chrono::{DateTime, Utc};
#[napi]
pub struct TradingEngine {
core: Arc<Mutex<TradingCore>>,
}
#[napi]
impl TradingEngine {
#[napi(constructor)]
pub fn new(mode: String, config: JsObject) -> Result<Self> {
let mode = parse_mode(&mode, config)?;
let market_data_source = create_market_data_source(&mode);
let execution_handler = create_execution_handler(&mode);
let time_provider = create_time_provider(&mode);
let core = TradingCore::new(mode, market_data_source, execution_handler, time_provider);
Ok(Self {
core: Arc::new(Mutex::new(core)),
})
}
#[napi]
pub fn get_mode(&self) -> String {
let core = self.core.lock();
match core.get_mode() {
TradingMode::Backtest { .. } => "backtest".to_string(),
TradingMode::Paper { .. } => "paper".to_string(),
TradingMode::Live { .. } => "live".to_string(),
}
}
#[napi]
pub fn get_current_time(&self) -> i64 {
let core = self.core.lock();
core.get_time().timestamp_millis()
}
#[napi]
pub fn submit_order(&self, order_js: JsObject) -> Result<String> {
let order = parse_order(order_js)?;
// For now, return a mock result - in real implementation would queue the order
let result = crate::ExecutionResult {
order_id: order.id.clone(),
status: crate::OrderStatus::Accepted,
fills: vec![],
};
Ok(serde_json::to_string(&result).unwrap())
}
#[napi]
pub fn check_risk(&self, order_js: JsObject) -> Result<String> {
let order = parse_order(order_js)?;
let core = self.core.lock();
// Get current position for the symbol
let position = core.position_tracker.get_position(&order.symbol);
let current_quantity = position.map(|p| p.quantity);
let result = core.risk_engine.check_order(&order, current_quantity);
Ok(serde_json::to_string(&result).unwrap())
}
#[napi]
pub fn update_quote(&self, symbol: String, bid: f64, ask: f64, bid_size: f64, ask_size: f64) -> Result<()> {
let quote = Quote { bid, ask, bid_size, ask_size };
let core = self.core.lock();
let timestamp = core.get_time();
core.orderbooks.update_quote(&symbol, quote, timestamp);
// Update unrealized P&L
let mid_price = (bid + ask) / 2.0;
core.position_tracker.update_unrealized_pnl(&symbol, mid_price);
Ok(())
}
#[napi]
pub fn update_trade(&self, symbol: String, price: f64, size: f64, side: String) -> Result<()> {
let side = match side.as_str() {
"buy" | "Buy" => Side::Buy,
"sell" | "Sell" => Side::Sell,
_ => return Err(Error::from_reason("Invalid side")),
};
let trade = Trade { price, size, side };
let core = self.core.lock();
let timestamp = core.get_time();
core.orderbooks.update_trade(&symbol, trade, timestamp);
Ok(())
}
#[napi]
pub fn get_orderbook_snapshot(&self, symbol: String, depth: u32) -> Result<String> {
let core = self.core.lock();
let snapshot = core.orderbooks.get_snapshot(&symbol, depth as usize)
.ok_or_else(|| Error::from_reason("Symbol not found"))?;
Ok(serde_json::to_string(&snapshot).unwrap())
}
#[napi]
pub fn get_best_bid_ask(&self, symbol: String) -> Result<Vec<f64>> {
let core = self.core.lock();
let (bid, ask) = core.orderbooks.get_best_bid_ask(&symbol)
.ok_or_else(|| Error::from_reason("Symbol not found"))?;
Ok(vec![bid, ask])
}
#[napi]
pub fn get_position(&self, symbol: String) -> Result<Option<String>> {
let core = self.core.lock();
let position = core.position_tracker.get_position(&symbol);
Ok(position.map(|p| serde_json::to_string(&p).unwrap()))
}
#[napi]
pub fn get_all_positions(&self) -> Result<String> {
let core = self.core.lock();
let positions = core.position_tracker.get_all_positions();
Ok(serde_json::to_string(&positions).unwrap())
}
#[napi]
pub fn get_open_positions(&self) -> Result<String> {
let core = self.core.lock();
let positions = core.position_tracker.get_open_positions();
Ok(serde_json::to_string(&positions).unwrap())
}
#[napi]
pub fn get_total_pnl(&self) -> Result<Vec<f64>> {
let core = self.core.lock();
let (realized, unrealized) = core.position_tracker.get_total_pnl();
Ok(vec![realized, unrealized])
}
#[napi]
pub fn process_fill(&self, symbol: String, price: f64, quantity: f64, side: String, commission: f64) -> Result<String> {
self.process_fill_with_metadata(symbol, price, quantity, side, commission, None, None)
}
#[napi]
pub fn process_fill_with_metadata(
&self,
symbol: String,
price: f64,
quantity: f64,
side: String,
commission: f64,
order_id: Option<String>,
strategy_id: Option<String>
) -> Result<String> {
let side = match side.as_str() {
"buy" | "Buy" => Side::Buy,
"sell" | "Sell" => Side::Sell,
_ => return Err(Error::from_reason("Invalid side")),
};
let core = self.core.lock();
let timestamp = core.get_time();
let fill = crate::Fill {
timestamp,
price,
quantity,
commission,
};
let update = core.position_tracker.process_fill_with_tracking(&symbol, &fill, side, order_id, strategy_id);
// Update risk engine with new position
core.risk_engine.update_position(&symbol, update.resulting_position.quantity);
// Update daily P&L
if update.resulting_position.realized_pnl != 0.0 {
core.risk_engine.update_daily_pnl(update.resulting_position.realized_pnl);
}
Ok(serde_json::to_string(&update).unwrap())
}
#[napi]
pub fn update_risk_limits(&self, limits_js: JsObject) -> Result<()> {
let limits = parse_risk_limits(limits_js)?;
let core = self.core.lock();
core.risk_engine.update_limits(limits);
Ok(())
}
#[napi]
pub fn reset_daily_metrics(&self) -> Result<()> {
let core = self.core.lock();
core.risk_engine.reset_daily_metrics();
Ok(())
}
#[napi]
pub fn get_risk_metrics(&self) -> Result<String> {
let core = self.core.lock();
let metrics = core.risk_engine.get_risk_metrics();
Ok(serde_json::to_string(&metrics).unwrap())
}
// Backtest-specific methods
#[napi]
pub fn advance_time(&self, to_timestamp: i64) -> Result<()> {
let core = self.core.lock();
if let TradingMode::Backtest { .. } = core.get_mode() {
// Downcast time provider to SimulatedTime and advance it
if let Some(simulated_time) = core.time_provider.as_any().downcast_ref::<crate::core::time_providers::SimulatedTime>() {
let new_time = DateTime::<Utc>::from_timestamp_millis(to_timestamp)
.ok_or_else(|| Error::from_reason("Invalid timestamp"))?;
simulated_time.advance_to(new_time);
Ok(())
} else {
Err(Error::from_reason("Failed to access simulated time provider"))
}
} else {
Err(Error::from_reason("Can only advance time in backtest mode"))
}
}
#[napi]
pub fn set_microstructure(&self, symbol: String, microstructure_js: JsObject) -> Result<()> {
let microstructure = parse_microstructure(microstructure_js)?;
let _core = self.core.lock();
// Store microstructure for use in fill simulation
// In real implementation, would pass to execution handler
Ok(())
}
#[napi]
pub fn load_historical_data(&self, data_json: String) -> Result<()> {
let data: Vec<MarketUpdate> = serde_json::from_str(&data_json)
.map_err(|e| Error::from_reason(format!("Failed to parse data: {}", e)))?;
let core = self.core.lock();
// Downcast to HistoricalDataSource if in backtest mode
if let TradingMode::Backtest { .. } = core.get_mode() {
let mut data_source = core.market_data_source.write();
if let Some(historical_source) = data_source.as_any_mut().downcast_mut::<crate::core::market_data_sources::HistoricalDataSource>() {
historical_source.load_data(data);
}
}
Ok(())
}
#[napi]
pub fn generate_mock_data(&self, symbol: String, start_time: i64, end_time: i64, seed: Option<u32>) -> Result<()> {
let core = self.core.lock();
// Only available in backtest mode
if let TradingMode::Backtest { .. } = core.get_mode() {
let mut data_source = core.market_data_source.write();
if let Some(historical_source) = data_source.as_any_mut().downcast_mut::<crate::core::market_data_sources::HistoricalDataSource>() {
let start_dt = DateTime::<Utc>::from_timestamp_millis(start_time)
.ok_or_else(|| Error::from_reason("Invalid start time"))?;
let end_dt = DateTime::<Utc>::from_timestamp_millis(end_time)
.ok_or_else(|| Error::from_reason("Invalid end time"))?;
historical_source.generate_mock_data(symbol, start_dt, end_dt, seed.map(|s| s as u64));
} else {
return Err(Error::from_reason("Failed to access historical data source"));
}
} else {
return Err(Error::from_reason("Mock data generation only available in backtest mode"));
}
Ok(())
}
#[napi]
pub fn get_trade_history(&self) -> Result<String> {
let core = self.core.lock();
let trades = core.position_tracker.get_trade_history();
Ok(serde_json::to_string(&trades).unwrap())
}
#[napi]
pub fn get_closed_trades(&self) -> Result<String> {
let core = self.core.lock();
let trades = core.position_tracker.get_closed_trades();
Ok(serde_json::to_string(&trades).unwrap())
}
#[napi]
pub fn get_open_trades(&self) -> Result<String> {
let core = self.core.lock();
let trades = core.position_tracker.get_open_trades();
Ok(serde_json::to_string(&trades).unwrap())
}
#[napi]
pub fn get_trade_count(&self) -> Result<u32> {
let core = self.core.lock();
Ok(core.position_tracker.get_trade_count() as u32)
}
#[napi]
pub fn get_closed_trade_count(&self) -> Result<u32> {
let core = self.core.lock();
Ok(core.position_tracker.get_closed_trade_count() as u32)
}
}
// Helper functions to parse JavaScript objects
fn parse_mode(mode_str: &str, config: JsObject) -> Result<TradingMode> {
match mode_str {
"backtest" => {
let start_time: i64 = config.get_named_property("startTime")?;
let end_time: i64 = config.get_named_property("endTime")?;
let speed_multiplier: f64 = config.get_named_property("speedMultiplier")
.unwrap_or(1.0);
Ok(TradingMode::Backtest {
start_time: DateTime::<Utc>::from_timestamp_millis(start_time)
.ok_or_else(|| Error::from_reason("Invalid start time"))?,
end_time: DateTime::<Utc>::from_timestamp_millis(end_time)
.ok_or_else(|| Error::from_reason("Invalid end time"))?,
speed_multiplier,
})
}
"paper" => {
let starting_capital: f64 = config.get_named_property("startingCapital")?;
Ok(TradingMode::Paper { starting_capital })
}
"live" => {
let broker: String = config.get_named_property("broker")?;
let account_id: String = config.get_named_property("accountId")?;
Ok(TradingMode::Live { broker, account_id })
}
_ => Err(Error::from_reason("Invalid mode")),
}
}
fn parse_order(order_js: JsObject) -> Result<Order> {
let id: String = order_js.get_named_property("id")?;
let symbol: String = order_js.get_named_property("symbol")?;
let side_str: String = order_js.get_named_property("side")?;
let side = match side_str.as_str() {
"buy" | "Buy" => Side::Buy,
"sell" | "Sell" => Side::Sell,
_ => return Err(Error::from_reason("Invalid side")),
};
let quantity: f64 = order_js.get_named_property("quantity")?;
let order_type_str: String = order_js.get_named_property("orderType")?;
let order_type = match order_type_str.as_str() {
"market" => OrderType::Market,
"limit" => {
let price: f64 = order_js.get_named_property("limitPrice")?;
OrderType::Limit { price }
}
_ => return Err(Error::from_reason("Invalid order type")),
};
let time_in_force_str: String = order_js.get_named_property("timeInForce")
.unwrap_or_else(|_| "DAY".to_string());
let time_in_force = match time_in_force_str.as_str() {
"DAY" => TimeInForce::Day,
"GTC" => TimeInForce::GTC,
"IOC" => TimeInForce::IOC,
"FOK" => TimeInForce::FOK,
_ => TimeInForce::Day,
};
Ok(Order {
id,
symbol,
side,
quantity,
order_type,
time_in_force,
})
}
fn parse_risk_limits(limits_js: JsObject) -> Result<RiskLimits> {
Ok(RiskLimits {
max_position_size: limits_js.get_named_property("maxPositionSize")?,
max_order_size: limits_js.get_named_property("maxOrderSize")?,
max_daily_loss: limits_js.get_named_property("maxDailyLoss")?,
max_gross_exposure: limits_js.get_named_property("maxGrossExposure")?,
max_symbol_exposure: limits_js.get_named_property("maxSymbolExposure")?,
})
}
fn parse_microstructure(microstructure_js: JsObject) -> Result<MarketMicrostructure> {
let intraday_volume_profile: Vec<f64> = microstructure_js.get_named_property("intradayVolumeProfile")
.unwrap_or_else(|_| vec![1.0/24.0; 24]);
Ok(MarketMicrostructure {
symbol: microstructure_js.get_named_property("symbol")?,
avg_spread_bps: microstructure_js.get_named_property("avgSpreadBps")?,
daily_volume: microstructure_js.get_named_property("dailyVolume")?,
avg_trade_size: microstructure_js.get_named_property("avgTradeSize")?,
volatility: microstructure_js.get_named_property("volatility")?,
tick_size: microstructure_js.get_named_property("tickSize")?,
lot_size: microstructure_js.get_named_property("lotSize")?,
intraday_volume_profile,
})
}

View file

@ -0,0 +1,166 @@
use napi_derive::napi;
use napi::{bindgen_prelude::*};
use crate::risk::{BetSizer, BetSizingParameters, MarketRegime, RiskModel};
use crate::orderbook::{OrderBookAnalytics, LiquidityProfile};
use crate::positions::Position;
use std::collections::HashMap;
#[napi]
pub struct RiskAnalyzer {
risk_model: RiskModel,
bet_sizer: BetSizer,
}
#[napi]
impl RiskAnalyzer {
#[napi(constructor)]
pub fn new(capital: f64, base_risk_per_trade: f64, lookback_period: u32) -> Self {
Self {
risk_model: RiskModel::new(lookback_period as usize),
bet_sizer: BetSizer::new(capital, base_risk_per_trade),
}
}
#[napi]
pub fn update_returns(&mut self, symbol: String, returns: Vec<f64>) -> Result<()> {
self.risk_model.update_returns(&symbol, returns);
Ok(())
}
#[napi]
pub fn calculate_portfolio_risk(&self, positions_json: String, prices_json: String) -> Result<String> {
// Parse positions
let positions_data: Vec<(String, f64, f64)> = serde_json::from_str(&positions_json)
.map_err(|e| Error::from_reason(format!("Failed to parse positions: {}", e)))?;
let mut positions = HashMap::new();
for (symbol, quantity, avg_price) in positions_data {
positions.insert(symbol.clone(), Position {
symbol,
quantity,
average_price: avg_price,
realized_pnl: 0.0,
unrealized_pnl: 0.0,
total_cost: quantity * avg_price,
last_update: chrono::Utc::now(),
});
}
// Parse prices
let prices: HashMap<String, f64> = serde_json::from_str(&prices_json)
.map_err(|e| Error::from_reason(format!("Failed to parse prices: {}", e)))?;
// Calculate risk
match self.risk_model.calculate_portfolio_risk(&positions, &prices) {
Ok(risk) => Ok(serde_json::to_string(&risk).unwrap()),
Err(e) => Err(Error::from_reason(e)),
}
}
#[napi]
pub fn calculate_position_size(
&self,
signal_strength: f64,
signal_confidence: f64,
volatility: f64,
liquidity_score: f64,
current_drawdown: f64,
price: f64,
stop_loss: Option<f64>,
market_regime: String,
) -> Result<String> {
let regime = match market_regime.as_str() {
"trending" => MarketRegime::Trending,
"range_bound" => MarketRegime::RangeBound,
"high_volatility" => MarketRegime::HighVolatility,
"low_volatility" => MarketRegime::LowVolatility,
_ => MarketRegime::Transitioning,
};
let params = BetSizingParameters {
signal_strength,
signal_confidence,
market_regime: regime,
volatility,
liquidity_score,
correlation_exposure: 0.0, // Would be calculated from portfolio
current_drawdown,
};
let position_size = self.bet_sizer.calculate_position_size(
&params,
price,
stop_loss,
None, // Historical performance
None, // Orderbook analytics
None, // Liquidity profile
);
Ok(serde_json::to_string(&position_size).unwrap())
}
#[napi]
pub fn calculate_optimal_stop_loss(
&self,
entry_price: f64,
volatility: f64,
support_levels: Vec<f64>,
atr: Option<f64>,
is_long: bool,
) -> f64 {
self.bet_sizer.calculate_optimal_stop_loss(
entry_price,
volatility,
&support_levels,
atr,
is_long,
)
}
}
#[napi]
pub struct OrderbookAnalyzer {}
#[napi]
impl OrderbookAnalyzer {
#[napi(constructor)]
pub fn new() -> Self {
Self {}
}
#[napi]
pub fn analyze_orderbook(&self, snapshot_json: String) -> Result<String> {
let snapshot: crate::OrderBookSnapshot = serde_json::from_str(&snapshot_json)
.map_err(|e| Error::from_reason(format!("Failed to parse snapshot: {}", e)))?;
match OrderBookAnalytics::calculate(&snapshot) {
Some(analytics) => Ok(serde_json::to_string(&analytics).unwrap()),
None => Err(Error::from_reason("Failed to calculate analytics")),
}
}
#[napi]
pub fn calculate_liquidity_profile(&self, snapshot_json: String) -> Result<String> {
let snapshot: crate::OrderBookSnapshot = serde_json::from_str(&snapshot_json)
.map_err(|e| Error::from_reason(format!("Failed to parse snapshot: {}", e)))?;
let profile = LiquidityProfile::from_snapshot(&snapshot);
Ok(serde_json::to_string(&profile).unwrap())
}
#[napi]
pub fn calculate_market_impact(
&self,
snapshot_json: String,
order_size_usd: f64,
is_buy: bool,
) -> Result<String> {
let snapshot: crate::OrderBookSnapshot = serde_json::from_str(&snapshot_json)
.map_err(|e| Error::from_reason(format!("Failed to parse snapshot: {}", e)))?;
let profile = LiquidityProfile::from_snapshot(&snapshot);
let impact = profile.calculate_market_impact(order_size_usd, is_buy);
Ok(serde_json::to_string(&impact).unwrap())
}
}

View file

@ -0,0 +1,86 @@
use napi_derive::napi;
use napi::bindgen_prelude::*;
use std::sync::Arc;
use chrono::{DateTime, Utc};
#[napi]
pub struct BacktestAPI {
core: Arc<crate::TradingCore>,
}
impl BacktestAPI {
pub fn new(core: Arc<crate::TradingCore>) -> Self {
Self { core }
}
}
#[napi]
impl BacktestAPI {
#[napi]
pub async fn configure(
&self,
start_date: String,
end_date: String,
symbols: Vec<String>,
initial_capital: f64,
commission: f64,
slippage: f64,
) -> Result<()> {
// Parse dates
let start = DateTime::parse_from_rfc3339(&start_date)
.map_err(|e| Error::from_reason(format!("Invalid start date: {}", e)))?
.with_timezone(&Utc);
let end = DateTime::parse_from_rfc3339(&end_date)
.map_err(|e| Error::from_reason(format!("Invalid end date: {}", e)))?
.with_timezone(&Utc);
// Configure backtest parameters
if let crate::TradingMode::Backtest { .. } = self.core.get_mode() {
// Update backtest configuration
todo!("Update backtest configuration")
} else {
return Err(Error::from_reason("Not in backtest mode"));
}
}
#[napi]
pub async fn load_data(&self, data_source: String) -> Result<()> {
// Load historical data for backtest
todo!("Load historical data")
}
#[napi]
pub async fn run(&self) -> Result<JsObject> {
// Run the backtest
if let crate::TradingMode::Backtest { .. } = self.core.get_mode() {
// Execute backtest
todo!("Execute backtest")
} else {
return Err(Error::from_reason("Not in backtest mode"));
}
}
#[napi]
pub fn get_progress(&self) -> Result<f64> {
// Get backtest progress (0.0 to 1.0)
todo!("Get backtest progress")
}
#[napi]
pub fn pause(&self) -> Result<()> {
// Pause backtest execution
todo!("Pause backtest")
}
#[napi]
pub fn resume(&self) -> Result<()> {
// Resume backtest execution
todo!("Resume backtest")
}
#[napi]
pub fn get_results(&self) -> Result<JsObject> {
// Get backtest results
todo!("Get backtest results")
}
}

View file

@ -0,0 +1,51 @@
use napi_derive::napi;
use napi::bindgen_prelude::*;
use std::sync::Arc;
#[napi]
pub struct MarketDataAPI {
core: Arc<crate::TradingCore>,
}
impl MarketDataAPI {
pub fn new(core: Arc<crate::TradingCore>) -> Self {
Self { core }
}
}
#[napi]
impl MarketDataAPI {
#[napi]
pub async fn subscribe(&self, symbols: Vec<String>) -> Result<()> {
// Subscribe to market data for symbols
self.core.subscribe_market_data(symbols)
.await
.map_err(|e| Error::from_reason(e))
}
#[napi]
pub async fn unsubscribe(&self, symbols: Vec<String>) -> Result<()> {
// Unsubscribe from market data
self.core.unsubscribe_market_data(symbols)
.await
.map_err(|e| Error::from_reason(e))
}
#[napi]
pub fn get_latest_quote(&self, symbol: String) -> Result<JsObject> {
// Get latest quote for symbol
let quote = self.core.orderbook_manager
.get_best_bid_ask(&symbol)
.ok_or_else(|| Error::from_reason("No quote available"))?;
// Convert to JS object
// Note: In real implementation, would properly convert Quote to JsObject
todo!("Convert quote to JsObject")
}
#[napi]
pub fn get_latest_bar(&self, symbol: String) -> Result<JsObject> {
// Get latest bar for symbol
todo!("Get latest bar implementation")
}
}

View file

@ -0,0 +1,76 @@
use napi_derive::napi;
pub mod market_data;
pub mod orders;
pub mod positions;
pub mod backtest;
pub mod strategies;
pub mod system;
// Main API entry point
#[napi]
pub struct TradingAPI {
inner: std::sync::Arc<crate::TradingCore>,
}
#[napi]
impl TradingAPI {
#[napi(constructor)]
pub fn new(mode: String) -> napi::Result<Self> {
let trading_mode = parse_mode(&mode)?;
let core = crate::TradingCore::new(trading_mode)
.map_err(|e| napi::Error::from_reason(e))?;
Ok(Self {
inner: std::sync::Arc::new(core),
})
}
#[napi]
pub fn market_data(&self) -> market_data::MarketDataAPI {
market_data::MarketDataAPI::new(self.inner.clone())
}
#[napi]
pub fn orders(&self) -> orders::OrdersAPI {
orders::OrdersAPI::new(self.inner.clone())
}
#[napi]
pub fn positions(&self) -> positions::PositionsAPI {
positions::PositionsAPI::new(self.inner.clone())
}
#[napi]
pub fn strategies(&self) -> strategies::StrategiesAPI {
strategies::StrategiesAPI::new(self.inner.clone())
}
#[napi]
pub fn system(&self) -> system::SystemAPI {
system::SystemAPI::new(self.inner.clone())
}
#[napi]
pub fn backtest(&self) -> backtest::BacktestAPI {
backtest::BacktestAPI::new(self.inner.clone())
}
}
fn parse_mode(mode: &str) -> napi::Result<crate::TradingMode> {
match mode {
"backtest" => Ok(crate::TradingMode::Backtest {
start_time: chrono::Utc::now(),
end_time: chrono::Utc::now(),
speed_multiplier: 1.0,
}),
"paper" => Ok(crate::TradingMode::Paper {
starting_capital: 100_000.0,
}),
"live" => Ok(crate::TradingMode::Live {
broker: "default".to_string(),
account_id: "default".to_string(),
}),
_ => Err(napi::Error::from_reason(format!("Unknown mode: {}", mode))),
}
}

View file

@ -0,0 +1,87 @@
use napi_derive::napi;
use napi::bindgen_prelude::*;
use std::sync::Arc;
#[napi]
pub struct OrdersAPI {
core: Arc<crate::TradingCore>,
}
impl OrdersAPI {
pub fn new(core: Arc<crate::TradingCore>) -> Self {
Self { core }
}
}
#[napi]
impl OrdersAPI {
#[napi]
pub async fn submit_order(
&self,
symbol: String,
side: String,
quantity: f64,
order_type: String,
limit_price: Option<f64>,
stop_price: Option<f64>,
) -> Result<String> {
let side = match side.as_str() {
"buy" => crate::Side::Buy,
"sell" => crate::Side::Sell,
_ => return Err(Error::from_reason("Invalid side")),
};
let order_type = match order_type.as_str() {
"market" => crate::OrderType::Market,
"limit" => {
let price = limit_price.ok_or_else(|| Error::from_reason("Limit price required"))?;
crate::OrderType::Limit { price }
}
"stop" => {
let price = stop_price.ok_or_else(|| Error::from_reason("Stop price required"))?;
crate::OrderType::Stop { stop_price: price }
}
"stop_limit" => {
let stop = stop_price.ok_or_else(|| Error::from_reason("Stop price required"))?;
let limit = limit_price.ok_or_else(|| Error::from_reason("Limit price required"))?;
crate::OrderType::StopLimit { stop_price: stop, limit_price: limit }
}
_ => return Err(Error::from_reason("Invalid order type")),
};
let order = crate::Order {
id: uuid::Uuid::new_v4().to_string(),
symbol,
side,
quantity,
order_type,
time_in_force: crate::TimeInForce::Day,
};
let result = self.core.execution_handler
.write()
.execute_order(order)
.await
.map_err(|e| Error::from_reason(e))?;
Ok(result.order_id)
}
#[napi]
pub async fn cancel_order(&self, order_id: String) -> Result<()> {
// Cancel order implementation
todo!("Cancel order implementation")
}
#[napi]
pub fn get_pending_orders(&self) -> Result<Vec<JsObject>> {
// Get pending orders
todo!("Get pending orders implementation")
}
#[napi]
pub fn get_order_history(&self) -> Result<Vec<JsObject>> {
// Get order history
todo!("Get order history implementation")
}
}

View file

@ -0,0 +1,67 @@
use napi_derive::napi;
use napi::bindgen_prelude::*;
use std::sync::Arc;
#[napi]
pub struct PositionsAPI {
core: Arc<crate::TradingCore>,
}
impl PositionsAPI {
pub fn new(core: Arc<crate::TradingCore>) -> Self {
Self { core }
}
}
#[napi]
impl PositionsAPI {
#[napi]
pub fn get_position(&self, symbol: String) -> Result<JsObject> {
let position = self.core.position_tracker
.get_position(&symbol)
.ok_or_else(|| Error::from_reason("No position found"))?;
// Convert position to JsObject
todo!("Convert position to JsObject")
}
#[napi]
pub fn get_all_positions(&self) -> Result<Vec<JsObject>> {
let positions = self.core.position_tracker.get_all_positions();
// Convert positions to JsObjects
todo!("Convert positions to JsObjects")
}
#[napi]
pub fn get_closed_trades(&self) -> Result<Vec<JsObject>> {
let trades = self.core.position_tracker.get_closed_trades();
// Convert trades to JsObjects
todo!("Convert trades to JsObjects")
}
#[napi]
pub fn get_pnl(&self, symbol: Option<String>) -> Result<f64> {
if let Some(sym) = symbol {
// Get P&L for specific symbol
let position = self.core.position_tracker
.get_position(&sym)
.ok_or_else(|| Error::from_reason("No position found"))?;
Ok(position.realized_pnl + position.unrealized_pnl)
} else {
// Get total P&L
let positions = self.core.position_tracker.get_all_positions();
let total_pnl = positions.into_iter()
.map(|p| p.realized_pnl + p.unrealized_pnl)
.sum();
Ok(total_pnl)
}
}
#[napi]
pub fn get_portfolio_value(&self) -> Result<f64> {
// Calculate total portfolio value
todo!("Calculate portfolio value")
}
}

View file

@ -0,0 +1,76 @@
use napi_derive::napi;
use napi::bindgen_prelude::*;
use std::sync::Arc;
#[napi]
pub struct StrategiesAPI {
core: Arc<crate::TradingCore>,
}
impl StrategiesAPI {
pub fn new(core: Arc<crate::TradingCore>) -> Self {
Self { core }
}
}
#[napi]
impl StrategiesAPI {
#[napi]
pub async fn add_strategy(
&self,
name: String,
strategy_type: String,
parameters: String,
) -> Result<String> {
// Parse parameters from JSON string
let params: serde_json::Value = serde_json::from_str(&parameters)
.map_err(|e| Error::from_reason(format!("Invalid parameters: {}", e)))?;
// Create strategy based on type
let strategy = match strategy_type.as_str() {
"sma_crossover" => {
// Create SMA crossover strategy
todo!("Create SMA strategy")
}
"momentum" => {
// Create momentum strategy
todo!("Create momentum strategy")
}
_ => return Err(Error::from_reason("Unknown strategy type")),
};
// Add strategy to core
todo!("Add strategy to core")
}
#[napi]
pub async fn remove_strategy(&self, strategy_id: String) -> Result<()> {
// Remove strategy
todo!("Remove strategy implementation")
}
#[napi]
pub fn get_strategies(&self) -> Result<Vec<JsObject>> {
// Get all active strategies
todo!("Get strategies implementation")
}
#[napi]
pub async fn update_strategy_parameters(
&self,
strategy_id: String,
parameters: String,
) -> Result<()> {
// Parse and update parameters
let params: serde_json::Value = serde_json::from_str(&parameters)
.map_err(|e| Error::from_reason(format!("Invalid parameters: {}", e)))?;
todo!("Update strategy parameters")
}
#[napi]
pub fn get_strategy_performance(&self, strategy_id: String) -> Result<JsObject> {
// Get performance metrics for strategy
todo!("Get strategy performance")
}
}

View file

@ -0,0 +1,77 @@
use napi_derive::napi;
use napi::bindgen_prelude::*;
use std::sync::Arc;
#[napi]
pub struct SystemAPI {
core: Arc<crate::TradingCore>,
}
impl SystemAPI {
pub fn new(core: Arc<crate::TradingCore>) -> Self {
Self { core }
}
}
#[napi]
impl SystemAPI {
#[napi]
pub async fn start(&self) -> Result<()> {
// Start the trading system
match self.core.get_mode() {
crate::TradingMode::Backtest { .. } => {
// Start backtest processing
todo!("Start backtest")
}
crate::TradingMode::Paper { .. } => {
// Start paper trading
todo!("Start paper trading")
}
crate::TradingMode::Live { .. } => {
// Start live trading
todo!("Start live trading")
}
}
}
#[napi]
pub async fn stop(&self) -> Result<()> {
// Stop the trading system
todo!("Stop trading system")
}
#[napi]
pub fn get_mode(&self) -> String {
match self.core.get_mode() {
crate::TradingMode::Backtest { .. } => "backtest".to_string(),
crate::TradingMode::Paper { .. } => "paper".to_string(),
crate::TradingMode::Live { .. } => "live".to_string(),
}
}
#[napi]
pub fn get_current_time(&self) -> String {
self.core.get_time().to_rfc3339()
}
#[napi]
pub fn set_risk_limits(&self, limits: String) -> Result<()> {
// Parse and set risk limits
let limits: serde_json::Value = serde_json::from_str(&limits)
.map_err(|e| Error::from_reason(format!("Invalid limits: {}", e)))?;
todo!("Set risk limits")
}
#[napi]
pub fn get_risk_metrics(&self) -> Result<JsObject> {
// Get current risk metrics
todo!("Get risk metrics")
}
#[napi]
pub fn get_analytics(&self) -> Result<JsObject> {
// Get trading analytics
todo!("Get analytics")
}
}

View file

@ -0,0 +1,715 @@
use std::collections::HashMap;
use std::sync::Arc;
use parking_lot::RwLock;
use chrono::{DateTime, Utc};
use serde::{Serialize, Deserialize};
use crate::{
TradingMode, MarketDataSource, ExecutionHandler, TimeProvider,
MarketUpdate, MarketDataType, Order, Fill, Side,
positions::PositionTracker,
risk::RiskEngine,
orderbook::OrderBookManager,
};
use super::{
BacktestConfig, BacktestState, EventQueue, BacktestEvent, EventType,
Strategy, Signal, SignalType, BacktestResult, TradeTracker,
};
pub struct BacktestEngine {
config: BacktestConfig,
state: Arc<RwLock<BacktestState>>,
event_queue: Arc<RwLock<EventQueue>>,
strategies: Arc<RwLock<Vec<Box<dyn Strategy>>>>,
// Core components
position_tracker: Arc<PositionTracker>,
risk_engine: Arc<RiskEngine>,
orderbook_manager: Arc<OrderBookManager>,
time_provider: Arc<Box<dyn TimeProvider>>,
pub market_data_source: Arc<RwLock<Box<dyn MarketDataSource>>>,
execution_handler: Arc<RwLock<Box<dyn ExecutionHandler>>>,
// Metrics
total_trades: usize,
profitable_trades: usize,
total_pnl: f64,
// Price tracking - single source of truth
// Maps symbol -> (timestamp, price)
last_prices: HashMap<String, (DateTime<Utc>, f64)>,
// Trade tracking
trade_tracker: TradeTracker,
}
impl BacktestEngine {
pub fn new(
config: BacktestConfig,
mode: TradingMode,
time_provider: Box<dyn TimeProvider>,
market_data_source: Box<dyn MarketDataSource>,
execution_handler: Box<dyn ExecutionHandler>,
) -> Self {
let state = Arc::new(RwLock::new(
BacktestState::new(config.initial_capital, config.start_time)
));
Self {
config,
state,
event_queue: Arc::new(RwLock::new(EventQueue::new())),
strategies: Arc::new(RwLock::new(Vec::new())),
position_tracker: Arc::new(PositionTracker::new()),
risk_engine: Arc::new(RiskEngine::new()),
orderbook_manager: Arc::new(OrderBookManager::new()),
time_provider: Arc::new(time_provider),
market_data_source: Arc::new(RwLock::new(market_data_source)),
execution_handler: Arc::new(RwLock::new(execution_handler)),
total_trades: 0,
profitable_trades: 0,
total_pnl: 0.0,
last_prices: HashMap::new(),
trade_tracker: TradeTracker::new(),
}
}
pub fn add_strategy(&mut self, strategy: Box<dyn Strategy>) {
self.strategies.write().push(strategy);
}
pub async fn run(&mut self) -> Result<BacktestResult, String> {
eprintln!("=== BacktestEngine::run() START ===");
eprintln!("Config: start={}, end={}, symbols={:?}",
self.config.start_time, self.config.end_time, self.config.symbols);
eprintln!("Number of strategies loaded: {}", self.strategies.read().len());
// Initialize start time
if let Some(simulated_time) = self.time_provider.as_any()
.downcast_ref::<crate::core::time_providers::SimulatedTime>()
{
simulated_time.advance_to(self.config.start_time);
eprintln!("Time initialized to: {}", self.config.start_time);
}
// Load market data
eprintln!("Loading market data from data source...");
self.load_market_data().await?;
let queue_len = self.event_queue.read().len();
eprintln!("Event queue length after loading: {}", queue_len);
if queue_len == 0 {
eprintln!("WARNING: No events loaded! Check data source.");
}
// Main event loop - process events grouped by timestamp
let mut iteration = 0;
let mut last_update_time = self.config.start_time;
while !self.event_queue.read().is_empty() {
iteration += 1;
// Get the next event's timestamp
let next_event_time = self.event_queue.read()
.peek_next()
.map(|e| e.timestamp);
if let Some(event_time) = next_event_time {
// Advance time to the next event
self.advance_time(event_time);
// Get all events at this timestamp
let current_time = self.time_provider.now();
let events = self.event_queue.write().pop_until(current_time);
if iteration <= 5 || iteration % 100 == 0 {
eprintln!("Processing iteration {} at time {} with {} events",
iteration, current_time, events.len());
}
// Process all events at this timestamp
for event in events {
self.process_event(event).await?;
}
// Only update portfolio value if time has actually advanced
// This ensures we have prices for all symbols at this timestamp
if current_time > last_update_time {
self.update_portfolio_value();
last_update_time = current_time;
}
} else {
// No more events
break;
}
}
eprintln!("Backtest complete. Total trades: {}", self.total_trades);
// Close all open positions at market prices
self.close_all_positions().await?;
// Generate results
Ok(self.generate_results())
}
async fn load_market_data(&mut self) -> Result<(), String> {
eprintln!("=== load_market_data START ===");
let mut data_source = self.market_data_source.write();
// Check if it's a HistoricalDataSource
if let Some(historical) = data_source.as_any()
.downcast_ref::<crate::core::market_data_sources::HistoricalDataSource>() {
eprintln!("Data source is HistoricalDataSource");
eprintln!("Historical data points available: {}", historical.data_len());
} else {
eprintln!("WARNING: Data source is NOT HistoricalDataSource!");
}
eprintln!("Seeking to start time: {}", self.config.start_time);
data_source.seek_to_time(self.config.start_time)?;
let mut count = 0;
let mut first_few = 0;
// Load all data into event queue
while let Some(update) = data_source.get_next_update().await {
if update.timestamp > self.config.end_time {
eprintln!("Reached end time at {} data points", count);
break;
}
count += 1;
// Log first few data points
if first_few < 3 {
eprintln!("Data point {}: symbol={}, time={}, type={:?}",
count, update.symbol, update.timestamp,
match &update.data {
MarketDataType::Bar(b) => format!("Bar(close={})", b.close),
MarketDataType::Quote(q) => format!("Quote(bid={}, ask={})", q.bid, q.ask),
MarketDataType::Trade(t) => format!("Trade(price={})", t.price),
}
);
first_few += 1;
}
if count % 100 == 0 {
eprintln!("Loaded {} data points so far...", count);
}
let event = BacktestEvent::market_data(update.timestamp, update);
self.event_queue.write().push(event);
}
eprintln!("=== load_market_data COMPLETE ===");
eprintln!("Total data points loaded: {}", count);
Ok(())
}
async fn process_event(&mut self, event: BacktestEvent) -> Result<(), String> {
match event.event_type {
EventType::MarketData(data) => {
self.process_market_data(data).await?;
}
EventType::OrderSubmitted(order) => {
self.process_order_submission(order).await?;
}
EventType::OrderFilled(_fill) => {
// Fills are already processed when orders are executed
// This event is just for recording
// Note: We now record fills in process_fill with symbol info
}
EventType::OrderCancelled(order_id) => {
self.process_order_cancellation(&order_id)?;
}
EventType::TimeUpdate(time) => {
self.advance_time(time);
}
}
Ok(())
}
async fn process_market_data(&mut self, data: MarketUpdate) -> Result<(), String> {
static mut MARKET_DATA_COUNT: usize = 0;
unsafe {
MARKET_DATA_COUNT += 1;
if MARKET_DATA_COUNT <= 3 || MARKET_DATA_COUNT % 100 == 0 {
eprintln!("process_market_data #{}: symbol={}, time={}",
MARKET_DATA_COUNT, data.symbol, data.timestamp);
}
}
// Update price tracking - single source of truth
let price = match &data.data {
MarketDataType::Bar(bar) => {
let old_entry = self.last_prices.get(&data.symbol);
let old_price = old_entry.map(|(_, p)| *p);
eprintln!("📊 PRICE UPDATE: {} @ {} - close: ${:.2} (was: ${:?})",
data.symbol, data.timestamp.format("%Y-%m-%d"), bar.close, old_price);
bar.close
}
MarketDataType::Quote(quote) => {
// Use mid price for quotes
(quote.bid + quote.ask) / 2.0
}
MarketDataType::Trade(trade) => {
trade.price
}
};
// Store price with timestamp - this is our source of truth
self.last_prices.insert(data.symbol.clone(), (data.timestamp, price));
// Convert to simpler MarketData for strategies
let market_data = self.convert_to_market_data(&data);
// Send to strategies
let mut all_signals = Vec::new();
{
let mut strategies = self.strategies.write();
for (i, strategy) in strategies.iter_mut().enumerate() {
let signals = strategy.on_market_data(&market_data);
if !signals.is_empty() {
eprintln!("Strategy {} generated {} signals!", i, signals.len());
}
all_signals.extend(signals);
}
}
// Process signals
for signal in all_signals {
eprintln!("Processing signal: {:?}", signal);
self.process_signal(signal).await?;
}
// Check pending orders for fills
self.check_pending_orders(&data).await?;
// Don't update portfolio value here - wait until all events at this timestamp are processed
Ok(())
}
fn convert_to_market_data(&self, update: &MarketUpdate) -> MarketUpdate {
// MarketData is a type alias for MarketUpdate
update.clone()
}
async fn process_signal(&mut self, signal: Signal) -> Result<(), String> {
let current_time = self.time_provider.now();
eprintln!("📡 SIGNAL at {}: {:?} {} (strength: {}, reason: {:?})",
current_time.format("%Y-%m-%d"),
signal.signal_type,
signal.symbol,
signal.strength,
signal.reason);
// Only process strong signals
if signal.strength.abs() < 0.7 {
eprintln!(" Signal ignored (strength < 0.7)");
return Ok(());
}
// Check current price before creating order
if let Some((price_time, price)) = self.last_prices.get(&signal.symbol) {
eprintln!(" Current price for {}: ${:.2} (from {})",
signal.symbol, price, price_time.format("%Y-%m-%d"));
}
// Convert signal to order
let order = self.signal_to_order(signal)?;
eprintln!(" Creating {:?} order for {} shares", order.side, order.quantity);
// Submit order
self.process_order_submission(order).await
}
fn signal_to_order(&self, signal: Signal) -> Result<Order, String> {
let quantity = signal.quantity.unwrap_or_else(|| {
// Calculate position size based on portfolio
self.calculate_position_size(&signal.symbol, signal.strength)
});
let side = match signal.signal_type {
SignalType::Buy => Side::Buy,
SignalType::Sell => Side::Sell,
SignalType::Close => {
// Determine side based on current position
let position = self.position_tracker.get_position(&signal.symbol);
if position.as_ref().map(|p| p.quantity > 0.0).unwrap_or(false) {
Side::Sell
} else {
Side::Buy
}
}
};
Ok(crate::Order {
id: format!("order_{}", uuid::Uuid::new_v4()),
symbol: signal.symbol,
side,
quantity,
order_type: crate::OrderType::Market,
time_in_force: crate::TimeInForce::Day,
})
}
async fn process_order_submission(&mut self, order: Order) -> Result<(), String> {
// Risk checks
// Get current position for the symbol
let current_position = self.position_tracker
.get_position(&order.symbol)
.map(|p| p.quantity);
let risk_check = self.risk_engine.check_order(&order, current_position);
if !risk_check.passed {
return Err(format!("Risk check failed: {:?}", risk_check.violations));
}
// Add to pending orders
self.state.write().add_pending_order(order.clone());
// For market orders in backtesting, fill immediately
if matches!(order.order_type, crate::OrderType::Market) {
self.check_order_fill(&order).await?;
}
Ok(())
}
async fn check_pending_orders(&mut self, market_data: &MarketUpdate) -> Result<(), String> {
let orders_to_check: Vec<Order> = {
let state = self.state.read();
state.pending_orders.values()
.filter(|o| o.symbol == market_data.symbol)
.cloned()
.collect()
};
for order in orders_to_check {
self.check_order_fill(&order).await?;
}
Ok(())
}
async fn check_order_fill(&mut self, order: &Order) -> Result<(), String> {
let current_time = self.time_provider.now();
// Get current market price - only use if it's from the current time
let (price_time, base_price) = self.last_prices.get(&order.symbol)
.copied()
.ok_or_else(|| format!("No price available for symbol: {}", order.symbol))?;
// CRITICAL: Verify the price is from the current time
if price_time != current_time {
eprintln!("⚠️ WARNING: Price timestamp mismatch! Current: {}, Price from: {}",
current_time.format("%Y-%m-%d %H:%M:%S"),
price_time.format("%Y-%m-%d %H:%M:%S"));
// In a real system, we would reject this fill or fetch current price
// For now, log the issue
}
eprintln!("🔍 CHECK_ORDER_FILL: {:?} {} @ time {} - price: ${:.2} (from {})",
order.side, order.symbol, current_time.format("%Y-%m-%d"),
base_price, price_time.format("%Y-%m-%d"));
// DEBUG: Check what's in last_prices for this symbol
eprintln!(" DEBUG: All prices for {}: {:?}",
order.symbol,
self.last_prices.get(&order.symbol));
// Apply slippage
let fill_price = match order.side {
crate::Side::Buy => base_price * (1.0 + self.config.slippage),
crate::Side::Sell => base_price * (1.0 - self.config.slippage),
};
eprintln!(" Fill price after slippage ({}): ${:.2}", self.config.slippage, fill_price);
// Create fill
let fill = crate::Fill {
timestamp: current_time,
price: fill_price,
quantity: order.quantity,
commission: order.quantity * fill_price * self.config.commission,
};
// Process the fill
self.process_fill(&order, fill).await
}
async fn process_fill(&mut self, order: &crate::Order, fill: crate::Fill) -> Result<(), String> {
// Remove from pending orders
self.state.write().remove_pending_order(&order.id);
// Get position before the fill
let position_before = self.position_tracker
.get_position(&order.symbol)
.map(|p| p.quantity)
.unwrap_or(0.0);
// Update positions
let update = self.position_tracker.process_fill(
&order.symbol,
&fill,
order.side,
);
// Calculate P&L if position was reduced/closed
let pnl = if update.resulting_position.realized_pnl != 0.0 {
Some(update.resulting_position.realized_pnl)
} else {
None
};
// Get position after this fill
let position_after = update.resulting_position.quantity;
// Record the fill with position and P&L information
self.state.write().record_fill(
order.symbol.clone(),
order.side,
fill.clone(),
position_after,
position_before,
pnl
);
// Track trades
self.trade_tracker.process_fill(&order.symbol, order.side, &fill);
// Update cash
let cash_change = match order.side {
crate::Side::Buy => -(fill.quantity * fill.price + fill.commission),
crate::Side::Sell => fill.quantity * fill.price - fill.commission,
};
self.state.write().cash += cash_change;
// Notify strategies
{
let mut strategies = self.strategies.write();
for strategy in strategies.iter_mut() {
let side_str = match order.side {
crate::Side::Buy => "buy",
crate::Side::Sell => "sell",
};
strategy.on_fill(&order.symbol, fill.quantity, fill.price, side_str);
}
}
let fill_date = fill.timestamp.format("%Y-%m-%d").to_string();
let is_feb_mar_2024 = fill_date >= "2024-02-28".to_string() && fill_date <= "2024-03-05".to_string();
if is_feb_mar_2024 {
eprintln!("
🔴 CRITICAL FILL on {}: {} {} @ {} (side: {:?})",
fill_date, fill.quantity, order.symbol, fill.price, order.side);
eprintln!("Cash before: ${:.2}, Cash after: ${:.2}, Cash change: ${:.2}",
self.state.read().cash - cash_change, self.state.read().cash, cash_change);
}
eprintln!("Fill processed: {} {} @ {} (side: {:?})",
fill.quantity, order.symbol, fill.price, order.side);
eprintln!("Current position after fill: {}",
self.position_tracker.get_position(&order.symbol)
.map(|p| p.quantity)
.unwrap_or(0.0));
// Update metrics
self.total_trades += 1;
if update.resulting_position.realized_pnl > 0.0 {
self.profitable_trades += 1;
}
self.total_pnl = update.resulting_position.realized_pnl;
Ok(())
}
fn process_order_cancellation(&mut self, order_id: &str) -> Result<(), String> {
self.state.write().remove_pending_order(order_id);
Ok(())
}
fn advance_time(&mut self, time: DateTime<Utc>) {
if let Some(simulated_time) = self.time_provider.as_any()
.downcast_ref::<crate::core::time_providers::SimulatedTime>()
{
simulated_time.advance_to(time);
}
self.state.write().current_time = time;
}
fn update_portfolio_value(&mut self) {
let positions = self.position_tracker.get_all_positions();
let cash = self.state.read().cash;
let mut portfolio_value = cash;
let current_time = self.time_provider.now();
// Debug logging for first few updates
static mut UPDATE_COUNT: usize = 0;
unsafe {
UPDATE_COUNT += 1;
if UPDATE_COUNT <= 5 || UPDATE_COUNT % 100 == 0 ||
// Log around Feb 28 - Mar 5, 2024
(current_time.format("%Y-%m-%d").to_string() >= "2024-02-28".to_string() &&
current_time.format("%Y-%m-%d").to_string() <= "2024-03-05".to_string()) {
eprintln!("=== Portfolio Update #{} at {} ===", UPDATE_COUNT, current_time);
eprintln!("Cash: ${:.2}", cash);
}
}
for position in &positions {
// Use last known price for the symbol
let price = self.last_prices.get(&position.symbol)
.map(|(_, p)| *p)
.unwrap_or(position.average_price);
// Calculate market value correctly for long and short positions
let market_value = if position.quantity > 0.0 {
// Long position: value = quantity * current_price
position.quantity * price
} else {
// Short position:
// We have a liability to buy back shares at current market price
// This is a negative value that reduces portfolio value
// Value = quantity * price (quantity is already negative)
position.quantity * price
};
portfolio_value += market_value;
unsafe {
if UPDATE_COUNT <= 5 || UPDATE_COUNT % 100 == 0 ||
// Log around Feb 28 - Mar 5, 2024
(current_time.format("%Y-%m-%d").to_string() >= "2024-02-28".to_string() &&
current_time.format("%Y-%m-%d").to_string() <= "2024-03-05".to_string()) {
let pnl = if position.quantity > 0.0 {
(price - position.average_price) * position.quantity
} else {
(position.average_price - price) * position.quantity.abs()
};
let position_type = if position.quantity > 0.0 { "LONG" } else { "SHORT" };
eprintln!(" {} {} position: {} shares @ avg ${:.2}, current ${:.2} = ${:.2} (P&L: ${:.2})",
position_type, position.symbol, position.quantity, position.average_price, price, market_value, pnl);
}
}
}
unsafe {
if UPDATE_COUNT <= 5 || UPDATE_COUNT % 100 == 0 ||
// Log around Feb 28 - Mar 5, 2024
(current_time.format("%Y-%m-%d").to_string() >= "2024-02-28".to_string() &&
current_time.format("%Y-%m-%d").to_string() <= "2024-03-05".to_string()) {
eprintln!("Total Portfolio Value: ${:.2}", portfolio_value);
eprintln!("===================================");
}
}
self.state.write().update_portfolio_value(portfolio_value);
}
fn calculate_position_size(&self, symbol: &str, signal_strength: f64) -> f64 {
let state = self.state.read();
let portfolio_value = state.portfolio_value;
let cash = state.cash;
// Use available cash, not total portfolio value for position sizing
let allocation = 0.2; // 20% of available cash per position
let position_value = cash.min(portfolio_value * allocation) * signal_strength.abs();
let price = self.last_prices.get(symbol)
.map(|(_, p)| *p)
.unwrap_or(100.0);
let shares = (position_value / price).floor();
eprintln!("Position sizing for {}: portfolio=${:.2}, cash=${:.2}, price=${:.2}, shares={}",
symbol, portfolio_value, cash, price, shares);
shares
}
fn get_next_event_time(&self) -> Option<DateTime<Utc>> {
// Get the timestamp of the next event in the queue
self.event_queue.read()
.peek_next()
.map(|event| event.timestamp)
}
async fn close_all_positions(&mut self) -> Result<(), String> {
eprintln!("=== Closing all open positions at end of backtest ===");
eprintln!("Current time: {}", self.time_provider.now());
eprintln!("Last prices:");
for (symbol, (time, price)) in &self.last_prices {
eprintln!(" {}: ${:.2} (from {})", symbol, price, time.format("%Y-%m-%d %H:%M:%S"));
}
let positions = self.position_tracker.get_all_positions();
for position in positions {
if position.quantity.abs() > 0.001 {
let last_price = self.last_prices.get(&position.symbol).map(|(_, p)| *p);
eprintln!("Closing position: {} {} shares of {} at last price: {:?}",
if position.quantity > 0.0 { "Selling" } else { "Buying" },
position.quantity.abs(),
position.symbol,
last_price
);
// Create market order to close position
let order = crate::Order {
id: format!("close_{}", uuid::Uuid::new_v4()),
symbol: position.symbol.clone(),
side: if position.quantity > 0.0 { Side::Sell } else { Side::Buy },
quantity: position.quantity.abs(),
order_type: crate::OrderType::Market,
time_in_force: crate::TimeInForce::Day,
};
// Process the closing order
self.check_order_fill(&order).await?;
}
}
eprintln!("All positions closed. Final cash: {}", self.state.read().cash);
Ok(())
}
fn generate_results(&self) -> BacktestResult {
let state = self.state.read();
let start_time = self.config.start_time;
// Get final positions
let final_positions = self.position_tracker.get_all_positions()
.into_iter()
.map(|p| (p.symbol.clone(), p))
.collect();
// Get completed trades from tracker
let completed_trades = self.trade_tracker.get_completed_trades();
// Convert last_prices to simple HashMap for results
let simple_last_prices: HashMap<String, f64> = self.last_prices
.iter()
.map(|(symbol, (_, price))| (symbol.clone(), *price))
.collect();
// Use simple results builder with proper trade data
BacktestResult::from_engine_data_with_trades(
self.config.clone(),
state.equity_curve.clone(),
state.completed_trades.clone(),
completed_trades,
final_positions,
start_time,
&simple_last_prices,
)
}
}
// Add uuid dependency
use uuid::Uuid;

View file

@ -0,0 +1,55 @@
use chrono::{DateTime, Utc};
use serde::{Serialize, Deserialize};
use crate::{MarketUpdate, Order, Fill};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum EventType {
MarketData(MarketUpdate),
OrderSubmitted(Order),
OrderFilled(Fill),
OrderCancelled(String), // order_id
TimeUpdate(DateTime<Utc>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BacktestEvent {
pub timestamp: DateTime<Utc>,
pub event_type: EventType,
}
impl BacktestEvent {
pub fn market_data(timestamp: DateTime<Utc>, data: MarketUpdate) -> Self {
Self {
timestamp,
event_type: EventType::MarketData(data),
}
}
pub fn order_submitted(timestamp: DateTime<Utc>, order: Order) -> Self {
Self {
timestamp,
event_type: EventType::OrderSubmitted(order),
}
}
pub fn order_filled(timestamp: DateTime<Utc>, fill: Fill) -> Self {
Self {
timestamp,
event_type: EventType::OrderFilled(fill),
}
}
pub fn order_cancelled(timestamp: DateTime<Utc>, order_id: String) -> Self {
Self {
timestamp,
event_type: EventType::OrderCancelled(order_id),
}
}
pub fn time_update(timestamp: DateTime<Utc>) -> Self {
Self {
timestamp,
event_type: EventType::TimeUpdate(timestamp),
}
}
}

View file

@ -0,0 +1,669 @@
use chrono::{DateTime, Utc, Datelike};
use std::collections::HashMap;
use super::results::*;
use super::{BacktestConfig, CompletedTrade};
use super::trade_tracker::CompletedTrade as TrackedTrade;
use crate::{Position, Side};
pub struct MetricsCalculator {
config: BacktestConfig,
equity_curve: Vec<(DateTime<Utc>, f64)>,
trades: Vec<TrackedTrade>,
fills: Vec<CompletedTrade>,
positions: HashMap<String, Position>,
}
impl MetricsCalculator {
pub fn new(
config: BacktestConfig,
equity_curve: Vec<(DateTime<Utc>, f64)>,
trades: Vec<TrackedTrade>,
fills: Vec<CompletedTrade>,
positions: HashMap<String, Position>,
) -> Self {
Self {
config,
equity_curve,
trades,
fills,
positions,
}
}
pub fn calculate_all_metrics(&self) -> BacktestResult {
let start_time = Utc::now();
let backtest_id = format!("rust-{}", start_time.timestamp_millis());
// Calculate comprehensive metrics
let metrics = self.calculate_metrics();
// Convert equity curve
let equity_points: Vec<EquityPoint> = self.equity_curve.iter()
.map(|(dt, value)| EquityPoint {
timestamp: dt.timestamp_millis(),
value: *value,
})
.collect();
// Calculate drawdown curve
let drawdown_curve = self.calculate_drawdown_curve();
// Calculate returns
let daily_returns = self.calculate_daily_returns();
let cumulative_returns = self.calculate_cumulative_returns();
// Convert trades to UI format
let trades = self.convert_fills_to_trades();
let open_trades = self.get_open_trades();
// Period analysis
let monthly_returns = self.calculate_period_returns("monthly");
let yearly_returns = self.calculate_period_returns("yearly");
// Symbol analysis
let symbol_analysis = self.calculate_symbol_analysis();
// Drawdown periods
let drawdown_periods = self.analyze_drawdown_periods();
// Exposure analysis
let exposure_analysis = self.calculate_exposure_analysis();
// Position history
let position_history = self.calculate_position_history();
// Trade signals (extracted from trades)
let trade_signals = self.extract_trade_signals();
BacktestResult {
backtest_id,
status: "completed".to_string(),
started_at: self.config.start_time.to_rfc3339(),
completed_at: Utc::now().to_rfc3339(),
execution_time_ms: (Utc::now() - start_time).num_milliseconds() as u64,
config: self.config.clone(),
metrics,
equity_curve: equity_points,
drawdown_curve,
daily_returns,
cumulative_returns,
trades,
open_trades,
trade_signals,
positions: self.positions.clone(),
position_history,
monthly_returns,
yearly_returns,
symbol_analysis,
drawdown_periods,
exposure_analysis,
ohlc_data: HashMap::new(), // Will be filled by orchestrator
error: None,
warnings: None,
}
}
fn calculate_metrics(&self) -> BacktestMetrics {
let initial_capital = self.config.initial_capital;
let final_capital = self.equity_curve.last().map(|(_, v)| *v).unwrap_or(initial_capital);
// Core performance
let total_return = final_capital - initial_capital;
let total_return_pct = (total_return / initial_capital) * 100.0;
let days = (self.config.end_time - self.config.start_time).num_days() as f64;
let years = days / 365.25;
let annualized_return = if years > 0.0 {
((final_capital / initial_capital).powf(1.0 / years) - 1.0) * 100.0
} else {
0.0
};
// Trade statistics
let total_trades = self.trades.len();
let winning_trades = self.trades.iter().filter(|t| t.pnl > 0.0).count();
let losing_trades = self.trades.iter().filter(|t| t.pnl < 0.0).count();
let breakeven_trades = self.trades.iter().filter(|t| t.pnl == 0.0).count();
let win_rate = if total_trades > 0 {
(winning_trades as f64 / total_trades as f64) * 100.0
} else {
0.0
};
let loss_rate = if total_trades > 0 {
(losing_trades as f64 / total_trades as f64) * 100.0
} else {
0.0
};
// PnL calculations
let total_pnl = self.trades.iter().map(|t| t.pnl).sum::<f64>();
let total_commission = self.trades.iter().map(|t| t.commission).sum::<f64>();
let net_pnl = total_pnl - total_commission;
let wins: Vec<f64> = self.trades.iter().filter(|t| t.pnl > 0.0).map(|t| t.pnl).collect();
let losses: Vec<f64> = self.trades.iter().filter(|t| t.pnl < 0.0).map(|t| t.pnl).collect();
let avg_win = if !wins.is_empty() {
wins.iter().sum::<f64>() / wins.len() as f64
} else {
0.0
};
let avg_loss = if !losses.is_empty() {
losses.iter().sum::<f64>() / losses.len() as f64
} else {
0.0
};
let largest_win = wins.iter().cloned().fold(0.0, f64::max);
let largest_loss = losses.iter().cloned().fold(0.0, f64::min);
let avg_trade_pnl = if total_trades > 0 {
total_pnl / total_trades as f64
} else {
0.0
};
// Risk metrics
let returns = self.calculate_returns_series();
let volatility = self.calculate_volatility(&returns);
let sharpe_ratio = self.calculate_sharpe_ratio(&returns, volatility);
let sortino_ratio = self.calculate_sortino_ratio(&returns);
let (max_drawdown, max_drawdown_pct, max_dd_duration) = self.calculate_max_drawdown();
let calmar_ratio = if max_drawdown_pct != 0.0 {
annualized_return / max_drawdown_pct.abs()
} else {
0.0
};
// Value at Risk
let (var_95, cvar_95) = self.calculate_var_cvar(&returns, 0.95);
// Profit factor
let gross_profit = wins.iter().sum::<f64>();
let gross_loss = losses.iter().map(|l| l.abs()).sum::<f64>();
let profit_factor = if gross_loss > 0.0 {
gross_profit / gross_loss
} else if gross_profit > 0.0 {
f64::INFINITY
} else {
0.0
};
// Expectancy
let expectancy = avg_win * (win_rate / 100.0) - avg_loss.abs() * (loss_rate / 100.0);
let expectancy_ratio = if avg_loss != 0.0 {
expectancy / avg_loss.abs()
} else {
0.0
};
// Position statistics
let avg_trade_duration_hours = if total_trades > 0 {
self.trades.iter()
.map(|t| t.duration_seconds as f64 / 3600.0)
.sum::<f64>() / total_trades as f64
} else {
0.0
};
let total_long_trades = self.trades.iter().filter(|t| matches!(t.side, Side::Buy)).count();
let total_short_trades = self.trades.iter().filter(|t| matches!(t.side, Side::Sell)).count();
let long_wins = self.trades.iter()
.filter(|t| matches!(t.side, Side::Buy) && t.pnl > 0.0)
.count();
let short_wins = self.trades.iter()
.filter(|t| matches!(t.side, Side::Sell) && t.pnl > 0.0)
.count();
let long_win_rate = if total_long_trades > 0 {
(long_wins as f64 / total_long_trades as f64) * 100.0
} else {
0.0
};
let short_win_rate = if total_short_trades > 0 {
(short_wins as f64 / total_short_trades as f64) * 100.0
} else {
0.0
};
// Trading activity
let total_volume_traded = self.trades.iter()
.map(|t| t.quantity * t.entry_price + t.quantity * t.exit_price)
.sum::<f64>();
let commission_pct_of_pnl = if total_pnl != 0.0 {
(total_commission / total_pnl.abs()) * 100.0
} else {
0.0
};
// Calculate additional metrics
let downside_deviation = self.calculate_downside_deviation(&returns);
let win_loss_ratio = if avg_loss != 0.0 {
avg_win / avg_loss.abs()
} else {
0.0
};
let kelly_criterion = if win_loss_ratio > 0.0 {
(win_rate / 100.0) - ((1.0 - win_rate / 100.0) / win_loss_ratio)
} else {
0.0
};
// Streaks
let (max_wins, max_losses, current_streak) = self.calculate_streaks();
BacktestMetrics {
// Core Performance
total_return,
total_return_pct,
annualized_return,
total_pnl,
net_pnl,
// Risk Metrics
sharpe_ratio,
sortino_ratio,
calmar_ratio,
max_drawdown,
max_drawdown_pct,
max_drawdown_duration_days: max_dd_duration,
volatility,
downside_deviation,
var_95,
cvar_95,
// Trade Statistics
total_trades,
winning_trades,
losing_trades,
breakeven_trades,
win_rate,
loss_rate,
avg_win,
avg_loss,
largest_win,
largest_loss,
avg_trade_pnl,
avg_trade_duration_hours,
profit_factor,
expectancy,
expectancy_ratio,
// Position Statistics
avg_position_size: 0.0, // TODO: Calculate from position history
max_position_size: 0.0,
avg_positions_held: 0.0,
max_concurrent_positions: 0,
total_long_trades,
total_short_trades,
long_win_rate,
short_win_rate,
// Trading Activity
total_volume_traded,
total_commission_paid: total_commission,
commission_pct_of_pnl,
avg_daily_trades: 0.0, // TODO: Calculate
max_daily_trades: 0,
trading_days: 0,
exposure_time_pct: 0.0,
// Efficiency Metrics
return_over_max_dd: if max_drawdown != 0.0 {
total_return / max_drawdown.abs()
} else {
0.0
},
win_loss_ratio,
kelly_criterion,
ulcer_index: 0.0, // TODO: Calculate
serenity_ratio: 0.0,
lake_ratio: 0.0,
// Monthly/Period Analysis
best_month_return: 0.0, // TODO: Calculate from monthly returns
worst_month_return: 0.0,
positive_months: 0,
negative_months: 0,
monthly_win_rate: 0.0,
avg_monthly_return: 0.0,
// Streak Analysis
max_consecutive_wins: max_wins,
max_consecutive_losses: max_losses,
current_streak,
avg_winning_streak: 0.0, // TODO: Calculate
avg_losing_streak: 0.0,
}
}
fn calculate_returns_series(&self) -> Vec<f64> {
let mut returns = Vec::new();
for i in 1..self.equity_curve.len() {
let prev_value = self.equity_curve[i - 1].1;
let curr_value = self.equity_curve[i].1;
if prev_value > 0.0 {
returns.push((curr_value - prev_value) / prev_value);
}
}
returns
}
fn calculate_volatility(&self, returns: &[f64]) -> f64 {
if returns.is_empty() {
return 0.0;
}
let mean = returns.iter().sum::<f64>() / returns.len() as f64;
let variance = returns.iter()
.map(|r| (r - mean).powi(2))
.sum::<f64>() / returns.len() as f64;
// Annualize the volatility
variance.sqrt() * (252.0_f64).sqrt()
}
fn calculate_sharpe_ratio(&self, returns: &[f64], volatility: f64) -> f64 {
if volatility == 0.0 || returns.is_empty() {
return 0.0;
}
let mean_return = returns.iter().sum::<f64>() / returns.len() as f64;
let annualized_return = mean_return * 252.0;
let risk_free_rate = 0.02; // 2% annual risk-free rate
(annualized_return - risk_free_rate) / volatility
}
fn calculate_sortino_ratio(&self, returns: &[f64]) -> f64 {
if returns.is_empty() {
return 0.0;
}
let mean_return = returns.iter().sum::<f64>() / returns.len() as f64;
let downside_deviation = self.calculate_downside_deviation(returns);
if downside_deviation == 0.0 {
return 0.0;
}
let annualized_return = mean_return * 252.0;
let risk_free_rate = 0.02;
(annualized_return - risk_free_rate) / downside_deviation
}
fn calculate_downside_deviation(&self, returns: &[f64]) -> f64 {
let negative_returns: Vec<f64> = returns.iter()
.filter(|&&r| r < 0.0)
.cloned()
.collect();
if negative_returns.is_empty() {
return 0.0;
}
let mean = negative_returns.iter().sum::<f64>() / negative_returns.len() as f64;
let variance = negative_returns.iter()
.map(|r| (r - mean).powi(2))
.sum::<f64>() / negative_returns.len() as f64;
variance.sqrt() * (252.0_f64).sqrt()
}
fn calculate_max_drawdown(&self) -> (f64, f64, i64) {
let mut max_drawdown = 0.0;
let mut max_drawdown_pct = 0.0;
let mut max_duration = 0i64;
let mut peak = self.config.initial_capital;
let mut peak_date = self.config.start_time;
for (date, value) in &self.equity_curve {
if *value > peak {
peak = *value;
peak_date = *date;
}
let drawdown = peak - value;
let drawdown_pct = (drawdown / peak) * 100.0;
if drawdown > max_drawdown {
max_drawdown = drawdown;
max_drawdown_pct = drawdown_pct;
}
if drawdown > 0.0 {
let duration = (*date - peak_date).num_days();
if duration > max_duration {
max_duration = duration;
}
}
}
(max_drawdown, max_drawdown_pct, max_duration)
}
fn calculate_var_cvar(&self, returns: &[f64], confidence: f64) -> (f64, f64) {
if returns.is_empty() {
return (0.0, 0.0);
}
let mut sorted_returns = returns.to_vec();
sorted_returns.sort_by(|a, b| a.partial_cmp(b).unwrap());
let index = ((1.0 - confidence) * sorted_returns.len() as f64) as usize;
let var = sorted_returns[index];
let cvar = sorted_returns[..=index].iter().sum::<f64>() / (index + 1) as f64;
(var * 100.0, cvar * 100.0)
}
fn calculate_streaks(&self) -> (usize, usize, i32) {
let mut max_wins = 0;
let mut max_losses = 0;
let mut current_streak = 0;
let mut current_wins = 0;
let mut current_losses = 0;
for trade in &self.trades {
if trade.pnl > 0.0 {
current_wins += 1;
current_losses = 0;
if current_wins > max_wins {
max_wins = current_wins;
}
current_streak = current_wins as i32;
} else if trade.pnl < 0.0 {
current_losses += 1;
current_wins = 0;
if current_losses > max_losses {
max_losses = current_losses;
}
current_streak = -(current_losses as i32);
}
}
(max_wins, max_losses, current_streak)
}
fn calculate_drawdown_curve(&self) -> Vec<EquityPoint> {
let mut drawdown_curve = Vec::new();
let mut peak = self.config.initial_capital;
for (date, value) in &self.equity_curve {
if *value > peak {
peak = *value;
}
let drawdown_pct = ((peak - value) / peak) * 100.0;
drawdown_curve.push(EquityPoint {
timestamp: date.timestamp_millis(),
value: -drawdown_pct, // Negative for visualization
});
}
drawdown_curve
}
fn calculate_daily_returns(&self) -> Vec<DailyReturn> {
let mut daily_returns = Vec::new();
for i in 1..self.equity_curve.len() {
let prev = &self.equity_curve[i - 1];
let curr = &self.equity_curve[i];
// Only include if it's a new day
if prev.0.date() != curr.0.date() {
let return_pct = ((curr.1 - prev.1) / prev.1) * 100.0;
daily_returns.push(DailyReturn {
date: curr.0.format("%Y-%m-%d").to_string(),
value: return_pct,
});
}
}
daily_returns
}
fn calculate_cumulative_returns(&self) -> Vec<EquityPoint> {
let initial = self.config.initial_capital;
self.equity_curve.iter()
.map(|(date, value)| EquityPoint {
timestamp: date.timestamp_millis(),
value: ((value - initial) / initial) * 100.0,
})
.collect()
}
fn convert_fills_to_trades(&self) -> Vec<Trade> {
self.fills.iter()
.enumerate()
.map(|(i, fill)| Trade {
id: format!("trade-{}", i),
timestamp: fill.timestamp,
symbol: fill.symbol.clone(),
side: match fill.side {
Side::Buy => "buy".to_string(),
Side::Sell => "sell".to_string(),
},
quantity: fill.quantity,
price: fill.price,
commission: fill.commission,
pnl: None, // Individual fills don't have PnL
})
.collect()
}
fn get_open_trades(&self) -> Vec<Trade> {
// TODO: Implement based on position tracker
Vec::new()
}
fn calculate_period_returns(&self, period: &str) -> Vec<PeriodReturn> {
// TODO: Implement monthly/yearly return calculations
Vec::new()
}
fn calculate_symbol_analysis(&self) -> HashMap<String, SymbolAnalysis> {
let mut analysis: HashMap<String, SymbolAnalysis> = HashMap::new();
// Group trades by symbol
for trade in &self.trades {
let entry = analysis.entry(trade.symbol.clone()).or_insert(SymbolAnalysis {
symbol: trade.symbol.clone(),
total_trades: 0,
winning_trades: 0,
losing_trades: 0,
total_pnl: 0.0,
win_rate: 0.0,
avg_win: 0.0,
avg_loss: 0.0,
profit_factor: 0.0,
total_volume: 0.0,
avg_position_size: 0.0,
exposure_time_pct: 0.0,
});
entry.total_trades += 1;
entry.total_pnl += trade.pnl;
entry.total_volume += trade.quantity * (trade.entry_price + trade.exit_price);
if trade.pnl > 0.0 {
entry.winning_trades += 1;
} else if trade.pnl < 0.0 {
entry.losing_trades += 1;
}
}
// Calculate derived metrics
for (_, entry) in analysis.iter_mut() {
if entry.total_trades > 0 {
entry.win_rate = (entry.winning_trades as f64 / entry.total_trades as f64) * 100.0;
entry.avg_position_size = entry.total_volume / (entry.total_trades as f64 * 2.0);
}
// Calculate avg win/loss
let wins: Vec<f64> = self.trades.iter()
.filter(|t| t.symbol == entry.symbol && t.pnl > 0.0)
.map(|t| t.pnl)
.collect();
let losses: Vec<f64> = self.trades.iter()
.filter(|t| t.symbol == entry.symbol && t.pnl < 0.0)
.map(|t| t.pnl)
.collect();
if !wins.is_empty() {
entry.avg_win = wins.iter().sum::<f64>() / wins.len() as f64;
}
if !losses.is_empty() {
entry.avg_loss = losses.iter().sum::<f64>() / losses.len() as f64;
}
// Profit factor
let gross_profit = wins.iter().sum::<f64>();
let gross_loss = losses.iter().map(|l| l.abs()).sum::<f64>();
if gross_loss > 0.0 {
entry.profit_factor = gross_profit / gross_loss;
} else if gross_profit > 0.0 {
entry.profit_factor = f64::INFINITY;
}
}
analysis
}
fn analyze_drawdown_periods(&self) -> Vec<DrawdownPeriod> {
// TODO: Implement comprehensive drawdown period analysis
Vec::new()
}
fn calculate_exposure_analysis(&self) -> ExposureAnalysis {
// TODO: Implement exposure analysis
ExposureAnalysis {
avg_gross_exposure: 0.0,
avg_net_exposure: 0.0,
max_gross_exposure: 0.0,
max_net_exposure: 0.0,
time_in_market_pct: 0.0,
long_exposure_pct: 0.0,
short_exposure_pct: 0.0,
}
}
fn calculate_position_history(&self) -> Vec<PositionSnapshot> {
// TODO: Implement position history tracking
Vec::new()
}
fn extract_trade_signals(&self) -> Vec<TradeSignal> {
// TODO: Extract signals from strategy execution
Vec::new()
}
}

View file

@ -0,0 +1,155 @@
use crate::{MarketUpdate, Order, Fill, TradingMode, MarketDataSource, ExecutionHandler, TimeProvider, Side};
use crate::positions::PositionTracker;
use crate::risk::RiskEngine;
use crate::orderbook::OrderBookManager;
use chrono::{DateTime, Utc};
use std::collections::{BTreeMap, VecDeque};
use std::sync::Arc;
use parking_lot::RwLock;
use serde::{Serialize, Deserialize};
pub mod engine;
pub mod event;
pub mod strategy;
pub mod simple_results;
pub mod trade_tracker;
pub use engine::BacktestEngine;
pub use event::{BacktestEvent, EventType};
pub use strategy::{Strategy, Signal, SignalType};
pub use simple_results::{BacktestResult, BacktestMetrics};
pub use trade_tracker::{TradeTracker, CompletedTrade as TrackedTrade};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletedTrade {
pub symbol: String,
pub side: Side,
pub timestamp: DateTime<Utc>,
pub price: f64,
pub quantity: f64,
pub commission: f64,
pub position_after: f64, // Position size after this trade
pub pnl: Option<f64>, // P&L if position was reduced/closed
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct BacktestConfig {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub strategy: Option<String>,
pub symbols: Vec<String>,
#[serde(rename = "startDate")]
pub start_time: DateTime<Utc>,
#[serde(rename = "endDate")]
pub end_time: DateTime<Utc>,
pub initial_capital: f64,
pub commission: f64,
pub slippage: f64,
pub data_frequency: String,
}
#[derive(Debug, Clone)]
pub struct BacktestState {
pub current_time: DateTime<Utc>,
pub portfolio_value: f64,
pub cash: f64,
pub equity_curve: Vec<(DateTime<Utc>, f64)>,
pub pending_orders: BTreeMap<String, Order>,
pub completed_trades: Vec<CompletedTrade>,
}
impl BacktestState {
pub fn new(initial_capital: f64, start_time: DateTime<Utc>) -> Self {
Self {
current_time: start_time,
portfolio_value: initial_capital,
cash: initial_capital,
equity_curve: vec![(start_time, initial_capital)],
pending_orders: BTreeMap::new(),
completed_trades: Vec::new(),
}
}
pub fn update_portfolio_value(&mut self, value: f64) {
self.portfolio_value = value;
// Only add a new equity curve point if the timestamp has changed
// or if it's the first point
if self.equity_curve.is_empty() ||
self.equity_curve.last().map(|(t, _)| *t != self.current_time).unwrap_or(true) {
self.equity_curve.push((self.current_time, value));
} else {
// Update the last point with the new value
if let Some(last) = self.equity_curve.last_mut() {
last.1 = value;
}
}
}
pub fn add_pending_order(&mut self, order: Order) {
self.pending_orders.insert(order.id.clone(), order);
}
pub fn remove_pending_order(&mut self, order_id: &str) -> Option<Order> {
self.pending_orders.remove(order_id)
}
pub fn record_fill(&mut self, symbol: String, side: Side, fill: Fill, position_after: f64, position_before: f64, pnl: Option<f64>) {
self.completed_trades.push(CompletedTrade {
symbol,
side,
timestamp: fill.timestamp,
price: fill.price,
quantity: fill.quantity,
commission: fill.commission,
position_after,
pnl,
});
}
}
// Event queue for deterministic replay
#[derive(Debug)]
pub struct EventQueue {
events: VecDeque<BacktestEvent>,
}
impl EventQueue {
pub fn new() -> Self {
Self {
events: VecDeque::new(),
}
}
pub fn push(&mut self, event: BacktestEvent) {
// Insert in time order
let pos = self.events.iter().position(|e| e.timestamp > event.timestamp)
.unwrap_or(self.events.len());
self.events.insert(pos, event);
}
pub fn pop_until(&mut self, timestamp: DateTime<Utc>) -> Vec<BacktestEvent> {
let mut events = Vec::new();
while let Some(event) = self.events.front() {
if event.timestamp <= timestamp {
events.push(self.events.pop_front().unwrap());
} else {
break;
}
}
events
}
pub fn is_empty(&self) -> bool {
self.events.is_empty()
}
pub fn len(&self) -> usize {
self.events.len()
}
pub fn peek_next(&self) -> Option<&BacktestEvent> {
self.events.front()
}
}

View file

@ -0,0 +1,266 @@
use chrono::{DateTime, Utc};
use serde::{Serialize, Deserialize};
use std::collections::HashMap;
use crate::Position;
use super::{BacktestConfig, CompletedTrade};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct BacktestMetrics {
// Core Performance Metrics
pub total_return: f64,
pub total_return_pct: f64,
pub annualized_return: f64,
pub total_pnl: f64,
pub net_pnl: f64, // After commissions
// Risk Metrics
pub sharpe_ratio: f64,
pub sortino_ratio: f64,
pub calmar_ratio: f64,
pub max_drawdown: f64,
pub max_drawdown_pct: f64,
pub max_drawdown_duration_days: i64,
pub volatility: f64,
pub downside_deviation: f64,
pub var_95: f64, // Value at Risk 95%
pub cvar_95: f64, // Conditional VaR 95%
// Trade Statistics
pub total_trades: usize,
pub winning_trades: usize,
pub losing_trades: usize,
pub breakeven_trades: usize,
pub win_rate: f64,
pub loss_rate: f64,
pub avg_win: f64,
pub avg_loss: f64,
pub largest_win: f64,
pub largest_loss: f64,
pub avg_trade_pnl: f64,
pub avg_trade_duration_hours: f64,
pub profit_factor: f64,
pub expectancy: f64,
pub expectancy_ratio: f64,
// Position Statistics
pub avg_position_size: f64,
pub max_position_size: f64,
pub avg_positions_held: f64,
pub max_concurrent_positions: usize,
pub total_long_trades: usize,
pub total_short_trades: usize,
pub long_win_rate: f64,
pub short_win_rate: f64,
// Trading Activity
pub total_volume_traded: f64,
pub total_commission_paid: f64,
pub commission_pct_of_pnl: f64,
pub avg_daily_trades: f64,
pub max_daily_trades: usize,
pub trading_days: usize,
pub exposure_time_pct: f64,
// Efficiency Metrics
pub return_over_max_dd: f64,
pub win_loss_ratio: f64,
pub kelly_criterion: f64,
pub ulcer_index: f64,
pub serenity_ratio: f64,
pub lake_ratio: f64,
// Monthly/Period Analysis
pub best_month_return: f64,
pub worst_month_return: f64,
pub positive_months: usize,
pub negative_months: usize,
pub monthly_win_rate: f64,
pub avg_monthly_return: f64,
// Streak Analysis
pub max_consecutive_wins: usize,
pub max_consecutive_losses: usize,
pub current_streak: i32, // Positive for wins, negative for losses
pub avg_winning_streak: f64,
pub avg_losing_streak: f64,
}
// Trade structure that matches web app expectations
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Trade {
pub id: String,
pub timestamp: DateTime<Utc>,
pub symbol: String,
pub side: String, // "buy" or "sell"
pub quantity: f64,
pub price: f64,
pub commission: f64,
#[serde(skip_serializing_if = "Option::is_none")]
pub pnl: Option<f64>,
}
// Equity curve point
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct EquityPoint {
pub timestamp: i64, // milliseconds since epoch
pub value: f64,
}
// Daily return data
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DailyReturn {
pub date: String,
pub value: f64,
}
// Performance data point for charts
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PerformanceDataPoint {
pub timestamp: String,
pub portfolio_value: f64,
pub pnl: f64,
pub drawdown: f64,
}
// Period returns (monthly, weekly, etc)
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PeriodReturn {
pub period: String, // "2024-01", "2024-W01", etc
pub start_date: String,
pub end_date: String,
pub return_pct: f64,
pub pnl: f64,
pub trades: usize,
pub win_rate: f64,
}
// Trade analysis by symbol
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SymbolAnalysis {
pub symbol: String,
pub total_trades: usize,
pub winning_trades: usize,
pub losing_trades: usize,
pub total_pnl: f64,
pub win_rate: f64,
pub avg_win: f64,
pub avg_loss: f64,
pub profit_factor: f64,
pub total_volume: f64,
pub avg_position_size: f64,
pub exposure_time_pct: f64,
}
// Drawdown period information
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct DrawdownPeriod {
pub start_date: String,
pub end_date: String,
pub peak_value: f64,
pub trough_value: f64,
pub drawdown_pct: f64,
pub recovery_date: Option<String>,
pub duration_days: i64,
pub recovery_days: Option<i64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct BacktestResult {
// Metadata
pub backtest_id: String,
pub status: String,
pub started_at: String,
pub completed_at: String,
pub execution_time_ms: u64,
// Configuration
pub config: BacktestConfig,
// Core Metrics
pub metrics: BacktestMetrics,
// Time Series Data
pub equity_curve: Vec<EquityPoint>,
pub drawdown_curve: Vec<EquityPoint>,
pub daily_returns: Vec<DailyReturn>,
pub cumulative_returns: Vec<EquityPoint>,
// Trade Data
pub trades: Vec<Trade>,
pub open_trades: Vec<Trade>,
pub trade_signals: Vec<TradeSignal>,
// Position Data
pub positions: HashMap<String, Position>,
pub position_history: Vec<PositionSnapshot>,
// Analytics
pub monthly_returns: Vec<PeriodReturn>,
pub yearly_returns: Vec<PeriodReturn>,
pub symbol_analysis: HashMap<String, SymbolAnalysis>,
pub drawdown_periods: Vec<DrawdownPeriod>,
pub exposure_analysis: ExposureAnalysis,
// Market Data
pub ohlc_data: HashMap<String, Vec<serde_json::Value>>,
// Error Handling
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub warnings: Option<Vec<String>>,
}
// Trade signal for visualization
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct TradeSignal {
pub timestamp: String,
pub symbol: String,
pub action: String, // "buy", "sell", "hold"
pub strength: f64,
pub reason: String,
}
// Position snapshot for tracking
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PositionSnapshot {
pub timestamp: String,
pub positions: HashMap<String, PositionDetail>,
pub total_value: f64,
pub cash: f64,
pub margin_used: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PositionDetail {
pub symbol: String,
pub quantity: f64,
pub avg_price: f64,
pub current_price: f64,
pub market_value: f64,
pub unrealized_pnl: f64,
pub unrealized_pnl_pct: f64,
}
// Exposure analysis
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ExposureAnalysis {
pub avg_gross_exposure: f64,
pub avg_net_exposure: f64,
pub max_gross_exposure: f64,
pub max_net_exposure: f64,
pub time_in_market_pct: f64,
pub long_exposure_pct: f64,
pub short_exposure_pct: f64,
}

View file

@ -0,0 +1,565 @@
use chrono::{DateTime, Utc, Datelike};
use serde::{Serialize, Deserialize};
use std::collections::HashMap;
use crate::Position;
use super::{BacktestConfig, CompletedTrade};
use super::trade_tracker::CompletedTrade as TrackedTrade;
// Simplified metrics that match what the web app expects
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct BacktestMetrics {
pub total_return: f64,
pub sharpe_ratio: f64,
pub max_drawdown: f64,
pub win_rate: f64,
pub total_trades: usize,
pub profitable_trades: usize,
// Additional fields for compatibility
pub profit_factor: f64,
pub total_pnl: f64,
pub avg_win: f64,
pub avg_loss: f64,
pub expectancy: f64,
pub calmar_ratio: f64,
pub sortino_ratio: f64,
// Missing fields required by web app
pub final_value: f64,
pub winning_trades: usize,
pub losing_trades: usize,
pub largest_win: f64,
pub largest_loss: f64,
pub annual_return: f64,
}
// Individual trade (fill) structure for UI
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Trade {
pub id: String,
pub timestamp: String,
pub symbol: String,
pub side: String, // "buy" or "sell"
pub quantity: f64,
pub price: f64,
pub commission: f64,
#[serde(skip_serializing_if = "Option::is_none")]
pub pnl: Option<f64>,
pub position_after: f64, // Position size after this trade
}
// Analytics data structure
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Analytics {
pub drawdown_series: Vec<EquityPoint>,
pub daily_returns: Vec<f64>,
pub monthly_returns: HashMap<String, f64>,
pub exposure_time: f64,
pub risk_metrics: HashMap<String, f64>,
}
// Position structure for web app
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PositionInfo {
pub symbol: String,
pub quantity: f64,
pub average_price: f64,
pub current_price: f64,
pub unrealized_pnl: f64,
pub realized_pnl: f64,
}
// Equity data point for charts
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EquityDataPoint {
pub date: String,
pub value: f64,
}
// Simplified backtest result
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct BacktestResult {
pub backtest_id: String,
pub status: String,
pub completed_at: String,
pub config: BacktestConfig,
pub metrics: BacktestMetrics,
pub equity: Vec<EquityDataPoint>,
pub trades: Vec<Trade>, // Now shows all individual fills
pub positions: Vec<PositionInfo>,
pub analytics: Analytics,
pub execution_time: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
pub ohlc_data: HashMap<String, Vec<serde_json::Value>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct EquityPoint {
pub timestamp: i64,
pub value: f64,
}
impl BacktestResult {
pub fn from_engine_data(
config: BacktestConfig,
equity_curve: Vec<(DateTime<Utc>, f64)>,
fills: Vec<CompletedTrade>,
final_positions: HashMap<String, Position>,
start_time: DateTime<Utc>,
) -> Self {
let initial_capital = config.initial_capital;
let final_value = equity_curve.last().map(|(_, v)| *v).unwrap_or(initial_capital);
// Convert fills to trades
let trades: Vec<Trade> = fills.iter()
.enumerate()
.map(|(i, fill)| Trade {
id: format!("trade-{}", i),
timestamp: fill.timestamp.to_rfc3339(),
symbol: fill.symbol.clone(),
side: match fill.side {
crate::Side::Buy => "buy".to_string(),
crate::Side::Sell => "sell".to_string(),
},
quantity: fill.quantity,
price: fill.price,
commission: fill.commission,
pnl: fill.pnl,
position_after: fill.position_after,
})
.collect();
// Calculate metrics
let total_trades = trades.len();
let total_return = ((final_value - initial_capital) / initial_capital) * 100.0;
let total_pnl = final_value - initial_capital;
// Basic win rate calculation - this is simplified
let profitable_trades = 0; // TODO: Calculate from paired trades
let win_rate = 0.0; // TODO: Calculate properly
// Calculate daily returns
let mut daily_returns_vec = Vec::new();
let mut prev_date = equity_curve.first().map(|(d, _)| d.date_naive()).unwrap_or_default();
let mut prev_value = initial_capital;
for (dt, value) in &equity_curve {
if dt.date_naive() != prev_date {
let daily_return = (value - prev_value) / prev_value;
daily_returns_vec.push(daily_return);
prev_date = dt.date_naive();
prev_value = *value;
}
}
// Calculate volatility (annualized)
let volatility = if daily_returns_vec.len() > 1 {
let mean = daily_returns_vec.iter().sum::<f64>() / daily_returns_vec.len() as f64;
let variance = daily_returns_vec.iter()
.map(|r| (r - mean).powi(2))
.sum::<f64>() / daily_returns_vec.len() as f64;
variance.sqrt() * (252.0_f64).sqrt()
} else {
0.0
};
// Calculate Sharpe ratio
let risk_free_rate = 0.02; // 2% annual
let days = (config.end_time - config.start_time).num_days() as f64;
let years = days / 365.25;
let annualized_return = if years > 0.0 {
(final_value / initial_capital).powf(1.0 / years) - 1.0
} else {
0.0
};
let sharpe_ratio = if volatility > 0.0 {
(annualized_return - risk_free_rate) / volatility
} else {
0.0
};
// Calculate max drawdown
let (max_drawdown, drawdown_series) = Self::calculate_drawdown(&equity_curve, initial_capital);
// Calculate Calmar ratio
let calmar_ratio = if max_drawdown > 0.0 {
annualized_return / max_drawdown
} else {
0.0
};
// Convert positions
let positions: Vec<PositionInfo> = final_positions.iter()
.map(|(symbol, pos)| {
let current_price = pos.average_price; // TODO: Get actual current price
let unrealized_pnl = (current_price - pos.average_price) * pos.quantity;
PositionInfo {
symbol: symbol.clone(),
quantity: pos.quantity,
average_price: pos.average_price,
current_price,
unrealized_pnl,
realized_pnl: pos.realized_pnl,
}
})
.collect();
// Convert equity curve to the format expected by web app
let equity: Vec<EquityDataPoint> = equity_curve.iter()
.map(|(dt, val)| EquityDataPoint {
date: dt.to_rfc3339(),
value: *val,
})
.collect();
let metrics = BacktestMetrics {
total_return,
sharpe_ratio,
max_drawdown,
win_rate,
total_trades,
profitable_trades,
profit_factor: 0.0,
total_pnl,
avg_win: 0.0,
avg_loss: 0.0,
expectancy: 0.0,
calmar_ratio,
sortino_ratio: 0.0, // TODO: Calculate
final_value,
winning_trades: profitable_trades,
losing_trades: 0,
largest_win: 0.0,
largest_loss: 0.0,
annual_return: annualized_return * 100.0,
};
// Create analytics
let analytics = Analytics {
drawdown_series,
daily_returns: daily_returns_vec,
monthly_returns: HashMap::new(), // TODO: Calculate
exposure_time: 0.0, // TODO: Calculate
risk_metrics: HashMap::new(), // TODO: Add risk metrics
};
BacktestResult {
backtest_id: format!("rust-{}", Utc::now().timestamp_millis()),
status: "completed".to_string(),
completed_at: Utc::now().to_rfc3339(),
config,
metrics,
equity,
trades: Vec::new(), // No completed trades in simple version
positions,
analytics,
execution_time: (Utc::now() - start_time).num_milliseconds() as u64,
error: None,
ohlc_data: HashMap::new(),
}
}
fn calculate_drawdown(equity_curve: &[(DateTime<Utc>, f64)], initial_capital: f64) -> (f64, Vec<EquityPoint>) {
let mut max_drawdown = 0.0;
let mut peak = initial_capital;
let mut drawdown_series = Vec::new();
for (dt, value) in equity_curve {
if *value > peak {
peak = *value;
}
let drawdown_pct = if peak > 0.0 {
((peak - value) / peak) * 100.0
} else {
0.0
};
if drawdown_pct > max_drawdown {
max_drawdown = drawdown_pct;
}
drawdown_series.push(EquityPoint {
timestamp: dt.timestamp_millis(),
value: -drawdown_pct, // Negative for visualization
});
}
(max_drawdown / 100.0, drawdown_series) // Return as decimal
}
pub fn from_engine_data_with_trades(
config: BacktestConfig,
equity_curve: Vec<(DateTime<Utc>, f64)>,
fills: Vec<CompletedTrade>,
completed_trades: Vec<TrackedTrade>,
final_positions: HashMap<String, Position>,
start_time: DateTime<Utc>,
last_prices: &HashMap<String, f64>,
) -> Self {
let initial_capital = config.initial_capital;
let final_value = equity_curve.last().map(|(_, v)| *v).unwrap_or(initial_capital);
// Convert fills to web app format (all individual trades)
let trades: Vec<Trade> = fills.iter()
.enumerate()
.map(|(i, fill)| Trade {
id: format!("trade-{}", i + 1),
timestamp: fill.timestamp.to_rfc3339(),
symbol: fill.symbol.clone(),
side: match fill.side {
crate::Side::Buy => "buy".to_string(),
crate::Side::Sell => "sell".to_string(),
},
quantity: fill.quantity,
price: fill.price,
commission: fill.commission,
pnl: fill.pnl,
position_after: fill.position_after,
})
.collect();
// Calculate metrics from completed trades
let total_trades = completed_trades.len();
let total_return = ((final_value - initial_capital) / initial_capital) * 100.0;
let total_pnl = final_value - initial_capital;
// Calculate win rate and profit metrics
let winning_trades: Vec<&TrackedTrade> = completed_trades.iter()
.filter(|t| t.pnl > 0.0)
.collect();
let losing_trades: Vec<&TrackedTrade> = completed_trades.iter()
.filter(|t| t.pnl < 0.0)
.collect();
let profitable_trades = winning_trades.len();
let win_rate = if total_trades > 0 {
(profitable_trades as f64 / total_trades as f64) * 100.0
} else {
0.0
};
let avg_win = if !winning_trades.is_empty() {
winning_trades.iter().map(|t| t.pnl).sum::<f64>() / winning_trades.len() as f64
} else {
0.0
};
let avg_loss = if !losing_trades.is_empty() {
losing_trades.iter().map(|t| t.pnl).sum::<f64>() / losing_trades.len() as f64
} else {
0.0
};
let gross_profit = winning_trades.iter().map(|t| t.pnl).sum::<f64>();
let gross_loss = losing_trades.iter().map(|t| t.pnl.abs()).sum::<f64>();
let profit_factor = if gross_loss > 0.0 {
gross_profit / gross_loss
} else if gross_profit > 0.0 {
f64::INFINITY
} else {
0.0
};
let expectancy = avg_win * (win_rate / 100.0) - avg_loss.abs() * ((100.0 - win_rate) / 100.0);
// Calculate daily returns
let mut daily_returns_vec = Vec::new();
let mut prev_date = equity_curve.first().map(|(d, _)| d.date_naive()).unwrap_or_default();
let mut prev_value = initial_capital;
for (dt, value) in &equity_curve {
if dt.date_naive() != prev_date {
let daily_return = (value - prev_value) / prev_value;
daily_returns_vec.push(daily_return);
prev_date = dt.date_naive();
prev_value = *value;
}
}
// Calculate volatility (annualized)
let volatility = if daily_returns_vec.len() > 1 {
let mean = daily_returns_vec.iter().sum::<f64>() / daily_returns_vec.len() as f64;
let variance = daily_returns_vec.iter()
.map(|r| (r - mean).powi(2))
.sum::<f64>() / daily_returns_vec.len() as f64;
variance.sqrt() * (252.0_f64).sqrt()
} else {
0.0
};
// Calculate Sharpe ratio
let risk_free_rate = 0.02; // 2% annual
let days = (config.end_time - config.start_time).num_days() as f64;
let years = days / 365.25;
let annualized_return = if years > 0.0 {
(final_value / initial_capital).powf(1.0 / years) - 1.0
} else {
0.0
};
let sharpe_ratio = if volatility > 0.0 {
(annualized_return - risk_free_rate) / volatility
} else {
0.0
};
// Calculate Sortino ratio (using downside deviation)
let negative_returns: Vec<f64> = daily_returns_vec.iter()
.filter(|&&r| r < 0.0)
.cloned()
.collect();
let downside_deviation = if !negative_returns.is_empty() {
let mean = negative_returns.iter().sum::<f64>() / negative_returns.len() as f64;
let variance = negative_returns.iter()
.map(|r| (r - mean).powi(2))
.sum::<f64>() / negative_returns.len() as f64;
variance.sqrt() * (252.0_f64).sqrt()
} else {
0.0
};
let sortino_ratio = if downside_deviation > 0.0 {
(annualized_return - risk_free_rate) / downside_deviation
} else {
0.0
};
// Calculate max drawdown
let (max_drawdown, drawdown_series) = Self::calculate_drawdown(&equity_curve, initial_capital);
// Calculate Calmar ratio
let calmar_ratio = if max_drawdown > 0.0 {
annualized_return / max_drawdown
} else {
0.0
};
// Convert positions with actual last prices
let positions: Vec<PositionInfo> = final_positions.iter()
.map(|(symbol, pos)| {
let current_price = last_prices.get(symbol).copied()
.unwrap_or(pos.average_price);
let unrealized_pnl = (current_price - pos.average_price) * pos.quantity;
PositionInfo {
symbol: symbol.clone(),
quantity: pos.quantity,
average_price: pos.average_price,
current_price,
unrealized_pnl,
realized_pnl: pos.realized_pnl,
}
})
.collect();
// Convert equity curve to the format expected by web app
let equity: Vec<EquityDataPoint> = equity_curve.iter()
.map(|(dt, val)| EquityDataPoint {
date: dt.to_rfc3339(),
value: *val,
})
.collect();
// Convert individual fills to UI format (for backward compatibility)
let _fill_trades: Vec<Trade> = fills.iter()
.enumerate()
.map(|(i, fill)| Trade {
id: format!("fill-{}", i),
timestamp: fill.timestamp.to_rfc3339(),
symbol: fill.symbol.clone(),
side: match fill.side {
crate::Side::Buy => "buy".to_string(),
crate::Side::Sell => "sell".to_string(),
},
quantity: fill.quantity,
price: fill.price,
commission: fill.commission,
pnl: fill.pnl,
position_after: fill.position_after,
})
.collect();
// Find largest win/loss
let largest_win = winning_trades.iter()
.map(|t| t.pnl)
.max_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap_or(0.0);
let largest_loss = losing_trades.iter()
.map(|t| t.pnl)
.min_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap_or(0.0);
let metrics = BacktestMetrics {
total_return,
sharpe_ratio,
max_drawdown,
win_rate,
total_trades,
profitable_trades,
profit_factor,
total_pnl,
avg_win,
avg_loss,
expectancy,
calmar_ratio,
sortino_ratio,
final_value,
winning_trades: profitable_trades,
losing_trades: losing_trades.len(),
largest_win,
largest_loss,
annual_return: annualized_return * 100.0, // Convert to percentage
};
// Calculate monthly returns
let mut monthly_returns = HashMap::new();
let mut monthly_values: HashMap<String, Vec<f64>> = HashMap::new();
for (dt, value) in &equity_curve {
let month_key = format!("{}-{:02}", dt.year(), dt.month());
monthly_values.entry(month_key).or_insert_with(Vec::new).push(*value);
}
for (month, values) in monthly_values {
if values.len() >= 2 {
let start = values.first().unwrap();
let end = values.last().unwrap();
let monthly_return = ((end - start) / start) * 100.0;
monthly_returns.insert(month, monthly_return);
}
}
// Create analytics
let analytics = Analytics {
drawdown_series,
daily_returns: daily_returns_vec,
monthly_returns,
exposure_time: 0.0, // TODO: Calculate based on position history
risk_metrics: {
let mut metrics = HashMap::new();
metrics.insert("volatility".to_string(), volatility);
metrics.insert("downside_deviation".to_string(), downside_deviation);
metrics.insert("value_at_risk_95".to_string(), 0.0); // TODO: Calculate VaR
metrics
},
};
BacktestResult {
backtest_id: format!("rust-{}", Utc::now().timestamp_millis()),
status: "completed".to_string(),
completed_at: Utc::now().to_rfc3339(),
config,
metrics,
equity,
trades, // Return the completed trades
positions,
analytics,
execution_time: (Utc::now() - start_time).num_milliseconds() as u64,
error: None,
ohlc_data: HashMap::new(),
}
}
}

View file

@ -0,0 +1,100 @@
use chrono::{DateTime, Utc};
use serde::{Serialize, Deserialize};
use crate::MarketData;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SignalType {
Buy,
Sell,
Close,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Signal {
pub symbol: String,
pub signal_type: SignalType,
pub strength: f64, // -1.0 to 1.0
pub quantity: Option<f64>,
pub reason: Option<String>,
pub metadata: Option<serde_json::Value>,
}
// This trait will be implemented by Rust strategies
// TypeScript strategies will communicate through FFI
pub trait Strategy: Send + Sync {
fn on_market_data(&mut self, data: &MarketData) -> Vec<Signal>;
fn on_fill(&mut self, symbol: &str, quantity: f64, price: f64, side: &str);
fn get_name(&self) -> &str;
fn get_parameters(&self) -> serde_json::Value;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StrategyCall {
pub method: String,
pub data: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StrategyResponse {
pub signals: Vec<Signal>,
}
// Bridge for TypeScript strategies
// This will be used to wrap TypeScript strategies
pub struct TypeScriptStrategy {
pub name: String,
pub id: String,
pub parameters: serde_json::Value,
// Callback function will be injected from TypeScript
pub callback: Option<Box<dyn Fn(StrategyCall) -> StrategyResponse + Send + Sync>>,
}
impl TypeScriptStrategy {
pub fn new(name: String, id: String, parameters: serde_json::Value) -> Self {
Self {
name,
id,
parameters,
callback: None,
}
}
}
impl Strategy for TypeScriptStrategy {
fn on_market_data(&mut self, data: &MarketData) -> Vec<Signal> {
if let Some(callback) = &self.callback {
let call = StrategyCall {
method: "on_market_data".to_string(),
data: serde_json::to_value(data).unwrap_or_default(),
};
let response = callback(call);
response.signals
} else {
Vec::new()
}
}
fn on_fill(&mut self, symbol: &str, quantity: f64, price: f64, side: &str) {
if let Some(callback) = &self.callback {
let call = StrategyCall {
method: "on_fill".to_string(),
data: serde_json::json!({
"symbol": symbol,
"quantity": quantity,
"price": price,
"side": side
}),
};
callback(call);
}
}
fn get_name(&self) -> &str {
&self.name
}
fn get_parameters(&self) -> serde_json::Value {
self.parameters.clone()
}
}

View file

@ -0,0 +1,269 @@
use std::collections::{HashMap, VecDeque};
use chrono::{DateTime, Utc};
use serde::{Serialize, Deserialize};
use crate::{Fill, Side};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletedTrade {
pub id: String,
pub symbol: String,
pub entry_time: DateTime<Utc>,
pub exit_time: DateTime<Utc>,
pub entry_price: f64,
pub exit_price: f64,
pub quantity: f64,
pub side: Side, // Side of the opening trade
pub pnl: f64,
pub pnl_percent: f64,
pub commission: f64,
pub duration_seconds: i64,
}
#[derive(Debug, Clone)]
struct OpenPosition {
symbol: String,
side: Side,
quantity: f64,
entry_price: f64,
entry_time: DateTime<Utc>,
commission: f64,
}
/// Tracks fills and matches them into completed trades
pub struct TradeTracker {
open_positions: HashMap<String, VecDeque<OpenPosition>>,
completed_trades: Vec<CompletedTrade>,
trade_counter: u64,
}
impl TradeTracker {
pub fn new() -> Self {
Self {
open_positions: HashMap::new(),
completed_trades: Vec::new(),
trade_counter: 0,
}
}
pub fn get_completed_trades(&self) -> Vec<CompletedTrade> {
self.completed_trades.clone()
}
pub fn process_fill(&mut self, symbol: &str, side: Side, fill: &Fill) {
let positions = self.open_positions.entry(symbol.to_string()).or_insert_with(VecDeque::new);
// Check if this fill closes existing positions
let mut remaining_quantity = fill.quantity;
let mut fills_to_remove = Vec::new();
for (idx, open_pos) in positions.iter_mut().enumerate() {
// Only match against opposite side positions
if open_pos.side == side {
continue;
}
if remaining_quantity <= 0.0 {
break;
}
let matched_quantity = remaining_quantity.min(open_pos.quantity);
// Calculate PnL
let (pnl, pnl_percent) = Self::calculate_pnl(
&open_pos,
fill.price,
matched_quantity,
fill.commission,
);
// Create completed trade
self.trade_counter += 1;
let completed_trade = CompletedTrade {
id: format!("trade-{}", self.trade_counter),
symbol: symbol.to_string(),
entry_time: open_pos.entry_time,
exit_time: fill.timestamp,
entry_price: open_pos.entry_price,
exit_price: fill.price,
quantity: matched_quantity,
side: open_pos.side.clone(),
pnl,
pnl_percent,
commission: open_pos.commission + (fill.commission * matched_quantity / fill.quantity),
duration_seconds: (fill.timestamp - open_pos.entry_time).num_seconds(),
};
eprintln!("📈 TRADE CLOSED: {} {} @ entry: ${:.2} ({}), exit: ${:.2} ({}) = P&L: ${:.2}",
symbol,
match open_pos.side { Side::Buy => "LONG", Side::Sell => "SHORT" },
open_pos.entry_price,
open_pos.entry_time.format("%Y-%m-%d"),
fill.price,
fill.timestamp.format("%Y-%m-%d"),
pnl
);
self.completed_trades.push(completed_trade);
// Update open position
open_pos.quantity -= matched_quantity;
remaining_quantity -= matched_quantity;
if open_pos.quantity <= 0.0 {
fills_to_remove.push(idx);
}
}
// Remove fully closed positions
for idx in fills_to_remove.iter().rev() {
positions.remove(*idx);
}
// If there's remaining quantity, it opens a new position
if remaining_quantity > 0.0 {
let new_position = OpenPosition {
symbol: symbol.to_string(),
side,
quantity: remaining_quantity,
entry_price: fill.price,
entry_time: fill.timestamp,
commission: fill.commission * remaining_quantity / fill.quantity,
};
positions.push_back(new_position);
}
}
fn calculate_pnl(open_pos: &OpenPosition, exit_price: f64, quantity: f64, exit_commission: f64) -> (f64, f64) {
let entry_value = open_pos.entry_price * quantity;
let exit_value = exit_price * quantity;
let gross_pnl = match open_pos.side {
Side::Buy => exit_value - entry_value,
Side::Sell => entry_value - exit_value,
};
let commission = open_pos.commission * (quantity / open_pos.quantity) + exit_commission * (quantity / open_pos.quantity);
let net_pnl = gross_pnl - commission;
let pnl_percent = (net_pnl / entry_value) * 100.0;
(net_pnl, pnl_percent)
}
pub fn get_open_positions(&self) -> HashMap<String, Vec<(Side, f64, f64)>> {
let mut result = HashMap::new();
for (symbol, positions) in &self.open_positions {
let pos_info: Vec<(Side, f64, f64)> = positions
.iter()
.map(|p| (p.side.clone(), p.quantity, p.entry_price))
.collect();
if !pos_info.is_empty() {
result.insert(symbol.clone(), pos_info);
}
}
result
}
pub fn get_net_position(&self, symbol: &str) -> f64 {
let positions = match self.open_positions.get(symbol) {
Some(pos) => pos,
None => return 0.0,
};
let mut net = 0.0;
for pos in positions {
match pos.side {
Side::Buy => net += pos.quantity,
Side::Sell => net -= pos.quantity,
}
}
net
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simple_round_trip() {
let mut tracker = TradeTracker::new();
// Buy 100 shares at $50
let buy_fill = Fill {
timestamp: Utc::now(),
price: 50.0,
quantity: 100.0,
commission: 1.0,
};
tracker.process_fill("AAPL", Side::Buy, &buy_fill);
// Sell 100 shares at $55
let sell_fill = Fill {
timestamp: Utc::now(),
price: 55.0,
quantity: 100.0,
commission: 1.0,
};
tracker.process_fill("AAPL", Side::Sell, &sell_fill);
let trades = tracker.get_completed_trades();
assert_eq!(trades.len(), 1);
let trade = &trades[0];
assert_eq!(trade.symbol, "AAPL");
assert_eq!(trade.quantity, 100.0);
assert_eq!(trade.entry_price, 50.0);
assert_eq!(trade.exit_price, 55.0);
assert_eq!(trade.pnl, 498.0); // (55-50)*100 - 2 commission
assert_eq!(trade.side, Side::Buy);
}
#[test]
fn test_partial_fills() {
let mut tracker = TradeTracker::new();
// Buy 100 shares at $50
let buy_fill = Fill {
timestamp: Utc::now(),
price: 50.0,
quantity: 100.0,
commission: 1.0,
};
tracker.process_fill("AAPL", Side::Buy, &buy_fill);
// Sell 60 shares at $55
let sell_fill1 = Fill {
timestamp: Utc::now(),
price: 55.0,
quantity: 60.0,
commission: 0.6,
};
tracker.process_fill("AAPL", Side::Sell, &sell_fill1);
// Check we have one completed trade and remaining position
let trades = tracker.get_completed_trades();
assert_eq!(trades.len(), 1);
assert_eq!(trades[0].quantity, 60.0);
assert_eq!(tracker.get_net_position("AAPL"), 40.0);
// Sell remaining 40 shares at $52
let sell_fill2 = Fill {
timestamp: Utc::now(),
price: 52.0,
quantity: 40.0,
commission: 0.4,
};
tracker.process_fill("AAPL", Side::Sell, &sell_fill2);
// Now we should have 2 completed trades and no position
let trades = tracker.get_completed_trades();
assert_eq!(trades.len(), 2);
assert_eq!(tracker.get_net_position("AAPL"), 0.0);
}
}

View file

@ -0,0 +1,282 @@
use crate::{ExecutionHandler, FillSimulator, Order, ExecutionResult, OrderStatus, Fill, OrderBookSnapshot, OrderType, Side, MarketMicrostructure};
use crate::analytics::{MarketImpactModel, ImpactModelType};
use chrono::Utc;
use parking_lot::Mutex;
use std::collections::HashMap;
// Simulated execution for backtest and paper trading
pub struct SimulatedExecution {
fill_simulator: Box<dyn FillSimulator>,
pending_orders: Mutex<HashMap<String, Order>>,
}
impl SimulatedExecution {
pub fn new(fill_simulator: Box<dyn FillSimulator>) -> Self {
Self {
fill_simulator,
pending_orders: Mutex::new(HashMap::new()),
}
}
pub fn check_pending_orders(&self, orderbook: &OrderBookSnapshot) -> Vec<ExecutionResult> {
let mut results = Vec::new();
let mut pending = self.pending_orders.lock();
pending.retain(|order_id, order| {
if let Some(fill) = self.fill_simulator.simulate_fill(order, orderbook) {
results.push(ExecutionResult {
order_id: order_id.clone(),
status: OrderStatus::Filled,
fills: vec![fill],
});
false // Remove from pending
} else {
true // Keep in pending
}
});
results
}
}
#[async_trait::async_trait]
impl ExecutionHandler for SimulatedExecution {
async fn execute_order(&mut self, order: Order) -> Result<ExecutionResult, String> {
// For market orders, execute immediately
// For limit orders, add to pending
match &order.order_type {
OrderType::Market => {
// In simulation, market orders always fill
// The orchestrator will provide the orderbook for realistic fills
Ok(ExecutionResult {
order_id: order.id.clone(),
status: OrderStatus::Pending,
fills: vec![],
})
}
OrderType::Limit { .. } => {
self.pending_orders.lock().insert(order.id.clone(), order.clone());
Ok(ExecutionResult {
order_id: order.id,
status: OrderStatus::Accepted,
fills: vec![],
})
}
_ => Err("Order type not yet implemented".to_string()),
}
}
fn get_fill_simulator(&self) -> Option<&dyn FillSimulator> {
Some(&*self.fill_simulator)
}
}
// Backtest fill simulator - uses historical data
pub struct BacktestFillSimulator {
slippage_model: SlippageModel,
impact_model: MarketImpactModel,
microstructure_cache: Mutex<HashMap<String, MarketMicrostructure>>,
}
impl BacktestFillSimulator {
pub fn new() -> Self {
Self {
slippage_model: SlippageModel::default(),
impact_model: MarketImpactModel::new(ImpactModelType::SquareRoot),
microstructure_cache: Mutex::new(HashMap::new()),
}
}
pub fn with_impact_model(mut self, model_type: ImpactModelType) -> Self {
self.impact_model = MarketImpactModel::new(model_type);
self
}
pub fn set_microstructure(&self, symbol: String, microstructure: MarketMicrostructure) {
self.microstructure_cache.lock().insert(symbol, microstructure);
}
}
impl FillSimulator for BacktestFillSimulator {
fn simulate_fill(&self, order: &Order, orderbook: &OrderBookSnapshot) -> Option<Fill> {
match &order.order_type {
OrderType::Market => {
// Get market microstructure if available
let microstructure_guard = self.microstructure_cache.lock();
let maybe_microstructure = microstructure_guard.get(&order.symbol);
// Calculate price with market impact
let (price, _impact) = if let Some(microstructure) = maybe_microstructure {
// Use sophisticated market impact model
let impact_estimate = self.impact_model.estimate_impact(
order.quantity,
order.side,
microstructure,
match order.side {
Side::Buy => &orderbook.asks,
Side::Sell => &orderbook.bids,
},
Utc::now(),
);
let base_price = match order.side {
Side::Buy => orderbook.asks.first()?.price,
Side::Sell => orderbook.bids.first()?.price,
};
let impact_price = match order.side {
Side::Buy => base_price * (1.0 + impact_estimate.total_impact / 10000.0),
Side::Sell => base_price * (1.0 - impact_estimate.total_impact / 10000.0),
};
(impact_price, impact_estimate.total_impact)
} else {
// Fallback to simple slippage model
match order.side {
Side::Buy => {
let base_price = orderbook.asks.first()?.price;
let slippage = self.slippage_model.calculate_slippage(order.quantity, &orderbook.asks);
(base_price + slippage, slippage * 10000.0 / base_price)
}
Side::Sell => {
let base_price = orderbook.bids.first()?.price;
let slippage = self.slippage_model.calculate_slippage(order.quantity, &orderbook.bids);
(base_price - slippage, slippage * 10000.0 / base_price)
}
}
};
// Calculate realistic commission
let commission_rate = 0.0005; // 5 bps for institutional
let min_commission = 1.0;
let commission = (order.quantity * price * commission_rate).max(min_commission);
Some(Fill {
timestamp: Utc::now(), // Will be overridden by backtest engine
price,
quantity: order.quantity,
commission,
})
}
OrderType::Limit { price: limit_price } => {
// Check if limit can be filled
match order.side {
Side::Buy => {
if orderbook.asks.first()?.price <= *limit_price {
Some(Fill {
timestamp: Utc::now(),
price: *limit_price,
quantity: order.quantity,
commission: order.quantity * limit_price * 0.001,
})
} else {
None
}
}
Side::Sell => {
if orderbook.bids.first()?.price >= *limit_price {
Some(Fill {
timestamp: Utc::now(),
price: *limit_price,
quantity: order.quantity,
commission: order.quantity * limit_price * 0.001,
})
} else {
None
}
}
}
}
_ => None,
}
}
}
// Paper trading fill simulator - uses real order book
pub struct PaperFillSimulator {
use_real_orderbook: bool,
add_latency_ms: u64,
}
impl PaperFillSimulator {
pub fn new() -> Self {
Self {
use_real_orderbook: true,
add_latency_ms: 100, // Simulate 100ms latency
}
}
}
impl FillSimulator for PaperFillSimulator {
fn simulate_fill(&self, order: &Order, orderbook: &OrderBookSnapshot) -> Option<Fill> {
// Similar to backtest but with more realistic modeling
// Consider actual order book depth
// Add realistic latency simulation
// Respect position size limits based on actual liquidity
// For now, similar implementation to backtest
BacktestFillSimulator::new().simulate_fill(order, orderbook)
}
}
// Real broker execution for live trading
pub struct BrokerExecution {
broker: String,
account_id: String,
// In real implementation, would have broker API client
}
impl BrokerExecution {
pub fn new(broker: String, account_id: String) -> Self {
Self {
broker,
account_id,
}
}
}
#[async_trait::async_trait]
impl ExecutionHandler for BrokerExecution {
async fn execute_order(&mut self, order: Order) -> Result<ExecutionResult, String> {
// In real implementation, would:
// 1. Connect to broker API
// 2. Submit order
// 3. Handle broker responses
// 4. Track order status
// Placeholder for now
Ok(ExecutionResult {
order_id: order.id,
status: OrderStatus::Pending,
fills: vec![],
})
}
fn get_fill_simulator(&self) -> Option<&dyn FillSimulator> {
None // Real broker doesn't simulate
}
}
// Slippage model for realistic fills
#[derive(Default)]
struct SlippageModel {
base_slippage_bps: f64,
impact_coefficient: f64,
}
impl SlippageModel {
fn calculate_slippage(&self, quantity: f64, levels: &[crate::PriceLevel]) -> f64 {
// Simple linear impact model
// In reality would use square-root or more sophisticated model
let total_liquidity: f64 = levels.iter().map(|l| l.size).sum();
let participation_rate = quantity / total_liquidity.max(1.0);
let spread = if levels.len() >= 2 {
(levels[1].price - levels[0].price).abs()
} else {
levels[0].price * 0.0001 // 1 bps if only one level
};
spread * participation_rate * self.impact_coefficient
}
}

View file

@ -0,0 +1,152 @@
use crate::{MarketDataSource, MarketUpdate};
use chrono::{DateTime, Utc};
use parking_lot::Mutex;
use std::collections::VecDeque;
use super::mock_data_generator::MockDataGenerator;
// Historical data source for backtesting
pub struct HistoricalDataSource {
data_queue: Mutex<VecDeque<MarketUpdate>>,
current_position: Mutex<usize>,
}
impl HistoricalDataSource {
pub fn new() -> Self {
Self {
data_queue: Mutex::new(VecDeque::new()),
current_position: Mutex::new(0),
}
}
// This would be called by the orchestrator to load data
pub fn load_data(&self, data: Vec<MarketUpdate>) {
eprintln!("HistoricalDataSource::load_data called with {} items", data.len());
// Log first few items
for (i, update) in data.iter().take(3).enumerate() {
eprintln!(" Item {}: symbol={}, time={}", i, update.symbol, update.timestamp);
}
let mut queue = self.data_queue.lock();
queue.clear();
queue.extend(data);
*self.current_position.lock() = 0;
eprintln!("Data loaded successfully. Queue size: {}", queue.len());
}
pub fn data_len(&self) -> usize {
self.data_queue.lock().len()
}
// Generate mock data for testing
pub fn generate_mock_data(
&self,
symbol: String,
start_time: DateTime<Utc>,
end_time: DateTime<Utc>,
seed: Option<u64>
) {
let mut generator = MockDataGenerator::new(seed.unwrap_or(42));
let data = generator.generate_mixed_data(symbol, start_time, end_time);
self.load_data(data);
}
}
#[async_trait::async_trait]
impl MarketDataSource for HistoricalDataSource {
async fn get_next_update(&mut self) -> Option<MarketUpdate> {
let queue = self.data_queue.lock();
let mut position = self.current_position.lock();
if *position < queue.len() {
let update = queue[*position].clone();
*position += 1;
Some(update)
} else {
None
}
}
fn seek_to_time(&mut self, timestamp: DateTime<Utc>) -> Result<(), String> {
let queue = self.data_queue.lock();
let mut position = self.current_position.lock();
eprintln!("HistoricalDataSource::seek_to_time called");
eprintln!(" Target time: {}", timestamp);
eprintln!(" Queue size: {}", queue.len());
if queue.is_empty() {
eprintln!(" WARNING: Queue is empty!");
return Ok(());
}
eprintln!(" First item time: {}", queue.front().map(|u| u.timestamp.to_string()).unwrap_or("N/A".to_string()));
eprintln!(" Last item time: {}", queue.back().map(|u| u.timestamp.to_string()).unwrap_or("N/A".to_string()));
// Binary search for the timestamp
match queue.binary_search_by_key(&timestamp, |update| update.timestamp) {
Ok(pos) => {
*position = pos;
eprintln!(" Found exact match at position {}", pos);
Ok(())
}
Err(pos) => {
// Position where it would be inserted
*position = pos;
eprintln!(" No exact match, would insert at position {}", pos);
Ok(())
}
}
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}
// Live data source for paper and live trading
pub struct LiveDataSource {
// Channel to receive data from the orchestrator
data_receiver: tokio::sync::Mutex<Option<tokio::sync::mpsc::Receiver<MarketUpdate>>>,
}
impl LiveDataSource {
pub fn new() -> Self {
Self {
data_receiver: tokio::sync::Mutex::new(None),
}
}
pub async fn set_receiver(&self, receiver: tokio::sync::mpsc::Receiver<MarketUpdate>) {
*self.data_receiver.lock().await = Some(receiver);
}
}
#[async_trait::async_trait]
impl MarketDataSource for LiveDataSource {
async fn get_next_update(&mut self) -> Option<MarketUpdate> {
let mut receiver_guard = self.data_receiver.lock().await;
if let Some(receiver) = receiver_guard.as_mut() {
receiver.recv().await
} else {
None
}
}
fn seek_to_time(&mut self, _timestamp: DateTime<Utc>) -> Result<(), String> {
Err("Cannot seek in live data source".to_string())
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}

View file

@ -0,0 +1,476 @@
use crate::{MarketMicrostructure, PriceLevel, Quote, Trade, Bar, Side};
use chrono::{DateTime, Utc, Duration, Timelike};
use rand::prelude::*;
use rand_distr::{Normal, Pareto, Beta};
pub struct OrderBookReconstructor {
tick_size: f64,
lot_size: f64,
num_levels: usize,
spread_model: SpreadModel,
depth_model: DepthModel,
}
#[derive(Clone)]
pub enum SpreadModel {
Fixed { spread_ticks: u32 },
Dynamic { base_bps: f64, volatility_factor: f64 },
InformedTrader { base_bps: f64, information_decay: f64 },
}
#[derive(Clone)]
pub enum DepthModel {
Linear { base_size: f64, decay_rate: f64 },
Exponential { base_size: f64, decay_factor: f64 },
PowerLaw { alpha: f64, x_min: f64 },
}
impl OrderBookReconstructor {
pub fn new(tick_size: f64, lot_size: f64) -> Self {
Self {
tick_size,
lot_size,
num_levels: 10,
spread_model: SpreadModel::Dynamic {
base_bps: 2.0,
volatility_factor: 1.5
},
depth_model: DepthModel::Exponential {
base_size: 1000.0,
decay_factor: 0.7
},
}
}
pub fn reconstruct_from_trades_and_quotes(
&self,
trades: &[(DateTime<Utc>, Trade)],
quotes: &[(DateTime<Utc>, Quote)],
timestamp: DateTime<Utc>,
) -> (Vec<PriceLevel>, Vec<PriceLevel>) {
// Find the most recent quote before timestamp
let recent_quote = quotes.iter()
.filter(|(t, _)| *t <= timestamp)
.last()
.map(|(_, q)| q);
// Find recent trades to estimate market conditions
let recent_trades: Vec<_> = trades.iter()
.filter(|(t, _)| {
let age = timestamp - *t;
age < Duration::minutes(5) && age >= Duration::zero()
})
.map(|(_, t)| t)
.collect();
if let Some(quote) = recent_quote {
// Start with actual quote
self.build_full_book(quote, &recent_trades, timestamp)
} else if !recent_trades.is_empty() {
// Reconstruct from trades only
self.reconstruct_from_trades_only(&recent_trades, timestamp)
} else {
// No data - return empty book
(vec![], vec![])
}
}
fn build_full_book(
&self,
top_quote: &Quote,
recent_trades: &[&Trade],
_timestamp: DateTime<Utc>,
) -> (Vec<PriceLevel>, Vec<PriceLevel>) {
let mut bids = Vec::with_capacity(self.num_levels);
let mut asks = Vec::with_capacity(self.num_levels);
// Add top of book
bids.push(PriceLevel {
price: top_quote.bid,
size: top_quote.bid_size,
order_count: Some(self.estimate_order_count(top_quote.bid_size)),
});
asks.push(PriceLevel {
price: top_quote.ask,
size: top_quote.ask_size,
order_count: Some(self.estimate_order_count(top_quote.ask_size)),
});
// Calculate spread and volatility from recent trades
let (_spread_bps, _volatility) = self.estimate_market_conditions(recent_trades, top_quote);
// Build deeper levels
for i in 1..self.num_levels {
// Bid levels
let bid_price = top_quote.bid - (i as f64 * self.tick_size);
let bid_size = self.calculate_level_size(i, top_quote.bid_size, &self.depth_model);
bids.push(PriceLevel {
price: bid_price,
size: bid_size,
order_count: Some(self.estimate_order_count(bid_size)),
});
// Ask levels
let ask_price = top_quote.ask + (i as f64 * self.tick_size);
let ask_size = self.calculate_level_size(i, top_quote.ask_size, &self.depth_model);
asks.push(PriceLevel {
price: ask_price,
size: ask_size,
order_count: Some(self.estimate_order_count(ask_size)),
});
}
(bids, asks)
}
fn reconstruct_from_trades_only(
&self,
recent_trades: &[&Trade],
_timestamp: DateTime<Utc>,
) -> (Vec<PriceLevel>, Vec<PriceLevel>) {
if recent_trades.is_empty() {
return (vec![], vec![]);
}
// Estimate mid price from trades
let prices: Vec<f64> = recent_trades.iter().map(|t| t.price).collect();
let mid_price = prices.iter().sum::<f64>() / prices.len() as f64;
// Estimate spread from trade price variance
let variance = prices.iter()
.map(|p| (p - mid_price).powi(2))
.sum::<f64>() / prices.len() as f64;
let estimated_spread = variance.sqrt() * 2.0; // Rough approximation
// Build synthetic book
let bid_price = (mid_price - estimated_spread / 2.0 / self.tick_size).round() * self.tick_size;
let ask_price = (mid_price + estimated_spread / 2.0 / self.tick_size).round() * self.tick_size;
// Estimate sizes from trade volumes
let avg_trade_size = recent_trades.iter()
.map(|t| t.size)
.sum::<f64>() / recent_trades.len() as f64;
let mut bids = Vec::with_capacity(self.num_levels);
let mut asks = Vec::with_capacity(self.num_levels);
for i in 0..self.num_levels {
let level_size = avg_trade_size * 10.0 / (i + 1) as f64; // Decay with depth
bids.push(PriceLevel {
price: bid_price - (i as f64 * self.tick_size),
size: level_size,
order_count: Some(self.estimate_order_count(level_size)),
});
asks.push(PriceLevel {
price: ask_price + (i as f64 * self.tick_size),
size: level_size,
order_count: Some(self.estimate_order_count(level_size)),
});
}
(bids, asks)
}
fn calculate_level_size(&self, level: usize, _top_size: f64, model: &DepthModel) -> f64 {
let size = match model {
DepthModel::Linear { base_size, decay_rate } => {
base_size - (level as f64 * decay_rate)
}
DepthModel::Exponential { base_size, decay_factor } => {
base_size * decay_factor.powi(level as i32)
}
DepthModel::PowerLaw { alpha, x_min } => {
x_min * ((level + 1) as f64).powf(-alpha)
}
};
// Round to lot size and ensure positive
((size / self.lot_size).round() * self.lot_size).max(self.lot_size)
}
fn estimate_order_count(&self, size: f64) -> u32 {
// Estimate based on typical order size distribution
let avg_order_size = 100.0;
let base_count = (size / avg_order_size).ceil() as u32;
// Add some randomness
let mut rng = thread_rng();
let variation = rng.gen_range(0.8..1.2);
((base_count as f64 * variation) as u32).max(1)
}
fn estimate_market_conditions(
&self,
recent_trades: &[&Trade],
quote: &Quote,
) -> (f64, f64) {
if recent_trades.is_empty() {
let spread_bps = ((quote.ask - quote.bid) / quote.bid) * 10000.0;
return (spread_bps, 0.02); // Default 2% volatility
}
// Calculate spread in bps
let mid_price = (quote.bid + quote.ask) / 2.0;
let spread_bps = ((quote.ask - quote.bid) / mid_price) * 10000.0;
// Estimate volatility from trade prices
let prices: Vec<f64> = recent_trades.iter().map(|t| t.price).collect();
let returns: Vec<f64> = prices.windows(2)
.map(|w| (w[1] / w[0]).ln())
.collect();
let volatility = if !returns.is_empty() {
let mean_return = returns.iter().sum::<f64>() / returns.len() as f64;
let variance = returns.iter()
.map(|r| (r - mean_return).powi(2))
.sum::<f64>() / returns.len() as f64;
variance.sqrt() * (252.0_f64).sqrt() // Annualize
} else {
0.02 // Default 2%
};
(spread_bps, volatility)
}
}
// Market data synthesizer for generating realistic data
pub struct MarketDataSynthesizer {
base_price: f64,
tick_size: f64,
base_spread_bps: f64,
volatility: f64,
mean_reversion_speed: f64,
jump_intensity: f64,
jump_size_dist: Normal<f64>,
volume_dist: Pareto<f64>,
intraday_pattern: Vec<f64>,
}
impl MarketDataSynthesizer {
pub fn new(symbol_params: &MarketMicrostructure) -> Self {
let jump_size_dist = Normal::new(0.0, symbol_params.volatility * 0.1).unwrap();
let volume_dist = Pareto::new(1.0, 1.5).unwrap();
Self {
base_price: 100.0, // Will be updated with actual price
tick_size: symbol_params.tick_size,
base_spread_bps: symbol_params.avg_spread_bps,
volatility: symbol_params.volatility,
mean_reversion_speed: 0.1,
jump_intensity: 0.05, // 5% chance of jump per time step
jump_size_dist,
volume_dist,
intraday_pattern: symbol_params.intraday_volume_profile.clone(),
}
}
pub fn generate_quote_sequence(
&mut self,
start_price: f64,
start_time: DateTime<Utc>,
end_time: DateTime<Utc>,
interval_ms: i64,
) -> Vec<(DateTime<Utc>, Quote)> {
self.base_price = start_price;
let mut quotes = Vec::new();
let mut current_time = start_time;
let mut mid_price = start_price;
let mut spread_factor;
let mut rng = thread_rng();
while current_time <= end_time {
// Generate price movement
let dt = interval_ms as f64 / 1000.0 / 86400.0; // Convert to days
// Ornstein-Uhlenbeck process with jumps
let drift = -self.mean_reversion_speed * (mid_price / self.base_price - 1.0).ln();
let diffusion = self.volatility * (dt.sqrt()) * rng.gen::<f64>();
// Add jump component
let jump = if rng.gen::<f64>() < self.jump_intensity * dt {
mid_price * self.jump_size_dist.sample(&mut rng)
} else {
0.0
};
mid_price *= 1.0 + drift * dt + diffusion + jump;
mid_price = (mid_price / self.tick_size).round() * self.tick_size;
// Dynamic spread based on volatility and time of day
let hour_index = current_time.hour() as usize;
let volume_factor = if hour_index < self.intraday_pattern.len() {
self.intraday_pattern[hour_index]
} else {
0.04 // Default 4% of daily volume per hour
};
// Wider spreads during low volume periods
spread_factor = 1.0 / volume_factor.sqrt();
let spread_bps = self.base_spread_bps * spread_factor;
let half_spread = mid_price * spread_bps / 20000.0;
// Generate bid/ask
let bid = ((mid_price - half_spread) / self.tick_size).floor() * self.tick_size;
let ask = ((mid_price + half_spread) / self.tick_size).ceil() * self.tick_size;
// Generate sizes with correlation to spread
let size_multiplier = 1.0 / spread_factor; // Tighter spread = more size
let bid_size = (self.volume_dist.sample(&mut rng) * 1000.0 * size_multiplier).round();
let ask_size = (self.volume_dist.sample(&mut rng) * 1000.0 * size_multiplier).round();
quotes.push((current_time, Quote {
bid,
ask,
bid_size,
ask_size,
}));
current_time = current_time + Duration::milliseconds(interval_ms);
}
quotes
}
pub fn generate_trade_sequence(
&mut self,
quotes: &[(DateTime<Utc>, Quote)],
trade_intensity: f64,
) -> Vec<(DateTime<Utc>, Trade)> {
let mut trades = Vec::new();
let mut rng = thread_rng();
let beta_dist = Beta::new(2.0, 5.0).unwrap(); // Skewed towards smaller trades
for (time, quote) in quotes {
// Poisson process for trade arrivals
let num_trades = rng.gen_range(0..((trade_intensity * 10.0) as u32));
for i in 0..num_trades {
// Determine trade side (slight bias based on spread)
let spread_ratio = (quote.ask - quote.bid) / quote.bid;
let buy_prob = 0.5 - spread_ratio * 10.0; // More sells when spread is wide
let side = if rng.gen::<f64>() < buy_prob {
Side::Buy
} else {
Side::Sell
};
// Trade price (sometimes inside spread for large trades)
let price = match side {
Side::Buy => {
if rng.gen::<f64>() < 0.9 {
quote.ask // Take liquidity
} else {
// Provide liquidity (inside spread)
quote.bid + (quote.ask - quote.bid) * rng.gen::<f64>()
}
}
Side::Sell => {
if rng.gen::<f64>() < 0.9 {
quote.bid // Take liquidity
} else {
// Provide liquidity (inside spread)
quote.bid + (quote.ask - quote.bid) * rng.gen::<f64>()
}
}
};
// Trade size (power law distribution)
let size_percentile = beta_dist.sample(&mut rng);
let base_size = match side {
Side::Buy => quote.ask_size,
Side::Sell => quote.bid_size,
};
let size = (base_size * size_percentile * 0.1).round().max(1.0);
// Add small time offset for multiple trades
let trade_time = *time + Duration::milliseconds(i as i64 * 100);
trades.push((trade_time, Trade {
price,
size,
side,
}));
}
}
trades.sort_by_key(|(t, _)| *t);
trades
}
pub fn aggregate_to_bars(
&self,
trades: &[(DateTime<Utc>, Trade)],
bar_duration: Duration,
) -> Vec<(DateTime<Utc>, Bar)> {
if trades.is_empty() {
return Vec::new();
}
let mut bars = Vec::new();
let mut current_bar_start = trades[0].0;
let mut current_bar_end = current_bar_start + bar_duration;
let mut open = 0.0;
let mut high = 0.0;
let mut low = f64::MAX;
let mut close = 0.0;
let mut volume = 0.0;
let mut vwap_numerator = 0.0;
let mut first_trade = true;
for (time, trade) in trades {
// Check if we need to start a new bar
while *time >= current_bar_end {
if volume > 0.0 {
bars.push((current_bar_start, Bar {
open,
high,
low,
close,
volume,
vwap: Some(vwap_numerator / volume),
}));
}
// Reset for new bar
current_bar_start = current_bar_end;
current_bar_end = current_bar_start + bar_duration;
open = 0.0;
high = 0.0;
low = f64::MAX;
close = 0.0;
volume = 0.0;
vwap_numerator = 0.0;
first_trade = true;
}
// Update current bar
if first_trade {
open = trade.price;
first_trade = false;
}
high = high.max(trade.price);
low = low.min(trade.price);
close = trade.price;
volume += trade.size;
vwap_numerator += trade.price * trade.size;
}
// Add final bar if it has data
if volume > 0.0 {
bars.push((current_bar_start, Bar {
open,
high,
low,
close,
volume,
vwap: Some(vwap_numerator / volume),
}));
}
bars
}
}

View file

@ -0,0 +1,229 @@
use crate::{MarketUpdate, MarketDataType, Quote, Trade, Bar, Side};
use chrono::{DateTime, Utc, Duration};
use rand::{Rng, SeedableRng};
use rand::rngs::StdRng;
use rand_distr::{Normal, Distribution};
pub struct MockDataGenerator {
rng: StdRng,
base_price: f64,
volatility: f64,
spread_bps: f64,
volume_mean: f64,
volume_std: f64,
}
impl MockDataGenerator {
pub fn new(seed: u64) -> Self {
Self {
rng: StdRng::seed_from_u64(seed),
base_price: 100.0,
volatility: 0.02, // 2% daily volatility
spread_bps: 5.0, // 5 basis points spread
volume_mean: 1_000_000.0,
volume_std: 200_000.0,
}
}
pub fn with_params(seed: u64, base_price: f64, volatility: f64, spread_bps: f64) -> Self {
Self {
rng: StdRng::seed_from_u64(seed),
base_price,
volatility,
spread_bps,
volume_mean: 1_000_000.0,
volume_std: 200_000.0,
}
}
pub fn generate_quotes(
&mut self,
symbol: String,
start_time: DateTime<Utc>,
end_time: DateTime<Utc>,
interval_ms: i64,
) -> Vec<MarketUpdate> {
let mut updates = Vec::new();
let mut current_time = start_time;
let mut price = self.base_price;
let price_dist = Normal::new(0.0, self.volatility).unwrap();
let volume_dist = Normal::new(self.volume_mean, self.volume_std).unwrap();
while current_time <= end_time {
// Generate price movement
let return_pct = price_dist.sample(&mut self.rng) / 100.0;
price *= 1.0 + return_pct;
price = price.max(0.01); // Ensure positive price
// Calculate bid/ask
let half_spread = price * self.spread_bps / 20000.0;
let bid = price - half_spread;
let ask = price + half_spread;
// Generate volume
let volume = volume_dist.sample(&mut self.rng).max(0.0) as u32;
updates.push(MarketUpdate {
symbol: symbol.clone(),
timestamp: current_time,
data: MarketDataType::Quote(Quote {
bid,
ask,
bid_size: (volume / 10) as f64,
ask_size: (volume / 10) as f64,
}),
});
current_time = current_time + Duration::milliseconds(interval_ms);
}
updates
}
pub fn generate_trades(
&mut self,
symbol: String,
start_time: DateTime<Utc>,
end_time: DateTime<Utc>,
trades_per_minute: u32,
) -> Vec<MarketUpdate> {
let mut updates = Vec::new();
let mut current_time = start_time;
let mut price = self.base_price;
let price_dist = Normal::new(0.0, self.volatility / 100.0).unwrap();
let volume_dist = Normal::new(100.0, 50.0).unwrap();
let interval_ms = 60_000 / trades_per_minute as i64;
while current_time <= end_time {
// Generate price movement
let return_pct = price_dist.sample(&mut self.rng);
price *= 1.0 + return_pct;
price = price.max(0.01);
// Generate trade size
let raw_size: f64 = volume_dist.sample(&mut self.rng);
let size = raw_size.max(1.0) as u32;
// Random buy/sell
let is_buy = self.rng.gen_bool(0.5);
updates.push(MarketUpdate {
symbol: symbol.clone(),
timestamp: current_time,
data: MarketDataType::Trade(Trade {
price,
size: size as f64,
side: if is_buy { Side::Buy } else { Side::Sell },
}),
});
current_time = current_time + Duration::milliseconds(interval_ms);
}
updates
}
pub fn generate_bars(
&mut self,
symbol: String,
start_time: DateTime<Utc>,
end_time: DateTime<Utc>,
timeframe: &str,
) -> Vec<MarketUpdate> {
let mut updates = Vec::new();
let mut current_time = start_time;
let mut price = self.base_price;
let interval = match timeframe {
"1m" => Duration::minutes(1),
"5m" => Duration::minutes(5),
"15m" => Duration::minutes(15),
"1h" => Duration::hours(1),
"1d" => Duration::days(1),
_ => Duration::minutes(1),
};
let price_dist = Normal::new(0.0, self.volatility).unwrap();
let volume_dist = Normal::new(self.volume_mean, self.volume_std).unwrap();
while current_time <= end_time {
// Generate OHLC
let open = price;
let mut high = open;
let mut low = open;
// Simulate intrabar movements
for _ in 0..4 {
let move_pct = price_dist.sample(&mut self.rng) / 100.0;
price *= 1.0 + move_pct;
price = price.max(0.01);
high = high.max(price);
low = low.min(price);
}
let close = price;
let volume = volume_dist.sample(&mut self.rng).max(0.0) as u64;
updates.push(MarketUpdate {
symbol: symbol.clone(),
timestamp: current_time,
data: MarketDataType::Bar(Bar {
open,
high,
low,
close,
volume: volume as f64,
vwap: Some((open + high + low + close) / 4.0),
}),
});
current_time = current_time + interval;
}
updates
}
pub fn generate_mixed_data(
&mut self,
symbol: String,
start_time: DateTime<Utc>,
end_time: DateTime<Utc>,
) -> Vec<MarketUpdate> {
let mut all_updates = Vec::new();
// Generate quotes every 100ms
let quotes = self.generate_quotes(
symbol.clone(),
start_time,
end_time,
100
);
all_updates.extend(quotes);
// Generate trades
let trades = self.generate_trades(
symbol.clone(),
start_time,
end_time,
20 // 20 trades per minute
);
all_updates.extend(trades);
// Generate 1-minute bars
let bars = self.generate_bars(
symbol,
start_time,
end_time,
"1m"
);
all_updates.extend(bars);
// Sort by timestamp
all_updates.sort_by_key(|update| update.timestamp);
all_updates
}
}

View file

@ -0,0 +1,51 @@
pub mod time_providers;
pub mod market_data_sources;
pub mod execution_handlers;
pub mod market_microstructure;
pub mod mock_data_generator;
use crate::{MarketDataSource, ExecutionHandler, TimeProvider, TradingMode};
// Factory functions to create appropriate implementations based on mode
pub fn create_market_data_source(mode: &TradingMode) -> Box<dyn MarketDataSource> {
match mode {
TradingMode::Backtest { .. } => {
Box::new(market_data_sources::HistoricalDataSource::new())
}
TradingMode::Paper { .. } | TradingMode::Live { .. } => {
Box::new(market_data_sources::LiveDataSource::new())
}
}
}
pub fn create_execution_handler(mode: &TradingMode) -> Box<dyn ExecutionHandler> {
match mode {
TradingMode::Backtest { .. } => {
Box::new(execution_handlers::SimulatedExecution::new(
Box::new(execution_handlers::BacktestFillSimulator::new())
))
}
TradingMode::Paper { .. } => {
Box::new(execution_handlers::SimulatedExecution::new(
Box::new(execution_handlers::PaperFillSimulator::new())
))
}
TradingMode::Live { broker, account_id } => {
Box::new(execution_handlers::BrokerExecution::new(
broker.clone(),
account_id.clone()
))
}
}
}
pub fn create_time_provider(mode: &TradingMode) -> Box<dyn TimeProvider> {
match mode {
TradingMode::Backtest { start_time, .. } => {
Box::new(time_providers::SimulatedTime::new(*start_time))
}
TradingMode::Paper { .. } | TradingMode::Live { .. } => {
Box::new(time_providers::SystemTime::new())
}
}
}

View file

@ -0,0 +1,74 @@
use crate::TimeProvider;
use chrono::{DateTime, Utc};
use parking_lot::Mutex;
use std::sync::Arc;
// Real-time provider for paper and live trading
pub struct SystemTime;
impl SystemTime {
pub fn new() -> Self {
Self
}
}
impl TimeProvider for SystemTime {
fn now(&self) -> DateTime<Utc> {
Utc::now()
}
fn sleep_until(&self, target: DateTime<Utc>) -> Result<(), String> {
let now = Utc::now();
if target > now {
let duration = (target - now).to_std()
.map_err(|e| format!("Invalid duration: {}", e))?;
std::thread::sleep(duration);
}
Ok(())
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
// Simulated time for backtesting
pub struct SimulatedTime {
current_time: Arc<Mutex<DateTime<Utc>>>,
}
impl SimulatedTime {
pub fn new(start_time: DateTime<Utc>) -> Self {
Self {
current_time: Arc::new(Mutex::new(start_time)),
}
}
pub fn advance_to(&self, new_time: DateTime<Utc>) {
let mut current = self.current_time.lock();
if new_time > *current {
*current = new_time;
}
}
pub fn advance_by(&self, duration: chrono::Duration) {
let mut current = self.current_time.lock();
*current = *current + duration;
}
}
impl TimeProvider for SimulatedTime {
fn now(&self) -> DateTime<Utc> {
*self.current_time.lock()
}
fn sleep_until(&self, _target: DateTime<Utc>) -> Result<(), String> {
// In backtest mode, we don't actually sleep
// Time is controlled by the backtest engine
Ok(())
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}

View file

@ -0,0 +1,65 @@
use chrono::{DateTime, Utc};
use serde::{Serialize, Deserialize};
use super::{Order, Fill, MarketUpdate, Position};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum EventType {
MarketData(MarketUpdate),
OrderSubmitted(Order),
OrderFilled { order_id: String, fill: Fill },
OrderCancelled { order_id: String },
OrderRejected { order_id: String, reason: String },
PositionUpdate(Position),
RiskAlert { message: String, severity: RiskSeverity },
SystemStatus { status: SystemStatus },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum RiskSeverity {
Info,
Warning,
Critical,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SystemStatus {
Starting,
Running,
Paused,
Stopping,
Error(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Event {
pub id: String,
pub timestamp: DateTime<Utc>,
pub event_type: EventType,
}
impl Event {
pub fn new(event_type: EventType) -> Self {
Self {
id: uuid::Uuid::new_v4().to_string(),
timestamp: Utc::now(),
event_type,
}
}
pub fn market_data(data: MarketUpdate) -> Self {
Self::new(EventType::MarketData(data))
}
pub fn order_submitted(order: Order) -> Self {
Self::new(EventType::OrderSubmitted(order))
}
pub fn order_filled(order_id: String, fill: Fill) -> Self {
Self::new(EventType::OrderFilled { order_id, fill })
}
}
pub trait EventHandler: Send + Sync {
fn handle_event(&self, event: &Event);
fn event_types(&self) -> Vec<String>; // Which event types this handler is interested in
}

View file

@ -0,0 +1,44 @@
use chrono::{DateTime, Utc};
use serde::{Serialize, Deserialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Quote {
pub bid: f64,
pub ask: f64,
pub bid_size: f64,
pub ask_size: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Bar {
pub open: f64,
pub high: f64,
pub low: f64,
pub close: f64,
pub volume: f64,
pub vwap: Option<f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Trade {
pub price: f64,
pub size: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum MarketDataType {
Quote(Quote),
Bar(Bar),
Trade(Trade),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MarketUpdate {
pub symbol: String,
pub timestamp: DateTime<Utc>,
pub data: MarketDataType,
}
// Type alias for compatibility
pub type MarketData = MarketUpdate;

View file

@ -0,0 +1,10 @@
pub mod market;
pub mod orders;
pub mod positions;
pub mod events;
// Re-export commonly used types
pub use market::{Quote, Bar, Trade, MarketUpdate, MarketDataType};
pub use orders::{Order, OrderType, OrderStatus, TimeInForce, Side, Fill};
pub use positions::{Position, PositionUpdate};
pub use events::{Event, EventType, EventHandler};

View file

@ -0,0 +1,104 @@
use chrono::{DateTime, Utc};
use serde::{Serialize, Deserialize};
use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum Side {
Buy,
Sell,
}
impl fmt::Display for Side {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Side::Buy => write!(f, "buy"),
Side::Sell => write!(f, "sell"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum OrderType {
Market,
Limit,
Stop,
StopLimit,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum OrderStatus {
Pending,
Submitted,
PartiallyFilled,
Filled,
Cancelled,
Rejected,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum TimeInForce {
Day,
GTC, // Good Till Cancelled
IOC, // Immediate or Cancel
FOK, // Fill or Kill
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Order {
pub id: String,
pub symbol: String,
pub side: Side,
pub quantity: f64,
pub order_type: OrderType,
pub limit_price: Option<f64>,
pub stop_price: Option<f64>,
pub time_in_force: TimeInForce,
pub status: OrderStatus,
pub submitted_at: Option<DateTime<Utc>>,
pub filled_quantity: f64,
pub average_fill_price: Option<f64>,
}
impl Order {
pub fn new_market_order(symbol: String, side: Side, quantity: f64) -> Self {
Self {
id: uuid::Uuid::new_v4().to_string(),
symbol,
side,
quantity,
order_type: OrderType::Market,
limit_price: None,
stop_price: None,
time_in_force: TimeInForce::Day,
status: OrderStatus::Pending,
submitted_at: None,
filled_quantity: 0.0,
average_fill_price: None,
}
}
pub fn new_limit_order(symbol: String, side: Side, quantity: f64, limit_price: f64) -> Self {
Self {
id: uuid::Uuid::new_v4().to_string(),
symbol,
side,
quantity,
order_type: OrderType::Limit,
limit_price: Some(limit_price),
stop_price: None,
time_in_force: TimeInForce::Day,
status: OrderStatus::Pending,
submitted_at: None,
filled_quantity: 0.0,
average_fill_price: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Fill {
pub timestamp: DateTime<Utc>,
pub price: f64,
pub quantity: f64,
pub commission: f64,
}

View file

@ -0,0 +1,59 @@
use chrono::{DateTime, Utc};
use serde::{Serialize, Deserialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Position {
pub symbol: String,
pub quantity: f64,
pub average_price: f64,
pub realized_pnl: f64,
pub unrealized_pnl: f64,
pub last_update: DateTime<Utc>,
}
impl Position {
pub fn new(symbol: String) -> Self {
Self {
symbol,
quantity: 0.0,
average_price: 0.0,
realized_pnl: 0.0,
unrealized_pnl: 0.0,
last_update: Utc::now(),
}
}
pub fn is_long(&self) -> bool {
self.quantity > 0.0
}
pub fn is_short(&self) -> bool {
self.quantity < 0.0
}
pub fn is_flat(&self) -> bool {
self.quantity.abs() < f64::EPSILON
}
pub fn market_value(&self, current_price: f64) -> f64 {
self.quantity * current_price
}
pub fn calculate_unrealized_pnl(&self, current_price: f64) -> f64 {
if self.is_flat() {
0.0
} else if self.is_long() {
(current_price - self.average_price) * self.quantity
} else {
(self.average_price - current_price) * self.quantity.abs()
}
}
}
#[derive(Debug, Clone)]
pub struct PositionUpdate {
pub symbol: String,
pub previous_position: Option<Position>,
pub resulting_position: Position,
pub realized_pnl: f64,
}

View file

@ -0,0 +1,91 @@
use std::sync::Arc;
use parking_lot::RwLock;
use std::collections::HashMap;
use crate::domain::{Event, EventHandler};
use tokio::sync::mpsc;
pub struct EventBus {
handlers: Arc<RwLock<HashMap<String, Vec<Arc<dyn EventHandler>>>>>,
sender: mpsc::UnboundedSender<Event>,
}
impl EventBus {
pub fn new() -> (Self, mpsc::UnboundedReceiver<Event>) {
let (sender, receiver) = mpsc::unbounded_channel();
let bus = Self {
handlers: Arc::new(RwLock::new(HashMap::new())),
sender,
};
(bus, receiver)
}
pub fn subscribe(&self, event_type: String, handler: Arc<dyn EventHandler>) {
let mut handlers = self.handlers.write();
handlers.entry(event_type).or_insert_with(Vec::new).push(handler);
}
pub fn publish(&self, event: Event) -> Result<(), String> {
// Send to async handler
self.sender.send(event.clone())
.map_err(|_| "Failed to send event".to_string())?;
// Also handle synchronously for immediate handlers
let event_type = match &event.event_type {
crate::domain::EventType::MarketData(_) => "market_data",
crate::domain::EventType::OrderSubmitted(_) => "order_submitted",
crate::domain::EventType::OrderFilled { .. } => "order_filled",
crate::domain::EventType::OrderCancelled { .. } => "order_cancelled",
crate::domain::EventType::OrderRejected { .. } => "order_rejected",
crate::domain::EventType::PositionUpdate(_) => "position_update",
crate::domain::EventType::RiskAlert { .. } => "risk_alert",
crate::domain::EventType::SystemStatus { .. } => "system_status",
};
let handlers = self.handlers.read();
if let Some(event_handlers) = handlers.get(event_type) {
for handler in event_handlers {
handler.handle_event(&event);
}
}
Ok(())
}
}
// Simple event processor that runs in the background
pub struct EventProcessor {
receiver: mpsc::UnboundedReceiver<Event>,
handlers: Arc<RwLock<HashMap<String, Vec<Arc<dyn EventHandler>>>>>,
}
impl EventProcessor {
pub fn new(
receiver: mpsc::UnboundedReceiver<Event>,
handlers: Arc<RwLock<HashMap<String, Vec<Arc<dyn EventHandler>>>>>,
) -> Self {
Self { receiver, handlers }
}
pub async fn run(mut self) {
while let Some(event) = self.receiver.recv().await {
// Process event asynchronously
let event_type = match &event.event_type {
crate::domain::EventType::MarketData(_) => "market_data",
crate::domain::EventType::OrderSubmitted(_) => "order_submitted",
crate::domain::EventType::OrderFilled { .. } => "order_filled",
crate::domain::EventType::OrderCancelled { .. } => "order_cancelled",
crate::domain::EventType::OrderRejected { .. } => "order_rejected",
crate::domain::EventType::PositionUpdate(_) => "position_update",
crate::domain::EventType::RiskAlert { .. } => "risk_alert",
crate::domain::EventType::SystemStatus { .. } => "system_status",
};
let handlers = self.handlers.read();
if let Some(event_handlers) = handlers.get(event_type) {
for handler in event_handlers {
handler.handle_event(&event);
}
}
}
}
}

View file

@ -0,0 +1,250 @@
use super::{Indicator, IncrementalIndicator, IndicatorResult, IndicatorError, PriceData};
use super::common::RollingWindow;
/// Average True Range (ATR) Indicator
///
/// Measures volatility by calculating the average of true ranges over a period
/// True Range = max(High - Low, |High - Previous Close|, |Low - Previous Close|)
pub struct ATR {
period: usize,
atr_value: Option<f64>,
prev_close: Option<f64>,
true_ranges: RollingWindow<f64>,
sum: f64,
initialized: bool,
}
impl ATR {
pub fn new(period: usize) -> Result<Self, IndicatorError> {
if period == 0 {
return Err(IndicatorError::InvalidParameter(
"Period must be greater than 0".to_string()
));
}
Ok(Self {
period,
atr_value: None,
prev_close: None,
true_ranges: RollingWindow::new(period),
sum: 0.0,
initialized: false,
})
}
/// Calculate True Range
fn calculate_true_range(high: f64, low: f64, prev_close: Option<f64>) -> f64 {
let high_low = high - low;
match prev_close {
Some(prev) => {
let high_close = (high - prev).abs();
let low_close = (low - prev).abs();
high_low.max(high_close).max(low_close)
}
None => high_low,
}
}
/// Calculate ATR for a series of price data
pub fn calculate_series(
high: &[f64],
low: &[f64],
close: &[f64],
period: usize
) -> Result<Vec<f64>, IndicatorError> {
if high.len() != low.len() || high.len() != close.len() {
return Err(IndicatorError::InvalidParameter(
"High, low, and close arrays must have the same length".to_string()
));
}
if high.len() < period + 1 {
return Err(IndicatorError::InsufficientData {
required: period + 1,
actual: high.len(),
});
}
let mut true_ranges = Vec::with_capacity(high.len() - 1);
// Calculate true ranges
for i in 1..high.len() {
let tr = Self::calculate_true_range(high[i], low[i], Some(close[i - 1]));
true_ranges.push(tr);
}
let mut atr_values = Vec::with_capacity(true_ranges.len() - period + 1);
// Calculate initial ATR as SMA of first period true ranges
let initial_atr: f64 = true_ranges[0..period].iter().sum::<f64>() / period as f64;
atr_values.push(initial_atr);
// Calculate subsequent ATRs using Wilder's smoothing
let mut atr = initial_atr;
for i in period..true_ranges.len() {
// Wilder's smoothing: ATR = ((n-1) * ATR + TR) / n
atr = ((period - 1) as f64 * atr + true_ranges[i]) / period as f64;
atr_values.push(atr);
}
Ok(atr_values)
}
}
impl Indicator for ATR {
fn calculate(&mut self, data: &PriceData) -> Result<IndicatorResult, IndicatorError> {
if data.high.len() != data.low.len() || data.high.len() != data.close.len() {
return Err(IndicatorError::InvalidParameter(
"Price data arrays must have the same length".to_string()
));
}
let atr_values = Self::calculate_series(
&data.high,
&data.low,
&data.close,
self.period
)?;
Ok(IndicatorResult::Series(atr_values))
}
fn reset(&mut self) {
self.atr_value = None;
self.prev_close = None;
self.true_ranges.clear();
self.sum = 0.0;
self.initialized = false;
}
fn is_ready(&self) -> bool {
self.initialized
}
}
impl IncrementalIndicator for ATR {
fn update(&mut self, value: f64) -> Result<Option<f64>, IndicatorError> {
// For single value update, we assume it's the close price
// In real usage, you'd need to provide high/low/close separately
self.update_hlc(value, value, value)
}
fn current(&self) -> Option<f64> {
self.atr_value
}
}
impl ATR {
/// Update with high, low, close values
pub fn update_hlc(&mut self, high: f64, low: f64, close: f64) -> Result<Option<f64>, IndicatorError> {
if let Some(prev_close) = self.prev_close {
let tr = Self::calculate_true_range(high, low, Some(prev_close));
if !self.initialized {
// Still building initial window
self.true_ranges.push(tr);
self.sum += tr;
if self.true_ranges.is_full() {
// Calculate initial ATR
let initial_atr = self.sum / self.period as f64;
self.atr_value = Some(initial_atr);
self.initialized = true;
}
} else {
// Update ATR using Wilder's smoothing
if let Some(current_atr) = self.atr_value {
let new_atr = ((self.period - 1) as f64 * current_atr + tr) / self.period as f64;
self.atr_value = Some(new_atr);
}
}
}
self.prev_close = Some(close);
Ok(self.atr_value)
}
/// Get ATR as a percentage of price
pub fn atr_percent(&self, price: f64) -> Option<f64> {
self.atr_value.map(|atr| (atr / price) * 100.0)
}
/// Calculate stop loss based on ATR multiple
pub fn calculate_stop_loss(&self, entry_price: f64, is_long: bool, atr_multiple: f64) -> Option<f64> {
self.atr_value.map(|atr| {
if is_long {
entry_price - (atr * atr_multiple)
} else {
entry_price + (atr * atr_multiple)
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_atr_calculation() {
let high = vec![
48.70, 48.72, 48.90, 48.87, 48.82,
49.05, 49.20, 49.35, 49.92, 50.19,
50.12, 49.66, 49.88, 50.19, 50.36
];
let low = vec![
47.79, 48.14, 48.39, 48.37, 48.24,
48.64, 48.94, 48.86, 49.50, 49.87,
49.20, 48.90, 49.43, 49.73, 49.26
];
let close = vec![
48.16, 48.61, 48.75, 48.63, 48.74,
49.03, 49.07, 49.32, 49.91, 50.13,
49.53, 49.50, 49.75, 50.03, 50.29
];
let atr_values = ATR::calculate_series(&high, &low, &close, 14).unwrap();
assert!(!atr_values.is_empty());
// ATR should always be positive
for atr in &atr_values {
assert!(*atr > 0.0);
}
}
#[test]
fn test_true_range_calculation() {
// Test case where high-low is largest
let tr = ATR::calculate_true_range(50.0, 45.0, Some(47.0));
assert_eq!(tr, 5.0);
// Test case where high-prev_close is largest
let tr = ATR::calculate_true_range(50.0, 48.0, Some(45.0));
assert_eq!(tr, 5.0);
// Test case where prev_close-low is largest
let tr = ATR::calculate_true_range(48.0, 45.0, Some(50.0));
assert_eq!(tr, 5.0);
}
#[test]
fn test_incremental_atr() {
let mut atr = ATR::new(5).unwrap();
// Need at least one previous close
assert_eq!(atr.update_hlc(10.0, 9.0, 9.5).unwrap(), None);
// Build up window
assert_eq!(atr.update_hlc(10.2, 9.1, 9.8).unwrap(), None);
assert_eq!(atr.update_hlc(10.5, 9.3, 10.0).unwrap(), None);
assert_eq!(atr.update_hlc(10.3, 9.5, 9.7).unwrap(), None);
assert_eq!(atr.update_hlc(10.1, 9.2, 9.6).unwrap(), None);
// Should have ATR value now
let atr_value = atr.update_hlc(10.0, 9.0, 9.5).unwrap();
assert!(atr_value.is_some());
assert!(atr_value.unwrap() > 0.0);
}
}

View file

@ -0,0 +1,256 @@
use super::{Indicator, IncrementalIndicator, IndicatorResult, IndicatorError, PriceData};
use super::sma::SMA;
use super::common::RollingWindow;
/// Bollinger Bands Indicator
///
/// Middle Band = SMA(n)
/// Upper Band = SMA(n) + (k × σ)
/// Lower Band = SMA(n) - (k × σ)
/// where σ is the standard deviation and k is typically 2
pub struct BollingerBands {
period: usize,
std_dev_multiplier: f64,
sma: SMA,
window: RollingWindow<f64>,
}
impl BollingerBands {
pub fn new(period: usize, std_dev_multiplier: f64) -> Result<Self, IndicatorError> {
if period == 0 {
return Err(IndicatorError::InvalidParameter(
"Period must be greater than 0".to_string()
));
}
if std_dev_multiplier <= 0.0 {
return Err(IndicatorError::InvalidParameter(
"Standard deviation multiplier must be positive".to_string()
));
}
Ok(Self {
period,
std_dev_multiplier,
sma: SMA::new(period)?,
window: RollingWindow::new(period),
})
}
/// Standard Bollinger Bands with 20 period and 2 standard deviations
pub fn standard() -> Result<Self, IndicatorError> {
Self::new(20, 2.0)
}
/// Calculate standard deviation
fn calculate_std_dev(values: &[f64], mean: f64) -> f64 {
if values.is_empty() {
return 0.0;
}
let variance = values.iter()
.map(|x| (*x - mean).powi(2))
.sum::<f64>() / values.len() as f64;
variance.sqrt()
}
/// Calculate Bollinger Bands for a series of values
pub fn calculate_series(
values: &[f64],
period: usize,
std_dev_multiplier: f64
) -> Result<(Vec<f64>, Vec<f64>, Vec<f64>), IndicatorError> {
if period == 0 {
return Err(IndicatorError::InvalidParameter(
"Period must be greater than 0".to_string()
));
}
if values.len() < period {
return Err(IndicatorError::InsufficientData {
required: period,
actual: values.len(),
});
}
// Calculate SMA (middle band)
let middle_band = SMA::calculate_series(values, period)?;
let mut upper_band = Vec::with_capacity(middle_band.len());
let mut lower_band = Vec::with_capacity(middle_band.len());
// Calculate bands
for i in 0..middle_band.len() {
// Get the window of values for this position
let start_idx = i;
let end_idx = i + period;
let window = &values[start_idx..end_idx];
// Calculate standard deviation
let std_dev = Self::calculate_std_dev(window, middle_band[i]);
let band_width = std_dev * std_dev_multiplier;
upper_band.push(middle_band[i] + band_width);
lower_band.push(middle_band[i] - band_width);
}
Ok((middle_band, upper_band, lower_band))
}
/// Calculate the current bandwidth (distance between upper and lower bands)
pub fn bandwidth(&self) -> Option<f64> {
if let Some(values) = self.get_current_bands() {
Some(values.upper - values.lower)
} else {
None
}
}
/// Calculate %B (percent b) - position of price relative to bands
/// %B = (Price - Lower Band) / (Upper Band - Lower Band)
pub fn percent_b(&self, price: f64) -> Option<f64> {
if let Some(values) = self.get_current_bands() {
let width = values.upper - values.lower;
if width > 0.0 {
Some((price - values.lower) / width)
} else {
None
}
} else {
None
}
}
fn get_current_bands(&self) -> Option<BollingerBandsValues> {
if let Some(middle) = self.sma.current() {
if self.window.is_full() {
let values = self.window.as_slice();
let std_dev = Self::calculate_std_dev(&values, middle);
let band_width = std_dev * self.std_dev_multiplier;
Some(BollingerBandsValues {
middle,
upper: middle + band_width,
lower: middle - band_width,
})
} else {
None
}
} else {
None
}
}
}
impl Indicator for BollingerBands {
fn calculate(&mut self, data: &PriceData) -> Result<IndicatorResult, IndicatorError> {
let values = &data.close;
let (middle, upper, lower) = Self::calculate_series(
values,
self.period,
self.std_dev_multiplier
)?;
Ok(IndicatorResult::BollingerBands {
middle,
upper,
lower,
})
}
fn reset(&mut self) {
self.sma.reset();
self.window.clear();
}
fn is_ready(&self) -> bool {
self.sma.is_ready() && self.window.is_full()
}
}
impl IncrementalIndicator for BollingerBands {
fn update(&mut self, value: f64) -> Result<Option<f64>, IndicatorError> {
// Update window
self.window.push(value);
// Update SMA
let _sma_result = self.sma.update(value)?;
// Return bandwidth if ready
if self.is_ready() {
Ok(self.bandwidth())
} else {
Ok(None)
}
}
fn current(&self) -> Option<f64> {
self.bandwidth()
}
}
/// Structure to hold all Bollinger Bands values
pub struct BollingerBandsValues {
pub middle: f64,
pub upper: f64,
pub lower: f64,
}
impl BollingerBands {
/// Get all current band values
pub fn current_values(&self) -> Option<BollingerBandsValues> {
self.get_current_bands()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bollinger_bands_calculation() {
let values = vec![
20.0, 21.0, 22.0, 23.0, 24.0,
25.0, 24.0, 23.0, 22.0, 21.0,
20.0, 21.0, 22.0, 23.0, 24.0,
25.0, 24.0, 23.0, 22.0, 21.0
];
let (middle, upper, lower) = BollingerBands::calculate_series(&values, 5, 2.0).unwrap();
assert_eq!(middle.len(), 16);
assert_eq!(upper.len(), 16);
assert_eq!(lower.len(), 16);
// Upper band should always be above middle
for i in 0..middle.len() {
assert!(upper[i] > middle[i]);
assert!(lower[i] < middle[i]);
}
}
#[test]
fn test_percent_b() {
let mut bb = BollingerBands::standard().unwrap();
// Create some test data
for i in 0..25 {
let _ = bb.update(20.0 + (i as f64 % 5.0));
}
// Price at upper band should give %B ≈ 1.0
if let Some(bands) = bb.current_values() {
let percent_b = bb.percent_b(bands.upper).unwrap();
assert!((percent_b - 1.0).abs() < 1e-10);
// Price at lower band should give %B ≈ 0.0
let percent_b = bb.percent_b(bands.lower).unwrap();
assert!(percent_b.abs() < 1e-10);
// Price at middle band should give %B ≈ 0.5
let percent_b = bb.percent_b(bands.middle).unwrap();
assert!((percent_b - 0.5).abs() < 0.1);
}
}
}

View file

@ -0,0 +1,142 @@
use serde::{Serialize, Deserialize};
use std::collections::VecDeque;
/// Common types and utilities for indicators
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PriceData {
pub open: Vec<f64>,
pub high: Vec<f64>,
pub low: Vec<f64>,
pub close: Vec<f64>,
pub volume: Vec<f64>,
}
impl PriceData {
pub fn new() -> Self {
Self {
open: Vec::new(),
high: Vec::new(),
low: Vec::new(),
close: Vec::new(),
volume: Vec::new(),
}
}
pub fn len(&self) -> usize {
self.close.len()
}
pub fn is_empty(&self) -> bool {
self.close.is_empty()
}
/// Get typical price (high + low + close) / 3
pub fn typical_prices(&self) -> Vec<f64> {
self.high.iter()
.zip(&self.low)
.zip(&self.close)
.map(|((h, l), c)| (h + l + c) / 3.0)
.collect()
}
/// Get true range for ATR calculation
pub fn true_ranges(&self) -> Vec<f64> {
if self.len() < 2 {
return vec![];
}
let mut ranges = Vec::with_capacity(self.len() - 1);
for i in 1..self.len() {
let high_low = self.high[i] - self.low[i];
let high_close = (self.high[i] - self.close[i - 1]).abs();
let low_close = (self.low[i] - self.close[i - 1]).abs();
ranges.push(high_low.max(high_close).max(low_close));
}
ranges
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum IndicatorResult {
Single(f64),
Multiple(Vec<f64>),
Series(Vec<f64>),
MACD {
macd: Vec<f64>,
signal: Vec<f64>,
histogram: Vec<f64>,
},
BollingerBands {
middle: Vec<f64>,
upper: Vec<f64>,
lower: Vec<f64>,
},
Stochastic {
k: Vec<f64>,
d: Vec<f64>,
},
}
#[derive(Debug, Clone)]
pub enum IndicatorError {
InsufficientData { required: usize, actual: usize },
InvalidParameter(String),
CalculationError(String),
}
impl std::fmt::Display for IndicatorError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
IndicatorError::InsufficientData { required, actual } => {
write!(f, "Insufficient data: required {}, got {}", required, actual)
}
IndicatorError::InvalidParameter(msg) => write!(f, "Invalid parameter: {}", msg),
IndicatorError::CalculationError(msg) => write!(f, "Calculation error: {}", msg),
}
}
}
impl std::error::Error for IndicatorError {}
/// Rolling window for incremental calculations
pub struct RollingWindow<T> {
window: VecDeque<T>,
capacity: usize,
}
impl<T: Clone> RollingWindow<T> {
pub fn new(capacity: usize) -> Self {
Self {
window: VecDeque::with_capacity(capacity),
capacity,
}
}
pub fn push(&mut self, value: T) {
if self.window.len() >= self.capacity {
self.window.pop_front();
}
self.window.push_back(value);
}
pub fn is_full(&self) -> bool {
self.window.len() >= self.capacity
}
pub fn len(&self) -> usize {
self.window.len()
}
pub fn iter(&self) -> impl Iterator<Item = &T> {
self.window.iter()
}
pub fn as_slice(&self) -> Vec<T> {
self.window.iter().cloned().collect()
}
pub fn clear(&mut self) {
self.window.clear();
}
}

View file

@ -0,0 +1,213 @@
use super::{Indicator, IncrementalIndicator, IndicatorResult, IndicatorError, PriceData};
/// Exponential Moving Average (EMA) Indicator
///
/// Calculates the exponentially weighted moving average
/// giving more weight to recent prices
pub struct EMA {
period: usize,
alpha: f64,
value: Option<f64>,
initialized: bool,
}
impl EMA {
pub fn new(period: usize) -> Result<Self, IndicatorError> {
if period == 0 {
return Err(IndicatorError::InvalidParameter(
"Period must be greater than 0".to_string()
));
}
// Calculate smoothing factor (alpha)
// Common formula: 2 / (period + 1)
let alpha = 2.0 / (period as f64 + 1.0);
Ok(Self {
period,
alpha,
value: None,
initialized: false,
})
}
/// Create EMA with custom smoothing factor
pub fn with_alpha(alpha: f64) -> Result<Self, IndicatorError> {
if alpha <= 0.0 || alpha > 1.0 {
return Err(IndicatorError::InvalidParameter(
"Alpha must be between 0 and 1".to_string()
));
}
// Calculate equivalent period for reference
let period = ((2.0 / alpha) - 1.0) as usize;
Ok(Self {
period,
alpha,
value: None,
initialized: false,
})
}
/// Calculate EMA for a series of values
pub fn calculate_series(values: &[f64], period: usize) -> Result<Vec<f64>, IndicatorError> {
if period == 0 {
return Err(IndicatorError::InvalidParameter(
"Period must be greater than 0".to_string()
));
}
if values.is_empty() {
return Ok(vec![]);
}
let alpha = 2.0 / (period as f64 + 1.0);
let mut result = Vec::with_capacity(values.len());
// Start with first value as initial EMA
let mut ema = values[0];
result.push(ema);
// Calculate EMA for remaining values
for i in 1..values.len() {
ema = alpha * values[i] + (1.0 - alpha) * ema;
result.push(ema);
}
Ok(result)
}
/// Alternative initialization using SMA of first N values
pub fn calculate_series_sma_init(values: &[f64], period: usize) -> Result<Vec<f64>, IndicatorError> {
if period == 0 {
return Err(IndicatorError::InvalidParameter(
"Period must be greater than 0".to_string()
));
}
if values.len() < period {
return Err(IndicatorError::InsufficientData {
required: period,
actual: values.len(),
});
}
let alpha = 2.0 / (period as f64 + 1.0);
let mut result = Vec::with_capacity(values.len() - period + 1);
// Calculate initial SMA
let initial_sma: f64 = values[0..period].iter().sum::<f64>() / period as f64;
let mut ema = initial_sma;
result.push(ema);
// Calculate EMA for remaining values
for i in period..values.len() {
ema = alpha * values[i] + (1.0 - alpha) * ema;
result.push(ema);
}
Ok(result)
}
}
impl Indicator for EMA {
fn calculate(&mut self, data: &PriceData) -> Result<IndicatorResult, IndicatorError> {
let values = &data.close;
if values.is_empty() {
return Err(IndicatorError::InsufficientData {
required: 1,
actual: 0,
});
}
// Reset and calculate from scratch
self.reset();
let ema_values = Self::calculate_series(values, self.period)?;
// Update internal state with last value
if let Some(&last) = ema_values.last() {
self.value = Some(last);
self.initialized = true;
}
Ok(IndicatorResult::Series(ema_values))
}
fn reset(&mut self) {
self.value = None;
self.initialized = false;
}
fn is_ready(&self) -> bool {
self.initialized && self.value.is_some()
}
}
impl IncrementalIndicator for EMA {
fn update(&mut self, value: f64) -> Result<Option<f64>, IndicatorError> {
match self.value {
Some(prev_ema) => {
// Update EMA: EMA = α × Price + (1 - α) × Previous EMA
let new_ema = self.alpha * value + (1.0 - self.alpha) * prev_ema;
self.value = Some(new_ema);
self.initialized = true;
Ok(Some(new_ema))
}
None => {
// First value becomes the initial EMA
self.value = Some(value);
self.initialized = true;
Ok(Some(value))
}
}
}
fn current(&self) -> Option<f64> {
self.value
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ema_calculation() {
let values = vec![10.0, 11.0, 12.0, 13.0, 14.0, 15.0];
let result = EMA::calculate_series(&values, 3).unwrap();
assert_eq!(result.len(), 6);
assert!((result[0] - 10.0).abs() < 1e-10); // First value
// Verify EMA calculation
let alpha = 2.0 / 4.0; // 0.5
let expected_ema2 = alpha * 11.0 + (1.0 - alpha) * 10.0; // 10.5
assert!((result[1] - expected_ema2).abs() < 1e-10);
}
#[test]
fn test_incremental_ema() {
let mut ema = EMA::new(3).unwrap();
// First value
assert_eq!(ema.update(10.0).unwrap(), Some(10.0));
// Second value: EMA = 0.5 * 12 + 0.5 * 10 = 11
assert_eq!(ema.update(12.0).unwrap(), Some(11.0));
// Third value: EMA = 0.5 * 14 + 0.5 * 11 = 12.5
assert_eq!(ema.update(14.0).unwrap(), Some(12.5));
}
#[test]
fn test_ema_with_sma_init() {
let values = vec![2.0, 4.0, 6.0, 8.0, 10.0, 12.0];
let result = EMA::calculate_series_sma_init(&values, 3).unwrap();
// Initial SMA = (2 + 4 + 6) / 3 = 4
assert_eq!(result.len(), 4);
assert!((result[0] - 4.0).abs() < 1e-10);
}
}

View file

@ -0,0 +1,229 @@
use super::{Indicator, IncrementalIndicator, IndicatorResult, IndicatorError, PriceData};
use super::ema::EMA;
/// Moving Average Convergence Divergence (MACD) Indicator
///
/// MACD = Fast EMA - Slow EMA
/// Signal = EMA of MACD
/// Histogram = MACD - Signal
pub struct MACD {
fast_period: usize,
slow_period: usize,
signal_period: usize,
fast_ema: EMA,
slow_ema: EMA,
signal_ema: EMA,
macd_value: Option<f64>,
}
impl MACD {
pub fn new(fast_period: usize, slow_period: usize, signal_period: usize) -> Result<Self, IndicatorError> {
if fast_period == 0 || slow_period == 0 || signal_period == 0 {
return Err(IndicatorError::InvalidParameter(
"All periods must be greater than 0".to_string()
));
}
if fast_period >= slow_period {
return Err(IndicatorError::InvalidParameter(
"Fast period must be less than slow period".to_string()
));
}
Ok(Self {
fast_period,
slow_period,
signal_period,
fast_ema: EMA::new(fast_period)?,
slow_ema: EMA::new(slow_period)?,
signal_ema: EMA::new(signal_period)?,
macd_value: None,
})
}
/// Standard MACD with 12, 26, 9 periods
pub fn standard() -> Result<Self, IndicatorError> {
Self::new(12, 26, 9)
}
/// Calculate MACD for a series of values
pub fn calculate_series(
values: &[f64],
fast_period: usize,
slow_period: usize,
signal_period: usize
) -> Result<(Vec<f64>, Vec<f64>, Vec<f64>), IndicatorError> {
if fast_period >= slow_period {
return Err(IndicatorError::InvalidParameter(
"Fast period must be less than slow period".to_string()
));
}
if values.len() < slow_period {
return Err(IndicatorError::InsufficientData {
required: slow_period,
actual: values.len(),
});
}
// Calculate EMAs
let fast_ema = EMA::calculate_series(values, fast_period)?;
let slow_ema = EMA::calculate_series(values, slow_period)?;
// Calculate MACD line
let mut macd_line = Vec::with_capacity(slow_ema.len());
for i in 0..slow_ema.len() {
// Align indices - slow EMA starts later
let fast_idx = i + (slow_period - fast_period);
macd_line.push(fast_ema[fast_idx] - slow_ema[i]);
}
// Calculate signal line (EMA of MACD)
let signal_line = if macd_line.len() >= signal_period {
EMA::calculate_series(&macd_line, signal_period)?
} else {
vec![]
};
// Calculate histogram
let mut histogram = Vec::with_capacity(signal_line.len());
for i in 0..signal_line.len() {
// Align indices
let macd_idx = i + (macd_line.len() - signal_line.len());
histogram.push(macd_line[macd_idx] - signal_line[i]);
}
Ok((macd_line, signal_line, histogram))
}
}
impl Indicator for MACD {
fn calculate(&mut self, data: &PriceData) -> Result<IndicatorResult, IndicatorError> {
let values = &data.close;
let (macd, signal, histogram) = Self::calculate_series(
values,
self.fast_period,
self.slow_period,
self.signal_period
)?;
Ok(IndicatorResult::MACD {
macd,
signal,
histogram,
})
}
fn reset(&mut self) {
self.fast_ema.reset();
self.slow_ema.reset();
self.signal_ema.reset();
self.macd_value = None;
}
fn is_ready(&self) -> bool {
self.fast_ema.is_ready() && self.slow_ema.is_ready() && self.signal_ema.is_ready()
}
}
impl IncrementalIndicator for MACD {
fn update(&mut self, value: f64) -> Result<Option<f64>, IndicatorError> {
// Update both EMAs
let fast_result = self.fast_ema.update(value)?;
let slow_result = self.slow_ema.update(value)?;
// Calculate MACD if both EMAs are ready
if let (Some(fast), Some(slow)) = (fast_result, slow_result) {
let macd = fast - slow;
self.macd_value = Some(macd);
// Update signal line
if let Some(signal) = self.signal_ema.update(macd)? {
// Return histogram value
Ok(Some(macd - signal))
} else {
Ok(None)
}
} else {
Ok(None)
}
}
fn current(&self) -> Option<f64> {
// Return current histogram value
if let (Some(fast), Some(slow), Some(signal)) = (
self.fast_ema.current(),
self.slow_ema.current(),
self.signal_ema.current()
) {
let macd = fast - slow;
Some(macd - signal)
} else {
None
}
}
}
/// Structure to hold all MACD values for incremental updates
pub struct MACDValues {
pub macd: f64,
pub signal: f64,
pub histogram: f64,
}
impl MACD {
/// Get all current MACD values
pub fn current_values(&self) -> Option<MACDValues> {
if let (Some(fast), Some(slow), Some(signal)) = (
self.fast_ema.current(),
self.slow_ema.current(),
self.signal_ema.current()
) {
let macd = fast - slow;
Some(MACDValues {
macd,
signal,
histogram: macd - signal,
})
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_macd_calculation() {
let values = vec![
10.0, 10.5, 11.0, 11.5, 12.0, 12.5, 13.0, 13.5,
14.0, 14.5, 15.0, 14.5, 14.0, 13.5, 13.0, 12.5,
12.0, 11.5, 11.0, 10.5, 10.0, 10.5, 11.0, 11.5,
12.0, 12.5, 13.0, 13.5, 14.0, 14.5
];
let (macd, signal, histogram) = MACD::calculate_series(&values, 12, 26, 9).unwrap();
// Should have values after slow period
assert!(!macd.is_empty());
assert!(!signal.is_empty());
assert!(!histogram.is_empty());
// Histogram should equal MACD - Signal
for i in 0..histogram.len() {
let expected = macd[macd.len() - histogram.len() + i] - signal[i];
assert!((histogram[i] - expected).abs() < 1e-10);
}
}
#[test]
fn test_standard_macd() {
let macd = MACD::standard().unwrap();
assert_eq!(macd.fast_period, 12);
assert_eq!(macd.slow_period, 26);
assert_eq!(macd.signal_period, 9);
}
}

View file

@ -0,0 +1,40 @@
// Technical Analysis Indicators Library
pub mod sma;
pub mod ema;
pub mod rsi;
pub mod macd;
pub mod bollinger_bands;
pub mod stochastic;
pub mod atr;
pub mod common;
// Re-export commonly used types and traits
pub use common::{IndicatorResult, IndicatorError, PriceData};
pub use sma::SMA;
pub use ema::EMA;
pub use rsi::RSI;
pub use macd::MACD;
pub use bollinger_bands::BollingerBands;
pub use stochastic::Stochastic;
pub use atr::ATR;
/// Trait that all indicators must implement
pub trait Indicator {
/// Calculate the indicator value(s) for the given data
fn calculate(&mut self, data: &PriceData) -> Result<IndicatorResult, IndicatorError>;
/// Reset the indicator state
fn reset(&mut self);
/// Check if the indicator has enough data to produce valid results
fn is_ready(&self) -> bool;
}
/// Trait for indicators that can be calculated incrementally
pub trait IncrementalIndicator: Indicator {
/// Update the indicator with a new data point
fn update(&mut self, value: f64) -> Result<Option<f64>, IndicatorError>;
/// Get the current value without updating
fn current(&self) -> Option<f64>;
}

View file

@ -0,0 +1,223 @@
use super::{Indicator, IncrementalIndicator, IndicatorResult, IndicatorError, PriceData};
use super::common::RollingWindow;
/// Relative Strength Index (RSI) Indicator
///
/// Measures momentum by comparing the magnitude of recent gains to recent losses
/// RSI = 100 - (100 / (1 + RS))
/// where RS = Average Gain / Average Loss
pub struct RSI {
period: usize,
avg_gain: f64,
avg_loss: f64,
prev_value: Option<f64>,
window: RollingWindow<f64>,
initialized: bool,
}
impl RSI {
pub fn new(period: usize) -> Result<Self, IndicatorError> {
if period == 0 {
return Err(IndicatorError::InvalidParameter(
"Period must be greater than 0".to_string()
));
}
Ok(Self {
period,
avg_gain: 0.0,
avg_loss: 0.0,
prev_value: None,
window: RollingWindow::new(period + 1),
initialized: false,
})
}
/// Calculate RSI for a series of values
pub fn calculate_series(values: &[f64], period: usize) -> Result<Vec<f64>, IndicatorError> {
if period == 0 {
return Err(IndicatorError::InvalidParameter(
"Period must be greater than 0".to_string()
));
}
if values.len() <= period {
return Err(IndicatorError::InsufficientData {
required: period + 1,
actual: values.len(),
});
}
let mut result = Vec::with_capacity(values.len() - period);
let mut gains = Vec::with_capacity(values.len() - 1);
let mut losses = Vec::with_capacity(values.len() - 1);
// Calculate price changes
for i in 1..values.len() {
let change = values[i] - values[i - 1];
if change > 0.0 {
gains.push(change);
losses.push(0.0);
} else {
gains.push(0.0);
losses.push(-change);
}
}
// Calculate initial averages using SMA
let initial_avg_gain: f64 = gains[0..period].iter().sum::<f64>() / period as f64;
let initial_avg_loss: f64 = losses[0..period].iter().sum::<f64>() / period as f64;
// Calculate first RSI
let rs = if initial_avg_loss > 0.0 {
initial_avg_gain / initial_avg_loss
} else {
100.0 // If no losses, RSI is 100
};
result.push(100.0 - (100.0 / (1.0 + rs)));
// Calculate remaining RSIs using EMA smoothing
let mut avg_gain = initial_avg_gain;
let mut avg_loss = initial_avg_loss;
let alpha = 1.0 / period as f64;
for i in period..gains.len() {
// Wilder's smoothing method
avg_gain = (avg_gain * (period - 1) as f64 + gains[i]) / period as f64;
avg_loss = (avg_loss * (period - 1) as f64 + losses[i]) / period as f64;
let rs = if avg_loss > 0.0 {
avg_gain / avg_loss
} else {
100.0
};
result.push(100.0 - (100.0 / (1.0 + rs)));
}
Ok(result)
}
fn calculate_rsi(&self) -> f64 {
if self.avg_loss == 0.0 {
100.0
} else {
let rs = self.avg_gain / self.avg_loss;
100.0 - (100.0 / (1.0 + rs))
}
}
}
impl Indicator for RSI {
fn calculate(&mut self, data: &PriceData) -> Result<IndicatorResult, IndicatorError> {
let values = &data.close;
if values.len() <= self.period {
return Err(IndicatorError::InsufficientData {
required: self.period + 1,
actual: values.len(),
});
}
let rsi_values = Self::calculate_series(values, self.period)?;
Ok(IndicatorResult::Series(rsi_values))
}
fn reset(&mut self) {
self.avg_gain = 0.0;
self.avg_loss = 0.0;
self.prev_value = None;
self.window.clear();
self.initialized = false;
}
fn is_ready(&self) -> bool {
self.initialized
}
}
impl IncrementalIndicator for RSI {
fn update(&mut self, value: f64) -> Result<Option<f64>, IndicatorError> {
self.window.push(value);
if let Some(prev) = self.prev_value {
let change = value - prev;
let gain = if change > 0.0 { change } else { 0.0 };
let loss = if change < 0.0 { -change } else { 0.0 };
if !self.initialized && self.window.len() > self.period {
// Initialize using first period values
let values = self.window.as_slice();
let mut sum_gain = 0.0;
let mut sum_loss = 0.0;
for i in 1..=self.period {
let change = values[i] - values[i - 1];
if change > 0.0 {
sum_gain += change;
} else {
sum_loss += -change;
}
}
self.avg_gain = sum_gain / self.period as f64;
self.avg_loss = sum_loss / self.period as f64;
self.initialized = true;
} else if self.initialized {
// Update using Wilder's smoothing
self.avg_gain = (self.avg_gain * (self.period - 1) as f64 + gain) / self.period as f64;
self.avg_loss = (self.avg_loss * (self.period - 1) as f64 + loss) / self.period as f64;
}
}
self.prev_value = Some(value);
if self.initialized {
Ok(Some(self.calculate_rsi()))
} else {
Ok(None)
}
}
fn current(&self) -> Option<f64> {
if self.initialized {
Some(self.calculate_rsi())
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rsi_calculation() {
let values = vec![
44.0, 44.25, 44.38, 44.38, 44.88, 45.05,
45.25, 45.38, 45.75, 46.03, 46.23, 46.08,
46.03, 45.85, 46.25, 46.38, 46.50
];
let result = RSI::calculate_series(&values, 14).unwrap();
assert_eq!(result.len(), 3);
// RSI should be between 0 and 100
for rsi in &result {
assert!(*rsi >= 0.0 && *rsi <= 100.0);
}
}
#[test]
fn test_rsi_extremes() {
// All gains - RSI should be close to 100
let increasing = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let result = RSI::calculate_series(&increasing, 5).unwrap();
assert!(result.last().unwrap() > &95.0);
// All losses - RSI should be close to 0
let decreasing = vec![8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
let result = RSI::calculate_series(&decreasing, 5).unwrap();
assert!(result.last().unwrap() < &5.0);
}
}

View file

@ -0,0 +1,139 @@
use super::{Indicator, IncrementalIndicator, IndicatorResult, IndicatorError, PriceData};
use super::common::RollingWindow;
/// Simple Moving Average (SMA) Indicator
///
/// Calculates the arithmetic mean of the last N periods
pub struct SMA {
period: usize,
window: RollingWindow<f64>,
sum: f64,
}
impl SMA {
pub fn new(period: usize) -> Result<Self, IndicatorError> {
if period == 0 {
return Err(IndicatorError::InvalidParameter(
"Period must be greater than 0".to_string()
));
}
Ok(Self {
period,
window: RollingWindow::new(period),
sum: 0.0,
})
}
/// Calculate SMA for a series of values
pub fn calculate_series(values: &[f64], period: usize) -> Result<Vec<f64>, IndicatorError> {
if period == 0 {
return Err(IndicatorError::InvalidParameter(
"Period must be greater than 0".to_string()
));
}
if values.len() < period {
return Err(IndicatorError::InsufficientData {
required: period,
actual: values.len(),
});
}
let mut result = Vec::with_capacity(values.len() - period + 1);
// Calculate first SMA
let mut sum: f64 = values[0..period].iter().sum();
result.push(sum / period as f64);
// Calculate remaining SMAs using sliding window
for i in period..values.len() {
sum = sum - values[i - period] + values[i];
result.push(sum / period as f64);
}
Ok(result)
}
}
impl Indicator for SMA {
fn calculate(&mut self, data: &PriceData) -> Result<IndicatorResult, IndicatorError> {
let values = &data.close;
if values.len() < self.period {
return Err(IndicatorError::InsufficientData {
required: self.period,
actual: values.len(),
});
}
let sma_values = Self::calculate_series(values, self.period)?;
Ok(IndicatorResult::Series(sma_values))
}
fn reset(&mut self) {
self.window.clear();
self.sum = 0.0;
}
fn is_ready(&self) -> bool {
self.window.is_full()
}
}
impl IncrementalIndicator for SMA {
fn update(&mut self, value: f64) -> Result<Option<f64>, IndicatorError> {
// If window is full, subtract the oldest value from sum
if self.window.is_full() {
if let Some(oldest) = self.window.iter().next() {
self.sum -= oldest;
}
}
// Add new value
self.window.push(value);
self.sum += value;
// Calculate SMA if we have enough data
if self.window.is_full() {
Ok(Some(self.sum / self.period as f64))
} else {
Ok(None)
}
}
fn current(&self) -> Option<f64> {
if self.window.is_full() {
Some(self.sum / self.period as f64)
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sma_calculation() {
let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let result = SMA::calculate_series(&values, 3).unwrap();
assert_eq!(result.len(), 8);
assert!((result[0] - 2.0).abs() < 1e-10); // (1+2+3)/3 = 2
assert!((result[1] - 3.0).abs() < 1e-10); // (2+3+4)/3 = 3
assert!((result[7] - 9.0).abs() < 1e-10); // (8+9+10)/3 = 9
}
#[test]
fn test_incremental_sma() {
let mut sma = SMA::new(3).unwrap();
assert_eq!(sma.update(1.0).unwrap(), None);
assert_eq!(sma.update(2.0).unwrap(), None);
assert_eq!(sma.update(3.0).unwrap(), Some(2.0));
assert_eq!(sma.update(4.0).unwrap(), Some(3.0));
assert_eq!(sma.update(5.0).unwrap(), Some(4.0));
}
}

View file

@ -0,0 +1,297 @@
use super::{Indicator, IncrementalIndicator, IndicatorResult, IndicatorError, PriceData};
use super::sma::SMA;
use super::common::RollingWindow;
/// Stochastic Oscillator Indicator
///
/// %K = 100 × (Close - Lowest Low) / (Highest High - Lowest Low)
/// %D = SMA of %K
pub struct Stochastic {
k_period: usize,
d_period: usize,
smooth_k: usize,
high_window: RollingWindow<f64>,
low_window: RollingWindow<f64>,
close_window: RollingWindow<f64>,
k_sma: SMA,
d_sma: SMA,
k_values: RollingWindow<f64>,
}
impl Stochastic {
pub fn new(k_period: usize, d_period: usize, smooth_k: usize) -> Result<Self, IndicatorError> {
if k_period == 0 || d_period == 0 {
return Err(IndicatorError::InvalidParameter(
"K and D periods must be greater than 0".to_string()
));
}
Ok(Self {
k_period,
d_period,
smooth_k,
high_window: RollingWindow::new(k_period),
low_window: RollingWindow::new(k_period),
close_window: RollingWindow::new(k_period),
k_sma: SMA::new(smooth_k.max(1))?,
d_sma: SMA::new(d_period)?,
k_values: RollingWindow::new(d_period),
})
}
/// Standard Fast Stochastic (14, 3)
pub fn fast() -> Result<Self, IndicatorError> {
Self::new(14, 3, 1)
}
/// Standard Slow Stochastic (14, 3, 3)
pub fn slow() -> Result<Self, IndicatorError> {
Self::new(14, 3, 3)
}
/// Calculate raw %K value
fn calculate_raw_k(high: &[f64], low: &[f64], close: f64) -> Option<f64> {
if high.is_empty() || low.is_empty() {
return None;
}
let highest = high.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let lowest = low.iter().cloned().fold(f64::INFINITY, f64::min);
let range = highest - lowest;
if range > 0.0 {
Some(100.0 * (close - lowest) / range)
} else {
Some(50.0) // If no range, return middle value
}
}
/// Calculate Stochastic for a series of price data
pub fn calculate_series(
high: &[f64],
low: &[f64],
close: &[f64],
k_period: usize,
d_period: usize,
smooth_k: usize,
) -> Result<(Vec<f64>, Vec<f64>), IndicatorError> {
if high.len() != low.len() || high.len() != close.len() {
return Err(IndicatorError::InvalidParameter(
"High, low, and close arrays must have the same length".to_string()
));
}
if high.len() < k_period {
return Err(IndicatorError::InsufficientData {
required: k_period,
actual: high.len(),
});
}
// Calculate raw %K values
let mut raw_k_values = Vec::with_capacity(high.len() - k_period + 1);
for i in k_period - 1..high.len() {
let start = i + 1 - k_period;
if let Some(k) = Self::calculate_raw_k(
&high[start..=i],
&low[start..=i],
close[i]
) {
raw_k_values.push(k);
}
}
// Smooth %K if requested
let k_values = if smooth_k > 1 {
SMA::calculate_series(&raw_k_values, smooth_k)?
} else {
raw_k_values
};
// Calculate %D (SMA of %K)
let d_values = if k_values.len() >= d_period {
SMA::calculate_series(&k_values, d_period)?
} else {
vec![]
};
Ok((k_values, d_values))
}
}
impl Indicator for Stochastic {
fn calculate(&mut self, data: &PriceData) -> Result<IndicatorResult, IndicatorError> {
if data.high.len() != data.low.len() || data.high.len() != data.close.len() {
return Err(IndicatorError::InvalidParameter(
"Price data arrays must have the same length".to_string()
));
}
let (k_values, d_values) = Self::calculate_series(
&data.high,
&data.low,
&data.close,
self.k_period,
self.d_period,
self.smooth_k,
)?;
Ok(IndicatorResult::Stochastic {
k: k_values,
d: d_values,
})
}
fn reset(&mut self) {
self.high_window.clear();
self.low_window.clear();
self.close_window.clear();
self.k_sma.reset();
self.d_sma.reset();
self.k_values.clear();
}
fn is_ready(&self) -> bool {
self.high_window.is_full() && self.d_sma.is_ready()
}
}
impl IncrementalIndicator for Stochastic {
fn update(&mut self, value: f64) -> Result<Option<f64>, IndicatorError> {
// For incremental updates, we assume value is close price
// In real usage, you'd need to provide high/low/close separately
self.high_window.push(value);
self.low_window.push(value);
self.close_window.push(value);
if self.high_window.is_full() {
// Calculate raw %K
if let Some(raw_k) = Self::calculate_raw_k(
&self.high_window.as_slice(),
&self.low_window.as_slice(),
value
) {
// Smooth %K if needed
let k_value = if self.smooth_k > 1 {
self.k_sma.update(raw_k)?
} else {
Some(raw_k)
};
if let Some(k) = k_value {
self.k_values.push(k);
// Calculate %D
if let Some(d) = self.d_sma.update(k)? {
return Ok(Some(d));
}
}
}
}
Ok(None)
}
fn current(&self) -> Option<f64> {
self.d_sma.current()
}
}
/// Structure to hold Stochastic values
pub struct StochasticValues {
pub k: f64,
pub d: f64,
}
impl Stochastic {
/// Get current %K and %D values
pub fn current_values(&self) -> Option<StochasticValues> {
if let (Some(&k), Some(d)) = (self.k_values.iter().last(), self.d_sma.current()) {
Some(StochasticValues { k, d })
} else {
None
}
}
/// Update with separate high, low, close values
pub fn update_hlc(&mut self, high: f64, low: f64, close: f64) -> Result<Option<StochasticValues>, IndicatorError> {
self.high_window.push(high);
self.low_window.push(low);
self.close_window.push(close);
if self.high_window.is_full() {
if let Some(raw_k) = Self::calculate_raw_k(
&self.high_window.as_slice(),
&self.low_window.as_slice(),
close
) {
let k_value = if self.smooth_k > 1 {
self.k_sma.update(raw_k)?
} else {
Some(raw_k)
};
if let Some(k) = k_value {
self.k_values.push(k);
if let Some(d) = self.d_sma.update(k)? {
return Ok(Some(StochasticValues { k, d }));
}
}
}
}
Ok(None)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_stochastic_calculation() {
let high = vec![
127.01, 127.62, 126.59, 127.35, 128.17,
128.43, 127.37, 126.42, 126.90, 126.85,
125.65, 125.72, 127.16, 127.72, 127.69
];
let low = vec![
125.36, 126.16, 124.93, 126.09, 126.82,
126.48, 126.03, 124.83, 126.39, 125.72,
124.56, 124.57, 125.07, 126.86, 126.63
];
let close = vec![
125.36, 126.16, 124.93, 126.09, 126.82,
126.48, 126.03, 124.83, 126.39, 125.72,
124.56, 124.57, 125.07, 126.86, 126.63
];
let (k_values, d_values) = Stochastic::calculate_series(&high, &low, &close, 14, 3, 1).unwrap();
assert!(!k_values.is_empty());
assert!(!d_values.is_empty());
// %K should be between 0 and 100
for k in &k_values {
assert!(*k >= 0.0 && *k <= 100.0);
}
}
#[test]
fn test_stochastic_extremes() {
// When close is at highest high, %K should be 100
let high = vec![10.0; 14];
let low = vec![5.0; 14];
let mut close = vec![7.5; 14];
close[13] = 10.0; // Last close at high
let (k_values, _) = Stochastic::calculate_series(&high, &low, &close, 14, 3, 1).unwrap();
assert!((k_values[0] - 100.0).abs() < 1e-10);
// When close is at lowest low, %K should be 0
close[13] = 5.0; // Last close at low
let (k_values, _) = Stochastic::calculate_series(&high, &low, &close, 14, 3, 1).unwrap();
assert!(k_values[0].abs() < 1e-10);
}
}

View file

@ -0,0 +1,230 @@
#![deny(clippy::all)]
// Existing modules
pub mod core;
pub mod adapters;
pub mod orderbook;
pub mod risk;
pub mod positions;
#[cfg(not(test))]
pub mod api;
pub mod analytics;
pub mod indicators;
pub mod backtest;
pub mod strategies;
// Re-export commonly used types
pub use positions::{Position, PositionUpdate, TradeRecord, ClosedTrade};
pub use risk::{RiskLimits, RiskCheckResult, RiskMetrics};
// Type alias for backtest compatibility
pub type MarketData = MarketUpdate;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use parking_lot::RwLock;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum TradingMode {
Backtest {
start_time: DateTime<Utc>,
end_time: DateTime<Utc>,
speed_multiplier: f64,
},
Paper {
starting_capital: f64,
},
Live {
broker: String,
account_id: String,
},
}
// Core traits that allow different implementations based on mode
#[async_trait::async_trait]
pub trait MarketDataSource: Send + Sync {
async fn get_next_update(&mut self) -> Option<MarketUpdate>;
fn seek_to_time(&mut self, timestamp: DateTime<Utc>) -> Result<(), String>;
fn as_any(&self) -> &dyn std::any::Any;
fn as_any_mut(&mut self) -> &mut dyn std::any::Any;
}
#[async_trait::async_trait]
pub trait ExecutionHandler: Send + Sync {
async fn execute_order(&mut self, order: Order) -> Result<ExecutionResult, String>;
fn get_fill_simulator(&self) -> Option<&dyn FillSimulator>;
}
pub trait TimeProvider: Send + Sync {
fn now(&self) -> DateTime<Utc>;
fn sleep_until(&self, target: DateTime<Utc>) -> Result<(), String>;
fn as_any(&self) -> &dyn std::any::Any;
}
pub trait FillSimulator: Send + Sync {
fn simulate_fill(&self, order: &Order, orderbook: &OrderBookSnapshot) -> Option<Fill>;
}
// Main trading core that works across all modes
pub struct TradingCore {
mode: TradingMode,
pub market_data_source: Arc<RwLock<Box<dyn MarketDataSource>>>,
pub execution_handler: Arc<RwLock<Box<dyn ExecutionHandler>>>,
pub time_provider: Arc<Box<dyn TimeProvider>>,
pub orderbooks: Arc<orderbook::OrderBookManager>,
pub risk_engine: Arc<risk::RiskEngine>,
pub position_tracker: Arc<positions::PositionTracker>,
}
// Core types used across the system
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MarketUpdate {
pub symbol: String,
pub timestamp: DateTime<Utc>,
pub data: MarketDataType,
}
// Market microstructure parameters
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MarketMicrostructure {
pub symbol: String,
pub avg_spread_bps: f64,
pub daily_volume: f64,
pub avg_trade_size: f64,
pub volatility: f64,
pub tick_size: f64,
pub lot_size: f64,
pub intraday_volume_profile: Vec<f64>, // 24 hourly buckets
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum MarketDataType {
Quote(Quote),
Trade(Trade),
Bar(Bar),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Quote {
pub bid: f64,
pub ask: f64,
pub bid_size: f64,
pub ask_size: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Trade {
pub price: f64,
pub size: f64,
pub side: Side,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Bar {
pub open: f64,
pub high: f64,
pub low: f64,
pub close: f64,
pub volume: f64,
pub vwap: Option<f64>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
pub enum Side {
Buy,
Sell,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Order {
pub id: String,
pub symbol: String,
pub side: Side,
pub quantity: f64,
pub order_type: OrderType,
pub time_in_force: TimeInForce,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum OrderType {
Market,
Limit { price: f64 },
Stop { stop_price: f64 },
StopLimit { stop_price: f64, limit_price: f64 },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum TimeInForce {
Day,
GTC,
IOC,
FOK,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecutionResult {
pub order_id: String,
pub status: OrderStatus,
pub fills: Vec<Fill>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum OrderStatus {
Pending,
Accepted,
PartiallyFilled,
Filled,
Cancelled,
Rejected(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Fill {
pub timestamp: DateTime<Utc>,
pub price: f64,
pub quantity: f64,
pub commission: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OrderBookSnapshot {
pub symbol: String,
pub timestamp: DateTime<Utc>,
pub bids: Vec<PriceLevel>,
pub asks: Vec<PriceLevel>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PriceLevel {
pub price: f64,
pub size: f64,
pub order_count: Option<u32>,
}
impl TradingCore {
pub fn new(
mode: TradingMode,
market_data_source: Box<dyn MarketDataSource>,
execution_handler: Box<dyn ExecutionHandler>,
time_provider: Box<dyn TimeProvider>,
) -> Self {
Self {
mode,
market_data_source: Arc::new(RwLock::new(market_data_source)),
execution_handler: Arc::new(RwLock::new(execution_handler)),
time_provider: Arc::new(time_provider),
orderbooks: Arc::new(orderbook::OrderBookManager::new()),
risk_engine: Arc::new(risk::RiskEngine::new()),
position_tracker: Arc::new(positions::PositionTracker::new()),
}
}
pub fn get_mode(&self) -> &TradingMode {
&self.mode
}
pub fn get_time(&self) -> DateTime<Utc> {
self.time_provider.now()
}
}

View file

@ -0,0 +1,67 @@
#![deny(clippy::all)]
// Domain modules
pub mod domain;
pub mod events;
pub mod modes;
// Core functionality modules
pub mod core;
pub mod orderbook;
pub mod risk;
pub mod positions;
pub mod analytics;
pub mod indicators;
pub mod backtest;
pub mod strategies;
// API layer
pub mod api;
// Re-export commonly used types from domain
pub use domain::{
// Market types
Quote, Bar, Trade, MarketUpdate, MarketDataType,
// Order types
Order, OrderType, OrderStatus, TimeInForce, Side, Fill,
// Position types
Position, PositionUpdate,
// Event types
Event, EventType, EventHandler,
};
// Re-export mode types
pub use modes::{TradingMode, TradingEngine};
// Re-export other commonly used types
pub use positions::{PositionTracker, TradeRecord, ClosedTrade};
pub use risk::{RiskLimits, RiskCheckResult, RiskMetrics};
// Core traits that define the system's abstractions
use chrono::{DateTime, Utc};
use std::sync::Arc;
use async_trait::async_trait;
#[async_trait]
pub trait MarketDataSource: Send + Sync {
async fn get_next_update(&mut self) -> Option<MarketUpdate>;
fn seek_to_time(&mut self, timestamp: DateTime<Utc>) -> Result<(), String>;
fn as_any(&self) -> &dyn std::any::Any;
fn as_any_mut(&mut self) -> &mut dyn std::any::Any;
}
#[async_trait]
pub trait ExecutionHandler: Send + Sync {
async fn submit_order(&self, order: &Order) -> Result<String, String>;
async fn cancel_order(&self, order_id: &str) -> Result<(), String>;
async fn get_order_status(&self, order_id: &str) -> Result<OrderStatus, String>;
}
pub trait TimeProvider: Send + Sync {
fn now(&self) -> DateTime<Utc>;
fn as_any(&self) -> &dyn std::any::Any;
}
pub trait FillSimulator: Send + Sync {
fn simulate_fill(&self, order: &Order, market_price: f64) -> Option<Fill>;
}

View file

@ -0,0 +1,123 @@
#![deny(clippy::all)]
// Domain modules - core types and abstractions
pub mod domain;
pub mod events;
pub mod modes;
pub mod strategies;
// Core functionality modules
pub mod core;
pub mod orderbook;
pub mod risk;
pub mod positions;
pub mod analytics;
pub mod indicators;
pub mod backtest;
// API layer
pub mod api;
// Re-export all domain types
pub use domain::*;
// Re-export event system
pub use events::{Event, EventType, EventBus, EventHandler};
// Re-export mode types and engine trait
pub use modes::{TradingMode, TradingEngine};
// Re-export strategy framework
pub use strategies::framework::{Strategy, StrategyContext, Signal, SignalAction};
// Re-export position tracking
pub use positions::{PositionTracker, TradeRecord, ClosedTrade};
// Re-export risk management
pub use risk::{RiskLimits, RiskCheckResult, RiskMetrics, RiskEngine};
// Re-export analytics
pub use analytics::{AnalyticsEngine, PerformanceMetrics};
// Re-export indicators
pub use indicators::{Indicator, IndicatorSet};
// Re-export backtest types
pub use backtest::{BacktestEngine, BacktestResults};
// Re-export API
pub use api::TradingAPI;
// Core system traits
use async_trait::async_trait;
use chrono::{DateTime, Utc};
/// Source of market data - can be historical files, live feed, or simulated
#[async_trait]
pub trait MarketDataSource: Send + Sync {
async fn get_next_update(&mut self) -> Option<MarketUpdate>;
fn seek_to_time(&mut self, timestamp: DateTime<Utc>) -> Result<(), String>;
}
/// Handles order execution - can be simulated, paper, or live broker
#[async_trait]
pub trait ExecutionHandler: Send + Sync {
async fn submit_order(&self, order: &Order) -> Result<String, String>;
async fn cancel_order(&self, order_id: &str) -> Result<(), String>;
async fn get_order_status(&self, order_id: &str) -> Result<OrderStatus, String>;
}
/// Provides current time - can be simulated time for backtest or real time
pub trait TimeProvider: Send + Sync {
fn now(&self) -> DateTime<Utc>;
fn sleep_until(&self, target: DateTime<Utc>) -> Result<(), String>;
}
/// Simulates order fills for backtest/paper trading
pub trait FillSimulator: Send + Sync {
fn simulate_fill(&self, order: &Order, market_price: f64, spread: f64) -> Option<Fill>;
}
/// Main trading system that coordinates all components
pub struct TradingCore {
mode: TradingMode,
engine: Box<dyn TradingEngine>,
event_bus: EventBus,
orderbook_manager: OrderBookManager,
position_tracker: PositionTracker,
risk_engine: RiskEngine,
analytics_engine: AnalyticsEngine,
}
impl TradingCore {
/// Create a new trading system for the specified mode
pub fn new(mode: TradingMode) -> Result<Self, String> {
let engine = modes::create_engine_for_mode(&mode)?;
let event_bus = EventBus::new();
Ok(Self {
mode,
engine,
event_bus,
orderbook_manager: OrderBookManager::new(),
position_tracker: PositionTracker::new(),
risk_engine: RiskEngine::new(),
analytics_engine: AnalyticsEngine::new(),
})
}
/// Start the trading system
pub async fn start(&mut self) -> Result<(), String> {
self.engine.start(&mut self.event_bus).await
}
/// Stop the trading system
pub async fn stop(&mut self) -> Result<(), String> {
self.engine.stop().await
}
/// Get current trading mode
pub fn mode(&self) -> &TradingMode {
&self.mode
}
}

View file

@ -0,0 +1,291 @@
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use std::sync::Arc;
use parking_lot::RwLock;
use crate::{
MarketDataSource, ExecutionHandler, TimeProvider, FillSimulator,
MarketUpdate, Order, OrderStatus, Fill, OrderBookSnapshot,
};
use crate::events::EventBus;
use crate::domain::{Event, EventType};
use super::TradingEngine;
/// Backtest-specific trading engine
pub struct BacktestEngine {
start_time: DateTime<Utc>,
end_time: DateTime<Utc>,
current_time: DateTime<Utc>,
speed_multiplier: f64,
market_data: Box<dyn MarketDataSource>,
execution: BacktestExecutor,
is_running: Arc<RwLock<bool>>,
}
impl BacktestEngine {
pub fn new(
start_time: DateTime<Utc>,
end_time: DateTime<Utc>,
speed_multiplier: f64,
market_data: Box<dyn MarketDataSource>,
) -> Self {
Self {
start_time,
end_time,
current_time: start_time,
speed_multiplier,
market_data,
execution: BacktestExecutor::new(),
is_running: Arc::new(RwLock::new(false)),
}
}
}
#[async_trait]
impl TradingEngine for BacktestEngine {
async fn start(&mut self, event_bus: &mut EventBus) -> Result<(), String> {
*self.is_running.write() = true;
self.current_time = self.start_time;
// Seek market data to start time
self.market_data.seek_to_time(self.start_time)?;
// Emit start event
event_bus.publish(Event {
event_type: EventType::SystemStart,
timestamp: self.current_time,
data: serde_json::json!({
"mode": "backtest",
"start_time": self.start_time,
"end_time": self.end_time,
}),
}).await;
// Main backtest loop
while *self.is_running.read() && self.current_time < self.end_time {
// Get next market update
if let Some(update) = self.market_data.get_next_update().await {
// Update current time
self.current_time = update.timestamp;
// Publish market data event
event_bus.publish(Event {
event_type: EventType::MarketData(update.clone()),
timestamp: self.current_time,
data: serde_json::Value::Null,
}).await;
// Process any pending orders
self.execution.process_orders(&update, event_bus).await?;
} else {
// No more data
break;
}
// Simulate time passing (for UI updates, etc.)
if self.speed_multiplier > 0.0 {
tokio::time::sleep(std::time::Duration::from_millis(
(10.0 / self.speed_multiplier) as u64
)).await;
}
}
// Emit stop event
event_bus.publish(Event {
event_type: EventType::SystemStop,
timestamp: self.current_time,
data: serde_json::json!({
"reason": "backtest_complete",
"end_time": self.current_time,
}),
}).await;
Ok(())
}
async fn stop(&mut self) -> Result<(), String> {
*self.is_running.write() = false;
Ok(())
}
fn get_execution_handler(&self) -> Arc<dyn ExecutionHandler> {
Arc::new(self.execution.clone())
}
fn get_time_provider(&self) -> Arc<dyn TimeProvider> {
Arc::new(BacktestTimeProvider {
current_time: self.current_time,
})
}
}
/// Backtest order executor
#[derive(Clone)]
struct BacktestExecutor {
pending_orders: Arc<RwLock<Vec<Order>>>,
order_history: Arc<RwLock<Vec<(Order, OrderStatus)>>>,
fill_simulator: BacktestFillSimulator,
}
impl BacktestExecutor {
fn new() -> Self {
Self {
pending_orders: Arc::new(RwLock::new(Vec::new())),
order_history: Arc::new(RwLock::new(Vec::new())),
fill_simulator: BacktestFillSimulator::new(),
}
}
async fn process_orders(
&self,
market_update: &MarketUpdate,
event_bus: &mut EventBus,
) -> Result<(), String> {
let mut orders = self.pending_orders.write();
let mut filled_indices = Vec::new();
for (idx, order) in orders.iter().enumerate() {
if order.symbol != market_update.symbol {
continue;
}
// Try to fill the order
if let Some(fill) = self.fill_simulator.try_fill(order, market_update) {
// Emit fill event
event_bus.publish(Event {
event_type: EventType::OrderFill {
order_id: order.id.clone(),
fill: fill.clone(),
},
timestamp: market_update.timestamp,
data: serde_json::Value::Null,
}).await;
filled_indices.push(idx);
}
}
// Remove filled orders
for idx in filled_indices.into_iter().rev() {
let order = orders.remove(idx);
self.order_history.write().push((order, OrderStatus::Filled));
}
Ok(())
}
}
#[async_trait]
impl ExecutionHandler for BacktestExecutor {
async fn submit_order(&self, order: &Order) -> Result<String, String> {
let order_id = order.id.clone();
self.pending_orders.write().push(order.clone());
Ok(order_id)
}
async fn cancel_order(&self, order_id: &str) -> Result<(), String> {
let mut orders = self.pending_orders.write();
if let Some(pos) = orders.iter().position(|o| o.id == order_id) {
let order = orders.remove(pos);
self.order_history.write().push((order, OrderStatus::Cancelled));
Ok(())
} else {
Err("Order not found".to_string())
}
}
async fn get_order_status(&self, order_id: &str) -> Result<OrderStatus, String> {
// Check pending orders
if self.pending_orders.read().iter().any(|o| o.id == order_id) {
return Ok(OrderStatus::Pending);
}
// Check history
if let Some((_, status)) = self.order_history.read()
.iter()
.find(|(o, _)| o.id == order_id) {
return Ok(status.clone());
}
Err("Order not found".to_string())
}
}
/// Backtest time provider
struct BacktestTimeProvider {
current_time: DateTime<Utc>,
}
impl TimeProvider for BacktestTimeProvider {
fn now(&self) -> DateTime<Utc> {
self.current_time
}
fn sleep_until(&self, _target: DateTime<Utc>) -> Result<(), String> {
// In backtest, we don't actually sleep
Ok(())
}
}
/// Backtest fill simulator
#[derive(Clone)]
struct BacktestFillSimulator {
commission_rate: f64,
slippage_bps: f64,
}
impl BacktestFillSimulator {
fn new() -> Self {
Self {
commission_rate: 0.001, // 0.1%
slippage_bps: 5.0, // 5 basis points
}
}
fn try_fill(&self, order: &Order, market_update: &MarketUpdate) -> Option<Fill> {
match &market_update.data {
crate::MarketDataType::Quote(quote) => {
self.simulate_fill(order, quote.bid, quote.ask - quote.bid)
}
crate::MarketDataType::Trade(trade) => {
self.simulate_fill(order, trade.price, 0.0001 * trade.price) // 1bp spread estimate
}
crate::MarketDataType::Bar(bar) => {
// Use close price with estimated spread
self.simulate_fill(order, bar.close, 0.0002 * bar.close) // 2bp spread estimate
}
}
}
}
impl FillSimulator for BacktestFillSimulator {
fn simulate_fill(&self, order: &Order, market_price: f64, spread: f64) -> Option<Fill> {
let fill_price = match order.order_type {
crate::OrderType::Market => {
// Fill at market with slippage
let slippage = market_price * self.slippage_bps / 10000.0;
match order.side {
crate::Side::Buy => market_price + spread/2.0 + slippage,
crate::Side::Sell => market_price - spread/2.0 - slippage,
}
}
crate::OrderType::Limit { price } => {
// Check if limit price is satisfied
match order.side {
crate::Side::Buy if price >= market_price + spread/2.0 => price,
crate::Side::Sell if price <= market_price - spread/2.0 => price,
_ => return None, // Limit not satisfied
}
}
_ => return None, // Other order types not implemented yet
};
let commission = fill_price * order.quantity * self.commission_rate;
Some(Fill {
timestamp: chrono::Utc::now(), // Will be overridden by engine
price: fill_price,
quantity: order.quantity,
commission,
})
}
}

View file

@ -0,0 +1,290 @@
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use std::sync::Arc;
use parking_lot::RwLock;
use tokio::sync::mpsc;
use crate::{
ExecutionHandler, TimeProvider,
Order, OrderStatus, Fill,
};
use crate::events::EventBus;
use crate::domain::{Event, EventType};
use super::TradingEngine;
/// Live trading engine - connects to real brokers
pub struct LiveEngine {
broker: String,
account_id: String,
broker_connection: Arc<dyn BrokerConnection>,
is_running: Arc<RwLock<bool>>,
shutdown_tx: Option<mpsc::Sender<()>>,
}
/// Trait for broker connections
#[async_trait]
pub trait BrokerConnection: Send + Sync {
async fn connect(&mut self, account_id: &str) -> Result<(), String>;
async fn disconnect(&mut self) -> Result<(), String>;
async fn subscribe_market_data(&mut self, symbols: Vec<String>) -> Result<(), String>;
async fn submit_order(&self, order: &Order) -> Result<String, String>;
async fn cancel_order(&self, order_id: &str) -> Result<(), String>;
async fn get_order_status(&self, order_id: &str) -> Result<OrderStatus, String>;
async fn get_positions(&self) -> Result<Vec<crate::Position>, String>;
async fn get_account_info(&self) -> Result<AccountInfo, String>;
}
#[derive(Debug, Clone)]
pub struct AccountInfo {
pub cash: f64,
pub buying_power: f64,
pub portfolio_value: f64,
pub day_trades_remaining: Option<u32>,
}
impl LiveEngine {
pub fn new(
broker: String,
account_id: String,
broker_connection: Arc<dyn BrokerConnection>,
) -> Self {
Self {
broker,
account_id,
broker_connection,
is_running: Arc::new(RwLock::new(false)),
shutdown_tx: None,
}
}
}
#[async_trait]
impl TradingEngine for LiveEngine {
async fn start(&mut self, event_bus: &mut EventBus) -> Result<(), String> {
*self.is_running.write() = true;
// Connect to broker
self.broker_connection.connect(&self.account_id).await?;
// Get initial account info
let account_info = self.broker_connection.get_account_info().await?;
let (shutdown_tx, mut shutdown_rx) = mpsc::channel(1);
self.shutdown_tx = Some(shutdown_tx);
// Emit start event
event_bus.publish(Event {
event_type: EventType::SystemStart,
timestamp: Utc::now(),
data: serde_json::json!({
"mode": "live",
"broker": self.broker,
"account_id": self.account_id,
"account_info": {
"cash": account_info.cash,
"buying_power": account_info.buying_power,
"portfolio_value": account_info.portfolio_value,
},
}),
}).await;
// Main live trading loop
loop {
tokio::select! {
_ = shutdown_rx.recv() => {
break;
}
_ = tokio::time::sleep(std::time::Duration::from_secs(1)) => {
// Periodic tasks (e.g., check positions, risk)
if !*self.is_running.read() {
break;
}
// Could emit periodic status updates here
}
}
}
// Disconnect from broker
self.broker_connection.disconnect().await?;
// Emit stop event
event_bus.publish(Event {
event_type: EventType::SystemStop,
timestamp: Utc::now(),
data: serde_json::json!({
"reason": "live_trading_stopped",
}),
}).await;
Ok(())
}
async fn stop(&mut self) -> Result<(), String> {
*self.is_running.write() = false;
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(()).await;
}
Ok(())
}
fn get_execution_handler(&self) -> Arc<dyn ExecutionHandler> {
Arc::new(LiveExecutor {
broker_connection: self.broker_connection.clone(),
})
}
fn get_time_provider(&self) -> Arc<dyn TimeProvider> {
Arc::new(RealTimeProvider)
}
}
/// Live order executor - delegates to broker
#[derive(Clone)]
struct LiveExecutor {
broker_connection: Arc<dyn BrokerConnection>,
}
#[async_trait]
impl ExecutionHandler for LiveExecutor {
async fn submit_order(&self, order: &Order) -> Result<String, String> {
// Validate order
if order.quantity <= 0.0 {
return Err("Invalid order quantity".to_string());
}
// Submit to broker
self.broker_connection.submit_order(order).await
}
async fn cancel_order(&self, order_id: &str) -> Result<(), String> {
self.broker_connection.cancel_order(order_id).await
}
async fn get_order_status(&self, order_id: &str) -> Result<OrderStatus, String> {
self.broker_connection.get_order_status(order_id).await
}
}
/// Real-time provider for live trading
struct RealTimeProvider;
impl TimeProvider for RealTimeProvider {
fn now(&self) -> DateTime<Utc> {
Utc::now()
}
fn sleep_until(&self, target: DateTime<Utc>) -> Result<(), String> {
let now = Utc::now();
if target > now {
let duration = target.signed_duration_since(now);
std::thread::sleep(duration.to_std().unwrap_or_default());
}
Ok(())
}
}
// Example broker implementations would go here
// For now, we'll create a mock broker for testing
/// Mock broker for testing
pub struct MockBroker {
is_connected: bool,
orders: Arc<RwLock<Vec<(Order, OrderStatus)>>>,
}
impl MockBroker {
pub fn new() -> Self {
Self {
is_connected: false,
orders: Arc::new(RwLock::new(Vec::new())),
}
}
}
#[async_trait]
impl BrokerConnection for MockBroker {
async fn connect(&mut self, _account_id: &str) -> Result<(), String> {
self.is_connected = true;
Ok(())
}
async fn disconnect(&mut self) -> Result<(), String> {
self.is_connected = false;
Ok(())
}
async fn subscribe_market_data(&mut self, _symbols: Vec<String>) -> Result<(), String> {
if !self.is_connected {
return Err("Not connected".to_string());
}
Ok(())
}
async fn submit_order(&self, order: &Order) -> Result<String, String> {
if !self.is_connected {
return Err("Not connected".to_string());
}
let order_id = order.id.clone();
self.orders.write().push((order.clone(), OrderStatus::Pending));
// Simulate order being filled after a delay
let orders = self.orders.clone();
let order_id_clone = order_id.clone();
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let mut orders = orders.write();
if let Some(pos) = orders.iter().position(|(o, _)| o.id == order_id_clone) {
orders[pos].1 = OrderStatus::Filled;
}
});
Ok(order_id)
}
async fn cancel_order(&self, order_id: &str) -> Result<(), String> {
if !self.is_connected {
return Err("Not connected".to_string());
}
let mut orders = self.orders.write();
if let Some(pos) = orders.iter().position(|(o, _)| o.id == order_id) {
orders[pos].1 = OrderStatus::Cancelled;
Ok(())
} else {
Err("Order not found".to_string())
}
}
async fn get_order_status(&self, order_id: &str) -> Result<OrderStatus, String> {
if !self.is_connected {
return Err("Not connected".to_string());
}
let orders = self.orders.read();
if let Some((_, status)) = orders.iter().find(|(o, _)| o.id == order_id) {
Ok(status.clone())
} else {
Err("Order not found".to_string())
}
}
async fn get_positions(&self) -> Result<Vec<crate::Position>, String> {
if !self.is_connected {
return Err("Not connected".to_string());
}
Ok(Vec::new())
}
async fn get_account_info(&self) -> Result<AccountInfo, String> {
if !self.is_connected {
return Err("Not connected".to_string());
}
Ok(AccountInfo {
cash: 100_000.0,
buying_power: 400_000.0,
portfolio_value: 100_000.0,
day_trades_remaining: Some(3),
})
}
}

View file

@ -0,0 +1,106 @@
use async_trait::async_trait;
use crate::events::EventBus;
use crate::{ExecutionHandler, TimeProvider, MarketDataSource};
use std::sync::Arc;
pub mod backtest;
pub mod paper;
pub mod live;
pub use backtest::BacktestEngine;
pub use paper::PaperEngine;
pub use live::{LiveEngine, BrokerConnection};
/// Trading mode configuration
#[derive(Debug, Clone)]
pub enum TradingMode {
Backtest {
start_time: chrono::DateTime<chrono::Utc>,
end_time: chrono::DateTime<chrono::Utc>,
speed_multiplier: f64,
},
Paper {
starting_capital: f64,
},
Live {
broker: String,
account_id: String,
},
}
/// Common interface for all trading engines
#[async_trait]
pub trait TradingEngine: Send + Sync {
/// Start the trading engine
async fn start(&mut self, event_bus: &mut EventBus) -> Result<(), String>;
/// Stop the trading engine
async fn stop(&mut self) -> Result<(), String>;
/// Get the execution handler for this mode
fn get_execution_handler(&self) -> Arc<dyn ExecutionHandler>;
/// Get the time provider for this mode
fn get_time_provider(&self) -> Arc<dyn TimeProvider>;
}
/// Create a trading engine for the specified mode
pub fn create_engine_for_mode(
mode: &TradingMode,
) -> Result<Box<dyn TradingEngine>, String> {
match mode {
TradingMode::Backtest { start_time, end_time, speed_multiplier } => {
// For backtest, we need to create a market data source
// This would typically load historical data
let market_data = create_backtest_data_source()?;
Ok(Box::new(BacktestEngine::new(
*start_time,
*end_time,
*speed_multiplier,
market_data,
)))
}
TradingMode::Paper { starting_capital } => {
// For paper trading, we need a real-time data source
let market_data = create_realtime_data_source()?;
Ok(Box::new(PaperEngine::new(
*starting_capital,
market_data,
)))
}
TradingMode::Live { broker, account_id } => {
// For live trading, we need a broker connection
let broker_connection = create_broker_connection(broker)?;
Ok(Box::new(LiveEngine::new(
broker.clone(),
account_id.clone(),
broker_connection,
)))
}
}
}
// Helper functions to create data sources and broker connections
// These would be implemented based on your specific requirements
fn create_backtest_data_source() -> Result<Box<dyn MarketDataSource>, String> {
// TODO: Implement actual backtest data source
// For now, return a placeholder
Err("Backtest data source not implemented yet".to_string())
}
fn create_realtime_data_source() -> Result<Box<dyn MarketDataSource>, String> {
// TODO: Implement actual real-time data source
// For now, return a placeholder
Err("Real-time data source not implemented yet".to_string())
}
fn create_broker_connection(broker: &str) -> Result<Arc<dyn BrokerConnection>, String> {
match broker {
"mock" => Ok(Arc::new(live::MockBroker::new())),
_ => Err(format!("Unknown broker: {}", broker)),
}
}

View file

@ -0,0 +1,306 @@
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use std::sync::Arc;
use parking_lot::RwLock;
use tokio::sync::mpsc;
use crate::{
MarketDataSource, ExecutionHandler, TimeProvider, FillSimulator,
MarketUpdate, Order, OrderStatus, Fill,
};
use crate::events::EventBus;
use crate::domain::{Event, EventType};
use super::TradingEngine;
/// Paper trading engine - simulates live trading without real money
pub struct PaperEngine {
starting_capital: f64,
market_data: Box<dyn MarketDataSource>,
execution: PaperExecutor,
is_running: Arc<RwLock<bool>>,
shutdown_tx: Option<mpsc::Sender<()>>,
}
impl PaperEngine {
pub fn new(
starting_capital: f64,
market_data: Box<dyn MarketDataSource>,
) -> Self {
Self {
starting_capital,
market_data,
execution: PaperExecutor::new(starting_capital),
is_running: Arc::new(RwLock::new(false)),
shutdown_tx: None,
}
}
}
#[async_trait]
impl TradingEngine for PaperEngine {
async fn start(&mut self, event_bus: &mut EventBus) -> Result<(), String> {
*self.is_running.write() = true;
let (shutdown_tx, mut shutdown_rx) = mpsc::channel(1);
self.shutdown_tx = Some(shutdown_tx);
// Emit start event
event_bus.publish(Event {
event_type: EventType::SystemStart,
timestamp: Utc::now(),
data: serde_json::json!({
"mode": "paper",
"starting_capital": self.starting_capital,
}),
}).await;
// Main paper trading loop
loop {
tokio::select! {
_ = shutdown_rx.recv() => {
break;
}
update = self.market_data.get_next_update() => {
if let Some(market_update) = update {
// Publish market data event
event_bus.publish(Event {
event_type: EventType::MarketData(market_update.clone()),
timestamp: Utc::now(),
data: serde_json::Value::Null,
}).await;
// Process pending orders
self.execution.process_orders(&market_update, event_bus).await?;
}
}
_ = tokio::time::sleep(std::time::Duration::from_millis(100)) => {
// Check if still running
if !*self.is_running.read() {
break;
}
}
}
}
// Emit stop event
event_bus.publish(Event {
event_type: EventType::SystemStop,
timestamp: Utc::now(),
data: serde_json::json!({
"reason": "paper_trading_stopped",
}),
}).await;
Ok(())
}
async fn stop(&mut self) -> Result<(), String> {
*self.is_running.write() = false;
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(()).await;
}
Ok(())
}
fn get_execution_handler(&self) -> Arc<dyn ExecutionHandler> {
Arc::new(self.execution.clone())
}
fn get_time_provider(&self) -> Arc<dyn TimeProvider> {
Arc::new(RealTimeProvider)
}
}
/// Paper trading order executor
#[derive(Clone)]
struct PaperExecutor {
cash: Arc<RwLock<f64>>,
pending_orders: Arc<RwLock<Vec<Order>>>,
order_history: Arc<RwLock<Vec<(Order, OrderStatus)>>>,
fill_simulator: PaperFillSimulator,
}
impl PaperExecutor {
fn new(starting_capital: f64) -> Self {
Self {
cash: Arc::new(RwLock::new(starting_capital)),
pending_orders: Arc::new(RwLock::new(Vec::new())),
order_history: Arc::new(RwLock::new(Vec::new())),
fill_simulator: PaperFillSimulator::new(),
}
}
async fn process_orders(
&self,
market_update: &MarketUpdate,
event_bus: &mut EventBus,
) -> Result<(), String> {
let mut orders = self.pending_orders.write();
let mut filled_indices = Vec::new();
for (idx, order) in orders.iter().enumerate() {
if order.symbol != market_update.symbol {
continue;
}
// Try to fill the order
if let Some(fill) = self.fill_simulator.try_fill(order, market_update) {
// Check if we have enough cash for buy orders
if order.side == crate::Side::Buy {
let required_cash = fill.price * fill.quantity + fill.commission;
let mut cash = self.cash.write();
if *cash >= required_cash {
*cash -= required_cash;
} else {
continue; // Skip this order
}
} else {
// For sell orders, add cash
let proceeds = fill.price * fill.quantity - fill.commission;
*self.cash.write() += proceeds;
}
// Emit fill event
event_bus.publish(Event {
event_type: EventType::OrderFill {
order_id: order.id.clone(),
fill: fill.clone(),
},
timestamp: Utc::now(),
data: serde_json::json!({
"cash_after": *self.cash.read(),
}),
}).await;
filled_indices.push(idx);
}
}
// Remove filled orders
for idx in filled_indices.into_iter().rev() {
let order = orders.remove(idx);
self.order_history.write().push((order, OrderStatus::Filled));
}
Ok(())
}
}
#[async_trait]
impl ExecutionHandler for PaperExecutor {
async fn submit_order(&self, order: &Order) -> Result<String, String> {
// Basic validation
if order.quantity <= 0.0 {
return Err("Invalid order quantity".to_string());
}
let order_id = order.id.clone();
self.pending_orders.write().push(order.clone());
Ok(order_id)
}
async fn cancel_order(&self, order_id: &str) -> Result<(), String> {
let mut orders = self.pending_orders.write();
if let Some(pos) = orders.iter().position(|o| o.id == order_id) {
let order = orders.remove(pos);
self.order_history.write().push((order, OrderStatus::Cancelled));
Ok(())
} else {
Err("Order not found".to_string())
}
}
async fn get_order_status(&self, order_id: &str) -> Result<OrderStatus, String> {
// Check pending orders
if self.pending_orders.read().iter().any(|o| o.id == order_id) {
return Ok(OrderStatus::Pending);
}
// Check history
if let Some((_, status)) = self.order_history.read()
.iter()
.find(|(o, _)| o.id == order_id) {
return Ok(status.clone());
}
Err("Order not found".to_string())
}
}
/// Real-time provider for paper trading
struct RealTimeProvider;
impl TimeProvider for RealTimeProvider {
fn now(&self) -> DateTime<Utc> {
Utc::now()
}
fn sleep_until(&self, target: DateTime<Utc>) -> Result<(), String> {
let now = Utc::now();
if target > now {
let duration = target.signed_duration_since(now);
std::thread::sleep(duration.to_std().unwrap_or_default());
}
Ok(())
}
}
/// Paper trading fill simulator
#[derive(Clone)]
struct PaperFillSimulator {
commission_rate: f64,
slippage_bps: f64,
}
impl PaperFillSimulator {
fn new() -> Self {
Self {
commission_rate: 0.001, // 0.1%
slippage_bps: 3.0, // 3 basis points (less than backtest)
}
}
fn try_fill(&self, order: &Order, market_update: &MarketUpdate) -> Option<Fill> {
match &market_update.data {
crate::MarketDataType::Quote(quote) => {
self.simulate_fill(order, quote.bid, quote.ask - quote.bid)
}
crate::MarketDataType::Trade(trade) => {
self.simulate_fill(order, trade.price, 0.0001 * trade.price)
}
crate::MarketDataType::Bar(bar) => {
self.simulate_fill(order, bar.close, 0.0002 * bar.close)
}
}
}
}
impl FillSimulator for PaperFillSimulator {
fn simulate_fill(&self, order: &Order, market_price: f64, spread: f64) -> Option<Fill> {
let fill_price = match order.order_type {
crate::OrderType::Market => {
let slippage = market_price * self.slippage_bps / 10000.0;
match order.side {
crate::Side::Buy => market_price + spread/2.0 + slippage,
crate::Side::Sell => market_price - spread/2.0 - slippage,
}
}
crate::OrderType::Limit { price } => {
match order.side {
crate::Side::Buy if price >= market_price + spread/2.0 => price,
crate::Side::Sell if price <= market_price - spread/2.0 => price,
_ => return None,
}
}
_ => return None,
};
let commission = fill_price * order.quantity * self.commission_rate;
Some(Fill {
timestamp: Utc::now(),
price: fill_price,
quantity: order.quantity,
commission,
})
}
}

View file

@ -0,0 +1,374 @@
use crate::{OrderBookSnapshot, PriceLevel};
use serde::{Serialize, Deserialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OrderBookAnalytics {
pub spread: f64,
pub spread_bps: f64,
pub mid_price: f64,
pub micro_price: f64, // Size-weighted mid price
pub imbalance: f64, // -1 to 1 (negative = bid pressure)
pub depth_imbalance: OrderBookImbalance,
pub liquidity_score: f64,
pub effective_spread: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OrderBookImbalance {
pub level_1: f64,
pub level_5: f64,
pub level_10: f64,
pub weighted: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LiquidityProfile {
pub bid_liquidity: Vec<LiquidityLevel>,
pub ask_liquidity: Vec<LiquidityLevel>,
pub total_bid_depth: f64,
pub total_ask_depth: f64,
pub bid_depth_weighted_price: f64,
pub ask_depth_weighted_price: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LiquidityLevel {
pub price: f64,
pub size: f64,
pub cumulative_size: f64,
pub cost_to_execute: f64, // Cost to buy/sell up to this level
}
impl OrderBookAnalytics {
pub fn calculate(snapshot: &OrderBookSnapshot) -> Option<Self> {
if snapshot.bids.is_empty() || snapshot.asks.is_empty() {
return None;
}
let best_bid = snapshot.bids[0].price;
let best_ask = snapshot.asks[0].price;
let spread = best_ask - best_bid;
let mid_price = (best_bid + best_ask) / 2.0;
let spread_bps = (spread / mid_price) * 10000.0;
// Calculate micro price (size-weighted)
let bid_size = snapshot.bids[0].size;
let ask_size = snapshot.asks[0].size;
let micro_price = (best_bid * ask_size + best_ask * bid_size) / (bid_size + ask_size);
// Calculate imbalance
let imbalance = (bid_size - ask_size) / (bid_size + ask_size);
// Calculate depth imbalance at different levels
let depth_imbalance = Self::calculate_depth_imbalance(snapshot);
// Calculate liquidity score
let liquidity_score = Self::calculate_liquidity_score(snapshot);
// Effective spread (considers depth)
let effective_spread = Self::calculate_effective_spread(snapshot, 1000.0); // $1000 order
Some(OrderBookAnalytics {
spread,
spread_bps,
mid_price,
micro_price,
imbalance,
depth_imbalance,
liquidity_score,
effective_spread,
})
}
fn calculate_depth_imbalance(snapshot: &OrderBookSnapshot) -> OrderBookImbalance {
let calc_imbalance = |depth: usize| -> f64 {
let bid_depth: f64 = snapshot.bids.iter()
.take(depth)
.map(|l| l.size)
.sum();
let ask_depth: f64 = snapshot.asks.iter()
.take(depth)
.map(|l| l.size)
.sum();
if bid_depth + ask_depth > 0.0 {
(bid_depth - ask_depth) / (bid_depth + ask_depth)
} else {
0.0
}
};
// Weighted imbalance (more weight on top levels)
let mut weighted_bid = 0.0;
let mut weighted_ask = 0.0;
let mut weight_sum = 0.0;
for (i, (bid, ask)) in snapshot.bids.iter().zip(snapshot.asks.iter()).enumerate().take(10) {
let weight = 1.0 / (i + 1) as f64;
weighted_bid += bid.size * weight;
weighted_ask += ask.size * weight;
weight_sum += weight;
}
let weighted = if weighted_bid + weighted_ask > 0.0 {
(weighted_bid - weighted_ask) / (weighted_bid + weighted_ask)
} else {
0.0
};
OrderBookImbalance {
level_1: calc_imbalance(1),
level_5: calc_imbalance(5),
level_10: calc_imbalance(10),
weighted,
}
}
fn calculate_liquidity_score(snapshot: &OrderBookSnapshot) -> f64 {
// Liquidity score based on depth and tightness
let depth_score = (snapshot.bids.len() + snapshot.asks.len()) as f64 / 20.0; // Normalize by 10 levels each side
let volume_score = {
let bid_volume: f64 = snapshot.bids.iter().take(5).map(|l| l.size).sum();
let ask_volume: f64 = snapshot.asks.iter().take(5).map(|l| l.size).sum();
((bid_volume + ask_volume) / 10000.0).min(1.0) // Normalize by $10k
};
let spread_score = if let (Some(bid), Some(ask)) = (snapshot.bids.first(), snapshot.asks.first()) {
let spread_bps = ((ask.price - bid.price) / ((ask.price + bid.price) / 2.0)) * 10000.0;
(50.0 / (spread_bps + 1.0)).min(1.0) // Lower spread = higher score
} else {
0.0
};
(depth_score * 0.3 + volume_score * 0.4 + spread_score * 0.3).min(1.0)
}
fn calculate_effective_spread(snapshot: &OrderBookSnapshot, order_size_usd: f64) -> f64 {
let avg_execution_price = |levels: &[PriceLevel], size_usd: f64, is_buy: bool| -> Option<f64> {
let mut remaining = size_usd;
let mut total_cost = 0.0;
let mut total_shares = 0.0;
for level in levels {
let level_value = level.price * level.size;
if remaining <= level_value {
let shares = remaining / level.price;
total_cost += remaining;
total_shares += shares;
break;
} else {
total_cost += level_value;
total_shares += level.size;
remaining -= level_value;
}
}
if total_shares > 0.0 {
Some(total_cost / total_shares)
} else {
None
}
};
if let (Some(bid_exec), Some(ask_exec)) = (
avg_execution_price(&snapshot.bids, order_size_usd, false),
avg_execution_price(&snapshot.asks, order_size_usd, true)
) {
ask_exec - bid_exec
} else if let (Some(bid), Some(ask)) = (snapshot.bids.first(), snapshot.asks.first()) {
ask.price - bid.price
} else {
0.0
}
}
}
impl LiquidityProfile {
pub fn from_snapshot(snapshot: &OrderBookSnapshot) -> Self {
let mut bid_liquidity = Vec::new();
let mut ask_liquidity = Vec::new();
let mut cumulative_bid_size = 0.0;
let mut cumulative_bid_cost = 0.0;
for bid in &snapshot.bids {
cumulative_bid_size += bid.size;
cumulative_bid_cost += bid.price * bid.size;
bid_liquidity.push(LiquidityLevel {
price: bid.price,
size: bid.size,
cumulative_size: cumulative_bid_size,
cost_to_execute: cumulative_bid_cost,
});
}
let mut cumulative_ask_size = 0.0;
let mut cumulative_ask_cost = 0.0;
for ask in &snapshot.asks {
cumulative_ask_size += ask.size;
cumulative_ask_cost += ask.price * ask.size;
ask_liquidity.push(LiquidityLevel {
price: ask.price,
size: ask.size,
cumulative_size: cumulative_ask_size,
cost_to_execute: cumulative_ask_cost,
});
}
let total_bid_depth = cumulative_bid_cost;
let total_ask_depth = cumulative_ask_cost;
let bid_depth_weighted_price = if cumulative_bid_size > 0.0 {
cumulative_bid_cost / cumulative_bid_size
} else {
0.0
};
let ask_depth_weighted_price = if cumulative_ask_size > 0.0 {
cumulative_ask_cost / cumulative_ask_size
} else {
0.0
};
Self {
bid_liquidity,
ask_liquidity,
total_bid_depth,
total_ask_depth,
bid_depth_weighted_price,
ask_depth_weighted_price,
}
}
/// Calculate the market impact of executing a given size
pub fn calculate_market_impact(&self, size_usd: f64, is_buy: bool) -> MarketImpact {
let levels = if is_buy { &self.ask_liquidity } else { &self.bid_liquidity };
if levels.is_empty() {
return MarketImpact::default();
}
let reference_price = levels[0].price;
let mut remaining = size_usd;
let mut total_cost = 0.0;
let mut total_shares = 0.0;
let mut levels_consumed = 0;
for (i, level) in levels.iter().enumerate() {
let level_value = level.price * level.size;
if remaining <= level_value {
let shares = remaining / level.price;
total_cost += shares * level.price;
total_shares += shares;
levels_consumed = i + 1;
break;
} else {
total_cost += level_value;
total_shares += level.size;
remaining -= level_value;
levels_consumed = i + 1;
}
}
let avg_execution_price = if total_shares > 0.0 {
total_cost / total_shares
} else {
reference_price
};
let price_impact = if is_buy {
(avg_execution_price - reference_price) / reference_price
} else {
(reference_price - avg_execution_price) / reference_price
};
let slippage = (avg_execution_price - reference_price).abs();
MarketImpact {
avg_execution_price,
price_impact,
slippage,
levels_consumed,
total_shares,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct MarketImpact {
pub avg_execution_price: f64,
pub price_impact: f64, // As percentage
pub slippage: f64, // In price units
pub levels_consumed: usize,
pub total_shares: f64,
}
/// Track orderbook dynamics over time
pub struct OrderBookDynamics {
snapshots: Vec<(chrono::DateTime<chrono::Utc>, OrderBookAnalytics)>,
max_history: usize,
}
impl OrderBookDynamics {
pub fn new(max_history: usize) -> Self {
Self {
snapshots: Vec::new(),
max_history,
}
}
pub fn add_snapshot(&mut self, timestamp: chrono::DateTime<chrono::Utc>, analytics: OrderBookAnalytics) {
self.snapshots.push((timestamp, analytics));
if self.snapshots.len() > self.max_history {
self.snapshots.remove(0);
}
}
pub fn get_volatility(&self, window: usize) -> Option<f64> {
if self.snapshots.len() < window {
return None;
}
let recent = &self.snapshots[self.snapshots.len() - window..];
let mid_prices: Vec<f64> = recent.iter().map(|(_, a)| a.mid_price).collect();
let mean = mid_prices.iter().sum::<f64>() / mid_prices.len() as f64;
let variance = mid_prices.iter()
.map(|p| (p - mean).powi(2))
.sum::<f64>() / mid_prices.len() as f64;
Some(variance.sqrt())
}
pub fn get_average_spread(&self, window: usize) -> Option<f64> {
if self.snapshots.len() < window {
return None;
}
let recent = &self.snapshots[self.snapshots.len() - window..];
let total_spread: f64 = recent.iter().map(|(_, a)| a.spread).sum();
Some(total_spread / window as f64)
}
pub fn detect_momentum(&self, window: usize) -> Option<f64> {
if self.snapshots.len() < window {
return None;
}
let recent = &self.snapshots[self.snapshots.len() - window..];
let imbalances: Vec<f64> = recent.iter()
.map(|(_, a)| a.depth_imbalance.weighted)
.collect();
// Average imbalance indicates momentum direction
Some(imbalances.iter().sum::<f64>() / imbalances.len() as f64)
}
}

View file

@ -0,0 +1,313 @@
pub mod analytics;
use crate::{Quote, Trade, Side, OrderBookSnapshot, PriceLevel};
use chrono::{DateTime, Utc};
use dashmap::DashMap;
use parking_lot::RwLock;
use std::collections::BTreeMap;
use std::sync::Arc;
pub use analytics::{OrderBookAnalytics, LiquidityProfile, OrderBookImbalance, MarketImpact};
// Manages order books for all symbols
pub struct OrderBookManager {
books: DashMap<String, Arc<RwLock<OrderBook>>>,
}
impl OrderBookManager {
pub fn new() -> Self {
Self {
books: DashMap::new(),
}
}
pub fn get_or_create(&self, symbol: &str) -> Arc<RwLock<OrderBook>> {
self.books
.entry(symbol.to_string())
.or_insert_with(|| Arc::new(RwLock::new(OrderBook::new(symbol.to_string()))))
.clone()
}
pub fn update_quote(&self, symbol: &str, quote: Quote, timestamp: DateTime<Utc>) {
let book = self.get_or_create(symbol);
let mut book_guard = book.write();
book_guard.update_quote(quote, timestamp);
}
pub fn update_trade(&self, symbol: &str, trade: Trade, timestamp: DateTime<Utc>) {
let book = self.get_or_create(symbol);
let mut book_guard = book.write();
book_guard.update_trade(trade, timestamp);
}
pub fn get_snapshot(&self, symbol: &str, depth: usize) -> Option<OrderBookSnapshot> {
self.books.get(symbol).map(|book| {
let book_guard = book.read();
book_guard.get_snapshot(depth)
})
}
pub fn get_best_bid_ask(&self, symbol: &str) -> Option<(f64, f64)> {
self.books.get(symbol).and_then(|book| {
let book_guard = book.read();
book_guard.get_best_bid_ask()
})
}
pub fn get_analytics(&self, symbol: &str, depth: usize) -> Option<OrderBookAnalytics> {
self.get_snapshot(symbol, depth)
.and_then(|snapshot| OrderBookAnalytics::calculate(&snapshot))
}
pub fn get_liquidity_profile(&self, symbol: &str, depth: usize) -> Option<LiquidityProfile> {
self.get_snapshot(symbol, depth)
.map(|snapshot| LiquidityProfile::from_snapshot(&snapshot))
}
}
// Individual order book for a symbol
pub struct OrderBook {
symbol: String,
bids: BTreeMap<OrderedFloat, Level>,
asks: BTreeMap<OrderedFloat, Level>,
last_update: DateTime<Utc>,
last_trade_price: Option<f64>,
last_trade_size: Option<f64>,
}
#[derive(Clone, Debug)]
struct Level {
price: f64,
size: f64,
order_count: u32,
last_update: DateTime<Utc>,
}
// Wrapper for f64 to allow BTreeMap ordering
#[derive(Clone, Copy, Debug, PartialEq)]
struct OrderedFloat(f64);
impl Eq for OrderedFloat {}
impl PartialOrd for OrderedFloat {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.0.partial_cmp(&other.0)
}
}
impl Ord for OrderedFloat {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.partial_cmp(other).unwrap_or(std::cmp::Ordering::Equal)
}
}
impl OrderBook {
pub fn new(symbol: String) -> Self {
Self {
symbol,
bids: BTreeMap::new(),
asks: BTreeMap::new(),
last_update: Utc::now(),
last_trade_price: None,
last_trade_size: None,
}
}
pub fn update_quote(&mut self, quote: Quote, timestamp: DateTime<Utc>) {
// Update bid
if quote.bid > 0.0 && quote.bid_size > 0.0 {
self.bids.insert(
OrderedFloat(-quote.bid), // Negative for reverse ordering
Level {
price: quote.bid,
size: quote.bid_size,
order_count: 1,
last_update: timestamp,
},
);
}
// Update ask
if quote.ask > 0.0 && quote.ask_size > 0.0 {
self.asks.insert(
OrderedFloat(quote.ask),
Level {
price: quote.ask,
size: quote.ask_size,
order_count: 1,
last_update: timestamp,
},
);
}
self.last_update = timestamp;
self.clean_stale_levels(timestamp);
}
pub fn update_trade(&mut self, trade: Trade, timestamp: DateTime<Utc>) {
self.last_trade_price = Some(trade.price);
self.last_trade_size = Some(trade.size);
self.last_update = timestamp;
// Optionally update order book based on trade
// Remove liquidity that was likely consumed
match trade.side {
Side::Buy => {
// Trade hit the ask, remove liquidity
self.remove_liquidity_up_to_asks(trade.price, trade.size);
}
Side::Sell => {
// Trade hit the bid, remove liquidity
self.remove_liquidity_up_to_bids(trade.price, trade.size);
}
}
}
pub fn get_snapshot(&self, depth: usize) -> OrderBookSnapshot {
let bids: Vec<PriceLevel> = self.bids
.values()
.take(depth)
.map(|level| PriceLevel {
price: level.price,
size: level.size,
order_count: Some(level.order_count),
})
.collect();
let asks: Vec<PriceLevel> = self.asks
.values()
.take(depth)
.map(|level| PriceLevel {
price: level.price,
size: level.size,
order_count: Some(level.order_count),
})
.collect();
OrderBookSnapshot {
symbol: self.symbol.clone(),
timestamp: self.last_update,
bids,
asks,
}
}
pub fn get_best_bid_ask(&self) -> Option<(f64, f64)> {
let best_bid = self.bids.values().next()?.price;
let best_ask = self.asks.values().next()?.price;
Some((best_bid, best_ask))
}
pub fn get_mid_price(&self) -> Option<f64> {
self.get_best_bid_ask()
.map(|(bid, ask)| (bid + ask) / 2.0)
}
pub fn get_spread(&self) -> Option<f64> {
self.get_best_bid_ask()
.map(|(bid, ask)| ask - bid)
}
pub fn get_depth_at_price(&self, price: f64, side: Side) -> f64 {
match side {
Side::Buy => {
self.bids.values()
.filter(|level| level.price >= price)
.map(|level| level.size)
.sum()
}
Side::Sell => {
self.asks.values()
.filter(|level| level.price <= price)
.map(|level| level.size)
.sum()
}
}
}
pub fn get_volume_weighted_price(&self, size: f64, side: Side) -> Option<f64> {
let levels: Vec<&Level> = match side {
Side::Buy => self.asks.values().collect(),
Side::Sell => self.bids.values().collect(),
};
let mut remaining_size = size;
let mut total_cost = 0.0;
let mut total_shares = 0.0;
for level in levels {
if remaining_size <= 0.0 {
break;
}
let fill_size = remaining_size.min(level.size);
total_cost += fill_size * level.price;
total_shares += fill_size;
remaining_size -= fill_size;
}
if total_shares > 0.0 {
Some(total_cost / total_shares)
} else {
None
}
}
fn clean_stale_levels(&mut self, current_time: DateTime<Utc>) {
let stale_threshold = chrono::Duration::seconds(60); // 60 seconds
self.bids.retain(|_, level| {
current_time - level.last_update < stale_threshold
});
self.asks.retain(|_, level| {
current_time - level.last_update < stale_threshold
});
}
fn remove_liquidity_up_to_asks(&mut self, price: f64, size: f64) {
let mut remaining_size = size;
let mut to_remove = Vec::new();
for (key, level) in self.asks.iter_mut() {
if level.price <= price {
if level.size <= remaining_size {
remaining_size -= level.size;
to_remove.push(*key);
} else {
level.size -= remaining_size;
break;
}
} else {
break;
}
}
for key in to_remove {
self.asks.remove(&key);
}
}
fn remove_liquidity_up_to_bids(&mut self, price: f64, size: f64) {
let mut remaining_size = size;
let mut to_remove = Vec::new();
for (key, level) in self.bids.iter_mut() {
if level.price >= price {
if level.size <= remaining_size {
remaining_size -= level.size;
to_remove.push(*key);
} else {
level.size -= remaining_size;
break;
}
} else {
break;
}
}
for key in to_remove {
self.bids.remove(&key);
}
}
}

View file

@ -0,0 +1,398 @@
use crate::{Fill, Side};
use chrono::{DateTime, Utc};
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use parking_lot::RwLock;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Position {
pub symbol: String,
pub quantity: f64,
pub average_price: f64,
pub realized_pnl: f64,
pub unrealized_pnl: f64,
pub total_cost: f64,
pub last_update: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PositionUpdate {
pub symbol: String,
pub fill: Fill,
pub resulting_position: Position,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TradeRecord {
pub id: String,
pub symbol: String,
pub side: Side,
pub quantity: f64,
pub price: f64,
pub timestamp: DateTime<Utc>,
pub commission: f64,
pub order_id: Option<String>,
pub strategy_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClosedTrade {
pub id: String,
pub symbol: String,
pub entry_time: DateTime<Utc>,
pub exit_time: DateTime<Utc>,
pub entry_price: f64,
pub exit_price: f64,
pub quantity: f64,
pub side: Side, // Side of the opening trade
pub pnl: f64,
pub pnl_percent: f64,
pub commission: f64,
pub duration_ms: i64,
pub entry_fill_id: String,
pub exit_fill_id: String,
}
pub struct PositionTracker {
positions: DashMap<String, Position>,
trade_history: Arc<RwLock<Vec<TradeRecord>>>,
closed_trades: Arc<RwLock<Vec<ClosedTrade>>>,
open_trades: DashMap<String, Vec<TradeRecord>>, // Track open trades by symbol
next_trade_id: Arc<RwLock<u64>>,
}
impl PositionTracker {
pub fn new() -> Self {
Self {
positions: DashMap::new(),
trade_history: Arc::new(RwLock::new(Vec::new())),
closed_trades: Arc::new(RwLock::new(Vec::new())),
open_trades: DashMap::new(),
next_trade_id: Arc::new(RwLock::new(1)),
}
}
fn generate_trade_id(&self) -> String {
let mut id = self.next_trade_id.write();
let current_id = *id;
*id += 1;
format!("T{:08}", current_id)
}
pub fn process_fill_with_tracking(
&self,
symbol: &str,
fill: &Fill,
side: Side,
order_id: Option<String>,
strategy_id: Option<String>
) -> PositionUpdate {
// First process the fill normally
let update = self.process_fill(symbol, fill, side);
// Create trade record
let trade_record = TradeRecord {
id: self.generate_trade_id(),
symbol: symbol.to_string(),
side,
quantity: fill.quantity,
price: fill.price,
timestamp: fill.timestamp,
commission: fill.commission,
order_id,
strategy_id,
};
// Add to trade history
self.trade_history.write().push(trade_record.clone());
// Handle trade matching for closed trades
match side {
Side::Buy => {
// For buy orders, try to match with open sell trades (closing shorts)
if let Some(mut open_trades) = self.open_trades.get_mut(symbol) {
let mut remaining_quantity = fill.quantity;
let mut trades_to_remove = Vec::new();
// FIFO matching against short positions
for (idx, open_trade) in open_trades.iter_mut().enumerate() {
if open_trade.side == Side::Sell && remaining_quantity > 0.0 {
let close_quantity = remaining_quantity.min(open_trade.quantity);
// Create closed trade record for short position
let closed_trade = ClosedTrade {
id: format!("CT{}", self.generate_trade_id()),
symbol: symbol.to_string(),
entry_time: open_trade.timestamp,
exit_time: fill.timestamp,
entry_price: open_trade.price,
exit_price: fill.price,
quantity: close_quantity,
side: Side::Sell, // Opening side (short)
pnl: close_quantity * (open_trade.price - fill.price) - (open_trade.commission + fill.commission * close_quantity / fill.quantity),
pnl_percent: ((open_trade.price - fill.price) / open_trade.price) * 100.0,
commission: open_trade.commission + fill.commission * close_quantity / fill.quantity,
duration_ms: (fill.timestamp - open_trade.timestamp).num_milliseconds(),
entry_fill_id: open_trade.id.clone(),
exit_fill_id: trade_record.id.clone(),
};
self.closed_trades.write().push(closed_trade);
// Update quantities
remaining_quantity -= close_quantity;
open_trade.quantity -= close_quantity;
if open_trade.quantity <= 0.0 {
trades_to_remove.push(idx);
}
}
}
// Remove fully closed trades
for idx in trades_to_remove.into_iter().rev() {
open_trades.remove(idx);
}
// If we still have quantity left, it's a new long position
if remaining_quantity > 0.0 {
let long_trade = TradeRecord {
quantity: remaining_quantity,
..trade_record.clone()
};
open_trades.push(long_trade);
}
} else {
// No open trades, start a new long position
self.open_trades.entry(symbol.to_string())
.or_insert_with(Vec::new)
.push(trade_record);
}
}
Side::Sell => {
// For sell orders, try to match with open buy trades
if let Some(mut open_trades) = self.open_trades.get_mut(symbol) {
let mut remaining_quantity = fill.quantity;
let mut trades_to_remove = Vec::new();
// FIFO matching
for (idx, open_trade) in open_trades.iter_mut().enumerate() {
if open_trade.side == Side::Buy && remaining_quantity > 0.0 {
let close_quantity = remaining_quantity.min(open_trade.quantity);
// Create closed trade record
let closed_trade = ClosedTrade {
id: format!("CT{}", self.generate_trade_id()),
symbol: symbol.to_string(),
entry_time: open_trade.timestamp,
exit_time: fill.timestamp,
entry_price: open_trade.price,
exit_price: fill.price,
quantity: close_quantity,
side: Side::Buy, // Opening side
pnl: close_quantity * (fill.price - open_trade.price) - (open_trade.commission + fill.commission * close_quantity / fill.quantity),
pnl_percent: ((fill.price - open_trade.price) / open_trade.price) * 100.0,
commission: open_trade.commission + fill.commission * close_quantity / fill.quantity,
duration_ms: (fill.timestamp - open_trade.timestamp).num_milliseconds(),
entry_fill_id: open_trade.id.clone(),
exit_fill_id: trade_record.id.clone(),
};
self.closed_trades.write().push(closed_trade);
// Update quantities
remaining_quantity -= close_quantity;
open_trade.quantity -= close_quantity;
if open_trade.quantity <= 0.0 {
trades_to_remove.push(idx);
}
}
}
// Remove fully closed trades
for idx in trades_to_remove.into_iter().rev() {
open_trades.remove(idx);
}
// If we still have quantity left, it's a short position
if remaining_quantity > 0.0 {
let short_trade = TradeRecord {
quantity: remaining_quantity,
..trade_record.clone()
};
open_trades.push(short_trade);
}
} else {
// No open trades, start a new short position
self.open_trades.entry(symbol.to_string())
.or_insert_with(Vec::new)
.push(trade_record);
}
}
}
update
}
pub fn process_fill(&self, symbol: &str, fill: &Fill, side: Side) -> PositionUpdate {
let mut entry = self.positions.entry(symbol.to_string()).or_insert_with(|| {
Position {
symbol: symbol.to_string(),
quantity: 0.0,
average_price: 0.0,
realized_pnl: 0.0,
unrealized_pnl: 0.0,
total_cost: 0.0,
last_update: fill.timestamp,
}
});
let position = entry.value_mut();
let old_quantity = position.quantity;
let old_avg_price = position.average_price;
// Calculate new position
match side {
Side::Buy => {
// Adding to position
position.quantity += fill.quantity;
if old_quantity >= 0.0 {
// Already long or flat, average up/down
position.total_cost += fill.price * fill.quantity;
position.average_price = if position.quantity > 0.0 {
position.total_cost / position.quantity
} else {
0.0
};
} else {
// Was short, closing or flipping
let close_quantity = fill.quantity.min(-old_quantity);
let open_quantity = fill.quantity - close_quantity;
// Realize P&L on closed portion
position.realized_pnl += close_quantity * (old_avg_price - fill.price);
// Update position for remaining
if open_quantity > 0.0 {
position.total_cost = open_quantity * fill.price;
position.average_price = fill.price;
} else {
position.total_cost = (position.quantity.abs()) * old_avg_price;
}
}
}
Side::Sell => {
// Reducing position
position.quantity -= fill.quantity;
if old_quantity <= 0.0 {
// Already short or flat, average up/down
position.total_cost += fill.price * fill.quantity;
position.average_price = if position.quantity < 0.0 {
position.total_cost / position.quantity.abs()
} else {
0.0
};
} else {
// Was long, closing or flipping
let close_quantity = fill.quantity.min(old_quantity);
let open_quantity = fill.quantity - close_quantity;
// Realize P&L on closed portion
position.realized_pnl += close_quantity * (fill.price - old_avg_price);
// Update position for remaining
if open_quantity > 0.0 {
position.total_cost = open_quantity * fill.price;
position.average_price = fill.price;
} else {
position.total_cost = (position.quantity.abs()) * old_avg_price;
}
}
}
}
// Subtract commission from realized P&L
position.realized_pnl -= fill.commission;
position.last_update = fill.timestamp;
PositionUpdate {
symbol: symbol.to_string(),
fill: fill.clone(),
resulting_position: position.clone(),
}
}
pub fn get_position(&self, symbol: &str) -> Option<Position> {
self.positions.get(symbol).map(|p| p.clone())
}
pub fn get_all_positions(&self) -> Vec<Position> {
self.positions.iter().map(|entry| entry.value().clone()).collect()
}
pub fn get_open_positions(&self) -> Vec<Position> {
self.positions
.iter()
.filter(|entry| entry.value().quantity.abs() > 0.0001)
.map(|entry| entry.value().clone())
.collect()
}
pub fn update_unrealized_pnl(&self, symbol: &str, current_price: f64) {
if let Some(mut position) = self.positions.get_mut(symbol) {
if position.quantity > 0.0 {
position.unrealized_pnl = position.quantity * (current_price - position.average_price);
} else if position.quantity < 0.0 {
position.unrealized_pnl = position.quantity * (current_price - position.average_price);
} else {
position.unrealized_pnl = 0.0;
}
}
}
pub fn get_total_pnl(&self) -> (f64, f64) {
let mut realized = 0.0;
let mut unrealized = 0.0;
for position in self.positions.iter() {
realized += position.realized_pnl;
unrealized += position.unrealized_pnl;
}
(realized, unrealized)
}
pub fn reset(&self) {
self.positions.clear();
self.trade_history.write().clear();
self.closed_trades.write().clear();
self.open_trades.clear();
*self.next_trade_id.write() = 1;
}
pub fn get_trade_history(&self) -> Vec<TradeRecord> {
self.trade_history.read().clone()
}
pub fn get_closed_trades(&self) -> Vec<ClosedTrade> {
self.closed_trades.read().clone()
}
pub fn get_open_trades(&self) -> Vec<TradeRecord> {
let mut all_open_trades = Vec::new();
for entry in self.open_trades.iter() {
all_open_trades.extend(entry.value().clone());
}
all_open_trades
}
pub fn get_trade_count(&self) -> usize {
self.trade_history.read().len()
}
pub fn get_closed_trade_count(&self) -> usize {
self.closed_trades.read().len()
}
}

View file

@ -0,0 +1,505 @@
use std::collections::HashMap;
use crate::orderbook::analytics::{OrderBookAnalytics, LiquidityProfile};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BetSizingParameters {
pub signal_strength: f64, // -1 to 1
pub signal_confidence: f64, // 0 to 1
pub market_regime: MarketRegime,
pub volatility: f64,
pub liquidity_score: f64,
pub correlation_exposure: f64,
pub current_drawdown: f64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum MarketRegime {
Trending,
RangeBound,
HighVolatility,
LowVolatility,
Transitioning,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PositionSize {
pub shares: u32,
pub notional_value: f64,
pub percent_of_capital: f64,
pub risk_adjusted_size: f64,
pub sizing_method: String,
pub adjustments: Vec<SizeAdjustment>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SizeAdjustment {
pub reason: String,
pub factor: f64,
}
pub struct BetSizer {
capital: f64,
base_risk_per_trade: f64,
max_position_size: f64,
min_position_size: f64,
use_kelly: bool,
volatility_scaling: bool,
regime_adjustments: HashMap<MarketRegime, f64>,
}
impl BetSizer {
pub fn new(capital: f64, base_risk_per_trade: f64) -> Self {
let mut regime_adjustments = HashMap::new();
regime_adjustments.insert(MarketRegime::Trending, 1.2);
regime_adjustments.insert(MarketRegime::RangeBound, 0.8);
regime_adjustments.insert(MarketRegime::HighVolatility, 0.6);
regime_adjustments.insert(MarketRegime::LowVolatility, 1.1);
regime_adjustments.insert(MarketRegime::Transitioning, 0.7);
Self {
capital,
base_risk_per_trade,
max_position_size: 0.10, // 10% max
min_position_size: 0.001, // 0.1% min
use_kelly: true,
volatility_scaling: true,
regime_adjustments,
}
}
pub fn calculate_position_size(
&self,
params: &BetSizingParameters,
price: f64,
stop_loss: Option<f64>,
historical_performance: Option<&PerformanceStats>,
orderbook_analytics: Option<&OrderBookAnalytics>,
liquidity_profile: Option<&LiquidityProfile>,
) -> PositionSize {
let mut adjustments = Vec::new();
// Start with base position size
let mut risk_fraction = self.base_risk_per_trade;
// 1. Kelly Criterion adjustment
if self.use_kelly {
if let Some(perf) = historical_performance {
let kelly_fraction = self.calculate_kelly_fraction(perf);
let kelly_adjustment = (kelly_fraction / self.base_risk_per_trade).min(2.0).max(0.1);
risk_fraction *= kelly_adjustment;
adjustments.push(SizeAdjustment {
reason: "Kelly Criterion".to_string(),
factor: kelly_adjustment,
});
}
}
// 2. Signal strength adjustment
let signal_adjustment = self.calculate_signal_adjustment(params.signal_strength, params.signal_confidence);
risk_fraction *= signal_adjustment;
adjustments.push(SizeAdjustment {
reason: "Signal Strength".to_string(),
factor: signal_adjustment,
});
// 3. Volatility adjustment
if self.volatility_scaling {
let vol_adjustment = self.calculate_volatility_adjustment(params.volatility);
risk_fraction *= vol_adjustment;
adjustments.push(SizeAdjustment {
reason: "Volatility Scaling".to_string(),
factor: vol_adjustment,
});
}
// 4. Market regime adjustment
if let Some(&regime_factor) = self.regime_adjustments.get(&params.market_regime) {
risk_fraction *= regime_factor;
adjustments.push(SizeAdjustment {
reason: format!("Market Regime ({:?})", params.market_regime),
factor: regime_factor,
});
}
// 5. Liquidity adjustment
let liquidity_adjustment = self.calculate_liquidity_adjustment(
params.liquidity_score,
orderbook_analytics,
liquidity_profile,
);
risk_fraction *= liquidity_adjustment;
adjustments.push(SizeAdjustment {
reason: "Liquidity".to_string(),
factor: liquidity_adjustment,
});
// 6. Correlation adjustment
let correlation_adjustment = self.calculate_correlation_adjustment(params.correlation_exposure);
risk_fraction *= correlation_adjustment;
adjustments.push(SizeAdjustment {
reason: "Correlation Exposure".to_string(),
factor: correlation_adjustment,
});
// 7. Drawdown adjustment
let drawdown_adjustment = self.calculate_drawdown_adjustment(params.current_drawdown);
risk_fraction *= drawdown_adjustment;
adjustments.push(SizeAdjustment {
reason: "Current Drawdown".to_string(),
factor: drawdown_adjustment,
});
// Calculate final position size
let position_value = if let Some(stop_price) = stop_loss {
// Risk-based sizing
let risk_per_share = (price - stop_price).abs();
let risk_amount = self.capital * risk_fraction;
risk_amount / risk_per_share
} else {
// Fixed percentage sizing
self.capital * risk_fraction / price
};
// Apply min/max constraints
let constrained_value = position_value
.min(self.capital * self.max_position_size / price)
.max(self.capital * self.min_position_size / price);
let shares = constrained_value.floor() as u32;
let notional_value = shares as f64 * price;
let percent_of_capital = notional_value / self.capital;
PositionSize {
shares,
notional_value,
percent_of_capital,
risk_adjusted_size: risk_fraction,
sizing_method: if stop_loss.is_some() { "Risk-based".to_string() } else { "Fixed-percentage".to_string() },
adjustments,
}
}
fn calculate_kelly_fraction(&self, perf: &PerformanceStats) -> f64 {
if perf.total_trades < 20 {
return self.base_risk_per_trade; // Not enough data
}
let win_rate = perf.win_rate;
let loss_rate = 1.0 - win_rate;
if perf.avg_loss == 0.0 {
return self.base_risk_per_trade;
}
let win_loss_ratio = perf.avg_win / perf.avg_loss.abs();
// Kelly formula: f = p - q/b
// where p = win probability, q = loss probability, b = win/loss ratio
let kelly = win_rate - (loss_rate / win_loss_ratio);
// Apply Kelly fraction (typically 25% of full Kelly)
let kelly_fraction = kelly * 0.25;
// Ensure it's positive and reasonable
kelly_fraction.max(0.0).min(0.25)
}
fn calculate_signal_adjustment(&self, strength: f64, confidence: f64) -> f64 {
// Strength is -1 to 1, we want 0 to 1
let normalized_strength = (strength.abs() + 1.0) / 2.0;
// Combine strength and confidence
let signal_score = normalized_strength * 0.7 + confidence * 0.3;
// Map to adjustment factor (0.5 to 1.5)
0.5 + signal_score
}
fn calculate_volatility_adjustment(&self, volatility: f64) -> f64 {
// Target volatility (e.g., 15% annualized)
let target_vol = 0.15 / (252.0_f64.sqrt()); // Daily vol
// Inverse volatility scaling
let adjustment = (target_vol / volatility).min(1.5).max(0.5);
adjustment
}
fn calculate_liquidity_adjustment(
&self,
liquidity_score: f64,
orderbook: Option<&OrderBookAnalytics>,
profile: Option<&LiquidityProfile>,
) -> f64 {
let mut adjustment = 1.0;
// Base liquidity score adjustment
if liquidity_score < 0.3 {
adjustment *= 0.5; // Very poor liquidity
} else if liquidity_score < 0.5 {
adjustment *= 0.7;
} else if liquidity_score > 0.8 {
adjustment *= 1.1; // Good liquidity bonus
}
// Orderbook spread adjustment
if let Some(ob) = orderbook {
if ob.spread_bps > 50.0 {
adjustment *= 0.8; // Wide spread penalty
} else if ob.spread_bps < 10.0 {
adjustment *= 1.1; // Tight spread bonus
}
}
// Market impact consideration
if let Some(prof) = profile {
// Check if our typical order size would move the market
let typical_order_value = self.capital * self.base_risk_per_trade;
let impact = prof.calculate_market_impact(typical_order_value, true);
if impact.price_impact > 0.001 { // More than 10 bps impact
adjustment *= (1.0 - impact.price_impact * 10.0).max(0.5);
}
}
adjustment
}
fn calculate_correlation_adjustment(&self, correlation_exposure: f64) -> f64 {
// Reduce size if highly correlated with existing positions
if correlation_exposure > 0.7 {
0.5
} else if correlation_exposure > 0.5 {
0.7
} else if correlation_exposure > 0.3 {
0.9
} else {
1.0
}
}
fn calculate_drawdown_adjustment(&self, current_drawdown: f64) -> f64 {
// Reduce size during drawdowns
if current_drawdown > 0.20 {
0.5 // 50% reduction if in 20%+ drawdown
} else if current_drawdown > 0.10 {
0.7
} else if current_drawdown > 0.05 {
0.85
} else {
1.0
}
}
pub fn calculate_optimal_stop_loss(
&self,
entry_price: f64,
volatility: f64,
support_levels: &[f64],
atr: Option<f64>,
is_long: bool,
) -> f64 {
let mut stop_candidates = Vec::new();
// 1. Volatility-based stop
let vol_stop = if is_long {
entry_price * (1.0 - 2.0 * volatility)
} else {
entry_price * (1.0 + 2.0 * volatility)
};
stop_candidates.push(("Volatility", vol_stop));
// 2. ATR-based stop
if let Some(atr_value) = atr {
let atr_stop = if is_long {
entry_price - 2.0 * atr_value
} else {
entry_price + 2.0 * atr_value
};
stop_candidates.push(("ATR", atr_stop));
}
// 3. Support/Resistance based stop
if !support_levels.is_empty() {
let technical_stop = if is_long {
// Find nearest support below entry
support_levels.iter()
.filter(|&&level| level < entry_price)
.max_by(|a, b| a.partial_cmp(b).unwrap())
.map(|&level| level * 0.995) // Just below support
} else {
// Find nearest resistance above entry
support_levels.iter()
.filter(|&&level| level > entry_price)
.min_by(|a, b| a.partial_cmp(b).unwrap())
.map(|&level| level * 1.005) // Just above resistance
};
if let Some(stop) = technical_stop {
stop_candidates.push(("Technical", stop));
}
}
// 4. Maximum loss stop (e.g., 5% from entry)
let max_loss_stop = if is_long {
entry_price * 0.95
} else {
entry_price * 1.05
};
stop_candidates.push(("MaxLoss", max_loss_stop));
// Choose the most conservative stop (closest to entry)
let optimal_stop = if is_long {
stop_candidates.iter()
.map(|(_, stop)| *stop)
.max_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap_or(vol_stop)
} else {
stop_candidates.iter()
.map(|(_, stop)| *stop)
.min_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap_or(vol_stop)
};
optimal_stop
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PerformanceStats {
pub total_trades: u32,
pub win_rate: f64,
pub avg_win: f64,
pub avg_loss: f64,
pub sharpe_ratio: f64,
pub max_consecutive_losses: u32,
}
/// Dynamic position sizing based on market conditions
pub struct DynamicSizer {
base_sizer: BetSizer,
regime_detector: MarketRegimeDetector,
performance_tracker: PerformanceTracker,
}
impl DynamicSizer {
pub fn new(capital: f64, base_risk: f64) -> Self {
Self {
base_sizer: BetSizer::new(capital, base_risk),
regime_detector: MarketRegimeDetector::new(),
performance_tracker: PerformanceTracker::new(),
}
}
pub fn update_market_data(&mut self, volatility: f64, volume: f64, trend_strength: f64) {
self.regime_detector.update(volatility, volume, trend_strength);
}
pub fn record_trade(&mut self, pnl: f64, holding_period: i64) {
self.performance_tracker.record_trade(pnl, holding_period);
}
pub fn get_sizing_parameters(&self) -> BetSizingParameters {
BetSizingParameters {
signal_strength: 0.0, // To be filled by strategy
signal_confidence: 0.0, // To be filled by strategy
market_regime: self.regime_detector.current_regime(),
volatility: self.regime_detector.current_volatility(),
liquidity_score: 0.0, // To be filled from orderbook
correlation_exposure: 0.0, // To be filled from portfolio
current_drawdown: self.performance_tracker.current_drawdown(),
}
}
}
struct MarketRegimeDetector {
volatility_history: Vec<f64>,
volume_history: Vec<f64>,
trend_history: Vec<f64>,
window_size: usize,
}
impl MarketRegimeDetector {
fn new() -> Self {
Self {
volatility_history: Vec::new(),
volume_history: Vec::new(),
trend_history: Vec::new(),
window_size: 20,
}
}
fn update(&mut self, volatility: f64, volume: f64, trend_strength: f64) {
self.volatility_history.push(volatility);
self.volume_history.push(volume);
self.trend_history.push(trend_strength);
// Keep only recent history
if self.volatility_history.len() > self.window_size {
self.volatility_history.remove(0);
self.volume_history.remove(0);
self.trend_history.remove(0);
}
}
fn current_regime(&self) -> MarketRegime {
if self.volatility_history.len() < 5 {
return MarketRegime::Transitioning;
}
let avg_vol = self.volatility_history.iter().sum::<f64>() / self.volatility_history.len() as f64;
let avg_trend = self.trend_history.iter().sum::<f64>() / self.trend_history.len() as f64;
if avg_vol > 0.02 {
MarketRegime::HighVolatility
} else if avg_vol < 0.01 {
MarketRegime::LowVolatility
} else if avg_trend.abs() > 0.7 {
MarketRegime::Trending
} else if avg_trend.abs() < 0.3 {
MarketRegime::RangeBound
} else {
MarketRegime::Transitioning
}
}
fn current_volatility(&self) -> f64 {
if self.volatility_history.is_empty() {
return 0.015; // Default 1.5% daily vol
}
self.volatility_history.iter().sum::<f64>() / self.volatility_history.len() as f64
}
}
struct PerformanceTracker {
trades: Vec<(f64, i64)>, // (pnl, holding_period)
peak_capital: f64,
current_capital: f64,
}
impl PerformanceTracker {
fn new() -> Self {
Self {
trades: Vec::new(),
peak_capital: 100000.0,
current_capital: 100000.0,
}
}
fn record_trade(&mut self, pnl: f64, holding_period: i64) {
self.trades.push((pnl, holding_period));
self.current_capital += pnl;
if self.current_capital > self.peak_capital {
self.peak_capital = self.current_capital;
}
}
fn current_drawdown(&self) -> f64 {
if self.peak_capital > 0.0 {
(self.peak_capital - self.current_capital) / self.peak_capital
} else {
0.0
}
}
}

View file

@ -0,0 +1,195 @@
pub mod portfolio;
pub mod bet_sizing;
use crate::{Order, Side};
use dashmap::DashMap;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
pub use portfolio::{PortfolioRisk, RiskModel, CorrelationMatrix, ConcentrationMetrics};
pub use bet_sizing::{BetSizer, BetSizingParameters, PositionSize, MarketRegime, DynamicSizer};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RiskLimits {
pub max_position_size: f64,
pub max_order_size: f64,
pub max_daily_loss: f64,
pub max_gross_exposure: f64,
pub max_symbol_exposure: f64,
}
impl Default for RiskLimits {
fn default() -> Self {
Self {
max_position_size: 100_000.0,
max_order_size: 10_000.0,
max_daily_loss: 5_000.0,
max_gross_exposure: 1_000_000.0,
max_symbol_exposure: 50_000.0,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RiskCheckResult {
pub passed: bool,
pub violations: Vec<String>,
pub checks: RiskChecks,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RiskChecks {
pub order_size: bool,
pub position_size: bool,
pub daily_loss: bool,
pub gross_exposure: bool,
pub symbol_exposure: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RiskMetrics {
pub current_exposure: f64,
pub daily_pnl: f64,
pub position_count: usize,
pub gross_exposure: f64,
pub max_position_size: f64,
pub utilization_pct: f64,
}
pub struct RiskEngine {
limits: Arc<RwLock<RiskLimits>>,
symbol_exposures: DashMap<String, f64>,
daily_pnl: Arc<RwLock<f64>>,
}
impl RiskEngine {
pub fn new() -> Self {
Self::with_limits(RiskLimits::default())
}
pub fn with_limits(limits: RiskLimits) -> Self {
Self {
limits: Arc::new(RwLock::new(limits)),
symbol_exposures: DashMap::new(),
daily_pnl: Arc::new(RwLock::new(0.0)),
}
}
pub fn update_limits(&self, new_limits: RiskLimits) {
*self.limits.write() = new_limits;
}
pub fn check_order(&self, order: &Order, current_position: Option<f64>) -> RiskCheckResult {
let mut violations = Vec::new();
let limits = self.limits.read();
// Check order size
if order.quantity > limits.max_order_size {
violations.push(format!(
"Order size {} exceeds limit {}",
order.quantity, limits.max_order_size
));
}
// Check position size after order
let current_pos = current_position.unwrap_or(0.0);
let new_position = match order.side {
Side::Buy => current_pos + order.quantity,
Side::Sell => current_pos - order.quantity,
};
if new_position.abs() > limits.max_position_size {
violations.push(format!(
"Position size {} would exceed limit {}",
new_position.abs(), limits.max_position_size
));
}
// Check symbol exposure
let symbol_exposure = self.symbol_exposures
.get(&order.symbol)
.map(|e| *e)
.unwrap_or(0.0);
let new_exposure = symbol_exposure + order.quantity;
if new_exposure > limits.max_symbol_exposure {
violations.push(format!(
"Symbol exposure {} would exceed limit {}",
new_exposure, limits.max_symbol_exposure
));
}
// Check daily loss
let daily_pnl = *self.daily_pnl.read();
if daily_pnl < -limits.max_daily_loss {
violations.push(format!(
"Daily loss {} exceeds limit {}",
-daily_pnl, limits.max_daily_loss
));
}
// Calculate gross exposure
let gross_exposure = self.calculate_gross_exposure();
if gross_exposure > limits.max_gross_exposure {
violations.push(format!(
"Gross exposure {} exceeds limit {}",
gross_exposure, limits.max_gross_exposure
));
}
RiskCheckResult {
passed: violations.is_empty(),
violations,
checks: RiskChecks {
order_size: order.quantity <= limits.max_order_size,
position_size: new_position.abs() <= limits.max_position_size,
daily_loss: daily_pnl >= -limits.max_daily_loss,
gross_exposure: gross_exposure <= limits.max_gross_exposure,
symbol_exposure: new_exposure <= limits.max_symbol_exposure,
},
}
}
pub fn update_position(&self, symbol: &str, new_position: f64) {
if new_position.abs() < 0.0001 {
self.symbol_exposures.remove(symbol);
} else {
self.symbol_exposures.insert(symbol.to_string(), new_position.abs());
}
}
pub fn update_daily_pnl(&self, pnl_change: f64) {
let mut daily_pnl = self.daily_pnl.write();
*daily_pnl += pnl_change;
}
pub fn reset_daily_metrics(&self) {
*self.daily_pnl.write() = 0.0;
}
fn calculate_gross_exposure(&self) -> f64 {
self.symbol_exposures
.iter()
.map(|entry| *entry.value())
.sum()
}
fn calculate_total_exposure(&self) -> f64 {
self.calculate_gross_exposure()
}
pub fn get_risk_metrics(&self) -> RiskMetrics {
let limits = self.limits.read();
let gross_exposure = self.calculate_gross_exposure();
RiskMetrics {
current_exposure: 0.0,
daily_pnl: *self.daily_pnl.read(),
position_count: self.symbol_exposures.len(),
gross_exposure,
max_position_size: limits.max_position_size,
utilization_pct: (gross_exposure / limits.max_gross_exposure * 100.0).min(100.0),
}
}
}

View file

@ -0,0 +1,533 @@
use std::collections::HashMap;
use nalgebra::{DMatrix, DVector};
use crate::positions::Position;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PortfolioRisk {
pub total_var_95: f64,
pub total_var_99: f64,
pub total_cvar_95: f64,
pub marginal_var: HashMap<String, f64>,
pub component_var: HashMap<String, f64>,
pub correlation_matrix: CorrelationMatrix,
pub concentration_risk: ConcentrationMetrics,
pub stress_test_results: HashMap<String, f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CorrelationMatrix {
pub symbols: Vec<String>,
pub matrix: Vec<Vec<f64>>,
pub average_correlation: f64,
pub max_correlation: (String, String, f64),
pub clustering_score: f64, // How clustered the portfolio is
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConcentrationMetrics {
pub herfindahl_index: f64,
pub effective_number_of_positions: f64,
pub top_5_concentration: f64,
pub sector_concentration: HashMap<String, f64>,
}
#[derive(Debug, Clone)]
pub struct RiskModel {
returns_history: HashMap<String, Vec<f64>>,
lookback_period: usize,
confidence_levels: Vec<f64>,
}
impl RiskModel {
pub fn new(lookback_period: usize) -> Self {
Self {
returns_history: HashMap::new(),
lookback_period,
confidence_levels: vec![0.95, 0.99],
}
}
pub fn update_returns(&mut self, symbol: &str, returns: Vec<f64>) {
self.returns_history.insert(symbol.to_string(), returns);
}
pub fn calculate_portfolio_risk(
&self,
positions: &HashMap<String, Position>,
current_prices: &HashMap<String, f64>,
) -> Result<PortfolioRisk, String> {
// Get active symbols and their weights
let (symbols, weights) = self.get_portfolio_weights(positions, current_prices)?;
// Calculate covariance matrix
let cov_matrix = self.calculate_covariance_matrix(&symbols)?;
// Portfolio variance
let portfolio_variance = self.calculate_portfolio_variance(&weights, &cov_matrix);
let portfolio_vol = portfolio_variance.sqrt();
// VaR calculations
let total_portfolio_value = self.calculate_total_value(positions, current_prices);
let total_var_95 = total_portfolio_value * portfolio_vol * 1.645; // 95% confidence
let total_var_99 = total_portfolio_value * portfolio_vol * 2.326; // 99% confidence
// CVaR (Conditional VaR)
let total_cvar_95 = self.calculate_cvar(&symbols, &weights, 0.95)?;
// Marginal and Component VaR
let (marginal_var, component_var) = self.calculate_var_decomposition(
&symbols,
&weights,
&cov_matrix,
portfolio_vol,
total_portfolio_value,
)?;
// Correlation analysis
let correlation_matrix = self.calculate_correlation_matrix(&symbols)?;
// Concentration metrics
let concentration_risk = self.calculate_concentration_metrics(positions, current_prices);
// Stress tests
let stress_test_results = self.run_stress_tests(positions, current_prices, &cov_matrix)?;
Ok(PortfolioRisk {
total_var_95,
total_var_99,
total_cvar_95,
marginal_var,
component_var,
correlation_matrix,
concentration_risk,
stress_test_results,
})
}
fn get_portfolio_weights(
&self,
positions: &HashMap<String, Position>,
current_prices: &HashMap<String, f64>,
) -> Result<(Vec<String>, DVector<f64>), String> {
let mut symbols = Vec::new();
let mut values = Vec::new();
let mut total_value = 0.0;
for (symbol, position) in positions {
if let Some(&price) = current_prices.get(symbol) {
let position_value = position.quantity * price;
symbols.push(symbol.clone());
values.push(position_value.abs());
total_value += position_value.abs();
}
}
if total_value == 0.0 {
return Err("Portfolio has zero value".to_string());
}
let weights = DVector::from_vec(values.iter().map(|v| v / total_value).collect());
Ok((symbols, weights))
}
fn calculate_covariance_matrix(&self, symbols: &[String]) -> Result<DMatrix<f64>, String> {
let n = symbols.len();
let mut matrix = DMatrix::zeros(n, n);
for (i, symbol_i) in symbols.iter().enumerate() {
for (j, symbol_j) in symbols.iter().enumerate() {
if i <= j {
let cov = self.calculate_covariance(symbol_i, symbol_j)?;
matrix[(i, j)] = cov;
matrix[(j, i)] = cov; // Symmetric
}
}
}
Ok(matrix)
}
fn calculate_covariance(&self, symbol1: &str, symbol2: &str) -> Result<f64, String> {
let returns1 = self.returns_history.get(symbol1)
.ok_or_else(|| format!("No returns data for {}", symbol1))?;
let returns2 = self.returns_history.get(symbol2)
.ok_or_else(|| format!("No returns data for {}", symbol2))?;
if returns1.len() != returns2.len() {
return Err("Mismatched returns length".to_string());
}
let n = returns1.len() as f64;
let mean1 = returns1.iter().sum::<f64>() / n;
let mean2 = returns2.iter().sum::<f64>() / n;
let covariance = returns1.iter()
.zip(returns2.iter())
.map(|(r1, r2)| (r1 - mean1) * (r2 - mean2))
.sum::<f64>() / (n - 1.0);
Ok(covariance)
}
fn calculate_portfolio_variance(&self, weights: &DVector<f64>, cov_matrix: &DMatrix<f64>) -> f64 {
let variance = weights.transpose() * cov_matrix * weights;
variance[(0, 0)]
}
fn calculate_var_decomposition(
&self,
symbols: &[String],
weights: &DVector<f64>,
cov_matrix: &DMatrix<f64>,
portfolio_vol: f64,
total_value: f64,
) -> Result<(HashMap<String, f64>, HashMap<String, f64>), String> {
let mut marginal_var = HashMap::new();
let mut component_var = HashMap::new();
// Marginal VaR = ∂VaR/∂w_i
let cov_weights = cov_matrix * weights;
for (i, symbol) in symbols.iter().enumerate() {
let marginal = (cov_weights[i] / portfolio_vol) * 1.645 * total_value;
let component = marginal * weights[i];
marginal_var.insert(symbol.clone(), marginal);
component_var.insert(symbol.clone(), component);
}
Ok((marginal_var, component_var))
}
fn calculate_cvar(&self, symbols: &[String], weights: &DVector<f64>, confidence: f64) -> Result<f64, String> {
// Simulate portfolio returns
let portfolio_returns = self.simulate_portfolio_returns(symbols, weights, 10000)?;
// Sort returns
let mut sorted_returns = portfolio_returns.clone();
sorted_returns.sort_by(|a, b| a.partial_cmp(b).unwrap());
// Find VaR threshold
let var_index = ((1.0 - confidence) * sorted_returns.len() as f64) as usize;
let var_threshold = sorted_returns[var_index];
// Calculate expected loss beyond VaR
let tail_losses: Vec<f64> = sorted_returns.iter()
.take(var_index)
.cloned()
.collect();
if tail_losses.is_empty() {
return Ok(0.0);
}
let cvar = -tail_losses.iter().sum::<f64>() / tail_losses.len() as f64;
Ok(cvar)
}
fn simulate_portfolio_returns(
&self,
symbols: &[String],
weights: &DVector<f64>,
num_simulations: usize,
) -> Result<Vec<f64>, String> {
let mut portfolio_returns = Vec::with_capacity(num_simulations);
// Use historical simulation
let min_length = symbols.iter()
.map(|s| self.returns_history.get(s).map(|r| r.len()).unwrap_or(0))
.min()
.unwrap_or(0);
if min_length == 0 {
return Err("No returns data available".to_string());
}
// Bootstrap from historical returns
use rand::prelude::*;
let mut rng = thread_rng();
for _ in 0..num_simulations {
let idx = rng.gen_range(0..min_length);
let mut portfolio_return = 0.0;
for (i, symbol) in symbols.iter().enumerate() {
if let Some(returns) = self.returns_history.get(symbol) {
portfolio_return += weights[i] * returns[idx];
}
}
portfolio_returns.push(portfolio_return);
}
Ok(portfolio_returns)
}
fn calculate_correlation_matrix(&self, symbols: &[String]) -> Result<CorrelationMatrix, String> {
let n = symbols.len();
let mut matrix = vec![vec![0.0; n]; n];
let mut sum_correlation = 0.0;
let mut count = 0;
let mut max_correlation = ("".to_string(), "".to_string(), 0.0);
for (i, symbol_i) in symbols.iter().enumerate() {
for (j, symbol_j) in symbols.iter().enumerate() {
if i <= j {
let corr = if i == j {
1.0
} else {
self.calculate_correlation(symbol_i, symbol_j)?
};
matrix[i][j] = corr;
matrix[j][i] = corr;
if i != j {
sum_correlation += corr.abs();
count += 1;
if corr.abs() > max_correlation.2 {
max_correlation = (symbol_i.clone(), symbol_j.clone(), corr);
}
}
}
}
}
let average_correlation = if count > 0 {
sum_correlation / count as f64
} else {
0.0
};
// Calculate clustering score (higher = more clustered)
let clustering_score = self.calculate_clustering_score(&matrix);
Ok(CorrelationMatrix {
symbols: symbols.to_vec(),
matrix,
average_correlation,
max_correlation,
clustering_score,
})
}
fn calculate_correlation(&self, symbol1: &str, symbol2: &str) -> Result<f64, String> {
let returns1 = self.returns_history.get(symbol1)
.ok_or_else(|| format!("No returns data for {}", symbol1))?;
let returns2 = self.returns_history.get(symbol2)
.ok_or_else(|| format!("No returns data for {}", symbol2))?;
let cov = self.calculate_covariance(symbol1, symbol2)?;
let std1 = self.calculate_std_dev(returns1);
let std2 = self.calculate_std_dev(returns2);
if std1 == 0.0 || std2 == 0.0 {
return Ok(0.0);
}
Ok(cov / (std1 * std2))
}
fn calculate_std_dev(&self, returns: &[f64]) -> f64 {
let n = returns.len() as f64;
let mean = returns.iter().sum::<f64>() / n;
let variance = returns.iter()
.map(|r| (r - mean).powi(2))
.sum::<f64>() / (n - 1.0);
variance.sqrt()
}
fn calculate_clustering_score(&self, correlation_matrix: &[Vec<f64>]) -> f64 {
// Use average linkage clustering metric
let n = correlation_matrix.len();
if n < 2 {
return 0.0;
}
let mut cluster_sum = 0.0;
let mut cluster_count = 0;
// Look for groups of highly correlated assets
for i in 0..n {
for j in i+1..n {
for k in j+1..n {
let corr_ij = correlation_matrix[i][j].abs();
let corr_ik = correlation_matrix[i][k].abs();
let corr_jk = correlation_matrix[j][k].abs();
// If all three are highly correlated, they form a cluster
let min_corr = corr_ij.min(corr_ik).min(corr_jk);
if min_corr > 0.5 {
cluster_sum += min_corr;
cluster_count += 1;
}
}
}
}
if cluster_count > 0 {
cluster_sum / cluster_count as f64
} else {
0.0
}
}
fn calculate_concentration_metrics(
&self,
positions: &HashMap<String, Position>,
current_prices: &HashMap<String, f64>,
) -> ConcentrationMetrics {
let mut position_values: Vec<(String, f64)> = Vec::new();
let mut total_value = 0.0;
for (symbol, position) in positions {
if let Some(&price) = current_prices.get(symbol) {
let value = (position.quantity * price).abs();
position_values.push((symbol.clone(), value));
total_value += value;
}
}
// Sort by value descending
position_values.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
// Herfindahl Index
let herfindahl_index = position_values.iter()
.map(|(_, value)| {
let weight = value / total_value;
weight * weight
})
.sum();
// Effective number of positions
let effective_number_of_positions = if herfindahl_index > 0.0 {
1.0 / herfindahl_index
} else {
0.0
};
// Top 5 concentration
let top_5_value: f64 = position_values.iter()
.take(5)
.map(|(_, value)| value)
.sum();
let top_5_concentration = top_5_value / total_value;
// Sector concentration (simplified - would need sector mapping)
let sector_concentration = HashMap::new(); // TODO: Implement with sector data
ConcentrationMetrics {
herfindahl_index,
effective_number_of_positions,
top_5_concentration,
sector_concentration,
}
}
fn calculate_total_value(
&self,
positions: &HashMap<String, Position>,
current_prices: &HashMap<String, f64>,
) -> f64 {
positions.iter()
.filter_map(|(symbol, position)| {
current_prices.get(symbol).map(|&price| (position.quantity * price).abs())
})
.sum()
}
fn run_stress_tests(
&self,
positions: &HashMap<String, Position>,
current_prices: &HashMap<String, f64>,
cov_matrix: &DMatrix<f64>,
) -> Result<HashMap<String, f64>, String> {
let mut results = HashMap::new();
// Market crash scenario
let market_crash_loss = self.calculate_scenario_loss(
positions,
current_prices,
-0.20, // 20% market drop
);
results.insert("market_crash_20pct".to_string(), market_crash_loss);
// Flight to quality
let flight_to_quality = self.calculate_correlation_stress(
positions,
current_prices,
cov_matrix,
0.9, // High correlation scenario
)?;
results.insert("flight_to_quality".to_string(), flight_to_quality);
// Volatility spike
let vol_spike = self.calculate_volatility_stress(
positions,
current_prices,
cov_matrix,
2.0, // Double volatility
)?;
results.insert("volatility_spike_2x".to_string(), vol_spike);
Ok(results)
}
fn calculate_scenario_loss(
&self,
positions: &HashMap<String, Position>,
current_prices: &HashMap<String, f64>,
shock: f64,
) -> f64 {
let current_value = self.calculate_total_value(positions, current_prices);
current_value * shock.abs()
}
fn calculate_correlation_stress(
&self,
positions: &HashMap<String, Position>,
current_prices: &HashMap<String, f64>,
cov_matrix: &DMatrix<f64>,
target_correlation: f64,
) -> Result<f64, String> {
// Adjust correlation matrix to stress scenario
let n = cov_matrix.nrows();
let mut stressed_cov = cov_matrix.clone();
for i in 0..n {
for j in 0..n {
if i != j {
let current_corr = cov_matrix[(i, j)] / (cov_matrix[(i, i)].sqrt() * cov_matrix[(j, j)].sqrt());
let stress_factor = target_correlation / current_corr.abs().max(0.1);
stressed_cov[(i, j)] *= stress_factor;
}
}
}
let (_, weights) = self.get_portfolio_weights(positions, current_prices)?;
let stressed_variance = self.calculate_portfolio_variance(&weights, &stressed_cov);
let stressed_vol = stressed_variance.sqrt();
Ok(self.calculate_total_value(positions, current_prices) * stressed_vol * 2.326) // 99% VaR
}
fn calculate_volatility_stress(
&self,
positions: &HashMap<String, Position>,
current_prices: &HashMap<String, f64>,
cov_matrix: &DMatrix<f64>,
vol_multiplier: f64,
) -> Result<f64, String> {
let stressed_cov = cov_matrix * vol_multiplier.powi(2);
let (_, weights) = self.get_portfolio_weights(positions, current_prices)?;
let stressed_variance = self.calculate_portfolio_variance(&weights, &stressed_cov);
let stressed_vol = stressed_variance.sqrt();
Ok(self.calculate_total_value(positions, current_prices) * stressed_vol * 2.326) // 99% VaR
}
}

View file

@ -0,0 +1,114 @@
use async_trait::async_trait;
use serde_json::Value;
use crate::domain::{MarketUpdate, Fill, Order, Side};
use crate::indicators::IndicatorSet;
use std::collections::HashMap;
/// Context provided to strategies
pub struct StrategyContext {
pub portfolio_value: f64,
pub cash: f64,
pub positions: HashMap<String, f64>,
pub pending_orders: Vec<Order>,
pub indicators: IndicatorSet,
}
/// Signal generated by a strategy
#[derive(Debug, Clone)]
pub struct Signal {
pub symbol: String,
pub action: SignalAction,
pub quantity: Option<f64>,
pub confidence: f64, // 0.0 to 1.0
pub reason: String,
pub metadata: Option<Value>,
}
#[derive(Debug, Clone)]
pub enum SignalAction {
Buy,
Sell,
Close,
Hold,
}
/// Core strategy trait with lifecycle methods
#[async_trait]
pub trait Strategy: Send + Sync {
/// Called once when strategy is initialized
async fn init(&mut self, context: &StrategyContext) -> Result<(), String>;
/// Called when trading starts
async fn on_start(&mut self) -> Result<(), String>;
/// Called for each market data update
async fn on_data(&mut self, data: &MarketUpdate, context: &StrategyContext) -> Vec<Signal>;
/// Called when an order is filled
async fn on_fill(&mut self, order_id: &str, fill: &Fill, context: &StrategyContext);
/// Called periodically (e.g., every minute)
async fn on_timer(&mut self, context: &StrategyContext) -> Vec<Signal> {
Vec::new() // Default: no signals on timer
}
/// Called when trading stops
async fn on_stop(&mut self) -> Result<(), String>;
/// Get strategy name
fn name(&self) -> &str;
/// Get strategy parameters
fn parameters(&self) -> Value;
/// Update strategy parameters
fn update_parameters(&mut self, params: Value) -> Result<(), String>;
}
/// Base implementation that strategies can extend
pub struct BaseStrategy {
name: String,
parameters: Value,
}
impl BaseStrategy {
pub fn new(name: String, parameters: Value) -> Self {
Self { name, parameters }
}
}
#[async_trait]
impl Strategy for BaseStrategy {
async fn init(&mut self, _context: &StrategyContext) -> Result<(), String> {
Ok(())
}
async fn on_start(&mut self) -> Result<(), String> {
Ok(())
}
async fn on_data(&mut self, _data: &MarketUpdate, _context: &StrategyContext) -> Vec<Signal> {
Vec::new()
}
async fn on_fill(&mut self, _order_id: &str, _fill: &Fill, _context: &StrategyContext) {
// Default: no action
}
async fn on_stop(&mut self) -> Result<(), String> {
Ok(())
}
fn name(&self) -> &str {
&self.name
}
fn parameters(&self) -> Value {
self.parameters.clone()
}
fn update_parameters(&mut self, params: Value) -> Result<(), String> {
self.parameters = params;
Ok(())
}
}

View file

@ -0,0 +1,210 @@
use std::collections::HashMap;
use chrono::{DateTime, Utc};
use serde_json::json;
use crate::{
MarketUpdate, MarketDataType,
backtest::{Strategy, Signal, SignalType},
};
/// Mean Reversion Strategy
///
/// This strategy identifies when a security's price deviates significantly from its
/// moving average and trades on the assumption that it will revert back to the mean.
///
/// Entry Signals:
/// - BUY when price falls below (MA - threshold * std_dev)
/// - SELL when price rises above (MA + threshold * std_dev)
///
/// Exit Signals:
/// - Exit long when price reaches MA
/// - Exit short when price reaches MA
pub struct MeanReversionStrategy {
name: String,
id: String,
lookback_period: usize,
entry_threshold: f64, // Number of standard deviations
exit_threshold: f64, // Typically 0 (at the mean)
position_size: f64,
// State
price_history: HashMap<String, Vec<f64>>,
positions: HashMap<String, f64>,
}
impl MeanReversionStrategy {
pub fn new(
name: String,
id: String,
lookback_period: usize,
entry_threshold: f64,
position_size: f64,
) -> Self {
Self {
name,
id,
lookback_period,
entry_threshold,
exit_threshold: 0.0, // Exit at the mean
position_size,
price_history: HashMap::new(),
positions: HashMap::new(),
}
}
fn calculate_mean(prices: &[f64]) -> f64 {
prices.iter().sum::<f64>() / prices.len() as f64
}
fn calculate_std_dev(prices: &[f64], mean: f64) -> f64 {
let variance = prices.iter()
.map(|p| (p - mean).powi(2))
.sum::<f64>() / prices.len() as f64;
variance.sqrt()
}
}
impl Strategy for MeanReversionStrategy {
fn on_market_data(&mut self, data: &MarketUpdate) -> Vec<Signal> {
let mut signals = Vec::new();
// Only process bar data
if let MarketDataType::Bar(bar) = &data.data {
let symbol = &data.symbol;
let price = bar.close;
// Update price history
let history = self.price_history.entry(symbol.clone()).or_insert_with(Vec::new);
history.push(price);
// Keep only necessary history
if history.len() > self.lookback_period {
history.remove(0);
}
// Need enough data
if history.len() >= self.lookback_period {
// Calculate statistics
let mean = Self::calculate_mean(history);
let std_dev = Self::calculate_std_dev(history, mean);
// Calculate bands
let upper_band = mean + self.entry_threshold * std_dev;
let lower_band = mean - self.entry_threshold * std_dev;
let current_position = self.positions.get(symbol).copied().unwrap_or(0.0);
// Check for entry signals
if current_position == 0.0 {
if price < lower_band {
// Price is oversold, buy
signals.push(Signal {
symbol: symbol.clone(),
signal_type: SignalType::Buy,
strength: 1.0,
quantity: Some(self.position_size),
reason: Some(format!(
"Mean reversion buy: price ${:.2} < lower band ${:.2} (mean: ${:.2}, std: ${:.2})",
price, lower_band, mean, std_dev
)),
metadata: Some(json!({
"mean": mean,
"std_dev": std_dev,
"upper_band": upper_band,
"lower_band": lower_band,
"price": price,
})),
});
self.positions.insert(symbol.clone(), self.position_size);
} else if price > upper_band {
// Price is overbought, sell short
signals.push(Signal {
symbol: symbol.clone(),
signal_type: SignalType::Sell,
strength: 1.0,
quantity: Some(self.position_size),
reason: Some(format!(
"Mean reversion sell: price ${:.2} > upper band ${:.2} (mean: ${:.2}, std: ${:.2})",
price, upper_band, mean, std_dev
)),
metadata: Some(json!({
"mean": mean,
"std_dev": std_dev,
"upper_band": upper_band,
"lower_band": lower_band,
"price": price,
})),
});
self.positions.insert(symbol.clone(), -self.position_size);
}
}
// Check for exit signals
else if current_position > 0.0 {
// We're long, exit when price crosses above mean (not just touches)
// or when we hit stop loss
let stop_loss = lower_band - std_dev; // Stop loss below lower band
if price >= mean * 1.01 || price <= stop_loss {
signals.push(Signal {
symbol: symbol.clone(),
signal_type: SignalType::Sell,
strength: 1.0,
quantity: Some(current_position),
reason: Some(format!(
"Mean reversion exit long: price ${:.2} {} mean ${:.2}",
price, if price >= mean * 1.01 { "crossed above" } else { "hit stop loss below" }, mean
)),
metadata: Some(json!({
"mean": mean,
"price": price,
"exit_type": "mean_reversion",
})),
});
self.positions.remove(symbol);
}
} else if current_position < 0.0 {
// We're short, exit when price crosses below mean
// or when we hit stop loss
let stop_loss = upper_band + std_dev; // Stop loss above upper band
if price <= mean * 0.99 || price >= stop_loss {
signals.push(Signal {
symbol: symbol.clone(),
signal_type: SignalType::Buy,
strength: 1.0,
quantity: Some(current_position.abs()),
reason: Some(format!(
"Mean reversion exit short: price ${:.2} {} mean ${:.2}",
price, if price <= mean * 0.99 { "crossed below" } else { "hit stop loss above" }, mean
)),
metadata: Some(json!({
"mean": mean,
"price": price,
"exit_type": "mean_reversion",
})),
});
self.positions.remove(symbol);
}
}
}
}
signals
}
fn on_fill(&mut self, symbol: &str, quantity: f64, price: f64, side: &str) {
// Position tracking is handled in on_market_data for simplicity
eprintln!("Mean reversion fill: {} {} @ {} - {}", quantity, symbol, price, side);
}
fn get_name(&self) -> &str {
&self.name
}
fn get_parameters(&self) -> serde_json::Value {
json!({
"lookback_period": self.lookback_period,
"entry_threshold": self.entry_threshold,
"exit_threshold": self.exit_threshold,
"position_size": self.position_size,
})
}
}

View file

@ -0,0 +1,268 @@
use std::collections::HashMap;
use serde_json::json;
use crate::{
MarketUpdate, MarketDataType,
backtest::{Strategy, Signal, SignalType},
};
/// Fixed Mean Reversion Strategy that properly tracks positions
///
/// This version doesn't maintain its own position tracking but relies
/// on the position information passed through on_fill callbacks
pub struct MeanReversionFixedStrategy {
name: String,
id: String,
lookback_period: usize,
entry_threshold: f64, // Number of standard deviations
exit_threshold: f64, // Exit when price moves back this fraction toward mean
position_size: f64,
// State
price_history: HashMap<String, Vec<f64>>,
current_positions: HashMap<String, f64>, // Track actual positions from fills
entry_prices: HashMap<String, f64>, // Track entry prices for exit decisions
}
impl MeanReversionFixedStrategy {
pub fn new(
name: String,
id: String,
lookback_period: usize,
entry_threshold: f64,
position_size: f64,
) -> Self {
Self {
name,
id,
lookback_period,
entry_threshold,
exit_threshold: 0.3, // Exit when price moves 30% back to mean
position_size,
price_history: HashMap::new(),
current_positions: HashMap::new(),
entry_prices: HashMap::new(),
}
}
fn calculate_mean(prices: &[f64]) -> f64 {
prices.iter().sum::<f64>() / prices.len() as f64
}
fn calculate_std_dev(prices: &[f64], mean: f64) -> f64 {
let variance = prices.iter()
.map(|p| (p - mean).powi(2))
.sum::<f64>() / prices.len() as f64;
variance.sqrt()
}
}
impl Strategy for MeanReversionFixedStrategy {
fn on_market_data(&mut self, data: &MarketUpdate) -> Vec<Signal> {
let mut signals = Vec::new();
// Only process bar data
if let MarketDataType::Bar(bar) = &data.data {
let symbol = &data.symbol;
let price = bar.close;
// Update price history
let history = self.price_history.entry(symbol.clone()).or_insert_with(Vec::new);
history.push(price);
// Keep only necessary history
if history.len() > self.lookback_period {
history.remove(0);
}
// Need enough data
if history.len() >= self.lookback_period {
// Calculate statistics
let mean = Self::calculate_mean(history);
let std_dev = Self::calculate_std_dev(history, mean);
// Calculate bands
let upper_band = mean + self.entry_threshold * std_dev;
let lower_band = mean - self.entry_threshold * std_dev;
// Get actual position from our tracking
let current_position = self.current_positions.get(symbol).copied().unwrap_or(0.0);
// Entry signals - allow pyramiding up to 3x base position
let max_long_position = self.position_size * 3.0;
let max_short_position = -self.position_size * 3.0;
if price < lower_band && current_position < max_long_position {
// Price is oversold, buy (or add to long position)
let remaining_capacity = max_long_position - current_position;
let trade_size = self.position_size.min(remaining_capacity);
if trade_size > 0.0 {
eprintln!("Mean reversion: {} oversold at ${:.2}, buying {} shares (current: {}, lower band: ${:.2}, mean: ${:.2})",
symbol, price, trade_size, current_position, lower_band, mean);
signals.push(Signal {
symbol: symbol.clone(),
signal_type: SignalType::Buy,
strength: 1.0,
quantity: Some(trade_size),
reason: Some(format!(
"Mean reversion buy: price ${:.2} < lower band ${:.2} (mean: ${:.2}, std: ${:.2})",
price, lower_band, mean, std_dev
)),
metadata: Some(json!({
"mean": mean,
"std_dev": std_dev,
"upper_band": upper_band,
"lower_band": lower_band,
"price": price,
})),
});
}
} else if price > upper_band && current_position > max_short_position {
// Price is overbought, sell short (or add to short position)
let remaining_capacity = current_position - max_short_position;
let trade_size = self.position_size.min(remaining_capacity);
if trade_size > 0.0 {
eprintln!("Mean reversion: {} overbought at ${:.2}, selling {} shares short (current: {}, upper band: ${:.2}, mean: ${:.2})",
symbol, price, trade_size, current_position, upper_band, mean);
signals.push(Signal {
symbol: symbol.clone(),
signal_type: SignalType::Sell,
strength: 1.0,
quantity: Some(trade_size),
reason: Some(format!(
"Mean reversion sell: price ${:.2} > upper band ${:.2} (mean: ${:.2}, std: ${:.2})",
price, upper_band, mean, std_dev
)),
metadata: Some(json!({
"mean": mean,
"std_dev": std_dev,
"upper_band": upper_band,
"lower_band": lower_band,
"price": price,
})),
});
}
}
// Exit signals based on current position
if current_position > 0.0 {
// We're long - check exit conditions
let entry_price = self.entry_prices.get(symbol).copied().unwrap_or(price);
let target_price = entry_price + (mean - entry_price) * self.exit_threshold;
let stop_loss = lower_band - std_dev; // Stop loss below lower band
// Exit if price reaches target or stop loss
if price >= target_price || price <= stop_loss {
let exit_reason = if price >= target_price { "target" } else { "stop_loss" };
eprintln!("Mean reversion: {} exit long at ${:.2} ({}), closing {} shares",
symbol, price, exit_reason, current_position);
signals.push(Signal {
symbol: symbol.clone(),
signal_type: SignalType::Sell,
strength: 1.0,
quantity: Some(current_position),
reason: Some(format!(
"Mean reversion exit long: price ${:.2} reached target ${:.2} (entry: ${:.2})",
price, target_price, entry_price
)),
metadata: Some(json!({
"mean": mean,
"price": price,
"entry_price": entry_price,
"target_price": target_price,
"stop_loss": stop_loss,
"exit_type": exit_reason,
})),
});
}
} else if current_position < 0.0 {
// We're short - check exit conditions
let entry_price = self.entry_prices.get(symbol).copied().unwrap_or(price);
let target_price = entry_price - (entry_price - mean) * self.exit_threshold;
let stop_loss = upper_band + std_dev; // Stop loss above upper band
// Exit if price reaches target or stop loss
if price <= target_price || price >= stop_loss {
let exit_reason = if price <= target_price { "target" } else { "stop_loss" };
eprintln!("Mean reversion: {} exit short at ${:.2} ({}), covering {} shares",
symbol, price, exit_reason, current_position.abs());
signals.push(Signal {
symbol: symbol.clone(),
signal_type: SignalType::Buy,
strength: 1.0,
quantity: Some(current_position.abs()),
reason: Some(format!(
"Mean reversion exit short: price ${:.2} {} (entry: ${:.2}, target: ${:.2}, stop: ${:.2})",
price, exit_reason, entry_price, target_price, stop_loss
)),
metadata: Some(json!({
"mean": mean,
"price": price,
"entry_price": entry_price,
"target_price": target_price,
"stop_loss": stop_loss,
"exit_type": exit_reason,
})),
});
}
}
}
}
signals
}
fn on_fill(&mut self, symbol: &str, quantity: f64, price: f64, side: &str) {
// Update our position tracking based on actual fills
let current = self.current_positions.get(symbol).copied().unwrap_or(0.0);
let new_position = if side == "buy" {
current + quantity
} else {
current - quantity
};
eprintln!("Mean reversion fill: {} {} @ {} - {}, position: {} -> {}",
quantity, symbol, price, side, current, new_position);
// Update position
if new_position.abs() < 0.001 {
// Position closed
self.current_positions.remove(symbol);
self.entry_prices.remove(symbol);
eprintln!("Position closed for {}", symbol);
} else {
self.current_positions.insert(symbol.to_string(), new_position);
// Track average entry price
if current.abs() < 0.001 {
// New position - set initial entry price
self.entry_prices.insert(symbol.to_string(), price);
} else if (current > 0.0 && new_position > current) || (current < 0.0 && new_position < current) {
// Adding to existing position - update average entry price
let old_price = self.entry_prices.get(symbol).copied().unwrap_or(price);
let avg_price = (old_price * current.abs() + price * quantity) / new_position.abs();
self.entry_prices.insert(symbol.to_string(), avg_price);
eprintln!("Updated avg entry price for {}: ${:.2}", symbol, avg_price);
}
}
}
fn get_name(&self) -> &str {
&self.name
}
fn get_parameters(&self) -> serde_json::Value {
json!({
"lookback_period": self.lookback_period,
"entry_threshold": self.entry_threshold,
"exit_threshold": self.exit_threshold,
"position_size": self.position_size,
})
}
}

View file

@ -0,0 +1,9 @@
pub mod mean_reversion;
pub mod mean_reversion_fixed;
pub mod momentum;
pub mod pairs_trading;
pub use mean_reversion::MeanReversionStrategy;
pub use mean_reversion_fixed::MeanReversionFixedStrategy;
pub use momentum::MomentumStrategy;
pub use pairs_trading::PairsTradingStrategy;

View file

@ -0,0 +1,228 @@
use std::collections::HashMap;
use chrono::{DateTime, Utc};
use serde_json::json;
use crate::{
MarketUpdate, MarketDataType,
backtest::{Strategy, Signal, SignalType},
};
/// Momentum Strategy
///
/// This strategy trades based on momentum indicators like rate of change (ROC) and
/// relative strength. It aims to capture trends by buying securities showing
/// upward momentum and selling those showing downward momentum.
///
/// Entry Signals:
/// - BUY when momentum crosses above threshold and accelerating
/// - SELL when momentum crosses below -threshold and decelerating
///
/// Exit Signals:
/// - Exit long when momentum turns negative
/// - Exit short when momentum turns positive
pub struct MomentumStrategy {
name: String,
id: String,
lookback_period: usize,
momentum_threshold: f64,
position_size: f64,
use_acceleration: bool,
// State
price_history: HashMap<String, Vec<f64>>,
momentum_history: HashMap<String, Vec<f64>>,
positions: HashMap<String, f64>,
}
impl MomentumStrategy {
pub fn new(
name: String,
id: String,
lookback_period: usize,
momentum_threshold: f64,
position_size: f64,
) -> Self {
Self {
name,
id,
lookback_period,
momentum_threshold,
position_size,
use_acceleration: true,
price_history: HashMap::new(),
momentum_history: HashMap::new(),
positions: HashMap::new(),
}
}
fn calculate_momentum(prices: &[f64], lookback_period: usize) -> f64 {
if prices.len() < 2 {
return 0.0;
}
let current = prices.last().unwrap();
let past = prices[prices.len() - lookback_period.min(prices.len())];
((current - past) / past) * 100.0
}
fn calculate_acceleration(momentum_values: &[f64]) -> f64 {
if momentum_values.len() < 2 {
return 0.0;
}
let current = momentum_values.last().unwrap();
let previous = momentum_values[momentum_values.len() - 2];
current - previous
}
}
impl Strategy for MomentumStrategy {
fn on_market_data(&mut self, data: &MarketUpdate) -> Vec<Signal> {
let mut signals = Vec::new();
// Only process bar data
if let MarketDataType::Bar(bar) = &data.data {
let symbol = &data.symbol;
let price = bar.close;
// Update price history
let history = self.price_history.entry(symbol.clone()).or_insert_with(Vec::new);
history.push(price);
// Keep reasonable history
if history.len() > self.lookback_period * 2 {
history.remove(0);
}
// Need enough data
if history.len() >= self.lookback_period {
// Calculate momentum
let momentum = Self::calculate_momentum(history, self.lookback_period);
// Update momentum history
let mom_history = self.momentum_history.entry(symbol.clone()).or_insert_with(Vec::new);
mom_history.push(momentum);
if mom_history.len() > 5 {
mom_history.remove(0);
}
// Calculate acceleration if enabled
let acceleration = if self.use_acceleration && mom_history.len() >= 2 {
Self::calculate_acceleration(mom_history)
} else {
0.0
};
let current_position = self.positions.get(symbol).copied().unwrap_or(0.0);
// Check for entry signals
if current_position == 0.0 {
// Long entry: strong positive momentum and accelerating
if momentum > self.momentum_threshold &&
(!self.use_acceleration || acceleration > 0.0) {
signals.push(Signal {
symbol: symbol.clone(),
signal_type: SignalType::Buy,
strength: momentum / 100.0, // Normalize strength
quantity: Some(self.position_size),
reason: Some(format!(
"Momentum buy: momentum {:.2}% > threshold {:.2}%, accel: {:.2}",
momentum, self.momentum_threshold, acceleration
)),
metadata: Some(json!({
"momentum": momentum,
"acceleration": acceleration,
"price": price,
})),
});
self.positions.insert(symbol.clone(), self.position_size);
}
// Short entry: strong negative momentum and decelerating
else if momentum < -self.momentum_threshold &&
(!self.use_acceleration || acceleration < 0.0) {
signals.push(Signal {
symbol: symbol.clone(),
signal_type: SignalType::Sell,
strength: momentum.abs() / 100.0,
quantity: Some(self.position_size),
reason: Some(format!(
"Momentum sell: momentum {:.2}% < threshold -{:.2}%, accel: {:.2}",
momentum, self.momentum_threshold, acceleration
)),
metadata: Some(json!({
"momentum": momentum,
"acceleration": acceleration,
"price": price,
})),
});
self.positions.insert(symbol.clone(), -self.position_size);
}
}
// Check for exit signals
else if current_position > 0.0 {
// Exit long when momentum turns negative
if momentum < 0.0 {
signals.push(Signal {
symbol: symbol.clone(),
signal_type: SignalType::Sell,
strength: 1.0,
quantity: Some(current_position),
reason: Some(format!(
"Momentum exit long: momentum turned negative {:.2}%",
momentum
)),
metadata: Some(json!({
"momentum": momentum,
"price": price,
"exit_type": "momentum_reversal",
})),
});
self.positions.remove(symbol);
}
} else if current_position < 0.0 {
// Exit short when momentum turns positive
if momentum > 0.0 {
signals.push(Signal {
symbol: symbol.clone(),
signal_type: SignalType::Buy,
strength: 1.0,
quantity: Some(current_position.abs()),
reason: Some(format!(
"Momentum exit short: momentum turned positive {:.2}%",
momentum
)),
metadata: Some(json!({
"momentum": momentum,
"price": price,
"exit_type": "momentum_reversal",
})),
});
self.positions.remove(symbol);
}
}
}
}
signals
}
fn on_fill(&mut self, symbol: &str, quantity: f64, price: f64, side: &str) {
eprintln!("Momentum fill: {} {} @ {} - {}", quantity, symbol, price, side);
}
fn get_name(&self) -> &str {
&self.name
}
fn get_parameters(&self) -> serde_json::Value {
json!({
"lookback_period": self.lookback_period,
"momentum_threshold": self.momentum_threshold,
"position_size": self.position_size,
"use_acceleration": self.use_acceleration,
})
}
}

View file

@ -0,0 +1,295 @@
use std::collections::HashMap;
use chrono::{DateTime, Utc};
use serde_json::json;
use crate::{
MarketUpdate, MarketDataType,
backtest::{Strategy, Signal, SignalType},
};
/// Pairs Trading Strategy
///
/// This strategy trades the spread between two correlated securities. When the spread
/// deviates from its historical mean, we trade expecting it to revert.
///
/// Entry Signals:
/// - Long pair A, Short pair B when spread < (mean - threshold * std)
/// - Short pair A, Long pair B when spread > (mean + threshold * std)
///
/// Exit Signals:
/// - Exit when spread returns to mean
pub struct PairsTradingStrategy {
name: String,
id: String,
pair_a: String,
pair_b: String,
lookback_period: usize,
entry_threshold: f64, // Number of standard deviations
position_size: f64,
hedge_ratio: f64, // How many shares of B per share of A
// State
price_history_a: Vec<f64>,
price_history_b: Vec<f64>,
spread_history: Vec<f64>,
positions: HashMap<String, f64>,
last_prices: HashMap<String, f64>,
}
impl PairsTradingStrategy {
pub fn new(
name: String,
id: String,
pair_a: String,
pair_b: String,
lookback_period: usize,
entry_threshold: f64,
position_size: f64,
) -> Self {
Self {
name,
id,
pair_a,
pair_b,
lookback_period,
entry_threshold,
position_size,
hedge_ratio: 1.0, // Default 1:1, could be calculated dynamically
price_history_a: Vec::new(),
price_history_b: Vec::new(),
spread_history: Vec::new(),
positions: HashMap::new(),
last_prices: HashMap::new(),
}
}
fn calculate_spread(&self, price_a: f64, price_b: f64) -> f64 {
price_a - self.hedge_ratio * price_b
}
fn calculate_mean(values: &[f64]) -> f64 {
values.iter().sum::<f64>() / values.len() as f64
}
fn calculate_std_dev(values: &[f64], mean: f64) -> f64 {
let variance = values.iter()
.map(|v| (v - mean).powi(2))
.sum::<f64>() / values.len() as f64;
variance.sqrt()
}
}
impl Strategy for PairsTradingStrategy {
fn on_market_data(&mut self, data: &MarketUpdate) -> Vec<Signal> {
let mut signals = Vec::new();
// Only process bar data
if let MarketDataType::Bar(bar) = &data.data {
let symbol = &data.symbol;
let price = bar.close;
// Update last prices
self.last_prices.insert(symbol.clone(), price);
// Update price histories
if symbol == &self.pair_a {
self.price_history_a.push(price);
if self.price_history_a.len() > self.lookback_period {
self.price_history_a.remove(0);
}
} else if symbol == &self.pair_b {
self.price_history_b.push(price);
if self.price_history_b.len() > self.lookback_period {
self.price_history_b.remove(0);
}
}
// Only generate signals when we have prices for both pairs
if let (Some(&price_a), Some(&price_b)) =
(self.last_prices.get(&self.pair_a), self.last_prices.get(&self.pair_b)) {
// Calculate current spread
let spread = self.calculate_spread(price_a, price_b);
// Update spread history
self.spread_history.push(spread);
if self.spread_history.len() > self.lookback_period {
self.spread_history.remove(0);
}
// Need enough data
if self.spread_history.len() >= self.lookback_period {
// Calculate statistics
let mean = Self::calculate_mean(&self.spread_history);
let std_dev = Self::calculate_std_dev(&self.spread_history, mean);
// Calculate bands
let upper_band = mean + self.entry_threshold * std_dev;
let lower_band = mean - self.entry_threshold * std_dev;
let position_a = self.positions.get(&self.pair_a).copied().unwrap_or(0.0);
let position_b = self.positions.get(&self.pair_b).copied().unwrap_or(0.0);
// Check for entry signals
if position_a == 0.0 && position_b == 0.0 {
if spread < lower_band {
// Spread too low: Buy A, Sell B
signals.push(Signal {
symbol: self.pair_a.clone(),
signal_type: SignalType::Buy,
strength: 1.0,
quantity: Some(self.position_size),
reason: Some(format!(
"Pairs trade: spread ${:.2} < lower band ${:.2}",
spread, lower_band
)),
metadata: Some(json!({
"spread": spread,
"mean": mean,
"std_dev": std_dev,
"pair": "A",
})),
});
signals.push(Signal {
symbol: self.pair_b.clone(),
signal_type: SignalType::Sell,
strength: 1.0,
quantity: Some(self.position_size * self.hedge_ratio),
reason: Some(format!(
"Pairs trade hedge: spread ${:.2} < lower band ${:.2}",
spread, lower_band
)),
metadata: Some(json!({
"spread": spread,
"mean": mean,
"std_dev": std_dev,
"pair": "B",
})),
});
self.positions.insert(self.pair_a.clone(), self.position_size);
self.positions.insert(self.pair_b.clone(), -self.position_size * self.hedge_ratio);
} else if spread > upper_band {
// Spread too high: Sell A, Buy B
signals.push(Signal {
symbol: self.pair_a.clone(),
signal_type: SignalType::Sell,
strength: 1.0,
quantity: Some(self.position_size),
reason: Some(format!(
"Pairs trade: spread ${:.2} > upper band ${:.2}",
spread, upper_band
)),
metadata: Some(json!({
"spread": spread,
"mean": mean,
"std_dev": std_dev,
"pair": "A",
})),
});
signals.push(Signal {
symbol: self.pair_b.clone(),
signal_type: SignalType::Buy,
strength: 1.0,
quantity: Some(self.position_size * self.hedge_ratio),
reason: Some(format!(
"Pairs trade hedge: spread ${:.2} > upper band ${:.2}",
spread, upper_band
)),
metadata: Some(json!({
"spread": spread,
"mean": mean,
"std_dev": std_dev,
"pair": "B",
})),
});
self.positions.insert(self.pair_a.clone(), -self.position_size);
self.positions.insert(self.pair_b.clone(), self.position_size * self.hedge_ratio);
}
}
// Check for exit signals
else if position_a != 0.0 && position_b != 0.0 {
// Exit when spread returns to mean
let spread_distance = (spread - mean).abs();
let exit_threshold = std_dev * 0.1; // Exit near mean
if spread_distance < exit_threshold {
// Close positions
if position_a > 0.0 {
// We're long A, short B
signals.push(Signal {
symbol: self.pair_a.clone(),
signal_type: SignalType::Sell,
strength: 1.0,
quantity: Some(position_a),
reason: Some(format!(
"Pairs trade exit: spread ${:.2} returned to mean ${:.2}",
spread, mean
)),
metadata: None,
});
signals.push(Signal {
symbol: self.pair_b.clone(),
signal_type: SignalType::Buy,
strength: 1.0,
quantity: Some(position_b.abs()),
reason: Some("Pairs trade exit: closing hedge".to_string()),
metadata: None,
});
} else {
// We're short A, long B
signals.push(Signal {
symbol: self.pair_a.clone(),
signal_type: SignalType::Buy,
strength: 1.0,
quantity: Some(position_a.abs()),
reason: Some(format!(
"Pairs trade exit: spread ${:.2} returned to mean ${:.2}",
spread, mean
)),
metadata: None,
});
signals.push(Signal {
symbol: self.pair_b.clone(),
signal_type: SignalType::Sell,
strength: 1.0,
quantity: Some(position_b),
reason: Some("Pairs trade exit: closing hedge".to_string()),
metadata: None,
});
}
self.positions.clear();
}
}
}
}
}
signals
}
fn on_fill(&mut self, symbol: &str, quantity: f64, price: f64, side: &str) {
eprintln!("Pairs trading fill: {} {} @ {} - {}", quantity, symbol, price, side);
}
fn get_name(&self) -> &str {
&self.name
}
fn get_parameters(&self) -> serde_json::Value {
json!({
"pair_a": self.pair_a,
"pair_b": self.pair_b,
"lookback_period": self.lookback_period,
"entry_threshold": self.entry_threshold,
"position_size": self.position_size,
"hedge_ratio": self.hedge_ratio,
})
}
}

View file

@ -0,0 +1,143 @@
use engine::{Order, OrderType, Side, TimeInForce, Quote, Bar, Trade, MarketDataType};
use chrono::Utc;
#[test]
fn test_order_creation() {
let order = Order {
id: "test-order-1".to_string(),
symbol: "AAPL".to_string(),
side: Side::Buy,
quantity: 100.0,
order_type: OrderType::Market,
time_in_force: TimeInForce::Day,
};
assert_eq!(order.id, "test-order-1");
assert_eq!(order.symbol, "AAPL");
assert_eq!(order.side, Side::Buy);
assert_eq!(order.quantity, 100.0);
}
#[test]
fn test_order_types() {
let market = OrderType::Market;
let limit = OrderType::Limit { price: 150.0 };
let stop = OrderType::Stop { stop_price: 145.0 };
let stop_limit = OrderType::StopLimit {
stop_price: 145.0,
limit_price: 144.5,
};
match limit {
OrderType::Limit { price } => assert_eq!(price, 150.0),
_ => panic!("Expected limit order"),
}
match stop {
OrderType::Stop { stop_price } => assert_eq!(stop_price, 145.0),
_ => panic!("Expected stop order"),
}
match stop_limit {
OrderType::StopLimit { stop_price, limit_price } => {
assert_eq!(stop_price, 145.0);
assert_eq!(limit_price, 144.5);
}
_ => panic!("Expected stop limit order"),
}
assert!(matches!(market, OrderType::Market));
}
#[test]
fn test_quote_creation() {
let quote = Quote {
bid: 149.95,
ask: 150.05,
bid_size: 1000.0,
ask_size: 800.0,
};
assert_eq!(quote.bid, 149.95);
assert_eq!(quote.ask, 150.05);
assert_eq!(quote.bid_size, 1000.0);
assert_eq!(quote.ask_size, 800.0);
}
#[test]
fn test_bar_creation() {
let bar = Bar {
open: 150.0,
high: 152.0,
low: 149.0,
close: 151.0,
volume: 1_000_000.0,
vwap: Some(150.5),
};
assert_eq!(bar.open, 150.0);
assert_eq!(bar.high, 152.0);
assert_eq!(bar.low, 149.0);
assert_eq!(bar.close, 151.0);
assert_eq!(bar.volume, 1_000_000.0);
assert_eq!(bar.vwap, Some(150.5));
}
#[test]
fn test_trade_creation() {
let trade = Trade {
price: 150.0,
size: 500.0,
side: Side::Buy,
};
assert_eq!(trade.price, 150.0);
assert_eq!(trade.size, 500.0);
assert_eq!(trade.side, Side::Buy);
}
#[test]
fn test_market_data_type() {
let quote = Quote {
bid: 149.95,
ask: 150.05,
bid_size: 100.0,
ask_size: 100.0,
};
let quote_data = MarketDataType::Quote(quote.clone());
match quote_data {
MarketDataType::Quote(q) => {
assert_eq!(q.bid, quote.bid);
assert_eq!(q.ask, quote.ask);
}
_ => panic!("Expected quote data"),
}
let bar = Bar {
open: 150.0,
high: 152.0,
low: 149.0,
close: 151.0,
volume: 10000.0,
vwap: None,
};
let bar_data = MarketDataType::Bar(bar.clone());
match bar_data {
MarketDataType::Bar(b) => {
assert_eq!(b.open, bar.open);
assert_eq!(b.close, bar.close);
}
_ => panic!("Expected bar data"),
}
}
#[test]
fn test_side_equality() {
assert_eq!(Side::Buy, Side::Buy);
assert_eq!(Side::Sell, Side::Sell);
assert_ne!(Side::Buy, Side::Sell);
}

View file

@ -0,0 +1,52 @@
// Simple integration test that verifies basic functionality
use engine::{Order, OrderType, Side, TimeInForce, TradingMode};
use chrono::Utc;
#[test]
fn test_trading_mode_creation() {
let backtest_mode = TradingMode::Backtest {
start_time: Utc::now(),
end_time: Utc::now(),
speed_multiplier: 1.0,
};
match backtest_mode {
TradingMode::Backtest { speed_multiplier, .. } => {
assert_eq!(speed_multiplier, 1.0);
}
_ => panic!("Expected backtest mode"),
}
let paper_mode = TradingMode::Paper {
starting_capital: 100_000.0,
};
match paper_mode {
TradingMode::Paper { starting_capital } => {
assert_eq!(starting_capital, 100_000.0);
}
_ => panic!("Expected paper mode"),
}
}
#[test]
fn test_order_with_limit_price() {
let order = Order {
id: "limit-order-1".to_string(),
symbol: "AAPL".to_string(),
side: Side::Buy,
quantity: 100.0,
order_type: OrderType::Limit { price: 150.0 },
time_in_force: TimeInForce::GTC,
};
assert_eq!(order.symbol, "AAPL");
assert_eq!(order.side, Side::Buy);
assert_eq!(order.quantity, 100.0);
match order.order_type {
OrderType::Limit { price } => assert_eq!(price, 150.0),
_ => panic!("Expected limit order"),
}
}

View file

@ -0,0 +1,7 @@
// Main test file for the engine crate
#[cfg(test)]
mod basic_tests;
#[cfg(test)]
mod integration_test;