1use anyhow::{Context, Result, anyhow};
6use const_oid::ObjectIdentifier;
7use der::{Encode, EncodePem};
8use ml_dsa::{
9 EncodedSigningKey, EncodedVerifyingKey, MlDsa44, MlDsa65, MlDsa87, SigningKey, VerifyingKey,
10};
11use pem_rfc7468;
12use pkcs8::{DecodePrivateKey, PrivateKeyInfo};
13use spki::{AssociatedAlgorithmIdentifier, DecodePublicKey, EncodePublicKey};
14use std::convert::{AsRef, TryFrom};
15use std::path::Path;
16
17use super::KeyEncoding;
18use crate::error::HsmError;
19use crate::util::attribute::{AttrData, AttributeMap, AttributeType, KeyType, ObjectClass};
20
21pub enum MldsaSigningKey {
22 V44(Box<SigningKey<MlDsa44>>),
23 V65(Box<SigningKey<MlDsa65>>),
24 V87(Box<SigningKey<MlDsa87>>),
25}
26
27impl MldsaSigningKey {
28 pub fn encode(&self) -> Vec<u8> {
29 match self {
30 Self::V44(k) => k.encode().to_vec(),
31 Self::V65(k) => k.encode().to_vec(),
32 Self::V87(k) => k.encode().to_vec(),
33 }
34 }
35
36 pub fn parameter_set(&self) -> u64 {
37 match self {
38 Self::V44(_) => 1,
39 Self::V65(_) => 2,
40 Self::V87(_) => 3,
41 }
42 }
43
44 pub fn oid(&self) -> ObjectIdentifier {
45 match self {
46 Self::V44(_) => MlDsa44::ALGORITHM_IDENTIFIER.oid,
47 Self::V65(_) => MlDsa65::ALGORITHM_IDENTIFIER.oid,
48 Self::V87(_) => MlDsa87::ALGORITHM_IDENTIFIER.oid,
49 }
50 }
51}
52
53pub enum MldsaVerifyingKey {
54 V44(Box<VerifyingKey<MlDsa44>>),
55 V65(Box<VerifyingKey<MlDsa65>>),
56 V87(Box<VerifyingKey<MlDsa87>>),
57}
58
59impl MldsaVerifyingKey {
60 pub fn encode(&self) -> Vec<u8> {
61 match self {
62 Self::V44(k) => k.encode().to_vec(),
63 Self::V65(k) => k.encode().to_vec(),
64 Self::V87(k) => k.encode().to_vec(),
65 }
66 }
67
68 pub fn parameter_set(&self) -> u64 {
69 match self {
70 Self::V44(_) => 1,
71 Self::V65(_) => 2,
72 Self::V87(_) => 3,
73 }
74 }
75
76 pub fn oid(&self) -> ObjectIdentifier {
77 match self {
78 Self::V44(_) => MlDsa44::ALGORITHM_IDENTIFIER.oid,
79 Self::V65(_) => MlDsa65::ALGORITHM_IDENTIFIER.oid,
80 Self::V87(_) => MlDsa87::ALGORITHM_IDENTIFIER.oid,
81 }
82 }
83}
84
85fn _load_private_key(path: &Path) -> Result<MldsaSigningKey> {
86 let data = std::fs::read(path)?;
87 let der_bytes = if let Ok((_label, bytes)) = pem_rfc7468::decode_vec(&data) {
88 bytes
89 } else {
90 data
91 };
92
93 if let Ok(key) = SigningKey::<MlDsa44>::from_pkcs8_der(&der_bytes) {
94 Ok(MldsaSigningKey::V44(Box::new(key)))
95 } else if let Ok(key) = SigningKey::<MlDsa65>::from_pkcs8_der(&der_bytes) {
96 Ok(MldsaSigningKey::V65(Box::new(key)))
97 } else if let Ok(key) = SigningKey::<MlDsa87>::from_pkcs8_der(&der_bytes) {
98 Ok(MldsaSigningKey::V87(Box::new(key)))
99 } else {
100 Err(anyhow!(
101 "Could not decode MLDSA private key from PKCS#8 DER"
102 ))
103 }
104}
105
106pub fn load_private_key<P: AsRef<Path>>(path: P) -> Result<MldsaSigningKey> {
107 _load_private_key(path.as_ref())
108}
109
110impl TryFrom<&MldsaSigningKey> for AttributeMap {
111 type Error = HsmError;
112 fn try_from(k: &MldsaSigningKey) -> std::result::Result<Self, Self::Error> {
113 let mut attr = AttributeMap::default();
114 attr.insert(
115 AttributeType::Class,
116 AttrData::ObjectClass(ObjectClass::PrivateKey),
117 );
118 attr.insert(AttributeType::KeyType, AttrData::KeyType(KeyType::MlDsa));
119 attr.insert(
120 AttributeType::ParameterSet,
121 AttrData::from(k.parameter_set()),
122 );
123 attr.insert(AttributeType::Value, AttrData::from(k.encode().as_slice()));
124 Ok(attr)
125 }
126}
127
128impl TryFrom<&AttributeMap> for MldsaSigningKey {
129 type Error = HsmError;
130 fn try_from(a: &AttributeMap) -> std::result::Result<Self, Self::Error> {
131 let class: ObjectClass = a
132 .get(&AttributeType::Class)
133 .ok_or_else(|| HsmError::KeyError("missing class".into()))?
134 .try_into()
135 .map_err(HsmError::AttributeError)?;
136 let key_type: KeyType = a
137 .get(&AttributeType::KeyType)
138 .ok_or_else(|| HsmError::KeyError("missing key_type".into()))?
139 .try_into()
140 .map_err(HsmError::AttributeError)?;
141 if class != ObjectClass::PrivateKey || key_type != KeyType::MlDsa {
142 return Err(HsmError::KeyError("bad key type".into()));
143 }
144
145 let value: Vec<u8> = a
146 .get(&AttributeType::Value)
147 .ok_or_else(|| HsmError::KeyError("missing value".into()))?
148 .try_into()
149 .map_err(HsmError::AttributeError)?;
150
151 let parameter_set = a
153 .get(&AttributeType::ParameterSet)
154 .and_then(|d| u64::try_from(d).ok());
155
156 match parameter_set {
157 Some(1) => {
158 let arr = EncodedSigningKey::<MlDsa44>::try_from(value.as_slice())
159 .map_err(|_| HsmError::KeyError("invalid MLDSA-44 key length".into()))?;
160 Ok(MldsaSigningKey::V44(Box::new(
161 SigningKey::<MlDsa44>::decode(&arr),
162 )))
163 }
164 Some(2) => {
165 let arr = EncodedSigningKey::<MlDsa65>::try_from(value.as_slice())
166 .map_err(|_| HsmError::KeyError("invalid MLDSA-65 key length".into()))?;
167 Ok(MldsaSigningKey::V65(Box::new(
168 SigningKey::<MlDsa65>::decode(&arr),
169 )))
170 }
171 Some(3) => {
172 let arr = EncodedSigningKey::<MlDsa87>::try_from(value.as_slice())
173 .map_err(|_| HsmError::KeyError("invalid MLDSA-87 key length".into()))?;
174 Ok(MldsaSigningKey::V87(Box::new(
175 SigningKey::<MlDsa87>::decode(&arr),
176 )))
177 }
178 _ => {
179 if let Ok(arr) = EncodedSigningKey::<MlDsa44>::try_from(value.as_slice()) {
181 Ok(MldsaSigningKey::V44(Box::new(
182 SigningKey::<MlDsa44>::decode(&arr),
183 )))
184 } else if let Ok(arr) = EncodedSigningKey::<MlDsa65>::try_from(value.as_slice()) {
185 Ok(MldsaSigningKey::V65(Box::new(
186 SigningKey::<MlDsa65>::decode(&arr),
187 )))
188 } else if let Ok(arr) = EncodedSigningKey::<MlDsa87>::try_from(value.as_slice()) {
189 Ok(MldsaSigningKey::V87(Box::new(
190 SigningKey::<MlDsa87>::decode(&arr),
191 )))
192 } else {
193 Err(HsmError::KeyError(
194 "Could not decode MLDSA private key".into(),
195 ))
196 }
197 }
198 }
199 }
200}
201
202fn _load_public_key(path: &Path) -> Result<MldsaVerifyingKey> {
203 let data = std::fs::read(path)?;
204 let der_bytes = if let Ok((_label, bytes)) = pem_rfc7468::decode_vec(&data) {
205 bytes
206 } else {
207 data
208 };
209
210 if let Ok(key) = VerifyingKey::<MlDsa44>::from_public_key_der(&der_bytes) {
211 Ok(MldsaVerifyingKey::V44(Box::new(key)))
212 } else if let Ok(key) = VerifyingKey::<MlDsa65>::from_public_key_der(&der_bytes) {
213 Ok(MldsaVerifyingKey::V65(Box::new(key)))
214 } else if let Ok(key) = VerifyingKey::<MlDsa87>::from_public_key_der(&der_bytes) {
215 Ok(MldsaVerifyingKey::V87(Box::new(key)))
216 } else {
217 Err(anyhow!("Could not decode MLDSA public key from SPKI DER"))
218 }
219}
220
221pub fn load_public_key<P: AsRef<Path>>(path: P) -> Result<MldsaVerifyingKey> {
222 _load_public_key(path.as_ref())
223}
224
225impl TryFrom<&MldsaVerifyingKey> for AttributeMap {
226 type Error = HsmError;
227 fn try_from(k: &MldsaVerifyingKey) -> std::result::Result<Self, Self::Error> {
228 let mut attr = AttributeMap::default();
229 attr.insert(
230 AttributeType::Class,
231 AttrData::ObjectClass(ObjectClass::PublicKey),
232 );
233 attr.insert(AttributeType::KeyType, AttrData::KeyType(KeyType::MlDsa));
234 attr.insert(
235 AttributeType::ParameterSet,
236 AttrData::from(k.parameter_set()),
237 );
238 attr.insert(AttributeType::Value, AttrData::from(k.encode().as_slice()));
239 Ok(attr)
240 }
241}
242
243impl TryFrom<&AttributeMap> for MldsaVerifyingKey {
244 type Error = HsmError;
245 fn try_from(a: &AttributeMap) -> std::result::Result<Self, Self::Error> {
246 let class: ObjectClass = a
247 .get(&AttributeType::Class)
248 .ok_or_else(|| HsmError::KeyError("missing class".into()))?
249 .try_into()
250 .map_err(HsmError::AttributeError)?;
251 let key_type: KeyType = a
252 .get(&AttributeType::KeyType)
253 .ok_or_else(|| HsmError::KeyError("missing key_type".into()))?
254 .try_into()
255 .map_err(HsmError::AttributeError)?;
256 if class != ObjectClass::PublicKey || key_type != KeyType::MlDsa {
257 return Err(HsmError::KeyError("bad key type".into()));
258 }
259
260 let value: Vec<u8> = a
261 .get(&AttributeType::Value)
262 .ok_or_else(|| HsmError::KeyError("missing value".into()))?
263 .try_into()
264 .map_err(HsmError::AttributeError)?;
265
266 let parameter_set = a
267 .get(&AttributeType::ParameterSet)
268 .and_then(|d| u64::try_from(d).ok());
269
270 match parameter_set {
271 Some(1) => {
272 let arr = EncodedVerifyingKey::<MlDsa44>::try_from(value.as_slice())
273 .map_err(|_| HsmError::KeyError("invalid MLDSA-44 key length".into()))?;
274 Ok(MldsaVerifyingKey::V44(Box::new(
275 VerifyingKey::<MlDsa44>::decode(&arr),
276 )))
277 }
278 Some(2) => {
279 let arr = EncodedVerifyingKey::<MlDsa65>::try_from(value.as_slice())
280 .map_err(|_| HsmError::KeyError("invalid MLDSA-65 key length".into()))?;
281 Ok(MldsaVerifyingKey::V65(Box::new(
282 VerifyingKey::<MlDsa65>::decode(&arr),
283 )))
284 }
285 Some(3) => {
286 let arr = EncodedVerifyingKey::<MlDsa87>::try_from(value.as_slice())
287 .map_err(|_| HsmError::KeyError("invalid MLDSA-87 key length".into()))?;
288 Ok(MldsaVerifyingKey::V87(Box::new(
289 VerifyingKey::<MlDsa87>::decode(&arr),
290 )))
291 }
292 _ => {
293 if let Ok(arr) = EncodedVerifyingKey::<MlDsa44>::try_from(value.as_slice()) {
294 Ok(MldsaVerifyingKey::V44(Box::new(
295 VerifyingKey::<MlDsa44>::decode(&arr),
296 )))
297 } else if let Ok(arr) = EncodedVerifyingKey::<MlDsa65>::try_from(value.as_slice()) {
298 Ok(MldsaVerifyingKey::V65(Box::new(
299 VerifyingKey::<MlDsa65>::decode(&arr),
300 )))
301 } else if let Ok(arr) = EncodedVerifyingKey::<MlDsa87>::try_from(value.as_slice()) {
302 Ok(MldsaVerifyingKey::V87(Box::new(
303 VerifyingKey::<MlDsa87>::decode(&arr),
304 )))
305 } else {
306 Err(HsmError::KeyError(
307 "Could not decode MLDSA public key".into(),
308 ))
309 }
310 }
311 }
312 }
313}
314
315pub fn save_private_key<P: AsRef<Path>>(
316 path: P,
317 key: &MldsaSigningKey,
318 enc: KeyEncoding,
319) -> Result<()> {
320 let encoded = key.encode();
321 let pk_info = match key {
322 MldsaSigningKey::V44(_) => PrivateKeyInfo::new(MlDsa44::ALGORITHM_IDENTIFIER, &encoded),
323 MldsaSigningKey::V65(_) => PrivateKeyInfo::new(MlDsa65::ALGORITHM_IDENTIFIER, &encoded),
324 MldsaSigningKey::V87(_) => PrivateKeyInfo::new(MlDsa87::ALGORITHM_IDENTIFIER, &encoded),
325 };
326
327 let data = match enc {
328 KeyEncoding::Der | KeyEncoding::Pkcs8Der => pk_info.to_der()?,
329 KeyEncoding::Pem | KeyEncoding::Pkcs8Pem => {
330 pk_info.to_pem(pkcs8::LineEnding::LF)?.as_bytes().to_vec()
331 }
332 _ => return Err(anyhow!("Unsupported format for MLDSA export: {:?}", enc)),
333 };
334 std::fs::write(path, data).context("Saving private key")
335}
336
337pub fn save_public_key<P: AsRef<Path>>(
338 path: P,
339 key: &MldsaVerifyingKey,
340 enc: KeyEncoding,
341) -> Result<()> {
342 let data = match key {
343 MldsaVerifyingKey::V44(k) => match enc {
344 KeyEncoding::Der | KeyEncoding::Pkcs8Der => k.to_public_key_der()?.as_bytes().to_vec(),
345 KeyEncoding::Pem | KeyEncoding::Pkcs8Pem => k
346 .to_public_key_pem(pkcs8::LineEnding::LF)?
347 .as_bytes()
348 .to_vec(),
349 _ => return Err(anyhow!("Unsupported format for MLDSA export")),
350 },
351 MldsaVerifyingKey::V65(k) => match enc {
352 KeyEncoding::Der | KeyEncoding::Pkcs8Der => k.to_public_key_der()?.as_bytes().to_vec(),
353 KeyEncoding::Pem | KeyEncoding::Pkcs8Pem => k
354 .to_public_key_pem(pkcs8::LineEnding::LF)?
355 .as_bytes()
356 .to_vec(),
357 _ => return Err(anyhow!("Unsupported format for MLDSA export")),
358 },
359 MldsaVerifyingKey::V87(k) => match enc {
360 KeyEncoding::Der | KeyEncoding::Pkcs8Der => k.to_public_key_der()?.as_bytes().to_vec(),
361 KeyEncoding::Pem | KeyEncoding::Pkcs8Pem => k
362 .to_public_key_pem(pkcs8::LineEnding::LF)?
363 .as_bytes()
364 .to_vec(),
365 _ => return Err(anyhow!("Unsupported format for MLDSA export")),
366 },
367 };
368 std::fs::write(path, data).context("Saving public key")
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374 use ml_dsa::{KeyGen, MlDsa44, MlDsa65, MlDsa87};
375 use rand::thread_rng;
376
377 #[test]
378 fn test_mldsa44_convert() -> Result<()> {
379 let mut rng = thread_rng();
380 let kp = MlDsa44::key_gen(&mut rng);
381 let key = MldsaSigningKey::V44(Box::new(kp.signing_key().clone()));
382
383 let hsm = AttributeMap::try_from(&key)?;
384 let key2 = MldsaSigningKey::try_from(&hsm)?;
385
386 assert_eq!(key.encode(), key2.encode());
387 assert_eq!(key.parameter_set(), 1);
388
389 let vk = MldsaVerifyingKey::V44(Box::new(kp.verifying_key().clone()));
390 let hsm_pub = AttributeMap::try_from(&vk)?;
391 let vk2 = MldsaVerifyingKey::try_from(&hsm_pub)?;
392
393 assert_eq!(vk.encode(), vk2.encode());
394 assert_eq!(vk.parameter_set(), 1);
395 Ok(())
396 }
397
398 #[test]
399 fn test_mldsa65_convert() -> Result<()> {
400 let mut rng = thread_rng();
401 let kp = MlDsa65::key_gen(&mut rng);
402 let key = MldsaSigningKey::V65(Box::new(kp.signing_key().clone()));
403
404 let hsm = AttributeMap::try_from(&key)?;
405 let key2 = MldsaSigningKey::try_from(&hsm)?;
406
407 assert_eq!(key.encode(), key2.encode());
408 assert_eq!(key.parameter_set(), 2);
409
410 let vk = MldsaVerifyingKey::V65(Box::new(kp.verifying_key().clone()));
411 let hsm_pub = AttributeMap::try_from(&vk)?;
412 let vk2 = MldsaVerifyingKey::try_from(&hsm_pub)?;
413
414 assert_eq!(vk.encode(), vk2.encode());
415 assert_eq!(vk.parameter_set(), 2);
416 Ok(())
417 }
418
419 #[test]
420 fn test_mldsa87_convert() -> Result<()> {
421 let mut rng = thread_rng();
422 let kp = MlDsa87::key_gen(&mut rng);
423 let key = MldsaSigningKey::V87(Box::new(kp.signing_key().clone()));
424
425 let hsm = AttributeMap::try_from(&key)?;
426 let key2 = MldsaSigningKey::try_from(&hsm)?;
427
428 assert_eq!(key.encode(), key2.encode());
429 assert_eq!(key.parameter_set(), 3);
430
431 let vk = MldsaVerifyingKey::V87(Box::new(kp.verifying_key().clone()));
432 let hsm_pub = AttributeMap::try_from(&vk)?;
433 let vk2 = MldsaVerifyingKey::try_from(&hsm_pub)?;
434
435 assert_eq!(vk.encode(), vk2.encode());
436 assert_eq!(vk.parameter_set(), 3);
437 Ok(())
438 }
439}