|
|
|
import os |
|
import sys |
|
sys.path.append(os.path.dirname(os.getcwd()) + "/python_lib") |
|
print( os.path.dirname(os.getcwd()) + "/python_lib" ) |
|
import midigpt |
|
import time |
|
import json |
|
import numpy as np |
|
import torch |
|
import torch.quantization |
|
from transformers import GPT2LMHeadModel, GPT2Config |
|
from transformers.modeling_utils import Conv1D |
|
|
|
from custom_models import * |
|
|
|
from torch import nn |
|
|
|
class QuantWrapper(nn.Module): |
|
def __init__(self, module): |
|
super(QuantWrapper, self).__init__() |
|
qconfig = module.qconfig if hasattr(module, 'qconfig') else None |
|
self.add_module('quant', torch.quantization.QuantStub(qconfig)) |
|
self.add_module('dequant', torch.quantization.DeQuantStub()) |
|
self.add_module('module', module) |
|
self.train(module.training) |
|
|
|
def forward(self, X, P): |
|
X = self.quant(X) |
|
P = self.quant(P) |
|
O = self.module(X,P) |
|
return self.dequant(O) |
|
|
|
def _conv1d_to_linear(module): |
|
in_size, out_size = module.weight.shape |
|
linear = torch.nn.Linear(in_size, out_size) |
|
linear.weight.data = module.weight.data.T.contiguous() |
|
linear.bias.data = module.bias.data |
|
return linear |
|
|
|
def conv1d_to_linear(model): |
|
for name in list(model._modules): |
|
module = model._modules[name] |
|
if isinstance(module, Conv1D): |
|
linear = _conv1d_to_linear(module) |
|
model._modules[name] = linear |
|
else: |
|
conv1d_to_linear(module) |
|
|
|
def score_model(model): |
|
targets = np.load("target.npz")["data"] |
|
|
|
|
|
def time_model(model): |
|
start = time.time() |
|
pkv = None |
|
for _ in range(1000): |
|
input_ids = torch.ones(1,1).type(torch.LongTensor) |
|
outputs = model(input_ids, past_key_values=pkv) |
|
pkv = outputs[1] |
|
print("BATCH TIME : {}".format(time.time() - start)) |
|
|
|
def print_size_of_model(model): |
|
import os |
|
torch.save(model.state_dict(), "temp.p") |
|
print('Size (MB):', os.path.getsize("temp.p")/1e6) |
|
os.remove('temp.p') |
|
|
|
def quantize_model(model): |
|
conv1d_to_linear(model) |
|
model = torch.quantization.quantize_dynamic( |
|
model, {torch.nn.Linear}, dtype=torch.qint8) |
|
return model |
|
|
|
def static_quantize_model(model): |
|
conv1d_to_linear(model) |
|
model.qconfig = torch.quantization.default_qconfig |
|
torch.quantization.prepare(model, inplace=True) |
|
torch.quantization.convert(model, inplace=True) |
|
|
|
return model |
|
|
|
def prune_model(model): |
|
import torch.nn.utils.prune as prune |
|
|
|
conv1d_to_linear(model) |
|
parameters_to_prune = [] |
|
for _,module in model.named_modules(): |
|
if isinstance(module, torch.nn.Linear): |
|
prune.l1_unstructured(module, name="weight", amount=.8) |
|
prune.remove(module, "weight") |
|
|
|
for _,submodule in module.named_modules(): |
|
if isinstance(submodule, torch.nn.Linear): |
|
prune.l1_unstructured(submodule, name="weight", amount=.8) |
|
prune.remove(submodule, "weight") |
|
|
|
return model |
|
|
|
def inject_metadata(path, metadata_path, encoder, new_state): |
|
model = torch.jit.load(path) |
|
with open(metadata_path, "r") as f: |
|
metadata = json.load(f) |
|
metadata["encoder"] = encoder |
|
metadata["new_state"] = new_state |
|
extra_files = torch._C.ExtraFilesMap() |
|
extra_files['metadata.json'] = json.dumps(metadata) |
|
out_path = os.path.splitext(path)[0] + "_WMETA.pt" |
|
torch.jit.save(model, out_path, _extra_files=extra_files) |
|
|
|
def convert(model, path, quantize=False, prune=False, force=False, control=False, ckpt_path=None, encoderX=None): |
|
if not os.path.exists(path) or force: |
|
model.eval() |
|
if quantize: |
|
model = quantize_model(model) |
|
if prune: |
|
model = prune_model(model) |
|
print_size_of_model(model) |
|
example_input = torch.zeros(1,300).type(torch.LongTensor) |
|
example_control = torch.zeros(1,300,3).type(torch.FloatTensor) |
|
if control: |
|
outputs = model(input_ids=example_input, control_ids=example_control, past_key_values=None) |
|
print(len(outputs[1])) |
|
traced_script_module = torch.jit.trace(model, [example_input,example_control,outputs[1]]) |
|
else: |
|
outputs = model(input_ids=example_input) |
|
traced_script_module = torch.jit.trace(model, [example_input, outputs[1]]) |
|
|
|
num_layers = len(outputs[1]) |
|
_,num_heads,_,num_hidden = outputs[1][0][0].detach().numpy().shape |
|
encoder = encoderX |
|
|
|
model_metadata = { |
|
"encoder" : encoder, |
|
"num_heads" : num_heads, |
|
"num_hidden" : num_hidden, |
|
"num_layers" : num_layers, |
|
"model_dim" : -1, |
|
"new_state" : True |
|
} |
|
|
|
print(model_metadata) |
|
|
|
extra_files = {} |
|
extra_files['metadata.json'] = json.dumps(model_metadata) |
|
torch.jit.save( |
|
traced_script_module, path, _extra_files=extra_files) |
|
|
|
|
|
class GPT2LMHeadModelWMeta(GPT2LMHeadModel): |
|
def extra_repr(self): |
|
return "trent is the man" |
|
|
|
if __name__ == "__main__": |
|
|
|
import argparse |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--ckpt_path", type=str, required=True) |
|
parser.add_argument("--output", type=str, default="") |
|
parser.add_argument("--metadata_path", type=str, default="") |
|
parser.add_argument("--config", type=str, default="") |
|
parser.add_argument("--encoder", type=str, default="NONE") |
|
parser.add_argument("--init", action="store_true") |
|
parser.add_argument("--inject", action="store_true") |
|
parser.add_argument("--new_state", action="store_true") |
|
parser.add_argument("--quantize", action="store_true") |
|
parser.add_argument("--prune", action="store_true") |
|
parser.add_argument("--control", action="store_true") |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.inject: |
|
assert len(args.metadata_path) |
|
inject_metadata( |
|
args.ckpt_path, args.metadata_path, args.encoder, True if args.new_state else False) |
|
|
|
else: |
|
assert len(args.output) |
|
if args.init: |
|
encoder_mode = midigpt.getEncoderType(args.encoder) |
|
assert encoder_mode is not midigpt.ENCODER_TYPE.NO_ENCODER |
|
encoder = midigpt.getEncoder(encoder_mode) |
|
vocab_size = encoder.vocab_size() |
|
if args.control: |
|
config = GPT2LMHeadModelContConfig().from_json_file(args.config) |
|
|
|
config.n_control_dim = encoder.config.embed_dim |
|
model_cls = GPT2LMHeadModelCont |
|
|
|
else: |
|
config = GPT2Config().from_json_file(args.config) |
|
config.vocab_size = vocab_size |
|
model_cls = GPT2LMHeadModel |
|
|
|
model = model_cls(config) |
|
else: |
|
if args.control: |
|
model = GPT2LMHeadModelCont.from_pretrained(args.ckpt_path, torchscript=True) |
|
else: |
|
model = GPT2LMHeadModel.from_pretrained(args.ckpt_path, torchscript=True) |
|
|
|
convert(model, args.output, quantize=args.quantize, prune=args.prune, control=args.control, ckpt_path=args.ckpt_path, encoderX=args.encoder) |
|
|
|
|