diff --git a/tests/test_generate_popular_node_signatures.py b/tests/test_generate_popular_node_signatures.py index 2ec63dd..62cbb02 100644 --- a/tests/test_generate_popular_node_signatures.py +++ b/tests/test_generate_popular_node_signatures.py @@ -666,6 +666,58 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_class_return_types_uses_definition_time_module_env(self): + source = ''' +RETURNS = ("IMAGE",) + + +class DefinitionTimeReturnTypesNode: + RETURN_TYPES = RETURNS + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +RETURNS = ("MASK",) + +NODE_CLASS_MAPPINGS = { + "DefinitionTimeReturnTypesNode": DefinitionTimeReturnTypesNode, +} +''' + result = self._extract_source(source, "definition-time-return-pack") + + self.assertEqual(["IMAGE"], result["nodes"]["DefinitionTimeReturnTypesNode"]["outputs"]) + self.assertEqual("ok", result["pack"]["status"]) + + def test_subscript_assignment_to_return_types_skips_node(self): + source = ''' +class SubscriptMutatedReturnTypesNode: + RETURN_TYPES = ["IMAGE"] + RETURN_TYPES[0] = "MASK" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "SubscriptMutatedReturnTypesNode": SubscriptMutatedReturnTypesNode, +} +''' + result = self._extract_source(source, "subscript-mutated-return-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_final_static_return_types_assignment_wins(self): source = ''' class FinalReturnTypesNode: diff --git a/tools/generate_popular_node_signatures.py b/tools/generate_popular_node_signatures.py index f674bbb..12f9dd0 100644 --- a/tools/generate_popular_node_signatures.py +++ b/tools/generate_popular_node_signatures.py @@ -283,9 +283,11 @@ def _has_module_wildcard_import(tree): return False -def _collect_module_env(tree): +def _collect_module_env(tree, class_envs=None): env = {} for stmt in tree.body: + if class_envs is not None and isinstance(stmt, ast.ClassDef): + class_envs[stmt.name] = dict(env) if isinstance(stmt, ast.Assign): names = _assignment_target_names(stmt) if len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name): @@ -372,7 +374,7 @@ def _class_attr(cls, name, env): value = _MISSING for stmt in cls.body: if isinstance(stmt, ast.Assign): - if not any(isinstance(target, ast.Name) and target.id == name for target in stmt.targets): + if name not in _assignment_target_names(stmt): continue if len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name): try: @@ -383,17 +385,20 @@ def _class_attr(cls, name, env): value = _INVALID continue if isinstance(stmt, ast.AnnAssign): - if not isinstance(stmt.target, ast.Name) or stmt.target.id != name: + if name not in _assignment_target_names(stmt): continue - if stmt.value is None: + if isinstance(stmt.target, ast.Name) and stmt.value is None: continue - try: - value = _literal(stmt.value, env) - except UnsupportedStaticExpression: + if not isinstance(stmt.target, ast.Name): value = _INVALID + else: + try: + value = _literal(stmt.value, env) + except UnsupportedStaticExpression: + value = _INVALID continue if isinstance(stmt, ast.AugAssign): - if isinstance(stmt.target, ast.Name) and stmt.target.id == name: + if name in _assignment_target_names(stmt): value = _INVALID continue if isinstance(stmt, ast.Delete): @@ -589,7 +594,8 @@ def extract_repo_signatures(repo_dir, pack_meta): tree = _parse_python_file(path) if tree is None: continue - env = _collect_module_env(tree) + class_envs = {} + env = _collect_module_env(tree, class_envs) mappings = _node_class_mappings(tree, env) displays = _display_mappings(tree, env) classes = _class_defs(tree) @@ -597,7 +603,8 @@ def extract_repo_signatures(repo_dir, pack_meta): cls = classes.get(class_name) if cls is None: continue - sig = _signature_from_class(node_type, cls, displays.get(node_type), pack_meta, env) + class_env = class_envs.get(class_name, env) + sig = _signature_from_class(node_type, cls, displays.get(node_type), pack_meta, class_env) if sig is not None: nodes[node_type] = sig