sharp-onnx-int8 / export_onnx.py
Olsc's picture
更新代码
04b19b4 verified
import argparse
import logging
import os
from pathlib import Path
import warnings
import torch
from sharp.models import PredictorParams, create_predictor
LOGGER = logging.getLogger(__name__)
# 默认权重文件名
DEFAULT_CHECKPOINT_NAME = "sharp_2572gikvuh.pt"
MODEL_URL = f"https://ml-site.cdn-apple.com/models/sharp/{DEFAULT_CHECKPOINT_NAME}"
def verify_model_precision(model_path: Path):
"""
静态检测 ONNX 模型的权重数据类型分布,用于验证量化是否生效。
"""
try:
import onnx
from collections import Counter
# 尝试加载模型,注意大模型可能带 .data 后缀
model = onnx.load(str(model_path))
# 统计所有初始化器(权重)的数据类型
# 在 ONNX 中,FLOAT=1, UINT8=2, INT8=3, FLOAT16=10, UINT4=27, INT4=28, FLOAT4=29
dtypes = [onnx.TensorProto.DataType.Name(i.data_type) for i in model.graph.initializer]
counter = Counter(dtypes)
LOGGER.info(f"--- 权重类型分布报告 [{model_path.name}] ---")
for dtype, count in counter.items():
LOGGER.info(f" - {dtype}: {count} 个张量")
# 计算文件大小并打印
file_size = os.path.getsize(model_path)
LOGGER.info(f" [文件信息] 路径: {model_path}")
LOGGER.info(f" [文件信息] 体积: {file_size / (1024*1024):.2f} MB")
LOGGER.info("------------------------------------------------")
except ImportError:
LOGGER.warning("未安装 'onnx' 模块,跳过权重精度验证。")
except Exception as e:
LOGGER.warning(f"验证模型精度时发生异常: {e}")
def export_onnx(output_path: Path, checkpoint_path: str = None, internal_shape: int = 1536, is_fp16: bool = False):
device = torch.device("cpu")
LOGGER.info("正在创建预测器模型...")
predictor = create_predictor(PredictorParams())
# FP16 处理:将网络参数转为半精度,体积减半
if is_fp16:
LOGGER.info("正在将模型转换为 FP16 (半精度) 以直接将其缩减进单一 < 2GB 的文件中...")
predictor = predictor.half()
if checkpoint_path is None:
torch_hub_checkpoints = Path(torch.hub.get_dir()) / "checkpoints"
search_paths = [
Path(DEFAULT_CHECKPOINT_NAME),
Path("data") / DEFAULT_CHECKPOINT_NAME,
torch_hub_checkpoints / DEFAULT_CHECKPOINT_NAME,
]
found_path = None
for p in search_paths:
if p.exists():
found_path = p
break
if found_path:
LOGGER.info(f"在本地找到模型权重: {found_path},加载本地权重...")
state_dict = torch.load(found_path, map_location="cpu", weights_only=True)
else:
state_dict = torch.hub.load_state_dict_from_url(MODEL_URL, progress=True, map_location="cpu")
else:
state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
predictor.load_state_dict(state_dict)
predictor.eval()
predictor.to(device)
dummy_image = torch.randn(1, 3, internal_shape, internal_shape, device=device)
dummy_disparity = torch.tensor([1.0], device=device)
if is_fp16:
dummy_image = dummy_image.half()
dummy_disparity = dummy_disparity.half()
dummy_inputs = (dummy_image, dummy_disparity)
LOGGER.info("正在导出为 ONNX 格式...")
torch.onnx.export(
predictor,
dummy_inputs,
str(output_path),
export_params=True,
opset_version=19,
do_constant_folding=True,
input_names=['image', 'disparity_factor'],
output_names=['mean_vectors', 'singular_values', 'quaternions', 'colors', 'opacities'],
keep_initializers_as_inputs=False,
dynamic_axes={k: {0: 'batch_size'} for k in ['image', 'disparity_factor', 'mean_vectors', 'singular_values', 'quaternions', 'colors', 'opacities']}
)
LOGGER.info(f"成功将全精度 ONNX 模型导出至: {output_path}")
def main():
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
parser = argparse.ArgumentParser(description="将 ML-SHARP 模型导出为多种精度的 ONNX 格式。")
parser.add_argument("--output", type=str, default="sharp.onnx", help="导出的基础 ONNX 模型保存路径。")
parser.add_argument("--checkpoint", type=str, default=None, help="本地 .pt 权重文件路径。")
parser.add_argument("--fp16", action="store_true", help="启用 FP16 半精度导出。")
parser.add_argument("--int8", action="store_true", help="应用动态 INT8 量化。")
parser.add_argument("--int4", action="store_true", help="应用真正的 INT4 重量级量化 (WOQ)。")
args = parser.parse_args()
# 如果启用了 FP16 模式且用户未指定 output,则重定向默认输出名
if args.fp16 and args.output == "sharp.onnx":
args.output = "sharp_fp16.onnx"
output_path = Path(args.output)
is_quant_requested = args.int8 or args.int4
if is_quant_requested and output_path.exists():
LOGGER.info(f"检测到基础模型 '{output_path}' 已存在,跳过导出并开始验证。")
verify_model_precision(output_path)
else:
export_onnx(output_path, args.checkpoint, is_fp16=args.fp16)
verify_model_precision(output_path)
if is_quant_requested:
try:
from onnxruntime.quantization import quantize_dynamic, QuantType
import onnx
from onnx import shape_inference
except ImportError:
LOGGER.error("缺失依赖项:请安装 onnx 和 onnxruntime。")
return
preprocessed_path = output_path.with_name(f"{output_path.stem}_pre.onnx")
if not preprocessed_path.exists():
LOGGER.info("正在将基础模型转换为外部数据格式 (已跳过导致崩溃的形状推导模式)...")
try:
model_raw = onnx.load(str(output_path))
onnx.save(
model_raw,
str(preprocessed_path),
save_as_external_data=True,
all_tensors_to_one_file=True,
location=preprocessed_path.name + ".data",
size_threshold=1024,
convert_attribute=True
)
model_source = preprocessed_path
except Exception as e:
LOGGER.warning(f"预分流模型保存失败: {e}")
model_source = output_path
else:
model_source = preprocessed_path
tasks = []
if args.int8: tasks.append("int8")
if args.int4: tasks.append("int4")
# 1. INT8 量化
int8_path = output_path.with_name(f"{output_path.stem}_int8.onnx")
if "int8" in tasks:
LOGGER.info("正在应用动态 INT8 量化...")
quantize_dynamic(str(model_source), str(int8_path), weight_type=QuantType.QUInt8)
verify_model_precision(int8_path)
# 2. INT4 重量级量化 (WOQ)
if "int4" in tasks:
LOGGER.info("正在准备进行真正的 INT4 权重压缩 (WOQ)...")
quantized_output = output_path.with_name(f"{output_path.stem}_int4.onnx")
import onnxruntime.quantization as oq
import importlib
import pkgutil
# 优先尝试新版 Unified API
quantize_func = getattr(oq, "quantize", None)
WeightOnlyConfig = getattr(oq, "MatMulWeightOnlyQuantConfig", None)
if quantize_func and WeightOnlyConfig:
LOGGER.info("[INT4] 使用 MatMulWeightOnlyQuantConfig 执行量化...")
try:
model = onnx.load(str(model_source))
config = WeightOnlyConfig(weight_type=getattr(QuantType, 'QInt4', QuantType.QUInt8))
quantize_func(model, str(quantized_output), config)
LOGGER.info(f"[INT4] Unified API 导出成功!")
verify_model_precision(quantized_output)
return
except Exception as e:
LOGGER.warning(f"[INT4] Unified API 失败: {e}")
# Fallback 1: 尝试 Bnb4Quantizer
Bnb4Quantizer = None
try:
from onnxruntime.quantization.matmul_bnb4_quantizer import MatMulBnb4Quantizer
Bnb4Quantizer = MatMulBnb4Quantizer
except: pass
if Bnb4Quantizer:
LOGGER.info("[INT4] 使用 MatMulBnb4Quantizer (BnB 4-bit) 执行量化...")
try:
import inspect
model = onnx.load(str(model_source))
sig = inspect.signature(Bnb4Quantizer.__init__)
params = sig.parameters
kwargs = {"model": model, "block_size": 32}
if "quant_type" in params:
kwargs["quant_type"] = 1
q = Bnb4Quantizer(**kwargs)
q.process()
m = q.model.model if hasattr(q, 'model') and hasattr(q.model, 'model') else (q.model if hasattr(q, 'model') else q)
LOGGER.info(f"正在保存单体 INT4 模型...")
onnx.save(m, str(quantized_output))
LOGGER.info(f"[INT4] BnB 导出成功!")
verify_model_precision(quantized_output)
return
except Exception as e:
LOGGER.warning(f"[INT4] BnB 失败: {e}")
# Fallback 2: 暴力搜索兼容旧版 Class API
def search_legacy(pkg):
for attr in dir(pkg):
low = attr.lower()
if ("quantiz" in low or "quantis" in low) and ("4bit" in low or "nbit" in low or "weight" in low):
return getattr(pkg, attr), pkg.__name__
if hasattr(pkg, "__path__"):
for _, modname, _ in pkgutil.walk_packages(pkg.__path__, pkg.__name__ + "."):
try:
mod = importlib.import_module(modname)
res, src = search_legacy(mod)
if res: return res, src
except: continue
return None, None
LegacyQuantizer, src = search_legacy(oq)
if LegacyQuantizer and (isinstance(LegacyQuantizer, type) or callable(LegacyQuantizer)):
LOGGER.info(f"[INT4] 找到量化类/函数: {src}.{LegacyQuantizer.__name__}")
try:
model = onnx.load(str(model_source))
if "Quantizer" in LegacyQuantizer.__name__:
q = LegacyQuantizer(model, block_size=32, is_symmetric=True)
q.process()
m = q.model.model if hasattr(q, 'model') and hasattr(q.model, 'model') else (q.model if hasattr(q, 'model') else q)
onnx.save(m, str(quantized_output))
else:
LegacyQuantizer(model, str(quantized_output), weight_type=QuantType.QInt4)
LOGGER.info(f"[INT4] 导出成功!")
verify_model_precision(quantized_output)
except Exception as e:
LOGGER.error(f"[INT4] 执行失败: {e}")
else:
LOGGER.error(f"[INT4] 未找到有效量化器。所有属性: {dir(oq)}")
if __name__ == "__main__":
main()