Spaces:
Running
on
Zero
Running
on
Zero
ti2ti model
Browse files- MMaDA/inference/gradio_multimodal_demo_inst.py +5 -3
- app.py +12 -1
MMaDA/inference/gradio_multimodal_demo_inst.py
CHANGED
|
@@ -1079,8 +1079,10 @@ class OmadaDemo:
|
|
| 1079 |
if not prompt or not prompt.strip():
|
| 1080 |
return None, "Please provide a text prompt."
|
| 1081 |
|
|
|
|
|
|
|
| 1082 |
image_tokens = torch.full(
|
| 1083 |
-
(1,
|
| 1084 |
self.mask_token_id,
|
| 1085 |
dtype=torch.long,
|
| 1086 |
device=self.device,
|
|
@@ -1107,7 +1109,7 @@ class OmadaDemo:
|
|
| 1107 |
temperature=float(temperature),
|
| 1108 |
timesteps=int(timesteps),
|
| 1109 |
noise_schedule=self.image_noise_schedule,
|
| 1110 |
-
seq_len=
|
| 1111 |
mask_token_id=self.mask_token_id,
|
| 1112 |
codebook_size=self.codebook_size,
|
| 1113 |
uni_prompting=self.uni_prompting,
|
|
@@ -2056,7 +2058,7 @@ def build_demo(app: OmadaDemo, share: bool, server_name: str, server_port: Optio
|
|
| 2056 |
label="Sub-mode",
|
| 2057 |
)
|
| 2058 |
with gr.Accordion("Generation settings", open=True, elem_classes=["omada-advanced"], visible=True) as t2i_settings:
|
| 2059 |
-
t2i_timesteps = gr.Slider(4, 128, value=
|
| 2060 |
t2i_temperature = gr.Slider(0.0, 2.0, value=1.0, label="Sampling temperature", step=0.05)
|
| 2061 |
t2i_guidance = gr.Slider(0.0, 8.0, value=3.5, label="CFG scale", step=0.1)
|
| 2062 |
with gr.Accordion("Editing settings", open=True, elem_classes=["omada-advanced"], visible=False) as i2i_settings:
|
|
|
|
| 1079 |
if not prompt or not prompt.strip():
|
| 1080 |
return None, "Please provide a text prompt."
|
| 1081 |
|
| 1082 |
+
image_seq_len = 729
|
| 1083 |
+
|
| 1084 |
image_tokens = torch.full(
|
| 1085 |
+
(1, image_seq_len),
|
| 1086 |
self.mask_token_id,
|
| 1087 |
dtype=torch.long,
|
| 1088 |
device=self.device,
|
|
|
|
| 1109 |
temperature=float(temperature),
|
| 1110 |
timesteps=int(timesteps),
|
| 1111 |
noise_schedule=self.image_noise_schedule,
|
| 1112 |
+
seq_len=image_seq_len,
|
| 1113 |
mask_token_id=self.mask_token_id,
|
| 1114 |
codebook_size=self.codebook_size,
|
| 1115 |
uni_prompting=self.uni_prompting,
|
|
|
|
| 2058 |
label="Sub-mode",
|
| 2059 |
)
|
| 2060 |
with gr.Accordion("Generation settings", open=True, elem_classes=["omada-advanced"], visible=True) as t2i_settings:
|
| 2061 |
+
t2i_timesteps = gr.Slider(4, 128, value=64, label="Timesteps", step=2)
|
| 2062 |
t2i_temperature = gr.Slider(0.0, 2.0, value=1.0, label="Sampling temperature", step=0.05)
|
| 2063 |
t2i_guidance = gr.Slider(0.0, 8.0, value=3.5, label="CFG scale", step=0.1)
|
| 2064 |
with gr.Accordion("Editing settings", open=True, elem_classes=["omada-advanced"], visible=False) as i2i_settings:
|
app.py
CHANGED
|
@@ -123,6 +123,17 @@ def download_style() -> Path:
|
|
| 123 |
|
| 124 |
def download_checkpoint() -> Path:
|
| 125 |
"""Download checkpoint snapshot and return an `unwrapped_model` directory."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
repo_id = os.getenv("MODEL_REPO_ID", "jaeikkim/AIDAS-Omni-Modal-Diffusion")
|
| 127 |
revision = os.getenv("MODEL_REVISION", "main")
|
| 128 |
token = os.getenv("HF_TOKEN")
|
|
@@ -779,4 +790,4 @@ with gr.Blocks(
|
|
| 779 |
|
| 780 |
|
| 781 |
if __name__ == "__main__":
|
| 782 |
-
demo.launch()
|
|
|
|
| 123 |
|
| 124 |
def download_checkpoint() -> Path:
|
| 125 |
"""Download checkpoint snapshot and return an `unwrapped_model` directory."""
|
| 126 |
+
local_override = os.getenv("MODEL_CHECKPOINT_PATH")
|
| 127 |
+
if local_override:
|
| 128 |
+
override_path = Path(local_override).expanduser()
|
| 129 |
+
if override_path.name != "unwrapped_model":
|
| 130 |
+
nested = override_path / "unwrapped_model"
|
| 131 |
+
if nested.is_dir():
|
| 132 |
+
override_path = nested
|
| 133 |
+
if not override_path.exists():
|
| 134 |
+
raise FileNotFoundError(f"MODEL_CHECKPOINT_PATH does not exist: {override_path}")
|
| 135 |
+
return override_path
|
| 136 |
+
|
| 137 |
repo_id = os.getenv("MODEL_REPO_ID", "jaeikkim/AIDAS-Omni-Modal-Diffusion")
|
| 138 |
revision = os.getenv("MODEL_REVISION", "main")
|
| 139 |
token = os.getenv("HF_TOKEN")
|
|
|
|
| 790 |
|
| 791 |
|
| 792 |
if __name__ == "__main__":
|
| 793 |
+
demo.launch()
|