import gradio as gr import cv2 import torch from PIL import Image from pathlib import Path from threading import Thread from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer import spaces import time TITLE = " מודל מבוסס גמה 3 ליצירת שירים מטופשים בעברית " DESCRIPTION= """ ניתן לבקש שיר על בסיס טקסט, תמונה ווידאו בכל פעם, יווצר שיר שונה, אז אם לא אהבתם, אפשר לנסות שוב עם אותו הפרומפט [המודל זמין להורדה](https://huggingface.co/Norod78/gemma-3_4b_hebrew-lyrics-finetune) המודל כּוּיַּיל ע״י [דורון אדלר](https://linktr.ee/Norod78) """ # model config model_4b_name = "Norod78/gemma-3_4b_hebrew-lyrics-finetune" model_4b = Gemma3ForConditionalGeneration.from_pretrained( model_4b_name, device_map="auto", torch_dtype=torch.bfloat16 ).eval() processor_4b = AutoProcessor.from_pretrained(model_4b_name) # I will add timestamp later def extract_video_frames(video_path, num_frames=8): cap = cv2.VideoCapture(video_path) frames = [] total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) step = max(total_frames // num_frames, 1) for i in range(num_frames): cap.set(cv2.CAP_PROP_POS_FRAMES, i * step) ret, frame = cap.read() if ret: frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(Image.fromarray(frame)) cap.release() return frames def format_message(content, files): message_content = [] if content: parts = content.split('') for i, part in enumerate(parts): if part.strip(): message_content.append({"type": "text", "text": part.strip()}) if i < len(parts) - 1 and files: img = Image.open(files.pop(0)) message_content.append({"type": "image", "image": img}) for file in files: file_path = file if isinstance(file, str) else file.name if Path(file_path).suffix.lower() in ['.jpg', '.jpeg', '.png']: img = Image.open(file_path) message_content.append({"type": "image", "image": img}) elif Path(file_path).suffix.lower() in ['.mp4', '.mov']: frames = extract_video_frames(file_path) for frame in frames: message_content.append({"type": "image", "image": frame}) return message_content def format_conversation_history(chat_history): messages = [] current_user_content = [] for item in chat_history: role = item["role"] content = item["content"] if role == "user": if isinstance(content, str): current_user_content.append({"type": "text", "text": content}) elif isinstance(content, list): current_user_content.extend(content) else: current_user_content.append({"type": "text", "text": str(content)}) elif role == "assistant": if current_user_content: messages.append({"role": "user", "content": current_user_content}) current_user_content = [] messages.append({"role": "assistant", "content": [{"type": "text", "text": str(content)}]}) if current_user_content: messages.append({"role": "user", "content": current_user_content}) return messages @spaces.GPU(duration=120) def generate_response(input_data, chat_history, max_new_tokens, system_prompt, temperature, top_p, top_k, repetition_penalty): """ Creates silly song lyrics in Hebrew based on user input and conversation history. Args: input_data (dict or str): - If dict: must include 'text' (str) and optional 'files' (list of image/video file paths). - If str: treated as plain text input. chat_history (list of dict): Sequence of past messages, each with keys 'role' and 'content'. max_new_tokens (int): Maximum number of tokens to generate for the response. system_prompt (str): Optional system-level instruction to guide the style and content of the response. temperature (float): Sampling temperature; higher values yield more diverse outputs. top_p (float): Nucleus sampling threshold for cumulative probability selection. top_k (int): Limits sampling to the top_k most likely tokens at each step. repetition_penalty (float): Penalty factor to discourage the model from repeating the same tokens. Yields: str: Streaming chunks of the generated Hebrew song lyrics in real time. """ if isinstance(input_data, dict) and "text" in input_data: text = input_data["text"] files = input_data.get("files", []) else: text = str(input_data) files = [] new_message_content = format_message(text, files) new_message = {"role": "user", "content": new_message_content} system_message = [{"role": "system", "content": [{"type": "text", "text": system_prompt}]}] if system_prompt else [] processed_history = format_conversation_history(chat_history) messages = system_message + processed_history if messages and messages[-1]["role"] == "user": messages[-1]["content"].extend(new_message["content"]) else: messages.append(new_message) model = model_4b processor = processor_4b inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True ).to(model.device) streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) generation_kwargs = dict( inputs, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty ) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() outputs = [] for text in streamer: outputs.append(text) yield "".join(outputs) chat_interface = gr.ChatInterface( fn=generate_response, chatbot=gr.Chatbot(rtl=True, show_copy_button=True,type="messages"), additional_inputs=[ gr.Slider(label="Max new tokens", minimum=100, maximum=2000, step=1, value=512), gr.Textbox( label="System Prompt", value="אתה משורר ישראלי, כותב שירים בעברית", lines=4, placeholder="שנה את ההגדרות של המודל", text_align = 'right', rtl = True ), gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.2), gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=0.4), gr.Slider(label="Top-k", minimum=1, maximum=100, step=1, value=30), gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.1), ], examples=[ [{"text": "כתוב לי בבקשה שיר המתאר את התמונה", "files": ["examples/image1.jpg"]}], [{"text": "תאר את התמונה ואז כתוב על זה שיר", "files": ["examples/image2.jpg"]}], [{"text": "תפוח אדמה עם חרדה חברתית"}] ], textbox=gr.MultimodalTextbox( rtl=True, label="קלט", file_types=["image", "video"], file_count="multiple", placeholder="בקשו שיר ו/או העלו תמונה", ), cache_examples=False, type="messages", fill_height=True, stop_btn="הפסק", css_paths=["style.css"], multimodal=True, title=TITLE, description=DESCRIPTION, theme=gr.themes.Soft(), ) if __name__ == "__main__": chat_interface.queue(max_size=20).launch()