Spaces:
Sleeping
Sleeping
| import os | |
| import pandas as pd | |
| import streamlit as st | |
| from transformers import AutoTokenizer | |
| if os.getenv("SPACE_ID"): | |
| USE_HF_SPACE = True | |
| os.environ["HF_HOME"] = "/data/.huggingface" | |
| os.environ["HF_DATASETS_CACHE"] = "/data/.huggingface" | |
| else: | |
| USE_HF_SPACE = False | |
| DEFAULT_TOKENIZER_NAME = os.environ.get( | |
| "DEFAULT_TOKENIZER_NAME", "tohoku-nlp/bert-base-japanese-v3" | |
| ) | |
| DEFAULT_TEXT = """ | |
| hello world! | |
| こんにちは、世界! | |
| 你好,世界 | |
| """.strip() | |
| DEFAULT_COLOR = "gray" | |
| COLORS_CYCLE = [ | |
| "yellow", | |
| "cyan", | |
| ] | |
| def color_cycle_generator(): | |
| def _color_cycle_generator(): | |
| while True: | |
| for color in COLORS_CYCLE: | |
| yield color | |
| return _color_cycle_generator() | |
| def get_tokenizer(tokenizer_name: str = DEFAULT_TOKENIZER_NAME): | |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) | |
| return tokenizer | |
| def main(): | |
| st.set_page_config( | |
| page_title="TokenViz: AutoTokenizer Visualization Tool", | |
| layout="centered", | |
| initial_sidebar_state="auto", | |
| ) | |
| st.title("TokenViz: AutoTokenizer Visualization Tool") | |
| st.text_input( | |
| "AutoTokenizer model name", key="tokenizer_name", value=DEFAULT_TOKENIZER_NAME | |
| ) | |
| if st.session_state.tokenizer_name: | |
| tokenizer = get_tokenizer(st.session_state.tokenizer_name) | |
| st.text_input("subword prefix", key="subword_prefix", value="##") | |
| st.text_area("text", key="text", height=200, value=DEFAULT_TEXT) | |
| # Submit | |
| if st.button("tokenize"): | |
| text = st.session_state.text.strip() | |
| subword_prefix = st.session_state.subword_prefix.strip() | |
| token_ids = tokenizer.encode(text, add_special_tokens=True) | |
| tokens = tokenizer.convert_ids_to_tokens(token_ids) | |
| total_tokens = len(tokens) | |
| token_table_df = pd.DataFrame( | |
| { | |
| "token_id": token_ids, | |
| "token": tokens, | |
| } | |
| ) | |
| st.subheader("visualized tokens") | |
| st.markdown(f"total tokens: **{total_tokens}**") | |
| tab_main, tab_token_table = st.tabs(["tokens", "table"]) | |
| color_gen = color_cycle_generator() | |
| with tab_main: | |
| current_subword_color = next(color_gen) | |
| token_html = "" | |
| for idx, (token_id, token) in enumerate(zip(token_ids, tokens)): | |
| if len(subword_prefix) == 0: | |
| token_border = f"1px solid {DEFAULT_COLOR}" | |
| else: | |
| current_token_is_subword = token.startswith(subword_prefix) | |
| next_token_is_subword = idx + 1 < total_tokens and tokens[ | |
| idx + 1 | |
| ].startswith(subword_prefix) | |
| if next_token_is_subword and not current_token_is_subword: | |
| current_subword_color = next(color_gen) | |
| if current_token_is_subword or next_token_is_subword: | |
| token_border = f"1px solid {current_subword_color}" | |
| else: | |
| token_border = f"1px solid {DEFAULT_COLOR}" | |
| html_escaped_token = token.replace("<", "<").replace(">", ">") | |
| token_html += f'<span title="{str(token_id)}" style="border: {token_border}; border-radius: 3px; padding: 2px; margin: 2px;">{html_escaped_token}</span>' | |
| st.html( | |
| f"<p style='line-height:2em;'>{token_html}</p>", | |
| ) | |
| st.subheader("token_ids") | |
| token_ids_str = ",".join(map(str, token_ids)) | |
| st.code(token_ids_str) | |
| with tab_token_table: | |
| st.table(token_table_df) | |
| if __name__ == "__main__": | |
| main() | |