Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -22,47 +22,33 @@ prompt_values = df.values.flatten()
|
|
22 |
# Récupérer le token
|
23 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
24 |
|
25 |
-
# --- Initialisation des Modèles (sur CPU uniquement) ---
|
26 |
device_cpu = "cpu"
|
27 |
dtype = torch.bfloat16
|
28 |
base_model = "black-forest-labs/FLUX.1-dev"
|
29 |
|
30 |
-
# ---
|
31 |
-
#
|
32 |
-
#
|
33 |
-
|
34 |
-
load_in_4bit=True,
|
35 |
-
bnb_4bit_quant_type="nf4",
|
36 |
-
bnb_4bit_use_double_quant=True,
|
37 |
-
bnb_4bit_compute_dtype=dtype
|
38 |
-
)
|
39 |
-
|
40 |
-
print(f"Chargement du LLM {GEMMA_MODEL_ID} sur CPU avec la config NF4...")
|
41 |
gemma_tokenizer = AutoTokenizer.from_pretrained(GEMMA_MODEL_ID, token=HF_TOKEN)
|
42 |
-
|
43 |
-
GEMMA_MODEL_ID,
|
44 |
-
quantization_config=bnb_config_cpu, # Utilisation de la configuration corrigée
|
45 |
-
token=HF_TOKEN,
|
46 |
-
device_map={'':device_cpu}
|
47 |
-
)
|
48 |
-
print("Modèle Gemma chargé.")
|
49 |
-
|
50 |
-
# Le reste du chargement est correct
|
51 |
-
print("Chargement des composants du modèle d'image sur CPU...")
|
52 |
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device_cpu)
|
53 |
good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype, token=HF_TOKEN).to(device_cpu)
|
54 |
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1, token=HF_TOKEN).to(device_cpu)
|
55 |
-
print("
|
56 |
-
|
57 |
-
print(f"Chargement du LoRA : {KRYPTO_LORA['repo']}")
|
58 |
pipe.load_lora_weights(KRYPTO_LORA['repo'], low_cpu_mem_usage=True, adapter_name=KRYPTO_LORA['adapter_name'], token=HF_TOKEN)
|
59 |
-
print("
|
60 |
|
61 |
MAX_SEED = 2**32 - 1
|
62 |
pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
|
63 |
|
64 |
-
|
|
|
|
|
|
|
65 |
def run_generation(prompt, enhance_prompt, lora_scale, cfg_scale, steps, randomize_seed, seed, aspect_ratio, base_resolution, progress=gr.Progress(track_tqdm=True)):
|
|
|
66 |
if not prompt:
|
67 |
raise gr.Error("Prompt cannot be empty.")
|
68 |
|
@@ -70,10 +56,24 @@ def run_generation(prompt, enhance_prompt, lora_scale, cfg_scale, steps, randomi
|
|
70 |
final_prompt = prompt
|
71 |
|
72 |
if enhance_prompt:
|
73 |
-
|
74 |
-
gemma_model
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
|
|
77 |
system_prompt = (
|
78 |
"You are an expert prompt engineer for a text-to-image AI. "
|
79 |
"Your task is to take a user's simple idea and transform it into a rich, detailed, and visually descriptive prompt. "
|
@@ -87,11 +87,7 @@ def run_generation(prompt, enhance_prompt, lora_scale, cfg_scale, steps, randomi
|
|
87 |
outputs = gemma_model.generate(**inputs, max_new_tokens=150, do_sample=True, temperature=0.7)
|
88 |
input_length = inputs["input_ids"].shape[1]
|
89 |
final_prompt = gemma_tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True).strip()
|
90 |
-
|
91 |
print(f"Prompt amélioré : {final_prompt}")
|
92 |
-
print("Libération de la mémoire de Gemma (déplacement vers CPU)...")
|
93 |
-
gemma_model.to(device_cpu)
|
94 |
-
torch.cuda.empty_cache()
|
95 |
|
96 |
prompt_mash = f"{KRYPTO_LORA['trigger']}, {final_prompt}"
|
97 |
print("Prompt final envoyé au modèle d'image:", prompt_mash)
|
@@ -99,17 +95,13 @@ def run_generation(prompt, enhance_prompt, lora_scale, cfg_scale, steps, randomi
|
|
99 |
print("Déplacement du pipeline d'image sur le GPU...")
|
100 |
pipe.to(device_gpu)
|
101 |
good_vae.to(device_gpu)
|
102 |
-
|
103 |
pipe.set_adapters([KRYPTO_LORA['adapter_name']], adapter_weights=[lora_scale])
|
104 |
|
105 |
-
if randomize_seed:
|
106 |
-
seed = random.randint(0, MAX_SEED)
|
107 |
-
|
108 |
width, height = calculate_dimensions(aspect_ratio, base_resolution)
|
109 |
print(f"Génération d'une image de {width}x{height} pixels.")
|
110 |
|
111 |
generator = torch.Generator(device=device_gpu).manual_seed(seed)
|
112 |
-
|
113 |
image_generator = pipe.flux_pipe_call_that_returns_an_iterable_of_images(
|
114 |
prompt=prompt_mash, num_inference_steps=steps, guidance_scale=cfg_scale,
|
115 |
width=width, height=height, generator=generator, output_type="pil", good_vae=good_vae,
|
|
|
22 |
# Récupérer le token
|
23 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
24 |
|
25 |
+
# --- Initialisation des Modèles (sur CPU uniquement et SANS QUANTIZATION initiale) ---
|
26 |
device_cpu = "cpu"
|
27 |
dtype = torch.bfloat16
|
28 |
base_model = "black-forest-labs/FLUX.1-dev"
|
29 |
|
30 |
+
# --- STRATÉGIE CORRIGÉE ---
|
31 |
+
# On charge Gemma sur le CPU SANS le quantizer au démarrage pour éviter le conflit avec l'environnement de Spaces.
|
32 |
+
# La quantization sera appliquée plus tard, uniquement sur le GPU.
|
33 |
+
print(f"Chargement du tokenizer pour {GEMMA_MODEL_ID}...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
gemma_tokenizer = AutoTokenizer.from_pretrained(GEMMA_MODEL_ID, token=HF_TOKEN)
|
35 |
+
print("Chargement du modèle d'image sur CPU...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device_cpu)
|
37 |
good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype, token=HF_TOKEN).to(device_cpu)
|
38 |
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1, token=HF_TOKEN).to(device_cpu)
|
39 |
+
print("Chargement du LoRA...")
|
|
|
|
|
40 |
pipe.load_lora_weights(KRYPTO_LORA['repo'], low_cpu_mem_usage=True, adapter_name=KRYPTO_LORA['adapter_name'], token=HF_TOKEN)
|
41 |
+
print("Tous les modèles sont pré-chargés sur CPU.")
|
42 |
|
43 |
MAX_SEED = 2**32 - 1
|
44 |
pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
|
45 |
|
46 |
+
# On garde une référence globale pour ne recharger le modèle qu'une fois.
|
47 |
+
gemma_model = None
|
48 |
+
|
49 |
+
@spaces.GPU(duration=180) # Augmentation de la durée pour accommoder le chargement complet de Gemma
|
50 |
def run_generation(prompt, enhance_prompt, lora_scale, cfg_scale, steps, randomize_seed, seed, aspect_ratio, base_resolution, progress=gr.Progress(track_tqdm=True)):
|
51 |
+
global gemma_model
|
52 |
if not prompt:
|
53 |
raise gr.Error("Prompt cannot be empty.")
|
54 |
|
|
|
56 |
final_prompt = prompt
|
57 |
|
58 |
if enhance_prompt:
|
59 |
+
# --- CHARGEMENT DYNAMIQUE SUR GPU ---
|
60 |
+
if gemma_model is None:
|
61 |
+
print(f"Premier appel : Chargement de {GEMMA_MODEL_ID} sur GPU avec quantization 4-bit...")
|
62 |
+
bnb_config_gpu = BitsAndBytesConfig(
|
63 |
+
load_in_4bit=True,
|
64 |
+
bnb_4bit_quant_type="nf4",
|
65 |
+
bnb_4bit_use_double_quant=True,
|
66 |
+
bnb_4bit_compute_dtype=dtype
|
67 |
+
)
|
68 |
+
gemma_model = AutoModelForCausalLM.from_pretrained(
|
69 |
+
GEMMA_MODEL_ID,
|
70 |
+
quantization_config=bnb_config_gpu,
|
71 |
+
token=HF_TOKEN,
|
72 |
+
device_map="auto" # "auto" fonctionnera car on est DÉJÀ sur un environnement GPU
|
73 |
+
)
|
74 |
+
print("Modèle Gemma chargé sur GPU.")
|
75 |
|
76 |
+
print(f"Amélioration du prompt '{prompt}' avec Gemma...")
|
77 |
system_prompt = (
|
78 |
"You are an expert prompt engineer for a text-to-image AI. "
|
79 |
"Your task is to take a user's simple idea and transform it into a rich, detailed, and visually descriptive prompt. "
|
|
|
87 |
outputs = gemma_model.generate(**inputs, max_new_tokens=150, do_sample=True, temperature=0.7)
|
88 |
input_length = inputs["input_ids"].shape[1]
|
89 |
final_prompt = gemma_tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True).strip()
|
|
|
90 |
print(f"Prompt amélioré : {final_prompt}")
|
|
|
|
|
|
|
91 |
|
92 |
prompt_mash = f"{KRYPTO_LORA['trigger']}, {final_prompt}"
|
93 |
print("Prompt final envoyé au modèle d'image:", prompt_mash)
|
|
|
95 |
print("Déplacement du pipeline d'image sur le GPU...")
|
96 |
pipe.to(device_gpu)
|
97 |
good_vae.to(device_gpu)
|
|
|
98 |
pipe.set_adapters([KRYPTO_LORA['adapter_name']], adapter_weights=[lora_scale])
|
99 |
|
100 |
+
if randomize_seed: seed = random.randint(0, MAX_SEED)
|
|
|
|
|
101 |
width, height = calculate_dimensions(aspect_ratio, base_resolution)
|
102 |
print(f"Génération d'une image de {width}x{height} pixels.")
|
103 |
|
104 |
generator = torch.Generator(device=device_gpu).manual_seed(seed)
|
|
|
105 |
image_generator = pipe.flux_pipe_call_that_returns_an_iterable_of_images(
|
106 |
prompt=prompt_mash, num_inference_steps=steps, guidance_scale=cfg_scale,
|
107 |
width=width, height=height, generator=generator, output_type="pil", good_vae=good_vae,
|