Skip to content

Stats

StatsTool dataclass

Tool responsible for executing raw analytical queries against the DuckDB database.

It returns the result in Markdown format, allowing the LLM to read tables directly.

Source code in api/src/tools/stats.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
@dataclass
class StatsTool:
    """
    Tool responsible for executing raw analytical queries against the DuckDB database.

    It returns the result in Markdown format, allowing the LLM to read tables directly.
    """

    def __call__(self, ctx: RunContext[AgentDeps], sql_query: str) -> str:
        """
        Executes a SQL query against the 'srag_analytics' table.

        **Constraints:**

        - **Read-Only:** The connection is strictly read-only.
        - **Row Limit:** If the result exceeds 20 rows, it asks the LLM to aggregate data,
          preventing context window overflow.

        Args:
            ctx: Runtime context containing the DB connection.
            sql_query: The executable SQL query.

        Returns:
            str: A Markdown table of the results or an error message.
        """
        logger.info(f"Received SQL: {sql_query}")

        # Using agent's dependency
        con = ctx.deps.get_db_connection(read_only=True)

        try:
            df = con.execute(sql_query).df()

            if len(df) > 20:
                return (
                    f"Error: Result contains {len(df)} rows. "
                    "Please aggregate your query using GROUP BY or use LIMIT 20."
                )

            if df.empty:
                return "Result: No data found for this query."

            return df.to_markdown(index=False)

        except Exception as e:
            logger.error(f"SQL Execution failed: {e}")
            return f"SQL Error: {str(e)}"
        finally:
            con.close()

__call__(ctx, sql_query)

Executes a SQL query against the 'srag_analytics' table.

Constraints:

  • Read-Only: The connection is strictly read-only.
  • Row Limit: If the result exceeds 20 rows, it asks the LLM to aggregate data, preventing context window overflow.

Parameters:

Name Type Description Default
ctx RunContext[AgentDeps]

Runtime context containing the DB connection.

required
sql_query str

The executable SQL query.

required

Returns:

Name Type Description
str str

A Markdown table of the results or an error message.

Source code in api/src/tools/stats.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def __call__(self, ctx: RunContext[AgentDeps], sql_query: str) -> str:
    """
    Executes a SQL query against the 'srag_analytics' table.

    **Constraints:**

    - **Read-Only:** The connection is strictly read-only.
    - **Row Limit:** If the result exceeds 20 rows, it asks the LLM to aggregate data,
      preventing context window overflow.

    Args:
        ctx: Runtime context containing the DB connection.
        sql_query: The executable SQL query.

    Returns:
        str: A Markdown table of the results or an error message.
    """
    logger.info(f"Received SQL: {sql_query}")

    # Using agent's dependency
    con = ctx.deps.get_db_connection(read_only=True)

    try:
        df = con.execute(sql_query).df()

        if len(df) > 20:
            return (
                f"Error: Result contains {len(df)} rows. "
                "Please aggregate your query using GROUP BY or use LIMIT 20."
            )

        if df.empty:
            return "Result: No data found for this query."

        return df.to_markdown(index=False)

    except Exception as e:
        logger.error(f"SQL Execution failed: {e}")
        return f"SQL Error: {str(e)}"
    finally:
        con.close()

create_stats_tool()

Factory to create the Stats Tool instance.

Returns:

Name Type Description
Tool Tool[AgentDeps]

The Pydantic AI Tool wrapping the StatsTool class.

Source code in api/src/tools/stats.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def create_stats_tool() -> Tool[AgentDeps]:
    """
    Factory to create the Stats Tool instance.

    Returns:
        Tool: The Pydantic AI Tool wrapping the StatsTool class.
    """
    return Tool(
        StatsTool().__call__,
        name="stats_tool",
        description=(
            "Executes a SQL query against the 'srag_analytics' table and returns the results. "
            "Use this to calculate metrics like mortality, counts, and averages."
        ),
    )

validate_sql_safety(args)

Security Validator: Checks for destructive SQL commands.

This function acts as a guardrail before the tool is even called. It inspects the arguments generated by the LLM.

Parameters:

Name Type Description Default
args dict

The dictionary of arguments passed by the LLM (e.g., {'sql_query': '...'}).

required

Returns:

Type Description
str | None

str | None: An error message if a violation is detected, or None if safe.

Source code in api/src/tools/stats.py
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def validate_sql_safety(args: dict) -> str | None:
    """
    Security Validator: Checks for destructive SQL commands.

    This function acts as a guardrail *before* the tool is even called.
    It inspects the arguments generated by the LLM.

    Args:
        args (dict): The dictionary of arguments passed by the LLM (e.g., `{'sql_query': '...'}`).

    Returns:
        str | None: An error message if a violation is detected, or None if safe.
    """
    query = args.get("sql_query", "").upper()
    forbidden_keywords = ["DROP", "DELETE", "TRUNCATE", "ALTER", "UPDATE", "INSERT"]

    if any(keyword in query for keyword in forbidden_keywords):
        return f"Security Violation: Destructive SQL commands ({', '.join(forbidden_keywords)}) are strictly prohibited."

    return None