jaeikkim commited on
Commit
88f06d8
·
1 Parent(s): c0c8614

ti2ti model

Browse files
MMaDA/inference/gradio_multimodal_demo_inst.py CHANGED
@@ -1079,8 +1079,10 @@ class OmadaDemo:
1079
  if not prompt or not prompt.strip():
1080
  return None, "Please provide a text prompt."
1081
 
 
 
1082
  image_tokens = torch.full(
1083
- (1, self.image_seq_len),
1084
  self.mask_token_id,
1085
  dtype=torch.long,
1086
  device=self.device,
@@ -1107,7 +1109,7 @@ class OmadaDemo:
1107
  temperature=float(temperature),
1108
  timesteps=int(timesteps),
1109
  noise_schedule=self.image_noise_schedule,
1110
- seq_len=self.image_seq_len,
1111
  mask_token_id=self.mask_token_id,
1112
  codebook_size=self.codebook_size,
1113
  uni_prompting=self.uni_prompting,
@@ -2056,7 +2058,7 @@ def build_demo(app: OmadaDemo, share: bool, server_name: str, server_port: Optio
2056
  label="Sub-mode",
2057
  )
2058
  with gr.Accordion("Generation settings", open=True, elem_classes=["omada-advanced"], visible=True) as t2i_settings:
2059
- t2i_timesteps = gr.Slider(4, 128, value=32, label="Timesteps", step=2)
2060
  t2i_temperature = gr.Slider(0.0, 2.0, value=1.0, label="Sampling temperature", step=0.05)
2061
  t2i_guidance = gr.Slider(0.0, 8.0, value=3.5, label="CFG scale", step=0.1)
2062
  with gr.Accordion("Editing settings", open=True, elem_classes=["omada-advanced"], visible=False) as i2i_settings:
 
1079
  if not prompt or not prompt.strip():
1080
  return None, "Please provide a text prompt."
1081
 
1082
+ image_seq_len = 729
1083
+
1084
  image_tokens = torch.full(
1085
+ (1, image_seq_len),
1086
  self.mask_token_id,
1087
  dtype=torch.long,
1088
  device=self.device,
 
1109
  temperature=float(temperature),
1110
  timesteps=int(timesteps),
1111
  noise_schedule=self.image_noise_schedule,
1112
+ seq_len=image_seq_len,
1113
  mask_token_id=self.mask_token_id,
1114
  codebook_size=self.codebook_size,
1115
  uni_prompting=self.uni_prompting,
 
2058
  label="Sub-mode",
2059
  )
2060
  with gr.Accordion("Generation settings", open=True, elem_classes=["omada-advanced"], visible=True) as t2i_settings:
2061
+ t2i_timesteps = gr.Slider(4, 128, value=64, label="Timesteps", step=2)
2062
  t2i_temperature = gr.Slider(0.0, 2.0, value=1.0, label="Sampling temperature", step=0.05)
2063
  t2i_guidance = gr.Slider(0.0, 8.0, value=3.5, label="CFG scale", step=0.1)
2064
  with gr.Accordion("Editing settings", open=True, elem_classes=["omada-advanced"], visible=False) as i2i_settings:
app.py CHANGED
@@ -123,6 +123,17 @@ def download_style() -> Path:
123
 
124
  def download_checkpoint() -> Path:
125
  """Download checkpoint snapshot and return an `unwrapped_model` directory."""
 
 
 
 
 
 
 
 
 
 
 
126
  repo_id = os.getenv("MODEL_REPO_ID", "jaeikkim/AIDAS-Omni-Modal-Diffusion")
127
  revision = os.getenv("MODEL_REVISION", "main")
128
  token = os.getenv("HF_TOKEN")
@@ -779,4 +790,4 @@ with gr.Blocks(
779
 
780
 
781
  if __name__ == "__main__":
782
- demo.launch()
 
123
 
124
  def download_checkpoint() -> Path:
125
  """Download checkpoint snapshot and return an `unwrapped_model` directory."""
126
+ local_override = os.getenv("MODEL_CHECKPOINT_PATH")
127
+ if local_override:
128
+ override_path = Path(local_override).expanduser()
129
+ if override_path.name != "unwrapped_model":
130
+ nested = override_path / "unwrapped_model"
131
+ if nested.is_dir():
132
+ override_path = nested
133
+ if not override_path.exists():
134
+ raise FileNotFoundError(f"MODEL_CHECKPOINT_PATH does not exist: {override_path}")
135
+ return override_path
136
+
137
  repo_id = os.getenv("MODEL_REPO_ID", "jaeikkim/AIDAS-Omni-Modal-Diffusion")
138
  revision = os.getenv("MODEL_REVISION", "main")
139
  token = os.getenv("HF_TOKEN")
 
790
 
791
 
792
  if __name__ == "__main__":
793
+ demo.launch()