Skip to content

Commit b942714

Browse files
meheffernancopybara-github
authored andcommitted
Improve metrics reporting of the optimization pipeline.
This includes a number of fixes/changes: * Gather metrics in un-aggregated form, one proto for each time a pass is run. The data is organized hierarchically mirroring the hiearchical structure of the pass pipeline. * Add a utility which dumps a table from the proto metrics data. The utility produces two tables: one aggregated by pass name, and one hierarchically organized. * Fix bug where the invariant checkers were not being run for some passes, specifically those within CapOptLevel or IfOptLevelAtLeast compound passes. As part of this fix clean up the pass_base code a bit. Make RunNested a method on PassBase which makes CompoundPassBase and PassBase a bit more interchangeable. * Add numerous test for various bits of PassBase functionality which were previously untested. PiperOrigin-RevId: 755096116
1 parent ccaa9aa commit b942714

23 files changed

+1108
-281
lines changed

xls/dev_tools/BUILD

+31
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,7 @@ cc_binary(
613613
srcs = ["benchmark_main.cc"],
614614
visibility = ["//xls:xls_users"],
615615
deps = [
616+
":pipeline_metrics",
616617
"//xls/codegen:module_signature",
617618
"//xls/common:exit_status",
618619
"//xls/common:init_xls",
@@ -838,6 +839,21 @@ cc_library(
838839
],
839840
)
840841

842+
cc_library(
843+
name = "pipeline_metrics",
844+
srcs = ["pipeline_metrics.cc"],
845+
hdrs = ["pipeline_metrics.h"],
846+
deps = [
847+
"//xls/ir",
848+
"//xls/passes:pass_metrics_cc_proto",
849+
"@com_google_absl//absl/container:flat_hash_map",
850+
"@com_google_absl//absl/strings",
851+
"@com_google_absl//absl/strings:str_format",
852+
"@com_google_absl//absl/time",
853+
"@com_google_protobuf//:duration_cc_proto",
854+
],
855+
)
856+
841857
cc_test(
842858
name = "extract_state_element_test",
843859
srcs = ["extract_state_element_test.cc"],
@@ -877,3 +893,18 @@ cc_binary(
877893
"@com_google_absl//absl/types:span",
878894
],
879895
)
896+
897+
cc_binary(
898+
name = "pipeline_metrics_main",
899+
srcs = ["pipeline_metrics_main.cc"],
900+
deps = [
901+
":pipeline_metrics",
902+
"//xls/common:exit_status",
903+
"//xls/common:init_xls",
904+
"//xls/common/file:filesystem",
905+
"//xls/common/status:status_macros",
906+
"//xls/passes:pass_metrics_cc_proto",
907+
"@com_google_absl//absl/log:check",
908+
"@com_google_absl//absl/status",
909+
],
910+
)

xls/dev_tools/benchmark_main.cc

+6-51
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
#include "xls/common/status/ret_check.h"
5050
#include "xls/common/status/status_macros.h"
5151
#include "xls/data_structures/binary_decision_diagram.h"
52+
#include "xls/dev_tools/pipeline_metrics.h"
5253
#include "xls/estimators/delay_model/analyze_critical_path.h"
5354
#include "xls/estimators/delay_model/delay_estimator.h"
5455
#include "xls/estimators/delay_model/delay_estimators.h"
@@ -183,6 +184,7 @@ absl::Status RunOptimizationAndPrintStats(Package* package) {
183184
OptimizationPassOptions pass_options;
184185
int64_t convert_array_index_to_select =
185186
absl::GetFlag(FLAGS_convert_array_index_to_select);
187+
pass_options.record_metrics = true;
186188
pass_options.convert_array_index_to_select =
187189
(convert_array_index_to_select < 0)
188190
? std::nullopt
@@ -203,57 +205,10 @@ absl::Status RunOptimizationAndPrintStats(Package* package) {
203205
std::cout << absl::StreamFormat("Optimization time: %dms\n",
204206
DurationToMs(total_time));
205207
std::cout << absl::StreamFormat("Dynamic pass count: %d\n",
206-
pass_results.invocations.size());
207-
208-
// Aggregate run times by the pass name and print a table of the aggregate
209-
// execution time of each pass in descending order.
210-
absl::flat_hash_map<std::string, absl::Duration> pass_times;
211-
absl::flat_hash_map<std::string, int64_t> pass_counts;
212-
absl::flat_hash_map<std::string, int64_t> changed_counts;
213-
for (const PassInvocation& invocation : pass_results.invocations) {
214-
pass_times[invocation.pass_name] += invocation.run_duration;
215-
++pass_counts[invocation.pass_name];
216-
changed_counts[invocation.pass_name] += invocation.ir_changed ? 1 : 0;
217-
}
218-
std::vector<std::string> pass_names;
219-
for (const auto& pair : pass_times) {
220-
pass_names.push_back(pair.first);
221-
}
222-
std::sort(pass_names.begin(), pass_names.end(),
223-
[&](const std::string& a, const std::string& b) {
224-
// Sort by time (at the same resolution we show), breaking ties by
225-
// # of times run, # of times changed, and finally pass name.
226-
int64_t a_time = DurationToMs(pass_times.at(a));
227-
int64_t b_time = DurationToMs(pass_times.at(b));
228-
if (a_time > b_time) {
229-
return true;
230-
}
231-
if (a_time < b_time) {
232-
return false;
233-
}
234-
if (pass_counts.at(a) > pass_counts.at(b)) {
235-
return true;
236-
}
237-
if (pass_counts.at(a) < pass_counts.at(b)) {
238-
return false;
239-
}
240-
if (changed_counts.at(a) > changed_counts.at(b)) {
241-
return true;
242-
}
243-
if (changed_counts.at(a) < changed_counts.at(b)) {
244-
return false;
245-
}
246-
return a > b;
247-
});
248-
std::cout << "Pass run durations (# of times pass changed IR / # of times "
249-
"pass was run):"
250-
<< '\n';
251-
for (const std::string& name : pass_names) {
252-
std::cout << absl::StreamFormat(" %-20s : %-5dms (%3d / %3d)\n", name,
253-
DurationToMs(pass_times.at(name)),
254-
changed_counts.at(name),
255-
pass_counts.at(name));
256-
}
208+
pass_results.total_invocations);
209+
210+
// Print table(s) of pass metrics.
211+
std::cout << SummarizePipelineMetrics(pass_results.ToProto());
257212
return absl::OkStatus();
258213
}
259214

xls/dev_tools/check_ir_equivalence_main.cc

+2-2
Original file line numberDiff line numberDiff line change
@@ -179,13 +179,13 @@ class ProcStateLegalizationPassShim : public OptimizationFunctionBasePass {
179179
SchedulingUnit sched = SchedulingUnit::CreateForSingleFunction(fb);
180180
SchedulingPassResults results;
181181
if (pass_results) {
182-
results.invocations = std::move(pass_results->invocations);
182+
results.invocation = std::move(pass_results->invocation);
183183
}
184184
XLS_ASSIGN_OR_RETURN(bool res,
185185
proc_state_sched_pass_.RunOnFunctionBase(
186186
fb, &sched, SchedulingPassOptions(), &results));
187187
if (pass_results) {
188-
pass_results->invocations = std::move(results.invocations);
188+
pass_results->invocation = std::move(results.invocation);
189189
}
190190
return res;
191191
}

xls/dev_tools/pipeline_metrics.cc

+244
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
// Copyright 2025 The XLS Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "xls/dev_tools/pipeline_metrics.h"
16+
17+
#include <algorithm>
18+
#include <cstdint>
19+
#include <numeric>
20+
#include <optional>
21+
#include <string>
22+
#include <string_view>
23+
#include <tuple>
24+
#include <utility>
25+
#include <vector>
26+
27+
#include "google/protobuf/duration.pb.h"
28+
#include "absl/container/flat_hash_map.h"
29+
#include "absl/strings/str_cat.h"
30+
#include "absl/strings/str_format.h"
31+
#include "absl/strings/str_join.h"
32+
#include "absl/time/time.h"
33+
#include "xls/ir/package.h"
34+
#include "xls/passes/pass_metrics.pb.h"
35+
36+
namespace xls {
37+
namespace {
38+
39+
int64_t DurationToMs(absl::Duration duration) {
40+
return duration / absl::Milliseconds(1);
41+
}
42+
43+
// Struct holding the aggregation of multiple PassResultProtos.
44+
struct AggregateMetrics {
45+
std::string pass_name;
46+
int64_t run_count = 0;
47+
int64_t changed_count = 0;
48+
absl::Duration run_duration;
49+
TransformMetrics metrics;
50+
51+
AggregateMetrics operator+(const AggregateMetrics& other) const {
52+
AggregateMetrics result;
53+
result.pass_name = (pass_name.empty() || pass_name == other.pass_name)
54+
? other.pass_name
55+
: "";
56+
result.run_count = run_count + other.run_count;
57+
result.changed_count = changed_count + other.changed_count;
58+
result.run_duration = run_duration + other.run_duration;
59+
result.metrics = metrics + other.metrics;
60+
return result;
61+
}
62+
63+
static AggregateMetrics FromProto(const PassResultProto& proto) {
64+
AggregateMetrics metrics;
65+
metrics.pass_name = proto.pass_name();
66+
metrics.run_count = 1;
67+
metrics.changed_count = proto.changed() ? 1 : 0;
68+
metrics.run_duration = absl::Seconds(proto.pass_duration().seconds()) +
69+
absl::Nanoseconds(proto.pass_duration().nanos());
70+
metrics.metrics = TransformMetrics::FromProto(proto.metrics());
71+
return metrics;
72+
}
73+
};
74+
75+
void AggregatePassResultsInternal(
76+
const PassResultProto& proto,
77+
absl::flat_hash_map<std::string, AggregateMetrics>& metrics_map) {
78+
if (proto.nested_results().empty()) {
79+
// Non-compound pass.
80+
AggregateMetrics& metrics = metrics_map[proto.pass_name()];
81+
metrics = metrics + AggregateMetrics::FromProto(proto);
82+
} else {
83+
for (const PassResultProto& nested_proto : proto.nested_results()) {
84+
AggregatePassResultsInternal(nested_proto, metrics_map);
85+
}
86+
}
87+
}
88+
89+
// Recursively walk the pass results within `proto` and aggregate the metrics by
90+
// pass name. Returns a vector sorted (decreasing) by run time.
91+
std::vector<AggregateMetrics> AggregatePassResults(
92+
const PassResultProto& proto) {
93+
absl::flat_hash_map<std::string, AggregateMetrics> metrics_map;
94+
AggregatePassResultsInternal(proto, metrics_map);
95+
std::vector<AggregateMetrics> metrics;
96+
for (auto& [_, m] : metrics_map) {
97+
metrics.push_back(m);
98+
}
99+
std::sort(metrics.begin(), metrics.end(),
100+
[&](const AggregateMetrics& a, const AggregateMetrics& b) {
101+
// Sort by time (at the same resolution we show), breaking ties by
102+
// # of times run, # of times changed, and finally pass name.
103+
auto key = [](const AggregateMetrics& x) {
104+
return std::tuple(DurationToMs(x.run_duration), x.run_count,
105+
x.changed_count, x.pass_name);
106+
};
107+
// Sort high to low.
108+
return key(a) > key(b);
109+
});
110+
return metrics;
111+
}
112+
113+
void BuildHierarchicalTableInternal(
114+
const PassResultProto& proto, int64_t indent,
115+
std::vector<std::string>& lines,
116+
std::optional<AggregateMetrics>& collapsed_summary_metrics) {
117+
std::string indent_str(indent * 2, ' ');
118+
if (proto.nested_results().empty()) {
119+
// Collapse sequences of non-compound passes into a single line.
120+
if (!collapsed_summary_metrics.has_value()) {
121+
collapsed_summary_metrics = AggregateMetrics::FromProto(proto);
122+
} else {
123+
*collapsed_summary_metrics =
124+
*collapsed_summary_metrics + AggregateMetrics::FromProto(proto);
125+
}
126+
return;
127+
}
128+
129+
auto add_line = [&](std::string_view pass_name,
130+
const AggregateMetrics& metrics) {
131+
lines.push_back(absl::StrFormat(
132+
"%-55s %6dms %4d/%4d %8d(+)/%8d(-)/%8d(R) "
133+
" %8d(-)/%8d(R)",
134+
indent_str + std::string{pass_name}, DurationToMs(metrics.run_duration),
135+
metrics.changed_count, metrics.run_count, metrics.metrics.nodes_added,
136+
metrics.metrics.nodes_removed, metrics.metrics.nodes_replaced,
137+
metrics.metrics.operands_removed, metrics.metrics.operands_replaced));
138+
};
139+
140+
auto maybe_add_summary_line = [&](bool extra_indent) {
141+
if (collapsed_summary_metrics.has_value()) {
142+
add_line(absl::StrFormat("%s[%d passes run]", extra_indent ? " " : "",
143+
collapsed_summary_metrics->run_count),
144+
*collapsed_summary_metrics);
145+
collapsed_summary_metrics.reset();
146+
}
147+
};
148+
149+
maybe_add_summary_line(false);
150+
151+
std::vector<std::pair<int64_t, int64_t>> intervals;
152+
if (proto.fixed_point_iterations() > 0) {
153+
// Fixed-point pass. Break the nested results into iterations.
154+
int64_t end = 0;
155+
int64_t pass_count =
156+
proto.nested_results().size() / proto.fixed_point_iterations();
157+
while (end < proto.nested_results().size()) {
158+
int64_t next_end =
159+
std::min(int64_t{proto.nested_results().size()}, end + pass_count);
160+
intervals.push_back({end, next_end});
161+
end = next_end;
162+
}
163+
} else {
164+
// Non-fixed point pass. Aggregate all nested results together.
165+
intervals.push_back({0, proto.nested_results().size()});
166+
}
167+
168+
int64_t iteration = 0;
169+
for (auto [start, end] : intervals) {
170+
AggregateMetrics interval_metrics;
171+
for (int64_t i = start; i < end; ++i) {
172+
interval_metrics = interval_metrics +
173+
AggregateMetrics::FromProto(proto.nested_results()[i]);
174+
}
175+
std::string pass_name =
176+
proto.fixed_point_iterations() > 0
177+
? absl::StrFormat("%s [iter #%d]", proto.pass_name(), iteration)
178+
: proto.pass_name();
179+
add_line(pass_name, interval_metrics);
180+
for (int64_t i = start; i < end; ++i) {
181+
const PassResultProto& nested_proto = proto.nested_results()[i];
182+
BuildHierarchicalTableInternal(nested_proto, indent + 1, lines,
183+
collapsed_summary_metrics);
184+
}
185+
186+
maybe_add_summary_line(true);
187+
188+
++iteration;
189+
}
190+
}
191+
192+
// Returns the lines of a table which mirrors the hierarchical structure of the
193+
// (compound) passes which generated the metrics in `proto`.
194+
std::string BuildHierarchicalTable(const PassResultProto& proto) {
195+
std::vector<std::string> lines;
196+
std::optional<AggregateMetrics> collapsed_summary_metrics;
197+
BuildHierarchicalTableInternal(proto, 0, lines, collapsed_summary_metrics);
198+
return absl::StrCat(absl::StrJoin(lines, "\n"), "\n");
199+
}
200+
201+
} // namespace
202+
203+
std::string SummarizePipelineMetrics(const PipelineMetricsProto& metrics) {
204+
// The metrics object is recursive. Aggregate the results by pass name.
205+
std::vector<AggregateMetrics> aggregate_metrics =
206+
AggregatePassResults(metrics.pass_results());
207+
std::string str = "Aggregate pass statistics:\n\n";
208+
absl::StrAppendFormat(
209+
&str,
210+
"%-30s Duration Runs: changed/total Nodes "
211+
"added(+)/removed(-)/replaced(R) Operands removed(-)/replaced(R)\n",
212+
"Pass name");
213+
std::string divider = std::string(135, '-') + "\n";
214+
absl::StrAppend(&str, divider);
215+
auto make_line = [](const AggregateMetrics& metric) {
216+
return absl::StrFormat(
217+
" %-30s %6dms %4d/%4d %8d(+)/%8d(-)/%8d(R) "
218+
" %8d(-)/%8d(R)\n",
219+
metric.pass_name, DurationToMs(metric.run_duration),
220+
metric.changed_count, metric.run_count, metric.metrics.nodes_added,
221+
metric.metrics.nodes_removed, metric.metrics.nodes_replaced,
222+
metric.metrics.operands_removed, metric.metrics.operands_replaced);
223+
};
224+
for (const AggregateMetrics& metric : aggregate_metrics) {
225+
absl::StrAppend(&str, make_line(metric));
226+
}
227+
absl::StrAppend(&str, divider);
228+
AggregateMetrics total = std::accumulate(
229+
aggregate_metrics.begin(), aggregate_metrics.end(), AggregateMetrics());
230+
total.pass_name = "Total";
231+
absl::StrAppend(&str, make_line(total));
232+
233+
absl::StrAppend(&str, "\n\nHierarchical pass statistics:\n\n");
234+
absl::StrAppendFormat(
235+
&str,
236+
"%-55s Duration Runs: changed/total Nodes "
237+
"added(+)/removed(-)/replaced(R) Operands removed(-)/replaced(R)\n",
238+
"Pass name");
239+
absl::StrAppend(&str, std::string(161, '-') + "\n");
240+
absl::StrAppend(&str, BuildHierarchicalTable(metrics.pass_results()));
241+
return str;
242+
}
243+
244+
} // namespace xls

0 commit comments

Comments
 (0)