Simplifying AI Code with Refiners

ยท 1472 words ยท 7 minute read

Ever embarked on a coding adventure, thinking you’ve struck gold with an AI repo, only to open model.py and be greeted with something that looks a bit like this:

class MyAmazingModel(nn.Module):
    def __init__(self, config, *args, **kwargs):
        super().__init__()
        self.layers = nn.ModuleList()
        for i in range(config.num_layers):
            if layer_type == 'conv':
                self.layers.append(nn.Conv2d(config.in_channels, config.out_channels, kernel_size=3))
            else:
                num_linear_features = config.in_features * self.get_num_features()
                self.layers.append(nn.Linear(config.in_features, config.out_features))
            if config.use_dropout:
                self.layers.append(nn.Dropout(config.dropout_rate))
            if hasattr(config, f'extra_param_{i}'):
                setattr(self, f'extra_layer_{i}', nn.Module())
        # And so on for a few hundreds of lines

    def forward(self, x, aux, cond, emb, *args, **kwargs):
        for i, layer in enumerate(self.layers):
            if isinstance(layer, nn.Conv2d):
                x = self.prepare_encoder_input(x, aux, cond, emb)
                x = F.relu(layer(kwarg1=x, kwarg2=kwargs['image']))
            elif isinstance(layer, nn.Linear):
                x = F.sigmoid(layer(kwarg1=x, kwarg2=kwargs['text']))
            if hasattr(self, f'extra_layer_{i}'):
                extra_layer = getattr(self, f'extra_layer_{i}')
                x = extra_layer(x)
        # And so on for a few hundreds of lines
    # A bunch of very elliptic yet completely cryptic methods

Sounds familiar? I thought so. And if you’re like me, you’ve probably wondered, “What’s happening here?” Don’t fret โ€” I’ve got you covered. The root of this chaos? A conspicuous absence of Refiners.

“But wait,” you might say, “isn’t this the coder’s misstep rather than PyTorch’s?” And you’re right. The issue isn’t with Pytorch but with how the code is structured. Let’s delve deeper.

The problem ๐Ÿ”—

Consider writing a model. You start off with something straightforward like a self-attention layer:

from torch import nn, Tensor

class SelfAttention(nn.Module):
    def __init__(self, embedding_dim: int) -> None:
        super().__init__()
        self.query_linear = nn.Linear(embedding_dim, embedding_dim)
        self.key_linear = nn.Linear(embedding_dim, embedding_dim)
        self.value_linear = nn.Linear(embedding_dim, embedding_dim)
        self.out_linear = nn.Linear(embedding_dim, embedding_dim)

    def forward(self, x: Tensor) -> Tensor:
        query = self.query_linear(x)
        key = self.key_linear(x)
        value = self.value_linear(x)
        attention = torch.softmax(
            torch.matmul(
                query, key.transpose(1, 2)) / (key.shape[-1] ** 0.5),
            dim=-1
        )
        out = torch.matmul(attention, value)
        out = self.out_linear(out)
        return out

Now, let’s spice things up by adding new features. This SelfAttention layer could be part of an LLM that is trained to generate text, but we would like to extend it so it can “see” images and generate text based on both the text and the image just like GPT-4V does.

We can do this by adding a cross-attention layer that will take an image embedding as input and output a tensor that will be added to the key and value tensors of the self-attention layer.

from torch import nn, Tensor

class SelfAttention(nn.Module):
    def __init__(self, embedding_dim: int, cross_embedding_dim: int) -> None:
        super().__init__()
        self.query_linear = nn.Linear(embedding_dim, embedding_dim)
        self.key_linear = nn.Linear(embedding_dim, embedding_dim)
        self.value_linear = nn.Linear(embedding_dim, embedding_dim)
        self.out_linear = nn.Linear(embedding_dim, embedding_dim)
        # We add two new layers
        self.cross_key_linear = nn.Linear(cross_embedding_dim, embedding_dim)
        self.cross_value_linear = nn.Linear(cross_embedding_dim, embedding_dim)

    # We need to modify the signature of the forward method
    def forward(self, x: Tensor, cross_x: Tensor) -> Tensor:
        query = self.query_linear(x)
        key = self.key_linear(x)
        value = self.value_linear(x)
        # We modify the forward method to add cross attention
        cross_key = self.cross_key_linear(cross_x)
        cross_value = self.cross_value_linear(cross_x)
        key += cross_key
        value += cross_value
        attention = torch.softmax(
            torch.matmul(
                query, key.transpose(1, 2)) / (key.shape[-1] ** 0.5),
            dim=-1
        )
        out = torch.matmul(attention, value)
        out = self.out_linear(out)
        return out

