use anyhow::{ensure, Context, Result};
use pem_rfc7468::{Decoder, Encoder, LineEnding};
use thiserror::Error;
use std::convert::TryInto;
use std::fs::File;
use std::io::{self, Read, Write};
use std::os::fd::{AsFd, BorrowedFd};
use std::path::Path;
use std::time::Duration;
#[derive(Debug, Error)]
enum PemError {
#[error("PEM type error; expecting {0:?} but got {1:?}")]
LabelError(&'static str, String),
}
pub trait PemSerilizable: ToWriter + FromReader {
fn label() -> &'static str;
fn write_pem_file(&self, path: &Path) -> Result<()> {
const MAX_PEM_SIZE: usize = 4096;
let mut bytes = Vec::<u8>::new();
self.to_writer(&mut bytes)?;
let mut buf = [0u8; MAX_PEM_SIZE];
let mut encoder = Encoder::new(Self::label(), LineEnding::LF, &mut buf)?;
encoder.encode(&bytes)?;
let len = encoder.finish()?;
let mut file = File::create(path)?;
Ok(file.write_all(&buf[..len])?)
}
fn read_pem_file(path: &Path) -> Result<Self> {
let mut file = File::open(path)?;
let mut pem = Vec::<u8>::new();
file.read_to_end(&mut pem)?;
let mut decoder = Decoder::new(&pem)?;
ensure!(
decoder.type_label() == Self::label(),
PemError::LabelError(Self::label(), decoder.type_label().to_owned()),
);
let mut buf = Vec::new();
decoder.decode_to_end(&mut buf)?;
Self::from_reader(buf.as_slice())
}
}
pub trait FromReader: Sized {
fn from_reader(r: impl Read) -> Result<Self>;
fn read_from_file(path: &Path) -> Result<Self> {
let file = File::open(path).with_context(|| format!("Failed to open {path:?}"))?;
Self::from_reader(file)
}
}
pub trait ToWriter: Sized {
fn to_writer(&self, w: &mut impl Write) -> Result<()>;
fn write_to_file(self, path: &Path) -> Result<()> {
let mut file = File::create(path).with_context(|| format!("Failed to create {path:?}"))?;
self.to_writer(&mut file)
}
}
pub fn wait_timeout(
fd: BorrowedFd<'_>,
events: rustix::event::PollFlags,
timeout: Duration,
) -> Result<()> {
let timeout = timeout.as_millis().try_into().unwrap_or(i32::MAX);
let mut pfd = [rustix::event::PollFd::from_borrowed_fd(fd, events)];
match rustix::event::poll(&mut pfd, timeout)? {
0 => Err(io::Error::new(
io::ErrorKind::TimedOut,
"timed out waiting for fd to be ready",
)
.into()),
_ => Ok(()),
}
}
pub fn wait_read_timeout(fd: &impl AsFd, timeout: Duration) -> Result<()> {
wait_timeout(fd.as_fd(), rustix::event::PollFlags::IN, timeout)
}
#[cfg(test)]
mod tests {
use super::*;
use anyhow::bail;
use rustix::io::{read, write};
use rustix::net::{socketpair, AddressFamily, SocketFlags, SocketType};
#[test]
fn test_data_ready() -> Result<()> {
let (snd, rcv) = socketpair(
AddressFamily::UNIX,
SocketType::STREAM,
SocketFlags::empty(),
None,
)?;
let sndbuf = b"abc123";
assert_eq!(write(&snd, sndbuf)?, sndbuf.len());
wait_read_timeout(&rcv, Duration::from_millis(10))?;
let mut rcvbuf = [0u8; 6];
assert_eq!(read(&rcv, &mut rcvbuf)?, sndbuf.len());
assert_eq!(sndbuf, &rcvbuf);
Ok(())
}
#[test]
fn test_timeout() -> Result<()> {
let (_snd, rcv) = socketpair(
AddressFamily::UNIX,
SocketType::STREAM,
SocketFlags::empty(),
None,
)?;
let result = wait_read_timeout(&rcv, Duration::from_millis(10));
assert!(result.is_err());
let err = result.unwrap_err();
match err.downcast_ref::<io::Error>() {
Some(e) => assert_eq!(io::ErrorKind::TimedOut, e.kind()),
_ => bail!("Unexpected error result {:?}", err),
}
Ok(())
}
}