1use anyhow::{Context, Result};
6use cryptoki::context::CInitializeArgs;
7use cryptoki::context::Pkcs11;
8use cryptoki::session::Session;
9use cryptoki::session::UserType;
10use cryptoki::slot::Slot;
11use cryptoki::types::AuthPin;
12use serde::de::{Deserialize, Deserializer};
13use std::rc::Rc;
14use std::str::FromStr;
15
16use crate::error::HsmError;
17use crate::extra::{SpxEf, SpxKms};
18use acorn::{Acorn, SpxInterface};
19
20#[derive(Debug, Clone)]
21pub enum SpxModule {
22 Acorn(String),
23 Pkcs11Ef,
24 CloudKms(String),
25}
26
27impl SpxModule {
28 pub const HELP: &'static str = "Type of sphincs+ module [allowed values: acorn:<libpath>, cloud-kms:<keyring-params>, pkcs11-ef]";
29}
30
31impl FromStr for SpxModule {
32 type Err = HsmError;
33 fn from_str(s: &str) -> Result<Self, Self::Err> {
34 if s[..6].eq_ignore_ascii_case("acorn:") {
35 Ok(SpxModule::Acorn(s[6..].into()))
36 } else if s.eq_ignore_ascii_case("pkcs11-ef") {
37 Ok(SpxModule::Pkcs11Ef)
38 } else if s[..10].eq_ignore_ascii_case("cloud-kms:") {
39 Ok(SpxModule::CloudKms(s[10..].into()))
40 } else {
41 Err(HsmError::ParseError(format!("unknown SpxModule {s:?}")))
42 }
43 }
44}
45
46pub struct Module {
47 pub pkcs11: Pkcs11,
48 pub session: Option<Rc<Session>>,
49 pub spx: Option<Box<dyn SpxInterface>>,
50 pub token: Option<String>,
51}
52
53impl Module {
54 pub fn initialize(module: &str) -> Result<Self> {
55 let pkcs11 = Pkcs11::new(module)?;
56 pkcs11.initialize(CInitializeArgs::OsThreads)?;
57 Ok(Module {
58 pkcs11,
59 session: None,
60 spx: None,
61 token: None,
62 })
63 }
64
65 pub fn initialize_spx(&mut self, module: &SpxModule) -> Result<()> {
66 let module = match module {
67 SpxModule::Acorn(libpath) => Acorn::new(libpath)? as Box<dyn SpxInterface>,
68 SpxModule::CloudKms(keyring) => SpxKms::new(keyring)? as Box<dyn SpxInterface>,
69 SpxModule::Pkcs11Ef => {
70 let session = self.session.clone().ok_or(HsmError::SessionRequired)?;
71 SpxEf::new(session) as Box<dyn SpxInterface>
72 }
73 };
74 self.spx = Some(module);
75 Ok(())
76 }
77
78 pub fn get_session(&self) -> Option<&Session> {
79 self.session.as_ref().map(Rc::as_ref)
80 }
81
82 pub fn get_token(&self, label: &str) -> Result<Slot> {
83 let slots = self.pkcs11.get_slots_with_token()?;
84 for slot in slots {
85 let info = self.pkcs11.get_token_info(slot)?;
86 if label == info.label() {
87 return Ok(slot);
88 }
89 }
90 Err(HsmError::TokenNotFound(label.into()).into())
91 }
92
93 pub fn connect(
94 &mut self,
95 token: &str,
96 user: Option<UserType>,
97 pin: Option<&str>,
98 ) -> Result<()> {
99 let slot = self.get_token(token)?;
100 let session = self.pkcs11.open_rw_session(slot)?;
101 if let Some(user) = user {
102 let pin = pin.map(|x| AuthPin::new(x.to_owned()));
103 session
104 .login(user, pin.as_ref())
105 .context("Failed HSM Login")?;
106 }
107 self.token = Some(token.into());
108 self.session = Some(Rc::new(session));
109 Ok(())
110 }
111}
112
113pub fn parse_user_type(val: &str) -> Result<UserType> {
114 match val {
115 "So" | "SO" | "so" | "security_officer" => Ok(UserType::So),
116 "User" | "USER" | "user" => Ok(UserType::User),
117 _ => Err(HsmError::UnknownUser(val.into()).into()),
118 }
119}
120
121pub fn deserialize_user<'de, D>(deserializer: D) -> std::result::Result<UserType, D::Error>
122where
123 D: Deserializer<'de>,
124{
125 let user = String::deserialize(deserializer)?;
126 parse_user_type(&user).map_err(serde::de::Error::custom)
127}