diff --git a/db.py b/db.py index 8488fc4..c95f90f 100644 --- a/db.py +++ b/db.py @@ -219,6 +219,9 @@ class ProjectDB: data.setdefault(str_key, 1.0) elif name_key in data and str_key not in data: data[str_key] = 1.0 + # Ensure strength is always a float (JSON may deserialize 1 as int) + if str_key in data: + data[str_key] = float(data[str_key]) return data def get_sequence(self, data_file_id: int, sequence_number: int) -> dict | None: @@ -252,6 +255,11 @@ class ProjectDB: return 0 return self.count_sequences(df["id"]) + _FLOAT_KEYS = frozenset( + f'lora {idx} {tier} strength' + for idx in range(1, 4) for tier in ('high', 'low') + ) + def get_sequence_keys(self, data_file_id: int, sequence_number: int) -> tuple[list[str], list[str]]: """Returns (keys, types) for a sequence's data dict.""" data = self.get_sequence(data_file_id, sequence_number) @@ -261,7 +269,9 @@ class ProjectDB: types = [] for k, v in data.items(): keys.append(k) - if isinstance(v, bool): + if k in self._FLOAT_KEYS: + types.append("FLOAT") + elif isinstance(v, bool): types.append("STRING") elif isinstance(v, int): types.append("INT") diff --git a/utils.py b/utils.py index d8d1083..d200b80 100644 --- a/utils.py +++ b/utils.py @@ -168,6 +168,9 @@ def _migrate_lora_keys(data: dict) -> None: item.setdefault(str_key, 1.0) elif name_key in item and str_key not in item: item[str_key] = 1.0 + # Ensure strength is always a float (JSON may deserialize 1 as int) + if str_key in item: + item[str_key] = float(item[str_key]) def load_json(path: str | Path) -> tuple[dict[str, Any], float]: