Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import CamembertTokenizer, CamembertForSequenceClassification | |
| import torch | |
| import numpy as np | |
| def load_tokenizer(): | |
| return CamembertTokenizer.from_pretrained("camembert-base") | |
| def load_model(): | |
| return CamembertForSequenceClassification.from_pretrained("herelles/camembert-base-lupan") | |
| # Define tokenizer: | |
| tokenizer = load_tokenizer() | |
| # Load model: | |
| model = load_model() | |
| model.to('cpu') | |
| def prediction(segment_text): | |
| test_ids = [] | |
| test_attention_mask = [] | |
| # Apply the tokenizer | |
| encoding = tokenizer(segment_text, padding="longest", return_tensors="pt") | |
| # Extract IDs and Attention Mask | |
| test_ids.append(encoding['input_ids']) | |
| test_attention_mask.append(encoding['attention_mask']) | |
| test_ids = torch.cat(test_ids, dim = 0) | |
| test_attention_mask = torch.cat(test_attention_mask, dim = 0) | |
| # Forward pass, calculate logit predictions | |
| with torch.no_grad(): | |
| output = model(test_ids.to('cpu'), token_type_ids = None, attention_mask = test_attention_mask.to('cpu')) | |
| return np.argmax(output.logits.cpu().numpy()).flatten().item() | |
| def main(): | |
| st.header('Textual segments Hérelles prediction tool', divider='rainbow') | |
| segment_text = st.text_area( | |
| "Text to classify:", | |
| "Article 1 : Occupations ou utilisations du sol interdites\n\n" | |
| "1) Dans l’ensemble de la zone sont interdits :\n\n" | |
| "Les constructions destinées à l’habitation ne dépendant pas d’une exploitation agricole autres\n" | |
| "que celles visées à l’article 2 paragraphe 1).", | |
| height=170, | |
| ) | |
| if st.button('Predict'): | |
| pred_id = prediction(segment_text) | |
| if pred_id == 0: | |
| pred_label = 'Not pertinent' | |
| elif pred_id == 1: | |
| pred_label = 'Pertinent (Soft)' | |
| elif pred_id == 2: | |
| pred_label = 'Pertinent (Strict, Non-verifiable)' | |
| elif pred_id == 3: | |
| pred_label = 'Pertinent (Strict, Verifiable)' | |
| st.write("Predicted Class: ", pred_label) | |
| if __name__ == "__main__": | |
| main() | |