diff --git a/tests/test_generate_popular_node_signatures.py b/tests/test_generate_popular_node_signatures.py index bb0d381..5216ed5 100644 --- a/tests/test_generate_popular_node_signatures.py +++ b/tests/test_generate_popular_node_signatures.py @@ -29,6 +29,12 @@ class StaticExtractionTests(unittest.TestCase): }, ) + def _skip_if_syntax_unsupported(self, source): + try: + compile(textwrap.dedent(source), "", "exec") + except SyntaxError as exc: + self.skipTest(f"syntax unsupported by this Python: {exc.msg}") + def test_normalise_input_spec_reduces_combo_lists(self): self.assertEqual("COMBO", normalise_input_spec((["nearest", "bilinear"],))) self.assertEqual("IMAGE", normalise_input_spec(("IMAGE",))) @@ -563,6 +569,99 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_except_handler_binding_invalidates_static_env_value(self): + source = ''' +INPUTS = { + "required": { + "image": ("IMAGE",), + }, +} +try: + pass +except Exception as INPUTS: + pass + + +class ExceptHandlerBoundInputEnvNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return INPUTS + + +NODE_CLASS_MAPPINGS = { + "ExceptHandlerBoundInputEnvNode": ExceptHandlerBoundInputEnvNode, +} +''' + result = self._extract_source(source, "except-handler-bound-input-env-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + + def test_trystar_assignment_invalidates_static_env_value(self): + source = ''' +def build_inputs(): + return {"required": {"mask": ("MASK",)}} + + +INPUTS = { + "required": { + "image": ("IMAGE",), + }, +} +try: + pass +except* RuntimeError: + INPUTS = build_inputs() + + +class TryStarRebindInputEnvNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return INPUTS + + +NODE_CLASS_MAPPINGS = { + "TryStarRebindInputEnvNode": TryStarRebindInputEnvNode, +} +''' + self._skip_if_syntax_unsupported(source) + result = self._extract_source(source, "trystar-rebind-input-env-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + + def test_type_alias_binding_invalidates_static_env_value(self): + source = ''' +INPUTS = { + "required": { + "image": ("IMAGE",), + }, +} +type INPUTS = int + + +class TypeAliasBoundInputEnvNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return INPUTS + + +NODE_CLASS_MAPPINGS = { + "TypeAliasBoundInputEnvNode": TypeAliasBoundInputEnvNode, +} +''' + self._skip_if_syntax_unsupported(source) + result = self._extract_source(source, "type-alias-bound-input-env-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_delete_invalidates_static_env_value(self): source = ''' INPUTS = { @@ -1023,6 +1122,33 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_except_handler_binding_to_return_types_skips_node(self): + source = ''' +class ExceptHandlerBoundReturnTypesNode: + RETURN_TYPES = ("IMAGE",) + try: + pass + except Exception as RETURN_TYPES: + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "ExceptHandlerBoundReturnTypesNode": ExceptHandlerBoundReturnTypesNode, +} +''' + result = self._extract_source(source, "except-handler-bound-return-types-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_return_types_alias_mutation_skips_node(self): source = ''' class AliasMutatedReturnTypesNode: @@ -1597,6 +1723,31 @@ ALIAS.clear() self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_type_alias_binding_invalidates_static_node_mapping(self): + source = ''' +class TypeAliasBoundMappingNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "TypeAliasBoundMappingNode": TypeAliasBoundMappingNode, +} +type NODE_CLASS_MAPPINGS = dict +''' + self._skip_if_syntax_unsupported(source) + result = self._extract_source(source, "type-alias-bound-mapping-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_wildcard_import_invalidates_static_node_mapping(self): source = ''' class WildcardImportMappingNode: diff --git a/tools/generate_popular_node_signatures.py b/tools/generate_popular_node_signatures.py index f1266f4..51bbd5e 100644 --- a/tools/generate_popular_node_signatures.py +++ b/tools/generate_popular_node_signatures.py @@ -33,6 +33,9 @@ _MUTATING_METHODS = { "sort", "update", } +_CONTROL_FLOW_TYPES = (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try, ast.With, ast.AsyncWith, ast.Match) +if hasattr(ast, "TryStar"): + _CONTROL_FLOW_TYPES += (ast.TryStar,) def _literal(node, env, allow_mutable_env=True): @@ -133,6 +136,8 @@ def _bound_names(stmt): names = set() if isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): names.add(stmt.name) + elif hasattr(ast, "TypeAlias") and isinstance(stmt, ast.TypeAlias): + names.update(_target_names(stmt.name)) elif isinstance(stmt, ast.Import): for alias in stmt.names: names.add(alias.asname or alias.name.split(".", 1)[0]) @@ -147,6 +152,9 @@ def _bound_names(stmt): elif isinstance(stmt, ast.Match): for case in stmt.cases: names.update(_pattern_bound_names(case.pattern)) + elif isinstance(stmt, ast.ExceptHandler): + if stmt.name: + names.add(stmt.name) names.update(_named_expr_target_names(stmt)) return names @@ -254,6 +262,13 @@ def _assigned_names_in_control_flow(stmt): def visit_Delete(self, node): names.update(_delete_target_names(node)) + def visit_ExceptHandler(self, node): + names.update(_bound_names(node)) + self.generic_visit(node) + + def visit_TypeAlias(self, node): + names.update(_bound_names(node)) + def visit_Expr(self, node): names.update(_mutating_call_target_names(node)) names.update(_named_expr_target_names(node)) @@ -312,7 +327,7 @@ def _has_module_wildcard_import(tree): for stmt in tree.body: if _has_wildcard_import(stmt): return True - if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try, ast.With, ast.AsyncWith, ast.Match)): + if isinstance(stmt, _CONTROL_FLOW_TYPES): if _has_wildcard_import_in_control_flow(stmt): return True return False @@ -397,7 +412,7 @@ def _apply_module_stmt_to_env(stmt, env, class_bindings=None): for name in names: env.pop(name, None) return - if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try, ast.With, ast.AsyncWith, ast.Match)): + if isinstance(stmt, _CONTROL_FLOW_TYPES): if _has_wildcard_import_in_control_flow(stmt): env.clear() if class_bindings is not None: @@ -530,7 +545,7 @@ def _class_attr(cls, name, env): if name in _bound_names(stmt): value = _INVALID continue - if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try, ast.With, ast.AsyncWith, ast.Match)): + if isinstance(stmt, _CONTROL_FLOW_TYPES): target_names = _assigned_names_in_control_flow(stmt) if aliases.intersection(target_names): value = _INVALID @@ -581,7 +596,7 @@ def _input_types(cls, env): if "INPUT_TYPES" in _bound_names(stmt): value = _INVALID continue - if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try, ast.With, ast.AsyncWith, ast.Match)): + if isinstance(stmt, _CONTROL_FLOW_TYPES): if "INPUT_TYPES" in _assigned_names_in_control_flow(stmt): value = _INVALID if _has_wildcard_import_in_control_flow(stmt): @@ -674,7 +689,7 @@ def _final_module_dict(tree, name, value_converter): value = _INVALID _apply_module_stmt_to_env(stmt, env, class_bindings) continue - if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try, ast.With, ast.AsyncWith, ast.Match)): + if isinstance(stmt, _CONTROL_FLOW_TYPES): if name in _assigned_names_in_control_flow(stmt): value = _INVALID if _has_wildcard_import_in_control_flow(stmt):