phi15-js-api / app.py
misalsathsara's picture
Update app.py
be7d660 verified
raw
history blame
2.93 kB
import os
os.environ["HF_HOME"] = "/app/cache"
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import re
app = FastAPI()
# Load model and tokenizer from Hugging Face
model_id = "misalsathsara/phi1.5-js-codegen"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
# Your system prompt
system_prompt = """
You are a smart javascript assistant that only generates only the best simple javascript functions without any comments like this:
function transform(row) {
row['Latitude'] = row['Location'].split(',')[0];
row['Longitude'] = row['Location'].split(',')[1];
return row;
}
when user gives a prompt like "convert the location field into separate latitude and longitude fields".
Generate simple javascript functions that should take a single row of data as input and the generated function name is always transform.
The user may use the words column, item or field to mean each column.
Guard against null and undefined for items in the row.
${fieldList}
Field names are case sensitive.
For parsing something into a date, assume a function called parseAnyDate is available.
If the code requires some numeric calculation - ensure the value is converted to a number first. Don't assume its always the correct data type.
When doing any string comparison, make it case insensitive.
When replacing characters in a string, make sure to use the correct replacement literal. For example, to replace hyphens with spaces, use: .replace(/-/g, ' ')
The function should not include a single comment before or after the function.
Don't add any text except for the function code.
Don't add any markdown block markers either.
Every function must end with return row;
"""
# Define the expected request body
class RequestData(BaseModel):
instruction: str
# POST endpoint
@app.post("/generate")
def generate_code(data: RequestData):
instruction = data.instruction
full_prompt = system_prompt + f"\n### Instruction:\n{instruction}\n\n### Response:\n"
# Tokenize input
input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids.to(device)
with torch.no_grad():
output_ids = model.generate(
input_ids,
max_new_tokens=200,
temperature=0.3,
top_k=50,
top_p=0.95,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
generated_text = tokenizer.decode(output_ids[0][input_ids.shape[-1]:], skip_special_tokens=True)
# Extract clean JS function
match = re.search(r"function\s*\(.*?\)\s*{.*?return row;\s*}", generated_text, re.DOTALL)
clean_output = match.group(0).strip() if match else generated_text.strip()
return {"result": clean_output}