Olsc commited on
Commit
04b19b4
·
verified ·
1 Parent(s): df782d9

更新代码

Browse files
Files changed (4) hide show
  1. LICENSE +47 -0
  2. README.md +37 -0
  3. export_onnx.py +259 -0
  4. webui.py +399 -0
LICENSE ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (C) 2025 Apple Inc. All Rights Reserved.
2
+
3
+ Disclaimer: IMPORTANT: This Apple software is supplied to you by Apple
4
+ Inc. ("Apple") in consideration of your agreement to the following
5
+ terms, and your use, installation, modification or redistribution of
6
+ this Apple software constitutes acceptance of these terms. If you do
7
+ not agree with these terms, please do not use, install, modify or
8
+ redistribute this Apple software.
9
+
10
+ In consideration of your agreement to abide by the following terms, and
11
+ subject to these terms, Apple grants you a personal, non-exclusive
12
+ license, under Apple's copyrights in this original Apple software (the
13
+ "Apple Software"), to use, reproduce, modify and redistribute the Apple
14
+ Software, with or without modifications, in source and/or binary forms;
15
+ provided that if you redistribute the Apple Software in its entirety and
16
+ without modifications, you must retain this notice and the following
17
+ text and disclaimers in all such redistributions of the Apple Software.
18
+ Neither the name, trademarks, service marks or logos of Apple Inc. may
19
+ be used to endorse or promote products derived from the Apple Software
20
+ without specific prior written permission from Apple. Except as
21
+ expressly stated in this notice, no other rights or licenses, express or
22
+ implied, are granted by Apple herein, including but not limited to any
23
+ patent rights that may be infringed by your derivative works or by other
24
+ works in which the Apple Software may be incorporated.
25
+
26
+ The Apple Software is provided by Apple on an "AS IS" basis. APPLE
27
+ MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION
28
+ THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS
29
+ FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND
30
+ OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS.
31
+
32
+ IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL
33
+ OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
34
+ SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
35
+ INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION,
36
+ MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED
37
+ AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE),
38
+ STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE
39
+ POSSIBILITY OF SUCH DAMAGE.
40
+
41
+
42
+ -------------------------------------------------------------------------------
43
+ SOFTWARE DISTRIBUTED IN THIS REPOSITORY:
44
+
45
+ This software includes a number of subcomponents with separate
46
+ copyright notices and license terms - please see the file ACKNOWLEDGEMENTS.
47
+ -------------------------------------------------------------------------------
README.md ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: onnxruntime
3
+ tags:
4
+ - onnx
5
+ - sharp
6
+ - view-synthesis
7
+ - 3d-gaussian-splatting
8
+ base_model: apple/Sharp
9
+ ---
10
+
11
+ # SHARP — ONNX Exports (INT8)
12
+
13
+ This repository provides ONNX exports of Apple’s **SHARP** model predictor, which outputs **NDC-space 3D Gaussians**.
14
+
15
+ ## Export to ONNX (INT8)
16
+
17
+ Run the following command to export the model:
18
+
19
+ ```bash
20
+ python export_onnx.py --int8
21
+ ````
22
+
23
+ ## Launch Web UI
24
+
25
+ Start the web interface with:
26
+
27
+ ```bash
28
+ python webui.py \
29
+ --host 0.0.0.0 \
30
+ --port 7860 \
31
+ --ssl_cert cert.pem \
32
+ --ssl_key key.pem
33
+ ```
34
+
35
+ ## Notes
36
+
37
+ This project is for learning purposes. Please refer to the LICENSE for usage terms.
export_onnx.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+ from pathlib import Path
5
+
6
+ import warnings
7
+ import torch
8
+ from sharp.models import PredictorParams, create_predictor
9
+
10
+ LOGGER = logging.getLogger(__name__)
11
+
12
+ # 默认权重文件名
13
+ DEFAULT_CHECKPOINT_NAME = "sharp_2572gikvuh.pt"
14
+ MODEL_URL = f"https://ml-site.cdn-apple.com/models/sharp/{DEFAULT_CHECKPOINT_NAME}"
15
+
16
+ def verify_model_precision(model_path: Path):
17
+ """
18
+ 静态检测 ONNX 模型的权重数据类型分布,用于验证量化是否生效。
19
+ """
20
+ try:
21
+ import onnx
22
+ from collections import Counter
23
+ # 尝试加载模型,注意大模型可能带 .data 后缀
24
+ model = onnx.load(str(model_path))
25
+ # 统计所有初始化器(权重)的数据类型
26
+ # 在 ONNX 中,FLOAT=1, UINT8=2, INT8=3, FLOAT16=10, UINT4=27, INT4=28, FLOAT4=29
27
+ dtypes = [onnx.TensorProto.DataType.Name(i.data_type) for i in model.graph.initializer]
28
+ counter = Counter(dtypes)
29
+
30
+ LOGGER.info(f"--- 权重类型分布报告 [{model_path.name}] ---")
31
+ for dtype, count in counter.items():
32
+ LOGGER.info(f" - {dtype}: {count} 个张量")
33
+
34
+ # 计算文件大小并打印
35
+ file_size = os.path.getsize(model_path)
36
+ LOGGER.info(f" [文件信息] 路径: {model_path}")
37
+ LOGGER.info(f" [文件信息] 体积: {file_size / (1024*1024):.2f} MB")
38
+ LOGGER.info("------------------------------------------------")
39
+ except ImportError:
40
+ LOGGER.warning("未安装 'onnx' 模块,跳过权重精度验证。")
41
+ except Exception as e:
42
+ LOGGER.warning(f"验证模型精度时发生异常: {e}")
43
+
44
+ def export_onnx(output_path: Path, checkpoint_path: str = None, internal_shape: int = 1536, is_fp16: bool = False):
45
+ device = torch.device("cpu")
46
+
47
+ LOGGER.info("正在创建预测器模型...")
48
+ predictor = create_predictor(PredictorParams())
49
+
50
+ # FP16 处理:将网络参数转为半精度,体积减半
51
+ if is_fp16:
52
+ LOGGER.info("正在将模型转换为 FP16 (半精度) 以直接将其缩减进单一 < 2GB 的文件中...")
53
+ predictor = predictor.half()
54
+
55
+ if checkpoint_path is None:
56
+ torch_hub_checkpoints = Path(torch.hub.get_dir()) / "checkpoints"
57
+ search_paths = [
58
+ Path(DEFAULT_CHECKPOINT_NAME),
59
+ Path("data") / DEFAULT_CHECKPOINT_NAME,
60
+ torch_hub_checkpoints / DEFAULT_CHECKPOINT_NAME,
61
+ ]
62
+
63
+ found_path = None
64
+ for p in search_paths:
65
+ if p.exists():
66
+ found_path = p
67
+ break
68
+
69
+ if found_path:
70
+ LOGGER.info(f"在本地找到模型权重: {found_path},加载本地权重...")
71
+ state_dict = torch.load(found_path, map_location="cpu", weights_only=True)
72
+ else:
73
+ state_dict = torch.hub.load_state_dict_from_url(MODEL_URL, progress=True, map_location="cpu")
74
+ else:
75
+ state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
76
+
77
+ predictor.load_state_dict(state_dict)
78
+ predictor.eval()
79
+ predictor.to(device)
80
+
81
+ dummy_image = torch.randn(1, 3, internal_shape, internal_shape, device=device)
82
+ dummy_disparity = torch.tensor([1.0], device=device)
83
+
84
+ if is_fp16:
85
+ dummy_image = dummy_image.half()
86
+ dummy_disparity = dummy_disparity.half()
87
+
88
+ dummy_inputs = (dummy_image, dummy_disparity)
89
+
90
+ LOGGER.info("正在导出为 ONNX 格式...")
91
+ torch.onnx.export(
92
+ predictor,
93
+ dummy_inputs,
94
+ str(output_path),
95
+ export_params=True,
96
+ opset_version=19,
97
+ do_constant_folding=True,
98
+ input_names=['image', 'disparity_factor'],
99
+ output_names=['mean_vectors', 'singular_values', 'quaternions', 'colors', 'opacities'],
100
+ keep_initializers_as_inputs=False,
101
+ dynamic_axes={k: {0: 'batch_size'} for k in ['image', 'disparity_factor', 'mean_vectors', 'singular_values', 'quaternions', 'colors', 'opacities']}
102
+ )
103
+ LOGGER.info(f"成功将全精度 ONNX 模型导出至: {output_path}")
104
+
105
+ def main():
106
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
107
+ parser = argparse.ArgumentParser(description="将 ML-SHARP 模型导出为多种精度的 ONNX 格式。")
108
+ parser.add_argument("--output", type=str, default="sharp.onnx", help="导出的基础 ONNX 模型保存路径。")
109
+ parser.add_argument("--checkpoint", type=str, default=None, help="本地 .pt 权重文件路径。")
110
+ parser.add_argument("--fp16", action="store_true", help="启用 FP16 半精度导出。")
111
+ parser.add_argument("--int8", action="store_true", help="应用动态 INT8 量化。")
112
+ parser.add_argument("--int4", action="store_true", help="应用真正的 INT4 重量级量化 (WOQ)。")
113
+ args = parser.parse_args()
114
+
115
+ # 如果启用了 FP16 模式且用户未指定 output,则重定向默认输出名
116
+ if args.fp16 and args.output == "sharp.onnx":
117
+ args.output = "sharp_fp16.onnx"
118
+
119
+ output_path = Path(args.output)
120
+ is_quant_requested = args.int8 or args.int4
121
+
122
+ if is_quant_requested and output_path.exists():
123
+ LOGGER.info(f"检测到基础模型 '{output_path}' 已存在,跳过导出并开始验证。")
124
+ verify_model_precision(output_path)
125
+ else:
126
+ export_onnx(output_path, args.checkpoint, is_fp16=args.fp16)
127
+ verify_model_precision(output_path)
128
+
129
+ if is_quant_requested:
130
+ try:
131
+ from onnxruntime.quantization import quantize_dynamic, QuantType
132
+ import onnx
133
+ from onnx import shape_inference
134
+ except ImportError:
135
+ LOGGER.error("缺失依赖项:请安装 onnx 和 onnxruntime。")
136
+ return
137
+
138
+ preprocessed_path = output_path.with_name(f"{output_path.stem}_pre.onnx")
139
+ if not preprocessed_path.exists():
140
+ LOGGER.info("正在将基础模型转换为外部数据格式 (已跳过导致崩溃的形状推导模式)...")
141
+ try:
142
+ model_raw = onnx.load(str(output_path))
143
+ onnx.save(
144
+ model_raw,
145
+ str(preprocessed_path),
146
+ save_as_external_data=True,
147
+ all_tensors_to_one_file=True,
148
+ location=preprocessed_path.name + ".data",
149
+ size_threshold=1024,
150
+ convert_attribute=True
151
+ )
152
+ model_source = preprocessed_path
153
+ except Exception as e:
154
+ LOGGER.warning(f"预分流模型保存失败: {e}")
155
+ model_source = output_path
156
+ else:
157
+ model_source = preprocessed_path
158
+
159
+ tasks = []
160
+ if args.int8: tasks.append("int8")
161
+ if args.int4: tasks.append("int4")
162
+
163
+ # 1. INT8 量化
164
+ int8_path = output_path.with_name(f"{output_path.stem}_int8.onnx")
165
+ if "int8" in tasks:
166
+ LOGGER.info("正在应用动态 INT8 量化...")
167
+ quantize_dynamic(str(model_source), str(int8_path), weight_type=QuantType.QUInt8)
168
+ verify_model_precision(int8_path)
169
+
170
+ # 2. INT4 重量级量化 (WOQ)
171
+ if "int4" in tasks:
172
+ LOGGER.info("正在准备进行真正的 INT4 权重压缩 (WOQ)...")
173
+ quantized_output = output_path.with_name(f"{output_path.stem}_int4.onnx")
174
+
175
+ import onnxruntime.quantization as oq
176
+ import importlib
177
+ import pkgutil
178
+
179
+ # 优先尝试新版 Unified API
180
+ quantize_func = getattr(oq, "quantize", None)
181
+ WeightOnlyConfig = getattr(oq, "MatMulWeightOnlyQuantConfig", None)
182
+
183
+ if quantize_func and WeightOnlyConfig:
184
+ LOGGER.info("[INT4] 使用 MatMulWeightOnlyQuantConfig 执行量化...")
185
+ try:
186
+ model = onnx.load(str(model_source))
187
+ config = WeightOnlyConfig(weight_type=getattr(QuantType, 'QInt4', QuantType.QUInt8))
188
+ quantize_func(model, str(quantized_output), config)
189
+ LOGGER.info(f"[INT4] Unified API 导出成功!")
190
+ verify_model_precision(quantized_output)
191
+ return
192
+ except Exception as e:
193
+ LOGGER.warning(f"[INT4] Unified API 失败: {e}")
194
+
195
+ # Fallback 1: 尝试 Bnb4Quantizer
196
+ Bnb4Quantizer = None
197
+ try:
198
+ from onnxruntime.quantization.matmul_bnb4_quantizer import MatMulBnb4Quantizer
199
+ Bnb4Quantizer = MatMulBnb4Quantizer
200
+ except: pass
201
+
202
+ if Bnb4Quantizer:
203
+ LOGGER.info("[INT4] 使用 MatMulBnb4Quantizer (BnB 4-bit) 执行量化...")
204
+ try:
205
+ import inspect
206
+ model = onnx.load(str(model_source))
207
+ sig = inspect.signature(Bnb4Quantizer.__init__)
208
+ params = sig.parameters
209
+ kwargs = {"model": model, "block_size": 32}
210
+ if "quant_type" in params:
211
+ kwargs["quant_type"] = 1
212
+
213
+ q = Bnb4Quantizer(**kwargs)
214
+ q.process()
215
+ m = q.model.model if hasattr(q, 'model') and hasattr(q.model, 'model') else (q.model if hasattr(q, 'model') else q)
216
+ LOGGER.info(f"正在保存单体 INT4 模型...")
217
+ onnx.save(m, str(quantized_output))
218
+ LOGGER.info(f"[INT4] BnB 导出成功!")
219
+ verify_model_precision(quantized_output)
220
+ return
221
+ except Exception as e:
222
+ LOGGER.warning(f"[INT4] BnB 失败: {e}")
223
+
224
+ # Fallback 2: 暴力搜索兼容旧版 Class API
225
+ def search_legacy(pkg):
226
+ for attr in dir(pkg):
227
+ low = attr.lower()
228
+ if ("quantiz" in low or "quantis" in low) and ("4bit" in low or "nbit" in low or "weight" in low):
229
+ return getattr(pkg, attr), pkg.__name__
230
+ if hasattr(pkg, "__path__"):
231
+ for _, modname, _ in pkgutil.walk_packages(pkg.__path__, pkg.__name__ + "."):
232
+ try:
233
+ mod = importlib.import_module(modname)
234
+ res, src = search_legacy(mod)
235
+ if res: return res, src
236
+ except: continue
237
+ return None, None
238
+
239
+ LegacyQuantizer, src = search_legacy(oq)
240
+ if LegacyQuantizer and (isinstance(LegacyQuantizer, type) or callable(LegacyQuantizer)):
241
+ LOGGER.info(f"[INT4] 找到量化类/函数: {src}.{LegacyQuantizer.__name__}")
242
+ try:
243
+ model = onnx.load(str(model_source))
244
+ if "Quantizer" in LegacyQuantizer.__name__:
245
+ q = LegacyQuantizer(model, block_size=32, is_symmetric=True)
246
+ q.process()
247
+ m = q.model.model if hasattr(q, 'model') and hasattr(q.model, 'model') else (q.model if hasattr(q, 'model') else q)
248
+ onnx.save(m, str(quantized_output))
249
+ else:
250
+ LegacyQuantizer(model, str(quantized_output), weight_type=QuantType.QInt4)
251
+ LOGGER.info(f"[INT4] 导出成功!")
252
+ verify_model_precision(quantized_output)
253
+ except Exception as e:
254
+ LOGGER.error(f"[INT4] 执行失败: {e}")
255
+ else:
256
+ LOGGER.error(f"[INT4] 未找到有效量化器。所有属性: {dir(oq)}")
257
+
258
+ if __name__ == "__main__":
259
+ main()
webui.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import gradio as gr
3
+ import onnxruntime as ort
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import numpy as np
7
+ import tempfile
8
+ import os
9
+ import time
10
+ import logging
11
+ import argparse
12
+ import sys
13
+ from pathlib import Path
14
+
15
+ sys.path.append(os.path.join(os.getcwd(), 'src'))
16
+
17
+ from sharp.utils import io
18
+ from sharp.utils.gaussians import Gaussians3D, save_ply, unproject_gaussians
19
+
20
+ warnings.filterwarnings("ignore")
21
+ logging.basicConfig(level=logging.INFO)
22
+ LOGGER = logging.getLogger(__name__)
23
+
24
+ SESSION = None
25
+ CURRENT_MODEL_PATH = None
26
+
27
+ def load_model(model_path):
28
+ global SESSION, CURRENT_MODEL_PATH
29
+ if SESSION is not None and CURRENT_MODEL_PATH == model_path:
30
+ return SESSION
31
+
32
+ if not Path(model_path).exists():
33
+ LOGGER.error(f"Model file not found: {model_path}")
34
+ return None
35
+
36
+ try:
37
+ LOGGER.info(f"Loading model: {model_path}...")
38
+ options = ort.SessionOptions()
39
+ options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
40
+ options.intra_op_num_threads = os.cpu_count() or 4
41
+ options.inter_op_num_threads = min(4, os.cpu_count() or 4)
42
+ options.enable_mem_pattern = True
43
+ options.enable_cpu_mem_arena = True
44
+ providers = ['CPUExecutionProvider']
45
+
46
+ # Free old session memory if possible
47
+ if SESSION is not None:
48
+ del SESSION
49
+
50
+ SESSION = ort.InferenceSession(model_path, sess_options=options, providers=providers)
51
+ CURRENT_MODEL_PATH = model_path
52
+ LOGGER.info(f"Model loaded successfully: {model_path}")
53
+ return SESSION
54
+ except Exception as e:
55
+ LOGGER.error(f"Failed to load model {model_path}: {e}")
56
+ return None
57
+
58
+ def get_available_models():
59
+ models = list(Path('.').glob('*.onnx'))
60
+ return [str(m) for m in models]
61
+
62
+ def process_image(image_filepath, opacity_threshold, downsample_rate):
63
+ if not image_filepath:
64
+ return None
65
+ if SESSION is None:
66
+ gr.Warning("Model not loaded. Using dummy processing or check console.")
67
+ return None
68
+
69
+ start_time = time.perf_counter()
70
+ img, _, f_px = io.load_rgb(Path(image_filepath), auto_rotate=True, remove_alpha=True)
71
+ height, width = img.shape[:2]
72
+ image_pt = torch.from_numpy(img.copy()).float().permute(2, 0, 1) / 255.0
73
+ disparity_factor = torch.tensor([f_px / width]).float()
74
+ internal_shape = (1536, 1536)
75
+ image_resized_pt = F.interpolate(
76
+ image_pt[None], size=(internal_shape[1], internal_shape[0]), mode="bilinear", align_corners=True
77
+ )
78
+ model_inputs = SESSION.get_inputs()
79
+ if model_inputs[0].type == 'tensor(float16)':
80
+ image_resized_pt = image_resized_pt.half()
81
+ disparity_factor = disparity_factor.half()
82
+ inputs = {'image': image_resized_pt.numpy(), 'disparity_factor': disparity_factor.numpy()}
83
+ outputs = SESSION.run(None, inputs)
84
+ gaussians_ndc = Gaussians3D(
85
+ mean_vectors=torch.from_numpy(outputs[0]).float(),
86
+ singular_values=torch.from_numpy(outputs[1]).float(),
87
+ quaternions=torch.from_numpy(outputs[2]).float(),
88
+ colors=torch.from_numpy(outputs[3]).float(),
89
+ opacities=torch.from_numpy(outputs[4]).float()
90
+ )
91
+ mask = gaussians_ndc.opacities[0] > opacity_threshold
92
+ sampler = slice(0, None, int(downsample_rate))
93
+ def apply_mask_and_sampling(tensor):
94
+ return tensor[:, mask][:, sampler]
95
+ filtered_gaussians_ndc = Gaussians3D(
96
+ mean_vectors=apply_mask_and_sampling(gaussians_ndc.mean_vectors),
97
+ singular_values=apply_mask_and_sampling(gaussians_ndc.singular_values),
98
+ quaternions=apply_mask_and_sampling(gaussians_ndc.quaternions),
99
+ colors=apply_mask_and_sampling(gaussians_ndc.colors),
100
+ opacities=apply_mask_and_sampling(gaussians_ndc.opacities)
101
+ )
102
+ intrinsics = torch.tensor([
103
+ [f_px, 0, width / 2, 0],
104
+ [0, f_px, height / 2, 0],
105
+ [0, 0, 1, 0],
106
+ [0, 0, 0, 1],
107
+ ]).float()
108
+ intrinsics_resized = intrinsics.clone()
109
+ intrinsics_resized[0] *= internal_shape[0] / width
110
+ intrinsics_resized[1] *= internal_shape[1] / height
111
+ gaussians = unproject_gaussians(filtered_gaussians_ndc, torch.eye(4), intrinsics_resized, internal_shape)
112
+ out_dir = Path(tempfile.mkdtemp())
113
+ out_path = out_dir / "output.ply"
114
+ save_ply(gaussians, f_px, (height, width), out_path)
115
+ return str(out_path)
116
+
117
+ custom_css = """
118
+ body, .gradio-container {
119
+ background: radial-gradient(circle at top left, #0d0d12 0%, #000000 100%) !important;
120
+ color: #e0e0e0 !important;
121
+ font-family: 'Inter', system-ui, -apple-system, sans-serif !important;
122
+ margin: 0 !important;
123
+ padding: 0 !important;
124
+ }
125
+ .panel-box {
126
+ background: rgba(20, 20, 25, 0.8) !important;
127
+ backdrop-filter: blur(10px);
128
+ border: 1px solid rgba(255, 255, 255, 0.1) !important;
129
+ border-radius: 20px !important;
130
+ padding: 24px;
131
+ box-shadow: 0 8px 32px 0 rgba(0, 0, 0, 0.8);
132
+ transition: all 0.3s ease;
133
+ margin-bottom: 16px;
134
+ }
135
+ #spark-container {
136
+ width: 100%;
137
+ height: 70vh; /* Responsive height */
138
+ min-height: 400px;
139
+ max-height: 720px;
140
+ background: #000;
141
+ border-radius: 12px;
142
+ border: 1px solid rgba(255, 255, 255, 0.1);
143
+ position: relative;
144
+ overflow: hidden;
145
+ }
146
+ #generate-btn {
147
+ background: linear-gradient(135deg, #6366f1 0%, #a855f7 100%) !important;
148
+ color: white !important;
149
+ font-weight: 700 !important;
150
+ border-radius: 12px !important;
151
+ border: none !important;
152
+ margin-top: 10px;
153
+ padding: 16px 24px !important; /* Larger for touch */
154
+ text-transform: uppercase;
155
+ letter-spacing: 1px;
156
+ font-size: 1.1rem !important;
157
+ transition: transform 0.2s, box-shadow 0.2s !important;
158
+ }
159
+ header h1 {
160
+ background: linear-gradient(to right, #fff, #a5a5a5);
161
+ -webkit-background-clip: text;
162
+ -webkit-text-fill-color: transparent;
163
+ font-size: 2rem !important;
164
+ font-weight: 900 !important;
165
+ text-align: center;
166
+ margin: 20px 0 !important;
167
+ }
168
+
169
+ /* Mobile Optimizations */
170
+ @media (max-width: 768px) {
171
+ .panel-box {
172
+ padding: 16px;
173
+ border-radius: 16px !important;
174
+ }
175
+ #spark-container {
176
+ height: 50vh; /* Shorter on mobile to leave room for controls */
177
+ min-height: 300px;
178
+ }
179
+ header h1 {
180
+ font-size: 1.5rem !important;
181
+ }
182
+ .gr-row {
183
+ flex-direction: column !important;
184
+ }
185
+ /* Make inputs full width on mobile */
186
+ .gr-form {
187
+ width: 100% !important;
188
+ }
189
+ }
190
+ """
191
+
192
+ head_content = """
193
+ <meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no">
194
+ <link rel="preconnect" href="https://fonts.googleapis.com">
195
+ <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
196
+ <link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;700;900&display=swap" rel="stylesheet">
197
+ <script type="importmap">
198
+ {
199
+ "imports": {
200
+ "three": "https://cdnjs.cloudflare.com/ajax/libs/three.js/0.178.0/three.module.js",
201
+ "@sparkjsdev/spark": "https://sparkjs.dev/releases/spark/0.1.10/spark.module.js"
202
+ }
203
+ }
204
+ </script>
205
+ <script type="module">
206
+ import * as THREE from "three";
207
+ import { OrbitControls } from "https://unpkg.com/three@0.178.0/examples/jsm/controls/OrbitControls.js";
208
+ import { SplatMesh } from "@sparkjsdev/spark";
209
+
210
+ let renderer, scene, camera, controls, splat, container;
211
+ let startTime, timerInterval;
212
+
213
+ window.initSpark = function() {
214
+ container = document.getElementById('spark-container');
215
+ if (!container || window.sparkInitialized) return;
216
+ scene = new THREE.Scene();
217
+ camera = new THREE.PerspectiveCamera(60, container.clientWidth / container.clientHeight, 0.1, 1000);
218
+ camera.position.set(0, 1, 4);
219
+ renderer = new THREE.WebGLRenderer({ antialias: true, alpha: true, logarithmicDepthBuffer: true });
220
+ renderer.setSize(container.clientWidth, container.clientHeight);
221
+ renderer.setPixelRatio(Math.min(window.devicePixelRatio, 2));
222
+ container.appendChild(renderer.domElement);
223
+ controls = new OrbitControls(camera, renderer.domElement);
224
+ controls.enableDamping = true;
225
+ function animate() {
226
+ requestAnimationFrame(animate);
227
+ controls.update();
228
+ renderer.render(scene, camera);
229
+ }
230
+ animate();
231
+ window.addEventListener('resize', () => {
232
+ if (!container) return;
233
+ camera.aspect = container.clientWidth / container.clientHeight;
234
+ camera.updateProjectionMatrix();
235
+ renderer.setSize(container.clientWidth, container.clientHeight);
236
+ });
237
+ window.sparkInitialized = true;
238
+ };
239
+
240
+ window.loadSplat = async function(url) {
241
+ if (!window.sparkInitialized) window.initSpark();
242
+ if (splat) { scene.remove(splat); splat.dispose(); }
243
+ try {
244
+ splat = new SplatMesh({ url: url });
245
+ splat.rotation.x = Math.PI;
246
+ scene.add(splat);
247
+ setTimeout(window.focusModel, 500);
248
+ } catch (e) { console.error(e); }
249
+ };
250
+
251
+ window.focusModel = function() {
252
+ if (!splat || !controls || !camera) return;
253
+ const box = new THREE.Box3();
254
+ let pointsFound = 0;
255
+ splat.traverse((obj) => {
256
+ if (obj.geometry && obj.geometry.attributes.position) {
257
+ const pos = obj.geometry.attributes.position;
258
+ const count = pos.count;
259
+ const step = Math.max(1, Math.floor(count / 5000));
260
+ for (let i = 0; i < count; i += step) {
261
+ const p = new THREE.Vector3(pos.getX(i), pos.getY(i), pos.getZ(i));
262
+ p.applyMatrix4(obj.matrixWorld);
263
+ box.expandByPoint(p);
264
+ }
265
+ pointsFound += count;
266
+ }
267
+ });
268
+ let center = new THREE.Vector3();
269
+ let size = new THREE.Vector3();
270
+ if (pointsFound === 0 || box.isEmpty()) { center.set(0, 1.5, -3); size.set(2, 2, 2); }
271
+ else { box.getCenter(center); box.getSize(size); }
272
+ const maxDim = Math.max(size.x, size.y, size.z);
273
+ const fovRad = camera.fov * (Math.PI / 180);
274
+ let distance = (maxDim / 2) / Math.tan(fovRad / 2) * 1.5;
275
+ controls.target.copy(center);
276
+ camera.position.set(center.x, center.y, center.z + distance);
277
+ controls.update();
278
+ };
279
+
280
+ function getBtn() {
281
+ return document.getElementById('generate-btn') || document.querySelector('#generate-btn button');
282
+ }
283
+
284
+ window.startTimer = function() {
285
+ const btn = getBtn();
286
+ if (!btn) return;
287
+ btn.disabled = true;
288
+ btn.style.opacity = "0.6";
289
+ btn.style.cursor = "wait";
290
+ startTime = Date.now();
291
+ timerInterval = setInterval(() => {
292
+ const elapsed = ((Date.now() - startTime) / 1000).toFixed(1);
293
+ btn.innerText = `Generating... ${elapsed}s`;
294
+ }, 100);
295
+ };
296
+
297
+ window.stopTimer = function() {
298
+ if (timerInterval) {
299
+ clearInterval(timerInterval);
300
+ const elapsed = ((Date.now() - startTime) / 1000).toFixed(1);
301
+ const btn = getBtn();
302
+ if (btn) btn.innerText = `Done in ${elapsed}s`;
303
+ }
304
+ };
305
+
306
+ window.resetBtn = function(hasImage) {
307
+ const btn = getBtn();
308
+ if (btn) {
309
+ btn.disabled = !hasImage;
310
+ btn.style.opacity = hasImage ? "1.0" : "0.5";
311
+ btn.style.cursor = hasImage ? "pointer" : "default";
312
+ btn.innerText = "Generate 3D Gaussians";
313
+ }
314
+ if (timerInterval) clearInterval(timerInterval);
315
+ };
316
+ </script>
317
+ """
318
+
319
+ if __name__ == "__main__":
320
+ parser = argparse.ArgumentParser()
321
+ parser.add_argument("--model", type=str, default="ml-sharp_int4.onnx")
322
+ parser.add_argument("--host", type=str, default="0.0.0.0")
323
+ parser.add_argument("--port", type=int, default=7860)
324
+ parser.add_argument("--ssl_cert", type=str, default="cert.pem")
325
+ parser.add_argument("--ssl_key", type=str, default="key.pem")
326
+ args = parser.parse_args()
327
+
328
+ # Pre-load initial model
329
+ load_model(args.model)
330
+
331
+ # Define UI inside __main__ or pass args to it
332
+ with gr.Blocks(css=custom_css, theme=gr.themes.Default(), head=head_content, title="SHARP 3D Recon") as demo:
333
+ gr.HTML("<header><h1>SHARP 3D RECONSTRUCTION</h1></header>")
334
+ with gr.Row():
335
+ with gr.Column(scale=1):
336
+ with gr.Group(elem_classes="panel-box"):
337
+ available_models = get_available_models()
338
+ model_selector = gr.Dropdown(
339
+ choices=available_models,
340
+ value=args.model if args.model in available_models else (available_models[0] if available_models else None),
341
+ label="Select ONNX Model (Precision)",
342
+ interactive=True
343
+ )
344
+
345
+ input_image = gr.Image(
346
+ type="filepath",
347
+ label="Capture or Upload Image",
348
+ height=400,
349
+ sources=["upload", "webcam"]
350
+ )
351
+
352
+ with gr.Accordion("Advanced Settings", open=False):
353
+ opacity_val = gr.Slider(0.0, 1.0, value=0.0, label="Opacity Threshold")
354
+ downsample_val = gr.Slider(1, 10, step=1, value=1, label="Downsample Rate")
355
+
356
+ submit_btn = gr.Button("Generate 3D Gaussians", variant="primary", elem_id="generate-btn", interactive=False)
357
+
358
+ gr.Markdown("Capture a photo from your phone or upload an image to start the real-time 3D conversion.")
359
+
360
+ with gr.Column(scale=2):
361
+ with gr.Group(elem_classes="panel-box"):
362
+ gr.HTML("<div id='spark-container'></div>")
363
+ output_file = gr.File(label="Output Model", visible=False)
364
+
365
+ demo.load(fn=None, inputs=None, outputs=None, js="() => { setTimeout(window.initSpark, 500); window.resetBtn(false); }")
366
+
367
+ model_selector.change(fn=load_model, inputs=[model_selector], outputs=None)
368
+
369
+ input_image.change(fn=lambda x: x is not None, inputs=[input_image], outputs=None, js="(img) => { window.resetBtn(!!img); }")
370
+
371
+ submit_btn.click(
372
+ fn=process_image,
373
+ inputs=[input_image, opacity_val, downsample_val],
374
+ outputs=[output_file],
375
+ js="(img, op, down) => { window.startTimer(); return [img, op, down]; }"
376
+ )
377
+
378
+ input_image.upload(
379
+ fn=process_image,
380
+ inputs=[input_image, opacity_val, downsample_val],
381
+ outputs=[output_file],
382
+ js="(img, op, down) => { window.startTimer(); return [img, op, down]; }"
383
+ )
384
+
385
+ output_file.change(
386
+ fn=None,
387
+ inputs=[output_file],
388
+ js="(f) => { window.stopTimer(); if (f && f.url) { window.loadSplat(f.url); } }"
389
+ )
390
+
391
+ # Launch with HTTPS support
392
+ demo.queue().launch(
393
+ server_name=args.host,
394
+ server_port=args.port,
395
+ share=False,
396
+ ssl_certfile=args.ssl_cert if os.path.exists(args.ssl_cert) else None,
397
+ ssl_keyfile=args.ssl_key if os.path.exists(args.ssl_key) else None,
398
+ ssl_verify=False
399
+ )