| import gradio as gr | |
| print('Imported gradio', gr.__version__) | |
| import sentence_transformers | |
| print('Imported sentence_transformers', sentence_transformers.__version__) | |
| import torch | |
| print('Imported torch', torch.__version__) | |
| language_mapping = { | |
| 'English (eng_Latn)': 'MonoLR_eng_Latn_PR', | |
| 'Irish (gle_Latn)': 'MonoLR_gle_Latn_PR', | |
| 'Maltese (mlt_Latn)': 'MonoLR_mlt_Latn_PR', | |
| 'Russian (rus_Cyrl)': 'MonoLR_rus_Cyrl_PR', | |
| 'Welsh (cym_Latn)': 'MonoLR_cym_Latn_PR', | |
| 'Xhosa (xho_Latn)': 'MonoLR_xho_Latn_PR' | |
| } | |
| print('Loading base model...') | |
| model = sentence_transformers.CrossEncoder('MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7') | |
| print('Base model loaded!') | |
| model.config.num_labels = 1 | |
| model.default_activation_function = torch.nn.Sigmoid() | |
| print('Loading adapters...') | |
| for language, adapter in language_mapping.items(): | |
| model.model.load_adapter(f'WilliamSotoM/{adapter}', adapter) | |
| print(adapter, 'loaded!') | |
| print('Defining evaluate function...') | |
| def evaluate(language, rdf_graph, generated_text): | |
| model.model.set_adapter(language_mapping[language]) | |
| print(f"Enabled {language} LoRA") | |
| precision = model.predict([(rdf_graph, generated_text)])[0] | |
| recall = model.predict([(generated_text, rdf_graph)])[0] | |
| f1 = (2*precision*recall)/(precision+recall) | |
| print('RDF Graph:', rdf_graph) | |
| print('Generated Text:', generated_text) | |
| print('-----') | |
| print(f'Precision: {precision:.4f}') | |
| print(f'Recall: {recall:.4f}') | |
| print(f'F1: {f1:.4f}') | |
| return precision, recall, f1 | |
| print('Evaluate function defined!') | |
| print('Instantiating gradio interface...') | |
| demo = gr.Interface( | |
| fn=evaluate, | |
| inputs = [ | |
| gr.Dropdown(label = 'Language', choices=list(language_mapping.keys()), value='English (eng_Latn)'), | |
| gr.Textbox(label='RDF Graph'), | |
| gr.Textbox(label='Generated Text') | |
| ], | |
| outputs = [ | |
| gr.Number(label='Semantic Precision'), | |
| gr.Number(label='Semantic Recall'), | |
| gr.Number(label='Semantic F1') | |
| ], | |
| title = 'Semantic Evaluation of Multilingual D2T', | |
| description = '''Select a language, then type in an input RDF Graph and its corresponding Generated Text to perform the evaluation. | |
| Indicate the subject, property, and object of the RDF triples with the following tokens: [S], [P], [O]. | |
| Separate each triple with the following token: [T] | |
| For example: | |
| [S]Buzz_Aldrin[P]mission[O]Apollo_12[T][S]Buzz_Aldrin[P]birthPlace[O]Glen_Ridge,_New_Jersey''' | |
| ) | |
| print('Gradio interface instantiated...') | |
| print('Launching server...') | |
| demo.launch() | |