philipp-zettl commited on
Commit
caed98f
·
verified ·
1 Parent(s): 48aaa8e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +444 -65
app.py CHANGED
@@ -1,70 +1,449 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
-
5
- def respond(
6
- message,
7
- history: list[dict[str, str]],
8
- system_message,
9
- max_tokens,
10
- temperature,
11
- top_p,
12
- hf_token: gr.OAuthToken,
13
- ):
14
- """
15
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
16
- """
17
- client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
18
-
19
- messages = [{"role": "system", "content": system_message}]
20
-
21
- messages.extend(history)
22
-
23
- messages.append({"role": "user", "content": message})
24
-
25
- response = ""
26
-
27
- for message in client.chat_completion(
28
- messages,
29
- max_tokens=max_tokens,
30
- stream=True,
31
- temperature=temperature,
32
- top_p=top_p,
33
- ):
34
- choices = message.choices
35
- token = ""
36
- if len(choices) and choices[0].delta.content:
37
- token = choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- chatbot = gr.ChatInterface(
47
- respond,
48
- type="messages",
49
- additional_inputs=[
50
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
51
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
52
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
53
- gr.Slider(
54
- minimum=0.1,
55
- maximum=1.0,
56
- value=0.95,
57
- step=0.05,
58
- label="Top-p (nucleus sampling)",
59
- ),
60
- ],
61
- )
62
-
63
- with gr.Blocks() as demo:
64
- with gr.Sidebar():
65
- gr.LoginButton()
66
- chatbot.render()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
 
69
  if __name__ == "__main__":
 
 
 
 
70
  demo.launch()
 
1
  import gradio as gr
