1use crate::error::{HsmError, HsmResult};
4use crate::providers::HsmProviderConfig;
5use crate::slot::HsmSlot;
6use cryptoki::mechanism::{Mechanism, MechanismType};
7use cryptoki::object::{Attribute, ObjectClass, ObjectHandle};
8use cryptoki::session::Session;
9use cryptoki::types::Ulong;
10use serde::Deserialize;
11use std::collections::HashMap;
12use url::Url;
13
14#[derive(Debug, Clone, Deserialize)]
19pub struct PqcMechanismIds {
20 #[serde(default)]
22 pub ml_dsa_keygen: Option<u64>,
23
24 #[serde(default)]
26 pub ml_dsa_44: Option<u64>,
27
28 #[serde(default)]
30 pub ml_dsa_65: Option<u64>,
31
32 #[serde(default)]
34 pub ml_dsa_87: Option<u64>,
35
36 #[serde(default)]
38 pub ml_kem_keygen: Option<u64>,
39
40 #[serde(default)]
42 pub ml_kem_512: Option<u64>,
43
44 #[serde(default)]
46 pub ml_kem_768: Option<u64>,
47
48 #[serde(default)]
50 pub ml_kem_1024: Option<u64>,
51}
52
53impl Default for PqcMechanismIds {
54 fn default() -> Self {
55 Self {
58 ml_dsa_keygen: Some(0x8000_0001),
59 ml_dsa_44: Some(0x8000_0002),
60 ml_dsa_65: Some(0x8000_0003),
61 ml_dsa_87: Some(0x8000_0004),
62 ml_kem_keygen: Some(0x8000_0010),
63 ml_kem_512: Some(0x8000_0011),
64 ml_kem_768: Some(0x8000_0012),
65 ml_kem_1024: Some(0x8000_0013),
66 }
67 }
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
72pub enum KeyAlgorithm {
73 Rsa(u32),
75 Ecdsa(EcdsaCurve),
77 MlDsa(MlDsaLevel),
79 MlKem(MlKemLevel),
81}
82
83#[derive(Debug, Clone, Copy, PartialEq, Eq)]
85pub enum EcdsaCurve {
86 P256,
87 P384,
88 P521,
89}
90
91impl EcdsaCurve {
92 pub fn oid(&self) -> &[u8] {
94 match self {
95 Self::P256 => &[0x06, 0x08, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x03, 0x01, 0x07],
96 Self::P384 => &[0x06, 0x05, 0x2b, 0x81, 0x04, 0x00, 0x22],
97 Self::P521 => &[0x06, 0x05, 0x2b, 0x81, 0x04, 0x00, 0x23],
98 }
99 }
100}
101
102#[derive(Debug, Clone, Copy, PartialEq, Eq)]
104pub enum MlDsaLevel {
105 L2,
107 L3,
109 L5,
111}
112
113#[derive(Debug, Clone, Copy, PartialEq, Eq)]
115pub enum MlKemLevel {
116 L1,
118 L3,
120 L5,
122}
123
124pub struct HsmKeyPair {
126 session: Session,
128 private_key: ObjectHandle,
130 public_key: ObjectHandle,
132 algorithm: KeyAlgorithm,
134}
135
136impl HsmKeyPair {
137 pub fn generate(
158 slot: &HsmSlot,
159 algorithm: KeyAlgorithm,
160 label: &str,
161 id: &[u8],
162 provider_config: &HsmProviderConfig,
163 pqc_mechanisms: &PqcMechanismIds,
164 ) -> HsmResult<Self> {
165 let session = slot.open_rw_session()?;
166
167 let (private_key, public_key) = match algorithm {
168 KeyAlgorithm::Rsa(bits) => {
169 Self::generate_rsa(&session, bits, label, id, provider_config)?
170 }
171 KeyAlgorithm::Ecdsa(curve) => {
172 Self::generate_ecdsa(&session, curve, label, id, provider_config)?
173 }
174 KeyAlgorithm::MlDsa(level) => {
175 Self::generate_ml_dsa(&session, level, label, id, provider_config, pqc_mechanisms)?
176 }
177 KeyAlgorithm::MlKem(level) => {
178 Self::generate_ml_kem(&session, level, label, id, provider_config, pqc_mechanisms)?
179 }
180 };
181
182 Ok(Self {
183 session,
184 private_key,
185 public_key,
186 algorithm,
187 })
188 }
189
190 fn generate_rsa(
192 session: &Session,
193 bits: u32,
194 label: &str,
195 id: &[u8],
196 config: &HsmProviderConfig,
197 ) -> HsmResult<(ObjectHandle, ObjectHandle)> {
198 if !config
199 .supported_mechanisms
200 .contains(&MechanismType::RSA_PKCS_KEY_PAIR_GEN)
201 {
202 return Err(HsmError::UnsupportedMechanism(
203 "RSA key generation not supported by HSM".to_string(),
204 ));
205 }
206
207 let mechanism = Mechanism::RsaPkcsKeyPairGen;
208
209 let public_key_template = vec![
210 Attribute::Token(true),
211 Attribute::Label(label.as_bytes().to_vec()),
212 Attribute::Id(id.to_vec()),
213 Attribute::Encrypt(true),
214 Attribute::Verify(true),
215 Attribute::ModulusBits(Ulong::from(bits as u64)),
216 Attribute::PublicExponent(vec![0x01, 0x00, 0x01]), ];
218
219 let private_key_template = vec![
220 Attribute::Token(true),
221 Attribute::Label(label.as_bytes().to_vec()),
222 Attribute::Id(id.to_vec()),
223 Attribute::Private(true),
224 Attribute::Sensitive(true), Attribute::Extractable(false), Attribute::Decrypt(true),
227 Attribute::Sign(true),
228 ];
229
230 session
231 .generate_key_pair(&mechanism, &public_key_template, &private_key_template)
232 .map_err(|e| HsmError::KeyGeneration(format!("RSA key generation failed: {e}")))
233 }
234
235 fn generate_ecdsa(
237 session: &Session,
238 curve: EcdsaCurve,
239 label: &str,
240 id: &[u8],
241 config: &HsmProviderConfig,
242 ) -> HsmResult<(ObjectHandle, ObjectHandle)> {
243 if !config
244 .supported_mechanisms
245 .contains(&MechanismType::ECC_KEY_PAIR_GEN)
246 {
247 return Err(HsmError::UnsupportedMechanism(
248 "ECDSA key generation not supported by HSM".to_string(),
249 ));
250 }
251
252 let mechanism = Mechanism::EccKeyPairGen;
253
254 let public_key_template = vec![
255 Attribute::Token(true),
256 Attribute::Label(label.as_bytes().to_vec()),
257 Attribute::Id(id.to_vec()),
258 Attribute::Verify(true),
259 Attribute::EcParams(curve.oid().to_vec()),
260 ];
261
262 let private_key_template = vec![
263 Attribute::Token(true),
264 Attribute::Label(label.as_bytes().to_vec()),
265 Attribute::Id(id.to_vec()),
266 Attribute::Private(true),
267 Attribute::Sensitive(true), Attribute::Extractable(false), Attribute::Sign(true),
270 ];
271
272 session
273 .generate_key_pair(&mechanism, &public_key_template, &private_key_template)
274 .map_err(|e| HsmError::KeyGeneration(format!("ECDSA key generation failed: {e}")))
275 }
276
277 fn generate_ml_dsa(
279 _session: &Session,
280 _level: MlDsaLevel,
281 _label: &str,
282 _id: &[u8],
283 _config: &HsmProviderConfig,
284 pqc_mechanisms: &PqcMechanismIds,
285 ) -> HsmResult<(ObjectHandle, ObjectHandle)> {
286 let mechanism_id = pqc_mechanisms.ml_dsa_keygen.ok_or_else(|| {
287 HsmError::PqcNotSupported("ML-DSA mechanism ID not configured".to_string())
288 })?;
289
290 tracing::warn!(
293 "Attempting ML-DSA key generation with vendor mechanism ID 0x{:08x}",
294 mechanism_id
295 );
296
297 Err(HsmError::PqcNotSupported(
300 "ML-DSA key generation requires vendor-specific PKCS#11 extensions not available in cryptoki 0.7".to_string()
301 ))
302 }
303
304 fn generate_ml_kem(
306 _session: &Session,
307 _level: MlKemLevel,
308 _label: &str,
309 _id: &[u8],
310 _config: &HsmProviderConfig,
311 pqc_mechanisms: &PqcMechanismIds,
312 ) -> HsmResult<(ObjectHandle, ObjectHandle)> {
313 let mechanism_id = pqc_mechanisms.ml_kem_keygen.ok_or_else(|| {
314 HsmError::PqcNotSupported("ML-KEM mechanism ID not configured".to_string())
315 })?;
316
317 tracing::warn!(
318 "Attempting ML-KEM key generation with vendor mechanism ID 0x{:08x}",
319 mechanism_id
320 );
321
322 Err(HsmError::PqcNotSupported(
325 "ML-KEM key generation requires vendor-specific PKCS#11 extensions not available in cryptoki 0.7".to_string()
326 ))
327 }
328
329 pub fn find_by_label(slot: &HsmSlot, label: &str, algorithm: KeyAlgorithm) -> HsmResult<Self> {
331 let session = slot.open_ro_session()?;
332
333 let template = vec![
334 Attribute::Label(label.as_bytes().to_vec()),
335 Attribute::Class(ObjectClass::PRIVATE_KEY),
336 ];
337
338 session.find_objects(&template).map_err(|e| {
339 HsmError::KeyNotFound(format!("Failed to search for key '{label}': {e}"))
340 })?;
341
342 let private_key = session
343 .find_objects(&template)
344 .map_err(|e| HsmError::KeyNotFound(format!("Find operation failed: {e}")))?
345 .into_iter()
346 .next()
347 .ok_or_else(|| HsmError::KeyNotFound(format!("Key '{label}' not found")))?;
348
349 let public_template = vec![
351 Attribute::Label(label.as_bytes().to_vec()),
352 Attribute::Class(ObjectClass::PUBLIC_KEY),
353 ];
354
355 let public_key = session
356 .find_objects(&public_template)
357 .map_err(|e| HsmError::KeyNotFound(format!("Public key search failed: {e}")))?
358 .into_iter()
359 .next()
360 .ok_or_else(|| HsmError::KeyNotFound(format!("Public key '{label}' not found")))?;
361
362 Ok(Self {
363 session,
364 private_key,
365 public_key,
366 algorithm,
367 })
368 }
369
370 pub fn find_by_id(slot: &HsmSlot, id: &[u8], algorithm: KeyAlgorithm) -> HsmResult<Self> {
372 let session = slot.open_ro_session()?;
373
374 let template = vec![
375 Attribute::Id(id.to_vec()),
376 Attribute::Class(ObjectClass::PRIVATE_KEY),
377 ];
378
379 let private_key = session
380 .find_objects(&template)
381 .map_err(|e| HsmError::KeyNotFound(format!("Find operation failed: {e}")))?
382 .into_iter()
383 .next()
384 .ok_or_else(|| {
385 HsmError::KeyNotFound(format!("Key with ID {} not found", hex::encode(id)))
386 })?;
387
388 let public_template = vec![
389 Attribute::Id(id.to_vec()),
390 Attribute::Class(ObjectClass::PUBLIC_KEY),
391 ];
392
393 let public_key = session
394 .find_objects(&public_template)
395 .map_err(|e| HsmError::KeyNotFound(format!("Public key search failed: {e}")))?
396 .into_iter()
397 .next()
398 .ok_or_else(|| {
399 HsmError::KeyNotFound(format!("Public key with ID {} not found", hex::encode(id)))
400 })?;
401
402 Ok(Self {
403 session,
404 private_key,
405 public_key,
406 algorithm,
407 })
408 }
409
410 pub fn from_uri(slot: &HsmSlot, uri: &str, algorithm: KeyAlgorithm) -> HsmResult<Self> {
422 let url = Url::parse(uri).map_err(|e| HsmError::UriParse(e.to_string()))?;
423
424 if url.scheme() != "pkcs11" {
425 return Err(HsmError::UriParse(format!(
426 "Invalid scheme '{}', expected 'pkcs11'",
427 url.scheme()
428 )));
429 }
430
431 let params: HashMap<String, String> = url
433 .query_pairs()
434 .map(|(k, v)| (k.to_string(), v.to_string()))
435 .collect();
436
437 if let Some(id_hex) = params.get("id") {
439 let id = hex::decode(id_hex)
440 .map_err(|e| HsmError::UriParse(format!("Invalid hex ID '{id_hex}': {e}")))?;
441 return Self::find_by_id(slot, &id, algorithm);
442 }
443
444 if let Some(label) = params.get("object") {
446 return Self::find_by_label(slot, label, algorithm);
447 }
448
449 Err(HsmError::UriParse(
450 "URI must contain 'id' or 'object' attribute".to_string(),
451 ))
452 }
453
454 pub fn private_key(&self) -> ObjectHandle {
456 self.private_key
457 }
458
459 pub fn public_key(&self) -> ObjectHandle {
461 self.public_key
462 }
463
464 pub fn session(&self) -> &Session {
466 &self.session
467 }
468
469 pub fn algorithm(&self) -> KeyAlgorithm {
471 self.algorithm
472 }
473}
474
475#[cfg(test)]
476mod tests {
477 use super::*;
478
479 #[test]
480 fn test_ecdsa_curve_oids() {
481 assert_eq!(EcdsaCurve::P256.oid().len(), 10);
482 assert_eq!(EcdsaCurve::P384.oid().len(), 7);
483 assert_eq!(EcdsaCurve::P521.oid().len(), 7);
484 }
485
486 #[test]
487 fn test_pqc_mechanism_ids_default() {
488 let ids = PqcMechanismIds::default();
489 assert!(ids.ml_dsa_keygen.is_some());
490 assert!(ids.ml_kem_keygen.is_some());
491 }
492
493 #[test]
494 fn test_pkcs11_uri_parsing() {
495 let uri = "pkcs11:token=MyToken;object=MyKey;type=private";
496 let url = Url::parse(uri).unwrap();
497 assert_eq!(url.scheme(), "pkcs11");
498 }
499}