|
|
|
|
|
|
|
class Conversation { |
|
constructor(config) { |
|
this.system = config.system; |
|
this.roles = config.roles; |
|
this.offset = config.offset; |
|
this.seps = config.seps; |
|
this.convId = null; |
|
this.contextWindowStart = 0; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
getPromptArray() { |
|
if (this.seps.length == 0) { |
|
throw Error("Need seps to work") |
|
} |
|
let ret = [this.system + this.seps[0]]; |
|
|
|
for (let i = 0; i < tvmjsGlobalEnv.workerHistoryMsg.length; ++i) { |
|
const item = tvmjsGlobalEnv.workerHistoryMsg[i]; |
|
const role = item[0]; |
|
const message = item[1]; |
|
if (message !== undefined && message != "") { |
|
ret.push(role + ": " + message + this.seps[i % this.seps.length]); |
|
} else { |
|
ret.push(role + ":"); |
|
} |
|
} |
|
return ret; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
getPromptArrayUnproccessed() { |
|
if (this.seps.length == 0) { |
|
throw Error("Need seps to work") |
|
} |
|
if (tvmjsGlobalEnv.workerHistoryMsg.length < 3) { |
|
throw Error("needs to call getLastPromptArray for the first message"); |
|
} |
|
let ret = [this.seps[this.seps.length - 1]]; |
|
for (let i = tvmjsGlobalEnv.workerHistoryMsg.length - 2; i < tvmjsGlobalEnv.workerHistoryMsg.length; ++i) { |
|
const item = tvmjsGlobalEnv.workerHistoryMsg[i]; |
|
const role = item[0]; |
|
const message = item[1]; |
|
if (message !== undefined && message != "") { |
|
ret.push(role + ": " + message + this.seps[i % this.seps.length]); |
|
} else { |
|
ret.push(role + ":"); |
|
} |
|
} |
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
getLastPromptArray() { |
|
if (this.seps.length == 0) { |
|
throw Error("Need seps to work") |
|
} |
|
let ret = [this.system + this.seps[0]]; |
|
|
|
for (let i = tvmjsGlobalEnv.workerHistoryMsg.length - 2; i < tvmjsGlobalEnv.workerHistoryMsg.length; ++i) { |
|
const item = tvmjsGlobalEnv.workerHistoryMsg[i]; |
|
const role = item[0]; |
|
const message = item[1]; |
|
if (message !== undefined && message != "") { |
|
ret.push(role + ": " + message + this.seps[i % this.seps.length]); |
|
} else { |
|
ret.push(role + ":"); |
|
} |
|
} |
|
return ret; |
|
} |
|
|
|
reset() { |
|
tvmjsGlobalEnv.workerHistoryMsg = []; |
|
this.covId = null |
|
} |
|
|
|
getStopStr() { |
|
return this.seps[this.seps.length - 1]; |
|
} |
|
|
|
appendMessage(role, message) { |
|
tvmjsGlobalEnv.workerHistoryMsg.push([role, message]); |
|
} |
|
|
|
switchConversation(message) { |
|
tvmjsGlobalEnv.workerHistoryMsg = message |
|
this.covId = tvmjsGlobalEnv.covId |
|
} |
|
} |
|
|
|
function defaultConversation(maxWindowLength = 2048) { |
|
return new Conversation({ |
|
system: "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. Follow the user's instructions carefully. Respond using markdown.", |
|
roles: ["user", "assistant"], |
|
maxWindowLength: maxWindowLength, |
|
offset: 0, |
|
seps: [" ", "</s>"], |
|
}); |
|
}; |
|
|
|
class LLMChatPipeline { |
|
constructor(tvm, tokenizer, cacheMetadata, config) { |
|
if (cacheMetadata == undefined) { |
|
throw Error("Expect cacheMetadata"); |
|
} |
|
this.tvm = tvm; |
|
this.logger = console.log; |
|
this.tokenizer = tokenizer; |
|
this.bosTokenId = 1; |
|
this.eosTokenId = 2; |
|
|
|
this.maxWindowLength = config.maxWindowLength; |
|
this.maxGenLength = config.maxGenLength; |
|
this.meanGenLength = config.meanGenLength; |
|
this.streamInterval = 1; |
|
|
|
this.decodingTotalTime = 0; |
|
this.decodingTotalTokens = 0; |
|
this.encodingTotalTime = 0; |
|
this.encodingTotalTokens = 0; |
|
|
|
this.conversation = defaultConversation(this.maxWindowLength); |
|
|
|
this.device = this.tvm.webgpu(); |
|
this.vm = this.tvm.detachFromCurrentScope( |
|
this.tvm.createVirtualMachine(this.device) |
|
); |
|
this.encoding = this.tvm.detachFromCurrentScope( |
|
this.vm.getFunction("encoding") |
|
); |
|
this.decoding = this.tvm.detachFromCurrentScope( |
|
this.vm.getFunction("decoding") |
|
); |
|
this.params = this.tvm.detachFromCurrentScope( |
|
this.tvm.getParamsFromCache("param", cacheMetadata.ParamSize) |
|
); |
|
const fcreateCache = this.vm.getFunction("create_kv_cache"); |
|
this.fclearKVCaches = this.tvm.detachFromCurrentScope( |
|
this.tvm.getGlobalFunc("vm.builtin.attention_kv_cache_array_clear") |
|
); |
|
|
|
|
|
this.kvCache = this.tvm.detachFromCurrentScope(fcreateCache()); |
|
|
|
this.logitsOnCPU = undefined; |
|
|
|
this.kvCacheLength = 0; |
|
this.clearCache = true |
|
} |
|
|
|
|
|
dispose() { |
|
|
|
this.params.dispose(); |
|
this.decoding.dispose(); |
|
this.encoding.dispose(); |
|
this.vm.dispose(); |
|
this.kvCache.dispose(); |
|
this.fclearKVCaches.dispose(); |
|
if (this.logitsOnCPU != undefined) { |
|
this.logitsOnCPU.dispose(); |
|
} |
|
} |
|
|
|
#clearKVCache() { |
|
this.fclearKVCaches(this.kvCache); |
|
this.kvCacheLength = 0; |
|
} |
|
|
|
#forward(inputs, curPos) { |
|
this.tvm.beginScope(); |
|
var retValue; |
|
const seqLenShape = this.tvm.makeShapeTuple([curPos]); |
|
if (inputs.shape[1] > 1) { |
|
retValue = this.encoding( |
|
inputs, seqLenShape, this.kvCache, this.params |
|
); |
|
} else { |
|
retValue = this.decoding( |
|
inputs, seqLenShape, this.kvCache, this.params |
|
); |
|
} |
|
const logits = this.tvm.detachFromCurrentScope(retValue.get(0)); |
|
this.tvm.endScope(); |
|
this.tvm.attachToCurrentScope(logits); |
|
return logits; |
|
} |
|
|
|
|
|
#updateLogitsOnCPU(logits) { |
|
if (this.logitsOnCPU == undefined) { |
|
this.logitsOnCPU = this.tvm.detachFromCurrentScope( |
|
this.tvm.empty(logits.shape, logits.dtype, this.tvm.cpu()) |
|
); |
|
} else { |
|
if (logits.shape[0] != this.logitsOnCPU.shape[0]) { |
|
throw Error("We expect the size of logits to remain unchanged"); |
|
} |
|
} |
|
this.logitsOnCPU.copyFrom(logits); |
|
} |
|
|
|
async sampleTokenFromLogits(logits, temperature = 0.8, top_p = 0.95) { |
|
this.tvm.beginScope(); |
|
this.#updateLogitsOnCPU(logits); |
|
this.tvm.endScope(); |
|
await this.device.sync(); |
|
return this.tvm.sampleTopPFromLogits(this.logitsOnCPU, temperature, top_p); |
|
} |
|
|
|
async getInputTokens() { |
|
let tokens = [this.bosTokenId]; |
|
let prompts = "" |
|
if (tvmjsGlobalEnv.workerHistoryMsg.length <= 2) { |
|
prompts = this.conversation.getPromptArray(); |
|
} else { |
|
tokens.pop(); |
|
prompts = this.conversation.getPromptArrayUnproccessed(); |
|
} |
|
tokens.push(...await this.tokenizer.encodeIds(prompts[0])); |
|
let ctxLength = tokens.length; |
|
let context = []; |
|
let need_shift_window = false; |
|
for (let i = prompts.length - 1; i > 0; --i) { |
|
const encoded = this.tokenizer.encodeIds(prompts[i]); |
|
ctxLength += encoded.length; |
|
if (this.kvCacheLength + ctxLength + this.meanGenLength >= this.maxWindowLength) { |
|
need_shift_window = true; |
|
break; |
|
} |
|
context.unshift(encoded); |
|
} |
|
if (!need_shift_window) { |
|
for (const ctx of context) { |
|
tokens.push(...ctx); |
|
} |
|
return tokens; |
|
} |
|
|
|
this.logger("need shift window") |
|
this.kvCacheLength = 0; |
|
this.clearCache = true; |
|
|
|
tokens = [this.bosTokenId] |
|
let all_prompts = this.conversation.getPromptArray(); |
|
tokens.push(...await this.tokenizer.encodeIds(all_prompts[0])); |
|
context = []; |
|
ctxLength = tokens.length; |
|
|
|
const fill_factor = 0.1 |
|
for (let i = all_prompts.length - 1; i > 0; --i) { |
|
const encoded = this.tokenizer.encodeIds(all_prompts[i]); |
|
ctxLength += encoded.length; |
|
if (ctxLength >= fill_factor * this.maxWindowLength && i + 2 < all_prompts.length) { |
|
break; |
|
} |
|
context.unshift(encoded); |
|
} |
|
for (const ctx of context) { |
|
tokens.push(...ctx); |
|
} |
|
if (tokens.length + this.meanGenLength >= this.maxWindowLength) { |
|
throw Error("Exceed max window length curr=" + tokens.length); |
|
} |
|
return tokens; |
|
} |
|
|
|
resetChat() { |
|
if (this.conversation) { |
|
this.conversation.reset(); |
|
} |
|
this.#clearKVCache(); |
|
this.decodingTotalTime = 0; |
|
this.encodingTotalTime = 0; |
|
this.decodingTotalTokens = 0; |
|
this.encodingTotalTokens = 0; |
|
} |
|
|
|
async generate(inputPrompt, callbackUpdateResponse) { |
|
|
|
if (this.conversation.convId !== tvmjsGlobalEnv.covId) {} |
|
this.conversation.appendMessage(this.conversation.roles[0], inputPrompt); |
|
this.conversation.appendMessage(this.conversation.roles[1], ""); |
|
const stopStr = this.conversation.getStopStr(); |
|
const tokens = await this.getInputTokens(); |
|
const inputTokenLength = tokens.length; |
|
|
|
var outputPrompt = ""; |
|
if (this.clearCache) { |
|
this.#clearKVCache(); |
|
this.clearCache = false; |
|
} |
|
const maxGenLen = Math.min(this.maxGenLength, this.maxWindowLength - tokens.length); |
|
if (maxGenLen < this.meanGenLength) { |
|
throw Error("Too small window size config"); |
|
} |
|
let step = 0; |
|
for (; step < maxGenLen && this.kvCacheLength + inputTokenLength + step < this.maxWindowLength; ++step) { |
|
this.tvm.beginScope(); |
|
var inputData; |
|
|
|
let tstart = performance.now(); |
|
if (step == 0) { |
|
inputData = this.tvm.empty([1, tokens.length], "int32", this.device); |
|
inputData.copyFrom(tokens); |
|
} else { |
|
inputData = this.tvm.empty([1, 1], "int32", this.device); |
|
inputData.copyFrom(tokens.slice(tokens.length - 1)); |
|
} |
|
const logits = this.tvm.detachFromCurrentScope( |
|
this.#forward(inputData, this.kvCacheLength + inputTokenLength + step) |
|
); |
|
this.tvm.endScope(); |
|
|
|
const nextToken = await this.sampleTokenFromLogits(logits); |
|
logits.dispose(); |
|
|
|
tokens.push(nextToken); |
|
const outputTokens = tokens.slice(inputTokenLength); |
|
outputPrompt = this.tokenizer.decodeIds(outputTokens); |
|
|
|
if (nextToken == this.eosTokenId) break; |
|
|
|
const stopPos = outputPrompt.lastIndexOf(stopStr); |
|
if (stopPos != -1) { |
|
outputPrompt = outputPrompt.substring(0, stopPos); |
|
break; |
|
} |
|
let tend = performance.now(); |
|
if (step != 0) { |
|
this.decodingTotalTokens += 1; |
|
this.decodingTotalTime += (tend - tstart) / 1000; |
|
} else { |
|
this.encodingTotalTime += (tend - tstart) / 1000; |
|
this.encodingTotalTokens += inputTokenLength; |
|
} |
|
|
|
if (step % this.streamInterval == 0) { |
|
callbackUpdateResponse(step, outputPrompt); |
|
} |
|
} |
|
this.kvCacheLength += tokens.length - 1; |
|
tvmjsGlobalEnv.workerHistoryMsg[tvmjsGlobalEnv.workerHistoryMsg.length - 1][1] = outputPrompt; |
|
return outputPrompt; |
|
} |
|
|
|
async evaluate() { |
|
|
|
this.#clearKVCache(); |
|
const testPrompt = "The capital of Canada is"; |
|
const ids = await this.tokenizer.encodeIds(testPrompt); |
|
const inputPromptSize = ids.length; |
|
const tokens = Array.from(ids); |
|
tokens.unshift(this.bosTokenId); |
|
if (tokens.length == 0) { |
|
throw Error("empty token"); |
|
} |
|
|
|
this.tvm.beginScope(); |
|
const inputData = this.tvm.empty([1, tokens.length], "int32", this.device); |
|
inputData.copyFrom(tokens); |
|
const encodingStart = performance.now(); |
|
this.#forward(inputData, tokens.length); |
|
this.tvm.endScope(); |
|
await this.device.sync(); |
|
|
|
const decodingStart = performance.now(); |
|
|
|
this.tvm.beginScope(); |
|
const firstSampleToken = this.tvm.empty([1, 1], "int32", this.device).copyFrom([6234]); |
|
this.#updateLogitsOnCPU(this.#forward(firstSampleToken, tokens.length + 1)); |
|
await this.device.sync(); |
|
this.tvm.endScope(); |
|
|
|
const decodingEnd = performance.now(); |
|
const msg = ( |
|
`encoding-time=${((decodingStart - encodingStart) / 1000).toFixed(4)} sec` + |
|
`decoding-time=${((decodingEnd - decodingStart) / 1000).toFixed(4)} sec` |
|
); |
|
|
|
|
|
console.log("Logits:"); |
|
console.log(this.logitsOnCPU.toArray()); |
|
console.log(msg); |
|
} |
|
|
|
|
|
|
|
|
|
async asyncLoadWebGPUPiplines() { |
|
await this.tvm.asyncLoadWebGPUPiplines(this.vm.getInternalModule()); |
|
} |
|
|
|
runtimeStatsText() { |
|
return ( |
|
`encoding: ${(this.encodingTotalTokens / this.encodingTotalTime).toFixed(4)} tokens/sec, ` + |
|
`decoding: ${(this.decodingTotalTokens / this.decodingTotalTime).toFixed(4)} tokens/sec` |
|
) |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
class LLMChatInstance { |
|
constructor() { |
|
this.requestInProgress = false; |
|
this.config = undefined; |
|
this.tvm = undefined; |
|
this.pipeline = undefined; |
|
this.logger = console.log; |
|
this.debugTest = false; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
async #asyncInitTVM(wasmUrl, cacheUrl) { |
|
if (this.tvm !== undefined) { |
|
return; |
|
} |
|
this.logger = console.log; |
|
|
|
const wasmSource = await ( |
|
await fetch(wasmUrl) |
|
).arrayBuffer(); |
|
const tvm = await tvmjs.instantiate( |
|
new Uint8Array(wasmSource), |
|
new EmccWASI(), |
|
this.logger |
|
); |
|
|
|
try { |
|
const output = await tvmjs.detectGPUDevice(); |
|
if (output !== undefined) { |
|
var label = "WebGPU"; |
|
if (output.adapterInfo.description.length != 0) { |
|
label += " - " + output.adapterInfo.description; |
|
} else { |
|
label += " - " + output.adapterInfo.vendor; |
|
} |
|
this.appendMessage("init", "Initialize GPU device: " + label); |
|
tvm.initWebGPU(output.device); |
|
} else { |
|
this.appendMessage("error", "This browser env do not support WebGPU"); |
|
this.reset(); |
|
throw Error("This browser env do not support WebGPU"); |
|
} |
|
} catch (err) { |
|
this.appendMessage("error", "Find an error initializing the WebGPU device " + err.toString()); |
|
console.log(err); |
|
this.reset(); |
|
throw Error("Find an error initializing WebGPU: " + err.toString()); |
|
} |
|
this.tvm = tvm; |
|
const initProgressCallback = (report) => { |
|
this.updateLastMessage("initing", report.text); |
|
} |
|
tvm.registerInitProgressCallback(initProgressCallback); |
|
|
|
await tvm.fetchNDArrayCache(cacheUrl, tvm.webgpu()); |
|
} |
|
|
|
|
|
|
|
async asyncInit() { |
|
if (this.pipeline !== undefined) return; |
|
await this.#asyncInitConfig(); |
|
await this.#asyncInitTVM(this.config.wasmUrl, this.config.cacheUrl); |
|
await this.#asyncInitPipeline(); |
|
} |
|
|
|
|
|
|
|
|
|
async #asyncInitConfig() { |
|
if (this.config !== undefined) return; |
|
this.config = await (await fetch("/lib/WebLLM/config.json")).json(); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
async #asyncInitPipeline() { |
|
if (this.pipeline !== undefined) return; |
|
|
|
const tokenizer = await tvmjsGlobalEnv.sentencePieceProcessor(this.config.tokenizer); |
|
this.pipeline = this.tvm.withNewScope(() => { |
|
return new LLMChatPipeline(this.tvm, tokenizer, this.tvm.cacheMetadata, this.config); |
|
}); |
|
await this.pipeline.asyncLoadWebGPUPiplines(); |
|
this.appendMessage("initing", "All initialization finished.", true); |
|
} |
|
|
|
appendMessage(kind, text, ifFinish) { |
|
if (kind == "initing") { |
|
text = "[System Initalize] " + text; |
|
} |
|
console.log(`[${kind}] ${text}`); |
|
globalThis.postMessage({ |
|
type: 'initing', |
|
action: 'append', |
|
msg: text, |
|
ifError: kind == 'error', |
|
ifFinish: !!ifFinish |
|
}) |
|
} |
|
|
|
updateLastMessage(type, text, ifFinish) { |
|
if (type == "initing") { |
|
text = `[System Initalize] ${text}` |
|
} |
|
globalThis.postMessage({ |
|
type, |
|
action: 'updateLast', |
|
msg: text, |
|
ifFinish: !!ifFinish |
|
}) |
|
} |
|
|
|
async respondTestMessage(repeat) { |
|
const testMessage = "I am a friendly bot. Please ask questions."; |
|
const encodedResult = await this.pipeline.tokenizer.encodeIds(testMessage); |
|
|
|
const currentIds = []; |
|
for (let k = 0; k < repeat; ++k) { |
|
for (let i = 0; i < encodedResult.length; ++i) { |
|
currentIds.push(encodedResult[i]); |
|
const msg = this.pipeline.tokenizer.decodeIds(currentIds); |
|
this.updateLastMessage("chatting", msg); |
|
await new Promise(resolve => setTimeout(resolve, 50)); |
|
} |
|
} |
|
} |
|
|
|
resetChat() { |
|
if (this.pipeline) { |
|
this.pipeline.resetChat(); |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
async generate() { |
|
if (this.requestInProgress) { |
|
return; |
|
} |
|
|
|
this.requestInProgress = true; |
|
|
|
try { |
|
await this.asyncInit(); |
|
} catch (err) { |
|
this.appendMessage("error", "Init error, " + err.toString()); |
|
console.log(err); |
|
this.reset(); |
|
this.requestInProgress = false; |
|
return; |
|
} |
|
|
|
if (this.debugTest) { |
|
await this.pipeline.evaluate(); |
|
this.requestInProgress = false; |
|
return; |
|
} |
|
|
|
const prompt = tvmjsGlobalEnv.message; |
|
if (prompt == "") { |
|
this.requestInProgress = false; |
|
return; |
|
} |
|
|
|
const callbackUpdateResponse = (step, msg) => { |
|
if (msg.endsWith("##")) { |
|
msg = msg.substring(0, msg.length - 2); |
|
} else if (msg.endsWith("#")) { |
|
msg = msg.substring(0, msg.length - 1); |
|
} |
|
this.updateLastMessage("chatting", msg); |
|
}; |
|
try { |
|
const output = await this.pipeline.generate(prompt, callbackUpdateResponse); |
|
this.updateLastMessage("chatting", output, true); |
|
this.updateLastMessage("stats",this.pipeline.runtimeStatsText()) |
|
console.log(this.pipeline.runtimeStatsText()); |
|
} catch (err) { |
|
this.appendMessage("error", "Generate error, " + err.toString()); |
|
console.log(err); |
|
this.reset(); |
|
} |
|
this.requestInProgress = false; |
|
} |
|
|
|
|
|
|
|
|
|
reset() { |
|
this.tvm = undefined; |
|
if (this.pipeline !== undefined) { |
|
this.pipeline.dispose(); |
|
} |
|
this.pipeline = undefined; |
|
} |
|
} |
|
|
|
localLLMChatIntance = new LLMChatInstance(); |
|
|
|
tvmjsGlobalEnv.asyncOnGenerate = async function () { |
|
await localLLMChatIntance.generate(); |
|
}; |
|
|
|
tvmjsGlobalEnv.asyncOnReset = async function () { |
|
await localLLMChatIntance.resetChat(); |
|
}; |
|
|