hsmtool/util/key/
mldsa.rs

1// Copyright lowRISC contributors (OpenTitan project).
2// Licensed under the Apache License, Version 2.0, see LICENSE for details.
3// SPDX-License-Identifier: Apache-2.0
4
5use anyhow::{Context, Result, anyhow};
6use const_oid::ObjectIdentifier;
7use der::{Encode, EncodePem};
8use ml_dsa::{
9    EncodedSigningKey, EncodedVerifyingKey, MlDsa44, MlDsa65, MlDsa87, SigningKey, VerifyingKey,
10};
11use pem_rfc7468;
12use pkcs8::{DecodePrivateKey, PrivateKeyInfo};
13use spki::{AssociatedAlgorithmIdentifier, DecodePublicKey, EncodePublicKey};
14use std::convert::{AsRef, TryFrom};
15use std::path::Path;
16
17use super::KeyEncoding;
18use crate::error::HsmError;
19use crate::util::attribute::{AttrData, AttributeMap, AttributeType, KeyType, ObjectClass};
20
21pub enum MldsaSigningKey {
22    V44(Box<SigningKey<MlDsa44>>),
23    V65(Box<SigningKey<MlDsa65>>),
24    V87(Box<SigningKey<MlDsa87>>),
25}
26
27impl MldsaSigningKey {
28    pub fn encode(&self) -> Vec<u8> {
29        match self {
30            Self::V44(k) => k.encode().to_vec(),
31            Self::V65(k) => k.encode().to_vec(),
32            Self::V87(k) => k.encode().to_vec(),
33        }
34    }
35
36    pub fn parameter_set(&self) -> u64 {
37        match self {
38            Self::V44(_) => 1,
39            Self::V65(_) => 2,
40            Self::V87(_) => 3,
41        }
42    }
43
44    pub fn oid(&self) -> ObjectIdentifier {
45        match self {
46            Self::V44(_) => MlDsa44::ALGORITHM_IDENTIFIER.oid,
47            Self::V65(_) => MlDsa65::ALGORITHM_IDENTIFIER.oid,
48            Self::V87(_) => MlDsa87::ALGORITHM_IDENTIFIER.oid,
49        }
50    }
51}
52
53pub enum MldsaVerifyingKey {
54    V44(Box<VerifyingKey<MlDsa44>>),
55    V65(Box<VerifyingKey<MlDsa65>>),
56    V87(Box<VerifyingKey<MlDsa87>>),
57}
58
59impl MldsaVerifyingKey {
60    pub fn encode(&self) -> Vec<u8> {
61        match self {
62            Self::V44(k) => k.encode().to_vec(),
63            Self::V65(k) => k.encode().to_vec(),
64            Self::V87(k) => k.encode().to_vec(),
65        }
66    }
67
68    pub fn parameter_set(&self) -> u64 {
69        match self {
70            Self::V44(_) => 1,
71            Self::V65(_) => 2,
72            Self::V87(_) => 3,
73        }
74    }
75
76    pub fn oid(&self) -> ObjectIdentifier {
77        match self {
78            Self::V44(_) => MlDsa44::ALGORITHM_IDENTIFIER.oid,
79            Self::V65(_) => MlDsa65::ALGORITHM_IDENTIFIER.oid,
80            Self::V87(_) => MlDsa87::ALGORITHM_IDENTIFIER.oid,
81        }
82    }
83}
84
85fn _load_private_key(path: &Path) -> Result<MldsaSigningKey> {
86    let data = std::fs::read(path)?;
87    let der_bytes = if let Ok((_label, bytes)) = pem_rfc7468::decode_vec(&data) {
88        bytes
89    } else {
90        data
91    };
92
93    if let Ok(key) = SigningKey::<MlDsa44>::from_pkcs8_der(&der_bytes) {
94        Ok(MldsaSigningKey::V44(Box::new(key)))
95    } else if let Ok(key) = SigningKey::<MlDsa65>::from_pkcs8_der(&der_bytes) {
96        Ok(MldsaSigningKey::V65(Box::new(key)))
97    } else if let Ok(key) = SigningKey::<MlDsa87>::from_pkcs8_der(&der_bytes) {
98        Ok(MldsaSigningKey::V87(Box::new(key)))
99    } else {
100        Err(anyhow!(
101            "Could not decode MLDSA private key from PKCS#8 DER"
102        ))
103    }
104}
105
106pub fn load_private_key<P: AsRef<Path>>(path: P) -> Result<MldsaSigningKey> {
107    _load_private_key(path.as_ref())
108}
109
110impl TryFrom<&MldsaSigningKey> for AttributeMap {
111    type Error = HsmError;
112    fn try_from(k: &MldsaSigningKey) -> std::result::Result<Self, Self::Error> {
113        let mut attr = AttributeMap::default();
114        attr.insert(
115            AttributeType::Class,
116            AttrData::ObjectClass(ObjectClass::PrivateKey),
117        );
118        attr.insert(AttributeType::KeyType, AttrData::KeyType(KeyType::MlDsa));
119        attr.insert(
120            AttributeType::ParameterSet,
121            AttrData::from(k.parameter_set()),
122        );
123        attr.insert(AttributeType::Value, AttrData::from(k.encode().as_slice()));
124        Ok(attr)
125    }
126}
127
128impl TryFrom<&AttributeMap> for MldsaSigningKey {
129    type Error = HsmError;
130    fn try_from(a: &AttributeMap) -> std::result::Result<Self, Self::Error> {
131        let class: ObjectClass = a
132            .get(&AttributeType::Class)
133            .ok_or_else(|| HsmError::KeyError("missing class".into()))?
134            .try_into()
135            .map_err(HsmError::AttributeError)?;
136        let key_type: KeyType = a
137            .get(&AttributeType::KeyType)
138            .ok_or_else(|| HsmError::KeyError("missing key_type".into()))?
139            .try_into()
140            .map_err(HsmError::AttributeError)?;
141        if class != ObjectClass::PrivateKey || key_type != KeyType::MlDsa {
142            return Err(HsmError::KeyError("bad key type".into()));
143        }
144
145        let value: Vec<u8> = a
146            .get(&AttributeType::Value)
147            .ok_or_else(|| HsmError::KeyError("missing value".into()))?
148            .try_into()
149            .map_err(HsmError::AttributeError)?;
150
151        // Try to determine the parameter set from AttributeType::ParameterSet if available
152        let parameter_set = a
153            .get(&AttributeType::ParameterSet)
154            .and_then(|d| u64::try_from(d).ok());
155
156        match parameter_set {
157            Some(1) => {
158                let arr = EncodedSigningKey::<MlDsa44>::try_from(value.as_slice())
159                    .map_err(|_| HsmError::KeyError("invalid MLDSA-44 key length".into()))?;
160                Ok(MldsaSigningKey::V44(Box::new(
161                    SigningKey::<MlDsa44>::decode(&arr),
162                )))
163            }
164            Some(2) => {
165                let arr = EncodedSigningKey::<MlDsa65>::try_from(value.as_slice())
166                    .map_err(|_| HsmError::KeyError("invalid MLDSA-65 key length".into()))?;
167                Ok(MldsaSigningKey::V65(Box::new(
168                    SigningKey::<MlDsa65>::decode(&arr),
169                )))
170            }
171            Some(3) => {
172                let arr = EncodedSigningKey::<MlDsa87>::try_from(value.as_slice())
173                    .map_err(|_| HsmError::KeyError("invalid MLDSA-87 key length".into()))?;
174                Ok(MldsaSigningKey::V87(Box::new(
175                    SigningKey::<MlDsa87>::decode(&arr),
176                )))
177            }
178            _ => {
179                // If parameter set is missing or unknown, try to guess from length
180                if let Ok(arr) = EncodedSigningKey::<MlDsa44>::try_from(value.as_slice()) {
181                    Ok(MldsaSigningKey::V44(Box::new(
182                        SigningKey::<MlDsa44>::decode(&arr),
183                    )))
184                } else if let Ok(arr) = EncodedSigningKey::<MlDsa65>::try_from(value.as_slice()) {
185                    Ok(MldsaSigningKey::V65(Box::new(
186                        SigningKey::<MlDsa65>::decode(&arr),
187                    )))
188                } else if let Ok(arr) = EncodedSigningKey::<MlDsa87>::try_from(value.as_slice()) {
189                    Ok(MldsaSigningKey::V87(Box::new(
190                        SigningKey::<MlDsa87>::decode(&arr),
191                    )))
192                } else {
193                    Err(HsmError::KeyError(
194                        "Could not decode MLDSA private key".into(),
195                    ))
196                }
197            }
198        }
199    }
200}
201
202fn _load_public_key(path: &Path) -> Result<MldsaVerifyingKey> {
203    let data = std::fs::read(path)?;
204    let der_bytes = if let Ok((_label, bytes)) = pem_rfc7468::decode_vec(&data) {
205        bytes
206    } else {
207        data
208    };
209
210    if let Ok(key) = VerifyingKey::<MlDsa44>::from_public_key_der(&der_bytes) {
211        Ok(MldsaVerifyingKey::V44(Box::new(key)))
212    } else if let Ok(key) = VerifyingKey::<MlDsa65>::from_public_key_der(&der_bytes) {
213        Ok(MldsaVerifyingKey::V65(Box::new(key)))
214    } else if let Ok(key) = VerifyingKey::<MlDsa87>::from_public_key_der(&der_bytes) {
215        Ok(MldsaVerifyingKey::V87(Box::new(key)))
216    } else {
217        Err(anyhow!("Could not decode MLDSA public key from SPKI DER"))
218    }
219}
220
221pub fn load_public_key<P: AsRef<Path>>(path: P) -> Result<MldsaVerifyingKey> {
222    _load_public_key(path.as_ref())
223}
224
225impl TryFrom<&MldsaVerifyingKey> for AttributeMap {
226    type Error = HsmError;
227    fn try_from(k: &MldsaVerifyingKey) -> std::result::Result<Self, Self::Error> {
228        let mut attr = AttributeMap::default();
229        attr.insert(
230            AttributeType::Class,
231            AttrData::ObjectClass(ObjectClass::PublicKey),
232        );
233        attr.insert(AttributeType::KeyType, AttrData::KeyType(KeyType::MlDsa));
234        attr.insert(
235            AttributeType::ParameterSet,
236            AttrData::from(k.parameter_set()),
237        );
238        attr.insert(AttributeType::Value, AttrData::from(k.encode().as_slice()));
239        Ok(attr)
240    }
241}
242
243impl TryFrom<&AttributeMap> for MldsaVerifyingKey {
244    type Error = HsmError;
245    fn try_from(a: &AttributeMap) -> std::result::Result<Self, Self::Error> {
246        let class: ObjectClass = a
247            .get(&AttributeType::Class)
248            .ok_or_else(|| HsmError::KeyError("missing class".into()))?
249            .try_into()
250            .map_err(HsmError::AttributeError)?;
251        let key_type: KeyType = a
252            .get(&AttributeType::KeyType)
253            .ok_or_else(|| HsmError::KeyError("missing key_type".into()))?
254            .try_into()
255            .map_err(HsmError::AttributeError)?;
256        if class != ObjectClass::PublicKey || key_type != KeyType::MlDsa {
257            return Err(HsmError::KeyError("bad key type".into()));
258        }
259
260        let value: Vec<u8> = a
261            .get(&AttributeType::Value)
262            .ok_or_else(|| HsmError::KeyError("missing value".into()))?
263            .try_into()
264            .map_err(HsmError::AttributeError)?;
265
266        let parameter_set = a
267            .get(&AttributeType::ParameterSet)
268            .and_then(|d| u64::try_from(d).ok());
269
270        match parameter_set {
271            Some(1) => {
272                let arr = EncodedVerifyingKey::<MlDsa44>::try_from(value.as_slice())
273                    .map_err(|_| HsmError::KeyError("invalid MLDSA-44 key length".into()))?;
274                Ok(MldsaVerifyingKey::V44(Box::new(
275                    VerifyingKey::<MlDsa44>::decode(&arr),
276                )))
277            }
278            Some(2) => {
279                let arr = EncodedVerifyingKey::<MlDsa65>::try_from(value.as_slice())
280                    .map_err(|_| HsmError::KeyError("invalid MLDSA-65 key length".into()))?;
281                Ok(MldsaVerifyingKey::V65(Box::new(
282                    VerifyingKey::<MlDsa65>::decode(&arr),
283                )))
284            }
285            Some(3) => {
286                let arr = EncodedVerifyingKey::<MlDsa87>::try_from(value.as_slice())
287                    .map_err(|_| HsmError::KeyError("invalid MLDSA-87 key length".into()))?;
288                Ok(MldsaVerifyingKey::V87(Box::new(
289                    VerifyingKey::<MlDsa87>::decode(&arr),
290                )))
291            }
292            _ => {
293                if let Ok(arr) = EncodedVerifyingKey::<MlDsa44>::try_from(value.as_slice()) {
294                    Ok(MldsaVerifyingKey::V44(Box::new(
295                        VerifyingKey::<MlDsa44>::decode(&arr),
296                    )))
297                } else if let Ok(arr) = EncodedVerifyingKey::<MlDsa65>::try_from(value.as_slice()) {
298                    Ok(MldsaVerifyingKey::V65(Box::new(
299                        VerifyingKey::<MlDsa65>::decode(&arr),
300                    )))
301                } else if let Ok(arr) = EncodedVerifyingKey::<MlDsa87>::try_from(value.as_slice()) {
302                    Ok(MldsaVerifyingKey::V87(Box::new(
303                        VerifyingKey::<MlDsa87>::decode(&arr),
304                    )))
305                } else {
306                    Err(HsmError::KeyError(
307                        "Could not decode MLDSA public key".into(),
308                    ))
309                }
310            }
311        }
312    }
313}
314
315pub fn save_private_key<P: AsRef<Path>>(
316    path: P,
317    key: &MldsaSigningKey,
318    enc: KeyEncoding,
319) -> Result<()> {
320    let encoded = key.encode();
321    let pk_info = match key {
322        MldsaSigningKey::V44(_) => PrivateKeyInfo::new(MlDsa44::ALGORITHM_IDENTIFIER, &encoded),
323        MldsaSigningKey::V65(_) => PrivateKeyInfo::new(MlDsa65::ALGORITHM_IDENTIFIER, &encoded),
324        MldsaSigningKey::V87(_) => PrivateKeyInfo::new(MlDsa87::ALGORITHM_IDENTIFIER, &encoded),
325    };
326
327    let data = match enc {
328        KeyEncoding::Der | KeyEncoding::Pkcs8Der => pk_info.to_der()?,
329        KeyEncoding::Pem | KeyEncoding::Pkcs8Pem => {
330            pk_info.to_pem(pkcs8::LineEnding::LF)?.as_bytes().to_vec()
331        }
332        _ => return Err(anyhow!("Unsupported format for MLDSA export: {:?}", enc)),
333    };
334    std::fs::write(path, data).context("Saving private key")
335}
336
337pub fn save_public_key<P: AsRef<Path>>(
338    path: P,
339    key: &MldsaVerifyingKey,
340    enc: KeyEncoding,
341) -> Result<()> {
342    let data = match key {
343        MldsaVerifyingKey::V44(k) => match enc {
344            KeyEncoding::Der | KeyEncoding::Pkcs8Der => k.to_public_key_der()?.as_bytes().to_vec(),
345            KeyEncoding::Pem | KeyEncoding::Pkcs8Pem => k
346                .to_public_key_pem(pkcs8::LineEnding::LF)?
347                .as_bytes()
348                .to_vec(),
349            _ => return Err(anyhow!("Unsupported format for MLDSA export")),
350        },
351        MldsaVerifyingKey::V65(k) => match enc {
352            KeyEncoding::Der | KeyEncoding::Pkcs8Der => k.to_public_key_der()?.as_bytes().to_vec(),
353            KeyEncoding::Pem | KeyEncoding::Pkcs8Pem => k
354                .to_public_key_pem(pkcs8::LineEnding::LF)?
355                .as_bytes()
356                .to_vec(),
357            _ => return Err(anyhow!("Unsupported format for MLDSA export")),
358        },
359        MldsaVerifyingKey::V87(k) => match enc {
360            KeyEncoding::Der | KeyEncoding::Pkcs8Der => k.to_public_key_der()?.as_bytes().to_vec(),
361            KeyEncoding::Pem | KeyEncoding::Pkcs8Pem => k
362                .to_public_key_pem(pkcs8::LineEnding::LF)?
363                .as_bytes()
364                .to_vec(),
365            _ => return Err(anyhow!("Unsupported format for MLDSA export")),
366        },
367    };
368    std::fs::write(path, data).context("Saving public key")
369}
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374    use ml_dsa::{KeyGen, MlDsa44, MlDsa65, MlDsa87};
375    use rand::thread_rng;
376
377    #[test]
378    fn test_mldsa44_convert() -> Result<()> {
379        let mut rng = thread_rng();
380        let kp = MlDsa44::key_gen(&mut rng);
381        let key = MldsaSigningKey::V44(Box::new(kp.signing_key().clone()));
382
383        let hsm = AttributeMap::try_from(&key)?;
384        let key2 = MldsaSigningKey::try_from(&hsm)?;
385
386        assert_eq!(key.encode(), key2.encode());
387        assert_eq!(key.parameter_set(), 1);
388
389        let vk = MldsaVerifyingKey::V44(Box::new(kp.verifying_key().clone()));
390        let hsm_pub = AttributeMap::try_from(&vk)?;
391        let vk2 = MldsaVerifyingKey::try_from(&hsm_pub)?;
392
393        assert_eq!(vk.encode(), vk2.encode());
394        assert_eq!(vk.parameter_set(), 1);
395        Ok(())
396    }
397
398    #[test]
399    fn test_mldsa65_convert() -> Result<()> {
400        let mut rng = thread_rng();
401        let kp = MlDsa65::key_gen(&mut rng);
402        let key = MldsaSigningKey::V65(Box::new(kp.signing_key().clone()));
403
404        let hsm = AttributeMap::try_from(&key)?;
405        let key2 = MldsaSigningKey::try_from(&hsm)?;
406
407        assert_eq!(key.encode(), key2.encode());
408        assert_eq!(key.parameter_set(), 2);
409
410        let vk = MldsaVerifyingKey::V65(Box::new(kp.verifying_key().clone()));
411        let hsm_pub = AttributeMap::try_from(&vk)?;
412        let vk2 = MldsaVerifyingKey::try_from(&hsm_pub)?;
413
414        assert_eq!(vk.encode(), vk2.encode());
415        assert_eq!(vk.parameter_set(), 2);
416        Ok(())
417    }
418
419    #[test]
420    fn test_mldsa87_convert() -> Result<()> {
421        let mut rng = thread_rng();
422        let kp = MlDsa87::key_gen(&mut rng);
423        let key = MldsaSigningKey::V87(Box::new(kp.signing_key().clone()));
424
425        let hsm = AttributeMap::try_from(&key)?;
426        let key2 = MldsaSigningKey::try_from(&hsm)?;
427
428        assert_eq!(key.encode(), key2.encode());
429        assert_eq!(key.parameter_set(), 3);
430
431        let vk = MldsaVerifyingKey::V87(Box::new(kp.verifying_key().clone()));
432        let hsm_pub = AttributeMap::try_from(&vk)?;
433        let vk2 = MldsaVerifyingKey::try_from(&hsm_pub)?;
434
435        assert_eq!(vk.encode(), vk2.encode());
436        assert_eq!(vk.parameter_set(), 3);
437        Ok(())
438    }
439}