opentitanlib/util/serde.rs
1// Copyright lowRISC contributors (OpenTitan project).
2// Licensed under the Apache License, Version 2.0, see LICENSE for details.
3// SPDX-License-Identifier: Apache-2.0
4
5use std::fmt;
6use std::marker::PhantomData;
7use std::str::FromStr;
8
9use serde::de::{self, MapAccess, Visitor};
10use serde::{Deserialize, Deserializer};
11
12/// Deserialize a type T from either a string or struct by forwarding
13/// string forms to `FromStr`.
14///
15/// The use-case for this is to allow specifying key material in ownership
16/// configuration files either directly or by filename:
17/// ```
18/// key: {
19/// Ecdsa: "some/path/to/key.pub.der"
20/// }
21///
22/// key: {
23/// Ecdsa: {
24/// x: "...",
25/// y: "..."
26/// }
27/// }
28/// ```
29// This function was taken nearly verbatim from the serde documentation.
30// The example in the serde documentation constrains the `FromStr` error
31// type to `Void`; we constrain to any type implementing std::error::Error.
32pub fn string_or_struct<'de, T, D>(deserializer: D) -> Result<T, D::Error>
33where
34 T: Deserialize<'de> + FromStr,
35 <T as FromStr>::Err: std::error::Error,
36 D: Deserializer<'de>,
37{
38 // This is a Visitor that forwards string types to T's `FromStr` impl and
39 // forwards map types to T's `Deserialize` impl. The `PhantomData` is to
40 // keep the compiler from complaining about T being an unused generic type
41 // parameter. We need T in order to know the Value type for the Visitor
42 // impl.
43 struct StringOrStruct<T>(PhantomData<fn() -> T>);
44
45 impl<'de, T> Visitor<'de> for StringOrStruct<T>
46 where
47 T: Deserialize<'de> + FromStr,
48 <T as FromStr>::Err: std::error::Error,
49 {
50 type Value = T;
51
52 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
53 formatter.write_str("string or map")
54 }
55
56 fn visit_str<E>(self, value: &str) -> Result<T, E>
57 where
58 E: de::Error,
59 {
60 FromStr::from_str(value).map_err(|e| E::custom(format!("{e:?}")))
61 }
62
63 fn visit_map<M>(self, map: M) -> Result<T, M::Error>
64 where
65 M: MapAccess<'de>,
66 {
67 Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))
68 }
69 }
70 deserializer.deserialize_any(StringOrStruct(PhantomData))
71}