Simple, right? But when your project evolves, so does the complexity of your code. You start adding conditional statements and new parameters, and suddenly, your once neat code resembles a dense forest.

The solution ๐Ÿ”—

We (Finegrain) developed Refiners to allow ML engineers to write models in a composable way. Let’s reimagine our self-attention layer using Refiners:

from refiners.fluxion import layers as fl
from refiners.fluxion.layers.attentions import ScaledDotProductAttention

class SelfAttention(fl.Chain):
    def __init__(self, embedding_dim: int) -> None:
        super().__init__(
            fl.Parallel(
                fl.Linear(embedding_dim, embedding_dim),
                fl.Linear(embedding_dim, embedding_dim),
                fl.Linear(embedding_dim, embedding_dim),
            ),
            ScaledDotProductAttention(),
            fl.Linear(embedding_dim, embedding_dim),
        )

Now, you might wonder about the forward method. Here’s a pleasant surprise: you don’t need to write it. The call method in the Chain class of Refiners does the heavy lifting: at its core, the Chain class is just a Sequential that will call each layer’s forward method in order.

fl.Chain(
    fl.Linear(2, 2),
    fl.SiLU(),
    fl.Linear(2, 2),
)

Since all models cannot be written sequentially, we must add more functionality. For example, we need to be able to run multiple layers in parallel.

fl.Parallel(
    fl.Linear(2, 2),
    fl.Linear(2, 2),
)

# is the same as computing
(linear_1(x), linear_2(x))

But you also have the Sum, Residual, Concatenate, and many other Chain-s.

fl.Sum(
    fl.Linear(2, 2),
    fl.Linear(2, 2),
)

# is the same as computing
linear_1(x) + linear_2(x)
fl.Residual(
    fl.Linear(2, 2),
    fl.Linear(2, 2),
)

# is the same as computing
x + linear_2(linear_1(x))
fl.Concatenate(
    fl.Linear(2, 2),
    fl.Linear(2, 2),
    dim=-1,
)

# is the same as computing
torch.cat((linear_1(x), linear_2(x)), dim=-1)

Let’s revisit our self-attention layer with cross-attention, this time with the Refiners touch:

class SelfAttention(fl.Chain):
    def __init__(self, embedding_dim: int, cross_embedding_dim: int):
        super().__init__(
            fl.Parallel(
                fl.Chain(
                    fl.GetArg(0),
                    fl.Linear(embedding_dim, embedding_dim),
                ),
                fl.Sum(
                    fl.Chain(
                        fl.GetArg(0),
                        fl.Linear(embedding_dim, embedding_dim),
                    ),
                    fl.Chain(
                        fl.GetArg(1),
                        fl.Linear(embedding_dim, embedding_dim),
                    )
                ),
                fl.Sum(
                    fl.Chain(
                        fl.GetArg(0),
                        fl.Linear(cross_embedding_dim, embedding_dim),
                    ),
                    fl.Chain(
                        fl.GetArg(1),
                        fl.Linear(cross_embedding_dim, embedding_dim),
                    )
                ),
            ),
            ScaledDotProductAttention(),
            fl.Linear(embedding_dim, embedding_dim),
        )

It’s already a bit more complex to follow: we need to understand where each argument is going. Also won’t be able to stack these layers since the second argument will not be passed to the following layers. The code will become complex to read if more arguments are added.

But we can do better. Let’s see how.

The data flow throughout a model usually consists of a single tensor alongside keyword arguments passed to certain sublayers. That’s the reason why there are often dictionaries that are passed down the layer hierarchy in modelization.

Since all models are defined using nested Chain-s in Refiners, we can pass down a Context accessible by all sublayers. Let’s see how we can rewrite the cross-attention layer using the Context API.

