1
use crate::entity::{EntityRef, Split, Tag, Transaction};
2
use crate::error::{Error, Result};
3
use crate::host;
4
use scripting_format::{
5
    ContextType, ENTITY_HEADER_SIZE, EntityHeader, EntityType, GLOBAL_HEADER_SIZE, GlobalHeader,
6
};
7

            
8
pub struct InputReader {
9
    buffer: &'static [u8],
10
    header: GlobalHeader,
11
    base_offset: u32,
12
}
13

            
14
impl InputReader {
15
    pub fn new() -> Result<Self> {
16
        let base_offset = host::input_offset();
17
        let base = base_offset as *const u8;
18

            
19
        // SAFETY: WASM linear memory is valid for the lifetime of the module.
20
        // The host guarantees the input buffer is properly initialized.
21
        let (buffer, header) = unsafe {
22
            let header_slice = core::slice::from_raw_parts(base, GLOBAL_HEADER_SIZE);
23
            let header = GlobalHeader::from_bytes(header_slice).ok_or(Error::InvalidMagic)?;
24
            let total_size =
25
                (header.strings_pool_offset - base_offset + header.strings_pool_size) as usize;
26
            let buffer = core::slice::from_raw_parts(base, total_size);
27
            (buffer, header)
28
        };
29

            
30
        Ok(Self {
31
            buffer,
32
            header,
33
            base_offset,
34
        })
35
    }
36

            
37
    pub fn context_type(&self) -> Result<ContextType> {
38
        ContextType::try_from(self.header.context_type).map_err(|()| Error::InvalidEntityType)
39
    }
40

            
41
    pub fn primary_entity_type(&self) -> Result<EntityType> {
42
        EntityType::try_from(self.header.primary_entity_type).map_err(|()| Error::InvalidEntityType)
43
    }
44

            
45
    #[must_use]
46
    pub fn entity_count(&self) -> u32 {
47
        self.header.input_entity_count
48
    }
49

            
50
    #[must_use]
51
    pub fn primary_entity_idx(&self) -> u32 {
52
        self.header.primary_entity_idx
53
    }
54

            
55
    fn strings_pool(&self) -> &[u8] {
56
        let offset = (self.header.strings_pool_offset - self.base_offset) as usize;
57
        let size = self.header.strings_pool_size as usize;
58
        self.buffer.get(offset..offset + size).unwrap_or_default()
59
    }
60

            
61
    fn entity_header_offset(&self, idx: u32) -> Result<usize> {
62
        if idx >= self.header.input_entity_count {
63
            return Err(Error::EntityNotFound);
64
        }
65
        let offset = (self.header.entities_offset - self.base_offset) as usize
66
            + idx as usize * ENTITY_HEADER_SIZE;
67
        if offset + ENTITY_HEADER_SIZE > self.buffer.len() {
68
            return Err(Error::InvalidHeader);
69
        }
70
        Ok(offset)
71
    }
72

            
73
    pub fn entity(&self, idx: u32) -> Result<EntityRef<'_>> {
74
        let offset = self.entity_header_offset(idx)?;
75
        let header_bytes = &self.buffer[offset..offset + ENTITY_HEADER_SIZE];
76
        let header = EntityHeader::from_bytes(header_bytes).ok_or(Error::InvalidHeader)?;
77
        let data_offset = (header.data_offset - self.base_offset) as usize;
78
        let data_size = header.data_size as usize;
79
        let data = self
80
            .buffer
81
            .get(data_offset..data_offset + data_size)
82
            .ok_or(Error::InvalidHeader)?;
83

            
84
        Ok(EntityRef {
85
            header,
86
            data,
87
            strings_pool: self.strings_pool(),
88
        })
89
    }
90

            
91
    pub fn primary(&self) -> Result<EntityRef<'_>> {
92
        self.entity(self.header.primary_entity_idx)
93
    }
94

            
95
    pub fn primary_transaction(&self) -> Result<Transaction> {
96
        self.primary()?.as_transaction()
97
    }
98

            
99
    pub fn transaction(&self, idx: u32) -> Result<Transaction> {
100
        self.entity(idx)?.as_transaction()
101
    }
102

            
103
    pub fn split(&self, idx: u32) -> Result<Split> {
104
        self.entity(idx)?.as_split()
105
    }
106

            
107
    pub fn tag(&self, idx: u32) -> Result<Tag<'_>> {
108
        self.entity(idx)?.as_tag()
109
    }
110

            
111
    pub fn splits_for(&self, parent_idx: u32) -> impl Iterator<Item = Result<Split>> + '_ {
112
        let count = self.header.input_entity_count;
113
        (0..count).filter_map(move |idx| {
114
            let entity = match self.entity(idx) {
115
                Ok(e) => e,
116
                Err(e) => return Some(Err(e)),
117
            };
118
            if entity.header.parent_idx == parent_idx as i32 {
119
                match entity.entity_type() {
120
                    Ok(EntityType::Split) => Some(entity.as_split()),
121
                    Ok(_) => None,
122
                    Err(e) => Some(Err(e)),
123
                }
124
            } else {
125
                None
126
            }
127
        })
128
    }
129

            
130
    pub fn tags_for(&self, parent_idx: u32) -> impl Iterator<Item = Result<Tag<'_>>> + '_ {
131
        let count = self.header.input_entity_count;
132
        (0..count).filter_map(move |idx| {
133
            let entity = match self.entity(idx) {
134
                Ok(e) => e,
135
                Err(e) => return Some(Err(e)),
136
            };
137
            if entity.header.parent_idx == parent_idx as i32 {
138
                match entity.entity_type() {
139
                    Ok(EntityType::Tag) => Some(entity.as_tag()),
140
                    Ok(_) => None,
141
                    Err(e) => Some(Err(e)),
142
                }
143
            } else {
144
                None
145
            }
146
        })
147
    }
148
}