|
|
import gradio as gr |
|
|
import torch |
|
|
import logging |
|
|
import warnings |
|
|
import os |
|
|
import yfinance as yf |
|
|
import pandas as pd |
|
|
from neuralprophet import NeuralProphet |
|
|
import plotly.graph_objs as go |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logging.getLogger("neuralprophet").setLevel(logging.ERROR) |
|
|
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR) |
|
|
warnings.filterwarnings("ignore") |
|
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' |
|
|
|
|
|
|
|
|
original_load = torch.load |
|
|
def patched_load(*args, **kwargs): |
|
|
if 'weights_only' not in kwargs: |
|
|
kwargs['weights_only'] = False |
|
|
return original_load(*args, **kwargs) |
|
|
torch.load = patched_load |
|
|
|
|
|
|
|
|
|
|
|
def predict_stock(ticker): |
|
|
""" |
|
|
Takes a ticker symbol, trains a NeuralProphet model, |
|
|
and returns a textual report and two Plotly figures. |
|
|
""" |
|
|
ticker = ticker.strip().upper() |
|
|
|
|
|
if not ticker: |
|
|
return "โ ๏ธ Please enter a ticker symbol.", None, None |
|
|
|
|
|
print(f"Processing {ticker}...") |
|
|
|
|
|
try: |
|
|
|
|
|
data = yf.download(ticker, period="3y", interval="1d", progress=False) |
|
|
|
|
|
if data.empty: |
|
|
return f"โ Could not find data for ticker '{ticker}'. Please check the symbol.", None, None |
|
|
|
|
|
|
|
|
if isinstance(data.columns, pd.MultiIndex): |
|
|
try: |
|
|
|
|
|
df = data.xs(ticker, axis=1, level=1) |
|
|
if 'Close' in df.columns: |
|
|
df = df[['Close']].reset_index() |
|
|
else: |
|
|
df = data['Close'].reset_index() |
|
|
except: |
|
|
|
|
|
df = data.copy() |
|
|
df.columns = ['_'.join(col).strip() for col in df.columns.values] |
|
|
close_col = [c for c in df.columns if "Close" in c][0] |
|
|
df = df[[close_col]].reset_index() |
|
|
else: |
|
|
df = data[['Close']].reset_index() |
|
|
|
|
|
|
|
|
df.columns = ['ds', 'y'] |
|
|
df['ds'] = df['ds'].dt.tz_localize(None) |
|
|
|
|
|
if len(df) < 100: |
|
|
return f"โ Not enough historical data found for {ticker} (Need > 100 days).", None, None |
|
|
|
|
|
|
|
|
m = NeuralProphet( |
|
|
yearly_seasonality=True, |
|
|
weekly_seasonality=True, |
|
|
daily_seasonality=False, |
|
|
learning_rate=0.01 |
|
|
) |
|
|
|
|
|
m.fit(df, freq="D") |
|
|
|
|
|
|
|
|
future = m.make_future_dataframe(df, periods=90) |
|
|
forecast = m.predict(future) |
|
|
|
|
|
|
|
|
current_price = df['y'].iloc[-1] |
|
|
predicted_price = forecast['yhat1'].iloc[-1] |
|
|
|
|
|
|
|
|
roi = ((predicted_price - current_price) / current_price) * 100 |
|
|
|
|
|
|
|
|
if roi > 10: |
|
|
verdict = "STRONG BUY ๐" |
|
|
color = "#10B981" |
|
|
bg_color = "#D1FAE5" |
|
|
elif roi > 2: |
|
|
verdict = "BUY ๐ข" |
|
|
color = "#10B981" |
|
|
bg_color = "#D1FAE5" |
|
|
elif roi > -5: |
|
|
verdict = "HOLD ๐ก" |
|
|
color = "#F59E0B" |
|
|
bg_color = "#FEF3C7" |
|
|
else: |
|
|
verdict = "SELL ๐ด" |
|
|
color = "#EF4444" |
|
|
bg_color = "#FEE2E2" |
|
|
|
|
|
|
|
|
|
|
|
html_report = f""" |
|
|
<div style="border: 2px solid {color}; border-radius: 10px; padding: 20px; background-color: {bg_color}; color: #1F2937; text-align: center; margin-bottom: 20px;"> |
|
|
<h2 style="margin: 0; font-size: 1.5rem; text-transform: uppercase; color: {color};">{verdict}</h2> |
|
|
<p style="margin-top: 5px; font-size: 0.9rem; opacity: 0.8;">Forecast Horizon: 90 Days</p> |
|
|
|
|
|
<div style="display: flex; justify-content: space-around; margin-top: 20px;"> |
|
|
<div> |
|
|
<div style="font-size: 0.8rem; text-transform: uppercase; letter-spacing: 1px;">Current</div> |
|
|
<div style="font-size: 1.5rem; font-weight: bold;">{current_price:.2f}</div> |
|
|
</div> |
|
|
<div> |
|
|
<div style="font-size: 0.8rem; text-transform: uppercase; letter-spacing: 1px;">Target</div> |
|
|
<div style="font-size: 1.5rem; font-weight: bold;">{predicted_price:.2f}</div> |
|
|
</div> |
|
|
<div> |
|
|
<div style="font-size: 0.8rem; text-transform: uppercase; letter-spacing: 1px;">ROI</div> |
|
|
<div style="font-size: 1.5rem; font-weight: bold; color: {color};">{roi:+.2f}%</div> |
|
|
</div> |
|
|
</div> |
|
|
</div> |
|
|
""" |
|
|
|
|
|
|
|
|
fig_forecast = m.plot(forecast) |
|
|
fig_forecast.update_layout(title_text="Price Forecast (Blue = Prediction)", title_x=0.5) |
|
|
|
|
|
fig_components = m.plot_components(forecast) |
|
|
fig_components.update_layout(title_text="Seasonality & Trend Analysis", title_x=0.5) |
|
|
|
|
|
return html_report, fig_forecast, fig_components |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return f"<h3 style='color: red'>โ Error: {str(e)}</h3>", None, None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
custom_css = """ |
|
|
.container { max-width: 900px; margin: auto; } |
|
|
.footer { text-align: center; font-size: 0.8em; margin-top: 20px; } |
|
|
""" |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo: |
|
|
|
|
|
with gr.Column(elem_classes="container"): |
|
|
gr.Markdown( |
|
|
""" |
|
|
# ๐ฎ NeuralProphet Stock Predictor |
|
|
**AI-Powered 90-Day Price Forecasts** |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=3): |
|
|
ticker_input = gr.Textbox( |
|
|
label="Stock Ticker", |
|
|
placeholder="e.g. AZN.L, AAPL, TSLA", |
|
|
value="AZN.L", |
|
|
show_label=False, |
|
|
container=False |
|
|
) |
|
|
with gr.Column(scale=1): |
|
|
submit_btn = gr.Button("๐ Analyze", variant="primary") |
|
|
|
|
|
|
|
|
result_html = gr.HTML(label="Analysis Results") |
|
|
|
|
|
with gr.Row(): |
|
|
plot1 = gr.Plot(label="Forecast") |
|
|
plot2 = gr.Plot(label="Seasonality") |
|
|
|
|
|
with gr.Accordion("โน๏ธ Disclaimer & Info", open=False): |
|
|
gr.Markdown(""" |
|
|
**How it works:** This app downloads 3 years of daily data and trains a NeuralProphet model on-the-fly. |
|
|
It detects yearly and weekly seasonality to project price action 90 days out. |
|
|
|
|
|
**Disclaimer:** |
|
|
AI models can hallucinate trends. Always do your own research before investing. |
|
|
""") |
|
|
|
|
|
gr.Examples( |
|
|
examples=["AZN.L", "AAPL", "NVDA", "TSCO.L", "BTC-USD"], |
|
|
inputs=ticker_input |
|
|
) |
|
|
|
|
|
submit_btn.click( |
|
|
fn=predict_stock, |
|
|
inputs=ticker_input, |
|
|
outputs=[result_html, plot1, plot2] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |