1use std::collections::HashMap;
4
5use gc_arena::Arena;
6
7use crate::{
8 Instance, OutputMode,
9 currency::{CurrencyConverter, default_currency_converter},
10 data_provider::{DataProvider, InternalProvider},
11 gc_serde::{SerializedRawValue, build_gc_objects, fill_gc_objects},
12 raw_value::{RawValue, ToRawValue},
13 rollback::RollbackAction,
14 series::Series,
15 snapshot::types::{
16 InstanceSnapshot, SNAPSHOT_VERSION, SerializedRollbackAction, SerializedVariableValue,
17 SnapshotError,
18 },
19 state::{InputValue, SeriesVariable, State, VariableValue},
20};
21
22pub struct RestoreBuilder {
25 snapshot: InstanceSnapshot,
26 currency_converter: Box<dyn CurrencyConverter>,
27 data_provider: Option<Box<dyn InternalProvider>>,
28}
29
30impl std::fmt::Debug for RestoreBuilder {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 f.debug_struct("RestoreBuilder")
33 .field("version", &self.snapshot.version)
34 .finish_non_exhaustive()
35 }
36}
37
38impl RestoreBuilder {
39 #[must_use]
43 pub fn currency_converter(mut self, converter: Box<dyn CurrencyConverter>) -> Self {
44 self.currency_converter = converter;
45 self
46 }
47
48 #[must_use]
51 pub fn with_data_provider<P>(mut self, provider: P) -> Self
52 where
53 P: DataProvider + 'static,
54 {
55 self.data_provider = Some(Box::new(provider) as Box<dyn InternalProvider>);
56 self
57 }
58
59 pub fn build(self) -> Result<Instance, SnapshotError> {
66 let snap = self.snapshot;
67
68 let program = snap.program;
71 let timeframe = snap.timeframe;
72 let symbol_info = snap.symbol_info;
73 let candlesticks = snap.candlesticks;
74 let candlesticks_len = candlesticks.len();
75 let last_info = snap.last_info;
76 let bar_index = snap.bar_index;
77 let input_index = snap.input_index;
78 let script_info = snap.script_info;
79 let chart = snap.chart;
80 let events = snap.events;
81 let input_sessions = snap.input_sessions;
82 let strategy_state = snap.strategy_state;
83 let execution_limits = snap.execution_limits;
84
85 let object_table = snap.object_table;
87 let ser_variables = snap.variables;
88 let ser_inputs = snap.inputs;
89 let ser_rollback_actions = snap.rollback_actions;
90 let var_initialized = snap.var_initialized;
91
92 let state_program = program.clone();
95
96 let arena = Arena::new(move |mc| {
98 let gc_objects = build_gc_objects(mc, &object_table, &state_program);
99
100 fill_gc_objects(&object_table, &gc_objects, |srv| match srv {
101 SerializedRawValue::Scalar(f) => f.to_raw_value(),
102 SerializedRawValue::Reference(id) => gc_objects[*id as usize],
103 });
104
105 let resolve = |srv: &SerializedRawValue| -> RawValue {
106 match srv {
107 SerializedRawValue::Scalar(f) => f.to_raw_value(),
108 SerializedRawValue::Reference(id) => gc_objects[*id as usize],
109 }
110 };
111
112 let variables: Vec<VariableValue> = ser_variables
113 .iter()
114 .map(|sv| match sv {
115 SerializedVariableValue::Series {
116 values,
117 offset,
118 max_length,
119 } => {
120 let queue = values.iter().map(&resolve).collect();
121 VariableValue::Series(SeriesVariable {
122 values: Series::from_queue(queue),
123 offset: *offset,
124 max_length: *max_length,
125 })
126 }
127 SerializedVariableValue::Simple(srv) => VariableValue::Simple(resolve(srv)),
128 })
129 .collect();
130
131 let inputs: Vec<InputValue> = ser_inputs
132 .iter()
133 .map(|si| InputValue {
134 value: resolve(&si.value),
135 is_reference_type: si.is_reference_type,
136 value_type: si.value_type.clone(),
137 })
138 .collect();
139
140 let rollback_actions: Vec<RollbackAction> = ser_rollback_actions
141 .iter()
142 .map(|sra| match sra {
143 SerializedRollbackAction::Var {
144 var_id,
145 value,
146 is_reference_type,
147 } => RollbackAction::Var {
148 var_id: *var_id,
149 value: resolve(value),
150 is_reference_type: *is_reference_type,
151 },
152 SerializedRollbackAction::UdtField {
153 udt,
154 field_id,
155 value,
156 is_reference_type,
157 } => RollbackAction::UdtField {
158 udt: resolve(udt),
159 field_id: *field_id,
160 value: resolve(value),
161 is_reference_type: *is_reference_type,
162 },
163 SerializedRollbackAction::RemoveGraph { id } => {
164 RollbackAction::RemoveGraph { id: *id }
165 }
166 SerializedRollbackAction::RestoreGraph { id, graph } => {
167 RollbackAction::RestoreGraph {
168 id: *id,
169 graph: graph.clone(),
170 }
171 }
172 })
173 .collect();
174
175 State {
176 program: state_program,
177 inputs,
178 variables,
179 rollback_actions,
180 var_initialized,
181 }
182 });
183
184 Ok(Instance {
185 program,
186 arena,
187 timeframe,
188 symbol_info,
189 candlesticks,
190 last_info,
191 bar_index,
192 input_index,
193 script_info,
194 chart,
195 events,
196 input_sessions,
197 strategy_state,
198 currency_converter: self.currency_converter,
199 execution_limits,
200 data_provider: self.data_provider,
201 candlestick_buffers: HashMap::new(),
202 security_sub_states: HashMap::new(),
203 security_lower_tf_sub_states: HashMap::new(),
204 last_bar_confirmed: bar_index == candlesticks_len,
209 pending_security_capture: None,
210
211 output_mode: OutputMode::default(),
212 })
213 }
214}
215
216impl Instance {
217 pub fn restore_state(data: &[u8]) -> Result<RestoreBuilder, SnapshotError> {
226 let config = bincode::config::standard();
227 let (snapshot, _): (InstanceSnapshot, _) = bincode::serde::decode_from_slice(data, config)
228 .map_err(|e| SnapshotError::Decode(e.to_string()))?;
229
230 if snapshot.version != SNAPSHOT_VERSION {
231 return Err(SnapshotError::VersionMismatch {
232 expected: SNAPSHOT_VERSION,
233 got: snapshot.version,
234 });
235 }
236
237 Ok(RestoreBuilder {
238 snapshot,
239 currency_converter: default_currency_converter(),
240 data_provider: None,
241 })
242 }
243}