class SelfAttention(fl.Chain):
    def __init__(self, embedding_dim: int, cross_embedding_dim: int) -> None:
        super().__init__(
            fl.Parallel(
                fl.Linear(embedding_dim, embedding_dim),
                fl.Sum(
                    fl.Linear(embedding_dim, embedding_dim),
                    fl.Chain(
                        fl.UseContext('self_attention', 'cross_x'),
                        fl.Linear(cross_embedding_dim, embedding_dim),
                    )
                fl.Sum(
                    fl.Linear(embedding_dim, embedding_dim),
                    fl.Chain(
                        fl.UseContext('self_attention', 'cross_x'),
                        fl.Linear(cross_embedding_dim, embedding_dim),
                    )
                ),
            ),
            ScaledDotProductAttention(),
            fl.Linear(embedding_dim, embedding_dim),
        )

# let's say that we have a model that uses this layer, and we want to pass a 
# cross_x tensor to it. This will work even if deeply nested in the model.
model = fl.Chain(
    SelfAttention(1280, 512),
    SelfAttention(1280, 512),
    fl.MultiLinear(1280, 1280, inner_dim=5120, num_layers=3),
)
x = torch.randn(1, 1280)
cross_x = torch.randn(1, 512)

# We can pass cross_x to the layer using a context
model.set_context('self_attention', {'cross_x': cross_x})

# And then we can call the model
model(x)

We now have a way to pass down parameters to sublayers. But we still needed to rewrite the whole self-attention layer to modify it. Since we are using Refiners, we can dynamically change the model using the Adapter pattern instead.

from refiners.fluxion.adapters import Adapter

# We define a new class that will be used to modify the model
class CrossAttentionAdapter(fl.Chain, Adapter[SelfAttention]):
    # This adapter will specifically "target" the SelfAttention class
    def __init__(self, target: SelfAttention, cross_embedding_dim: int) -> None:
        self.embedding_dim = target.ensure_find(fl.Linear).in_features
        self.cross_embedding_dim = cross_embedding_dim
        # The setup_adapter method takes care of the boilerplate code
        with self.setup_adapter(target):
            # Since an Adapter is also always a Chain, the flow of the original 
            # model is preserved.
            super().__init__(target)

    # The inject method is called when the adapter is injected into the model
    def inject(self, parent: fl.Chain | None = None):
        # Let's grab the Parallel layer; we try to never use index or keys to 
        # find a layer in Refiners since it's not safe.
        parallel = self.target.ensure_find(fl.Parallel)
        assert len(parallel) == 3
        new_parallel = fl.Parallel(
            parallel[0],
            fl.Sum(
                parallel[1],
                fl.Chain(
                    fl.UseContext(context='self_attention', key='cross_x'),
                    fl.Linear(self.cross_embedding_dim, self.embedding_dim),
                )
            ),
            fl.Sum(
                parallel[2],
                fl.Chain(
                    fl.UseContext(context='self_attention', key='cross_x'),
                    fl.Linear(self.cross_embedding_dim, self.embedding_dim),
                )
            ),
        )
        self.target.replace(parallel, new_parallel)
        # You must specify a parent for the adapter when the target is not 
        # already in a Chain
        return super().inject(parent)

    # The extract method is called when the adapter is extracted from the model
    def eject(self):
        parallel = self.target.ensure_find(fl.Parallel)
        assert len(parallel) == 3
        old_parallel = fl.Parallel(
            parallel[0],
            parallel[1].ensure_find(fl.Linear),
            parallel[2].ensure_find(fl.Linear),
        )
        self.target.replace(parallel, old_parallel)
        super().eject()

# We can now use the adapter
model = fl.Chain(
    SelfAttention(1280),
    SelfAttention(1280),
    fl.MultiLinear(1280, 1280, inner_dim=5120, num_layers=3),
)
for self_attention in model.layers(SelfAttention):
    adapter = CrossAttentionAdapter(self_attention, cross_embedding_dim=512)
    adapter.inject()

x = torch.rand(1, 1, 1280)
cross_x = torch.rand(1, 1, 512)

model.set_context('self_attention', {'cross_x': cross_x})

model(x)

The Adapter pattern is essential for achieving modularity and reusability in your code. Once you implement a foundational model, the original code and API remain unchanged, and you can modify it without having to rewrite the entire model. This approach enables external contributions and faster experimentation.

If you examine the Refiners codebase, you’ll notice the absence of any Dropout layer. That’s because we use the Adapter pattern to add the Dropout layer before training. Working with Adapters enables us to separate responsibilities and create a more modular codebase. Just imagine how many if/else statements you need in your code to experiment with different Dropout placement strategies!

Whether you’re a seasoned machine learning engineer or just starting out, give Refiners a try. It could be the tool that transforms your AI coding journey from a daunting task into an enjoyable and creative process. Remember, the only limit is your imagination!

Benjamin from the Finegrain Team