File size: 17,409 Bytes
9a8fa2b
a2a26d5
9a8fa2b
a083908
 
 
1bd1c19
8f21768
43e022b
 
 
 
e511669
a2a26d5
3e8b9ba
 
c4d01dc
 
a083908
 
 
 
 
 
3e8b9ba
 
 
63b741b
3e8b9ba
 
 
 
63b741b
 
 
 
 
 
6b06af0
 
 
 
3e8b9ba
ab55825
6b06af0
593928c
73d39e7
ab55825
 
27cf7bd
339c381
27cf7bd
73d39e7
 
de1b0d5
 
 
 
 
 
 
 
 
 
 
 
 
73d39e7
 
 
 
 
 
 
 
c9649d5
 
 
 
 
 
 
 
 
de1b0d5
 
 
c9649d5
de1b0d5
 
c9649d5
73d39e7
1bd1c19
43e022b
 
 
 
 
 
 
 
 
 
 
 
 
 
1bd1c19
 
43e022b
 
 
 
1bd1c19
c4d01dc
 
 
 
48c598c
1bd1c19
73d39e7
6b06af0
8f21768
 
 
63b741b
 
6b06af0
 
 
 
 
 
 
 
 
 
63b741b
 
 
 
 
6b06af0
 
 
 
 
 
 
 
 
 
 
 
 
ab55825
6b06af0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48c598c
6b06af0
 
48c598c
6b06af0
 
de1b0d5
1bd1c19
 
8f21768
 
1bd1c19
 
 
 
 
 
 
 
 
 
 
 
 
c9649d5
1bd1c19
48c598c
73d39e7
 
1bd1c19
 
 
 
70aff72
c4d01dc
f6c0c9a
48c598c
3e8b9ba
c25b13e
48c598c
 
 
c25b13e
 
 
 
 
4a6621a
44cf7cf
fc91b44
 
 
 
44cf7cf
 
0a02b1c
 
3a381f9
0a02b1c
3a381f9
0a02b1c
db684ca
 
c3176c9
db684ca
c3176c9
db684ca
3a381f9
 
 
4a6621a
 
 
c4d01dc
73d39e7
e511669
4a6621a
 
ab55825
4a6621a
 
bb12eec
4a6621a
 
 
48c598c
 
 
4a6621a
c25b13e
 
 
 
 
 
 
 
 
 
 
 
 
 
44cf7cf
 
 
 
3724f29
 
4e21e99
 
 
 
fc91b44
 
3724f29
fc91b44
 
 
 
 
 
4a6621a
3724f29
 
 
 
 
 
 
 
 
db684ca
3b8a787
601a404
3b8a787
4a6621a
48c598c
c25b13e
48c598c
c3176c9
 
 
44cf7cf
73d39e7
 
 
4a6621a
8d81e67
 
eb8dc75
 
8d81e67
 
eb8dc75
 
70aff72
eb8dc75
 
8d81e67
70aff72
 
c4d01dc
 
 
 
f6c0c9a
f2a701f
48c598c
 
 
 
 
f2a701f
 
 
 
0a02b1c
3a381f9
 
0a02b1c
0259690
 
c4d01dc
48c598c
 
 
 
 
 
 
 
0259690
48c598c
3a381f9
 
1bd1c19
f6c0c9a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
import spaces
import gradio as gr
import torch
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer

import json
from datetime import datetime
from uuid import uuid4
import os
from pathlib import Path
from huggingface_hub import CommitScheduler
from utils import hide_code, hide_css

# TODO make it so that feedback is only saved on prev. example if user makes another obfuscation
# and changes slider but doesn't hit obfuscate
# TODO maybe make it save and reset if user hits submit feedback
# TODO sampling params for modles
# TODO obfuscation ID?

# Converts text to the correct format for LoRA adapters in StyleRemix
def convert_data_to_format(text):
    output = f"### Original: {text}\n ### Rewrite:"
    return output

