diff --git a/tests/test_generate_popular_node_signatures.py b/tests/test_generate_popular_node_signatures.py index 2096c0f..4bb286c 100644 --- a/tests/test_generate_popular_node_signatures.py +++ b/tests/test_generate_popular_node_signatures.py @@ -16,6 +16,19 @@ class StaticExtractionTests(unittest.TestCase): parsed = json.loads(text) return text.replace(parsed["generated_at"], "") + def _extract_source(self, source, pack_id="sample-pack"): + with tempfile.TemporaryDirectory() as tmp: + Path(tmp, "__init__.py").write_text(textwrap.dedent(source), encoding="utf-8") + return extract_repo_signatures( + Path(tmp), + { + "id": pack_id, + "title": "Sample Pack", + "repository": f"https://github.com/example/{pack_id}", + "rank": 1, + }, + ) + def test_normalise_input_spec_reduces_combo_lists(self): self.assertEqual("COMBO", normalise_input_spec((["nearest", "bilinear"],))) self.assertEqual("IMAGE", normalise_input_spec(("IMAGE",))) @@ -178,6 +191,179 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_annotated_reassignment_invalidates_static_env_value(self): + source = ''' +def build_inputs(): + return {"required": {"image": ("IMAGE",)}} + + +INPUTS = { + "required": { + "image": ("IMAGE",), + }, +} +INPUTS: dict = build_inputs() + + +class AnnotatedRebindNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return INPUTS + + +NODE_CLASS_MAPPINGS = { + "AnnotatedRebindNode": AnnotatedRebindNode, +} +''' + result = self._extract_source(source, "annotated-rebind-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + + def test_multi_target_reassignment_invalidates_static_env_value(self): + source = ''' +def build_inputs(): + return {"required": {"image": ("IMAGE",)}} + + +INPUTS = { + "required": { + "image": ("IMAGE",), + }, +} +OTHER = INPUTS = build_inputs() + + +class MultiTargetRebindNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return INPUTS + + +NODE_CLASS_MAPPINGS = { + "MultiTargetRebindNode": MultiTargetRebindNode, +} +''' + result = self._extract_source(source, "multi-target-rebind-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + + def test_augmented_assignment_invalidates_static_env_value(self): + source = ''' +INPUTS = { + "required": { + "image": ("IMAGE",), + }, +} +INPUTS += ({},) + + +class AugmentedRebindNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return INPUTS + + +NODE_CLASS_MAPPINGS = { + "AugmentedRebindNode": AugmentedRebindNode, +} +''' + result = self._extract_source(source, "augmented-rebind-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + + def test_control_flow_assignment_invalidates_static_env_value(self): + source = ''' +def build_inputs(): + return {"required": {"image": ("IMAGE",)}} + + +INPUTS = { + "required": { + "image": ("IMAGE",), + }, +} +if True: + INPUTS = build_inputs() + + +class ControlFlowRebindNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return INPUTS + + +NODE_CLASS_MAPPINGS = { + "ControlFlowRebindNode": ControlFlowRebindNode, +} +''' + result = self._extract_source(source, "control-flow-rebind-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + + def test_dynamic_return_types_reassignment_skips_node(self): + source = ''' +def build_outputs(): + return ("MASK",) + + +class DynamicReturnTypesNode: + RETURN_TYPES = ("IMAGE",) + RETURN_TYPES = build_outputs() + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "DynamicReturnTypesNode": DynamicReturnTypesNode, +} +''' + result = self._extract_source(source, "dynamic-return-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + + def test_final_static_return_types_assignment_wins(self): + source = ''' +class FinalReturnTypesNode: + RETURN_TYPES = ("IMAGE",) + RETURN_TYPES = ("MASK",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "mask": ("MASK",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "FinalReturnTypesNode": FinalReturnTypesNode, +} +''' + result = self._extract_source(source, "final-return-pack") + + self.assertEqual(["MASK"], result["nodes"]["FinalReturnTypesNode"]["outputs"]) + self.assertEqual("ok", result["pack"]["status"]) + def test_write_artifact_is_deterministic(self): with tempfile.TemporaryDirectory() as tmp: out_one = Path(tmp, "one.json") diff --git a/tools/generate_popular_node_signatures.py b/tools/generate_popular_node_signatures.py index f90c56f..9e904ac 100644 --- a/tools/generate_popular_node_signatures.py +++ b/tools/generate_popular_node_signatures.py @@ -16,6 +16,10 @@ class UnsupportedStaticExpression(Exception): pass +_MISSING = object() +_INVALID = object() + + def _literal(node, env): if isinstance(node, ast.Constant): return node.value @@ -35,19 +39,104 @@ def _literal(node, env): raise UnsupportedStaticExpression(type(node).__name__) +def _target_names(target): + if isinstance(target, ast.Name): + return {target.id} + if isinstance(target, (ast.List, ast.Tuple)): + names = set() + for item in target.elts: + names.update(_target_names(item)) + return names + if isinstance(target, ast.Starred): + return _target_names(target.value) + if isinstance(target, (ast.Attribute, ast.Subscript)): + return _target_names(target.value) + return set() + + +def _assignment_target_names(stmt): + if isinstance(stmt, ast.Assign): + names = set() + for target in stmt.targets: + names.update(_target_names(target)) + return names + if isinstance(stmt, (ast.AnnAssign, ast.AugAssign)): + return _target_names(stmt.target) + if isinstance(stmt, (ast.For, ast.AsyncFor)): + return _target_names(stmt.target) + return set() + + +def _assigned_names_in_control_flow(stmt): + names = set() + + class AssignmentVisitor(ast.NodeVisitor): + def visit_FunctionDef(self, node): + return None + + def visit_AsyncFunctionDef(self, node): + return None + + def visit_ClassDef(self, node): + return None + + def visit_Assign(self, node): + names.update(_assignment_target_names(node)) + + def visit_AnnAssign(self, node): + names.update(_assignment_target_names(node)) + + def visit_AugAssign(self, node): + names.update(_assignment_target_names(node)) + + def visit_For(self, node): + names.update(_assignment_target_names(node)) + self.generic_visit(node) + + def visit_AsyncFor(self, node): + names.update(_assignment_target_names(node)) + self.generic_visit(node) + + AssignmentVisitor().visit(stmt) + return names + + def _collect_module_env(tree): env = {} for stmt in tree.body: - if not isinstance(stmt, ast.Assign): + if isinstance(stmt, ast.Assign): + names = _assignment_target_names(stmt) + if len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name): + name = stmt.targets[0].id + try: + env[name] = _literal(stmt.value, env) + except UnsupportedStaticExpression: + env.pop(name, None) + else: + for name in names: + env.pop(name, None) continue - if len(stmt.targets) != 1 or not isinstance(stmt.targets[0], ast.Name): + if isinstance(stmt, ast.AnnAssign): + names = _assignment_target_names(stmt) + if stmt.value is None: + continue + if isinstance(stmt.target, ast.Name): + name = stmt.target.id + try: + env[name] = _literal(stmt.value, env) + except UnsupportedStaticExpression: + env.pop(name, None) + else: + for name in names: + env.pop(name, None) continue - name = stmt.targets[0].id - try: - env[name] = _literal(stmt.value, env) - except UnsupportedStaticExpression: - env.pop(name, None) + if isinstance(stmt, ast.AugAssign): + for name in _assignment_target_names(stmt): + env.pop(name, None) continue + if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try)): + for name in _assigned_names_in_control_flow(stmt): + env.pop(name, None) return env @@ -63,16 +152,39 @@ def _class_defs(tree): def _class_attr(cls, name, env): + value = _MISSING for stmt in cls.body: - if not isinstance(stmt, ast.Assign): - continue - for target in stmt.targets: - if isinstance(target, ast.Name) and target.id == name: + if isinstance(stmt, ast.Assign): + if not any(isinstance(target, ast.Name) and target.id == name for target in stmt.targets): + continue + if len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name): try: - return _literal(stmt.value, env) + value = _literal(stmt.value, env) except UnsupportedStaticExpression: - return None - return None + value = _INVALID + else: + value = _INVALID + continue + if isinstance(stmt, ast.AnnAssign): + if not isinstance(stmt.target, ast.Name) or stmt.target.id != name: + continue + if stmt.value is None: + continue + try: + value = _literal(stmt.value, env) + except UnsupportedStaticExpression: + value = _INVALID + continue + if isinstance(stmt, ast.AugAssign): + if isinstance(stmt.target, ast.Name) and stmt.target.id == name: + value = _INVALID + continue + if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try)): + if name in _assigned_names_in_control_flow(stmt): + value = _INVALID + if value in (_MISSING, _INVALID): + return None + return value def _input_types(cls, env):