1use anyhow::Result;
19use rand::RngCore;
20use serde::de::{self, Deserializer, Unexpected};
21use serde::ser::Serializer;
22use serde::{Deserialize, Serialize};
23use std::any::type_name;
24use std::fmt;
25use std::marker::PhantomData;
26use std::ops::Deref;
27
28use crate::util::parse_int::{ParseInt, ParseIntError};
29
30pub fn deserialize<'de, D, T>(deserializer: D) -> Result<T, D::Error>
32where
33 D: Deserializer<'de>,
34 T: ParseInt,
35{
36 struct Visitor<U>(PhantomData<U>);
37
38 impl<'a, U: ParseInt> de::Visitor<'a> for Visitor<U> {
39 type Value = U;
40
41 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
42 formatter.write_fmt(format_args!("a string that parses to {}", type_name::<U>()))
43 }
44
45 fn visit_string<E>(self, mut name: String) -> Result<Self::Value, E>
46 where
47 E: de::Error,
48 {
49 if name.starts_with("false") {
50 name = "0".to_owned()
51 } else if name.starts_with("true") {
52 name = "1".to_owned()
53 }
54
55 let trimmed = if name.starts_with("0x") {
56 &name
57 } else {
58 let trimmed = name[0..name.len() - 1].trim_start_matches('0');
59 &name[name.len() - trimmed.len() - 1..]
60 };
61
62 match U::from_str(trimmed) {
63 Ok(value) => Ok(value),
64 Err(_) => Err(de::Error::invalid_value(Unexpected::Str(trimmed), &self)),
65 }
66 }
67
68 fn visit_str<E>(self, name: &str) -> Result<Self::Value, E>
69 where
70 E: de::Error,
71 {
72 self.visit_string(name.to_owned())
73 }
74
75 fn visit_bool<E>(self, v: bool) -> Result<Self::Value, E>
76 where
77 E: de::Error,
78 {
79 if v {
80 self.visit_str("1")
81 } else {
82 self.visit_str("0")
83 }
84 }
85 }
86
87 deserializer.deserialize_string(Visitor(PhantomData::<T>))
88}
89
90#[derive(Debug, PartialEq, Clone)]
92enum DeferredInit {
93 Initialized(Vec<u8>),
94 Random,
95}
96
97#[derive(Debug, Clone, Deserialize)]
98pub struct DeferredValue(#[serde(deserialize_with = "deserialize")] DeferredInit);
99
100impl DeferredValue {
101 pub fn resolve(&self, size: usize, rng: &mut dyn RngCore) -> Vec<u8> {
102 match self.0.clone() {
103 DeferredInit::Initialized(mut vec) => {
104 vec.resize(size, 0);
105 vec
106 }
107 DeferredInit::Random => {
108 let mut vec = vec![0u8; size];
109 rng.fill_bytes(&mut vec);
110 vec
111 }
112 }
113 }
114
115 pub fn is_initialized(&self) -> bool {
116 matches!(self.0, DeferredInit::Initialized(_))
117 }
118}
119
120impl ParseInt for DeferredInit {
121 type FromStrRadixErr = ParseIntError;
122
123 fn from_str_radix(src: &str, radix: u32) -> Result<Self, ParseIntError> {
124 Ok(DeferredInit::Initialized(Vec::<u8>::from_str_radix(
125 src, radix,
126 )?))
127 }
128
129 fn from_str(src: &str) -> Result<Self, ParseIntError> {
130 if src == "<random>" {
131 Ok(DeferredInit::Random)
132 } else {
133 Ok(DeferredInit::Initialized(Vec::<u8>::from_str(src)?))
134 }
135 }
136}
137
138impl Deref for DeferredValue {
139 type Target = [u8];
140
141 fn deref(&self) -> &Self::Target {
142 match &self.0 {
143 DeferredInit::Initialized(val) => val,
144 _ => panic!("Value has not been initialized"),
145 }
146 }
147}
148
149#[derive(Clone, Deserialize, Debug, PartialEq)]
151pub struct OctEncoded<T>(#[serde(deserialize_with = "deserialize")] pub T)
152where
153 T: ParseInt + fmt::Octal;
154
155#[derive(Clone, Deserialize, Debug, PartialEq)]
157pub struct DecEncoded<T>(#[serde(deserialize_with = "deserialize")] pub T)
158where
159 T: ParseInt + fmt::Display;
160
161#[derive(Clone, Deserialize, Debug, PartialEq)]
163pub struct HexEncoded<T>(#[serde(deserialize_with = "deserialize")] pub T)
164where
165 T: ParseInt + fmt::LowerHex;
166
167macro_rules! impl_parse_int_enc {
168 ($ty:ident, $radix:expr, $fmt:path, $prefix:expr) => {
169 impl<T: ParseInt + $fmt> std::ops::Deref for $ty<T> {
170 type Target = T;
171
172 fn deref(&self) -> &Self::Target {
173 &self.0
174 }
175 }
176
177 impl<T: ParseInt + $fmt> ParseInt for $ty<T> {
178 type FromStrRadixErr = T::FromStrRadixErr;
179
180 fn from_str_radix(src: &str, radix: u32) -> Result<Self, T::FromStrRadixErr> {
181 Ok(Self(T::from_str_radix(src, radix)?))
182 }
183
184 fn from_str(src: &str) -> Result<Self, ParseIntError> {
185 Self::from_str_radix(src, $radix).map_err(|e| e.into())
186 }
187 }
188
189 impl<T: ParseInt + $fmt> fmt::Display for $ty<T> {
190 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
191 write!(f, "{}", $prefix)?;
192 <_ as $fmt>::fmt(&self.0, f)
193 }
194 }
195
196 impl<T: ParseInt + $fmt> Serialize for $ty<T> {
197 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
198 where
199 S: Serializer,
200 {
201 serializer.serialize_str(&self.to_string())
202 }
203 }
204 };
205}
206
207impl_parse_int_enc!(OctEncoded, 8, fmt::Octal, "0o");
208impl_parse_int_enc!(DecEncoded, 10, fmt::Display, "");
209impl_parse_int_enc!(HexEncoded, 16, fmt::LowerHex, "0x");
210
211impl ParseInt for Vec<u8> {
212 type FromStrRadixErr = ParseIntError;
213
214 fn from_str_radix(src: &str, radix: u32) -> Result<Self, ParseIntError> {
215 let mut bytes = vec![];
216 for digit_bytes in src.as_bytes().rchunks(2) {
217 let digits = std::str::from_utf8(digit_bytes).unwrap();
218 bytes.push(u8::from_str_radix(digits, radix)?);
219 }
220 Ok(bytes)
221 }
222}
223
224#[cfg(test)]
225mod test {
226 use super::*;
227 use serde::Deserialize;
228
229 #[test]
230 fn de_u8() -> Result<()> {
231 #[derive(Debug, Deserialize)]
232 struct TestData {
233 #[serde(deserialize_with = "deserialize")]
234 oct: OctEncoded<u8>,
235 #[serde(deserialize_with = "deserialize")]
236 dec: DecEncoded<u8>,
237 #[serde(deserialize_with = "deserialize")]
238 hex: HexEncoded<u8>,
239 }
240
241 let data: TestData = deser_hjson::from_str(stringify!(
242 {
243 oct: "77",
244 dec: "77",
245 hex: "77"
246 }))?;
247
248 assert_eq!(*data.oct, 63);
249 assert_eq!(*data.dec, 77);
250 assert_eq!(*data.hex, 119);
251 Ok(())
252 }
253
254 #[test]
255 fn byte_field_test() {
256 assert_eq!(Vec::from_str("0x1"), Ok(vec![0x1]));
257 assert_eq!(
258 Vec::from_str("0x0706050403020100"),
259 Ok(vec![0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07])
260 );
261 assert_eq!(
262 u64::from_ne_bytes(
263 Vec::from_str("0x0706050403020100")
264 .unwrap()
265 .try_into()
266 .unwrap()
267 ),
268 u64::from_str("0x0706050403020100").unwrap()
269 );
270 assert!(Vec::from_str("-1").is_err());
271 }
272}