1#![warn(missing_docs)]
2
3extern crate proc_macro;
4
5use std::collections::{BTreeMap, BTreeSet, VecDeque};
6use std::fmt::Debug;
7use std::iter::FusedIterator;
8
9use itertools::Itertools;
10use proc_macro2::{Ident, Literal, Span, TokenStream};
11use quote::{ToTokens, TokenStreamExt, format_ident, quote, quote_spanned};
12use serde::{Deserialize, Serialize};
13use slotmap::{Key, SecondaryMap, SlotMap, SparseSecondaryMap};
14use syn::spanned::Spanned;
15
16use super::graph_write::{Dot, GraphWrite, Mermaid};
17use super::ops::{
18 DelayType, OPERATORS, OperatorWriteOutput, WriteContextArgs, find_op_op_constraints,
19 null_write_iterator_fn,
20};
21use super::{
22 CONTEXT, Color, DiMulGraph, GRAPH, GraphEdgeId, GraphLoopId, GraphNode, GraphNodeId,
23 GraphSubgraphId, HANDOFF_NODE_STR, MODULE_BOUNDARY_NODE_STR, OperatorInstance, PortIndexValue,
24 Varname, change_spans, get_operator_generics,
25};
26use crate::diagnostic::{Diagnostic, Diagnostics, Level};
27use crate::pretty_span::{PrettyRowCol, PrettySpan};
28use crate::process_singletons;
29
30#[derive(Default, Debug, Serialize, Deserialize)]
40pub struct DfirGraph {
41 nodes: SlotMap<GraphNodeId, GraphNode>,
43
44 #[serde(skip)]
47 operator_instances: SecondaryMap<GraphNodeId, OperatorInstance>,
48 operator_tag: SecondaryMap<GraphNodeId, String>,
50 graph: DiMulGraph<GraphNodeId, GraphEdgeId>,
52 ports: SecondaryMap<GraphEdgeId, (PortIndexValue, PortIndexValue)>,
54
55 node_loops: SecondaryMap<GraphNodeId, GraphLoopId>,
57 loop_nodes: SlotMap<GraphLoopId, Vec<GraphNodeId>>,
59 loop_parent: SparseSecondaryMap<GraphLoopId, GraphLoopId>,
61 root_loops: Vec<GraphLoopId>,
63 loop_children: SecondaryMap<GraphLoopId, Vec<GraphLoopId>>,
65
66 node_subgraph: SecondaryMap<GraphNodeId, GraphSubgraphId>,
68
69 subgraph_nodes: SlotMap<GraphSubgraphId, Vec<GraphNodeId>>,
71 subgraph_stratum: SecondaryMap<GraphSubgraphId, usize>,
73
74 node_singleton_references: SparseSecondaryMap<GraphNodeId, Vec<Option<GraphNodeId>>>,
76 node_varnames: SparseSecondaryMap<GraphNodeId, Varname>,
78
79 subgraph_laziness: SecondaryMap<GraphSubgraphId, bool>,
83}
84
85impl DfirGraph {
87 pub fn new() -> Self {
89 Default::default()
90 }
91}
92
93impl DfirGraph {
95 pub fn node(&self, node_id: GraphNodeId) -> &GraphNode {
97 self.nodes.get(node_id).expect("Node not found.")
98 }
99
100 pub fn node_op_inst(&self, node_id: GraphNodeId) -> Option<&OperatorInstance> {
105 self.operator_instances.get(node_id)
106 }
107
108 pub fn node_varname(&self, node_id: GraphNodeId) -> Option<&Varname> {
110 self.node_varnames.get(node_id)
111 }
112
113 pub fn node_subgraph(&self, node_id: GraphNodeId) -> Option<GraphSubgraphId> {
115 self.node_subgraph.get(node_id).copied()
116 }
117
118 pub fn node_degree_in(&self, node_id: GraphNodeId) -> usize {
120 self.graph.degree_in(node_id)
121 }
122
123 pub fn node_degree_out(&self, node_id: GraphNodeId) -> usize {
125 self.graph.degree_out(node_id)
126 }
127
128 pub fn node_successors(
130 &self,
131 src: GraphNodeId,
132 ) -> impl '_
133 + DoubleEndedIterator<Item = (GraphEdgeId, GraphNodeId)>
134 + ExactSizeIterator
135 + FusedIterator
136 + Clone
137 + Debug {
138 self.graph.successors(src)
139 }
140
141 pub fn node_predecessors(
143 &self,
144 dst: GraphNodeId,
145 ) -> impl '_
146 + DoubleEndedIterator<Item = (GraphEdgeId, GraphNodeId)>
147 + ExactSizeIterator
148 + FusedIterator
149 + Clone
150 + Debug {
151 self.graph.predecessors(dst)
152 }
153
154 pub fn node_successor_edges(
156 &self,
157 src: GraphNodeId,
158 ) -> impl '_
159 + DoubleEndedIterator<Item = GraphEdgeId>
160 + ExactSizeIterator
161 + FusedIterator
162 + Clone
163 + Debug {
164 self.graph.successor_edges(src)
165 }
166
167 pub fn node_predecessor_edges(
169 &self,
170 dst: GraphNodeId,
171 ) -> impl '_
172 + DoubleEndedIterator<Item = GraphEdgeId>
173 + ExactSizeIterator
174 + FusedIterator
175 + Clone
176 + Debug {
177 self.graph.predecessor_edges(dst)
178 }
179
180 pub fn node_successor_nodes(
182 &self,
183 src: GraphNodeId,
184 ) -> impl '_
185 + DoubleEndedIterator<Item = GraphNodeId>
186 + ExactSizeIterator
187 + FusedIterator
188 + Clone
189 + Debug {
190 self.graph.successor_vertices(src)
191 }
192
193 pub fn node_predecessor_nodes(
195 &self,
196 dst: GraphNodeId,
197 ) -> impl '_
198 + DoubleEndedIterator<Item = GraphNodeId>
199 + ExactSizeIterator
200 + FusedIterator
201 + Clone
202 + Debug {
203 self.graph.predecessor_vertices(dst)
204 }
205
206 pub fn node_ids(&self) -> slotmap::basic::Keys<'_, GraphNodeId, GraphNode> {
208 self.nodes.keys()
209 }
210
211 pub fn nodes(&self) -> slotmap::basic::Iter<'_, GraphNodeId, GraphNode> {
213 self.nodes.iter()
214 }
215
216 pub fn insert_node(
218 &mut self,
219 node: GraphNode,
220 varname_opt: Option<Ident>,
221 loop_opt: Option<GraphLoopId>,
222 ) -> GraphNodeId {
223 let node_id = self.nodes.insert(node);
224 if let Some(varname) = varname_opt {
225 self.node_varnames.insert(node_id, Varname(varname));
226 }
227 if let Some(loop_id) = loop_opt {
228 self.node_loops.insert(node_id, loop_id);
229 self.loop_nodes[loop_id].push(node_id);
230 }
231 node_id
232 }
233
234 pub fn insert_node_op_inst(&mut self, node_id: GraphNodeId, op_inst: OperatorInstance) {
236 assert!(matches!(
237 self.nodes.get(node_id),
238 Some(GraphNode::Operator(_))
239 ));
240 let old_inst = self.operator_instances.insert(node_id, op_inst);
241 assert!(old_inst.is_none());
242 }
243
244 pub fn insert_node_op_insts_all(&mut self, diagnostics: &mut Diagnostics) {
246 let mut op_insts = Vec::new();
247 for (node_id, node) in self.nodes() {
248 let GraphNode::Operator(operator) = node else {
249 continue;
250 };
251 if self.node_op_inst(node_id).is_some() {
252 continue;
253 };
254
255 let Some(op_constraints) = find_op_op_constraints(operator) else {
257 diagnostics.push(Diagnostic::spanned(
258 operator.path.span(),
259 Level::Error,
260 format!("Unknown operator `{}`", operator.name_string()),
261 ));
262 continue;
263 };
264
265 let (input_ports, output_ports) = {
267 let mut input_edges: Vec<(&PortIndexValue, GraphNodeId)> = self
268 .node_predecessors(node_id)
269 .map(|(edge_id, pred_id)| (self.edge_ports(edge_id).1, pred_id))
270 .collect();
271 input_edges.sort();
273 let input_ports: Vec<PortIndexValue> = input_edges
274 .into_iter()
275 .map(|(port, _pred)| port)
276 .cloned()
277 .collect();
278
279 let mut output_edges: Vec<(&PortIndexValue, GraphNodeId)> = self
281 .node_successors(node_id)
282 .map(|(edge_id, succ)| (self.edge_ports(edge_id).0, succ))
283 .collect();
284 output_edges.sort();
286 let output_ports: Vec<PortIndexValue> = output_edges
287 .into_iter()
288 .map(|(port, _succ)| port)
289 .cloned()
290 .collect();
291
292 (input_ports, output_ports)
293 };
294
295 let generics = get_operator_generics(diagnostics, operator);
297 {
299 let generics_span = generics
301 .generic_args
302 .as_ref()
303 .map(Spanned::span)
304 .unwrap_or_else(|| operator.path.span());
305
306 if !op_constraints
307 .persistence_args
308 .contains(&generics.persistence_args.len())
309 {
310 diagnostics.push(Diagnostic::spanned(
311 generics_span,
312 Level::Error,
313 format!(
314 "`{}` should have {} persistence lifetime arguments, actually has {}.",
315 op_constraints.name,
316 op_constraints.persistence_args.human_string(),
317 generics.persistence_args.len()
318 ),
319 ));
320 }
321 if !op_constraints.type_args.contains(&generics.type_args.len()) {
322 diagnostics.push(Diagnostic::spanned(
323 generics_span,
324 Level::Error,
325 format!(
326 "`{}` should have {} generic type arguments, actually has {}.",
327 op_constraints.name,
328 op_constraints.type_args.human_string(),
329 generics.type_args.len()
330 ),
331 ));
332 }
333 }
334
335 op_insts.push((
336 node_id,
337 OperatorInstance {
338 op_constraints,
339 input_ports,
340 output_ports,
341 singletons_referenced: operator.singletons_referenced.clone(),
342 generics,
343 arguments_pre: operator.args.clone(),
344 arguments_raw: operator.args_raw.clone(),
345 },
346 ));
347 }
348
349 for (node_id, op_inst) in op_insts {
350 self.insert_node_op_inst(node_id, op_inst);
351 }
352 }
353
354 pub fn insert_intermediate_node(
366 &mut self,
367 edge_id: GraphEdgeId,
368 new_node: GraphNode,
369 ) -> (GraphNodeId, GraphEdgeId) {
370 let span = Some(new_node.span());
371
372 let op_inst_opt = 'oc: {
374 let GraphNode::Operator(operator) = &new_node else {
375 break 'oc None;
376 };
377 let Some(op_constraints) = find_op_op_constraints(operator) else {
378 break 'oc None;
379 };
380 let (input_port, output_port) = self.ports.get(edge_id).cloned().unwrap();
381
382 let mut dummy_diagnostics = Diagnostics::new();
383 let generics = get_operator_generics(&mut dummy_diagnostics, operator);
384 assert!(dummy_diagnostics.is_empty());
385
386 Some(OperatorInstance {
387 op_constraints,
388 input_ports: vec![input_port],
389 output_ports: vec![output_port],
390 singletons_referenced: operator.singletons_referenced.clone(),
391 generics,
392 arguments_pre: operator.args.clone(),
393 arguments_raw: operator.args_raw.clone(),
394 })
395 };
396
397 let node_id = self.nodes.insert(new_node);
399 if let Some(op_inst) = op_inst_opt {
401 self.operator_instances.insert(node_id, op_inst);
402 }
403 let (e0, e1) = self
405 .graph
406 .insert_intermediate_vertex(node_id, edge_id)
407 .unwrap();
408
409 let (src_idx, dst_idx) = self.ports.remove(edge_id).unwrap();
411 self.ports
412 .insert(e0, (src_idx, PortIndexValue::Elided(span)));
413 self.ports
414 .insert(e1, (PortIndexValue::Elided(span), dst_idx));
415
416 (node_id, e1)
417 }
418
419 pub fn remove_intermediate_node(&mut self, node_id: GraphNodeId) {
422 assert_eq!(
423 1,
424 self.node_degree_in(node_id),
425 "Removed intermediate node must have one predecessor"
426 );
427 assert_eq!(
428 1,
429 self.node_degree_out(node_id),
430 "Removed intermediate node must have one successor"
431 );
432 assert!(
433 self.node_subgraph.is_empty() && self.subgraph_nodes.is_empty(),
434 "Should not remove intermediate node after subgraph partitioning"
435 );
436
437 assert!(self.nodes.remove(node_id).is_some());
438 let (new_edge_id, (pred_edge_id, succ_edge_id)) =
439 self.graph.remove_intermediate_vertex(node_id).unwrap();
440 self.operator_instances.remove(node_id);
441 self.node_varnames.remove(node_id);
442
443 let (src_port, _) = self.ports.remove(pred_edge_id).unwrap();
444 let (_, dst_port) = self.ports.remove(succ_edge_id).unwrap();
445 self.ports.insert(new_edge_id, (src_port, dst_port));
446 }
447
448 pub(crate) fn node_color(&self, node_id: GraphNodeId) -> Option<Color> {
454 if matches!(self.node(node_id), GraphNode::Handoff { .. }) {
455 return Some(Color::Hoff);
456 }
457
458 if let GraphNode::Operator(op) = self.node(node_id)
460 && (op.name_string() == "resolve_futures_blocking"
461 || op.name_string() == "resolve_futures_blocking_ordered")
462 {
463 return Some(Color::Push);
464 }
465
466 let inn_degree = self.node_predecessor_nodes(node_id).count();
468 let out_degree = self.node_successor_nodes(node_id).count();
470
471 match (inn_degree, out_degree) {
472 (0, 0) => None, (0, 1) => Some(Color::Pull),
474 (1, 0) => Some(Color::Push),
475 (1, 1) => None, (_many, 0 | 1) => Some(Color::Pull),
477 (0 | 1, _many) => Some(Color::Push),
478 (_many, _to_many) => Some(Color::Comp),
479 }
480 }
481
482 pub fn set_operator_tag(&mut self, node_id: GraphNodeId, tag: String) {
484 self.operator_tag.insert(node_id, tag);
485 }
486}
487
488impl DfirGraph {
490 pub fn set_node_singleton_references(
493 &mut self,
494 node_id: GraphNodeId,
495 singletons_referenced: Vec<Option<GraphNodeId>>,
496 ) -> Option<Vec<Option<GraphNodeId>>> {
497 self.node_singleton_references
498 .insert(node_id, singletons_referenced)
499 }
500
501 pub fn node_singleton_references(&self, node_id: GraphNodeId) -> &[Option<GraphNodeId>] {
504 self.node_singleton_references
505 .get(node_id)
506 .map(std::ops::Deref::deref)
507 .unwrap_or_default()
508 }
509}
510
511impl DfirGraph {
513 pub fn merge_modules(&mut self) -> Result<(), Diagnostic> {
521 let mod_bound_nodes = self
522 .nodes()
523 .filter(|(_nid, node)| matches!(node, GraphNode::ModuleBoundary { .. }))
524 .map(|(nid, _node)| nid)
525 .collect::<Vec<_>>();
526
527 for mod_bound_node in mod_bound_nodes {
528 self.remove_module_boundary(mod_bound_node)?;
529 }
530
531 Ok(())
532 }
533
534 fn remove_module_boundary(&mut self, mod_bound_node: GraphNodeId) -> Result<(), Diagnostic> {
538 assert!(
539 self.node_subgraph.is_empty() && self.subgraph_nodes.is_empty(),
540 "Should not remove intermediate node after subgraph partitioning"
541 );
542
543 let mut mod_pred_ports = BTreeMap::new();
544 let mut mod_succ_ports = BTreeMap::new();
545
546 for mod_out_edge in self.node_predecessor_edges(mod_bound_node) {
547 let (pred_port, succ_port) = self.edge_ports(mod_out_edge);
548 mod_pred_ports.insert(succ_port.clone(), (mod_out_edge, pred_port.clone()));
549 }
550
551 for mod_inn_edge in self.node_successor_edges(mod_bound_node) {
552 let (pred_port, succ_port) = self.edge_ports(mod_inn_edge);
553 mod_succ_ports.insert(pred_port.clone(), (mod_inn_edge, succ_port.clone()));
554 }
555
556 if mod_pred_ports.keys().collect::<BTreeSet<_>>()
557 != mod_succ_ports.keys().collect::<BTreeSet<_>>()
558 {
559 let GraphNode::ModuleBoundary { input, import_expr } = self.node(mod_bound_node) else {
561 panic!();
562 };
563
564 if *input {
565 return Err(Diagnostic {
566 span: *import_expr,
567 level: Level::Error,
568 message: format!(
569 "The ports into the module did not match. input: {:?}, expected: {:?}",
570 mod_pred_ports.keys().map(|x| x.to_string()).join(", "),
571 mod_succ_ports.keys().map(|x| x.to_string()).join(", ")
572 ),
573 });
574 } else {
575 return Err(Diagnostic {
576 span: *import_expr,
577 level: Level::Error,
578 message: format!(
579 "The ports out of the module did not match. output: {:?}, expected: {:?}",
580 mod_succ_ports.keys().map(|x| x.to_string()).join(", "),
581 mod_pred_ports.keys().map(|x| x.to_string()).join(", "),
582 ),
583 });
584 }
585 }
586
587 for (port, (pred_edge, pred_port)) in mod_pred_ports {
588 let (succ_edge, succ_port) = mod_succ_ports.remove(&port).unwrap();
589
590 let (src, _) = self.edge(pred_edge);
591 let (_, dst) = self.edge(succ_edge);
592 self.remove_edge(pred_edge);
593 self.remove_edge(succ_edge);
594
595 let new_edge_id = self.graph.insert_edge(src, dst);
596 self.ports.insert(new_edge_id, (pred_port, succ_port));
597 }
598
599 self.graph.remove_vertex(mod_bound_node);
600 self.nodes.remove(mod_bound_node);
601
602 Ok(())
603 }
604}
605
606impl DfirGraph {
608 pub fn edge(&self, edge_id: GraphEdgeId) -> (GraphNodeId, GraphNodeId) {
610 let (src, dst) = self.graph.edge(edge_id).expect("Edge not found.");
611 (src, dst)
612 }
613
614 pub fn edge_ports(&self, edge_id: GraphEdgeId) -> (&PortIndexValue, &PortIndexValue) {
616 let (src_port, dst_port) = self.ports.get(edge_id).expect("Edge not found.");
617 (src_port, dst_port)
618 }
619
620 pub fn edge_ids(&self) -> slotmap::basic::Keys<'_, GraphEdgeId, (GraphNodeId, GraphNodeId)> {
622 self.graph.edge_ids()
623 }
624
625 pub fn edges(
627 &self,
628 ) -> impl '_
629 + ExactSizeIterator<Item = (GraphEdgeId, (GraphNodeId, GraphNodeId))>
630 + FusedIterator
631 + Clone
632 + Debug {
633 self.graph.edges()
634 }
635
636 pub fn insert_edge(
638 &mut self,
639 src: GraphNodeId,
640 src_port: PortIndexValue,
641 dst: GraphNodeId,
642 dst_port: PortIndexValue,
643 ) -> GraphEdgeId {
644 let edge_id = self.graph.insert_edge(src, dst);
645 self.ports.insert(edge_id, (src_port, dst_port));
646 edge_id
647 }
648
649 pub fn remove_edge(&mut self, edge: GraphEdgeId) {
651 let (_src, _dst) = self.graph.remove_edge(edge).unwrap();
652 let (_src_port, _dst_port) = self.ports.remove(edge).unwrap();
653 }
654}
655
656impl DfirGraph {
658 pub fn subgraph(&self, subgraph_id: GraphSubgraphId) -> &Vec<GraphNodeId> {
660 self.subgraph_nodes
661 .get(subgraph_id)
662 .expect("Subgraph not found.")
663 }
664
665 pub fn subgraph_ids(&self) -> slotmap::basic::Keys<'_, GraphSubgraphId, Vec<GraphNodeId>> {
667 self.subgraph_nodes.keys()
668 }
669
670 pub fn subgraphs(&self) -> slotmap::basic::Iter<'_, GraphSubgraphId, Vec<GraphNodeId>> {
672 self.subgraph_nodes.iter()
673 }
674
675 pub fn insert_subgraph(
677 &mut self,
678 node_ids: Vec<GraphNodeId>,
679 ) -> Result<GraphSubgraphId, (GraphNodeId, GraphSubgraphId)> {
680 for &node_id in node_ids.iter() {
682 if let Some(&old_sg_id) = self.node_subgraph.get(node_id) {
683 return Err((node_id, old_sg_id));
684 }
685 }
686 let subgraph_id = self.subgraph_nodes.insert_with_key(|sg_id| {
687 for &node_id in node_ids.iter() {
688 self.node_subgraph.insert(node_id, sg_id);
689 }
690 node_ids
691 });
692
693 Ok(subgraph_id)
694 }
695
696 pub fn remove_from_subgraph(&mut self, node_id: GraphNodeId) -> bool {
698 if let Some(old_sg_id) = self.node_subgraph.remove(node_id) {
699 self.subgraph_nodes[old_sg_id].retain(|&other_node_id| other_node_id != node_id);
700 true
701 } else {
702 false
703 }
704 }
705
706 pub fn subgraph_stratum(&self, sg_id: GraphSubgraphId) -> Option<usize> {
708 self.subgraph_stratum.get(sg_id).copied()
709 }
710
711 pub fn set_subgraph_stratum(
713 &mut self,
714 sg_id: GraphSubgraphId,
715 stratum: usize,
716 ) -> Option<usize> {
717 self.subgraph_stratum.insert(sg_id, stratum)
718 }
719
720 fn subgraph_laziness(&self, sg_id: GraphSubgraphId) -> bool {
722 self.subgraph_laziness.get(sg_id).copied().unwrap_or(false)
723 }
724
725 pub fn set_subgraph_laziness(&mut self, sg_id: GraphSubgraphId, lazy: bool) -> bool {
727 self.subgraph_laziness.insert(sg_id, lazy).unwrap_or(false)
728 }
729
730 pub fn max_stratum(&self) -> Option<usize> {
732 self.subgraph_stratum.values().copied().max()
733 }
734
735 fn find_pull_to_push_idx(&self, subgraph_nodes: &[GraphNodeId]) -> usize {
737 subgraph_nodes
738 .iter()
739 .position(|&node_id| {
740 self.node_color(node_id)
741 .is_some_and(|color| Color::Pull != color)
742 })
743 .unwrap_or(subgraph_nodes.len())
744 }
745}
746
747impl DfirGraph {
749 fn node_as_ident(&self, node_id: GraphNodeId, is_pred: bool) -> Ident {
751 let name = match &self.nodes[node_id] {
752 GraphNode::Operator(_) => format!("op_{:?}", node_id.data()),
753 GraphNode::Handoff { .. } => format!(
754 "hoff_{:?}_{}",
755 node_id.data(),
756 if is_pred { "recv" } else { "send" }
757 ),
758 GraphNode::ModuleBoundary { .. } => panic!(),
759 };
760 let span = match (is_pred, &self.nodes[node_id]) {
761 (_, GraphNode::Operator(operator)) => operator.span(),
762 (true, &GraphNode::Handoff { src_span, .. }) => src_span,
763 (false, &GraphNode::Handoff { dst_span, .. }) => dst_span,
764 (_, GraphNode::ModuleBoundary { .. }) => panic!(),
765 };
766 Ident::new(&name, span)
767 }
768
769 fn node_as_singleton_ident(&self, node_id: GraphNodeId, span: Span) -> Ident {
771 Ident::new(&format!("singleton_op_{:?}", node_id.data()), span)
772 }
773
774 fn helper_resolve_singletons(&self, node_id: GraphNodeId, span: Span) -> Vec<Ident> {
776 self.node_singleton_references(node_id)
777 .iter()
778 .map(|singleton_node_id| {
779 self.node_as_singleton_ident(
781 singleton_node_id
782 .expect("Expected singleton to be resolved but was not, this is a bug."),
783 span,
784 )
785 })
786 .collect::<Vec<_>>()
787 }
788
789 fn helper_collect_subgraph_handoffs(
792 &self,
793 ) -> SecondaryMap<GraphSubgraphId, (Vec<GraphNodeId>, Vec<GraphNodeId>)> {
794 let mut subgraph_handoffs: SecondaryMap<
796 GraphSubgraphId,
797 (Vec<GraphNodeId>, Vec<GraphNodeId>),
798 > = self
799 .subgraph_nodes
800 .keys()
801 .map(|k| (k, Default::default()))
802 .collect();
803
804 for (hoff_id, node) in self.nodes() {
806 if !matches!(node, GraphNode::Handoff { .. }) {
807 continue;
808 }
809 for (_edge, succ_id) in self.node_successors(hoff_id) {
811 let succ_sg = self.node_subgraph(succ_id).unwrap();
812 subgraph_handoffs[succ_sg].0.push(hoff_id);
813 }
814 for (_edge, pred_id) in self.node_predecessors(hoff_id) {
816 let pred_sg = self.node_subgraph(pred_id).unwrap();
817 subgraph_handoffs[pred_sg].1.push(hoff_id);
818 }
819 }
820
821 subgraph_handoffs
822 }
823
824 fn codegen_nested_loops(&self, df: &Ident) -> TokenStream {
826 let mut out = TokenStream::new();
828 let mut queue = VecDeque::from_iter(self.root_loops.iter().copied());
829 while let Some(loop_id) = queue.pop_front() {
830 let parent_opt = self
831 .loop_parent(loop_id)
832 .map(|loop_id| loop_id.as_ident(Span::call_site()))
833 .map(|ident| quote! { Some(#ident) })
834 .unwrap_or_else(|| quote! { None });
835 let loop_name = loop_id.as_ident(Span::call_site());
836 out.append_all(quote! {
837 let #loop_name = #df.add_loop(#parent_opt);
838 });
839 queue.extend(self.loop_children.get(loop_id).into_iter().flatten());
840 }
841 out
842 }
843
844 pub fn as_code(
848 &self,
849 root: &TokenStream,
850 include_type_guards: bool,
851 prefix: TokenStream,
852 diagnostics: &mut Diagnostics,
853 ) -> Result<TokenStream, Diagnostics> {
854 let df = Ident::new(GRAPH, Span::call_site());
855 let context = Ident::new(CONTEXT, Span::call_site());
856
857 let handoff_code = self
859 .nodes
860 .iter()
861 .filter_map(|(node_id, node)| match node {
862 GraphNode::Operator(_) => None,
863 &GraphNode::Handoff { src_span, dst_span } => Some((node_id, (src_span, dst_span))),
864 GraphNode::ModuleBoundary { .. } => panic!(),
865 })
866 .map(|(node_id, (src_span, dst_span))| {
867 let ident_send = Ident::new(&format!("hoff_{:?}_send", node_id.data()), dst_span);
868 let ident_recv = Ident::new(&format!("hoff_{:?}_recv", node_id.data()), src_span);
869 let span = src_span.join(dst_span).unwrap_or(src_span);
870 let mut hoff_name = Literal::string(&format!("handoff {:?}", node_id));
871 hoff_name.set_span(span);
872 let hoff_type = quote_spanned! (span=> #root::scheduled::handoff::VecHandoff<_>);
873 quote_spanned! {span=>
874 let (#ident_send, #ident_recv) =
875 #df.make_edge::<_, #hoff_type>(#hoff_name);
876 }
877 });
878
879 let subgraph_handoffs = self.helper_collect_subgraph_handoffs();
880
881 let (subgraphs_without_preds, subgraphs_with_preds) = self
883 .subgraph_nodes
884 .iter()
885 .partition::<Vec<_>, _>(|(_, nodes)| {
886 nodes
887 .iter()
888 .any(|&node_id| self.node_degree_in(node_id) == 0)
889 });
890
891 let mut op_prologue_code = Vec::new();
892 let mut op_prologue_after_code = Vec::new();
893 let mut subgraphs = Vec::new();
894 {
895 for &(subgraph_id, subgraph_nodes) in subgraphs_without_preds
896 .iter()
897 .chain(subgraphs_with_preds.iter())
898 {
899 let (recv_hoffs, send_hoffs) = &subgraph_handoffs[subgraph_id];
900 let recv_ports: Vec<Ident> = recv_hoffs
901 .iter()
902 .map(|&hoff_id| self.node_as_ident(hoff_id, true))
903 .collect();
904 let send_ports: Vec<Ident> = send_hoffs
905 .iter()
906 .map(|&hoff_id| self.node_as_ident(hoff_id, false))
907 .collect();
908
909 let recv_port_code = recv_ports.iter().map(|ident| {
910 quote_spanned! {ident.span()=>
911 let mut #ident = #ident.borrow_mut_swap();
912 let #ident = #root::dfir_pipes::iter(#ident.drain(..));
913 }
914 });
915 let send_port_code = send_ports.iter().map(|ident| {
916 quote_spanned! {ident.span()=>
917 let #ident = #root::sinktools::for_each(|v| {
918 #ident.give(Some(v));
919 });
920 }
921 });
922
923 let loop_id = self
924 .node_loop(subgraph_nodes[0]);
926
927 let mut subgraph_op_iter_code = Vec::new();
928 let mut subgraph_op_iter_after_code = Vec::new();
929 {
930 let pull_to_push_idx = self.find_pull_to_push_idx(subgraph_nodes);
931
932 let (pull_half, push_half) = subgraph_nodes.split_at(pull_to_push_idx);
933 let nodes_iter = pull_half.iter().chain(push_half.iter().rev());
934
935 for (idx, &node_id) in nodes_iter.enumerate() {
936 let node = &self.nodes[node_id];
937 assert!(
938 matches!(node, GraphNode::Operator(_)),
939 "Handoffs are not part of subgraphs."
940 );
941 let op_inst = &self.operator_instances[node_id];
942
943 let op_span = node.span();
944 let op_name = op_inst.op_constraints.name;
945 let root = change_spans(root.clone(), op_span);
947 let op_constraints = OPERATORS
949 .iter()
950 .find(|op| op_name == op.name)
951 .unwrap_or_else(|| panic!("Failed to find op: {}", op_name));
952
953 let ident = self.node_as_ident(node_id, false);
954
955 {
956 let mut input_edges = self
959 .graph
960 .predecessor_edges(node_id)
961 .map(|edge_id| (self.edge_ports(edge_id).1, edge_id))
962 .collect::<Vec<_>>();
963 input_edges.sort();
965
966 let inputs = input_edges
967 .iter()
968 .map(|&(_port, edge_id)| {
969 let (pred, _) = self.edge(edge_id);
970 self.node_as_ident(pred, true)
971 })
972 .collect::<Vec<_>>();
973
974 let mut output_edges = self
976 .graph
977 .successor_edges(node_id)
978 .map(|edge_id| (&self.ports[edge_id].0, edge_id))
979 .collect::<Vec<_>>();
980 output_edges.sort();
982
983 let outputs = output_edges
984 .iter()
985 .map(|&(_port, edge_id)| {
986 let (_, succ) = self.edge(edge_id);
987 self.node_as_ident(succ, false)
988 })
989 .collect::<Vec<_>>();
990
991 let is_pull = idx < pull_to_push_idx;
992
993 let singleton_output_ident = &if op_constraints.has_singleton_output {
994 self.node_as_singleton_ident(node_id, op_span)
995 } else {
996 Ident::new(&format!("{}_has_no_singleton_output", op_name), op_span)
998 };
999
1000 let df_local = &Ident::new(GRAPH, op_span.resolved_at(df.span()));
1009 let context = &Ident::new(CONTEXT, op_span.resolved_at(context.span()));
1010
1011 let singletons_resolved =
1012 self.helper_resolve_singletons(node_id, op_span);
1013 let arguments = &process_singletons::postprocess_singletons(
1014 op_inst.arguments_raw.clone(),
1015 singletons_resolved.clone(),
1016 context,
1017 );
1018 let arguments_handles =
1019 &process_singletons::postprocess_singletons_handles(
1020 op_inst.arguments_raw.clone(),
1021 singletons_resolved.clone(),
1022 );
1023
1024 let source_tag = 'a: {
1025 if let Some(tag) = self.operator_tag.get(node_id).cloned() {
1026 break 'a tag;
1027 }
1028
1029 #[cfg(nightly)]
1030 if proc_macro::is_available() {
1031 let op_span = op_span.unwrap();
1032 break 'a format!(
1033 "loc_{}_{}_{}_{}_{}",
1034 crate::pretty_span::make_source_path_relative(
1035 &op_span.file()
1036 )
1037 .display()
1038 .to_string()
1039 .replace(|x: char| !x.is_ascii_alphanumeric(), "_"),
1040 op_span.start().line(),
1041 op_span.start().column(),
1042 op_span.end().line(),
1043 op_span.end().column(),
1044 );
1045 }
1046
1047 format!(
1048 "loc_nopath_{}_{}_{}_{}",
1049 op_span.start().line,
1050 op_span.start().column,
1051 op_span.end().line,
1052 op_span.end().column
1053 )
1054 };
1055
1056 let work_fn = format_ident!(
1057 "{}__{}__{}",
1058 ident,
1059 op_name,
1060 source_tag,
1061 span = op_span
1062 );
1063 let work_fn_async = format_ident!("{}__async", work_fn, span = op_span);
1064
1065 let context_args = WriteContextArgs {
1066 root: &root,
1067 df_ident: df_local,
1068 context,
1069 subgraph_id,
1070 node_id,
1071 loop_id,
1072 op_span,
1073 op_tag: self.operator_tag.get(node_id).cloned(),
1074 work_fn: &work_fn,
1075 work_fn_async: &work_fn_async,
1076 ident: &ident,
1077 is_pull,
1078 inputs: &inputs,
1079 outputs: &outputs,
1080 singleton_output_ident,
1081 op_name,
1082 op_inst,
1083 arguments,
1084 arguments_handles,
1085 };
1086
1087 let write_result =
1088 (op_constraints.write_fn)(&context_args, diagnostics);
1089 let OperatorWriteOutput {
1090 write_prologue,
1091 write_prologue_after,
1092 write_iterator,
1093 write_iterator_after,
1094 } = write_result.unwrap_or_else(|()| {
1095 assert!(
1096 diagnostics.has_error(),
1097 "Operator `{}` returned `Err` but emitted no diagnostics, this is a bug.",
1098 op_name,
1099 );
1100 OperatorWriteOutput { write_iterator: null_write_iterator_fn(&context_args), ..Default::default() }
1101 });
1102
1103 op_prologue_code.push(syn::parse_quote! {
1104 #[allow(non_snake_case)]
1105 #[inline(always)]
1106 fn #work_fn<T>(thunk: impl ::std::ops::FnOnce() -> T) -> T {
1107 thunk()
1108 }
1109
1110 #[allow(non_snake_case)]
1111 #[inline(always)]
1112 async fn #work_fn_async<T>(thunk: impl ::std::future::Future<Output = T>) -> T {
1113 thunk.await
1114 }
1115 });
1116 op_prologue_code.push(write_prologue);
1117 op_prologue_after_code.push(write_prologue_after);
1118 subgraph_op_iter_code.push(write_iterator);
1119
1120 if include_type_guards {
1121 let type_guard = if is_pull {
1123 quote_spanned! {op_span=>
1124 let #ident = {
1125 #[allow(non_snake_case)]
1126 #[inline(always)]
1127 pub fn #work_fn<Item, Input>(input: Input)
1128 -> impl #root::dfir_pipes::Pull<Item = Item, Meta = (), CanPend = Input::CanPend, CanEnd = Input::CanEnd>
1129 where
1130 Input: #root::dfir_pipes::Pull<Item = Item, Meta = ()>,
1131 {
1132 #root::pin_project_lite::pin_project! {
1133 #[repr(transparent)]
1134 struct Pull<Item, Input: #root::dfir_pipes::Pull<Item = Item>> {
1135 #[pin]
1136 inner: Input
1137 }
1138 }
1139
1140 impl<Item, Input> #root::dfir_pipes::Pull for Pull<Item, Input>
1141 where
1142 Input: #root::dfir_pipes::Pull<Item = Item>,
1143 {
1144 type Ctx<'ctx> = Input::Ctx<'ctx>;
1145
1146 type Item = Item;
1147 type Meta = Input::Meta;
1148 type CanPend = Input::CanPend;
1149 type CanEnd = Input::CanEnd;
1150
1151 #[inline(always)]
1152 fn pull(
1153 self: ::std::pin::Pin<&mut Self>,
1154 ctx: &mut Self::Ctx<'_>,
1155 ) -> #root::dfir_pipes::Step<Self::Item, Self::Meta, Self::CanPend, Self::CanEnd> {
1156 #root::dfir_pipes::Pull::pull(self.project().inner, ctx)
1157 }
1158
1159 #[inline(always)]
1160 fn size_hint(self: ::std::pin::Pin<&Self>) -> (usize, Option<usize>) {
1161 #root::dfir_pipes::Pull::size_hint(self.project_ref().inner)
1162 }
1163 }
1164
1165 Pull {
1166 inner: input
1167 }
1168 }
1169 #work_fn::<_, _>( #ident )
1170 };
1171 }
1172 } else {
1173 quote_spanned! {op_span=>
1174 let #ident = {
1175 #[allow(non_snake_case)]
1176 #[inline(always)]
1177 pub fn #work_fn<Item, Si>(si: Si) -> impl #root::futures::sink::Sink<Item, Error = #root::Never>
1178 where
1179 Si: #root::futures::sink::Sink<Item, Error = #root::Never>
1180 {
1181 #root::pin_project_lite::pin_project! {
1182 #[repr(transparent)]
1183 struct Push<Si> {
1184 #[pin]
1185 si: Si,
1186 }
1187 }
1188
1189 impl<Item, Si> #root::futures::sink::Sink<Item> for Push<Si>
1190 where
1191 Si: #root::futures::sink::Sink<Item>,
1192 {
1193 type Error = Si::Error;
1194
1195 fn poll_ready(
1196 self: ::std::pin::Pin<&mut Self>,
1197 cx: &mut ::std::task::Context<'_>,
1198 ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
1199 self.project().si.poll_ready(cx)
1200 }
1201
1202 fn start_send(
1203 self: ::std::pin::Pin<&mut Self>,
1204 item: Item,
1205 ) -> ::std::result::Result<(), Self::Error> {
1206 self.project().si.start_send(item)
1207 }
1208
1209 fn poll_flush(
1210 self: ::std::pin::Pin<&mut Self>,
1211 cx: &mut ::std::task::Context<'_>,
1212 ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
1213 self.project().si.poll_flush(cx)
1214 }
1215
1216 fn poll_close(
1217 self: ::std::pin::Pin<&mut Self>,
1218 cx: &mut ::std::task::Context<'_>,
1219 ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
1220 self.project().si.poll_close(cx)
1221 }
1222 }
1223
1224 Push {
1225 si
1226 }
1227 }
1228 #work_fn( #ident )
1229 };
1230 }
1231 };
1232 subgraph_op_iter_code.push(type_guard);
1233 }
1234 subgraph_op_iter_after_code.push(write_iterator_after);
1235 }
1236 }
1237
1238 {
1239 let pull_ident = if 0 < pull_to_push_idx {
1241 self.node_as_ident(subgraph_nodes[pull_to_push_idx - 1], false)
1242 } else {
1243 recv_ports[0].clone()
1245 };
1246
1247 #[rustfmt::skip]
1248 let push_ident = if let Some(&node_id) =
1249 subgraph_nodes.get(pull_to_push_idx)
1250 {
1251 self.node_as_ident(node_id, false)
1252 } else if 1 == send_ports.len() {
1253 send_ports[0].clone()
1255 } else {
1256 diagnostics.push(Diagnostic::spanned(
1257 pull_ident.span(),
1258 Level::Error,
1259 "Degenerate subgraph detected, is there a disconnected `null()` or other degenerate pipeline somewhere?",
1260 ));
1261 continue;
1262 };
1263
1264 let pivot_span = pull_ident
1266 .span()
1267 .join(push_ident.span())
1268 .unwrap_or_else(|| push_ident.span());
1269 let pivot_fn_ident =
1270 Ident::new(&format!("pivot_run_sg_{:?}", subgraph_id.0), pivot_span);
1271 let root = change_spans(root.clone(), pivot_span);
1272 subgraph_op_iter_code.push(quote_spanned! {pivot_span=>
1273 #[inline(always)]
1274 fn #pivot_fn_ident<Pull, Push, Item>(pull: Pull, push: Push)
1275 -> impl ::std::future::Future<Output = ::std::result::Result<(), #root::Never>>
1276 where
1277 Pull: #root::dfir_pipes::Pull<Item = Item>,
1278 Push: #root::futures::sink::Sink<Item, Error = #root::Never>,
1279 {
1280 #root::dfir_pipes::Pull::send_sink(pull, push)
1281 }
1282 (#pivot_fn_ident)(#pull_ident, #push_ident).await.unwrap();
1283 });
1284 }
1285 };
1286
1287 let subgraph_name = Literal::string(&format!("Subgraph {:?}", subgraph_id));
1288 let stratum = Literal::usize_unsuffixed(
1289 self.subgraph_stratum.get(subgraph_id).cloned().unwrap_or(0),
1290 );
1291 let laziness = self.subgraph_laziness(subgraph_id);
1292
1293 let loop_id_opt = loop_id
1295 .map(|loop_id| loop_id.as_ident(Span::call_site()))
1296 .map(|ident| quote! { Some(#ident) })
1297 .unwrap_or_else(|| quote! { None });
1298
1299 let sg_ident = subgraph_id.as_ident(Span::call_site());
1300
1301 subgraphs.push(quote! {
1302 let #sg_ident = #df.add_subgraph_full(
1303 #subgraph_name,
1304 #stratum,
1305 var_expr!( #( #recv_ports ),* ),
1306 var_expr!( #( #send_ports ),* ),
1307 #laziness,
1308 #loop_id_opt,
1309 async move |#context, var_args!( #( #recv_ports ),* ), var_args!( #( #send_ports ),* )| {
1310 #( #recv_port_code )*
1311 #( #send_port_code )*
1312 #( #subgraph_op_iter_code )*
1313 #( #subgraph_op_iter_after_code )*
1314 },
1315 );
1316 });
1317 }
1318 }
1319
1320 if diagnostics.has_error() {
1321 return Err(std::mem::take(diagnostics));
1322 }
1323 let _ = diagnostics; let loop_code = self.codegen_nested_loops(&df);
1326
1327 let code = quote! {
1332 #( #handoff_code )*
1333 #loop_code
1334 #( #op_prologue_code )*
1335 #( #subgraphs )*
1336 #( #op_prologue_after_code )*
1337 };
1338
1339 let meta_graph_json = serde_json::to_string(&self).unwrap();
1340 let meta_graph_json = Literal::string(&meta_graph_json);
1341
1342 let serde_diagnostics: Vec<_> = diagnostics.iter().map(Diagnostic::to_serde).collect();
1343 let diagnostics_json = serde_json::to_string(&*serde_diagnostics).unwrap();
1344 let diagnostics_json = Literal::string(&diagnostics_json);
1345
1346 Ok(quote! {
1347 {
1348 #[allow(unused_qualifications, clippy::await_holding_refcell_ref)]
1349 {
1350 #prefix
1351
1352 use #root::{var_expr, var_args};
1353
1354 let mut #df = #root::scheduled::graph::Dfir::new();
1355 #df.__assign_meta_graph(#meta_graph_json);
1356 #df.__assign_diagnostics(#diagnostics_json);
1357
1358 #code
1359
1360 #df
1361 }
1362 }
1363 })
1364 }
1365
1366 pub fn node_color_map(&self) -> SparseSecondaryMap<GraphNodeId, Color> {
1369 let mut node_color_map: SparseSecondaryMap<GraphNodeId, Color> = self
1370 .node_ids()
1371 .filter_map(|node_id| {
1372 let op_color = self.node_color(node_id)?;
1373 Some((node_id, op_color))
1374 })
1375 .collect();
1376
1377 for sg_nodes in self.subgraph_nodes.values() {
1379 let pull_to_push_idx = self.find_pull_to_push_idx(sg_nodes);
1380
1381 for (idx, node_id) in sg_nodes.iter().copied().enumerate() {
1382 let is_pull = idx < pull_to_push_idx;
1383 node_color_map.insert(node_id, if is_pull { Color::Pull } else { Color::Push });
1384 }
1385 }
1386
1387 node_color_map
1388 }
1389
1390 pub fn to_mermaid(&self, write_config: &WriteConfig) -> String {
1392 let mut output = String::new();
1393 self.write_mermaid(&mut output, write_config).unwrap();
1394 output
1395 }
1396
1397 pub fn write_mermaid(
1399 &self,
1400 output: impl std::fmt::Write,
1401 write_config: &WriteConfig,
1402 ) -> std::fmt::Result {
1403 let mut graph_write = Mermaid::new(output);
1404 self.write_graph(&mut graph_write, write_config)
1405 }
1406
1407 pub fn to_dot(&self, write_config: &WriteConfig) -> String {
1409 let mut output = String::new();
1410 let mut graph_write = Dot::new(&mut output);
1411 self.write_graph(&mut graph_write, write_config).unwrap();
1412 output
1413 }
1414
1415 pub fn write_dot(
1417 &self,
1418 output: impl std::fmt::Write,
1419 write_config: &WriteConfig,
1420 ) -> std::fmt::Result {
1421 let mut graph_write = Dot::new(output);
1422 self.write_graph(&mut graph_write, write_config)
1423 }
1424
1425 pub(crate) fn write_graph<W>(
1427 &self,
1428 mut graph_write: W,
1429 write_config: &WriteConfig,
1430 ) -> Result<(), W::Err>
1431 where
1432 W: GraphWrite,
1433 {
1434 fn helper_edge_label(
1435 src_port: &PortIndexValue,
1436 dst_port: &PortIndexValue,
1437 ) -> Option<String> {
1438 let src_label = match src_port {
1439 PortIndexValue::Path(path) => Some(path.to_token_stream().to_string()),
1440 PortIndexValue::Int(index) => Some(index.value.to_string()),
1441 _ => None,
1442 };
1443 let dst_label = match dst_port {
1444 PortIndexValue::Path(path) => Some(path.to_token_stream().to_string()),
1445 PortIndexValue::Int(index) => Some(index.value.to_string()),
1446 _ => None,
1447 };
1448 let label = match (src_label, dst_label) {
1449 (Some(l1), Some(l2)) => Some(format!("{}\n{}", l1, l2)),
1450 (Some(l1), None) => Some(l1),
1451 (None, Some(l2)) => Some(l2),
1452 (None, None) => None,
1453 };
1454 label
1455 }
1456
1457 let node_color_map = self.node_color_map();
1459
1460 graph_write.write_prologue()?;
1462
1463 let mut skipped_handoffs = BTreeSet::new();
1465 let mut subgraph_handoffs = <BTreeMap<GraphSubgraphId, Vec<GraphNodeId>>>::new();
1466 for (node_id, node) in self.nodes() {
1467 if matches!(node, GraphNode::Handoff { .. }) {
1468 if write_config.no_handoffs {
1469 skipped_handoffs.insert(node_id);
1470 continue;
1471 } else {
1472 let pred_node = self.node_predecessor_nodes(node_id).next().unwrap();
1473 let pred_sg = self.node_subgraph(pred_node);
1474 let succ_node = self.node_successor_nodes(node_id).next().unwrap();
1475 let succ_sg = self.node_subgraph(succ_node);
1476 if let Some((pred_sg, succ_sg)) = pred_sg.zip(succ_sg)
1477 && pred_sg == succ_sg
1478 {
1479 subgraph_handoffs.entry(pred_sg).or_default().push(node_id);
1480 }
1481 }
1482 }
1483 graph_write.write_node_definition(
1484 node_id,
1485 &if write_config.op_short_text {
1486 node.to_name_string()
1487 } else if write_config.op_text_no_imports {
1488 let full_text = node.to_pretty_string();
1490 let mut output = String::new();
1491 for sentence in full_text.split('\n') {
1492 if sentence.trim().starts_with("use") {
1493 continue;
1494 }
1495 output.push('\n');
1496 output.push_str(sentence);
1497 }
1498 output.into()
1499 } else {
1500 node.to_pretty_string()
1501 },
1502 if write_config.no_pull_push {
1503 None
1504 } else {
1505 node_color_map.get(node_id).copied()
1506 },
1507 )?;
1508 }
1509
1510 for (edge_id, (src_id, mut dst_id)) in self.edges() {
1512 if skipped_handoffs.contains(&src_id) {
1514 continue;
1515 }
1516
1517 let (src_port, mut dst_port) = self.edge_ports(edge_id);
1518 if skipped_handoffs.contains(&dst_id) {
1519 let mut handoff_succs = self.node_successors(dst_id);
1520 assert_eq!(1, handoff_succs.len());
1521 let (succ_edge, succ_node) = handoff_succs.next().unwrap();
1522 dst_id = succ_node;
1523 dst_port = self.edge_ports(succ_edge).1;
1524 }
1525
1526 let label = helper_edge_label(src_port, dst_port);
1527 let delay_type = self
1528 .node_op_inst(dst_id)
1529 .and_then(|op_inst| (op_inst.op_constraints.input_delaytype_fn)(dst_port));
1530 graph_write.write_edge(src_id, dst_id, delay_type, label.as_deref(), false)?;
1531 }
1532
1533 if !write_config.no_references {
1535 for dst_id in self.node_ids() {
1536 for src_ref_id in self
1537 .node_singleton_references(dst_id)
1538 .iter()
1539 .copied()
1540 .flatten()
1541 {
1542 let delay_type = Some(DelayType::Stratum);
1543 let label = None;
1544 graph_write.write_edge(src_ref_id, dst_id, delay_type, label, true)?;
1545 }
1546 }
1547 }
1548
1549 let loop_subgraphs = self.subgraph_ids().map(|sg_id| {
1560 let loop_id = if write_config.no_loops {
1561 None
1562 } else {
1563 self.subgraph_loop(sg_id)
1564 };
1565 (loop_id, sg_id)
1566 });
1567 let loop_subgraphs = into_group_map(loop_subgraphs);
1568 for (loop_id, subgraph_ids) in loop_subgraphs {
1569 if let Some(loop_id) = loop_id {
1570 graph_write.write_loop_start(loop_id)?;
1571 }
1572
1573 let subgraph_varnames_nodes = subgraph_ids.into_iter().flat_map(|sg_id| {
1575 self.subgraph(sg_id).iter().copied().map(move |node_id| {
1576 let opt_sg_id = if write_config.no_subgraphs {
1577 None
1578 } else {
1579 Some(sg_id)
1580 };
1581 (opt_sg_id, (self.node_varname(node_id), node_id))
1582 })
1583 });
1584 let subgraph_varnames_nodes = into_group_map(subgraph_varnames_nodes);
1585 for (sg_id, varnames) in subgraph_varnames_nodes {
1586 if let Some(sg_id) = sg_id {
1587 let stratum = self.subgraph_stratum(sg_id).unwrap();
1588 graph_write.write_subgraph_start(sg_id, stratum)?;
1589 }
1590
1591 let varname_nodes = varnames.into_iter().map(|(varname, node)| {
1593 let varname = if write_config.no_varnames {
1594 None
1595 } else {
1596 varname
1597 };
1598 (varname, node)
1599 });
1600 let varname_nodes = into_group_map(varname_nodes);
1601 for (varname, node_ids) in varname_nodes {
1602 if let Some(varname) = varname {
1603 graph_write.write_varname_start(&varname.0.to_string(), sg_id)?;
1604 }
1605
1606 for node_id in node_ids {
1608 graph_write.write_node(node_id)?;
1609 }
1610
1611 if varname.is_some() {
1612 graph_write.write_varname_end()?;
1613 }
1614 }
1615
1616 if sg_id.is_some() {
1617 graph_write.write_subgraph_end()?;
1618 }
1619 }
1620
1621 if loop_id.is_some() {
1622 graph_write.write_loop_end()?;
1623 }
1624 }
1625
1626 graph_write.write_epilogue()?;
1628
1629 Ok(())
1630 }
1631
1632 pub fn surface_syntax_string(&self) -> String {
1634 let mut string = String::new();
1635 self.write_surface_syntax(&mut string).unwrap();
1636 string
1637 }
1638
1639 pub fn write_surface_syntax(&self, write: &mut impl std::fmt::Write) -> std::fmt::Result {
1641 for (key, node) in self.nodes.iter() {
1642 match node {
1643 GraphNode::Operator(op) => {
1644 writeln!(write, "{:?} = {};", key.data(), op.to_token_stream())?;
1645 }
1646 GraphNode::Handoff { .. } => {
1647 writeln!(write, "// {:?} = <handoff>;", key.data())?;
1648 }
1649 GraphNode::ModuleBoundary { .. } => panic!(),
1650 }
1651 }
1652 writeln!(write)?;
1653 for (_e, (src_key, dst_key)) in self.graph.edges() {
1654 writeln!(write, "{:?} -> {:?};", src_key.data(), dst_key.data())?;
1655 }
1656 Ok(())
1657 }
1658
1659 pub fn mermaid_string_flat(&self) -> String {
1661 let mut string = String::new();
1662 self.write_mermaid_flat(&mut string).unwrap();
1663 string
1664 }
1665
1666 pub fn write_mermaid_flat(&self, write: &mut impl std::fmt::Write) -> std::fmt::Result {
1668 writeln!(write, "flowchart TB")?;
1669 for (key, node) in self.nodes.iter() {
1670 match node {
1671 GraphNode::Operator(operator) => writeln!(
1672 write,
1673 " %% {span}\n {id:?}[\"{row_col} <tt>{code}</tt>\"]",
1674 span = PrettySpan(node.span()),
1675 id = key.data(),
1676 row_col = PrettyRowCol(node.span()),
1677 code = operator
1678 .to_token_stream()
1679 .to_string()
1680 .replace('&', "&")
1681 .replace('<', "<")
1682 .replace('>', ">")
1683 .replace('"', """)
1684 .replace('\n', "<br>"),
1685 ),
1686 GraphNode::Handoff { .. } => {
1687 writeln!(write, r#" {:?}{{"{}"}}"#, key.data(), HANDOFF_NODE_STR)
1688 }
1689 GraphNode::ModuleBoundary { .. } => {
1690 writeln!(
1691 write,
1692 r#" {:?}{{"{}"}}"#,
1693 key.data(),
1694 MODULE_BOUNDARY_NODE_STR
1695 )
1696 }
1697 }?;
1698 }
1699 writeln!(write)?;
1700 for (_e, (src_key, dst_key)) in self.graph.edges() {
1701 writeln!(write, " {:?}-->{:?}", src_key.data(), dst_key.data())?;
1702 }
1703 Ok(())
1704 }
1705}
1706
1707impl DfirGraph {
1709 pub fn loop_ids(&self) -> slotmap::basic::Keys<'_, GraphLoopId, Vec<GraphNodeId>> {
1711 self.loop_nodes.keys()
1712 }
1713
1714 pub fn loops(&self) -> slotmap::basic::Iter<'_, GraphLoopId, Vec<GraphNodeId>> {
1716 self.loop_nodes.iter()
1717 }
1718
1719 pub fn insert_loop(&mut self, parent_loop: Option<GraphLoopId>) -> GraphLoopId {
1721 let loop_id = self.loop_nodes.insert(Vec::new());
1722 self.loop_children.insert(loop_id, Vec::new());
1723 if let Some(parent_loop) = parent_loop {
1724 self.loop_parent.insert(loop_id, parent_loop);
1725 self.loop_children
1726 .get_mut(parent_loop)
1727 .unwrap()
1728 .push(loop_id);
1729 } else {
1730 self.root_loops.push(loop_id);
1731 }
1732 loop_id
1733 }
1734
1735 pub fn node_loop(&self, node_id: GraphNodeId) -> Option<GraphLoopId> {
1737 self.node_loops.get(node_id).copied()
1738 }
1739
1740 pub fn subgraph_loop(&self, subgraph_id: GraphSubgraphId) -> Option<GraphLoopId> {
1742 let &node_id = self.subgraph(subgraph_id).first().unwrap();
1743 let out = self.node_loop(node_id);
1744 debug_assert!(
1745 self.subgraph(subgraph_id)
1746 .iter()
1747 .all(|&node_id| self.node_loop(node_id) == out),
1748 "Subgraph nodes should all have the same loop context."
1749 );
1750 out
1751 }
1752
1753 pub fn loop_parent(&self, loop_id: GraphLoopId) -> Option<GraphLoopId> {
1755 self.loop_parent.get(loop_id).copied()
1756 }
1757
1758 pub fn loop_children(&self, loop_id: GraphLoopId) -> &Vec<GraphLoopId> {
1760 self.loop_children.get(loop_id).unwrap()
1761 }
1762}
1763
1764#[derive(Clone, Debug, Default)]
1766#[cfg_attr(feature = "clap-derive", derive(clap::Args))]
1767pub struct WriteConfig {
1768 #[cfg_attr(feature = "clap-derive", arg(long))]
1770 pub no_subgraphs: bool,
1771 #[cfg_attr(feature = "clap-derive", arg(long))]
1773 pub no_varnames: bool,
1774 #[cfg_attr(feature = "clap-derive", arg(long))]
1776 pub no_pull_push: bool,
1777 #[cfg_attr(feature = "clap-derive", arg(long))]
1779 pub no_handoffs: bool,
1780 #[cfg_attr(feature = "clap-derive", arg(long))]
1782 pub no_references: bool,
1783 #[cfg_attr(feature = "clap-derive", arg(long))]
1785 pub no_loops: bool,
1786
1787 #[cfg_attr(feature = "clap-derive", arg(long))]
1789 pub op_short_text: bool,
1790 #[cfg_attr(feature = "clap-derive", arg(long))]
1792 pub op_text_no_imports: bool,
1793}
1794
1795#[derive(Copy, Clone, Debug)]
1797#[cfg_attr(feature = "clap-derive", derive(clap::Parser, clap::ValueEnum))]
1798pub enum WriteGraphType {
1799 Mermaid,
1801 Dot,
1803}
1804
1805fn into_group_map<K, V>(iter: impl IntoIterator<Item = (K, V)>) -> BTreeMap<K, Vec<V>>
1807where
1808 K: Ord,
1809{
1810 let mut out: BTreeMap<_, Vec<_>> = BTreeMap::new();
1811 for (k, v) in iter {
1812 out.entry(k).or_default().push(v);
1813 }
1814 out
1815}