Akshit Chaturvedi commited on
Commit
0f1ac1c
ยท
1 Parent(s): 3e27552

Made outputs look better

Browse files
Files changed (1) hide show
  1. app.py +94 -38
app.py CHANGED
@@ -36,18 +36,16 @@ def predict_stock(ticker):
36
  if not ticker:
37
  return "โš ๏ธ Please enter a ticker symbol.", None, None
38
 
39
- # Status update for the logs
40
  print(f"Processing {ticker}...")
41
 
42
  try:
43
  # 1. Get Data
44
  data = yf.download(ticker, period="3y", interval="1d", progress=False)
45
 
46
- # Handle cases where yfinance returns empty dataframe or multi-index columns
47
  if data.empty:
48
  return f"โŒ Could not find data for ticker '{ticker}'. Please check the symbol.", None, None
49
 
50
- # Flatten MultiIndex if present (yfinance update quirk)
51
  if isinstance(data.columns, pd.MultiIndex):
52
  try:
53
  # Attempt to extract just the Close column for the specific ticker
@@ -60,7 +58,6 @@ def predict_stock(ticker):
60
  # Brute force flatten
61
  df = data.copy()
62
  df.columns = ['_'.join(col).strip() for col in df.columns.values]
63
- # Look for a column containing "Close"
64
  close_col = [c for c in df.columns if "Close" in c][0]
65
  df = df[[close_col]].reset_index()
66
  else:
@@ -68,15 +65,12 @@ def predict_stock(ticker):
68
 
69
  # Rename for NeuralProphet
70
  df.columns = ['ds', 'y']
71
-
72
- # Ensure dates are timezone-naive
73
  df['ds'] = df['ds'].dt.tz_localize(None)
74
 
75
  if len(df) < 100:
76
  return f"โŒ Not enough historical data found for {ticker} (Need > 100 days).", None, None
77
 
78
  # 2. Train Model
