moving engine to rust

This commit is contained in:
Boki 2025-07-03 20:10:33 -04:00
parent d14380d740
commit 16ac28a565
16 changed files with 1598 additions and 3 deletions

View file

@ -0,0 +1,187 @@
import { BacktestEngine as RustEngine } from '@stock-bot/core';
import { RustStrategy } from '../strategies/RustStrategy';
import { MarketData, BacktestConfig } from '../types';
import { StorageService } from '../services/StorageService';
import { IServiceContainer } from '@stock-bot/di';
export interface RustBacktestConfig {
name: string;
symbols: string[];
startDate: string;
endDate: string;
initialCapital: number;
commission: number;
slippage: number;
dataFrequency: string;
}
export interface RustBacktestResult {
config: RustBacktestConfig;
metrics: {
totalReturn: number;
totalTrades: number;
profitableTrades: number;
winRate: number;
profitFactor: number;
sharpeRatio: number;
maxDrawdown: number;
totalPnl: number;
avgWin: number;
avgLoss: number;
};
equityCurve: Array<{ date: string; value: number }>;
trades: any[];
finalPositions: Record<string, any>;
}
export class RustBacktestEngine {
private engine: RustEngine;
private container: IServiceContainer;
private storageService: StorageService;
private config: RustBacktestConfig;
constructor(
container: IServiceContainer,
storageService: StorageService,
config: BacktestConfig
) {
this.container = container;
this.storageService = storageService;
// Convert config for Rust
const rustConfig: RustBacktestConfig = {
name: config.name || 'Backtest',
symbols: config.symbols,
startDate: config.startDate,
endDate: config.endDate,
initialCapital: config.initialCapital,
commission: config.commission || 0.001,
slippage: config.slippage || 0.0001,
dataFrequency: config.dataFrequency || '1d',
};
this.config = rustConfig;
this.engine = new RustEngine(rustConfig as any);
}
/**
* Add a TypeScript strategy to the backtest
*/
addStrategy(strategy: RustStrategy): void {
strategy.register(this.engine);
}
/**
* Load historical data for the backtest
*/
async loadData(): Promise<void> {
// Get config from engine
const config = this.getConfig();
const startDate = new Date(config.startDate);
const endDate = new Date(config.endDate);
// Load data for each symbol
for (const symbol of config.symbols) {
const bars = await this.storageService.getHistoricalBars(
symbol,
startDate,
endDate,
config.dataFrequency
);
// Convert to Rust format
const marketData = bars.map(bar => ({
symbol,
timestamp: bar.timestamp.getTime(), // Convert to milliseconds
type: 'bar',
open: bar.open,
high: bar.high,
low: bar.low,
close: bar.close,
volume: bar.volume,
vwap: bar.vwap || (bar.high + bar.low + bar.close) / 3,
}));
// Load into Rust engine
console.log(`Loading ${marketData.length} bars for ${symbol}`);
this.engine.loadMarketData(marketData as any);
}
}
/**
* Run the backtest
*/
async run(): Promise<RustBacktestResult> {
console.log('Starting backtest run...');
// Load data first
await this.loadData();
// Run backtest in Rust
const resultJson = await this.engine.run();
// Parse result and convert snake_case to camelCase
const rustResult = JSON.parse(resultJson);
const result: RustBacktestResult = {
config: rustResult.config,
metrics: {
totalReturn: rustResult.metrics.total_return,
totalTrades: rustResult.metrics.total_trades,
profitableTrades: rustResult.metrics.profitable_trades,
winRate: rustResult.metrics.win_rate,
profitFactor: rustResult.metrics.profit_factor,
sharpeRatio: rustResult.metrics.sharpe_ratio,
maxDrawdown: rustResult.metrics.max_drawdown,
totalPnl: rustResult.metrics.total_pnl,
avgWin: rustResult.metrics.avg_win,
avgLoss: rustResult.metrics.avg_loss,
},
equityCurve: rustResult.equity_curve.map((point: any) => ({
date: point[0],
value: point[1],
})),
trades: rustResult.trades,
finalPositions: rustResult.final_positions,
};
// Log results
this.container.logger.info('Rust backtest completed', {
totalReturn: result.metrics.totalReturn,
totalTrades: result.metrics.totalTrades,
sharpeRatio: result.metrics.sharpeRatio,
});
return result;
}
/**
* Get the backtest configuration
*/
private getConfig(): RustBacktestConfig {
return this.config;
}
}
/**
* Factory function to create a Rust-powered backtest
*/
export async function createRustBacktest(
container: IServiceContainer,
config: BacktestConfig,
strategies: RustStrategy[]
): Promise<RustBacktestResult> {
const storageService = container.custom?.StorageService || new StorageService();
// Create engine
const engine = new RustBacktestEngine(container, storageService, config);
// Add strategies
for (const strategy of strategies) {
engine.addStrategy(strategy);
}
// Run backtest
return await engine.run();
}

