diff --git a/tests/test_utfcn_core_generated.py b/tests/test_utfcn_core_generated.py index d4e16bd..5923109 100644 --- a/tests/test_utfcn_core_generated.py +++ b/tests/test_utfcn_core_generated.py @@ -1,11 +1,16 @@ import json import tempfile import unittest +from collections import defaultdict from pathlib import Path import utfcn_core +def _empty_generated(): + return {"sigs": {}, "meta": {}, "by_out": defaultdict(list)} + + class GeneratedSignatureLoaderTests(unittest.TestCase): def test_missing_generated_file_returns_empty_indexes(self): with tempfile.TemporaryDirectory() as tmp: @@ -130,6 +135,12 @@ class GeneratedSignatureMatchingTests(unittest.TestCase): "outputs": ["MASK"], "output_names": ["mask"], }, + "CoreImagePassthrough": { + "inputs": {"image": "IMAGE"}, + "required": {"image"}, + "outputs": ["IMAGE"], + "output_names": ["image"], + }, "CuratedTarget": { "inputs": {"image": "IMAGE"}, "required": {"image"}, @@ -140,9 +151,10 @@ class GeneratedSignatureMatchingTests(unittest.TestCase): sources = { "CoreImageSize": {"source": "core", "pack": "nodes", "display": "Core Image Size"}, "CoreMaskInvert": {"source": "core", "pack": "nodes", "display": "Core Mask Invert"}, + "CoreImagePassthrough": {"source": "core", "pack": "nodes", "display": "Core Image Passthrough"}, "CuratedTarget": {"source": "core", "pack": "nodes", "display": "Curated Target"}, } - by_out = utfcn_core.defaultdict(list) + by_out = defaultdict(list) for name, sig in live_sigs.items(): by_out[sig["outputs"][0]].append(name) return { @@ -150,11 +162,11 @@ class GeneratedSignatureMatchingTests(unittest.TestCase): "sigs": live_sigs, "by_out": by_out, "rules": rules or {}, - "generated": generated or utfcn_core._empty_generated_signatures(), + "generated": generated or _empty_generated(), } def test_generated_exact_signature_matches_missing_node_as_verified(self): - generated = utfcn_core._empty_generated_signatures() + generated = _empty_generated() generated["sigs"]["SampleImageSize"] = { "inputs": {"image": "IMAGE"}, "required": {"image"}, @@ -177,7 +189,7 @@ class GeneratedSignatureMatchingTests(unittest.TestCase): self.assertTrue(result["SampleImageSize"][0]["verified"]) def test_curated_rule_stays_first_before_generated_exact_match(self): - generated = utfcn_core._empty_generated_signatures() + generated = _empty_generated() generated["sigs"]["SampleImageSize"] = { "inputs": {"image": "IMAGE"}, "required": {"image"}, @@ -208,7 +220,7 @@ class GeneratedSignatureMatchingTests(unittest.TestCase): self.assertTrue(result["SampleImageSize"][0]["verified"]) def test_generated_partial_signature_matches_but_is_not_verified(self): - generated = utfcn_core._empty_generated_signatures() + generated = _empty_generated() generated["sigs"]["SampleMaskInvert"] = { "inputs": {"masks": "MASK"}, "required": {"masks"}, @@ -230,6 +242,56 @@ class GeneratedSignatureMatchingTests(unittest.TestCase): self.assertEqual("partial", result["SampleMaskInvert"][0]["tier"]) self.assertFalse(result["SampleMaskInvert"][0]["verified"]) + def test_contradictory_generated_signature_falls_back_to_serialized_signature(self): + generated = _empty_generated() + generated["sigs"]["SampleMaskInvert"] = { + "inputs": {"image": "IMAGE"}, + "required": {"image"}, + "outputs": ["IMAGE"], + "output_names": ["image"], + } + generated["meta"]["SampleMaskInvert"] = { + "source": "generated", + "pack": "sample-pack", + "display": "Sample Mask Invert", + "repository": "https://github.com/example/sample-pack", + "confidence": "static_exact", + } + generated["by_out"]["IMAGE"].append("SampleMaskInvert") + + result = utfcn_core.match( + self._ctx(generated=generated), + [ + { + "type": "SampleMaskInvert", + "inputs": {"masks": "MASK"}, + "outputs": ["MASK"], + "output_names": ["mask"], + } + ], + ) + + self.assertEqual("CoreMaskInvert", result["SampleMaskInvert"][0]["to"]) + self.assertEqual("partial", result["SampleMaskInvert"][0]["tier"]) + self.assertFalse(result["SampleMaskInvert"][0]["verified"]) + + def test_malformed_generated_context_falls_back_without_raising(self): + result = utfcn_core.match( + self._ctx(generated={"sigs": "bad", "meta": "bad"}), + [ + { + "type": "SerializedMaskInvert", + "inputs": {"masks": "MASK"}, + "outputs": ["MASK"], + "output_names": ["mask"], + } + ], + ) + + self.assertEqual("CoreMaskInvert", result["SerializedMaskInvert"][0]["to"]) + self.assertEqual("partial", result["SerializedMaskInvert"][0]["tier"]) + self.assertFalse(result["SerializedMaskInvert"][0]["verified"]) + def test_serialized_signature_fallback_still_handles_unknown_generated_node(self): result = utfcn_core.match( self._ctx(), diff --git a/utfcn_core.py b/utfcn_core.py index f4fff57..1cc1e46 100644 --- a/utfcn_core.py +++ b/utfcn_core.py @@ -362,6 +362,72 @@ def build_index(ctx): return {"sources": sources, "candidates": candidates, "stats": stats} +def _signature_from_item(it): + inputs_raw = it.get("inputs") or {} + if not isinstance(inputs_raw, dict): + inputs_raw = {} + outputs_raw = it.get("outputs") or [] + if not isinstance(outputs_raw, list): + outputs_raw = [] + output_names_raw = it.get("output_names") or [] + if not isinstance(output_names_raw, list): + output_names_raw = [] + + inputs = {str(k): str(v) for k, v in inputs_raw.items() if k is not None} + return { + "inputs": inputs, + "required": set(inputs), + "outputs": [str(x) for x in outputs_raw], + "output_names": [str(x) for x in output_names_raw], + } + + +def _generated_signature_usable(sig): + return isinstance(sig, dict) and isinstance(sig.get("inputs"), dict) and isinstance(sig.get("outputs"), list) + + +def _normalised_generated_signature(sig): + if not _generated_signature_usable(sig): + return None + + try: + inputs = {str(k): str(v) for k, v in sig["inputs"].items() if k is not None} + outputs = [str(x) for x in sig["outputs"]] + required_raw = sig.get("required") or [] + if not isinstance(required_raw, (list, set, tuple)): + required_raw = [] + output_names_raw = sig.get("output_names") or [] + if not isinstance(output_names_raw, list): + output_names_raw = [] + return { + "inputs": inputs, + "required": {str(v) for v in required_raw if str(v) in inputs}, + "outputs": outputs, + "output_names": [str(x) for x in output_names_raw], + } + except Exception: + return None + + +def _generated_signature_conflicts(serialized_sig, generated_sig): + if not serialized_sig["inputs"] and not serialized_sig["outputs"]: + return False + + generated_inputs = generated_sig["inputs"] + generated_input_types = set(generated_inputs.values()) + for name, typ in serialized_sig["inputs"].items(): + if name in generated_inputs: + if generated_inputs[name] != typ: + return True + elif typ not in generated_input_types: + return True + + if Counter(serialized_sig["outputs"]) - Counter(generated_sig["outputs"]): + return True + + return False + + def match(ctx, items): """ Match a batch of nodes given only their (possibly serialized) signature — @@ -375,30 +441,35 @@ def match(ctx, items): Returns a mapping from source node type to candidate list. """ out = {} - generated = ctx.get("generated") or _empty_generated_signatures() + generated = ctx.get("generated") or {} + if not isinstance(generated, dict): + generated = {} generated_sigs = generated.get("sigs") or {} + if not isinstance(generated_sigs, dict): + generated_sigs = {} generated_meta = generated.get("meta") or {} + if not isinstance(generated_meta, dict): + generated_meta = {} for it in items: + if not isinstance(it, dict): + continue t = it.get("type") if not t or t in out: continue - gen_sig = generated_sigs.get(t) - if gen_sig is not None: - gen_pack = (generated_meta.get(t) or {}).get("pack") + sig = _signature_from_item(it) + gen_sig = _normalised_generated_signature(generated_sigs.get(t)) + if gen_sig is not None and not _generated_signature_conflicts(sig, gen_sig): + gen_meta = generated_meta.get(t) or {} + if not isinstance(gen_meta, dict): + gen_meta = {} + gen_pack = gen_meta.get("pack") found = _candidates_for(t, gen_sig, gen_pack, ctx) if found: out[t] = found continue - inputs = {k: str(v) for k, v in (it.get("inputs") or {}).items()} - sig = { - "inputs": inputs, - "required": set(inputs), - "outputs": [str(x) for x in (it.get("outputs") or [])], - "output_names": list(it.get("output_names") or []), - } found = _candidates_for(t, sig, None, ctx) if found: out[t] = found