diff --git a/tests/test_generate_popular_node_signatures.py b/tests/test_generate_popular_node_signatures.py index a7abe41..3c42548 100644 --- a/tests/test_generate_popular_node_signatures.py +++ b/tests/test_generate_popular_node_signatures.py @@ -590,6 +590,34 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_nested_mutable_env_literal_skips_static_node(self): + source = ''' +REQ = { + "image": ("IMAGE",), +} +INPUTS = { + "required": REQ, +} +REQ.clear() + + +class NestedMutableEnvLiteralNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return INPUTS + + +NODE_CLASS_MAPPINGS = { + "NestedMutableEnvLiteralNode": NestedMutableEnvLiteralNode, +} +''' + result = self._extract_source(source, "nested-mutable-env-literal-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_post_class_input_reassignment_skips_static_node(self): source = ''' def build_inputs(): @@ -650,6 +678,36 @@ NODE_CLASS_MAPPINGS = { self.assertIn("LiteralInputTypesAfterEnvChangeNode", result["nodes"]) self.assertEqual("ok", result["pack"]["status"]) + def test_later_dynamic_input_types_binding_skips_node(self): + source = ''' +def build_inputs(): + return {"required": {"mask": ("MASK",)}} + + +class LaterDynamicInputTypesNode: + RETURN_TYPES = ("IMAGE",) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + def INPUT_TYPES(cls): + return build_inputs() + + +NODE_CLASS_MAPPINGS = { + "LaterDynamicInputTypesNode": LaterDynamicInputTypesNode, +} +''' + result = self._extract_source(source, "later-dynamic-input-types-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_dynamic_return_types_reassignment_skips_node(self): source = ''' def build_outputs(): @@ -827,6 +885,32 @@ NODE_CLASS_MAPPINGS = { self.assertEqual({}, result["nodes"]) self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_return_types_transitive_alias_mutation_skips_node(self): + source = ''' +class TransitiveAliasMutatedReturnTypesNode: + RETURN_TYPES = ["IMAGE"] + A = RETURN_TYPES + B = A + B.clear() + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + } + + +NODE_CLASS_MAPPINGS = { + "TransitiveAliasMutatedReturnTypesNode": TransitiveAliasMutatedReturnTypesNode, +} +''' + result = self._extract_source(source, "transitive-alias-mutated-return-types-pack") + + self.assertEqual({}, result["nodes"]) + self.assertEqual("no_static_nodes", result["pack"]["status"]) + def test_class_return_types_uses_definition_time_module_env(self): source = ''' RETURNS = ("IMAGE",) diff --git a/tools/generate_popular_node_signatures.py b/tools/generate_popular_node_signatures.py index 9583e74..b14b59c 100644 --- a/tools/generate_popular_node_signatures.py +++ b/tools/generate_popular_node_signatures.py @@ -35,22 +35,25 @@ _MUTATING_METHODS = { } -def _literal(node, env): +def _literal(node, env, allow_mutable_env=True): if isinstance(node, ast.Constant): return node.value if isinstance(node, ast.List): - return [_literal(item, env) for item in node.elts] + return [_literal(item, env, allow_mutable_env=False) for item in node.elts] if isinstance(node, ast.Tuple): - return tuple(_literal(item, env) for item in node.elts) + return tuple(_literal(item, env, allow_mutable_env=False) for item in node.elts) if isinstance(node, ast.Dict): result = {} for key, value in zip(node.keys, node.values): if key is None: raise UnsupportedStaticExpression("dict unpacking is not supported") - result[_literal(key, env)] = _literal(value, env) + result[_literal(key, env, allow_mutable_env=False)] = _literal(value, env, allow_mutable_env=False) return result if isinstance(node, ast.Name) and node.id in env: - return env[node.id] + value = env[node.id] + if not allow_mutable_env and _is_mutable_static_value(value): + raise UnsupportedStaticExpression(f"mutable env reference {node.id!r} is not supported") + return value raise UnsupportedStaticExpression(type(node).__name__) @@ -413,7 +416,7 @@ def _class_attr(cls, name, env): and isinstance(stmt.targets[0], ast.Name) and stmt.targets[0].id != name and isinstance(stmt.value, ast.Name) - and stmt.value.id == name + and (stmt.value.id == name or stmt.value.id in aliases) ): aliases.add(stmt.targets[0].id) continue @@ -439,7 +442,7 @@ def _class_attr(cls, name, env): isinstance(stmt.target, ast.Name) and stmt.target.id != name and isinstance(stmt.value, ast.Name) - and stmt.value.id == name + and (stmt.value.id == name or stmt.value.id in aliases) ): aliases.add(stmt.target.id) continue @@ -505,17 +508,47 @@ def _class_attr(cls, name, env): def _input_types(cls, env): + value = _MISSING for stmt in cls.body: - if not isinstance(stmt, ast.FunctionDef) or stmt.name != "INPUT_TYPES": + if isinstance(stmt, ast.FunctionDef) and stmt.name == "INPUT_TYPES": + if len(stmt.body) != 1 or not isinstance(stmt.body[0], ast.Return): + value = _INVALID + continue + try: + candidate = _literal(stmt.body[0].value, env) + except UnsupportedStaticExpression: + value = _INVALID + continue + value = candidate if isinstance(candidate, dict) else _INVALID continue - if len(stmt.body) != 1 or not isinstance(stmt.body[0], ast.Return): - return None - try: - value = _literal(stmt.body[0].value, env) - except UnsupportedStaticExpression: - return None - return value if isinstance(value, dict) else None - return None + if isinstance(stmt, ast.AsyncFunctionDef) and stmt.name == "INPUT_TYPES": + value = _INVALID + continue + if isinstance(stmt, (ast.Assign, ast.AnnAssign, ast.AugAssign)): + if "INPUT_TYPES" in _assignment_target_names(stmt): + value = _INVALID + continue + if isinstance(stmt, ast.Delete): + if "INPUT_TYPES" in _delete_target_names(stmt): + value = _INVALID + continue + if isinstance(stmt, ast.Expr): + if "INPUT_TYPES" in _mutating_call_target_names(stmt): + value = _INVALID + if "INPUT_TYPES" in _bound_names(stmt): + value = _INVALID + continue + if isinstance(stmt, (ast.If, ast.For, ast.AsyncFor, ast.While, ast.Try, ast.With, ast.AsyncWith, ast.Match)): + if "INPUT_TYPES" in _assigned_names_in_control_flow(stmt): + value = _INVALID + if _has_wildcard_import_in_control_flow(stmt): + value = _INVALID + continue + if "INPUT_TYPES" in _bound_names(stmt): + value = _INVALID + if value in (_MISSING, _INVALID): + return None + return value def _mapping_value_name(value):