phi15-js-api / app.py
misalsathsara's picture
Update app.py
93228bf verified
raw
history blame
2.99 kB
import os
# Redirect Hugging Face cache to a writable directory
os.environ["HF_HOME"] = "/tmp"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf-cache"
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import re
app = FastAPI()
model_id = "misalsathsara/phi1.5-js-codegen"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
system_prompt = """
You are a smart javascript assistant that only generates only the best simple javascript functions without any comments like this:
function transform(row) {
row['Latitude'] = row['Location'].split(',')[0];
row['Longitude'] = row['Location'].split(',')[1];
return row;
}
when user gives a prompt like "convert the location field into separate latitude and longitude fields".
Generate simple javascript functions that should take a single row of data as input and the generated function name is always transform.
The user may use the words column, item or field to mean each column.
Guard against null and undefined for items in the row.
${fieldList}
Field names are case sensitive.
For parsing something into a date, assume a function called parseAnyDate is available.
If the code requires some numeric calculation - ensure the value is converted to a number first. Don't assume its always the correct data type.
When doing any string comparison, make it case insensitive.
When replacing characters in a string, make sure to use the correct replacement literal. For example, to replace hyphens with spaces, use: .replace(/-/g, ' ')
The function should not include a single comment before or after the function.
Don't add any text except for the function code.
Don't add any markdown block markers either.
Every function must end with return row;
"""
class RequestData(BaseModel):
instruction: str
@app.post("/generate")
def generate_code(data: RequestData):
instruction = data.instruction
full_prompt = system_prompt + f"\n### Instruction:\n{instruction}\n\n### Response:\n"
input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids.to(device)
with torch.no_grad():
output_ids = model.generate(
input_ids,
max_new_tokens=200,
temperature=0.3,
top_k=50,
top_p=0.95,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
generated_text = tokenizer.decode(output_ids[0][input_ids.shape[-1]:], skip_special_tokens=True)
# Only return JavaScript function β€” no extra text
match = re.search(r"function\s+transform\(.*?\)\s*{.*?return row;\s*}", generated_text, re.DOTALL)
clean_output = match.group(0).strip() if match else generated_text.strip()
from fastapi.responses import PlainTextResponse
return PlainTextResponse(clean_output)