Tổng quan
Các kết nối dư thừa tiêu chuẩn tích lũy đầu ra của tất cả các lớp với trọng số cố định. Khi mô hình mở rộng, sự tích lũy này làm giảm đóng góp của từng lớp và gây ra vấn đề về độ lớn của trạng thái ẩn.
Attention Residuals (AttnRes) thay thế sự tích lũy này bằng attention mềm mại trên đầu ra của các lớp trước:
h_l = \sum_{i=0}^{l-1} \alpha_{i \rightarrow l} \cdot v_i
Trong đó, các trọng số ( \alpha_{i \rightarrow l} ) được tính toán qua một query giả học được cho mỗi lớp, cho phép truy cập thông minh và có nhận thức nội dung đến tất cả các biểu diễn trước đó.
Block AttnRes
Full AttnRes đơn giản nhưng yêu cầu nhiều bộ nhớ với độ phức tạp O(Ld). Block AttnRes phân chia các lớp thành N khối, tích lũy trong từng khối qua kết nối dư tiêu chuẩn và áp dụng attention chỉ qua các biểu diễn cấp khối. Với khoảng 8 khối, Block AttnRes thu lại phần lớn sự cải tiến của Full AttnRes với chi phí bộ nhớ không đáng kể.
def block_attn_res(blocks: list[Tensor], partial_block: Tensor, proj: Linear, norm: RMSNorm) -> Tensor:
"""
Inter-block attention: attend over block reps + partial sum.
blocks:
N tensors of shape [B, T, D]: completed block representations for each previous block
partial_block:
[B, T, D]: intra-block partial sum (b_n^i)
"""
V = torch.stack(blocks + [partial_block]) # [N+1, B, T, D]
K = norm(V)
logits = torch.einsum('d, n b t d -> n b t', proj.weight.squeeze(), K)
h = torch.einsum('n b t, n b t d -> b t d', logits.softmax(0), V)
return h
def forward(self, blocks: list[Tensor], hidden_states: Tensor) -> tuple[list[Tensor], Tensor]:
partial_block = hidden_states
# apply block attnres before attn
# blocks already include token embedding
h = block_attn_res(blocks, partial_block, self.attn_res_proj, self.attn_res_norm)
# if reaches block boundary, start new block
# block_size counts ATTN + MLP; each transformer layer has 2
if self.layer_number % (self.block_size // 2) == 0:
blocks.append(partial_block)
partial_block = None
# self-attention layer
attn_out = self.attn(self.attn_norm(h))
partial_block = partial_block + attn_out if partial_block is not None else attn_out
# apply block attnres before MLP
h = block_attn_res(blocks, partial_block, self.mlp_res_proj, self.mlp_res_norm)
# MLP layer
mlp_out = self.mlp(self.mlp_norm(h))
partial_block = partial_block + mlp_out
return blocks, partial_block
Kết quả
Quy luật mở rộng
AttnRes luôn vượt trội hơn mô hình cơ bản trên mọi ngân sách tính toán. Block AttnRes đạt được hiệu suất tương đương với mô hình cơ bản nhưng chỉ cần 1.25x tính toán hơn.
Hiệu suất thực tế (Kimi Linear 48B / 3B activated, 1.4T tokens)
| Hạng mục | Benchmark | Baseline | AttnRes |
|---|---|---|---|
| Tổng quát | MMLU | 73.5 | 74.6 |
| GPQA-Diamond | 36.9 | 44.4 | |
| BBH | 76.3 | 78.0 | |
| TriviaQA | 69.9 | 71.8 | |
| Toán & lập trình | Math | 53.5 | 57.1 |
| HumanEval | 59.1 | 62.2 | |
| MBPP | 72.0 | 73.9 | |
| Tiếng Trung | CMMLU | 82.0 | 82.9 |
| C-Eval | 79.6 | 82.5 |
AttnRes cải thiện đáng kể, đặc biệt trong suy luận nhiều bước (+7.5 trên GPQA-Diamond) và tạo mã (+3.1 trên HumanEval).
Động lực đào tạo
AttnRes giảm thiểu sự pha loãng trong PreNorm: độ lớn đầu ra ổn định và độ phân phối của gradient được cân bằng trên các lớp.