from threading import Thread
from typing import Dict
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
GenerationConfig,
TextIteratorStreamer,
)
CHECKPOINT = "tiiuae/falcon-7b-instruct"
DEFAULT_MAX_NEW_TOKENS = 150
DEFAULT_TOP_P = 0.95
class Model:
def __init__(self, **kwargs) -> None:
self.tokenizer = None
self.model = None
def load(self):
self.tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
self.tokenizer.pad_token = self.tokenizer.eos_token_id
self.model = AutoModelForCausalLM.from_pretrained(
CHECKPOINT,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto",
)
def predict(self, request: Dict) -> Dict:
prompt = request.pop("prompt")
inputs = self.tokenizer(
prompt, return_tensors="pt", max_length=512, truncation=True, padding=True
)
input_ids = inputs["input_ids"].to("cuda")
streamer = TextIteratorStreamer(self.tokenizer)
generation_config = GenerationConfig(
temperature=1,
top_p=DEFAULT_TOP_P,
top_k=40,
)
with torch.no_grad():
generation_kwargs = {
"input_ids": input_ids,
"generation_config": generation_config,
"return_dict_in_generate": True,
"output_scores": True,
"pad_token_id": self.tokenizer.eos_token_id,
"max_new_tokens": DEFAULT_MAX_NEW_TOKENS,
"streamer": streamer,
}
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
def inner():
for text in streamer:
yield text
thread.join()
return inner()