Harden static signature extraction
This commit is contained in:
@@ -12,6 +12,10 @@ from tools.generate_popular_node_signatures import (
|
|||||||
|
|
||||||
|
|
||||||
class StaticExtractionTests(unittest.TestCase):
|
class StaticExtractionTests(unittest.TestCase):
|
||||||
|
def _normalise_generated_at(self, text):
|
||||||
|
parsed = json.loads(text)
|
||||||
|
return text.replace(parsed["generated_at"], "<generated-at>")
|
||||||
|
|
||||||
def test_normalise_input_spec_reduces_combo_lists(self):
|
def test_normalise_input_spec_reduces_combo_lists(self):
|
||||||
self.assertEqual("COMBO", normalise_input_spec((["nearest", "bilinear"],)))
|
self.assertEqual("COMBO", normalise_input_spec((["nearest", "bilinear"],)))
|
||||||
self.assertEqual("IMAGE", normalise_input_spec(("IMAGE",)))
|
self.assertEqual("IMAGE", normalise_input_spec(("IMAGE",)))
|
||||||
@@ -99,15 +103,105 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
self.assertEqual({}, result["nodes"])
|
self.assertEqual({}, result["nodes"])
|
||||||
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||||
|
|
||||||
|
def test_skips_unparseable_python_files_and_extracts_static_nodes(self):
|
||||||
|
good_source = '''
|
||||||
|
class GoodNode:
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"GoodNode": GoodNode,
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
Path(tmp, "bad.py").write_bytes(b"class Bad:\xff\n")
|
||||||
|
Path(tmp, "good.py").write_text(textwrap.dedent(good_source), encoding="utf-8")
|
||||||
|
result = extract_repo_signatures(
|
||||||
|
Path(tmp),
|
||||||
|
{
|
||||||
|
"id": "mixed-pack",
|
||||||
|
"title": "Mixed Pack",
|
||||||
|
"repository": "https://github.com/example/mixed-pack",
|
||||||
|
"rank": 1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertIn("GoodNode", result["nodes"])
|
||||||
|
self.assertEqual("ok", result["pack"]["status"])
|
||||||
|
|
||||||
|
def test_unsupported_reassignment_invalidates_static_env_value(self):
|
||||||
|
source = '''
|
||||||
|
def build_inputs():
|
||||||
|
return {"required": {"image": ("IMAGE",)}}
|
||||||
|
|
||||||
|
|
||||||
|
INPUTS = {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
INPUTS = build_inputs()
|
||||||
|
|
||||||
|
|
||||||
|
class StaleEnvNode:
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return INPUTS
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"StaleEnvNode": StaleEnvNode,
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
Path(tmp, "__init__.py").write_text(textwrap.dedent(source), encoding="utf-8")
|
||||||
|
result = extract_repo_signatures(
|
||||||
|
Path(tmp),
|
||||||
|
{
|
||||||
|
"id": "stale-env-pack",
|
||||||
|
"title": "Stale Env Pack",
|
||||||
|
"repository": "https://github.com/example/stale-env-pack",
|
||||||
|
"rank": 1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual({}, result["nodes"])
|
||||||
|
self.assertEqual("no_static_nodes", result["pack"]["status"])
|
||||||
|
|
||||||
def test_write_artifact_is_deterministic(self):
|
def test_write_artifact_is_deterministic(self):
|
||||||
with tempfile.TemporaryDirectory() as tmp:
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
out = Path(tmp, "popular_node_signatures.json")
|
out_one = Path(tmp, "one.json")
|
||||||
|
out_two = Path(tmp, "two.json")
|
||||||
write_artifact(
|
write_artifact(
|
||||||
out,
|
out_one,
|
||||||
sources={"manager_url": "https://example.invalid/manager.json", "limit": 1},
|
sources={
|
||||||
|
"manager_url": "https://example.invalid/manager.json",
|
||||||
|
"limit": 1,
|
||||||
|
"registry": {"z": "last", "a": "first"},
|
||||||
|
},
|
||||||
packs={
|
packs={
|
||||||
"b-pack": {"id": "b-pack", "title": "B Pack", "status": "ok"},
|
"b-pack": {
|
||||||
"a-pack": {"id": "a-pack", "title": "A Pack", "status": "ok"},
|
"id": "b-pack",
|
||||||
|
"title": "B Pack",
|
||||||
|
"status": "ok",
|
||||||
|
"metadata": {"z": 2, "a": 1},
|
||||||
|
},
|
||||||
|
"a-pack": {
|
||||||
|
"id": "a-pack",
|
||||||
|
"title": "A Pack",
|
||||||
|
"status": "ok",
|
||||||
|
"metadata": {"z": 4, "a": 3},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
nodes={
|
nodes={
|
||||||
"BNode": {
|
"BNode": {
|
||||||
@@ -115,7 +209,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"display": "B Node",
|
"display": "B Node",
|
||||||
"pack": "b-pack",
|
"pack": "b-pack",
|
||||||
"repository": "https://github.com/example/b-pack",
|
"repository": "https://github.com/example/b-pack",
|
||||||
"inputs": {},
|
"inputs": {"zeta": "FLOAT", "alpha": "IMAGE"},
|
||||||
"required": [],
|
"required": [],
|
||||||
"outputs": ["IMAGE"],
|
"outputs": ["IMAGE"],
|
||||||
"output_names": ["image"],
|
"output_names": ["image"],
|
||||||
@@ -126,7 +220,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"display": "A Node",
|
"display": "A Node",
|
||||||
"pack": "a-pack",
|
"pack": "a-pack",
|
||||||
"repository": "https://github.com/example/a-pack",
|
"repository": "https://github.com/example/a-pack",
|
||||||
"inputs": {},
|
"inputs": {"zeta": "FLOAT", "alpha": "IMAGE"},
|
||||||
"required": [],
|
"required": [],
|
||||||
"outputs": ["IMAGE"],
|
"outputs": ["IMAGE"],
|
||||||
"output_names": ["image"],
|
"output_names": ["image"],
|
||||||
@@ -134,10 +228,59 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
parsed = json.loads(out.read_text(encoding="utf-8"))
|
write_artifact(
|
||||||
|
out_two,
|
||||||
|
sources={
|
||||||
|
"registry": {"a": "first", "z": "last"},
|
||||||
|
"limit": 1,
|
||||||
|
"manager_url": "https://example.invalid/manager.json",
|
||||||
|
},
|
||||||
|
packs={
|
||||||
|
"a-pack": {
|
||||||
|
"metadata": {"a": 3, "z": 4},
|
||||||
|
"status": "ok",
|
||||||
|
"title": "A Pack",
|
||||||
|
"id": "a-pack",
|
||||||
|
},
|
||||||
|
"b-pack": {
|
||||||
|
"metadata": {"a": 1, "z": 2},
|
||||||
|
"status": "ok",
|
||||||
|
"title": "B Pack",
|
||||||
|
"id": "b-pack",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
nodes={
|
||||||
|
"ANode": {
|
||||||
|
"confidence": "static_exact",
|
||||||
|
"output_names": ["image"],
|
||||||
|
"outputs": ["IMAGE"],
|
||||||
|
"required": [],
|
||||||
|
"inputs": {"alpha": "IMAGE", "zeta": "FLOAT"},
|
||||||
|
"repository": "https://github.com/example/a-pack",
|
||||||
|
"pack": "a-pack",
|
||||||
|
"display": "A Node",
|
||||||
|
"type": "ANode",
|
||||||
|
},
|
||||||
|
"BNode": {
|
||||||
|
"confidence": "static_exact",
|
||||||
|
"output_names": ["image"],
|
||||||
|
"outputs": ["IMAGE"],
|
||||||
|
"required": [],
|
||||||
|
"inputs": {"alpha": "IMAGE", "zeta": "FLOAT"},
|
||||||
|
"repository": "https://github.com/example/b-pack",
|
||||||
|
"pack": "b-pack",
|
||||||
|
"display": "B Node",
|
||||||
|
"type": "BNode",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
text_one = out_one.read_text(encoding="utf-8")
|
||||||
|
text_two = out_two.read_text(encoding="utf-8")
|
||||||
|
parsed = json.loads(text_one)
|
||||||
|
|
||||||
self.assertEqual(["a-pack", "b-pack"], list(parsed["packs"]))
|
self.assertEqual(["a-pack", "b-pack"], list(parsed["packs"]))
|
||||||
self.assertEqual(["ANode", "BNode"], list(parsed["nodes"]))
|
self.assertEqual(["ANode", "BNode"], list(parsed["nodes"]))
|
||||||
|
self.assertEqual(self._normalise_generated_at(text_one), self._normalise_generated_at(text_two))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -42,9 +42,11 @@ def _collect_module_env(tree):
|
|||||||
continue
|
continue
|
||||||
if len(stmt.targets) != 1 or not isinstance(stmt.targets[0], ast.Name):
|
if len(stmt.targets) != 1 or not isinstance(stmt.targets[0], ast.Name):
|
||||||
continue
|
continue
|
||||||
|
name = stmt.targets[0].id
|
||||||
try:
|
try:
|
||||||
env[stmt.targets[0].id] = _literal(stmt.value, env)
|
env[name] = _literal(stmt.value, env)
|
||||||
except UnsupportedStaticExpression:
|
except UnsupportedStaticExpression:
|
||||||
|
env.pop(name, None)
|
||||||
continue
|
continue
|
||||||
return env
|
return env
|
||||||
|
|
||||||
@@ -181,14 +183,23 @@ def _python_files(repo_dir):
|
|||||||
yield Path(root, filename)
|
yield Path(root, filename)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_python_file(path):
|
||||||
|
try:
|
||||||
|
return ast.parse(path.read_text(encoding="utf-8"), filename=str(path))
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
try:
|
||||||
|
return ast.parse(path.read_text(encoding="utf-8", errors="ignore"), filename=str(path))
|
||||||
|
except SyntaxError:
|
||||||
|
return None
|
||||||
|
except SyntaxError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def extract_repo_signatures(repo_dir, pack_meta):
|
def extract_repo_signatures(repo_dir, pack_meta):
|
||||||
nodes = {}
|
nodes = {}
|
||||||
for path in sorted(_python_files(repo_dir)):
|
for path in sorted(_python_files(repo_dir)):
|
||||||
try:
|
tree = _parse_python_file(path)
|
||||||
tree = ast.parse(path.read_text(encoding="utf-8"), filename=str(path))
|
if tree is None:
|
||||||
except UnicodeDecodeError:
|
|
||||||
tree = ast.parse(path.read_text(encoding="utf-8", errors="ignore"), filename=str(path))
|
|
||||||
except SyntaxError:
|
|
||||||
continue
|
continue
|
||||||
env = _collect_module_env(tree)
|
env = _collect_module_env(tree)
|
||||||
mappings = _node_class_mappings(tree, env)
|
mappings = _node_class_mappings(tree, env)
|
||||||
@@ -213,13 +224,21 @@ def extract_repo_signatures(repo_dir, pack_meta):
|
|||||||
return {"pack": pack, "nodes": nodes}
|
return {"pack": pack, "nodes": nodes}
|
||||||
|
|
||||||
|
|
||||||
|
def _sorted_json_value(value):
|
||||||
|
if isinstance(value, dict):
|
||||||
|
return {key: _sorted_json_value(value[key]) for key in sorted(value)}
|
||||||
|
if isinstance(value, list):
|
||||||
|
return [_sorted_json_value(item) for item in value]
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
def write_artifact(path, sources, packs, nodes):
|
def write_artifact(path, sources, packs, nodes):
|
||||||
payload = {
|
payload = {
|
||||||
"schema_version": SCHEMA_VERSION,
|
"schema_version": SCHEMA_VERSION,
|
||||||
"generated_at": datetime.now(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z"),
|
"generated_at": datetime.now(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z"),
|
||||||
"sources": sources,
|
"sources": _sorted_json_value(sources),
|
||||||
"packs": {key: packs[key] for key in sorted(packs)},
|
"packs": _sorted_json_value(packs),
|
||||||
"nodes": {key: nodes[key] for key in sorted(nodes)},
|
"nodes": _sorted_json_value(nodes),
|
||||||
}
|
}
|
||||||
path.parent.mkdir(parents=True, exist_ok=True)
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
path.write_text(json.dumps(payload, indent=2, sort_keys=False) + "\n", encoding="utf-8")
|
path.write_text(json.dumps(payload, indent=2, sort_keys=False) + "\n", encoding="utf-8")
|
||||||
|
|||||||
Reference in New Issue
Block a user