stock-bot/apps/stock/engine/src/indicators/rsi.rs
2025-07-04 11:24:27 -04:00

223 lines
No EOL
7.2 KiB
Rust

use super::{Indicator, IncrementalIndicator, IndicatorResult, IndicatorError, PriceData};
use super::common::RollingWindow;
/// Relative Strength Index (RSI) Indicator
///
/// Measures momentum by comparing the magnitude of recent gains to recent losses
/// RSI = 100 - (100 / (1 + RS))
/// where RS = Average Gain / Average Loss
pub struct RSI {
period: usize,
avg_gain: f64,
avg_loss: f64,
prev_value: Option<f64>,
window: RollingWindow<f64>,
initialized: bool,
}
impl RSI {
pub fn new(period: usize) -> Result<Self, IndicatorError> {
if period == 0 {
return Err(IndicatorError::InvalidParameter(
"Period must be greater than 0".to_string()
));
}
Ok(Self {
period,
avg_gain: 0.0,
avg_loss: 0.0,
prev_value: None,
window: RollingWindow::new(period + 1),
initialized: false,
})
}
/// Calculate RSI for a series of values
pub fn calculate_series(values: &[f64], period: usize) -> Result<Vec<f64>, IndicatorError> {
if period == 0 {
return Err(IndicatorError::InvalidParameter(
"Period must be greater than 0".to_string()
));
}
if values.len() <= period {
return Err(IndicatorError::InsufficientData {
required: period + 1,
actual: values.len(),
});
}
let mut result = Vec::with_capacity(values.len() - period);
let mut gains = Vec::with_capacity(values.len() - 1);
let mut losses = Vec::with_capacity(values.len() - 1);
// Calculate price changes
for i in 1..values.len() {
let change = values[i] - values[i - 1];
if change > 0.0 {
gains.push(change);
losses.push(0.0);
} else {
gains.push(0.0);
losses.push(-change);
}
}
// Calculate initial averages using SMA
let initial_avg_gain: f64 = gains[0..period].iter().sum::<f64>() / period as f64;
let initial_avg_loss: f64 = losses[0..period].iter().sum::<f64>() / period as f64;
// Calculate first RSI
let rs = if initial_avg_loss > 0.0 {
initial_avg_gain / initial_avg_loss
} else {
100.0 // If no losses, RSI is 100
};
result.push(100.0 - (100.0 / (1.0 + rs)));
// Calculate remaining RSIs using EMA smoothing
let mut avg_gain = initial_avg_gain;
let mut avg_loss = initial_avg_loss;
let alpha = 1.0 / period as f64;
for i in period..gains.len() {
// Wilder's smoothing method
avg_gain = (avg_gain * (period - 1) as f64 + gains[i]) / period as f64;
avg_loss = (avg_loss * (period - 1) as f64 + losses[i]) / period as f64;
let rs = if avg_loss > 0.0 {
avg_gain / avg_loss
} else {
100.0
};
result.push(100.0 - (100.0 / (1.0 + rs)));
}
Ok(result)
}
fn calculate_rsi(&self) -> f64 {
if self.avg_loss == 0.0 {
100.0
} else {
let rs = self.avg_gain / self.avg_loss;
100.0 - (100.0 / (1.0 + rs))
}
}
}
impl Indicator for RSI {
fn calculate(&mut self, data: &PriceData) -> Result<IndicatorResult, IndicatorError> {
let values = &data.close;
if values.len() <= self.period {
return Err(IndicatorError::InsufficientData {
required: self.period + 1,
actual: values.len(),
});
}
let rsi_values = Self::calculate_series(values, self.period)?;
Ok(IndicatorResult::Series(rsi_values))
}
fn reset(&mut self) {
self.avg_gain = 0.0;
self.avg_loss = 0.0;
self.prev_value = None;
self.window.clear();
self.initialized = false;
}
fn is_ready(&self) -> bool {
self.initialized
}
}
impl IncrementalIndicator for RSI {
fn update(&mut self, value: f64) -> Result<Option<f64>, IndicatorError> {
self.window.push(value);
if let Some(prev) = self.prev_value {
let change = value - prev;
let gain = if change > 0.0 { change } else { 0.0 };
let loss = if change < 0.0 { -change } else { 0.0 };
if !self.initialized && self.window.len() > self.period {
// Initialize using first period values
let values = self.window.as_slice();
let mut sum_gain = 0.0;
let mut sum_loss = 0.0;
for i in 1..=self.period {
let change = values[i] - values[i - 1];
if change > 0.0 {
sum_gain += change;
} else {
sum_loss += -change;
}
}
self.avg_gain = sum_gain / self.period as f64;
self.avg_loss = sum_loss / self.period as f64;
self.initialized = true;
} else if self.initialized {
// Update using Wilder's smoothing
self.avg_gain = (self.avg_gain * (self.period - 1) as f64 + gain) / self.period as f64;
self.avg_loss = (self.avg_loss * (self.period - 1) as f64 + loss) / self.period as f64;
}
}
self.prev_value = Some(value);
if self.initialized {
Ok(Some(self.calculate_rsi()))
} else {
Ok(None)
}
}
fn current(&self) -> Option<f64> {
if self.initialized {
Some(self.calculate_rsi())
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rsi_calculation() {
let values = vec![
44.0, 44.25, 44.38, 44.38, 44.88, 45.05,
45.25, 45.38, 45.75, 46.03, 46.23, 46.08,
46.03, 45.85, 46.25, 46.38, 46.50
];
let result = RSI::calculate_series(&values, 14).unwrap();
assert_eq!(result.len(), 3);
// RSI should be between 0 and 100
for rsi in &result {
assert!(*rsi >= 0.0 && *rsi <= 100.0);
}
}
#[test]
fn test_rsi_extremes() {
// All gains - RSI should be close to 100
let increasing = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let result = RSI::calculate_series(&increasing, 5).unwrap();
assert!(result.last().unwrap() > &95.0);
// All losses - RSI should be close to 0
let decreasing = vec![8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
let result = RSI::calculate_series(&decreasing, 5).unwrap();
assert!(result.last().unwrap() < &5.0);
}
}