40
40
#include " xls/passes/optimization_pass.h"
41
41
#include " xls/passes/optimization_pass_registry.h"
42
42
#include " xls/passes/pass_base.h"
43
+ #include " xls/passes/post_dominator_analysis.h"
43
44
#include " ortools/graph/graph.h"
44
45
#include " ortools/graph/cliques.h"
45
46
@@ -260,12 +261,26 @@ bool AreMutuallyExclusive(
260
261
return false ;
261
262
}
262
263
264
+ // This function returns true if @node_to_check reaches @point_in_the_graph,
265
+ // false otherwise.
266
+ bool DoesReach (const absl::flat_hash_map<Node *, absl::flat_hash_set<Node *>>
267
+ &reachability_result,
268
+ Node *point_in_the_graph, Node *node_to_check) {
269
+ // Fetch the set of nodes that reach @point_in_the_graph
270
+ auto iter = reachability_result.find (point_in_the_graph);
271
+ CHECK (iter != reachability_result.end ());
272
+ const absl::flat_hash_set<Node *> &reaching_nodes = iter->second ;
273
+
274
+ // Check if the specified node reaches @point_in_the_graph
275
+ return reaching_nodes.contains (node_to_check);
276
+ }
277
+
263
278
std::optional<uint32_t > GetSelectCaseNumberOfNode (
264
279
const absl::flat_hash_map<Node *, absl::flat_hash_set<Node *>>
265
280
&reachability_result,
266
281
Node *node, Node *select_case, uint32_t select_case_number) {
267
282
// Check if @node reaches the select case given as input
268
- if (!reachability_result. at ( select_case). contains ( node)) {
283
+ if (!DoesReach (reachability_result, select_case, node)) {
269
284
return {};
270
285
}
271
286
@@ -408,8 +423,8 @@ ComputeReachabilityAnalysis(FunctionBase *f, OptimizationContext &context) {
408
423
// In other words, some mutually-exclusive pairs of instructions might not be
409
424
// detected by this analysis.
410
425
// Hence, this analysis can be improved in the future.
411
- absl::flat_hash_map<Node *,
412
- absl::flat_hash_map<Node *, absl::flat_hash_set<Node *>>>
426
+ absl::StatusOr<absl:: flat_hash_map<
427
+ Node *, absl::flat_hash_map<Node *, absl::flat_hash_set<Node *> >>>
413
428
ComputeMutualExclusionAnalysis (
414
429
FunctionBase *f, OptimizationContext &context,
415
430
absl::flat_hash_map<Node *, absl::flat_hash_set<Node *>>
@@ -418,14 +433,30 @@ ComputeMutualExclusionAnalysis(
418
433
absl::flat_hash_map<Node *, absl::flat_hash_set<Node *>>>
419
434
mutual_exclusivity_relation;
420
435
436
+ // Run the post-dominance analysis.
437
+ //
438
+ // The result of this analysis is used to determine whether a given node @n is
439
+ // mutually exclusive with another node @m where both of them reach a select
440
+ // node @s.
441
+ //
442
+ // In more detail, the post-dominance analysis result is used to guarantee the
443
+ // following code will lead to conclude @n isn't mutually exclusive with @m:
444
+ // n = ...
445
+ // m = ...
446
+ // s = priority_sel(selector, cases=[n, m])
447
+ // i = add(n, s)
448
+ // return i
449
+ XLS_ASSIGN_OR_RETURN (std::unique_ptr<PostDominatorAnalysis> post_dominators,
450
+ PostDominatorAnalysis::Run (f));
451
+
421
452
// Compute the mutual exclusion binary relation between instructions
422
453
for (const auto &[n, s] : reachability_result) {
423
454
// Find the next select
424
455
if (!n->OpIn ({Op::kSel , Op::kPrioritySel , Op::kOneHotSel })) {
425
456
continue ;
426
457
}
427
458
428
- // Compute the mutually-exclusive instructions
459
+ // Compute the mutually-exclusive instructions created by the select @n
429
460
absl::Span<Node *const > cases = GetCases (n);
430
461
for (uint32_t case_number = 0 ; case_number < cases.length ();
431
462
case_number++) {
@@ -435,31 +466,60 @@ ComputeMutualExclusionAnalysis(
435
466
// are mutually exclusive with the nodes that reach the next cases
436
467
for (Node *current_case_reaching_node :
437
468
reachability_result[current_case]) {
469
+ // Do not bother looking at nodes that we will not be able to fold
470
+ if (!CanTarget (current_case_reaching_node)) {
471
+ continue ;
472
+ }
473
+
474
+ // Only nodes that are post-dominated by the select are considered to be
475
+ // mutually exclusive with another node
476
+ if (!post_dominators->NodeIsPostDominatedBy (current_case_reaching_node,
477
+ n)) {
478
+ continue ;
479
+ }
480
+
438
481
// Check if the current reaching node reaches the other cases
439
482
for (uint32_t case_number_2 = case_number + 1 ;
440
483
case_number_2 < cases.length (); case_number_2++) {
441
484
Node *current_case_2 = cases[case_number_2];
442
- if (reachability_result[ current_case_2]. contains (
443
- current_case_reaching_node)) {
485
+ if (DoesReach ( reachability_result, current_case_2,
486
+ current_case_reaching_node)) {
444
487
continue ;
445
488
}
446
489
447
- // The current reaching node does not reach the current other case.
490
+ // The current reaching node @current_case_reaching_node does not
491
+ // reach the other case @current_case_2.
492
+ //
448
493
// Add as mutually-exclusive all reaching nodes of the current other
449
- // case that also do not reach @current_case_reaching_node
494
+ // case @current_case_2 that also do not reach
495
+ // @current_case_reaching_node.
450
496
for (Node *other_case_reaching_node :
451
497
reachability_result[current_case_2]) {
452
- if (reachability_result[current_case].contains (
453
- other_case_reaching_node)) {
498
+ // Do not bother looking at nodes that we will not be able to fold
499
+ if (!CanTarget (other_case_reaching_node)) {
500
+ continue ;
501
+ }
502
+
503
+ // Only nodes that are post-dominated by the select are considered
504
+ // to be mutually exclusive with another node
505
+ if (!post_dominators->NodeIsPostDominatedBy (
506
+ other_case_reaching_node, n)) {
454
507
continue ;
455
508
}
456
- if (current_case_reaching_node < other_case_reaching_node) {
457
- mutual_exclusivity_relation[n][current_case_reaching_node]. insert (
458
- other_case_reaching_node);
459
- } else {
460
- mutual_exclusivity_relation[n][ other_case_reaching_node]. insert (
461
- current_case_reaching_node) ;
509
+
510
+ // If @other_case_reaching_node reaches @current_case, then it
511
+ // cannot be mutually exclusive with @current_case_reaching_node
512
+ if ( DoesReach (reachability_result, current_case,
513
+ other_case_reaching_node)) {
514
+ continue ;
462
515
}
516
+
517
+ // @current_case_reaching_node and @other_case_reaching_node are
518
+ // mutually exclusive.
519
+ mutual_exclusivity_relation[n][current_case_reaching_node].insert (
520
+ other_case_reaching_node);
521
+ mutual_exclusivity_relation[n][other_case_reaching_node].insert (
522
+ current_case_reaching_node);
463
523
}
464
524
}
465
525
}
@@ -473,7 +533,7 @@ ComputeMutualExclusionAnalysis(
473
533
for (const auto &[n0, s0] : mer.second ) {
474
534
VLOG (4 ) << " " << n0->ToString ();
475
535
for (auto n1 : s0) {
476
- VLOG (4 ) << " -> " << n1->ToString ();
536
+ VLOG (4 ) << " < -> " << n1->ToString ();
477
537
}
478
538
}
479
539
}
@@ -491,14 +551,25 @@ std::vector<std::unique_ptr<BinaryFoldingAction>> ComputeFoldableActions(
491
551
&reachability_result) {
492
552
std::vector<std::unique_ptr<BinaryFoldingAction>> foldable_actions;
493
553
494
- // Compute all possible foldable actions
554
+ // Identify as many folding actions that are legal as possible by our current
555
+ // analyses.
495
556
for (const auto &mer : mutual_exclusivity_relation) {
496
557
Node *select = mer.first ;
497
558
for (const auto &[n0, s0] : mer.second ) {
559
+ // Skip nodes that cannot be folded
498
560
if (!CanTarget (n0)) {
499
561
continue ;
500
562
}
501
- for (auto n1 : s0) {
563
+
564
+ // Find nodes that n0 can fold into
565
+ for (Node *n1 : s0) {
566
+ // Since the mutual exclusive relation is symmetric, use only one side
567
+ // of it
568
+ if (n1->id () < n0->id ()) {
569
+ continue ;
570
+ }
571
+
572
+ // Skip nodes that cannot be folded
502
573
if (!CanTarget (n1)) {
503
574
continue ;
504
575
}
@@ -647,6 +718,7 @@ absl::StatusOr<bool> PerformFoldingActions(
647
718
// Fold
648
719
//
649
720
// - Step 0: Get the subset of the bits of the selector that are relevant
721
+ VLOG (3 ) << " Step 0: generate the new selector" ;
650
722
Node *selector = folding->GetSelector ();
651
723
std::vector<Node *> from_bits;
652
724
from_bits.reserve (froms_to_use.size ());
@@ -659,8 +731,11 @@ absl::StatusOr<bool> PerformFoldingActions(
659
731
absl::c_reverse (from_bits);
660
732
XLS_ASSIGN_OR_RETURN (Node * new_selector,
661
733
f->MakeNode <Concat>(selector->loc (), from_bits));
734
+ VLOG (3 ) << " " << new_selector->ToString ();
662
735
663
736
// - Step 1: Create a new select for each input
737
+ VLOG (3 ) << " Step 1: generate the priority selects, one per input of "
738
+ " the folding target" ;
664
739
std::vector<Node *> new_operands;
665
740
for (uint32_t op_id = 0 ; op_id < to_node->operand_count (); op_id++) {
666
741
// Fetch the current operand for the target of the folding action.
@@ -679,27 +754,34 @@ absl::StatusOr<bool> PerformFoldingActions(
679
754
f->MakeNode <PrioritySelect>(selector->loc (), new_selector,
680
755
operand_select_cases, to_operand));
681
756
new_operands.push_back (operand_select);
757
+ VLOG (3 ) << " " << operand_select->ToString ();
682
758
}
683
759
CHECK_EQ (new_operands.size (), 2 );
684
760
685
761
// - Step 2: Replace the operands of the @to_node to use the results of the
686
762
// new selectors computed at Step 1.
687
- for (uint32_t op_id = 0 ; op_id < to_node->operand_count (); op_id++) {
763
+ VLOG (3 ) << " Step 2: update the target of the folding transformation" ;
764
+ for (int64_t op_id = 0LL ; op_id < to_node->operand_count (); op_id++) {
688
765
XLS_RETURN_IF_ERROR (
689
766
to_node->ReplaceOperandNumber (op_id, new_operands[op_id], true ));
690
767
}
768
+ VLOG (3 ) << " " << to_node->ToString ();
691
769
692
770
// - Step 3: Replace every source of the folding action with the new
693
771
// @to_node
772
+ VLOG (3 )
773
+ << " Step 3: update the def-use chains to use the new folded node" ;
694
774
for (auto [from_node, from_node_case_number] : froms_to_use) {
695
775
XLS_RETURN_IF_ERROR (from_node->ReplaceUsesWith (to_node));
696
776
}
697
777
698
778
// - Step 4: Remove all the sources of the folding action as they are now
699
779
// dead
780
+ VLOG (3 ) << " Step 4: remove the sources of the folding transformation" ;
700
781
for (auto [from_node, from_node_case_number] : froms_to_use) {
701
782
XLS_RETURN_IF_ERROR (f->RemoveNode (from_node));
702
783
}
784
+ VLOG (3 ) << " Folding completed" ;
703
785
}
704
786
705
787
return modified;
@@ -728,8 +810,10 @@ absl::StatusOr<bool> ResourceSharingPass::RunOnFunctionBaseInternal(
728
810
// Compute the mutually exclusive binary relation between IR instructions
729
811
absl::flat_hash_map<Node *,
730
812
absl::flat_hash_map<Node *, absl::flat_hash_set<Node *>>>
731
- mutual_exclusivity_relation =
732
- ComputeMutualExclusionAnalysis (f, context, reachability_result);
813
+ mutual_exclusivity_relation;
814
+ XLS_ASSIGN_OR_RETURN (
815
+ mutual_exclusivity_relation,
816
+ ComputeMutualExclusionAnalysis (f, context, reachability_result));
733
817
734
818
// Identify the set of legal folding actions
735
819
std::vector<std::unique_ptr<BinaryFoldingAction>> foldable_actions =
0 commit comments