Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -14,6 +14,15 @@ import torch
|
|
| 14 |
import yfinance as yf
|
| 15 |
from datetime import datetime, timedelta
|
| 16 |
from math import sqrt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
# Set up logging
|
| 19 |
logging.basicConfig(level=logging.INFO)
|
|
@@ -48,21 +57,20 @@ try:
|
|
| 48 |
from PIL import Image
|
| 49 |
import io
|
| 50 |
except ModuleNotFoundError:
|
| 51 |
-
subprocess.check_call([sys.executable, "-m", "pip", "install", "matplotlib", "pillow"])
|
| 52 |
import matplotlib.pyplot as plt
|
| 53 |
from PIL import Image
|
| 54 |
import io
|
|
|
|
| 55 |
|
| 56 |
MAX_MAX_NEW_TOKENS = 512
|
| 57 |
DEFAULT_MAX_NEW_TOKENS = 512
|
| 58 |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "1024"))
|
| 59 |
|
| 60 |
DESCRIPTION = """# FinChat: Investing Q&A (Optimized for Speed)
|
| 61 |
-
This application delivers an interactive chat interface powered by a highly efficient, small AI model adapted for addressing investing and finance inquiries through specialized prompt engineering. It ensures rapid, reasoned responses to user queries. Duplicate this Space for customization or queue-free deployment.
|
| 62 |
-
<p>Running on CPU or GPU if available. Using Phi-2 model for faster inference. Inference is heavily optimized for responses in under 10 seconds for simple queries, with output limited to 250 tokens maximum. For longer responses, increase 'Max New Tokens' in Advanced Settings. Brief delays may occur in free-tier environments due to shared resources, but typical generation speeds are improved with the smaller model.</p>"""
|
| 63 |
|
| 64 |
-
LICENSE = """<p/>
|
| 65 |
-
---
|
| 66 |
This application employs the Phi-2 model, governed by Microsoft's Terms of Use. Refer to the [model card](https://huggingface.co/TheBloke/phi-2-GGUF) for details."""
|
| 67 |
|
| 68 |
# Load the model (skip fine-tuning for faster startup)
|
|
@@ -75,12 +83,12 @@ try:
|
|
| 75 |
llm = Llama(
|
| 76 |
model_path=model_path,
|
| 77 |
n_ctx=1024,
|
| 78 |
-
n_batch=1024,
|
| 79 |
n_threads=multiprocessing.cpu_count(),
|
| 80 |
n_gpu_layers=n_gpu_layers,
|
| 81 |
-
chat_format="chatml"
|
| 82 |
)
|
| 83 |
-
logger.info(f"Model loaded successfully with n_gpu_layers=
|
| 84 |
# Warm up the model for faster initial inference
|
| 85 |
llm("Warm-up prompt", max_tokens=1, echo=False)
|
| 86 |
logger.info("Model warm-up completed.")
|
|
@@ -101,6 +109,8 @@ Assistant:
|
|
| 101 |
- Represents average annual return with compounding
|
| 102 |
- Past performance is not indicative of future results."""
|
| 103 |
|
|
|
|
|
|
|
| 104 |
# Function to calculate CAGR using yfinance
|
| 105 |
def calculate_cagr(ticker, start_date, end_date):
|
| 106 |
try:
|
|
@@ -116,7 +126,7 @@ def calculate_cagr(ticker, start_date, end_date):
|
|
| 116 |
logger.error(f"Error calculating CAGR for {ticker}: {str(e)}")
|
| 117 |
return None
|
| 118 |
|
| 119 |
-
#
|
| 120 |
def calculate_risk_metrics(ticker, years=5):
|
| 121 |
try:
|
| 122 |
end_date = datetime.now().strftime('%Y-%m-%d')
|
|
@@ -134,6 +144,68 @@ def calculate_risk_metrics(ticker, years=5):
|
|
| 134 |
logger.error(f"Error calculating risk metrics for {ticker}: {str(e)}")
|
| 135 |
return None, None
|
| 136 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
# Assuming the generate function handles the chat logic (extended to include risk comparison)
|
| 138 |
def generate(
|
| 139 |
message: str,
|
|
@@ -143,10 +215,13 @@ def generate(
|
|
| 143 |
temperature: float,
|
| 144 |
top_p: float,
|
| 145 |
top_k: int,
|
| 146 |
-
|
|
|
|
|
|
|
| 147 |
if not system_prompt:
|
| 148 |
system_prompt = DEFAULT_SYSTEM_PROMPT
|
| 149 |
|
|
|
|
| 150 |
# Detect CAGR query
|
| 151 |
cagr_match = re.search(r'average return for (\w+) between (\d{4}) and (\d{4})', message.lower())
|
| 152 |
if cagr_match:
|
|
@@ -157,11 +232,23 @@ def generate(
|
|
| 157 |
end_date = f"{end_year}-12-31"
|
| 158 |
cagr = calculate_cagr(ticker, start_date, end_date)
|
| 159 |
if cagr is not None:
|
| 160 |
-
|
| 161 |
-
|
|
|
|
| 162 |
else:
|
| 163 |
-
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
# Detect risk comparison query
|
| 167 |
risk_match = re.search(r'which stock is riskier (\w+) or (\w+)', message.lower())
|
|
@@ -171,29 +258,120 @@ def generate(
|
|
| 171 |
vol1, sharpe1 = calculate_risk_metrics(ticker1)
|
| 172 |
vol2, sharpe2 = calculate_risk_metrics(ticker2)
|
| 173 |
if vol1 is None or vol2 is None:
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
riskier = ticker1
|
| 178 |
-
less_risky = ticker2
|
| 179 |
-
higher_vol = vol1
|
| 180 |
-
lower_vol = vol2
|
| 181 |
-
riskier_sharpe = sharpe1
|
| 182 |
-
less_sharpe = sharpe2
|
| 183 |
else:
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
|
| 193 |
# For other queries, fall back to LLM generation
|
| 194 |
conversation = [{"role": "system", "content": system_prompt}]
|
| 195 |
for user, assistant in history:
|
| 196 |
-
conversation.extend([
|
|
|
|
|
|
|
|
|
|
| 197 |
conversation.append({"role": "user", "content": message})
|
| 198 |
|
| 199 |
# Generate response using LLM (streamed)
|
|
@@ -205,35 +383,73 @@ def generate(
|
|
| 205 |
top_k=top_k,
|
| 206 |
stream=True
|
| 207 |
)
|
| 208 |
-
|
| 209 |
partial_text = ""
|
| 210 |
for chunk in response:
|
| 211 |
if "content" in chunk["choices"][0]["delta"]:
|
| 212 |
partial_text += chunk["choices"][0]["delta"]["content"]
|
| 213 |
yield partial_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
|
| 215 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
with gr.Blocks(theme=themes.Default()) as demo:
|
| 217 |
gr.Markdown(DESCRIPTION)
|
| 218 |
gr.Markdown(LICENSE)
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
)
|
| 237 |
-
|
|
|
|
| 238 |
|
| 239 |
-
|
|
|
|
| 14 |
import yfinance as yf
|
| 15 |
from datetime import datetime, timedelta
|
| 16 |
from math import sqrt
|
| 17 |
+
import time
|
| 18 |
+
import base64
|
| 19 |
+
import io
|
| 20 |
+
import numpy as np
|
| 21 |
+
try:
|
| 22 |
+
import scipy.optimize as opt
|
| 23 |
+
except ModuleNotFoundError:
|
| 24 |
+
subprocess.check_call([sys.executable, "-m", "pip", "install", "scipy"])
|
| 25 |
+
import scipy.optimize as opt
|
| 26 |
|
| 27 |
# Set up logging
|
| 28 |
logging.basicConfig(level=logging.INFO)
|
|
|
|
| 57 |
from PIL import Image
|
| 58 |
import io
|
| 59 |
except ModuleNotFoundError:
|
| 60 |
+
subprocess.check_call([sys.executable, "-m", "pip", "install", "matplotlib", "pillow", "numpy"])
|
| 61 |
import matplotlib.pyplot as plt
|
| 62 |
from PIL import Image
|
| 63 |
import io
|
| 64 |
+
import numpy as np
|
| 65 |
|
| 66 |
MAX_MAX_NEW_TOKENS = 512
|
| 67 |
DEFAULT_MAX_NEW_TOKENS = 512
|
| 68 |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "1024"))
|
| 69 |
|
| 70 |
DESCRIPTION = """# FinChat: Investing Q&A (Optimized for Speed)
|
| 71 |
+
This application delivers an interactive chat interface powered by a highly efficient, small AI model adapted for addressing investing and finance inquiries through specialized prompt engineering. It ensures rapid, reasoned responses to user queries. Duplicate this Space for customization or queue-free deployment.<p>Running on CPU or GPU if available. Using Phi-2 model for faster inference. Inference is heavily optimized for responses in under 10 seconds for simple queries, with output limited to 250 tokens maximum. For longer responses, increase 'Max New Tokens' in Advanced Settings. Brief delays may occur in free-tier environments due to shared resources, but typical generation speeds are improved with the smaller model.</p>"""
|
|
|
|
| 72 |
|
| 73 |
+
LICENSE = """<p/>---
|
|
|
|
| 74 |
This application employs the Phi-2 model, governed by Microsoft's Terms of Use. Refer to the [model card](https://huggingface.co/TheBloke/phi-2-GGUF) for details."""
|
| 75 |
|
| 76 |
# Load the model (skip fine-tuning for faster startup)
|
|
|
|
| 83 |
llm = Llama(
|
| 84 |
model_path=model_path,
|
| 85 |
n_ctx=1024,
|
| 86 |
+
n_batch=1024, # Increased for faster processing
|
| 87 |
n_threads=multiprocessing.cpu_count(),
|
| 88 |
n_gpu_layers=n_gpu_layers,
|
| 89 |
+
chat_format="chatml" # Phi-2 uses ChatML format in llama.cpp
|
| 90 |
)
|
| 91 |
+
logger.info(f"Model loaded successfully with n_gpu_layers={n_gpu_layers}.")
|
| 92 |
# Warm up the model for faster initial inference
|
| 93 |
llm("Warm-up prompt", max_tokens=1, echo=False)
|
| 94 |
logger.info("Model warm-up completed.")
|
|
|
|
| 109 |
- Represents average annual return with compounding
|
| 110 |
- Past performance is not indicative of future results."""
|
| 111 |
|
| 112 |
+
logs = []
|
| 113 |
+
|
| 114 |
# Function to calculate CAGR using yfinance
|
| 115 |
def calculate_cagr(ticker, start_date, end_date):
|
| 116 |
try:
|
|
|
|
| 126 |
logger.error(f"Error calculating CAGR for {ticker}: {str(e)}")
|
| 127 |
return None
|
| 128 |
|
| 129 |
+
# Function to calculate risk metrics using yfinance
|
| 130 |
def calculate_risk_metrics(ticker, years=5):
|
| 131 |
try:
|
| 132 |
end_date = datetime.now().strftime('%Y-%m-%d')
|
|
|
|
| 144 |
logger.error(f"Error calculating risk metrics for {ticker}: {str(e)}")
|
| 145 |
return None, None
|
| 146 |
|
| 147 |
+
# Function for inline plot
|
| 148 |
+
def generate_plot(ticker, period='5y'):
|
| 149 |
+
try:
|
| 150 |
+
data = yf.download(ticker, period=period)
|
| 151 |
+
if data.empty:
|
| 152 |
+
return "Unable to fetch data for plotting."
|
| 153 |
+
plt.figure(figsize=(10, 5))
|
| 154 |
+
plt.plot(data['Adj Close'], label='Adjusted Close')
|
| 155 |
+
plt.title(f'{ticker} Price History ({period})')
|
| 156 |
+
plt.xlabel('Date')
|
| 157 |
+
plt.ylabel('Price (USD)')
|
| 158 |
+
plt.legend()
|
| 159 |
+
plt.grid(True)
|
| 160 |
+
buf = io.BytesIO()
|
| 161 |
+
plt.savefig(buf, format='png', bbox_inches='tight')
|
| 162 |
+
buf.seek(0)
|
| 163 |
+
b64 = base64.b64encode(buf.read()).decode('utf-8')
|
| 164 |
+
plt.close()
|
| 165 |
+
return f""
|
| 166 |
+
except Exception as e:
|
| 167 |
+
logger.error(f"Error generating plot for {ticker}: {str(e)}")
|
| 168 |
+
return "Error generating plot."
|
| 169 |
+
|
| 170 |
+
# Function for portfolio optimization using scipy
|
| 171 |
+
def portfolio_optimization(tickers, target_return=None):
|
| 172 |
+
try:
|
| 173 |
+
data = yf.download(tickers, period='5y')['Adj Close']
|
| 174 |
+
returns = data.pct_change().dropna()
|
| 175 |
+
mean_returns = returns.mean() * 252
|
| 176 |
+
cov_matrix = returns.cov() * 252
|
| 177 |
+
num_assets = len(tickers)
|
| 178 |
+
|
| 179 |
+
def portfolio_volatility(weights):
|
| 180 |
+
return np.sqrt(np.dot(weights.T, np.dot(cov_matrix, weights)))
|
| 181 |
+
|
| 182 |
+
constraints = ({'type': 'eq', 'fun': lambda x: np.sum(x) - 1})
|
| 183 |
+
bounds = tuple((0, 1) for _ in range(num_assets))
|
| 184 |
+
initial_guess = np.array(num_assets * [1. / num_assets])
|
| 185 |
+
|
| 186 |
+
if target_return:
|
| 187 |
+
# Maximize Sharpe or min vol for target return
|
| 188 |
+
def objective(weights):
|
| 189 |
+
ret = np.sum(mean_returns * weights)
|
| 190 |
+
vol = portfolio_volatility(weights)
|
| 191 |
+
return - (ret - 0.02) / vol if vol != 0 else np.inf # Neg Sharpe
|
| 192 |
+
cons = [{'type': 'eq', 'fun': lambda x: np.sum(x) - 1},
|
| 193 |
+
{'type': 'eq', 'fun': lambda x: np.sum(mean_returns * x) - target_return}]
|
| 194 |
+
result = opt.minimize(objective, initial_guess, method='SLSQP', bounds=bounds, constraints=cons)
|
| 195 |
+
else:
|
| 196 |
+
# Minimize volatility
|
| 197 |
+
result = opt.minimize(portfolio_volatility, initial_guess, method='SLSQP',
|
| 198 |
+
bounds=bounds, constraints=constraints)
|
| 199 |
+
|
| 200 |
+
if result.success:
|
| 201 |
+
weights = dict(zip(tickers, result.x))
|
| 202 |
+
return weights
|
| 203 |
+
else:
|
| 204 |
+
return {ticker: 1/len(tickers) for ticker in tickers} # Fallback equal weights
|
| 205 |
+
except Exception as e:
|
| 206 |
+
logger.error(f"Error in portfolio optimization: {str(e)}")
|
| 207 |
+
return {ticker: 1/len(tickers) for ticker in tickers}
|
| 208 |
+
|
| 209 |
# Assuming the generate function handles the chat logic (extended to include risk comparison)
|
| 210 |
def generate(
|
| 211 |
message: str,
|
|
|
|
| 215 |
temperature: float,
|
| 216 |
top_p: float,
|
| 217 |
top_k: int,
|
| 218 |
+
logs_state: list
|
| 219 |
+
) -> tuple[Iterator[str], list]:
|
| 220 |
+
start_time = time.time()
|
| 221 |
if not system_prompt:
|
| 222 |
system_prompt = DEFAULT_SYSTEM_PROMPT
|
| 223 |
|
| 224 |
+
full_response = ""
|
| 225 |
# Detect CAGR query
|
| 226 |
cagr_match = re.search(r'average return for (\w+) between (\d{4}) and (\d{4})', message.lower())
|
| 227 |
if cagr_match:
|
|
|
|
| 232 |
end_date = f"{end_year}-12-31"
|
| 233 |
cagr = calculate_cagr(ticker, start_date, end_date)
|
| 234 |
if cagr is not None:
|
| 235 |
+
response = f"- {ticker} CAGR ({start_year}-{end_year}): ~{cagr:.2f}%\n- Represents average annual return with compounding\n- Past performance is not indicative of future results.\n- Consult a financial advisor for personalized advice."
|
| 236 |
+
yield response
|
| 237 |
+
full_response = response
|
| 238 |
else:
|
| 239 |
+
response = "Unable to calculate CAGR for the specified period."
|
| 240 |
+
yield response
|
| 241 |
+
full_response = response
|
| 242 |
+
end_time = time.time()
|
| 243 |
+
logs_state.append({
|
| 244 |
+
'timestamp': datetime.now().isoformat(),
|
| 245 |
+
'query': message,
|
| 246 |
+
'response': full_response,
|
| 247 |
+
'response_length': len(full_response.split()),
|
| 248 |
+
'generation_time': end_time - start_time,
|
| 249 |
+
'token_efficiency': len(full_response.split()) / max_new_tokens
|
| 250 |
+
})
|
| 251 |
+
return iter([]), logs_state # No more yield
|
| 252 |
|
| 253 |
# Detect risk comparison query
|
| 254 |
risk_match = re.search(r'which stock is riskier (\w+) or (\w+)', message.lower())
|
|
|
|
| 258 |
vol1, sharpe1 = calculate_risk_metrics(ticker1)
|
| 259 |
vol2, sharpe2 = calculate_risk_metrics(ticker2)
|
| 260 |
if vol1 is None or vol2 is None:
|
| 261 |
+
response = "Unable to fetch risk metrics for one or both tickers."
|
| 262 |
+
yield response
|
| 263 |
+
full_response = response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
else:
|
| 265 |
+
if vol1 > vol2:
|
| 266 |
+
riskier = ticker1
|
| 267 |
+
less_risky = ticker2
|
| 268 |
+
higher_vol = vol1
|
| 269 |
+
lower_vol = vol2
|
| 270 |
+
riskier_sharpe = sharpe1
|
| 271 |
+
less_sharpe = sharpe2
|
| 272 |
+
else:
|
| 273 |
+
riskier = ticker2
|
| 274 |
+
less_risky = ticker1
|
| 275 |
+
higher_vol = vol2
|
| 276 |
+
lower_vol = vol1
|
| 277 |
+
riskier_sharpe = sharpe2
|
| 278 |
+
less_sharpe = sharpe1
|
| 279 |
+
response = f"- {riskier} is riskier compared to {less_risky}.\n- It has a higher annualized standard deviation ({higher_vol:.2f}% vs {lower_vol:.2f}%) and a lower Sharpe ratio ({riskier_sharpe:.2f} vs {less_sharpe:.2f}), indicating greater volatility and potentially lower risk-adjusted returns.\n- Calculations based on the past 5 years of data.\n- Past performance is not indicative of future results. Consult a financial advisor for personalized advice."
|
| 280 |
+
yield response
|
| 281 |
+
full_response = response
|
| 282 |
+
end_time = time.time()
|
| 283 |
+
logs_state.append({
|
| 284 |
+
'timestamp': datetime.now().isoformat(),
|
| 285 |
+
'query': message,
|
| 286 |
+
'response': full_response,
|
| 287 |
+
'response_length': len(full_response.split()),
|
| 288 |
+
'generation_time': end_time - start_time,
|
| 289 |
+
'token_efficiency': len(full_response.split()) / max_new_tokens
|
| 290 |
+
})
|
| 291 |
+
return iter([]), logs_state
|
| 292 |
+
|
| 293 |
+
# Detect plot/chart query
|
| 294 |
+
plot_match = re.search(r'(plot|chart)\s+(\w+)(?:\s+(historical|price|volatility))?', message.lower())
|
| 295 |
+
if plot_match:
|
| 296 |
+
ticker = plot_match.group(2).upper()
|
| 297 |
+
plot_type = plot_match.group(3) if plot_match.group(3) else 'price'
|
| 298 |
+
if plot_type == 'volatility':
|
| 299 |
+
# Simple volatility plot (returns histogram)
|
| 300 |
+
try:
|
| 301 |
+
data = yf.download(ticker, period='1y')
|
| 302 |
+
returns = data['Adj Close'].pct_change().dropna()
|
| 303 |
+
plt.figure(figsize=(10, 5))
|
| 304 |
+
plt.hist(returns, bins=50, alpha=0.7)
|
| 305 |
+
plt.title(f'{ticker} Daily Returns Distribution (1Y)')
|
| 306 |
+
plt.xlabel('Return')
|
| 307 |
+
plt.ylabel('Frequency')
|
| 308 |
+
except:
|
| 309 |
+
plot_type = 'price' # Fallback
|
| 310 |
+
if plot_type != 'volatility':
|
| 311 |
+
plot_md = generate_plot(ticker)
|
| 312 |
+
response = f"Price chart for {ticker}:\n{plot_md}\n- This visualizes the historical adjusted close prices.\n- Past performance is not indicative of future results. Consult a financial advisor."
|
| 313 |
+
yield response
|
| 314 |
+
full_response = response
|
| 315 |
+
else:
|
| 316 |
+
# For volatility, similar
|
| 317 |
+
buf = io.BytesIO()
|
| 318 |
+
plt.savefig(buf, format='png', bbox_inches='tight')
|
| 319 |
+
buf.seek(0)
|
| 320 |
+
b64 = base64.b64encode(buf.read()).decode('utf-8')
|
| 321 |
+
plt.close()
|
| 322 |
+
plot_md = f""
|
| 323 |
+
response = f"Volatility chart for {ticker}:\n{plot_md}\n- Histogram of daily returns over the past year."
|
| 324 |
+
yield response
|
| 325 |
+
full_response = response
|
| 326 |
+
end_time = time.time()
|
| 327 |
+
logs_state.append({
|
| 328 |
+
'timestamp': datetime.now().isoformat(),
|
| 329 |
+
'query': message,
|
| 330 |
+
'response': full_response,
|
| 331 |
+
'response_length': len(full_response.split()),
|
| 332 |
+
'generation_time': end_time - start_time,
|
| 333 |
+
'token_efficiency': len(full_response.split()) / max_new_tokens
|
| 334 |
+
})
|
| 335 |
+
return iter([]), logs_state
|
| 336 |
+
|
| 337 |
+
# Detect portfolio optimization query
|
| 338 |
+
port_match = re.search(r'optimize\s+portfolio\s+for\s+([\w,\s]+)(?:\s+with\s+(risk|return)\s+tolerance\s+([\d.]+))?', message.lower())
|
| 339 |
+
if port_match:
|
| 340 |
+
tickers_str = port_match.group(1).strip()
|
| 341 |
+
tickers = [t.strip().upper() for t in re.split(r'[,;]', tickers_str) if t.strip()]
|
| 342 |
+
target = None
|
| 343 |
+
if port_match.group(3):
|
| 344 |
+
target = float(port_match.group(3))
|
| 345 |
+
if port_match.group(2) == 'risk':
|
| 346 |
+
# For risk tolerance, min vol with vol <= target (but simplify to min vol)
|
| 347 |
+
pass # Use default min vol
|
| 348 |
+
else:
|
| 349 |
+
target_return = target
|
| 350 |
+
weights = portfolio_optimization(tickers, target_return=target if 'return' in (port_match.group(2) or '') else None)
|
| 351 |
+
df = pd.DataFrame(list(weights.items()), columns=['Ticker', 'Weight'])
|
| 352 |
+
df['Weight'] = df['Weight'].round(4)
|
| 353 |
+
table_md = df.to_markdown(index=False)
|
| 354 |
+
response = f"- Suggested portfolio weights for {', '.join(tickers)}:\n{table_md}\n- Based on minimum variance optimization (or target return if specified).\n- Assumes 5-year historical data for means and covariances.\n- Past performance is not indicative of future results. Consult a financial advisor for personalized advice."
|
| 355 |
+
yield response
|
| 356 |
+
full_response = response
|
| 357 |
+
end_time = time.time()
|
| 358 |
+
logs_state.append({
|
| 359 |
+
'timestamp': datetime.now().isoformat(),
|
| 360 |
+
'query': message,
|
| 361 |
+
'response': full_response,
|
| 362 |
+
'response_length': len(full_response.split()),
|
| 363 |
+
'generation_time': end_time - start_time,
|
| 364 |
+
'token_efficiency': len(full_response.split()) / max_new_tokens
|
| 365 |
+
})
|
| 366 |
+
return iter([]), logs_state
|
| 367 |
|
| 368 |
# For other queries, fall back to LLM generation
|
| 369 |
conversation = [{"role": "system", "content": system_prompt}]
|
| 370 |
for user, assistant in history:
|
| 371 |
+
conversation.extend([
|
| 372 |
+
{"role": "user", "content": user},
|
| 373 |
+
{"role": "assistant", "content": assistant}
|
| 374 |
+
])
|
| 375 |
conversation.append({"role": "user", "content": message})
|
| 376 |
|
| 377 |
# Generate response using LLM (streamed)
|
|
|
|
| 383 |
top_k=top_k,
|
| 384 |
stream=True
|
| 385 |
)
|
|
|
|
| 386 |
partial_text = ""
|
| 387 |
for chunk in response:
|
| 388 |
if "content" in chunk["choices"][0]["delta"]:
|
| 389 |
partial_text += chunk["choices"][0]["delta"]["content"]
|
| 390 |
yield partial_text
|
| 391 |
+
full_response = partial_text
|
| 392 |
+
end_time = time.time()
|
| 393 |
+
logs_state.append({
|
| 394 |
+
'timestamp': datetime.now().isoformat(),
|
| 395 |
+
'query': message,
|
| 396 |
+
'response': full_response,
|
| 397 |
+
'response_length': len(full_response.split()),
|
| 398 |
+
'generation_time': end_time - start_time,
|
| 399 |
+
'token_efficiency': len(full_response.split()) / max_new_tokens
|
| 400 |
+
})
|
| 401 |
+
return iter([]), logs_state
|
| 402 |
|
| 403 |
+
def update_logs(logs_state):
|
| 404 |
+
if logs_state:
|
| 405 |
+
df = pd.DataFrame(logs_state)
|
| 406 |
+
return df
|
| 407 |
+
return pd.DataFrame()
|
| 408 |
+
|
| 409 |
+
# Gradio interface setup
|
| 410 |
with gr.Blocks(theme=themes.Default()) as demo:
|
| 411 |
gr.Markdown(DESCRIPTION)
|
| 412 |
gr.Markdown(LICENSE)
|
| 413 |
+
|
| 414 |
+
with gr.Tabs():
|
| 415 |
+
with gr.TabItem("Chat"):
|
| 416 |
+
chatbot = gr.Chatbot()
|
| 417 |
+
msg = gr.Textbox(label="Enter your question")
|
| 418 |
+
with gr.Row():
|
| 419 |
+
submit = gr.Button("Submit")
|
| 420 |
+
clear = gr.Button("Clear")
|
| 421 |
+
advanced = gr.Accordion("Advanced Settings", open=False)
|
| 422 |
+
with advanced:
|
| 423 |
+
system_prompt = gr.Textbox(label="System Prompt", value=DEFAULT_SYSTEM_PROMPT, lines=6)
|
| 424 |
+
max_new_tokens = gr.Slider(minimum=1, maximum=MAX_MAX_NEW_TOKENS, value=DEFAULT_MAX_NEW_TOKENS, step=1, label="Max New Tokens")
|
| 425 |
+
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, label="Temperature")
|
| 426 |
+
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.1, label="Top P")
|
| 427 |
+
top_k = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top K")
|
| 428 |
+
|
| 429 |
+
with gr.TabItem("Metrics"):
|
| 430 |
+
metrics_df = gr.Dataframe(headers=['timestamp', 'query', 'response', 'response_length', 'generation_time', 'token_efficiency'])
|
| 431 |
+
|
| 432 |
+
logs_state = gr.State(logs)
|
| 433 |
+
|
| 434 |
+
def submit_fn(msg, history, system_prompt, max_new_tokens, temperature, top_p, top_k, logs_state):
|
| 435 |
+
gen, new_logs = generate(msg, history, system_prompt, max_new_tokens, temperature, top_p, top_k, logs_state)
|
| 436 |
+
history.append((msg, ""))
|
| 437 |
+
for partial in gen:
|
| 438 |
+
history[-1] = (history[-1][0], partial)
|
| 439 |
+
yield history, "", new_logs
|
| 440 |
+
return history, "", new_logs
|
| 441 |
+
|
| 442 |
+
submit.click(
|
| 443 |
+
submit_fn,
|
| 444 |
+
inputs=[msg, chatbot, system_prompt, max_new_tokens, temperature, top_p, top_k, logs_state],
|
| 445 |
+
outputs=[chatbot, msg, logs_state],
|
| 446 |
+
queue=False
|
| 447 |
+
).then(
|
| 448 |
+
update_logs,
|
| 449 |
+
inputs=[logs_state],
|
| 450 |
+
outputs=[metrics_df]
|
| 451 |
)
|
| 452 |
+
|
| 453 |
+
clear.click(lambda: ([], []), None, (chatbot, logs_state))
|
| 454 |
|
| 455 |
+
demo.launch()
|