Harden generated signature matching

This commit is contained in:
2026-07-02 11:47:07 +02:00
parent 38c142d42e
commit 9c17083298
2 changed files with 149 additions and 16 deletions
+67 -5
View File
@@ -1,11 +1,16 @@
import json import json
import tempfile import tempfile
import unittest import unittest
from collections import defaultdict
from pathlib import Path from pathlib import Path
import utfcn_core import utfcn_core
def _empty_generated():
return {"sigs": {}, "meta": {}, "by_out": defaultdict(list)}
class GeneratedSignatureLoaderTests(unittest.TestCase): class GeneratedSignatureLoaderTests(unittest.TestCase):
def test_missing_generated_file_returns_empty_indexes(self): def test_missing_generated_file_returns_empty_indexes(self):
with tempfile.TemporaryDirectory() as tmp: with tempfile.TemporaryDirectory() as tmp:
@@ -130,6 +135,12 @@ class GeneratedSignatureMatchingTests(unittest.TestCase):
"outputs": ["MASK"], "outputs": ["MASK"],
"output_names": ["mask"], "output_names": ["mask"],
}, },
"CoreImagePassthrough": {
"inputs": {"image": "IMAGE"},
"required": {"image"},
"outputs": ["IMAGE"],
"output_names": ["image"],
},
"CuratedTarget": { "CuratedTarget": {
"inputs": {"image": "IMAGE"}, "inputs": {"image": "IMAGE"},
"required": {"image"}, "required": {"image"},
@@ -140,9 +151,10 @@ class GeneratedSignatureMatchingTests(unittest.TestCase):
sources = { sources = {
"CoreImageSize": {"source": "core", "pack": "nodes", "display": "Core Image Size"}, "CoreImageSize": {"source": "core", "pack": "nodes", "display": "Core Image Size"},
"CoreMaskInvert": {"source": "core", "pack": "nodes", "display": "Core Mask Invert"}, "CoreMaskInvert": {"source": "core", "pack": "nodes", "display": "Core Mask Invert"},
"CoreImagePassthrough": {"source": "core", "pack": "nodes", "display": "Core Image Passthrough"},
"CuratedTarget": {"source": "core", "pack": "nodes", "display": "Curated Target"}, "CuratedTarget": {"source": "core", "pack": "nodes", "display": "Curated Target"},
} }
by_out = utfcn_core.defaultdict(list) by_out = defaultdict(list)
for name, sig in live_sigs.items(): for name, sig in live_sigs.items():
by_out[sig["outputs"][0]].append(name) by_out[sig["outputs"][0]].append(name)
return { return {
@@ -150,11 +162,11 @@ class GeneratedSignatureMatchingTests(unittest.TestCase):
"sigs": live_sigs, "sigs": live_sigs,
"by_out": by_out, "by_out": by_out,
"rules": rules or {}, "rules": rules or {},
"generated": generated or utfcn_core._empty_generated_signatures(), "generated": generated or _empty_generated(),
} }
def test_generated_exact_signature_matches_missing_node_as_verified(self): def test_generated_exact_signature_matches_missing_node_as_verified(self):
generated = utfcn_core._empty_generated_signatures() generated = _empty_generated()
generated["sigs"]["SampleImageSize"] = { generated["sigs"]["SampleImageSize"] = {
"inputs": {"image": "IMAGE"}, "inputs": {"image": "IMAGE"},
"required": {"image"}, "required": {"image"},
@@ -177,7 +189,7 @@ class GeneratedSignatureMatchingTests(unittest.TestCase):
self.assertTrue(result["SampleImageSize"][0]["verified"]) self.assertTrue(result["SampleImageSize"][0]["verified"])
def test_curated_rule_stays_first_before_generated_exact_match(self): def test_curated_rule_stays_first_before_generated_exact_match(self):
generated = utfcn_core._empty_generated_signatures() generated = _empty_generated()
generated["sigs"]["SampleImageSize"] = { generated["sigs"]["SampleImageSize"] = {
"inputs": {"image": "IMAGE"}, "inputs": {"image": "IMAGE"},
"required": {"image"}, "required": {"image"},
@@ -208,7 +220,7 @@ class GeneratedSignatureMatchingTests(unittest.TestCase):
self.assertTrue(result["SampleImageSize"][0]["verified"]) self.assertTrue(result["SampleImageSize"][0]["verified"])
def test_generated_partial_signature_matches_but_is_not_verified(self): def test_generated_partial_signature_matches_but_is_not_verified(self):
generated = utfcn_core._empty_generated_signatures() generated = _empty_generated()
generated["sigs"]["SampleMaskInvert"] = { generated["sigs"]["SampleMaskInvert"] = {
"inputs": {"masks": "MASK"}, "inputs": {"masks": "MASK"},
"required": {"masks"}, "required": {"masks"},
@@ -230,6 +242,56 @@ class GeneratedSignatureMatchingTests(unittest.TestCase):
self.assertEqual("partial", result["SampleMaskInvert"][0]["tier"]) self.assertEqual("partial", result["SampleMaskInvert"][0]["tier"])
self.assertFalse(result["SampleMaskInvert"][0]["verified"]) self.assertFalse(result["SampleMaskInvert"][0]["verified"])
def test_contradictory_generated_signature_falls_back_to_serialized_signature(self):
generated = _empty_generated()
generated["sigs"]["SampleMaskInvert"] = {
"inputs": {"image": "IMAGE"},
"required": {"image"},
"outputs": ["IMAGE"],
"output_names": ["image"],
}
generated["meta"]["SampleMaskInvert"] = {
"source": "generated",
"pack": "sample-pack",
"display": "Sample Mask Invert",
"repository": "https://github.com/example/sample-pack",
"confidence": "static_exact",
}
generated["by_out"]["IMAGE"].append("SampleMaskInvert")
result = utfcn_core.match(
self._ctx(generated=generated),
[
{
"type": "SampleMaskInvert",
"inputs": {"masks": "MASK"},
"outputs": ["MASK"],
"output_names": ["mask"],
}
],
)
self.assertEqual("CoreMaskInvert", result["SampleMaskInvert"][0]["to"])
self.assertEqual("partial", result["SampleMaskInvert"][0]["tier"])
self.assertFalse(result["SampleMaskInvert"][0]["verified"])
def test_malformed_generated_context_falls_back_without_raising(self):
result = utfcn_core.match(
self._ctx(generated={"sigs": "bad", "meta": "bad"}),
[
{
"type": "SerializedMaskInvert",
"inputs": {"masks": "MASK"},
"outputs": ["MASK"],
"output_names": ["mask"],
}
],
)
self.assertEqual("CoreMaskInvert", result["SerializedMaskInvert"][0]["to"])
self.assertEqual("partial", result["SerializedMaskInvert"][0]["tier"])
self.assertFalse(result["SerializedMaskInvert"][0]["verified"])
def test_serialized_signature_fallback_still_handles_unknown_generated_node(self): def test_serialized_signature_fallback_still_handles_unknown_generated_node(self):
result = utfcn_core.match( result = utfcn_core.match(
self._ctx(), self._ctx(),
+82 -11
View File
@@ -362,6 +362,72 @@ def build_index(ctx):
return {"sources": sources, "candidates": candidates, "stats": stats} return {"sources": sources, "candidates": candidates, "stats": stats}
def _signature_from_item(it):
inputs_raw = it.get("inputs") or {}
if not isinstance(inputs_raw, dict):
inputs_raw = {}
outputs_raw = it.get("outputs") or []
if not isinstance(outputs_raw, list):
outputs_raw = []
output_names_raw = it.get("output_names") or []
if not isinstance(output_names_raw, list):
output_names_raw = []
inputs = {str(k): str(v) for k, v in inputs_raw.items() if k is not None}
return {
"inputs": inputs,
"required": set(inputs),
"outputs": [str(x) for x in outputs_raw],
"output_names": [str(x) for x in output_names_raw],
}
def _generated_signature_usable(sig):
return isinstance(sig, dict) and isinstance(sig.get("inputs"), dict) and isinstance(sig.get("outputs"), list)
def _normalised_generated_signature(sig):
if not _generated_signature_usable(sig):
return None
try:
inputs = {str(k): str(v) for k, v in sig["inputs"].items() if k is not None}
outputs = [str(x) for x in sig["outputs"]]
required_raw = sig.get("required") or []
if not isinstance(required_raw, (list, set, tuple)):
required_raw = []
output_names_raw = sig.get("output_names") or []
if not isinstance(output_names_raw, list):
output_names_raw = []
return {
"inputs": inputs,
"required": {str(v) for v in required_raw if str(v) in inputs},
"outputs": outputs,
"output_names": [str(x) for x in output_names_raw],
}
except Exception:
return None
def _generated_signature_conflicts(serialized_sig, generated_sig):
if not serialized_sig["inputs"] and not serialized_sig["outputs"]:
return False
generated_inputs = generated_sig["inputs"]
generated_input_types = set(generated_inputs.values())
for name, typ in serialized_sig["inputs"].items():
if name in generated_inputs:
if generated_inputs[name] != typ:
return True
elif typ not in generated_input_types:
return True
if Counter(serialized_sig["outputs"]) - Counter(generated_sig["outputs"]):
return True
return False
def match(ctx, items): def match(ctx, items):
""" """
Match a batch of nodes given only their (possibly serialized) signature — Match a batch of nodes given only their (possibly serialized) signature —
@@ -375,30 +441,35 @@ def match(ctx, items):
Returns a mapping from source node type to candidate list. Returns a mapping from source node type to candidate list.
""" """
out = {} out = {}
generated = ctx.get("generated") or _empty_generated_signatures() generated = ctx.get("generated") or {}
if not isinstance(generated, dict):
generated = {}
generated_sigs = generated.get("sigs") or {} generated_sigs = generated.get("sigs") or {}
if not isinstance(generated_sigs, dict):
generated_sigs = {}
generated_meta = generated.get("meta") or {} generated_meta = generated.get("meta") or {}
if not isinstance(generated_meta, dict):
generated_meta = {}
for it in items: for it in items:
if not isinstance(it, dict):
continue
t = it.get("type") t = it.get("type")
if not t or t in out: if not t or t in out:
continue continue
gen_sig = generated_sigs.get(t) sig = _signature_from_item(it)
if gen_sig is not None: gen_sig = _normalised_generated_signature(generated_sigs.get(t))
gen_pack = (generated_meta.get(t) or {}).get("pack") if gen_sig is not None and not _generated_signature_conflicts(sig, gen_sig):
gen_meta = generated_meta.get(t) or {}
if not isinstance(gen_meta, dict):
gen_meta = {}
gen_pack = gen_meta.get("pack")
found = _candidates_for(t, gen_sig, gen_pack, ctx) found = _candidates_for(t, gen_sig, gen_pack, ctx)
if found: if found:
out[t] = found out[t] = found
continue continue
inputs = {k: str(v) for k, v in (it.get("inputs") or {}).items()}
sig = {
"inputs": inputs,
"required": set(inputs),
"outputs": [str(x) for x in (it.get("outputs") or [])],
"output_names": list(it.get("output_names") or []),
}
found = _candidates_for(t, sig, None, ctx) found = _candidates_for(t, sig, None, ctx)
if found: if found:
out[t] = found out[t] = found