use anyhow::{anyhow, bail, Result};
use once_cell::sync::Lazy;
use regex::Regex;
use serde_annotate::Annotate;
use std::any::Any;
use std::cell::RefCell;
use std::cmp::Ordering;
use crate::transport::{
Capabilities, Capability, ProgressIndicator, Transport, TransportError, UpdateFirmware,
};
use crate::util::usb::UsbBackend;
const VID_ST_MICROELECTRONICS: u16 = 0x0483;
const PID_DFU_BOOTLOADER: u16 = 0xdf11;
pub struct HyperdebugDfu {
usb_backend: RefCell<UsbBackend>,
current_firmware_version: Option<String>,
usb_vid: u16,
usb_pid: u16,
}
impl HyperdebugDfu {
pub fn open(
usb_vid: Option<u16>,
usb_pid: Option<u16>,
usb_serial: Option<&str>,
) -> Result<Self> {
if let Ok(usb_backend) =
UsbBackend::new(VID_ST_MICROELECTRONICS, PID_DFU_BOOTLOADER, usb_serial)
{
return Ok(Self {
usb_backend: RefCell::new(usb_backend),
current_firmware_version: None,
usb_vid: usb_vid.unwrap_or(super::VID_GOOGLE),
usb_pid: usb_pid.unwrap_or(super::PID_HYPERDEBUG),
});
}
let usb_backend = UsbBackend::new(
usb_vid.unwrap_or(super::VID_GOOGLE),
usb_pid.unwrap_or(super::PID_HYPERDEBUG),
usb_serial,
)?;
let config_desc = usb_backend.active_config_descriptor()?;
let current_firmware_version = if let Some(idx) = config_desc.description_string_index() {
if let Ok(current_firmware_version) = usb_backend.read_string_descriptor_ascii(idx) {
Some(current_firmware_version)
} else {
None
}
} else {
None
};
Ok(Self {
usb_backend: RefCell::new(usb_backend),
current_firmware_version,
usb_vid: usb_vid.unwrap_or(super::VID_GOOGLE),
usb_pid: usb_pid.unwrap_or(super::PID_HYPERDEBUG),
})
}
}
impl Transport for HyperdebugDfu {
fn capabilities(&self) -> Result<Capabilities> {
Ok(Capabilities::new(Capability::NONE))
}
fn dispatch(&self, action: &dyn Any) -> Result<Option<Box<dyn Annotate>>> {
if let Some(update_firmware_action) = action.downcast_ref::<UpdateFirmware>() {
update_firmware(
&mut self.usb_backend.borrow_mut(),
self.current_firmware_version.as_deref(),
&update_firmware_action.firmware,
update_firmware_action.progress.as_ref(),
update_firmware_action.force,
self.usb_vid,
self.usb_pid,
)
} else {
bail!(TransportError::UnsupportedOperation)
}
}
}
const USB_CLASS_APP: u8 = 0xFE;
const USB_SUBCLASS_DFU: u8 = 0x01;
const DFUSE_ERASE_PAGE: u8 = 0x41;
const DFUSE_PROGRAM_PAGE: u8 = 0x21;
const DFU_STATUS_OK: u8 = 0x00;
const DFU_STATE_APP_IDLE: u8 = 0x00;
const DFU_STATE_DFU_IDLE: u8 = 0x02;
const DFU_STATE_DOWNLOAD_BUSY: u8 = 0x04;
const DFU_STATE_DOWNLOAD_IDLE: u8 = 0x05;
const USB_DFU_DETACH: u8 = 0;
const USB_DFU_DNLOAD: u8 = 1;
const USB_DFU_GETSTATUS: u8 = 3;
#[cfg(not(feature = "include_hyperdebug_firmware"))]
const OFFICIAL_FIRMWARE: Option<&'static [u8]> = None;
#[cfg(feature = "include_hyperdebug_firmware")]
const OFFICIAL_FIRMWARE: Option<&'static [u8]> = Some(include_bytes!(env!("hyperdebug_firmware")));
pub fn official_firmware_version() -> Result<Option<&'static str>> {
if let Some(fw) = OFFICIAL_FIRMWARE {
Ok(Some(get_hyperdebug_firmware_version(fw)?))
} else {
Ok(None)
}
}
fn validate_firmware_image(firmware: &[u8]) -> Result<()> {
get_hyperdebug_firmware_version(firmware)?;
Ok(())
}
const EC_COOKIE: [u8; 4] = [0x99, 0x88, 0x77, 0xce];
const EC_FIRMWARE_NAME_LEN: usize = 32;
fn get_hyperdebug_firmware_version(firmware: &[u8]) -> Result<&str> {
let Some(pos) = firmware[0..1024]
.chunks(4)
.position(|c| c[0..4] == EC_COOKIE)
else {
bail!(TransportError::FirmwareProgramFailed(
"File is not a HyperDebug firmware image".to_string()
));
};
let firmware_name_field = &firmware[(pos + 1) * 4..(pos + 1) * 4 + EC_FIRMWARE_NAME_LEN];
let end = firmware_name_field
.iter()
.rev()
.position(|b| *b != 0x00)
.map(|j| EC_FIRMWARE_NAME_LEN - j)
.unwrap_or(0);
Ok(std::str::from_utf8(&firmware_name_field[0..end])?)
}
pub fn update_firmware(
usb_device: &mut UsbBackend,
current_firmware_version: Option<&str>,
firmware: &Option<Vec<u8>>,
progress: &dyn ProgressIndicator,
force: bool,
usb_vid: u16,
usb_pid: u16,
) -> Result<Option<Box<dyn Annotate>>> {
let firmware: &[u8] = if let Some(vec) = firmware.as_ref() {
validate_firmware_image(vec)?;
vec
} else {
OFFICIAL_FIRMWARE.ok_or_else(|| anyhow!("No build-in firmware, use --filename"))?
};
if !force {
if let Some(current_version) = current_firmware_version {
let new_version = get_hyperdebug_firmware_version(firmware)?;
if new_version == current_version {
log::warn!(
"HyperDebug already running firmware version {}. Consider --force.",
new_version,
);
return Ok(None);
}
if is_older_than(new_version, current_version)? {
log::warn!(
"Will not downgrade from {} to {}. Consider --force.",
current_version,
new_version,
);
return Ok(None);
}
}
}
let dfu_desc = scan_usb_descriptor(usb_device)?;
usb_device.claim_interface(dfu_desc.dfu_interface)?;
if wait_for_idle(usb_device, dfu_desc.dfu_interface)? != DFU_STATE_APP_IDLE {
do_update_firmware(usb_device, dfu_desc, firmware, progress)?;
restablish_connection(usb_vid, usb_pid, usb_device.get_serial_number())?;
return Ok(None);
}
log::info!("Requesting switch to DFU mode...");
let _ = usb_device
.write_control(
rusb::request_type(
rusb::Direction::Out,
rusb::RequestType::Class,
rusb::Recipient::Interface,
),
USB_DFU_DETACH,
1000,
dfu_desc.dfu_interface as u16,
&[],
)
.and_then(|_| wait_for_idle(usb_device, dfu_desc.dfu_interface));
std::thread::sleep(std::time::Duration::from_millis(1000));
log::info!("Connecting to DFU bootloader...");
let mut dfu_device = UsbBackend::new(
VID_ST_MICROELECTRONICS,
PID_DFU_BOOTLOADER,
Some(usb_device.get_serial_number()),
)?;
log::info!("Connected to DFU bootloader");
let dfu_desc = scan_usb_descriptor(&dfu_device)?;
dfu_device.claim_interface(dfu_desc.dfu_interface)?;
do_update_firmware(&dfu_device, dfu_desc, firmware, progress)?;
restablish_connection(usb_vid, usb_pid, usb_device.get_serial_number())?;
Ok(None)
}
fn restablish_connection(usb_vid: u16, usb_pid: u16, serial_number: &str) -> Result<()> {
log::info!("Connecting to newly flashed firmware...");
for _ in 0..10 {
std::thread::sleep(std::time::Duration::from_millis(500));
if UsbBackend::new(usb_vid, usb_pid, Some(serial_number)).is_ok() {
return Ok(());
}
}
bail!(TransportError::FirmwareProgramFailed(
"Unable to establish connection after flashing. Possibly bad image.".to_string()
));
}
fn do_update_firmware(
usb_device: &UsbBackend,
dfu_desc: DfuDescriptor,
firmware: &[u8],
progress: &dyn ProgressIndicator,
) -> Result<()> {
let DfuDescriptor {
dfu_interface,
xfer_size,
page_size,
flash_size,
base_address,
} = dfu_desc;
if page_size == 0 || flash_size != 0x80000 || xfer_size == 0 {
bail!(TransportError::UsbOpenError(
"Unrecognized DFU layout (not a Nucleo-L552ZE?)".to_string()
));
}
log::info!("Erasing flash storage...");
let firmware_len = firmware.len() as u32;
progress.new_stage("Erasing", firmware_len as usize);
let mut bytes_erased: u32 = 0;
while bytes_erased < firmware_len {
let mut request = [0u8; 5];
request[0] = DFUSE_ERASE_PAGE;
request[1..5].copy_from_slice(&(base_address + bytes_erased).to_le_bytes());
usb_device.write_control(
rusb::request_type(
rusb::Direction::Out,
rusb::RequestType::Class,
rusb::Recipient::Interface,
),
USB_DFU_DNLOAD,
0,
dfu_interface as u16,
&request,
)?;
wait_for_idle(usb_device, dfu_interface)?;
bytes_erased += page_size;
progress.progress(bytes_erased as usize);
}
log::info!("Programming flash storage...");
progress.new_stage("Writing", firmware_len as usize);
let mut bytes_sent: u32 = 0;
while bytes_sent < firmware_len {
let chunk_size = std::cmp::min(firmware_len - bytes_sent, xfer_size);
let mut request = [0u8; 5];
request[0] = DFUSE_PROGRAM_PAGE;
request[1..5].copy_from_slice(&(base_address + bytes_sent).to_le_bytes());
usb_device.write_control(
rusb::request_type(
rusb::Direction::Out,
rusb::RequestType::Class,
rusb::Recipient::Interface,
),
USB_DFU_DNLOAD,
0,
dfu_interface as u16,
&request,
)?;
wait_for_idle(usb_device, dfu_interface)?;
usb_device.write_control(
rusb::request_type(
rusb::Direction::Out,
rusb::RequestType::Class,
rusb::Recipient::Interface,
),
USB_DFU_DNLOAD,
2,
dfu_interface as u16,
&firmware[bytes_sent as usize..(bytes_sent + chunk_size) as usize],
)?;
wait_for_idle(usb_device, dfu_interface)?;
bytes_sent += chunk_size;
progress.progress(bytes_sent as usize);
}
usb_device.write_control(
rusb::request_type(
rusb::Direction::Out,
rusb::RequestType::Class,
rusb::Recipient::Interface,
),
USB_DFU_DNLOAD,
0,
dfu_interface as u16,
&[],
)?;
let _ = wait_for_idle(usb_device, dfu_interface);
Ok(())
}
struct DfuDescriptor {
dfu_interface: u8,
xfer_size: u32,
page_size: u32,
flash_size: u32,
base_address: u32,
}
fn scan_usb_descriptor(usb_device: &UsbBackend) -> Result<DfuDescriptor> {
let mut dfu_interface = 0;
let mut xfer_size = 0;
let mut page_size = 0;
let mut flash_size = 0;
let mut base_address = 0;
let config_desc = usb_device.active_config_descriptor()?;
for interface in config_desc.interfaces() {
for interface_desc in interface.descriptors() {
let idx = match interface_desc.description_string_index() {
Some(idx) => idx,
None => continue,
};
let interface_name = match usb_device.read_string_descriptor_ascii(idx) {
Ok(interface_name) => interface_name,
_ => continue,
};
if interface_desc.class_code() != USB_CLASS_APP
|| interface_desc.sub_class_code() != USB_SUBCLASS_DFU
|| (interface_desc.protocol_code() != 0x01
&& interface_desc.protocol_code() != 0x02)
{
continue;
}
dfu_interface = interface.number();
let extra_bytes = interface_desc.extra();
if extra_bytes.len() >= 9 {
xfer_size = extra_bytes[5] as u32 | (extra_bytes[6] as u32) << 8;
}
static DFU_SECTION_REGEX: Lazy<Regex> = Lazy::new(|| {
Regex::new("^@([^/]*)/0x([0-9a-fA-F]+)/([0-9]+)\\*([0-9]+)(..)").unwrap()
});
let Some(captures) = DFU_SECTION_REGEX.captures(&interface_name) else {
continue;
};
let section_name = captures.get(1).unwrap().as_str().trim();
if section_name != "Internal Flash" {
continue;
}
base_address = u32::from_str_radix(captures.get(2).unwrap().as_str(), 16).unwrap();
let num_pages = captures.get(3).unwrap().as_str().parse::<u32>().unwrap();
page_size = captures.get(4).unwrap().as_str().parse::<u32>().unwrap();
let suffix = captures.get(5).unwrap().as_str();
if suffix.starts_with('K') {
page_size *= 1024;
}
flash_size = num_pages * page_size;
}
}
Ok(DfuDescriptor {
dfu_interface,
xfer_size,
page_size,
flash_size,
base_address,
})
}
fn wait_for_idle(dfu_device: &UsbBackend, dfu_interface: u8) -> Result<u8> {
loop {
let mut response = [0u8; 6];
let rc = dfu_device.read_control(
rusb::request_type(
rusb::Direction::In,
rusb::RequestType::Class,
rusb::Recipient::Interface,
),
USB_DFU_GETSTATUS,
0,
dfu_interface as u16,
&mut response,
)?;
if rc != response.len() {
bail!(TransportError::FirmwareProgramFailed("".to_string()));
}
let command_status = response[0];
let minimum_delay_ms =
u64::from_le_bytes([response[1], response[2], response[3], 0, 0, 0, 0, 0]);
let device_state = response[4];
if command_status != DFU_STATUS_OK {
bail!(TransportError::FirmwareProgramFailed(format!(
"Unexpected DFU status {}",
response[0]
)));
}
if device_state == DFU_STATE_APP_IDLE
|| device_state == DFU_STATE_DFU_IDLE
|| device_state == DFU_STATE_DOWNLOAD_IDLE
{
return Ok(device_state);
} else if device_state == DFU_STATE_DOWNLOAD_BUSY {
std::thread::sleep(std::time::Duration::from_millis(minimum_delay_ms));
} else {
bail!(TransportError::FirmwareProgramFailed(format!(
"Unexpected DFU state {}",
response[4]
)));
}
}
}
fn is_older_than(version_a: &str, version_b: &str) -> Result<bool> {
let apos = version_a.find(char::is_numeric).unwrap_or(version_a.len());
let bpos = version_b.find(char::is_numeric).unwrap_or(version_b.len());
if version_a[..apos] != version_b[..bpos] {
return Ok(false);
}
let version_a = &version_a[apos..];
let version_b = &version_b[apos..];
if version_a.is_empty() || version_b.is_empty() {
return Ok(false);
}
let apos = version_a
.find(|ch: char| !char::is_numeric(ch))
.unwrap_or(version_a.len());
let bpos = version_b
.find(|ch: char| !char::is_numeric(ch))
.unwrap_or(version_b.len());
let aval = version_a[..apos].parse::<u64>()?;
let bval = version_b[..bpos].parse::<u64>()?;
match aval.cmp(&bval) {
Ordering::Less => Ok(true),
Ordering::Greater => Ok(false),
Ordering::Equal => is_older_than(&version_a[apos..], &version_b[apos..]),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_older_than() {
assert_eq!(
is_older_than("hyp_20240101_99", "hyp_20240801_01").unwrap(),
true
);
assert_eq!(
is_older_than("hyp_20240801_01", "hyp_20240101_99").unwrap(),
false
);
assert_eq!(
is_older_than("hyp_20240101_01", "hyp_20240101_02").unwrap(),
true
);
assert_eq!(
is_older_than("hyp_20240101_02", "hyp_20240101_01").unwrap(),
false
);
assert_eq!(
is_older_than("hyp_20240101_01", "hyp_20240101_01").unwrap(),
false
);
assert_eq!(is_older_than("fancy_1.2.5", "fancy_1.11.1").unwrap(), true);
assert_eq!(is_older_than("fancy_1.11.1", "fancy_1.2.5").unwrap(), false);
assert_eq!(is_older_than("fancy_1.2.2", "fancy_1.2.11").unwrap(), true);
assert_eq!(is_older_than("fancy_1.2.11", "fancy_1.2.2").unwrap(), false);
assert_eq!(
is_older_than("fancy_1.2.11", "fancy_1.2.11").unwrap(),
false
);
assert_eq!(
is_older_than("fancy_1.2.5", "hyperdebug_20240101_02").unwrap(),
false
);
assert_eq!(
is_older_than("hyperdebug_20240101_02", "fancy_1.2.5").unwrap(),
false
);
}
}