Skip to content
tinAI
Go back

GitHub - MoonshotAI/Attention-Residuals

Bài gốc: GitHub - MoonshotAI/Attention-Residuals

Tác giả: Unknown

Ngày đăng: Dịch ngày:

TL;DR

Attention Residuals (AttnRes) là một cải tiến cho các kết nối dư thừa trong Transformer, giúp mỗi lớp tập trung linh hoạt hơn thông qua attention học được. Block AttnRes quản lý bộ nhớ hiệu quả mà vẫn duy trì hiệu suất cao.

Tổng quan

Các kết nối dư thừa tiêu chuẩn tích lũy đầu ra của tất cả các lớp với trọng số cố định. Khi mô hình mở rộng, sự tích lũy này làm giảm đóng góp của từng lớp và gây ra vấn đề về độ lớn của trạng thái ẩn.

Attention Residuals (AttnRes) thay thế sự tích lũy này bằng attention mềm mại trên đầu ra của các lớp trước:

h_l = \sum_{i=0}^{l-1} \alpha_{i \rightarrow l} \cdot v_i

Trong đó, các trọng số ( \alpha_{i \rightarrow l} ) được tính toán qua một query giả học được cho mỗi lớp, cho phép truy cập thông minh và có nhận thức nội dung đến tất cả các biểu diễn trước đó.

Block AttnRes

Full AttnRes đơn giản nhưng yêu cầu nhiều bộ nhớ với độ phức tạp O(Ld). Block AttnRes phân chia các lớp thành N khối, tích lũy trong từng khối qua kết nối dư tiêu chuẩn và áp dụng attention chỉ qua các biểu diễn cấp khối. Với khoảng 8 khối, Block AttnRes thu lại phần lớn sự cải tiến của Full AttnRes với chi phí bộ nhớ không đáng kể.

 def block_attn_res(blocks: list[Tensor], partial_block: Tensor, proj: Linear, norm: RMSNorm) -> Tensor:
     """
  Inter-block attention: attend over block reps + partial sum.
  blocks:
  N tensors of shape [B, T, D]: completed block representations for each previous block
  partial_block:
  [B, T, D]: intra-block partial sum (b_n^i)
  """
     V = torch.stack(blocks + [partial_block])  # [N+1, B, T, D]
     K = norm(V)
     logits = torch.einsum('d, n b t d -> n b t', proj.weight.squeeze(), K)
     h = torch.einsum('n b t, n b t d -> b t d', logits.softmax(0), V)
     return h 

 def forward(self, blocks: list[Tensor], hidden_states: Tensor) -> tuple[list[Tensor], Tensor]:
     partial_block = hidden_states
     # apply block attnres before attn
     # blocks already include token embedding
     h = block_attn_res(blocks, partial_block, self.attn_res_proj, self.attn_res_norm)

     # if reaches block boundary, start new block
     # block_size counts ATTN + MLP; each transformer layer has 2
     if self.layer_number % (self.block_size // 2) == 0:
         blocks.append(partial_block)
         partial_block = None

     # self-attention layer
     attn_out = self.attn(self.attn_norm(h))
     partial_block = partial_block + attn_out if partial_block is not None else attn_out

     # apply block attnres before MLP
     h = block_attn_res(blocks, partial_block, self.mlp_res_proj, self.mlp_res_norm)

     # MLP layer
     mlp_out = self.mlp(self.mlp_norm(h))
     partial_block = partial_block + mlp_out

     return blocks, partial_block 

Kết quả

Quy luật mở rộng

AttnRes luôn vượt trội hơn mô hình cơ bản trên mọi ngân sách tính toán. Block AttnRes đạt được hiệu suất tương đương với mô hình cơ bản nhưng chỉ cần 1.25x tính toán hơn.

Hiệu suất thực tế (Kimi Linear 48B / 3B activated, 1.4T tokens)

Hạng mụcBenchmarkBaselineAttnRes
Tổng quátMMLU73.574.6
GPQA-Diamond36.944.4
BBH76.378.0
TriviaQA69.971.8
Toán & lập trìnhMath53.557.1
HumanEval59.162.2
MBPP72.073.9
Tiếng TrungCMMLU82.082.9
C-Eval79.682.5

AttnRes cải thiện đáng kể, đặc biệt trong suy luận nhiều bước (+7.5 trên GPQA-Diamond) và tạo mã (+3.1 trên HumanEval).

Động lực đào tạo

AttnRes giảm thiểu sự pha loãng trong PreNorm: độ lớn đầu ra ổn định và độ phân phối của gradient được cân bằng trên các lớp.


Read Original (EN) Quay lại Newsletter