2
+ from pydantic import BaseModel, field_validator
3
+ from typing import List, Optional, Dict, Any
4
+ import numpy as np
5
+ import random
6
+ import json
7
+
8
+ # --- Pydantic Models (from original app) ---
9
+ # We keep these for data validation and structure, even without FastAPI
10
+ class BaselineRequest(BaseModel):
11
+ task: str # "classification", "regression", "generation", "chess_moves"
12
+ dataset_size: int
13
+ output_format: str # "categorical", "continuous", "sequence"
14
+ classes: Optional[List[str]] = None
15
+ num_classes: Optional[int] = None
16
+ sequence_length: Optional[int] = None
17
+ target_distribution: Optional[Dict[str, float]] = None
18
+
19
+ @field_validator('dataset_size')
20
+ def size_must_be_positive(cls, v):
21
+ if v <= 0:
22
+ raise ValueError('Dataset size must be positive')
23
+ return v
24
+
25
+ class BaselineResponse(BaseModel):
26
+ task: str
27
+ baseline_type: str
28
+ metrics: Dict[str, Any] # Changed to Any to accommodate range list
29
+ sample_predictions: List[Any]
30
+ reality_check: str
31
+ advice: str
32
+
33
+ # --- Core Logic Functions (from original app) ---
34
+
35
+ def generate_random_classification(request: BaselineRequest):
36
+ """Generate random classification baseline"""
37
+ if request.classes:
38
+ num_classes = len(request.classes)
39
+ class_names = request.classes
40
+ else:
41
+ num_classes = request.num_classes or 2
42
+ class_names = [f"class_{i}" for i in range(num_classes)]
43
+
44
+ # Ensure num_classes is not zero
45
+ if num_classes == 0:
46
+ num_classes = 1
47
+ class_names = ["default_class"]
48
+
49
+ # Generate random predictions
50
+ if request.target_distribution:
51
+ # Use provided distribution
52
+ weights = [request.target_distribution.get(cls, 1/num_classes) for cls in class_names]
53
+ try:
54
+ predictions = random.choices(class_names, weights=weights, k=request.dataset_size)
55
+ except ValueError: # Handle all-zero weights
56
+ predictions = [random.choice(class_names) for _ in range(request.dataset_size)]
57
+ else:
58
+ # Uniform random
59
+ predictions = [random.choice(class_names) for _ in range(request.dataset_size)]
60
+
61
+ # Calculate expected accuracy for uniform random
62
+ expected_accuracy = 1 / num_classes
63
+
64
+ return {
65
+ "baseline_type": "uniform_random" if not request.target_distribution else "weighted_random",
66
+ "metrics": {
67
+ "expected_accuracy": round(expected_accuracy, 4),
68
+ "expected_f1": round(expected_accuracy, 4), # Simplified for uniform case
69
+ "num_classes": num_classes
70
+ },
71
+ "sample_predictions": predictions[:10],
72
+ "reality_check": f"Random guessing should get ~{expected_accuracy:.1%} accuracy. If your model doesn't beat this by a significant margin, it's probably garbage.",
73
+ "advice": "Train a simple baseline (logistic regression, random forest) before going neural. Save yourself the GPU bills."
74
+ }
75
+
76
+ def generate_random_regression(request: BaselineRequest):
77
+ """Generate random regression baseline"""
78
+ # Generate random continuous values
79
+ predictions = np.random.normal(0, 1, request.dataset_size)
80
+
81
+ return {
82
+ "baseline_type": "gaussian_random",
83
+ "metrics": {
84
+ "mean": round(float(np.mean(predictions)), 4),
85
+ "std": round(float(np.std(predictions)), 4),
86
+ "range": [round(float(np.min(predictions)), 4), round(float(np.max(predictions)), 4)]
87
+ },
88
+ "sample_predictions": predictions[:10].tolist(),
89
+ "reality_check": "Random regression predictions have infinite MSE against any reasonable target. If your model's MSE isn't dramatically better, you're wasting compute.",
90
+ "advice": "Start with mean prediction baseline, then linear regression. Neural networks are overkill for most regression problems."
91
+ }
92
+
93
+ def generate_random_sequence(request: BaselineRequest):
94
+ """Generate random sequence baseline (like text/chess moves)"""
95
+ vocab_size = len(request.classes) if request.classes else 1000
96
+ if vocab_size == 0: # Handle empty vocab
97
+ vocab_size = 1
98
+
99
+ seq_len = request.sequence_length or 50
100
+
101
+ sequences = []
102
+ for _ in range(min(10, request.dataset_size)):
103
+ if request.classes:
104
+ seq = [random.choice(request.classes) for _ in range(seq_len)]
105
+ else:
106
+ seq = [random.randint(0, vocab_size-1) for _ in range(seq_len)]
107
+ sequences.append(seq)
108
+
109
+ perplexity = vocab_size # Worst case perplexity for uniform random
110
+
111
+ return {
112
+ "baseline_type": "uniform_random_sequence",
113
+ "metrics": {
114
+ "perplexity": perplexity,
115
+ "sequence_length": seq_len,
116
+ "vocab_size": vocab_size
117
+ },
118
+ "sample_predictions": sequences,
119
+ "reality_check": f"Random sequences have perplexity ~{perplexity}. If your language model doesn't crush this, it learned nothing.",
120
+ "advice": "Even a bigram model should destroy random baselines. If it doesn't, check your data preprocessing."
121
+ }
122
+
123
+ # Special handlers (from original app)
124
+ TASK_HANDLERS = {
125
+ "chess_moves": lambda req: generate_random_sequence(BaselineRequest(
126
+ task="chess_moves",
127
+ dataset_size=req.dataset_size,
128
+ output_format="sequence",
129
+ classes=["e4", "d4", "Nf3", "c4", "g3", "Nc3", "f4", "e3"], # Common opening moves
130
+ sequence_length=1
131
+ )),
132
+ "sentiment": lambda req: generate_random_classification(BaselineRequest(
133
+ task="sentiment",
134
+ dataset_size=req.dataset_size,
135
+ output_format="categorical",
136
+ classes=["positive", "negative", "neutral"]
137
+ )),
138
+ "image_classification": lambda req: generate_random_classification(BaselineRequest(
139
+ task="image_classification",
140
+ dataset_size=req.dataset_size,
141
+ output_format="categorical",
142
+ num_classes=req.num_classes or 1000 # ImageNet default
143
+ ))
144
+ }
145
+
146
+ # Roast logic (from original app)
147
+ ROASTS = [
148
+ "Your neural network is just an expensive random number generator.",
149
+ "I bet your model's accuracy is 50.1% and you're calling it 'promising results'.",
150
+ "Random guessing doesn't need 8 GPUs and a PhD to run.",
151
+ "Your transformer probably learned to predict the dataset bias, not the actual task.",
152
+ "If random baseline beats your model, maybe try a different career?",
153
+ "Your model: 47% accuracy. Random baseline: 50%. Congratulations, you made it worse.",
154
+ ]
155
+
156
+ def get_roast():
157
+ """Get roasted for probably having a model worse than random"""
158
+ return random.choice(ROASTS)
159
+
160
+
161
+ # --- Gradio Interface Functions ---
162
+
163
+ def handle_classification(task_choice, dataset_size, num_classes, classes_str, dist_str):
164
+ """Gradio handler for the classification tab"""
165
+ try:
166
+ # 1. Parse Inputs
167
+ task_name = task_choice
168
+ if task_choice == "image_classification (1000 class)":
169
+ task_name = "image_classification"
170
+ num_classes = 1000 # Override
171
+
172
+ classes_list = [c.strip() for c in classes_str.split(',')] if classes_str else None
173
+
174
+ target_dist = None
175
+ if dist_str:
176
+ try:
177
+ target_dist = json.loads(dist_str)
178
+ if not isinstance(target_dist, dict):
179
+ raise ValueError("JSON must be an object/dictionary.")
180
+ except json.JSONDecodeError as e:
181
+ raise gr.Error(f"Invalid JSON in target distribution: {e}")
182
+ except ValueError as e:
183
+ raise gr.Error(str(e))
184
+
185
+ # 2. Build Request
186
+ request = BaselineRequest(
187
+ task=task_name,
188
+ dataset_size=int(dataset_size),
189
+ output_format="categorical",
190
+ classes=classes_list,
191
+ num_classes=int(num_classes) if num_classes else None,
192
+ target_distribution=target_dist
193
+ )
194
+
195
+ # 3. Get Result
196
+ if request.task in TASK_HANDLERS:
197
+ result = TASK_HANDLERS[request.task](request)
198
+ else: # "custom"
199
+ result = generate_random_classification(request)
200
+
201
+ # 4. Format Output
202
+ response = BaselineResponse(task=request.task, **result)
203
+ return (
204
+ response.task,
205
+ response.baseline_type,
206
+ response.metrics,
207
+ response.sample_predictions,
208
+ response.reality_check,
209
+ response.advice
210
+ )
211
+ except Exception as e:
212
+ raise gr.Error(str(e))
213
+
214
+
215
+ def handle_regression(dataset_size):
216
+ """Gradio handler for the regression tab"""
217
+ try:
218
+ request = BaselineRequest(
219
+ task="regression",
220
+ dataset_size=int(dataset_size),
221
+ output_format="continuous"
222
+ )
223
+ result = generate_random_regression(request)
224
+ response = BaselineResponse(task=request.task, **result)
225
+ return (
226
+ response.task,
227
+ response.baseline_type,
228
+ response.metrics,
229
+ response.sample_predictions,
230
+ response.reality_check,
231
+ response.advice
232
+ )
233
+ except Exception as e:
234
+ raise gr.Error(str(e))
235
+
236
+ def handle_sequence(task_choice, dataset_size, seq_len, vocab_str):
237
+ """Gradio handler for the generation/sequence tab"""
238
+ try:
239
+ vocab_list = [c.strip() for c in vocab_str.split(',')] if vocab_str else None
240
+
241
+ request = BaselineRequest(
242
+ task=task_choice,
243
+ dataset_size=int(dataset_size),
244
+ output_format="sequence",
245
+ classes=vocab_list,
246
+ sequence_length=int(seq_len) if seq_len else 50
247
+ )
248
+
249
+ if request.task in TASK_HANDLERS:
250
+ result = TASK_HANDLERS[request.task](request)
251
+ else: # "custom"
252
+ result = generate_random_sequence(request)
253
+
254
+ response = BaselineResponse(task=request.task, **result)
255
+ return (
256
+ response.task,
257
+ response.baseline_type,
258
+ response.metrics,
259
+ response.sample_predictions,
260
+ response.reality_check,
261
+ response.advice
262
+ )
263
+ except Exception as e:
264
+ raise gr.Error(str(e))
265
+
266
+
267
+ # --- Gradio UI Layout ---
268
+
269
+ with gr.Blocks(theme=gr.themes.Soft(), title="Random Baseline API") as demo:
270
+ gr.Markdown(
271
+ """
272
+ # Random Baseline API
273
+ **The most honest ML API in existence. Keeping researchers humble since 2025.**
274
+
275
+ Get a random baseline for your ML task. Because sometimes you need to know how bad 'bad' really is.
276
+ """
277
+ )
278
+
279
+ with gr.Tabs():
280
+ # --- Classification Tab ---
281
+ with gr.TabItem("Classification"):
282
+ with gr.Row():
283
+ with gr.Column(scale=1):
284
+ task_cls = gr.Radio(
285
+ ["sentiment", "image_classification (1000 class)", "custom"],
286
+ label="Task",
287
+ value="sentiment"
288
+ )
289
+ dataset_size_cls = gr.Number(label="Dataset Size", value=1000, minimum=1, step=1)
290
+
291
+ # Custom options
292
+ num_classes_cls = gr.Number(
293
+ label="Number of Classes (if classes not specified)",
294
+ value=10,
295
+ visible=False,
296
+ minimum=1,
297
+ step=1
298
+ )
299
+ classes_cls = gr.Textbox(
300
+ label="Comma-separated classes (e.g., cat,dog,fish)",
301
+ visible=False,
302
+ placeholder="cat, dog, fish"
303
+ )
304
+ dist_cls = gr.Textbox(
305
+ label='JSON target distribution (e.g., {"cat": 0.8})',
306
+ visible=False,
307
+ placeholder='{"cat": 0.8, "dog": 0.1, "fish": 0.1}'
308
+ )
309
+
310
+ btn_cls = gr.Button("Get Classification Baseline", variant="primary")
311
+
312
+ with gr.Column(scale=2):
313
+ out_task_cls = gr.Textbox(label="Task", interactive=False)
314
+ out_btype_cls = gr.Textbox(label="Baseline Type", interactive=False)
315
+ out_metrics_cls = gr.JSON(label="Metrics")
316
+ out_preds_cls = gr.JSON(label="Sample Predictions")
317
+ out_reality_cls = gr.Textbox(label="Reality Check", lines=3, interactive=False)
318
+ out_advice_cls = gr.Textbox(label="Advice", lines=3, interactive=False)
319
+
320
+ # --- Regression Tab ---
321
+ with gr.TabItem("Regression"):
322
+ with gr.Row():
323
+ with gr.Column(scale=1):
324
+ dataset_size_reg = gr.Number(label="Dataset Size", value=1000, minimum=1, step=1)
325
+ btn_reg = gr.Button("Get Regression Baseline", variant="primary")
326
+
327
+ with gr.Column(scale=2):
328
+ out_task_reg = gr.Textbox(label="Task", interactive=False)
329
+ out_btype_reg = gr.Textbox(label="Baseline Type", interactive=False)
330
+ out_metrics_reg = gr.JSON(label="Metrics")
331
+ out_preds_reg = gr.JSON(label="Sample Predictions")
332
+ out_reality_reg = gr.Textbox(label="Reality Check", lines=3, interactive=False)
333
+ out_advice_reg = gr.Textbox(label="Advice", lines=3, interactive=False)
334
+
335
+ # --- Generation/Sequence Tab ---
336
+ with gr.TabItem("Generation / Sequence"):
337
+ with gr.Row():
338
+ with gr.Column(scale=1):
339
+ task_seq = gr.Radio(
340
+ ["chess_moves", "custom"],
341
+ label="Task",
342
+ value="chess_moves"
343
+ )
344
+ dataset_size_seq = gr.Number(label="Dataset Size", value=1000, minimum=1, step=1)
345
+
346
+ # Custom options
347
+ seq_len_seq = gr.Number(label="Sequence Length", value=50, visible=False, minimum=1, step=1)
348
+ vocab_seq = gr.Textbox(
349
+ label="Comma-separated vocabulary (e.g., a,b,c)",
350
+ visible=False,
351
+ placeholder="a, b, c, <pad>, <eos>"
352
+ )
353
+
354
+ btn_seq = gr.Button("Get Sequence Baseline", variant="primary")
355
+
356
+ with gr.Column(scale=2):
357
+ out_task_seq = gr.Textbox(label="Task", interactive=False)
358
+ out_btype_seq = gr.Textbox(label="Baseline Type", interactive=False)
359
+ out_metrics_seq = gr.JSON(label="Metrics")
360
+ out_preds_seq = gr.JSON(label="Sample Predictions")
361
+ out_reality_seq = gr.Textbox(label="Reality Check", lines=3, interactive=False)
362
+ out_advice_seq = gr.Textbox(label="Advice", lines=3, interactive=False)
363
+
364
+ # --- Roast Tab ---
365
+ with gr.TabItem("Roast My Model"):
366
+ gr.Markdown("Feeling too good about your model's 98% accuracy on a balanced dataset? Let us fix that.")
367
+ btn_roast = gr.Button("Roast Me!", variant="stop")
368
+ out_roast = gr.Textbox(label="Your Roast", lines=3, interactive=False)
369
+
370
+
371
+ # --- UI Listeners ---
372
+
373
+ def update_cls_ui(task):
374
+ """Show/hide custom classification options"""
375
+ if task == "custom":
376
+ return {
377
+ num_classes_cls: gr.update(visible=True, value=10),
378
+ classes_cls: gr.update(visible=True),
379
+ dist_cls: gr.update(visible=True)
380
+ }
381
+ elif task == "image_classification (1000 class)":
382
+ return {
383
+ num_classes_cls: gr.update(visible=False, value=1000),
384
+ classes_cls: gr.update(visible=False),
385
+ dist_cls: gr.update(visible=False)
386
+ }
387
+ else: # sentiment
388
+ return {
389
+ num_classes_cls: gr.update(visible=False),
390
+ classes_cls: gr.update(visible=False),
391
+ dist_cls: gr.update(visible=False)
392
+ }
393
+
394
+ task_cls.change(
395
+ fn=update_cls_ui,
396
+ inputs=task_cls,
397
+ outputs=[num_classes_cls, classes_cls, dist_cls]
398
+ )
399
+
400
+ def update_seq_ui(task):
401
+ """Show/hide custom sequence options"""
402
+ if task == "custom":
403
+ return {
404
+ seq_len_seq: gr.update(visible=True),
405
+ vocab_seq: gr.update(visible=True)
406
+ }
407
+ else: # chess_moves
408
+ return {
409
+ seq_len_seq: gr.update(visible=False),
410
+ vocab_seq: gr.update(visible=False)
411
+ }
412
+
413
+ task_seq.change(
414
+ fn=update_seq_ui,
415
+ inputs=task_seq,
416
+ outputs=[seq_len_seq, vocab_seq]
417
+ )
418
+
419
+ # Button click handlers
420
+ cls_outputs = [out_task_cls, out_btype_cls, out_metrics_cls, out_preds_cls, out_reality_cls, out_advice_cls]
421
+ btn_cls.click(
422
+ fn=handle_classification,
423
+ inputs=[task_cls, dataset_size_cls, num_classes_cls, classes_cls, dist_cls],
424
+ outputs=cls_outputs
425
+ )
426
+
427
+ reg_outputs = [out_task_reg, out_btype_reg, out_metrics_reg, out_preds_reg, out_reality_reg, out_advice_reg]
428
+ btn_reg.click(
429
+ fn=handle_regression,
430
+ inputs=[dataset_size_reg],
431
+ outputs=reg_outputs
432
+ )
433
+
434
+ seq_outputs = [out_task_seq, out_btype_seq, out_metrics_seq, out_preds_seq, out_reality_seq, out_advice_seq]
435
+ btn_seq.click(
436
+ fn=handle_sequence,
437
+ inputs=[task_seq, dataset_size_seq, seq_len_seq, vocab_seq],
438
+ outputs=seq_outputs
439
+ )
440
+
441
+ btn_roast.click(fn=get_roast, inputs=None, outputs=out_roast)
442
 
443
 
444
  if __name__ == "__main__":
445
+ # To run this, save as a .py file and run:
446
+ # 1. pip install gradio pydantic numpy
447
+ # 2. python your_app_name.py
448
+ print("Starting Gradio app... Access it at http://127.0.0.1:7860 (or the URL shown below)")
449
  demo.launch()