Storytelling / app.py
YSOMAAAD's picture
Update app.py
70ec7ec verified
import streamlit as st
from transformers import pipeline
from PIL import Image
import os
import tempfile
import torch
import torchaudio
import uuid
# Custom CSS for kid-friendly styling
st.markdown("""
<style>
.main {background-color: #f0f8ff;}
.title {font-size: 36px; color: #ff4500; font-weight: bold; text-align: center;}
.header {font-size: 24px; color: #4682b4; font-weight: bold;}
.story-text {font-size: 18px; color: #2f4f4f; background-color: #e6e6fa; padding: 10px; border-radius: 10px;}
.caption-text {font-size: 16px; color: #228b22; font-style: italic;}
.sidebar .sidebar-content {background-color: #fffacd;}
.stButton>button {background-color: #ff69b4; color: white; border-radius: 10px;}
.stProgress .st-bo {background-color: #90ee90;}
</style>
""", unsafe_allow_html=True)
# Initialize session state for resetting
if 'reset' not in st.session_state:
st.session_state.reset = False
if 'uploader_key' not in st.session_state:
st.session_state.uploader_key = str(uuid.uuid4())
# Cache model loading to improve performance
@st.cache_resource
def load_caption_generator():
"""Load ViT-GPT2 model for image captioning."""
try:
return pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")
except Exception as e:
st.error(f"Failed to load caption generator: {str(e)}")
return None
@st.cache_resource
def load_story_generator():
"""Load Qwen2-7B-Instruct model for story generation."""
try:
return pipeline("text-generation", model="Qwen/Qwen3-1.7B")
except Exception as e:
st.error(f"Failed to load story generator: {str(e)}")
return None
@st.cache_resource
def load_tts_generator():
"""Load MMS-TTS model for text-to-speech."""
try:
return pipeline("text-to-speech", model="facebook/mms-tts-eng")
except Exception as e:
st.error(f"Failed to load TTS generator: {str(e)}")
return None
# Function to process image and generate caption
def process_image(image):
"""Generate a caption for the uploaded image using ViT-GPT2 model."""
caption_generator = load_caption_generator()
if not caption_generator:
return None
try:
caption = caption_generator(image)[0]['generated_text']
return caption
except Exception as e:
st.markdown(f'<span style="color: red;">Oops, couldn’t read the picture: {str(e)}</span>', unsafe_allow_html=True)
return None
# Function to generate story from caption
def generate_story(caption, name, grade, gender):
"""Generate a child-friendly story using Qwen2-7B-Instruct model."""
story_generator = load_story_generator()
if not story_generator:
return None
try:
protagonist = name if name else "a happy child"
pronoun = "he" if gender == "Boy" else "she" if gender == "Girl" else "they"
grade_text = {
"Kindergarten": "kindergarten child, using very simple words and short sentences",
"Grades 1-2": "Grades 1-2 child, using simple words and light adventure plots",
"Grades 3-4": "Grades 3-4 child, using moderately complex words and adventure themes",
"Grades 5-6": "Grades 5-6 child, using slightly longer sentences and themes of courage or teamwork",
"None": "child aged 3-10, using simple words"
}.get(grade, "child aged 3-10, using simple words")
pro = 'Start the story with Hey {},'.format(protagonist) if name else ''
prompt = (
f"Write a fun, happy, complete story for a {grade_text} named {protagonist} based on: {caption}. {pro}"
f"Use a cheerful tone, around 100 words, safe for kids. "
#f"Make {protagonist} the hero, using '{pronoun}' pronouns. "
#f"Do not include the prompt or any other contents in your response, just the story itself is enough."
#f"at the beginning of the story and end of the story, use **\start\ and **\end\ to mark the beginning and end of the story."
)
story = story_generator(
prompt,
max_length=500,
num_return_sequences=1,
truncation=True,
temperature=0.5,
top_p=0.9,
no_repeat_ngram_size=2
)[0]['generated_text']
print(story)
return story.strip(prompt)
except Exception as e:
st.markdown(f'<span style="color: red;">Oops, couldn’t create the story: {str(e)}</span>', unsafe_allow_html=True)
return None
# Function to convert text to speech
def text_to_speech(text):
"""Convert the generated story to audio using facebook/mms-tts-eng."""
tts_generator = load_tts_generator()
if not tts_generator:
return None
try:
speech = tts_generator(text)
audio_data = speech['audio']
audio_tensor = torch.tensor(audio_data, dtype=torch.float32)
sampling_rate = speech['sampling_rate']
temp_dir = tempfile.gettempdir()
audio_path = os.path.join(temp_dir, f"story_audio_{uuid.uuid4()}.wav")
torchaudio.save(audio_path, audio_tensor, sampling_rate)
if os.path.exists(audio_path):
return audio_path
else:
st.markdown('<span style="color: red;">Oops, the audio didn’t work.</span>', unsafe_allow_html=True)
return None
except Exception as e:
st.markdown(f'<span style="color: red;">Oops, couldn’t make the audio: {str(e)}</span>', unsafe_allow_html=True)
return None
# Page 1: Welcome and Introduction
def welcome_page():
st.markdown('<div class="title">🌟 Welcome to Kids\' Magical Storytelling App 🌟</div>', unsafe_allow_html=True)
st.markdown("""
Welcome to the **Kids' Magical Storytelling App**! This is a magical world designed for kids aged 3-10! 🎉
Upload a picture, and we’ll create a joyful story just for you, complete with audio to listen to!
This app combines Hugging Face’s AI magic with Streamlit’s simple design to spark children’s imagination.
**Let’s get started!** Upload an image and embark on a fantastical adventure! ✨
""")
st.markdown('<div class="header">About This App</div>', unsafe_allow_html=True)
st.write("""
This app is a special project for **ISOM5240**. It uses:
- **AI Image Captioning** to describe your picture.
- **AI Story Generation** to write a fun tale.
- **Text-to-Speech** to bring the story to life.
**How to Use**:
1. Go to the “Create Story” page and upload a JPG or PNG image.
2. Enter your name, grade, and gender (optional) to personalize your story.
3. Wait for the story to appear, then read and listen!
Made for kids, with ❤️ by [Your Name].
""")
st.image("https://huggingface.co/front/assets/huggingface_logo-noborder.svg", width=100)
st.write("[Visit Hugging Face](https://huggingface.co)")
# Page 2: Story Creation
def story_page():
#st.set_page_config(page_title="Kids' Magical Storytelling App", layout="wide")
with st.sidebar:
st.markdown('<div class="header">Personalize Your Story</div>', unsafe_allow_html=True)
name = st.text_input("What’s Your Name? (Optional)", value="")
grade = st.selectbox("Choose Your Grade (Optional)", ["None", "Kindergarten", "Grades 1-2", "Grades 3-4", "Grades 5-6"])
gender = st.selectbox("Choose Your Gender (Optional)", ["None", "Boy", "Girl"])
st.markdown('<div class="title">🌟 Create Your Magical Story 🌟</div>', unsafe_allow_html=True)
st.markdown('<div class="header">📸 Upload Your Picture</div>', unsafe_allow_html=True)
uploaded_file = st.file_uploader("Choose a fun image (JPG/PNG)...", type=["jpg", "jpeg", "png"], key=st.session_state.uploader_key)
if uploaded_file is not None and not st.session_state.reset:
image = Image.open(uploaded_file)
st.image(image, caption="Your Magical Image", use_container_width=True)
progress_bar = st.progress(0)
with st.spinner("🖼️ Reading your picture..."):
caption = process_image(image)
if caption is None:
return
progress_bar.progress(33)
st.markdown(f'<div class="caption-text">✨ Image Description: {caption}</div>', unsafe_allow_html=True)
with st.spinner("📝 Writing a magical story..."):
story = generate_story(caption, name, grade, gender)
if story is None:
return
progress_bar.progress(66)
st.markdown('<div class="header">📚 Your Story</div>', unsafe_allow_html=True)
st.markdown(f'<div class="story-text">{story}</div>', unsafe_allow_html=True)
with st.spinner("🎙️ Turning story into audio..."):
audio_path = text_to_speech(story)
if audio_path is None:
return
progress_bar.progress(100)
st.markdown('<div class="header">🎵 Listen to Your Story</div>', unsafe_allow_html=True)
if os.path.exists(audio_path):
with open(audio_path, "rb") as audio_file:
audio_bytes = audio_file.read()
st.audio(audio_bytes, format="audio/wav")
col0, col1, col2 = st.columns(3)
col0.download_button(
label="Save Your Story!",
data=story,
file_name="story.txt",
mime="text/plain"
)
col1.download_button(
label="Save Your Audio!",
data=audio_bytes,
file_name="story_audio.wav",
mime="audio/wav"
)
try:
os.remove(audio_path)
except Exception as e:
st.markdown(f'<span style="color: orange;">Couldn’t clean up audio file: {str(e)}</span>', unsafe_allow_html=True)
else:
st.markdown('<span style="color: red;">Oops, the audio file is missing. Try again!</span>', unsafe_allow_html=True)
st.balloons()
st.write("🎉 All done! Try another picture for a new adventure!")
st.session_state.reset = False
# Navigation setup
pages = {
"Welcome": st.Page(welcome_page, title="Welcome"),
"Create Story": st.Page(story_page, title="Create Story")
}
pg = st.navigation([welcome_page,story_page])
pg.run()