更新代码
Browse files
LICENSE
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
| 2 |
+
|
| 3 |
+
Disclaimer: IMPORTANT: This Apple software is supplied to you by Apple
|
| 4 |
+
Inc. ("Apple") in consideration of your agreement to the following
|
| 5 |
+
terms, and your use, installation, modification or redistribution of
|
| 6 |
+
this Apple software constitutes acceptance of these terms. If you do
|
| 7 |
+
not agree with these terms, please do not use, install, modify or
|
| 8 |
+
redistribute this Apple software.
|
| 9 |
+
|
| 10 |
+
In consideration of your agreement to abide by the following terms, and
|
| 11 |
+
subject to these terms, Apple grants you a personal, non-exclusive
|
| 12 |
+
license, under Apple's copyrights in this original Apple software (the
|
| 13 |
+
"Apple Software"), to use, reproduce, modify and redistribute the Apple
|
| 14 |
+
Software, with or without modifications, in source and/or binary forms;
|
| 15 |
+
provided that if you redistribute the Apple Software in its entirety and
|
| 16 |
+
without modifications, you must retain this notice and the following
|
| 17 |
+
text and disclaimers in all such redistributions of the Apple Software.
|
| 18 |
+
Neither the name, trademarks, service marks or logos of Apple Inc. may
|
| 19 |
+
be used to endorse or promote products derived from the Apple Software
|
| 20 |
+
without specific prior written permission from Apple. Except as
|
| 21 |
+
expressly stated in this notice, no other rights or licenses, express or
|
| 22 |
+
implied, are granted by Apple herein, including but not limited to any
|
| 23 |
+
patent rights that may be infringed by your derivative works or by other
|
| 24 |
+
works in which the Apple Software may be incorporated.
|
| 25 |
+
|
| 26 |
+
The Apple Software is provided by Apple on an "AS IS" basis. APPLE
|
| 27 |
+
MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION
|
| 28 |
+
THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS
|
| 29 |
+
FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND
|
| 30 |
+
OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS.
|
| 31 |
+
|
| 32 |
+
IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL
|
| 33 |
+
OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
| 34 |
+
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
| 35 |
+
INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION,
|
| 36 |
+
MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED
|
| 37 |
+
AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE),
|
| 38 |
+
STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE
|
| 39 |
+
POSSIBILITY OF SUCH DAMAGE.
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
-------------------------------------------------------------------------------
|
| 43 |
+
SOFTWARE DISTRIBUTED IN THIS REPOSITORY:
|
| 44 |
+
|
| 45 |
+
This software includes a number of subcomponents with separate
|
| 46 |
+
copyright notices and license terms - please see the file ACKNOWLEDGEMENTS.
|
| 47 |
+
-------------------------------------------------------------------------------
|
README.md
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: onnxruntime
|
| 3 |
+
tags:
|
| 4 |
+
- onnx
|
| 5 |
+
- sharp
|
| 6 |
+
- view-synthesis
|
| 7 |
+
- 3d-gaussian-splatting
|
| 8 |
+
base_model: apple/Sharp
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# SHARP — ONNX Exports (INT8)
|
| 12 |
+
|
| 13 |
+
This repository provides ONNX exports of Apple’s **SHARP** model predictor, which outputs **NDC-space 3D Gaussians**.
|
| 14 |
+
|
| 15 |
+
## Export to ONNX (INT8)
|
| 16 |
+
|
| 17 |
+
Run the following command to export the model:
|
| 18 |
+
|
| 19 |
+
```bash
|
| 20 |
+
python export_onnx.py --int8
|
| 21 |
+
````
|
| 22 |
+
|
| 23 |
+
## Launch Web UI
|
| 24 |
+
|
| 25 |
+
Start the web interface with:
|
| 26 |
+
|
| 27 |
+
```bash
|
| 28 |
+
python webui.py \
|
| 29 |
+
--host 0.0.0.0 \
|
| 30 |
+
--port 7860 \
|
| 31 |
+
--ssl_cert cert.pem \
|
| 32 |
+
--ssl_key key.pem
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
## Notes
|
| 36 |
+
|
| 37 |
+
This project is for learning purposes. Please refer to the LICENSE for usage terms.
|
export_onnx.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import warnings
|
| 7 |
+
import torch
|
| 8 |
+
from sharp.models import PredictorParams, create_predictor
|
| 9 |
+
|
| 10 |
+
LOGGER = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
# 默认权重文件名
|
| 13 |
+
DEFAULT_CHECKPOINT_NAME = "sharp_2572gikvuh.pt"
|
| 14 |
+
MODEL_URL = f"https://ml-site.cdn-apple.com/models/sharp/{DEFAULT_CHECKPOINT_NAME}"
|
| 15 |
+
|
| 16 |
+
def verify_model_precision(model_path: Path):
|
| 17 |
+
"""
|
| 18 |
+
静态检测 ONNX 模型的权重数据类型分布,用于验证量化是否生效。
|
| 19 |
+
"""
|
| 20 |
+
try:
|
| 21 |
+
import onnx
|
| 22 |
+
from collections import Counter
|
| 23 |
+
# 尝试加载模型,注意大模型可能带 .data 后缀
|
| 24 |
+
model = onnx.load(str(model_path))
|
| 25 |
+
# 统计所有初始化器(权重)的数据类型
|
| 26 |
+
# 在 ONNX 中,FLOAT=1, UINT8=2, INT8=3, FLOAT16=10, UINT4=27, INT4=28, FLOAT4=29
|
| 27 |
+
dtypes = [onnx.TensorProto.DataType.Name(i.data_type) for i in model.graph.initializer]
|
| 28 |
+
counter = Counter(dtypes)
|
| 29 |
+
|
| 30 |
+
LOGGER.info(f"--- 权重类型分布报告 [{model_path.name}] ---")
|
| 31 |
+
for dtype, count in counter.items():
|
| 32 |
+
LOGGER.info(f" - {dtype}: {count} 个张量")
|
| 33 |
+
|
| 34 |
+
# 计算文件大小并打印
|
| 35 |
+
file_size = os.path.getsize(model_path)
|
| 36 |
+
LOGGER.info(f" [文件信息] 路径: {model_path}")
|
| 37 |
+
LOGGER.info(f" [文件信息] 体积: {file_size / (1024*1024):.2f} MB")
|
| 38 |
+
LOGGER.info("------------------------------------------------")
|
| 39 |
+
except ImportError:
|
| 40 |
+
LOGGER.warning("未安装 'onnx' 模块,跳过权重精度验证。")
|
| 41 |
+
except Exception as e:
|
| 42 |
+
LOGGER.warning(f"验证模型精度时发生异常: {e}")
|
| 43 |
+
|
| 44 |
+
def export_onnx(output_path: Path, checkpoint_path: str = None, internal_shape: int = 1536, is_fp16: bool = False):
|
| 45 |
+
device = torch.device("cpu")
|
| 46 |
+
|
| 47 |
+
LOGGER.info("正在创建预测器模型...")
|
| 48 |
+
predictor = create_predictor(PredictorParams())
|
| 49 |
+
|
| 50 |
+
# FP16 处理:将网络参数转为半精度,体积减半
|
| 51 |
+
if is_fp16:
|
| 52 |
+
LOGGER.info("正在将模型转换为 FP16 (半精度) 以直接将其缩减进单一 < 2GB 的文件中...")
|
| 53 |
+
predictor = predictor.half()
|
| 54 |
+
|
| 55 |
+
if checkpoint_path is None:
|
| 56 |
+
torch_hub_checkpoints = Path(torch.hub.get_dir()) / "checkpoints"
|
| 57 |
+
search_paths = [
|
| 58 |
+
Path(DEFAULT_CHECKPOINT_NAME),
|
| 59 |
+
Path("data") / DEFAULT_CHECKPOINT_NAME,
|
| 60 |
+
torch_hub_checkpoints / DEFAULT_CHECKPOINT_NAME,
|
| 61 |
+
]
|
| 62 |
+
|
| 63 |
+
found_path = None
|
| 64 |
+
for p in search_paths:
|
| 65 |
+
if p.exists():
|
| 66 |
+
found_path = p
|
| 67 |
+
break
|
| 68 |
+
|
| 69 |
+
if found_path:
|
| 70 |
+
LOGGER.info(f"在本地找到模型权重: {found_path},加载本地权重...")
|
| 71 |
+
state_dict = torch.load(found_path, map_location="cpu", weights_only=True)
|
| 72 |
+
else:
|
| 73 |
+
state_dict = torch.hub.load_state_dict_from_url(MODEL_URL, progress=True, map_location="cpu")
|
| 74 |
+
else:
|
| 75 |
+
state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
|
| 76 |
+
|
| 77 |
+
predictor.load_state_dict(state_dict)
|
| 78 |
+
predictor.eval()
|
| 79 |
+
predictor.to(device)
|
| 80 |
+
|
| 81 |
+
dummy_image = torch.randn(1, 3, internal_shape, internal_shape, device=device)
|
| 82 |
+
dummy_disparity = torch.tensor([1.0], device=device)
|
| 83 |
+
|
| 84 |
+
if is_fp16:
|
| 85 |
+
dummy_image = dummy_image.half()
|
| 86 |
+
dummy_disparity = dummy_disparity.half()
|
| 87 |
+
|
| 88 |
+
dummy_inputs = (dummy_image, dummy_disparity)
|
| 89 |
+
|
| 90 |
+
LOGGER.info("正在导出为 ONNX 格式...")
|
| 91 |
+
torch.onnx.export(
|
| 92 |
+
predictor,
|
| 93 |
+
dummy_inputs,
|
| 94 |
+
str(output_path),
|
| 95 |
+
export_params=True,
|
| 96 |
+
opset_version=19,
|
| 97 |
+
do_constant_folding=True,
|
| 98 |
+
input_names=['image', 'disparity_factor'],
|
| 99 |
+
output_names=['mean_vectors', 'singular_values', 'quaternions', 'colors', 'opacities'],
|
| 100 |
+
keep_initializers_as_inputs=False,
|
| 101 |
+
dynamic_axes={k: {0: 'batch_size'} for k in ['image', 'disparity_factor', 'mean_vectors', 'singular_values', 'quaternions', 'colors', 'opacities']}
|
| 102 |
+
)
|
| 103 |
+
LOGGER.info(f"成功将全精度 ONNX 模型导出至: {output_path}")
|
| 104 |
+
|
| 105 |
+
def main():
|
| 106 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
| 107 |
+
parser = argparse.ArgumentParser(description="将 ML-SHARP 模型导出为多种精度的 ONNX 格式。")
|
| 108 |
+
parser.add_argument("--output", type=str, default="sharp.onnx", help="导出的基础 ONNX 模型保存路径。")
|
| 109 |
+
parser.add_argument("--checkpoint", type=str, default=None, help="本地 .pt 权重文件路径。")
|
| 110 |
+
parser.add_argument("--fp16", action="store_true", help="启用 FP16 半精度导出。")
|
| 111 |
+
parser.add_argument("--int8", action="store_true", help="应用动态 INT8 量化。")
|
| 112 |
+
parser.add_argument("--int4", action="store_true", help="应用真正的 INT4 重量级量化 (WOQ)。")
|
| 113 |
+
args = parser.parse_args()
|
| 114 |
+
|
| 115 |
+
# 如果启用了 FP16 模式且用户未指定 output,则重定向默认输出名
|
| 116 |
+
if args.fp16 and args.output == "sharp.onnx":
|
| 117 |
+
args.output = "sharp_fp16.onnx"
|
| 118 |
+
|
| 119 |
+
output_path = Path(args.output)
|
| 120 |
+
is_quant_requested = args.int8 or args.int4
|
| 121 |
+
|
| 122 |
+
if is_quant_requested and output_path.exists():
|
| 123 |
+
LOGGER.info(f"检测到基础模型 '{output_path}' 已存在,跳过导出并开始验证。")
|
| 124 |
+
verify_model_precision(output_path)
|
| 125 |
+
else:
|
| 126 |
+
export_onnx(output_path, args.checkpoint, is_fp16=args.fp16)
|
| 127 |
+
verify_model_precision(output_path)
|
| 128 |
+
|
| 129 |
+
if is_quant_requested:
|
| 130 |
+
try:
|
| 131 |
+
from onnxruntime.quantization import quantize_dynamic, QuantType
|
| 132 |
+
import onnx
|
| 133 |
+
from onnx import shape_inference
|
| 134 |
+
except ImportError:
|
| 135 |
+
LOGGER.error("缺失依赖项:请安装 onnx 和 onnxruntime。")
|
| 136 |
+
return
|
| 137 |
+
|
| 138 |
+
preprocessed_path = output_path.with_name(f"{output_path.stem}_pre.onnx")
|
| 139 |
+
if not preprocessed_path.exists():
|
| 140 |
+
LOGGER.info("正在将基础模型转换为外部数据格式 (已跳过导致崩溃的形状推导模式)...")
|
| 141 |
+
try:
|
| 142 |
+
model_raw = onnx.load(str(output_path))
|
| 143 |
+
onnx.save(
|
| 144 |
+
model_raw,
|
| 145 |
+
str(preprocessed_path),
|
| 146 |
+
save_as_external_data=True,
|
| 147 |
+
all_tensors_to_one_file=True,
|
| 148 |
+
location=preprocessed_path.name + ".data",
|
| 149 |
+
size_threshold=1024,
|
| 150 |
+
convert_attribute=True
|
| 151 |
+
)
|
| 152 |
+
model_source = preprocessed_path
|
| 153 |
+
except Exception as e:
|
| 154 |
+
LOGGER.warning(f"预分流模型保存失败: {e}")
|
| 155 |
+
model_source = output_path
|
| 156 |
+
else:
|
| 157 |
+
model_source = preprocessed_path
|
| 158 |
+
|
| 159 |
+
tasks = []
|
| 160 |
+
if args.int8: tasks.append("int8")
|
| 161 |
+
if args.int4: tasks.append("int4")
|
| 162 |
+
|
| 163 |
+
# 1. INT8 量化
|
| 164 |
+
int8_path = output_path.with_name(f"{output_path.stem}_int8.onnx")
|
| 165 |
+
if "int8" in tasks:
|
| 166 |
+
LOGGER.info("正在应用动态 INT8 量化...")
|
| 167 |
+
quantize_dynamic(str(model_source), str(int8_path), weight_type=QuantType.QUInt8)
|
| 168 |
+
verify_model_precision(int8_path)
|
| 169 |
+
|
| 170 |
+
# 2. INT4 重量级量化 (WOQ)
|
| 171 |
+
if "int4" in tasks:
|
| 172 |
+
LOGGER.info("正在准备进行真正的 INT4 权重压缩 (WOQ)...")
|
| 173 |
+
quantized_output = output_path.with_name(f"{output_path.stem}_int4.onnx")
|
| 174 |
+
|
| 175 |
+
import onnxruntime.quantization as oq
|
| 176 |
+
import importlib
|
| 177 |
+
import pkgutil
|
| 178 |
+
|
| 179 |
+
# 优先尝试新版 Unified API
|
| 180 |
+
quantize_func = getattr(oq, "quantize", None)
|
| 181 |
+
WeightOnlyConfig = getattr(oq, "MatMulWeightOnlyQuantConfig", None)
|
| 182 |
+
|
| 183 |
+
if quantize_func and WeightOnlyConfig:
|
| 184 |
+
LOGGER.info("[INT4] 使用 MatMulWeightOnlyQuantConfig 执行量化...")
|
| 185 |
+
try:
|
| 186 |
+
model = onnx.load(str(model_source))
|
| 187 |
+
config = WeightOnlyConfig(weight_type=getattr(QuantType, 'QInt4', QuantType.QUInt8))
|
| 188 |
+
quantize_func(model, str(quantized_output), config)
|
| 189 |
+
LOGGER.info(f"[INT4] Unified API 导出成功!")
|
| 190 |
+
verify_model_precision(quantized_output)
|
| 191 |
+
return
|
| 192 |
+
except Exception as e:
|
| 193 |
+
LOGGER.warning(f"[INT4] Unified API 失败: {e}")
|
| 194 |
+
|
| 195 |
+
# Fallback 1: 尝试 Bnb4Quantizer
|
| 196 |
+
Bnb4Quantizer = None
|
| 197 |
+
try:
|
| 198 |
+
from onnxruntime.quantization.matmul_bnb4_quantizer import MatMulBnb4Quantizer
|
| 199 |
+
Bnb4Quantizer = MatMulBnb4Quantizer
|
| 200 |
+
except: pass
|
| 201 |
+
|
| 202 |
+
if Bnb4Quantizer:
|
| 203 |
+
LOGGER.info("[INT4] 使用 MatMulBnb4Quantizer (BnB 4-bit) 执行量化...")
|
| 204 |
+
try:
|
| 205 |
+
import inspect
|
| 206 |
+
model = onnx.load(str(model_source))
|
| 207 |
+
sig = inspect.signature(Bnb4Quantizer.__init__)
|
| 208 |
+
params = sig.parameters
|
| 209 |
+
kwargs = {"model": model, "block_size": 32}
|
| 210 |
+
if "quant_type" in params:
|
| 211 |
+
kwargs["quant_type"] = 1
|
| 212 |
+
|
| 213 |
+
q = Bnb4Quantizer(**kwargs)
|
| 214 |
+
q.process()
|
| 215 |
+
m = q.model.model if hasattr(q, 'model') and hasattr(q.model, 'model') else (q.model if hasattr(q, 'model') else q)
|
| 216 |
+
LOGGER.info(f"正在保存单体 INT4 模型...")
|
| 217 |
+
onnx.save(m, str(quantized_output))
|
| 218 |
+
LOGGER.info(f"[INT4] BnB 导出成功!")
|
| 219 |
+
verify_model_precision(quantized_output)
|
| 220 |
+
return
|
| 221 |
+
except Exception as e:
|
| 222 |
+
LOGGER.warning(f"[INT4] BnB 失败: {e}")
|
| 223 |
+
|
| 224 |
+
# Fallback 2: 暴力搜索兼容旧版 Class API
|
| 225 |
+
def search_legacy(pkg):
|
| 226 |
+
for attr in dir(pkg):
|
| 227 |
+
low = attr.lower()
|
| 228 |
+
if ("quantiz" in low or "quantis" in low) and ("4bit" in low or "nbit" in low or "weight" in low):
|
| 229 |
+
return getattr(pkg, attr), pkg.__name__
|
| 230 |
+
if hasattr(pkg, "__path__"):
|
| 231 |
+
for _, modname, _ in pkgutil.walk_packages(pkg.__path__, pkg.__name__ + "."):
|
| 232 |
+
try:
|
| 233 |
+
mod = importlib.import_module(modname)
|
| 234 |
+
res, src = search_legacy(mod)
|
| 235 |
+
if res: return res, src
|
| 236 |
+
except: continue
|
| 237 |
+
return None, None
|
| 238 |
+
|
| 239 |
+
LegacyQuantizer, src = search_legacy(oq)
|
| 240 |
+
if LegacyQuantizer and (isinstance(LegacyQuantizer, type) or callable(LegacyQuantizer)):
|
| 241 |
+
LOGGER.info(f"[INT4] 找到量化类/函数: {src}.{LegacyQuantizer.__name__}")
|
| 242 |
+
try:
|
| 243 |
+
model = onnx.load(str(model_source))
|
| 244 |
+
if "Quantizer" in LegacyQuantizer.__name__:
|
| 245 |
+
q = LegacyQuantizer(model, block_size=32, is_symmetric=True)
|
| 246 |
+
q.process()
|
| 247 |
+
m = q.model.model if hasattr(q, 'model') and hasattr(q.model, 'model') else (q.model if hasattr(q, 'model') else q)
|
| 248 |
+
onnx.save(m, str(quantized_output))
|
| 249 |
+
else:
|
| 250 |
+
LegacyQuantizer(model, str(quantized_output), weight_type=QuantType.QInt4)
|
| 251 |
+
LOGGER.info(f"[INT4] 导出成功!")
|
| 252 |
+
verify_model_precision(quantized_output)
|
| 253 |
+
except Exception as e:
|
| 254 |
+
LOGGER.error(f"[INT4] 执行失败: {e}")
|
| 255 |
+
else:
|
| 256 |
+
LOGGER.error(f"[INT4] 未找到有效量化器。所有属性: {dir(oq)}")
|
| 257 |
+
|
| 258 |
+
if __name__ == "__main__":
|
| 259 |
+
main()
|
webui.py
ADDED
|
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import onnxruntime as ort
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import numpy as np
|
| 7 |
+
import tempfile
|
| 8 |
+
import os
|
| 9 |
+
import time
|
| 10 |
+
import logging
|
| 11 |
+
import argparse
|
| 12 |
+
import sys
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
sys.path.append(os.path.join(os.getcwd(), 'src'))
|
| 16 |
+
|
| 17 |
+
from sharp.utils import io
|
| 18 |
+
from sharp.utils.gaussians import Gaussians3D, save_ply, unproject_gaussians
|
| 19 |
+
|
| 20 |
+
warnings.filterwarnings("ignore")
|
| 21 |
+
logging.basicConfig(level=logging.INFO)
|
| 22 |
+
LOGGER = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
SESSION = None
|
| 25 |
+
CURRENT_MODEL_PATH = None
|
| 26 |
+
|
| 27 |
+
def load_model(model_path):
|
| 28 |
+
global SESSION, CURRENT_MODEL_PATH
|
| 29 |
+
if SESSION is not None and CURRENT_MODEL_PATH == model_path:
|
| 30 |
+
return SESSION
|
| 31 |
+
|
| 32 |
+
if not Path(model_path).exists():
|
| 33 |
+
LOGGER.error(f"Model file not found: {model_path}")
|
| 34 |
+
return None
|
| 35 |
+
|
| 36 |
+
try:
|
| 37 |
+
LOGGER.info(f"Loading model: {model_path}...")
|
| 38 |
+
options = ort.SessionOptions()
|
| 39 |
+
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
| 40 |
+
options.intra_op_num_threads = os.cpu_count() or 4
|
| 41 |
+
options.inter_op_num_threads = min(4, os.cpu_count() or 4)
|
| 42 |
+
options.enable_mem_pattern = True
|
| 43 |
+
options.enable_cpu_mem_arena = True
|
| 44 |
+
providers = ['CPUExecutionProvider']
|
| 45 |
+
|
| 46 |
+
# Free old session memory if possible
|
| 47 |
+
if SESSION is not None:
|
| 48 |
+
del SESSION
|
| 49 |
+
|
| 50 |
+
SESSION = ort.InferenceSession(model_path, sess_options=options, providers=providers)
|
| 51 |
+
CURRENT_MODEL_PATH = model_path
|
| 52 |
+
LOGGER.info(f"Model loaded successfully: {model_path}")
|
| 53 |
+
return SESSION
|
| 54 |
+
except Exception as e:
|
| 55 |
+
LOGGER.error(f"Failed to load model {model_path}: {e}")
|
| 56 |
+
return None
|
| 57 |
+
|
| 58 |
+
def get_available_models():
|
| 59 |
+
models = list(Path('.').glob('*.onnx'))
|
| 60 |
+
return [str(m) for m in models]
|
| 61 |
+
|
| 62 |
+
def process_image(image_filepath, opacity_threshold, downsample_rate):
|
| 63 |
+
if not image_filepath:
|
| 64 |
+
return None
|
| 65 |
+
if SESSION is None:
|
| 66 |
+
gr.Warning("Model not loaded. Using dummy processing or check console.")
|
| 67 |
+
return None
|
| 68 |
+
|
| 69 |
+
start_time = time.perf_counter()
|
| 70 |
+
img, _, f_px = io.load_rgb(Path(image_filepath), auto_rotate=True, remove_alpha=True)
|
| 71 |
+
height, width = img.shape[:2]
|
| 72 |
+
image_pt = torch.from_numpy(img.copy()).float().permute(2, 0, 1) / 255.0
|
| 73 |
+
disparity_factor = torch.tensor([f_px / width]).float()
|
| 74 |
+
internal_shape = (1536, 1536)
|
| 75 |
+
image_resized_pt = F.interpolate(
|
| 76 |
+
image_pt[None], size=(internal_shape[1], internal_shape[0]), mode="bilinear", align_corners=True
|
| 77 |
+
)
|
| 78 |
+
model_inputs = SESSION.get_inputs()
|
| 79 |
+
if model_inputs[0].type == 'tensor(float16)':
|
| 80 |
+
image_resized_pt = image_resized_pt.half()
|
| 81 |
+
disparity_factor = disparity_factor.half()
|
| 82 |
+
inputs = {'image': image_resized_pt.numpy(), 'disparity_factor': disparity_factor.numpy()}
|
| 83 |
+
outputs = SESSION.run(None, inputs)
|
| 84 |
+
gaussians_ndc = Gaussians3D(
|
| 85 |
+
mean_vectors=torch.from_numpy(outputs[0]).float(),
|
| 86 |
+
singular_values=torch.from_numpy(outputs[1]).float(),
|
| 87 |
+
quaternions=torch.from_numpy(outputs[2]).float(),
|
| 88 |
+
colors=torch.from_numpy(outputs[3]).float(),
|
| 89 |
+
opacities=torch.from_numpy(outputs[4]).float()
|
| 90 |
+
)
|
| 91 |
+
mask = gaussians_ndc.opacities[0] > opacity_threshold
|
| 92 |
+
sampler = slice(0, None, int(downsample_rate))
|
| 93 |
+
def apply_mask_and_sampling(tensor):
|
| 94 |
+
return tensor[:, mask][:, sampler]
|
| 95 |
+
filtered_gaussians_ndc = Gaussians3D(
|
| 96 |
+
mean_vectors=apply_mask_and_sampling(gaussians_ndc.mean_vectors),
|
| 97 |
+
singular_values=apply_mask_and_sampling(gaussians_ndc.singular_values),
|
| 98 |
+
quaternions=apply_mask_and_sampling(gaussians_ndc.quaternions),
|
| 99 |
+
colors=apply_mask_and_sampling(gaussians_ndc.colors),
|
| 100 |
+
opacities=apply_mask_and_sampling(gaussians_ndc.opacities)
|
| 101 |
+
)
|
| 102 |
+
intrinsics = torch.tensor([
|
| 103 |
+
[f_px, 0, width / 2, 0],
|
| 104 |
+
[0, f_px, height / 2, 0],
|
| 105 |
+
[0, 0, 1, 0],
|
| 106 |
+
[0, 0, 0, 1],
|
| 107 |
+
]).float()
|
| 108 |
+
intrinsics_resized = intrinsics.clone()
|
| 109 |
+
intrinsics_resized[0] *= internal_shape[0] / width
|
| 110 |
+
intrinsics_resized[1] *= internal_shape[1] / height
|
| 111 |
+
gaussians = unproject_gaussians(filtered_gaussians_ndc, torch.eye(4), intrinsics_resized, internal_shape)
|
| 112 |
+
out_dir = Path(tempfile.mkdtemp())
|
| 113 |
+
out_path = out_dir / "output.ply"
|
| 114 |
+
save_ply(gaussians, f_px, (height, width), out_path)
|
| 115 |
+
return str(out_path)
|
| 116 |
+
|
| 117 |
+
custom_css = """
|
| 118 |
+
body, .gradio-container {
|
| 119 |
+
background: radial-gradient(circle at top left, #0d0d12 0%, #000000 100%) !important;
|
| 120 |
+
color: #e0e0e0 !important;
|
| 121 |
+
font-family: 'Inter', system-ui, -apple-system, sans-serif !important;
|
| 122 |
+
margin: 0 !important;
|
| 123 |
+
padding: 0 !important;
|
| 124 |
+
}
|
| 125 |
+
.panel-box {
|
| 126 |
+
background: rgba(20, 20, 25, 0.8) !important;
|
| 127 |
+
backdrop-filter: blur(10px);
|
| 128 |
+
border: 1px solid rgba(255, 255, 255, 0.1) !important;
|
| 129 |
+
border-radius: 20px !important;
|
| 130 |
+
padding: 24px;
|
| 131 |
+
box-shadow: 0 8px 32px 0 rgba(0, 0, 0, 0.8);
|
| 132 |
+
transition: all 0.3s ease;
|
| 133 |
+
margin-bottom: 16px;
|
| 134 |
+
}
|
| 135 |
+
#spark-container {
|
| 136 |
+
width: 100%;
|
| 137 |
+
height: 70vh; /* Responsive height */
|
| 138 |
+
min-height: 400px;
|
| 139 |
+
max-height: 720px;
|
| 140 |
+
background: #000;
|
| 141 |
+
border-radius: 12px;
|
| 142 |
+
border: 1px solid rgba(255, 255, 255, 0.1);
|
| 143 |
+
position: relative;
|
| 144 |
+
overflow: hidden;
|
| 145 |
+
}
|
| 146 |
+
#generate-btn {
|
| 147 |
+
background: linear-gradient(135deg, #6366f1 0%, #a855f7 100%) !important;
|
| 148 |
+
color: white !important;
|
| 149 |
+
font-weight: 700 !important;
|
| 150 |
+
border-radius: 12px !important;
|
| 151 |
+
border: none !important;
|
| 152 |
+
margin-top: 10px;
|
| 153 |
+
padding: 16px 24px !important; /* Larger for touch */
|
| 154 |
+
text-transform: uppercase;
|
| 155 |
+
letter-spacing: 1px;
|
| 156 |
+
font-size: 1.1rem !important;
|
| 157 |
+
transition: transform 0.2s, box-shadow 0.2s !important;
|
| 158 |
+
}
|
| 159 |
+
header h1 {
|
| 160 |
+
background: linear-gradient(to right, #fff, #a5a5a5);
|
| 161 |
+
-webkit-background-clip: text;
|
| 162 |
+
-webkit-text-fill-color: transparent;
|
| 163 |
+
font-size: 2rem !important;
|
| 164 |
+
font-weight: 900 !important;
|
| 165 |
+
text-align: center;
|
| 166 |
+
margin: 20px 0 !important;
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
/* Mobile Optimizations */
|
| 170 |
+
@media (max-width: 768px) {
|
| 171 |
+
.panel-box {
|
| 172 |
+
padding: 16px;
|
| 173 |
+
border-radius: 16px !important;
|
| 174 |
+
}
|
| 175 |
+
#spark-container {
|
| 176 |
+
height: 50vh; /* Shorter on mobile to leave room for controls */
|
| 177 |
+
min-height: 300px;
|
| 178 |
+
}
|
| 179 |
+
header h1 {
|
| 180 |
+
font-size: 1.5rem !important;
|
| 181 |
+
}
|
| 182 |
+
.gr-row {
|
| 183 |
+
flex-direction: column !important;
|
| 184 |
+
}
|
| 185 |
+
/* Make inputs full width on mobile */
|
| 186 |
+
.gr-form {
|
| 187 |
+
width: 100% !important;
|
| 188 |
+
}
|
| 189 |
+
}
|
| 190 |
+
"""
|
| 191 |
+
|
| 192 |
+
head_content = """
|
| 193 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no">
|
| 194 |
+
<link rel="preconnect" href="https://fonts.googleapis.com">
|
| 195 |
+
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
| 196 |
+
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;700;900&display=swap" rel="stylesheet">
|
| 197 |
+
<script type="importmap">
|
| 198 |
+
{
|
| 199 |
+
"imports": {
|
| 200 |
+
"three": "https://cdnjs.cloudflare.com/ajax/libs/three.js/0.178.0/three.module.js",
|
| 201 |
+
"@sparkjsdev/spark": "https://sparkjs.dev/releases/spark/0.1.10/spark.module.js"
|
| 202 |
+
}
|
| 203 |
+
}
|
| 204 |
+
</script>
|
| 205 |
+
<script type="module">
|
| 206 |
+
import * as THREE from "three";
|
| 207 |
+
import { OrbitControls } from "https://unpkg.com/three@0.178.0/examples/jsm/controls/OrbitControls.js";
|
| 208 |
+
import { SplatMesh } from "@sparkjsdev/spark";
|
| 209 |
+
|
| 210 |
+
let renderer, scene, camera, controls, splat, container;
|
| 211 |
+
let startTime, timerInterval;
|
| 212 |
+
|
| 213 |
+
window.initSpark = function() {
|
| 214 |
+
container = document.getElementById('spark-container');
|
| 215 |
+
if (!container || window.sparkInitialized) return;
|
| 216 |
+
scene = new THREE.Scene();
|
| 217 |
+
camera = new THREE.PerspectiveCamera(60, container.clientWidth / container.clientHeight, 0.1, 1000);
|
| 218 |
+
camera.position.set(0, 1, 4);
|
| 219 |
+
renderer = new THREE.WebGLRenderer({ antialias: true, alpha: true, logarithmicDepthBuffer: true });
|
| 220 |
+
renderer.setSize(container.clientWidth, container.clientHeight);
|
| 221 |
+
renderer.setPixelRatio(Math.min(window.devicePixelRatio, 2));
|
| 222 |
+
container.appendChild(renderer.domElement);
|
| 223 |
+
controls = new OrbitControls(camera, renderer.domElement);
|
| 224 |
+
controls.enableDamping = true;
|
| 225 |
+
function animate() {
|
| 226 |
+
requestAnimationFrame(animate);
|
| 227 |
+
controls.update();
|
| 228 |
+
renderer.render(scene, camera);
|
| 229 |
+
}
|
| 230 |
+
animate();
|
| 231 |
+
window.addEventListener('resize', () => {
|
| 232 |
+
if (!container) return;
|
| 233 |
+
camera.aspect = container.clientWidth / container.clientHeight;
|
| 234 |
+
camera.updateProjectionMatrix();
|
| 235 |
+
renderer.setSize(container.clientWidth, container.clientHeight);
|
| 236 |
+
});
|
| 237 |
+
window.sparkInitialized = true;
|
| 238 |
+
};
|
| 239 |
+
|
| 240 |
+
window.loadSplat = async function(url) {
|
| 241 |
+
if (!window.sparkInitialized) window.initSpark();
|
| 242 |
+
if (splat) { scene.remove(splat); splat.dispose(); }
|
| 243 |
+
try {
|
| 244 |
+
splat = new SplatMesh({ url: url });
|
| 245 |
+
splat.rotation.x = Math.PI;
|
| 246 |
+
scene.add(splat);
|
| 247 |
+
setTimeout(window.focusModel, 500);
|
| 248 |
+
} catch (e) { console.error(e); }
|
| 249 |
+
};
|
| 250 |
+
|
| 251 |
+
window.focusModel = function() {
|
| 252 |
+
if (!splat || !controls || !camera) return;
|
| 253 |
+
const box = new THREE.Box3();
|
| 254 |
+
let pointsFound = 0;
|
| 255 |
+
splat.traverse((obj) => {
|
| 256 |
+
if (obj.geometry && obj.geometry.attributes.position) {
|
| 257 |
+
const pos = obj.geometry.attributes.position;
|
| 258 |
+
const count = pos.count;
|
| 259 |
+
const step = Math.max(1, Math.floor(count / 5000));
|
| 260 |
+
for (let i = 0; i < count; i += step) {
|
| 261 |
+
const p = new THREE.Vector3(pos.getX(i), pos.getY(i), pos.getZ(i));
|
| 262 |
+
p.applyMatrix4(obj.matrixWorld);
|
| 263 |
+
box.expandByPoint(p);
|
| 264 |
+
}
|
| 265 |
+
pointsFound += count;
|
| 266 |
+
}
|
| 267 |
+
});
|
| 268 |
+
let center = new THREE.Vector3();
|
| 269 |
+
let size = new THREE.Vector3();
|
| 270 |
+
if (pointsFound === 0 || box.isEmpty()) { center.set(0, 1.5, -3); size.set(2, 2, 2); }
|
| 271 |
+
else { box.getCenter(center); box.getSize(size); }
|
| 272 |
+
const maxDim = Math.max(size.x, size.y, size.z);
|
| 273 |
+
const fovRad = camera.fov * (Math.PI / 180);
|
| 274 |
+
let distance = (maxDim / 2) / Math.tan(fovRad / 2) * 1.5;
|
| 275 |
+
controls.target.copy(center);
|
| 276 |
+
camera.position.set(center.x, center.y, center.z + distance);
|
| 277 |
+
controls.update();
|
| 278 |
+
};
|
| 279 |
+
|
| 280 |
+
function getBtn() {
|
| 281 |
+
return document.getElementById('generate-btn') || document.querySelector('#generate-btn button');
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
window.startTimer = function() {
|
| 285 |
+
const btn = getBtn();
|
| 286 |
+
if (!btn) return;
|
| 287 |
+
btn.disabled = true;
|
| 288 |
+
btn.style.opacity = "0.6";
|
| 289 |
+
btn.style.cursor = "wait";
|
| 290 |
+
startTime = Date.now();
|
| 291 |
+
timerInterval = setInterval(() => {
|
| 292 |
+
const elapsed = ((Date.now() - startTime) / 1000).toFixed(1);
|
| 293 |
+
btn.innerText = `Generating... ${elapsed}s`;
|
| 294 |
+
}, 100);
|
| 295 |
+
};
|
| 296 |
+
|
| 297 |
+
window.stopTimer = function() {
|
| 298 |
+
if (timerInterval) {
|
| 299 |
+
clearInterval(timerInterval);
|
| 300 |
+
const elapsed = ((Date.now() - startTime) / 1000).toFixed(1);
|
| 301 |
+
const btn = getBtn();
|
| 302 |
+
if (btn) btn.innerText = `Done in ${elapsed}s`;
|
| 303 |
+
}
|
| 304 |
+
};
|
| 305 |
+
|
| 306 |
+
window.resetBtn = function(hasImage) {
|
| 307 |
+
const btn = getBtn();
|
| 308 |
+
if (btn) {
|
| 309 |
+
btn.disabled = !hasImage;
|
| 310 |
+
btn.style.opacity = hasImage ? "1.0" : "0.5";
|
| 311 |
+
btn.style.cursor = hasImage ? "pointer" : "default";
|
| 312 |
+
btn.innerText = "Generate 3D Gaussians";
|
| 313 |
+
}
|
| 314 |
+
if (timerInterval) clearInterval(timerInterval);
|
| 315 |
+
};
|
| 316 |
+
</script>
|
| 317 |
+
"""
|
| 318 |
+
|
| 319 |
+
if __name__ == "__main__":
|
| 320 |
+
parser = argparse.ArgumentParser()
|
| 321 |
+
parser.add_argument("--model", type=str, default="ml-sharp_int4.onnx")
|
| 322 |
+
parser.add_argument("--host", type=str, default="0.0.0.0")
|
| 323 |
+
parser.add_argument("--port", type=int, default=7860)
|
| 324 |
+
parser.add_argument("--ssl_cert", type=str, default="cert.pem")
|
| 325 |
+
parser.add_argument("--ssl_key", type=str, default="key.pem")
|
| 326 |
+
args = parser.parse_args()
|
| 327 |
+
|
| 328 |
+
# Pre-load initial model
|
| 329 |
+
load_model(args.model)
|
| 330 |
+
|
| 331 |
+
# Define UI inside __main__ or pass args to it
|
| 332 |
+
with gr.Blocks(css=custom_css, theme=gr.themes.Default(), head=head_content, title="SHARP 3D Recon") as demo:
|
| 333 |
+
gr.HTML("<header><h1>SHARP 3D RECONSTRUCTION</h1></header>")
|
| 334 |
+
with gr.Row():
|
| 335 |
+
with gr.Column(scale=1):
|
| 336 |
+
with gr.Group(elem_classes="panel-box"):
|
| 337 |
+
available_models = get_available_models()
|
| 338 |
+
model_selector = gr.Dropdown(
|
| 339 |
+
choices=available_models,
|
| 340 |
+
value=args.model if args.model in available_models else (available_models[0] if available_models else None),
|
| 341 |
+
label="Select ONNX Model (Precision)",
|
| 342 |
+
interactive=True
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
input_image = gr.Image(
|
| 346 |
+
type="filepath",
|
| 347 |
+
label="Capture or Upload Image",
|
| 348 |
+
height=400,
|
| 349 |
+
sources=["upload", "webcam"]
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
with gr.Accordion("Advanced Settings", open=False):
|
| 353 |
+
opacity_val = gr.Slider(0.0, 1.0, value=0.0, label="Opacity Threshold")
|
| 354 |
+
downsample_val = gr.Slider(1, 10, step=1, value=1, label="Downsample Rate")
|
| 355 |
+
|
| 356 |
+
submit_btn = gr.Button("Generate 3D Gaussians", variant="primary", elem_id="generate-btn", interactive=False)
|
| 357 |
+
|
| 358 |
+
gr.Markdown("Capture a photo from your phone or upload an image to start the real-time 3D conversion.")
|
| 359 |
+
|
| 360 |
+
with gr.Column(scale=2):
|
| 361 |
+
with gr.Group(elem_classes="panel-box"):
|
| 362 |
+
gr.HTML("<div id='spark-container'></div>")
|
| 363 |
+
output_file = gr.File(label="Output Model", visible=False)
|
| 364 |
+
|
| 365 |
+
demo.load(fn=None, inputs=None, outputs=None, js="() => { setTimeout(window.initSpark, 500); window.resetBtn(false); }")
|
| 366 |
+
|
| 367 |
+
model_selector.change(fn=load_model, inputs=[model_selector], outputs=None)
|
| 368 |
+
|
| 369 |
+
input_image.change(fn=lambda x: x is not None, inputs=[input_image], outputs=None, js="(img) => { window.resetBtn(!!img); }")
|
| 370 |
+
|
| 371 |
+
submit_btn.click(
|
| 372 |
+
fn=process_image,
|
| 373 |
+
inputs=[input_image, opacity_val, downsample_val],
|
| 374 |
+
outputs=[output_file],
|
| 375 |
+
js="(img, op, down) => { window.startTimer(); return [img, op, down]; }"
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
input_image.upload(
|
| 379 |
+
fn=process_image,
|
| 380 |
+
inputs=[input_image, opacity_val, downsample_val],
|
| 381 |
+
outputs=[output_file],
|
| 382 |
+
js="(img, op, down) => { window.startTimer(); return [img, op, down]; }"
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
output_file.change(
|
| 386 |
+
fn=None,
|
| 387 |
+
inputs=[output_file],
|
| 388 |
+
js="(f) => { window.stopTimer(); if (f && f.url) { window.loadSplat(f.url); } }"
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
# Launch with HTTPS support
|
| 392 |
+
demo.queue().launch(
|
| 393 |
+
server_name=args.host,
|
| 394 |
+
server_port=args.port,
|
| 395 |
+
share=False,
|
| 396 |
+
ssl_certfile=args.ssl_cert if os.path.exists(args.ssl_cert) else None,
|
| 397 |
+
ssl_keyfile=args.ssl_key if os.path.exists(args.ssl_key) else None,
|
| 398 |
+
ssl_verify=False
|
| 399 |
+
)
|