Giới thiệu về Mamba-3
Mamba-3 là một mô hình không gian trạng thái (SSM) mới, được thiết kế với mục tiêu chính là tăng cường hiệu suất suy diễn, khác biệt với Mamba-2 tập trung vào tốc độ huấn luyện. Cải tiến chính bao gồm công thức hồi quy phức tạp hơn, theo dõi trạng thái bằng giá trị phức, và biến thể MIMO (đa đầu vào, đa đầu ra) để nâng cao độ chính xác mà không làm chậm quá trình giải mã.
Cải tiến mấu chốt
- Công thức hồi quy biểu cảm hơn: Được phát triển từ sơ đồ phân đoạn lũy thừa mũ.
- Theo dõi trạng thái giá trị phức: Mô hình SSM với giá trị phức hợp.
- MIMO SSMs: Dùng nhiều SSMs song song, thay vì chỉ một.
Những cải tiến này giúp Mamba-3 đạt được hiệu suất vượt trội mà vẫn duy trì được độ trễ dự đoán tương tự.
Kiến trúc
Khác với Mamba-2, Mamba-3 sử dụng thêm QKNorm, giúp ổn định quá trình huấn luyện. Loại bỏ hoàn toàn sự cần thiết của short causal convolution bằng cách tích hợp trực tiếp các thành phần như RoPE và chiếu MIMO.
Kết quả thực nghiệm
Mô hình ngôn ngữ
Mô hình Mamba-3 vượt qua Mamba-2 và các mô hình hồi quy tuyến tính khác như GDN trong các bài kiểm tra mô hình ngôn ngữ trước huấn luyện. Biến thể MIMO còn nâng cao độ chính xác thêm hơn 1% so với Mamba-3 thông thường.
Nhiệm vụ truy xuất
Mặc dù các mô hình tuyến tính thường kém hơn Transformers trong nhiệm vụ tìm kiếm dữ liệu, Mamba-3 vẫn đạt kết quả khả quan. Biến thể MIMO giúp cải thiện hiệu suất mà không làm tăng kích thước trạng thái.
Hiệu suất và tốc độ xử lý
Các lõi của chúng tôi sử dụng Triton, TileLang, và CuTe DSL để đạt hiệu suất phần cứng tối đa. Triton hỗ trợ các bộ xử lý tensor của GPU cho tốc độ huấn luyện nhanh, trong khi CuTe DSL tối ưu hoá tốc độ giải mã.
| Mô hình | n=512 | 1024 | 2048 | 4096 | 16384 |
|---|---|---|---|---|---|
| Mamba-3 (SISO) | 4.39 | 8.78 | 17.57 | 35.11 | 140.61 |
| Mamba-3 (MIMO r=4) | 4.74 | 9.48 | 18.96 | 37.85 | 151.81 |
Kết luận
Các bạn có thể xem thêm chi tiết trong bài viết trên arXiv và mã nguồn tại mamba-ssm.