Skip to content

Commit debaa02

Browse files
richmckeevercopybara-github
authored andcommitted
Fix the speed of the slow builtin function unit tests in TIv2.
Some of the tests exercise a flaw in unification caching where an InferImplicitParametrics call may repeatedly resolve the same heavy type variable within the scope of one resolution call that has a filter. The general caching logic will refuse to capture the result that was influenced by the filter. This change makes it so that such globally useless results can be retained within the scope of one overarching resolution request, thereby getting the builtin test suite from ~23 sec to ~3 sec. PiperOrigin-RevId: 756087705
1 parent bcc33e5 commit debaa02

18 files changed

+454
-302
lines changed

xls/dslx/frontend/ast.cc

+16
Original file line numberDiff line numberDiff line change
@@ -1101,6 +1101,22 @@ std::string ElementTypeAnnotation::ToString() const {
11011101
: container_type_->ToString());
11021102
}
11031103

1104+
// -- class SliceTypeAnnotation
1105+
1106+
SliceTypeAnnotation::SliceTypeAnnotation(
1107+
Module* owner, Span span, TypeAnnotation* source_type,
1108+
std::variant<Slice*, WidthSlice*> slice)
1109+
: TypeAnnotation(owner, span), source_type_(source_type), slice_(slice) {}
1110+
1111+
std::vector<AstNode*> SliceTypeAnnotation::GetChildren(bool want_types) const {
1112+
return std::vector<AstNode*>{source_type_, ToAstNode(slice_)};
1113+
}
1114+
1115+
std::string SliceTypeAnnotation::ToString() const {
1116+
return absl::Substitute("Slice ($0, $1)", source_type_->ToString(),
1117+
ToAstNode(slice_)->ToString());
1118+
}
1119+
11041120
// -- class FunctionTypeAnnotation
11051121

