diff --git a/tests/test_generate_popular_node_signatures.py b/tests/test_generate_popular_node_signatures.py index 2d3599f..59947d6 100644 --- a/tests/test_generate_popular_node_signatures.py +++ b/tests/test_generate_popular_node_signatures.py @@ -1960,6 +1960,80 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_bare_return_types_reference_skips_node(self): + source = ''' +class BareReturnTypesReferenceNode: + RETURN_TYPES = ("IMAGE",) + RETURN_TYPES + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "BareReturnTypesReferenceNode": BareReturnTypesReferenceNode, +} +''' + result = self._extract_source(source, "bare-return-types-reference-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + + def test_bare_return_types_alias_reference_skips_node(self): + source = ''' +class BareReturnTypesAliasReferenceNode: + RETURN_TYPES = ("IMAGE",) + ALIAS = RETURN_TYPES + ALIAS + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "BareReturnTypesAliasReferenceNode": BareReturnTypesAliasReferenceNode, +} +''' + result = self._extract_source(source, "bare-return-types-alias-reference-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + + def test_bare_return_names_reference_skips_node(self): + source = ''' +class BareReturnNamesReferenceNode: + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ("image",) + RETURN_NAMES + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "BareReturnNamesReferenceNode": BareReturnNamesReferenceNode, +} +''' + result = self._extract_source(source, "bare-return-names-reference-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_return_types_chained_alias_mutation_skips_node(self): source = ''' class ChainedAliasMutatedReturnTypesNode: @@ -4595,6 +4669,56 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_bare_input_types_reference_skips_node(self): + source = ''' +class BareInputTypesReferenceNode: + RETURN_TYPES = ("IMAGE",) + INPUT_TYPES + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "BareInputTypesReferenceNode": BareInputTypesReferenceNode, +} +''' + result = self._extract_source(source, "bare-input-types-reference-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + + def test_bare_input_types_alias_reference_skips_node(self): + source = ''' +class BareInputTypesAliasReferenceNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + ALIAS = INPUT_TYPES + ALIAS + + +NODE_CLASS_MAPPINGS = { + "BareInputTypesAliasReferenceNode": BareInputTypesAliasReferenceNode, +} +''' + result = self._extract_source(source, "bare-input-types-alias-reference-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_input_types_default_observed_by_arbitrary_call_skips_node(self): source = ''' class DefaultObservedInputTypesNode: diff --git a/tools/generate_popular_node_signatures.py b/tools/generate_popular_node_signatures.py index bad507d..c9bdc24 100644 --- a/tools/generate_popular_node_signatures.py +++ b/tools/generate_popular_node_signatures.py @@ -677,6 +677,44 @@ def _definition_time_referenced_names(stmt): return names +def _class_body_expression_referenced_names(stmt): + if not isinstance(stmt, ast.Expr): + return set() + + names = set() + + class ClassBodyExpressionReferenceVisitor(ast.NodeVisitor): + def visit_Call(self, child): + name = _namespace_lookup_name(child) + if name is not None: + names.add(name) + self.generic_visit(child) + + def visit_Subscript(self, child): + name = _namespace_subscript_name(child) + if name is not None: + names.add(name) + self.generic_visit(child) + + def visit_Lambda(self, child): + self.visit(child.args) + + def visit_FunctionDef(self, child): + return None + + def visit_AsyncFunctionDef(self, child): + return None + + def visit_ClassDef(self, child): + return None + + def visit_Name(self, child): + names.add(child.id) + + ClassBodyExpressionReferenceVisitor().visit(stmt.value) + return names + + def _assigned_names_in_control_flow(stmt): names = _mutating_call_target_names(stmt) | _arbitrary_call_observed_names(stmt) @@ -1182,6 +1220,7 @@ def _class_attr(cls, name, env): for stmt in cls.body: mutating_targets = _mutating_call_target_names(stmt) observed_targets = _arbitrary_call_observed_names(stmt) + expression_references = _class_body_expression_referenced_names(stmt) if aliases.intersection(mutating_targets): value = _INVALID if name in mutating_targets: @@ -1190,6 +1229,10 @@ def _class_attr(cls, name, env): value = _INVALID if name in observed_targets: value = _INVALID + if aliases.intersection(expression_references): + value = _INVALID + if name in expression_references: + value = _INVALID if isinstance(stmt, ast.Assign): target_names = _assignment_target_names(stmt) if len(stmt.targets) > 1 and _class_attr_alias_sources(stmt.value, name, aliases): @@ -1307,6 +1350,7 @@ def _class_attr(cls, name, env): def _input_types(cls, env, decorator_env): value = _MISSING + sticky_invalid = False aliases = set() classmethod_shadowed = "classmethod" in decorator_env namespace_mutations = _class_body_namespace_mutation_names(cls) @@ -1316,6 +1360,7 @@ def _input_types(cls, env, decorator_env): mutating_targets = _mutating_call_target_names(stmt) observed_targets = _arbitrary_call_observed_names(stmt) definition_references = _definition_time_referenced_names(stmt) + expression_references = _class_body_expression_referenced_names(stmt) protected_definition_references = _CLASS_SIGNATURE_ATTRS | aliases input_types_invalidated = ( "INPUT_TYPES" in mutating_targets @@ -1323,27 +1368,37 @@ def _input_types(cls, env, decorator_env): or "INPUT_TYPES" in observed_targets or bool(aliases.intersection(observed_targets)) or bool(definition_references.intersection(protected_definition_references)) + or bool(expression_references.intersection(protected_definition_references)) ) if input_types_invalidated: value = _INVALID + sticky_invalid = True if isinstance(stmt, ast.FunctionDef) and stmt.name == "INPUT_TYPES": - if input_types_invalidated: + if input_types_invalidated or sticky_invalid: continue if not _input_types_decorators_are_supported(stmt.decorator_list, classmethod_shadowed): value = _INVALID + sticky_invalid = True continue if len(stmt.body) != 1 or not isinstance(stmt.body[0], ast.Return): value = _INVALID + sticky_invalid = True continue try: candidate = _literal(stmt.body[0].value, env) except UnsupportedStaticExpression: value = _INVALID + sticky_invalid = True continue - value = candidate if isinstance(candidate, dict) else _INVALID + if isinstance(candidate, dict): + value = candidate + else: + value = _INVALID + sticky_invalid = True continue if isinstance(stmt, ast.AsyncFunctionDef) and stmt.name == "INPUT_TYPES": value = _INVALID + sticky_invalid = True continue rebound_names = _assignment_target_names(stmt) | _delete_target_names(stmt) | _bound_names(stmt) aliases.difference_update(rebound_names)