Trong bài viết này, mình chia sẻ cách một LLM được thực thi (inference) trên các inference server, từ đó giúp chúng ta hiểu rõ hơn về nguyên lý hoạt động của LLM, ý nghĩa của các tham số khi gọi model và cách các inference server tối ưu quá trình inference.
Để dễ dàng theo dõi bài viết này, bạn nên có:
Kiến thức cơ bản về đại số tuyến tính và các phép biến đổi ma trận, vector như nhân vô hướng/có hướng, chuyển vị,…
Hiểu biết nền tảng về machine learning và deep learning, chẳng hạn như perceptron, cách mô hình hoạt động và sinh ra giá trị dự đoán,…
Trải nghiệm sử dụng LLM (Large Language Model).
OK, giờ thì hãy bắt đầu.
Một trong những bài toán cơ bản của LLM là hoàn thiện phần còn lại của 1 câu văn. Ví dụ: "Tôi đi học..." thì output của LLM có thể là "Tôi đi học 5 ngày trong tuần".
I. Hiểu về quá trình inference
1. Transformer
Ở bước này, chúng ta sẽ tính toán mối quan hệ giữa các từ trong prompt, xác định từ nào có liên hệ với từ nào và mức độ liên hệ đó ra sao.
Để thực hiện điều này, chúng ta forward đầu vào X0 - ma trận embedding của các token trong prompt qua các Transformer block.
Như trong hình, mỗi Transformer block gồm hai thành phần chính:
Attention: tính toán mối quan hệ giữa các từ trong prompt.
Feed Forward Neural Network (MLP – Multi-Layer Perceptron): thực hiện các biến đổi phi tuyến tính để trích xuất và biểu diễn đặc trưng ở mức cao hơn.
Ri là hàng thứ i của ma trận R∈Rn,dmodel — vector ẩn của tokeni∈Rdmodel
μi=dmodel1∑jRi,j
σi2=dmodel1∑j(Ri,j−μi)2
γ,β∈Rdmodel là hai vector trainable parameters (scale và shift)
💡 Nhận xét cá nhân: Các bước trên có thể được xem như quá trình LLM tự động trích xuất đặc trưng về mối quan hệ giữa các token trong prompt. Cách làm này khá tương đồng với các kỹ thuật Feature Extraction (tự động) như PCA, SIFT, HOG,… hoặc Feature Engineering (thủ công) như thống kê, tính xác suất, v.v.
💡 Nhận xét cá nhân: Ở bước FFN, LLM thực hiện biến đổi phi tuyến tính trên dữ liệu đầu vào (kết quả attention giữa các token).
Quá trình này tương tự như:
Các lớp ẩn (hidden layer) trong mạng neural, nơi dữ liệu được biến đổi phi tuyến để học biểu diễn trừu tượng hơn.
Các phép biến đổi phi tuyến (kernel trick) trong một số thuật toán machine learning, chẳng hạn như Support Vector Machine.
Layer Norm
X1=LN(F3+R)trong đoˊF3,R∈Rn,dmodel(10)
Sau khi đi qua một Transformer block, ta thu được đầu ra X1.
Tuy nhiên, trong thực tế, một LLM thường bao gồm từ 12 đến 120 Transformer block, và quá trình tính toán được thực hiện tuần tự như sau:
Giả sử LLM được huấn luyện trên một bộ từ vựng gồm nvocab từ.
Khi đó, ta có một ma trận biểu diễn không gian embedding của toàn bộ từ vựng, được định nghĩa như sau:
Wvocab∈Rdvocab,dmodelbvocab∈Rnvocab
Ta chiếu vector đặc trưng của token cuối cùng trong prompt lên không gian từ vựng để tính giá trị logits, tức là mức độ phù hợp của token đó với từng từ trong bộ từ vựng.
Sau đó, ta tính phân phối xác suất trên toàn bộ từ vựng bằng cách chuẩn hóa các giá trị logits (thường thông qua hàm softmax) để xác định xác suất xuất hiện của từng từ.
Sau bước này, ta tiếp tục lặp lại quy trình tương tự như trong phần Transformer để sinh các token tiếp theo, cho đến khi đạt điều kiện dừng (ví dụ: sinh ra token kết thúc hoặc đạt độ dài tối đa).
II. Một số phương pháp tối ưu
Ở phần trước, chúng ta đã tìm hiểu các bước giúp LLM sinh ra token mới.
Trong phần này, mình sẽ chia sẻ những phương pháp tối ưu hóa mà các inference server thường sử dụng để tăng tốc và giảm chi phí cho quá trình inference.
1. KV cache
Quan sát lại công thức (16), để tính Attention của token sinh thứ t, ta cần thực hiện phép nhân Q.KT có kích thước là (n+t,n+t).
Do đó, độ phức tạp tính toán của phép nhân ma trận này là O(n2), và sẽ tăng nhanh theo độ dài của prompt.
Để giải quyết vấn đề này, các inference server thường dùng một phần bộ nhớ (RAM hoặc VRAM) để lưu trữ cache cho các giá trị K và V của prompt cùng các token đã sinh, nhằm tái sử dụng kết quả cũ và tránh tính toán dư thừa ở các bước tiếp theo.
2. Flash Attention
Nhìn lại công thứ 4, kết quả của O=Q.KTvớiO∈Rn,n
Dưới đây là bảng số liệu cho mức độ tương quan giữa số lượng token trong prompt, độ phức tạp tính toán O, và lượng bộ nhớ cần sử dụng trong quá trình inference:
n_tokens
shape(0)
size(0) in FP16 (Gi)
512
262.144
0.00048
2048
4.194.304
0.0078
32.768
1.073.741.824
2
Với các model hỗ trợ 32k token, một head khi tính toán Attention có thể chiếm khoảng 2 GiB bộ nhớ.
Các model như GPT-3.5 hoặc GPT-4o có khoảng 100 head, nên chi phí lưu trữ và tính toán trở nên rất lớn.
Một số vấn đề phát sinh bao gồm:
I/O lớn do phải tải và truy xuất lượng thông tin khổng lồ.
Độ phức tạp tính toán cao khi số lượng token tăng.
Bộ nhớ yêu cầu rất lớn để lưu trữ các giá trị trung gian (đặc biệt là K, V cache).
Để giải quyết vấn đề này, người ta sử dụng thuật toán Flash Attention, trong đó các ma trận Q, K, V được chia nhỏ thành các khối (block) có kích thước cố định (thường là 128 × 128).
Việc tính toán Attention trên từng khối nhỏ giúp giảm đáng kể lượng bộ nhớ tạm cần dùng, tăng khả năng song song hóa trên GPU, và giữ kết quả trung gian trong SRAM thay vì DRAM, từ đó tăng tốc độ và giảm chi phí tính toán.
Như vậy, ta có Qi,Ki,Vi∈R128,128vớii=1...m. Ta gọi tập Qi,Ki,Vi là các tensor.
Để tính toán các tensor một cách hiệu quả, người ta sử dụng một thành phần chuyên dụng trong GPU gọi là Tensor Core.
Tensor Core đảm nhiệm hai chức năng chính:
Tính toán hiệu quả trên 2 ma trận có precision khác nhau, ví dụ FP16 x INT8, phù hợp cho các quantization model
Trên CPU hoặc CUDA core, cơ chế vectorization cho phép thực hiện song song nhiều phép tính trên các phần tử của một vector (tương tự SIMD – Single Instruction, Multiple Data).
Trong khi đó, Tensor Core nâng cấp mức song song này — thay vì xử lý từng vector riêng lẻ, nó thực hiện phép nhân giữa các ma trận nhỏ (tức là nhiều vector cùng lúc) chỉ trong một chu kỳ lệnh, giúp tăng tốc độ xử lý vượt trội trong các bài toán như attention và matrix multiplication.
Trên đây là hai tối ưu quan trọng liên quan trực tiếp đến các công thức tính toán của LLM.
Tuy nhiên, trong thực tế, các inference server còn áp dụng nhiều kỹ thuật tối ưu khác như Paged KV Cache, Weight Sharing, Operator Fusion,…
Những nội dung này mình sẽ trình bày chi tiết hơn trong các bài viết tiếp theo.
II. Tổng kết
Vậy mà mình đã chia sẻ cách mà các Large Language Model hoạt động để chúng ta thấy được những pain point trong quá trình inference cũng như cách các inference server và GPU tối ưu chúng.
Trong những bài viết sau, mình sẽ chia sẻ về ý nghĩa toán học ở high level (low level thì chắc không do mình không phải dân chuyên 😄) cách LLM được thiết kế cũng như training/triển khai 1 LLM bằng Tensorflow code.