본문 바로가기
python/딥러닝

langchain에서 VLLM 사용하기 (with lora)

by _avocado_ 2024. 8. 12.

langchain 라이브러리와 VLLM을 함께 이용하는 방법을 소개한다.

현재 langchain_community 0.2.6 버전 기준으로 lora adapter 적용 및 stop token이 작용하지 않는 것으로 확인된다.
이를 위한 class 처리 방법도 함께 소개한다.


 langchain VLLM class 사용법

1. 모델 선언

langchain_community 에서 VLLM class 를 불러와 사용할 수 있다.

일반 적인 파라미터들을 바로 설정 할 수도 있고 vllm_kwargs 에 dict 형태로 vllm에서 사용하던 파라미터들을 모두 사용할 수 있다.

import torch
from langchain_community.llms import VLLM

model_path = "model_path"

vllm_kwargs = {'kv_cache_dtype':"auto",
               'gpu_memory_utilization':0.5,
               'enable_lora':True,
               'max_lora_rank':64,
               'enforce_eager':True,}

llm = VLLM(
    model=model_path,
    stop=["<|eot_id|>"],
    max_new_tokens=512,
    temperature=0.5,
    vllm_kwargs=vllm_kwargs,
    tensor_parallel_size=1 # GPU 사용 개수
    )

 

2. chain 구성

langchain으로 불러온 LLM 모델로 아래와 같이 chain을 구성하여 사용할 수 있다.

from langchain.prompts import PromptTemplate

prompt = f'''<|begin_of_text|><|start_header_id|>system<|end_header_id|>
			다음 한국어 입력을 영어로 번역해서 말하세요.
			<|eot_id|><|start_header_id|>user<|end_header_id|>
			{text}
			<|eot_id|><|start_header_id|>assistant<|end_header_id|>'''
            
prompt = PromptTemplate.from_template(prompt)

chain = prompt | llm

text = '''나는 너를 좋아해'''

chain.invoke({'text':text})

 lora adapter 적용 VLLM class 구성

langchain VLLM class는 lora adapter를 적용하는 방법이 현재는 존재하지 않는다.

따라서 class를 새로 구성해줘야 한다. chain에서 langchain llm은 _generate 함수를 사용한다.
따라서  _generate 함수를 재구성 하여 lora adapter를 적용한다.

 

기존 _generation 코드

def _generate(
        self,
        prompts: List[str],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> LLMResult:
        """Run the LLM on the given prompt and input."""

        from vllm import SamplingParams

        # build sampling parameters
        params = {**self._default_params, **kwargs, "stop": stop}
        sampling_params = SamplingParams(**params)
        # call the model
        outputs = self.client.generate(prompts, sampling_params)

        generations = []
        for output in outputs:
            text = output.outputs[0].text
            generations.append([Generation(text=text)])

        return LLMResult(generations=generations)

 

커스텀 _generation 함수

 

lora_adapter 변수에 adapter path를 입력하고

generation 시에 lora_adapter 가 존재 한다면 adapter를 적용한 결과를, 그렇지 않으면 일반 결과를 반환하는 코드

from typing import Any, List, Optional
from langchain_core.outputs import Generation, LLMResult
from langchain_core.callbacks import CallbackManagerForLLMRun
from vllm.lora.request import LoRARequest
from vllm import SamplingParams

class VLLM_LORA(VLLM): # 상속
    lora_adapter : str = None # lora_adapter 변수명 선언
    def set_lora_adapter(self, lora_adapter): # lora_adapter 변수명 선언 함수
        if lora_adapter:
            self.lora_adapter = lora_adapter

    def _generate(
        self,
        prompts: List[str],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> LLMResult:
        """Run the LLM on the given prompt and input."""

        # build sampling parameters
        params = {**self._default_params, **kwargs, "stop": ["<|eot_id|>","<|start_header_id|>","<eos>","<|end_of_text|>"]}
        sampling_params = SamplingParams(**params)
        # call the model
        if self.lora_adapter:
            outputs = self.client.generate(
                                prompts,
                                sampling_params,
                                lora_request=LoRARequest("adapter", 1, self.lora_adapter) # adapter 적용
                            )
        else:
            outputs = self.client.generate(prompts,
                                           sampling_params)

        generations = []
        for output in outputs:
            text = output.outputs[0].text
            generations.append([Generation(text=text)])

        return LLMResult(generations=generations)

 

사용방법

모델 선언시는 위와 같은 방법으로 진행하고 set_lora_adapter 함수로 adapter의 path를 입력하여 사용한다.

llm = VLLM_LORA(
    model=model_path,
    quantization="awq",
    tensor_parallel_size=2,
    enable_lora=True,
    max_lora_rank=64,
    kv_cache_dtype="auto",
    gpu_memory_utilization=0.7, 
    enforce_eager=True
    )
    
adapter_path = "adapter_path"

llm.set_lora_adapter(adapter_path) # adapter 넣기

chain = prompt | llm

text = '''나는 너를 좋아해'''

chain.invoke({'text':text})

 

* 추가적으로 langchain VLLM에서 stops에 토큰을 넣어도 적용되지 않는 것으로 확인되며 위에 lora_adapter 적용 class에서 generation 시에 stops를 직

접 넣어주어 해결 할 수 있다.

댓글