Spaces:
Runtime error
Runtime error
| import re | |
| import json | |
| import argparse | |
| import openai | |
| import gradio as gr | |
| from functools import partial | |
| class GPT4News(): | |
| def __init__(self, prompt_formats): | |
| self.name2prompt = {x['name']: x for x in prompt_formats} | |
| def preprocess(self, function_name, input_txt): | |
| if not self.name2prompt[function_name]['pre_filter']: | |
| return [input_txt] | |
| max_length = self.name2prompt[function_name]['split_length'] | |
| max_convs = self.name2prompt[function_name]['split_round'] | |
| input_txt = re.sub(r'(说话人)(\d+ \d\d:\d\d)', r'Speaker \2', input_txt) | |
| speaker_pattern = re.compile(r'(Speaker \d+ \d\d:\d\d)') | |
| input_txt = speaker_pattern.split(input_txt) | |
| input_txt = [x.strip().replace('\n', ' ') for x in input_txt] | |
| conversations = [] | |
| for idx, txt in enumerate(input_txt): | |
| if speaker_pattern.match(txt): | |
| if idx < len(input_txt) - 1: | |
| if not speaker_pattern.match(input_txt[idx + 1]): | |
| conv = [txt, input_txt[idx + 1]] | |
| else: | |
| conv = [txt, ''] | |
| while len(''.join(conv)) > max_length: | |
| pruned_len = max_length - len(''.join(conv[0])) | |
| pruned_conv = [txt, conv[1][:pruned_len]] | |
| conversations.append(pruned_conv) | |
| conv = [txt, conv[-1][pruned_len:]] | |
| conversations.append(conv) | |
| input_txt_list = [''] | |
| for conv in conversations: | |
| conv_length = len(''.join(conv)) | |
| if len(input_txt_list[-1]) + conv_length >= max_length: | |
| input_txt_list.append('') | |
| elif len(speaker_pattern.findall(input_txt_list[-1])) >= max_convs: | |
| input_txt_list.append('') | |
| input_txt_list[-1] += ''.join(conv) | |
| processed_txt_list = [] | |
| for input_txt in input_txt_list: | |
| input_txt = ''.join(input_txt) | |
| input_txt = speaker_pattern.sub(r'\n\1: ', input_txt) | |
| processed_txt_list.append(input_txt.strip()) | |
| return processed_txt_list | |
| def chatgpt(self, messages, temperature=0.0): | |
| try: | |
| completion = openai.ChatCompletion.create( | |
| model="gpt-3.5-turbo", | |
| messages=messages, | |
| temperature=temperature | |
| ) | |
| return completion.choices[0].message.content | |
| except Exception as err: | |
| print(err) | |
| return self.chatgpt(messages, temperature) | |
| def llm(self, function_name, temperature, **kwargs): | |
| prompt = self.name2prompt[function_name] | |
| user_kwargs = {key: kwargs[key] for key in prompt['user_keys']} | |
| user = prompt['user'].format(**user_kwargs) | |
| system_kwargs = {key: kwargs[key] for key in prompt['system_keys']} | |
| system = prompt['system'].format(**system_kwargs) | |
| messages = [ | |
| {'role': 'system', | |
| 'content': system}, | |
| {'role': 'user', | |
| 'content': user}] | |
| response = self.chatgpt(messages, temperature=temperature) | |
| print(f'SYSTEM:\n\n{system}') | |
| print(f'USER:\n\n{user}') | |
| print(f'RESPONSE:\n\n{response}') | |
| return response | |
| def translate(self, txt, output_lang): | |
| if output_lang == 'English': | |
| return txt | |
| system = 'You are a translator.' | |
| user = 'Translate the following text to {}:\n\n{}'.format( | |
| output_lang, txt) | |
| messages = [{'role': 'system', 'content': system}, | |
| {'role': 'user', 'content': user}] | |
| response = self.chatgpt(messages) | |
| print(f'SYSTEM:\n\n{system}') | |
| print(f'USER:\n\n{user}') | |
| print(f'RESPONSE:\n\n{response}') | |
| return response | |
| def postprocess(self, function_name, input_txt, output_txt_list, | |
| output_lang): | |
| if not self.name2prompt[function_name]['post_filter']: | |
| output_txt = '\n\n'.join(output_txt_list) | |
| output_txt = self.translate(output_txt, output_lang) | |
| return output_txt | |
| speaker_pattern = re.compile(r'(Speaker \d+ \d\d:\d\d)') | |
| output_txt = [] | |
| for txt in output_txt_list: | |
| if len(speaker_pattern.findall(txt)) > 0: | |
| output_txt.append(txt) | |
| output_txt = ''.join(output_txt) | |
| speakers = set(speaker_pattern.findall(input_txt)) | |
| output_txt = speaker_pattern.split(output_txt) | |
| results = [] | |
| for idx, txt in enumerate(output_txt): | |
| if speaker_pattern.match(txt): | |
| if txt not in speakers: | |
| continue | |
| if idx < len(output_txt) - 1: | |
| if not speaker_pattern.match(output_txt[idx + 1]): | |
| res = txt + output_txt[idx + 1] | |
| else: | |
| res = txt | |
| res = self.translate(res, output_lang) | |
| results.append(res.strip()) | |
| return '\n\n'.join(results) | |
| def __call__(self, api_key, function_name, temperature, output_lang, | |
| input_txt, tags): | |
| if api_key is None or api_key == '': | |
| return 'OPENAI API Key is not set.' | |
| if function_name is None or function_name == '': | |
| return 'Function is not selected.' | |
| openai.api_key = api_key | |
| input_txt_list = self.preprocess(function_name, input_txt) | |
| input_txt = '\n'.join(input_txt_list) | |
| output_txt_list = [] | |
| for txt in input_txt_list: | |
| llm_kwargs = dict(input_txt=txt, | |
| tags=tags) | |
| output_txt = self.llm(function_name, temperature, **llm_kwargs) | |
| output_txt_list.append(output_txt) | |
| output_txt = self.postprocess( | |
| function_name, input_txt, output_txt_list, output_lang) | |
| return output_txt | |
| def function_names(self): | |
| return self.name2prompt.keys() | |
| def function_name_select_callback(componments, name2prompt, function_name): | |
| prompt = name2prompt[function_name] | |
| user_keys = prompt['user_keys'] | |
| result = [] | |
| for comp in componments: | |
| result.append(gr.update(visible=comp in user_keys)) | |
| return result | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--prompt', type=str, default='prompts/interview.json', | |
| help='path to the prompt file') | |
| parser.add_argument('--temperature', type=float, default='0.7', | |
| help='temperature for the llm model') | |
| args = parser.parse_args() | |
| prompt_formats = json.load(open(args.prompt, 'r')) | |
| gpt4news = GPT4News(prompt_formats) | |
| languages = ['Arabic', 'Bengali', 'Chinese (Simplified)', | |
| 'Chinese (Traditional)', 'Dutch', 'English', 'French', | |
| 'German', 'Hindi', 'Italian', 'Japanese', 'Korean', | |
| 'Portuguese', 'Punjabi', 'Russian', 'Spanish', 'Turkish', | |
| 'Urdu'] | |
| default_func = sorted(gpt4news.function_names)[0] | |
| default_user_keys = gpt4news.name2prompt[default_func]['user_keys'] | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(scale=0.3): | |
| with gr.Row(): | |
| api_key = gr.Textbox( | |
| lines=1, | |
| label='OPENAI API Key', | |
| elem_id='api_key_textbox', | |
| placeholder='Enter your OPENAI API Key') | |
| with gr.Row(): | |
| function_name = gr.Dropdown( | |
| sorted(gpt4news.function_names), | |
| value=default_func, | |
| elem_id='function_dropdown', | |
| label='Function', | |
| info='choose a function to run') | |
| with gr.Row(): | |
| output_lang = gr.Dropdown( | |
| languages, | |
| value='English', | |
| elem_id='output_lang_dropdown', | |
| label='Output Language', | |
| info='choose a language to output') | |
| with gr.Row(): | |
| temperature = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=args.temperature, | |
| step=0.1, | |
| interactive=True, | |
| label='Temperature', | |
| info='higher temperature means more creative') | |
| with gr.Row(): | |
| tags = gr.Textbox( | |
| lines=1, | |
| visible='tags' in default_user_keys, | |
| label='Tags', | |
| elem_id='tags_textbox', | |
| placeholder='Enter tags split by semicolon') | |
| with gr.Row(): | |
| input_txt = gr.Textbox( | |
| lines=4, | |
| visible='input_txt' in default_user_keys, | |
| label='Input', | |
| elem_id='input_textbox', | |
| placeholder='Enter text and press submit') | |
| with gr.Row(): | |
| submit = gr.Button('Submit') | |
| with gr.Row(): | |
| clear = gr.Button('Clear') | |
| with gr.Column(scale=0.7): | |
| output_txt = gr.Textbox( | |
| lines=8, | |
| label='Output', | |
| elem_id='output_textbox') | |
| function_name.select( | |
| partial(function_name_select_callback, ['input_txt', 'tags'], | |
| gpt4news.name2prompt), | |
| [function_name], | |
| [input_txt, tags] | |
| ) | |
| submit.click( | |
| gpt4news, | |
| [api_key, function_name, temperature, output_lang, | |
| input_txt, tags], | |
| [output_txt]) | |
| clear.click( | |
| lambda: ['', '', ''], | |
| None, | |
| tags, input_txt) | |
| demo.queue(concurrency_count=6) | |
| demo.launch() | |