YongdongWang commited on
Commit
5f15f77
·
verified ·
1 Parent(s): 1f65ffa

Create Llama 3.1 8B fine-tuned chat space

Browse files
Files changed (3) hide show
  1. README.md +39 -6
  2. app.py +218 -0
  3. requirements.txt +7 -0
README.md CHANGED
@@ -1,12 +1,45 @@
1
  ---
2
- title: Robot Task Planning
3
- emoji: 🦀
4
- colorFrom: gray
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 5.34.2
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Robot Task Planning - Llama 3.1 8B
3
+ emoji: 🤖
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: 4.44.0
8
  app_file: app.py
9
  pinned: false
10
+ license: llama3.1
11
  ---
12
 
13
+ # Robot Task Planning - Llama 3.1 8B
14
+
15
+ This Space demonstrates a fine-tuned version of Meta's Llama 3.1 8B model specialized for **robot task planning** using QLoRA technique.
16
+
17
+ The model converts natural language commands into structured task sequences for construction robots like excavators and dump trucks.
18
+
19
+ ## Model
20
+
21
+ The model is available at: [YongdongWang/llama-3.1-8b-dart-qlora](https://huggingface.co/YongdongWang/llama-3.1-8b-dart-qlora)
22
+
23
+ ## Features
24
+
25
+ - **Robot Command Processing**: Convert natural language to structured robot tasks
26
+ - **Multi-Robot Coordination**: Handle complex scenarios with multiple excavators and dump trucks
27
+ - **Task Dependencies**: Generate proper task sequences with dependencies
28
+ - **Real-time Planning**: Instant task generation powered by Gradio
29
+
30
+ ## Usage
31
+
32
+ Input natural language robot commands like "Deploy Excavator 1 to Soil Area 1" and the model will generate structured task sequences in JSON format for robot execution.
33
+
34
+ ## Technical Details
35
+
36
+ - **Base Model**: meta-llama/Llama-3.1-8B
37
+ - **Fine-tuning**: QLoRA (4-bit quantization + LoRA)
38
+ - **Interface**: Gradio
39
+ - **Hosting**: HuggingFace Spaces
40
+ - **Input**: Natural language robot commands
41
+ - **Output**: Structured JSON task sequences
42
+
43
+ ## Performance
44
+
45
+ ⚠️ **Note**: Model loading may take 3-5 minutes on first startup due to the large model size and quantization process.
app.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
4
+ from peft import PeftModel
5
+ import warnings
6
+ warnings.filterwarnings("ignore")
7
+
8
+ # 模型配置
9
+ MODEL_NAME = "meta-llama/Llama-3.1-8B"
10
+ LORA_MODEL = "YongdongWang/llama-3.1-8b-dart-qlora"
11
+
12
+ def load_model():
13
+ """加载模型和分词器"""
14
+ print("🔄 Loading model...")
15
+
16
+ try:
17
+ # 4位量化配置
18
+ bnb_config = BitsAndBytesConfig(
19
+ load_in_4bit=True,
20
+ bnb_4bit_quant_type="nf4",
21
+ bnb_4bit_compute_dtype=torch.float16,
22
+ bnb_4bit_use_double_quant=True,
23
+ )
24
+
25
+ # 加载分词器
26
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
27
+ if tokenizer.pad_token is None:
28
+ tokenizer.pad_token = tokenizer.eos_token
29
+
30
+ # 加载基础模型
31
+ base_model = AutoModelForCausalLM.from_pretrained(
32
+ MODEL_NAME,
33
+ quantization_config=bnb_config,
34
+ device_map="auto",
35
+ torch_dtype=torch.float16,
36
+ trust_remote_code=True
37
+ )
38
+
39
+ # 加载 LoRA 适配器
40
+ model = PeftModel.from_pretrained(base_model, LORA_MODEL)
41
+ model.eval()
42
+
43
+ print("✅ Model loaded successfully!")
44
+ return model, tokenizer
45
+
46
+ except Exception as load_error:
47
+ print(f"❌ Model loading failed: {load_error}")
48
+ return None, None
49
+
50
+ # 全局变量存储模型
51
+ model = None
52
+ tokenizer = None
53
+
54
+ def initialize_model():
55
+ """初始化模型 - 延迟加载"""
56
+ global model, tokenizer
57
+ if model is None or tokenizer is None:
58
+ model, tokenizer = load_model()
59
+ return model is not None and tokenizer is not None
60
+
61
+ def generate_response(prompt, max_tokens=200, temperature=0.7, top_p=0.9):
62
+ """生成回复"""
63
+ if not initialize_model():
64
+ return "❌ Model not loaded. Please check the logs or try again."
65
+
66
+ try:
67
+ # 格式化输入 - 移除多余的字符串插值
68
+ formatted_prompt = prompt.strip()
69
+
70
+ # 编码输入
71
+ inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
72
+
73
+ # 生成回复
74
+ with torch.no_grad():
75
+ outputs = model.generate(
76
+ **inputs,
77
+ max_new_tokens=max_tokens,
78
+ do_sample=True,
79
+ temperature=temperature,
80
+ top_p=top_p,
81
+ pad_token_id=tokenizer.pad_token_id,
82
+ eos_token_id=tokenizer.eos_token_id,
83
+ repetition_penalty=1.1,
84
+ early_stopping=True,
85
+ )
86
+
87
+ # 解码输出
88
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
89
+
90
+ # 移除原始输入,只保留生成的部分
91
+ if len(response) > len(formatted_prompt):
92
+ response = response[len(formatted_prompt):].strip()
93
+
94
+ # 如果回复包含特殊标记,进行清理
95
+ if "Assistant:" in response:
96
+ response = response.split("Assistant:")[-1].strip()
97
+
98
+ return response if response else "❌ No response generated. Please try again."
99
+
100
+ except Exception as generation_error:
101
+ return f"❌ Generation Error: {str(generation_error)}"
102
+
103
+ def chat_interface(message, history, max_tokens, temperature, top_p):
104
+ """聊天界面"""
105
+ if not message.strip():
106
+ return history, ""
107
+
108
+ try:
109
+ response = generate_response(message, max_tokens, temperature, top_p)
110
+ history.append((message, response))
111
+ return history, ""
112
+ except Exception as chat_error:
113
+ error_msg = f"❌ Chat Error: {str(chat_error)}"
114
+ history.append((message, error_msg))
115
+ return history, ""
116
+
117
+ # 创建 Gradio 应用
118
+ with gr.Blocks(title="Robot Task Planning - Llama 3.1 8B", theme=gr.themes.Soft()) as demo:
119
+ gr.Markdown("""
120
+ # 🤖 Llama 3.1 8B - Robot Task Planning
121
+
122
+ This is a fine-tuned version of Meta's Llama 3.1 8B model specialized for **robot task planning** using QLoRA technique.
123
+
124
+ **Capabilities**: Convert natural language robot commands into structured task sequences for excavators, dump trucks, and other construction robots.
125
+
126
+ **Model**: [YongdongWang/llama-3.1-8b-dart-qlora](https://huggingface.co/YongdongWang/llama-3.1-8b-dart-qlora)
127
+
128
+ ⚠️ **Note**: Model loading may take a few minutes on first startup.
129
+ """)
130
+
131
+ with gr.Row():
132
+ with gr.Column(scale=3):
133
+ chatbot = gr.Chatbot(
134
+ label="Task Planning Results",
135
+ height=400,
136
+ show_label=True,
137
+ container=True,
138
+ bubble_full_width=False
139
+ )
140
+
141
+ msg = gr.Textbox(
142
+ label="Robot Command",
143
+ placeholder="Enter robot task command (e.g., 'Deploy Excavator 1 to Soil Area 1')...",
144
+ lines=2,
145
+ max_lines=5,
146
+ show_label=True,
147
+ container=True
148
+ )
149
+
150
+ with gr.Row():
151
+ send_btn = gr.Button("Generate Tasks", variant="primary", size="sm")
152
+ clear_btn = gr.Button("Clear", variant="secondary", size="sm")
153
+
154
+ with gr.Column(scale=1):
155
+ gr.Markdown("### ⚙️ Generation Settings")
156
+
157
+ max_tokens = gr.Slider(
158
+ minimum=50,
159
+ maximum=500,
160
+ value=200,
161
+ step=10,
162
+ label="Max Tokens",
163
+ info="Maximum number of tokens to generate"
164
+ )
165
+
166
+ temperature = gr.Slider(
167
+ minimum=0.1,
168
+ maximum=2.0,
169
+ value=0.7,
170
+ step=0.1,
171
+ label="Temperature",
172
+ info="Controls randomness (lower = more focused)"
173
+ )
174
+
175
+ top_p = gr.Slider(
176
+ minimum=0.1,
177
+ maximum=1.0,
178
+ value=0.9,
179
+ step=0.05,
180
+ label="Top-p",
181
+ info="Nucleus sampling threshold"
182
+ )
183
+
184
+ # 示例对话
185
+ gr.Examples(
186
+ examples=[
187
+ ["Deploy Excavator 1 to Soil Area 1 for excavation."],
188
+ ["Send Dump Truck 1 to collect material, then unload at storage area."],
189
+ ["Move all robots to avoid Puddle 1 after inspection."],
190
+ ["Deploy multiple excavators to different soil areas simultaneously."],
191
+ ["Coordinate dump trucks to transport materials from excavation site to storage."],
192
+ ["Send robot to inspect rock area, then avoid with all other robots."],
193
+ ["Return all robots to start position after completing tasks."],
194
+ ],
195
+ inputs=msg,
196
+ label="💡 Example Robot Commands"
197
+ )
198
+
199
+ # 事件处理
200
+ msg.submit(
201
+ chat_interface,
202
+ inputs=[msg, chatbot, max_tokens, temperature, top_p],
203
+ outputs=[chatbot, msg]
204
+ )
205
+
206
+ send_btn.click(
207
+ chat_interface,
208
+ inputs=[msg, chatbot, max_tokens, temperature, top_p],
209
+ outputs=[chatbot, msg]
210
+ )
211
+
212
+ clear_btn.click(
213
+ lambda: ([], ""),
214
+ outputs=[chatbot, msg]
215
+ )
216
+
217
+ if __name__ == "__main__":
218
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio==4.44.0
2
+ transformers==4.44.2
3
+ torch==2.1.0
4
+ peft==0.7.1
5
+ bitsandbytes==0.41.3
6
+ accelerate==0.24.1
7
+ scipy==1.11.4