| import argparse |
| import logging |
| import os |
| from pathlib import Path |
|
|
| import warnings |
| import torch |
| from sharp.models import PredictorParams, create_predictor |
|
|
| LOGGER = logging.getLogger(__name__) |
|
|
| |
| DEFAULT_CHECKPOINT_NAME = "sharp_2572gikvuh.pt" |
| MODEL_URL = f"https://ml-site.cdn-apple.com/models/sharp/{DEFAULT_CHECKPOINT_NAME}" |
|
|
| def verify_model_precision(model_path: Path): |
| """ |
| 静态检测 ONNX 模型的权重数据类型分布,用于验证量化是否生效。 |
| """ |
| try: |
| import onnx |
| from collections import Counter |
| |
| model = onnx.load(str(model_path)) |
| |
| |
| dtypes = [onnx.TensorProto.DataType.Name(i.data_type) for i in model.graph.initializer] |
| counter = Counter(dtypes) |
| |
| LOGGER.info(f"--- 权重类型分布报告 [{model_path.name}] ---") |
| for dtype, count in counter.items(): |
| LOGGER.info(f" - {dtype}: {count} 个张量") |
| |
| |
| file_size = os.path.getsize(model_path) |
| LOGGER.info(f" [文件信息] 路径: {model_path}") |
| LOGGER.info(f" [文件信息] 体积: {file_size / (1024*1024):.2f} MB") |
| LOGGER.info("------------------------------------------------") |
| except ImportError: |
| LOGGER.warning("未安装 'onnx' 模块,跳过权重精度验证。") |
| except Exception as e: |
| LOGGER.warning(f"验证模型精度时发生异常: {e}") |
|
|
| def export_onnx(output_path: Path, checkpoint_path: str = None, internal_shape: int = 1536, is_fp16: bool = False): |
| device = torch.device("cpu") |
| |
| LOGGER.info("正在创建预测器模型...") |
| predictor = create_predictor(PredictorParams()) |
| |
| |
| if is_fp16: |
| LOGGER.info("正在将模型转换为 FP16 (半精度) 以直接将其缩减进单一 < 2GB 的文件中...") |
| predictor = predictor.half() |
| |
| if checkpoint_path is None: |
| torch_hub_checkpoints = Path(torch.hub.get_dir()) / "checkpoints" |
| search_paths = [ |
| Path(DEFAULT_CHECKPOINT_NAME), |
| Path("data") / DEFAULT_CHECKPOINT_NAME, |
| torch_hub_checkpoints / DEFAULT_CHECKPOINT_NAME, |
| ] |
| |
| found_path = None |
| for p in search_paths: |
| if p.exists(): |
| found_path = p |
| break |
| |
| if found_path: |
| LOGGER.info(f"在本地找到模型权重: {found_path},加载本地权重...") |
| state_dict = torch.load(found_path, map_location="cpu", weights_only=True) |
| else: |
| state_dict = torch.hub.load_state_dict_from_url(MODEL_URL, progress=True, map_location="cpu") |
| else: |
| state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True) |
| |
| predictor.load_state_dict(state_dict) |
| predictor.eval() |
| predictor.to(device) |
| |
| dummy_image = torch.randn(1, 3, internal_shape, internal_shape, device=device) |
| dummy_disparity = torch.tensor([1.0], device=device) |
| |
| if is_fp16: |
| dummy_image = dummy_image.half() |
| dummy_disparity = dummy_disparity.half() |
| |
| dummy_inputs = (dummy_image, dummy_disparity) |
| |
| LOGGER.info("正在导出为 ONNX 格式...") |
| torch.onnx.export( |
| predictor, |
| dummy_inputs, |
| str(output_path), |
| export_params=True, |
| opset_version=19, |
| do_constant_folding=True, |
| input_names=['image', 'disparity_factor'], |
| output_names=['mean_vectors', 'singular_values', 'quaternions', 'colors', 'opacities'], |
| keep_initializers_as_inputs=False, |
| dynamic_axes={k: {0: 'batch_size'} for k in ['image', 'disparity_factor', 'mean_vectors', 'singular_values', 'quaternions', 'colors', 'opacities']} |
| ) |
| LOGGER.info(f"成功将全精度 ONNX 模型导出至: {output_path}") |
|
|
| def main(): |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") |
| parser = argparse.ArgumentParser(description="将 ML-SHARP 模型导出为多种精度的 ONNX 格式。") |
| parser.add_argument("--output", type=str, default="sharp.onnx", help="导出的基础 ONNX 模型保存路径。") |
| parser.add_argument("--checkpoint", type=str, default=None, help="本地 .pt 权重文件路径。") |
| parser.add_argument("--fp16", action="store_true", help="启用 FP16 半精度导出。") |
| parser.add_argument("--int8", action="store_true", help="应用动态 INT8 量化。") |
| parser.add_argument("--int4", action="store_true", help="应用真正的 INT4 重量级量化 (WOQ)。") |
| args = parser.parse_args() |
|
|
| |
| if args.fp16 and args.output == "sharp.onnx": |
| args.output = "sharp_fp16.onnx" |
| |
| output_path = Path(args.output) |
| is_quant_requested = args.int8 or args.int4 |
| |
| if is_quant_requested and output_path.exists(): |
| LOGGER.info(f"检测到基础模型 '{output_path}' 已存在,跳过导出并开始验证。") |
| verify_model_precision(output_path) |
| else: |
| export_onnx(output_path, args.checkpoint, is_fp16=args.fp16) |
| verify_model_precision(output_path) |
|
|
| if is_quant_requested: |
| try: |
| from onnxruntime.quantization import quantize_dynamic, QuantType |
| import onnx |
| from onnx import shape_inference |
| except ImportError: |
| LOGGER.error("缺失依赖项:请安装 onnx 和 onnxruntime。") |
| return |
|
|
| preprocessed_path = output_path.with_name(f"{output_path.stem}_pre.onnx") |
| if not preprocessed_path.exists(): |
| LOGGER.info("正在将基础模型转换为外部数据格式 (已跳过导致崩溃的形状推导模式)...") |
| try: |
| model_raw = onnx.load(str(output_path)) |
| onnx.save( |
| model_raw, |
| str(preprocessed_path), |
| save_as_external_data=True, |
| all_tensors_to_one_file=True, |
| location=preprocessed_path.name + ".data", |
| size_threshold=1024, |
| convert_attribute=True |
| ) |
| model_source = preprocessed_path |
| except Exception as e: |
| LOGGER.warning(f"预分流模型保存失败: {e}") |
| model_source = output_path |
| else: |
| model_source = preprocessed_path |
|
|
| tasks = [] |
| if args.int8: tasks.append("int8") |
| if args.int4: tasks.append("int4") |
|
|
| |
| int8_path = output_path.with_name(f"{output_path.stem}_int8.onnx") |
| if "int8" in tasks: |
| LOGGER.info("正在应用动态 INT8 量化...") |
| quantize_dynamic(str(model_source), str(int8_path), weight_type=QuantType.QUInt8) |
| verify_model_precision(int8_path) |
|
|
| |
| if "int4" in tasks: |
| LOGGER.info("正在准备进行真正的 INT4 权重压缩 (WOQ)...") |
| quantized_output = output_path.with_name(f"{output_path.stem}_int4.onnx") |
| |
| import onnxruntime.quantization as oq |
| import importlib |
| import pkgutil |
|
|
| |
| quantize_func = getattr(oq, "quantize", None) |
| WeightOnlyConfig = getattr(oq, "MatMulWeightOnlyQuantConfig", None) |
| |
| if quantize_func and WeightOnlyConfig: |
| LOGGER.info("[INT4] 使用 MatMulWeightOnlyQuantConfig 执行量化...") |
| try: |
| model = onnx.load(str(model_source)) |
| config = WeightOnlyConfig(weight_type=getattr(QuantType, 'QInt4', QuantType.QUInt8)) |
| quantize_func(model, str(quantized_output), config) |
| LOGGER.info(f"[INT4] Unified API 导出成功!") |
| verify_model_precision(quantized_output) |
| return |
| except Exception as e: |
| LOGGER.warning(f"[INT4] Unified API 失败: {e}") |
|
|
| |
| Bnb4Quantizer = None |
| try: |
| from onnxruntime.quantization.matmul_bnb4_quantizer import MatMulBnb4Quantizer |
| Bnb4Quantizer = MatMulBnb4Quantizer |
| except: pass |
|
|
| if Bnb4Quantizer: |
| LOGGER.info("[INT4] 使用 MatMulBnb4Quantizer (BnB 4-bit) 执行量化...") |
| try: |
| import inspect |
| model = onnx.load(str(model_source)) |
| sig = inspect.signature(Bnb4Quantizer.__init__) |
| params = sig.parameters |
| kwargs = {"model": model, "block_size": 32} |
| if "quant_type" in params: |
| kwargs["quant_type"] = 1 |
| |
| q = Bnb4Quantizer(**kwargs) |
| q.process() |
| m = q.model.model if hasattr(q, 'model') and hasattr(q.model, 'model') else (q.model if hasattr(q, 'model') else q) |
| LOGGER.info(f"正在保存单体 INT4 模型...") |
| onnx.save(m, str(quantized_output)) |
| LOGGER.info(f"[INT4] BnB 导出成功!") |
| verify_model_precision(quantized_output) |
| return |
| except Exception as e: |
| LOGGER.warning(f"[INT4] BnB 失败: {e}") |
|
|
| |
| def search_legacy(pkg): |
| for attr in dir(pkg): |
| low = attr.lower() |
| if ("quantiz" in low or "quantis" in low) and ("4bit" in low or "nbit" in low or "weight" in low): |
| return getattr(pkg, attr), pkg.__name__ |
| if hasattr(pkg, "__path__"): |
| for _, modname, _ in pkgutil.walk_packages(pkg.__path__, pkg.__name__ + "."): |
| try: |
| mod = importlib.import_module(modname) |
| res, src = search_legacy(mod) |
| if res: return res, src |
| except: continue |
| return None, None |
|
|
| LegacyQuantizer, src = search_legacy(oq) |
| if LegacyQuantizer and (isinstance(LegacyQuantizer, type) or callable(LegacyQuantizer)): |
| LOGGER.info(f"[INT4] 找到量化类/函数: {src}.{LegacyQuantizer.__name__}") |
| try: |
| model = onnx.load(str(model_source)) |
| if "Quantizer" in LegacyQuantizer.__name__: |
| q = LegacyQuantizer(model, block_size=32, is_symmetric=True) |
| q.process() |
| m = q.model.model if hasattr(q, 'model') and hasattr(q.model, 'model') else (q.model if hasattr(q, 'model') else q) |
| onnx.save(m, str(quantized_output)) |
| else: |
| LegacyQuantizer(model, str(quantized_output), weight_type=QuantType.QInt4) |
| LOGGER.info(f"[INT4] 导出成功!") |
| verify_model_precision(quantized_output) |
| except Exception as e: |
| LOGGER.error(f"[INT4] 执行失败: {e}") |
| else: |
| LOGGER.error(f"[INT4] 未找到有效量化器。所有属性: {dir(oq)}") |
|
|
| if __name__ == "__main__": |
| main() |
|
|