Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import pipeline | |
from PIL import Image | |
import os | |
import tempfile | |
import torch | |
import torchaudio | |
import uuid | |
# Custom CSS for kid-friendly styling | |
st.markdown(""" | |
<style> | |
.main {background-color: #f0f8ff;} | |
.title {font-size: 36px; color: #ff4500; font-weight: bold; text-align: center;} | |
.header {font-size: 24px; color: #4682b4; font-weight: bold;} | |
.story-text {font-size: 18px; color: #2f4f4f; background-color: #e6e6fa; padding: 10px; border-radius: 10px;} | |
.caption-text {font-size: 16px; color: #228b22; font-style: italic;} | |
.sidebar .sidebar-content {background-color: #fffacd;} | |
.stButton>button {background-color: #ff69b4; color: white; border-radius: 10px;} | |
.stProgress .st-bo {background-color: #90ee90;} | |
</style> | |
""", unsafe_allow_html=True) | |
# Initialize session state for resetting | |
if 'reset' not in st.session_state: | |
st.session_state.reset = False | |
if 'uploader_key' not in st.session_state: | |
st.session_state.uploader_key = str(uuid.uuid4()) | |
# Cache model loading to improve performance | |
def load_caption_generator(): | |
"""Load ViT-GPT2 model for image captioning.""" | |
try: | |
return pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning") | |
except Exception as e: | |
st.error(f"Failed to load caption generator: {str(e)}") | |
return None | |
def load_story_generator(): | |
"""Load Qwen2-7B-Instruct model for story generation.""" | |
try: | |
return pipeline("text-generation", model="Qwen/Qwen3-1.7B") | |
except Exception as e: | |
st.error(f"Failed to load story generator: {str(e)}") | |
return None | |
def load_tts_generator(): | |
"""Load MMS-TTS model for text-to-speech.""" | |
try: | |
return pipeline("text-to-speech", model="facebook/mms-tts-eng") | |
except Exception as e: | |
st.error(f"Failed to load TTS generator: {str(e)}") | |
return None | |
# Function to process image and generate caption | |
def process_image(image): | |
"""Generate a caption for the uploaded image using ViT-GPT2 model.""" | |
caption_generator = load_caption_generator() | |
if not caption_generator: | |
return None | |
try: | |
caption = caption_generator(image)[0]['generated_text'] | |
return caption | |
except Exception as e: | |
st.markdown(f'<span style="color: red;">Oops, couldn’t read the picture: {str(e)}</span>', unsafe_allow_html=True) | |
return None | |
# Function to generate story from caption | |
def generate_story(caption, name, grade, gender): | |
"""Generate a child-friendly story using Qwen2-7B-Instruct model.""" | |
story_generator = load_story_generator() | |
if not story_generator: | |
return None | |
try: | |
protagonist = name if name else "a happy child" | |
pronoun = "he" if gender == "Boy" else "she" if gender == "Girl" else "they" | |
grade_text = { | |
"Kindergarten": "kindergarten child, using very simple words and short sentences", | |
"Grades 1-2": "Grades 1-2 child, using simple words and light adventure plots", | |
"Grades 3-4": "Grades 3-4 child, using moderately complex words and adventure themes", | |
"Grades 5-6": "Grades 5-6 child, using slightly longer sentences and themes of courage or teamwork", | |
"None": "child aged 3-10, using simple words" | |
}.get(grade, "child aged 3-10, using simple words") | |
pro = 'Start the story with Hey {},'.format(protagonist) if name else '' | |
prompt = ( | |
f"Write a fun, happy, complete story for a {grade_text} named {protagonist} based on: {caption}. {pro}" | |
f"Use a cheerful tone, around 100 words, safe for kids. " | |
#f"Make {protagonist} the hero, using '{pronoun}' pronouns. " | |
#f"Do not include the prompt or any other contents in your response, just the story itself is enough." | |
#f"at the beginning of the story and end of the story, use **\start\ and **\end\ to mark the beginning and end of the story." | |
) | |
story = story_generator( | |
prompt, | |
max_length=500, | |
num_return_sequences=1, | |
truncation=True, | |
temperature=0.5, | |
top_p=0.9, | |
no_repeat_ngram_size=2 | |
)[0]['generated_text'] | |
print(story) | |
return story.strip(prompt) | |
except Exception as e: | |
st.markdown(f'<span style="color: red;">Oops, couldn’t create the story: {str(e)}</span>', unsafe_allow_html=True) | |
return None | |
# Function to convert text to speech | |
def text_to_speech(text): | |
"""Convert the generated story to audio using facebook/mms-tts-eng.""" | |
tts_generator = load_tts_generator() | |
if not tts_generator: | |
return None | |
try: | |
speech = tts_generator(text) | |
audio_data = speech['audio'] | |
audio_tensor = torch.tensor(audio_data, dtype=torch.float32) | |
sampling_rate = speech['sampling_rate'] | |
temp_dir = tempfile.gettempdir() | |
audio_path = os.path.join(temp_dir, f"story_audio_{uuid.uuid4()}.wav") | |
torchaudio.save(audio_path, audio_tensor, sampling_rate) | |
if os.path.exists(audio_path): | |
return audio_path | |
else: | |
st.markdown('<span style="color: red;">Oops, the audio didn’t work.</span>', unsafe_allow_html=True) | |
return None | |
except Exception as e: | |
st.markdown(f'<span style="color: red;">Oops, couldn’t make the audio: {str(e)}</span>', unsafe_allow_html=True) | |
return None | |
# Page 1: Welcome and Introduction | |
def welcome_page(): | |
st.markdown('<div class="title">🌟 Welcome to Kids\' Magical Storytelling App 🌟</div>', unsafe_allow_html=True) | |
st.markdown(""" | |
Welcome to the **Kids' Magical Storytelling App**! This is a magical world designed for kids aged 3-10! 🎉 | |
Upload a picture, and we’ll create a joyful story just for you, complete with audio to listen to! | |
This app combines Hugging Face’s AI magic with Streamlit’s simple design to spark children’s imagination. | |
**Let’s get started!** Upload an image and embark on a fantastical adventure! ✨ | |
""") | |
st.markdown('<div class="header">About This App</div>', unsafe_allow_html=True) | |
st.write(""" | |
This app is a special project for **ISOM5240**. It uses: | |
- **AI Image Captioning** to describe your picture. | |
- **AI Story Generation** to write a fun tale. | |
- **Text-to-Speech** to bring the story to life. | |
**How to Use**: | |
1. Go to the “Create Story” page and upload a JPG or PNG image. | |
2. Enter your name, grade, and gender (optional) to personalize your story. | |
3. Wait for the story to appear, then read and listen! | |
Made for kids, with ❤️ by [Your Name]. | |
""") | |
st.image("https://huggingface.co/front/assets/huggingface_logo-noborder.svg", width=100) | |
st.write("[Visit Hugging Face](https://huggingface.co)") | |
# Page 2: Story Creation | |
def story_page(): | |
#st.set_page_config(page_title="Kids' Magical Storytelling App", layout="wide") | |
with st.sidebar: | |
st.markdown('<div class="header">Personalize Your Story</div>', unsafe_allow_html=True) | |
name = st.text_input("What’s Your Name? (Optional)", value="") | |
grade = st.selectbox("Choose Your Grade (Optional)", ["None", "Kindergarten", "Grades 1-2", "Grades 3-4", "Grades 5-6"]) | |
gender = st.selectbox("Choose Your Gender (Optional)", ["None", "Boy", "Girl"]) | |
st.markdown('<div class="title">🌟 Create Your Magical Story 🌟</div>', unsafe_allow_html=True) | |
st.markdown('<div class="header">📸 Upload Your Picture</div>', unsafe_allow_html=True) | |
uploaded_file = st.file_uploader("Choose a fun image (JPG/PNG)...", type=["jpg", "jpeg", "png"], key=st.session_state.uploader_key) | |
if uploaded_file is not None and not st.session_state.reset: | |
image = Image.open(uploaded_file) | |
st.image(image, caption="Your Magical Image", use_container_width=True) | |
progress_bar = st.progress(0) | |
with st.spinner("🖼️ Reading your picture..."): | |
caption = process_image(image) | |
if caption is None: | |
return | |
progress_bar.progress(33) | |
st.markdown(f'<div class="caption-text">✨ Image Description: {caption}</div>', unsafe_allow_html=True) | |
with st.spinner("📝 Writing a magical story..."): | |
story = generate_story(caption, name, grade, gender) | |
if story is None: | |
return | |
progress_bar.progress(66) | |
st.markdown('<div class="header">📚 Your Story</div>', unsafe_allow_html=True) | |
st.markdown(f'<div class="story-text">{story}</div>', unsafe_allow_html=True) | |
with st.spinner("🎙️ Turning story into audio..."): | |
audio_path = text_to_speech(story) | |
if audio_path is None: | |
return | |
progress_bar.progress(100) | |
st.markdown('<div class="header">🎵 Listen to Your Story</div>', unsafe_allow_html=True) | |
if os.path.exists(audio_path): | |
with open(audio_path, "rb") as audio_file: | |
audio_bytes = audio_file.read() | |
st.audio(audio_bytes, format="audio/wav") | |
col0, col1, col2 = st.columns(3) | |
col0.download_button( | |
label="Save Your Story!", | |
data=story, | |
file_name="story.txt", | |
mime="text/plain" | |
) | |
col1.download_button( | |
label="Save Your Audio!", | |
data=audio_bytes, | |
file_name="story_audio.wav", | |
mime="audio/wav" | |
) | |
try: | |
os.remove(audio_path) | |
except Exception as e: | |
st.markdown(f'<span style="color: orange;">Couldn’t clean up audio file: {str(e)}</span>', unsafe_allow_html=True) | |
else: | |
st.markdown('<span style="color: red;">Oops, the audio file is missing. Try again!</span>', unsafe_allow_html=True) | |
st.balloons() | |
st.write("🎉 All done! Try another picture for a new adventure!") | |
st.session_state.reset = False | |
# Navigation setup | |
pages = { | |
"Welcome": st.Page(welcome_page, title="Welcome"), | |
"Create Story": st.Page(story_page, title="Create Story") | |
} | |
pg = st.navigation([welcome_page,story_page]) | |
pg.run() | |