MODEL_PATHS = {
    "length_more": "hallisky/lora-length-long-llama-3-8b",
    "length_less": "hallisky/lora-length-short-llama-3-8b",
    "function_more": "hallisky/lora-function-more-llama-3-8b",
    "function_less": "hallisky/lora-function-less-llama-3-8b",
    "grade_more": "hallisky/lora-grade-highschool-llama-3-8b", 
    "grade_less": "hallisky/lora-grade-elementary-llama-3-8b",
    "formality_more": "hallisky/lora-formality-formal-llama-3-8b", 
    "formality_less": "hallisky/lora-formality-informal-llama-3-8b", 
    "sarcasm_more": "hallisky/lora-sarcasm-more-llama-3-8b", 
    "sarcasm_less": "hallisky/lora-sarcasm-less-llama-3-8b", 
    "voice_passive": "hallisky/lora-voice-passive-llama-3-8b", 
    "voice_active": "hallisky/lora-voice-active-llama-3-8b", 
    "type_persuasive": "hallisky/lora-type-persuasive-llama-3-8b",
    "type_expository": "hallisky/lora-type-expository-llama-3-8b",
    "type_narrative": "hallisky/lora-type-narrative-llama-3-8b",
    "type_descriptive": "hallisky/lora-type-descriptive-llama-3-8b",
}
FIRST_MODEL = list(MODEL_PATHS.keys())[0]
MAX_NEW_TOKENS=1024

DESCRIPTION = """\
# Authorship Obfuscation with StyleRemix
This Space demonstrates StyleRemix, a controllable and interpretable method for authorship obfuscation. At its core, it uses a Llama-3 model with 8B parameters and various LoRA adapters fine-tuned to rewrite text towards specific stylistic attributes (like text being longer or shorter). Feel free to play with it, or duplicate to run generations without a queue! If you want to run your own service, you can also deploy the model on [deploy the model on Inference Endpoints](https://huggingface.co/inference-endpoints). 
<br> πŸ•΅οΈ Want to learn more? Check out our paper [here](https://arxiv.org/abs/2408.15666v1) and our code [here](https://github.com/jfisher52/StyleRemix)!
<br> 🧐 Have questions about our work or issues with the demo? Feel free to email us at [email protected] and [email protected].
<br> <b>Disclaimer</b>: <em>We may collect and use your queries for further research and development purposes. The data collected will be anonymized and used to enhance our understanding of desired stylistic attributes when rewriting text.</em>
"""

import subprocess
def print_nvidia_smi():
    try:
        # Run the nvidia-smi command
        result = subprocess.run(['nvidia-smi'], capture_output=True, text=True, check=True)
        print(result.stdout)
    except subprocess.CalledProcessError as e:
        # Handle errors in the subprocess
        print(f"Failed to run nvidia-smi: {e}")
    except FileNotFoundError:
        # Handle the case where nvidia-smi is not installed
        print("nvidia-smi is not installed or not in the PATH.")

# Load models
if not torch.cuda.is_available():
    device = "cpu"
    DESCRIPTION += "\n<p>Running on CPU πŸ₯Ά This demo does not work on CPU.</p>"

if torch.cuda.is_available():
    device = "cuda"
    model_id = "meta-llama/Meta-Llama-3-8B"

    tokenizer = AutoTokenizer.from_pretrained(model_id, add_bos_token=True, add_eos_token=False, padding_side="left")
    tokenizer.add_special_tokens({'pad_token': '<padding_token>'})

    base_model = AutoModelForCausalLM.from_pretrained(model_id).to(device) # device_map="auto" requires accelerate
    base_model.resize_token_embeddings(len(tokenizer)) # Resize to add pad token. Value doesn't matter
    # Load in the first model
    model = PeftModel.from_pretrained(base_model, MODEL_PATHS[FIRST_MODEL], adapter_name=FIRST_MODEL).to(device)
    # Load in the rest of the models
    for cur_adapter in MODEL_PATHS.keys():
        if cur_adapter != FIRST_MODEL:
            model.load_adapter(MODEL_PATHS[cur_adapter], adapter_name=cur_adapter)

    # print(model.device) # Seems it re-allocates to CPU
    model.to(device)
    model.eval()

