Is CUDA required? Could it work on macOS?
I try to use starflow on macOS an get this errors:
bash scripts/setup_conda.sh
Setting up Starflow development environment with conda...
Activating conda environment 'starflow'...
Installing PyTorch and related packages via conda...
Retrieving notices: ...working... done
Channels:
- pytorch
- nvidia
- defaults
Platform: osx-arm64
Collecting package metadata (repodata.json): done
Solving environment: failed
LibMambaUnsatisfiableError: Encountered problems while solving:
- nothing provides cuda 11.6.* needed by pytorch-cuda-11.6-h867d48c_0
Could not solve for environment specs
The following packages are incompatible
ββ pytorch-cuda is not installable because there are no viable options
ββ pytorch-cuda 11.6 would require
β ββ cuda 11.6.* , which does not exist (perhaps a missing channel);
ββ pytorch-cuda 11.7 would require
β ββ cuda 11.7.* , which does not exist (perhaps a missing channel);
ββ pytorch-cuda 11.8 would require
ββ cuda 11.8.* , which does not exist (perhaps a missing channel).
I Already to try and it's can work on Mac .
article (Japanese) : https://qiita.com/syun88/items/20cf758344825e9bea5e
Clone the repository:
git clone https://github.com/apple/ml-starflow cd ml-starflowUse Conda (recommended):
bash scripts/setup_conda.shOr install dependencies manually:
pip install -r requirements.txtPotential error during installation:
You might encounter this error:ERROR: Ignored the following versions that require a different python version: 1.6.2 Requires-Python >=3.7,<3.10; ... ERROR: Could not find a version that satisfies the requirement decord (from versions: none) ERROR: No matching distribution found for decordSolution for macOS:
- Use Python 3.10 for compatibility.
- Replace the original
decordwitheva-decord. - Run the model on CPU/CUDA (not MPS). MPS can cause black images and poor performance, so I excluded it.
- Modify the
sample.pyto choose CUDA if available, otherwise fall back to CPU.
Updated requirements.txt:
jupyter>=1.0.0
transformers
accelerate
torchinfo
einops
scipy
webdataset
sentencepiece
wandb[media]
torchmetrics[image]
simple_parsing
eva-decord
opencv-python
psutil
git+https://github.com/KeKsBoTer/torch-dwt
git+https://github.com/huggingface/diffusers.git
pyyaml
av==12.3.0
Code modifications in sample.py:
@@ -14,6 +14,7 @@ Usage:
"""
import argparse
+import contextlib
import copy
import pathlib
import time
@@ -48,13 +49,17 @@ DEFAULT_CAPTIONS = {
'template5': "A realistic selfie of a llama standing in front of a classic Ivy League building on the Princeton University campus. He is smiling gently, wearing his iconic wild hair and mustache, dressed in a wool sweater and collared shirt. The photo has a vintage, slightly sepia tone, with soft natural lighting and leafy trees in the background, capturing an academic and historical vibe.",
}
-
+def resolve_device() -> torch.device:
+ """Choose the best available device: CUDA -> CPU (explicitly disable MPS)."""
+ if torch.cuda.is_available():
+ return torch.device("cuda")
+ return torch.device("cpu")
def setup_model_and_components(args: argparse.Namespace) -> Tuple[torch.nn.Module, Optional[torch.nn.Module], tuple]:
"""Initialize and load the model, VAE, and text encoder."""
dist = utils.Distributed()
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ device = resolve_device()
# Set random seed
utils.set_random_seed(args.seed + dist.rank)
@@ -81,7 +86,9 @@ def setup_model_and_components(args: argparse.Namespace) -> Tuple[torch.nn.Modul
print(f"Loading checkpoint from local path: {args.checkpoint_path}")
state_dict = torch.load(args.checkpoint_path, map_location='cpu')
model.load_state_dict(state_dict, strict=False)
- del state_dict; torch.cuda.empty_cache()
+ del state_dict
+ if device.type == "cuda":
+ torch.cuda.empty_cache()
# Set model to eval mode and disable gradients
for p in model.parameters():
@@ -190,6 +197,9 @@ def main(args: argparse.Namespace) -> None:
trainer_dict = vars(trainer_args)
trainer_dict.update(vars(args))
args = argparse.Namespace(**trainer_dict)
+ device = resolve_device()
+ if device.type != "cuda":
+ args.fsdp = 0 # CPU/MPS fallback
# Handle target length configuration for video
if args.target_length is not None:
@@ -205,7 +215,8 @@ def main(args: argparse.Namespace) -> None:
args.context_length = args.local_attn_window - 1
# Override some settings for sampling
- args.fsdp = 1 # sampling using FSDP if available.
+ if device.type == "cuda":
+ args.fsdp = 1 # sampling using FSDP if available.
if args.use_pretrained_lm is not None:
args.text = args.use_pretrained_lm
@@ -223,19 +234,24 @@ def main(args: argparse.Namespace) -> None:
# Prepare captions and sampling parameters
fixed_y, fixed_idxs, num_samples, caption_name = prepare_captions(args, dist)
- print(f'Sampling {num_samples} from {args.caption} on {dist.world_size} GPU(s)')
+ print(f'Sampling {num_samples} from {args.caption} on {dist.world_size} device(s) [{device.type}]')
get_noise = get_noise_shape(args, vae)
sampling_kwargs = build_sampling_kwargs(args, caption_name)
noise_std = args.target_noise_std if args.target_noise_std else args.noise_std
# Start sampling
- print(f'Starting sampling with global batch size {args.sample_batch_size}x{dist.world_size} GPUs')
- torch.cuda.synchronize()
+ print(f'Starting sampling with global batch size {args.sample_batch_size}x{dist.world_size} devices')
+ if device.type == "cuda":
+ torch.cuda.synchronize()
start_time = time.time()
with torch.no_grad():
- with torch.autocast(device_type='cuda', dtype=torch.float32):
+ if device.type == "cuda":
+ autocast_ctx = torch.autocast(device_type='cuda', dtype=torch.float32)
+ else:
+ autocast_ctx = contextlib.nullcontext()
+ with autocast_ctx:
for i in tqdm.tqdm(range(int(np.ceil(num_samples / (args.sample_batch_size * dist.world_size))))):
# Determine aspect ratio and image shape
x_aspect = args.aspect_ratio if args.mix_aspect else None
@@ -290,7 +306,9 @@ def main(args: argparse.Namespace) -> None:
# Generate samples
samples = model(noise, y, reverse=True, kv_caches=kv_caches, **sampling_kwargs)
- del kv_caches; torch.cuda.empty_cache() # free up memory
+ del kv_caches
+ if device.type == "cuda":
+ torch.cuda.empty_cache() # free up memory
# Apply denoising if enabled
samples = process_denoising(
@@ -330,7 +348,8 @@ def main(args: argparse.Namespace) -> None:
)
# Print timing statistics
- torch.cuda.synchronize()
+ if device.type == "cuda":
+ torch.cuda.synchronize()
elapsed_time = time.time() - start_time
print(f'{model_name} cfg {args.cfg:.2f}, bsz={args.sample_batch_size}x{dist.world_size}, '
f'time={elapsed_time:.2f}s, speed={num_samples / elapsed_time:.2f} images/s')
dataset.py
@@
-# Initialize multiprocessing manager
-manager = torch.multiprocessing.Manager()
+# Lazy multiprocessing manager; creating at import breaks spawn on macOS
+manager = None
+
+
+def get_mp_manager():
+ """Create or return a global multiprocessing Manager."""
+ global manager
+ if manager is None:
+ manager = torch.multiprocessing.Manager()
+ return manager
@@
class OnlineImageTarDataset(ImageTarDataset):
max_retry_n = 20
max_read = 4096
- tar_keys_lock = manager.Lock() if manager is not None else None
def __init__(self, dataset_tsv, image_size, batch_size=None, **kwargs):
super().__init__(dataset_tsv, image_size, **kwargs)
+
+ mgr = get_mp_manager()
+ self.tar_keys_lock = mgr.Lock() if mgr is not None else None
@@
for key in self.tar_lists.keys():
repeat = int(self.weights.get(key, 1))
self.reset_tar_keys.extend([key] * repeat)
- self.tar_keys = manager.list(self.reset_tar_keys) if manager is not None else list(self.reset_tar_keys)
+ self.tar_keys = mgr.list(self.reset_tar_keys) if mgr is not None else list(self.reset_tar_keys)
@@
def _get_next_key(self):
- with self.tar_keys_lock:
+ lock = self.tar_keys_lock
+ if lock:
+ with lock:
+ if not self.tar_keys or len(self.tar_keys) == 0:
+ xprint(f'[WARN] all dataset exhausted... this should not happen usually')
+ self.tar_keys.extend(list(self.reset_tar_keys)) # reset
+ random.shuffle(self.tar_keys)
+ return self.tar_keys.pop(0) # remove and return the first key
+ else:
if not self.tar_keys or len(self.tar_keys) == 0:
xprint(f'[WARN] all dataset exhausted... this should not happen usually')
self.tar_keys.extend(list(self.reset_tar_keys)) # reset
random.shuffle(self.tar_keys)
- return self.tar_keys.pop(0) # remove and return the first key
+ return self.tar_keys.pop(0)
@@
# shuffle the image list
random.shuffle(self.tar_lists[key]) # shuffle the list
- with self.tar_keys_lock:
- self.tar_keys.append(key) # return the key to the list so other workers can use it
+ if self.tar_keys_lock:
+ with self.tar_keys_lock:
+ self.tar_keys.append(key) # return the key to the list so other workers can use it
+ else:
+ self.tar_keys.append(key)
utils/inference.py
- torch.cuda.empty_cache()
+ # Track the original device (CUDA/MPS/CPU) and clear cache when supported
+ device = samples.device
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ elif torch.backends.mps.is_available():
+ torch.mps.empty_cache()
@@
- x_all = torch.clone(samples[j * db : (j + 1) * db]).detach().cuda()
+ x_all = torch.clone(samples[j * db : (j + 1) * db]).detach().to(device)
@@
- torch.cuda.empty_cache()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ elif torch.backends.mps.is_available():
+ torch.mps.empty_cache()
@@
- return torch.cat(denoised_samples, dim=0).cuda()
+ return torch.cat(denoised_samples, dim=0).to(device)
utils/training.py
- if os.environ.get('MASTER_PORT'): # When running with torchrun
+ use_cuda = torch.cuda.is_available()
+ if os.environ.get('MASTER_PORT'): # When running with torchrun
@@
- torch.distributed.init_process_group(
- backend='nccl',
+ backend = 'nccl' if use_cuda else 'gloo'
+ torch.distributed.init_process_group(
+ backend=backend,
@@
- torch.cuda.set_device(self.local_rank)
+ if use_cuda:
+ torch.cuda.set_device(self.local_rank)
@@
def parallelize_model(args, model: nn.Module, dist: Distributed, device='cuda', block_names=['AttentionBlock']) -> nn.Module:
+ device_type = device.type if hasattr(device, "type") else str(device)
+
+ # FSDP/DP only make sense on CUDA
+ if (not torch.cuda.is_available()) or (device_type != 'cuda'):
+ args.fsdp = 0
+
+ requires_grad_exists = any(p.requires_grad for p in model.parameters())
@@
- if dist.distributed:
+ if dist.distributed and requires_grad_exists and device_type == 'cuda':
print(f"Using DDP")
- model_ddp = torch.nn.parallel.DistributedDataParallel(model, device_ids=[dist.local_rank])
+ ddp_kwargs = {"device_ids": [dist.local_rank]} if device_type == 'cuda' else {"device_ids": None}
+ model_ddp = torch.nn.parallel.DistributedDataParallel(model, **ddp_kwargs)
@@
- torch.cuda.manual_seed_all(seed)
+ torch.cuda.manual_seed_all(seed)
transformer_flow.py
def jacobi(self,
z: torch.Tensor,
y: torch.Tensor | None = None,
guidance: float = 0,
rope=None,
kv_cache=None,
verbose=False,
jacobi_block_size: int = 32,
jacobi_max_iter: int = 32,
jacobi_th: float = 0.001,
context_length: int = None,
**unused_kwargs) -> torch.Tensor:
assert self.use_sos, "Jacobi iteration requires SOS token to be used"
assert self.pos_embed is None, "Jacobi iteration does not support positional embedding"
# Ensure sampling tensors are in float32 for numerical stability
original_dtype = z.dtype
z = z.float()
freqs_cis = self.get_freqs_cis(z, y, rope) if rope is not None else None
if guidance > 0:
z = torch.cat([z, z], 0)
# kv cache
reuse_kv_cache = kv_cache.prefix_cache is not None and kv_cache.kv_index[0] > 0
kv_cache = self.initialize_kv_cache(kv_cache, z, freqs_cis, reuse_kv_cache)
video_length = z.size(1) if z.dim() == 5 else 1
# permute the input
z = self.permutation(z)
# prepare input
x_full = torch.cat([self.get_sos_embed(z), z.clone()], dim=1)
if reuse_kv_cache:
x_full[:, 1: kv_cache.prefix_cache.size(1) + 1] = kv_cache.prefix_cache # fill the prefix cache
# conditioning
if self.txt_dim > 0:
if not reuse_kv_cache:
self.reverse_step_condition(y, kv_cache, freqs_cis=freqs_cis)
txt_size = y.size(1) if self.txt_dim > 0 else 0
video_frame_size = z.size(1) // video_length
start_idx = 0
if reuse_kv_cache:
start_idx = kv_cache.kv_index[0] - txt_size # start from the last cached index
prog_bar = tqdm.tqdm(total=z.size(1), disable=not verbose, desc='Block-wise Jacobi Iteration', leave=False)
prog_bar.update(start_idx)
local_attn_window = self.local_attn_window * video_frame_size if self.local_attn_window is not None else None
target_frame_size = z.size(1) if local_attn_window is None else min(z.size(1), local_attn_window)
context_size = None if local_attn_window is None else context_length * video_frame_size
while target_frame_size <= z.size(1):
while start_idx < target_frame_size:
chunk_size = jacobi_block_size if start_idx <= video_frame_size else jacobi_block_size * 4
local_done = torch.zeros((), dtype=torch.bool, device=x_full.device)
for i in tqdm.tqdm(range(jacobi_max_iter), disable=True, desc='Jacobi Iteration', leave=False):
if start_idx + chunk_size >= target_frame_size:
chunk_size = target_frame_size - start_idx
if i == 0 and start_idx > video_frame_size: # optional to use past frame to initialize the current frame
x = x_full[:, start_idx - video_frame_size: start_idx + chunk_size - video_frame_size]
else:
x = x_full[:, start_idx: start_idx + chunk_size]
# main forward - convert to model dtype for neural network computation
if hasattr(self.proj_in, 'weight'):
target_dtype = self.proj_in.weight.dtype
x = x.to(target_dtype)
x = self.get_proj_in(x)
for it, block in enumerate(self.attn_blocks):
_kv_cache = partial(kv_cache, it) if kv_cache is not None else None
x = block(x, None, freqs_cis=freqs_cis, kv_cache=_kv_cache)[0]
if self.use_final_norm:
x = self.final_norm(x)
x = self.get_proj_out(x)
xa, xb = x.chunk(2, dim=-1)
# Convert back to float32 for sampling computations
xa, xb = xa.float(), xb.float()
if not self.use_softplus:
xa = xa.exp()
else:
xa = F.softplus(xa + INV_SOFTPLUS_1)
if guidance > 0:
xb, xa = self.guidance(xa, xb, guidance, 1.0, 'ab')
# compute the Jacobi Iteration - all in float32
new_x = xb + xa * z[:, start_idx: start_idx+chunk_size]
diff = ((new_x - x_full[:, start_idx+1: start_idx+1+chunk_size]) ** 2).mean() / (new_x ** 2).mean()
x_full[:, start_idx+1: start_idx+1+chunk_size] = new_x
if diff < jacobi_th or i == jacobi_max_iter - 1: # do not clean the cache on the last iteration
local_done.fill_(1)
global_done = local_done.clone()
# Single-process runs (e.g., MPS) might not initialize torch.distributed
if torch.distributed.is_available() and torch.distributed.is_initialized():
torch.distributed.all_reduce(global_done, op=torch.distributed.ReduceOp.MIN)
if int(global_done.item()) == 1:
break
kv_cache.backward_in_time(chunk_size)
start_idx += chunk_size
prog_bar.update(chunk_size)
if target_frame_size >= z.size(1):
break
target_frame_size += local_attn_window - context_size if local_attn_window is not None else video_frame_size
target_frame_size = min(target_frame_size, z.size(1))
# re-encode the context with attention blocks
print(f're-encoding the context {start_idx+1-context_size}:{start_idx+1}')
kv_cache.reset_kv_index()
if self.txt_dim > 0:
self.reverse_step_condition(y, kv_cache, freqs_cis=freqs_cis)
x_context = x_full[:, start_idx+1-context_size: start_idx+1]
x_context_in, x_context =(x_context)
x_context = self.get_proj_in(x_context)
for it, block in enumerate(self.attn_blocks):
_kv_cache = partial(kv_cache, it) if kv_cache is not None else None
x_context = block(x_context, None, freqs_cis=freqs_cis, kv_cache=_kv_cache)[0]
x = x_full[:, 1:]
if guidance > 0:
x = x.chunk(2, dim=0)[0] # remove SOS token
x = self.permutation(x, inverse=True)
# Convert back to original dtype if needed
return x.to(original_dtype)
use
bash scripts/test_sample_image_mps.sh "a film still of a cat playing piano"
prompt: "A girl with backlighting, her silhouette against the sunset, bright halo effect"
prompt: "A anime girl with backlighting, her silhouette against the sunset, bright halo effect"
prompt: "A mysterious and ethereal figure with translucent wings, glowing eyes, and hair that flows like liquid silver"
This should help resolve the errors and allow you to run Starflow without the need for CUDA or MPS. Let me know if you need further assistance!
What is the speed per image? Iβve got 24 hours
I use MacBook pro M4 32GB and it's cost me to run 2 min to load checkpoint and pre image about 8 min 4 min to pre 4 min to create (in loop process like my upload video)
and I use test_sample_image_mps.shyou can try to change bz=16 size maybe speed up
(.venv) syun@syunnoMacBook-Pro ml-starflow1 % bash scripts/test_sample_image_mps.sh "a film still of a cat playing piano"
caption=a film still of a cat playing piano
input_image=none
/Users/syun/python_project/apple/ml-starflow1/.venv/lib/python3.10/site-packages/torch/amp/autocast_mode.py:270: UserWarning: User provided device_type of 'cuda', but CUDA is not available. Disabling
warnings.warn(
W1205 06:26:44.264000 90588 .venv/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/redirects.py:29] NOTE: Redirects are currently not supported in Windows or MacOs.
objc[90588]: Class AVFFrameReceiver is implemented in both /Users/syun/python_project/apple/ml-starflow1/.venv/lib/python3.10/site-packages/av/.dylibs/libavdevice.60.3.100.dylib (0x10f744760) and /Users/syun/python_project/apple/ml-starflow1/.venv/lib/python3.10/site-packages/decord/.dylibs/libavdevice.59.7.100.dylib (0x132858a10). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.
objc[90588]: Class AVFAudioReceiver is implemented in both /Users/syun/python_project/apple/ml-starflow1/.venv/lib/python3.10/site-packages/av/.dylibs/libavdevice.60.3.100.dylib (0x10f7447b0) and /Users/syun/python_project/apple/ml-starflow1/.venv/lib/python3.10/site-packages/decord/.dylibs/libavdevice.59.7.100.dylib (0x132858a60). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.
Loading text encoder google/flan-t5-xl...
Loading checkpoint shards: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 2/2 [00:00<00:00, 19.07it/s]
Loading VAE stabilityai/sd-vae-ft-ema...
Loading checkpoint from local path: ckpts/starflow_3B_t2i_256x256.pth
------------------------------------- Load ------------------------------------- starflow_3B_t2i_256x256
Sampling 1 from a film still of a cat playing piano on 1 device(s) [cpu]
Starting sampling with global batch size 1x1 devices
0%| | 0/1 [00:00<?, ?it/s Saving samples ... logs/starflow_3B_t2i_256x256
Saved samples to logs/starflow_3B_t2i_256x256/a film still of a cat playing piano_256x256_video_000.png
100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 1/1 [08:27<00:00, 507.85s/it]
starflow_3B_t2i_256x256 cfg 3.60, bsz=1x1, time=507.91s, speed=0.00 images/s
Thanks, thatβs very long, though!
And T5 and SD VAE look sus. I thought they made all-in-one model



