1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
// Copyright lowRISC contributors (OpenTitan project).
// Licensed under the Apache License, Version 2.0, see LICENSE for details.
// SPDX-License-Identifier: Apache-2.0

use anyhow::{ensure, Context, Result};
use pem_rfc7468::{Decoder, Encoder, LineEnding};
use thiserror::Error;

use std::convert::TryInto;
use std::fs::File;
use std::io::{self, Read, Write};
use std::os::fd::{AsFd, BorrowedFd};
use std::path::Path;
use std::time::Duration;

/// Error type for errors related to PEM serialization.
#[derive(Debug, Error)]
enum PemError {
    #[error("PEM type error; expecting {0:?} but got {1:?}")]
    LabelError(&'static str, String),
}

/// Trait for data that can be written to and read from PEM files.
pub trait PemSerilizable: ToWriter + FromReader {
    /// The label for the PEM file.
    ///
    /// Appears around the base64 encoded data.
    /// -----BEGIN MY_LABEL-----
    /// ...
    /// -----END MY_LABEL-----
    fn label() -> &'static str;

    /// Write to PEM file with label from `Self::label()`.
    fn write_pem_file(&self, path: &Path) -> Result<()> {
        const MAX_PEM_SIZE: usize = 4096;

        let mut bytes = Vec::<u8>::new();
        self.to_writer(&mut bytes)?;

        let mut buf = [0u8; MAX_PEM_SIZE];
        let mut encoder = Encoder::new(Self::label(), LineEnding::LF, &mut buf)?;
        encoder.encode(&bytes)?;
        let len = encoder.finish()?;

        let mut file = File::create(path)?;
        Ok(file.write_all(&buf[..len])?)
    }

    /// Read in from PEM file, ensuring the label matches `Self::label()`.
    fn read_pem_file(path: &Path) -> Result<Self> {
        let mut file = File::open(path)?;
        let mut pem = Vec::<u8>::new();
        file.read_to_end(&mut pem)?;

        let mut decoder = Decoder::new(&pem)?;
        ensure!(
            decoder.type_label() == Self::label(),
            PemError::LabelError(Self::label(), decoder.type_label().to_owned()),
        );

        let mut buf = Vec::new();
        decoder.decode_to_end(&mut buf)?;

        Self::from_reader(buf.as_slice())
    }
}

/// Trait for data types that can be streamed to a reader.
pub trait FromReader: Sized {
    /// Reads in an instance of `Self`.
    fn from_reader(r: impl Read) -> Result<Self>;

    /// Reads an instance of `Self` from a binary file at `path`.
    fn read_from_file(path: &Path) -> Result<Self> {
        let file = File::open(path).with_context(|| format!("Failed to open {path:?}"))?;
        Self::from_reader(file)
    }
}

/// Trait for data types that can be written to a writer.
pub trait ToWriter: Sized {
    /// Writes out `self`.
    fn to_writer(&self, w: &mut impl Write) -> Result<()>;

    /// Writes `self` to a file at `path` in binary format.
    fn write_to_file(self, path: &Path) -> Result<()> {
        let mut file = File::create(path).with_context(|| format!("Failed to create {path:?}"))?;
        self.to_writer(&mut file)
    }
}

/// Waits for an event on `fd` or for `timeout` to expire.
pub fn wait_timeout(
    fd: BorrowedFd<'_>,
    events: rustix::event::PollFlags,
    timeout: Duration,
) -> Result<()> {
    let timeout = timeout.as_millis().try_into().unwrap_or(i32::MAX);
    let mut pfd = [rustix::event::PollFd::from_borrowed_fd(fd, events)];
    match rustix::event::poll(&mut pfd, timeout)? {
        0 => Err(io::Error::new(
            io::ErrorKind::TimedOut,
            "timed out waiting for fd to be ready",
        )
        .into()),
        _ => Ok(()),
    }
}

/// Waits for `fd` to become ready to read or `timeout` to expire.
pub fn wait_read_timeout(fd: &impl AsFd, timeout: Duration) -> Result<()> {
    wait_timeout(fd.as_fd(), rustix::event::PollFlags::IN, timeout)
}

#[cfg(test)]
mod tests {
    use super::*;
    use anyhow::bail;
    use rustix::io::{read, write};
    use rustix::net::{socketpair, AddressFamily, SocketFlags, SocketType};

    #[test]
    fn test_data_ready() -> Result<()> {
        let (snd, rcv) = socketpair(
            AddressFamily::UNIX,
            SocketType::STREAM,
            SocketFlags::empty(),
            None,
        )?;

        // Send the test data into the socket.
        let sndbuf = b"abc123";
        assert_eq!(write(&snd, sndbuf)?, sndbuf.len());

        // Wait for it to be ready.
        wait_read_timeout(&rcv, Duration::from_millis(10))?;

        // Receive the test data and compare.
        let mut rcvbuf = [0u8; 6];
        assert_eq!(read(&rcv, &mut rcvbuf)?, sndbuf.len());
        assert_eq!(sndbuf, &rcvbuf);
        Ok(())
    }

    #[test]
    fn test_timeout() -> Result<()> {
        let (_snd, rcv) = socketpair(
            AddressFamily::UNIX,
            SocketType::STREAM,
            SocketFlags::empty(),
            None,
        )?;

        // Expect to timeout since there is no data ready on the socket.
        let result = wait_read_timeout(&rcv, Duration::from_millis(10));
        assert!(result.is_err());
        let err = result.unwrap_err();
        match err.downcast_ref::<io::Error>() {
            Some(e) => assert_eq!(io::ErrorKind::TimedOut, e.kind()),
            _ => bail!("Unexpected error result {:?}", err),
        }
        Ok(())
    }
}