File size: 4,378 Bytes
0cf02bf
 
 
 
 
0f412e0
 
a61bd5c
0f412e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a61bd5c
0cf02bf
 
 
 
 
 
 
 
 
a61bd5c
0cf02bf
 
 
 
 
 
 
 
 
 
 
 
298e502
0cf02bf
 
 
 
 
 
 
298e502
0cf02bf
 
 
298e502
0cf02bf
 
 
298e502
0cf02bf
 
 
 
 
 
 
0f412e0
298e502
0cf02bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
import torch.nn as nn
import json
import gradio as gr

# --- Step 1: Create a "Smart" Vocabulary Loader ---
# This function will load the vocabularies and automatically fix any format mismatches.

def load_vocabularies():
    """
    Loads vocabularies and intelligently determines the correct format,
    preventing crashes due to misnamed files.
    """
    with open('char_to_int.json', 'r', encoding='utf-8') as f:
        char_to_int_map = json.load(f)

    # Load the file the user has named 'int_to_lang.json'.
    with open('int_to_lang.json', 'r', encoding='utf-8') as f:
        language_vocab = json.load(f)

    # Get the first key to check the format (e.g., is it "0" or "C#")
    first_key = next(iter(language_vocab))
    
    int_to_lang_map = {}
    
    try:
        # Try to convert the first key to an integer.
        int(first_key)
        # If this SUCCEEDS, the file is in the correct {"0": "Language"} format.
        print("[INFO] Detected int->lang format. Loading directly.")
        int_to_lang_map = {int(k): v for k, v in language_vocab.items()}

    except ValueError:
        # If this FAILS, the file is in the {"Language": 0} format.
        # We must reverse it to create the correct int->lang map.
        print("[INFO] Detected lang->int format. Reversing dictionary to fix.")
        int_to_lang_map = {v: k for k, v in language_vocab.items()}
        
    return char_to_int_map, int_to_lang_map

# Load the vocabularies using our smart function
char_to_int, int_to_lang = load_vocabularies()


# --- Step 2: Re-define the Model Architecture ---
# This MUST be the exact same architecture as the one you trained.
class CodeClassifierRNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers, bidirectional, dropout, pad_idx):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=n_layers, bidirectional=bidirectional, dropout=dropout if n_layers > 1 else 0, batch_first=True)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_dim * 2, output_dim) # * 2 for bidirectional
    def forward(self, text):
        embedded = self.embedding(text)
        _, (hidden, _) = self.lstm(embedded)
        hidden = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)
        hidden = self.dropout(hidden)
        output = self.fc(hidden)
        return output

# --- Step 3: Instantiate the model and load the trained weights ---
PAD_IDX = char_to_int['<PAD>']
VOCAB_SIZE = len(char_to_int)
EMBEDDING_DIM = 128
HIDDEN_DIM = 192
OUTPUT_DIM = len(int_to_lang)
N_LAYERS = 2
BIDIRECTIONAL = True
DROPOUT = 0.5

model = CodeClassifierRNN(VOCAB_SIZE, EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM, N_LAYERS, BIDIRECTIONAL, DROPOUT, PAD_IDX)
model.load_state_dict(torch.load('polyglot_classifier.pt', map_location='cpu'))
model.eval()

# --- Step 4: Create the prediction function ---
def classify_code(code_snippet):
    if not code_snippet or not code_snippet.strip():
        return {}
        
    indexed = [char_to_int.get(c, char_to_int['<UNK>']) for c in code_snippet]
    tensor = torch.LongTensor(indexed).unsqueeze(0)
    
    with torch.no_grad():
        prediction = model(tensor)
    
    probabilities = torch.softmax(prediction, dim=1)
    top5_probs, top5_indices = torch.topk(probabilities, 5)

    # This lookup will now work regardless of the original file format.
    confidences = {int_to_lang[idx.item()]: prob.item() for idx, prob in zip(top5_indices[0], top5_probs[0])}
    
    return confidences

# --- Step 5: Create and launch the Gradio Interface ---
iface = gr.Interface(
    fn=classify_code,
    inputs=gr.Code(language=None, label="Code Snippet"),
    outputs=gr.Label(num_top_classes=5, label="Predicted Language"),
    title="Polyglot Code Classifier",
    description="Enter a code snippet to see which programming language the AI thinks it is. This model was trained from scratch on a custom dataset.",
    examples=[
        ["def hello_world():\n    print('Hello from Python!')"],
        ["function greet() {\n    console.log('Hello from JavaScript!');\n}"],
        ["public class Main {\n    public static void main(String[] args) {\n        System.out.println(\"Hello, Java!\");\n    }\n}"]
    ]
)

iface.launch()