heesuuuuuu commited on
Commit
981c0cd
ยท
verified ยท
1 Parent(s): 46a6091

Upload 5 files

Browse files
Files changed (5) hide show
  1. .gitattributes +1 -0
  2. app.py +151 -0
  3. proj2_voca/index.faiss +3 -0
  4. proj2_voca/index.pkl +3 -0
  5. requirements.txt +9 -0
.gitattributes CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  dream_interpreter/proj2_voca/index.faiss filter=lfs diff=lfs merge=lfs -text
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  dream_interpreter/proj2_voca/index.faiss filter=lfs diff=lfs merge=lfs -text
37
+ proj2_voca/index.faiss filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.llms import HuggingFacePipeline
2
+ import torch
3
+ import os
4
+
5
+ os.environ["TRANSFORMERS_CACHE"] = "/data/heesu/hf_cache"
6
+
7
+ # beomi/gemma-ko-2b ๋ชจ๋ธ ์ง€์ •
8
+ model_id = "beomi/gemma-ko-2b"
9
+ llm = HuggingFacePipeline.from_model_id(
10
+ model_id=model_id,
11
+ task="text-generation",
12
+ device=None, # Explicitly set to None to avoid conflicts with device_map
13
+ model_kwargs={"torch_dtype": torch.bfloat16, "device_map": "auto"}
14
+ )
15
+
16
+ # %%
17
+ import os
18
+
19
+ os.environ['TRANSFORMERS_CACHE'] = '/data/heesu/huggingface_cache'
20
+ import torch
21
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
22
+ from langchain_community.vectorstores import FAISS
23
+ from langchain_huggingface import HuggingFaceEmbeddings
24
+ from langchain_community.llms import HuggingFacePipeline
25
+ from langchain.prompts import PromptTemplate
26
+ from langchain_core.runnables import RunnablePassthrough
27
+ from langchain_core.output_parsers import StrOutputParser
28
+ import gradio as gr # Gradio ์ž„ํฌํŠธ ์ถ”๊ฐ€
29
+
30
+ # --- 1. Vector DB ๋ฐ ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ ---
31
+ print("Vector DB์™€ ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ์„ ๋ถˆ๋Ÿฌ์˜ค๋Š” ์ค‘์ž…๋‹ˆ๋‹ค...")
32
+ index_path = "proj2_voca"
33
+ model_name = "jhgan/ko-sroberta-multitask"
34
+ model_kwargs = {'device': 'cuda'}
35
+ encode_kwargs = {'normalize_embeddings': True}
36
+ embeddings = HuggingFaceEmbeddings(
37
+ model_name=model_name,
38
+ model_kwargs=model_kwargs,
39
+ encode_kwargs=encode_kwargs
40
+ )
41
+ vectorstore = FAISS.load_local(index_path, embeddings, allow_dangerous_deserialization=True)
42
+ retriever = vectorstore.as_retriever(search_kwargs={'k': 2})
43
+
44
+ # --- 2. ์–ธ์–ด ๋ชจ๋ธ(LLM) ๋ถˆ๋Ÿฌ์˜ค๊ธฐ ---
45
+ print("์–ธ์–ด ๋ชจ๋ธ(beomi/gemma-ko-2b)์„ ๋ถˆ๋Ÿฌ์˜ค๋Š” ์ค‘์ž…๋‹ˆ๋‹ค. ์‹œ๊ฐ„์ด ๋‹ค์†Œ ๊ฑธ๋ฆด ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค...")
46
+ model_id = "beomi/gemma-ko-2b"
47
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
48
+ model = AutoModelForCausalLM.from_pretrained(
49
+ model_id,
50
+ torch_dtype=torch.bfloat16,
51
+ device_map=None # ์ž๋™ ๋ฐฐ์น˜ ๋น„ํ™œ์„ฑํ™”
52
+ ).to("cuda:2") # GPU ๋ฒˆํ˜ธ ๋ช…์‹œ
53
+
54
+ # --- 2. ์–ธ์–ด ๋ชจ๋ธ(LLM) ๋ถˆ๋Ÿฌ์˜ค๊ธฐ ---
55
+ # ... (์ด์ „ ์ฝ”๋“œ ์ƒ๋žต) ...
56
+ pipe = pipeline(
57
+ "text-generation",
58
+ model=model,
59
+ tokenizer=tokenizer,
60
+ device=2,
61
+ max_new_tokens=170, # ์ตœ๋Œ€ ์ƒ์„ฑ ํ† ํฐ ์ˆ˜ ๊ฐ์†Œ
62
+ temperature=0.7,
63
+ repetition_penalty=1.2, # ๋ฐ˜๋ณต ๋ฐฉ์ง€ ํŒจ๋„ํ‹ฐ ์ถ”๊ฐ€
64
+ pad_token_id=tokenizer.eos_token_id
65
+ )
66
+ # ... (์ดํ›„ ์ฝ”๋“œ ์ƒ๋žต) ...
67
+ llm = HuggingFacePipeline(pipeline=pipe)
68
+
69
+ # --- 3. ํ”„๋กฌํ”„ํŠธ ํ…œํ”Œ๋ฆฟ ์ •์˜ ---
70
+ template = """
71
+ ๋‹น์‹ ์€ ์‚ฌ์šฉ์ž์˜ ๊ฟˆ์„ ๋ช…ํ™•ํ•˜๊ฒŒ ํ•ด์„ํ•ด์ฃผ๋Š” ๊ฟˆ ํ•ด๋ชฝ ์ „๋ฌธ๊ฐ€์ž…๋‹ˆ๋‹ค. ์•„๋ž˜ '๊ฒ€์ƒ‰๋œ ๊ฟˆ ํ•ด๋ชฝ ์ •๋ณด'๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ, ์‚ฌ์šฉ์ž์˜ ์งˆ๋ฌธ์— ๋Œ€ํ•ด ๊ฐ€์žฅ ๊ด€๋ จ์„ฑ์ด ๋†’์€ ๋‚ด์šฉ์„ ํ•œ ๋ฌธ์žฅ์œผ๋กœ ์š”์•ฝํ•ด์„œ ๋‹ต๋ณ€ํ•ด์ฃผ์„ธ์š”
72
+
73
+
74
+ ### ๊ฒ€์ƒ‰๋œ ๊ฟˆ ํ•ด๋ชฝ ์ •๋ณด:
75
+ {context}
76
+
77
+
78
+
79
+ ### ์‚ฌ์šฉ์ž์˜ ์งˆ๋ฌธ:
80
+ {question}
81
+
82
+
83
+
84
+ ### ์ „๋ฌธ๊ฐ€์˜ ๋‹ต๋ณ€:
85
+ """
86
+ prompt = PromptTemplate.from_template(template)
87
+
88
+ # --- 4. RAG ์ฒด์ธ(Chain) ๊ตฌ์„ฑ ---
89
+ rag_chain = (
90
+ {"context": retriever, "question": RunnablePassthrough()}
91
+ | prompt
92
+ | llm
93
+ | StrOutputParser()
94
+ )
95
+
96
+
97
+ # --- ์‘๋‹ต ์ฒญ์†Œ ํ•จ์ˆ˜ ์ถ”๊ฐ€ (์›๋ณธ์—์„œ ๊ฐ€์ ธ์˜ด) ---
98
+ def clean_response(response):
99
+ # "###"๋กœ ๋ถ„ํ• ํ•˜๊ณ  ์ฒซ ๋ฒˆ์งธ ๋ถ€๋ถ„๋งŒ ๋ฐ˜ํ™˜ (๋ถˆํ•„์š”ํ•œ ํ›„์† ๋‚ด์šฉ ์ œ๊ฑฐ)
100
+ cleaned = response.split("###")[0].strip()
101
+ return cleaned
102
+
103
+
104
+ # --- 5. ๊ฟˆ ํ•ด๋ชฝ ํ•จ์ˆ˜ ์ •์˜ (Gradio์—์„œ ํ˜ธ์ถœ๋  ํ•จ์ˆ˜) ---
105
+ def interpret_dream(query):
106
+ docs_with_scores = vectorstore.similarity_search_with_score(query, k=2)
107
+ score_threshold = 1.2
108
+ filtered_docs = [doc for doc, score in docs_with_scores if score < score_threshold]
109
+
110
+ if not filtered_docs:
111
+ return "์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. ํ˜„์žฌ ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค์— ํ•ด๋‹น ๊ฟˆ์— ๋Œ€ํ•œ ์ •๋ณด๊ฐ€ ๋ถ€์กฑํ•˜์—ฌ ํ•ด๋ชฝ์„ ์ œ๊ณตํ•˜๊ธฐ ์–ด๋ ต์Šต๋‹ˆ๋‹ค."
112
+ else:
113
+ context_texts = [f"- {d.page_content}: {d.metadata['meaning']}" for d in filtered_docs]
114
+
115
+ local_rag_chain = (
116
+ prompt
117
+ | llm
118
+ | StrOutputParser()
119
+ )
120
+
121
+ response = local_rag_chain.invoke({
122
+ "context": "\n".join(context_texts),
123
+ "question": query
124
+ })
125
+
126
+ try:
127
+ final_answer = response.split("### ์ „๋ฌธ๊ฐ€์˜ ๋‹ต๋ณ€:")[1].strip()
128
+ except IndexError:
129
+ final_answer = response.strip()
130
+
131
+ # ์ฒญ์†Œ ํ•จ์ˆ˜ ์ ์šฉ (์›๋ณธ์—์„œ ๊ฐ€์ ธ์˜ด)
132
+ final_answer = clean_response(final_answer)
133
+
134
+ return final_answer
135
+
136
+
137
+ # --- 6. Gradio ์ธํ„ฐํŽ˜์ด์Šค ๊ตฌ์„ฑ ---
138
+ with gr.Blocks() as demo:
139
+ gr.Markdown("# โœจ ๊ฟˆ ํ•ด๋ชฝ ์„œ๋น„์Šค")
140
+ gr.Markdown("๋‹น์‹ ์˜ ๊ฟˆ ๋‚ด์šฉ์„ ์ž…๋ ฅํ•˜์„ธ์š”. AI๊ฐ€ ๋ช…ํ™•ํ•˜๊ฒŒ ํ•ด์„ํ•ด ๋“œ๋ฆฝ๋‹ˆ๋‹ค!")
141
+
142
+ with gr.Row():
143
+ input_text = gr.Textbox(label="๋‹น์‹ ์˜ ๊ฟˆ์€ ๋ฌด์—‡์ด์—ˆ๋‚˜์š”?", placeholder="์˜ˆ: ํŒŒ๋ž€ ๋ถˆ์ด ๋‚˜๋Š” ๊ฟˆ")
144
+ output_text = gr.Textbox(label="๊ฟˆ ํ•ด๋ชฝ ๊ฒฐ๊ณผ", interactive=False)
145
+
146
+ submit_button = gr.Button("ํ•ด๋ชฝํ•˜๊ธฐ")
147
+ submit_button.click(interpret_dream, inputs=input_text, outputs=output_text)
148
+
149
+ # ์•ฑ ์‹คํ–‰
150
+ if __name__ == '__main__':
151
+ demo.launch(share=True) # share=True๋กœ ๊ณต๊ณต URL ์ƒ์„ฑ (์„ ํƒ์‚ฌํ•ญ)
proj2_voca/index.faiss ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7897c8a84e645b5a16d4f4787f221246909169487e835e7ddea76f06a34e9618
3
+ size 844845
proj2_voca/index.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:800c4f48d82dd66cc1e2e77033f72dc393b90def434f204050e6d16d3d258f07
3
+ size 62707
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ json
2
+ langchain_communinity
3
+ langchain
4
+ os
5
+ torch
6
+ transformers
7
+ langchain_huggingface
8
+ langchain_core
9
+ gradio