1use std::collections::HashMap;
19use std::net::SocketAddr;
20use std::time::{Duration, Instant};
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
28pub enum DtlsVersion {
29 V1_2,
31 V1_3,
33}
34
35impl DtlsVersion {
36 pub fn as_str(&self) -> &'static str {
38 match self {
39 Self::V1_2 => "DTLS 1.2",
40 Self::V1_3 => "DTLS 1.3",
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
54pub struct DtlsSession {
55 session_id: Vec<u8>,
57 peer_addr: SocketAddr,
59 client_cert: Option<Vec<u8>>,
64 created_at: Instant,
66 protocol_version: DtlsVersion,
68}
69
70impl DtlsSession {
71 pub fn new(session_id: Vec<u8>, peer_addr: SocketAddr, protocol_version: DtlsVersion) -> Self {
73 Self {
74 session_id,
75 peer_addr,
76 client_cert: None,
77 created_at: Instant::now(),
78 protocol_version,
79 }
80 }
81
82 pub fn with_client_cert(
84 session_id: Vec<u8>,
85 peer_addr: SocketAddr,
86 protocol_version: DtlsVersion,
87 client_cert_der: Vec<u8>,
88 ) -> Self {
89 Self {
90 session_id,
91 peer_addr,
92 client_cert: Some(client_cert_der),
93 created_at: Instant::now(),
94 protocol_version,
95 }
96 }
97
98 pub fn session_id(&self) -> &[u8] {
100 &self.session_id
101 }
102
103 pub fn peer_addr(&self) -> SocketAddr {
105 self.peer_addr
106 }
107
108 pub fn client_cert(&self) -> Option<&[u8]> {
110 self.client_cert.as_deref()
111 }
112
113 pub fn created_at(&self) -> Instant {
115 self.created_at
116 }
117
118 pub fn protocol_version(&self) -> DtlsVersion {
120 self.protocol_version
121 }
122
123 pub fn is_expired(&self, ttl: Duration) -> bool {
125 self.created_at.elapsed() > ttl
126 }
127
128 pub fn client_cert_info(&self) -> Option<ClientCertInfo> {
133 let der = self.client_cert.as_ref()?;
134 Some(ClientCertInfo {
137 subject_dn: String::new(),
138 serial: Vec::new(),
139 der_bytes: der.clone(),
140 })
141 }
142}
143
144#[derive(Debug, Clone, PartialEq, Eq)]
149pub struct ClientCertInfo {
150 pub subject_dn: String,
154 pub serial: Vec<u8>,
156 pub der_bytes: Vec<u8>,
158}
159
160#[derive(Debug)]
171pub struct DtlsSessionCache {
172 sessions: HashMap<SocketAddr, DtlsSession>,
174 max_sessions: usize,
176 ttl: Duration,
178}
179
180impl DtlsSessionCache {
181 pub fn new(max_sessions: usize, ttl: Duration) -> Self {
189 Self {
190 sessions: HashMap::with_capacity(max_sessions),
191 max_sessions,
192 ttl,
193 }
194 }
195
196 pub fn insert(&mut self, session: DtlsSession) {
201 if self.sessions.len() >= self.max_sessions
202 && !self.sessions.contains_key(&session.peer_addr)
203 {
204 self.cleanup_expired();
205
206 if self.sessions.len() >= self.max_sessions {
208 if let Some(oldest_addr) = self.oldest_session_addr() {
209 self.sessions.remove(&oldest_addr);
210 }
211 }
212 }
213
214 self.sessions.insert(session.peer_addr, session);
215 }
216
217 pub fn get(&mut self, peer_addr: &SocketAddr) -> Option<&DtlsSession> {
222 if let Some(session) = self.sessions.get(peer_addr) {
224 if session.is_expired(self.ttl) {
225 self.sessions.remove(peer_addr);
226 return None;
227 }
228 }
229
230 self.sessions.get(peer_addr)
231 }
232
233 pub fn remove(&mut self, peer_addr: &SocketAddr) -> Option<DtlsSession> {
238 self.sessions.remove(peer_addr)
239 }
240
241 pub fn cleanup_expired(&mut self) -> usize {
245 let ttl = self.ttl;
246 let before = self.sessions.len();
247 self.sessions.retain(|_, session| !session.is_expired(ttl));
248 before - self.sessions.len()
249 }
250
251 pub fn len(&self) -> usize {
253 self.sessions.len()
254 }
255
256 pub fn is_empty(&self) -> bool {
258 self.sessions.is_empty()
259 }
260
261 pub fn max_sessions(&self) -> usize {
263 self.max_sessions
264 }
265
266 pub fn ttl(&self) -> Duration {
268 self.ttl
269 }
270
271 fn oldest_session_addr(&self) -> Option<SocketAddr> {
273 self.sessions
274 .iter()
275 .min_by_key(|(_, session)| session.created_at)
276 .map(|(addr, _)| *addr)
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283 use std::net::{IpAddr, Ipv4Addr};
284
285 fn test_addr(port: u16) -> SocketAddr {
286 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, port as u8)), port)
287 }
288
289 #[test]
290 fn test_dtls_version_as_str() {
291 assert_eq!(DtlsVersion::V1_2.as_str(), "DTLS 1.2");
292 assert_eq!(DtlsVersion::V1_3.as_str(), "DTLS 1.3");
293 }
294
295 #[test]
296 fn test_session_creation() {
297 let addr = test_addr(5683);
298 let session = DtlsSession::new(vec![1, 2, 3], addr, DtlsVersion::V1_3);
299
300 assert_eq!(session.session_id(), &[1, 2, 3]);
301 assert_eq!(session.peer_addr(), addr);
302 assert!(session.client_cert().is_none());
303 assert_eq!(session.protocol_version(), DtlsVersion::V1_3);
304 }
305
306 #[test]
307 fn test_session_with_client_cert() {
308 let addr = test_addr(5683);
309 let cert_der = vec![0x30, 0x82, 0x01, 0x00];
310 let session =
311 DtlsSession::with_client_cert(vec![1, 2, 3], addr, DtlsVersion::V1_2, cert_der.clone());
312
313 assert_eq!(session.client_cert(), Some(cert_der.as_slice()));
314 assert!(session.client_cert_info().is_some());
315 }
316
317 #[test]
318 fn test_session_expiry() {
319 let addr = test_addr(5683);
320 let session = DtlsSession::new(vec![1], addr, DtlsVersion::V1_3);
321
322 assert!(!session.is_expired(Duration::from_secs(3600)));
324
325 assert!(session.is_expired(Duration::ZERO));
327 }
328
329 #[test]
330 fn test_cache_insert_and_get() {
331 let mut cache = DtlsSessionCache::new(10, Duration::from_secs(3600));
332 let addr = test_addr(5683);
333 let session = DtlsSession::new(vec![1, 2, 3], addr, DtlsVersion::V1_3);
334
335 cache.insert(session);
336 assert_eq!(cache.len(), 1);
337 assert!(!cache.is_empty());
338
339 let retrieved = cache.get(&addr);
340 assert!(retrieved.is_some());
341 assert_eq!(retrieved.unwrap().session_id(), &[1, 2, 3]);
342 }
343
344 #[test]
345 fn test_cache_remove() {
346 let mut cache = DtlsSessionCache::new(10, Duration::from_secs(3600));
347 let addr = test_addr(5683);
348 let session = DtlsSession::new(vec![1], addr, DtlsVersion::V1_3);
349
350 cache.insert(session);
351 assert_eq!(cache.len(), 1);
352
353 let removed = cache.remove(&addr);
354 assert!(removed.is_some());
355 assert_eq!(cache.len(), 0);
356 }
357
358 #[test]
359 fn test_cache_eviction_on_capacity() {
360 let mut cache = DtlsSessionCache::new(2, Duration::from_secs(3600));
361
362 cache.insert(DtlsSession::new(vec![1], test_addr(1), DtlsVersion::V1_3));
363 cache.insert(DtlsSession::new(vec![2], test_addr(2), DtlsVersion::V1_3));
364 assert_eq!(cache.len(), 2);
365
366 cache.insert(DtlsSession::new(vec![3], test_addr(3), DtlsVersion::V1_3));
368 assert_eq!(cache.len(), 2);
369
370 assert!(cache.get(&test_addr(1)).is_none());
372 assert!(cache.get(&test_addr(3)).is_some());
373 }
374
375 #[test]
376 fn test_cache_expired_not_returned() {
377 let mut cache = DtlsSessionCache::new(10, Duration::ZERO);
378 let addr = test_addr(5683);
379 let session = DtlsSession::new(vec![1], addr, DtlsVersion::V1_3);
380
381 cache.insert(session);
382 assert!(cache.get(&addr).is_none());
384 assert_eq!(cache.len(), 0);
385 }
386
387 #[test]
388 fn test_cache_cleanup_expired() {
389 let mut cache = DtlsSessionCache::new(10, Duration::ZERO);
390 cache.insert(DtlsSession::new(vec![1], test_addr(1), DtlsVersion::V1_3));
391 cache.insert(DtlsSession::new(vec![2], test_addr(2), DtlsVersion::V1_3));
392
393 let removed = cache.cleanup_expired();
395 assert_eq!(removed, 2);
396 assert!(cache.is_empty());
397 }
398
399 #[test]
400 fn test_cache_update_existing() {
401 let mut cache = DtlsSessionCache::new(10, Duration::from_secs(3600));
402 let addr = test_addr(5683);
403
404 cache.insert(DtlsSession::new(vec![1], addr, DtlsVersion::V1_2));
405 cache.insert(DtlsSession::new(vec![2], addr, DtlsVersion::V1_3));
406
407 assert_eq!(cache.len(), 1);
408 let session = cache.get(&addr).unwrap();
409 assert_eq!(session.session_id(), &[2]);
410 assert_eq!(session.protocol_version(), DtlsVersion::V1_3);
411 }
412}