Level Up Stable Diffusion with IP-Adapter

Β· 901 words Β· 5 minute read

ControlNet (which received the best paper prize at ICCV 2023 πŸ‘) or T2I-Adapters are game changers for Stable Diffusion practitioners. And it is not for no reason:

  1. They add a super effective level of control which mitigates hairy prompt engineering
  2. They have been designed as “adapters”, i.e. lightweight, composable and cheap units (a single copy of the base model is needed)

However, a good dose of prompt engineering is still required to infuse some specific style. Good news: the folks at Tencent AI Lab released a lightweight and super powerful IP-Adapter:

The image prompt adapter is designed to enable a pretrained text-to-image diffusion model to generate images with image prompt.

The core of IP-Adapter consists in replacing each UNet’s cross-attention layer with a more capable version able to consume both text and image tokens, aka “decoupled cross-attention”, and keeping the rest unchanged. This new kind of cross-attention looks like this:

decoupled cross attention

The image tokens are derived from the input image prompt thanks to a pretrained CLIP model and an extra network which turns the features into a sequence (not depicted in the above diagram).

One key aspect is that IP-Adapter has been designed to be compatible and composable with ControlNet and the like. This makes it a perfect candidate for Refiners, the PyTorch-based microframework we are building in the open at Finegrain dedicated to foundation model adaptation. Here is why and how:

(1) Swapping specific layers is native πŸ”—

How to retrieve all cross-attention layers? πŸ”—

In Refiners, models are expressed in a declarative way as a composition of basic layers, thanks to the Chain class. The UNet is no exception: it is a chain made of down/middle/up blocks which include attention and cross-attention layers at a deeper level.

πŸ’‘ Refiners provides a collection of layers designed for this declarative approach, like Passthrough (aka PASS, useful to record state), Residual (aka RES, for skip connections) and many others grouped under what we call fluxion layers. On the other hand, core layers like Conv2d, Linear, Embedding, etc are direct subclasses of their PyTorch equivalent.

Here is a folded-and-truncated, bird-eye view of the SDXL UNet:

(CHAIN) SDXLUNet(in_channels=4)
    β”œβ”€β”€ (PASS) TimestepEncoder() ...
    β”œβ”€β”€ (CHAIN) DownBlocks(in_channels=4) ...
    β”œβ”€β”€ (CHAIN) MiddleBlock() ...
    β”œβ”€β”€ (RES) Residual() ...
    β”œβ”€β”€ (CHAIN) UpBlocks() ...
    └── (CHAIN) OutputBlock() ...

Behind the scenes, it is a tree like structure and the Chain class offers methods to easily walk the tree and/or target specific layers. So here is how you can build a list of all the cross-attention layers:

from refiners.foundationals.latent_diffusion import SDXLUNet
import refiners.fluxion.layers as fl

# Create a vanilla UNet
unet = SDXLUNet(in_channels=4)

# Retrieve all attention layers thanks to `unet.layers(...)` and
# filter out self-attentions which are out-of-scope of IP-Adapter
cross_attns = list(filter(
    lambda attn: type(attn) != fl.SelfAttention,
    unet.layers(fl.Attention),
))

# Sanity check: XL is made of 70 cross-attention layers in total
assert len(cross_attns) == 70

How to create the adapter scaffold? πŸ”—

Refiners provides an Adapter class used to replace any target layer (here: a cross-attention layer) by another one (here: decoupled cross-attention). The scaffold looks like this:

# `CrossAttentionAdapter` is implemented as a subclass of `Chain`
# (= that way the decoupled cross-attention can be expressed as a
# chain of layers - for LoRA we would use sum instead) and is made
# to target `Attention` layers (the ones to be replaced)
class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
    def __init__(
        self,
        target: fl.Attention,
        # ... add extra args here
    ) -> None:
        with self.setup_adapter(target):
            super().__init__(
                # ... put decoupled cross-attention here
            )

    def inject(self, ...) -> "CrossAttentionAdapter":
    	# ... hook for extra logic triggered when the cross-attention
    	# is replaced

   	def eject(self) -> None:
   		# ... hook for extra logic triggered when the cross-attention
   		# is restored

Once implemented (more on this below), the adapter can just be inject-ed to take effect, and later on eject-ed to get back to the original UNet version. Thanks to this approach, the UNet implementation is left completely untouched.

πŸ’‘ You can think of it as model surgery.

How to implement decoupled cross-attentions? πŸ”—

The regular UNet’s cross-attention layers are implemented as the following Chain in Refiners:

regular cross attention

And the decoupled cross-attention for IP-Adapter should be expressed as a Chain like this:

decoupled cross attention

πŸ’‘ The job of Lambda(Q, K, V) is just to pick the right elements in the (Q, K, K', V, V') tuple.

At inject time, the placeholder weights should be filled with the frozen weights Wq, Wk and Wv taken from the original attention. Again, this can easily be done thanks to the Chain APIs.

And that’s it, the UNet now holds decoupled cross-attentions implementing IP-Adapter!

Please refer to the actual implementation for more details, in particular the uncovered parts such as the image projection model turning image features into a sequence.

(2) Composing compatible adapters is seamless πŸ”—

As stated above, one of the core design goals of IP-Adapter is:

[…] [to be] compatible with other controllable adapters such as ControlNet, allowing for an easy combination of image prompt with structure controls.

And Refiners has built-in support for ControlNet and T2I-Adapter, among others. So ultimately, combining adapters is just a matter of inject-ing extra adapters in addition to IP-Adapter:

# ... create UNet and then... just stack adapters

# First inject IP-Adapter
ip_adapter = SDXLIPAdapter(target=unet).inject()

# And then inject e.g. a T2I-Adapter leveraging Canny edges
t2i_adapter_canny = SDXLT2IAdapter(target=unet, name="canny").inject()

# ... write the denoising loop

See https://refine.rs/guides/adapting_sdxl/ for an end-to-end example with SDXL.

CΓ©dric from The Finegrain Team