# Global variable to store the latest obfuscation result
user_id = str(uuid4())  # Generate a unique session-specific user ID

JSON_DATASET_DIR = Path("json_dataset")
JSON_DATASET_DIR.mkdir(parents=True, exist_ok=True)

JSON_DATASET_PATH = JSON_DATASET_DIR / f"train-{user_id}.json"

scheduler = CommitScheduler(
    repo_id="authorship-obfuscation-demo-data",
    repo_type="dataset",
    folder_path=JSON_DATASET_DIR,
    path_in_repo="data",
    every=0.5
)

def save_data(data):
    with scheduler.lock:
        with JSON_DATASET_PATH.open("a") as f:
            json.dump(data, f)
            f.write("\n")

def save_feedback(feedback_rating, feedback_text, latest_obfuscation):
    latest_obfuscation["feedback_rating"] = feedback_rating
    latest_obfuscation["feedback_text"] = feedback_text
    save_data(latest_obfuscation)
    return "No Feedback Selected", "", gr.update(visible=True)

@spaces.GPU
def greet(input_text, length, function_words, grade_level, formality, sarcasm, voice, persuasive, descriptive, narrative, expository):
    global latest_obfuscation, user_id
    current_time = datetime.now().isoformat()
    
    sliders_dict = {}
    cur_keys = []
    cur_keys.append(("length_more" if length > 0 else (None if length == 0 else "length_less"), abs(length)))
    cur_keys.append(("function_more" if function_words > 0 else (None if function_words == 0 else "function_less"), abs(function_words)))
    cur_keys.append(("grade_more" if grade_level > 0 else (None if grade_level == 0 else "grade_less"), abs(grade_level)))
    cur_keys.append(("sarcasm_more" if sarcasm > 0 else (None if sarcasm == 0 else "sarcasm_less"), abs(sarcasm)))
    cur_keys.append(("formality_more" if formality > 0 else (None if formality == 0 else "formality_less"), abs(formality)))
    cur_keys.append(("voice_active" if voice > 0 else (None if voice == 0 else "voice_passive"),abs(voice)))
    cur_keys.append(("type_persuasive" if persuasive != 0 else None, abs(persuasive)))
    cur_keys.append(("type_descriptive" if descriptive != 0 else None, abs(descriptive)))
    cur_keys.append(("type_narrative" if narrative != 0 else None, abs(narrative)))
    cur_keys.append(("type_expository" if expository != 0 else None, abs(expository)))

    for cur_key in cur_keys:
        if cur_key[0] is not None:
            sliders_dict[cur_key[0]] = cur_key[1]

    # Make the adapter and switch to it
    print(sliders_dict)

    if len(sliders_dict) > 0:
        combo_adapter_name = ""
        for slider_key in sliders_dict:
            print(slider_key)
            print(sliders_dict[slider_key])
            combo_adapter_name += slider_key + str(int(100*sliders_dict[slider_key])) + "-"
        combo_adapter_name = combo_adapter_name[:-1]
        print(combo_adapter_name)
        print(list(sliders_dict.values()))
        print(list(sliders_dict.keys()))
        print(list(model.peft_config.keys()))

        # Add and set the weighted adapater
        model.add_weighted_adapter(
            list(sliders_dict.keys()),
            weights = list(sliders_dict.values()),
            adapter_name = combo_adapter_name,
            combination_type = "cat"
        )
        model.set_adapter(combo_adapter_name)
        
        # Convert the list of strings in data to a list of model inputs
        converted_text = convert_data_to_format(input_text)
        inputs = tokenizer(converted_text, return_tensors="pt", max_length=2048, truncation=True).to(device)
        input_length = inputs.input_ids.shape[1]
        with torch.no_grad(): 
            outputs = model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS, top_p = 0.95)
        response = tokenizer.decode(outputs[0, input_length:], skip_special_tokens=True).strip()
        full_output = tokenizer.decode(outputs[0], skip_special_tokens=False)
    else:
        response = input_text # If no sliders passed, do not do anything
        full_output = response

    # print_nvidia_smi() # Print GPU usage

    # Save the new obfuscation result and reset feedback
    latest_obfuscation = {
        "datetime": current_time,
        "user_id": user_id,
        "input_text": input_text,
        "sliders": {
            "length": length,
            "function_words": function_words,
            "grade_level": grade_level,
            "sarcasm": sarcasm,
            "formality": formality,
            "voice": voice,
            "persuasive": persuasive,
            "descriptive": descriptive,
            "narrative": narrative,
            "expository": expository
        },
        "input": input_text,
        "output": response,
        "full_output": full_output,
        "feedback_rating": "No Feedback Selected",
        "feedback_text": ""
    }
    
    # Save the obfuscation result
    save_data(latest_obfuscation)
    
    return response, gr.update(interactive=True), gr.update(interactive=True), latest_obfuscation

