Why does the new Block Transformer Architecture seem promising?

A recent paper jointly published by DeepMind, LG, and KAIST researchers introduced a new innovation for Transformer models: the global-to-local Block Architecture. The paper highlights up to 20x improvement in inference throughput vis-à-vis conventional transformer models. While this claim needs to be empirically validated by early adopters, the ideas posited in the paper do seem intuitively robust.

The Problem with Standard Transformers

In standard (autoregressive-based) transformer models, the self-attention mechanism needs to attend to all previous tokens, thereby significantly increasing the computational cost of token generation. As a remedial solution, the key-value (KV) states of each token across each layer are cached during autoregressive decoding. While each decoding step computes the KV state of a single token, the KV states of all tokens from previous sequences also need to be retrieved to compute the self-attention scores. This leads to high KV cache IO during the inference process, thus increasing the inference costs.

The Architecture of the Block Transformer

The Block Transformer comprises three main components: Embedder, Block Decoder, and Token Decoder.

  • Embedder: Embeds input block tokens into a block embedding, which then serves as input to the (block) decoder.
  • Block Decoder: Autoregressive transformer that applies self-attention (between blocks) to decode a context block embedding that contains information to predict the next block.
  • Token Decoder: Autoregressively decodes the next block tokens by applying local self-attention between tokens within each block.

A reference diagram from one of the paper’s authors explains the architecture clearly. [Owner/Source: https://x.com/itsnamgyu/status/1807400615657803944]

Why does the concept of Block Architecture seem intuitively efficient?

Firstly, since the Block Decoder uses coarse-grained block inputs (instead of individual tokens), the context length gets reduced, thereby reducing the FLOPs for positionwise computation, and attention score computation. Additionally, the KV cache usage and KV cache IO are also reduced.

Secondly, the Token Decoder ensures that only local-context tokens need to be computed, stored, and retrieved – thereby eliminating a significant chunk of attention-related operations, and reducing computational complexity. For instance, the KV cache IO changes from quadratic complexity to linear complexity.

Here’s a comparison of the inference bottleneck and computational complexity between the vanilla transformer and the block/token decoders, as mentioned in the paper:

Note: D = Dimensions; B = Batch size; L = Context length; Lb = Block length; N = Number of layers

The researchers experimented with three embedder strategies to create the block embedding, and recommended lookup strategy as the most effective one. Longer prefixes in the token decoder were found to increase performance with minimal overhead. Additionally, MSE or contrastive losses at the block decoder were found to degrade performance.

Closing Comments

The overall innovation in this research is not entirely new. It is based on the concept of hierarchical global-to-local model architectures in which global dependencies are encoded in coarse detail, and fine details are encoded within local regions. The researchers extended this concept to mitigate the primary inference bottlenecks, which has a lot of practical significance fo companies working to develop/train/leverage LLMs within constrained budgets.

Share this article.