๋ณธ๋ฌธ ๋ฐ”๋กœ๊ฐ€๊ธฐ

๐Ÿ“ฑIT

Torch.compile๊ณผ FlashAttention์„ ํ†ตํ•œ ์ตœ์ ํ™” ๋ฐฉ๋ฒ•

์•ˆ๋…•ํ•˜์„ธ์š”! ์—ฌ๋Ÿฌ๋ถ„, ์˜ค๋Š˜์€ ์ธ๊ณต์ง€๋Šฅ ๋ชจ๋ธ์˜ ์„ฑ๋Šฅ์„ ๊ทน๋Œ€ํ™”ํ•˜๋Š” ๋ฐฉ๋ฒ•์— ๋Œ€ํ•ด ์ด์•ผ๊ธฐํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.


์„œ๋ก : ์ธ๊ณต์ง€๋Šฅ์˜ ๋ฐœ์ „๊ณผ ์ตœ์ ํ™”

 

์ธ๊ณต์ง€๋Šฅ(AI)์€ ์šฐ๋ฆฌ์˜ ์‚ถ์— ํฐ ๋ณ€ํ™”๋ฅผ ๊ฐ€์ ธ์™”์Šต๋‹ˆ๋‹ค.

ํŠนํžˆ ๋”ฅ๋Ÿฌ๋‹๊ณผ ๊ฐ™์€ ๊ธฐ์ˆ ์€ ๋‹ค์–‘ํ•œ ๋ถ„์•ผ์—์„œ ํ˜์‹ ์„ ์ด๋ฃจ๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ์ด๋Ÿฌํ•œ ๊ธฐ์ˆ ์„ ์ตœ๋Œ€ํ•œ ํ™œ์šฉํ•˜๊ธฐ ์œ„ํ•ด์„œ๋Š” ๋ชจ๋ธ์˜ ์„ฑ๋Šฅ์„ ์ตœ์ ํ™”ํ•˜๋Š” ๊ฒƒ์ด ์ค‘์š”ํ•ฉ๋‹ˆ๋‹ค. ์˜ค๋Š˜์€ ๊ทธ ์ค‘์—์„œ๋„ Torch.compile๊ณผ FlashAttention์— ๋Œ€ํ•ด ์•Œ์•„๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.


Torch.compile: PyTorch ์ฝ”๋“œ ์ตœ์ ํ™”

 

Torch.compile์€ PyTorch 2.0์—์„œ ์ œ๊ณตํ•˜๋Š” ๊ธฐ๋Šฅ์œผ๋กœ, PyTorch ์ฝ”๋“œ๋ฅผ ์ตœ์ ํ™”๋œ ์ปค๋„๋กœ ์ปดํŒŒ์ผํ•˜์—ฌ ์‹คํ–‰ ์†๋„๋ฅผ ํฌ๊ฒŒ ํ–ฅ์ƒ์‹œํ‚ต๋‹ˆ๋‹ค. ์ด ๊ธฐ๋Šฅ์€ ๋Œ€๋ถ€๋ถ„์˜ ๊ณผ์ •์—์„œ ๋‹จ ํ•œ ์ค„์˜ ์ฝ”๋“œ ์ˆ˜์ •๋งŒ์œผ๋กœ๋„ ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค.

model = torch.compile(model)

์ด ๊ฐ„๋‹จํ•œ ์ฝ”๋“œ ํ•œ ์ค„๋กœ ๋ชจ๋ธ์˜ ์„ฑ๋Šฅ์„ ํฌ๊ฒŒ ํ–ฅ์ƒ์‹œํ‚ฌ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์‹ค์ œ๋กœ, PyTorch 2.0์„ ์‚ฌ์šฉํ•˜๋Š” ๋งŽ์€ ์—ฐ๊ตฌ์ž๋“ค์ด ์ด ๊ธฐ๋Šฅ์„ ํ†ตํ•ด ๋ชจ๋ธ์˜ ์ฒ˜๋ฆฌ ์†๋„๋ฅผ ํš๊ธฐ์ ์œผ๋กœ ๊ฐœ์„ ํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.

์ถœ์ฒ˜: QuickAITutorial


FlashAttention: ํšจ์œจ์ ์ธ ์ฃผ์˜ ๋ฉ”์ปค๋‹ˆ์ฆ˜

 

FlashAttention์€ Transformer ๋ชจ๋ธ์—์„œ ์ค‘์š”ํ•œ ์—ญํ• ์„ ํ•˜๋Š” Scaled Dot-Product Attention(SDPA)์„ ๋”์šฑ ํšจ์œจ์ ์œผ๋กœ ๊ตฌํ˜„ํ•œ ๊ธฐ๋Šฅ์ž…๋‹ˆ๋‹ค. ์ด ๊ธฐ๋Šฅ์€ ์ฟผ๋ฆฌ, ํ‚ค, ๊ฐ’ ๋ฒกํ„ฐ ๊ฐ„์˜ ์ฃผ์˜ ์ ์ˆ˜๋ฅผ ๊ณ„์‚ฐํ•˜๋Š” ๊ณผ์ •์—์„œ ๋ฐœ์ƒํ•  ์ˆ˜ ์žˆ๋Š” ๊ทธ๋ž˜๋””์–ธํŠธ ์†Œ์‹ค์ด๋‚˜ ํญ๋ฐœ์„ ๋ฐฉ์ง€ํ•ฉ๋‹ˆ๋‹ค.

