1use std::sync::atomic::{AtomicUsize, Ordering};
7
8use serde::{Deserialize, Serialize};
9
10use super::pool::{CaConnection, CaStatus};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14#[serde(rename_all = "snake_case")]
15pub enum FailoverStrategy {
16 ActivePassive,
19
20 RoundRobin,
23
24 Weighted,
27
28 LatencyBased,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35#[serde(rename_all = "snake_case")]
36pub enum FallbackBehavior {
37 Reject,
39 QueueAndRetry,
42}
43
44pub struct StrategySelector {
46 strategy: FailoverStrategy,
47 rr_counter: AtomicUsize,
49}
50
51impl StrategySelector {
52 pub fn new(strategy: FailoverStrategy) -> Self {
54 Self {
55 strategy,
56 rr_counter: AtomicUsize::new(0),
57 }
58 }
59
60 pub fn select(&self, candidates: &[(&CaConnection, &CaStatus)]) -> Option<CaConnection> {
65 if candidates.is_empty() {
66 return None;
67 }
68
69 match &self.strategy {
70 FailoverStrategy::ActivePassive => self.select_active_passive(candidates),
71 FailoverStrategy::RoundRobin => self.select_round_robin(candidates),
72 FailoverStrategy::Weighted => self.select_weighted(candidates),
73 FailoverStrategy::LatencyBased => self.select_latency_based(candidates),
74 }
75 }
76
77 fn select_active_passive(
80 &self,
81 candidates: &[(&CaConnection, &CaStatus)],
82 ) -> Option<CaConnection> {
83 candidates
84 .iter()
85 .min_by_key(|(conn, _)| conn.priority)
86 .map(|(conn, _)| (*conn).clone())
87 }
88
89 fn select_round_robin(
91 &self,
92 candidates: &[(&CaConnection, &CaStatus)],
93 ) -> Option<CaConnection> {
94 let idx = self.rr_counter.fetch_add(1, Ordering::Relaxed) % candidates.len();
95 candidates.get(idx).map(|(conn, _)| (*conn).clone())
96 }
97
98 fn select_weighted(&self, candidates: &[(&CaConnection, &CaStatus)]) -> Option<CaConnection> {
104 let total_weight: u32 = candidates.iter().map(|(c, _)| c.weight).sum();
105 if total_weight == 0 {
106 return self.select_round_robin(candidates);
107 }
108
109 let tick = self.rr_counter.fetch_add(1, Ordering::Relaxed) as u32 % total_weight;
110 let mut cumulative = 0u32;
111 for (conn, _) in candidates {
112 cumulative += conn.weight;
113 if tick < cumulative {
114 return Some((*conn).clone());
115 }
116 }
117
118 candidates.last().map(|(c, _)| (*c).clone())
120 }
121
122 fn select_latency_based(
124 &self,
125 candidates: &[(&CaConnection, &CaStatus)],
126 ) -> Option<CaConnection> {
127 candidates
128 .iter()
129 .min_by(|(_, a), (_, b)| {
130 a.latency_ema_ms
131 .partial_cmp(&b.latency_ema_ms)
132 .unwrap_or(std::cmp::Ordering::Equal)
133 })
134 .map(|(conn, _)| (*conn).clone())
135 }
136}