use anyhow::Result;
use cryptoki::mechanism::Mechanism;
use rsa::pkcs1v15::Pkcs1v15Sign;
use serde::{Deserialize, Serialize};
use sha2::digest::const_oid::AssociatedOid;
use sha2::digest::Digest;
use sha2::Sha256;
use crate::error::HsmError;
use crate::util::attribute::KeyType;
#[derive(clap::ValueEnum, Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum SignData {
#[serde(alias = "plain-text")]
PlainText,
#[serde(alias = "sha256-hash")]
Sha256Hash,
#[serde(alias = "raw")]
Raw,
}
impl SignData {
pub fn prepare(&self, keytype: KeyType, input: &[u8]) -> Result<Vec<u8>> {
match keytype {
KeyType::Rsa => match self {
SignData::PlainText => Self::pkcs15sign::<Sha256>(&Self::data_plain_text(input)?),
SignData::Sha256Hash => Self::pkcs15sign::<Sha256>(input),
SignData::Raw => Self::data_raw(input),
},
KeyType::Ec => match self {
SignData::PlainText => Self::data_plain_text(input),
SignData::Sha256Hash => Self::data_raw(input),
SignData::Raw => Self::data_raw(input),
},
_ => Err(HsmError::Unsupported("SignData prepare for {keytype:?}".into()).into()),
}
}
pub fn mechanism(&self, keytype: KeyType) -> Result<Mechanism> {
match keytype {
KeyType::Rsa => match self {
SignData::PlainText => Ok(Mechanism::RsaPkcs),
SignData::Sha256Hash => Ok(Mechanism::RsaPkcs),
SignData::Raw => Err(HsmError::Unsupported(
"rust-cryptoki Mechanism doesn't include RSA_X_509".into(),
)
.into()),
},
KeyType::Ec => match self {
SignData::PlainText => Ok(Mechanism::Ecdsa),
SignData::Sha256Hash => Ok(Mechanism::Ecdsa),
SignData::Raw => Ok(Mechanism::Ecdsa),
},
_ => Err(HsmError::Unsupported("No mechanism for {keytype:?}".into()).into()),
}
}
fn data_raw(input: &[u8]) -> Result<Vec<u8>> {
let mut result = Vec::new();
result.extend_from_slice(input);
Ok(result)
}
fn pkcs15sign<D>(input: &[u8]) -> Result<Vec<u8>>
where
D: Digest + AssociatedOid,
{
let s = Pkcs1v15Sign::new::<D>();
let hash_len = s.hash_len.unwrap();
if hash_len != input.len() {
return Err(HsmError::HashSizeError(hash_len, input.len()).into());
}
let mut result = Vec::new();
result.extend_from_slice(&s.prefix);
result.extend_from_slice(input);
Ok(result)
}
fn data_plain_text(input: &[u8]) -> Result<Vec<u8>> {
Ok(Sha256::digest(input).as_slice().to_vec())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_raw() -> Result<()> {
let result = SignData::Raw.prepare(KeyType::Rsa, b"abc123")?;
assert_eq!(result, b"abc123");
Ok(())
}
#[test]
fn test_plain_text() -> Result<()> {
let result = SignData::PlainText.prepare(
KeyType::Rsa,
b"The quick brown fox jumped over the lazy dog",
)?;
assert_eq!(hex::encode(result),
"3031300d0609608648016503040201050004207d38b5cd25a2baf85ad3bb5b9311383e671a8a142eb302b324d4a5fba8748c69"
);
Ok(())
}
#[test]
fn test_hashed() -> Result<()> {
let input =
hex::decode("7d38b5cd25a2baf85ad3bb5b9311383e671a8a142eb302b324d4a5fba8748c69")?;
let result = SignData::Sha256Hash.prepare(KeyType::Rsa, &input)?;
assert_eq!(hex::encode(result),
"3031300d0609608648016503040201050004207d38b5cd25a2baf85ad3bb5b9311383e671a8a142eb302b324d4a5fba8748c69"
);
assert!(SignData::Sha256Hash.prepare(KeyType::Rsa, b"").is_err());
Ok(())
}
}