F-G Fernandez commited on
Commit
546db7e
·
1 Parent(s): bcbd897

feat: update gradio build

Browse files
Files changed (5) hide show
  1. Dockerfile +23 -0
  2. README.md +6 -3
  3. app.py +0 -71
  4. requirements.txt +5 -5
  5. src/gradio_app.py +91 -0
Dockerfile ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.13.5-slim
2
+
3
+ WORKDIR /app
4
+
5
+ RUN apt-get update && apt-get install -y \
6
+ build-essential \
7
+ curl \
8
+ git \
9
+ && rm -rf /var/lib/apt/lists/*
10
+
11
+ COPY requirements.txt ./
12
+ RUN --mount=from=ghcr.io/astral-sh/uv:0.9.5,source=/uv,target=/bin/uv \
13
+ --mount=type=cache,target=/root/.cache/uv \
14
+ --mount=type=bind,source=src.txt,target=requirements.txt \
15
+ UV_TORCH_BACKEND='cpu' uv pip install --system -r requirements.txt
16
+
17
+ COPY src/ ./src/
18
+
19
+ EXPOSE 8501
20
+
21
+ HEALTHCHECK CMD curl --fail http://localhost:8501
22
+
23
+ ENTRYPOINT ["python", "src/gradio_app.py", "--port=3000", "--repo=frgfm/rexnet1_0x"]
README.md CHANGED
@@ -3,10 +3,13 @@ title: Holocron
3
  emoji: 👁️‍🗨️
4
  colorFrom: gray
5
  colorTo: purple
6
- sdk: gradio
7
- app_file: app.py
8
- pinned: false
 
 
9
  license: apache-2.0
 
10
  ---
11
 
12
  # Configuration
 
3
  emoji: 👁️‍🗨️
4
  colorFrom: gray
5
  colorTo: purple
6
+ sdk: docker
7
+ app_port: 8501
8
+ tags:
9
+ - gradio
10
+ pinned: true
11
  license: apache-2.0
12
+ short_description: PyTorch implementations of recent Computer Vision tricks (ReXNet, RepVGG, Unet3p, YOLOv4, CIoU loss, AdaBelief, PolyLoss, MobileOne).
13
  ---
14
 
15
  # Configuration
app.py DELETED
@@ -1,71 +0,0 @@
1
- # Copyright (C) 2022, François-Guillaume Fernandez.
2
-
3
- # This program is licensed under the Apache License 2.0.
4
- # See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.
5
-
6
- import argparse
7
- import json
8
-
9
- import gradio as gr
10
- import numpy as np
11
- import onnxruntime
12
- from huggingface_hub import hf_hub_download
13
- from PIL import Image
14
-
15
-
16
- REPO = "frgfm/rexnet1_0x"
17
-
18
- # Download model config & checkpoint
19
- with open(hf_hub_download(REPO, filename="config.json"), "rb") as f:
20
- cfg = json.load(f)
21
-
22
- ort_session = onnxruntime.InferenceSession(hf_hub_download(REPO, filename="model.onnx"))
23
-
24
- def preprocess_image(pil_img: Image.Image) -> np.ndarray:
25
- """Preprocess an image for inference
26
-
27
- Args:
28
- pil_img: a valid pillow image
29
-
30
- Returns:
31
- the resized and normalized image of shape (1, C, H, W)
32
- """
33
-
34
- # Resizing (PIL takes (W, H) order for resizing)
35
- img = pil_img.resize(cfg["input_shape"][-2:][::-1], Image.BILINEAR)
36
- # (H, W, C) --> (C, H, W)
37
- img = np.asarray(img).transpose((2, 0, 1)).astype(np.float32) / 255
38
- # Normalization
39
- img -= np.array(cfg["mean"])[:, None, None]
40
- img /= np.array(cfg["std"])[:, None, None]
41
-
42
- return img[None, ...]
43
-
44
- def predict(image):
45
- # Preprocessing
46
- np_img = preprocess_image(image)
47
- ort_input = {ort_session.get_inputs()[0].name: np_img}
48
-
49
- # Inference
50
- ort_out = ort_session.run(None, ort_input)
51
- # Post-processing
52
- out_exp = np.exp(ort_out[0][0])
53
- probs = out_exp / out_exp.sum()
54
-
55
- return {class_name: float(conf) for class_name, conf in zip(cfg["classes"], probs)}
56
-
57
- img = gr.inputs.Image(type="pil")
58
- outputs = gr.outputs.Label(num_top_classes=3)
59
-
60
- gr.Interface(
61
- fn=predict,
62
- inputs=[img],
63
- outputs=outputs,
64
- title="Holocron: image classification demo",
65
- article=(
66
- "<p style='text-align: center'><a href='https://github.com/frgfm/Holocron'>"
67
- "Github Repo</a> | "
68
- "<a href='https://frgfm.github.io/Holocron/'>Documentation</a></p>"
69
- ),
70
- live=True,
71
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
- gradio>=3.0.2,<4.0.0
2
- Pillow>=8.4.0
3
- onnxruntime>=1.10.0,<2.0.0
4
- huggingface-hub>=0.4.0,<1.0.0
5
- numpy>=1.19.5,<2.0.0
 
1
+ gradio>=5.0.0,<6.0.0
2
+ huggingface-hub>=1.0.0,<2.0.0
3
+ numpy>=1.19.5,<3.0.0
4
+ onnxruntime>=1.22.0,<2.0.0
5
+ Pillow>=8.4.0,!=9.2.0
src/gradio_app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-2025, François-Guillaume Fernandez.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.
5
+
6
+ # /// script
7
+ # requires-python = ">=3.11"
8
+ # dependencies = [
9
+ # "gradio>=5.0.0,<6.0.0",
10
+ # "huggingface-hub>=1.0.0,<2.0.0",
11
+ # "numpy>=1.19.5,<3.0.0",
12
+ # "onnxruntime>=1.22.0,<2.0.0",
13
+ # "Pillow>=8.4.0,!=9.2.0",
14
+ # ]
15
+ # ///
16
+
17
+ import json
18
+ from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
19
+ from pathlib import Path
20
+
21
+ import gradio as gr
22
+ import numpy as np
23
+ import onnxruntime
24
+ from huggingface_hub import hf_hub_download
25
+ from PIL import Image
26
+
27
+
28
+ def main(args):
29
+ # Download model config & checkpoint
30
+ with Path(hf_hub_download(args.repo, filename="config.json")).open("rb") as f:
31
+ cfg = json.load(f)
32
+
33
+ ort_session = onnxruntime.InferenceSession(hf_hub_download(args.repo, filename="model.onnx"))
34
+
35
+ def preprocess_image(pil_img: Image.Image) -> np.ndarray:
36
+ """Preprocess an image for inference
37
+
38
+ Args:
39
+ pil_img: a valid pillow image
40
+
41
+ Returns:
42
+ the resized and normalized image of shape (1, C, H, W)
43
+ """
44
+ # Resizing (PIL takes (W, H) order for resizing)
45
+ img = pil_img.resize(cfg["input_shape"][-2:][::-1], Image.BILINEAR)
46
+ # (H, W, C) --> (C, H, W)
47
+ img = np.asarray(img).transpose((2, 0, 1)).astype(np.float32) / 255
48
+ # Normalization
49
+ img -= np.array(cfg["mean"])[:, None, None]
50
+ img /= np.array(cfg["std"])[:, None, None]
51
+
52
+ return img[None, ...]
53
+
54
+ def predict(image):
55
+ # Preprocessing
56
+ np_img = preprocess_image(image)
57
+ ort_input = {ort_session.get_inputs()[0].name: np_img}
58
+
59
+ # Inference
60
+ ort_out = ort_session.run(None, ort_input)
61
+ # Post-processing
62
+ out_exp = np.exp(ort_out[0][0])
63
+ probs = out_exp / out_exp.sum()
64
+
65
+ return {class_name: float(conf) for class_name, conf in zip(cfg["classes"], probs, strict=True)}
66
+
67
+ interface = gr.Interface(
68
+ fn=predict,
69
+ inputs=gr.Image(type="pil"),
70
+ outputs=gr.Label(num_top_classes=3),
71
+ title="Holocron: image classification demo",
72
+ article=(
73
+ "<p style='text-align: center'><a href='https://github.com/frgfm/holocron'>"
74
+ "Github Repo</a> | "
75
+ "<a href='https://frgfm.github.io/holocron/'>Documentation</a></p>"
76
+ ),
77
+ live=True,
78
+ )
79
+
80
+ interface.launch(server_port=args.port, show_error=True)
81
+
82
+
83
+ if __name__ == "__main__":
84
+ parser = ArgumentParser(
85
+ description="Holocron image classification demo", formatter_class=ArgumentDefaultsHelpFormatter
86
+ )
87
+ parser.add_argument("--repo", type=str, default="frgfm/rexnet1_0x", help="HF Hub repo to use")
88
+ parser.add_argument("--port", type=int, default=8001, help="Port on which the webserver will be run")
89
+ args = parser.parse_args()
90
+
91
+ main(args)