ControlNet (which received the best paper prize at ICCV 2023 π) or T2I-Adapters are game changers for Stable Diffusion practitioners. And it is not for no reason:
- They add a super effective level of control which mitigates hairy prompt engineering
- They have been designed as “adapters”, i.e. lightweight, composable and cheap units (a single copy of the base model is needed)
However, a good dose of prompt engineering is still required to infuse some specific style. Good news: the folks at Tencent AI Lab released a lightweight and super powerful IP-Adapter:
The image prompt adapter is designed to enable a pretrained text-to-image diffusion model to generate images with image prompt.
The core of IP-Adapter consists in replacing each UNet’s cross-attention layer with a more capable version able to consume both text and image tokens, aka “decoupled cross-attention”, and keeping the rest unchanged. This new kind of cross-attention looks like this:
The image tokens are derived from the input image prompt thanks to a pretrained CLIP model and an extra network which turns the features into a sequence (not depicted in the above diagram).
One key aspect is that IP-Adapter has been designed to be compatible and composable with ControlNet and the like. This makes it a perfect candidate for Refiners, the PyTorch-based microframework we are building in the open at Finegrain dedicated to foundation model adaptation. Here is why and how:
(1) Swapping specific layers is native π
How to retrieve all cross-attention layers? π
In Refiners, models are expressed in a declarative way as a composition of basic layers, thanks to the Chain
class. The UNet is no exception: it is a chain made of down/middle/up blocks which include attention and cross-attention layers at a deeper level.
π‘ Refiners provides a collection of layers designed for this declarative approach, like
Passthrough
(akaPASS
, useful to record state),Residual
(akaRES
, for skip connections) and many others grouped under what we call fluxion layers. On the other hand, core layers likeConv2d
,Linear
,Embedding
, etc are direct subclasses of their PyTorch equivalent.
Here is a folded-and-truncated, bird-eye view of the SDXL UNet:
(CHAIN) SDXLUNet(in_channels=4)
βββ (PASS) TimestepEncoder() ...
βββ (CHAIN) DownBlocks(in_channels=4) ...
βββ (CHAIN) MiddleBlock() ...
βββ (RES) Residual() ...
βββ (CHAIN) UpBlocks() ...
βββ (CHAIN) OutputBlock() ...
Behind the scenes, it is a tree like structure and the Chain
class offers methods to easily walk the tree and/or target specific layers. So here is how you can build a list of all the cross-attention layers:
from refiners.foundationals.latent_diffusion import SDXLUNet
import refiners.fluxion.layers as fl
# Create a vanilla UNet
unet = SDXLUNet(in_channels=4)
# Retrieve all attention layers thanks to `unet.layers(...)` and
# filter out self-attentions which are out-of-scope of IP-Adapter
cross_attns = list(filter(
lambda attn: type(attn) != fl.SelfAttention,
unet.layers(fl.Attention),
))
# Sanity check: XL is made of 70 cross-attention layers in total
assert len(cross_attns) == 70
How to create the adapter scaffold? π
Refiners provides an Adapter
class used to replace any target layer (here: a cross-attention layer) by another one (here: decoupled cross-attention). The scaffold looks like this:
# `CrossAttentionAdapter` is implemented as a subclass of `Chain`
# (= that way the decoupled cross-attention can be expressed as a
# chain of layers - for LoRA we would use sum instead) and is made
# to target `Attention` layers (the ones to be replaced)
class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
def __init__(
self,
target: fl.Attention,
# ... add extra args here
) -> None:
with self.setup_adapter(target):
super().__init__(
# ... put decoupled cross-attention here
)
def inject(self, ...) -> "CrossAttentionAdapter":
# ... hook for extra logic triggered when the cross-attention
# is replaced
def eject(self) -> None:
# ... hook for extra logic triggered when the cross-attention
# is restored
Once implemented (more on this below), the adapter can just be inject
-ed to take effect, and later on eject
-ed to get back to the original UNet version. Thanks to this approach, the UNet implementation is left completely untouched.
π‘ You can think of it as model surgery.
How to implement decoupled cross-attentions? π
The regular UNet’s cross-attention layers are implemented as the following Chain
in Refiners:
And the decoupled cross-attention for IP-Adapter should be expressed as a Chain
like this:
π‘ The job of
Lambda(Q, K, V)
is just to pick the right elements in the(Q, K, K', V, V')
tuple.
At inject
time, the placeholder weights should be filled with the frozen weights Wq
, Wk
and Wv
taken from the original attention. Again, this can easily be done thanks to the Chain
APIs.
And that’s it, the UNet now holds decoupled cross-attentions implementing IP-Adapter!
Please refer to the actual implementation for more details, in particular the uncovered parts such as the image projection model turning image features into a sequence.
(2) Composing compatible adapters is seamless π
As stated above, one of the core design goals of IP-Adapter is:
[…] [to be] compatible with other controllable adapters such as ControlNet, allowing for an easy combination of image prompt with structure controls.
And Refiners has built-in support for ControlNet and T2I-Adapter, among others. So ultimately, combining adapters is just a matter of inject
-ing extra adapters in addition to IP-Adapter:
# ... create UNet and then... just stack adapters
# First inject IP-Adapter
ip_adapter = SDXLIPAdapter(target=unet).inject()
# And then inject e.g. a T2I-Adapter leveraging Canny edges
t2i_adapter_canny = SDXLT2IAdapter(target=unet, name="canny").inject()
# ... write the denoising loop
See https://refine.rs/guides/adapting_sdxl/ for an end-to-end example with SDXL.
CΓ©dric from The Finegrain Team