Uni-MoE 2.0-Image

Uni-MoE 2.0 is a fully open-source omnimodal model that substantially advances the capabilities of Lychee's Uni-MoE series in language-centric multimodal understanding, reasoning, and generating.

Uni-MoE 2.0-Image is a visual generation model derived from Uni-MoE 2.0-Omni, which has been specifically fine-tuned on visual generation data.


If you enjoy our work or want timely updates, please give us a like and follow us.

Open-source Plan

Model Introduction

Getting Started

1. Clone this repository and navigate to the Uni-MoE 2.0 folder

git clone https://github.com/HITsz-TMG/Uni-MoE.git
cd Uni-MoE-2

2. Set up environment

Install the evaluation environment according to the requirements.

conda create -n uni_moe_2 python=3.11
conda activate uni_moe_2
pip install torch==2.5.1 torchaudio==2.5.1 torchvision==0.20.1
pip install -r requirements.txt
pip install flash-attn==2.6.0.post1 --no-build-isolation
pip install clip==1.0@git+https://github.com/openai/CLIP.git@dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1

Example Usage

We provide a simple example on the usage of this repo. For detailed usage, please refer to cookbook

import os 
import sys 
from typing import Dict, Optional, Sequence, List, Any, Union

import torch, torchaudio
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
from uni_moe.model.modeling_out import GrinQwen2VLOutForConditionalGeneration
from uni_moe.model.processing_qwen2_vl import Qwen2VLProcessor
from uni_moe.qwen_vl_utils import process_mm_info
from PIL import Image
from uni_moe.model import deepspeed_moe_inference_utils
import torch.distributed as dist


def load_unimoe(model_path: str):
    processor = Qwen2VLProcessor.from_pretrained(model_path)
    model = GrinQwen2VLOutForConditionalGeneration.from_pretrained(
        model_path, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16
    )
    model.cuda()

    # sync processors
    processor.data_args = model.config

    return model, processor


EXAMPLES = [
    # generation
    {
        "prompt": "<image>\nImage generation: In the art piece, a realistically depicted young girl with flowing blonde hair gazes intently into the distance, her eyes reflecting the vibrant hues of a spring forest. The verdant greens and soft pastels of the budding trees are captured in subtle brushstrokes, giving the scene a serene and tranquil atmosphere. The minimalist composition focuses on the girl's expression of wonder and the lush woodland background, while the texture of the oil paint adds depth and richness to the canvas.",
        "input_image": None,
        "out_name": "genarate.png",
    },
    # edition
    {
        "prompt": "<image>\nAdd a dog standing near the fence in the foreground, close to the road.",
        "input_image": "examples/assets/visual_gen/input_images/edit.jpg",
        "out_name": "edit.png",
    }
]


def make_message(prompt: str, image_path: str = None) -> List[Dict[str, Any]]:
    """Return messages list compatible with the processor.apply_chat_template
    If image_path is provided, include it as first message of type image.
    """
    user_items = []
    if image_path is not None:
        user_items.append({"type": "image", "image": image_path})
    else: 
        user_items.append({"type": "image", "image": "examples/assets/visual_gen/input_images/white.png"})
    user_items.append({"type": "text", "text": prompt})
    return [{"role": "user", "content": user_items}]


def run_batch(model_path: str, examples: List[Dict[str, Any]], save_dir: str):
    os.makedirs(save_dir, exist_ok=True)
    model, processor = load_unimoe(model_path)

    for i, ex in enumerate(examples, start=1):
        print(f"\n=== [{i}/{len(examples)}]  prompt={ex['prompt']}")
        messages = make_message(ex['prompt'], ex.get('input_image'))
        print(messages)

        texts = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        texts = texts.replace("<image>","<|vision_start|><|image_pad|><|vision_end|>").replace("<audio>","<|audio_start|><|audio_pad|><|audio_end|>").replace("<video>","<|vision_start|><|video_pad|><|vision_end|>")

        image_inputs, video_inputs, audio_inputs = process_mm_info(messages)

        inputs = processor(
            text=texts,
            images=image_inputs,
            videos=video_inputs,
            audios=audio_inputs,
            padding=True,
            return_tensors="pt",
        )

        # ensure batch dim
        if inputs.get("input_ids") is None:
            print("Warning: input_ids missing, skipping example")
            continue
        inputs["input_ids"] = inputs["input_ids"].unsqueeze(0)

        # prepare save path
        base_out = os.path.splitext(ex['out_name'])[0]
        save_name = f"{base_out}.png"
        save_path = os.path.join(save_dir, save_name)

        # call generate_visualgen
        output_ids = model.generate_visualgen(
            input_ids=inputs["input_ids"].to(device=model.device),
            pixel_values = inputs["pixel_values"].to(dtype=torch.bfloat16,device=model.device) if "pixel_values" in inputs else None,
            image_grid_thw=inputs.get("image_grid_thw", None),
            pixel_values_videos=inputs.get("pixel_values_videos", None),
            video_grid_thw=inputs.get("video_grid_thw", None),
            audio_features=inputs.get("audio_features", None),
            audio_grid_thw=inputs.get("audio_grid_thw", None),
            use_cache=True,
            attention_mask=inputs["input_ids"].ne(processor.tokenizer.pad_token_id),
            pad_token_id=processor.tokenizer.eos_token_id,
            golden_caption_emb=None,
            golden_task_emb=None,
            golden_visual_emb=None,
            image_path=ex.get("input_image", None),
            save_path=save_path,
            do_sample=False,
            num_beams=1,
            temperature=0.0,
            max_new_tokens=4096,
        )

        decoded = processor.batch_decode(output_ids[:, inputs["input_ids"].shape[-1]:], skip_special_tokens=True)[0]
        print("Generated text output:\n", decoded)
        print("Saved image to:", save_path)


if __name__ == "__main__":
    MODEL_PATH = "HIT-TMG/Uni-MoE-2.0-Image"
    SAVE_DIR = "Path to Save Images"
    run_batch(MODEL_PATH, EXAMPLES, SAVE_DIR)
Downloads last month
680
Safetensors
Model size
31B params
Tensor type
BF16
ยท
F32
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Collection including HIT-TMG/Uni-MoE-2.0-Image