Fail closed on dynamic patches and displays

This commit is contained in:
2026-07-02 16:27:48 +02:00
parent 9752248ee9
commit 7e4e85a0bd
2 changed files with 148 additions and 38 deletions
+60 -3
View File
@@ -1081,6 +1081,39 @@ NODE_CLASS_MAPPINGS = {
self.assertEqual({}, result["nodes"]) self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"]) self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_shadowed_classmethod_decorator_skips_node(self):
source = '''
def classmethod(fn):
def replacement(cls):
return {
"required": {
"mask": ("MASK",),
},
}
return replacement
class ShadowedClassmethodInputTypesNode:
RETURN_TYPES = ("IMAGE",)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
},
}
NODE_CLASS_MAPPINGS = {
"ShadowedClassmethodInputTypesNode": ShadowedClassmethodInputTypesNode,
}
'''
result = self._extract_source(source, "shadowed-classmethod-input-types-pack")
self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_input_types_with_present_non_dict_sections_skips_node(self): def test_input_types_with_present_non_dict_sections_skips_node(self):
source = ''' source = '''
class InvalidInputSectionsNode: class InvalidInputSectionsNode:
@@ -1860,6 +1893,30 @@ PatchedReturnTypesNode.RETURN_TYPES = ("MASK",)
self.assertEqual({}, result["nodes"]) self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"]) self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_module_class_setattr_patch_after_mapping_skips_node(self):
source = '''
class SetattrPatchedNode:
RETURN_TYPES = ("IMAGE",)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
},
}
NODE_CLASS_MAPPINGS = {
"SetattrPatchedNode": SetattrPatchedNode,
}
setattr(SetattrPatchedNode, "RETURN_TYPES", ("MASK",))
'''
result = self._extract_source(source, "setattr-patched-node-pack")
self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_module_class_input_types_patch_after_mapping_skips_node(self): def test_module_class_input_types_patch_after_mapping_skips_node(self):
source = ''' source = '''
def build_inputs(): def build_inputs():
@@ -2266,7 +2323,7 @@ NODE_CLASS_MAPPINGS = {
self.assertEqual({}, result["nodes"]) self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"]) self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_dynamic_display_mapping_reassignment_falls_back_to_node_type(self): def test_dynamic_display_mapping_reassignment_skips_node(self):
source = ''' source = '''
def build_displays(): def build_displays():
return {"DisplayInvalidatedNode": "Dynamic Display"} return {"DisplayInvalidatedNode": "Dynamic Display"}
@@ -2294,8 +2351,8 @@ NODE_DISPLAY_NAME_MAPPINGS = build_displays()
''' '''
result = self._extract_source(source, "dynamic-display-pack") result = self._extract_source(source, "dynamic-display-pack")
self.assertEqual("DisplayInvalidatedNode", result["nodes"]["DisplayInvalidatedNode"]["display"]) self.assertEqual({}, result["nodes"])
self.assertEqual("ok", result["pack"]["status"]) self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_input_types_with_dynamic_control_flow_is_skipped(self): def test_input_types_with_dynamic_control_flow_is_skipped(self):
source = ''' source = '''
+88 -35
View File
@@ -59,12 +59,26 @@ def _literal(node, env, allow_mutable_env=True):
return result return result
if isinstance(node, ast.Name) and node.id in env: if isinstance(node, ast.Name) and node.id in env:
value = env[node.id] value = env[node.id]
if value is _INVALID:
raise UnsupportedStaticExpression(f"unsupported env reference {node.id!r}")
if not allow_mutable_env and _is_mutable_static_value(value): if not allow_mutable_env and _is_mutable_static_value(value):
raise UnsupportedStaticExpression(f"mutable env reference {node.id!r} is not supported") raise UnsupportedStaticExpression(f"mutable env reference {node.id!r} is not supported")
return value return value
raise UnsupportedStaticExpression(type(node).__name__) raise UnsupportedStaticExpression(type(node).__name__)
def _invalidate_env_name(env, name):
if name == "classmethod":
env[name] = _INVALID
else:
env.pop(name, None)
def _invalidate_env_names(env, names):
for name in names:
_invalidate_env_name(env, name)
def _is_mutable_static_value(value): def _is_mutable_static_value(value):
return isinstance(value, (dict, list, set)) return isinstance(value, (dict, list, set))
@@ -108,6 +122,24 @@ def _attribute_target_base_names(target):
return set() return set()
def _setattr_delattr_target_names(node):
if not isinstance(node, ast.Call):
return set()
if not isinstance(node.func, ast.Name) or node.func.id not in {"delattr", "setattr"}:
return set()
if len(node.args) < 2:
return set()
attr = node.args[1]
if (
isinstance(attr, ast.Constant)
and isinstance(attr.value, str)
and attr.value not in _CLASS_SIGNATURE_ATTRS
):
return set()
name = _root_name(node.args[0])
return {name} if name else set()
def _class_attribute_mutation_target_names(stmt): def _class_attribute_mutation_target_names(stmt):
names = set() names = set()
@@ -161,6 +193,7 @@ def _class_attribute_mutation_target_names(stmt):
names.update(_attribute_target_base_names(target)) names.update(_attribute_target_base_names(target))
def visit_Call(self, node): def visit_Call(self, node):
names.update(_setattr_delattr_target_names(node))
if isinstance(node.func, ast.Attribute) and node.func.attr in _MUTATING_METHODS: if isinstance(node.func, ast.Attribute) and node.func.attr in _MUTATING_METHODS:
names.update(_attribute_target_base_names(node.func.value)) names.update(_attribute_target_base_names(node.func.value))
self.generic_visit(node) self.generic_visit(node)
@@ -328,6 +361,7 @@ def _mutating_call_target_names(stmt):
self.visit(node.args) self.visit(node.args)
def visit_Call(self, node): def visit_Call(self, node):
names.update(_setattr_delattr_target_names(node))
if isinstance(node.func, ast.Attribute) and node.func.attr in _MUTATING_METHODS: if isinstance(node.func, ast.Attribute) and node.func.attr in _MUTATING_METHODS:
names.update(_target_names(node.func.value)) names.update(_target_names(node.func.value))
self.generic_visit(node) self.generic_visit(node)
@@ -451,15 +485,14 @@ def _invalidate_class_bindings(class_bindings, names):
def _apply_module_stmt_to_env(stmt, env, class_bindings=None): def _apply_module_stmt_to_env(stmt, env, class_bindings=None):
names = _mutating_call_target_names(stmt) names = _mutating_call_target_names(stmt)
_invalidate_class_bindings(class_bindings, names) _invalidate_class_bindings(class_bindings, names)
for name in names: _invalidate_env_names(env, names)
env.pop(name, None)
if isinstance(stmt, ast.ClassDef): if isinstance(stmt, ast.ClassDef):
if class_bindings is not None: if class_bindings is not None:
if stmt.decorator_list: if stmt.decorator_list:
class_bindings.pop(stmt.name, None) class_bindings.pop(stmt.name, None)
else: else:
class_bindings[stmt.name] = (stmt, dict(env)) class_bindings[stmt.name] = (stmt, dict(env))
env.pop(stmt.name, None) _invalidate_env_name(env, stmt.name)
return return
if isinstance(stmt, ast.Assign): if isinstance(stmt, ast.Assign):
names = _assignment_target_names(stmt) names = _assignment_target_names(stmt)
@@ -469,7 +502,7 @@ def _apply_module_stmt_to_env(stmt, env, class_bindings=None):
subscript_root = _mutable_env_subscript_root(stmt.value, env) subscript_root = _mutable_env_subscript_root(stmt.value, env)
if subscript_root is not None: if subscript_root is not None:
env.pop(subscript_root, None) env.pop(subscript_root, None)
env.pop(name, None) _invalidate_env_name(env, name)
return return
if ( if (
isinstance(stmt.value, ast.Name) isinstance(stmt.value, ast.Name)
@@ -477,15 +510,14 @@ def _apply_module_stmt_to_env(stmt, env, class_bindings=None):
and _is_mutable_static_value(env[stmt.value.id]) and _is_mutable_static_value(env[stmt.value.id])
): ):
env.pop(stmt.value.id, None) env.pop(stmt.value.id, None)
env.pop(name, None) _invalidate_env_name(env, name)
return return
try: try:
env[name] = _literal(stmt.value, env) env[name] = _literal(stmt.value, env)
except UnsupportedStaticExpression: except UnsupportedStaticExpression:
env.pop(name, None) _invalidate_env_name(env, name)
else: else:
for name in names: _invalidate_env_names(env, names)
env.pop(name, None)
return return
if isinstance(stmt, ast.AnnAssign): if isinstance(stmt, ast.AnnAssign):
names = _assignment_target_names(stmt) names = _assignment_target_names(stmt)
@@ -497,7 +529,7 @@ def _apply_module_stmt_to_env(stmt, env, class_bindings=None):
subscript_root = _mutable_env_subscript_root(stmt.value, env) subscript_root = _mutable_env_subscript_root(stmt.value, env)
if subscript_root is not None: if subscript_root is not None:
env.pop(subscript_root, None) env.pop(subscript_root, None)
env.pop(name, None) _invalidate_env_name(env, name)
return return
if ( if (
isinstance(stmt.value, ast.Name) isinstance(stmt.value, ast.Name)
@@ -505,33 +537,29 @@ def _apply_module_stmt_to_env(stmt, env, class_bindings=None):
and _is_mutable_static_value(env[stmt.value.id]) and _is_mutable_static_value(env[stmt.value.id])
): ):
env.pop(stmt.value.id, None) env.pop(stmt.value.id, None)
env.pop(name, None) _invalidate_env_name(env, name)
return return
try: try:
env[name] = _literal(stmt.value, env) env[name] = _literal(stmt.value, env)
except UnsupportedStaticExpression: except UnsupportedStaticExpression:
env.pop(name, None) _invalidate_env_name(env, name)
else: else:
for name in names: _invalidate_env_names(env, names)
env.pop(name, None)
return return
if isinstance(stmt, ast.AugAssign): if isinstance(stmt, ast.AugAssign):
names = _assignment_target_names(stmt) names = _assignment_target_names(stmt)
_invalidate_class_bindings(class_bindings, names) _invalidate_class_bindings(class_bindings, names)
for name in names: _invalidate_env_names(env, names)
env.pop(name, None)
return return
if isinstance(stmt, ast.Delete): if isinstance(stmt, ast.Delete):
names = _delete_target_names(stmt) names = _delete_target_names(stmt)
_invalidate_class_bindings(class_bindings, names) _invalidate_class_bindings(class_bindings, names)
for name in names: _invalidate_env_names(env, names)
env.pop(name, None)
return return
if isinstance(stmt, ast.Expr): if isinstance(stmt, ast.Expr):
names = _bound_names(stmt) names = _bound_names(stmt)
_invalidate_class_bindings(class_bindings, names) _invalidate_class_bindings(class_bindings, names)
for name in names: _invalidate_env_names(env, names)
env.pop(name, None)
return return
if isinstance(stmt, _CONTROL_FLOW_TYPES): if isinstance(stmt, _CONTROL_FLOW_TYPES):
if _has_wildcard_import_in_control_flow(stmt): if _has_wildcard_import_in_control_flow(stmt):
@@ -541,8 +569,7 @@ def _apply_module_stmt_to_env(stmt, env, class_bindings=None):
return return
names = _assigned_names_in_control_flow(stmt) names = _assigned_names_in_control_flow(stmt)
_invalidate_class_bindings(class_bindings, names) _invalidate_class_bindings(class_bindings, names)
for name in names: _invalidate_env_names(env, names)
env.pop(name, None)
return return
if _has_wildcard_import(stmt): if _has_wildcard_import(stmt):
env.clear() env.clear()
@@ -551,8 +578,7 @@ def _apply_module_stmt_to_env(stmt, env, class_bindings=None):
return return
names = _bound_names(stmt) names = _bound_names(stmt)
_invalidate_class_bindings(class_bindings, names) _invalidate_class_bindings(class_bindings, names)
for name in names: _invalidate_env_names(env, names)
env.pop(name, None)
def _collect_module_env(tree, class_bindings=None): def _collect_module_env(tree, class_bindings=None):
@@ -586,8 +612,13 @@ def _mutable_env_subscript_root(node, env):
return None return None
def _input_types_decorators_are_supported(decorators): def _input_types_decorators_are_supported(decorators, classmethod_shadowed):
return all(isinstance(decorator, ast.Name) and decorator.id == "classmethod" for decorator in decorators) for decorator in decorators:
if not isinstance(decorator, ast.Name) or decorator.id != "classmethod":
return False
if classmethod_shadowed:
return False
return True
def _class_attr_alias_sources(value, name, aliases): def _class_attr_alias_sources(value, name, aliases):
@@ -721,13 +752,14 @@ def _class_attr(cls, name, env):
return value return value
def _input_types(cls, env): def _input_types(cls, env, decorator_env):
value = _MISSING value = _MISSING
classmethod_shadowed = "classmethod" in decorator_env
for stmt in cls.body: for stmt in cls.body:
if "INPUT_TYPES" in _mutating_call_target_names(stmt): if "INPUT_TYPES" in _mutating_call_target_names(stmt):
value = _INVALID value = _INVALID
if isinstance(stmt, ast.FunctionDef) and stmt.name == "INPUT_TYPES": if isinstance(stmt, ast.FunctionDef) and stmt.name == "INPUT_TYPES":
if not _input_types_decorators_are_supported(stmt.decorator_list): if not _input_types_decorators_are_supported(stmt.decorator_list, classmethod_shadowed):
value = _INVALID value = _INVALID
continue continue
if len(stmt.body) != 1 or not isinstance(stmt.body[0], ast.Return): if len(stmt.body) != 1 or not isinstance(stmt.body[0], ast.Return):
@@ -743,6 +775,13 @@ def _input_types(cls, env):
if isinstance(stmt, ast.AsyncFunctionDef) and stmt.name == "INPUT_TYPES": if isinstance(stmt, ast.AsyncFunctionDef) and stmt.name == "INPUT_TYPES":
value = _INVALID value = _INVALID
continue continue
if "classmethod" in (
_assignment_target_names(stmt)
| _delete_target_names(stmt)
| _bound_names(stmt)
| _mutating_call_target_names(stmt)
):
classmethod_shadowed = True
if isinstance(stmt, (ast.Assign, ast.AnnAssign, ast.AugAssign)): if isinstance(stmt, (ast.Assign, ast.AnnAssign, ast.AugAssign)):
if "INPUT_TYPES" in _assignment_target_names(stmt): if "INPUT_TYPES" in _assignment_target_names(stmt):
value = _INVALID value = _INVALID
@@ -894,7 +933,16 @@ def _update_class_attribute_aliases(stmt, class_attribute_aliases, class_aliases
class_attribute_aliases[stmt.target.id] = sources class_attribute_aliases[stmt.target.id] = sources
def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=None): def _module_class_attribute_invalidated_names(stmt, class_aliases, class_attribute_aliases):
names = _expanded_class_attribute_names(
_class_attribute_mutation_target_names(stmt),
class_aliases,
)
names.update(_class_attribute_alias_invalidated_names(stmt, class_attribute_aliases))
return names
def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=None, return_state=False):
value_invalidated_by_names = value_invalidated_by_names or (lambda _value, _names: False) value_invalidated_by_names = value_invalidated_by_names or (lambda _value, _names: False)
value = _MISSING value = _MISSING
env = {} env = {}
@@ -905,18 +953,14 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N
def advance_module_state(stmt): def advance_module_state(stmt):
_invalidate_class_bindings( _invalidate_class_bindings(
class_bindings, class_bindings,
_class_attribute_alias_invalidated_names(stmt, class_attribute_aliases), _module_class_attribute_invalidated_names(stmt, class_aliases, class_attribute_aliases),
) )
_apply_module_stmt_to_env(stmt, env, class_bindings) _apply_module_stmt_to_env(stmt, env, class_bindings)
_update_class_aliases(stmt, class_aliases, class_bindings) _update_class_aliases(stmt, class_aliases, class_bindings)
_update_class_attribute_aliases(stmt, class_attribute_aliases, class_aliases, class_bindings) _update_class_attribute_aliases(stmt, class_attribute_aliases, class_aliases, class_bindings)
for stmt in tree.body: for stmt in tree.body:
class_attr_names = _expanded_class_attribute_names( class_attr_names = _module_class_attribute_invalidated_names(stmt, class_aliases, class_attribute_aliases)
_class_attribute_mutation_target_names(stmt),
class_aliases,
)
class_attr_names.update(_class_attribute_alias_invalidated_names(stmt, class_attribute_aliases))
if ( if (
value not in (_MISSING, _INVALID) value not in (_MISSING, _INVALID)
and class_attr_names and class_attr_names
@@ -986,6 +1030,8 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N
if name in _bound_names(stmt): if name in _bound_names(stmt):
value = _INVALID value = _INVALID
advance_module_state(stmt) advance_module_state(stmt)
if return_state:
return value
if value in (_MISSING, _INVALID): if value in (_MISSING, _INVALID):
return {} return {}
return value return value
@@ -1022,12 +1068,17 @@ def _display_mappings(tree):
tree, tree,
"NODE_DISPLAY_NAME_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS",
lambda value, env, _class_bindings: _literal(value, env), lambda value, env, _class_bindings: _literal(value, env),
return_state=True,
) )
if displays is _MISSING:
return {}
if displays is _INVALID:
return _INVALID
return {str(k): str(v) for k, v in displays.items()} return {str(k): str(v) for k, v in displays.items()}
def _signature_from_class(node_type, cls, display, pack_meta, class_env, input_env): def _signature_from_class(node_type, cls, display, pack_meta, class_env, input_env):
input_types = _input_types(cls, input_env) input_types = _input_types(cls, input_env, class_env)
return_types = _class_attr(cls, "RETURN_TYPES", class_env) return_types = _class_attr(cls, "RETURN_TYPES", class_env)
return_names = _class_attr(cls, "RETURN_NAMES", class_env) return_names = _class_attr(cls, "RETURN_NAMES", class_env)
if return_types is _INVALID or return_names is _INVALID: if return_types is _INVALID or return_names is _INVALID:
@@ -1097,6 +1148,8 @@ def extract_repo_signatures(repo_dir, pack_meta):
env = _collect_module_env(tree) env = _collect_module_env(tree)
mappings = _node_class_mappings(tree) mappings = _node_class_mappings(tree)
displays = _display_mappings(tree) displays = _display_mappings(tree)
if displays is _INVALID:
continue
for node_type, binding in sorted(mappings.items()): for node_type, binding in sorted(mappings.items()):
cls, class_env = binding cls, class_env = binding
sig = _signature_from_class(node_type, cls, displays.get(node_type), pack_meta, class_env, env) sig = _signature_from_class(node_type, cls, displays.get(node_type), pack_meta, class_env, env)