SHIKARICHACHA commited on
Commit
c6a27a9
·
verified ·
1 Parent(s): cb89e3f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -83
app.py CHANGED
@@ -6,6 +6,7 @@ Generates custom musical exercises with LLM, perfectly fit to user-specified num
6
  AND time signature, guaranteeing exact durations in MIDI and in the UI!
7
 
8
  Major updates:
 
9
  - Added duration sum display in Exercise Data tab
10
  - Shows total duration units (16th notes) for verification
11
  - Added DeepSeek AI model option
@@ -19,6 +20,8 @@ Major updates:
19
  import sys
20
  import subprocess
21
  from typing import Dict, Optional, Tuple, List
 
 
22
 
23
  def install(packages: List[str]):
24
  for package in packages:
@@ -36,7 +39,6 @@ install([
36
  # -----------------------------------------------------------------------------
37
  # 2. Static imports
38
  # -----------------------------------------------------------------------------
39
- import random
40
  import requests
41
  import json
42
  import tempfile
@@ -55,14 +57,20 @@ import os
55
  import subprocess as sp
56
  import base64
57
  import shutil
58
- from openai import OpenAI # For DeepSeek API
59
 
60
  # -----------------------------------------------------------------------------
61
  # 3. Configuration & constants
62
  # -----------------------------------------------------------------------------
63
  MISTRAL_API_URL = "https://api.mistral.ai/v1/chat/completions"
64
  MISTRAL_API_KEY = "yQdfM8MLbX9uhInQ7id4iUTwN4h4pDLX" # Replace with your key
65
- DEEPSEEK_API_KEY = "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"
 
 
 
 
 
 
66
 
67
  SOUNDFONT_URLS = {
68
  "Trumpet": "https://github.com/FluidSynth/fluidsynth/raw/master/sf2/Trumpet.sf2",
@@ -261,7 +269,7 @@ def get_technique_based_on_level(level: str) -> str:
261
  return random.choice(techniques.get(level, ["with slurs"]))
262
 
263
  # -----------------------------------------------------------------------------
264
- # 9. LLM Query Function (supports Mistral and DeepSeek)
265
  # -----------------------------------------------------------------------------
266
  def query_llm(model_name: str, prompt: str, instrument: str, level: str, key: str,
267
  time_sig: str, measures: int) -> str:
@@ -295,7 +303,90 @@ def query_llm(model_name: str, prompt: str, instrument: str, level: str, key: st
295
  "Sum must be exactly as specified. ONLY output the JSON array. No prose."
296
  )
297
 
298
- if model_name == "Mistral":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
  headers = {
300
  "Authorization": f"Bearer {MISTRAL_API_KEY}",
301
  "Content-Type": "application/json",
@@ -312,45 +403,12 @@ def query_llm(model_name: str, prompt: str, instrument: str, level: str, key: st
312
  "frequency_penalty": 0.2,
313
  "presence_penalty": 0.2,
314
  }
315
- try:
316
- response = requests.post(MISTRAL_API_URL, headers=headers, json=payload)
317
- response.raise_for_status()
318
- content = response.json()["choices"][0]["message"]["content"]
319
- return content.replace("```json","").replace("```","").strip()
320
- except Exception as e:
321
- print(f"Error querying Mistral API: {e}")
322
- return get_fallback_exercise(instrument, level, key, time_sig, measures)
323
-
324
- elif model_name == "DeepSeek":
325
- try:
326
- client = OpenAI(
327
- base_url="https://openrouter.ai/api/v1",
328
- api_key=DEEPSEEK_API_KEY,
329
- )
330
-
331
- completion = client.chat.completions.create(
332
- extra_headers={
333
- "HTTP-Referer": "https://github.com/AdaptiveMusicExerciseGenerator",
334
- "X-Title": "Music Exercise Generator",
335
- },
336
- model="deepseek/deepseek-chat-v3-0324:free",
337
- messages=[
338
- {"role": "system", "content": system_prompt},
339
- {"role": "user", "content": user_prompt},
340
- ],
341
- temperature=0.7 if level == "Advanced" else 0.5,
342
- max_tokens=1000,
343
- top_p=0.95,
344
- frequency_penalty=0.2,
345
- presence_penalty=0.2,
346
- )
347
- content = completion.choices[0].message.content
348
- return content.replace("```json","").replace("```","").strip()
349
- except Exception as e:
350
- print(f"Error querying DeepSeek API: {e}")
351
- return get_fallback_exercise(instrument, level, key, time_sig, measures)
352
-
353
- else:
354
  return get_fallback_exercise(instrument, level, key, time_sig, measures)
355
 
356
  # -----------------------------------------------------------------------------
@@ -399,7 +457,7 @@ def generate_exercise(instrument: str, level: str, key: str, tempo: int, time_si
399
  return f"Error: {str(e)}", None, str(tempo), None, "0", time_signature, 0
400
 
401
  # -----------------------------------------------------------------------------
402
- # 12. Simple AI chat assistant (optional, shares LLM)
403
  # -----------------------------------------------------------------------------
404
  def handle_chat(message: str, history: List, instrument: str, level: str, ai_model: str):
405
  if not message.strip():
@@ -410,49 +468,84 @@ def handle_chat(message: str, history: List, instrument: str, level: str, ai_mod
410
  messages.append({"role": "assistant", "content": assistant_msg})
411
  messages.append({"role": "user", "content": message})
412
 
413
- if ai_model == "Mistral":
414
- headers = {"Authorization": f"Bearer {MISTRAL_API_KEY}", "Content-Type": "application/json"}
415
- payload = {"model": "mistral-medium", "messages": messages, "temperature": 0.7, "max_tokens": 500}
416
- try:
417
- response = requests.post(MISTRAL_API_URL, headers=headers, json=payload)
418
- response.raise_for_status()
419
- content = response.json()["choices"][0]["message"]["content"]
420
- history.append((message, content))
421
- return "", history
422
- except Exception as e:
423
- history.append((message, f"Error: {str(e)}"))
424
- return "", history
425
 
426
- elif ai_model == "DeepSeek":
427
  try:
428
- client = OpenAI(
429
- base_url="https://openrouter.ai/api/v1",
430
- api_key=DEEPSEEK_API_KEY,
431
- )
 
 
 
 
432
 
433
- completion = client.chat.completions.create(
434
- extra_headers={
435
- "HTTP-Referer": "https://github.com/AdaptiveMusicExerciseGenerator",
436
- "X-Title": "Music Exercise Generator",
437
- },
438
- model="deepseek/deepseek-chat-v3-0324:free",
439
- messages=messages,
440
- temperature=0.7,
441
- max_tokens=500,
442
- )
443
- content = completion.choices[0].message.content
444
- history.append((message, content))
445
- return "", history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
446
  except Exception as e:
447
- history.append((message, f"Error: {str(e)}"))
448
- return "", history
449
-
450
- else:
451
- history.append((message, "Error: Invalid AI model selected"))
452
- return "", history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
453
 
454
  # -----------------------------------------------------------------------------
455
- # 13. Gradio user interface definition (for humans!)
456
  # -----------------------------------------------------------------------------
457
  def create_ui() -> gr.Blocks:
458
  with gr.Blocks(title="Adaptive Music Exercise Generator", theme="soft") as demo:
@@ -466,7 +559,7 @@ def create_ui() -> gr.Blocks:
466
  with gr.Group(visible=True) as params_group:
467
  gr.Markdown("### Exercise Parameters")
468
  ai_model = gr.Radio(
469
- ["Mistral", "DeepSeek"],
470
  value="Mistral",
471
  label="AI Model"
472
  )
 
6
  AND time signature, guaranteeing exact durations in MIDI and in the UI!
7
 
8
  Major updates:
9
+ - Added Gemma, Kimi Dev 72b, and Llama 3.1 AI model options
10
  - Added duration sum display in Exercise Data tab
11
  - Shows total duration units (16th notes) for verification
12
  - Added DeepSeek AI model option
 
20
  import sys
21
  import subprocess
22
  from typing import Dict, Optional, Tuple, List
23
+ import time
24
+ import random
25
 
26
  def install(packages: List[str]):
27
  for package in packages:
 
39
  # -----------------------------------------------------------------------------
40
  # 2. Static imports
41
  # -----------------------------------------------------------------------------
 
42
  import requests
43
  import json
44
  import tempfile
 
57
  import subprocess as sp
58
  import base64
59
  import shutil
60
+ from openai import OpenAI # For API models
61
 
62
  # -----------------------------------------------------------------------------
63
  # 3. Configuration & constants
64
  # -----------------------------------------------------------------------------
65
  MISTRAL_API_URL = "https://api.mistral.ai/v1/chat/completions"
66
  MISTRAL_API_KEY = "yQdfM8MLbX9uhInQ7id4iUTwN4h4pDLX" # Replace with your key
67
+ OPENROUTER_API_KEYS = {
68
+ "DeepSeek": "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",
69
+ "Claude": "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",
70
+ "Gemma": "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",
71
+ "Kimi": "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",
72
+ "Llama 3.1": "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"
73
+ }
74
 
75
  SOUNDFONT_URLS = {
76
  "Trumpet": "https://github.com/FluidSynth/fluidsynth/raw/master/sf2/Trumpet.sf2",
 
269
  return random.choice(techniques.get(level, ["with slurs"]))
270
 
271
  # -----------------------------------------------------------------------------
272
+ # 9. LLM Query Function (with enhanced error handling)
273
  # -----------------------------------------------------------------------------
274
  def query_llm(model_name: str, prompt: str, instrument: str, level: str, key: str,
275
  time_sig: str, measures: int) -> str:
 
303
  "Sum must be exactly as specified. ONLY output the JSON array. No prose."
304
  )
305
 
306
+ # Retry up to 3 times for rate limited models
307
+ max_retries = 3
308
+ retry_delay = 5 # seconds
309
+
310
+ for attempt in range(max_retries):
311
+ try:
312
+ if model_name == "Mistral":
313
+ headers = {
314
+ "Authorization": f"Bearer {MISTRAL_API_KEY}",
315
+ "Content-Type": "application/json",
316
+ }
317
+ payload = {
318
+ "model": "mistral-medium",
319
+ "messages": [
320
+ {"role": "system", "content": system_prompt},
321
+ {"role": "user", "content": user_prompt},
322
+ ],
323
+ "temperature": 0.7 if level == "Advanced" else 0.5,
324
+ "max_tokens": 1000,
325
+ "top_p": 0.95,
326
+ "frequency_penalty": 0.2,
327
+ "presence_penalty": 0.2,
328
+ }
329
+ response = requests.post(MISTRAL_API_URL, headers=headers, json=payload)
330
+ response.raise_for_status()
331
+ content = response.json()["choices"][0]["message"]["content"]
332
+ return content.replace("```json","").replace("```","").strip()
333
+
334
+ elif model_name in ["DeepSeek", "Claude", "Gemma", "Kimi", "Llama 3.1"]:
335
+ client = OpenAI(
336
+ base_url="https://openrouter.ai/api/v1",
337
+ api_key=OPENROUTER_API_KEYS[model_name],
338
+ )
339
+
340
+ model_map = {
341
+ "DeepSeek": "deepseek/deepseek-chat-v3-0324:free",
342
+ "Claude": "anthropic/claude-3.5-sonnet:beta",
343
+ "Gemma": "google/gemma-3n-e2b-it:free",
344
+ "Kimi": "moonshotai/kimi-dev-72b:free",
345
+ "Llama 3.1": "meta-llama/llama-3.1-405b-instruct:free"
346
+ }
347
+
348
+ # Special handling for Gemma API structure
349
+ if model_name == "Gemma":
350
+ messages = [
351
+ {"role": "user", "content": user_prompt}
352
+ ]
353
+ else:
354
+ messages = [
355
+ {"role": "system", "content": system_prompt},
356
+ {"role": "user", "content": user_prompt},
357
+ ]
358
+
359
+ completion = client.chat.completions.create(
360
+ extra_headers={
361
+ "HTTP-Referer": "https://github.com/AdaptiveMusicExerciseGenerator",
362
+ "X-Title": "Music Exercise Generator",
363
+ },
364
+ model=model_map[model_name],
365
+ messages=messages,
366
+ temperature=0.7 if level == "Advanced" else 0.5,
367
+ max_tokens=1000,
368
+ top_p=0.95,
369
+ frequency_penalty=0.2,
370
+ presence_penalty=0.2,
371
+ )
372
+ content = completion.choices[0].message.content
373
+ return content.replace("```json","").replace("```","").strip()
374
+
375
+ else:
376
+ return get_fallback_exercise(instrument, level, key, time_sig, measures)
377
+
378
+ except Exception as e:
379
+ print(f"Error querying {model_name} API (attempt {attempt+1}): {e}")
380
+ if "429" in str(e) or "Rate limit" in str(e):
381
+ print(f"Rate limited, retrying in {retry_delay} seconds...")
382
+ time.sleep(retry_delay)
383
+ retry_delay *= 2 # Exponential backoff
384
+ else:
385
+ break
386
+
387
+ # Fallback to Mistral if other APIs fail
388
+ print(f"All attempts failed for {model_name}, using Mistral fallback")
389
+ try:
390
  headers = {
391
  "Authorization": f"Bearer {MISTRAL_API_KEY}",
392
  "Content-Type": "application/json",
 
403
  "frequency_penalty": 0.2,
404
  "presence_penalty": 0.2,
405
  }
406
+ response = requests.post(MISTRAL_API_URL, headers=headers, json=payload)
407
+ response.raise_for_status()
408
+ content = response.json()["choices"][0]["message"]["content"]
409
+ return content.replace("```json","").replace("```","").strip()
410
+ except Exception as e:
411
+ print(f"Error querying Mistral fallback: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
  return get_fallback_exercise(instrument, level, key, time_sig, measures)
413
 
414
  # -----------------------------------------------------------------------------
 
457
  return f"Error: {str(e)}", None, str(tempo), None, "0", time_signature, 0
458
 
459
  # -----------------------------------------------------------------------------
460
+ # 12. AI chat assistant with enhanced error handling
461
  # -----------------------------------------------------------------------------
462
  def handle_chat(message: str, history: List, instrument: str, level: str, ai_model: str):
463
  if not message.strip():
 
468
  messages.append({"role": "assistant", "content": assistant_msg})
469
  messages.append({"role": "user", "content": message})
470
 
471
+ max_retries = 3
472
+ retry_delay = 3 # seconds
 
 
 
 
 
 
 
 
 
 
473
 
474
+ for attempt in range(max_retries):
475
  try:
476
+ if ai_model == "Mistral":
477
+ headers = {"Authorization": f"Bearer {MISTRAL_API_KEY}", "Content-Type": "application/json"}
478
+ payload = {"model": "mistral-medium", "messages": messages, "temperature": 0.7, "max_tokens": 500}
479
+ response = requests.post(MISTRAL_API_URL, headers=headers, json=payload)
480
+ response.raise_for_status()
481
+ content = response.json()["choices"][0]["message"]["content"]
482
+ history.append((message, content))
483
+ return "", history
484
 
485
+ elif ai_model in ["DeepSeek", "Claude", "Gemma", "Kimi", "Llama 3.1"]:
486
+ client = OpenAI(
487
+ base_url="https://openrouter.ai/api/v1",
488
+ api_key=OPENROUTER_API_KEYS[ai_model],
489
+ )
490
+
491
+ model_map = {
492
+ "DeepSeek": "deepseek/deepseek-chat-v3-0324:free",
493
+ "Claude": "anthropic/claude-3.5-sonnet:beta",
494
+ "Gemma": "google/gemma-3n-e2b-it:free",
495
+ "Kimi": "moonshotai/kimi-dev-72b:free",
496
+ "Llama 3.1": "meta-llama/llama-3.1-405b-instruct:free"
497
+ }
498
+
499
+ # Special handling for Gemma API structure
500
+ if ai_model == "Gemma":
501
+ adjusted_messages = [{"role": "user", "content": msg["content"]} for msg in messages]
502
+ else:
503
+ adjusted_messages = messages
504
+
505
+ completion = client.chat.completions.create(
506
+ extra_headers={
507
+ "HTTP-Referer": "https://github.com/AdaptiveMusicExerciseGenerator",
508
+ "X-Title": "Music Exercise Generator",
509
+ },
510
+ model=model_map[ai_model],
511
+ messages=adjusted_messages,
512
+ temperature=0.7,
513
+ max_tokens=500,
514
+ )
515
+ content = completion.choices[0].message.content
516
+ history.append((message, content))
517
+ return "", history
518
+
519
+ else:
520
+ history.append((message, "Error: Invalid AI model selected"))
521
+ return "", history
522
+
523
  except Exception as e:
524
+ print(f"Chat error with {ai_model} (attempt {attempt+1}): {e}")
525
+ if "429" in str(e) or "Rate limit" in str(e):
526
+ print(f"Rate limited, retrying in {retry_delay} seconds...")
527
+ time.sleep(retry_delay)
528
+ retry_delay *= 2 # Exponential backoff
529
+ else:
530
+ # Fallback to Mistral
531
+ print(f"Using Mistral fallback for chat")
532
+ try:
533
+ headers = {"Authorization": f"Bearer {MISTRAL_API_KEY}", "Content-Type": "application/json"}
534
+ payload = {"model": "mistral-medium", "messages": messages, "temperature": 0.7, "max_tokens": 500}
535
+ response = requests.post(MISTRAL_API_URL, headers=headers, json=payload)
536
+ response.raise_for_status()
537
+ content = response.json()["choices"][0]["message"]["content"]
538
+ history.append((message, content))
539
+ return "", history
540
+ except Exception as e:
541
+ history.append((message, f"Error: {str(e)}"))
542
+ return "", history
543
+
544
+ history.append((message, "Error: All API attempts failed"))
545
+ return "", history
546
 
547
  # -----------------------------------------------------------------------------
548
+ # 13. Gradio user interface definition
549
  # -----------------------------------------------------------------------------
550
  def create_ui() -> gr.Blocks:
551
  with gr.Blocks(title="Adaptive Music Exercise Generator", theme="soft") as demo:
 
559
  with gr.Group(visible=True) as params_group:
560
  gr.Markdown("### Exercise Parameters")
561
  ai_model = gr.Radio(
562
+ ["Mistral", "DeepSeek", "Claude", "Gemma", "Kimi", "Llama 3.1"],
563
  value="Mistral",
564
  label="AI Model"
565
  )