diff --git a/scripts/tokenizer.py b/scripts/tokenizer.py index df69c16..02963b3 100644 --- a/scripts/tokenizer.py +++ b/scripts/tokenizer.py @@ -21,9 +21,13 @@ css = """ """ -def tokenize(text): +def tokenize(text, input_is_ids=False): clip: FrozenCLIPEmbedder = shared.sd_model.cond_stage_model.wrapped - tokens = clip.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"] + + if input_is_ids: + tokens = [int(x.strip()) for x in text.split(",")] + else: + tokens = clip.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"] vocab = {v: k for k, v in clip.tokenizer.get_vocab().items()} @@ -33,21 +37,41 @@ def tokenize(text): current_ids = [] class_index = 0 - def dump(): - nonlocal code, ids, current_ids, class_index + def dump(last=False): + nonlocal code, ids, current_ids words = [vocab.get(x, "") for x in current_ids] + def wordscode(ids, word): + nonlocal class_index + res = f"""{html.escape(word)}""" + class_index += 1 + return res + try: word = bytearray([clip.tokenizer.byte_decoder[x] for x in ''.join(words)]).decode("utf-8") except UnicodeDecodeError: - return + if last: + word = "❌" * len(current_ids) + elif len(current_ids) > 4: + id = current_ids[0] + ids += [id] + local_ids = current_ids[1:] + code += wordscode([id], "❌") + + current_ids = [] + for id in local_ids: + current_ids.append(id) + dump() + + return + else: + return word = word.replace("", " ") - code += f"""{html.escape(word)}""" + code += wordscode(current_ids, word) ids += current_ids - class_index += 1 current_ids = [] @@ -57,9 +81,16 @@ def tokenize(text): dump() - dump() + dump(last=True) - return code, ids + ids_html = f""" +
+Token count: {len(ids)}
+{", ".join([str(x) for x in ids])}
+