Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import random | |
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| from chatterbox.tts import ChatterboxTTS | |
| from huggingface_hub import hf_hub_download | |
| from safetensors.torch import load_file | |
| from torch import nn | |
| import re | |
| # === Einstellungen === | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| MODEL_REPO = "SebastianBodza/Kartoffelbox-v0.1" | |
| T3_CHECKPOINT_FILE = "t3_kartoffelbox.safetensors" | |
| MAX_CHARS = 5000 | |
| CHUNK_CHAR_LIMIT = 300 | |
| SETTINGS_DIR = "settings" | |
| # === Init === | |
| if not os.path.exists(SETTINGS_DIR): | |
| os.makedirs(SETTINGS_DIR) | |
| MODEL = None | |
| print(f"🚀 Running on device: {DEVICE}") | |
| def get_or_load_model(): | |
| global MODEL | |
| if MODEL is None: | |
| print("Model not loaded, initializing...") | |
| MODEL = ChatterboxTTS.from_pretrained(DEVICE) | |
| checkpoint_path = hf_hub_download( | |
| repo_id=MODEL_REPO, | |
| filename=T3_CHECKPOINT_FILE, | |
| token=os.environ.get("HUGGING_FACE_HUB_TOKEN", "") | |
| ) | |
| t3_state = load_file(checkpoint_path, device="cpu") | |
| MODEL.t3.load_state_dict(t3_state) | |
| # Position Embeddings erweitern | |
| pos_emb_module = MODEL.t3.text_pos_emb | |
| old_pos = pos_emb_module.emb.num_embeddings | |
| if MAX_CHARS > old_pos: | |
| emb_dim = pos_emb_module.emb.embedding_dim | |
| new_emb = nn.Embedding(MAX_CHARS, emb_dim) | |
| with torch.no_grad(): | |
| new_emb.weight[:old_pos] = pos_emb_module.emb.weight | |
| pos_emb_module.emb = new_emb | |
| print(f"Expanded position embeddings: {old_pos} → {MAX_CHARS}") | |
| MODEL.t3.to(DEVICE) | |
| MODEL.s3gen.to(DEVICE) | |
| print(f"Model loaded. Device: {MODEL.device}") | |
| return MODEL | |
| try: | |
| get_or_load_model() | |
| except Exception as e: | |
| print(f"CRITICAL: Failed to load model: {e}") | |
| def set_seed(seed: int): | |
| torch.manual_seed(seed) | |
| if DEVICE == "cuda": | |
| torch.cuda.manual_seed_all(seed) | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| def split_text_into_chunks(text, max_length=CHUNK_CHAR_LIMIT): | |
| sentences = re.split(r'(?<=[.!?]) +', text) | |
| chunks = [] | |
| chunk = "" | |
| for sentence in sentences: | |
| if len(chunk) + len(sentence) < max_length: | |
| chunk += " " + sentence | |
| else: | |
| if chunk: | |
| chunks.append(chunk.strip()) | |
| chunk = sentence | |
| if chunk: | |
| chunks.append(chunk.strip()) | |
| return chunks | |
| # === Einstellungen speichern/laden === | |
| def list_presets(): | |
| return [f[:-5] for f in os.listdir(SETTINGS_DIR) if f.endswith(".json") and f != "last.json"] | |
| def load_preset(name): | |
| path = os.path.join(SETTINGS_DIR, name + ".json") | |
| if os.path.exists(path): | |
| with open(path, "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| return None | |
| def save_preset(name, data): | |
| path = os.path.join(SETTINGS_DIR, name + ".json") | |
| with open(path, "w", encoding="utf-8") as f: | |
| json.dump(data, f, indent=2) | |
| save_preset("last", data) # Als "zuletzt genutzt" speichern | |
| def generate_tts_audio(text_input, audio_prompt_path_input, exaggeration_input, temperature_input, seed_num_input, cfgw_input): | |
| model = get_or_load_model() | |
| if seed_num_input != 0: | |
| set_seed(int(seed_num_input)) | |
| full_audio = [] | |
| chunks = split_text_into_chunks(text_input[:MAX_CHARS]) | |
| print(f"Text wird in {len(chunks)} Teile aufgeteilt…") | |
| for i, chunk in enumerate(chunks): | |
| print(f"▶️ Teil {i+1}/{len(chunks)}: {chunk[:60]}...") | |
| wav = model.generate( | |
| chunk, | |
| audio_prompt_path=audio_prompt_path_input, | |
| exaggeration=exaggeration_input, | |
| temperature=temperature_input, | |
| cfg_weight=cfgw_input, | |
| ) | |
| full_audio.append(wav.squeeze(0).cpu().numpy()) | |
| audio_concat = np.concatenate(full_audio) | |
| return (model.sr, audio_concat) | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| gr.Markdown("# 🥔 Kartoffel-TTS (Chatterbox)\nLangtext → Sprachstil mit Profilen") | |
| with gr.Row(): | |
| with gr.Column(): | |
| preset_dropdown = gr.Dropdown(label="🔄 Preset wählen", choices=list_presets(), value=None) | |
| preset_name = gr.Textbox(label="📝 Name zum Speichern", value="mein-profil") | |
| text = gr.Textbox( | |
| value="Hier kannst du einen längeren deutschen Text eingeben…", | |
| label=f"Text (max {MAX_CHARS} Zeichen)", | |
| max_lines=12 | |
| ) | |
| ref_wav = gr.Audio( | |
| sources=["upload", "microphone"], | |
| type="filepath", | |
| label="Referenz-Audiodatei (optional)", | |
| value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac" | |
| ) | |
| exaggeration = gr.Slider(0.25, 2, step=.05, label="Exaggeration", value=.5) | |
| cfg_weight = gr.Slider(0.2, 1, step=.05, label="CFG/Pace", value=0.3) | |
| with gr.Accordion("Weitere Optionen", open=False): | |
| seed_num = gr.Number(value=0, label="Zufalls-Seed (0 = zufällig)") | |
| temp = gr.Slider(0.05, 5, step=.05, label="Temperature", value=.6) | |
| save_btn = gr.Button("💾 Einstellungen speichern") | |
| run_btn = gr.Button("🎤 Audio generieren") | |
| with gr.Column(): | |
| audio_output = gr.Audio(label="🔊 Ergebnis") | |
| # Funktionen zuweisen | |
| def on_preset_selected(name): | |
| if name: | |
| p = load_preset(name) | |
| if p: | |
| return p["exaggeration"], p["temperature"], p["seed"], p["cfg"] | |
| return gr.update(), gr.update(), gr.update(), gr.update() | |
| preset_dropdown.change( | |
| on_preset_selected, | |
| inputs=[preset_dropdown], | |
| outputs=[exaggeration, temp, seed_num, cfg_weight] | |
| ) | |
| def save_current_settings(name, exaggeration, temperature, seed, cfg): | |
| save_preset(name, { | |
| "exaggeration": exaggeration, | |
| "temperature": temperature, | |
| "seed": seed, | |
| "cfg": cfg | |
| }) | |
| return gr.update(choices=list_presets()) | |
| save_btn.click( | |
| fn=save_current_settings, | |
| inputs=[preset_name, exaggeration, temp, seed_num, cfg_weight], | |
| outputs=[preset_dropdown] | |
| ) | |
| run_btn.click( | |
| fn=generate_tts_audio, | |
| inputs=[text, ref_wav, exaggeration, temp, seed_num, cfg_weight], | |
| outputs=[audio_output], | |
| ) | |
| # Letztes Profil beim Start laden | |
| if os.path.exists(os.path.join(SETTINGS_DIR, "last.json")): | |
| last = load_preset("last") | |
| if last: | |
| exaggeration.value = last["exaggeration"] | |
| temp.value = last["temperature"] | |
| seed_num.value = last["seed"] | |
| cfg_weight.value = last["cfg"] | |
| # 👇 ROBUSTER START – wichtig für exe ohne Konsole! | |
| demo.launch( | |
| quiet=True, | |
| show_error=True, | |
| prevent_thread_lock=False | |
| ) | |