Spaces:
Build error
Build error
| from transformers import pipeline | |
| from transformers import AutoModelForSeq2SeqLM | |
| from transformers import AutoTokenizer | |
| from textblob import TextBlob | |
| from hatesonar import Sonar | |
| import gradio as gr | |
| import torch | |
| # Load trained model | |
| model = AutoModelForSeq2SeqLM.from_pretrained("output/reframer") | |
| tokenizer = AutoTokenizer.from_pretrained("output/reframer") | |
| reframer = pipeline('summarization', model=model, tokenizer=tokenizer) | |
| CHAR_LENGTH_LOWER_BOUND = 15 # The minimum character length threshold for the input text | |
| CHAR_LENGTH_HIGHER_BOUND = 150 # The maximum character length threshold for the input text | |
| SENTIMENT_THRESHOLD = 0.2 # The maximum Textblob sentiment score for the input text | |
| OFFENSIVENESS_CONFIDENCE_THRESHOLD = 0.8 # The threshold for the confidence score of a text being offensive | |
| LENGTH_ERROR = "The input text is too long or too short. Please try again by inputing text with moderate length." | |
| SENTIMENT_ERROR = "The input text is too positive. Please try again by inputing text with negative sentiment." | |
| OFFENSIVE_ERROR = "The input text is offensive. Please try again by inputing non-offensive text." | |
| CACHE = [] # A list storing the most recent 5 reframing history | |
| MAX_STORE = 5 # The maximum number of history user would like to store | |
| BEST_N = 3 # The number of best decodes user would like to seee | |
| def input_error_message(error_type): | |
| # type: (str) -> str | |
| """Generate an input error message from error type.""" | |
| return "[Error]: Invalid Input. " + error_type | |
| def update_cache(cache, new_record): | |
| # type: List[List[str, str, str]] -> List[List[str, str, str]] | |
| """Update the cache to store the most recent five reframing histories.""" | |
| cache.append(new_record) | |
| if len(cache) > MAX_STORE: | |
| cache = cache[1:] | |
| return cache | |
| def reframe(input_text, strategy): | |
| # type: (str, str) -> str | |
| """Reframe the input text with a specified strategy. | |
| The strategy will be concetenated to the input text and passed to a finetuned BART model. | |
| The reframed positive text will be returned. | |
| """ | |
| text_with_strategy = input_text + "Strategy: ['" + strategy + "']" | |
| # Input Control | |
| # The input text cannot be too short to ensure it has substantial content to be reframed. It also cannot be too long to ensure the text has a focused idea. | |
| if len(input_text) < CHAR_LENGTH_LOWER_BOUND or len(input_text) > CHAR_LENGTH_HIGHER_BOUND: | |
| return input_error_message(LENGTH_ERROR) | |
| # The input text cannot be too positive to ensure the text can be positively reframed. | |
| if TextBlob(input_text).sentiment.polarity > 0.2: | |
| return input_error_message(SENTIMENT_ERROR) | |
| # The input text cannot be offensive. | |
| sonar = Sonar() | |
| # sonar.ping(input_text) outputs a dictionary and the second score under the key classes is the confidence for the input text being offensive language | |
| if sonar.ping(input_text)['classes'][1]['confidence'] > OFFENSIVENESS_CONFIDENCE_THRESHOLD: | |
| return input_error_message(OFFENSIVE_ERROR) | |
| # Reframing | |
| # reframer pipeline outputs a list containing one dictionary where the value for 'summary_text' is the reframed text output | |
| reframed_text = reframer(text_with_strategy)[0]['summary_text'] | |
| # Update cache | |
| global CACHE | |
| CACHE = update_cache(CACHE, [input_text, strategy, reframed_text]) | |
| return reframed_text | |
| def show_reframe_change(input_text, strategy): | |
| # type: (str, str) -> List[Tuple[str, str]] | |
| """Compare the addition and deletion of characters in input_text to form reframed_text. | |
| The returned output is a list of tuples with two elements, the first element being the character in reframed text and the second element being the action performed with respect to the input text. | |
| """ | |
| reframed_text = reframe(input_text, strategy) | |
| from difflib import Differ | |
| d = Differ() | |
| return [ | |
| (token[2:], token[0] if token[0] != " " else None) | |
| for token in d.compare(input_text, reframed_text) | |
| ] | |
| def show_n_best_decodes(input_text, strategy): | |
| # type: (str, str) -> str | |
| prompt = [input_text + "Strategy: ['" + strategy + "']"] | |
| n_best_decodes = model.generate(torch.tensor(tokenizer(prompt, padding=True)['input_ids']), | |
| do_sample=True, | |
| num_return_sequences=BEST_N | |
| ) | |
| best_n_result = "" | |
| for i in range(len(n_best_decodes)): | |
| best_n_result += str(i+1) + " " + tokenizer.decode(n_best_decodes[i], skip_special_tokens=True) | |
| if i < BEST_N - 1: | |
| best_n_result += "\n" | |
| return best_n_result | |
| def show_history(cache): | |
| # type: List[List[str, str, str]] -> str | |
| history = "" | |
| for i in cache: | |
| input_text, strategy, reframed_text = i | |
| history += "Input text: " + input_text + " Strategy: " + strategy + " -> Reframed text: " + reframed_text + "\n" | |
| return gr.Textbox.update(value=history, visible=True) | |
| # Build Gradio interface | |
| with gr.Blocks() as demo: | |
| # Instruction | |
| gr.Markdown( | |
| ''' | |
| # Positive Reframing | |
| Start inputing negative texts to see how you can see the same event from a positive angle. | |
| ''') | |
| # Input text to be reframed | |
| text = gr.Textbox(label="Original Text") | |
| # Input strategy for the reframing | |
| gr.Markdown( | |
| ''' | |
| Choose one of the six strategies to carry out reframing: \n | |
| **Growth Mindset:** Viewing a challenging event as an opportunity for the author specifically to grow or improve themselves. \n | |
| **Impermanence:** Saying bad things don’t last forever, will get better soon, and/or that others have experienced similar struggles. \n | |
| **Neutralizing:** Replacing a negative word with a neutral word. For example, “This was a terrible day” becomes “This was a long day.” \n | |
| **Optimism:** Focusing on things about the situation itself, in that moment, that are good (not just forecasting a better future). \n | |
| **Self-affirmation:** Talking about what strengths the author already has, or the values they admire, like love, courage, perseverance, etc. \n | |
| **Thankfulness:** Expressing thankfulness or gratitude with key words like appreciate, glad that, thankful for, good thing, etc. | |
| ''') | |
| strategy = gr.Radio( | |
| ["thankfulness", "neutralizing", "optimism", "growth", "impermanence", "self_affirmation"], label="Strategy to use?" | |
| ) | |
| # Trigger button for reframing | |
| greet_btn = gr.Button("Reframe") | |
| best_output = gr.HighlightedText( | |
| label="Diff", | |
| combine_adjacent=True, | |
| ).style(color_map={"+": "green", "-": "red"}) | |
| greet_btn.click(fn=show_reframe_change, inputs=[text, strategy], outputs=best_output) | |
| # Trigger button for showing n best reframings | |
| greet_btn = gr.Button("Show Best {n} Results".format(n=BEST_N)) | |
| n_best_output = gr.Textbox(interactive=False) | |
| greet_btn.click(fn=show_n_best_decodes, inputs=[text, strategy], outputs=n_best_output) | |
| # Default examples of text and strategy pairs for user to have a quick start | |
| gr.Markdown("## Examples") | |
| gr.Examples( | |
| [["I have a lot of homework to do today.", "self_affirmation"], ["This has been the longest and most stressful week of my life!", "optimism"], ["So stressed about the midterms next week.", "thankfulness"]], | |
| [text, strategy], output, show_reframe_change, cache_examples=False, run_on_click=False | |
| ) | |
| # Link to paper and Github repo | |
| gr.Markdown( | |
| ''' | |
| For more details: You can read our [paper](https://arxiv.org/abs/2204.02952) or access our [code](https://github.com/SALT-NLP/positive-frames). | |
| ''') | |
| demo.launch() |