added rust engine and adapter pattern
This commit is contained in:
parent
a58072cf93
commit
0a4702d12a
6 changed files with 328 additions and 186 deletions
|
|
@ -3,6 +3,7 @@ import { IServiceContainer } from '@stock-bot/di';
|
|||
import { BacktestEngine as RustEngine } from '@stock-bot/core';
|
||||
import { BacktestConfig, BacktestResult } from '../types';
|
||||
import { StorageService } from '../services/StorageService';
|
||||
import { StrategyExecutor, SMACrossoverStrategy } from '../strategies/StrategyExecutor';
|
||||
|
||||
/**
|
||||
* Adapter that bridges the orchestrator with the Rust backtest engine
|
||||
|
|
@ -12,6 +13,7 @@ export class RustBacktestAdapter extends EventEmitter {
|
|||
private storageService: StorageService;
|
||||
private currentEngine?: RustEngine;
|
||||
private isRunning = false;
|
||||
private strategyExecutor: StrategyExecutor;
|
||||
|
||||
constructor(container: IServiceContainer) {
|
||||
super();
|
||||
|
|
@ -22,6 +24,7 @@ export class RustBacktestAdapter extends EventEmitter {
|
|||
container.postgres,
|
||||
null
|
||||
);
|
||||
this.strategyExecutor = new StrategyExecutor();
|
||||
}
|
||||
|
||||
async runBacktest(config: BacktestConfig): Promise<BacktestResult> {
|
||||
|
|
@ -237,165 +240,64 @@ export class RustBacktestAdapter extends EventEmitter {
|
|||
private registerStrategy(strategyName: string, parameters: any): void {
|
||||
if (!this.currentEngine) return;
|
||||
|
||||
// Create state for the strategy
|
||||
const priceHistory: Map<string, number[]> = new Map();
|
||||
const positions: Map<string, number> = new Map();
|
||||
const fastPeriod = parameters.fastPeriod || 5;
|
||||
const slowPeriod = parameters.slowPeriod || 15;
|
||||
|
||||
this.container.logger.info('Registering TypeScript strategy', {
|
||||
this.container.logger.info('Registering strategy', {
|
||||
strategyName,
|
||||
fastPeriod,
|
||||
slowPeriod
|
||||
parameters
|
||||
});
|
||||
|
||||
// Create a TypeScript strategy callback
|
||||
let callCount = 0;
|
||||
const callback = (callJson: string) => {
|
||||
callCount++;
|
||||
const call = JSON.parse(callJson);
|
||||
// Check if we should use TypeScript or native Rust implementation
|
||||
const useTypeScript = parameters.useTypeScriptImplementation || false;
|
||||
|
||||
if (useTypeScript) {
|
||||
// Register TypeScript strategy
|
||||
this.container.logger.info('Using TypeScript strategy implementation');
|
||||
|
||||
// Log every 10th call to see if we're getting data
|
||||
if (callCount % 10 === 1) {
|
||||
this.container.logger.info(`Strategy callback called ${callCount} times, method: ${call.method}`);
|
||||
}
|
||||
|
||||
if (call.method === 'on_market_data') {
|
||||
const marketData = call.data;
|
||||
const signals: any[] = [];
|
||||
// For now, we'll use the SMACrossoverStrategy
|
||||
if (strategyName.toLowerCase().includes('sma') || strategyName.toLowerCase().includes('crossover')) {
|
||||
const strategyId = `strategy-${Date.now()}`;
|
||||
this.strategyExecutor.registerStrategy(strategyId, SMACrossoverStrategy);
|
||||
|
||||
// Debug log first few data points
|
||||
if (priceHistory.size === 0) {
|
||||
this.container.logger.info('First market data received:', JSON.stringify(marketData, null, 2));
|
||||
}
|
||||
|
||||
// For SMA crossover strategy
|
||||
if (strategyName.toLowerCase().includes('sma') || strategyName.toLowerCase().includes('crossover')) {
|
||||
// Log the structure to understand the data format
|
||||
if (callCount === 1) {
|
||||
this.container.logger.info('Market data structure:', {
|
||||
hasData: !!marketData.data,
|
||||
hasBar: !!marketData.data?.Bar,
|
||||
hasClose: !!marketData.data?.close,
|
||||
dataKeys: marketData.data ? Object.keys(marketData.data) : [],
|
||||
});
|
||||
// Create a callback that uses the strategy executor
|
||||
const callback = (callJson: string) => {
|
||||
try {
|
||||
const call = JSON.parse(callJson);
|
||||
|
||||
if (call.method === 'on_market_data') {
|
||||
const signals = this.strategyExecutor.onMarketData(call.data);
|
||||
return JSON.stringify({ signals });
|
||||
} else if (call.method === 'on_fill') {
|
||||
const { symbol, quantity, price, side } = call.data;
|
||||
this.strategyExecutor.onFill(symbol, quantity, price, side);
|
||||
return JSON.stringify({ signals: [] });
|
||||
}
|
||||
|
||||
return JSON.stringify({ signals: [] });
|
||||
} catch (error) {
|
||||
this.container.logger.error('Strategy execution error:', error);
|
||||
return JSON.stringify({ signals: [] });
|
||||
}
|
||||
|
||||
// Check if it's bar data - handle different possible structures
|
||||
const isBar = marketData.data?.Bar ||
|
||||
(marketData.data && 'close' in marketData.data) ||
|
||||
(marketData && 'close' in marketData);
|
||||
|
||||
if (isBar) {
|
||||
const symbol = marketData.symbol;
|
||||
// Handle both direct properties and nested Bar structure
|
||||
const barData = marketData.data?.Bar || marketData.data || marketData;
|
||||
const price = barData.close;
|
||||
|
||||
// Log that we're processing bar data
|
||||
if (callCount <= 3) {
|
||||
this.container.logger.info(`Processing bar data for ${symbol}, price: ${price}`);
|
||||
}
|
||||
|
||||
// Update price history
|
||||
if (!priceHistory.has(symbol)) {
|
||||
priceHistory.set(symbol, []);
|
||||
}
|
||||
|
||||
const history = priceHistory.get(symbol)!;
|
||||
history.push(price);
|
||||
|
||||
// Keep only necessary history
|
||||
if (history.length > slowPeriod) {
|
||||
history.shift();
|
||||
}
|
||||
|
||||
// Need enough data
|
||||
if (history.length >= slowPeriod) {
|
||||
// Calculate SMAs
|
||||
const fastSMA = history.slice(-fastPeriod).reduce((a, b) => a + b, 0) / fastPeriod;
|
||||
const slowSMA = history.reduce((a, b) => a + b, 0) / slowPeriod;
|
||||
|
||||
// Log SMA values periodically
|
||||
if (history.length % 5 === 0 || history.length === slowPeriod) {
|
||||
this.container.logger.debug(`SMAs for ${symbol}: Fast(${fastPeriod})=${fastSMA.toFixed(2)}, Slow(${slowPeriod})=${slowSMA.toFixed(2)}, Price=${price.toFixed(2)}, History length=${history.length}`);
|
||||
}
|
||||
|
||||
// Previous SMAs (if we have enough history)
|
||||
if (history.length > slowPeriod) {
|
||||
const prevHistory = history.slice(0, -1);
|
||||
const prevFastSMA = prevHistory.slice(-fastPeriod).reduce((a, b) => a + b, 0) / fastPeriod;
|
||||
const prevSlowSMA = prevHistory.reduce((a, b) => a + b, 0) / slowPeriod;
|
||||
|
||||
const currentPosition = positions.get(symbol) || 0;
|
||||
|
||||
// Log crossover checks periodically
|
||||
if (history.length % 10 === 0) {
|
||||
this.container.logger.debug(`Crossover check for ${symbol}: prevFast=${prevFastSMA.toFixed(2)}, prevSlow=${prevSlowSMA.toFixed(2)}, currFast=${fastSMA.toFixed(2)}, currSlow=${slowSMA.toFixed(2)}, position=${currentPosition}`);
|
||||
}
|
||||
|
||||
// Golden cross - buy signal
|
||||
if (prevFastSMA <= prevSlowSMA && fastSMA > slowSMA && currentPosition <= 0) {
|
||||
this.container.logger.info(`Golden cross detected for ${symbol} at price ${price}`);
|
||||
signals.push({
|
||||
symbol,
|
||||
signal_type: 'Buy',
|
||||
strength: 1.0,
|
||||
quantity: 100, // Fixed quantity for testing
|
||||
reason: 'Golden cross'
|
||||
});
|
||||
positions.set(symbol, 1);
|
||||
}
|
||||
|
||||
// Death cross - sell signal
|
||||
else if (prevFastSMA >= prevSlowSMA && fastSMA < slowSMA && currentPosition >= 0) {
|
||||
this.container.logger.info(`Death cross detected for ${symbol} at price ${price}`);
|
||||
signals.push({
|
||||
symbol,
|
||||
signal_type: 'Sell',
|
||||
strength: 1.0,
|
||||
quantity: 100, // Fixed quantity for testing
|
||||
reason: 'Death cross'
|
||||
});
|
||||
positions.set(symbol, -1);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Log while building up history
|
||||
if (history.length % 5 === 0 || history.length === 1) {
|
||||
this.container.logger.debug(`Building history for ${symbol}: ${history.length}/${slowPeriod} bars collected`);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
return JSON.stringify({ signals });
|
||||
// Register with Rust engine
|
||||
this.currentEngine.addTypescriptStrategy(
|
||||
strategyName,
|
||||
strategyId,
|
||||
parameters,
|
||||
callback
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// Use native Rust strategy for maximum performance
|
||||
this.container.logger.info('Using native Rust strategy implementation');
|
||||
|
||||
if (call.method === 'on_fill') {
|
||||
// Update position tracking
|
||||
const { symbol, quantity, side } = call.data;
|
||||
const currentPos = positions.get(symbol) || 0;
|
||||
const newPos = side === 'buy' ? currentPos + quantity : currentPos - quantity;
|
||||
|
||||
if (Math.abs(newPos) < 0.0001) {
|
||||
positions.delete(symbol);
|
||||
} else {
|
||||
positions.set(symbol, newPos);
|
||||
}
|
||||
|
||||
return JSON.stringify({ signals: [] });
|
||||
}
|
||||
|
||||
return JSON.stringify({ signals: [] });
|
||||
};
|
||||
|
||||
this.currentEngine.addTypescriptStrategy(
|
||||
strategyName,
|
||||
`strategy-${Date.now()}`,
|
||||
parameters,
|
||||
callback
|
||||
);
|
||||
// Use the addNativeStrategy method instead
|
||||
this.currentEngine.addNativeStrategy(
|
||||
'sma_crossover', // strategy type
|
||||
strategyName,
|
||||
`strategy-${Date.now()}`,
|
||||
parameters
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
private calculateExpectancy(metrics: any): number {
|
||||
|
|
|
|||
135
apps/stock/orchestrator/src/strategies/StrategyExecutor.ts
Normal file
135
apps/stock/orchestrator/src/strategies/StrategyExecutor.ts
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
import { MarketData, Signal } from '../types';
|
||||
|
||||
export interface IStrategyExecutor {
|
||||
onMarketData(data: MarketData): Signal[];
|
||||
onFill(symbol: string, quantity: number, price: number, side: string): void;
|
||||
getState(): any;
|
||||
setState(state: any): void;
|
||||
}
|
||||
|
||||
/**
|
||||
* Executes strategies in-process for backtesting
|
||||
* This avoids the complexity of async callbacks between Rust and TypeScript
|
||||
*/
|
||||
export class StrategyExecutor implements IStrategyExecutor {
|
||||
private strategies: Map<string, any> = new Map();
|
||||
private strategyStates: Map<string, any> = new Map();
|
||||
|
||||
registerStrategy(id: string, strategy: any) {
|
||||
this.strategies.set(id, strategy);
|
||||
this.strategyStates.set(id, {
|
||||
priceHistory: new Map<string, number[]>(),
|
||||
positions: new Map<string, number>(),
|
||||
});
|
||||
}
|
||||
|
||||
onMarketData(data: MarketData): Signal[] {
|
||||
const allSignals: Signal[] = [];
|
||||
|
||||
for (const [id, strategy] of this.strategies) {
|
||||
const state = this.strategyStates.get(id)!;
|
||||
const signals = strategy.onMarketData(data, state);
|
||||
|
||||
if (signals && signals.length > 0) {
|
||||
allSignals.push(...signals);
|
||||
}
|
||||
}
|
||||
|
||||
return allSignals;
|
||||
}
|
||||
|
||||
onFill(symbol: string, quantity: number, price: number, side: string): void {
|
||||
for (const [id, strategy] of this.strategies) {
|
||||
const state = this.strategyStates.get(id)!;
|
||||
|
||||
if (strategy.onFill) {
|
||||
strategy.onFill({ symbol, quantity, price, side }, state);
|
||||
}
|
||||
|
||||
// Update position tracking
|
||||
const currentPos = state.positions.get(symbol) || 0;
|
||||
const newPos = side === 'buy' ? currentPos + quantity : currentPos - quantity;
|
||||
|
||||
if (Math.abs(newPos) < 0.0001) {
|
||||
state.positions.delete(symbol);
|
||||
} else {
|
||||
state.positions.set(symbol, newPos);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
getState(): any {
|
||||
return Object.fromEntries(this.strategyStates);
|
||||
}
|
||||
|
||||
setState(state: any): void {
|
||||
this.strategyStates = new Map(Object.entries(state));
|
||||
}
|
||||
}
|
||||
|
||||
// Example SMA Crossover Strategy
|
||||
export const SMACrossoverStrategy = {
|
||||
onMarketData(data: MarketData, state: any): Signal[] {
|
||||
const signals: Signal[] = [];
|
||||
|
||||
// Check if it's bar data
|
||||
if (data.type !== 'bar') return signals;
|
||||
|
||||
const { symbol, close } = data.data;
|
||||
const fastPeriod = 5;
|
||||
const slowPeriod = 15;
|
||||
|
||||
// Update price history
|
||||
if (!state.priceHistory.has(symbol)) {
|
||||
state.priceHistory.set(symbol, []);
|
||||
}
|
||||
|
||||
const history = state.priceHistory.get(symbol)!;
|
||||
history.push(close);
|
||||
|
||||
// Keep only necessary history
|
||||
if (history.length > slowPeriod + 1) {
|
||||
history.shift();
|
||||
}
|
||||
|
||||
// Need enough data
|
||||
if (history.length >= slowPeriod) {
|
||||
// Calculate SMAs
|
||||
const fastSMA = history.slice(-fastPeriod).reduce((a, b) => a + b, 0) / fastPeriod;
|
||||
const slowSMA = history.reduce((a, b) => a + b, 0) / history.length;
|
||||
|
||||
// Previous SMAs (if we have enough history)
|
||||
if (history.length > slowPeriod) {
|
||||
const prevHistory = history.slice(0, -1);
|
||||
const prevFastSMA = prevHistory.slice(-fastPeriod).reduce((a, b) => a + b, 0) / fastPeriod;
|
||||
const prevSlowSMA = prevHistory.reduce((a, b) => a + b, 0) / prevHistory.length;
|
||||
|
||||
const currentPosition = state.positions.get(symbol) || 0;
|
||||
|
||||
// Golden cross - buy signal
|
||||
if (prevFastSMA <= prevSlowSMA && fastSMA > slowSMA && currentPosition <= 0) {
|
||||
signals.push({
|
||||
symbol,
|
||||
signal_type: 'Buy',
|
||||
strength: 1.0,
|
||||
quantity: 100,
|
||||
reason: 'Golden cross',
|
||||
});
|
||||
}
|
||||
|
||||
// Death cross - sell signal
|
||||
else if (prevFastSMA >= prevSlowSMA && fastSMA < slowSMA && currentPosition >= 0) {
|
||||
signals.push({
|
||||
symbol,
|
||||
signal_type: 'Sell',
|
||||
strength: 1.0,
|
||||
quantity: 100,
|
||||
reason: 'Death cross',
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return signals;
|
||||
},
|
||||
};
|
||||
Loading…
Add table
Add a link
Reference in a new issue