opentitanlib/io/console/
ext.rs1use std::time::Duration;
6
7use anyhow::{Context, Result};
8
9use super::ConsoleDevice;
10use crate::io::console::{ConsoleError, Logged};
11
12pub trait ConsoleExt {
14 fn read(&self, buf: &mut [u8]) -> Result<usize>;
17
18 fn read_timeout(&self, buf: &mut [u8], timeout: Duration) -> Result<usize>;
22
23 fn logged(self) -> Logged<Self>
25 where
26 Self: Sized;
27
28 fn try_wait_for_line<P: MatchPattern>(
32 &self,
33 pattern: P,
34 timeout: Duration,
35 ) -> Result<Option<P::MatchResult>>;
36
37 fn wait_for_line<P: MatchPattern>(
53 &self,
54 pattern: P,
55 timeout: Duration,
56 ) -> Result<P::MatchResult>;
57}
58
59impl<T: ConsoleDevice + ?Sized> ConsoleExt for T {
60 fn read(&self, buf: &mut [u8]) -> Result<usize> {
61 crate::util::runtime::block_on(std::future::poll_fn(|cx| self.poll_read(cx, buf)))
62 }
63
64 fn read_timeout(&self, buf: &mut [u8], timeout: Duration) -> Result<usize> {
65 crate::util::runtime::block_on(async {
66 tokio::time::timeout(timeout, std::future::poll_fn(|cx| self.poll_read(cx, buf))).await
67 })
68 .unwrap_or(Ok(0))
69 }
70
71 fn logged(self) -> Logged<Self>
72 where
73 Self: Sized,
74 {
75 Logged::new(self)
76 }
77
78 fn try_wait_for_line<P: MatchPattern>(
79 &self,
80 pattern: P,
81 timeout: Duration,
82 ) -> Result<Option<P::MatchResult>> {
83 crate::util::runtime::block_on(async {
84 match tokio::time::timeout(timeout, async {
85 loop {
86 let line = read_line(self).await?;
87 if let Some(m) = pattern.perform_match(&line) {
88 return Ok(m);
89 }
90 }
91 })
92 .await
93 {
94 Ok(Ok(v)) => Ok(Some(v)),
95 Ok(Err(e)) => Err(e),
96 Err(_) => Ok(None),
97 }
98 })
99 }
100
101 fn wait_for_line<P: MatchPattern>(
102 &self,
103 pattern: P,
104 timeout: Duration,
105 ) -> Result<P::MatchResult> {
106 self.try_wait_for_line(pattern, timeout)?
107 .with_context(|| ConsoleError::TimedOut)
108 }
109}
110
111async fn read_line<T: ConsoleDevice + ?Sized>(console: &T) -> Result<Vec<u8>> {
112 let mut buf = Vec::new();
113
114 loop {
115 let mut ch = 0;
117 let len =
118 std::future::poll_fn(|cx| console.poll_read(cx, std::slice::from_mut(&mut ch))).await?;
119 if len == 0 {
120 break;
121 }
122
123 buf.push(ch);
124 if ch == b'\n' {
125 break;
126 }
127 }
128
129 Ok(buf)
130}
131
132pub trait MatchPattern {
134 type MatchResult;
135
136 fn perform_match(&self, haystack: &[u8]) -> Option<Self::MatchResult>;
137}
138
139impl<T: MatchPattern + ?Sized> MatchPattern for &T {
140 type MatchResult = T::MatchResult;
141
142 fn perform_match(&self, haystack: &[u8]) -> Option<Self::MatchResult> {
143 T::perform_match(self, haystack)
144 }
145}
146
147impl MatchPattern for [u8] {
148 type MatchResult = ();
149
150 fn perform_match(&self, haystack: &[u8]) -> Option<Self::MatchResult> {
151 memchr::memmem::find(haystack, self).map(|_| ())
152 }
153}
154
155impl MatchPattern for regex::bytes::Regex {
156 type MatchResult = Vec<Vec<u8>>;
157
158 fn perform_match(&self, haystack: &[u8]) -> Option<Self::MatchResult> {
159 Some(
160 self.captures(haystack)?
161 .iter()
162 .map(|x| x.map(|m| m.as_bytes().to_owned()).unwrap_or_default())
163 .collect(),
164 )
165 }
166}
167
168impl MatchPattern for str {
169 type MatchResult = ();
170
171 fn perform_match(&self, haystack: &[u8]) -> Option<Self::MatchResult> {
172 self.as_bytes().perform_match(haystack)
173 }
174}
175
176impl MatchPattern for regex::Regex {
177 type MatchResult = Vec<String>;
178
179 fn perform_match(&self, haystack: &[u8]) -> Option<Self::MatchResult> {
180 let haystack = String::from_utf8_lossy(haystack);
181 Some(
182 self.captures(&haystack)?
183 .iter()
184 .map(|x| x.map(|m| m.as_str().to_owned()).unwrap_or_default())
185 .collect(),
186 )
187 }
188}
189
190pub struct PassFail<T, E>(pub T, pub E);
192
193pub enum PassFailResult<T, E> {
194 Pass(T),
195 Fail(E),
196}
197
198impl<T: MatchPattern, E: MatchPattern> MatchPattern for PassFail<T, E> {
199 type MatchResult = PassFailResult<T::MatchResult, E::MatchResult>;
200
201 fn perform_match(&self, haystack: &[u8]) -> Option<Self::MatchResult> {
202 if let Some(m) = self.1.perform_match(haystack) {
203 return Some(PassFailResult::Fail(m));
204 }
205
206 Some(PassFailResult::Pass(self.0.perform_match(haystack)?))
207 }
208}