hsmtool/extra/
spxkms.rs

1// Copyright lowRISC contributors (OpenTitan project).
2// Licensed under the Apache License, Version 2.0, see LICENSE for details.
3// SPDX-License-Identifier: Apache-2.0
4
5use acorn::{GenerateFlags, KeyEntry, KeyInfo, SpxInterface};
6use anyhow::{Context, Result, anyhow};
7use base64ct::{Base64, Encoding};
8use indexmap::IndexMap;
9use serde::{Deserialize, Serialize, de::DeserializeOwned};
10use serde_json::Value;
11use sphincsplus::{SphincsPlus, SpxDomain, SpxPublicKey};
12use std::process::Command;
13use std::str::FromStr;
14use thiserror::Error;
15use zeroize::Zeroizing;
16
17use reqwest::blocking::Client;
18use reqwest::{IntoUrl, Url};
19
20use crate::error::HsmError;
21
22/// SpxEf implements SPHINCS+ signing via Google CloudKms.
23pub struct SpxKms {
24    keyring: Url,
25    project: String,
26    auth: Zeroizing<String>,
27}
28
29/// ApiError represents an error result from the cloud API.
30#[derive(Deserialize, Debug, Error)]
31#[error("api error: code={code} message={message:?}; details={details:?}")]
32#[serde(rename_all = "camelCase")]
33pub struct ApiError {
34    pub code: u32,
35    pub message: String,
36    pub status: String,
37    #[serde(flatten)]
38    pub details: IndexMap<String, Value>,
39}
40
41// CloudResult assists in deserializing the cloud API return into an error
42// or a specific type.
43#[derive(Deserialize, Debug)]
44enum CloudResult<T> {
45    #[serde(rename = "error")]
46    Error(ApiError),
47    #[serde(untagged)]
48    Ok(T),
49}
50
51#[derive(Deserialize, Debug, Clone)]
52#[serde(rename_all = "camelCase")]
53struct KmsKeyList {
54    crypto_keys: Vec<KmsKeyRef>,
55}
56
57// Note: dead_code is allowed because some of the fields defined in this struct
58// are not used, but are fields returned by the KMS json API.
59#[allow(dead_code)]
60#[derive(Serialize, Deserialize, Debug, Clone)]
61#[serde(rename_all = "camelCase")]
62struct VersionTemplate {
63    #[serde(default)]
64    protection_level: String,
65    #[serde(default)]
66    algorithm: String,
67}
68
69#[derive(Serialize, Deserialize, Debug, Clone)]
70#[serde(rename_all = "camelCase")]
71struct KmsCreateKey {
72    purpose: String,
73    version_template: VersionTemplate,
74}
75
76#[derive(Deserialize, Debug, Clone)]
77#[serde(rename_all = "camelCase")]
78struct KmsKeyRef {
79    name: String,
80    version_template: VersionTemplate,
81}
82
83#[derive(Deserialize, Debug, Clone)]
84#[serde(rename_all = "camelCase")]
85struct KmsKeyVersion {
86    name: String,
87    state: String,
88    #[serde(default)]
89    algorithm: String,
90}
91
92#[derive(Deserialize, Debug, Clone)]
93#[serde(rename_all = "camelCase")]
94struct KmsKeyVersions {
95    crypto_key_versions: Vec<KmsKeyVersion>,
96}
97
98#[derive(Deserialize, Debug, Clone)]
99#[serde(rename_all = "camelCase")]
100struct KmsPublicKeyData {
101    data: String,
102}
103
104// Note: dead_code is allowed because some of the fields defined in this struct
105// are not used, but are fields returned by the KMS json API.
106#[allow(dead_code)]
107#[derive(Deserialize, Debug, Clone)]
108#[serde(rename_all = "camelCase")]
109struct KmsPublicKey {
110    algorithm: String,
111    name: String,
112    #[serde(default)]
113    protection_level: String,
114    #[serde(default)]
115    public_key_format: String,
116    pem: Option<String>,
117    public_key: Option<KmsPublicKeyData>,
118}
119
120#[derive(Serialize, Debug)]
121struct KmsDigest {
122    sha256: String,
123}
124
125#[derive(Serialize, Debug)]
126struct KmsSignRequest {
127    #[serde(skip_serializing_if = "Option::is_none")]
128    digest: Option<KmsDigest>,
129    #[serde(skip_serializing_if = "Option::is_none")]
130    data: Option<String>,
131}
132
133impl SpxKms {
134    const ALGORITHM: &'static str = "PQ_SIGN_SLH_DSA_SHA2_128S";
135
136    pub fn new(parameters: &str) -> Result<Box<Self>> {
137        let output = Command::new("gcloud")
138            .args(["auth", "print-access-token"])
139            .output()?;
140        if output.status.success() {
141            // Get the authorization token and strip trailing newlines.
142            let mut auth = String::from_utf8(output.stdout)?;
143            let len = auth.trim_end().len();
144            auth.truncate(len);
145
146            let mut params = IndexMap::new();
147            params.extend(parameters.split(':').map(|p| {
148                p.split_once('=')
149                    .expect("KMS parameters should be key=value")
150            }));
151
152            let project = params.get("project").ok_or(HsmError::Unsupported(
153                "KMS requires a project parameter".into(),
154            ))?;
155            let location = params.get("location").ok_or(HsmError::Unsupported(
156                "KMS requires a location parameter".into(),
157            ))?;
158            let keyring = params.get("keyring").ok_or(HsmError::Unsupported(
159                "KMS requires a keyring parameter".into(),
160            ))?;
161            let url = format!(
162                "https://cloudkms.googleapis.com/v1/projects/{project}/locations/{location}/keyRings/{keyring}/"
163            );
164            log::info!("keyring url: {url}");
165            Ok(Box::new(Self {
166                keyring: Url::parse(&url)?,
167                project: project.to_string(),
168                auth: auth.into(),
169            }))
170        } else {
171            let stderr = String::from_utf8_lossy(&output.stderr);
172            Err(anyhow!("gcloud error {:?}: {}", output.status, stderr))
173        }
174    }
175
176    fn get<RSP: DeserializeOwned>(&self, url: impl IntoUrl) -> Result<RSP> {
177        let client = Client::new();
178        log::debug!("GET {}", url.as_str());
179        let resp = client
180            .get(url)
181            .bearer_auth(&*self.auth)
182            .header("content-type", "application/json")
183            .header("X-Goog-User-Project", &self.project)
184            .send()?;
185        let data = resp.text()?;
186        log::debug!("data: {data}");
187        match serde_json::from_str::<CloudResult<RSP>>(&data)? {
188            CloudResult::Error(e) => Err(e.into()),
189            CloudResult::Ok(v) => Ok(v),
190        }
191    }
192
193    fn post<RSP: DeserializeOwned>(&self, url: impl IntoUrl, req: &impl Serialize) -> Result<RSP> {
194        let client = Client::new();
195        log::debug!("POST {}", url.as_str());
196        let resp = client
197            .post(url)
198            .bearer_auth(&*self.auth)
199            .header("content-type", "application/json")
200            .header("X-Goog-User-Project", &self.project)
201            .json(req)
202            .send()?;
203        let data = resp.text()?;
204        log::debug!("data: {data}");
205        match serde_json::from_str::<CloudResult<RSP>>(&data)? {
206            CloudResult::Error(e) => Err(e.into()),
207            CloudResult::Ok(v) => Ok(v),
208        }
209    }
210
211    fn get_key_version(&self, alias: &str) -> Result<KmsKeyVersion> {
212        let url = self
213            .keyring
214            .join(&format!("cryptoKeys/{alias}/cryptoKeyVersions"))?;
215        let versions = self.get::<KmsKeyVersions>(url)?;
216        match versions
217            .crypto_key_versions
218            .iter()
219            .filter(|v| v.state == "ENABLED" && v.algorithm == Self::ALGORITHM)
220            .next_back()
221        {
222            Some(key) => Ok(key.clone()),
223            None => Err(HsmError::ObjectNotFound(alias.into()).into()),
224        }
225    }
226
227    fn get_public_key(&self, alias: &str) -> Result<KmsPublicKey> {
228        let key = self.get_key_version(alias)?;
229        let url = self.keyring.join(&format!("/v1/{}/publicKey", key.name))?;
230        self.get(url)
231    }
232}
233
234impl SpxInterface for SpxKms {
235    /// Get the version of the backend.
236    fn get_version(&self) -> Result<String> {
237        Ok(String::from("CloudKMS 0.0.1"))
238    }
239
240    /// List keys known to the backend.
241    fn list_keys(&self) -> Result<Vec<KeyEntry>> {
242        let keys = self.keyring.join("cryptoKeys")?;
243        let keys = self.get::<KmsKeyList>(keys)?;
244        let mut result = Vec::new();
245
246        for k in keys.crypto_keys.iter() {
247            let (_, name) = k
248                .name
249                .rsplit_once('/')
250                .ok_or_else(|| HsmError::ParseError("could not parse key name".into()))
251                .with_context(|| format!("key name {:?}", k.name))?;
252            if k.version_template.algorithm != Self::ALGORITHM {
253                continue;
254            }
255            let key = self.get_key_version(name)?;
256            result.push(KeyEntry {
257                alias: name.into(),
258                hash: None,
259                algorithm: key.algorithm.clone(),
260                ..Default::default()
261            });
262        }
263        Ok(result)
264    }
265
266    /// Get the public key info.
267    fn get_key_info(&self, alias: &str) -> Result<KeyInfo> {
268        let key = self.get_public_key(alias)?;
269        let algorithm = key
270            .algorithm
271            .trim_start_matches("PQ_SIGN_")
272            .replace('_', "-");
273        let data = if let Some(pem) = &key.pem {
274            pem.as_str()
275        } else if let Some(public_key) = &key.public_key {
276            public_key.data.as_str()
277        } else {
278            return Err(HsmError::Unsupported("did not find public key material".into()).into());
279        };
280        Ok(KeyInfo {
281            hash: "".into(),
282            algorithm,
283            public_key: Base64::decode_vec(data)?,
284            private_blob: Vec::new(),
285        })
286    }
287
288    /// Generate a key pair.
289    fn generate_key(
290        &self,
291        alias: &str,
292        _algorithm: &str,
293        _token: &str,
294        flags: GenerateFlags,
295    ) -> Result<KeyEntry> {
296        if flags.contains(GenerateFlags::EXPORT_PRIVATE) {
297            return Err(HsmError::Unsupported("export of private key material".into()).into());
298        }
299        let url = self
300            .keyring
301            .join(&format!("cryptoKeys?crypto_key_id={alias}"))?;
302        let template = KmsCreateKey {
303            purpose: "ASYMMETRIC_SIGN".into(),
304            version_template: VersionTemplate {
305                algorithm: Self::ALGORITHM.into(),
306                protection_level: "SOFTWARE".into(),
307            },
308        };
309        let resp = self.post::<KmsKeyRef>(url, &template)?;
310        Ok(KeyEntry {
311            alias: alias.into(),
312            algorithm: resp
313                .version_template
314                .algorithm
315                .trim_start_matches("PQ_SIGN_")
316                .replace('_', "-"),
317            ..Default::default()
318        })
319    }
320
321    /// Import a key pair.
322    fn import_keypair(
323        &self,
324        _alias: &str,
325        _algorithm: &str,
326        _token: &str,
327        _overwrite: bool,
328        _public_key: &[u8],
329        _private_key: &[u8],
330    ) -> Result<KeyEntry> {
331        Err(HsmError::Unsupported(format!(
332            "key import is not supported by {}",
333            self.get_version()?
334        ))
335        .into())
336    }
337
338    /// Sign a message.
339    fn sign(
340        &self,
341        alias: Option<&str>,
342        key_hash: Option<&str>,
343        domain: SpxDomain,
344        message: &[u8],
345    ) -> Result<Vec<u8>> {
346        match domain {
347            SpxDomain::Pure => {}
348            _ => {
349                return Err(HsmError::Unsupported(format!(
350                    "domain {domain} not supported by {}",
351                    self.get_version()?
352                ))
353                .into());
354            }
355        };
356        let alias = alias.ok_or(HsmError::NoSearchCriteria)?;
357        if key_hash.is_some() {
358            log::warn!("ignored key_hash {key_hash:?}");
359        }
360        let key = self.get_key_version(alias)?;
361        let url = self
362            .keyring
363            .join(&format!("/v1/{}:asymmetricSign", key.name))?;
364        let req = KmsSignRequest {
365            digest: None,
366            data: Some(Base64::encode_string(message)),
367        };
368        let resp = self.post::<IndexMap<String, String>>(url, &req)?;
369        let signature = Base64::decode_vec(&resp["signature"])?;
370        Ok(signature)
371    }
372
373    /// Verify a message.
374    fn verify(
375        &self,
376        alias: Option<&str>,
377        key_hash: Option<&str>,
378        domain: SpxDomain,
379        message: &[u8],
380        signature: &[u8],
381    ) -> Result<bool> {
382        let alias = alias.ok_or(HsmError::NoSearchCriteria)?;
383        if key_hash.is_some() {
384            log::warn!("ignored key_hash {key_hash:?}");
385        }
386        let info = self.get_key_info(alias)?;
387        let pk =
388            SpxPublicKey::from_bytes(SphincsPlus::from_str(&info.algorithm)?, &info.public_key)?;
389        match pk.verify(domain, signature, message) {
390            Ok(_) => Ok(true),
391            Err(_) => Ok(false),
392        }
393    }
394}