1use anyhow::{Context, Result};
6use cryptoki::context::CInitializeArgs;
7use cryptoki::context::CInitializeFlags;
8use cryptoki::context::Pkcs11;
9use cryptoki::session::Session;
10use cryptoki::session::UserType;
11use cryptoki::slot::Slot;
12use cryptoki::types::AuthPin;
13use serde::de::{Deserialize, Deserializer};
14use std::rc::Rc;
15use std::str::FromStr;
16
17use crate::error::HsmError;
18use crate::extra::{SpxEf, SpxKms};
19use acorn::{Acorn, SpxInterface};
20
21#[derive(Debug, Clone)]
22pub enum SpxModule {
23 Acorn(String),
24 Pkcs11Ef,
25 CloudKms(String),
26}
27
28impl SpxModule {
29 pub const HELP: &'static str = "Type of sphincs+ module [allowed values: acorn:<libpath>, cloud-kms:<keyring-params>, pkcs11-ef]";
30}
31
32impl FromStr for SpxModule {
33 type Err = HsmError;
34 fn from_str(s: &str) -> Result<Self, Self::Err> {
35 if s[..6].eq_ignore_ascii_case("acorn:") {
36 Ok(SpxModule::Acorn(s[6..].into()))
37 } else if s.eq_ignore_ascii_case("pkcs11-ef") {
38 Ok(SpxModule::Pkcs11Ef)
39 } else if s[..10].eq_ignore_ascii_case("cloud-kms:") {
40 Ok(SpxModule::CloudKms(s[10..].into()))
41 } else {
42 Err(HsmError::ParseError(format!("unknown SpxModule {s:?}")))
43 }
44 }
45}
46
47pub struct Module {
48 pub pkcs11: Pkcs11,
49 pub session: Option<Rc<Session>>,
50 pub spx: Option<Box<dyn SpxInterface>>,
51 pub token: Option<String>,
52}
53
54impl Module {
55 pub fn initialize(module: &str) -> Result<Self> {
56 let pkcs11 = Pkcs11::new(module)?;
57 pkcs11.initialize(CInitializeArgs::new(CInitializeFlags::OS_LOCKING_OK))?;
58 Ok(Module {
59 pkcs11,
60 session: None,
61 spx: None,
62 token: None,
63 })
64 }
65
66 pub fn initialize_spx(&mut self, module: &SpxModule) -> Result<()> {
67 let module = match module {
68 SpxModule::Acorn(libpath) => Acorn::new(libpath)? as Box<dyn SpxInterface>,
69 SpxModule::CloudKms(keyring) => SpxKms::new(keyring)? as Box<dyn SpxInterface>,
70 SpxModule::Pkcs11Ef => {
71 let session = self.session.clone().ok_or(HsmError::SessionRequired)?;
72 SpxEf::new(session) as Box<dyn SpxInterface>
73 }
74 };
75 self.spx = Some(module);
76 Ok(())
77 }
78
79 pub fn get_session(&self) -> Option<&Session> {
80 self.session.as_ref().map(Rc::as_ref)
81 }
82
83 pub fn get_token(&self, label: &str) -> Result<Slot> {
84 let slots = self.pkcs11.get_slots_with_token()?;
85 for slot in slots {
86 let info = self.pkcs11.get_token_info(slot)?;
87 if label == info.label() {
88 return Ok(slot);
89 }
90 }
91 Err(HsmError::TokenNotFound(label.into()).into())
92 }
93
94 pub fn connect(
95 &mut self,
96 token: &str,
97 user: Option<UserType>,
98 pin: Option<&str>,
99 ) -> Result<()> {
100 let slot = self.get_token(token)?;
101 let session = self.pkcs11.open_rw_session(slot)?;
102 if let Some(user) = user {
103 let pin = pin.map(|x| AuthPin::new(x.to_owned().into()));
104 session
105 .login(user, pin.as_ref())
106 .context("Failed HSM Login")?;
107 }
108 self.token = Some(token.into());
109 self.session = Some(Rc::new(session));
110 Ok(())
111 }
112}
113
114pub fn parse_user_type(val: &str) -> Result<UserType> {
115 match val {
116 "So" | "SO" | "so" | "security_officer" => Ok(UserType::So),
117 "User" | "USER" | "user" => Ok(UserType::User),
118 _ => Err(HsmError::UnknownUser(val.into()).into()),
119 }
120}
121
122pub fn deserialize_user<'de, D>(deserializer: D) -> std::result::Result<UserType, D::Error>
123where
124 D: Deserializer<'de>,
125{
126 let user = String::deserialize(deserializer)?;
127 parse_user_type(&user).map_err(serde::de::Error::custom)
128}