acid/firmware/acid-firmware/src/db/mod.rs

277 lines
7.6 KiB
Rust
Raw Normal View History

2026-01-24 21:12:25 +01:00
use core::{
iter::Chain,
ops::{Deref, DerefMut, Range},
};
use alloc::{borrow::Cow, boxed::Box, vec::Vec};
use ekv::{Database, flash::PageID};
use embassy_embedded_hal::{adapter::BlockingAsync, flash::partition::Partition};
use embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex;
use embedded_storage_async::nor_flash::{NorFlash, ReadNorFlash};
use esp_hal::rng::Trng;
use esp_storage::FlashStorage;
use log::info;
pub type PartitionAcid =
Partition<'static, CriticalSectionRawMutex, BlockingAsync<FlashStorage<'static>>>;
// Workaround for alignment requirements.
#[repr(C, align(4))]
struct AlignedBuf<const N: usize>(pub [u8; N]);
pub struct EkvFlash<T> {
flash: T,
buffer: Box<AlignedBuf<{ ekv::config::PAGE_SIZE }>>,
}
impl<T> EkvFlash<T> {
fn new(flash: T) -> Self {
Self {
flash,
buffer: {
// Allocate the buffer directly on the heap.
let buffer = Box::new_zeroed();
unsafe { buffer.assume_init() }
},
}
}
}
impl<T> Deref for EkvFlash<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.flash
}
}
impl<T> DerefMut for EkvFlash<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.flash
}
}
impl<T: NorFlash + ReadNorFlash> ekv::flash::Flash for EkvFlash<T> {
type Error = T::Error;
fn page_count(&self) -> usize {
ekv::config::MAX_PAGE_COUNT
}
async fn erase(
&mut self,
page_id: PageID,
) -> Result<(), <EkvFlash<T> as ekv::flash::Flash>::Error> {
self.flash
.erase(
(page_id.index() * ekv::config::PAGE_SIZE) as u32,
((page_id.index() + 1) * ekv::config::PAGE_SIZE) as u32,
)
.await
}
async fn read(
&mut self,
page_id: PageID,
offset: usize,
data: &mut [u8],
) -> Result<(), <EkvFlash<T> as ekv::flash::Flash>::Error> {
let address = page_id.index() * ekv::config::PAGE_SIZE + offset;
self.flash
.read(address as u32, &mut self.buffer.0[..data.len()])
.await?;
data.copy_from_slice(&self.buffer.0[..data.len()]);
Ok(())
}
async fn write(
&mut self,
page_id: PageID,
offset: usize,
data: &[u8],
) -> Result<(), <EkvFlash<T> as ekv::flash::Flash>::Error> {
let address = page_id.index() * ekv::config::PAGE_SIZE + offset;
self.buffer.0[..data.len()].copy_from_slice(data);
self.flash
.write(address as u32, &self.buffer.0[..data.len()])
.await
}
}
pub struct AcidDatabase {
db: Database<EkvFlash<PartitionAcid>, esp_sync::RawMutex>,
}
impl Deref for AcidDatabase {
type Target = Database<EkvFlash<PartitionAcid>, esp_sync::RawMutex>;
fn deref(&self) -> &Self::Target {
&self.db
}
}
impl DerefMut for AcidDatabase {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.db
}
}
impl AcidDatabase {
pub async fn mount(flash: PartitionAcid) -> AcidDatabase {
let mut db_config = ekv::Config::default();
db_config.random_seed = Trng::try_new()
.expect("A `TrngSource` was not initialized before constructing this `Trng`.")
.random();
let db = Database::<_, esp_sync::RawMutex>::new(EkvFlash::new(flash), db_config);
#[cfg(feature = "format-db")]
{
warn!("Formatting EKV database...");
db.format()
.await
.unwrap_or_else(|error| panic!("Failed to format the EKV database: {error:?}"));
warn!("EKV database formatted successfully.");
}
match db.mount().await {
Ok(()) => info!("EKV database mounted."),
Err(error) => panic!("Failed to mount the EKV database: {error:?}"),
};
Self { db }
}
}
type DbPathSegment<'a> = Cow<'a, str>;
type DbPathBuf<'a> = Vec<DbPathSegment<'a>>;
type DbPath<'a> = [DbPathSegment<'a>];
struct DbKey(Vec<u8>);
impl Deref for DbKey {
type Target = [u8];
fn deref(&self) -> &Self::Target {
&self.0[0..(self.0.len() - 2)]
}
}
impl DbKey {
fn from_raw(mut key: Vec<u8>) -> Self {
key.extend_from_slice(&[0, 1]);
Self(key)
}
fn new<'a>(path: impl IntoIterator<Item = DbPathSegment<'a>>) -> Self {
// Null bytes are not allowed within path segments, and will cause a panic.
// The segments are separated by `[0, 0]`.
// Two null-bytes are used to allow for easy range lookups by suffixing the key with:
// `[0]..[0, 1]`
// To avoid reallocations, we always suffix the key with `[0, 1]`.
// Then, a specific key can be looked up using by omitting the last two bytes.
// By omitting one byte, you get the start of the range of all paths within this path.
// By not omitting any bytes, you get the end of that range.
let mut bytes = Vec::new();
for segment in path {
assert!(
!segment.as_bytes().contains(&0x00),
"A path segment must not contain null bytes."
);
bytes.extend_from_slice(segment.as_bytes());
bytes.extend_from_slice(&[0, 0]);
}
if let Some(last_byte) = bytes.last_mut() {
*last_byte = 1;
} else {
panic!("An empty path is not a valid path.");
}
DbKey(bytes)
}
fn range_of_children(&self) -> Range<&[u8]> {
(&self.0[0..(self.0.len() - 1)])..(&self.0[..])
}
fn segments(&self) -> impl Iterator<Item = DbPathSegment<'_>> {
struct SegmentIterator<'a> {
rest: &'a [u8],
}
impl<'a> Iterator for SegmentIterator<'a> {
type Item = DbPathSegment<'a>;
fn next(&mut self) -> Option<Self::Item> {
if let Some(end_index) = self.rest.iter().position(|byte| *byte == 0) {
let segment = &self.rest[..end_index];
let segment = str::from_utf8(segment).unwrap();
self.rest = &self.rest[end_index + 2..];
Some(Cow::Borrowed(segment))
} else {
None
}
}
}
SegmentIterator {
rest: self.0.as_slice(),
}
}
}
pub struct DbPathSpectreUsers;
impl<'a> IntoIterator for DbPathSpectreUsers {
type Item = DbPathSegment<'static>;
type IntoIter = core::array::IntoIter<DbPathSegment<'static>, 2>;
fn into_iter(self) -> Self::IntoIter {
[
DbPathSegment::Borrowed("spectre"),
DbPathSegment::Borrowed("users"),
]
.into_iter()
}
}
pub struct DbPathSpectreUserSites<'a> {
username: DbPathSegment<'a>,
}
impl<'a> IntoIterator for DbPathSpectreUserSites<'a> {
type Item = DbPathSegment<'a>;
type IntoIter = core::array::IntoIter<DbPathSegment<'a>, 4>;
fn into_iter(self) -> Self::IntoIter {
[
DbPathSegment::Borrowed("spectre"),
DbPathSegment::Borrowed("user"),
self.username,
DbPathSegment::Borrowed("site"),
]
.into_iter()
}
}
pub struct DbPathSpectreUserSite<'a> {
user_sites: DbPathSpectreUserSites<'a>,
site: DbPathSegment<'a>,
}
impl<'a> IntoIterator for DbPathSpectreUserSite<'a> {
type Item = DbPathSegment<'a>;
type IntoIter =
Chain<core::array::IntoIter<DbPathSegment<'a>, 4>, core::iter::Once<DbPathSegment<'a>>>;
fn into_iter(self) -> Self::IntoIter {
self.user_sites
.into_iter()
.chain(core::iter::once(self.site))
}
}