79
- # FIX: Removed 'trainer_config' to prevent PyTorch Lightning crash
80
  m = NeuralProphet(
81
  yearly_seasonality=True,
82
  weekly_seasonality=True,
@@ -97,55 +91,117 @@ def predict_stock(ticker):
97
  # Calculate ROI
98
  roi = ((predicted_price - current_price) / current_price) * 100
99
 
100
- # Generate Verdict
101
- if roi > 10: verdict = "STRONG BUY ๐ŸŸข"
102
- elif roi > 2: verdict = "BUY ๐ŸŸข"
103
- elif roi > -5: verdict = "HOLD ๐ŸŸก"
104
- else: verdict = "SELL ๐Ÿ”ด"
105
-
106
- # 5. formatting Output Text
107
- report = f"""
108
- ### ๐Ÿ“Š Analysis Report: {ticker}
109
- **Current Price:** {current_price:.2f}
110
- **90-Day Target:** {predicted_price:.2f}
111
- **Projected ROI:** {roi:.2f}%
112
- **Verdict:** {verdict}
113
-
114
- *Disclaimer: Apply due diligence before investing.*
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  """
116
 
117
  # 6. Generate Plots
118
- # Note: We rely on standard m.plot() which returns a plotly figure
119
  fig_forecast = m.plot(forecast)
 
 
120
  fig_components = m.plot_components(forecast)
 
121
 
122
- return report, fig_forecast, fig_components
123
 
124
  except Exception as e:
125
  import traceback
126
  traceback.print_exc()
127
- return f"โŒ An error occurred while processing {ticker}: {str(e)}", None, None
128
 
129
  # --- STEP 3: GRADIO INTERFACE ---
130
 
131
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
132
- gr.Markdown("# ๐Ÿ“ˆ NeuralProphet Stock Predictor")
133
- gr.Markdown("Enter a stock ticker (e.g., `AAPL`, `TSLA`, `AZN.L`) to generate a 90-day forecast.")
134
-
135
- with gr.Row():
136
- ticker_input = gr.Textbox(label="Ticker Symbol", placeholder="e.g. AZN.L", value="AZN.L")
137
- submit_btn = gr.Button("Analyze Stock", variant="primary")
138
-
139
- result_text = gr.Markdown(label="Verdict")
140
 
141
- with gr.Row():
142
- plot1 = gr.Plot(label="Price Forecast")
143
- plot2 = gr.Plot(label="Seasonality Components")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  submit_btn.click(
146
  fn=predict_stock,
147
  inputs=ticker_input,
148
- outputs=[result_text, plot1, plot2]
149
  )
150
 
151
  if __name__ == "__main__":
 
36
  if not ticker:
37
  return "โš ๏ธ Please enter a ticker symbol.", None, None
38
 
 
39
  print(f"Processing {ticker}...")
40
 
41
  try:
42
  # 1. Get Data
43
  data = yf.download(ticker, period="3y", interval="1d", progress=False)
44
 
 
45
  if data.empty:
46
  return f"โŒ Could not find data for ticker '{ticker}'. Please check the symbol.", None, None
47
 
48
+ # Flatten MultiIndex if present
49
  if isinstance(data.columns, pd.MultiIndex):
50
  try:
51
  # Attempt to extract just the Close column for the specific ticker
 
58
  # Brute force flatten
59
  df = data.copy()
60
  df.columns = ['_'.join(col).strip() for col in df.columns.values]
 
61
  close_col = [c for c in df.columns if "Close" in c][0]
62
  df = df[[close_col]].reset_index()
63
  else:
 
65
 
66
  # Rename for NeuralProphet
67
  df.columns = ['ds', 'y']
 
 
68
  df['ds'] = df['ds'].dt.tz_localize(None)
69
 
70
  if len(df) < 100:
71
  return f"โŒ Not enough historical data found for {ticker} (Need > 100 days).", None, None
72
 
73
  # 2. Train Model
 
74
  m = NeuralProphet(
75
  yearly_seasonality=True,
76
  weekly_seasonality=True,
 
91
  # Calculate ROI
92
  roi = ((predicted_price - current_price) / current_price) * 100
93
 
94
+ # Generate Verdict & Colors
95
+ if roi > 10:
96
+ verdict = "STRONG BUY ๐Ÿš€"
97
+ color = "#10B981" # Green
98
+ bg_color = "#D1FAE5"
99
+ elif roi > 2:
100
+ verdict = "BUY ๐ŸŸข"
101
+ color = "#10B981" # Green
102
+ bg_color = "#D1FAE5"
103
+ elif roi > -5:
104
+ verdict = "HOLD ๐ŸŸก"
105
+ color = "#F59E0B" # Yellow
106
+ bg_color = "#FEF3C7"
107
+ else:
108
+ verdict = "SELL ๐Ÿ”ด"
109
+ color = "#EF4444" # Red
110
+ bg_color = "#FEE2E2"
111
+
112
+ # 5. Format Output HTML (Pretty Dashboard)
113
+ # Using inline CSS to ensure it looks good in Gradio
114
+ html_report = f"""
115
+ <div style="border: 2px solid {color}; border-radius: 10px; padding: 20px; background-color: {bg_color}; color: #1F2937; text-align: center; margin-bottom: 20px;">
116
+ <h2 style="margin: 0; font-size: 1.5rem; text-transform: uppercase; color: {color};">{verdict}</h2>
117
+ <p style="margin-top: 5px; font-size: 0.9rem; opacity: 0.8;">Forecast Horizon: 90 Days</p>
118
+
119
+ <div style="display: flex; justify-content: space-around; margin-top: 20px;">
120
+ <div>
121
+ <div style="font-size: 0.8rem; text-transform: uppercase; letter-spacing: 1px;">Current</div>
122
+ <div style="font-size: 1.5rem; font-weight: bold;">{current_price:.2f}</div>
123
+ </div>
124
+ <div>
125
+ <div style="font-size: 0.8rem; text-transform: uppercase; letter-spacing: 1px;">Target</div>
126
+ <div style="font-size: 1.5rem; font-weight: bold;">{predicted_price:.2f}</div>
127
+ </div>
128
+ <div>
129
+ <div style="font-size: 0.8rem; text-transform: uppercase; letter-spacing: 1px;">ROI</div>
130
+ <div style="font-size: 1.5rem; font-weight: bold; color: {color};">{roi:+.2f}%</div>
131
+ </div>
132
+ </div>
133
+ </div>
134
  """
135
 
136
  # 6. Generate Plots
 
137
  fig_forecast = m.plot(forecast)
138
+ fig_forecast.update_layout(title_text="Price Forecast (Blue = Prediction)", title_x=0.5)
139
+
140
  fig_components = m.plot_components(forecast)
141
+ fig_components.update_layout(title_text="Seasonality & Trend Analysis", title_x=0.5)
142
 
143
+ return html_report, fig_forecast, fig_components
144
 
145
  except Exception as e:
146
  import traceback
147
  traceback.print_exc()
148
+ return f"<h3 style='color: red'>โŒ Error: {str(e)}</h3>", None, None
149
 
150
  # --- STEP 3: GRADIO INTERFACE ---
151
 
152
+ # Custom CSS for a cleaner look
153
+ custom_css = """
154
+ .container { max-width: 900px; margin: auto; }
155
+ .footer { text-align: center; font-size: 0.8em; margin-top: 20px; }
156
+ """
157
+
158
+ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
 
 
159
 
160
+ with gr.Column(elem_classes="container"):
161
+ gr.Markdown(
162
+ """
163
+ # ๐Ÿ”ฎ NeuralProphet Stock Predictor
164
+ **AI-Powered 90-Day Price Forecasts**
165
+ """
166
+ )
167
+
168
+ with gr.Row():
169
+ with gr.Column(scale=3):
170
+ ticker_input = gr.Textbox(
171
+ label="Stock Ticker",
172
+ placeholder="e.g. AZN.L, AAPL, TSLA",
173
+ value="AZN.L",
174
+ show_label=False,
175
+ container=False
176
+ )
177
+ with gr.Column(scale=1):
178
+ submit_btn = gr.Button("๐Ÿš€ Analyze", variant="primary")
179
+
180
+ # HTML Result Dashboard
181
+ result_html = gr.HTML(label="Analysis Results")
182
+
183
+ with gr.Row():
184
+ plot1 = gr.Plot(label="Forecast")
185
+ plot2 = gr.Plot(label="Seasonality")
186
+
187
+ with gr.Accordion("โ„น๏ธ Disclaimer & Info", open=False):
188
+ gr.Markdown("""
189
+ **How it works:** This app downloads 3 years of daily data and trains a NeuralProphet model on-the-fly.
190
+ It detects yearly and weekly seasonality to project price action 90 days out.
191
+
192
+ **Disclaimer:** This tool is for educational purposes only. It is not financial advice.
193
+ AI models can hallucinate trends. Always do your own research.
194
+ """)
195
+
196
+ gr.Examples(
197
+ examples=["AZN.L", "AAPL", "NVDA", "TSCO.L", "BTC-USD"],
198
+ inputs=ticker_input
199
+ )
200
 
201
  submit_btn.click(
202
  fn=predict_stock,
203
  inputs=ticker_input,
204
+ outputs=[result_html, plot1, plot2]
205
  )
206
 
207
  if __name__ == "__main__":