Upload merge_sharded_safetensors.py
Browse files- merge_sharded_safetensors.py +58 -0
merge_sharded_safetensors.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from safetensors import safe_open
|
| 3 |
+
from safetensors.torch import save_file
|
| 4 |
+
import torch # Needed for torch.cat
|
| 5 |
+
|
| 6 |
+
def merge_safetensor_files(sftsr_files, output_file="model.safetensors"):
|
| 7 |
+
slices_dict = {}
|
| 8 |
+
metadata = {}
|
| 9 |
+
|
| 10 |
+
for idx, file in enumerate(sftsr_files):
|
| 11 |
+
with safe_open(file, framework="pt") as sf_tsr:
|
| 12 |
+
if idx == 0:
|
| 13 |
+
metadata = sf_tsr.metadata()
|
| 14 |
+
for key in sf_tsr.keys():
|
| 15 |
+
tensor = sf_tsr.get_tensor(key)
|
| 16 |
+
if key not in slices_dict:
|
| 17 |
+
slices_dict[key] = []
|
| 18 |
+
slices_dict[key].append(tensor)
|
| 19 |
+
|
| 20 |
+
merged_tensors = {}
|
| 21 |
+
for key, slices in slices_dict.items():
|
| 22 |
+
if len(slices) == 1:
|
| 23 |
+
merged_tensors[key] = slices[0]
|
| 24 |
+
else:
|
| 25 |
+
# Simple heuristic: find dim with mismatched size
|
| 26 |
+
ref_shape = slices[0].shape
|
| 27 |
+
concat_dim = None
|
| 28 |
+
for dim in range(len(ref_shape)):
|
| 29 |
+
dim_sizes = [s.shape[dim] for s in slices]
|
| 30 |
+
if len(set(dim_sizes)) > 1:
|
| 31 |
+
concat_dim = dim
|
| 32 |
+
break
|
| 33 |
+
if concat_dim is None:
|
| 34 |
+
concat_dim = 0 # fallback
|
| 35 |
+
merged_tensors[key] = torch.cat(slices, dim=concat_dim)
|
| 36 |
+
print(f"Merged key '{key}' from {len(slices)} slices along dim {concat_dim}")
|
| 37 |
+
|
| 38 |
+
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
| 39 |
+
save_file(merged_tensors, output_file, metadata)
|
| 40 |
+
print(f"Merged {len(sftsr_files)} shards into {output_file}")
|
| 41 |
+
|
| 42 |
+
def get_safetensor_files(directory):
|
| 43 |
+
safetensors_files = []
|
| 44 |
+
for root, _, files in os.walk(directory):
|
| 45 |
+
for file in files:
|
| 46 |
+
if file.endswith(".safetensors"):
|
| 47 |
+
safetensors_files.append(os.path.join(root, file))
|
| 48 |
+
return safetensors_files
|
| 49 |
+
|
| 50 |
+
if __name__ == "__main__":
|
| 51 |
+
safetensor_files = get_safetensor_files("./shards")
|
| 52 |
+
print(f"The following shards/chunks will be merged: {safetensor_files}")
|
| 53 |
+
|
| 54 |
+
default_output = "./output/merged_model.safetensors"
|
| 55 |
+
user_output = input(f"Enter output file path [{default_output}]: ").strip()
|
| 56 |
+
output_file = user_output if user_output else default_output
|
| 57 |
+
|
| 58 |
+
merge_safetensor_files(safetensor_files, output_file=output_file)
|