import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
CHECKPOINT = "mistralai/Mistral-7B-v0.1"
class Model:
def __init__(self, **kwargs) -> None:
self.tokenizer = None
self.model = None
def load(self):
self.model = AutoModelForCausalLM.from_pretrained(
CHECKPOINT, torch_dtype=torch.float16, device_map="auto"
)
self.tokenizer = AutoTokenizer.from_pretrained(
CHECKPOINT,
)
def predict(self, request: dict):
prompt = request.pop("prompt")
generate_args = {
"max_new_tokens": request.get("max_new_tokens", 128),
"temperature": request.get("temperature", 1.0),
"top_p": request.get("top_p", 0.95),
"top_k": request.get("top_p", 50),
"repetition_penalty": 1.0,
"no_repeat_ngram_size": 0,
"use_cache": True,
"do_sample": True,
"eos_token_id": self.tokenizer.eos_token_id,
"pad_token_id": self.tokenizer.pad_token_id,
}
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.cuda()
with torch.no_grad():
output = self.model.generate(inputs=input_ids, **generate_args)
return self.tokenizer.decode(output[0])