File size: 1,049 Bytes
c24d815
f535f4a
c24d815
 
 
 
 
 
72f383a
 
 
 
 
 
 
 
 
c24d815
 
 
 
72f383a
 
 
9250959
72f383a
9250959
72f383a
9250959
72f383a
c24d815
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import streamlit as st
import torch
from transformers import AutoModelForSequenceClassification as ASC 
from transformers import AutoTokenizer as AT 

model = ASC.from_pretrained("rickxzo/albert-large-v2-s.a.m-nli")
tokenizer = AT.from_pretrained("rickxzo/albert-large-v2-s.a.m-nli")

def infer(sentence1, sentence2):
    inputs = tokenizer(sentence1, sentence2, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        outputs = model(**inputs)

    logits = outputs.logits 
    probs = torch.nn.functional.softmax(logits, dim=-1) 
    return torch.argmax(probs).item()

st.title("Contradiction Detector using AlBERT model")
premise = st.text_area("Enter the premise: ")
hypothesis = st.text_area("Enter the hypothesis: ")

if premise and hypothesis:
    k = infer(premise, hypothesis)
    if k == 2:
        st.write("#### **Contradicting Statements Detected!**")
    elif k == 1:
        st.write("#### **Neutral Statements Detected.**")
    elif k == 0:
        st.write("#### **Entailing Statements Detected.**")