223 lines
No EOL
7.2 KiB
Rust
223 lines
No EOL
7.2 KiB
Rust
use super::{Indicator, IncrementalIndicator, IndicatorResult, IndicatorError, PriceData};
|
|
use super::common::RollingWindow;
|
|
|
|
/// Relative Strength Index (RSI) Indicator
|
|
///
|
|
/// Measures momentum by comparing the magnitude of recent gains to recent losses
|
|
/// RSI = 100 - (100 / (1 + RS))
|
|
/// where RS = Average Gain / Average Loss
|
|
pub struct RSI {
|
|
period: usize,
|
|
avg_gain: f64,
|
|
avg_loss: f64,
|
|
prev_value: Option<f64>,
|
|
window: RollingWindow<f64>,
|
|
initialized: bool,
|
|
}
|
|
|
|
impl RSI {
|
|
pub fn new(period: usize) -> Result<Self, IndicatorError> {
|
|
if period == 0 {
|
|
return Err(IndicatorError::InvalidParameter(
|
|
"Period must be greater than 0".to_string()
|
|
));
|
|
}
|
|
|
|
Ok(Self {
|
|
period,
|
|
avg_gain: 0.0,
|
|
avg_loss: 0.0,
|
|
prev_value: None,
|
|
window: RollingWindow::new(period + 1),
|
|
initialized: false,
|
|
})
|
|
}
|
|
|
|
/// Calculate RSI for a series of values
|
|
pub fn calculate_series(values: &[f64], period: usize) -> Result<Vec<f64>, IndicatorError> {
|
|
if period == 0 {
|
|
return Err(IndicatorError::InvalidParameter(
|
|
"Period must be greater than 0".to_string()
|
|
));
|
|
}
|
|
|
|
if values.len() <= period {
|
|
return Err(IndicatorError::InsufficientData {
|
|
required: period + 1,
|
|
actual: values.len(),
|
|
});
|
|
}
|
|
|
|
let mut result = Vec::with_capacity(values.len() - period);
|
|
let mut gains = Vec::with_capacity(values.len() - 1);
|
|
let mut losses = Vec::with_capacity(values.len() - 1);
|
|
|
|
// Calculate price changes
|
|
for i in 1..values.len() {
|
|
let change = values[i] - values[i - 1];
|
|
if change > 0.0 {
|
|
gains.push(change);
|
|
losses.push(0.0);
|
|
} else {
|
|
gains.push(0.0);
|
|
losses.push(-change);
|
|
}
|
|
}
|
|
|
|
// Calculate initial averages using SMA
|
|
let initial_avg_gain: f64 = gains[0..period].iter().sum::<f64>() / period as f64;
|
|
let initial_avg_loss: f64 = losses[0..period].iter().sum::<f64>() / period as f64;
|
|
|
|
// Calculate first RSI
|
|
let rs = if initial_avg_loss > 0.0 {
|
|
initial_avg_gain / initial_avg_loss
|
|
} else {
|
|
100.0 // If no losses, RSI is 100
|
|
};
|
|
result.push(100.0 - (100.0 / (1.0 + rs)));
|
|
|
|
// Calculate remaining RSIs using EMA smoothing
|
|
let mut avg_gain = initial_avg_gain;
|
|
let mut avg_loss = initial_avg_loss;
|
|
let alpha = 1.0 / period as f64;
|
|
|
|
for i in period..gains.len() {
|
|
// Wilder's smoothing method
|
|
avg_gain = (avg_gain * (period - 1) as f64 + gains[i]) / period as f64;
|
|
avg_loss = (avg_loss * (period - 1) as f64 + losses[i]) / period as f64;
|
|
|
|
let rs = if avg_loss > 0.0 {
|
|
avg_gain / avg_loss
|
|
} else {
|
|
100.0
|
|
};
|
|
result.push(100.0 - (100.0 / (1.0 + rs)));
|
|
}
|
|
|
|
Ok(result)
|
|
}
|
|
|
|
fn calculate_rsi(&self) -> f64 {
|
|
if self.avg_loss == 0.0 {
|
|
100.0
|
|
} else {
|
|
let rs = self.avg_gain / self.avg_loss;
|
|
100.0 - (100.0 / (1.0 + rs))
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Indicator for RSI {
|
|
fn calculate(&mut self, data: &PriceData) -> Result<IndicatorResult, IndicatorError> {
|
|
let values = &data.close;
|
|
|
|
if values.len() <= self.period {
|
|
return Err(IndicatorError::InsufficientData {
|
|
required: self.period + 1,
|
|
actual: values.len(),
|
|
});
|
|
}
|
|
|
|
let rsi_values = Self::calculate_series(values, self.period)?;
|
|
Ok(IndicatorResult::Series(rsi_values))
|
|
}
|
|
|
|
fn reset(&mut self) {
|
|
self.avg_gain = 0.0;
|
|
self.avg_loss = 0.0;
|
|
self.prev_value = None;
|
|
self.window.clear();
|
|
self.initialized = false;
|
|
}
|
|
|
|
fn is_ready(&self) -> bool {
|
|
self.initialized
|
|
}
|
|
}
|
|
|
|
impl IncrementalIndicator for RSI {
|
|
fn update(&mut self, value: f64) -> Result<Option<f64>, IndicatorError> {
|
|
self.window.push(value);
|
|
|
|
if let Some(prev) = self.prev_value {
|
|
let change = value - prev;
|
|
let gain = if change > 0.0 { change } else { 0.0 };
|
|
let loss = if change < 0.0 { -change } else { 0.0 };
|
|
|
|
if !self.initialized && self.window.len() > self.period {
|
|
// Initialize using first period values
|
|
let values = self.window.as_slice();
|
|
let mut sum_gain = 0.0;
|
|
let mut sum_loss = 0.0;
|
|
|
|
for i in 1..=self.period {
|
|
let change = values[i] - values[i - 1];
|
|
if change > 0.0 {
|
|
sum_gain += change;
|
|
} else {
|
|
sum_loss += -change;
|
|
}
|
|
}
|
|
|
|
self.avg_gain = sum_gain / self.period as f64;
|
|
self.avg_loss = sum_loss / self.period as f64;
|
|
self.initialized = true;
|
|
} else if self.initialized {
|
|
// Update using Wilder's smoothing
|
|
self.avg_gain = (self.avg_gain * (self.period - 1) as f64 + gain) / self.period as f64;
|
|
self.avg_loss = (self.avg_loss * (self.period - 1) as f64 + loss) / self.period as f64;
|
|
}
|
|
}
|
|
|
|
self.prev_value = Some(value);
|
|
|
|
if self.initialized {
|
|
Ok(Some(self.calculate_rsi()))
|
|
} else {
|
|
Ok(None)
|
|
}
|
|
}
|
|
|
|
fn current(&self) -> Option<f64> {
|
|
if self.initialized {
|
|
Some(self.calculate_rsi())
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_rsi_calculation() {
|
|
let values = vec![
|
|
44.0, 44.25, 44.38, 44.38, 44.88, 45.05,
|
|
45.25, 45.38, 45.75, 46.03, 46.23, 46.08,
|
|
46.03, 45.85, 46.25, 46.38, 46.50
|
|
];
|
|
|
|
let result = RSI::calculate_series(&values, 14).unwrap();
|
|
assert_eq!(result.len(), 3);
|
|
|
|
// RSI should be between 0 and 100
|
|
for rsi in &result {
|
|
assert!(*rsi >= 0.0 && *rsi <= 100.0);
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_rsi_extremes() {
|
|
// All gains - RSI should be close to 100
|
|
let increasing = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
|
let result = RSI::calculate_series(&increasing, 5).unwrap();
|
|
assert!(result.last().unwrap() > &95.0);
|
|
|
|
// All losses - RSI should be close to 0
|
|
let decreasing = vec![8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
|
|
let result = RSI::calculate_series(&decreasing, 5).unwrap();
|
|
assert!(result.last().unwrap() < &5.0);
|
|
}
|
|
} |