use super::{Indicator, IncrementalIndicator, IndicatorResult, IndicatorError, PriceData}; use super::common::RollingWindow; /// Relative Strength Index (RSI) Indicator /// /// Measures momentum by comparing the magnitude of recent gains to recent losses /// RSI = 100 - (100 / (1 + RS)) /// where RS = Average Gain / Average Loss pub struct RSI { period: usize, avg_gain: f64, avg_loss: f64, prev_value: Option, window: RollingWindow, initialized: bool, } impl RSI { pub fn new(period: usize) -> Result { if period == 0 { return Err(IndicatorError::InvalidParameter( "Period must be greater than 0".to_string() )); } Ok(Self { period, avg_gain: 0.0, avg_loss: 0.0, prev_value: None, window: RollingWindow::new(period + 1), initialized: false, }) } /// Calculate RSI for a series of values pub fn calculate_series(values: &[f64], period: usize) -> Result, IndicatorError> { if period == 0 { return Err(IndicatorError::InvalidParameter( "Period must be greater than 0".to_string() )); } if values.len() <= period { return Err(IndicatorError::InsufficientData { required: period + 1, actual: values.len(), }); } let mut result = Vec::with_capacity(values.len() - period); let mut gains = Vec::with_capacity(values.len() - 1); let mut losses = Vec::with_capacity(values.len() - 1); // Calculate price changes for i in 1..values.len() { let change = values[i] - values[i - 1]; if change > 0.0 { gains.push(change); losses.push(0.0); } else { gains.push(0.0); losses.push(-change); } } // Calculate initial averages using SMA let initial_avg_gain: f64 = gains[0..period].iter().sum::() / period as f64; let initial_avg_loss: f64 = losses[0..period].iter().sum::() / period as f64; // Calculate first RSI let rs = if initial_avg_loss > 0.0 { initial_avg_gain / initial_avg_loss } else { 100.0 // If no losses, RSI is 100 }; result.push(100.0 - (100.0 / (1.0 + rs))); // Calculate remaining RSIs using EMA smoothing let mut avg_gain = initial_avg_gain; let mut avg_loss = initial_avg_loss; let alpha = 1.0 / period as f64; for i in period..gains.len() { // Wilder's smoothing method avg_gain = (avg_gain * (period - 1) as f64 + gains[i]) / period as f64; avg_loss = (avg_loss * (period - 1) as f64 + losses[i]) / period as f64; let rs = if avg_loss > 0.0 { avg_gain / avg_loss } else { 100.0 }; result.push(100.0 - (100.0 / (1.0 + rs))); } Ok(result) } fn calculate_rsi(&self) -> f64 { if self.avg_loss == 0.0 { 100.0 } else { let rs = self.avg_gain / self.avg_loss; 100.0 - (100.0 / (1.0 + rs)) } } } impl Indicator for RSI { fn calculate(&mut self, data: &PriceData) -> Result { let values = &data.close; if values.len() <= self.period { return Err(IndicatorError::InsufficientData { required: self.period + 1, actual: values.len(), }); } let rsi_values = Self::calculate_series(values, self.period)?; Ok(IndicatorResult::Series(rsi_values)) } fn reset(&mut self) { self.avg_gain = 0.0; self.avg_loss = 0.0; self.prev_value = None; self.window.clear(); self.initialized = false; } fn is_ready(&self) -> bool { self.initialized } } impl IncrementalIndicator for RSI { fn update(&mut self, value: f64) -> Result, IndicatorError> { self.window.push(value); if let Some(prev) = self.prev_value { let change = value - prev; let gain = if change > 0.0 { change } else { 0.0 }; let loss = if change < 0.0 { -change } else { 0.0 }; if !self.initialized && self.window.len() > self.period { // Initialize using first period values let values = self.window.as_slice(); let mut sum_gain = 0.0; let mut sum_loss = 0.0; for i in 1..=self.period { let change = values[i] - values[i - 1]; if change > 0.0 { sum_gain += change; } else { sum_loss += -change; } } self.avg_gain = sum_gain / self.period as f64; self.avg_loss = sum_loss / self.period as f64; self.initialized = true; } else if self.initialized { // Update using Wilder's smoothing self.avg_gain = (self.avg_gain * (self.period - 1) as f64 + gain) / self.period as f64; self.avg_loss = (self.avg_loss * (self.period - 1) as f64 + loss) / self.period as f64; } } self.prev_value = Some(value); if self.initialized { Ok(Some(self.calculate_rsi())) } else { Ok(None) } } fn current(&self) -> Option { if self.initialized { Some(self.calculate_rsi()) } else { None } } } #[cfg(test)] mod tests { use super::*; #[test] fn test_rsi_calculation() { let values = vec![ 44.0, 44.25, 44.38, 44.38, 44.88, 45.05, 45.25, 45.38, 45.75, 46.03, 46.23, 46.08, 46.03, 45.85, 46.25, 46.38, 46.50 ]; let result = RSI::calculate_series(&values, 14).unwrap(); assert_eq!(result.len(), 3); // RSI should be between 0 and 100 for rsi in &result { assert!(*rsi >= 0.0 && *rsi <= 100.0); } } #[test] fn test_rsi_extremes() { // All gains - RSI should be close to 100 let increasing = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; let result = RSI::calculate_series(&increasing, 5).unwrap(); assert!(result.last().unwrap() > &95.0); // All losses - RSI should be close to 0 let decreasing = vec![8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0]; let result = RSI::calculate_series(&decreasing, 5).unwrap(); assert!(result.last().unwrap() < &5.0); } }