import streamlit as st from transformers import pipeline from PIL import Image import os import tempfile import torch import torchaudio import uuid # Custom CSS for kid-friendly styling st.markdown(""" """, unsafe_allow_html=True) # Initialize session state for resetting if 'reset' not in st.session_state: st.session_state.reset = False if 'uploader_key' not in st.session_state: st.session_state.uploader_key = str(uuid.uuid4()) # Cache model loading to improve performance @st.cache_resource def load_caption_generator(): """Load ViT-GPT2 model for image captioning.""" try: return pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning") except Exception as e: st.error(f"Failed to load caption generator: {str(e)}") return None @st.cache_resource def load_story_generator(): """Load Qwen2-7B-Instruct model for story generation.""" try: return pipeline("text-generation", model="Qwen/Qwen3-1.7B") except Exception as e: st.error(f"Failed to load story generator: {str(e)}") return None @st.cache_resource def load_tts_generator(): """Load MMS-TTS model for text-to-speech.""" try: return pipeline("text-to-speech", model="facebook/mms-tts-eng") except Exception as e: st.error(f"Failed to load TTS generator: {str(e)}") return None # Function to process image and generate caption def process_image(image): """Generate a caption for the uploaded image using ViT-GPT2 model.""" caption_generator = load_caption_generator() if not caption_generator: return None try: caption = caption_generator(image)[0]['generated_text'] return caption except Exception as e: st.markdown(f'Oops, couldn’t read the picture: {str(e)}', unsafe_allow_html=True) return None # Function to generate story from caption def generate_story(caption, name, grade, gender): """Generate a child-friendly story using Qwen2-7B-Instruct model.""" story_generator = load_story_generator() if not story_generator: return None try: protagonist = name if name else "a happy child" pronoun = "he" if gender == "Boy" else "she" if gender == "Girl" else "they" grade_text = { "Kindergarten": "kindergarten child, using very simple words and short sentences", "Grades 1-2": "Grades 1-2 child, using simple words and light adventure plots", "Grades 3-4": "Grades 3-4 child, using moderately complex words and adventure themes", "Grades 5-6": "Grades 5-6 child, using slightly longer sentences and themes of courage or teamwork", "None": "child aged 3-10, using simple words" }.get(grade, "child aged 3-10, using simple words") pro = 'Start the story with Hey {},'.format(protagonist) if name else '' prompt = ( f"Write a fun, happy, complete story for a {grade_text} named {protagonist} based on: {caption}. {pro}" f"Use a cheerful tone, around 100 words, safe for kids. " #f"Make {protagonist} the hero, using '{pronoun}' pronouns. " #f"Do not include the prompt or any other contents in your response, just the story itself is enough." #f"at the beginning of the story and end of the story, use **\start\ and **\end\ to mark the beginning and end of the story." ) story = story_generator( prompt, max_length=500, num_return_sequences=1, truncation=True, temperature=0.5, top_p=0.9, no_repeat_ngram_size=2 )[0]['generated_text'] print(story) return story.strip(prompt) except Exception as e: st.markdown(f'Oops, couldn’t create the story: {str(e)}', unsafe_allow_html=True) return None # Function to convert text to speech def text_to_speech(text): """Convert the generated story to audio using facebook/mms-tts-eng.""" tts_generator = load_tts_generator() if not tts_generator: return None try: speech = tts_generator(text) audio_data = speech['audio'] audio_tensor = torch.tensor(audio_data, dtype=torch.float32) sampling_rate = speech['sampling_rate'] temp_dir = tempfile.gettempdir() audio_path = os.path.join(temp_dir, f"story_audio_{uuid.uuid4()}.wav") torchaudio.save(audio_path, audio_tensor, sampling_rate) if os.path.exists(audio_path): return audio_path else: st.markdown('Oops, the audio didn’t work.', unsafe_allow_html=True) return None except Exception as e: st.markdown(f'Oops, couldn’t make the audio: {str(e)}', unsafe_allow_html=True) return None # Page 1: Welcome and Introduction def welcome_page(): st.markdown('
🌟 Welcome to Kids\' Magical Storytelling App 🌟
', unsafe_allow_html=True) st.markdown(""" Welcome to the **Kids' Magical Storytelling App**! This is a magical world designed for kids aged 3-10! 🎉 Upload a picture, and we’ll create a joyful story just for you, complete with audio to listen to! This app combines Hugging Face’s AI magic with Streamlit’s simple design to spark children’s imagination. **Let’s get started!** Upload an image and embark on a fantastical adventure! ✨ """) st.markdown('
About This App
', unsafe_allow_html=True) st.write(""" This app is a special project for **ISOM5240**. It uses: - **AI Image Captioning** to describe your picture. - **AI Story Generation** to write a fun tale. - **Text-to-Speech** to bring the story to life. **How to Use**: 1. Go to the “Create Story” page and upload a JPG or PNG image. 2. Enter your name, grade, and gender (optional) to personalize your story. 3. Wait for the story to appear, then read and listen! Made for kids, with ❤️ by [Your Name]. """) st.image("https://huggingface.co/front/assets/huggingface_logo-noborder.svg", width=100) st.write("[Visit Hugging Face](https://huggingface.co)") # Page 2: Story Creation def story_page(): #st.set_page_config(page_title="Kids' Magical Storytelling App", layout="wide") with st.sidebar: st.markdown('
Personalize Your Story
', unsafe_allow_html=True) name = st.text_input("What’s Your Name? (Optional)", value="") grade = st.selectbox("Choose Your Grade (Optional)", ["None", "Kindergarten", "Grades 1-2", "Grades 3-4", "Grades 5-6"]) gender = st.selectbox("Choose Your Gender (Optional)", ["None", "Boy", "Girl"]) st.markdown('
🌟 Create Your Magical Story 🌟
', unsafe_allow_html=True) st.markdown('
📸 Upload Your Picture
', unsafe_allow_html=True) uploaded_file = st.file_uploader("Choose a fun image (JPG/PNG)...", type=["jpg", "jpeg", "png"], key=st.session_state.uploader_key) if uploaded_file is not None and not st.session_state.reset: image = Image.open(uploaded_file) st.image(image, caption="Your Magical Image", use_container_width=True) progress_bar = st.progress(0) with st.spinner("🖼️ Reading your picture..."): caption = process_image(image) if caption is None: return progress_bar.progress(33) st.markdown(f'
✨ Image Description: {caption}
', unsafe_allow_html=True) with st.spinner("📝 Writing a magical story..."): story = generate_story(caption, name, grade, gender) if story is None: return progress_bar.progress(66) st.markdown('
📚 Your Story
', unsafe_allow_html=True) st.markdown(f'
{story}
', unsafe_allow_html=True) with st.spinner("🎙️ Turning story into audio..."): audio_path = text_to_speech(story) if audio_path is None: return progress_bar.progress(100) st.markdown('
🎵 Listen to Your Story
', unsafe_allow_html=True) if os.path.exists(audio_path): with open(audio_path, "rb") as audio_file: audio_bytes = audio_file.read() st.audio(audio_bytes, format="audio/wav") col0, col1, col2 = st.columns(3) col0.download_button( label="Save Your Story!", data=story, file_name="story.txt", mime="text/plain" ) col1.download_button( label="Save Your Audio!", data=audio_bytes, file_name="story_audio.wav", mime="audio/wav" ) try: os.remove(audio_path) except Exception as e: st.markdown(f'Couldn’t clean up audio file: {str(e)}', unsafe_allow_html=True) else: st.markdown('Oops, the audio file is missing. Try again!', unsafe_allow_html=True) st.balloons() st.write("🎉 All done! Try another picture for a new adventure!") st.session_state.reset = False # Navigation setup pages = { "Welcome": st.Page(welcome_page, title="Welcome"), "Create Story": st.Page(story_page, title="Create Story") } pg = st.navigation([welcome_page,story_page]) pg.run()