File size: 3,510 Bytes
afefd94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# agent_app.py
from smolagents import CodeAgent, TransformersModel
from smolagents import tool
from sheet_tool import (
    fetch_sheet_as_df,
    create_pivot,
    summary_stats,
    plot_dataframe,
    df_to_csv_bytes,
)
import base64

# Initialize model
model = TransformersModel(model_id="HuggingFaceTB/SmolLM-135M-Instruct")

# -------------------------------
# ✅ TOOL DEFINITIONS
# -------------------------------

@tool
def load_sheet() -> dict:
    """Load Google Sheet into a dataframe and return a short summary (not the full sheet)."""
    df = fetch_sheet_as_df()
    if df.empty:
        return {"error": "Sheet is empty or not found."}
    return {
        "rows": len(df),
        "columns": list(df.columns),
        "head": df.head(5).to_dict(orient="records"),
    }


@tool
def pivot(index_cols: str, column_cols: str, value_cols: str, aggfunc: str = "sum") -> dict:
    """
    Create a pivot table from the Google Sheet.

    Args:
        index_cols (str): Comma-separated list of columns to use as the pivot table index.
        column_cols (str): Comma-separated list of columns to use as pivot table columns.
        value_cols (str): Comma-separated list of columns to aggregate.
        aggfunc (str, optional): Aggregation function to apply (e.g., 'sum', 'mean', 'count'). Defaults to 'sum'.
    """
    df = fetch_sheet_as_df()
    if df.empty:
        return {"error": "Sheet empty"}
    index = [c.strip() for c in index_cols.split(",")] if index_cols else []
    columns = [c.strip() for c in column_cols.split(",")] if column_cols else []
    values = [c.strip() for c in value_cols.split(",")] if value_cols else []
    pivot_df = create_pivot(df, index=index, columns=columns, values=values, aggfunc=aggfunc)
    csv_bytes = df_to_csv_bytes(pivot_df)
    return {
        "pivot_preview": pivot_df.head(10).to_dict(orient="records"),
        "csv_b64": base64.b64encode(csv_bytes).decode("utf-8")
    }



@tool
def stats() -> dict:
    """Generate summary statistics of the sheet."""
    df = fetch_sheet_as_df()
    if df.empty:
        return {"error": "Sheet empty"}
    s = summary_stats(df)
    return {"summary": s.to_dict()}


@tool
def plot(kind: str = "bar", x: str = None, y: str = None, title: str = None) -> dict:
    """
    Create a plot from the Google Sheet data.

    Args:
        kind (str): Type of chart to create. Example values: 'bar', 'line', 'pie', 'scatter'.
        x (str, optional): Column name to use for the X-axis. Example: 'Date'.
        y (str, optional): Comma-separated column names to use for Y-axis. Example: 'Sales,Profit'.
        title (str, optional): Chart title to display at the top.

    Returns:
        dict: A dictionary containing the base64-encoded plot image or an error message if the sheet is empty.
    """
    df = fetch_sheet_as_df()
    if df.empty:
        return {"error": "Sheet empty"}

    y_list = [c.strip() for c in y.split(",")] if y else None
    img_data_uri = plot_dataframe(df, kind=kind, x=x, y=y_list, title=title)
    return {"image": img_data_uri}


# -------------------------------
# ✅ AGENT CREATION
# -------------------------------
agent = CodeAgent(model=model, tools=[load_sheet, pivot, stats, plot], add_base_tools=True)


def ask_agent(nl_query: str) -> dict:
    """Send a natural-language query to the agent and return structured response."""
    try:
        resp = agent.run(nl_query)
        return {"text": str(resp)}
    except Exception as e:
        return {"error": str(e)}