Skip to content

Commit 2e60753

Browse files
scampanonicopybara-github
authored andcommitted
This commit extends the resource sharing pass to correctly handle multiplications used (directly or indirectly via def-use chains) across different select cases and when at least one of these multiplications is not post-dominated by the select.
In such scenarios, these multiplications are now correctly identified as not mutually exclusive. A new test is added to automatically verify this code pattern. Additionally, this commit includes minor code improvements: extracting the node reachability check into a common function, adding comments, and enhancing log outputs. PiperOrigin-RevId: 749563145
1 parent 1227aa1 commit 2e60753

File tree

3 files changed

+240
-120
lines changed

3 files changed

+240
-120
lines changed

xls/passes/BUILD

+2
Original file line numberDiff line numberDiff line change
@@ -974,6 +974,7 @@ cc_library(
974974
":optimization_pass",
975975
":optimization_pass_registry",
976976
":pass_base",
977+
":post_dominator_analysis",
977978
"//xls/common/status:status_macros",
978979
"//xls/ir",
979980
"//xls/ir:op",
@@ -982,6 +983,7 @@ cc_library(
982983
"@com_google_absl//absl/container:flat_hash_set",
983984
"@com_google_absl//absl/log",
984985
"@com_google_absl//absl/log:check",
986+
"@com_google_absl//absl/status",
985987
"@com_google_absl//absl/status:statusor",
986988
"@com_google_absl//absl/types:span",
987989
"@com_google_ortools//ortools/graph",

xls/passes/resource_sharing_pass.cc

+106-22
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
#include "xls/passes/optimization_pass.h"
4141
#include "xls/passes/optimization_pass_registry.h"
4242
#include "xls/passes/pass_base.h"
43+
#include "xls/passes/post_dominator_analysis.h"
4344
#include "ortools/graph/graph.h"
4445
#include "ortools/graph/cliques.h"
4546

@@ -260,12 +261,26 @@ bool AreMutuallyExclusive(
260261
return false;
261262
}
262263

264+
// This function returns true if @node_to_check reaches @point_in_the_graph,
265+
// false otherwise.
266+
bool DoesReach(const absl::flat_hash_map<Node *, absl::flat_hash_set<Node *>>
267+
&reachability_result,
268+
Node *point_in_the_graph, Node *node_to_check) {
269+
// Fetch the set of nodes that reach @point_in_the_graph
270+
auto iter = reachability_result.find(point_in_the_graph);
271+
CHECK(iter != reachability_result.end());
272+
const absl::flat_hash_set<Node *> &reaching_nodes = iter->second;
273+
274+
// Check if the specified node reaches @point_in_the_graph
275+
return reaching_nodes.contains(node_to_check);
276+
}
277+
263278
std::optional<uint32_t> GetSelectCaseNumberOfNode(
264279
const absl::flat_hash_map<Node *, absl::flat_hash_set<Node *>>
265280
&reachability_result,
266281
Node *node, Node *select_case, uint32_t select_case_number) {
267282
// Check if @node reaches the select case given as input
268-
if (!reachability_result.at(select_case).contains(node)) {
283+
if (!DoesReach(reachability_result, select_case, node)) {
269284
return {};
270285
}
271286

@@ -408,8 +423,8 @@ ComputeReachabilityAnalysis(FunctionBase *f, OptimizationContext &context) {
408423
// In other words, some mutually-exclusive pairs of instructions might not be
409424
// detected by this analysis.
410425
// Hence, this analysis can be improved in the future.
411-
absl::flat_hash_map<Node *,
412-
absl::flat_hash_map<Node *, absl::flat_hash_set<Node *>>>
426+
absl::StatusOr<absl::flat_hash_map<
427+
Node *, absl::flat_hash_map<Node *, absl::flat_hash_set<Node *>>>>
413428
ComputeMutualExclusionAnalysis(
414429
FunctionBase *f, OptimizationContext &context,
415430
absl::flat_hash_map<Node *, absl::flat_hash_set<Node *>>
@@ -418,14 +433,30 @@ ComputeMutualExclusionAnalysis(
418433
absl::flat_hash_map<Node *, absl::flat_hash_set<Node *>>>
419434
mutual_exclusivity_relation;
420435

436+
// Run the post-dominance analysis.
437+
//
438+
// The result of this analysis is used to determine whether a given node @n is
439+
// mutually exclusive with another node @m where both of them reach a select
440+
// node @s.
441+
//
442+
// In more detail, the post-dominance analysis result is used to guarantee the
443+
// following code will lead to conclude @n isn't mutually exclusive with @m:
444+
// n = ...
445+
// m = ...
446+
// s = priority_sel(selector, cases=[n, m])
447+
// i = add(n, s)
448+
// return i
449+
XLS_ASSIGN_OR_RETURN(std::unique_ptr<PostDominatorAnalysis> post_dominators,
450+
PostDominatorAnalysis::Run(f));
451+
421452
// Compute the mutual exclusion binary relation between instructions
422453
for (const auto &[n, s] : reachability_result) {
423454
// Find the next select
424455
if (!n->OpIn({Op::kSel, Op::kPrioritySel, Op::kOneHotSel})) {
425456
continue;
426457
}
427458

428-
// Compute the mutually-exclusive instructions
459+
// Compute the mutually-exclusive instructions created by the select @n
429460
absl::Span<Node *const> cases = GetCases(n);
430461
for (uint32_t case_number = 0; case_number < cases.length();
431462
case_number++) {
@@ -435,31 +466,60 @@ ComputeMutualExclusionAnalysis(
435466
// are mutually exclusive with the nodes that reach the next cases
436467
for (Node *current_case_reaching_node :
437468
reachability_result[current_case]) {
469+
// Do not bother looking at nodes that we will not be able to fold
470+
if (!CanTarget(current_case_reaching_node)) {
471+
continue;
472+
}
473+
474+
// Only nodes that are post-dominated by the select are considered to be
475+
// mutually exclusive with another node
476+
if (!post_dominators->NodeIsPostDominatedBy(current_case_reaching_node,
477+
n)) {
478+
continue;
479+
}
480+
438481
// Check if the current reaching node reaches the other cases
439482
for (uint32_t case_number_2 = case_number + 1;
440483
case_number_2 < cases.length(); case_number_2++) {
441484
Node *current_case_2 = cases[case_number_2];
442-
if (reachability_result[current_case_2].contains(
443-
current_case_reaching_node)) {
485+
if (DoesReach(reachability_result, current_case_2,
486+
current_case_reaching_node)) {
444487
continue;
445488
}
446489

447-
// The current reaching node does not reach the current other case.
490+
// The current reaching node @current_case_reaching_node does not
491+
// reach the other case @current_case_2.
492+
//
448493
// Add as mutually-exclusive all reaching nodes of the current other
449-
// case that also do not reach @current_case_reaching_node
494+
// case @current_case_2 that also do not reach
495+
// @current_case_reaching_node.
450496
for (Node *other_case_reaching_node :
451497
reachability_result[current_case_2]) {
452-
if (reachability_result[current_case].contains(
453-
other_case_reaching_node)) {
498+
// Do not bother looking at nodes that we will not be able to fold
499+
if (!CanTarget(other_case_reaching_node)) {
500+
continue;
501+
}
502+
503+
// Only nodes that are post-dominated by the select are considered
504+
// to be mutually exclusive with another node
505+
if (!post_dominators->NodeIsPostDominatedBy(
506+
other_case_reaching_node, n)) {
454507
continue;
455508
}
456-
if (current_case_reaching_node < other_case_reaching_node) {
457-
mutual_exclusivity_relation[n][current_case_reaching_node].insert(
458-
other_case_reaching_node);
459-
} else {
460-
mutual_exclusivity_relation[n][other_case_reaching_node].insert(
461-
current_case_reaching_node);
509+
510+
// If @other_case_reaching_node reaches @current_case, then it
511+
// cannot be mutually exclusive with @current_case_reaching_node
512+
if (DoesReach(reachability_result, current_case,
513+
other_case_reaching_node)) {
514+
continue;
462515
}
516+
517+
// @current_case_reaching_node and @other_case_reaching_node are
518+
// mutually exclusive.
519+
mutual_exclusivity_relation[n][current_case_reaching_node].insert(
520+
other_case_reaching_node);
521+
mutual_exclusivity_relation[n][other_case_reaching_node].insert(
522+
current_case_reaching_node);
463523
}
464524
}
465525
}
@@ -473,7 +533,7 @@ ComputeMutualExclusionAnalysis(
473533
for (const auto &[n0, s0] : mer.second) {
474534
VLOG(4) << " " << n0->ToString();
475535
for (auto n1 : s0) {
476-
VLOG(4) << " -> " << n1->ToString();
536+
VLOG(4) << " <-> " << n1->ToString();
477537
}
478538
}
479539
}
@@ -491,14 +551,25 @@ std::vector<std::unique_ptr<BinaryFoldingAction>> ComputeFoldableActions(
491551
&reachability_result) {
492552
std::vector<std::unique_ptr<BinaryFoldingAction>> foldable_actions;
493553

494-
// Compute all possible foldable actions
554+
// Identify as many folding actions that are legal as possible by our current
555+
// analyses.
495556
for (const auto &mer : mutual_exclusivity_relation) {
496557
Node *select = mer.first;
497558
for (const auto &[n0, s0] : mer.second) {
559+
// Skip nodes that cannot be folded
498560
if (!CanTarget(n0)) {
499561
continue;
500562
}
501-
for (auto n1 : s0) {
563+
564+
// Find nodes that n0 can fold into
565+
for (Node *n1 : s0) {
566+
// Since the mutual exclusive relation is symmetric, use only one side
567+
// of it
568+
if (n1->id() < n0->id()) {
569+
continue;
570+
}
571+
572+
// Skip nodes that cannot be folded
502573
if (!CanTarget(n1)) {
503574
continue;
504575
}
@@ -647,6 +718,7 @@ absl::StatusOr<bool> PerformFoldingActions(
647718
// Fold
648719
//
649720
// - Step 0: Get the subset of the bits of the selector that are relevant
721+
VLOG(3) << " Step 0: generate the new selector";
650722
Node *selector = folding->GetSelector();
651723
std::vector<Node *> from_bits;
652724
from_bits.reserve(froms_to_use.size());
@@ -659,8 +731,11 @@ absl::StatusOr<bool> PerformFoldingActions(
659731
absl::c_reverse(from_bits);
660732
XLS_ASSIGN_OR_RETURN(Node * new_selector,
661733
f->MakeNode<Concat>(selector->loc(), from_bits));
734+
VLOG(3) << " " << new_selector->ToString();
662735

663736
// - Step 1: Create a new select for each input
737+
VLOG(3) << " Step 1: generate the priority selects, one per input of "
738+
"the folding target";
664739
std::vector<Node *> new_operands;
665740
for (uint32_t op_id = 0; op_id < to_node->operand_count(); op_id++) {
666741
// Fetch the current operand for the target of the folding action.
@@ -679,27 +754,34 @@ absl::StatusOr<bool> PerformFoldingActions(
679754
f->MakeNode<PrioritySelect>(selector->loc(), new_selector,
680755
operand_select_cases, to_operand));
681756
new_operands.push_back(operand_select);
757+
VLOG(3) << " " << operand_select->ToString();
682758
}
683759
CHECK_EQ(new_operands.size(), 2);
684760

685761
// - Step 2: Replace the operands of the @to_node to use the results of the
686762
// new selectors computed at Step 1.
687-
for (uint32_t op_id = 0; op_id < to_node->operand_count(); op_id++) {
763+
VLOG(3) << " Step 2: update the target of the folding transformation";
764+
for (int64_t op_id = 0LL; op_id < to_node->operand_count(); op_id++) {
688765
XLS_RETURN_IF_ERROR(
689766
to_node->ReplaceOperandNumber(op_id, new_operands[op_id], true));
690767
}
768+
VLOG(3) << " " << to_node->ToString();
691769

692770
// - Step 3: Replace every source of the folding action with the new
693771
// @to_node
772+
VLOG(3)
773+
<< " Step 3: update the def-use chains to use the new folded node";
694774
for (auto [from_node, from_node_case_number] : froms_to_use) {
695775
XLS_RETURN_IF_ERROR(from_node->ReplaceUsesWith(to_node));
696776
}
697777

698778
// - Step 4: Remove all the sources of the folding action as they are now
699779
// dead
780+
VLOG(3) << " Step 4: remove the sources of the folding transformation";
700781
for (auto [from_node, from_node_case_number] : froms_to_use) {
701782
XLS_RETURN_IF_ERROR(f->RemoveNode(from_node));
702783
}
784+
VLOG(3) << " Folding completed";
703785
}
704786

705787
return modified;
@@ -728,8 +810,10 @@ absl::StatusOr<bool> ResourceSharingPass::RunOnFunctionBaseInternal(
728810
// Compute the mutually exclusive binary relation between IR instructions
729811
absl::flat_hash_map<Node *,
730812
absl::flat_hash_map<Node *, absl::flat_hash_set<Node *>>>
731-
mutual_exclusivity_relation =
732-
ComputeMutualExclusionAnalysis(f, context, reachability_result);
813+
mutual_exclusivity_relation;
814+
XLS_ASSIGN_OR_RETURN(
815+
mutual_exclusivity_relation,
816+
ComputeMutualExclusionAnalysis(f, context, reachability_result));
733817

734818
// Identify the set of legal folding actions
735819
std::vector<std::unique_ptr<BinaryFoldingAction>> foldable_actions =

0 commit comments

Comments
 (0)