11061122
FunctionTypeAnnotation::FunctionTypeAnnotation(

xls/dslx/frontend/ast.h

+29
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@
125125
X(BuiltinTypeAnnotation) \
126126
X(ChannelTypeAnnotation) \
127127
X(ElementTypeAnnotation) \
128+
X(SliceTypeAnnotation) \
128129
X(FunctionTypeAnnotation) \
129130
X(MemberTypeAnnotation) \
130131
X(ParamTypeAnnotation) \
@@ -547,6 +548,34 @@ class ElementTypeAnnotation : public TypeAnnotation {
547548
const bool allow_bit_vector_destructuring_;
548549
};
549550

551+
// An indirect type annotation for a slice, expressed in terms of the type of
552+
// the source entity and the slice node. This is used only within type
553+
// inference.
554+
class SliceTypeAnnotation : public TypeAnnotation {
555+
public:
556+
SliceTypeAnnotation(Module* owner, Span span, TypeAnnotation* source_type,
557+
std::variant<Slice*, WidthSlice*> slice);
558+
559+
absl::Status Accept(AstNodeVisitor* v) const override {
560+
return v->HandleSliceTypeAnnotation(this);
561+
}
562+
563+
std::string_view GetNodeTypeName() const override {
564+
return "SliceTypeAnnotation";
565+
}
566+
567+
TypeAnnotation* source_type() const { return source_type_; }
568+
std::variant<Slice*, WidthSlice*> slice() const { return slice_; }
569+
570+
std::vector<AstNode*> GetChildren(bool want_types) const override;
571+
572+
std::string ToString() const override;
573+
574+
private:
575+
TypeAnnotation* source_type_;
576+
std::variant<Slice*, WidthSlice*> slice_;
577+
};
578+
550579
// Represents a function signature with a return type and parameter types. The
551580
// signature elements are all non-nullable; a function with no return should use
552581
// a unit tuple type annotation for the return type.

xls/dslx/frontend/ast_cloner.cc

+17
Original file line numberDiff line numberDiff line change
@@ -1033,6 +1033,23 @@ class AstCloner : public AstNodeVisitor {
10331033
return absl::OkStatus();
10341034
}
10351035

1036+
absl::Status HandleSliceTypeAnnotation(
1037+
const SliceTypeAnnotation* n) override {
1038+
XLS_RETURN_IF_ERROR(ReplaceOrVisit(n->source_type()));
1039+
XLS_RETURN_IF_ERROR(ReplaceOrVisit(ToAstNode(n->slice())));
1040+
AstNode* new_slice_node = old_to_new_[ToAstNode(n->slice())];
1041+
std::variant<Slice*, WidthSlice*> new_slice;
1042+
if (Slice* slice = dynamic_cast<Slice*>(new_slice_node)) {
1043+
new_slice = slice;
1044+
} else {
1045+
new_slice = down_cast<WidthSlice*>(new_slice_node);
1046+
}
1047+
old_to_new_[n] = module_->Make<SliceTypeAnnotation>(
1048+
n->span(), down_cast<TypeAnnotation*>(old_to_new_[n->source_type()]),
1049+
new_slice);
1050+
return absl::OkStatus();
1051+
}
1052+
10361053
absl::Status HandleFunctionTypeAnnotation(
10371054
const FunctionTypeAnnotation* n) override {
10381055
XLS_RETURN_IF_ERROR(ReplaceOrVisit(n->return_type()));

xls/dslx/ir_convert/function_converter.cc

+1
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,7 @@ class FunctionConverterVisitor : public AstNodeVisitor {
416416
INVALID(BuiltinTypeAnnotation)
417417
INVALID(ChannelTypeAnnotation)
418418
INVALID(ElementTypeAnnotation)
419+
INVALID(SliceTypeAnnotation)
419420
INVALID(FunctionTypeAnnotation)
420421
INVALID(GenericTypeAnnotation)
421422
INVALID(MemberTypeAnnotation)

xls/dslx/type_system/deduce.cc

+4-29
Original file line numberDiff line numberDiff line change
@@ -902,35 +902,6 @@ absl::StatusOr<std::unique_ptr<Type>> DeduceStatementBlock(
902902
return last;
903903
}
904904

905-
// Returns (start, width), resolving indices via DSLX bit slice semantics.
906-
static absl::StatusOr<StartAndWidth> ResolveBitSliceIndices(
907-
int64_t bit_count, std::optional<int64_t> start_opt,
908-
std::optional<int64_t> limit_opt) {
909-
XLS_RET_CHECK_GE(bit_count, 0);
910-
int64_t start = 0;
911-
int64_t limit = bit_count;
912-
913-
if (start_opt.has_value()) {
914-
start = *start_opt;
915-
}
916-
if (limit_opt.has_value()) {
917-
limit = *limit_opt;
918-
}
919-
920-
if (start < 0) {
921-
start += bit_count;
922-
}
923-
if (limit < 0) {
924-
limit += bit_count;
925-
}
926-
927-
limit = std::min(std::max(limit, int64_t{0}), bit_count);
928-
start = std::min(std::max(start, int64_t{0}), limit);
929-
XLS_RET_CHECK_GE(start, 0);
930-
XLS_RET_CHECK_GE(limit, start);
931-
return StartAndWidth{.start = start, .width = limit - start};
932-
}
933-
934905
static absl::StatusOr<std::unique_ptr<Type>> DeduceWidthSliceType(
935906
const Index* node, const Type& subject_type,
936907
const BitsLikeProperties& subject_bits_like, const WidthSlice& width_slice,
@@ -2176,6 +2147,10 @@ class DeduceVisitor : public AstNodeVisitor {
21762147
const ElementTypeAnnotation* n) override {
21772148
return Fatal(n);
21782149
}
2150+
absl::Status HandleSliceTypeAnnotation(
2151+
const SliceTypeAnnotation* n) override {
2152+
return Fatal(n);
2153+
}
21792154
absl::Status HandleFunctionTypeAnnotation(
21802155
const FunctionTypeAnnotation* n) override {
21812156
return Fatal(n);

xls/dslx/type_system/deduce_utils.cc

+29
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "xls/dslx/type_system/deduce_utils.h"
1616

17+
#include <algorithm>
1718
#include <cstdint>
1819
#include <functional>
1920
#include <memory>
@@ -718,6 +719,34 @@ absl::StatusOr<Proc*> ResolveProc(Expr* callee, const TypeInfo* type_info) {
718719
return imported_info->module->GetMemberOrError<Proc>(colon_ref->attr());
719720
}
720721

722+
absl::StatusOr<StartAndWidth> ResolveBitSliceIndices(
723+
int64_t bit_count, std::optional<int64_t> start_opt,
724+
std::optional<int64_t> limit_opt) {
725+
XLS_RET_CHECK_GE(bit_count, 0);
726+
int64_t start = 0;
727+
int64_t limit = bit_count;
728+
729+
if (start_opt.has_value()) {
730+
start = *start_opt;
731+
}
732+
if (limit_opt.has_value()) {
733+
limit = *limit_opt;
734+
}
735+
736+
if (start < 0) {
737+
start += bit_count;
738+
}
739+
if (limit < 0) {
740+
limit += bit_count;
741+
}
742+
743+
limit = std::min(std::max(limit, int64_t{0}), bit_count);
744+
start = std::min(std::max(start, int64_t{0}), limit);
745+
XLS_RET_CHECK_GE(start, 0);
746+
XLS_RET_CHECK_GE(limit, start);
747+
return StartAndWidth{.start = start, .width = limit - start};
748+
}
749+
721750
absl::StatusOr<std::unique_ptr<Type>> ParametricBindingToType(
722751
const ParametricBinding& binding, DeduceCtx* ctx) {
723752
Module* binding_module = binding.owner();

xls/dslx/type_system/deduce_utils.h

+8
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,14 @@ absl::StatusOr<Function*> ResolveFunction(Expr* callee,
134134
// The target proc must have been typechecked prior to this call.
135135
absl::StatusOr<Proc*> ResolveProc(Expr* callee, const TypeInfo* type_info);
136136

137+
// Normalizes the given optional start and width values for a slice of a bit
138+
// vector of size `bit_count`. One or both values may be omitted (to indicate
139+
// the absolute start or end), negative (i.e. an end-based index), or out of
140+
// range. This function will produce positive, in-range values.
141+
absl::StatusOr<StartAndWidth> ResolveBitSliceIndices(
142+
int64_t bit_count, std::optional<int64_t> start_opt,
143+
std::optional<int64_t> limit_opt);
144+
137145
// Returns an AST node typed T from module "m", resolved via name "name".
138146
//
139147
// Errors are attributed to span "span".

xls/dslx/type_system_v2/BUILD

+5-1
Original file line numberDiff line numberDiff line change
@@ -252,12 +252,12 @@ cc_library(
252252
":import_utils",
253253
":inference_table",
254254
":type_annotation_utils",
255+
"//xls/common:visitor",
255256
"//xls/common/status:status_macros",
256257
"//xls/dslx:channel_direction",
257258
"//xls/dslx:errors",
258259
"//xls/dslx:import_data",
259260
"//xls/dslx:import_routines",
260-
"//xls/dslx:warning_collector",
261261
"//xls/dslx/frontend:ast",
262262
"//xls/dslx/frontend:ast_node",
263263
"//xls/dslx/frontend:ast_node_visitor_with_default",
@@ -276,6 +276,7 @@ cc_library(
276276
"@com_google_absl//absl/status:statusor",
277277
"@com_google_absl//absl/strings",
278278
"@com_google_absl//absl/strings:str_format",
279+
"@com_google_absl//absl/types:variant",
279280
],
280281
)
281282

@@ -550,7 +551,10 @@ cc_library(
550551
"//xls/dslx/frontend:ast_cloner",
551552
"//xls/dslx/frontend:ast_node",
552553
"//xls/dslx/frontend:pos",
554+
"//xls/dslx/type_system:deduce_utils",
555+
"//xls/dslx/type_system:type_info",
553556
"@com_google_absl//absl/algorithm:container",
557+
"@com_google_absl//absl/container:flat_hash_map",
554558
"@com_google_absl//absl/container:flat_hash_set",
555559
"@com_google_absl//absl/functional:function_ref",
556560
"@com_google_absl//absl/log",

xls/dslx/type_system_v2/evaluator.h

+4
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ class Evaluator {
3939
std::optional<const ParametricContext*> parametric_context,
4040
std::variant<int64_t, const Expr*> value_or_expr) = 0;
4141

42+
virtual absl::StatusOr<int64_t> EvaluateS32OrExpr(
43+
std::optional<const ParametricContext*> parametric_context,
44+
std::variant<int64_t, const Expr*> value_or_expr) = 0;
45+
4246
virtual absl::StatusOr<InterpValue> Evaluate(
4347
const ParametricContextScopedExpr& scoped_expr) = 0;
4448
};

xls/dslx/type_system_v2/inference_table.cc

+9-2
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,13 @@ absl::StatusOr<InferenceVariableKind> TypeAnnotationToInferenceVariableKind(
9797
annotation->ToString()));
9898
}
9999

100+
std::string ValueOrExprToString(
101+
std::variant<int64_t, const Expr*> expr_or_value) {
102+
return std::holds_alternative<int64_t>(expr_or_value)
103+
? absl::StrCat(std::get<int64_t>(expr_or_value))
104+
: std::get<const Expr*>(expr_or_value)->ToString();
105+
}
106+
100107
// Represents the immutable metadata for a variable in an `InferenceTable`.
101108
class InferenceVariable {
102109
public:
@@ -562,10 +569,10 @@ class InferenceTableImpl : public InferenceTable {
562569
if (data.slice_start_and_width_exprs.has_value()) {
563570
absl::StrAppendFormat(
564571
&result, " Start: %s\n",
565-
data.slice_start_and_width_exprs->start->ToString());
572+
ValueOrExprToString(data.slice_start_and_width_exprs->start));
566573
absl::StrAppendFormat(
567574
&result, " Width: %s\n",
568-
data.slice_start_and_width_exprs->width->ToString());
575+
ValueOrExprToString(data.slice_start_and_width_exprs->width));
569576
}
570577
const std::vector<const ParametricContext*>& contexts =
571578
contexts_per_node[node];

xls/dslx/type_system_v2/inference_table_converter_impl.cc

+31-57
Original file line numberDiff line numberDiff line change
@@ -119,12 +119,6 @@ class ConversionOrderVisitor : public AstNodeVisitorWithDefault {
119119
return DefaultHandler(node);
120120
}
121121

122-
absl::Status HandleSlice(const Slice* slice) override {
123-
// Slices get replaced with `StartAndWith` objects in `TypeInfo`, so there
124-
// is no point in trying to compute the type info of the slice itself.
125-
return absl::OkStatus();
126-
}
127-
128122
absl::Status HandleInvocation(const Invocation* node) override {
129123
// Exclude the arguments of invocations, but otherwise do the equivalent of
130124
// DefaultHandler. We exclude the arguments, because when an argument should
@@ -1264,9 +1258,20 @@ class InferenceTableConverterImpl : public InferenceTableConverter,
12641258
// A `Slice` actually has its bounds stored in `TypeInfo` out-of-band from
12651259
// the real type info, mirroring the `StartAndWidthExprs` that we store in
12661260
// the `InferenceTable`.
1267-
if (std::holds_alternative<Slice*>(index->rhs()) ||
1268-
std::holds_alternative<WidthSlice*>(index->rhs())) {
1269-
XLS_RETURN_IF_ERROR(ConcretizeSlice(parametric_context, index, ti));
1261+
if (std::holds_alternative<Slice*>(index->rhs())) {
1262+
std::optional<StartAndWidthExprs> start_and_width_exprs =
1263+
table_.GetSliceStartAndWidthExprs(ToAstNode(index->rhs()));
1264+
CHECK(start_and_width_exprs.has_value());
1265+
StartAndWidth start_and_width;
1266+
XLS_ASSIGN_OR_RETURN(start_and_width.start,
1267+
EvaluateU32OrExpr(parametric_context,
1268+
start_and_width_exprs->start));
1269+
XLS_ASSIGN_OR_RETURN(start_and_width.width,
1270+
EvaluateU32OrExpr(parametric_context,
1271+
start_and_width_exprs->width));
1272+
ti->AddSliceStartAndWidth(std::get<Slice*>(index->rhs()),
1273+
GetParametricEnv(parametric_context),
1274+
start_and_width);
12701275
}
12711276
}
12721277
if (const auto* const_assert = dynamic_cast<const ConstAssert*>(node)) {
@@ -1295,53 +1300,6 @@ class InferenceTableConverterImpl : public InferenceTableConverter,
12951300
return absl::OkStatus();
12961301
}
12971302

1298-
// Adds the concrete start and width value of the slice requested by the given
1299-
// `index` node to the given `TypeInfo`.
1300-
absl::Status ConcretizeSlice(
1301-
std::optional<const ParametricContext*> parametric_context,
1302-
const Index* index, TypeInfo* ti) {
1303-
std::optional<StartAndWidthExprs> start_and_width_exprs =
1304-
table_.GetSliceStartAndWidthExprs(index);
1305-
CHECK(start_and_width_exprs.has_value());
1306-
absl::StatusOr<int32_t> start =
1307-
EvaluateU32OrExpr(parametric_context, start_and_width_exprs->start);
1308-
XLS_ASSIGN_OR_RETURN(
1309-
uint32_t width,
1310-
EvaluateU32OrExpr(parametric_context, start_and_width_exprs->width));
1311-
const Type& array_type = **ti->GetItem(index->lhs());
1312-
int64_t array_size;
1313-
if (array_type.IsArray()) {
1314-
XLS_ASSIGN_OR_RETURN(array_size,
1315-
array_type.AsArray().size().GetAsInt64());
1316-
} else {
1317-
std::optional<BitsLikeProperties> bits_like = GetBitsLike(array_type);
1318-
CHECK(bits_like.has_value());
1319-
XLS_ASSIGN_OR_RETURN(array_size, bits_like->size.GetAsInt64());
1320-
}
1321-
1322-
// A generic `Slice` must have a constexpr start value. A `WidthSlice` can
1323-
// have a constexpr or dynamic start value. If it's constexpr, we validate
1324-
// it.
1325-
const bool is_generic_slice = std::holds_alternative<Slice*>(index->rhs());
1326-
if (is_generic_slice && !start.ok()) {
1327-
return start.status();
1328-
}
1329-
if (start.ok() && (*start < 0 || *start + width > array_size)) {
1330-
return TypeInferenceErrorStatus(
1331-
index->span(), nullptr,
1332-
absl::StrCat("Slice range out of bounds for array of size ",
1333-
array_size),
1334-
file_table_);
1335-
}
1336-
1337-
if (is_generic_slice) {
1338-
ti->AddSliceStartAndWidth(std::get<Slice*>(index->rhs()),
1339-
GetParametricEnv(parametric_context),
1340-
StartAndWidth{*start, width});
1341-
}
1342-
return absl::OkStatus();
1343-
}
1344-
13451303
// Used when a `ConstAssert` is concretized, to actually check if the asserted
13461304
// expression holds.
13471305
absl::Status CheckConstAssert(
@@ -2259,6 +2217,20 @@ class InferenceTableConverterImpl : public InferenceTableConverter,
22592217
absl::StatusOr<int64_t> EvaluateU32OrExpr(
22602218
std::optional<const ParametricContext*> parametric_context,
22612219
std::variant<int64_t, const Expr*> value_or_expr) override {
2220+
return Evaluate32BitIntOrExpr(parametric_context, value_or_expr,
2221+
/*is_signed=*/false);
2222+
}
2223+
2224+
absl::StatusOr<int64_t> EvaluateS32OrExpr(
2225+
std::optional<const ParametricContext*> parametric_context,
2226+
std::variant<int64_t, const Expr*> value_or_expr) override {
2227+
return Evaluate32BitIntOrExpr(parametric_context, value_or_expr,
2228+
/*is_signed=*/true);
2229+
}
2230+
2231+
absl::StatusOr<int64_t> Evaluate32BitIntOrExpr(
2232+
std::optional<const ParametricContext*> parametric_context,
2233+
std::variant<int64_t, const Expr*> value_or_expr, bool is_signed) {
22622234
if (std::holds_alternative<int64_t>(value_or_expr)) {
22632235
return std::get<int64_t>(value_or_expr);
22642236
}
@@ -2270,7 +2242,9 @@ class InferenceTableConverterImpl : public InferenceTableConverter,
22702242
std::optional<const TypeAnnotation*> type_annotation =
22712243
table_.GetTypeAnnotation(expr);
22722244
if (!type_annotation.has_value()) {
2273-
type_annotation = CreateU32Annotation(*expr->owner(), expr->span());
2245+
type_annotation = is_signed
2246+
? CreateS32Annotation(*expr->owner(), expr->span())
2247+
: CreateU32Annotation(*expr->owner(), expr->span());
22742248
}
22752249
XLS_ASSIGN_OR_RETURN(InterpValue value,
22762250
Evaluate(ParametricContextScopedExpr(

0 commit comments

Comments
 (0)