11use std:: borrow:: Cow ;
2+ use std:: cell:: RefCell ;
3+ use std:: rc:: Rc ;
24
35use itertools:: { Either , Itertools } ;
46use ruff_db:: diagnostic:: { Annotation , DiagnosticId , Severity } ;
@@ -222,6 +224,9 @@ pub(super) struct TypeInferenceBuilder<'db, 'ast> {
222224 /// The types of every expression in this region.
223225 expressions : FxHashMap < ExpressionNodeKey , Type < ' db > > ,
224226
227+ /// An expression cache shared across builders during multi-inference.
228+ expression_cache : Option < Rc < RefCell < ExpressionCache < ' db > > > > ,
229+
225230 /// Expressions that are string annotations
226231 string_annotations : FxHashSet < ExpressionNodeKey > ,
227232
@@ -306,6 +311,9 @@ pub(super) struct TypeInferenceBuilder<'db, 'ast> {
306311 dataclass_field_specifiers : SmallVec < [ Type < ' db > ; NUM_FIELD_SPECIFIERS_INLINE ] > ,
307312}
308313
314+ /// An expression cache shared across builders during multi-inference.
315+ type ExpressionCache < ' db > = FxHashMap < ( ExpressionNodeKey , TypeContext < ' db > ) , Type < ' db > > ;
316+
309317impl < ' db , ' ast > TypeInferenceBuilder < ' db , ' ast > {
310318 /// How big a string do we build before bailing?
311319 ///
@@ -332,6 +340,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
332340 deferred_state : DeferredExpressionState :: None ,
333341 inferring_vararg_annotation : false ,
334342 expressions : FxHashMap :: default ( ) ,
343+ expression_cache : None ,
335344 string_annotations : FxHashSet :: default ( ) ,
336345 bindings : VecMap :: default ( ) ,
337346 declarations : VecMap :: default ( ) ,
@@ -475,6 +484,22 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
475484 }
476485 }
477486
487+ /// Setup a shared expression cache for multi-inference.
488+ ///
489+ /// Returns `false` if the expression cache was already initialized.
490+ fn setup_expression_cache ( & mut self ) -> bool {
491+ if self . expression_cache . is_some ( ) {
492+ false
493+ } else {
494+ self . expression_cache = Some ( Rc :: new ( RefCell :: new ( FxHashMap :: default ( ) ) ) ) ;
495+ true
496+ }
497+ }
498+
499+ fn teardown_expression_cache ( & mut self ) {
500+ self . expression_cache = None ;
501+ }
502+
478503 /// Are we currently inferring types in file with deferred types?
479504 /// This is true for stub files, for files with `__future__.annotations`, and
480505 /// by default for all source files in Python 3.14 and later.
@@ -5331,6 +5356,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
53315356
53325357 let mut seen = FxHashSet :: default ( ) ;
53335358
5359+ // Cache expressions inferred across speculative inference attempts.
5360+ //
5361+ // This is important to avoid exponential blowup for deeply nested generic calls,
5362+ // as inner expressions are repeatedly inferred with the same type context.
5363+ let teardown = self . setup_expression_cache ( ) ;
5364+
53345365 for ( parameter, parameter_tcx) in parameter_types {
53355366 if !seen. insert ( parameter. annotated_type ( ) ) {
53365367 continue ;
@@ -5345,6 +5376,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
53455376 ) ;
53465377 argument_types. insert ( parameter. annotated_type ( ) , inferred_ty) ;
53475378 }
5379+
5380+ if teardown {
5381+ self . teardown_expression_cache ( ) ;
5382+ }
53485383 }
53495384 }
53505385 }
@@ -5435,6 +5470,16 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
54355470 expression : & ast:: Expr ,
54365471 tcx : TypeContext < ' db > ,
54375472 ) -> Type < ' db > {
5473+ if let Some ( ty) = self . expression_cache . as_ref ( ) . and_then ( |expression_cache| {
5474+ expression_cache
5475+ . borrow ( )
5476+ . get ( & ( expression. into ( ) , tcx) )
5477+ . copied ( )
5478+ } ) {
5479+ self . store_expression_type ( expression, ty) ;
5480+ return ty;
5481+ }
5482+
54385483 let mut ty = match expression {
54395484 ast:: Expr :: NoneLiteral ( ast:: ExprNoneLiteral {
54405485 range : _,
@@ -5497,6 +5542,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
54975542
54985543 self . store_expression_type ( expression, ty) ;
54995544
5545+ if let Some ( expression_cache) = & self . expression_cache {
5546+ expression_cache
5547+ . borrow_mut ( )
5548+ . insert ( ( expression. into ( ) , tcx) , ty) ;
5549+ }
5550+
55005551 ty
55015552 }
55025553
@@ -8989,6 +9040,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
89899040 undecorated_type : _,
89909041
89919042 // builder only state
9043+ expression_cache : _,
89929044 typevar_binding_context : _,
89939045 inference_flags : _,
89949046 deferred_state : _,
@@ -9067,6 +9119,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
90679119 expressions,
90689120 bindings,
90699121 called_functions,
9122+ expression_cache : _,
90709123 declarations : _,
90719124 deferred : _,
90729125 scope : _,
@@ -9111,7 +9164,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
91119164 cycle_recovery,
91129165 undecorated_type,
91139166 called_functions,
9167+
91149168 // builder only state
9169+ expression_cache : _,
91159170 dataclass_field_specifiers : _,
91169171 all_definitely_bound : _,
91179172 typevar_binding_context : _,
@@ -9193,6 +9248,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
91939248 undecorated_type : _,
91949249
91959250 // Builder only state
9251+ expression_cache : _,
91969252 dataclass_field_specifiers : _,
91979253 all_definitely_bound : _,
91989254 typevar_binding_context : _,
@@ -9237,6 +9293,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
92379293 inference_flags,
92389294 typevar_binding_context,
92399295 inferring_vararg_annotation,
9296+ ref expression_cache,
92409297 ref return_types_and_ranges,
92419298 ref dataclass_field_specifiers,
92429299
@@ -9265,6 +9322,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
92659322 builder. typevar_binding_context = typevar_binding_context;
92669323 builder. inference_flags = inference_flags;
92679324 builder. inferring_vararg_annotation = inferring_vararg_annotation;
9325+ builder. expression_cache . clone_from ( expression_cache) ;
92689326 builder
92699327 . return_types_and_ranges
92709328 . clone_from ( return_types_and_ranges) ;
@@ -9292,6 +9350,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
92929350 undecorated_type : _,
92939351
92949352 // builder only state
9353+ expression_cache : _,
92959354 all_definitely_bound : _,
92969355 typevar_binding_context : _,
92979356 inference_flags : _,
0 commit comments