with torch.backends.cuda.sdp_kernel(
    enable_flash=True,
    enable_math=False,
    enable_mem_efficient=False
):
    with torch.no_grad():
        output_ids = model.generate(
            token_ids.to(model.device),
            max_new_tokens=256,
            temperature=0.8,
            top_p=0.95,
            top_k=50,
            repetition_penalty=1.10,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )

์ด ์ฝ”๋“œ๋ฅผ ํ†ตํ•ด FlashAttention์„ ์ ์šฉํ•˜๋ฉด, ๋ชจ๋ธ์˜ ์—ฐ์‚ฐ ํšจ์œจ์„ฑ์„ ํฌ๊ฒŒ ๋†’์ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ์ตœ์ ํ™”๋Š” ํŠนํžˆ ๋Œ€๊ทœ๋ชจ ๋ฐ์ดํ„ฐ์…‹์„ ๋‹ค๋ฃฐ ๋•Œ ๋งค์šฐ ์œ ์šฉํ•ฉ๋‹ˆ๋‹ค.

์ถœ์ฒ˜: QuickAITutorial

 

Five Technique : VLLM + Torch + Flash_Attention =Super Local LLM

As LLms Boom, The model size of LLM increases according to a scaling law to improve performance, and recent LLMs

quickaitutorial.com

 


VLLM: ์ƒˆ๋กœ์šด ํ…์ŠคํŠธ ์ƒ์„ฑ ๋ฐฉ๋ฒ•

 

VLLM์€ ํ…์ŠคํŠธ ์ƒ์„ฑ์„ ์œ„ํ•œ ์ƒˆ๋กœ์šด ์ ‘๊ทผ ๋ฐฉ์‹์œผ๋กœ, SamplingParams ํด๋ž˜์Šค๋ฅผ ํ†ตํ•ด ํ…์ŠคํŠธ ์ƒ์„ฑ ๊ณผ์ •์—์„œ์˜ ๋ฌด์ž‘์œ„์„ฑ๊ณผ ์„ ํƒ์„ ์ œ์–ดํ•ฉ๋‹ˆ๋‹ค. ์ด๋Š” ๋‹ค์–‘ํ•œ ํ…์ŠคํŠธ ์ƒ์„ฑ ์‹œ๋‚˜๋ฆฌ์˜ค์—์„œ ๋งค์šฐ ์œ ์šฉํ•˜๊ฒŒ ์‚ฌ์šฉ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

from vllm import LLM, SamplingParams
model = LLM(model=model_name, tokenizer=tokenizer_name, dtype='float16')
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, top_k=50)
outputs = model.generate(self.prompt, sampling_params)

์ด ์ ‘๊ทผ ๋ฐฉ์‹์€ ํ…์ŠคํŠธ ์ƒ์„ฑ์˜ ๋‹ค์–‘์„ฑ์„ ๋†’์ด๋ฉฐ, ๋ณด๋‹ค ์ž์—ฐ์Šค๋Ÿฌ์šด ๊ฒฐ๊ณผ๋ฌผ์„ ์ƒ์„ฑํ•  ์ˆ˜ ์žˆ๊ฒŒ ํ•ฉ๋‹ˆ๋‹ค.

์ถœ์ฒ˜: QuickAITutorial


๊ฒฐ๋ก : ์ตœ์ ํ™”์˜ ์ค‘์š”์„ฑ

 

์ธ๊ณต์ง€๋Šฅ ๋ชจ๋ธ์˜ ์„ฑ๋Šฅ์„ ์ตœ์ ํ™”ํ•˜๋Š” ๊ฒƒ์€ ๋งค์šฐ ์ค‘์š”ํ•ฉ๋‹ˆ๋‹ค. Torch.compile๊ณผ FlashAttention, VLLM๊ณผ ๊ฐ™์€ ๊ธฐ์ˆ ๋“ค์€ ์ด๋Ÿฌํ•œ ์ตœ์ ํ™”๋ฅผ ํ†ตํ•ด ๋ชจ๋ธ์˜ ํšจ์œจ์„ฑ์„ ํฌ๊ฒŒ ๋†’์ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์ด๋Ÿฌํ•œ ๊ธฐ์ˆ ๋“ค์„ ํ™œ์šฉํ•˜๋ฉด, ๋ณด๋‹ค ๋น ๋ฅด๊ณ  ์ •ํ™•ํ•œ ์ธ๊ณต์ง€๋Šฅ ๋ชจ๋ธ์„ ๊ตฌ์ถ•ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

ํ–ฅํ›„ ์—ฐ๊ตฌ ๋ฐฉํ–ฅ์œผ๋กœ๋Š” ์ด๋Ÿฌํ•œ ์ตœ์ ํ™” ๊ธฐ์ˆ ๋“ค์„ ๋”์šฑ ๋ฐœ์ „์‹œํ‚ค๊ณ , ๋‹ค์–‘ํ•œ ์‘์šฉ ๋ถ„์•ผ์— ์ ์šฉํ•˜๋Š” ๊ฒƒ์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. ์ด๋ฅผ ํ†ตํ•ด ์ธ๊ณต์ง€๋Šฅ์˜ ์ž ์žฌ๋ ฅ์„ ์ตœ๋Œ€ํ•œ ๋ฐœํœ˜ํ•  ์ˆ˜ ์žˆ์„ ๊ฒƒ์ž…๋‹ˆ๋‹ค.


#์ธ๊ณต์ง€๋Šฅ #๋”ฅ๋Ÿฌ๋‹ #PyTorch #๋ชจ๋ธ์ตœ์ ํ™” #FlashAttention #VLLM

NuuNStation์˜ FirstSation์œผ๋กœ ์ž‘์„ฑ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.