Econogoat commited on
Commit
aa8484e
·
verified ·
1 Parent(s): 806ad66

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -39
app.py CHANGED
@@ -22,47 +22,33 @@ prompt_values = df.values.flatten()
22
  # Récupérer le token
23
  HF_TOKEN = os.getenv("HF_TOKEN")
24
 
25
- # --- Initialisation des Modèles (sur CPU uniquement) ---
26
  device_cpu = "cpu"
27
  dtype = torch.bfloat16
28
  base_model = "black-forest-labs/FLUX.1-dev"
29
 
30
- # --- CORRECTION DÉFINITIVE, BASÉE SUR VOTRE ANALYSE ---
31
- # Création d'une configuration BitsAndBytes qui spécifie explicitement `quant_type="nf4"`.
32
- # C'est la seule configuration supportée par bitsandbytes pour la quantization 4-bit sur CPU.
33
- bnb_config_cpu = BitsAndBytesConfig(
34
- load_in_4bit=True,
35
- bnb_4bit_quant_type="nf4",
36
- bnb_4bit_use_double_quant=True,
37
- bnb_4bit_compute_dtype=dtype
38
- )
39
-
40
- print(f"Chargement du LLM {GEMMA_MODEL_ID} sur CPU avec la config NF4...")
41
  gemma_tokenizer = AutoTokenizer.from_pretrained(GEMMA_MODEL_ID, token=HF_TOKEN)
42
- gemma_model = AutoModelForCausalLM.from_pretrained(
43
- GEMMA_MODEL_ID,
44
- quantization_config=bnb_config_cpu, # Utilisation de la configuration corrigée
45
- token=HF_TOKEN,
46
- device_map={'':device_cpu}
47
- )
48
- print("Modèle Gemma chargé.")
49
-
50
- # Le reste du chargement est correct
51
- print("Chargement des composants du modèle d'image sur CPU...")
52
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device_cpu)
53
  good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype, token=HF_TOKEN).to(device_cpu)
54
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1, token=HF_TOKEN).to(device_cpu)
55
- print("Modèles d'image chargés.")
56
-
57
- print(f"Chargement du LoRA : {KRYPTO_LORA['repo']}")
58
  pipe.load_lora_weights(KRYPTO_LORA['repo'], low_cpu_mem_usage=True, adapter_name=KRYPTO_LORA['adapter_name'], token=HF_TOKEN)
59
- print("LoRA chargé.")
60
 
61
  MAX_SEED = 2**32 - 1
62
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
63
 
64
- @spaces.GPU(duration=120)
 
 
 
65
  def run_generation(prompt, enhance_prompt, lora_scale, cfg_scale, steps, randomize_seed, seed, aspect_ratio, base_resolution, progress=gr.Progress(track_tqdm=True)):
 
66
  if not prompt:
67
  raise gr.Error("Prompt cannot be empty.")
68
 
