1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
// Copyright lowRISC contributors (OpenTitan project).
// Licensed under the Apache License, Version 2.0, see LICENSE for details.
// SPDX-License-Identifier: Apache-2.0

use anyhow::{Context, Result};
use cryptoki::context::CInitializeArgs;
use cryptoki::context::Pkcs11;
use cryptoki::session::Session;
use cryptoki::session::UserType;
use cryptoki::slot::Slot;
use cryptoki::types::AuthPin;
use serde::de::{Deserialize, Deserializer};

use crate::error::HsmError;
use acorn::Acorn;

pub struct Module {
    pub pkcs11: Pkcs11,
    pub acorn: Option<Acorn>,
    pub token: Option<String>,
}

impl Module {
    pub fn initialize(module: &str, acorn: Option<&str>) -> Result<Self> {
        let pkcs11 = Pkcs11::new(module)?;
        pkcs11.initialize(CInitializeArgs::OsThreads)?;
        let acorn = acorn.map(Acorn::new).transpose()?;
        Ok(Module {
            pkcs11,
            acorn,
            token: None,
        })
    }

    pub fn get_token(&self, label: &str) -> Result<Slot> {
        let slots = self.pkcs11.get_slots_with_token()?;
        for slot in slots {
            let info = self.pkcs11.get_token_info(slot)?;
            if label == info.label() {
                return Ok(slot);
            }
        }
        Err(HsmError::TokenNotFound(label.into()).into())
    }

    pub fn connect(
        &mut self,
        token: &str,
        user: Option<UserType>,
        pin: Option<&str>,
    ) -> Result<Session> {
        let slot = self.get_token(token)?;
        let session = self.pkcs11.open_rw_session(slot)?;
        if let Some(user) = user {
            let pin = pin.map(|x| AuthPin::new(x.to_owned()));
            session
                .login(user, pin.as_ref())
                .context("Failed HSM Login")?;
        }
        self.token = Some(token.into());
        Ok(session)
    }
}

pub fn parse_user_type(val: &str) -> Result<UserType> {
    match val {
        "So" | "SO" | "so" | "security_officer" => Ok(UserType::So),
        "User" | "USER" | "user" => Ok(UserType::User),
        _ => Err(HsmError::UnknownUser(val.into()).into()),
    }
}

pub fn deserialize_user<'de, D>(deserializer: D) -> std::result::Result<UserType, D::Error>
where
    D: Deserializer<'de>,
{
    let user = String::deserialize(deserializer)?;
    parse_user_type(&user).map_err(serde::de::Error::custom)
}