diff --git a/tests/test_generate_popular_node_signatures.py b/tests/test_generate_popular_node_signatures.py index f1d81dd..3ba51ea 100644 --- a/tests/test_generate_popular_node_signatures.py +++ b/tests/test_generate_popular_node_signatures.py @@ -1730,6 +1730,110 @@ PatchedInputTypesNode.INPUT_TYPES = build_inputs self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_duplicate_node_mapping_key_with_dynamic_value_skips_node(self): + source = ''' +def build_node(): + return object() + + +class DuplicateMappingKeyNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "DuplicateMappingKeyNode": DuplicateMappingKeyNode, + "DuplicateMappingKeyNode": build_node(), +} +''' + result = self._extract_source(source, "duplicate-mapping-key-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + + def test_module_class_alias_patch_after_mapping_skips_node(self): + source = ''' +class AliasPatchedNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "AliasPatchedNode": AliasPatchedNode, +} +Alias = AliasPatchedNode +Alias.RETURN_TYPES = ("MASK",) +''' + result = self._extract_source(source, "alias-patched-node-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + + def test_definition_time_class_attribute_mutation_after_mapping_skips_node(self): + source = ''' +class DefinitionTimeMutatedMappedNode: + RETURN_TYPES = ["IMAGE"] + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "DefinitionTimeMutatedMappedNode": DefinitionTimeMutatedMappedNode, +} +def helper(x=DefinitionTimeMutatedMappedNode.RETURN_TYPES.clear()): + pass +''' + result = self._extract_source(source, "definition-time-mutated-mapped-node-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + + def test_unhashable_node_mapping_key_skips_repo_without_raising(self): + source = ''' +KEY = ["UnhashableMappingKeyNode"] + + +class UnhashableMappingKeyNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + KEY: UnhashableMappingKeyNode, +} +''' + result = self._extract_source(source, "unhashable-mapping-key-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_mutated_node_class_mapping_skips_node(self): source = ''' class MutatedMappingNode: diff --git a/tools/generate_popular_node_signatures.py b/tools/generate_popular_node_signatures.py index ebc7f85..b8ea88b 100644 --- a/tools/generate_popular_node_signatures.py +++ b/tools/generate_popular_node_signatures.py @@ -107,17 +107,35 @@ def _class_attribute_mutation_target_names(stmt): names = set() class AttributeMutationVisitor(ast.NodeVisitor): + def _visit_function_definition_expressions(self, node): + for decorator in node.decorator_list: + self.visit(decorator) + self.visit(node.args) + if node.returns is not None: + self.visit(node.returns) + for type_param in getattr(node, "type_params", ()): + self.visit(type_param) + def visit_FunctionDef(self, node): - return None + self._visit_function_definition_expressions(node) def visit_AsyncFunctionDef(self, node): - return None + self._visit_function_definition_expressions(node) def visit_ClassDef(self, node): - return None + for decorator in node.decorator_list: + self.visit(decorator) + for base in node.bases: + self.visit(base) + for keyword in node.keywords: + self.visit(keyword.value) + for type_param in getattr(node, "type_params", ()): + self.visit(type_param) + for child in node.body: + self.visit(child) def visit_Lambda(self, node): - return None + self.visit(node.args) def visit_Assign(self, node): for target in node.targets: @@ -698,20 +716,73 @@ def _module_dict_entries(node, env, class_bindings, value_converter): for key, value in zip(node.keys, node.values): if key is None: raise UnsupportedStaticExpression("dict unpacking is not supported") + key_value = _literal(key, env) + try: + hash(key_value) + except TypeError as exc: + raise UnsupportedStaticExpression("unhashable dict key") from exc + if key_value in result: + raise UnsupportedStaticExpression("duplicate dict key") converted_value = value_converter(value, env, class_bindings) if converted_value is None: - continue - result[_literal(key, env)] = converted_value + raise UnsupportedStaticExpression("unsupported dict value") + result[key_value] = converted_value return result +def _class_alias_sources(value, class_aliases, class_bindings): + if not isinstance(value, ast.Name): + return set() + if value.id in class_aliases: + return set(class_aliases[value.id]) + if value.id in class_bindings: + return {value.id} + return set() + + +def _update_class_aliases(stmt, class_aliases, class_bindings): + rebound_names = _assignment_target_names(stmt) | _delete_target_names(stmt) | _bound_names(stmt) + for name in rebound_names: + class_aliases.pop(name, None) + + if isinstance(stmt, ast.ClassDef): + if stmt.name in class_bindings and not stmt.decorator_list: + class_aliases[stmt.name] = {stmt.name} + return + + if isinstance(stmt, ast.Assign) and len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name): + sources = _class_alias_sources(stmt.value, class_aliases, class_bindings) + if sources: + class_aliases[stmt.targets[0].id] = sources + elif isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name) and stmt.value is not None: + sources = _class_alias_sources(stmt.value, class_aliases, class_bindings) + if sources: + class_aliases[stmt.target.id] = sources + + +def _expanded_class_attribute_names(names, class_aliases): + expanded = set(names) + for name in names: + expanded.update(class_aliases.get(name, ())) + return expanded + + def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=None): value_invalidated_by_names = value_invalidated_by_names or (lambda _value, _names: False) value = _MISSING env = {} class_bindings = {} + class_aliases = {} + + def advance_module_state(stmt): + _apply_module_stmt_to_env(stmt, env, class_bindings) + _update_class_aliases(stmt, class_aliases, class_bindings) + for stmt in tree.body: - class_attr_names = _class_attribute_mutation_target_names(stmt) + class_attr_names = _expanded_class_attribute_names( + _class_attribute_mutation_target_names(stmt), + class_aliases, + ) if ( value not in (_MISSING, _INVALID) and class_attr_names @@ -724,7 +795,7 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N if not _name_is_assigned(stmt, name): if isinstance(stmt.value, ast.Name) and stmt.value.id == name: value = _INVALID - _apply_module_stmt_to_env(stmt, env, class_bindings) + advance_module_state(stmt) continue if len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name): try: @@ -733,13 +804,13 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N value = _INVALID else: value = _INVALID - _apply_module_stmt_to_env(stmt, env, class_bindings) + advance_module_state(stmt) continue if isinstance(stmt, ast.AnnAssign): if not _name_is_assigned(stmt, name): if isinstance(stmt.value, ast.Name) and stmt.value.id == name: value = _INVALID - _apply_module_stmt_to_env(stmt, env, class_bindings) + advance_module_state(stmt) continue if isinstance(stmt.target, ast.Name) and stmt.value is not None: try: @@ -748,39 +819,39 @@ def _final_module_dict(tree, name, value_converter, value_invalidated_by_names=N value = _INVALID else: value = _INVALID - _apply_module_stmt_to_env(stmt, env, class_bindings) + advance_module_state(stmt) continue if isinstance(stmt, ast.AugAssign): if _name_is_assigned(stmt, name): value = _INVALID - _apply_module_stmt_to_env(stmt, env, class_bindings) + advance_module_state(stmt) continue if isinstance(stmt, ast.Delete): if name in _delete_target_names(stmt): value = _INVALID - _apply_module_stmt_to_env(stmt, env, class_bindings) + advance_module_state(stmt) continue if isinstance(stmt, ast.Expr): if name in _mutating_call_target_names(stmt): value = _INVALID if name in _bound_names(stmt): value = _INVALID - _apply_module_stmt_to_env(stmt, env, class_bindings) + advance_module_state(stmt) continue if isinstance(stmt, _CONTROL_FLOW_TYPES): if name in _assigned_names_in_control_flow(stmt): value = _INVALID if _has_wildcard_import_in_control_flow(stmt): value = _INVALID - _apply_module_stmt_to_env(stmt, env, class_bindings) + advance_module_state(stmt) continue if _has_wildcard_import(stmt): value = _INVALID - _apply_module_stmt_to_env(stmt, env, class_bindings) + advance_module_state(stmt) continue if name in _bound_names(stmt): value = _INVALID - _apply_module_stmt_to_env(stmt, env, class_bindings) + advance_module_state(stmt) if value in (_MISSING, _INVALID): return {} return value