opentitanlib/util/
file.rs

1// Copyright lowRISC contributors (OpenTitan project).
2// Licensed under the Apache License, Version 2.0, see LICENSE for details.
3// SPDX-License-Identifier: Apache-2.0
4
5use anyhow::{Context, Result, ensure};
6use pem_rfc7468::{Decoder, Encoder, LineEnding};
7use thiserror::Error;
8
9use std::convert::TryInto;
10use std::fs::File;
11use std::io::{self, Read, Write};
12use std::os::fd::{AsFd, BorrowedFd};
13use std::path::Path;
14use std::time::Duration;
15
16/// Error type for errors related to PEM serialization.
17#[derive(Debug, Error)]
18enum PemError {
19    #[error("PEM type error; expecting {0:?} but got {1:?}")]
20    LabelError(&'static str, String),
21}
22
23/// Trait for data that can be written to and read from PEM files.
24pub trait PemSerilizable: ToWriter + FromReader {
25    /// The label for the PEM file.
26    ///
27    /// Appears around the base64 encoded data.
28    /// -----BEGIN MY_LABEL-----
29    /// ...
30    /// -----END MY_LABEL-----
31    fn label() -> &'static str;
32
33    /// Write to PEM file with label from `Self::label()`.
34    fn write_pem_file(&self, path: &Path) -> Result<()> {
35        const MAX_PEM_SIZE: usize = 4096;
36
37        let mut bytes = Vec::<u8>::new();
38        self.to_writer(&mut bytes)?;
39
40        let mut buf = [0u8; MAX_PEM_SIZE];
41        let mut encoder = Encoder::new(Self::label(), LineEnding::LF, &mut buf)?;
42        encoder.encode(&bytes)?;
43        let len = encoder.finish()?;
44
45        let mut file = File::create(path)?;
46        Ok(file.write_all(&buf[..len])?)
47    }
48
49    /// Read in from PEM file, ensuring the label matches `Self::label()`.
50    fn read_pem_file(path: &Path) -> Result<Self> {
51        let mut file = File::open(path)?;
52        let mut pem = Vec::<u8>::new();
53        file.read_to_end(&mut pem)?;
54
55        let mut decoder = Decoder::new(&pem)?;
56        ensure!(
57            decoder.type_label() == Self::label(),
58            PemError::LabelError(Self::label(), decoder.type_label().to_owned()),
59        );
60
61        let mut buf = Vec::new();
62        decoder.decode_to_end(&mut buf)?;
63
64        Self::from_reader(buf.as_slice())
65    }
66}
67
68/// Trait for data types that can be streamed to a reader.
69pub trait FromReader: Sized {
70    /// Reads in an instance of `Self`.
71    fn from_reader(r: impl Read) -> Result<Self>;
72
73    /// Reads an instance of `Self` from a binary file at `path`.
74    fn read_from_file(path: &Path) -> Result<Self> {
75        let file = File::open(path).with_context(|| format!("Failed to open {path:?}"))?;
76        Self::from_reader(file)
77    }
78}
79
80/// Trait for data types that can be written to a writer.
81pub trait ToWriter: Sized {
82    /// Writes out `self`.
83    fn to_writer(&self, w: &mut impl Write) -> Result<()>;
84
85    /// Writes `self` to a file at `path` in binary format.
86    fn write_to_file(self, path: &Path) -> Result<()> {
87        let mut file = File::create(path).with_context(|| format!("Failed to create {path:?}"))?;
88        self.to_writer(&mut file)
89    }
90}
91
92/// Waits for an event on `fd` or for `timeout` to expire.
93pub fn wait_timeout(
94    fd: BorrowedFd<'_>,
95    events: rustix::event::PollFlags,
96    timeout: Duration,
97) -> Result<()> {
98    let mut pfd = [rustix::event::PollFd::from_borrowed_fd(fd, events)];
99    match rustix::event::poll(&mut pfd, timeout.try_into().ok().as_ref())? {
100        0 => Err(io::Error::new(
101            io::ErrorKind::TimedOut,
102            "timed out waiting for fd to be ready",
103        )
104        .into()),
105        _ => Ok(()),
106    }
107}
108
109/// Waits for `fd` to become ready to read or `timeout` to expire.
110pub fn wait_read_timeout(fd: &impl AsFd, timeout: Duration) -> Result<()> {
111    wait_timeout(fd.as_fd(), rustix::event::PollFlags::IN, timeout)
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117    use anyhow::bail;
118    use rustix::io::{read, write};
119    use rustix::net::{AddressFamily, SocketFlags, SocketType, socketpair};
120
121    #[test]
122    fn test_data_ready() -> Result<()> {
123        let (snd, rcv) = socketpair(
124            AddressFamily::UNIX,
125            SocketType::STREAM,
126            SocketFlags::empty(),
127            None,
128        )?;
129
130        // Send the test data into the socket.
131        let sndbuf = b"abc123";
132        assert_eq!(write(&snd, sndbuf)?, sndbuf.len());
133
134        // Wait for it to be ready.
135        wait_read_timeout(&rcv, Duration::from_millis(10))?;
136
137        // Receive the test data and compare.
138        let mut rcvbuf = [0u8; 6];
139        assert_eq!(read(&rcv, &mut rcvbuf)?, sndbuf.len());
140        assert_eq!(sndbuf, &rcvbuf);
141        Ok(())
142    }
143
144    #[test]
145    fn test_timeout() -> Result<()> {
146        let (_snd, rcv) = socketpair(
147            AddressFamily::UNIX,
148            SocketType::STREAM,
149            SocketFlags::empty(),
150            None,
151        )?;
152
153        // Expect to timeout since there is no data ready on the socket.
154        let result = wait_read_timeout(&rcv, Duration::from_millis(10));
155        assert!(result.is_err());
156        let err = result.unwrap_err();
157        match err.downcast_ref::<io::Error>() {
158            Some(e) => assert_eq!(io::ErrorKind::TimedOut, e.kind()),
159            _ => bail!("Unexpected error result {:?}", err),
160        }
161        Ok(())
162    }
163}