Fail closed on duplicate nodes and observed input types

This commit is contained in:
2026-07-02 18:08:48 +02:00
parent 3219ec0c39
commit bf46f9b389
2 changed files with 175 additions and 4 deletions
@@ -209,6 +209,57 @@ NODE_CLASS_MAPPINGS = {
self.assertIn("GoodUtf8Node", result["nodes"]) self.assertIn("GoodUtf8Node", result["nodes"])
self.assertEqual("ok", result["pack"]["status"]) self.assertEqual("ok", result["pack"]["status"])
def test_duplicate_node_ids_across_files_are_skipped(self):
source_a = '''
class FirstDupNode:
RETURN_TYPES = ("IMAGE",)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
},
}
NODE_CLASS_MAPPINGS = {
"DupNode": FirstDupNode,
}
'''
source_b = '''
class SecondDupNode:
RETURN_TYPES = ("MASK",)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
},
}
NODE_CLASS_MAPPINGS = {
"DupNode": SecondDupNode,
}
'''
with tempfile.TemporaryDirectory() as tmp:
Path(tmp, "a.py").write_text(textwrap.dedent(source_a), encoding="utf-8")
Path(tmp, "b.py").write_text(textwrap.dedent(source_b), encoding="utf-8")
result = extract_repo_signatures(
Path(tmp),
{
"id": "duplicate-node-pack",
"title": "Duplicate Node Pack",
"repository": "https://github.com/example/duplicate-node-pack",
"rank": 1,
},
)
self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_unsupported_reassignment_invalidates_static_env_value(self): def test_unsupported_reassignment_invalidates_static_env_value(self):
source = ''' source = '''
def build_inputs(): def build_inputs():
@@ -3391,6 +3442,57 @@ NODE_CLASS_MAPPINGS = {
self.assertEqual({}, result["nodes"]) self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"]) self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_input_types_observed_by_arbitrary_call_skips_node(self):
source = '''
class ObservedInputTypesNode:
RETURN_TYPES = ("IMAGE",)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
},
}
observe(INPUT_TYPES)
NODE_CLASS_MAPPINGS = {
"ObservedInputTypesNode": ObservedInputTypesNode,
}
'''
result = self._extract_source(source, "observed-input-types-pack")
self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_input_types_alias_observed_by_arbitrary_call_skips_node(self):
source = '''
class AliasObservedInputTypesNode:
RETURN_TYPES = ("IMAGE",)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
},
}
ALIAS = INPUT_TYPES
observe(ALIAS)
NODE_CLASS_MAPPINGS = {
"AliasObservedInputTypesNode": AliasObservedInputTypesNode,
}
'''
result = self._extract_source(source, "alias-observed-input-types-pack")
self.assertEqual({}, result["nodes"])
self.assertEqual("no_static_nodes", result["pack"]["status"])
def test_write_artifact_is_deterministic(self): def test_write_artifact_is_deterministic(self):
with tempfile.TemporaryDirectory() as tmp: with tempfile.TemporaryDirectory() as tmp:
out_one = Path(tmp, "one.json") out_one = Path(tmp, "one.json")
+73 -4
View File
@@ -955,6 +955,26 @@ def _update_class_attr_aliases_from_unpack(target, value, name, aliases):
return found return found
def _input_types_alias_sources(value, aliases):
if isinstance(value, ast.Name):
return value.id == "INPUT_TYPES" or value.id in aliases
if isinstance(value, (ast.Tuple, ast.List)):
return any(_input_types_alias_sources(item, aliases) for item in value.elts)
return False
def _update_input_types_aliases_from_unpack(target, value, aliases):
found = False
for target_item, value_item in _unpack_target_value_pairs(target, value):
target_name = _alias_target_name(target_item)
if target_name is None:
continue
if _input_types_alias_sources(value_item, aliases):
aliases.add(target_name)
found = True
return found
def _class_attr(cls, name, env): def _class_attr(cls, name, env):
value = _MISSING value = _MISSING
aliases = set() aliases = set()
@@ -1074,9 +1094,14 @@ def _class_attr(cls, name, env):
def _input_types(cls, env, decorator_env): def _input_types(cls, env, decorator_env):
value = _MISSING value = _MISSING
aliases = set()
classmethod_shadowed = "classmethod" in decorator_env classmethod_shadowed = "classmethod" in decorator_env
for stmt in cls.body: for stmt in cls.body:
if "INPUT_TYPES" in _mutating_call_target_names(stmt): mutating_targets = _mutating_call_target_names(stmt)
observed_targets = _arbitrary_call_observed_names(stmt)
if "INPUT_TYPES" in mutating_targets or aliases.intersection(mutating_targets):
value = _INVALID
if "INPUT_TYPES" in observed_targets or aliases.intersection(observed_targets):
value = _INVALID value = _INVALID
if isinstance(stmt, ast.FunctionDef) and stmt.name == "INPUT_TYPES": if isinstance(stmt, ast.FunctionDef) and stmt.name == "INPUT_TYPES":
if not _input_types_decorators_are_supported(stmt.decorator_list, classmethod_shadowed): if not _input_types_decorators_are_supported(stmt.decorator_list, classmethod_shadowed):
@@ -1095,14 +1120,48 @@ def _input_types(cls, env, decorator_env):
if isinstance(stmt, ast.AsyncFunctionDef) and stmt.name == "INPUT_TYPES": if isinstance(stmt, ast.AsyncFunctionDef) and stmt.name == "INPUT_TYPES":
value = _INVALID value = _INVALID
continue continue
rebound_names = _assignment_target_names(stmt) | _delete_target_names(stmt) | _bound_names(stmt)
aliases.difference_update(rebound_names)
if "classmethod" in ( if "classmethod" in (
_assignment_target_names(stmt) _assignment_target_names(stmt)
| _delete_target_names(stmt) | _delete_target_names(stmt)
| _bound_names(stmt) | _bound_names(stmt)
| _mutating_call_target_names(stmt) | mutating_targets
): ):
classmethod_shadowed = True classmethod_shadowed = True
if isinstance(stmt, (ast.Assign, ast.AnnAssign, ast.AugAssign)): if isinstance(stmt, ast.Assign):
target_names = _assignment_target_names(stmt)
if (
len(stmt.targets) == 1
and isinstance(stmt.targets[0], ast.Name)
and stmt.targets[0].id != "INPUT_TYPES"
and _input_types_alias_sources(stmt.value, aliases)
):
aliases.add(stmt.targets[0].id)
continue
if (
len(stmt.targets) == 1
and "INPUT_TYPES" not in target_names
and _update_input_types_aliases_from_unpack(stmt.targets[0], stmt.value, aliases)
):
continue
if "INPUT_TYPES" in target_names:
value = _INVALID
continue
if isinstance(stmt, ast.AnnAssign):
target_names = _assignment_target_names(stmt)
if (
isinstance(stmt.target, ast.Name)
and stmt.target.id != "INPUT_TYPES"
and stmt.value is not None
and _input_types_alias_sources(stmt.value, aliases)
):
aliases.add(stmt.target.id)
continue
if "INPUT_TYPES" in target_names:
value = _INVALID
continue
if isinstance(stmt, ast.AugAssign):
if "INPUT_TYPES" in _assignment_target_names(stmt): if "INPUT_TYPES" in _assignment_target_names(stmt):
value = _INVALID value = _INVALID
continue continue
@@ -1111,7 +1170,7 @@ def _input_types(cls, env, decorator_env):
value = _INVALID value = _INVALID
continue continue
if isinstance(stmt, ast.Expr): if isinstance(stmt, ast.Expr):
if "INPUT_TYPES" in _mutating_call_target_names(stmt): if "INPUT_TYPES" in mutating_targets:
value = _INVALID value = _INVALID
if "INPUT_TYPES" in _bound_names(stmt): if "INPUT_TYPES" in _bound_names(stmt):
value = _INVALID value = _INVALID
@@ -1602,6 +1661,8 @@ def _parse_python_file(path):
def extract_repo_signatures(repo_dir, pack_meta): def extract_repo_signatures(repo_dir, pack_meta):
nodes = {} nodes = {}
node_sources = {}
duplicate_node_types = set()
for path in sorted(_python_files(repo_dir)): for path in sorted(_python_files(repo_dir)):
tree = _parse_python_file(path) tree = _parse_python_file(path)
if tree is None: if tree is None:
@@ -1612,6 +1673,14 @@ def extract_repo_signatures(repo_dir, pack_meta):
if displays is _INVALID: if displays is _INVALID:
continue continue
for node_type, binding in sorted(mappings.items()): for node_type, binding in sorted(mappings.items()):
prior_path = node_sources.get(node_type)
if prior_path is not None and prior_path != path:
duplicate_node_types.add(node_type)
nodes.pop(node_type, None)
continue
node_sources.setdefault(node_type, path)
if node_type in duplicate_node_types:
continue
cls, class_env = binding cls, class_env = binding
sig = _signature_from_class(node_type, cls, displays.get(node_type), pack_meta, class_env, env) sig = _signature_from_class(node_type, cls, displays.get(node_type), pack_meta, class_env, env)
if sig is not None: if sig is not None: