1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
// Copyright lowRISC contributors (OpenTitan project).
// Licensed under the Apache License, Version 2.0, see LICENSE for details.
// SPDX-License-Identifier: Apache-2.0
use std::fmt;
use std::marker::PhantomData;
use std::str::FromStr;
use serde::de::{self, MapAccess, Visitor};
use serde::{Deserialize, Deserializer};
/// Deserialize a type T from either a string or struct by forwarding
/// string forms to `FromStr`.
///
/// The use-case for this is to allow specifying key material in ownership
/// configuration files either directly or by filename:
/// ```
/// key: {
/// Ecdsa: "some/path/to/key.pub.der"
/// }
///
/// key: {
/// Ecdsa: {
/// x: "...",
/// y: "..."
/// }
/// }
/// ```
// This function was taken nearly verbatim from the serde documentation.
// The example in the serde documentation constrains the `FromStr` error
// type to `Void`; we constrain to any type implementing std::error::Error.
pub fn string_or_struct<'de, T, D>(deserializer: D) -> Result<T, D::Error>
where
T: Deserialize<'de> + FromStr,
<T as FromStr>::Err: std::error::Error,
D: Deserializer<'de>,
{
// This is a Visitor that forwards string types to T's `FromStr` impl and
// forwards map types to T's `Deserialize` impl. The `PhantomData` is to
// keep the compiler from complaining about T being an unused generic type
// parameter. We need T in order to know the Value type for the Visitor
// impl.
struct StringOrStruct<T>(PhantomData<fn() -> T>);
impl<'de, T> Visitor<'de> for StringOrStruct<T>
where
T: Deserialize<'de> + FromStr,
<T as FromStr>::Err: std::error::Error,
{
type Value = T;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("string or map")
}
fn visit_str<E>(self, value: &str) -> Result<T, E>
where
E: de::Error,
{
FromStr::from_str(value).map_err(|e| E::custom(format!("{e:?}")))
}
fn visit_map<M>(self, map: M) -> Result<T, M::Error>
where
M: MapAccess<'de>,
{
Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))
}
}
deserializer.deserialize_any(StringOrStruct(PhantomData))
}