View file

@ -0,0 +1,139 @@
import { BacktestEngine } from '@stock-bot/core';
import { MarketData } from '../types';
export interface Signal {
symbol: string;
signal_type: 'Buy' | 'Sell' | 'Close';
strength: number; // -1.0 to 1.0
quantity?: number;
reason?: string;
metadata?: any;
}
export interface StrategyCall {
method: string;
data: any;
}
export interface StrategyResponse {
signals: Signal[];
}
/**
* Base class for TypeScript strategies that run in the Rust backtest engine
*/
export abstract class RustStrategy {
protected name: string;
protected id: string;
protected parameters: Record<string, any>;
protected positions: Map<string, number> = new Map();
constructor(name: string, id: string, parameters: Record<string, any> = {}) {
this.name = name;
this.id = id;
this.parameters = parameters;
}
/**
* Main callback that Rust will call
*/
public handleCall(call: StrategyCall): StrategyResponse {
switch (call.method) {
case 'on_market_data':
const signals = this.onMarketData(call.data);
return { signals };
case 'on_fill':
this.onFill(
call.data.symbol,
call.data.quantity,
call.data.price,
call.data.side
);
return { signals: [] };
default:
return { signals: [] };
}
}
/**
* Called when new market data is received
*/
protected abstract onMarketData(data: MarketData): Signal[];
/**
* Called when an order is filled
*/
protected onFill(symbol: string, quantity: number, price: number, side: string): void {
const currentPosition = this.positions.get(symbol) || 0;
const newPosition = side === 'buy' ?
currentPosition + quantity :
currentPosition - quantity;
if (Math.abs(newPosition) < 0.0001) {
this.positions.delete(symbol);
} else {
this.positions.set(symbol, newPosition);
}
}
/**
* Helper to create a buy signal
*/
protected buySignal(symbol: string, strength: number = 1.0, reason?: string): Signal {
return {
symbol,
signal_type: 'Buy',
strength,
reason,
};
}
/**
* Helper to create a sell signal
*/
protected sellSignal(symbol: string, strength: number = 1.0, reason?: string): Signal {
return {
symbol,
signal_type: 'Sell',
strength,
reason,
};
}
/**
* Helper to create a close position signal
*/
protected closeSignal(symbol: string, reason?: string): Signal {
return {
symbol,
signal_type: 'Close',
strength: 1.0,
reason,
};
}
/**
* Register this strategy with a backtest engine
*/
public register(engine: BacktestEngine): void {
console.log(`Registering strategy ${this.name} with id ${this.id}`);
// Convert the handleCall method to what Rust expects
const callback = (callJson: string) => {
console.log('Strategy callback called with:', callJson);
const call: StrategyCall = JSON.parse(callJson);
const response = this.handleCall(call);
console.log('Strategy response:', response);
return JSON.stringify(response);
};
engine.addTypescriptStrategy(
this.name,
this.id,
this.parameters,
callback as any
);
}
}

View file

