import streamlit as st from transformers import pipeline from PIL import Image import os import tempfile import torch import torchaudio import uuid # Custom CSS for kid-friendly styling st.markdown(""" """, unsafe_allow_html=True) # Initialize session state for resetting if 'reset' not in st.session_state: st.session_state.reset = False if 'uploader_key' not in st.session_state: st.session_state.uploader_key = str(uuid.uuid4()) # Cache model loading to improve performance @st.cache_resource def load_caption_generator(): """Load ViT-GPT2 model for image captioning.""" try: return pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning") except Exception as e: st.error(f"Failed to load caption generator: {str(e)}") return None @st.cache_resource def load_story_generator(): """Load Qwen2-7B-Instruct model for story generation.""" try: return pipeline("text-generation", model="Qwen/Qwen3-1.7B") except Exception as e: st.error(f"Failed to load story generator: {str(e)}") return None @st.cache_resource def load_tts_generator(): """Load MMS-TTS model for text-to-speech.""" try: return pipeline("text-to-speech", model="facebook/mms-tts-eng") except Exception as e: st.error(f"Failed to load TTS generator: {str(e)}") return None # Function to process image and generate caption def process_image(image): """Generate a caption for the uploaded image using ViT-GPT2 model.""" caption_generator = load_caption_generator() if not caption_generator: return None try: caption = caption_generator(image)[0]['generated_text'] return caption except Exception as e: st.markdown(f'Oops, couldn’t read the picture: {str(e)}', unsafe_allow_html=True) return None # Function to generate story from caption def generate_story(caption, name, grade, gender): """Generate a child-friendly story using Qwen2-7B-Instruct model.""" story_generator = load_story_generator() if not story_generator: return None try: protagonist = name if name else "a happy child" pronoun = "he" if gender == "Boy" else "she" if gender == "Girl" else "they" grade_text = { "Kindergarten": "kindergarten child, using very simple words and short sentences", "Grades 1-2": "Grades 1-2 child, using simple words and light adventure plots", "Grades 3-4": "Grades 3-4 child, using moderately complex words and adventure themes", "Grades 5-6": "Grades 5-6 child, using slightly longer sentences and themes of courage or teamwork", "None": "child aged 3-10, using simple words" }.get(grade, "child aged 3-10, using simple words") pro = 'Start the story with Hey {},'.format(protagonist) if name else '' prompt = ( f"Write a fun, happy, complete story for a {grade_text} named {protagonist} based on: {caption}. {pro}" f"Use a cheerful tone, around 100 words, safe for kids. " #f"Make {protagonist} the hero, using '{pronoun}' pronouns. " #f"Do not include the prompt or any other contents in your response, just the story itself is enough." #f"at the beginning of the story and end of the story, use **\start\ and **\end\ to mark the beginning and end of the story." ) story = story_generator( prompt, max_length=500, num_return_sequences=1, truncation=True, temperature=0.5, top_p=0.9, no_repeat_ngram_size=2 )[0]['generated_text'] print(story) return story.strip(prompt) except Exception as e: st.markdown(f'Oops, couldn’t create the story: {str(e)}', unsafe_allow_html=True) return None # Function to convert text to speech def text_to_speech(text): """Convert the generated story to audio using facebook/mms-tts-eng.""" tts_generator = load_tts_generator() if not tts_generator: return None try: speech = tts_generator(text) audio_data = speech['audio'] audio_tensor = torch.tensor(audio_data, dtype=torch.float32) sampling_rate = speech['sampling_rate'] temp_dir = tempfile.gettempdir() audio_path = os.path.join(temp_dir, f"story_audio_{uuid.uuid4()}.wav") torchaudio.save(audio_path, audio_tensor, sampling_rate) if os.path.exists(audio_path): return audio_path else: st.markdown('Oops, the audio didn’t work.', unsafe_allow_html=True) return None except Exception as e: st.markdown(f'Oops, couldn’t make the audio: {str(e)}', unsafe_allow_html=True) return None # Page 1: Welcome and Introduction def welcome_page(): st.markdown('