Skip to content

Plot

PlotTool dataclass

Tool responsible for generating and saving visualization charts based on SRAG data.

Unlike a generic plotting tool, this is highly specialized. It executes pre-defined SQL aggregations to guarantee correct statistical representation for specific chart types (trend_30d and history_12m).

Dual Output Strategy:

  1. Visual: Saves a PNG file to disk.
  2. Semantic: Returns a text summary (growth rates, peaks, totals) to the LLM. This allows the Agent to "see" the data and describe the chart accurately in the report.

Attributes:

Name Type Description
output_dir Path

The directory where PNG files will be saved.

Source code in api/src/tools/plot.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
@dataclass
class PlotTool:
    """
    Tool responsible for generating and saving visualization charts based on SRAG data.

    Unlike a generic plotting tool, this is highly specialized. It executes pre-defined
    SQL aggregations to guarantee correct statistical representation for specific
    chart types (`trend_30d` and `history_12m`).

    **Dual Output Strategy:**

    1.  **Visual:** Saves a PNG file to disk.
    2.  **Semantic:** Returns a text summary (growth rates, peaks, totals) to the LLM.
        This allows the Agent to "see" the data and describe the chart accurately in the report.

    Attributes:
        output_dir (Path): The directory where PNG files will be saved.
    """

    output_dir: Path

    def __post_init__(self):
        self.output_dir.mkdir(parents=True, exist_ok=True)

    def __call__(
        self,
        ctx: RunContext[AgentDeps],
        chart_type: str,
    ) -> str:
        """
        Generates the requested chart and returns a status message with data insights.

        Args:
            ctx (RunContext): The context containing the database connection.
            chart_type (str): The type of chart to generate ('trend_30d' or 'history_12m').

        Returns:
            str: A system note containing the file path of the generated chart AND
                 a statistical summary of the data plotted (e.g., "Growth: +15%").
        """
        logger.info(f"Agent requested chart: {chart_type}")
        con = ctx.deps.get_db_connection(read_only=True)

        try:
            max_date_res = con.execute(
                "SELECT MAX(DT_NOTIFIC) FROM srag_analytics"
            ).fetchone()

            if not max_date_res or not max_date_res[0]:
                return "Error: Database is empty."

            max_date = max_date_res[0]

            sns.set_theme(style="whitegrid")
            fig, ax = plt.subplots(figsize=(10, 6))

            stats_summary = ""

            if chart_type == "trend_30d":
                query = f"""
                    SELECT DT_NOTIFIC, COUNT(*) AS cases
                    FROM srag_analytics
                    WHERE DT_NOTIFIC >= CAST('{max_date}' AS DATE) - INTERVAL 45 DAY
                    GROUP BY 1 ORDER BY 1 ASC
                """
                df = con.execute(query).df()
                df["DT_NOTIFIC"] = pd.to_datetime(df["DT_NOTIFIC"])

                df = (
                    df.set_index("DT_NOTIFIC")
                    .resample("D")
                    .sum()
                    .fillna(0)
                    .reset_index()
                )

                last_7d = df.tail(7)["cases"].sum()
                prev_7d = df.iloc[-14:-7]["cases"].sum()
                growth_rate = (
                    ((last_7d - prev_7d) / prev_7d * 100) if prev_7d > 0 else 0
                )
                peak_day = df.loc[df["cases"].idxmax()]

                stats_summary = (
                    f"DATA SUMMARY FOR AGENT: Growth rate: {growth_rate:+.1f}%. "
                    f"Last 7 days total: {last_7d}. "
                    f"Peak: {peak_day['cases']} on {peak_day['DT_NOTIFIC'].strftime('%Y-%m-%d')}."
                )

                cutoff_30d = pd.to_datetime(max_date) - pd.Timedelta(days=30)
                plot_df = df[df["DT_NOTIFIC"] >= cutoff_30d]

                sns.lineplot(
                    data=plot_df,
                    x="DT_NOTIFIC",
                    y="cases",
                    marker="o",
                    color="#d62728",
                    ax=ax,
                )

                trend_icon = "📈" if growth_rate > 0 else "📉"
                ax.set_title(f"30-Day Trend | Growth: {growth_rate:+.1f}% {trend_icon}")
                ax.tick_params(axis="x", rotation=45)

            elif chart_type == "history_12m":
                query = f"""
                    SELECT strftime(DT_NOTIFIC, '%Y-%m') AS month_str, COUNT(*) AS cases
                    FROM srag_analytics
                    WHERE DT_NOTIFIC >= CAST('{max_date}' AS DATE) - INTERVAL 12 MONTH
                    GROUP BY 1 ORDER BY 1 ASC
                """
                df = con.execute(query).df()

                total_cases = df["cases"].sum()
                peak_month = df.loc[df["cases"].idxmax()]
                avg_cases = df["cases"].mean()

                stats_summary = (
                    f"DATA SUMMARY FOR AGENT: 12 months total: {total_cases}. "
                    f"Avg: {avg_cases:.1f}. Peak: {peak_month['month_str']}."
                )

                sns.barplot(data=df, x="month_str", y="cases", color="#1f77b4", ax=ax)
                ax.set_title("12-Month History")
                ax.tick_params(axis="x", rotation=45)

            fig.tight_layout()

            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            filename = f"{chart_type}_{timestamp}.png"
            filepath = self.output_dir / filename

            fig.savefig(
                filepath, format="png", dpi=100, facecolor="white", bbox_inches="tight"
            )

            plt.close(fig)

            logger.info(f"Plot saved to {filepath}")
            return f"**System Note:** Chart generated at {filepath}.\n\n{stats_summary}"

        except Exception as e:
            logger.error(f"Plot generation failed: {e}", exc_info=True)

            if "fig" in locals():
                plt.close(fig)

            return f"Error generating chart: {str(e)}"

        finally:
            con.close()

