1use std::collections::HashMap;
2
3use gc_arena::Arena;
4use openpine_compiler::{
5 CompileOptions,
6 loader::{CombinedLoader, LibraryLoader, Project},
7 program::Program,
8};
9use openpine_error::ErrorWithSourceFile;
10use serde::{Deserialize, Serialize};
11
12use crate::{
13 Error, Exception, ExecutionLimits, Instance, LastInfo, OutputMode, Series, StrategyConfig,
14 StrategyState, SymbolInfo, TimeFrame, TimeUnit, TradeSession,
15 bar_state::BarState,
16 context::ExecuteContext,
17 currency::{CurrencyConverter, default_currency_converter},
18 data_provider::{DataProvider, InternalProvider, PartialSymbolInfo},
19 inst_executor::Interrupt,
20 native_funcs::NativeFuncs,
21 script_info::{PartialScriptInfo, ScriptInfo, ScriptType},
22 state::{State, VariableValue},
23 visuals::{Chart, Color},
24};
25
26bitflags::bitflags! {
27 #[derive(Debug, Copy, Clone, Eq, PartialEq, Serialize, Deserialize)]
29 pub struct InputSessions: u8 {
30 const REGULAR = 0b0001;
32 const EXTENDED = 0b0010;
34 const OVERNIGHT = 0b0100;
36
37 const ALL = Self::REGULAR.bits() | Self::EXTENDED.bits() | Self::OVERNIGHT.bits();
39 }
40}
41
42impl InputSessions {
43 #[inline]
44 pub(crate) fn allow(&self, trade_session: TradeSession) -> bool {
45 match trade_session {
46 TradeSession::Regular => self.contains(InputSessions::REGULAR),
47 TradeSession::PreMarket | TradeSession::AfterHours => {
48 self.contains(InputSessions::EXTENDED)
49 }
50 TradeSession::Overnight => self.contains(InputSessions::OVERNIGHT),
51 }
52 }
53}
54
55pub struct InstanceBuilder<'a, L> {
57 compile_opts: CompileOptions<'a, L>,
58 input_values: HashMap<usize, serde_value::Value>,
59 timeframe: TimeFrame,
60 input_sessions: InputSessions,
61 symbol: String,
62 background_color: Option<Color>,
63 last_info: Option<LastInfo>,
64 currency_converter: Box<dyn CurrencyConverter>,
65 execution_limits: ExecutionLimits,
66 data_provider: Option<Box<dyn InternalProvider>>,
67 output_mode: OutputMode,
68}
69
70impl<'a, L> InstanceBuilder<'a, L> {
71 pub fn with_path(mut self, path: &'a str) -> Self {
73 self.compile_opts = self.compile_opts.with_path(path);
74 self
75 }
76
77 pub fn with_input_value(mut self, id: usize, value: impl Serialize) -> Self {
104 let value = serde_value::to_value(value).unwrap_or(serde_value::Value::Unit);
105 self.input_values.insert(id, value);
106 self
107 }
108
109 pub fn with_locale(mut self, locale: &'a str) -> Self {
111 self.compile_opts = self.compile_opts.with_locale(locale);
112 self
113 }
114
115 pub fn with_input_sessions(mut self, sessions: InputSessions) -> Self {
117 self.input_sessions = sessions;
118 self
119 }
120
121 pub fn with_background_color(mut self, color: impl Into<Option<Color>>) -> Self {
123 self.background_color = color.into();
124 self
125 }
126
127 pub fn with_last_info(mut self, last_info: impl Into<Option<LastInfo>>) -> Self {
129 self.last_info = last_info.into();
130 self
131 }
132
133 pub fn with_currency_converter(mut self, converter: Box<dyn CurrencyConverter>) -> Self {
139 self.currency_converter = converter;
140 self
141 }
142
143 pub fn with_execution_limits(mut self, limits: ExecutionLimits) -> Self {
148 self.execution_limits = limits;
149 self
150 }
151
152 pub fn with_output_mode(mut self, mode: OutputMode) -> Self {
159 self.output_mode = mode;
160 self
161 }
162}
163
164impl<'a, L> InstanceBuilder<'a, L>
165where
166 L: LibraryLoader,
167{
168 pub fn with_library_loader<Q>(
170 self,
171 library_loader: Q,
172 ) -> InstanceBuilder<'a, CombinedLoader<Q, L>>
173 where
174 Q: LibraryLoader,
175 {
176 InstanceBuilder {
177 compile_opts: self.compile_opts.with_library_loader(library_loader),
178 input_values: self.input_values,
179 timeframe: self.timeframe,
180 input_sessions: self.input_sessions,
181 symbol: self.symbol,
182 background_color: self.background_color,
183 last_info: self.last_info,
184 currency_converter: self.currency_converter,
185 execution_limits: self.execution_limits,
186 data_provider: self.data_provider,
187
188 output_mode: self.output_mode,
189 }
190 }
191
192 pub async fn build(self) -> Result<Instance, Error> {
220 let symbol_info = resolve_symbol_info(&self.symbol, self.data_provider.as_deref()).await?;
221 let program = Self::compile_program(self.compile_opts)?;
222 let script_info = get_script_info(&program, &symbol_info).await?;
223
224 if matches!(script_info.script_type, ScriptType::Library(_)) {
225 return Err(Error::LibraryScriptNotExecutable);
226 }
227
228 let mut arena = Arena::new(|mc| State::new(mc, &program, Some(&script_info)));
229
230 for (id, input_value) in self.input_values {
231 let input = script_info
232 .inputs
233 .get(id)
234 .ok_or(Error::InputValueNotFound(id))?;
235 arena.mutate_root(|mc, state: &mut State| {
236 let value = input
237 .serialize_value(mc, &program, &input_value)
238 .map_err(|err| Error::SetInputValue(id, err.to_string()))?;
239 state.inputs[id].value = value;
240 Ok::<_, Error>(())
241 })?;
242 }
243
244 let strategy_state = if let ScriptType::Strategy(strategy) = &script_info.script_type {
245 let config = StrategyConfig::new(Some(&symbol_info), strategy);
246 Some(Box::new(StrategyState::new(config)))
247 } else {
248 None
249 };
250
251 let mut chart = Chart::default();
252 chart.set_background_color(self.background_color);
253
254 Ok(Instance {
255 program,
256 arena,
257 timeframe: self.timeframe,
258 symbol_info,
259 candlesticks: Series::new(),
260 last_info: self.last_info,
261 bar_index: 0,
262 input_index: 0,
263 script_info,
264 chart,
265 events: vec![],
266 input_sessions: self.input_sessions,
267 strategy_state,
268 currency_converter: self.currency_converter,
269 execution_limits: self.execution_limits,
270 data_provider: self.data_provider,
271 candlestick_buffers: HashMap::new(),
272 security_sub_states: HashMap::new(),
273 security_lower_tf_sub_states: HashMap::new(),
274 last_bar_confirmed: true,
275 pending_security_capture: None,
276
277 output_mode: self.output_mode,
278 })
279 }
280
281 fn compile_program(compile_opts: CompileOptions<'a, L>) -> Result<Program, Error> {
282 let compile_opts = compile_opts
283 .with_library_loader(crate::builtins_loader())
284 .with_load_options(crate::builtins_load_options())
285 .with_native_func_registry(&NativeFuncs);
286 Ok(compile_opts.compile()?)
287 }
288}
289
290async fn resolve_symbol_info(
291 symbol: &str,
292 provider: Option<&dyn InternalProvider>,
293) -> Result<SymbolInfo, Error> {
294 let partial = match provider {
295 Some(p) => p
296 .symbol_info(symbol.to_string())
297 .await
298 .map_err(|e| Error::DataProvider(e.to_string()))?,
299 None => PartialSymbolInfo::default(),
300 };
301 Ok(SymbolInfo::from_partial(symbol, partial)?)
302}
303
304async fn get_script_info(program: &Program, symbol_info: &SymbolInfo) -> Result<ScriptInfo, Error> {
305 let mut script_info = PartialScriptInfo::default();
306
307 let mut candlesticks = Series::new();
308 candlesticks.append_new();
309
310 let mut arena = Arena::new(move |mc| {
311 let mut state = State::new(mc, program, None);
312 for var_value in state.variables.iter_mut() {
313 if let VariableValue::Series(series) = var_value {
314 series.values.append_new();
315 }
316 }
317 state
318 });
319
320 let res = crate::inst_executor::execute(
321 program.instruction(),
322 &mut ExecuteContext {
323 program,
324 arena: &mut arena,
325 candlesticks: &candlesticks,
326 last_info: None,
327 bar_state: BarState::History,
328 bar_index: 0,
329 input_index: 0,
330 partial_script_info: Some(&mut script_info),
331 script_info: None,
332 chart: &mut Chart::default(),
333 events: &mut vec![],
334 current_span: None,
335 module_stack: vec![],
336 timeframe: &TimeFrame::new(1, TimeUnit::Day),
337 symbol_info,
338 input_sessions: InputSessions::ALL,
339 strategy_state: None,
340 currency_converter: None,
341 loop_iterations_remaining: ExecutionLimits::default().max_loop_iterations_per_bar,
342 security_provider: None,
343 candlestick_buffers: &mut HashMap::new(),
344 security_sub_states: &mut HashMap::new(),
345 security_lower_tf_sub_states: &mut HashMap::new(),
346 security_depth: 0,
347 execution_limits: ExecutionLimits::default(),
348 security_capture: None,
349 output_mode: OutputMode::default(),
350 },
351 )
352 .await;
353
354 match res {
355 Ok(_) => {}
356 Err(Interrupt::RuntimeError { error, backtrace }) => {
357 return Err(Error::Exception(Exception::new(
358 ErrorWithSourceFile::new(
359 openpine_error::Error::new(vec![error.span], error.value),
360 program.source_files().clone(),
361 ),
362 backtrace,
363 )));
364 }
365 Err(_) => unreachable!("Unhandled interrupt in top-level execution"),
366 }
367
368 script_info.try_into()
369}
370
371pub fn script_info(source: &str) -> Result<ScriptInfo, Error> {
383 let compile_opts = CompileOptions::new(source)
384 .with_library_loader(crate::builtins_loader())
385 .with_load_options(crate::builtins_load_options())
386 .with_native_func_registry(&NativeFuncs);
387 let program = compile_opts.compile()?;
388 let symbol_info = SymbolInfo::from_partial("NASDAQ:AAPL", PartialSymbolInfo::default())?;
389 pollster::block_on(get_script_info(&program, &symbol_info))
390}
391
392pub fn script_info_from_project(project: &Project) -> Result<ScriptInfo, Error> {
398 let compile_opts = CompileOptions::new_from_project(project)
399 .with_library_loader(crate::builtins_loader())
400 .with_load_options(crate::builtins_load_options())
401 .with_native_func_registry(&NativeFuncs);
402 let program = compile_opts.compile()?;
403 let symbol_info = SymbolInfo::from_partial("NASDAQ:AAPL", PartialSymbolInfo::default())?;
404 pollster::block_on(get_script_info(&program, &symbol_info))
405}
406
407impl Instance {
408 #[inline]
455 pub fn builder<'a, P>(
456 provider: P,
457 source: &'a str,
458 timeframe: TimeFrame,
459 symbol: impl Into<String>,
460 ) -> InstanceBuilder<'a, ()>
461 where
462 P: DataProvider + 'static,
463 {
464 InstanceBuilder {
465 compile_opts: CompileOptions::new(source),
466 input_values: HashMap::new(),
467 timeframe,
468 input_sessions: InputSessions::ALL,
469 symbol: symbol.into(),
470 background_color: None,
471 last_info: None,
472 currency_converter: default_currency_converter(),
473 execution_limits: ExecutionLimits::default(),
474 data_provider: Some(Box::new(provider) as Box<dyn InternalProvider>),
475
476 output_mode: OutputMode::default(),
477 }
478 }
479
480 #[inline]
488 pub fn builder_from_project<'a, P>(
489 provider: P,
490 project: &'a Project,
491 timeframe: TimeFrame,
492 symbol: impl Into<String>,
493 ) -> InstanceBuilder<'a, ()>
494 where
495 P: DataProvider + 'static,
496 {
497 InstanceBuilder {
498 compile_opts: CompileOptions::new_from_project(project),
499 input_values: HashMap::new(),
500 timeframe,
501 input_sessions: InputSessions::ALL,
502 symbol: symbol.into(),
503 background_color: None,
504 last_info: None,
505 currency_converter: default_currency_converter(),
506 execution_limits: ExecutionLimits::default(),
507 data_provider: Some(Box::new(provider) as Box<dyn InternalProvider>),
508
509 output_mode: OutputMode::default(),
510 }
511 }
512}