1use std::collections::HashMap;
4use std::fmt::{Debug, Display};
5use std::ops::{Bound, RangeBounds};
6use std::sync::OnceLock;
7
8use documented::DocumentedVariants;
9use proc_macro2::{Ident, Literal, Span, TokenStream};
10use quote::quote_spanned;
11use serde::{Deserialize, Serialize};
12use slotmap::Key;
13use syn::punctuated::Punctuated;
14use syn::{Expr, Token, parse_quote_spanned};
15
16use super::{
17 GraphLoopId, GraphNode, GraphNodeId, GraphSubgraphId, OpInstGenerics, OperatorInstance,
18 PortIndexValue,
19};
20use crate::diagnostic::{Diagnostic, Diagnostics, Level};
21use crate::parse::{Operator, PortIndex};
22
23#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug)]
25pub enum DelayType {
26 Stratum,
28 MonotoneAccum,
30 Tick,
32 TickLazy,
34}
35
36pub enum PortListSpec {
38 Variadic,
40 Fixed(Punctuated<PortIndex, Token![,]>),
42}
43
44pub struct OperatorConstraints {
46 pub name: &'static str,
48 pub categories: &'static [OperatorCategory],
50
51 pub hard_range_inn: &'static dyn RangeTrait<usize>,
54 pub soft_range_inn: &'static dyn RangeTrait<usize>,
56 pub hard_range_out: &'static dyn RangeTrait<usize>,
58 pub soft_range_out: &'static dyn RangeTrait<usize>,
60 pub num_args: usize,
62 pub persistence_args: &'static dyn RangeTrait<usize>,
64 pub type_args: &'static dyn RangeTrait<usize>,
68 pub is_external_input: bool,
71 pub has_singleton_output: bool,
75 pub flo_type: Option<FloType>,
77
78 pub ports_inn: Option<fn() -> PortListSpec>,
80 pub ports_out: Option<fn() -> PortListSpec>,
82
83 pub input_delaytype_fn: fn(&PortIndexValue) -> Option<DelayType>,
85 pub write_fn: WriteFn,
87}
88
89pub type WriteFn = fn(&WriteContextArgs<'_>, &mut Diagnostics) -> Result<OperatorWriteOutput, ()>;
91
92impl Debug for OperatorConstraints {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 f.debug_struct("OperatorConstraints")
95 .field("name", &self.name)
96 .field("hard_range_inn", &self.hard_range_inn)
97 .field("soft_range_inn", &self.soft_range_inn)
98 .field("hard_range_out", &self.hard_range_out)
99 .field("soft_range_out", &self.soft_range_out)
100 .field("num_args", &self.num_args)
101 .field("persistence_args", &self.persistence_args)
102 .field("type_args", &self.type_args)
103 .field("is_external_input", &self.is_external_input)
104 .field("ports_inn", &self.ports_inn)
105 .field("ports_out", &self.ports_out)
106 .finish()
110 }
111}
112
113#[derive(Default)]
115#[non_exhaustive]
116pub struct OperatorWriteOutput {
117 pub write_prologue: TokenStream,
121 pub write_prologue_after: TokenStream,
124 pub write_iterator: TokenStream,
131 pub write_iterator_after: TokenStream,
133}
134
135pub const RANGE_ANY: &'static dyn RangeTrait<usize> = &(0..);
137pub const RANGE_0: &'static dyn RangeTrait<usize> = &(0..=0);
139pub const RANGE_1: &'static dyn RangeTrait<usize> = &(1..=1);
141
142pub fn identity_write_iterator_fn(
145 &WriteContextArgs {
146 root,
147 op_span,
148 ident,
149 inputs,
150 outputs,
151 is_pull,
152 op_inst:
153 OperatorInstance {
154 generics: OpInstGenerics { type_args, .. },
155 ..
156 },
157 ..
158 }: &WriteContextArgs,
159) -> TokenStream {
160 let generic_type = type_args
161 .first()
162 .map(quote::ToTokens::to_token_stream)
163 .unwrap_or(quote_spanned!(op_span=> _));
164
165 if is_pull {
166 let input = &inputs[0];
167 quote_spanned! {op_span=>
168 let #ident = {
169 fn check_input<Pull, Item>(pull: Pull) -> impl #root::dfir_pipes::Pull<Item = Item, Meta = Pull::Meta, CanPend = Pull::CanPend, CanEnd = Pull::CanEnd>
170 where
171 Pull: #root::dfir_pipes::Pull<Item = Item>,
172 {
173 pull
174 }
175 check_input::<_, #generic_type>(#input)
176 };
177 }
178 } else {
179 let output = &outputs[0];
180 quote_spanned! {op_span=>
181 let #ident = {
182 fn check_output<Si, Item>(sink: Si) -> impl #root::futures::sink::Sink<Item, Error = #root::Never>
183 where
184 Si: #root::futures::sink::Sink<Item, Error = #root::Never>,
185 {
186 sink
187 }
188 check_output::<_, #generic_type>(#output)
189 };
190 }
191 }
192}
193
194pub const IDENTITY_WRITE_FN: WriteFn = |write_context_args, _| {
196 let write_iterator = identity_write_iterator_fn(write_context_args);
197 Ok(OperatorWriteOutput {
198 write_iterator,
199 ..Default::default()
200 })
201};
202
203pub fn null_write_iterator_fn(
206 &WriteContextArgs {
207 root,
208 op_span,
209 ident,
210 inputs,
211 outputs,
212 is_pull,
213 op_inst:
214 OperatorInstance {
215 generics: OpInstGenerics { type_args, .. },
216 ..
217 },
218 ..
219 }: &WriteContextArgs,
220) -> TokenStream {
221 let default_type = parse_quote_spanned! {op_span=> _};
222 let iter_type = type_args.first().unwrap_or(&default_type);
223
224 if is_pull {
225 quote_spanned! {op_span=>
226 let #ident = #root::dfir_pipes::poll_fn({
227 #(
228 let mut #inputs = ::std::boxed::Box::pin(#inputs);
229 )*
230 move |_cx| {
231 #(
235 let #inputs = #root::dfir_pipes::Pull::pull(
236 ::std::pin::Pin::as_mut(&mut #inputs),
237 <_ as #root::dfir_pipes::Context>::from_task(_cx),
238 );
239 )*
240 #(
241 if let #root::dfir_pipes::Step::Pending(_) = #inputs {
242 return #root::dfir_pipes::Step::Pending(#root::dfir_pipes::Yes);
243 }
244 )*
245 #root::dfir_pipes::Step::<_, _, #root::dfir_pipes::Yes, _>::Ended(#root::dfir_pipes::Yes)
246 }
247 });
248 }
249 } else {
250 quote_spanned! {op_span=>
251 #[allow(clippy::let_unit_value)]
252 let _ = (#(#outputs),*);
253 let #ident = #root::sinktools::for_each::ForEach::new::<#iter_type>(::std::mem::drop::<#iter_type>);
254 }
255 }
256}
257
258pub const NULL_WRITE_FN: WriteFn = |write_context_args, _| {
261 let write_iterator = null_write_iterator_fn(write_context_args);
262 Ok(OperatorWriteOutput {
263 write_iterator,
264 ..Default::default()
265 })
266};
267
268macro_rules! declare_ops {
269 ( $( $mod:ident :: $op:ident, )* ) => {
270 $( pub(crate) mod $mod; )*
271 pub const OPERATORS: &[OperatorConstraints] = &[
273 $( $mod :: $op, )*
274 ];
275 };
276}
277declare_ops![
278 all_iterations::ALL_ITERATIONS,
279 all_once::ALL_ONCE,
280 anti_join::ANTI_JOIN,
281 assert::ASSERT,
282 assert_eq::ASSERT_EQ,
283 batch::BATCH,
284 chain::CHAIN,
285 chain_first_n::CHAIN_FIRST_N,
286 _counter::_COUNTER,
287 cross_join::CROSS_JOIN,
288 cross_join_multiset::CROSS_JOIN_MULTISET,
289 cross_singleton::CROSS_SINGLETON,
290 demux_enum::DEMUX_ENUM,
291 dest_file::DEST_FILE,
292 dest_sink::DEST_SINK,
293 dest_sink_serde::DEST_SINK_SERDE,
294 difference::DIFFERENCE,
295 enumerate::ENUMERATE,
296 filter::FILTER,
297 filter_map::FILTER_MAP,
298 flat_map::FLAT_MAP,
299 flatten::FLATTEN,
300 fold::FOLD,
301 fold_no_replay::FOLD_NO_REPLAY,
302 for_each::FOR_EACH,
303 identity::IDENTITY,
304 initialize::INITIALIZE,
305 inspect::INSPECT,
306 join::JOIN,
307 join_fused::JOIN_FUSED,
308 join_fused_lhs::JOIN_FUSED_LHS,
309 join_fused_rhs::JOIN_FUSED_RHS,
310 join_multiset::JOIN_MULTISET,
311 fold_keyed::FOLD_KEYED,
312 reduce_keyed::REDUCE_KEYED,
313 repeat_n::REPEAT_N,
314 lattice_bimorphism::LATTICE_BIMORPHISM,
316 _lattice_fold_batch::_LATTICE_FOLD_BATCH,
317 lattice_fold::LATTICE_FOLD,
318 _lattice_join_fused_join::_LATTICE_JOIN_FUSED_JOIN,
319 lattice_reduce::LATTICE_REDUCE,
320 map::MAP,
321 union::UNION,
322 multiset_delta::MULTISET_DELTA,
323 next_iteration::NEXT_ITERATION,
324 next_stratum::NEXT_STRATUM,
325 defer_signal::DEFER_SIGNAL,
326 defer_tick::DEFER_TICK,
327 defer_tick_lazy::DEFER_TICK_LAZY,
328 null::NULL,
329 partition::PARTITION,
330 persist::PERSIST,
331 persist_mut::PERSIST_MUT,
332 persist_mut_keyed::PERSIST_MUT_KEYED,
333 prefix::PREFIX,
334 resolve_futures::RESOLVE_FUTURES,
335 resolve_futures_blocking::RESOLVE_FUTURES_BLOCKING,
336 resolve_futures_blocking_ordered::RESOLVE_FUTURES_BLOCKING_ORDERED,
337 resolve_futures_ordered::RESOLVE_FUTURES_ORDERED,
338 reduce::REDUCE,
339 reduce_no_replay::REDUCE_NO_REPLAY,
340 scan::SCAN,
341 spin::SPIN,
342 sort::SORT,
343 sort_by_key::SORT_BY_KEY,
344 source_file::SOURCE_FILE,
345 source_interval::SOURCE_INTERVAL,
346 source_iter::SOURCE_ITER,
347 source_json::SOURCE_JSON,
348 source_stdin::SOURCE_STDIN,
349 source_stream::SOURCE_STREAM,
350 source_stream_serde::SOURCE_STREAM_SERDE,
351 state::STATE,
352 state_by::STATE_BY,
353 tee::TEE,
354 unique::UNIQUE,
355 unzip::UNZIP,
356 zip::ZIP,
357 zip_longest::ZIP_LONGEST,
358];
359
360pub fn operator_lookup() -> &'static HashMap<&'static str, &'static OperatorConstraints> {
362 pub static OPERATOR_LOOKUP: OnceLock<HashMap<&'static str, &'static OperatorConstraints>> =
363 OnceLock::new();
364 OPERATOR_LOOKUP.get_or_init(|| OPERATORS.iter().map(|op| (op.name, op)).collect())
365}
366pub fn find_node_op_constraints(node: &GraphNode) -> Option<&'static OperatorConstraints> {
368 if let GraphNode::Operator(operator) = node {
369 find_op_op_constraints(operator)
370 } else {
371 None
372 }
373}
374pub fn find_op_op_constraints(operator: &Operator) -> Option<&'static OperatorConstraints> {
376 let name = &*operator.name_string();
377 operator_lookup().get(name).copied()
378}
379
380#[derive(Clone)]
382pub struct WriteContextArgs<'a> {
383 pub root: &'a TokenStream,
385 pub context: &'a Ident,
388 pub df_ident: &'a Ident,
392 pub subgraph_id: GraphSubgraphId,
394 pub node_id: GraphNodeId,
396 pub loop_id: Option<GraphLoopId>,
398 pub op_span: Span,
400 pub op_tag: Option<String>,
402 pub work_fn: &'a Ident,
404 pub work_fn_async: &'a Ident,
406
407 pub ident: &'a Ident,
409 pub is_pull: bool,
411 pub inputs: &'a [Ident],
413 pub outputs: &'a [Ident],
415 pub singleton_output_ident: &'a Ident,
417
418 pub op_name: &'static str,
420 pub op_inst: &'a OperatorInstance,
422 pub arguments: &'a Punctuated<Expr, Token![,]>,
428 pub arguments_handles: &'a Punctuated<Expr, Token![,]>,
430}
431impl WriteContextArgs<'_> {
432 pub fn make_ident(&self, suffix: impl AsRef<str>) -> Ident {
438 Ident::new(
439 &format!(
440 "sg_{:?}_node_{:?}_{}",
441 self.subgraph_id.data(),
442 self.node_id.data(),
443 suffix.as_ref(),
444 ),
445 self.op_span,
446 )
447 }
448
449 pub fn persistence_as_state_lifespan(&self, persistence: Persistence) -> Option<TokenStream> {
452 let root = self.root;
453 let variant =
454 persistence.as_state_lifespan_variant(self.subgraph_id, self.loop_id, self.op_span)?;
455 Some(quote_spanned! {self.op_span=>
456 #root::scheduled::graph::StateLifespan::#variant
457 })
458 }
459
460 pub fn persistence_args_disallow_mutable<const N: usize>(
462 &self,
463 diagnostics: &mut Diagnostics,
464 ) -> [Persistence; N] {
465 let len = self.op_inst.generics.persistence_args.len();
466 if 0 != len && 1 != len && N != len {
467 diagnostics.push(Diagnostic::spanned(
468 self.op_span,
469 Level::Error,
470 format!(
471 "The operator `{}` only accepts 0, 1, or {} persistence arguments",
472 self.op_name, N
473 ),
474 ));
475 }
476
477 let default_persistence = if self.loop_id.is_some() {
478 Persistence::None
479 } else {
480 Persistence::Tick
481 };
482 let mut out = [default_persistence; N];
483 self.op_inst
484 .generics
485 .persistence_args
486 .iter()
487 .copied()
488 .cycle() .take(N)
490 .enumerate()
491 .filter(|&(_i, p)| {
492 if p == Persistence::Mutable {
493 diagnostics.push(Diagnostic::spanned(
494 self.op_span,
495 Level::Error,
496 format!(
497 "An implementation of `'{}` does not exist",
498 p.to_str_lowercase()
499 ),
500 ));
501 false
502 } else {
503 true
504 }
505 })
506 .for_each(|(i, p)| {
507 out[i] = p;
508 });
509 out
510 }
511}
512
513pub trait RangeTrait<T>: Send + Sync + Debug
515where
516 T: ?Sized,
517{
518 fn start_bound(&self) -> Bound<&T>;
520 fn end_bound(&self) -> Bound<&T>;
522 fn contains(&self, item: &T) -> bool
524 where
525 T: PartialOrd<T>;
526
527 fn human_string(&self) -> String
529 where
530 T: Display + PartialEq,
531 {
532 match (self.start_bound(), self.end_bound()) {
533 (Bound::Unbounded, Bound::Unbounded) => "any number of".to_owned(),
534
535 (Bound::Included(n), Bound::Included(x)) if n == x => {
536 format!("exactly {}", n)
537 }
538 (Bound::Included(n), Bound::Included(x)) => {
539 format!("at least {} and at most {}", n, x)
540 }
541 (Bound::Included(n), Bound::Excluded(x)) => {
542 format!("at least {} and less than {}", n, x)
543 }
544 (Bound::Included(n), Bound::Unbounded) => format!("at least {}", n),
545 (Bound::Excluded(n), Bound::Included(x)) => {
546 format!("more than {} and at most {}", n, x)
547 }
548 (Bound::Excluded(n), Bound::Excluded(x)) => {
549 format!("more than {} and less than {}", n, x)
550 }
551 (Bound::Excluded(n), Bound::Unbounded) => format!("more than {}", n),
552 (Bound::Unbounded, Bound::Included(x)) => format!("at most {}", x),
553 (Bound::Unbounded, Bound::Excluded(x)) => format!("less than {}", x),
554 }
555 }
556}
557
558impl<R, T> RangeTrait<T> for R
559where
560 R: RangeBounds<T> + Send + Sync + Debug,
561{
562 fn start_bound(&self) -> Bound<&T> {
563 self.start_bound()
564 }
565
566 fn end_bound(&self) -> Bound<&T> {
567 self.end_bound()
568 }
569
570 fn contains(&self, item: &T) -> bool
571 where
572 T: PartialOrd<T>,
573 {
574 self.contains(item)
575 }
576}
577
578#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug, Serialize, Deserialize)]
580pub enum Persistence {
581 None,
583 Loop,
585 Tick,
587 Static,
589 Mutable,
591}
592impl Persistence {
593 pub fn as_state_lifespan_variant(
595 self,
596 subgraph_id: GraphSubgraphId,
597 loop_id: Option<GraphLoopId>,
598 span: Span,
599 ) -> Option<TokenStream> {
600 match self {
601 Persistence::None => {
602 let sg_ident = subgraph_id.as_ident(span);
603 Some(quote_spanned!(span=> Subgraph(#sg_ident)))
604 }
605 Persistence::Loop => {
606 let loop_ident = loop_id
607 .expect("`Persistence::Loop` outside of a loop context.")
608 .as_ident(span);
609 Some(quote_spanned!(span=> Loop(#loop_ident)))
610 }
611 Persistence::Tick => Some(quote_spanned!(span=> Tick)),
612 Persistence::Static => None,
613 Persistence::Mutable => None,
614 }
615 }
616
617 pub fn to_str_lowercase(self) -> &'static str {
619 match self {
620 Persistence::None => "none",
621 Persistence::Tick => "tick",
622 Persistence::Loop => "loop",
623 Persistence::Static => "static",
624 Persistence::Mutable => "mutable",
625 }
626 }
627}
628
629fn make_missing_runtime_msg(op_name: &str) -> Literal {
631 Literal::string(&format!(
632 "`{}()` must be used within a Tokio runtime. For example, use `#[dfir_rs::main]` on your main method.",
633 op_name
634 ))
635}
636
637#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, DocumentedVariants)]
639pub enum OperatorCategory {
640 Map,
642 Filter,
644 Flatten,
646 Fold,
648 KeyedFold,
650 LatticeFold,
652 Persistence,
654 MultiIn,
656 MultiOut,
658 Source,
660 Sink,
662 Control,
664 CompilerFusionOperator,
666 Windowing,
668 Unwindowing,
670}
671impl OperatorCategory {
672 pub fn name(self) -> &'static str {
674 self.get_variant_docs().split_once(":").unwrap().0
675 }
676 pub fn description(self) -> &'static str {
678 self.get_variant_docs().split_once(":").unwrap().1
679 }
680}
681
682#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug)]
684pub enum FloType {
685 Source,
687 Windowing,
689 Unwindowing,
691 NextIteration,
693}