__call__(ctx, chart_type)

Generates the requested chart and returns a status message with data insights.

Parameters:

Name Type Description Default
ctx RunContext

The context containing the database connection.

required
chart_type str

The type of chart to generate ('trend_30d' or 'history_12m').

required

Returns:

Name Type Description
str str

A system note containing the file path of the generated chart AND a statistical summary of the data plotted (e.g., "Growth: +15%").

Source code in api/src/tools/plot.py
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
def __call__(
    self,
    ctx: RunContext[AgentDeps],
    chart_type: str,
) -> str:
    """
    Generates the requested chart and returns a status message with data insights.

    Args:
        ctx (RunContext): The context containing the database connection.
        chart_type (str): The type of chart to generate ('trend_30d' or 'history_12m').

    Returns:
        str: A system note containing the file path of the generated chart AND
             a statistical summary of the data plotted (e.g., "Growth: +15%").
    """
    logger.info(f"Agent requested chart: {chart_type}")
    con = ctx.deps.get_db_connection(read_only=True)

    try:
        max_date_res = con.execute(
            "SELECT MAX(DT_NOTIFIC) FROM srag_analytics"
        ).fetchone()

        if not max_date_res or not max_date_res[0]:
            return "Error: Database is empty."

        max_date = max_date_res[0]

        sns.set_theme(style="whitegrid")
        fig, ax = plt.subplots(figsize=(10, 6))

        stats_summary = ""

        if chart_type == "trend_30d":
            query = f"""
                SELECT DT_NOTIFIC, COUNT(*) AS cases
                FROM srag_analytics
                WHERE DT_NOTIFIC >= CAST('{max_date}' AS DATE) - INTERVAL 45 DAY
                GROUP BY 1 ORDER BY 1 ASC
            """
            df = con.execute(query).df()
            df["DT_NOTIFIC"] = pd.to_datetime(df["DT_NOTIFIC"])

            df = (
                df.set_index("DT_NOTIFIC")
                .resample("D")
                .sum()
                .fillna(0)
                .reset_index()
            )

            last_7d = df.tail(7)["cases"].sum()
            prev_7d = df.iloc[-14:-7]["cases"].sum()
            growth_rate = (
                ((last_7d - prev_7d) / prev_7d * 100) if prev_7d > 0 else 0
            )
            peak_day = df.loc[df["cases"].idxmax()]

            stats_summary = (
                f"DATA SUMMARY FOR AGENT: Growth rate: {growth_rate:+.1f}%. "
                f"Last 7 days total: {last_7d}. "
                f"Peak: {peak_day['cases']} on {peak_day['DT_NOTIFIC'].strftime('%Y-%m-%d')}."
            )

            cutoff_30d = pd.to_datetime(max_date) - pd.Timedelta(days=30)
            plot_df = df[df["DT_NOTIFIC"] >= cutoff_30d]

            sns.lineplot(
                data=plot_df,
                x="DT_NOTIFIC",
                y="cases",
                marker="o",
                color="#d62728",
                ax=ax,
            )

            trend_icon = "📈" if growth_rate > 0 else "📉"
            ax.set_title(f"30-Day Trend | Growth: {growth_rate:+.1f}% {trend_icon}")
            ax.tick_params(axis="x", rotation=45)

        elif chart_type == "history_12m":
            query = f"""
                SELECT strftime(DT_NOTIFIC, '%Y-%m') AS month_str, COUNT(*) AS cases
                FROM srag_analytics
                WHERE DT_NOTIFIC >= CAST('{max_date}' AS DATE) - INTERVAL 12 MONTH
                GROUP BY 1 ORDER BY 1 ASC
            """
            df = con.execute(query).df()

            total_cases = df["cases"].sum()
            peak_month = df.loc[df["cases"].idxmax()]
            avg_cases = df["cases"].mean()

            stats_summary = (
                f"DATA SUMMARY FOR AGENT: 12 months total: {total_cases}. "
                f"Avg: {avg_cases:.1f}. Peak: {peak_month['month_str']}."
            )

            sns.barplot(data=df, x="month_str", y="cases", color="#1f77b4", ax=ax)
            ax.set_title("12-Month History")
            ax.tick_params(axis="x", rotation=45)

        fig.tight_layout()

        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"{chart_type}_{timestamp}.png"
        filepath = self.output_dir / filename

        fig.savefig(
            filepath, format="png", dpi=100, facecolor="white", bbox_inches="tight"
        )

        plt.close(fig)

        logger.info(f"Plot saved to {filepath}")
        return f"**System Note:** Chart generated at {filepath}.\n\n{stats_summary}"

    except Exception as e:
        logger.error(f"Plot generation failed: {e}", exc_info=True)

        if "fig" in locals():
            plt.close(fig)

        return f"Error generating chart: {str(e)}"

    finally:
        con.close()

create_plot_tool(output_dir)

Factory to create the Plot Tool instance.

Parameters:

Name Type Description Default
output_dir Path

The target directory for saving chart images.

required

Returns:

Name Type Description
Tool Tool[AgentDeps]

The Pydantic AI Tool wrapping the PlotTool class.

Source code in api/src/tools/plot.py
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
def create_plot_tool(output_dir: Path) -> Tool[AgentDeps]:
    """
    Factory to create the Plot Tool instance.

    Args:
        output_dir (Path): The target directory for saving chart images.

    Returns:
        Tool: The Pydantic AI Tool wrapping the PlotTool class.
    """
    return Tool(
        PlotTool(output_dir=output_dir).__call__,
        name="plot_tool",
        description=(
            "Generates a chart and saves it to disk. "
            "Returns a statistical summary to help interpret the visual data."
        ),
    )