1use acorn::{GenerateFlags, KeyEntry, KeyInfo, SpxInterface};
6use anyhow::{Context, Result, anyhow};
7use base64ct::{Base64, Encoding};
8use indexmap::IndexMap;
9use serde::{Deserialize, Serialize, de::DeserializeOwned};
10use serde_json::Value;
11use sphincsplus::{SphincsPlus, SpxDomain, SpxPublicKey};
12use std::process::Command;
13use std::str::FromStr;
14use thiserror::Error;
15use zeroize::Zeroizing;
16
17use reqwest::blocking::Client;
18use reqwest::{IntoUrl, Url};
19
20use crate::error::HsmError;
21
22pub struct SpxKms {
24 keyring: Url,
25 project: String,
26 auth: Zeroizing<String>,
27}
28
29#[derive(Deserialize, Debug, Error)]
31#[error("api error: code={code} message={message:?}; details={details:?}")]
32#[serde(rename_all = "camelCase")]
33pub struct ApiError {
34 pub code: u32,
35 pub message: String,
36 pub status: String,
37 #[serde(flatten)]
38 pub details: IndexMap<String, Value>,
39}
40
41#[derive(Deserialize, Debug)]
44enum CloudResult<T> {
45 #[serde(rename = "error")]
46 Error(ApiError),
47 #[serde(untagged)]
48 Ok(T),
49}
50
51#[derive(Deserialize, Debug, Clone)]
52#[serde(rename_all = "camelCase")]
53struct KmsKeyList {
54 crypto_keys: Vec<KmsKeyRef>,
55}
56
57#[allow(dead_code)]
60#[derive(Serialize, Deserialize, Debug, Clone)]
61#[serde(rename_all = "camelCase")]
62struct VersionTemplate {
63 #[serde(default)]
64 protection_level: String,
65 #[serde(default)]
66 algorithm: String,
67}
68
69#[derive(Serialize, Deserialize, Debug, Clone)]
70#[serde(rename_all = "camelCase")]
71struct KmsCreateKey {
72 purpose: String,
73 version_template: VersionTemplate,
74}
75
76#[derive(Deserialize, Debug, Clone)]
77#[serde(rename_all = "camelCase")]
78struct KmsKeyRef {
79 name: String,
80 version_template: VersionTemplate,
81}
82
83#[derive(Deserialize, Debug, Clone)]
84#[serde(rename_all = "camelCase")]
85struct KmsKeyVersion {
86 name: String,
87 state: String,
88 #[serde(default)]
89 algorithm: String,
90}
91
92#[derive(Deserialize, Debug, Clone)]
93#[serde(rename_all = "camelCase")]
94struct KmsKeyVersions {
95 crypto_key_versions: Vec<KmsKeyVersion>,
96}
97
98#[derive(Deserialize, Debug, Clone)]
99#[serde(rename_all = "camelCase")]
100struct KmsPublicKeyData {
101 data: String,
102}
103
104#[allow(dead_code)]
107#[derive(Deserialize, Debug, Clone)]
108#[serde(rename_all = "camelCase")]
109struct KmsPublicKey {
110 algorithm: String,
111 name: String,
112 #[serde(default)]
113 protection_level: String,
114 #[serde(default)]
115 public_key_format: String,
116 pem: Option<String>,
117 public_key: Option<KmsPublicKeyData>,
118}
119
120#[derive(Serialize, Debug)]
121struct KmsDigest {
122 sha256: String,
123}
124
125#[derive(Serialize, Debug, Default)]
126struct KmsSignRequest {
127 #[serde(skip_serializing_if = "Option::is_none")]
128 digest: Option<KmsDigest>,
129 #[serde(skip_serializing_if = "Option::is_none")]
130 data: Option<String>,
131}
132
133impl SpxKms {
134 const PURE_ALGORITHM: &'static str = "PQ_SIGN_SLH_DSA_SHA2_128S";
135 const PREHASH_ALGORITHM: &'static str = "PQ_SIGN_HASH_SLH_DSA_SHA2_128S_SHA256";
136
137 pub fn new(parameters: &str) -> Result<Box<Self>> {
138 let output = Command::new("gcloud")
139 .args(["auth", "print-access-token"])
140 .output()?;
141 if output.status.success() {
142 let mut auth = String::from_utf8(output.stdout)?;
144 let len = auth.trim_end().len();
145 auth.truncate(len);
146
147 let mut params = IndexMap::new();
148 params.extend(parameters.split(':').map(|p| {
149 p.split_once('=')
150 .expect("KMS parameters should be key=value")
151 }));
152
153 let project = params.get("project").ok_or(HsmError::Unsupported(
154 "KMS requires a project parameter".into(),
155 ))?;
156 let location = params.get("location").ok_or(HsmError::Unsupported(
157 "KMS requires a location parameter".into(),
158 ))?;
159 let keyring = params.get("keyring").ok_or(HsmError::Unsupported(
160 "KMS requires a keyring parameter".into(),
161 ))?;
162 let url = format!(
163 "https://cloudkms.googleapis.com/v1/projects/{project}/locations/{location}/keyRings/{keyring}/"
164 );
165 log::info!("keyring url: {url}");
166 Ok(Box::new(Self {
167 keyring: Url::parse(&url)?,
168 project: project.to_string(),
169 auth: auth.into(),
170 }))
171 } else {
172 let stderr = String::from_utf8_lossy(&output.stderr);
173 Err(anyhow!("gcloud error {:?}: {}", output.status, stderr))
174 }
175 }
176
177 fn kms_to_algorithm(kms_algo: &str) -> Result<String> {
178 match kms_algo {
179 Self::PURE_ALGORITHM | Self::PREHASH_ALGORITHM => Ok("SLH-DSA-SHA2-128S".into()),
180 _ => Err(HsmError::Unsupported(format!("algorithm {kms_algo}")).into()),
181 }
182 }
183
184 fn kms_to_domain(kms_algo: &str) -> Result<SpxDomain> {
185 match kms_algo {
186 Self::PURE_ALGORITHM => Ok(SpxDomain::Pure),
187 Self::PREHASH_ALGORITHM => Ok(SpxDomain::PreHashedSha256),
188 _ => Err(HsmError::Unsupported(format!("algorithm {kms_algo}")).into()),
189 }
190 }
191
192 fn get<RSP: DeserializeOwned>(&self, url: impl IntoUrl) -> Result<RSP> {
193 let client = Client::new();
194 log::debug!("GET {}", url.as_str());
195 let resp = client
196 .get(url)
197 .bearer_auth(&*self.auth)
198 .header("content-type", "application/json")
199 .header("X-Goog-User-Project", &self.project)
200 .send()?;
201 let data = resp.text()?;
202 log::debug!("data: {data}");
203 match serde_json::from_str::<CloudResult<RSP>>(&data)? {
204 CloudResult::Error(e) => Err(e.into()),
205 CloudResult::Ok(v) => Ok(v),
206 }
207 }
208
209 fn post<RSP: DeserializeOwned>(&self, url: impl IntoUrl, req: &impl Serialize) -> Result<RSP> {
210 let client = Client::new();
211 log::debug!("POST {}", url.as_str());
212 let resp = client
213 .post(url)
214 .bearer_auth(&*self.auth)
215 .header("content-type", "application/json")
216 .header("X-Goog-User-Project", &self.project)
217 .json(req)
218 .send()?;
219 let data = resp.text()?;
220 log::debug!("data: {data}");
221 match serde_json::from_str::<CloudResult<RSP>>(&data)? {
222 CloudResult::Error(e) => Err(e.into()),
223 CloudResult::Ok(v) => Ok(v),
224 }
225 }
226
227 fn get_key_version(&self, alias: &str) -> Result<KmsKeyVersion> {
228 let url = self
229 .keyring
230 .join(&format!("cryptoKeys/{alias}/cryptoKeyVersions"))?;
231 let versions = self.get::<KmsKeyVersions>(url)?;
232 match versions
233 .crypto_key_versions
234 .iter()
235 .filter(|v| v.state == "ENABLED" && Self::kms_to_algorithm(&v.algorithm).is_ok())
236 .next_back()
237 {
238 Some(key) => Ok(key.clone()),
239 None => Err(HsmError::ObjectNotFound(alias.into()).into()),
240 }
241 }
242
243 fn get_public_key(&self, alias: &str) -> Result<KmsPublicKey> {
244 let key = self.get_key_version(alias)?;
245 let mut url = self.keyring.join(&format!("/v1/{}/publicKey", key.name))?;
246 url.set_query(Some("public_key_format=NIST_PQC"));
247 self.get(url)
248 }
249}
250
251impl SpxInterface for SpxKms {
252 fn get_version(&self) -> Result<String> {
254 Ok(String::from("CloudKMS 0.0.1"))
255 }
256
257 fn list_keys(&self) -> Result<Vec<KeyEntry>> {
259 let keys = self.keyring.join("cryptoKeys")?;
260 let keys = self.get::<KmsKeyList>(keys)?;
261 let mut result = Vec::new();
262
263 for k in keys.crypto_keys.iter() {
264 let (_, name) = k
265 .name
266 .rsplit_once('/')
267 .ok_or_else(|| HsmError::ParseError("could not parse key name".into()))
268 .with_context(|| format!("key name {:?}", k.name))?;
269 if Self::kms_to_algorithm(&k.version_template.algorithm).is_err() {
270 continue;
271 }
272 let key = self.get_key_version(name)?;
273 result.push(KeyEntry {
274 alias: name.into(),
275 hash: None,
276 algorithm: Self::kms_to_algorithm(&key.algorithm)?,
277 domain: Some(Self::kms_to_domain(&key.algorithm)?),
278 ..Default::default()
279 });
280 }
281 Ok(result)
282 }
283
284 fn get_key_info(&self, alias: &str) -> Result<KeyInfo> {
286 let key = self.get_public_key(alias)?;
287 let data = if let Some(pem) = &key.pem {
288 pem.as_str()
289 } else if let Some(public_key) = &key.public_key {
290 public_key.data.as_str()
291 } else {
292 return Err(HsmError::Unsupported("did not find public key material".into()).into());
293 };
294 Ok(KeyInfo {
295 hash: "".into(),
296 algorithm: Self::kms_to_algorithm(&key.algorithm)?,
297 domain: Some(Self::kms_to_domain(&key.algorithm)?),
298 public_key: Base64::decode_vec(data)?,
299 private_blob: Vec::new(),
300 })
301 }
302
303 fn generate_key(
305 &self,
306 alias: &str,
307 _algorithm: &str,
308 domain: SpxDomain,
309 _token: &str,
310 flags: GenerateFlags,
311 ) -> Result<KeyEntry> {
312 if flags.contains(GenerateFlags::EXPORT_PRIVATE) {
313 return Err(HsmError::Unsupported("export of private key material".into()).into());
314 }
315 let algorithm = match domain {
316 SpxDomain::Pure => Self::PURE_ALGORITHM,
317 SpxDomain::PreHashedSha256 => Self::PREHASH_ALGORITHM,
318 _ => return Err(HsmError::Unsupported(format!("domain {domain}")).into()),
319 };
320 let url = self
321 .keyring
322 .join(&format!("cryptoKeys?crypto_key_id={alias}"))?;
323 let template = KmsCreateKey {
324 purpose: "ASYMMETRIC_SIGN".into(),
325 version_template: VersionTemplate {
326 algorithm: algorithm.into(),
327 protection_level: "SOFTWARE".into(),
328 },
329 };
330 let resp = self.post::<KmsKeyRef>(url, &template)?;
331 Ok(KeyEntry {
332 alias: alias.into(),
333 algorithm: Self::kms_to_algorithm(&resp.version_template.algorithm)?,
334 domain: Some(Self::kms_to_domain(&resp.version_template.algorithm)?),
335 ..Default::default()
336 })
337 }
338
339 fn import_keypair(
341 &self,
342 _alias: &str,
343 _algorithm: &str,
344 _domain: SpxDomain,
345 _token: &str,
346 _overwrite: bool,
347 _public_key: &[u8],
348 _private_key: &[u8],
349 ) -> Result<KeyEntry> {
350 Err(HsmError::Unsupported(format!(
351 "key import is not supported by {}",
352 self.get_version()?
353 ))
354 .into())
355 }
356
357 fn sign(
359 &self,
360 alias: Option<&str>,
361 key_hash: Option<&str>,
362 domain: SpxDomain,
363 message: &[u8],
364 ) -> Result<Vec<u8>> {
365 let alias = alias.ok_or(HsmError::NoSearchCriteria)?;
366 if key_hash.is_some() {
367 log::warn!("ignored key_hash {key_hash:?}");
368 }
369 let key = self.get_key_version(alias)?;
370 let keydomain = Self::kms_to_domain(&key.algorithm)?;
371 if domain != keydomain {
372 return Err(HsmError::Unsupported(format!(
373 "domain {domain} not supported by key {alias}",
374 ))
375 .into());
376 }
377
378 let url = self
379 .keyring
380 .join(&format!("/v1/{}:asymmetricSign", key.name))?;
381
382 let req = match keydomain {
386 SpxDomain::Pure => KmsSignRequest {
387 data: Some(Base64::encode_string(message)),
388 ..Default::default()
389 },
390 SpxDomain::PreHashedSha256 => KmsSignRequest {
391 digest: Some(KmsDigest {
392 sha256: Base64::encode_string(message),
393 }),
394 ..Default::default()
395 },
396 _ => unreachable!(),
397 };
398
399 let resp = self.post::<IndexMap<String, String>>(url, &req)?;
400 let signature = Base64::decode_vec(&resp["signature"])?;
401 Ok(signature)
402 }
403
404 fn verify(
406 &self,
407 alias: Option<&str>,
408 key_hash: Option<&str>,
409 domain: SpxDomain,
410 message: &[u8],
411 signature: &[u8],
412 ) -> Result<bool> {
413 let alias = alias.ok_or(HsmError::NoSearchCriteria)?;
414 if key_hash.is_some() {
415 log::warn!("ignored key_hash {key_hash:?}");
416 }
417 let info = self.get_key_info(alias)?;
418 let keydomain = info.domain.expect("kms key domain");
419 if domain != keydomain {
420 return Err(HsmError::Unsupported(format!(
421 "domain {domain} not supported by key {alias}",
422 ))
423 .into());
424 }
425 let pk =
426 SpxPublicKey::from_bytes(SphincsPlus::from_str(&info.algorithm)?, &info.public_key)?;
427 match pk.verify(domain, signature, message) {
428 Ok(_) => Ok(true),
429 Err(_) => Ok(false),
430 }
431 }
432}