import os import sys import gc import tempfile from pathlib import Path # 권한 문제 해결을 위한 환경 변수 설정 (최상단에 위치) temp_dir = tempfile.gettempdir() os.environ["STREAMLIT_HOME"] = temp_dir os.environ["STREAMLIT_CONFIG_DIR"] = os.path.join(temp_dir, ".streamlit") os.environ["STREAMLIT_SERVER_HEADLESS"] = "true" os.environ["STREAMLIT_BROWSER_GATHER_USAGE_STATS"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false" # 캐시 디렉토리도 temp로 설정 os.environ["TRANSFORMERS_CACHE"] = os.path.join(temp_dir, "transformers_cache") os.environ["HF_HOME"] = os.path.join(temp_dir, "huggingface") # PyTorch 클래스 경로 충돌 해결 try: import torch import importlib.util torch_classes_path = os.path.join(os.path.dirname(importlib.util.find_spec("torch").origin), "classes") if hasattr(torch, "classes"): torch.classes.__path__ = [torch_classes_path] except Exception: pass import streamlit as st # transformers 라이브러리 import 및 상태 체크 try: from transformers import AutoTokenizer, AutoModelForCausalLM import torch TRANSFORMERS_AVAILABLE = True except ImportError as e: TRANSFORMERS_AVAILABLE = False st.error(f"Transformers 라이브러리를 불러올 수 없습니다: {e}") # 페이지 설정 st.set_page_config( page_title="TinyLlama Demo", page_icon="🦙", layout="wide", initial_sidebar_state="expanded" ) st.title("🦙 TinyLlama 1.1B (CPU 전용) 데모") if not TRANSFORMERS_AVAILABLE: st.error("필요한 라이브러리를 설치해주세요:") st.code(""" pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu pip install transformers pip install streamlit """, language="bash") st.stop() @st.cache_resource(show_spinner=False) def load_tinyllama_model(): """TinyLlama 1.1B 모델 로드 (CPU Only)""" try: # 여러 가능한 모델 이름 시도 model_options = [ "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T", "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "microsoft/DialoGPT-small" # 백업 옵션 ] for model_name in model_options: try: st.info(f"모델 시도 중: {model_name}") # 토크나이저 로드 tokenizer = AutoTokenizer.from_pretrained( model_name, trust_remote_code=True, cache_dir=os.environ.get("TRANSFORMERS_CACHE") ) # 모델 로드 model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float32, trust_remote_code=True, cache_dir=os.environ.get("TRANSFORMERS_CACHE"), device_map="cpu" ) # CPU로 명시적 이동 및 평가 모드 model = model.to("cpu") model.eval() # 토크나이저 설정 if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # 메모리 정리 gc.collect() return model, tokenizer, f"✅ {model_name} 로드 성공!" except Exception as model_error: st.warning(f"{model_name} 로드 실패: {str(model_error)}") continue return None, None, "❌ 모든 모델 로드 실패" except Exception as e: return None, None, f"❌ 전체 로드 실패: {str(e)}" def generate_text(model, tokenizer, prompt, max_new_tokens=150, temperature=0.7): """안전한 텍스트 생성 함수""" try: # 입력 길이 제한 max_input_length = 400 # 토큰화 inputs = tokenizer( prompt, return_tensors="pt", truncation=True, max_length=max_input_length, padding=False ) # CPU로 이동 inputs = {k: v.to("cpu") for k, v in inputs.items()} input_length = inputs['input_ids'].shape[1] # 안전한 생성 길이 계산 safe_max_tokens = min(max_new_tokens, 800 - input_length) if safe_max_tokens < 20: safe_max_tokens = 20 # 생성 설정 generation_kwargs = { "max_new_tokens": safe_max_tokens, "temperature": temperature, "do_sample": True, "top_p": 0.9, "top_k": 50, "repetition_penalty": 1.1, "pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id, "eos_token_id": tokenizer.eos_token_id, "use_cache": True, "early_stopping": True } # 메모리 정리 gc.collect() # 생성 실행 with st.spinner("텍스트 생성 중..."): with torch.no_grad(): outputs = model.generate( **inputs, **generation_kwargs ) # 새로 생성된 부분만 추출 new_tokens = outputs[0][input_length:] generated_text = tokenizer.decode(new_tokens, skip_special_tokens=True) return generated_text.strip() except Exception as e: raise Exception(f"생성 중 오류: {str(e)}") def main(): # 모델 로드 with st.spinner("TinyLlama 모델 로딩 중... (처음 실행 시 다운로드로 인해 시간이 걸릴 수 있습니다)"): model, tokenizer, status = load_tinyllama_model() st.info(status) if not (model and tokenizer): st.error("모델 로드에 실패했습니다. 인터넷 연결을 확인하고 다시 시도해주세요.") return # 사이드바 설정 st.sidebar.header("⚙️ 생성 설정") max_new_tokens = st.sidebar.slider("최대 새 토큰 수", 20, 200, 100) temperature = st.sidebar.slider("Temperature (창의성)", 0.1, 1.0, 0.7, 0.1) # 도움말 st.sidebar.header("📖 사용 가이드") st.sidebar.info(""" **Tips:** - 프롬프트는 명확하고 간결하게 - CPU 전용이므로 생성에 시간이 걸립니다 - 첫 실행 시 모델 다운로드로 시간이 더 걸립니다 """) # 메인 인터페이스 st.header("💬 텍스트 생성") # 예제 프롬프트 col1, col2 = st.columns([2, 1]) with col1: example_prompts = [ "사용자 정의 입력", "The future of artificial intelligence is", "Once upon a time in a magical forest,", "Python is a programming language that", "Climate change is an important issue because", "The benefits of reading books include" ] selected_prompt = st.selectbox("예제 프롬프트 선택:", example_prompts) with col2: st.write("") # 공간 확보 st.write("") # 공간 확보 if st.button("🎲 랜덤 예제", help="랜덤한 예제 프롬프트 선택"): import random random_prompt = random.choice(example_prompts[1:]) # 첫 번째 제외 st.session_state.random_prompt = random_prompt # 프롬프트 입력 if selected_prompt == "사용자 정의 입력": default_prompt = st.session_state.get('random_prompt', '') prompt = st.text_area( "프롬프트를 입력하세요:", value=default_prompt, height=100, placeholder="여기에 텍스트를 입력하세요..." ) else: prompt = st.text_area( "프롬프트:", value=selected_prompt, height=100 ) # 토큰 수 표시 if prompt and tokenizer: try: token_count = len(tokenizer.encode(prompt)) st.caption(f"현재 프롬프트 토큰 수: {token_count}") if token_count > 400: st.warning("⚠️ 프롬프트가 너무 깁니다. 400 토큰으로 자동 잘림됩니다.") except: pass # 생성 버튼 col1, col2, col3 = st.columns([1, 1, 2]) with col1: generate_btn = st.button("🚀 생성 시작", type="primary", use_container_width=True) with col2: clear_btn = st.button("🗑️ 결과 지우기", use_container_width=True) # 결과 지우기 if clear_btn: if 'generated_result' in st.session_state: del st.session_state['generated_result'] st.rerun() # 텍스트 생성 if generate_btn and prompt.strip(): try: # 생성 진행률 표시 progress_container = st.container() with progress_container: progress_bar = st.progress(0) status_text = st.empty() status_text.text("토큰화 중...") progress_bar.progress(20) status_text.text("텍스트 생성 중... (CPU에서 실행되므로 시간이 걸립니다)") progress_bar.progress(40) # 실제 생성 generated_text = generate_text( model, tokenizer, prompt.strip(), max_new_tokens, temperature ) progress_bar.progress(80) status_text.text("결과 처리 중...") # 결과 저장 full_result = prompt + generated_text st.session_state['generated_result'] = { 'prompt': prompt, 'generated': generated_text, 'full_text': full_result } progress_bar.progress(100) status_text.text("완료!") # 진행률 표시 제거 progress_bar.empty() status_text.empty() except Exception as e: st.error(f"생성 중 오류가 발생했습니다: {str(e)}") st.info("💡 다시 시도하거나 더 짧은 프롬프트를 사용해보세요.") elif generate_btn: st.warning("⚠️ 프롬프트를 입력해주세요.") # 결과 표시 if 'generated_result' in st.session_state: result = st.session_state['generated_result'] st.header("📝 생성 결과") # 탭으로 구분 tab1, tab2 = st.tabs(["🎯 생성된 텍스트만", "📄 전체 텍스트"]) with tab1: st.markdown("**새로 생성된 부분:**") st.markdown(f'