Spaces:
Runtime error
Runtime error
| # https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#regional-prompting-pipeline | |
| # https://github.com/huggingface/diffusers/blob/main/examples/community/regional_prompting_stable_diffusion.py | |
| import gradio as gr | |
| from diffusers.pipelines import pipeline_utils | |
| from modules import shared, devices, scripts, processing, sd_models, prompt_parser_diffusers | |
| def hijack_register_modules(self, **kwargs): | |
| for name, module in kwargs.items(): | |
| if module is None or isinstance(module, (tuple, list)) and module[0] is None: | |
| register_dict = {name: (None, None)} | |
| elif isinstance(module, bool): | |
| pass | |
| else: | |
| library, class_name = pipeline_utils._fetch_class_library_tuple(module) # pylint: disable=protected-access | |
| register_dict = {name: (library, class_name)} | |
| self.register_to_config(**register_dict) | |
| setattr(self, name, module) | |
| class Script(scripts.Script): | |
| def title(self): | |
| return 'Regional prompting' | |
| def show(self, is_img2img): | |
| return not is_img2img if shared.backend == shared.Backend.DIFFUSERS else False | |
| def change(self, mode): | |
| return [gr.update(visible='Col' in mode or 'Row' in mode), gr.update(visible='Prompt' in mode)] | |
| def ui(self, _is_img2img): | |
| with gr.Row(): | |
| gr.HTML('<a href="https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#regional-prompting-pipeline">  Regional prompting</a><br>') | |
| with gr.Row(): | |
| mode = gr.Radio(label='Mode', choices=['None', 'Prompt', 'Prompt EX', 'Columns', 'Rows'], value='None') | |
| with gr.Row(): | |
| power = gr.Slider(label='Power', minimum=0, maximum=1, value=1.0, step=0.01) | |
| threshold = gr.Textbox('', label='Prompt thresholds:', default='', visible=False) | |
| grid = gr.Text('', label='Grid sections:', default='', visible=False) | |
| mode.change(fn=self.change, inputs=[mode], outputs=[grid, threshold]) | |
| return mode, grid, power, threshold | |
| def run(self, p: processing.StableDiffusionProcessing, mode, grid, power, threshold): # pylint: disable=arguments-differ | |
| if mode is None or mode == 'None': | |
| return | |
| # backup pipeline and params | |
| orig_pipeline = shared.sd_model | |
| orig_dtype = devices.dtype | |
| orig_prompt_attention = shared.opts.prompt_attention | |
| # create pipeline | |
| if shared.sd_model_type != 'sd': | |
| shared.log.error(f'Regional prompting: incorrect base model: {shared.sd_model.__class__.__name__}') | |
| return | |
| pipeline_utils.DiffusionPipeline.register_modules = hijack_register_modules | |
| prompt_parser_diffusers.EmbeddingsProvider._encode_token_ids_to_embeddings = prompt_parser_diffusers.orig_encode_token_ids_to_embeddings # pylint: disable=protected-access | |
| shared.sd_model = sd_models.switch_pipe('regional_prompting_stable_diffusion', shared.sd_model) | |
| if shared.sd_model.__class__.__name__ != 'RegionalPromptingStableDiffusionPipeline': # switch failed | |
| shared.log.error(f'Regional prompting: not a tiling pipeline: {shared.sd_model.__class__.__name__}') | |
| shared.sd_model = orig_pipeline | |
| return | |
| sd_models.set_diffuser_options(shared.sd_model) | |
| shared.opts.data['prompt_attention'] = 'Fixed attention' # this pipeline is not compatible with embeds | |
| processing.fix_seed(p) | |
| # set pipeline specific params, note that standard params are applied when applicable | |
| rp_args = { | |
| 'mode': mode.lower(), | |
| 'power': power, | |
| } | |
| if 'prompt' in mode.lower(): | |
| rp_args['th'] = threshold | |
| else: | |
| rp_args['div'] = grid | |
| p.task_args = { | |
| **p.task_args, | |
| 'prompt': p.prompt, | |
| 'rp_args': rp_args, | |
| } | |
| # run pipeline | |
| shared.log.debug(f'Regional: args={p.task_args}') | |
| processed: processing.Processed = processing.process_images(p) # runs processing using main loop | |
| # restore pipeline and params | |
| prompt_parser_diffusers.EmbeddingsProvider._encode_token_ids_to_embeddings = prompt_parser_diffusers.compel_hijack # pylint: disable=protected-access | |
| shared.opts.data['prompt_attention'] = orig_prompt_attention | |
| shared.sd_model = orig_pipeline | |
| shared.sd_model.to(orig_dtype) | |
| return processed | |