Harden generated signature matching
This commit is contained in:
@@ -1,11 +1,16 @@
|
|||||||
import json
|
import json
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import utfcn_core
|
import utfcn_core
|
||||||
|
|
||||||
|
|
||||||
|
def _empty_generated():
|
||||||
|
return {"sigs": {}, "meta": {}, "by_out": defaultdict(list)}
|
||||||
|
|
||||||
|
|
||||||
class GeneratedSignatureLoaderTests(unittest.TestCase):
|
class GeneratedSignatureLoaderTests(unittest.TestCase):
|
||||||
def test_missing_generated_file_returns_empty_indexes(self):
|
def test_missing_generated_file_returns_empty_indexes(self):
|
||||||
with tempfile.TemporaryDirectory() as tmp:
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
@@ -130,6 +135,12 @@ class GeneratedSignatureMatchingTests(unittest.TestCase):
|
|||||||
"outputs": ["MASK"],
|
"outputs": ["MASK"],
|
||||||
"output_names": ["mask"],
|
"output_names": ["mask"],
|
||||||
},
|
},
|
||||||
|
"CoreImagePassthrough": {
|
||||||
|
"inputs": {"image": "IMAGE"},
|
||||||
|
"required": {"image"},
|
||||||
|
"outputs": ["IMAGE"],
|
||||||
|
"output_names": ["image"],
|
||||||
|
},
|
||||||
"CuratedTarget": {
|
"CuratedTarget": {
|
||||||
"inputs": {"image": "IMAGE"},
|
"inputs": {"image": "IMAGE"},
|
||||||
"required": {"image"},
|
"required": {"image"},
|
||||||
@@ -140,9 +151,10 @@ class GeneratedSignatureMatchingTests(unittest.TestCase):
|
|||||||
sources = {
|
sources = {
|
||||||
"CoreImageSize": {"source": "core", "pack": "nodes", "display": "Core Image Size"},
|
"CoreImageSize": {"source": "core", "pack": "nodes", "display": "Core Image Size"},
|
||||||
"CoreMaskInvert": {"source": "core", "pack": "nodes", "display": "Core Mask Invert"},
|
"CoreMaskInvert": {"source": "core", "pack": "nodes", "display": "Core Mask Invert"},
|
||||||
|
"CoreImagePassthrough": {"source": "core", "pack": "nodes", "display": "Core Image Passthrough"},
|
||||||
"CuratedTarget": {"source": "core", "pack": "nodes", "display": "Curated Target"},
|
"CuratedTarget": {"source": "core", "pack": "nodes", "display": "Curated Target"},
|
||||||
}
|
}
|
||||||
by_out = utfcn_core.defaultdict(list)
|
by_out = defaultdict(list)
|
||||||
for name, sig in live_sigs.items():
|
for name, sig in live_sigs.items():
|
||||||
by_out[sig["outputs"][0]].append(name)
|
by_out[sig["outputs"][0]].append(name)
|
||||||
return {
|
return {
|
||||||
@@ -150,11 +162,11 @@ class GeneratedSignatureMatchingTests(unittest.TestCase):
|
|||||||
"sigs": live_sigs,
|
"sigs": live_sigs,
|
||||||
"by_out": by_out,
|
"by_out": by_out,
|
||||||
"rules": rules or {},
|
"rules": rules or {},
|
||||||
"generated": generated or utfcn_core._empty_generated_signatures(),
|
"generated": generated or _empty_generated(),
|
||||||
}
|
}
|
||||||
|
|
||||||
def test_generated_exact_signature_matches_missing_node_as_verified(self):
|
def test_generated_exact_signature_matches_missing_node_as_verified(self):
|
||||||
generated = utfcn_core._empty_generated_signatures()
|
generated = _empty_generated()
|
||||||
generated["sigs"]["SampleImageSize"] = {
|
generated["sigs"]["SampleImageSize"] = {
|
||||||
"inputs": {"image": "IMAGE"},
|
"inputs": {"image": "IMAGE"},
|
||||||
"required": {"image"},
|
"required": {"image"},
|
||||||
@@ -177,7 +189,7 @@ class GeneratedSignatureMatchingTests(unittest.TestCase):
|
|||||||
self.assertTrue(result["SampleImageSize"][0]["verified"])
|
self.assertTrue(result["SampleImageSize"][0]["verified"])
|
||||||
|
|
||||||
def test_curated_rule_stays_first_before_generated_exact_match(self):
|
def test_curated_rule_stays_first_before_generated_exact_match(self):
|
||||||
generated = utfcn_core._empty_generated_signatures()
|
generated = _empty_generated()
|
||||||
generated["sigs"]["SampleImageSize"] = {
|
generated["sigs"]["SampleImageSize"] = {
|
||||||
"inputs": {"image": "IMAGE"},
|
"inputs": {"image": "IMAGE"},
|
||||||
"required": {"image"},
|
"required": {"image"},
|
||||||
@@ -208,7 +220,7 @@ class GeneratedSignatureMatchingTests(unittest.TestCase):
|
|||||||
self.assertTrue(result["SampleImageSize"][0]["verified"])
|
self.assertTrue(result["SampleImageSize"][0]["verified"])
|
||||||
|
|
||||||
def test_generated_partial_signature_matches_but_is_not_verified(self):
|
def test_generated_partial_signature_matches_but_is_not_verified(self):
|
||||||
generated = utfcn_core._empty_generated_signatures()
|
generated = _empty_generated()
|
||||||
generated["sigs"]["SampleMaskInvert"] = {
|
generated["sigs"]["SampleMaskInvert"] = {
|
||||||
"inputs": {"masks": "MASK"},
|
"inputs": {"masks": "MASK"},
|
||||||
"required": {"masks"},
|
"required": {"masks"},
|
||||||
@@ -230,6 +242,56 @@ class GeneratedSignatureMatchingTests(unittest.TestCase):
|
|||||||
self.assertEqual("partial", result["SampleMaskInvert"][0]["tier"])
|
self.assertEqual("partial", result["SampleMaskInvert"][0]["tier"])
|
||||||
self.assertFalse(result["SampleMaskInvert"][0]["verified"])
|
self.assertFalse(result["SampleMaskInvert"][0]["verified"])
|
||||||
|
|
||||||
|
def test_contradictory_generated_signature_falls_back_to_serialized_signature(self):
|
||||||
|
generated = _empty_generated()
|
||||||
|
generated["sigs"]["SampleMaskInvert"] = {
|
||||||
|
"inputs": {"image": "IMAGE"},
|
||||||
|
"required": {"image"},
|
||||||
|
"outputs": ["IMAGE"],
|
||||||
|
"output_names": ["image"],
|
||||||
|
}
|
||||||
|
generated["meta"]["SampleMaskInvert"] = {
|
||||||
|
"source": "generated",
|
||||||
|
"pack": "sample-pack",
|
||||||
|
"display": "Sample Mask Invert",
|
||||||
|
"repository": "https://github.com/example/sample-pack",
|
||||||
|
"confidence": "static_exact",
|
||||||
|
}
|
||||||
|
generated["by_out"]["IMAGE"].append("SampleMaskInvert")
|
||||||
|
|
||||||
|
result = utfcn_core.match(
|
||||||
|
self._ctx(generated=generated),
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"type": "SampleMaskInvert",
|
||||||
|
"inputs": {"masks": "MASK"},
|
||||||
|
"outputs": ["MASK"],
|
||||||
|
"output_names": ["mask"],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual("CoreMaskInvert", result["SampleMaskInvert"][0]["to"])
|
||||||
|
self.assertEqual("partial", result["SampleMaskInvert"][0]["tier"])
|
||||||
|
self.assertFalse(result["SampleMaskInvert"][0]["verified"])
|
||||||
|
|
||||||
|
def test_malformed_generated_context_falls_back_without_raising(self):
|
||||||
|
result = utfcn_core.match(
|
||||||
|
self._ctx(generated={"sigs": "bad", "meta": "bad"}),
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"type": "SerializedMaskInvert",
|
||||||
|
"inputs": {"masks": "MASK"},
|
||||||
|
"outputs": ["MASK"],
|
||||||
|
"output_names": ["mask"],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual("CoreMaskInvert", result["SerializedMaskInvert"][0]["to"])
|
||||||
|
self.assertEqual("partial", result["SerializedMaskInvert"][0]["tier"])
|
||||||
|
self.assertFalse(result["SerializedMaskInvert"][0]["verified"])
|
||||||
|
|
||||||
def test_serialized_signature_fallback_still_handles_unknown_generated_node(self):
|
def test_serialized_signature_fallback_still_handles_unknown_generated_node(self):
|
||||||
result = utfcn_core.match(
|
result = utfcn_core.match(
|
||||||
self._ctx(),
|
self._ctx(),
|
||||||
|
|||||||
+82
-11
@@ -362,6 +362,72 @@ def build_index(ctx):
|
|||||||
return {"sources": sources, "candidates": candidates, "stats": stats}
|
return {"sources": sources, "candidates": candidates, "stats": stats}
|
||||||
|
|
||||||
|
|
||||||
|
def _signature_from_item(it):
|
||||||
|
inputs_raw = it.get("inputs") or {}
|
||||||
|
if not isinstance(inputs_raw, dict):
|
||||||
|
inputs_raw = {}
|
||||||
|
outputs_raw = it.get("outputs") or []
|
||||||
|
if not isinstance(outputs_raw, list):
|
||||||
|
outputs_raw = []
|
||||||
|
output_names_raw = it.get("output_names") or []
|
||||||
|
if not isinstance(output_names_raw, list):
|
||||||
|
output_names_raw = []
|
||||||
|
|
||||||
|
inputs = {str(k): str(v) for k, v in inputs_raw.items() if k is not None}
|
||||||
|
return {
|
||||||
|
"inputs": inputs,
|
||||||
|
"required": set(inputs),
|
||||||
|
"outputs": [str(x) for x in outputs_raw],
|
||||||
|
"output_names": [str(x) for x in output_names_raw],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _generated_signature_usable(sig):
|
||||||
|
return isinstance(sig, dict) and isinstance(sig.get("inputs"), dict) and isinstance(sig.get("outputs"), list)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalised_generated_signature(sig):
|
||||||
|
if not _generated_signature_usable(sig):
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
inputs = {str(k): str(v) for k, v in sig["inputs"].items() if k is not None}
|
||||||
|
outputs = [str(x) for x in sig["outputs"]]
|
||||||
|
required_raw = sig.get("required") or []
|
||||||
|
if not isinstance(required_raw, (list, set, tuple)):
|
||||||
|
required_raw = []
|
||||||
|
output_names_raw = sig.get("output_names") or []
|
||||||
|
if not isinstance(output_names_raw, list):
|
||||||
|
output_names_raw = []
|
||||||
|
return {
|
||||||
|
"inputs": inputs,
|
||||||
|
"required": {str(v) for v in required_raw if str(v) in inputs},
|
||||||
|
"outputs": outputs,
|
||||||
|
"output_names": [str(x) for x in output_names_raw],
|
||||||
|
}
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _generated_signature_conflicts(serialized_sig, generated_sig):
|
||||||
|
if not serialized_sig["inputs"] and not serialized_sig["outputs"]:
|
||||||
|
return False
|
||||||
|
|
||||||
|
generated_inputs = generated_sig["inputs"]
|
||||||
|
generated_input_types = set(generated_inputs.values())
|
||||||
|
for name, typ in serialized_sig["inputs"].items():
|
||||||
|
if name in generated_inputs:
|
||||||
|
if generated_inputs[name] != typ:
|
||||||
|
return True
|
||||||
|
elif typ not in generated_input_types:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if Counter(serialized_sig["outputs"]) - Counter(generated_sig["outputs"]):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def match(ctx, items):
|
def match(ctx, items):
|
||||||
"""
|
"""
|
||||||
Match a batch of nodes given only their (possibly serialized) signature —
|
Match a batch of nodes given only their (possibly serialized) signature —
|
||||||
@@ -375,30 +441,35 @@ def match(ctx, items):
|
|||||||
Returns a mapping from source node type to candidate list.
|
Returns a mapping from source node type to candidate list.
|
||||||
"""
|
"""
|
||||||
out = {}
|
out = {}
|
||||||
generated = ctx.get("generated") or _empty_generated_signatures()
|
generated = ctx.get("generated") or {}
|
||||||
|
if not isinstance(generated, dict):
|
||||||
|
generated = {}
|
||||||
generated_sigs = generated.get("sigs") or {}
|
generated_sigs = generated.get("sigs") or {}
|
||||||
|
if not isinstance(generated_sigs, dict):
|
||||||
|
generated_sigs = {}
|
||||||
generated_meta = generated.get("meta") or {}
|
generated_meta = generated.get("meta") or {}
|
||||||
|
if not isinstance(generated_meta, dict):
|
||||||
|
generated_meta = {}
|
||||||
|
|
||||||
for it in items:
|
for it in items:
|
||||||
|
if not isinstance(it, dict):
|
||||||
|
continue
|
||||||
t = it.get("type")
|
t = it.get("type")
|
||||||
if not t or t in out:
|
if not t or t in out:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
gen_sig = generated_sigs.get(t)
|
sig = _signature_from_item(it)
|
||||||
if gen_sig is not None:
|
gen_sig = _normalised_generated_signature(generated_sigs.get(t))
|
||||||
gen_pack = (generated_meta.get(t) or {}).get("pack")
|
if gen_sig is not None and not _generated_signature_conflicts(sig, gen_sig):
|
||||||
|
gen_meta = generated_meta.get(t) or {}
|
||||||
|
if not isinstance(gen_meta, dict):
|
||||||
|
gen_meta = {}
|
||||||
|
gen_pack = gen_meta.get("pack")
|
||||||
found = _candidates_for(t, gen_sig, gen_pack, ctx)
|
found = _candidates_for(t, gen_sig, gen_pack, ctx)
|
||||||
if found:
|
if found:
|
||||||
out[t] = found
|
out[t] = found
|
||||||
continue
|
continue
|
||||||
|
|
||||||
inputs = {k: str(v) for k, v in (it.get("inputs") or {}).items()}
|
|
||||||
sig = {
|
|
||||||
"inputs": inputs,
|
|
||||||
"required": set(inputs),
|
|
||||||
"outputs": [str(x) for x in (it.get("outputs") or [])],
|
|
||||||
"output_names": list(it.get("output_names") or []),
|
|
||||||
}
|
|
||||||
found = _candidates_for(t, sig, None, ctx)
|
found = _candidates_for(t, sig, None, ctx)
|
||||||
if found:
|
if found:
|
||||||
out[t] = found
|
out[t] = found
|
||||||
|
|||||||
Reference in New Issue
Block a user