Skip to content

Avoid repeated work for identical parametric contexts in TIv2. #2154

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions xls/dslx/type_system_v2/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ cc_test(
"//xls/common/status:status_macros",
"//xls/dslx:create_import_data",
"//xls/dslx:import_data",
"//xls/dslx:interp_value",
"//xls/dslx:warning_collector",
"//xls/dslx:warning_kind",
"//xls/dslx/frontend:ast",
Expand All @@ -139,7 +140,9 @@ cc_test(
"//xls/dslx/frontend:parser",
"//xls/dslx/frontend:pos",
"//xls/dslx/frontend:scanner",
"//xls/dslx/type_system:parametric_env",
"//xls/dslx/type_system:type_info",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
Expand Down Expand Up @@ -363,6 +366,7 @@ cc_library(
"//xls/dslx/frontend:ast",
"//xls/dslx/type_system:type_info",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
Expand Down
111 changes: 96 additions & 15 deletions xls/dslx/type_system_v2/inference_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,11 @@ struct NodeData {

// The mutable data for a `ParametricContext` in an `InferenceTable`.
struct MutableParametricContextData {
// An arbitrarily-selected context that has the same `ParametricEnv` as this
// one. This is only set in cases where "duplicate" contexts are identified by
// the table.
std::optional<const ParametricContext*> canonical_context;

absl::flat_hash_map<const InferenceVariable*, ParametricContextScopedExpr>
parametric_values;
absl::flat_hash_map<const InferenceVariable*,
Expand Down Expand Up @@ -177,8 +182,8 @@ class UnificationCache {
unified_type;
}
for (const NameRef* dep : transitive_variable_dependencies) {
transitive_consumers_[std::get<const NameDef*>(dep->name_def())].insert(
variable_name_def);
const NameDef* dep_def = std::get<const NameDef*>(dep->name_def());
transitive_consumers_[dep_def].insert(variable_name_def);
}
}

Expand All @@ -201,22 +206,43 @@ class UnificationCache {
return it->second.cached_type;
}

void InvalidateVariable(const NameRef* variable) {
void InvalidateVariable(
std::optional<const ParametricContext*> parametric_context,
const NameRef* variable) {
VLOG(6) << "Invalidating unification cache for variable due to a direct "
"change: "
<< variable->ToString();
const auto* def = std::get<const NameDef*>(variable->name_def());
cache_.erase(def);
const auto consumers = transitive_consumers_.find(def);
if (consumers != transitive_consumers_.end()) {
for (const NameDef* consumer : consumers->second) {
VLOG(5) << "Invalidating unification cache for variable: "
<< consumer->ToString()
<< " due to dependency on changed variable: "
<< variable->ToString();
cache_.erase(consumer);
if (parametric_context.has_value()) {
const auto consumer_it = cache_.find(consumer);
if (consumer_it != cache_.end()) {
consumer_it->second.cached_type_per_parametric_context.erase(
*parametric_context);
}
} else {
// Note that if the invalidation is not scoped to a context, it could
// affect any context.
cache_.erase(consumer);
}
}
transitive_consumers_.erase(consumers);
}

if (parametric_context.has_value()) {
const auto it = cache_.find(def);
if (it != cache_.end()) {
it->second.cached_type_per_parametric_context.erase(
*parametric_context);
}
} else {
cache_.erase(def);
transitive_consumers_.erase(def);
}
}

Expand Down Expand Up @@ -275,7 +301,7 @@ class InferenceTableImpl : public InferenceTable {
return name_ref;
}

absl::StatusOr<const ParametricContext*> AddParametricInvocation(
absl::StatusOr<ParametricContext*> AddParametricInvocation(
const Invocation& node, const Function& callee,
std::optional<const Function*> caller,
std::optional<const ParametricContext*> parent_context,
Expand Down Expand Up @@ -325,12 +351,42 @@ class InferenceTableImpl : public InferenceTable {
context.get(), binding->type_annotation(), binding->expr()));
}
}
const ParametricContext* result = context.get();
ParametricContext* result = context.get();
parametric_contexts_.push_back(std::move(context));
mutable_parametric_context_data_.emplace(result, std::move(mutable_data));
return result;
}

bool MapToCanonicalInvocationTypeInfo(ParametricContext* parametric_context,
ParametricEnv env) override {
CHECK(parametric_context->is_invocation());

// `ParametricEnv` doesn't currently capture generic types, so for the time
// being, we can't canonicalize invocations that use them.
for (const ParametricBinding* binding :
parametric_context->parametric_bindings()) {
if (dynamic_cast<const GenericTypeAnnotation*>(
binding->type_annotation())) {
return false;
}
}

parametric_context->SetInvocationEnv(env);
const ParametricInvocationDetails& details =
std::get<ParametricInvocationDetails>(parametric_context->details());
const auto it = canonical_parametric_context_.find({details.callee, env});
if (it == canonical_parametric_context_.end()) {
canonical_parametric_context_.emplace_hint(
it, std::make_pair(details.callee, env), parametric_context);
return false;
}

mutable_parametric_context_data_.at(parametric_context).canonical_context =
it->second;
parametric_context->SetTypeInfo(it->second->type_info());
return true;
}

std::vector<const ParametricContext*> GetParametricInvocations()
const override {
std::vector<const ParametricContext*> result;
Expand Down Expand Up @@ -410,12 +466,16 @@ class InferenceTableImpl : public InferenceTable {
const auto it = type_annotations_per_type_variable_.find(variable);
if (it != type_annotations_per_type_variable_.end()) {
auto& annotations = it->second;
size_t count_before = annotations.size();
annotations.erase(std::remove_if(annotations.begin(), annotations.end(),
[&](const TypeAnnotation* annotation) {
return remove_predicate(annotation);
}),
annotations.end());
cache_.InvalidateVariable(variable->name_ref());
if (annotations.size() != count_before) {
cache_.InvalidateVariable(/*parametric_context=*/std::nullopt,
variable->name_ref());
}
}
return absl::OkStatus();
}
Expand Down Expand Up @@ -647,18 +707,33 @@ class InferenceTableImpl : public InferenceTable {
const absl::flat_hash_set<const NameRef*>&
transitive_variable_dependencies,
const TypeAnnotation* unified_type) override {
cache_.SetUnifiedTypeForVariable(parametric_context, variable,
transitive_variable_dependencies,
cache_.SetUnifiedTypeForVariable(GetCanonicalContext(parametric_context),
variable, transitive_variable_dependencies,
unified_type);
}

std::optional<const TypeAnnotation*> GetCachedUnifiedTypeForVariable(
std::optional<const ParametricContext*> parametric_context,
const NameRef* variable) override {
return cache_.GetUnifiedTypeForVariable(parametric_context, variable);
return cache_.GetUnifiedTypeForVariable(
GetCanonicalContext(parametric_context), variable);
}

private:
std::optional<const ParametricContext*> GetCanonicalContext(
std::optional<const ParametricContext*> context) {
if (!context.has_value()) {
return std::nullopt;
}
const auto it = mutable_parametric_context_data_.find(*context);
if (it == mutable_parametric_context_data_.end()) {
return std::nullopt;
}
return it->second.canonical_context.has_value()
? it->second.canonical_context
: context;
}

void AddVariable(const NameDef* name_def,
std::unique_ptr<InferenceVariable> variable) {
variables_.emplace(name_def, std::move(variable));
Expand Down Expand Up @@ -713,10 +788,12 @@ class InferenceTableImpl : public InferenceTable {
}
}
if (old_variable.has_value()) {
cache_.InvalidateVariable((*old_variable)->name_ref());
cache_.InvalidateVariable(/*parametric_context=*/std::nullopt,
(*old_variable)->name_ref());
}
if (node_data.type_variable.has_value()) {
cache_.InvalidateVariable((*node_data.type_variable)->name_ref());
cache_.InvalidateVariable(/*parametric_context=*/std::nullopt,
(*node_data.type_variable)->name_ref());
}
return absl::OkStatus();
}
Expand All @@ -732,7 +809,8 @@ class InferenceTableImpl : public InferenceTable {
} else {
type_annotations_per_type_variable_[variable].push_back(annotation);
}
cache_.InvalidateVariable(variable->name_ref());
cache_.InvalidateVariable(GetCanonicalContext(context),
variable->name_ref());
}

Module& module_;
Expand All @@ -755,6 +833,9 @@ class InferenceTableImpl : public InferenceTable {
std::vector<std::unique_ptr<ParametricContext>> parametric_contexts_;
absl::flat_hash_map<const ParametricContext*, MutableParametricContextData>
mutable_parametric_context_data_;
absl::flat_hash_map<std::pair<const Function*, ParametricEnv>,
const ParametricContext*>
canonical_parametric_context_;
absl::flat_hash_set<const TypeAnnotation*> auto_literal_annotations_;
absl::flat_hash_map<const ColonRef*, const AstNode*> colon_ref_targets_;
absl::flat_hash_map<
Expand Down
30 changes: 26 additions & 4 deletions xls/dslx/type_system_v2/inference_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ struct ParametricInvocationDetails {
// The details for a `ParametricContext` that is for a struct.
struct ParametricStructDetails {
const StructDefBase* struct_or_proc_def;
ParametricEnv env;
const ParametricEnv env;
};

// Identifies either an invocation of a parametric function, or a
Expand Down Expand Up @@ -160,6 +160,10 @@ class ParametricContext {
type_info_ != nullptr ? type_info_->module()->name() : "none");
}

// Intended to be used only from the owning table, upon canonicalization.
void SetInvocationEnv(ParametricEnv env) { invocation_env_ = std::move(env); }
void SetTypeInfo(TypeInfo* type_info) { type_info_ = type_info; }

private:
static std::string DetailsToString(const Details& details) {
return absl::visit(
Expand All @@ -180,9 +184,10 @@ class ParametricContext {
const uint64_t id_; // Just for logging.
const AstNode* node_;
const Details details_;
TypeInfo* type_info_;
const std::optional<const ParametricContext*> parent_context_;
const std::optional<const TypeAnnotation*> self_type_;
TypeInfo* type_info_;
std::optional<ParametricEnv> invocation_env_;
};

inline std::string ToString(std::optional<const ParametricContext*> context) {
Expand Down Expand Up @@ -279,12 +284,29 @@ class InferenceTable {
// context. Note that the `caller` must only be `nullopt` if the invocation is
// not in a function (e.g. it may be in the RHS of a free constant
// declaration).
virtual absl::StatusOr<const ParametricContext*> AddParametricInvocation(
virtual absl::StatusOr<ParametricContext*> AddParametricInvocation(
const Invocation& invocation, const Function& callee,
std::optional<const Function*> caller,
std::optional<const ParametricContext*> parent_context,
std::optional<const TypeAnnotation*> self_type,
TypeInfo* invocation_type_info = nullptr) = 0;
TypeInfo* invocation_type_info) = 0;

// Finds an existing `ParametricContext` in the table that represents an
// invocation of the same function with the same parametrics as `context` and
// `env`.
//
// If an existing context is found, updates the `TypeInfo` pointer associated
// with `context` to be one from the existing context, and returns true. The
// implication is that there is then no point in spending effort to populate
// the original `TypeInfo` that was associated with `context`. Whether or not
// the caller subsequently uses this information, the table will make use of
// it for internal cache optimization.
//
// If an existing matching context is not found, then this function captures
// the `env`, makes `context` eligible to serve as a canonical context for
// that `env` in the future, and returns false.
virtual bool MapToCanonicalInvocationTypeInfo(ParametricContext* context,
ParametricEnv env) = 0;

// Retrieves all the parametric invocations that have been defined.
virtual std::vector<const ParametricContext*> GetParametricInvocations()
Expand Down
29 changes: 16 additions & 13 deletions xls/dslx/type_system_v2/inference_table_converter_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,6 @@ class InferenceTableConverterImpl : public InferenceTableConverter,
// The parametric invocation now gets its own data structure set up in both
// the `InferenceTable` and the `TypeInfo` hierarchy.

TypeInfo* invocation_type_info = nullptr;
std::optional<const ParametricContext*> caller_or_target_struct_context =
function_and_target_object.target_struct_context.has_value()
? function_and_target_object.target_struct_context
Expand All @@ -572,11 +571,11 @@ class InferenceTableConverterImpl : public InferenceTableConverter,
TypeInfo * base_type_info,
GetTypeInfo(function->owner(), caller_or_target_struct_context));
XLS_ASSIGN_OR_RETURN(
invocation_type_info,
TypeInfo * invocation_type_info,
import_data_.type_info_owner().New(function->owner(), base_type_info));

XLS_ASSIGN_OR_RETURN(
const ParametricContext* invocation_context,
ParametricContext * invocation_context,
table_.AddParametricInvocation(
*invocation, *function, caller, caller_context,
function_and_target_object.target_struct_context.has_value()
Expand Down Expand Up @@ -615,9 +614,12 @@ class InferenceTableConverterImpl : public InferenceTableConverter,
}

// Figure out any implicit parametrics and generate the `ParametricEnv`.
XLS_RETURN_IF_ERROR(GenerateParametricFunctionEnv(
function_and_target_object.target_struct_context, invocation_context,
invocation));
XLS_ASSIGN_OR_RETURN(ParametricEnv env,
GenerateParametricFunctionEnv(
function_and_target_object.target_struct_context,
invocation_context, invocation));
const bool canonicalized = table_.MapToCanonicalInvocationTypeInfo(
invocation_context, std::move(env));
XLS_RETURN_IF_ERROR(AddInvocationTypeInfo(invocation_context));

// For an instance method call like `some_object.parametric_fn(args)`, type
Expand Down Expand Up @@ -689,7 +691,10 @@ class InferenceTableConverterImpl : public InferenceTableConverter,
XLS_ASSIGN_OR_RETURN(proc_type_info_frame,
PushProcTypeInfo(*function->proc()));
}
XLS_RETURN_IF_ERROR(ConvertSubtree(function, function, invocation_context));
if (!canonicalized) {
XLS_RETURN_IF_ERROR(
ConvertSubtree(function, function, invocation_context));
}
return GenerateTypeInfo(caller_context, invocation);
}

Expand Down Expand Up @@ -758,10 +763,6 @@ class InferenceTableConverterImpl : public InferenceTableConverter,
std::optional<const TypeAnnotation*> pre_unified_type = std::nullopt,
TypeAnnotationFilter type_annotation_filter =
TypeAnnotationFilter::None()) {
// Don't generate type info for the rest of tuple wildcard.
if (node->kind() == AstNodeKind::kRestOfTuple) {
return absl::OkStatus();
}
TypeSystemTrace trace = tracer_->TraceConvertNode(node);
VLOG(5) << "GenerateTypeInfo for node: " << node->ToString()
<< " of kind: `" << AstNodeKindToString(node->kind()) << "`"
Expand Down Expand Up @@ -1130,13 +1131,15 @@ class InferenceTableConverterImpl : public InferenceTableConverter,
file_table_);
}
ti->NoteConstExpr(value.value, *evaluated_value);
ti->NoteConstExpr(value.name_def, *evaluated_value);
}
// Evaluated enum value has numeric type, which needs to be converted
// to enum type.
XLS_ASSIGN_OR_RETURN(auto bits, evaluated_value->GetBits());
InterpValue enum_value = InterpValue::MakeEnum(
bits, evaluated_value->IsSigned(), *enum_def);
ti->NoteConstExpr(value.value, enum_value);
ti->NoteConstExpr(value.name_def, enum_value);
members.push_back(enum_value);
}
}
Expand Down Expand Up @@ -1893,7 +1896,7 @@ class InferenceTableConverterImpl : public InferenceTableConverter,
// Determines any implicit parametric values in the given `invocation`, and
// generates its `ParametricEnv` in `converted_parametric_envs_`. Also
// populates `parametric_value_exprs_` for the invocation.
absl::Status GenerateParametricFunctionEnv(
absl::StatusOr<ParametricEnv> GenerateParametricFunctionEnv(
std::optional<const ParametricContext*> callee_struct_context,
const ParametricContext* invocation_context,
const Invocation* invocation) {
Expand Down Expand Up @@ -2010,7 +2013,7 @@ class InferenceTableConverterImpl : public InferenceTableConverter,
std::move(actual_parametrics));
ParametricEnv env(std::move(values));
converted_parametric_envs_.emplace(invocation_context, env);
return absl::OkStatus();
return env;
}

// Attempts to infer the values of the specified implicit parametrics in an
Expand Down
Loading