Spaces:
Runtime error
Runtime error
Commit
·
a2b7758
1
Parent(s):
b637f0b
bug fixes: logging of model loading
Browse files
app.py
CHANGED
|
@@ -17,7 +17,7 @@ from models import (
|
|
| 17 |
|
| 18 |
model_labels_list = list(model_labels)
|
| 19 |
|
| 20 |
-
#
|
| 21 |
models = []
|
| 22 |
for preset in model_presets:
|
| 23 |
model = load_model(preset)
|
|
@@ -32,7 +32,7 @@ for preset in model_presets:
|
|
| 32 |
# model = keras_hub.models.Llama3CausalLM.from_preset(
|
| 33 |
# "hf://meta-llama/Llama-3.2-1B-Instruct", dtype="bfloat16"
|
| 34 |
# )
|
| 35 |
-
# models = [model, model]
|
| 36 |
|
| 37 |
|
| 38 |
def chat_turn_assistant_1(
|
|
@@ -170,7 +170,7 @@ with gr.Blocks(fill_width=True, title="Keras demo") as demo:
|
|
| 170 |
gr.HTML(
|
| 171 |
"<H2> Battle of the Keras chatbots on TPU</H2>"
|
| 172 |
+ "All the models are loaded into the TPU memory. "
|
| 173 |
-
+ "You can call
|
| 174 |
+ "The entire chat history is fed to the models at every submission."
|
| 175 |
+ "This demno is runnig on a Google TPU v5e 2x4 (8 cores).",
|
| 176 |
)
|
|
|
|
| 17 |
|
| 18 |
model_labels_list = list(model_labels)
|
| 19 |
|
| 20 |
+
# load and warm up (compile) all the models
|
| 21 |
models = []
|
| 22 |
for preset in model_presets:
|
| 23 |
model = load_model(preset)
|
|
|
|
| 32 |
# model = keras_hub.models.Llama3CausalLM.from_preset(
|
| 33 |
# "hf://meta-llama/Llama-3.2-1B-Instruct", dtype="bfloat16"
|
| 34 |
# )
|
| 35 |
+
# models = [model, model, model, model, model]
|
| 36 |
|
| 37 |
|
| 38 |
def chat_turn_assistant_1(
|
|
|
|
| 170 |
gr.HTML(
|
| 171 |
"<H2> Battle of the Keras chatbots on TPU</H2>"
|
| 172 |
+ "All the models are loaded into the TPU memory. "
|
| 173 |
+
+ "You can call any of them and compare their answers. <br/>"
|
| 174 |
+ "The entire chat history is fed to the models at every submission."
|
| 175 |
+ "This demno is runnig on a Google TPU v5e 2x4 (8 cores).",
|
| 176 |
)
|
models.py
CHANGED
|
@@ -39,24 +39,25 @@ def get_default_layout_map(preset_name, device_mesh):
|
|
| 39 |
|
| 40 |
|
| 41 |
def log_applied_layout_map(model):
|
| 42 |
-
if "Gemma" in type(model):
|
| 43 |
transformer_decoder_block_name = "decoder_block_1"
|
| 44 |
-
elif "Llama3" in type(model) or "Mistral" in type(model):
|
| 45 |
transformer_decoder_block_name = "transformer_layer_1"
|
| 46 |
else:
|
| 47 |
assert (0, "Model type not recognized. Cannot display model layout.")
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
|
|
|
| 60 |
|
| 61 |
|
| 62 |
def load_model(preset):
|
|
|
|
| 39 |
|
| 40 |
|
| 41 |
def log_applied_layout_map(model):
|
| 42 |
+
if "Gemma" in type(model).__name__:
|
| 43 |
transformer_decoder_block_name = "decoder_block_1"
|
| 44 |
+
elif "Llama3" in type(model).__name__ or "Mistral" in type(model).__name__:
|
| 45 |
transformer_decoder_block_name = "transformer_layer_1"
|
| 46 |
else:
|
| 47 |
assert (0, "Model type not recognized. Cannot display model layout.")
|
| 48 |
+
|
| 49 |
+
# See how layer sharding was applied
|
| 50 |
+
embedding_layer = model.backbone.get_layer("token_embedding")
|
| 51 |
+
print(embedding_layer)
|
| 52 |
+
decoder_block = model.backbone.get_layer(transformer_decoder_block_name)
|
| 53 |
+
print(type(decoder_block))
|
| 54 |
+
for variable in embedding_layer.weights + decoder_block.weights:
|
| 55 |
+
print(
|
| 56 |
+
f"{variable.path:<58} \
|
| 57 |
+
{str(variable.shape):<16} \
|
| 58 |
+
{str(variable.value.sharding.spec):<35} \
|
| 59 |
+
{str(variable.dtype)}"
|
| 60 |
+
)
|
| 61 |
|
| 62 |
|
| 63 |
def load_model(preset):
|