openpine_vm/snapshot/
restore.rs

1//! `Instance::restore_state` and [`RestoreBuilder`] implementation.
2
3use std::collections::HashMap;
4
5use gc_arena::Arena;
6
7use crate::{
8    Instance, OutputMode,
9    currency::{CurrencyConverter, default_currency_converter},
10    data_provider::{DataProvider, InternalProvider},
11    gc_serde::{SerializedRawValue, build_gc_objects, fill_gc_objects},
12    raw_value::{RawValue, ToRawValue},
13    rollback::RollbackAction,
14    series::Series,
15    snapshot::types::{
16        InstanceSnapshot, SNAPSHOT_VERSION, SerializedRollbackAction, SerializedVariableValue,
17        SnapshotError,
18    },
19    state::{InputValue, SeriesVariable, State, VariableValue},
20};
21
22/// Builder returned by [`Instance::restore_state`] for configuring
23/// optional parameters before completing the restore.
24pub struct RestoreBuilder {
25    snapshot: InstanceSnapshot,
26    currency_converter: Box<dyn CurrencyConverter>,
27    data_provider: Option<Box<dyn InternalProvider>>,
28}
29
30impl std::fmt::Debug for RestoreBuilder {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        f.debug_struct("RestoreBuilder")
33            .field("version", &self.snapshot.version)
34            .finish_non_exhaustive()
35    }
36}
37
38impl RestoreBuilder {
39    /// Sets a custom currency converter for cross-currency backtesting.
40    ///
41    /// When not called, the default identity converter is used.
42    #[must_use]
43    pub fn currency_converter(mut self, converter: Box<dyn CurrencyConverter>) -> Self {
44        self.currency_converter = converter;
45        self
46    }
47
48    /// Injects a [`DataProvider`] to enable `request.security()` and to
49    /// supply symbol metadata.
50    #[must_use]
51    pub fn with_data_provider<P>(mut self, provider: P) -> Self
52    where
53        P: DataProvider + 'static,
54    {
55        self.data_provider = Some(Box::new(provider) as Box<dyn InternalProvider>);
56        self
57    }
58
59    /// Finishes restoring the instance from the snapshot.
60    ///
61    /// # Errors
62    ///
63    /// Returns [`SnapshotError`] if the object table contains invalid
64    /// references or objects cannot be reconstructed.
65    pub fn build(self) -> Result<Instance, SnapshotError> {
66        let snap = self.snapshot;
67
68        // Extract Instance-level fields before moving GC-related data
69        // into the Arena closure.
70        let program = snap.program;
71        let timeframe = snap.timeframe;
72        let symbol_info = snap.symbol_info;
73        let candlesticks = snap.candlesticks;
74        let candlesticks_len = candlesticks.len();
75        let last_info = snap.last_info;
76        let bar_index = snap.bar_index;
77        let input_index = snap.input_index;
78        let script_info = snap.script_info;
79        let chart = snap.chart;
80        let events = snap.events;
81        let input_sessions = snap.input_sessions;
82        let strategy_state = snap.strategy_state;
83        let execution_limits = snap.execution_limits;
84
85        // These are consumed inside the arena closure.
86        let object_table = snap.object_table;
87        let ser_variables = snap.variables;
88        let ser_inputs = snap.inputs;
89        let ser_rollback_actions = snap.rollback_actions;
90        let var_initialized = snap.var_initialized;
91
92        // Clone program for State (the arena closure takes ownership of
93        // the clone; the original stays for Instance).
94        let state_program = program.clone();
95
96        // Create a new arena and reconstruct all GC state.
97        let arena = Arena::new(move |mc| {
98            let gc_objects = build_gc_objects(mc, &object_table, &state_program);
99
100            fill_gc_objects(&object_table, &gc_objects, |srv| match srv {
101                SerializedRawValue::Scalar(f) => f.to_raw_value(),
102                SerializedRawValue::Reference(id) => gc_objects[*id as usize],
103            });
104
105            let resolve = |srv: &SerializedRawValue| -> RawValue {
106                match srv {
107                    SerializedRawValue::Scalar(f) => f.to_raw_value(),
108                    SerializedRawValue::Reference(id) => gc_objects[*id as usize],
109                }
110            };
111
112            let variables: Vec<VariableValue> = ser_variables
113                .iter()
114                .map(|sv| match sv {
115                    SerializedVariableValue::Series {
116                        values,
117                        offset,
118                        max_length,
119                    } => {
120                        let queue = values.iter().map(&resolve).collect();
121                        VariableValue::Series(SeriesVariable {
122                            values: Series::from_queue(queue),
123                            offset: *offset,
124                            max_length: *max_length,
125                        })
126                    }
127                    SerializedVariableValue::Simple(srv) => VariableValue::Simple(resolve(srv)),
128                })
129                .collect();
130
131            let inputs: Vec<InputValue> = ser_inputs
132                .iter()
133                .map(|si| InputValue {
134                    value: resolve(&si.value),
135                    is_reference_type: si.is_reference_type,
136                    value_type: si.value_type.clone(),
137                })
138                .collect();
139
140            let rollback_actions: Vec<RollbackAction> = ser_rollback_actions
141                .iter()
142                .map(|sra| match sra {
143                    SerializedRollbackAction::Var {
144                        var_id,
145                        value,
146                        is_reference_type,
147                    } => RollbackAction::Var {
148                        var_id: *var_id,
149                        value: resolve(value),
150                        is_reference_type: *is_reference_type,
151                    },
152                    SerializedRollbackAction::UdtField {
153                        udt,
154                        field_id,
155                        value,
156                        is_reference_type,
157                    } => RollbackAction::UdtField {
158                        udt: resolve(udt),
159                        field_id: *field_id,
160                        value: resolve(value),
161                        is_reference_type: *is_reference_type,
162                    },
163                    SerializedRollbackAction::RemoveGraph { id } => {
164                        RollbackAction::RemoveGraph { id: *id }
165                    }
166                    SerializedRollbackAction::RestoreGraph { id, graph } => {
167                        RollbackAction::RestoreGraph {
168                            id: *id,
169                            graph: graph.clone(),
170                        }
171                    }
172                })
173                .collect();
174
175            State {
176                program: state_program,
177                inputs,
178                variables,
179                rollback_actions,
180                var_initialized,
181            }
182        });
183
184        Ok(Instance {
185            program,
186            arena,
187            timeframe,
188            symbol_info,
189            candlesticks,
190            last_info,
191            bar_index,
192            input_index,
193            script_info,
194            chart,
195            events,
196            input_sessions,
197            strategy_state,
198            currency_converter: self.currency_converter,
199            execution_limits,
200            data_provider: self.data_provider,
201            candlestick_buffers: HashMap::new(),
202            security_sub_states: HashMap::new(),
203            security_lower_tf_sub_states: HashMap::new(),
204            // Derive from existing state: after History or RealtimeConfirmed,
205            // bar_index == candlesticks.len() (bar_index has been advanced past
206            // the last element). After RealtimeNew/Update, bar_index points
207            // into the array, so bar_index < candlesticks.len().
208            last_bar_confirmed: bar_index == candlesticks_len,
209            pending_security_capture: None,
210
211            output_mode: OutputMode::default(),
212        })
213    }
214}
215
216impl Instance {
217    /// Deserializes a snapshot and returns a [`RestoreBuilder`] for
218    /// configuring optional parameters before completing the restore.
219    ///
220    /// # Errors
221    ///
222    /// Returns [`SnapshotError::Decode`] if decoding fails, or
223    /// [`SnapshotError::VersionMismatch`] if the snapshot version does not
224    /// match the current build.
225    pub fn restore_state(data: &[u8]) -> Result<RestoreBuilder, SnapshotError> {
226        let config = bincode::config::standard();
227        let (snapshot, _): (InstanceSnapshot, _) = bincode::serde::decode_from_slice(data, config)
228            .map_err(|e| SnapshotError::Decode(e.to_string()))?;
229
230        if snapshot.version != SNAPSHOT_VERSION {
231            return Err(SnapshotError::VersionMismatch {
232                expected: SNAPSHOT_VERSION,
233                got: snapshot.version,
234            });
235        }
236
237        Ok(RestoreBuilder {
238            snapshot,
239            currency_converter: default_currency_converter(),
240            data_provider: None,
241        })
242    }
243}