WaysAheadGlobal commited on
Commit
d78fda0
·
verified ·
1 Parent(s): ebdab01

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -0
app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+
3
+ import streamlit as st
4
+ from PIL import Image
5
+ import torch
6
+
7
+ # Import TinyLLaVA modules (use local copy!)
8
+ from tinyllava.model.builder import load_pretrained_model
9
+ from tinyllava.utils import disable_torch_init
10
+ from tinyllava.mm_utils import (
11
+ process_images,
12
+ tokenizer_image_token,
13
+ get_model_name_from_path
14
+ )
15
+
16
+ # Disable torch default init for speed
17
+ disable_torch_init()
18
+
19
+ # Load TinyLLaVA 3.1B
20
+ MODEL_PATH = "bczhou/TinyLLaVA-3.1B"
21
+ tokenizer, model, image_processor, context_len = load_pretrained_model(
22
+ model_path=MODEL_PATH,
23
+ model_base=None,
24
+ model_name="TinyLLaVA-3.1B"
25
+ )
26
+
27
+ device = torch.device("cpu")
28
+ model.to(device)
29
+
30
+ # Streamlit UI
31
+ st.set_page_config(page_title="TinyLLaVA 3.1B (Streamlit)", layout="centered")
32
+ st.title("🦙 TinyLLaVA 3.1B — Vision-Language Q&A")
33
+
34
+ uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
35
+
36
+ prompt = st.text_input("Ask a question about the image:")
37
+
38
+ if uploaded_file is not None and prompt:
39
+ image = Image.open(uploaded_file).convert("RGB")
40
+
41
+ # Process image
42
+ image_tensor = process_images([image], image_processor, model.config)
43
+ image_tensor = image_tensor.to(device)
44
+
45
+ # Process prompt
46
+ prompt_text = tokenizer_image_token(prompt, tokenizer, context_len)
47
+ inputs = tokenizer([prompt_text])
48
+ input_ids = torch.tensor(inputs.input_ids).unsqueeze(0).to(device)
49
+
50
+ # Run inference
51
+ with st.spinner("Generating answer..."):
52
+ output_ids = model.generate(
53
+ input_ids,
54
+ images=image_tensor,
55
+ do_sample=True,
56
+ temperature=0.2,
57
+ max_new_tokens=200
58
+ )
59
+ out_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
60
+
61
+ st.subheader("Answer:")
62
+ st.write(out_text)