opentitanlib/util/
file.rs1use anyhow::{Context, Result, ensure};
6use pem_rfc7468::{Decoder, Encoder, LineEnding};
7use thiserror::Error;
8
9use std::convert::TryInto;
10use std::fs::File;
11use std::io::{self, Read, Write};
12use std::os::fd::{AsFd, BorrowedFd};
13use std::path::Path;
14use std::time::Duration;
15
16#[derive(Debug, Error)]
18enum PemError {
19 #[error("PEM type error; expecting {0:?} but got {1:?}")]
20 LabelError(&'static str, String),
21}
22
23pub trait PemSerilizable: ToWriter + FromReader {
25 fn label() -> &'static str;
32
33 fn write_pem_file(&self, path: &Path) -> Result<()> {
35 const MAX_PEM_SIZE: usize = 4096;
36
37 let mut bytes = Vec::<u8>::new();
38 self.to_writer(&mut bytes)?;
39
40 let mut buf = [0u8; MAX_PEM_SIZE];
41 let mut encoder = Encoder::new(Self::label(), LineEnding::LF, &mut buf)?;
42 encoder.encode(&bytes)?;
43 let len = encoder.finish()?;
44
45 let mut file = File::create(path)?;
46 Ok(file.write_all(&buf[..len])?)
47 }
48
49 fn read_pem_file(path: &Path) -> Result<Self> {
51 let mut file = File::open(path)?;
52 let mut pem = Vec::<u8>::new();
53 file.read_to_end(&mut pem)?;
54
55 let mut decoder = Decoder::new(&pem)?;
56 ensure!(
57 decoder.type_label() == Self::label(),
58 PemError::LabelError(Self::label(), decoder.type_label().to_owned()),
59 );
60
61 let mut buf = Vec::new();
62 decoder.decode_to_end(&mut buf)?;
63
64 Self::from_reader(buf.as_slice())
65 }
66}
67
68pub trait FromReader: Sized {
70 fn from_reader(r: impl Read) -> Result<Self>;
72
73 fn read_from_file(path: &Path) -> Result<Self> {
75 let file = File::open(path).with_context(|| format!("Failed to open {path:?}"))?;
76 Self::from_reader(file)
77 }
78}
79
80pub trait ToWriter: Sized {
82 fn to_writer(&self, w: &mut impl Write) -> Result<()>;
84
85 fn write_to_file(self, path: &Path) -> Result<()> {
87 let mut file = File::create(path).with_context(|| format!("Failed to create {path:?}"))?;
88 self.to_writer(&mut file)
89 }
90}
91
92pub fn wait_timeout(
94 fd: BorrowedFd<'_>,
95 events: rustix::event::PollFlags,
96 timeout: Duration,
97) -> Result<()> {
98 let mut pfd = [rustix::event::PollFd::from_borrowed_fd(fd, events)];
99 match rustix::event::poll(&mut pfd, timeout.try_into().ok().as_ref())? {
100 0 => Err(io::Error::new(
101 io::ErrorKind::TimedOut,
102 "timed out waiting for fd to be ready",
103 )
104 .into()),
105 _ => Ok(()),
106 }
107}
108
109pub fn wait_read_timeout(fd: &impl AsFd, timeout: Duration) -> Result<()> {
111 wait_timeout(fd.as_fd(), rustix::event::PollFlags::IN, timeout)
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117 use anyhow::bail;
118 use rustix::io::{read, write};
119 use rustix::net::{AddressFamily, SocketFlags, SocketType, socketpair};
120
121 #[test]
122 fn test_data_ready() -> Result<()> {
123 let (snd, rcv) = socketpair(
124 AddressFamily::UNIX,
125 SocketType::STREAM,
126 SocketFlags::empty(),
127 None,
128 )?;
129
130 let sndbuf = b"abc123";
132 assert_eq!(write(&snd, sndbuf)?, sndbuf.len());
133
134 wait_read_timeout(&rcv, Duration::from_millis(10))?;
136
137 let mut rcvbuf = [0u8; 6];
139 assert_eq!(read(&rcv, &mut rcvbuf)?, sndbuf.len());
140 assert_eq!(sndbuf, &rcvbuf);
141 Ok(())
142 }
143
144 #[test]
145 fn test_timeout() -> Result<()> {
146 let (_snd, rcv) = socketpair(
147 AddressFamily::UNIX,
148 SocketType::STREAM,
149 SocketFlags::empty(),
150 None,
151 )?;
152
153 let result = wait_read_timeout(&rcv, Duration::from_millis(10));
155 assert!(result.is_err());
156 let err = result.unwrap_err();
157 match err.downcast_ref::<io::Error>() {
158 Some(e) => assert_eq!(io::ErrorKind::TimedOut, e.kind()),
159 _ => bail!("Unexpected error result {:?}", err),
160 }
161 Ok(())
162 }
163}