1from typing import List, Dict, Optional, Tuple, Any
2from enum import Enum
3from openai import OpenAI
4import uuid
5import time
6from dataclasses import dataclass
7from datetime import datetime
8import wikipedia
9import psycopg2
10from psycopg2.extras import Json, UUID_adapter, register_uuid
11import json
12
13## Tools
14
15def get_database_schema() -> str:
16 """Retrieve the database schema information"""
17 schema_query = """
18 SELECT
19 t.table_name,
20 array_agg(
21 c.column_name || ' ' ||
22 c.data_type ||
23 CASE
24 WHEN c.is_nullable = 'NO' THEN ' NOT NULL'
25 ELSE ''
26 END
27 ) as columns
28 FROM information_schema.tables t
29 JOIN information_schema.columns c
30 ON c.table_name = t.table_name
31 WHERE t.table_schema = 'public'
32 AND t.table_type = 'BASE TABLE'
33 GROUP BY t.table_name;
34 """
35
36 try:
37 with psycopg2.connect(
38 "dbname=demodb user=postgres password=postgrespw host=127.0.0.1"
39 ) as conn:
40 with conn.cursor() as cur:
41 cur.execute(schema_query)
42 schema = cur.fetchall()
43
44 # Format schema information
45 schema_str = "Database Schema:\n"
46 for table_name, columns in schema:
47 schema_str += f"\n{table_name}\n"
48 for col in columns:
49 schema_str += f" - {col}\n"
50
51 return schema_str
52 except Exception as e:
53 return f"Error fetching schema: {str(e)}"
54
55# First, we define how the LLM should understand our tools
56tools = [
57 {
58 "type": "function",
59 "function": {
60 "name": "query_database",
61 "description": """Execute a PostgreSQL SELECT query and return the results.
62 Available tables and their schemas:
63
64 users
65 - id SERIAL PRIMARY KEY
66 - email VARCHAR(255) NOT NULL
67 - age INTEGER
68 - location VARCHAR(100)
69 - signup_date DATE NOT NULL
70 - last_login TIMESTAMP WITH TIME ZONE
71 - job_industry VARCHAR(100)
72
73 user_activity
74 - id SERIAL PRIMARY KEY
75 - user_id INTEGER REFERENCES users(id)
76 - activity_date DATE NOT NULL
77 - activity_type VARCHAR(50) NOT NULL
78 """,
79 "parameters": {
80 "type": "object",
81 "properties": {
82 "query": {
83 "type": "string",
84 "description": "The SQL SELECT query to execute. Must start with SELECT for security.",
85 }
86 },
87 "required": ["query"],
88 "additionalProperties": False,
89 },
90 "strict": True,
91 },
92 },
93 {
94 "type": "function",
95 "function": {
96 "name": "search_wikipedia",
97 "description": """Search Wikipedia and return a concise summary.
98 Returns the first three sentences of the most
99 relevant article.""",
100 "parameters": {
101 "type": "object",
102 "properties": {
103 "query": {
104 "type": "string",
105 "description": "The topic to search for on Wikipedia",
106 }
107 },
108 "required": ["query"],
109 "additionalProperties": False,
110 },
111 "strict": True,
112 },
113 },
114]
115
116# Now implement the actual tool functions
117# Updated database query function
118def query_database(query: str) -> str:
119 """
120 Execute a PostgreSQL query with schema awareness.
121 Only SELECT queries are allowed for security.
122 """
123 if not query.lower().strip().startswith("select"):
124 return json.dumps(
125 {
126 "error": "Only SELECT queries are allowed for security reasons.",
127 "schema": get_database_schema(), # Return schema for reference
128 }
129 )
130
131 try:
132 with psycopg2.connect(
133 "dbname=demodb user=postgres password=postgrespw host=127.0.0.1"
134 ) as conn:
135 with conn.cursor() as cur:
136 cur.execute(query)
137
138 # Get column names from cursor description
139 columns = [desc[0] for desc in cur.description]
140
141 # Fetch results and convert to list of dictionaries
142 results = []
143 for row in cur.fetchall():
144 results.append(dict(zip(columns, row)))
145
146 return json.dumps(
147 {
148 "success": True,
149 "data": results,
150 "row_count": len(results),
151 "columns": columns,
152 }
153 )
154
155 except Exception as e:
156 return json.dumps(
157 {
158 "error": str(e),
159 "schema": get_database_schema(), # Return schema on error for help
160 }
161 )
162
163def search_wikipedia(query: str) -> str:
164 """
165 Search Wikipedia and return a concise summary.
166 Handles disambiguation and missing pages gracefully.
167 """
168 try:
169 # Try to get the most relevant page summary
170 summary = wikipedia.summary(
171 query, sentences=3, auto_suggest=True, redirect=True
172 )
173
174 return json.dumps(
175 {"success": True, "summary": summary, "url": wikipedia.page(query).url}
176 )
177
178 except wikipedia.DisambiguationError as e:
179 # Handle multiple matching pages
180 return json.dumps(
181 {
182 "error": "Disambiguation error",
183 "options": e.options[:5], # List first 5 options
184 "message": "Topic is ambiguous. Please be more specific.",
185 }
186 )
187 except wikipedia.PageError:
188 return json.dumps(
189 {
190 "error": "Page not found",
191 "message": f"No Wikipedia article found for: {query}",
192 }
193 )
194 except Exception as e:
195 return json.dumps({"error": "Unexpected error", "message": str(e)})
196
197## Memory config
198register_uuid() # Registers UUID adapter for psycopg2
199
200@dataclass
201class MemoryConfig:
202 """Configuration for memory management"""
203 max_messages: int = 20 # When to summarize
204 summary_length: int = 2000 # Max summary length in words
205 db_connection: str = (
206 "dbname=demodb user=postgres password=postgrespw host=127.0.0.1"
207 )
208
209class AgentMemory:
210 def __init__(self, config: Optional[MemoryConfig] = None):
211 self.config = config or MemoryConfig()
212 self.session_id = uuid.uuid4()
213 self.setup_database()
214
215 def setup_database(self):
216 """Create necessary database tables if they don't exist"""
217 queries = [
218 """
219 CREATE TABLE IF NOT EXISTS conversations (
220 id SERIAL PRIMARY KEY,
221 session_id UUID NOT NULL,
222 user_input TEXT NOT NULL,
223 agent_response TEXT NOT NULL,
224 tool_calls JSONB,
225 timestamp TIMESTAMPTZ DEFAULT NOW()
226 );
227 """,
228 """
229 CREATE TABLE IF NOT EXISTS conversation_summaries (
230 id SERIAL PRIMARY KEY,
231 session_id UUID NOT NULL,
232 summary TEXT NOT NULL,
233 start_time TIMESTAMPTZ NOT NULL,
234 end_time TIMESTAMPTZ NOT NULL,
235 message_count INTEGER NOT NULL
236 );
237 """,
238 ]
239 with psycopg2.connect(self.config.db_connection) as conn:
240 with conn.cursor() as cur:
241 for query in queries:
242 cur.execute(query)
243
244 def store_interaction(
245 self,
246 user_input: str,
247 agent_response: str,
248 tool_calls: Optional[List[Dict]] = None,
249 ):
250 """Store a single interaction in the database"""
251 query = """
252 INSERT INTO conversations
253 (session_id, user_input, agent_response, tool_calls)
254 VALUES (%s, %s, %s, %s)
255 """
256 with psycopg2.connect(self.config.db_connection) as conn:
257 with conn.cursor() as cur:
258 cur.execute(
259 query,
260 (
261 self.session_id,
262 user_input,
263 agent_response,
264 Json(tool_calls) if tool_calls else None,
265 ),
266 )
267
268 def create_summary(self, messages: List[Dict]) -> str:
269 """Create a summary of messages using the LLM"""
270 client = OpenAI()
271 # Prepare messages for summarization
272 summary_prompt = f"""
273 Summarize the following conversation in less than {self.config.summary_length} words.
274 Focus on key points, decisions, and important information discovered through tool usage.
275
276 Conversation:
277 {messages}
278 """
279 response = client.chat.completions.create(
280 model="gpt-4o", messages=[{"role": "user", "content": summary_prompt}]
281 )
282 return response.choices[0].message.content
283
284 def store_summary(
285 self, summary: str, start_time: datetime, end_time: datetime, message_count: int
286 ):
287 """Store a conversation summary"""
288 query = """
289 INSERT INTO conversation_summaries
290 (session_id, summary, start_time, end_time, message_count)
291 VALUES (%s, %s, %s, %s, %s)
292 """
293 with psycopg2.connect(self.config.db_connection) as conn:
294 with conn.cursor() as cur:
295 cur.execute(
296 query,
297 (self.session_id, summary, start_time, end_time, message_count),
298 )
299
300 def get_recent_context(self) -> str:
301 """Get recent conversations and summaries for context"""
302 # First, get the most recent summary
303 summary_query = """
304 SELECT summary, end_time
305 FROM conversation_summaries
306 WHERE session_id = %s
307 ORDER BY end_time DESC
308 LIMIT 1
309 """
310 # Then get conversations after the summary
311 conversations_query = """
312 SELECT user_input, agent_response, tool_calls, timestamp
313 FROM conversations
314 WHERE session_id = %s
315 AND timestamp > %s
316 ORDER BY timestamp ASC
317 """
318 with psycopg2.connect(self.config.db_connection) as conn:
319 with conn.cursor() as cur:
320 # Get latest summary
321 cur.execute(summary_query, (self.session_id,))
322 summary_row = cur.fetchone()
323 if summary_row:
324 summary, last_summary_time = summary_row
325 # Get conversations after the summary
326 cur.execute(
327 conversations_query, (self.session_id, last_summary_time)
328 )
329 else:
330 # If no summary exists, get recent conversations
331 cur.execute(
332 """
333 SELECT user_input, agent_response, tool_calls, timestamp
334 FROM conversations
335 WHERE session_id = %s
336 ORDER BY timestamp DESC
337 LIMIT %s
338 """,
339 (self.session_id, self.config.max_messages),
340 )
341 conversations = cur.fetchall()
342 # Format context
343 context = []
344 if summary_row:
345 context.append(f"Previous conversation summary: {summary}")
346 for conv in conversations:
347 user_input, agent_response, tool_calls, _ = conv
348 context.append(f"User: {user_input}")
349 if tool_calls:
350 context.append(f"Tool Usage: {tool_calls}")
351 context.append(f"Assistant: {agent_response}")
352 return "\n".join(context)
353
354 def check_and_summarize(self):
355 """Check if we need to summarize and do it if necessary"""
356 query = """
357 SELECT COUNT(*)
358 FROM conversations
359 WHERE session_id = %s
360 AND timestamp > (
361 SELECT COALESCE(MAX(end_time), '1970-01-01'::timestamptz)
362 FROM conversation_summaries
363 WHERE session_id = %s
364 )
365 """
366 with psycopg2.connect(self.config.db_connection) as conn:
367 with conn.cursor() as cur:
368 cur.execute(query, (self.session_id, self.session_id))
369 count = cur.fetchone()[0]
370 if count >= self.config.max_messages:
371 # Get messages to summarize
372 cur.execute(
373 """
374 SELECT user_input, agent_response, tool_calls, timestamp
375 FROM conversations
376 WHERE session_id = %s
377 ORDER BY timestamp ASC
378 LIMIT %s
379 """,
380 (self.session_id, count),
381 )
382 messages = cur.fetchall()
383 if messages:
384 # Create and store summary
385 summary = self.create_summary(messages)
386 self.store_summary(
387 summary,
388 messages[0][3], # start_time
389 messages[-1][3], # end_time
390 len(messages),
391 )
392
393class AgentState(Enum):
394 THINKING = "thinking"
395 DONE = "done"
396 ERROR = "error"
397 NEED_MORE_INFO = "need_more_info"
398
399class Agent:
400 def __init__(self, max_iterations: int = 5, think_time: float = 0.5):
401 self.client = OpenAI()
402 self.max_iterations = max_iterations
403 self.think_time = think_time # Time between iterations
404 self.messages = []
405 self.iteration_count = 0
406
407 def process_with_loop(self, user_input: str) -> Dict:
408 """
409 Process user input with multiple iterations if needed.
410 Returns both final answer and execution trace.
411 """
412 self.iteration_count = 0
413 trace = []
414 # Initial prompt
415 self.messages.append({"role": "user", "content": user_input})
416 while self.iteration_count < self.max_iterations:
417 self.iteration_count += 1
418 try:
419 # Get agent's thoughts and next action
420 state, response = self._think_and_act()
421 # Record this iteration
422 trace.append(
423 {
424 "iteration": self.iteration_count,
425 "state": state.value,
426 "response": response,
427 }
428 )
429 # Handle different states
430 if state == AgentState.DONE:
431 return {
432 "status": "success",
433 "answer": response,
434 "iterations": trace,
435 "iteration_count": self.iteration_count,
436 }
437 elif state == AgentState.ERROR:
438 return {
439 "status": "error",
440 "error": response,
441 "iterations": trace,
442 "iteration_count": self.iteration_count,
443 }
444 elif state == AgentState.NEED_MORE_INFO:
445 return {
446 "status": "need_more_info",
447 "question": response,
448 "iterations": trace,
449 "iteration_count": self.iteration_count,
450 }
451 # Add thinking time between iterations
452 time.sleep(self.think_time)
453 except Exception as e:
454 return {
455 "status": "error",
456 "error": str(e),
457 "iterations": trace,
458 "iteration_count": self.iteration_count,
459 }
460 return {
461 "status": "max_iterations_reached",
462 "iterations": trace,
463 "iteration_count": self.iteration_count,
464 "final_state": response,
465 }
466
467 def _think_and_act(self) -> Tuple[AgentState, str]:
468 """
469 Single iteration of thinking and acting.
470 Returns state and response.
471 """
472 completion = self.client.chat.completions.create(
473 model="gpt-4",
474 messages=[
475 *self.messages,
476 {
477 "role": "system",
478 "content": f"""
479 This is iteration {self.iteration_count} of {self.max_iterations}.
480 Determine if you:
481 1. Need to use tools to gather more information
482 2. Need to ask the user for clarification
483 3. Have enough information to provide final answer
484
485 Format your response as:
486 THOUGHT: your reasoning process
487 ACTION: TOOL_CALL or ASK_USER or FINAL_ANSWER
488 CONTENT: your tool call, question, or final answer
489 """,
490 },
491 ],
492 tools=tools,
493 )
494 response = completion.choices[0].message
495 self.messages.append(response)
496 # Parse the response
497 content = response.content
498 if "ACTION: TOOL_CALL" in content:
499 # Handle tool calls through function calling
500 if response.tool_calls:
501 tool_results = []
502 for tool_call in response.tool_calls:
503 result = self.execute_tool(tool_call)
504 tool_results.append(result)
505 self.messages.append(
506 {
507 "role": "tool",
508 "tool_call_id": tool_call.id,
509 "content": result,
510 }
511 )
512 return AgentState.THINKING, "Executed tools: " + ", ".join(tool_results)
513 elif "ACTION: ASK_USER" in content:
514 # Extract question from CONTENT section
515 question = content.split("CONTENT:")[1].strip()
516 return AgentState.NEED_MORE_INFO, question
517 elif "ACTION: FINAL_ANSWER" in content:
518 # Extract final answer from CONTENT section
519 answer = content.split("CONTENT:")[1].strip()
520 return AgentState.DONE, answer
521 return AgentState.ERROR, "Could not determine next action"
522
523 def execute_tool(self, tool_call: Any) -> str:
524 """
525 Execute a tool based on the LLM's decision.
526 Args:
527 tool_call: The function call object from OpenAI's API
528 Returns:
529 str: JSON-formatted result of the tool execution
530 """
531 try:
532 # Extract function details
533 function_name = tool_call.function.name
534 function_args = json.loads(tool_call.function.arguments)
535 # Log tool usage (helpful for debugging)
536 print(f"Executing tool: {function_name} with args: {function_args}")
537 # Execute the appropriate tool
538 if function_name == "query_database":
539 result = query_database(function_args["query"])
540 elif function_name == "search_wikipedia":
541 result = search_wikipedia(function_args["query"])
542 else:
543 result = json.dumps({"error": f"Unknown tool: {function_name}"})
544 # Log tool result (helpful for debugging)
545 print(f"Tool result: {result}")
546 return result
547 except json.JSONDecodeError:
548 return json.dumps({"error": "Failed to parse tool arguments"})
549 except Exception as e:
550 return json.dumps({"error": f"Tool execution failed: {str(e)}"})
551
552# Usage example for non-memory agent:
553def interact_with_agent():
554 agent = Agent(max_iterations=5)
555 while True:
556 user_input = input("\nYour question (or 'quit' to exit): ")
557 if user_input.lower() == "quit":
558 break
559 result = agent.process_with_loop(user_input)
560 if result["status"] == "success":
561 print(f"\nAnswer: {result['answer']}")
562 elif result["status"] == "need_more_info":
563 print(f"\nNeed more information: {result['question']}")
564 else:
565 print(f"\nError or max iterations reached: {result['status']}")
566 # Optional: Show iteration trace
567 print("\nExecution trace:")
568 for step in result["iterations"]:
569 print(f"\nIteration {step['iteration']} ({step['state']}):")
570 print(step["response"])
571
572### Memory agent
573
574class MemoryAgent(Agent):
575 def __init__(
576 self,
577 max_iterations: int = 5,
578 think_time: float = 0.5,
579 memory_config: Optional[MemoryConfig] = None,
580 ):
581 super().__init__(max_iterations, think_time)
582 self.memory = AgentMemory(memory_config)
583 # Initialize with system prompt and context
584 self.messages = [
585 {
586 "role": "system",
587 "content": """You are a helpful AI assistant with memory of past interactions.
588 You can access previous context to provide more relevant and consistent responses.
589 Use this context wisely to maintain conversation continuity.""",
590 }
591 ]
592
593 def process_with_loop(self, user_input: str) -> Dict:
594 """Enhanced process_with_loop with memory integration"""
595 try:
596 # Check if we need to summarize before processing
597 self.memory.check_and_summarize()
598 # Get context from memory
599 context = self.memory.get_recent_context()
600 # Add context to messages if it exists
601 if context:
602 self.messages.append(
603 {
604 "role": "system",
605 "content": f"Previous conversation context:\n{context}",
606 }
607 )
608 # Process the query using parent class method
609 result = super().process_with_loop(user_input)
610 # Store the interaction in memory
611 if result["status"] == "success":
612 tool_calls = []
613 for iteration in result["iterations"]:
614 if "Executed tools:" in iteration["response"]:
615 tool_calls.append(
616 {
617 "iteration": iteration["iteration"],
618 "tools": iteration["response"],
619 }
620 )
621 self.memory.store_interaction(
622 user_input=user_input,
623 agent_response=result["answer"],
624 tool_calls=tool_calls if tool_calls else None,
625 )
626 return result
627 except Exception as e:
628 error_message = f"Error processing query: {str(e)}"
629 self.memory.store_interaction(
630 user_input=user_input, agent_response=error_message
631 )
632 return {
633 "status": "error",
634 "error": error_message,
635 "iterations": [],
636 "iteration_count": 0,
637 }
638
639 def _think_and_act(self) -> Tuple[AgentState, str]:
640 """Enhanced thinking process with memory context"""
641 # Add memory-aware system message
642 memory_context = self.memory.get_recent_context()
643 completion = self.client.chat.completions.create(
644 model="gpt-4",
645 messages=[
646 *self.messages,
647 {
648 "role": "system",
649 "content": f"""
650 This is iteration {self.iteration_count} of {self.max_iterations}.
651 Previous context: {memory_context}
652
653 Determine if you:
654 1. Need to use tools to gather more information
655 2. Need to ask the user for clarification
656 3. Have enough information to provide final answer
657
658 Consider previous context when making decisions.
659
660 Format your response as:
661 THOUGHT: your reasoning process
662 ACTION: TOOL_CALL or ASK_USER or FINAL_ANSWER
663 CONTENT: your tool call, question, or final answer
664 """,
665 },
666 ],
667 tools=tools,
668 )
669 return super()._think_and_act()
670
671def interact_with_memory_agent():
672 agent = MemoryAgent(max_iterations=5)
673 while True:
674 user_input = input("\nYour question (or 'quit' to exit): ")
675 if user_input.lower() == "quit":
676 break
677 result = agent.process_with_loop(user_input)
678 if result["status"] == "success":
679 print(f"\nAnswer: {result['answer']}")
680 elif result["status"] == "need_more_info":
681 print(f"\nNeed more information: {result['question']}")
682 else:
683 print(f"\nError or max iterations reached: {result['status']}")
684 print("\nExecution trace:")
685 for step in result["iterations"]:
686 print(f"\nIteration {step['iteration']} ({step['state']}):")
687 print(step["response"])
688
689# Run it
690if __name__ == "__main__":
691 interact_with_memory_agent()
692 # interact_with_agent() # Use this for non-memory agent