"""State classes for OpenChatBI graph execution."""
from typing import Annotated, Any
from langchain_core.messages import AIMessage, HumanMessage
from langgraph.graph import MessagesState
from langgraph.types import Send
[docs]
def add_history_messages(left: list, right: list):
if left:
total_messages = left + right
else:
total_messages = right
return total_messages
[docs]
class AgentState(MessagesState):
"""State for the main agent graph execution.
Extends MessagesState with additional fields for routing and responses.
"""
history_messages: Annotated[list[HumanMessage | AIMessage], add_history_messages]
agent_next_node: str
sends: list[Send]
sql: str
final_answer: str
[docs]
class SQLGraphState(MessagesState):
"""State for SQL generation subgraph.
Contains rewritten question, table selection, extracted entities, and generated SQL.
"""
rewrite_question: str
tables: list[dict[str, Any]]
info_entities: dict[str, Any]
sql: str
sql_retry_count: int
sql_execution_result: str
schema_info: dict[str, Any] # Data schema analysis results
data: str # CSV data for display
previous_sql_errors: list[dict[str, Any]]
visualization_dsl: dict[str, Any]
[docs]
class OutputState(MessagesState):
"""Output state schema for the main graph."""
pass
[docs]
class SQLOutputState(MessagesState):
"""Output state schema for the SQL generation subgraph."""
rewrite_question: str
tables: list[dict[str, Any]]
sql: str
schema_info: dict[str, Any] # Data schema analysis results
data: str # CSV data for display
visualization_dsl: dict[str, Any]