@ -0,0 +1,136 @@
import { RustStrategy, Signal } from '../RustStrategy';
import { MarketData } from '../../types';
interface BarData {
open: number;
high: number;
low: number;
close: number;
volume: number;
timestamp: number;
}
export class SimpleMovingAverageCrossoverRust extends RustStrategy {
private priceHistory: Map<string, number[]> = new Map();
private lastCrossover: Map<string, 'golden' | 'death' | null> = new Map();
private barsSinceSignal: Map<string, number> = new Map();
private fastPeriod: number;
private slowPeriod: number;
private minHoldingBars: number;
constructor(parameters: {
fastPeriod?: number;
slowPeriod?: number;
minHoldingBars?: number;
} = {}) {
super('SimpleMovingAverageCrossover', `sma-crossover-${Date.now()}`, parameters);
this.fastPeriod = parameters.fastPeriod || 5;
this.slowPeriod = parameters.slowPeriod || 15;
this.minHoldingBars = parameters.minHoldingBars || 3;
}
protected onMarketData(data: MarketData): Signal[] {
const signals: Signal[] = [];
// Only process bar data
if (data.data.type !== 'bar') {
return signals;
}
const bar = data.data as BarData;
const symbol = data.symbol;
const price = bar.close;
// Update price history
if (!this.priceHistory.has(symbol)) {
this.priceHistory.set(symbol, []);
this.barsSinceSignal.set(symbol, 0);
}
const history = this.priceHistory.get(symbol)!;
history.push(price);
// Keep only necessary history
if (history.length > this.slowPeriod) {
history.shift();
}
// Update bars since signal
const currentBar = this.barsSinceSignal.get(symbol) || 0;
this.barsSinceSignal.set(symbol, currentBar + 1);
// Need enough data
if (history.length < this.slowPeriod) {
return signals;
}
// Calculate moving averages
const fastMA = this.calculateSMA(history, this.fastPeriod);
const slowMA = this.calculateSMA(history, this.slowPeriod);
// Calculate previous MAs
const prevHistory = history.slice(0, -1);
const prevFastMA = this.calculateSMA(prevHistory, this.fastPeriod);
const prevSlowMA = this.calculateSMA(prevHistory, this.slowPeriod);
// Check for crossovers
const currentPosition = this.positions.get(symbol) || 0;
const lastCrossover = this.lastCrossover.get(symbol);
const barsSinceSignal = this.barsSinceSignal.get(symbol) || 0;
// Golden cross - bullish signal
if (prevFastMA <= prevSlowMA && fastMA > slowMA) {
this.lastCrossover.set(symbol, 'golden');
if (currentPosition < 0) {
// Close short position
signals.push(this.closeSignal(symbol, 'Golden cross - Close short'));
} else if (currentPosition === 0 && barsSinceSignal >= this.minHoldingBars) {
// Open long position
signals.push(this.buySignal(symbol, 0.8, 'Golden cross - Open long'));
this.barsSinceSignal.set(symbol, 0);
}
}
// Death cross - bearish signal
else if (prevFastMA >= prevSlowMA && fastMA < slowMA) {
this.lastCrossover.set(symbol, 'death');
if (currentPosition > 0) {
// Close long position
signals.push(this.closeSignal(symbol, 'Death cross - Close long'));
} else if (currentPosition === 0 && barsSinceSignal >= this.minHoldingBars) {
// Open short position
signals.push(this.sellSignal(symbol, 0.8, 'Death cross - Open short'));
this.barsSinceSignal.set(symbol, 0);
}
}
// Trend following - enter on pullbacks
else if (currentPosition === 0 && barsSinceSignal >= this.minHoldingBars) {
if (lastCrossover === 'golden' && fastMA > slowMA) {
// Bullish trend continues
signals.push(this.buySignal(symbol, 0.8, 'Bullish trend - Open long'));
this.barsSinceSignal.set(symbol, 0);
} else if (lastCrossover === 'death' && fastMA < slowMA) {
// Bearish trend continues
signals.push(this.sellSignal(symbol, 0.8, 'Bearish trend - Open short'));
this.barsSinceSignal.set(symbol, 0);
}
}
return signals;
}
private calculateSMA(prices: number[], period: number): number {
if (prices.length < period) {
return 0;
}
const relevantPrices = prices.slice(-period);
const sum = relevantPrices.reduce((a, b) => a + b, 0);
return sum / period;
}
}

View file

