1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
// Copyright lowRISC contributors (OpenTitan project).
// Licensed under the Apache License, Version 2.0, see LICENSE for details.
// SPDX-License-Identifier: Apache-2.0

use anyhow::{Context, Result};
use cryptoki::context::CInitializeArgs;
use cryptoki::context::Pkcs11;
use cryptoki::session::Session;
use cryptoki::session::UserType;
use cryptoki::slot::Slot;
use cryptoki::types::AuthPin;
use serde::de::{Deserialize, Deserializer};
use std::rc::Rc;
use std::str::FromStr;

use crate::error::HsmError;
use crate::spxef::SpxEf;
use acorn::{Acorn, SpxInterface};

#[derive(Debug, Clone)]
pub enum SpxModule {
    Acorn(String),
    Pkcs11Ef,
}

impl SpxModule {
    pub const HELP: &'static str =
        "Type of sphincs+ module [allowed values: acorn:<libpath>, pkcs11-ef]";
}

impl FromStr for SpxModule {
    type Err = HsmError;
    fn from_str(s: &str) -> Result<Self, Self::Err> {
        if s[..6].eq_ignore_ascii_case("acorn:") {
            Ok(SpxModule::Acorn(s[6..].into()))
        } else if s.eq_ignore_ascii_case("pkcs11-ef") {
            Ok(SpxModule::Pkcs11Ef)
        } else {
            Err(HsmError::ParseError(format!("unknown SpxModule {s:?}")))
        }
    }
}

pub struct Module {
    pub pkcs11: Pkcs11,
    pub session: Option<Rc<Session>>,
    pub acorn: Option<Box<dyn SpxInterface>>,
    pub token: Option<String>,
}

impl Module {
    pub fn initialize(module: &str) -> Result<Self> {
        let pkcs11 = Pkcs11::new(module)?;
        pkcs11.initialize(CInitializeArgs::OsThreads)?;
        Ok(Module {
            pkcs11,
            session: None,
            acorn: None,
            token: None,
        })
    }

    pub fn initialize_spx(&mut self, module: &SpxModule) -> Result<()> {
        let module = match module {
            SpxModule::Acorn(libpath) => Acorn::new(libpath)? as Box<dyn SpxInterface>,
            SpxModule::Pkcs11Ef => {
                let session = self
                    .session
                    .as_ref()
                    .map(Rc::clone)
                    .ok_or(HsmError::SessionRequired)?;
                SpxEf::new(session) as Box<dyn SpxInterface>
            }
        };
        self.acorn = Some(module);
        Ok(())
    }

    pub fn get_session(&self) -> Option<&Session> {
        self.session.as_ref().map(Rc::as_ref)
    }

    pub fn get_token(&self, label: &str) -> Result<Slot> {
        let slots = self.pkcs11.get_slots_with_token()?;
        for slot in slots {
            let info = self.pkcs11.get_token_info(slot)?;
            if label == info.label() {
                return Ok(slot);
            }
        }
        Err(HsmError::TokenNotFound(label.into()).into())
    }

    pub fn connect(
        &mut self,
        token: &str,
        user: Option<UserType>,
        pin: Option<&str>,
    ) -> Result<()> {
        let slot = self.get_token(token)?;
        let session = self.pkcs11.open_rw_session(slot)?;
        if let Some(user) = user {
            let pin = pin.map(|x| AuthPin::new(x.to_owned()));
            session
                .login(user, pin.as_ref())
                .context("Failed HSM Login")?;
        }
        self.token = Some(token.into());
        self.session = Some(Rc::new(session));
        Ok(())
    }
}

pub fn parse_user_type(val: &str) -> Result<UserType> {
    match val {
        "So" | "SO" | "so" | "security_officer" => Ok(UserType::So),
        "User" | "USER" | "user" => Ok(UserType::User),
        _ => Err(HsmError::UnknownUser(val.into()).into()),
    }
}

pub fn deserialize_user<'de, D>(deserializer: D) -> std::result::Result<UserType, D::Error>
where
    D: Deserializer<'de>,
{
    let user = String::deserialize(deserializer)?;
    parse_user_type(&user).map_err(serde::de::Error::custom)
}