Path: blob/main/sagemaker/24_train_bloom_peft_lora/scripts/inference.py
5906 views
from transformers import AutoModelForCausalLM, AutoTokenizer1import torch234def model_fn(model_dir):5# load model and processor from model_dir6model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="auto", load_in_8bit=True)7tokenizer = AutoTokenizer.from_pretrained(model_dir)89return model, tokenizer101112def predict_fn(data, model_and_tokenizer):13# unpack model and tokenizer14model, tokenizer = model_and_tokenizer1516# process input17inputs = data.pop("inputs", data)18parameters = data.pop("parameters", None)1920# preprocess21input_ids = tokenizer(inputs, return_tensors="pt").input_ids.to(model.device)2223# pass inputs with all kwargs in data24if parameters is not None:25outputs = model.generate(input_ids, **parameters)26else:27outputs = model.generate(input_ids)2829# postprocess the prediction30prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)3132return [{"generated_text": prediction}]33343536