Fail closed on definition references and sticky mappings

This commit is contained in:
2026-07-02 19:04:25 +02:00
parent 39b991800a
commit 1b56798018
2 changed files with 170 additions and 3 deletions
@@ -1169,6 +1169,52 @@ NODE_CLASS_MAPPINGS = {
self.assertEqual({}, result["nodes"]) self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"]) self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_input_types_default_referencing_return_types_skips_node(self):
source = '''
class DefaultReferencesReturnTypesInputNode:
RETURN_TYPES = ("IMAGE",)
@classmethod
def INPUT_TYPES(cls, value=RETURN_TYPES):
return {
"required": {
"image": ("IMAGE",),
},
}
NODE_CLASS_MAPPINGS = {
"DefaultReferencesReturnTypesInputNode": DefaultReferencesReturnTypesInputNode,
}
'''
result = self._extract_source(source, "default-references-return-types-input-pack")
self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_input_types_return_annotation_referencing_input_types_skips_node(self):
source = '''
class AnnotationReferencesInputTypesNode:
RETURN_TYPES = ("IMAGE",)
@classmethod
def INPUT_TYPES(cls) -> INPUT_TYPES:
return {
"required": {
"image": ("IMAGE",),
},
}
NODE_CLASS_MAPPINGS = {
"AnnotationReferencesInputTypesNode": AnnotationReferencesInputTypesNode,
}
'''
result = self._extract_source(source, "annotation-references-input-types-pack")
self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_decorated_input_types_skips_node(self): def test_decorated_input_types_skips_node(self):
source = ''' source = '''
def replace(fn): def replace(fn):
@@ -2163,6 +2209,37 @@ NODE_CLASS_MAPPINGS = build_mappings()
self.assertEqual({}, result["nodes"]) self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"]) self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_dynamic_node_class_mapping_assignment_stays_invalid_after_static_reassignment(self):
source = '''
def build_mappings():
return {"StickyDynamicMappingNode": StickyDynamicMappingNode}
class StickyDynamicMappingNode:
RETURN_TYPES = ("IMAGE",)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
},
}
NODE_CLASS_MAPPINGS = {
"StickyDynamicMappingNode": StickyDynamicMappingNode,
}
NODE_CLASS_MAPPINGS = build_mappings()
NODE_CLASS_MAPPINGS = {
"StickyDynamicMappingNode": StickyDynamicMappingNode,
}
'''
result = self._extract_source(source, "sticky-dynamic-mapping-pack")
self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_rebound_node_class_name_skips_static_mapping(self): def test_rebound_node_class_name_skips_static_mapping(self):
source = ''' source = '''
def build_node(): def build_node():
@@ -3719,6 +3796,40 @@ NODE_DISPLAY_NAME_MAPPINGS = build_displays()
self.assertEqual({}, result["nodes"]) self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"]) self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_dynamic_display_mapping_assignment_stays_invalid_after_static_reassignment(self):
source = '''
def build_displays():
return {"StickyDisplayInvalidatedNode": "Dynamic Display"}
class StickyDisplayInvalidatedNode:
RETURN_TYPES = ("IMAGE",)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
},
}
NODE_CLASS_MAPPINGS = {
"StickyDisplayInvalidatedNode": StickyDisplayInvalidatedNode,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"StickyDisplayInvalidatedNode": "Stale Display",
}
NODE_DISPLAY_NAME_MAPPINGS = build_displays()
NODE_DISPLAY_NAME_MAPPINGS = {
"StickyDisplayInvalidatedNode": "Recovered Display",
}
'''
result = self._extract_source(source, "sticky-display-invalidated-pack")
self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_non_string_display_mapping_value_skips_node(self): def test_non_string_display_mapping_value_skips_node(self):
source = ''' source = '''
class NonStringDisplayValueNode: class NonStringDisplayValueNode:
+59 -3
View File
@@ -648,6 +648,35 @@ def _arbitrary_call_observed_names(stmt):
return names return names
def _definition_time_referenced_names(stmt):
names = set()
def collect_function_definition_expressions(node):
for decorator in node.decorator_list:
names.update(_referenced_names(decorator))
names.update(_referenced_names(node.args))
if node.returns is not None:
names.update(_referenced_names(node.returns))
for type_param in getattr(node, "type_params", ()):
names.update(_referenced_names(type_param))
if isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef)):
collect_function_definition_expressions(stmt)
elif isinstance(stmt, ast.ClassDef):
for decorator in stmt.decorator_list:
names.update(_referenced_names(decorator))
for base in stmt.bases:
names.update(_referenced_names(base))
for keyword in stmt.keywords:
names.update(_referenced_names(keyword.value))
for type_param in getattr(stmt, "type_params", ()):
names.update(_referenced_names(type_param))
elif isinstance(stmt, ast.Lambda):
names.update(_referenced_names(stmt.args))
return names
def _assigned_names_in_control_flow(stmt): def _assigned_names_in_control_flow(stmt):
names = _mutating_call_target_names(stmt) | _arbitrary_call_observed_names(stmt) names = _mutating_call_target_names(stmt) | _arbitrary_call_observed_names(stmt)
@@ -980,7 +1009,7 @@ def _update_class_attr_aliases_from_unpack(target, value, name, aliases):
def _input_types_alias_sources(value, aliases): def _input_types_alias_sources(value, aliases):
if isinstance(value, ast.Name): if isinstance(value, ast.Name):
return value.id == "INPUT_TYPES" or value.id in aliases return value.id in _CLASS_SIGNATURE_ATTRS or value.id in aliases
if isinstance(value, (ast.Tuple, ast.List)): if isinstance(value, (ast.Tuple, ast.List)):
return any(_input_types_alias_sources(item, aliases) for item in value.elts) return any(_input_types_alias_sources(item, aliases) for item in value.elts)
return False return False
@@ -1134,11 +1163,14 @@ def _input_types(cls, env, decorator_env):
for stmt in cls.body: for stmt in cls.body:
mutating_targets = _mutating_call_target_names(stmt) mutating_targets = _mutating_call_target_names(stmt)
observed_targets = _arbitrary_call_observed_names(stmt) observed_targets = _arbitrary_call_observed_names(stmt)
definition_references = _definition_time_referenced_names(stmt)
protected_definition_references = _CLASS_SIGNATURE_ATTRS | aliases
input_types_invalidated = ( input_types_invalidated = (
"INPUT_TYPES" in mutating_targets "INPUT_TYPES" in mutating_targets
or bool(aliases.intersection(mutating_targets)) or bool(aliases.intersection(mutating_targets))
or "INPUT_TYPES" in observed_targets or "INPUT_TYPES" in observed_targets
or bool(aliases.intersection(observed_targets)) or bool(aliases.intersection(observed_targets))
or bool(definition_references.intersection(protected_definition_references))
) )
if input_types_invalidated: if input_types_invalidated:
value = _INVALID value = _INVALID
@@ -1669,6 +1701,7 @@ def _update_module_dict_aliases(stmt, name, aliases):
def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=None, return_state=False): def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=None, return_state=False):
value_invalidated_by_names = value_invalidated_by_names or (lambda _value, _names: False) value_invalidated_by_names = value_invalidated_by_names or (lambda _value, _names: False)
value = _MISSING value = _MISSING
sticky_invalid = False
env = {} env = {}
class_bindings = {} class_bindings = {}
class_aliases = {} class_aliases = {}
@@ -1695,74 +1728,97 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N
and value_invalidated_by_names(value, class_attr_names) and value_invalidated_by_names(value, class_attr_names)
): ):
value = _INVALID value = _INVALID
sticky_invalid = True
if _name_invalidated_by(name, _mutating_call_target_names(stmt)): if _name_invalidated_by(name, _mutating_call_target_names(stmt)):
value = _INVALID value = _INVALID
sticky_invalid = True
if _name_invalidated_by(name, _arbitrary_call_observed_names(stmt)): if _name_invalidated_by(name, _arbitrary_call_observed_names(stmt)):
value = _INVALID value = _INVALID
sticky_invalid = True
if _name_invalidated_by(name, _namespace_alias_mutation_target_names(stmt, namespace_aliases)): if _name_invalidated_by(name, _namespace_alias_mutation_target_names(stmt, namespace_aliases)):
value = _INVALID value = _INVALID
sticky_invalid = True
if _module_dict_alias_invalidated(stmt, module_dict_aliases): if _module_dict_alias_invalidated(stmt, module_dict_aliases):
value = _INVALID value = _INVALID
sticky_invalid = True
if isinstance(stmt, ast.Assign): if isinstance(stmt, ast.Assign):
if not _name_is_assigned(stmt, name): if not _name_is_assigned(stmt, name):
if isinstance(stmt.value, ast.Name) and stmt.value.id == name: if isinstance(stmt.value, ast.Name) and stmt.value.id == name:
value = _INVALID value = _INVALID
sticky_invalid = True
advance_module_state(stmt) advance_module_state(stmt)
continue continue
if len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name): if sticky_invalid:
value = _INVALID
elif len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name):
try: try:
value = _module_dict_entries(stmt.value, env, class_bindings, value_converter) value = _module_dict_entries(stmt.value, env, class_bindings, value_converter)
except UnsupportedStaticExpression: except UnsupportedStaticExpression:
value = _INVALID value = _INVALID
sticky_invalid = True
else: else:
value = _INVALID value = _INVALID
sticky_invalid = True
advance_module_state(stmt) advance_module_state(stmt)
continue continue
if isinstance(stmt, ast.AnnAssign): if isinstance(stmt, ast.AnnAssign):
if not _name_is_assigned(stmt, name): if not _name_is_assigned(stmt, name):
if isinstance(stmt.value, ast.Name) and stmt.value.id == name: if isinstance(stmt.value, ast.Name) and stmt.value.id == name:
value = _INVALID value = _INVALID
sticky_invalid = True
advance_module_state(stmt) advance_module_state(stmt)
continue continue
if isinstance(stmt.target, ast.Name) and stmt.value is not None: if sticky_invalid:
value = _INVALID
elif isinstance(stmt.target, ast.Name) and stmt.value is not None:
try: try:
value = _module_dict_entries(stmt.value, env, class_bindings, value_converter) value = _module_dict_entries(stmt.value, env, class_bindings, value_converter)
except UnsupportedStaticExpression: except UnsupportedStaticExpression:
value = _INVALID value = _INVALID
sticky_invalid = True
else: else:
value = _INVALID value = _INVALID
sticky_invalid = True
advance_module_state(stmt) advance_module_state(stmt)
continue continue
if isinstance(stmt, ast.AugAssign): if isinstance(stmt, ast.AugAssign):
if _name_is_assigned(stmt, name): if _name_is_assigned(stmt, name):
value = _INVALID value = _INVALID
sticky_invalid = True
advance_module_state(stmt) advance_module_state(stmt)
continue continue
if isinstance(stmt, ast.Delete): if isinstance(stmt, ast.Delete):
if name in _delete_target_names(stmt): if name in _delete_target_names(stmt):
value = _INVALID value = _INVALID
sticky_invalid = True
advance_module_state(stmt) advance_module_state(stmt)
continue continue
if isinstance(stmt, ast.Expr): if isinstance(stmt, ast.Expr):
if name in _mutating_call_target_names(stmt): if name in _mutating_call_target_names(stmt):
value = _INVALID value = _INVALID
sticky_invalid = True
if name in _bound_names(stmt): if name in _bound_names(stmt):
value = _INVALID value = _INVALID
sticky_invalid = True
advance_module_state(stmt) advance_module_state(stmt)
continue continue
if isinstance(stmt, _CONTROL_FLOW_TYPES): if isinstance(stmt, _CONTROL_FLOW_TYPES):
if name in _assigned_names_in_control_flow(stmt): if name in _assigned_names_in_control_flow(stmt):
value = _INVALID value = _INVALID
sticky_invalid = True
if _has_wildcard_import_in_control_flow(stmt): if _has_wildcard_import_in_control_flow(stmt):
value = _INVALID value = _INVALID
sticky_invalid = True
advance_module_state(stmt) advance_module_state(stmt)
continue continue
if _has_wildcard_import(stmt): if _has_wildcard_import(stmt):
value = _INVALID value = _INVALID
sticky_invalid = True
advance_module_state(stmt) advance_module_state(stmt)
continue continue
if name in _bound_names(stmt): if name in _bound_names(stmt):
value = _INVALID value = _INVALID
sticky_invalid = True
advance_module_state(stmt) advance_module_state(stmt)
if return_state: if return_state:
return value return value