def auto_sliders():
    return [0.5] * 7 + [0] * 3

def reset_sliders():
    return [0] * 7 + [0] * 3

def toggle_slider(checked, value):
    if checked:
        return gr.update(value=value, interactive=True)
    else:
        return gr.update(value=0, interactive=False)

def reset_writing_type_sliders(selected_type):
    reset_values = [gr.update(value=0, interactive=False) for _ in range(4)]
    if selected_type != "None":
        index = ["Persuasive", "Descriptive", "Narrative", "Expository"].index(selected_type)
        reset_values[index] = gr.update(value=0, interactive=True)
    return reset_values

def update_save_feedback_button(feedback_rating, feedback_text):
    if feedback_rating != "No Feedback Selected" or feedback_text.strip() != "":
        return gr.update(interactive=True), gr.update(visible=False)
    else:
        return gr.update(interactive=False), gr.update(visible=True)

def update_obfuscate_button(input_text):
    if input_text.strip() == "":
        return gr.update(interactive=False), gr.update(visible=True)
    else:
        return gr.update(interactive=True), gr.update(visible=False)

def check_initial_feedback_state(feedback_rating, feedback_text):
    return update_save_feedback_button(feedback_rating, feedback_text)

demo = gr.Blocks()

with demo:
    latest_obfuscation = gr.State({})
    gr.Markdown(DESCRIPTION)
    gr.HTML(hide_css)
    with gr.Row():
        with gr.Column(variant="panel"):
            gr.Markdown("# 1) Input Text\n### Enter the text to be obfuscated. We recommend *full sentences* or *paragraphs*.")
            input_text = gr.Textbox(
                label="Input Text",
                placeholder="The quick brown fox jumped over the lazy dogs."
            )
            gr.Markdown("# 2) Style Element Sliders\n### Adjust the style element sliders to the desired levels to steer the obfuscation.")
            
            with gr.Row():
                auto_button = gr.Button("Choose slider values automatically (based on input text)")
                reset_button = gr.Button("Reset slider values")
            
            sliders = []
            slider_values = [
                ("Length (Shorter \u2192 Longer)", -1, 1, 0),
                ("Function Words (Fewer \u2192 More)", -1, 1, 0),
                ("Grade Level (Lower \u2192 Higher)", -1, 1, 0),
                ("Formality (Less \u2192 More)", -1, 1, 0),
                ("Sarcasm (Less \u2192 More)", -1, 1, 0),
                ("Voice (Passive \u2192 Active)", -1, 1, 0),
                ("Writing Type: Persuasive (None \u2192 More)", 0, 1, 0),
                ("Writing Type: Descriptive (None \u2192 More)", 0, 1, 0),
                ("Writing Type: Narrative (None \u2192 More)", 0, 1, 0),
                ("Writing Type: Expository (None \u2192 More)", 0, 1, 0)
            ]
            
            non_writing_type_sliders = []
            writing_type_sliders = []
            
            for idx, (label, min_val, max_val, default) in enumerate(slider_values):
                if "Writing Type" not in label:
                    with gr.Row():
                        # with gr.Column(scale=1, min_width=25):
                        checkbox = gr.Checkbox(label=label.split("(")[0], scale=1)
                        #with gr.Column(scale=2, min_width=50):
                        slider = gr.Slider(label=label.split("(")[1][:-1], minimum=min_val, maximum=max_val, step=0.01, value=default, interactive=False, scale=3)
                        checkbox.change(fn=toggle_slider, inputs=[checkbox, gr.State(default)], outputs=slider)
                        non_writing_type_sliders.append(slider)
                        sliders.append(slider)
            
            writing_type_radio = gr.Radio(
                label="Writing Type",
                choices=["None", "Persuasive", "Descriptive", "Narrative", "Expository"],
                value="None"
            )
            
            writing_type_radio.change(fn=reset_writing_type_sliders, inputs=writing_type_radio, outputs=writing_type_sliders)
            
            for idx, (label, min_val, max_val, default) in enumerate(slider_values):
                if "Writing Type" in label:
                    with gr.Row():
                        slider = gr.Slider(label=label, minimum=min_val, maximum=max_val, step=0.01, value=default, interactive=False)
                        writing_type_sliders.append(slider)
                        sliders.append(slider)
            
            obfuscate_button = gr.Button("Obfuscate Text", interactive=False)
            warning_message = gr.Markdown(
                "<div style='text-align: center; color: red;'>⚠️ Please enter text before obfuscating. ⚠️</div>", visible=True
            )
            
            auto_button.click(fn=auto_sliders, inputs=[], outputs=sliders)
            reset_button.click(fn=reset_sliders, inputs=[], outputs=sliders)

            input_text.change(fn=update_obfuscate_button, inputs=input_text, outputs=[obfuscate_button, warning_message])
            # Initialize the button and warning message state on page load
            demo.load(fn=update_obfuscate_button, inputs=input_text, outputs=[obfuscate_button, warning_message])

        # with gr.Column(variant="panel"):
        #     gr.Markdown("# 3) Obfuscated Output")

        with gr.Column(variant="panel"):
            gr.Markdown("# 3) Obfuscated Output")

            output = gr.Textbox(label="Output", lines=3)
            
            gr.Markdown("## Feedback [Optional]")

            # Add thumbs up / thumbs down
            gr.Markdown("### Is the response good or bad?")
            feedback_rating = gr.Radio(choices=["No Feedback Selected", "Good πŸ‘", "Bad πŸ‘Ž"], value="No Feedback Selected", interactive=False, label="Rate the Response")
            
            # Add feedback box
            gr.Markdown("### Provide any feedback on the obfuscation")
            feedback_text = gr.Textbox(label="Feedback", lines=3, interactive=False)

            obfuscate_button.click(
                fn=greet, 
                inputs=[input_text] + sliders, 
                outputs=[output, feedback_rating, feedback_text, latest_obfuscation])

            save_feedback_button = gr.Button("Submit Feedback", interactive=False)

            confirmation_message = gr.Markdown(
                "<div id='confirmation-message' style='text-align: center; color: green;'>πŸ₯³ Feedback has been submitted successfully! 🎊</div>", visible=False
            )

            feedback_warning_message = gr.Markdown(
                "<div id='feedback-warning' style='text-align: center; color: red;'>⚠️ Please provide feedback or a rating before submitting. ⚠️</div>", visible=True
            )

            # Update the interactivity of the save_feedback_button based on feedback_rating and feedback_text
            feedback_rating.change(fn=update_save_feedback_button, inputs=[feedback_rating, feedback_text], outputs=[save_feedback_button, feedback_warning_message])
            feedback_text.change(fn=update_save_feedback_button, inputs=[feedback_rating, feedback_text], outputs=[save_feedback_button, feedback_warning_message])

            save_feedback_button.click(
                fn=save_feedback,
                inputs=[feedback_rating, feedback_text, latest_obfuscation],
                outputs=[feedback_rating, feedback_text, confirmation_message]
            )

            save_feedback_button.click(
                fn=None,
                inputs=[],
                outputs=None,
                js=hide_code
            )
            
            # Initialize the save feedback button and warning message state on page load
            demo.load(fn=check_initial_feedback_state, inputs=[feedback_rating, feedback_text], outputs=[save_feedback_button, feedback_warning_message])

demo.launch()