์๋ ํ์ธ์! ์ฌ๋ฌ๋ถ, ์ค๋์ ์ธ๊ณต์ง๋ฅ ๋ชจ๋ธ์ ์ฑ๋ฅ์ ๊ทน๋ํํ๋ ๋ฐฉ๋ฒ์ ๋ํด ์ด์ผ๊ธฐํด๋ณด๊ฒ ์ต๋๋ค.
์๋ก : ์ธ๊ณต์ง๋ฅ์ ๋ฐ์ ๊ณผ ์ต์ ํ
์ธ๊ณต์ง๋ฅ(AI)์ ์ฐ๋ฆฌ์ ์ถ์ ํฐ ๋ณํ๋ฅผ ๊ฐ์ ธ์์ต๋๋ค.
ํนํ ๋ฅ๋ฌ๋๊ณผ ๊ฐ์ ๊ธฐ์ ์ ๋ค์ํ ๋ถ์ผ์์ ํ์ ์ ์ด๋ฃจ๊ณ ์์ต๋๋ค. ๊ทธ๋ฌ๋ ์ด๋ฌํ ๊ธฐ์ ์ ์ต๋ํ ํ์ฉํ๊ธฐ ์ํด์๋ ๋ชจ๋ธ์ ์ฑ๋ฅ์ ์ต์ ํํ๋ ๊ฒ์ด ์ค์ํฉ๋๋ค. ์ค๋์ ๊ทธ ์ค์์๋ Torch.compile๊ณผ FlashAttention์ ๋ํด ์์๋ณด๊ฒ ์ต๋๋ค.
Torch.compile: PyTorch ์ฝ๋ ์ต์ ํ
Torch.compile์ PyTorch 2.0์์ ์ ๊ณตํ๋ ๊ธฐ๋ฅ์ผ๋ก, PyTorch ์ฝ๋๋ฅผ ์ต์ ํ๋ ์ปค๋๋ก ์ปดํ์ผํ์ฌ ์คํ ์๋๋ฅผ ํฌ๊ฒ ํฅ์์ํต๋๋ค. ์ด ๊ธฐ๋ฅ์ ๋๋ถ๋ถ์ ๊ณผ์ ์์ ๋จ ํ ์ค์ ์ฝ๋ ์์ ๋ง์ผ๋ก๋ ๊ฐ๋ฅํฉ๋๋ค.
model = torch.compile(model)
์ด ๊ฐ๋จํ ์ฝ๋ ํ ์ค๋ก ๋ชจ๋ธ์ ์ฑ๋ฅ์ ํฌ๊ฒ ํฅ์์ํฌ ์ ์์ต๋๋ค.
์ค์ ๋ก, PyTorch 2.0์ ์ฌ์ฉํ๋ ๋ง์ ์ฐ๊ตฌ์๋ค์ด ์ด ๊ธฐ๋ฅ์ ํตํด ๋ชจ๋ธ์ ์ฒ๋ฆฌ ์๋๋ฅผ ํ๊ธฐ์ ์ผ๋ก ๊ฐ์ ํ๊ณ ์์ต๋๋ค.
์ถ์ฒ: QuickAITutorial
FlashAttention: ํจ์จ์ ์ธ ์ฃผ์ ๋ฉ์ปค๋์ฆ
FlashAttention์ Transformer ๋ชจ๋ธ์์ ์ค์ํ ์ญํ ์ ํ๋ Scaled Dot-Product Attention(SDPA)์ ๋์ฑ ํจ์จ์ ์ผ๋ก ๊ตฌํํ ๊ธฐ๋ฅ์ ๋๋ค. ์ด ๊ธฐ๋ฅ์ ์ฟผ๋ฆฌ, ํค, ๊ฐ ๋ฒกํฐ ๊ฐ์ ์ฃผ์ ์ ์๋ฅผ ๊ณ์ฐํ๋ ๊ณผ์ ์์ ๋ฐ์ํ ์ ์๋ ๊ทธ๋๋์ธํธ ์์ค์ด๋ ํญ๋ฐ์ ๋ฐฉ์งํฉ๋๋ค.
with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=False,
enable_mem_efficient=False
):
with torch.no_grad():
output_ids = model.generate(
token_ids.to(model.device),
max_new_tokens=256,
temperature=0.8,
top_p=0.95,
top_k=50,
repetition_penalty=1.10,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
์ด ์ฝ๋๋ฅผ ํตํด FlashAttention์ ์ ์ฉํ๋ฉด, ๋ชจ๋ธ์ ์ฐ์ฐ ํจ์จ์ฑ์ ํฌ๊ฒ ๋์ผ ์ ์์ต๋๋ค. ์ด๋ฌํ ์ต์ ํ๋ ํนํ ๋๊ท๋ชจ ๋ฐ์ดํฐ์ ์ ๋ค๋ฃฐ ๋ ๋งค์ฐ ์ ์ฉํฉ๋๋ค.
์ถ์ฒ: QuickAITutorial
Five Technique : VLLM + Torch + Flash_Attention =Super Local LLM
As LLms Boom, The model size of LLM increases according to a scaling law to improve performance, and recent LLMs
quickaitutorial.com
VLLM: ์๋ก์ด ํ ์คํธ ์์ฑ ๋ฐฉ๋ฒ
VLLM์ ํ ์คํธ ์์ฑ์ ์ํ ์๋ก์ด ์ ๊ทผ ๋ฐฉ์์ผ๋ก, SamplingParams ํด๋์ค๋ฅผ ํตํด ํ ์คํธ ์์ฑ ๊ณผ์ ์์์ ๋ฌด์์์ฑ๊ณผ ์ ํ์ ์ ์ดํฉ๋๋ค. ์ด๋ ๋ค์ํ ํ ์คํธ ์์ฑ ์๋๋ฆฌ์ค์์ ๋งค์ฐ ์ ์ฉํ๊ฒ ์ฌ์ฉ๋ ์ ์์ต๋๋ค.
from vllm import LLM, SamplingParams
model = LLM(model=model_name, tokenizer=tokenizer_name, dtype='float16')
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, top_k=50)
outputs = model.generate(self.prompt, sampling_params)
์ด ์ ๊ทผ ๋ฐฉ์์ ํ ์คํธ ์์ฑ์ ๋ค์์ฑ์ ๋์ด๋ฉฐ, ๋ณด๋ค ์์ฐ์ค๋ฌ์ด ๊ฒฐ๊ณผ๋ฌผ์ ์์ฑํ ์ ์๊ฒ ํฉ๋๋ค.
์ถ์ฒ: QuickAITutorial
๊ฒฐ๋ก : ์ต์ ํ์ ์ค์์ฑ
์ธ๊ณต์ง๋ฅ ๋ชจ๋ธ์ ์ฑ๋ฅ์ ์ต์ ํํ๋ ๊ฒ์ ๋งค์ฐ ์ค์ํฉ๋๋ค. Torch.compile๊ณผ FlashAttention, VLLM๊ณผ ๊ฐ์ ๊ธฐ์ ๋ค์ ์ด๋ฌํ ์ต์ ํ๋ฅผ ํตํด ๋ชจ๋ธ์ ํจ์จ์ฑ์ ํฌ๊ฒ ๋์ผ ์ ์์ต๋๋ค.
์ด๋ฌํ ๊ธฐ์ ๋ค์ ํ์ฉํ๋ฉด, ๋ณด๋ค ๋น ๋ฅด๊ณ ์ ํํ ์ธ๊ณต์ง๋ฅ ๋ชจ๋ธ์ ๊ตฌ์ถํ ์ ์์ต๋๋ค.
ํฅํ ์ฐ๊ตฌ ๋ฐฉํฅ์ผ๋ก๋ ์ด๋ฌํ ์ต์ ํ ๊ธฐ์ ๋ค์ ๋์ฑ ๋ฐ์ ์ํค๊ณ , ๋ค์ํ ์์ฉ ๋ถ์ผ์ ์ ์ฉํ๋ ๊ฒ์ด ํ์ํฉ๋๋ค. ์ด๋ฅผ ํตํด ์ธ๊ณต์ง๋ฅ์ ์ ์ฌ๋ ฅ์ ์ต๋ํ ๋ฐํํ ์ ์์ ๊ฒ์ ๋๋ค.
#์ธ๊ณต์ง๋ฅ #๋ฅ๋ฌ๋ #PyTorch #๋ชจ๋ธ์ต์ ํ #FlashAttention #VLLM
NuuNStation์ FirstSation์ผ๋ก ์์ฑ๋์์ต๋๋ค.
'๐ฑIT' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
DSPy: ํ์ ์ ์ธ ์ธ์ด ๋ชจ๋ธ ์ต์ ํ ํ๋ ์์ํฌ (1) | 2024.06.11 |
---|---|
LangChain, RAG Fusion, GPT-4o๋ฅผ ํ์ฉํ ๊ฐ๋ ฅํ ์ฑ๋ด ๋ง๋ค๊ธฐ (0) | 2024.06.11 |
์ด๋๋น์ ํ์ ์ ์ธ ์๋ ํ๋ ์ ํ ์ด์ ์์ฑ ๊ธฐ์ (0) | 2024.06.08 |
ํธ๋์คํฌ๋จธ ๋์ฝ๋ ์๋ฒ ๋ฉ ๋ณํ์ ๋น๋ฐ์ ํํค์น๋ค! ๐ค (1) | 2024.06.08 |
OpenRLHF: AI ํ๋ จ์ ์๋ก์ด ํ์ (0) | 2024.06.08 |