@ -0,0 +1,145 @@
import { createRustBacktest } from './src/backtest/RustBacktestEngine';
import { SimpleMovingAverageCrossoverRust } from './src/strategies/rust/SimpleMovingAverageCrossoverRust';
import { IServiceContainer } from '@stock-bot/di';
// Mock StorageService
class MockStorageService {
async getHistoricalBars(symbol: string, startDate: Date, endDate: Date, frequency: string) {
// Generate mock data
const bars = [];
const msPerDay = 24 * 60 * 60 * 1000;
let currentDate = new Date(startDate);
let price = 100 + Math.random() * 50; // Start between 100-150
while (currentDate <= endDate) {
// Random walk
const change = (Math.random() - 0.5) * 2; // +/- 1%
price *= (1 + change / 100);
bars.push({
symbol,
timestamp: new Date(currentDate),
open: price * (1 + (Math.random() - 0.5) * 0.01),
high: price * (1 + Math.random() * 0.02),
low: price * (1 - Math.random() * 0.02),
close: price,
volume: 1000000 + Math.random() * 500000,
});
currentDate = new Date(currentDate.getTime() + msPerDay);
}
return bars;
}
}
async function testRustBacktest() {
console.log('🚀 Testing Rust Backtest Engine with TypeScript Strategy\n');
// Create minimal container
const container: IServiceContainer = {
logger: {
info: (msg: string, ...args: any[]) => console.log('[INFO]', msg, ...args),
error: (msg: string, ...args: any[]) => console.error('[ERROR]', msg, ...args),
warn: (msg: string, ...args: any[]) => console.warn('[WARN]', msg, ...args),
debug: (msg: string, ...args: any[]) => console.log('[DEBUG]', msg, ...args),
} as any,
custom: {
StorageService: new MockStorageService(),
}
};
// Backtest configuration
const config = {
mode: 'backtest' as const,
name: 'Rust Engine Test',
strategy: 'sma-crossover',
symbols: ['AAPL'], // Just one symbol for testing
startDate: '2023-01-01T00:00:00Z',
endDate: '2023-01-31T00:00:00Z', // Just one month for testing
initialCapital: 100000,
commission: 0.001,
slippage: 0.0001,
dataFrequency: '1d',
speed: 'max' as const,
};
// Create strategy
const strategy = new SimpleMovingAverageCrossoverRust({
fastPeriod: 10,
slowPeriod: 30,
minHoldingBars: 5,
});
console.log('Configuration:');
console.log(` Symbols: ${config.symbols.join(', ')}`);
console.log(` Period: ${config.startDate} to ${config.endDate}`);
console.log(` Initial Capital: $${config.initialCapital.toLocaleString()}`);
console.log(` Strategy: ${strategy.constructor.name}`);
console.log('');
try {
console.log('Running backtest in Rust engine...\n');
const startTime = Date.now();
try {
const result = await createRustBacktest(container, config, [strategy]);
console.log('Raw result:', result);
const duration = (Date.now() - startTime) / 1000;
console.log(`\n✅ Backtest completed in ${duration.toFixed(2)} seconds`);
if (!result || !result.metrics) {
console.error('Invalid result structure:', result);
return;
}
console.log('\n=== PERFORMANCE METRICS ===');
console.log(`Total Return: ${result.metrics.totalReturn?.toFixed(2) || 'N/A'}%`);
console.log(`Sharpe Ratio: ${result.metrics.sharpeRatio?.toFixed(2) || 'N/A'}`);
console.log(`Max Drawdown: ${result.metrics.maxDrawdown ? (result.metrics.maxDrawdown * 100).toFixed(2) : 'N/A'}%`);
console.log(`Win Rate: ${result.metrics.winRate?.toFixed(1) || 'N/A'}%`);
console.log(`Total Trades: ${result.metrics.totalTrades || 0}`);
console.log(`Profit Factor: ${result.metrics.profitFactor?.toFixed(2) || 'N/A'}`);
console.log('\n=== TRADE STATISTICS ===');
console.log(`Profitable Trades: ${result.metrics.profitableTrades || 0}`);
console.log(`Average Win: $${result.metrics.avgWin?.toFixed(2) || '0.00'}`);
console.log(`Average Loss: $${result.metrics.avgLoss?.toFixed(2) || '0.00'}`);
console.log(`Total P&L: $${result.metrics.totalPnl?.toFixed(2) || '0.00'}`);
console.log('\n=== EQUITY CURVE ===');
if (result.equityCurve.length > 0) {
const firstValue = result.equityCurve[0].value;
const lastValue = result.equityCurve[result.equityCurve.length - 1].value;
console.log(`Starting Value: $${firstValue.toLocaleString()}`);
console.log(`Ending Value: $${lastValue.toLocaleString()}`);
console.log(`Growth: ${((lastValue / firstValue - 1) * 100).toFixed(2)}%`);
}
console.log('\n=== FINAL POSITIONS ===');
const positions = Object.entries(result.finalPositions);
if (positions.length > 0) {
for (const [symbol, position] of positions) {
console.log(`${symbol}: ${position.quantity} shares @ $${position.averagePrice}`);
}
} else {
console.log('No open positions');
}
// Compare with TypeScript engine performance
console.log('\n=== PERFORMANCE COMPARISON ===');
console.log('TypeScript Engine: ~5-10 seconds for 1 year backtest');
console.log(`Rust Engine: ${duration.toFixed(2)} seconds`);
console.log(`Speed Improvement: ${(10 / duration).toFixed(1)}x faster`);
} catch (innerError) {
console.error('Result processing error:', innerError);
}
} catch (error) {
console.error('❌ Backtest failed:', error);
}
}
// Run the test
testRustBacktest().catch(console.error);