File size: 3,037 Bytes
be7d660
6a4eb93
be7d660
234a50d
 
 
 
 
 
6a4eb93
 
 
 
 
234a50d
6a4eb93
234a50d
 
 
6a4eb93
 
234a50d
 
 
 
6a4eb93
234a50d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a4eb93
234a50d
 
 
6a4eb93
 
234a50d
 
6a4eb93
 
 
 
234a50d
6a4eb93
 
234a50d
 
 
 
 
 
 
 
6a4eb93
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
os.environ["HF_HOME"] = "/tmp/hf"  # Prevents write errors on Hugging Face Spaces

from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import re

app = FastAPI(
    title="JavaScript Code Generator API",
    description="Generate simple JavaScript functions from natural language instructions",
    version="1.0"
)

# Load model and tokenizer
model_id = "misalsathsara/phi1.5-js-codegen"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

# Prompt Template
system_prompt = """
You are a smart javascript assistant that only generates only the best simple javascript functions without any comments like this:
function transform(row) {
    row['Latitude'] = row['Location'].split(',')[0];
    row['Longitude'] = row['Location'].split(',')[1];
    return row;
}
when user gives a prompt like "convert the location field into separate latitude and longitude fields".
Generate simple javascript functions that should take a single row of data as input and the generated function name is always transform.
The user may use the words column, item or field to mean each column.
Guard against null and undefined for items in the row.
${fieldList}
Field names are case sensitive.
For parsing something into a date, assume a function called parseAnyDate is available.
If the code requires some numeric calculation - ensure the value is converted to a number first. Don't assume its always the correct data type.
When doing any string comparison, make it case insensitive.
When replacing characters in a string, make sure to use the correct replacement literal. For example, to replace hyphens with spaces, use: .replace(/-/g, ' ')
The function should not include a single comment before or after the function.
Don't add any text except for the function code.
Don't add any markdown block markers either.
Every function must end with return row;
"""

# Input schema
class RequestData(BaseModel):
    instruction: str

# Main route
@app.post("/generate", summary="Generate JavaScript code", tags=["Code Generation"])
def generate_code(data: RequestData):
    instruction = data.instruction
    full_prompt = f"{system_prompt}\n### Instruction:\n{instruction}\n\n### Response:\n"

    inputs = tokenizer(full_prompt, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=200,
            temperature=0.3,
            top_k=50,
            top_p=0.95,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )

    result = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
    match = re.search(r"function\s*\(.*?\)\s*{.*?return row;\s*}", result, re.DOTALL)
    return {"result": match.group(0).strip() if match else result.strip()}