opentitanlib/util/
unknown.rs

1// Copyright lowRISC contributors (OpenTitan project).
2// Licensed under the Apache License, Version 2.0, see LICENSE for details.
3// SPDX-License-Identifier: Apache-2.0
4
5use thiserror::Error;
6
7/// Creates C-like enums which preserve unknown (un-enumerated) values.
8///
9/// If you wanted an enum like:
10/// ```
11/// #[repr(u32)]
12/// pub enum HardenedBool {
13///     True = 0x739,
14///     False = 0x146,
15///     Unknown(u32),
16/// }
17/// ```
18///
19/// Where the `Unknown` discriminator would be the catch-all for any
20/// non-enumerated values, you can use `with_unknown!` as follows:
21///
22/// ```
23/// with_unknown! {
24///     pub enum HardenedBool: u32 {
25///         True = 0x739,
26///         False = 0x14d,
27///     }
28/// }
29/// ```
30///
31/// This "enum" can be used later in match statements:
32/// ```
33/// match foo {
34///     HardenedBool::True => do_the_thing(),
35///     HardenedBool::False => do_the_opposite_thing(),
36///     HardenedBool(x) => panic!("Oh noes! {} is neither True nor False!", x),
37/// }
38/// ```
39///
40/// Behind the scenes, `with_unknown!` implements a newtype struct and
41/// creates associated constants for each of the enumerated values.
42/// The struct also implements `Copy`, `Clone`, `PartialEq`, `Eq`,
43/// `PartialOrd`, `Ord`, `Hash`, `Debug` and `Display` (including the hex,
44/// octal and binary versions).
45///
46/// In addition, `serde::Serialize` and `serde::Deserialize` are
47/// implemented.  The serialized form is a string for known values and an
48/// integer for all unknown values.
49
50#[derive(Debug, Error)]
51pub enum ParseError {
52    #[error("Unknown enum variant: {0}")]
53    Unknown(String),
54}
55
56#[macro_export]
57macro_rules! with_unknown {
58    (
59        $(
60            $(#[$outer:meta])*
61            $vis:vis enum $Enum:ident: $type:ty $([default = $dfl:expr])? {
62                $(
63                    $(#[$inner:meta])*
64                    $enumerator:ident = $value:expr,
65                )*
66            }
67        )*
68    ) => {$(
69        $(#[$outer])*
70        #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
71        #[repr(transparent)]
72        $vis struct $Enum(pub $type);
73
74        #[allow(non_upper_case_globals)]
75        impl $Enum {
76            $(
77                $(#[$inner])*
78                $vis const $enumerator: $Enum = $Enum($value);
79            )*
80        }
81
82        #[allow(dead_code)]
83        impl $Enum {
84            pub const VARIANTS: &[&'static str] = &[
85                $(
86                    stringify!($enumerator),
87                )*
88            ];
89            pub fn is_known_value(&self) -> bool {
90                match *self {
91                    $(
92                        $Enum::$enumerator => true,
93                    )*
94                    _ => false,
95                }
96            }
97        }
98
99        impl From<$Enum> for $type {
100            fn from(v: $Enum) -> $type {
101                v.0
102            }
103        }
104
105        $crate::__impl_default!($Enum, $($dfl)*);
106
107        $crate::__impl_try_from!(i8, $Enum);
108        $crate::__impl_try_from!(i16, $Enum);
109        $crate::__impl_try_from!(i32, $Enum);
110        $crate::__impl_try_from!(i64, $Enum);
111        $crate::__impl_try_from!(u8, $Enum);
112        $crate::__impl_try_from!(u16, $Enum);
113        $crate::__impl_try_from!(u32, $Enum);
114        $crate::__impl_try_from!(u64, $Enum);
115
116        // Implement the various display traits.
117        $crate::__impl_fmt_unknown!(Display, "{}", "{}", $Enum { $($enumerator),* });
118        $crate::__impl_fmt_unknown!(LowerHex, "{:x}", "{:#x}", $Enum { $($enumerator),* });
119        $crate::__impl_fmt_unknown!(UpperHex, "{:X}", "{:#X}", $Enum { $($enumerator),* });
120        $crate::__impl_fmt_unknown!(Octal, "{:o}", "{:#o}", $Enum { $($enumerator),* });
121        $crate::__impl_fmt_unknown!(Binary, "{:b}", "{:#b}", $Enum { $($enumerator),* });
122
123        // Manually implement Serialize and Deserialize to have tight control over how
124        // the struct is serialized.
125        const _: () = {
126            use std::str::FromStr;
127            use serde::ser::{Serialize, Serializer};
128            use serde::de::{Deserialize, Deserializer, Error, Visitor};
129            use std::convert::TryFrom;
130            use $crate::util::unknown::ParseError;
131            use clap::ValueEnum;
132            use clap::builder::PossibleValue;
133
134            impl ValueEnum for $Enum {
135                fn value_variants<'a>() -> &'a [Self] {
136                    const VARIANTS: &[$Enum] = &[
137                        $($Enum::$enumerator),*
138                    ];
139                    VARIANTS
140                }
141
142                fn to_possible_value(&self) -> Option<PossibleValue> {
143                    let s = match *self {
144                        $(
145                            $Enum::$enumerator => stringify!($enumerator),
146                        )*
147                        _ => return None,
148
149                    };
150                    Some(PossibleValue::new(s))
151                }
152            }
153
154            impl FromStr for $Enum {
155                type Err = ParseError;
156                fn from_str(value: &str) -> Result<Self, Self::Err> {
157                    match value {
158                        $(
159                            stringify!($enumerator) => Ok($Enum::$enumerator),
160                        )*
161                        _ => Err(ParseError::Unknown(value.to_string()))
162                    }
163                }
164            }
165
166            impl Serialize for $Enum {
167                /// Serializes the enumerated values.  All named discriminants are
168                /// serialized to strings.  All unknown values are serialized as
169                /// integers.
170                fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
171                where
172                    S: Serializer,
173                {
174                    match *self {
175                        $(
176                            $Enum::$enumerator => serializer.serialize_str(stringify!($enumerator)),
177                        )*
178                        $Enum(value) => value.serialize(serializer),
179                    }
180                }
181            }
182
183            // The `EnumVistor` assists in deserializing the value.
184            struct EnumVisitor;
185            impl<'de> Visitor<'de> for EnumVisitor {
186                type Value = $Enum;
187
188                fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
189                    f.write_str(concat!("A valid enumerator of ", stringify!($Enum)))
190                }
191
192                fn visit_str<E: Error>(self, value: &str) -> Result<Self::Value, E> {
193                    match value {
194                        $(
195                            stringify!($enumerator) => Ok($Enum::$enumerator),
196                        )*
197                        _ => Err(E::custom(format!("unrecognized: {}", value))),
198                    }
199                }
200                $crate::__expand_visit_fn!(visit_i8, i8, $Enum, $type);
201                $crate::__expand_visit_fn!(visit_i16, i16, $Enum, $type);
202                $crate::__expand_visit_fn!(visit_i32, i32, $Enum, $type);
203                $crate::__expand_visit_fn!(visit_i64, i64, $Enum, $type);
204                $crate::__expand_visit_fn!(visit_u8, u8, $Enum, $type);
205                $crate::__expand_visit_fn!(visit_u16, u16, $Enum, $type);
206                $crate::__expand_visit_fn!(visit_u32, u32, $Enum, $type);
207                $crate::__expand_visit_fn!(visit_u64, u64, $Enum, $type);
208            }
209
210            impl<'de> Deserialize<'de> for $Enum {
211                /// Deserializes the value by forwarding to `deserialize_any`.
212                /// `deserialize_any` will forward strings to the string visitor
213                /// and forward integers to the appropriate integer visitor.
214                fn deserialize<D>(deserializer: D) -> Result<$Enum, D::Error>
215                where
216                    D: Deserializer<'de>,
217                {
218                    deserializer.deserialize_any(EnumVisitor)
219                }
220            }
221        };
222    )*};
223}
224
225#[macro_export]
226macro_rules! __impl_try_from {
227    ($from_type:ty, $Enum:ident) => {
228        impl TryFrom<$from_type> for $Enum {
229            type Error = std::num::TryFromIntError;
230            fn try_from(value: $from_type) -> Result<Self, Self::Error> {
231                Ok($Enum(value.try_into()?))
232            }
233        }
234    };
235}
236
237#[macro_export]
238macro_rules! __expand_visit_fn {
239    ($visit_func:ident, $ser_type:ty, $Enum:ident, $enum_type:ty) => {
240        fn $visit_func<E>(self, value: $ser_type) -> Result<Self::Value, E>
241        where
242            E: Error,
243        {
244            match <$enum_type>::try_from(value) {
245                Ok(v) => Ok($Enum(v)),
246                Err(_) => Err(E::custom(format!(
247                    "cannot convert {:?} to {}({})",
248                    value,
249                    stringify!($Enum),
250                    stringify!($enum_type)
251                ))),
252            }
253        }
254    };
255}
256
257// Helper macro for implementing the various formatting traits.
258#[macro_export]
259macro_rules! __impl_fmt_unknown {
260    (
261        $Trait:ident, $Fmt:literal, $Alt:literal, $Enum:ident {
262            $($enumerator:ident),*
263        }
264    ) => {
265        impl std::fmt::$Trait for $Enum {
266            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
267                match *self {
268                    $(
269                        $Enum::$enumerator => write!(f, "{}", stringify!($enumerator)),
270                    )*
271                    $Enum(value) => {
272                        if f.alternate() {
273                            write!(f, concat!(stringify!($Enum), "(", $Alt, ")"), value)
274                        } else {
275                            write!(f, concat!(stringify!($Enum), "(", $Fmt, ")"), value)
276                        }
277                    }
278                }
279            }
280        }
281    }
282}
283
284#[macro_export]
285macro_rules! __impl_default {
286    ($Enum:ident, $dfl:expr) => {
287        impl Default for $Enum {
288            fn default() -> Self {
289                $dfl
290            }
291        }
292    };
293
294    ($Enum:ident, /*nothing*/ ) => {
295        // No default defined, so no implementation.
296    };
297}
298
299#[cfg(test)]
300mod tests {
301    use anyhow::Result;
302    use serde::{Deserialize, Serialize};
303
304    with_unknown! {
305        enum HardenedBool: u32 {
306            True = 0x739,
307            False = 0x14d,
308        }
309
310        // Check creating a `Default` implementation.
311        enum Misc: u8 [default = Self::Z] {
312            X = 0,
313            Y = 1,
314            Z = 2,
315        }
316    }
317
318    #[test]
319    fn test_display() -> Result<()> {
320        let t = HardenedBool::True;
321        assert_eq!(t.to_string(), "True");
322        assert!(t.is_known_value());
323
324        let f = HardenedBool::False;
325        assert_eq!(f.to_string(), "False");
326        assert!(f.is_known_value());
327
328        let j = HardenedBool(0x6A);
329        assert!(!j.is_known_value());
330        assert_eq!(j.to_string(), "HardenedBool(106)");
331        assert_eq!(format!("{:x}", j), "HardenedBool(6a)");
332        assert_eq!(format!("{:#x}", j), "HardenedBool(0x6a)");
333        assert_eq!(format!("{:X}", j), "HardenedBool(6A)");
334        assert_eq!(format!("{:o}", j), "HardenedBool(152)");
335        assert_eq!(format!("{:b}", j), "HardenedBool(1101010)");
336        assert_eq!(format!("{:#b}", j), "HardenedBool(0b1101010)");
337        Ok(())
338    }
339
340    #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
341    struct SomeBools {
342        a: HardenedBool,
343        b: HardenedBool,
344        c: HardenedBool,
345    }
346
347    #[test]
348    fn test_conversion() -> Result<()> {
349        let t = HardenedBool::True;
350        let x = HardenedBool(12345);
351        assert_eq!(u32::from(t), 0x739);
352        assert_eq!(u32::from(x), 12345);
353        Ok(())
354    }
355
356    #[test]
357    fn test_default() -> Result<()> {
358        let z = Misc::default();
359        assert_eq!(z, Misc::Z);
360        Ok(())
361    }
362
363    #[test]
364    fn test_serde() -> Result<()> {
365        let b = SomeBools {
366            a: HardenedBool::True,
367            b: HardenedBool::False,
368            c: HardenedBool(0x6a),
369        };
370        let json = serde_json::to_string(&b)?;
371        assert_eq!(json, r#"{"a":"True","b":"False","c":106}"#);
372
373        let de = serde_json::from_str::<SomeBools>(&json)?;
374        assert_eq!(de, b);
375        Ok(())
376    }
377
378    #[test]
379    fn test_serde_error() -> Result<()> {
380        let json = r#"{"a":"True","b":"False","c":-1}"#;
381        let de = serde_json::from_str::<SomeBools>(json);
382        let err = de.unwrap_err().to_string();
383        assert_eq!(
384            err,
385            "cannot convert -1 to HardenedBool(u32) at line 1 column 30"
386        );
387        Ok(())
388    }
389}