opentitanlib/transport/proxy/
mod.rs1use std::cell::RefCell;
6use std::collections::HashMap;
7use std::io;
8use std::io::{BufWriter, ErrorKind, Read, Write};
9use std::net::{TcpStream, ToSocketAddrs};
10use std::rc::Rc;
11
12use anyhow::{Context, Result, bail};
13use serde::{Deserialize, Serialize};
14use thiserror::Error;
15use tokio::io::AsyncWriteExt;
16
17use crate::bootstrap::BootstrapOptions;
18use crate::impl_serializable_error;
19use crate::io::emu::Emulator;
20use crate::io::gpio::{GpioBitbanging, GpioMonitoring, GpioPin};
21use crate::io::i2c::Bus;
22use crate::io::spi::Target;
23use crate::io::uart::Uart;
24use crate::proxy::protocol::{
25 AsyncMessage, Message, ProxyRequest, ProxyResponse, Request, Response, UartRequest,
26 UartResponse,
27};
28use crate::transport::{Capabilities, Capability, ProxyOps, Transport, TransportError};
29
30mod emu;
31mod gpio;
32mod i2c;
33mod spi;
34mod uart;
35
36#[derive(Debug, Error, Serialize, Deserialize)]
37pub enum ProxyError {
38 #[error("Unexpected reply")]
39 UnexpectedReply(),
40 #[error("JSON encoding: {0}")]
41 JsonEncoding(String),
42 #[error("JSON decoding: {0}")]
43 JsonDecoding(String),
44}
45impl_serializable_error!(ProxyError);
46
47pub struct Proxy {
50 inner: Rc<Inner>,
51}
52
53impl Proxy {
54 pub fn open(host: Option<&str>, port: u16) -> Result<Self> {
56 let host = host.unwrap_or("localhost");
57 let addr = ToSocketAddrs::to_socket_addrs(&(host, port))
58 .map_err(|e| TransportError::ProxyLookupError(host.to_string(), e.to_string()))?
59 .next()
60 .unwrap();
61 let conn = TcpStream::connect(addr)
62 .map_err(|e| TransportError::ProxyConnectError(addr.to_string(), e.to_string()))?;
63 Ok(Self {
64 inner: Rc::new(Inner {
65 conn: RefCell::new(conn),
66 uarts: RefCell::new(HashMap::new()),
67 uart_channel_map: RefCell::new(HashMap::new()),
68 recv_buf: RefCell::new(Vec::new()),
69 }),
70 })
71 }
72}
73
74struct UartRecord {
75 pub uart: Rc<dyn Uart>,
76 pub pipe_sender: tokio::io::WriteHalf<tokio::io::SimplexStream>,
77 pub pipe_receiver: tokio::io::ReadHalf<tokio::io::SimplexStream>,
78}
79
80struct Inner {
81 conn: RefCell<TcpStream>,
82 pub uarts: RefCell<HashMap<String, UartRecord>>,
83 uart_channel_map: RefCell<HashMap<u32, String>>,
84 recv_buf: RefCell<Vec<u8>>,
85}
86
87impl Inner {
88 fn execute_command(&self, req: Request) -> Result<Response> {
91 self.send_json_request(req).context("json encoding")?;
92 loop {
93 match self.recv_json_response().context("json decoding")? {
94 Message::Res(res) => match res {
95 Ok(value) => return Ok(value),
96 Err(e) => return Err(anyhow::Error::from(e)),
97 },
98 Message::Async { channel, msg } => self.process_async_data(channel, msg)?,
99 _ => bail!(ProxyError::UnexpectedReply()),
100 }
101 }
102 }
103
104 fn poll_for_async_data(&self) -> Result<()> {
105 self.recv_nonblocking()?;
106 while let Some(msg) = self.dequeue_json_response()? {
107 match msg {
108 Message::Async { channel, msg } => self.process_async_data(channel, msg)?,
109 _ => bail!(ProxyError::UnexpectedReply()),
110 }
111 }
112 Ok(())
113 }
114
115 fn process_async_data(&self, channel: u32, msg: AsyncMessage) -> Result<()> {
116 match msg {
117 AsyncMessage::UartData { data } => {
118 if let Some(uart_instance) = self.uart_channel_map.borrow().get(&channel)
119 && let Some(uart_record) = self.uarts.borrow_mut().get_mut(uart_instance)
120 {
121 crate::util::runtime::block_on(async {
122 uart_record.pipe_sender.write_all(&data).await
123 })?;
124 }
125 }
126 }
127 Ok(())
128 }
129
130 fn send_json_request(&self, req: Request) -> Result<()> {
132 let conn: &mut std::net::TcpStream = &mut self.conn.borrow_mut();
133 let mut writer = BufWriter::new(conn);
134 serde_json::to_writer(&mut writer, &Message::Req(req))?;
135 writer.write_all(b"\n")?;
136 writer.flush()?;
137 Ok(())
138 }
139
140 fn recv_json_response(&self) -> Result<Message> {
142 if let Some(msg) = self.dequeue_json_response()? {
143 return Ok(msg);
144 }
145 let mut conn = self.conn.borrow_mut();
146 let mut buf = self.recv_buf.borrow_mut();
147 let mut idx: usize = buf.len();
148 loop {
149 buf.resize(idx + 2048, 0);
150 let rc = conn.read(&mut buf[idx..])?;
151 if rc == 0 {
152 anyhow::bail!(io::Error::new(
153 ErrorKind::UnexpectedEof,
154 "Server unexpectedly closed connection"
155 ))
156 }
157 idx += rc;
158 let Some(newline_pos) = buf[idx - rc..idx].iter().position(|b| *b == b'\n') else {
159 continue;
160 };
161 let result = serde_json::from_slice::<Message>(&buf[..idx - rc + newline_pos])?;
162 buf.resize(idx, 0u8);
163 buf.drain(..idx - rc + newline_pos + 1);
164 return Ok(result);
165 }
166 }
167
168 fn recv_nonblocking(&self) -> Result<()> {
169 let mut conn = self.conn.borrow_mut();
170 conn.set_nonblocking(true)?;
171 let mut buf = self.recv_buf.borrow_mut();
172 let mut idx: usize = buf.len();
173 loop {
174 buf.resize(idx + 2048, 0);
175 match conn.read(&mut buf[idx..]) {
176 Ok(0) => {
177 anyhow::bail!(io::Error::new(
178 ErrorKind::UnexpectedEof,
179 "Server unexpectedly closed connection"
180 ))
181 }
182 Ok(rc) => idx += rc,
183 Err(ref e) if e.kind() == ErrorKind::WouldBlock => break,
184 Err(e) => anyhow::bail!(e),
185 }
186 }
187 buf.resize(idx, 0);
188 conn.set_nonblocking(false)?;
189 Ok(())
190 }
191
192 fn dequeue_json_response(&self) -> Result<Option<Message>> {
193 let mut buf = self.recv_buf.borrow_mut();
194 let Some(newline_pos) = buf.iter().position(|b| *b == b'\n') else {
195 return Ok(None);
196 };
197 let result = serde_json::from_slice::<Message>(&buf[..newline_pos])?;
198 buf.drain(..newline_pos + 1);
199 Ok(Some(result))
200 }
201}
202
203pub struct ProxyOpsImpl {
204 inner: Rc<Inner>,
205}
206
207impl ProxyOpsImpl {
208 pub fn new(proxy: &Proxy) -> Result<Self> {
209 Ok(Self {
210 inner: Rc::clone(&proxy.inner),
211 })
212 }
213
214 fn execute_command(&self, command: ProxyRequest) -> Result<ProxyResponse> {
216 match self.inner.execute_command(Request::Proxy(command))? {
217 Response::Proxy(resp) => Ok(resp),
218 _ => bail!(ProxyError::UnexpectedReply()),
219 }
220 }
221}
222
223impl ProxyOps for ProxyOpsImpl {
224 fn provides_map(&self) -> Result<HashMap<String, String>> {
225 match self.execute_command(ProxyRequest::Provides {})? {
226 ProxyResponse::Provides { provides_map } => Ok(provides_map),
227 _ => bail!(ProxyError::UnexpectedReply()),
228 }
229 }
230
231 fn bootstrap(&self, options: &BootstrapOptions, payload: &[u8]) -> Result<()> {
232 match self.execute_command(ProxyRequest::Bootstrap {
233 options: options.clone(),
234 payload: payload.to_vec(),
235 })? {
236 ProxyResponse::Bootstrap => Ok(()),
237 _ => bail!(ProxyError::UnexpectedReply()),
238 }
239 }
240
241 fn apply_pin_strapping(&self, strapping_name: &str) -> Result<()> {
242 match self.execute_command(ProxyRequest::ApplyPinStrapping {
243 strapping_name: strapping_name.to_string(),
244 })? {
245 ProxyResponse::ApplyPinStrapping => Ok(()),
246 _ => bail!(ProxyError::UnexpectedReply()),
247 }
248 }
249
250 fn remove_pin_strapping(&self, strapping_name: &str) -> Result<()> {
251 match self.execute_command(ProxyRequest::RemovePinStrapping {
252 strapping_name: strapping_name.to_string(),
253 })? {
254 ProxyResponse::RemovePinStrapping => Ok(()),
255 _ => bail!(ProxyError::UnexpectedReply()),
256 }
257 }
258
259 fn apply_default_configuration_with_strap(&self, strapping_name: &str) -> Result<()> {
260 match self.execute_command(ProxyRequest::ApplyDefaultConfigurationWithStrapping {
261 strapping_name: strapping_name.to_string(),
262 })? {
263 ProxyResponse::ApplyDefaultConfigurationWithStrapping => Ok(()),
264 _ => bail!(ProxyError::UnexpectedReply()),
265 }
266 }
267}
268
269impl Transport for Proxy {
270 fn capabilities(&self) -> Result<Capabilities> {
271 match self.inner.execute_command(Request::GetCapabilities)? {
272 Response::GetCapabilities(capabilities) => Ok(capabilities.add(Capability::PROXY)),
273 _ => bail!(ProxyError::UnexpectedReply()),
274 }
275 }
276
277 fn apply_default_configuration(&self) -> Result<()> {
278 match self
279 .inner
280 .execute_command(Request::ApplyDefaultConfiguration)?
281 {
282 Response::ApplyDefaultConfiguration => Ok(()),
283 _ => bail!(ProxyError::UnexpectedReply()),
284 }
285 }
286
287 fn spi(&self, instance: &str) -> Result<Rc<dyn Target>> {
289 Ok(Rc::new(spi::ProxySpi::open(self, instance)?))
290 }
291
292 fn i2c(&self, instance: &str) -> Result<Rc<dyn Bus>> {
294 Ok(Rc::new(i2c::ProxyI2c::open(self, instance)?))
295 }
296
297 fn uart(&self, instance_name: &str) -> Result<Rc<dyn Uart>> {
299 if let Some(instance) = self.inner.uarts.borrow().get(instance_name) {
300 return Ok(Rc::clone(&instance.uart));
301 }
302
303 let Response::Uart(UartResponse::RegisterNonblockingRead { channel }) =
307 self.inner.execute_command(Request::Uart {
308 id: instance_name.to_owned(),
309 command: UartRequest::RegisterNonblockingRead,
310 })?
311 else {
312 bail!(ProxyError::UnexpectedReply())
313 };
314
315 let instance: Rc<dyn Uart> = Rc::new(uart::ProxyUart::open(self, instance_name)?);
316 let (pipe_receiver, pipe_sender) = tokio::io::simplex(65536);
317
318 self.inner
319 .uart_channel_map
320 .borrow_mut()
321 .insert(channel, instance_name.to_owned());
322 self.inner.uarts.borrow_mut().insert(
323 instance_name.to_owned(),
324 UartRecord {
325 uart: Rc::clone(&instance),
326 pipe_sender,
327 pipe_receiver,
328 },
329 );
330 Ok(instance)
331 }
332
333 fn gpio_pin(&self, pinname: &str) -> Result<Rc<dyn GpioPin>> {
335 Ok(Rc::new(gpio::ProxyGpioPin::open(self, pinname)?))
336 }
337
338 fn gpio_monitoring(&self) -> Result<Rc<dyn GpioMonitoring>> {
340 Ok(Rc::new(gpio::GpioMonitoringImpl::new(self)?))
341 }
342
343 fn gpio_bitbanging(&self) -> Result<Rc<dyn GpioBitbanging>> {
345 Ok(Rc::new(gpio::GpioBitbangingImpl::new(self)?))
346 }
347
348 fn emulator(&self) -> Result<Rc<dyn Emulator>> {
350 Ok(Rc::new(emu::ProxyEmu::open(self)?))
351 }
352
353 fn proxy_ops(&self) -> Result<Rc<dyn ProxyOps>> {
355 Ok(Rc::new(ProxyOpsImpl::new(self)?))
356 }
357}