hsmtool/extra/
spxkms.rs

1// Copyright lowRISC contributors (OpenTitan project).
2// Licensed under the Apache License, Version 2.0, see LICENSE for details.
3// SPDX-License-Identifier: Apache-2.0
4
5use acorn::{GenerateFlags, KeyEntry, KeyInfo, SpxInterface};
6use anyhow::{Context, Result, anyhow};
7use base64ct::{Base64, Encoding};
8use indexmap::IndexMap;
9use serde::{Deserialize, Serialize, de::DeserializeOwned};
10use serde_json::Value;
11use sphincsplus::{SphincsPlus, SpxDomain, SpxPublicKey};
12use std::process::Command;
13use std::str::FromStr;
14use thiserror::Error;
15use zeroize::Zeroizing;
16
17use reqwest::blocking::Client;
18use reqwest::{IntoUrl, Url};
19
20use crate::error::HsmError;
21
22/// SpxEf implements SPHINCS+ signing via Google CloudKms.
23pub struct SpxKms {
24    keyring: Url,
25    project: String,
26    auth: Zeroizing<String>,
27}
28
29/// ApiError represents an error result from the cloud API.
30#[derive(Deserialize, Debug, Error)]
31#[error("api error: code={code} message={message:?}; details={details:?}")]
32#[serde(rename_all = "camelCase")]
33pub struct ApiError {
34    pub code: u32,
35    pub message: String,
36    pub status: String,
37    #[serde(flatten)]
38    pub details: IndexMap<String, Value>,
39}
40
41// CloudResult assists in deserializing the cloud API return into an error
42// or a specific type.
43#[derive(Deserialize, Debug)]
44enum CloudResult<T> {
45    #[serde(rename = "error")]
46    Error(ApiError),
47    #[serde(untagged)]
48    Ok(T),
49}
50
51#[derive(Deserialize, Debug, Clone)]
52#[serde(rename_all = "camelCase")]
53struct KmsKeyList {
54    crypto_keys: Vec<KmsKeyRef>,
55}
56
57// Note: dead_code is allowed because some of the fields defined in this struct
58// are not used, but are fields returned by the KMS json API.
59#[allow(dead_code)]
60#[derive(Serialize, Deserialize, Debug, Clone)]
61#[serde(rename_all = "camelCase")]
62struct VersionTemplate {
63    #[serde(default)]
64    protection_level: String,
65    #[serde(default)]
66    algorithm: String,
67}
68
69#[derive(Serialize, Deserialize, Debug, Clone)]
70#[serde(rename_all = "camelCase")]
71struct KmsCreateKey {
72    purpose: String,
73    version_template: VersionTemplate,
74}
75
76#[derive(Deserialize, Debug, Clone)]
77#[serde(rename_all = "camelCase")]
78struct KmsKeyRef {
79    name: String,
80    version_template: VersionTemplate,
81}
82
83#[derive(Deserialize, Debug, Clone)]
84#[serde(rename_all = "camelCase")]
85struct KmsKeyVersion {
86    name: String,
87    state: String,
88    #[serde(default)]
89    algorithm: String,
90}
91
92#[derive(Deserialize, Debug, Clone)]
93#[serde(rename_all = "camelCase")]
94struct KmsKeyVersions {
95    crypto_key_versions: Vec<KmsKeyVersion>,
96}
97
98#[derive(Deserialize, Debug, Clone)]
99#[serde(rename_all = "camelCase")]
100struct KmsPublicKeyData {
101    data: String,
102}
103
104// Note: dead_code is allowed because some of the fields defined in this struct
105// are not used, but are fields returned by the KMS json API.
106#[allow(dead_code)]
107#[derive(Deserialize, Debug, Clone)]
108#[serde(rename_all = "camelCase")]
109struct KmsPublicKey {
110    algorithm: String,
111    name: String,
112    #[serde(default)]
113    protection_level: String,
114    #[serde(default)]
115    public_key_format: String,
116    pem: Option<String>,
117    public_key: Option<KmsPublicKeyData>,
118}
119
120#[derive(Serialize, Debug)]
121struct KmsDigest {
122    sha256: String,
123}
124
125#[derive(Serialize, Debug, Default)]
126struct KmsSignRequest {
127    #[serde(skip_serializing_if = "Option::is_none")]
128    digest: Option<KmsDigest>,
129    #[serde(skip_serializing_if = "Option::is_none")]
130    data: Option<String>,
131}
132
133impl SpxKms {
134    const PURE_ALGORITHM: &'static str = "PQ_SIGN_SLH_DSA_SHA2_128S";
135    const PREHASH_ALGORITHM: &'static str = "PQ_SIGN_HASH_SLH_DSA_SHA2_128S_SHA256";
136
137    pub fn new(parameters: &str) -> Result<Box<Self>> {
138        let output = Command::new("gcloud")
139            .args(["auth", "print-access-token"])
140            .output()?;
141        if output.status.success() {
142            // Get the authorization token and strip trailing newlines.
143            let mut auth = String::from_utf8(output.stdout)?;
144            let len = auth.trim_end().len();
145            auth.truncate(len);
146
147            let mut params = IndexMap::new();
148            params.extend(parameters.split(':').map(|p| {
149                p.split_once('=')
150                    .expect("KMS parameters should be key=value")
151            }));
152
153            let project = params.get("project").ok_or(HsmError::Unsupported(
154                "KMS requires a project parameter".into(),
155            ))?;
156            let location = params.get("location").ok_or(HsmError::Unsupported(
157                "KMS requires a location parameter".into(),
158            ))?;
159            let keyring = params.get("keyring").ok_or(HsmError::Unsupported(
160                "KMS requires a keyring parameter".into(),
161            ))?;
162            let url = format!(
163                "https://cloudkms.googleapis.com/v1/projects/{project}/locations/{location}/keyRings/{keyring}/"
164            );
165            log::info!("keyring url: {url}");
166            Ok(Box::new(Self {
167                keyring: Url::parse(&url)?,
168                project: project.to_string(),
169                auth: auth.into(),
170            }))
171        } else {
172            let stderr = String::from_utf8_lossy(&output.stderr);
173            Err(anyhow!("gcloud error {:?}: {}", output.status, stderr))
174        }
175    }
176
177    fn kms_to_algorithm(kms_algo: &str) -> Result<String> {
178        match kms_algo {
179            Self::PURE_ALGORITHM | Self::PREHASH_ALGORITHM => Ok("SLH-DSA-SHA2-128S".into()),
180            _ => Err(HsmError::Unsupported(format!("algorithm {kms_algo}")).into()),
181        }
182    }
183
184    fn kms_to_domain(kms_algo: &str) -> Result<SpxDomain> {
185        match kms_algo {
186            Self::PURE_ALGORITHM => Ok(SpxDomain::Pure),
187            Self::PREHASH_ALGORITHM => Ok(SpxDomain::PreHashedSha256),
188            _ => Err(HsmError::Unsupported(format!("algorithm {kms_algo}")).into()),
189        }
190    }
191
192    fn get<RSP: DeserializeOwned>(&self, url: impl IntoUrl) -> Result<RSP> {
193        let client = Client::new();
194        log::debug!("GET {}", url.as_str());
195        let resp = client
196            .get(url)
197            .bearer_auth(&*self.auth)
198            .header("content-type", "application/json")
199            .header("X-Goog-User-Project", &self.project)
200            .send()?;
201        let data = resp.text()?;
202        log::debug!("data: {data}");
203        match serde_json::from_str::<CloudResult<RSP>>(&data)? {
204            CloudResult::Error(e) => Err(e.into()),
205            CloudResult::Ok(v) => Ok(v),
206        }
207    }
208
209    fn post<RSP: DeserializeOwned>(&self, url: impl IntoUrl, req: &impl Serialize) -> Result<RSP> {
210        let client = Client::new();
211        log::debug!("POST {}", url.as_str());
212        let resp = client
213            .post(url)
214            .bearer_auth(&*self.auth)
215            .header("content-type", "application/json")
216            .header("X-Goog-User-Project", &self.project)
217            .json(req)
218            .send()?;
219        let data = resp.text()?;
220        log::debug!("data: {data}");
221        match serde_json::from_str::<CloudResult<RSP>>(&data)? {
222            CloudResult::Error(e) => Err(e.into()),
223            CloudResult::Ok(v) => Ok(v),
224        }
225    }
226
227    fn get_key_version(&self, alias: &str) -> Result<KmsKeyVersion> {
228        let url = self
229            .keyring
230            .join(&format!("cryptoKeys/{alias}/cryptoKeyVersions"))?;
231        let versions = self.get::<KmsKeyVersions>(url)?;
232        match versions
233            .crypto_key_versions
234            .iter()
235            .filter(|v| v.state == "ENABLED" && Self::kms_to_algorithm(&v.algorithm).is_ok())
236            .next_back()
237        {
238            Some(key) => Ok(key.clone()),
239            None => Err(HsmError::ObjectNotFound(alias.into()).into()),
240        }
241    }
242
243    fn get_public_key(&self, alias: &str) -> Result<KmsPublicKey> {
244        let key = self.get_key_version(alias)?;
245        let mut url = self.keyring.join(&format!("/v1/{}/publicKey", key.name))?;
246        url.set_query(Some("public_key_format=NIST_PQC"));
247        self.get(url)
248    }
249}
250
251impl SpxInterface for SpxKms {
252    /// Get the version of the backend.
253    fn get_version(&self) -> Result<String> {
254        Ok(String::from("CloudKMS 0.0.1"))
255    }
256
257    /// List keys known to the backend.
258    fn list_keys(&self) -> Result<Vec<KeyEntry>> {
259        let keys = self.keyring.join("cryptoKeys")?;
260        let keys = self.get::<KmsKeyList>(keys)?;
261        let mut result = Vec::new();
262
263        for k in keys.crypto_keys.iter() {
264            let (_, name) = k
265                .name
266                .rsplit_once('/')
267                .ok_or_else(|| HsmError::ParseError("could not parse key name".into()))
268                .with_context(|| format!("key name {:?}", k.name))?;
269            if Self::kms_to_algorithm(&k.version_template.algorithm).is_err() {
270                continue;
271            }
272            let key = self.get_key_version(name)?;
273            result.push(KeyEntry {
274                alias: name.into(),
275                hash: None,
276                algorithm: Self::kms_to_algorithm(&key.algorithm)?,
277                domain: Some(Self::kms_to_domain(&key.algorithm)?),
278                ..Default::default()
279            });
280        }
281        Ok(result)
282    }
283
284    /// Get the public key info.
285    fn get_key_info(&self, alias: &str) -> Result<KeyInfo> {
286        let key = self.get_public_key(alias)?;
287        let data = if let Some(pem) = &key.pem {
288            pem.as_str()
289        } else if let Some(public_key) = &key.public_key {
290            public_key.data.as_str()
291        } else {
292            return Err(HsmError::Unsupported("did not find public key material".into()).into());
293        };
294        Ok(KeyInfo {
295            hash: "".into(),
296            algorithm: Self::kms_to_algorithm(&key.algorithm)?,
297            domain: Some(Self::kms_to_domain(&key.algorithm)?),
298            public_key: Base64::decode_vec(data)?,
299            private_blob: Vec::new(),
300        })
301    }
302
303    /// Generate a key pair.
304    fn generate_key(
305        &self,
306        alias: &str,
307        _algorithm: &str,
308        domain: SpxDomain,
309        _token: &str,
310        flags: GenerateFlags,
311    ) -> Result<KeyEntry> {
312        if flags.contains(GenerateFlags::EXPORT_PRIVATE) {
313            return Err(HsmError::Unsupported("export of private key material".into()).into());
314        }
315        let algorithm = match domain {
316            SpxDomain::Pure => Self::PURE_ALGORITHM,
317            SpxDomain::PreHashedSha256 => Self::PREHASH_ALGORITHM,
318            _ => return Err(HsmError::Unsupported(format!("domain {domain}")).into()),
319        };
320        let url = self
321            .keyring
322            .join(&format!("cryptoKeys?crypto_key_id={alias}"))?;
323        let template = KmsCreateKey {
324            purpose: "ASYMMETRIC_SIGN".into(),
325            version_template: VersionTemplate {
326                algorithm: algorithm.into(),
327                protection_level: "SOFTWARE".into(),
328            },
329        };
330        let resp = self.post::<KmsKeyRef>(url, &template)?;
331        Ok(KeyEntry {
332            alias: alias.into(),
333            algorithm: Self::kms_to_algorithm(&resp.version_template.algorithm)?,
334            domain: Some(Self::kms_to_domain(&resp.version_template.algorithm)?),
335            ..Default::default()
336        })
337    }
338
339    /// Import a key pair.
340    fn import_keypair(
341        &self,
342        _alias: &str,
343        _algorithm: &str,
344        _domain: SpxDomain,
345        _token: &str,
346        _overwrite: bool,
347        _public_key: &[u8],
348        _private_key: &[u8],
349    ) -> Result<KeyEntry> {
350        Err(HsmError::Unsupported(format!(
351            "key import is not supported by {}",
352            self.get_version()?
353        ))
354        .into())
355    }
356
357    /// Sign a message.
358    fn sign(
359        &self,
360        alias: Option<&str>,
361        key_hash: Option<&str>,
362        domain: SpxDomain,
363        message: &[u8],
364    ) -> Result<Vec<u8>> {
365        let alias = alias.ok_or(HsmError::NoSearchCriteria)?;
366        if key_hash.is_some() {
367            log::warn!("ignored key_hash {key_hash:?}");
368        }
369        let key = self.get_key_version(alias)?;
370        let keydomain = Self::kms_to_domain(&key.algorithm)?;
371        if domain != keydomain {
372            return Err(HsmError::Unsupported(format!(
373                "domain {domain} not supported by key {alias}",
374            ))
375            .into());
376        }
377
378        let url = self
379            .keyring
380            .join(&format!("/v1/{}:asymmetricSign", key.name))?;
381
382        // Format the signing request:
383        // - For the "pure" domain, we place the message in the `data` field.
384        // - For the "prehashed" domain, we place the digest into the `digest` field.
385        let req = match keydomain {
386            SpxDomain::Pure => KmsSignRequest {
387                data: Some(Base64::encode_string(message)),
388                ..Default::default()
389            },
390            SpxDomain::PreHashedSha256 => KmsSignRequest {
391                digest: Some(KmsDigest {
392                    sha256: Base64::encode_string(message),
393                }),
394                ..Default::default()
395            },
396            _ => unreachable!(),
397        };
398
399        let resp = self.post::<IndexMap<String, String>>(url, &req)?;
400        let signature = Base64::decode_vec(&resp["signature"])?;
401        Ok(signature)
402    }
403
404    /// Verify a message.
405    fn verify(
406        &self,
407        alias: Option<&str>,
408        key_hash: Option<&str>,
409        domain: SpxDomain,
410        message: &[u8],
411        signature: &[u8],
412    ) -> Result<bool> {
413        let alias = alias.ok_or(HsmError::NoSearchCriteria)?;
414        if key_hash.is_some() {
415            log::warn!("ignored key_hash {key_hash:?}");
416        }
417        let info = self.get_key_info(alias)?;
418        let keydomain = info.domain.expect("kms key domain");
419        if domain != keydomain {
420            return Err(HsmError::Unsupported(format!(
421                "domain {domain} not supported by key {alias}",
422            ))
423            .into());
424        }
425        let pk =
426            SpxPublicKey::from_bytes(SphincsPlus::from_str(&info.algorithm)?, &info.public_key)?;
427        match pk.verify(domain, signature, message) {
428            Ok(_) => Ok(true),
429            Err(_) => Ok(false),
430        }
431    }
432}