misalsathsara commited on
Commit
0d621b0
·
verified ·
1 Parent(s): 81dff5e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -15
app.py CHANGED
@@ -1,24 +1,33 @@
1
  import os
 
 
 
 
 
 
2
 
3
- # Redirect Hugging Face cache to a writable directory
4
  os.environ["HF_HOME"] = "/tmp"
5
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf-cache"
6
 
7
- from fastapi import FastAPI
8
- from pydantic import BaseModel
9
- from transformers import AutoTokenizer, AutoModelForCausalLM
10
- import torch
11
- import re
12
 
13
  app = FastAPI()
14
 
 
15
  model_id = "misalsathsara/phi1.5-js-codegen"
16
  tokenizer = AutoTokenizer.from_pretrained(model_id)
 
17
  model = AutoModelForCausalLM.from_pretrained(model_id)
 
 
 
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
  model.to(device)
20
  model.eval()
21
 
 
22
  system_prompt = """
23
  You are a smart javascript assistant that only generates only the best simple javascript functions without any comments like this:
24
  function transform(row) {
@@ -49,13 +58,13 @@ class RequestData(BaseModel):
49
  def generate_code(data: RequestData):
50
  instruction = data.instruction
51
  full_prompt = system_prompt + f"\n### Instruction:\n{instruction}\n\n### Response:\n"
52
-
53
  input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids.to(device)
54
-
55
  with torch.no_grad():
56
  output_ids = model.generate(
57
  input_ids,
58
- max_new_tokens=200,
59
  temperature=0.3,
60
  top_k=50,
61
  top_p=0.95,
@@ -65,17 +74,12 @@ def generate_code(data: RequestData):
65
 
66
  generated_text = tokenizer.decode(output_ids[0][input_ids.shape[-1]:], skip_special_tokens=True)
67
 
68
- # Only return JavaScript function — no extra text
69
  # Extract only the JavaScript function that ends with return row;
70
  match = re.search(r"function\s+transform\s*\([^)]*\)\s*{[^}]*return row;\s*}", generated_text, re.DOTALL)
71
  if match:
72
  clean_output = match.group(0).strip()
73
  else:
74
- # fallback: try to grab only up to "return row;"
75
  fallback = generated_text.split("return row;")[0] + "return row;"
76
  clean_output = fallback.strip()
77
-
78
- from fastapi.responses import PlainTextResponse
79
- return PlainTextResponse(clean_output)
80
-
81
 
 
 
1
  import os
2
+ import torch
3
+ import re
4
+ from fastapi import FastAPI
5
+ from fastapi.responses import PlainTextResponse
6
+ from pydantic import BaseModel
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM
8
 
9
+ # Set cache directory for HF Spaces
10
  os.environ["HF_HOME"] = "/tmp"
11
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf-cache"
12
 
13
+ # Optional: speed up inference on CPU
14
+ torch.set_num_threads(1)
 
 
 
15
 
16
  app = FastAPI()
17
 
18
+ # Load model + tokenizer
19
  model_id = "misalsathsara/phi1.5-js-codegen"
20
  tokenizer = AutoTokenizer.from_pretrained(model_id)
21
+
22
  model = AutoModelForCausalLM.from_pretrained(model_id)
23
+ # Optional: Compile model if using PyTorch >= 2 (comment out if error)
24
+ # model = torch.compile(model)
25
+
26
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
  model.to(device)
28
  model.eval()
29
 
30
+ # JS assistant system prompt
31
  system_prompt = """
32
  You are a smart javascript assistant that only generates only the best simple javascript functions without any comments like this:
33
  function transform(row) {
 
58
  def generate_code(data: RequestData):
59
  instruction = data.instruction
60
  full_prompt = system_prompt + f"\n### Instruction:\n{instruction}\n\n### Response:\n"
61
+
62
  input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids.to(device)
63
+
64
  with torch.no_grad():
65
  output_ids = model.generate(
66
  input_ids,
67
+ max_new_tokens=100, # Faster
68
  temperature=0.3,
69
  top_k=50,
70
  top_p=0.95,
 
74
 
75
  generated_text = tokenizer.decode(output_ids[0][input_ids.shape[-1]:], skip_special_tokens=True)
76
 
 
77
  # Extract only the JavaScript function that ends with return row;
78
  match = re.search(r"function\s+transform\s*\([^)]*\)\s*{[^}]*return row;\s*}", generated_text, re.DOTALL)
79
  if match:
80
  clean_output = match.group(0).strip()
81
  else:
 
82
  fallback = generated_text.split("return row;")[0] + "return row;"
83
  clean_output = fallback.strip()
 
 
 
 
84
 
85
+ return PlainTextResponse(clean_output)