Skip to content

Commit 466167e

Browse files
authored
[ty] Used shared expression cache during generic call inference (astral-sh#24219)
Resolves astral-sh/ty#3123. The alternative here would be to make call arguments standalone expressions, but I suspect this is cheaper, though it has the downside that nested standalone expressions won't benefit from the cache.
1 parent 64c4c96 commit 466167e

1 file changed

Lines changed: 59 additions & 0 deletions

File tree

  • crates/ty_python_semantic/src/types/infer

crates/ty_python_semantic/src/types/infer/builder.rs

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
use std::borrow::Cow;
2+
use std::cell::RefCell;
3+
use std::rc::Rc;
24

35
use itertools::{Either, Itertools};
46
use ruff_db::diagnostic::{Annotation, DiagnosticId, Severity};
@@ -222,6 +224,9 @@ pub(super) struct TypeInferenceBuilder<'db, 'ast> {
222224
/// The types of every expression in this region.
223225
expressions: FxHashMap<ExpressionNodeKey, Type<'db>>,
224226

227+
/// An expression cache shared across builders during multi-inference.
228+
expression_cache: Option<Rc<RefCell<ExpressionCache<'db>>>>,
229+
225230
/// Expressions that are string annotations
226231
string_annotations: FxHashSet<ExpressionNodeKey>,
227232

@@ -306,6 +311,9 @@ pub(super) struct TypeInferenceBuilder<'db, 'ast> {
306311
dataclass_field_specifiers: SmallVec<[Type<'db>; NUM_FIELD_SPECIFIERS_INLINE]>,
307312
}
308313

314+
/// An expression cache shared across builders during multi-inference.
315+
type ExpressionCache<'db> = FxHashMap<(ExpressionNodeKey, TypeContext<'db>), Type<'db>>;
316+
309317
impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
310318
/// How big a string do we build before bailing?
311319
///
@@ -332,6 +340,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
332340
deferred_state: DeferredExpressionState::None,
333341
inferring_vararg_annotation: false,
334342
expressions: FxHashMap::default(),
343+
expression_cache: None,
335344
string_annotations: FxHashSet::default(),
336345
bindings: VecMap::default(),
337346
declarations: VecMap::default(),
@@ -475,6 +484,22 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
475484
}
476485
}
477486

487+
/// Setup a shared expression cache for multi-inference.
488+
///
489+
/// Returns `false` if the expression cache was already initialized.
490+
fn setup_expression_cache(&mut self) -> bool {
491+
if self.expression_cache.is_some() {
492+
false
493+
} else {
494+
self.expression_cache = Some(Rc::new(RefCell::new(FxHashMap::default())));
495+
true
496+
}
497+
}
498+
499+
fn teardown_expression_cache(&mut self) {
500+
self.expression_cache = None;
501+
}
502+
478503
/// Are we currently inferring types in file with deferred types?
479504
/// This is true for stub files, for files with `__future__.annotations`, and
480505
/// by default for all source files in Python 3.14 and later.
@@ -5331,6 +5356,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
53315356

53325357
let mut seen = FxHashSet::default();
53335358

5359+
// Cache expressions inferred across speculative inference attempts.
5360+
//
5361+
// This is important to avoid exponential blowup for deeply nested generic calls,
5362+
// as inner expressions are repeatedly inferred with the same type context.
5363+
let teardown = self.setup_expression_cache();
5364+
53345365
for (parameter, parameter_tcx) in parameter_types {
53355366
if !seen.insert(parameter.annotated_type()) {
53365367
continue;
@@ -5345,6 +5376,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
53455376
);
53465377
argument_types.insert(parameter.annotated_type(), inferred_ty);
53475378
}
5379+
5380+
if teardown {
5381+
self.teardown_expression_cache();
5382+
}
53485383
}
53495384
}
53505385
}
@@ -5435,6 +5470,16 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
54355470
expression: &ast::Expr,
54365471
tcx: TypeContext<'db>,
54375472
) -> Type<'db> {
5473+
if let Some(ty) = self.expression_cache.as_ref().and_then(|expression_cache| {
5474+
expression_cache
5475+
.borrow()
5476+
.get(&(expression.into(), tcx))
5477+
.copied()
5478+
}) {
5479+
self.store_expression_type(expression, ty);
5480+
return ty;
5481+
}
5482+
54385483
let mut ty = match expression {
54395484
ast::Expr::NoneLiteral(ast::ExprNoneLiteral {
54405485
range: _,
@@ -5497,6 +5542,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
54975542

54985543
self.store_expression_type(expression, ty);
54995544

5545+
if let Some(expression_cache) = &self.expression_cache {
5546+
expression_cache
5547+
.borrow_mut()
5548+
.insert((expression.into(), tcx), ty);
5549+
}
5550+
55005551
ty
55015552
}
55025553

@@ -8989,6 +9040,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
89899040
undecorated_type: _,
89909041

89919042
// builder only state
9043+
expression_cache: _,
89929044
typevar_binding_context: _,
89939045
inference_flags: _,
89949046
deferred_state: _,
@@ -9067,6 +9119,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
90679119
expressions,
90689120
bindings,
90699121
called_functions,
9122+
expression_cache: _,
90709123
declarations: _,
90719124
deferred: _,
90729125
scope: _,
@@ -9111,7 +9164,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
91119164
cycle_recovery,
91129165
undecorated_type,
91139166
called_functions,
9167+
91149168
// builder only state
9169+
expression_cache: _,
91159170
dataclass_field_specifiers: _,
91169171
all_definitely_bound: _,
91179172
typevar_binding_context: _,
@@ -9193,6 +9248,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
91939248
undecorated_type: _,
91949249

91959250
// Builder only state
9251+
expression_cache: _,
91969252
dataclass_field_specifiers: _,
91979253
all_definitely_bound: _,
91989254
typevar_binding_context: _,
@@ -9237,6 +9293,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
92379293
inference_flags,
92389294
typevar_binding_context,
92399295
inferring_vararg_annotation,
9296+
ref expression_cache,
92409297
ref return_types_and_ranges,
92419298
ref dataclass_field_specifiers,
92429299

@@ -9265,6 +9322,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
92659322
builder.typevar_binding_context = typevar_binding_context;
92669323
builder.inference_flags = inference_flags;
92679324
builder.inferring_vararg_annotation = inferring_vararg_annotation;
9325+
builder.expression_cache.clone_from(expression_cache);
92689326
builder
92699327
.return_types_and_ranges
92709328
.clone_from(return_types_and_ranges);
@@ -9292,6 +9350,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
92929350
undecorated_type: _,
92939351

92949352
// builder only state
9353+
expression_cache: _,
92959354
all_definitely_bound: _,
92969355
typevar_binding_context: _,
92979356
inference_flags: _,

0 commit comments

Comments
 (0)