moving engine to rust
This commit is contained in:
parent
d14380d740
commit
16ac28a565
16 changed files with 1598 additions and 3 deletions
187
apps/stock/orchestrator/src/backtest/RustBacktestEngine.ts
Normal file
187
apps/stock/orchestrator/src/backtest/RustBacktestEngine.ts
Normal file
|
|
@ -0,0 +1,187 @@
|
|||
import { BacktestEngine as RustEngine } from '@stock-bot/core';
|
||||
import { RustStrategy } from '../strategies/RustStrategy';
|
||||
import { MarketData, BacktestConfig } from '../types';
|
||||
import { StorageService } from '../services/StorageService';
|
||||
import { IServiceContainer } from '@stock-bot/di';
|
||||
|
||||
export interface RustBacktestConfig {
|
||||
name: string;
|
||||
symbols: string[];
|
||||
startDate: string;
|
||||
endDate: string;
|
||||
initialCapital: number;
|
||||
commission: number;
|
||||
slippage: number;
|
||||
dataFrequency: string;
|
||||
}
|
||||
|
||||
export interface RustBacktestResult {
|
||||
config: RustBacktestConfig;
|
||||
metrics: {
|
||||
totalReturn: number;
|
||||
totalTrades: number;
|
||||
profitableTrades: number;
|
||||
winRate: number;
|
||||
profitFactor: number;
|
||||
sharpeRatio: number;
|
||||
maxDrawdown: number;
|
||||
totalPnl: number;
|
||||
avgWin: number;
|
||||
avgLoss: number;
|
||||
};
|
||||
equityCurve: Array<{ date: string; value: number }>;
|
||||
trades: any[];
|
||||
finalPositions: Record<string, any>;
|
||||
}
|
||||
|
||||
export class RustBacktestEngine {
|
||||
private engine: RustEngine;
|
||||
private container: IServiceContainer;
|
||||
private storageService: StorageService;
|
||||
private config: RustBacktestConfig;
|
||||
|
||||
constructor(
|
||||
container: IServiceContainer,
|
||||
storageService: StorageService,
|
||||
config: BacktestConfig
|
||||
) {
|
||||
this.container = container;
|
||||
this.storageService = storageService;
|
||||
|
||||
// Convert config for Rust
|
||||
const rustConfig: RustBacktestConfig = {
|
||||
name: config.name || 'Backtest',
|
||||
symbols: config.symbols,
|
||||
startDate: config.startDate,
|
||||
endDate: config.endDate,
|
||||
initialCapital: config.initialCapital,
|
||||
commission: config.commission || 0.001,
|
||||
slippage: config.slippage || 0.0001,
|
||||
dataFrequency: config.dataFrequency || '1d',
|
||||
};
|
||||
|
||||
this.config = rustConfig;
|
||||
this.engine = new RustEngine(rustConfig as any);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a TypeScript strategy to the backtest
|
||||
*/
|
||||
addStrategy(strategy: RustStrategy): void {
|
||||
strategy.register(this.engine);
|
||||
}
|
||||
|
||||
/**
|
||||
* Load historical data for the backtest
|
||||
*/
|
||||
async loadData(): Promise<void> {
|
||||
// Get config from engine
|
||||
const config = this.getConfig();
|
||||
|
||||
const startDate = new Date(config.startDate);
|
||||
const endDate = new Date(config.endDate);
|
||||
|
||||
// Load data for each symbol
|
||||
for (const symbol of config.symbols) {
|
||||
const bars = await this.storageService.getHistoricalBars(
|
||||
symbol,
|
||||
startDate,
|
||||
endDate,
|
||||
config.dataFrequency
|
||||
);
|
||||
|
||||
// Convert to Rust format
|
||||
const marketData = bars.map(bar => ({
|
||||
symbol,
|
||||
timestamp: bar.timestamp.getTime(), // Convert to milliseconds
|
||||
type: 'bar',
|
||||
open: bar.open,
|
||||
high: bar.high,
|
||||
low: bar.low,
|
||||
close: bar.close,
|
||||
volume: bar.volume,
|
||||
vwap: bar.vwap || (bar.high + bar.low + bar.close) / 3,
|
||||
}));
|
||||
|
||||
// Load into Rust engine
|
||||
console.log(`Loading ${marketData.length} bars for ${symbol}`);
|
||||
this.engine.loadMarketData(marketData as any);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Run the backtest
|
||||
*/
|
||||
async run(): Promise<RustBacktestResult> {
|
||||
console.log('Starting backtest run...');
|
||||
|
||||
// Load data first
|
||||
await this.loadData();
|
||||
|
||||
// Run backtest in Rust
|
||||
const resultJson = await this.engine.run();
|
||||
|
||||
// Parse result and convert snake_case to camelCase
|
||||
const rustResult = JSON.parse(resultJson);
|
||||
|
||||
const result: RustBacktestResult = {
|
||||
config: rustResult.config,
|
||||
metrics: {
|
||||
totalReturn: rustResult.metrics.total_return,
|
||||
totalTrades: rustResult.metrics.total_trades,
|
||||
profitableTrades: rustResult.metrics.profitable_trades,
|
||||
winRate: rustResult.metrics.win_rate,
|
||||
profitFactor: rustResult.metrics.profit_factor,
|
||||
sharpeRatio: rustResult.metrics.sharpe_ratio,
|
||||
maxDrawdown: rustResult.metrics.max_drawdown,
|
||||
totalPnl: rustResult.metrics.total_pnl,
|
||||
avgWin: rustResult.metrics.avg_win,
|
||||
avgLoss: rustResult.metrics.avg_loss,
|
||||
},
|
||||
equityCurve: rustResult.equity_curve.map((point: any) => ({
|
||||
date: point[0],
|
||||
value: point[1],
|
||||
})),
|
||||
trades: rustResult.trades,
|
||||
finalPositions: rustResult.final_positions,
|
||||
};
|
||||
|
||||
// Log results
|
||||
this.container.logger.info('Rust backtest completed', {
|
||||
totalReturn: result.metrics.totalReturn,
|
||||
totalTrades: result.metrics.totalTrades,
|
||||
sharpeRatio: result.metrics.sharpeRatio,
|
||||
});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the backtest configuration
|
||||
*/
|
||||
private getConfig(): RustBacktestConfig {
|
||||
return this.config;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Factory function to create a Rust-powered backtest
|
||||
*/
|
||||
export async function createRustBacktest(
|
||||
container: IServiceContainer,
|
||||
config: BacktestConfig,
|
||||
strategies: RustStrategy[]
|
||||
): Promise<RustBacktestResult> {
|
||||
const storageService = container.custom?.StorageService || new StorageService();
|
||||
|
||||
// Create engine
|
||||
const engine = new RustBacktestEngine(container, storageService, config);
|
||||
|
||||
// Add strategies
|
||||
for (const strategy of strategies) {
|
||||
engine.addStrategy(strategy);
|
||||
}
|
||||
|
||||
// Run backtest
|
||||
return await engine.run();
|
||||
}
|
||||
139
apps/stock/orchestrator/src/strategies/RustStrategy.ts
Normal file
139
apps/stock/orchestrator/src/strategies/RustStrategy.ts
Normal file
|
|
@ -0,0 +1,139 @@
|
|||
import { BacktestEngine } from '@stock-bot/core';
|
||||
import { MarketData } from '../types';
|
||||
|
||||
export interface Signal {
|
||||
symbol: string;
|
||||
signal_type: 'Buy' | 'Sell' | 'Close';
|
||||
strength: number; // -1.0 to 1.0
|
||||
quantity?: number;
|
||||
reason?: string;
|
||||
metadata?: any;
|
||||
}
|
||||
|
||||
export interface StrategyCall {
|
||||
method: string;
|
||||
data: any;
|
||||
}
|
||||
|
||||
export interface StrategyResponse {
|
||||
signals: Signal[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Base class for TypeScript strategies that run in the Rust backtest engine
|
||||
*/
|
||||
export abstract class RustStrategy {
|
||||
protected name: string;
|
||||
protected id: string;
|
||||
protected parameters: Record<string, any>;
|
||||
protected positions: Map<string, number> = new Map();
|
||||
|
||||
constructor(name: string, id: string, parameters: Record<string, any> = {}) {
|
||||
this.name = name;
|
||||
this.id = id;
|
||||
this.parameters = parameters;
|
||||
}
|
||||
|
||||
/**
|
||||
* Main callback that Rust will call
|
||||
*/
|
||||
public handleCall(call: StrategyCall): StrategyResponse {
|
||||
switch (call.method) {
|
||||
case 'on_market_data':
|
||||
const signals = this.onMarketData(call.data);
|
||||
return { signals };
|
||||
|
||||
case 'on_fill':
|
||||
this.onFill(
|
||||
call.data.symbol,
|
||||
call.data.quantity,
|
||||
call.data.price,
|
||||
call.data.side
|
||||
);
|
||||
return { signals: [] };
|
||||
|
||||
default:
|
||||
return { signals: [] };
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Called when new market data is received
|
||||
*/
|
||||
protected abstract onMarketData(data: MarketData): Signal[];
|
||||
|
||||
/**
|
||||
* Called when an order is filled
|
||||
*/
|
||||
protected onFill(symbol: string, quantity: number, price: number, side: string): void {
|
||||
const currentPosition = this.positions.get(symbol) || 0;
|
||||
const newPosition = side === 'buy' ?
|
||||
currentPosition + quantity :
|
||||
currentPosition - quantity;
|
||||
|
||||
if (Math.abs(newPosition) < 0.0001) {
|
||||
this.positions.delete(symbol);
|
||||
} else {
|
||||
this.positions.set(symbol, newPosition);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper to create a buy signal
|
||||
*/
|
||||
protected buySignal(symbol: string, strength: number = 1.0, reason?: string): Signal {
|
||||
return {
|
||||
symbol,
|
||||
signal_type: 'Buy',
|
||||
strength,
|
||||
reason,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper to create a sell signal
|
||||
*/
|
||||
protected sellSignal(symbol: string, strength: number = 1.0, reason?: string): Signal {
|
||||
return {
|
||||
symbol,
|
||||
signal_type: 'Sell',
|
||||
strength,
|
||||
reason,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper to create a close position signal
|
||||
*/
|
||||
protected closeSignal(symbol: string, reason?: string): Signal {
|
||||
return {
|
||||
symbol,
|
||||
signal_type: 'Close',
|
||||
strength: 1.0,
|
||||
reason,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Register this strategy with a backtest engine
|
||||
*/
|
||||
public register(engine: BacktestEngine): void {
|
||||
console.log(`Registering strategy ${this.name} with id ${this.id}`);
|
||||
|
||||
// Convert the handleCall method to what Rust expects
|
||||
const callback = (callJson: string) => {
|
||||
console.log('Strategy callback called with:', callJson);
|
||||
const call: StrategyCall = JSON.parse(callJson);
|
||||
const response = this.handleCall(call);
|
||||
console.log('Strategy response:', response);
|
||||
return JSON.stringify(response);
|
||||
};
|
||||
|
||||
engine.addTypescriptStrategy(
|
||||
this.name,
|
||||
this.id,
|
||||
this.parameters,
|
||||
callback as any
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,136 @@
|
|||
import { RustStrategy, Signal } from '../RustStrategy';
|
||||
import { MarketData } from '../../types';
|
||||
|
||||
interface BarData {
|
||||
open: number;
|
||||
high: number;
|
||||
low: number;
|
||||
close: number;
|
||||
volume: number;
|
||||
timestamp: number;
|
||||
}
|
||||
|
||||
export class SimpleMovingAverageCrossoverRust extends RustStrategy {
|
||||
private priceHistory: Map<string, number[]> = new Map();
|
||||
private lastCrossover: Map<string, 'golden' | 'death' | null> = new Map();
|
||||
private barsSinceSignal: Map<string, number> = new Map();
|
||||
|
||||
private fastPeriod: number;
|
||||
private slowPeriod: number;
|
||||
private minHoldingBars: number;
|
||||
|
||||
constructor(parameters: {
|
||||
fastPeriod?: number;
|
||||
slowPeriod?: number;
|
||||
minHoldingBars?: number;
|
||||
} = {}) {
|
||||
super('SimpleMovingAverageCrossover', `sma-crossover-${Date.now()}`, parameters);
|
||||
|
||||
this.fastPeriod = parameters.fastPeriod || 5;
|
||||
this.slowPeriod = parameters.slowPeriod || 15;
|
||||
this.minHoldingBars = parameters.minHoldingBars || 3;
|
||||
}
|
||||
|
||||
protected onMarketData(data: MarketData): Signal[] {
|
||||
const signals: Signal[] = [];
|
||||
|
||||
// Only process bar data
|
||||
if (data.data.type !== 'bar') {
|
||||
return signals;
|
||||
}
|
||||
|
||||
const bar = data.data as BarData;
|
||||
const symbol = data.symbol;
|
||||
const price = bar.close;
|
||||
|
||||
// Update price history
|
||||
if (!this.priceHistory.has(symbol)) {
|
||||
this.priceHistory.set(symbol, []);
|
||||
this.barsSinceSignal.set(symbol, 0);
|
||||
}
|
||||
|
||||
const history = this.priceHistory.get(symbol)!;
|
||||
history.push(price);
|
||||
|
||||
// Keep only necessary history
|
||||
if (history.length > this.slowPeriod) {
|
||||
history.shift();
|
||||
}
|
||||
|
||||
// Update bars since signal
|
||||
const currentBar = this.barsSinceSignal.get(symbol) || 0;
|
||||
this.barsSinceSignal.set(symbol, currentBar + 1);
|
||||
|
||||
// Need enough data
|
||||
if (history.length < this.slowPeriod) {
|
||||
return signals;
|
||||
}
|
||||
|
||||
// Calculate moving averages
|
||||
const fastMA = this.calculateSMA(history, this.fastPeriod);
|
||||
const slowMA = this.calculateSMA(history, this.slowPeriod);
|
||||
|
||||
// Calculate previous MAs
|
||||
const prevHistory = history.slice(0, -1);
|
||||
const prevFastMA = this.calculateSMA(prevHistory, this.fastPeriod);
|
||||
const prevSlowMA = this.calculateSMA(prevHistory, this.slowPeriod);
|
||||
|
||||
// Check for crossovers
|
||||
const currentPosition = this.positions.get(symbol) || 0;
|
||||
const lastCrossover = this.lastCrossover.get(symbol);
|
||||
const barsSinceSignal = this.barsSinceSignal.get(symbol) || 0;
|
||||
|
||||
// Golden cross - bullish signal
|
||||
if (prevFastMA <= prevSlowMA && fastMA > slowMA) {
|
||||
this.lastCrossover.set(symbol, 'golden');
|
||||
|
||||
if (currentPosition < 0) {
|
||||
// Close short position
|
||||
signals.push(this.closeSignal(symbol, 'Golden cross - Close short'));
|
||||
} else if (currentPosition === 0 && barsSinceSignal >= this.minHoldingBars) {
|
||||
// Open long position
|
||||
signals.push(this.buySignal(symbol, 0.8, 'Golden cross - Open long'));
|
||||
this.barsSinceSignal.set(symbol, 0);
|
||||
}
|
||||
}
|
||||
|
||||
// Death cross - bearish signal
|
||||
else if (prevFastMA >= prevSlowMA && fastMA < slowMA) {
|
||||
this.lastCrossover.set(symbol, 'death');
|
||||
|
||||
if (currentPosition > 0) {
|
||||
// Close long position
|
||||
signals.push(this.closeSignal(symbol, 'Death cross - Close long'));
|
||||
} else if (currentPosition === 0 && barsSinceSignal >= this.minHoldingBars) {
|
||||
// Open short position
|
||||
signals.push(this.sellSignal(symbol, 0.8, 'Death cross - Open short'));
|
||||
this.barsSinceSignal.set(symbol, 0);
|
||||
}
|
||||
}
|
||||
|
||||
// Trend following - enter on pullbacks
|
||||
else if (currentPosition === 0 && barsSinceSignal >= this.minHoldingBars) {
|
||||
if (lastCrossover === 'golden' && fastMA > slowMA) {
|
||||
// Bullish trend continues
|
||||
signals.push(this.buySignal(symbol, 0.8, 'Bullish trend - Open long'));
|
||||
this.barsSinceSignal.set(symbol, 0);
|
||||
} else if (lastCrossover === 'death' && fastMA < slowMA) {
|
||||
// Bearish trend continues
|
||||
signals.push(this.sellSignal(symbol, 0.8, 'Bearish trend - Open short'));
|
||||
this.barsSinceSignal.set(symbol, 0);
|
||||
}
|
||||
}
|
||||
|
||||
return signals;
|
||||
}
|
||||
|
||||
private calculateSMA(prices: number[], period: number): number {
|
||||
if (prices.length < period) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const relevantPrices = prices.slice(-period);
|
||||
const sum = relevantPrices.reduce((a, b) => a + b, 0);
|
||||
return sum / period;
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue