use std::collections::HashMap;
use std::fmt::{Debug, Display};
use std::fs;
use std::ops::{Deref, DerefMut, Range};
use std::path::Path;
use std::time::Duration;
use anyhow::Result;
use object::{Object, ObjectSymbol};
use crate::io::jtag::{Jtag, RiscvCsr, RiscvGpr, RiscvReg};
pub struct ElfSymbols {
symbols: HashMap<String, u32>,
}
pub struct ElfDebugger<'a> {
symbols: &'a ElfSymbols,
jtag: Box<dyn Jtag + 'a>,
}
impl<'a> Deref for ElfDebugger<'a> {
type Target = dyn Jtag + 'a;
#[inline]
fn deref(&self) -> &Self::Target {
&*self.jtag
}
}
impl<'a> DerefMut for ElfDebugger<'a> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
&mut *self.jtag
}
}
#[derive(Clone)]
pub enum SymbolicAddress {
Absolute(u32),
SymbolRelative(String, u32),
}
impl From<String> for SymbolicAddress {
#[inline]
fn from(s: String) -> Self {
Self::SymbolRelative(s, 0)
}
}
impl From<&str> for SymbolicAddress {
#[inline]
fn from(s: &str) -> Self {
Self::SymbolRelative(s.to_owned(), 0)
}
}
impl From<u32> for SymbolicAddress {
#[inline]
fn from(s: u32) -> Self {
Self::Absolute(s)
}
}
impl std::ops::Add<u32> for SymbolicAddress {
type Output = Self;
#[inline]
fn add(self, rhs: u32) -> Self::Output {
match self {
SymbolicAddress::Absolute(addr) => SymbolicAddress::Absolute(addr + rhs),
SymbolicAddress::SymbolRelative(symbol, offset) => {
SymbolicAddress::SymbolRelative(symbol, offset + rhs)
}
}
}
}
impl Debug for SymbolicAddress {
#[inline]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Display::fmt(self, f)
}
}
impl Display for SymbolicAddress {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SymbolicAddress::Absolute(addr) => write!(f, "{addr:#x}"),
SymbolicAddress::SymbolRelative(symbol, offset) => {
if *offset == 0 {
write!(f, "{symbol}")
} else {
write!(f, "{symbol} + {offset:#x}")
}
}
}
}
}
#[derive(Clone)]
pub struct ResolvedAddress {
pub address: SymbolicAddress,
pub resolution: u32,
}
impl Debug for ResolvedAddress {
#[inline]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Display::fmt(self, f)
}
}
impl Display for ResolvedAddress {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Display::fmt(&self.address, f)?;
if !matches!(self.address, SymbolicAddress::Absolute(_)) {
write!(f, " ({:#x})", self.resolution)?;
}
Ok(())
}
}
impl ElfSymbols {
pub fn load_elf(path: impl AsRef<Path>) -> Result<Self> {
let elf_binary = fs::read(path)?;
let elf_file = object::File::parse(&*elf_binary)?;
let mut symbols = HashMap::new();
for sym in elf_file.symbols() {
symbols.insert(sym.name()?.to_owned(), sym.address() as u32);
}
Ok(Self { symbols })
}
pub fn resolve(&self, address: impl Into<SymbolicAddress>) -> Result<ResolvedAddress> {
let address = address.into();
let resolution = match address {
SymbolicAddress::Absolute(addr) => addr,
SymbolicAddress::SymbolRelative(ref symbol, offset) => {
let Some(symbol_addr) = self.symbols.get(symbol).copied() else {
anyhow::bail!("Cannot resolve symbol {symbol}");
};
symbol_addr + offset
}
};
Ok(ResolvedAddress {
address,
resolution,
})
}
pub fn attach<'a>(&'a self, jtag: Box<dyn Jtag + 'a>) -> ElfDebugger<'a> {
ElfDebugger {
symbols: self,
jtag,
}
}
}
impl<'a> ElfDebugger<'a> {
pub fn disconnect(self) -> Result<()> {
self.jtag.disconnect()
}
pub fn resolve(&self, address: impl Into<SymbolicAddress>) -> Result<ResolvedAddress> {
self.symbols.resolve(address)
}
pub fn read_reg(&mut self, reg: impl Into<RiscvReg>) -> Result<u32> {
self.read_riscv_reg(®.into())
}
pub fn write_reg(&mut self, reg: impl Into<RiscvReg>, value: u32) -> Result<()> {
self.write_riscv_reg(®.into(), value)
}
pub fn read_u32(&mut self, addr: u32) -> Result<u32> {
let mut ret = [0];
self.read_memory32(addr, &mut ret)?;
Ok(ret[0])
}
pub fn write_u32(&mut self, addr: u32, value: u32) -> Result<()> {
self.write_memory32(addr, &[value])
}
pub fn get_pc(&mut self) -> Result<u32> {
self.read_riscv_reg(&RiscvReg::Csr(RiscvCsr::DPC))
}
pub fn set_pc(&mut self, address: impl Into<SymbolicAddress>) -> Result<()> {
let resolved = self.resolve(address)?;
log::info!("Set PC to {}", resolved);
self.write_reg(RiscvCsr::DPC, resolved.resolution)?;
Ok(())
}
pub fn set_breakpoint(&mut self, address: impl Into<SymbolicAddress>) -> Result<()> {
let resolved = self.resolve(address)?;
log::info!("Set breakpoint at {}", resolved);
self.jtag.set_breakpoint(resolved.resolution, true)?;
Ok(())
}
pub fn expect_pc(&mut self, address: impl Into<SymbolicAddress>) -> Result<()> {
let resolved = self.resolve(address)?;
let pc = self.get_pc()?;
log::info!("PC = {:#x}, expected PC = {}", pc, resolved);
if pc != resolved.resolution {
anyhow::bail!("unexpected PC");
}
Ok(())
}
pub fn expect_pc_range(&mut self, range: Range<impl Into<SymbolicAddress>>) -> Result<()> {
let start = self.resolve(range.start)?;
let end = self.resolve(range.end)?;
let pc = self.get_pc()?;
log::info!("PC = {:#x}, expected PC = {}..{}", pc, start, end,);
if !(start.resolution..end.resolution).contains(&pc) {
anyhow::bail!("unexpected PC");
}
Ok(())
}
pub fn run_until(
&mut self,
address: impl Into<SymbolicAddress>,
timeout: Duration,
) -> Result<()> {
let resolved = self.resolve(address)?;
log::info!("Run until {}", resolved);
self.jtag.set_breakpoint(resolved.resolution, true)?;
self.jtag.resume()?;
self.jtag.wait_halt(timeout)?;
self.expect_pc(resolved.resolution)?;
self.jtag.remove_breakpoint(resolved.resolution)?;
Ok(())
}
pub fn finish(&mut self, timeout: Duration) -> Result<()> {
let ra = self.read_reg(RiscvGpr::RA)?;
self.run_until(ra, timeout)
}
pub fn call(
&mut self,
address: impl Into<SymbolicAddress>,
args: &[u32],
timeout: Duration,
) -> Result<(u32, u32)> {
const REGS_TO_SAVE: &[RiscvGpr] = &[
RiscvGpr::RA,
RiscvGpr::SP,
RiscvGpr::T0,
RiscvGpr::T1,
RiscvGpr::T2,
RiscvGpr::A0,
RiscvGpr::A1,
RiscvGpr::A2,
RiscvGpr::A3,
RiscvGpr::A4,
RiscvGpr::A5,
RiscvGpr::A6,
RiscvGpr::A7,
RiscvGpr::T3,
RiscvGpr::T4,
RiscvGpr::T5,
RiscvGpr::T6,
];
let mut saved = [0; REGS_TO_SAVE.len()];
for (idx, gpr) in REGS_TO_SAVE.iter().copied().enumerate() {
saved[idx] = self.read_reg(gpr)?;
}
if saved[1] % 16 != 0 {
self.write_reg(RiscvGpr::SP, saved[1] & !15)?;
}
const ARG_GPRS: &[RiscvGpr] = &[
RiscvGpr::A0,
RiscvGpr::A1,
RiscvGpr::A2,
RiscvGpr::A3,
RiscvGpr::A4,
RiscvGpr::A5,
RiscvGpr::A6,
];
assert!(args.len() < ARG_GPRS.len());
for (gpr, arg) in ARG_GPRS.iter().copied().zip(args.iter().copied()) {
self.write_reg(gpr, arg)?;
}
let pc = self.get_pc()?;
self.write_reg(RiscvGpr::RA, pc)?;
self.set_pc(address)?;
self.run_until(pc, timeout)?;
let a0 = self.read_reg(RiscvGpr::A0)?;
let a1 = self.read_reg(RiscvGpr::A1)?;
for (idx, gpr) in REGS_TO_SAVE.iter().copied().enumerate() {
self.write_reg(gpr, saved[idx])?;
}
Ok((a0, a1))
}
}