1use acorn::{GenerateFlags, KeyEntry, KeyInfo, SpxInterface};
6use anyhow::{Context, Result, anyhow};
7use base64ct::{Base64, Encoding};
8use indexmap::IndexMap;
9use serde::{Deserialize, Serialize, de::DeserializeOwned};
10use serde_json::Value;
11use sphincsplus::{SphincsPlus, SpxDomain, SpxPublicKey};
12use std::process::Command;
13use std::str::FromStr;
14use thiserror::Error;
15use zeroize::Zeroizing;
16
17use reqwest::blocking::Client;
18use reqwest::{IntoUrl, Url};
19
20use crate::error::HsmError;
21
22pub struct SpxKms {
24 keyring: Url,
25 project: String,
26 auth: Zeroizing<String>,
27}
28
29#[derive(Deserialize, Debug, Error)]
31#[error("api error: code={code} message={message:?}; details={details:?}")]
32#[serde(rename_all = "camelCase")]
33pub struct ApiError {
34 pub code: u32,
35 pub message: String,
36 pub status: String,
37 #[serde(flatten)]
38 pub details: IndexMap<String, Value>,
39}
40
41#[derive(Deserialize, Debug)]
44enum CloudResult<T> {
45 #[serde(rename = "error")]
46 Error(ApiError),
47 #[serde(untagged)]
48 Ok(T),
49}
50
51#[derive(Deserialize, Debug, Clone)]
52#[serde(rename_all = "camelCase")]
53struct KmsKeyList {
54 crypto_keys: Vec<KmsKeyRef>,
55}
56
57#[allow(dead_code)]
60#[derive(Serialize, Deserialize, Debug, Clone)]
61#[serde(rename_all = "camelCase")]
62struct VersionTemplate {
63 #[serde(default)]
64 protection_level: String,
65 #[serde(default)]
66 algorithm: String,
67}
68
69#[derive(Serialize, Deserialize, Debug, Clone)]
70#[serde(rename_all = "camelCase")]
71struct KmsCreateKey {
72 purpose: String,
73 version_template: VersionTemplate,
74}
75
76#[derive(Deserialize, Debug, Clone)]
77#[serde(rename_all = "camelCase")]
78struct KmsKeyRef {
79 name: String,
80 version_template: VersionTemplate,
81}
82
83#[derive(Deserialize, Debug, Clone)]
84#[serde(rename_all = "camelCase")]
85struct KmsKeyVersion {
86 name: String,
87 state: String,
88 #[serde(default)]
89 algorithm: String,
90}
91
92#[derive(Deserialize, Debug, Clone)]
93#[serde(rename_all = "camelCase")]
94struct KmsKeyVersions {
95 crypto_key_versions: Vec<KmsKeyVersion>,
96}
97
98#[derive(Deserialize, Debug, Clone)]
99#[serde(rename_all = "camelCase")]
100struct KmsPublicKeyData {
101 data: String,
102}
103
104#[allow(dead_code)]
107#[derive(Deserialize, Debug, Clone)]
108#[serde(rename_all = "camelCase")]
109struct KmsPublicKey {
110 algorithm: String,
111 name: String,
112 #[serde(default)]
113 protection_level: String,
114 #[serde(default)]
115 public_key_format: String,
116 pem: Option<String>,
117 public_key: Option<KmsPublicKeyData>,
118}
119
120#[derive(Serialize, Debug)]
121struct KmsDigest {
122 sha256: String,
123}
124
125#[derive(Serialize, Debug)]
126struct KmsSignRequest {
127 #[serde(skip_serializing_if = "Option::is_none")]
128 digest: Option<KmsDigest>,
129 #[serde(skip_serializing_if = "Option::is_none")]
130 data: Option<String>,
131}
132
133impl SpxKms {
134 const ALGORITHM: &'static str = "PQ_SIGN_SLH_DSA_SHA2_128S";
135
136 pub fn new(parameters: &str) -> Result<Box<Self>> {
137 let output = Command::new("gcloud")
138 .args(["auth", "print-access-token"])
139 .output()?;
140 if output.status.success() {
141 let mut auth = String::from_utf8(output.stdout)?;
143 let len = auth.trim_end().len();
144 auth.truncate(len);
145
146 let mut params = IndexMap::new();
147 params.extend(parameters.split(':').map(|p| {
148 p.split_once('=')
149 .expect("KMS parameters should be key=value")
150 }));
151
152 let project = params.get("project").ok_or(HsmError::Unsupported(
153 "KMS requires a project parameter".into(),
154 ))?;
155 let location = params.get("location").ok_or(HsmError::Unsupported(
156 "KMS requires a location parameter".into(),
157 ))?;
158 let keyring = params.get("keyring").ok_or(HsmError::Unsupported(
159 "KMS requires a keyring parameter".into(),
160 ))?;
161 let url = format!(
162 "https://cloudkms.googleapis.com/v1/projects/{project}/locations/{location}/keyRings/{keyring}/"
163 );
164 log::info!("keyring url: {url}");
165 Ok(Box::new(Self {
166 keyring: Url::parse(&url)?,
167 project: project.to_string(),
168 auth: auth.into(),
169 }))
170 } else {
171 let stderr = String::from_utf8_lossy(&output.stderr);
172 Err(anyhow!("gcloud error {:?}: {}", output.status, stderr))
173 }
174 }
175
176 fn get<RSP: DeserializeOwned>(&self, url: impl IntoUrl) -> Result<RSP> {
177 let client = Client::new();
178 log::debug!("GET {}", url.as_str());
179 let resp = client
180 .get(url)
181 .bearer_auth(&*self.auth)
182 .header("content-type", "application/json")
183 .header("X-Goog-User-Project", &self.project)
184 .send()?;
185 let data = resp.text()?;
186 log::debug!("data: {data}");
187 match serde_json::from_str::<CloudResult<RSP>>(&data)? {
188 CloudResult::Error(e) => Err(e.into()),
189 CloudResult::Ok(v) => Ok(v),
190 }
191 }
192
193 fn post<RSP: DeserializeOwned>(&self, url: impl IntoUrl, req: &impl Serialize) -> Result<RSP> {
194 let client = Client::new();
195 log::debug!("POST {}", url.as_str());
196 let resp = client
197 .post(url)
198 .bearer_auth(&*self.auth)
199 .header("content-type", "application/json")
200 .header("X-Goog-User-Project", &self.project)
201 .json(req)
202 .send()?;
203 let data = resp.text()?;
204 log::debug!("data: {data}");
205 match serde_json::from_str::<CloudResult<RSP>>(&data)? {
206 CloudResult::Error(e) => Err(e.into()),
207 CloudResult::Ok(v) => Ok(v),
208 }
209 }
210
211 fn get_key_version(&self, alias: &str) -> Result<KmsKeyVersion> {
212 let url = self
213 .keyring
214 .join(&format!("cryptoKeys/{alias}/cryptoKeyVersions"))?;
215 let versions = self.get::<KmsKeyVersions>(url)?;
216 match versions
217 .crypto_key_versions
218 .iter()
219 .filter(|v| v.state == "ENABLED" && v.algorithm == Self::ALGORITHM)
220 .next_back()
221 {
222 Some(key) => Ok(key.clone()),
223 None => Err(HsmError::ObjectNotFound(alias.into()).into()),
224 }
225 }
226
227 fn get_public_key(&self, alias: &str) -> Result<KmsPublicKey> {
228 let key = self.get_key_version(alias)?;
229 let url = self.keyring.join(&format!("/v1/{}/publicKey", key.name))?;
230 self.get(url)
231 }
232}
233
234impl SpxInterface for SpxKms {
235 fn get_version(&self) -> Result<String> {
237 Ok(String::from("CloudKMS 0.0.1"))
238 }
239
240 fn list_keys(&self) -> Result<Vec<KeyEntry>> {
242 let keys = self.keyring.join("cryptoKeys")?;
243 let keys = self.get::<KmsKeyList>(keys)?;
244 let mut result = Vec::new();
245
246 for k in keys.crypto_keys.iter() {
247 let (_, name) = k
248 .name
249 .rsplit_once('/')
250 .ok_or_else(|| HsmError::ParseError("could not parse key name".into()))
251 .with_context(|| format!("key name {:?}", k.name))?;
252 if k.version_template.algorithm != Self::ALGORITHM {
253 continue;
254 }
255 let key = self.get_key_version(name)?;
256 result.push(KeyEntry {
257 alias: name.into(),
258 hash: None,
259 algorithm: key.algorithm.clone(),
260 ..Default::default()
261 });
262 }
263 Ok(result)
264 }
265
266 fn get_key_info(&self, alias: &str) -> Result<KeyInfo> {
268 let key = self.get_public_key(alias)?;
269 let algorithm = key
270 .algorithm
271 .trim_start_matches("PQ_SIGN_")
272 .replace('_', "-");
273 let data = if let Some(pem) = &key.pem {
274 pem.as_str()
275 } else if let Some(public_key) = &key.public_key {
276 public_key.data.as_str()
277 } else {
278 return Err(HsmError::Unsupported("did not find public key material".into()).into());
279 };
280 Ok(KeyInfo {
281 hash: "".into(),
282 algorithm,
283 public_key: Base64::decode_vec(data)?,
284 private_blob: Vec::new(),
285 })
286 }
287
288 fn generate_key(
290 &self,
291 alias: &str,
292 _algorithm: &str,
293 _token: &str,
294 flags: GenerateFlags,
295 ) -> Result<KeyEntry> {
296 if flags.contains(GenerateFlags::EXPORT_PRIVATE) {
297 return Err(HsmError::Unsupported("export of private key material".into()).into());
298 }
299 let url = self
300 .keyring
301 .join(&format!("cryptoKeys?crypto_key_id={alias}"))?;
302 let template = KmsCreateKey {
303 purpose: "ASYMMETRIC_SIGN".into(),
304 version_template: VersionTemplate {
305 algorithm: Self::ALGORITHM.into(),
306 protection_level: "SOFTWARE".into(),
307 },
308 };
309 let resp = self.post::<KmsKeyRef>(url, &template)?;
310 Ok(KeyEntry {
311 alias: alias.into(),
312 algorithm: resp
313 .version_template
314 .algorithm
315 .trim_start_matches("PQ_SIGN_")
316 .replace('_', "-"),
317 ..Default::default()
318 })
319 }
320
321 fn import_keypair(
323 &self,
324 _alias: &str,
325 _algorithm: &str,
326 _token: &str,
327 _overwrite: bool,
328 _public_key: &[u8],
329 _private_key: &[u8],
330 ) -> Result<KeyEntry> {
331 Err(HsmError::Unsupported(format!(
332 "key import is not supported by {}",
333 self.get_version()?
334 ))
335 .into())
336 }
337
338 fn sign(
340 &self,
341 alias: Option<&str>,
342 key_hash: Option<&str>,
343 domain: SpxDomain,
344 message: &[u8],
345 ) -> Result<Vec<u8>> {
346 match domain {
347 SpxDomain::Pure => {}
348 _ => {
349 return Err(HsmError::Unsupported(format!(
350 "domain {domain} not supported by {}",
351 self.get_version()?
352 ))
353 .into());
354 }
355 };
356 let alias = alias.ok_or(HsmError::NoSearchCriteria)?;
357 if key_hash.is_some() {
358 log::warn!("ignored key_hash {key_hash:?}");
359 }
360 let key = self.get_key_version(alias)?;
361 let url = self
362 .keyring
363 .join(&format!("/v1/{}:asymmetricSign", key.name))?;
364 let req = KmsSignRequest {
365 digest: None,
366 data: Some(Base64::encode_string(message)),
367 };
368 let resp = self.post::<IndexMap<String, String>>(url, &req)?;
369 let signature = Base64::decode_vec(&resp["signature"])?;
370 Ok(signature)
371 }
372
373 fn verify(
375 &self,
376 alias: Option<&str>,
377 key_hash: Option<&str>,
378 domain: SpxDomain,
379 message: &[u8],
380 signature: &[u8],
381 ) -> Result<bool> {
382 let alias = alias.ok_or(HsmError::NoSearchCriteria)?;
383 if key_hash.is_some() {
384 log::warn!("ignored key_hash {key_hash:?}");
385 }
386 let info = self.get_key_info(alias)?;
387 let pk =
388 SpxPublicKey::from_bytes(SphincsPlus::from_str(&info.algorithm)?, &info.public_key)?;
389 match pk.verify(domain, signature, message) {
390 Ok(_) => Ok(true),
391 Err(_) => Ok(false),
392 }
393 }
394}