@@ -70,10 +56,24 @@ def run_generation(prompt, enhance_prompt, lora_scale, cfg_scale, steps, randomi
70
  final_prompt = prompt
71
 
72
  if enhance_prompt:
73
- print("Déplacement de Gemma sur le GPU...")
74
- gemma_model.to(device_gpu)
75
- print(f"Amélioration du prompt '{prompt}' avec Gemma...")
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
 
77
  system_prompt = (
78
  "You are an expert prompt engineer for a text-to-image AI. "
79
  "Your task is to take a user's simple idea and transform it into a rich, detailed, and visually descriptive prompt. "
@@ -87,11 +87,7 @@ def run_generation(prompt, enhance_prompt, lora_scale, cfg_scale, steps, randomi
87
  outputs = gemma_model.generate(**inputs, max_new_tokens=150, do_sample=True, temperature=0.7)
88
  input_length = inputs["input_ids"].shape[1]
89
  final_prompt = gemma_tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True).strip()
90
-
91
  print(f"Prompt amélioré : {final_prompt}")
92
- print("Libération de la mémoire de Gemma (déplacement vers CPU)...")
93
- gemma_model.to(device_cpu)
94
- torch.cuda.empty_cache()
95
 
96
  prompt_mash = f"{KRYPTO_LORA['trigger']}, {final_prompt}"
97
  print("Prompt final envoyé au modèle d'image:", prompt_mash)
@@ -99,17 +95,13 @@ def run_generation(prompt, enhance_prompt, lora_scale, cfg_scale, steps, randomi
99
  print("Déplacement du pipeline d'image sur le GPU...")
100
  pipe.to(device_gpu)
101
  good_vae.to(device_gpu)
102
-
103
  pipe.set_adapters([KRYPTO_LORA['adapter_name']], adapter_weights=[lora_scale])
104
 
105
- if randomize_seed:
106
- seed = random.randint(0, MAX_SEED)
107
-
108
  width, height = calculate_dimensions(aspect_ratio, base_resolution)
109
  print(f"Génération d'une image de {width}x{height} pixels.")
110
 
111
  generator = torch.Generator(device=device_gpu).manual_seed(seed)
112
-
113
  image_generator = pipe.flux_pipe_call_that_returns_an_iterable_of_images(
114
  prompt=prompt_mash, num_inference_steps=steps, guidance_scale=cfg_scale,
115
  width=width, height=height, generator=generator, output_type="pil", good_vae=good_vae,
 
22
  # Récupérer le token
23
  HF_TOKEN = os.getenv("HF_TOKEN")
24
 
25
+ # --- Initialisation des Modèles (sur CPU uniquement et SANS QUANTIZATION initiale) ---
26
  device_cpu = "cpu"
27
  dtype = torch.bfloat16
28
  base_model = "black-forest-labs/FLUX.1-dev"
29
 
30
+ # --- STRATÉGIE CORRIGÉE ---
31
+ # On charge Gemma sur le CPU SANS le quantizer au démarrage pour éviter le conflit avec l'environnement de Spaces.
32
+ # La quantization sera appliquée plus tard, uniquement sur le GPU.
33
+ print(f"Chargement du tokenizer pour {GEMMA_MODEL_ID}...")
 
 
 
 
 
 
 
34
  gemma_tokenizer = AutoTokenizer.from_pretrained(GEMMA_MODEL_ID, token=HF_TOKEN)
35
+ print("Chargement du modèle d'image sur CPU...")
 
 
 
 
 
 
 
 
 
36
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device_cpu)
37
  good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype, token=HF_TOKEN).to(device_cpu)
38
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1, token=HF_TOKEN).to(device_cpu)
39
+ print("Chargement du LoRA...")
 
 
40
  pipe.load_lora_weights(KRYPTO_LORA['repo'], low_cpu_mem_usage=True, adapter_name=KRYPTO_LORA['adapter_name'], token=HF_TOKEN)
41
+ print("Tous les modèles sont pré-chargés sur CPU.")
42
 
43
  MAX_SEED = 2**32 - 1
44
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
45
 
46
+ # On garde une référence globale pour ne recharger le modèle qu'une fois.
47
+ gemma_model = None
48
+
49
+ @spaces.GPU(duration=180) # Augmentation de la durée pour accommoder le chargement complet de Gemma
50
  def run_generation(prompt, enhance_prompt, lora_scale, cfg_scale, steps, randomize_seed, seed, aspect_ratio, base_resolution, progress=gr.Progress(track_tqdm=True)):
51
+ global gemma_model
52
  if not prompt:
53
  raise gr.Error("Prompt cannot be empty.")
54
 
 
56
  final_prompt = prompt
57
 
58
  if enhance_prompt:
59
+ # --- CHARGEMENT DYNAMIQUE SUR GPU ---
60
+ if gemma_model is None:
61
+ print(f"Premier appel : Chargement de {GEMMA_MODEL_ID} sur GPU avec quantization 4-bit...")
62
+ bnb_config_gpu = BitsAndBytesConfig(
63
+ load_in_4bit=True,
64
+ bnb_4bit_quant_type="nf4",
65
+ bnb_4bit_use_double_quant=True,
66
+ bnb_4bit_compute_dtype=dtype
67
+ )
68
+ gemma_model = AutoModelForCausalLM.from_pretrained(
69
+ GEMMA_MODEL_ID,
70
+ quantization_config=bnb_config_gpu,
71
+ token=HF_TOKEN,
72
+ device_map="auto" # "auto" fonctionnera car on est DÉJÀ sur un environnement GPU
73
+ )
74
+ print("Modèle Gemma chargé sur GPU.")
75
 
76
+ print(f"Amélioration du prompt '{prompt}' avec Gemma...")
77
  system_prompt = (
78
  "You are an expert prompt engineer for a text-to-image AI. "
79
  "Your task is to take a user's simple idea and transform it into a rich, detailed, and visually descriptive prompt. "
 
87
  outputs = gemma_model.generate(**inputs, max_new_tokens=150, do_sample=True, temperature=0.7)
88
  input_length = inputs["input_ids"].shape[1]
89
  final_prompt = gemma_tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True).strip()
 
90
  print(f"Prompt amélioré : {final_prompt}")
 
 
 
91
 
92
  prompt_mash = f"{KRYPTO_LORA['trigger']}, {final_prompt}"
93
  print("Prompt final envoyé au modèle d'image:", prompt_mash)
 
95
  print("Déplacement du pipeline d'image sur le GPU...")
96
  pipe.to(device_gpu)
97
  good_vae.to(device_gpu)
 
98
  pipe.set_adapters([KRYPTO_LORA['adapter_name']], adapter_weights=[lora_scale])
99
 
100
+ if randomize_seed: seed = random.randint(0, MAX_SEED)
 
 
101
  width, height = calculate_dimensions(aspect_ratio, base_resolution)
102
  print(f"Génération d'une image de {width}x{height} pixels.")
103
 
104
  generator = torch.Generator(device=device_gpu).manual_seed(seed)
 
105
  image_generator = pipe.flux_pipe_call_that_returns_an_iterable_of_images(
106
  prompt=prompt_mash, num_inference_steps=steps, guidance_scale=cfg_scale,
107
  width=width, height=height, generator=generator, output_type="pil", good_vae=good_vae,