opentitanlib/util/
serde.rs

1// Copyright lowRISC contributors (OpenTitan project).
2// Licensed under the Apache License, Version 2.0, see LICENSE for details.
3// SPDX-License-Identifier: Apache-2.0
4
5use std::fmt;
6use std::marker::PhantomData;
7use std::str::FromStr;
8
9use serde::de::{self, MapAccess, Visitor};
10use serde::{Deserialize, Deserializer};
11
12/// Deserialize a type T from either a string or struct by forwarding
13/// string forms to `FromStr`.
14///
15/// The use-case for this is to allow specifying key material in ownership
16/// configuration files either directly or by filename:
17/// ```
18/// key: {
19///   Ecdsa: "some/path/to/key.pub.der"
20/// }
21///
22/// key: {
23///   Ecdsa: {
24///     x: "...",
25///     y: "..."
26///   }
27/// }
28/// ```
29// This function was taken nearly verbatim from the serde documentation.
30// The example in the serde documentation constrains the `FromStr` error
31// type to `Void`; we constrain to any type implementing std::error::Error.
32pub fn string_or_struct<'de, T, D>(deserializer: D) -> Result<T, D::Error>
33where
34    T: Deserialize<'de> + FromStr,
35    <T as FromStr>::Err: std::error::Error,
36    D: Deserializer<'de>,
37{
38    // This is a Visitor that forwards string types to T's `FromStr` impl and
39    // forwards map types to T's `Deserialize` impl. The `PhantomData` is to
40    // keep the compiler from complaining about T being an unused generic type
41    // parameter. We need T in order to know the Value type for the Visitor
42    // impl.
43    struct StringOrStruct<T>(PhantomData<fn() -> T>);
44
45    impl<'de, T> Visitor<'de> for StringOrStruct<T>
46    where
47        T: Deserialize<'de> + FromStr,
48        <T as FromStr>::Err: std::error::Error,
49    {
50        type Value = T;
51
52        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
53            formatter.write_str("string or map")
54        }
55
56        fn visit_str<E>(self, value: &str) -> Result<T, E>
57        where
58            E: de::Error,
59        {
60            FromStr::from_str(value).map_err(|e| E::custom(format!("{e:?}")))
61        }
62
63        fn visit_map<M>(self, map: M) -> Result<T, M::Error>
64        where
65            M: MapAccess<'de>,
66        {
67            Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))
68        }
69    }
70    deserializer.deserialize_any(StringOrStruct(PhantomData))
71}