opentitanlib/util/
num_de.rs

1// Copyright lowRISC contributors (OpenTitan project).
2// Licensed under the Apache License, Version 2.0, see LICENSE for details.
3// SPDX-License-Identifier: Apache-2.0
4
5/// Deserialization utilities for certain values in OTP HJSON files.
6///
7/// The OTP HJSON files have some strange values:
8///
9/// Integers, sometimes wrapped in strings, with inconsistent formatting and meta values, such as:
10///   - value: "0x739"
11///   - key_size: "16"
12///   - seed: "10556718629619452145"
13///   - seed: 01931961561863975174  // This is a decimal integer, not octal.
14///   - value: "<random>"
15///
16/// Additionally, some values have sizes defined within the config files themselves, such as the
17/// keys. This module exists to handle these peculiar cases.
18use anyhow::Result;
19use rand::RngCore;
20use serde::de::{self, Deserializer, Unexpected};
21use serde::ser::Serializer;
22use serde::{Deserialize, Serialize};
23use std::any::type_name;
24use std::fmt;
25use std::marker::PhantomData;
26use std::ops::Deref;
27
28use crate::util::parse_int::{ParseInt, ParseIntError};
29
30/// Deserialize numeric types from HJSON config files.
31pub fn deserialize<'de, D, T>(deserializer: D) -> Result<T, D::Error>
32where
33    D: Deserializer<'de>,
34    T: ParseInt,
35{
36    struct Visitor<U>(PhantomData<U>);
37
38    impl<'a, U: ParseInt> de::Visitor<'a> for Visitor<U> {
39        type Value = U;
40
41        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
42            formatter.write_fmt(format_args!("a string that parses to {}", type_name::<U>()))
43        }
44
45        fn visit_string<E>(self, mut name: String) -> Result<Self::Value, E>
46        where
47            E: de::Error,
48        {
49            if name.starts_with("false") {
50                name = "0".to_owned()
51            } else if name.starts_with("true") {
52                name = "1".to_owned()
53            }
54
55            let trimmed = if name.starts_with("0x") {
56                &name
57            } else {
58                let trimmed = name[0..name.len() - 1].trim_start_matches('0');
59                &name[name.len() - trimmed.len() - 1..]
60            };
61
62            match U::from_str(trimmed) {
63                Ok(value) => Ok(value),
64                Err(_) => Err(de::Error::invalid_value(Unexpected::Str(trimmed), &self)),
65            }
66        }
67
68        fn visit_str<E>(self, name: &str) -> Result<Self::Value, E>
69        where
70            E: de::Error,
71        {
72            self.visit_string(name.to_owned())
73        }
74
75        fn visit_bool<E>(self, v: bool) -> Result<Self::Value, E>
76        where
77            E: de::Error,
78        {
79            if v {
80                self.visit_str("1")
81            } else {
82                self.visit_str("0")
83            }
84        }
85    }
86
87    deserializer.deserialize_string(Visitor(PhantomData::<T>))
88}
89
90/// Placeholder type for values that cannot be resolved during deserialization.
91#[derive(Debug, PartialEq, Clone)]
92enum DeferredInit {
93    Initialized(Vec<u8>),
94    Random,
95}
96
97#[derive(Debug, Clone, Deserialize)]
98pub struct DeferredValue(#[serde(deserialize_with = "deserialize")] DeferredInit);
99
100impl DeferredValue {
101    pub fn resolve(&self, size: usize, rng: &mut dyn RngCore) -> Vec<u8> {
102        match self.0.clone() {
103            DeferredInit::Initialized(mut vec) => {
104                vec.resize(size, 0);
105                vec
106            }
107            DeferredInit::Random => {
108                let mut vec = vec![0u8; size];
109                rng.fill_bytes(&mut vec);
110                vec
111            }
112        }
113    }
114
115    pub fn is_initialized(&self) -> bool {
116        matches!(self.0, DeferredInit::Initialized(_))
117    }
118}
119
120impl ParseInt for DeferredInit {
121    type FromStrRadixErr = ParseIntError;
122
123    fn from_str_radix(src: &str, radix: u32) -> Result<Self, ParseIntError> {
124        Ok(DeferredInit::Initialized(Vec::<u8>::from_str_radix(
125            src, radix,
126        )?))
127    }
128
129    fn from_str(src: &str) -> Result<Self, ParseIntError> {
130        if src == "<random>" {
131            Ok(DeferredInit::Random)
132        } else {
133            Ok(DeferredInit::Initialized(Vec::<u8>::from_str(src)?))
134        }
135    }
136}
137
138impl Deref for DeferredValue {
139    type Target = [u8];
140
141    fn deref(&self) -> &Self::Target {
142        match &self.0 {
143            DeferredInit::Initialized(val) => val,
144            _ => panic!("Value has not been initialized"),
145        }
146    }
147}
148
149/// Wrapper type to force deserialization assuming octal encoding.
150#[derive(Clone, Deserialize, Debug, PartialEq)]
151pub struct OctEncoded<T>(#[serde(deserialize_with = "deserialize")] pub T)
152where
153    T: ParseInt + fmt::Octal;
154
155/// Wrapper type to force deserialization assuming decimal encoding.
156#[derive(Clone, Deserialize, Debug, PartialEq)]
157pub struct DecEncoded<T>(#[serde(deserialize_with = "deserialize")] pub T)
158where
159    T: ParseInt + fmt::Display;
160
161/// Wrapper type to force deserialization assuming hexadecimal encoding.
162#[derive(Clone, Deserialize, Debug, PartialEq)]
163pub struct HexEncoded<T>(#[serde(deserialize_with = "deserialize")] pub T)
164where
165    T: ParseInt + fmt::LowerHex;
166
167macro_rules! impl_parse_int_enc {
168    ($ty:ident, $radix:expr, $fmt:path, $prefix:expr) => {
169        impl<T: ParseInt + $fmt> std::ops::Deref for $ty<T> {
170            type Target = T;
171
172            fn deref(&self) -> &Self::Target {
173                &self.0
174            }
175        }
176
177        impl<T: ParseInt + $fmt> ParseInt for $ty<T> {
178            type FromStrRadixErr = T::FromStrRadixErr;
179
180            fn from_str_radix(src: &str, radix: u32) -> Result<Self, T::FromStrRadixErr> {
181                Ok(Self(T::from_str_radix(src, radix)?))
182            }
183
184            fn from_str(src: &str) -> Result<Self, ParseIntError> {
185                Self::from_str_radix(src, $radix).map_err(|e| e.into())
186            }
187        }
188
189        impl<T: ParseInt + $fmt> fmt::Display for $ty<T> {
190            fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
191                write!(f, "{}", $prefix)?;
192                <_ as $fmt>::fmt(&self.0, f)
193            }
194        }
195
196        impl<T: ParseInt + $fmt> Serialize for $ty<T> {
197            fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
198            where
199                S: Serializer,
200            {
201                serializer.serialize_str(&self.to_string())
202            }
203        }
204    };
205}
206
207impl_parse_int_enc!(OctEncoded, 8, fmt::Octal, "0o");
208impl_parse_int_enc!(DecEncoded, 10, fmt::Display, "");
209impl_parse_int_enc!(HexEncoded, 16, fmt::LowerHex, "0x");
210
211impl ParseInt for Vec<u8> {
212    type FromStrRadixErr = ParseIntError;
213
214    fn from_str_radix(src: &str, radix: u32) -> Result<Self, ParseIntError> {
215        let mut bytes = vec![];
216        for digit_bytes in src.as_bytes().rchunks(2) {
217            let digits = std::str::from_utf8(digit_bytes).unwrap();
218            bytes.push(u8::from_str_radix(digits, radix)?);
219        }
220        Ok(bytes)
221    }
222}
223
224#[cfg(test)]
225mod test {
226    use super::*;
227    use serde::Deserialize;
228
229    #[test]
230    fn de_u8() -> Result<()> {
231        #[derive(Debug, Deserialize)]
232        struct TestData {
233            #[serde(deserialize_with = "deserialize")]
234            oct: OctEncoded<u8>,
235            #[serde(deserialize_with = "deserialize")]
236            dec: DecEncoded<u8>,
237            #[serde(deserialize_with = "deserialize")]
238            hex: HexEncoded<u8>,
239        }
240
241        let data: TestData = deser_hjson::from_str(stringify!(
242        {
243            oct: "77",
244            dec: "77",
245            hex: "77"
246        }))?;
247
248        assert_eq!(*data.oct, 63);
249        assert_eq!(*data.dec, 77);
250        assert_eq!(*data.hex, 119);
251        Ok(())
252    }
253
254    #[test]
255    fn byte_field_test() {
256        assert_eq!(Vec::from_str("0x1"), Ok(vec![0x1]));
257        assert_eq!(
258            Vec::from_str("0x0706050403020100"),
259            Ok(vec![0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07])
260        );
261        assert_eq!(
262            u64::from_ne_bytes(
263                Vec::from_str("0x0706050403020100")
264                    .unwrap()
265                    .try_into()
266                    .unwrap()
267            ),
268            u64::from_str("0x0706050403020100").unwrap()
269        );
270        assert!(Vec::from_str("-1").is_err());
271    }
272}