1use std::sync::Arc;
24
25use axum::http::header::{AUTHORIZATION, WWW_AUTHENTICATE};
26use axum::http::request::Parts;
27use axum::http::{HeaderValue, StatusCode};
28use axum::response::{IntoResponse, Response};
29use base64::Engine as _;
30use sha2::{Digest, Sha256};
31use subtle::ConstantTimeEq;
32use tracing::{debug, warn};
33
34use super::{AuthMethod, AuthResult};
35use crate::state::AppState;
36
37pub async fn try_extract_otp(
44 parts: &Parts,
45 app: &Arc<AppState>,
46) -> Option<Result<AuthResult, Response>> {
47 let auth_header = parts.headers.get(AUTHORIZATION)?.to_str().ok()?;
48
49 let credentials_b64 = auth_header.strip_prefix("Basic ")?;
51
52 let decoded = match base64::engine::general_purpose::STANDARD.decode(credentials_b64) {
53 Ok(d) => d,
54 Err(_) => {
55 return Some(Err(unauthorized_response("malformed Basic auth encoding")));
56 }
57 };
58
59 if decoded.contains(&0x00) {
61 return Some(Err(unauthorized_response(
62 "Basic auth credentials contain null byte (rejected for security)",
63 )));
64 }
65
66 let credentials = match String::from_utf8(decoded) {
68 Ok(s) => s,
69 Err(_) => {
70 return Some(Err(unauthorized_response(
71 "Basic auth credentials are not valid UTF-8 (RFC 7617 §2.1)",
72 )));
73 }
74 };
75
76 let (entity_id, otp_value) = match credentials.split_once(':') {
78 Some((u, p)) => (u.to_string(), p.to_string()),
79 None => {
80 return Some(Err(unauthorized_response(
81 "malformed Basic auth credentials (missing ':' separator, RFC 7617 §2)",
82 )));
83 }
84 };
85
86 if entity_id.is_empty() {
88 return Some(Err(unauthorized_response(
89 "entity-id must not be empty (RFC 7617 §2)",
90 )));
91 }
92
93 if otp_value.is_empty() {
94 return Some(Err(unauthorized_response("OTP value must not be empty")));
95 }
96
97 debug!(entity_id = %entity_id, "validating OTP for entity");
98
99 match validate_otp(app, &entity_id, &otp_value).await {
101 Ok(()) => {
102 Some(Ok(AuthResult {
104 identity: entity_id,
105 method: AuthMethod::Otp,
106 client_cert_der: None,
107 subject_dn: None,
108 subject_alt_names: Vec::new(),
109 extended_key_usage: Vec::new(),
110 }))
111 }
112 Err(e) => {
113 warn!(entity_id = %entity_id, error = %e, "OTP validation failed");
114
115 app.record_audit_event(
117 "otp_auth_failure",
118 &format!("entity_id={entity_id}, reason={e}"),
119 )
120 .await;
121
122 Some(Err(unauthorized_response("OTP authentication failed")))
123 }
124 }
125}
126
127fn unauthorized_response(detail: &str) -> Response {
133 let mut resp = (StatusCode::UNAUTHORIZED, detail.to_string()).into_response();
134 resp.headers_mut().insert(
135 WWW_AUTHENTICATE,
136 HeaderValue::from_static(kipuka_util::WWW_AUTHENTICATE_BASIC),
137 );
138 resp
139}
140
141#[derive(sqlx::FromRow)]
143struct OtpValidationRow {
144 id: i64,
145 token_hash: String,
146 current_uses: i64,
147 max_uses: i64,
148}
149
150async fn validate_otp(app: &Arc<AppState>, entity_id: &str, otp_value: &str) -> Result<(), String> {
165 let otp_config = &app.config.otp;
167 if !otp_config.enabled {
168 return Err("OTP authentication is not enabled".into());
169 }
170
171 let incoming_hash = hex::encode(Sha256::digest(otp_value.as_bytes()));
173
174 let now = chrono::Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Millis, true);
175
176 let row: OtpValidationRow = sqlx::query_as(crate::db::pg_sql(
180 "SELECT id, token_hash, current_uses, max_uses FROM otp_tokens \
181 WHERE entity_id = ? AND token_hash = ? AND revoked = ? AND expires_at > ? AND current_uses < max_uses",
182 ))
183 .bind(entity_id)
184 .bind(&incoming_hash)
185 .bind(false)
186 .bind(&now)
187 .fetch_optional(&app.db_ro)
188 .await
189 .map_err(|e| format!("database error: {e}"))?
190 .ok_or_else(|| "no valid OTP found for this entity".to_string())?;
191
192 if incoming_hash
197 .as_bytes()
198 .ct_eq(row.token_hash.as_bytes())
199 .unwrap_u8()
200 == 0
201 {
202 return Err("no valid OTP found for this entity".to_string());
203 }
204
205 let result = sqlx::query(crate::db::pg_sql(
208 "UPDATE otp_tokens SET current_uses = current_uses + 1 \
209 WHERE entity_id = ? AND token_hash = ? AND revoked = ? AND expires_at > ? AND current_uses < max_uses",
210 ))
211 .bind(entity_id)
212 .bind(&incoming_hash)
213 .bind(false)
214 .bind(&now)
215 .execute(&app.db)
216 .await
217 .map_err(|e| format!("database error: {e}"))?;
218
219 if result.rows_affected() == 0 {
220 return Err("OTP was consumed by a concurrent request".to_string());
221 }
222
223 debug!(
224 entity_id = %entity_id,
225 otp_id = row.id,
226 current_uses = row.current_uses + 1,
227 max_uses = row.max_uses,
228 "OTP validated and consumed"
229 );
230
231 Ok(())
232}