anderson-ufrj commited on
Commit
43cf505
·
1 Parent(s): 68d8151

test(memory): implement memory system tests

Browse files

- Test episodic memory storage and retrieval
- Test semantic memory and knowledge graphs
- Test conversational memory management
- Test memory consolidation process
- Test cross-referencing between memory types
- Add importance calculation tests

Files changed (1) hide show
  1. tests/unit/test_memory_system.py +595 -0
tests/unit/test_memory_system.py ADDED
@@ -0,0 +1,595 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unit tests for memory system components."""
2
+ import pytest
3
+ import asyncio
4
+ from datetime import datetime, timedelta
5
+ from unittest.mock import MagicMock, patch, AsyncMock
6
+ import numpy as np
7
+ import json
8
+
9
+ from src.memory.episodic import (
10
+ EpisodicMemory,
11
+ Episode,
12
+ EpisodeType,
13
+ MemoryConsolidation
14
+ )
15
+ from src.memory.semantic import (
16
+ SemanticMemory,
17
+ Concept,
18
+ ConceptRelation,
19
+ KnowledgeGraph
20
+ )
21
+ from src.memory.conversational import (
22
+ ConversationalMemory,
23
+ DialogTurn,
24
+ ConversationContext,
25
+ IntentMemory
26
+ )
27
+ from src.memory.base import MemoryStore, MemoryEntry, ImportanceCalculator
28
+
29
+
30
+ class TestMemoryEntry:
31
+ """Test base memory entry."""
32
+
33
+ def test_memory_entry_creation(self):
34
+ """Test creating memory entry."""
35
+ entry = MemoryEntry(
36
+ id="mem_123",
37
+ content={"data": "test memory"},
38
+ timestamp=datetime.now(),
39
+ importance=0.8,
40
+ access_count=0
41
+ )
42
+
43
+ assert entry.id == "mem_123"
44
+ assert entry.content["data"] == "test memory"
45
+ assert entry.importance == 0.8
46
+ assert entry.access_count == 0
47
+
48
+ def test_memory_entry_decay(self):
49
+ """Test memory importance decay over time."""
50
+ # Create old memory
51
+ old_timestamp = datetime.now() - timedelta(days=7)
52
+ entry = MemoryEntry(
53
+ id="old_mem",
54
+ content="old data",
55
+ timestamp=old_timestamp,
56
+ importance=1.0
57
+ )
58
+
59
+ # Calculate decayed importance
60
+ decayed = entry.get_decayed_importance(decay_rate=0.1)
61
+
62
+ # Should be less than original
63
+ assert decayed < 1.0
64
+ assert decayed > 0.0
65
+
66
+ def test_memory_entry_access_tracking(self):
67
+ """Test memory access tracking."""
68
+ entry = MemoryEntry(
69
+ id="tracked_mem",
70
+ content="data",
71
+ importance=0.5
72
+ )
73
+
74
+ # Track accesses
75
+ entry.record_access()
76
+ entry.record_access()
77
+ entry.record_access()
78
+
79
+ assert entry.access_count == 3
80
+ assert entry.last_accessed is not None
81
+
82
+
83
+ class TestImportanceCalculator:
84
+ """Test importance calculation strategies."""
85
+
86
+ def test_recency_importance(self):
87
+ """Test recency-based importance."""
88
+ calculator = ImportanceCalculator(strategy="recency")
89
+
90
+ # Recent memory should have high importance
91
+ recent = datetime.now() - timedelta(minutes=10)
92
+ importance = calculator.calculate(
93
+ content="recent data",
94
+ metadata={"timestamp": recent}
95
+ )
96
+ assert importance > 0.8
97
+
98
+ # Old memory should have lower importance
99
+ old = datetime.now() - timedelta(days=30)
100
+ importance = calculator.calculate(
101
+ content="old data",
102
+ metadata={"timestamp": old}
103
+ )
104
+ assert importance < 0.3
105
+
106
+ def test_frequency_importance(self):
107
+ """Test frequency-based importance."""
108
+ calculator = ImportanceCalculator(strategy="frequency")
109
+
110
+ # High access count = high importance
111
+ importance = calculator.calculate(
112
+ content="popular data",
113
+ metadata={"access_count": 100}
114
+ )
115
+ assert importance > 0.7
116
+
117
+ # Low access count = low importance
118
+ importance = calculator.calculate(
119
+ content="unpopular data",
120
+ metadata={"access_count": 1}
121
+ )
122
+ assert importance < 0.3
123
+
124
+ def test_combined_importance(self):
125
+ """Test combined importance calculation."""
126
+ calculator = ImportanceCalculator(strategy="combined")
127
+
128
+ # Recent and frequently accessed
129
+ importance = calculator.calculate(
130
+ content="important data",
131
+ metadata={
132
+ "timestamp": datetime.now() - timedelta(hours=1),
133
+ "access_count": 50,
134
+ "user_rating": 0.9
135
+ }
136
+ )
137
+ assert importance > 0.8
138
+
139
+
140
+ class TestEpisodicMemory:
141
+ """Test episodic memory system."""
142
+
143
+ @pytest.fixture
144
+ def episodic_memory(self):
145
+ """Create episodic memory instance."""
146
+ return EpisodicMemory(max_episodes=100)
147
+
148
+ @pytest.mark.asyncio
149
+ async def test_store_episode(self, episodic_memory):
150
+ """Test storing investigation episode."""
151
+ episode = Episode(
152
+ id="ep_123",
153
+ type=EpisodeType.INVESTIGATION,
154
+ content={
155
+ "investigation_id": "inv_456",
156
+ "anomalies_found": 5,
157
+ "confidence": 0.85
158
+ },
159
+ participants=["zumbi", "anita"],
160
+ outcome="success"
161
+ )
162
+
163
+ await episodic_memory.store_episode(episode)
164
+
165
+ # Retrieve episode
166
+ retrieved = await episodic_memory.get_episode("ep_123")
167
+ assert retrieved is not None
168
+ assert retrieved.content["anomalies_found"] == 5
169
+ assert "zumbi" in retrieved.participants
170
+
171
+ @pytest.mark.asyncio
172
+ async def test_retrieve_similar_episodes(self, episodic_memory):
173
+ """Test retrieving similar episodes."""
174
+ # Store multiple episodes
175
+ episodes = [
176
+ Episode(
177
+ id=f"ep_{i}",
178
+ type=EpisodeType.INVESTIGATION,
179
+ content={
180
+ "target_entity": "Ministry of Health",
181
+ "anomaly_type": "price" if i % 2 == 0 else "vendor",
182
+ "severity": 0.7 + (i * 0.05)
183
+ }
184
+ )
185
+ for i in range(5)
186
+ ]
187
+
188
+ for episode in episodes:
189
+ await episodic_memory.store_episode(episode)
190
+
191
+ # Query similar episodes
192
+ query = {
193
+ "target_entity": "Ministry of Health",
194
+ "anomaly_type": "price"
195
+ }
196
+
197
+ similar = await episodic_memory.retrieve_similar(query, top_k=3)
198
+
199
+ assert len(similar) <= 3
200
+ # Should prioritize episodes with price anomalies
201
+ assert all(ep.content.get("anomaly_type") == "price"
202
+ for ep in similar[:2] if "anomaly_type" in ep.content)
203
+
204
+ @pytest.mark.asyncio
205
+ async def test_episode_consolidation(self, episodic_memory):
206
+ """Test episode consolidation process."""
207
+ # Create related episodes
208
+ episodes = []
209
+ base_time = datetime.now() - timedelta(days=7)
210
+
211
+ for i in range(10):
212
+ episode = Episode(
213
+ id=f"consolidate_{i}",
214
+ type=EpisodeType.INVESTIGATION,
215
+ content={
216
+ "entity": "Entity_A",
217
+ "pattern": "suspicious_spending",
218
+ "value": 100000 + (i * 10000)
219
+ },
220
+ timestamp=base_time + timedelta(hours=i)
221
+ )
222
+ episodes.append(episode)
223
+ await episodic_memory.store_episode(episode)
224
+
225
+ # Consolidate episodes
226
+ consolidator = MemoryConsolidation()
227
+ consolidated = await consolidator.consolidate_episodes(episodes)
228
+
229
+ assert consolidated is not None
230
+ assert consolidated.type == EpisodeType.PATTERN
231
+ assert "Entity_A" in consolidated.content.get("entities", [])
232
+ assert consolidated.content.get("pattern_type") == "suspicious_spending"
233
+
234
+ @pytest.mark.asyncio
235
+ async def test_episode_temporal_retrieval(self, episodic_memory):
236
+ """Test temporal-based episode retrieval."""
237
+ # Store episodes at different times
238
+ now = datetime.now()
239
+ time_points = [
240
+ now - timedelta(days=30), # Old
241
+ now - timedelta(days=7), # Week ago
242
+ now - timedelta(days=1), # Yesterday
243
+ now - timedelta(hours=1) # Recent
244
+ ]
245
+
246
+ for i, timestamp in enumerate(time_points):
247
+ episode = Episode(
248
+ id=f"temporal_{i}",
249
+ type=EpisodeType.ANALYSIS,
250
+ content={"data": f"event_{i}"},
251
+ timestamp=timestamp
252
+ )
253
+ await episodic_memory.store_episode(episode)
254
+
255
+ # Retrieve recent episodes
256
+ recent = await episodic_memory.get_recent_episodes(days=3)
257
+
258
+ assert len(recent) == 2 # Yesterday and 1 hour ago
259
+ assert all(ep.id in ["temporal_2", "temporal_3"] for ep in recent)
260
+
261
+
262
+ class TestSemanticMemory:
263
+ """Test semantic memory and knowledge graph."""
264
+
265
+ @pytest.fixture
266
+ def semantic_memory(self):
267
+ """Create semantic memory instance."""
268
+ return SemanticMemory()
269
+
270
+ @pytest.mark.asyncio
271
+ async def test_store_concept(self, semantic_memory):
272
+ """Test storing concepts in semantic memory."""
273
+ concept = Concept(
274
+ id="concept_price_anomaly",
275
+ name="Price Anomaly",
276
+ category="anomaly_type",
277
+ properties={
278
+ "detection_method": "statistical",
279
+ "severity_range": [0.5, 1.0],
280
+ "common_causes": ["overpricing", "emergency_purchase"]
281
+ },
282
+ embeddings=np.random.rand(384).tolist() # Mock embedding
283
+ )
284
+
285
+ await semantic_memory.store_concept(concept)
286
+
287
+ # Retrieve concept
288
+ retrieved = await semantic_memory.get_concept("concept_price_anomaly")
289
+ assert retrieved is not None
290
+ assert retrieved.name == "Price Anomaly"
291
+ assert "statistical" in retrieved.properties["detection_method"]
292
+
293
+ @pytest.mark.asyncio
294
+ async def test_concept_relations(self, semantic_memory):
295
+ """Test concept relationships in knowledge graph."""
296
+ # Create related concepts
297
+ anomaly = Concept(
298
+ id="anomaly",
299
+ name="Anomaly",
300
+ category="root"
301
+ )
302
+ price_anomaly = Concept(
303
+ id="price_anomaly",
304
+ name="Price Anomaly",
305
+ category="anomaly_type"
306
+ )
307
+ overpricing = Concept(
308
+ id="overpricing",
309
+ name="Overpricing",
310
+ category="anomaly_subtype"
311
+ )
312
+
313
+ # Store concepts
314
+ for concept in [anomaly, price_anomaly, overpricing]:
315
+ await semantic_memory.store_concept(concept)
316
+
317
+ # Create relations
318
+ relations = [
319
+ ConceptRelation(
320
+ source_id="anomaly",
321
+ target_id="price_anomaly",
322
+ relation_type="has_subtype",
323
+ strength=1.0
324
+ ),
325
+ ConceptRelation(
326
+ source_id="price_anomaly",
327
+ target_id="overpricing",
328
+ relation_type="includes",
329
+ strength=0.9
330
+ )
331
+ ]
332
+
333
+ for relation in relations:
334
+ await semantic_memory.add_relation(relation)
335
+
336
+ # Query related concepts
337
+ related = await semantic_memory.get_related_concepts(
338
+ "anomaly",
339
+ relation_type="has_subtype"
340
+ )
341
+
342
+ assert len(related) >= 1
343
+ assert any(c.id == "price_anomaly" for c in related)
344
+
345
+ @pytest.mark.asyncio
346
+ async def test_semantic_search(self, semantic_memory):
347
+ """Test semantic similarity search."""
348
+ # Create concepts with embeddings
349
+ concepts = [
350
+ Concept(
351
+ id=f"concept_{i}",
352
+ name=f"Concept {i}",
353
+ category="test",
354
+ embeddings=np.random.rand(384).tolist()
355
+ )
356
+ for i in range(5)
357
+ ]
358
+
359
+ for concept in concepts:
360
+ await semantic_memory.store_concept(concept)
361
+
362
+ # Search with query embedding
363
+ query_embedding = np.random.rand(384).tolist()
364
+ similar = await semantic_memory.search_similar(
365
+ query_embedding,
366
+ top_k=3
367
+ )
368
+
369
+ assert len(similar) <= 3
370
+ assert all(isinstance(c, Concept) for c in similar)
371
+
372
+ @pytest.mark.asyncio
373
+ async def test_knowledge_graph_traversal(self, semantic_memory):
374
+ """Test knowledge graph traversal."""
375
+ # Build a simple knowledge graph
376
+ kg = KnowledgeGraph()
377
+
378
+ # Add nodes
379
+ nodes = ["government", "ministry", "health_ministry", "contracts"]
380
+ for node in nodes:
381
+ kg.add_node(node, {"type": "entity"})
382
+
383
+ # Add edges
384
+ kg.add_edge("government", "ministry", "contains")
385
+ kg.add_edge("ministry", "health_ministry", "instance_of")
386
+ kg.add_edge("health_ministry", "contracts", "manages")
387
+
388
+ # Find path
389
+ path = kg.find_path("government", "contracts")
390
+
391
+ assert path is not None
392
+ assert len(path) == 4 # government -> ministry -> health_ministry -> contracts
393
+
394
+
395
+ class TestConversationalMemory:
396
+ """Test conversational memory system."""
397
+
398
+ @pytest.fixture
399
+ def conv_memory(self):
400
+ """Create conversational memory instance."""
401
+ return ConversationalMemory(max_turns=50)
402
+
403
+ @pytest.mark.asyncio
404
+ async def test_store_dialog_turn(self, conv_memory):
405
+ """Test storing dialog turns."""
406
+ turn = DialogTurn(
407
+ id="turn_1",
408
+ conversation_id="conv_123",
409
+ speaker="user",
410
+ utterance="Find anomalies in health ministry contracts",
411
+ intent="investigate_anomalies",
412
+ entities=["health_ministry", "contracts"]
413
+ )
414
+
415
+ await conv_memory.add_turn(turn)
416
+
417
+ # Retrieve conversation
418
+ conversation = await conv_memory.get_conversation("conv_123")
419
+ assert len(conversation) == 1
420
+ assert conversation[0].speaker == "user"
421
+ assert "health_ministry" in conversation[0].entities
422
+
423
+ @pytest.mark.asyncio
424
+ async def test_conversation_context(self, conv_memory):
425
+ """Test maintaining conversation context."""
426
+ conv_id = "context_test"
427
+
428
+ # Multi-turn conversation
429
+ turns = [
430
+ DialogTurn(
431
+ id="t1",
432
+ conversation_id=conv_id,
433
+ speaker="user",
434
+ utterance="Analyze ministry of health",
435
+ entities=["ministry_of_health"]
436
+ ),
437
+ DialogTurn(
438
+ id="t2",
439
+ conversation_id=conv_id,
440
+ speaker="agent",
441
+ utterance="Found 5 anomalies in contracts",
442
+ entities=["anomalies", "contracts"]
443
+ ),
444
+ DialogTurn(
445
+ id="t3",
446
+ conversation_id=conv_id,
447
+ speaker="user",
448
+ utterance="Show me the price anomalies",
449
+ intent="filter_results",
450
+ entities=["price_anomalies"]
451
+ )
452
+ ]
453
+
454
+ for turn in turns:
455
+ await conv_memory.add_turn(turn)
456
+
457
+ # Get context
458
+ context = await conv_memory.get_context(conv_id)
459
+
460
+ assert context is not None
461
+ assert len(context.entities) >= 3
462
+ assert "ministry_of_health" in context.entities
463
+ assert context.current_topic is not None
464
+
465
+ @pytest.mark.asyncio
466
+ async def test_intent_memory(self, conv_memory):
467
+ """Test intent pattern memory."""
468
+ # Store intent patterns
469
+ intents = [
470
+ ("Find anomalies in {entity}", "investigate_anomalies"),
471
+ ("Show me {anomaly_type} anomalies", "filter_anomalies"),
472
+ ("Generate report for {investigation}", "generate_report")
473
+ ]
474
+
475
+ intent_memory = IntentMemory()
476
+ for pattern, intent in intents:
477
+ await intent_memory.store_pattern(pattern, intent)
478
+
479
+ # Match new utterance
480
+ utterance = "Find anomalies in education ministry"
481
+ matched_intent = await intent_memory.match_intent(utterance)
482
+
483
+ assert matched_intent is not None
484
+ assert matched_intent["intent"] == "investigate_anomalies"
485
+ assert matched_intent["entities"]["entity"] == "education ministry"
486
+
487
+ @pytest.mark.asyncio
488
+ async def test_conversation_summarization(self, conv_memory):
489
+ """Test conversation summarization."""
490
+ conv_id = "long_conv"
491
+
492
+ # Create long conversation
493
+ for i in range(20):
494
+ turn = DialogTurn(
495
+ id=f"turn_{i}",
496
+ conversation_id=conv_id,
497
+ speaker="user" if i % 2 == 0 else "agent",
498
+ utterance=f"Message {i} about topic {i // 5}"
499
+ )
500
+ await conv_memory.add_turn(turn)
501
+
502
+ # Summarize conversation
503
+ summary = await conv_memory.summarize_conversation(conv_id)
504
+
505
+ assert summary is not None
506
+ assert "topics" in summary
507
+ assert "key_points" in summary
508
+ assert len(summary["key_points"]) < 20 # Condensed
509
+
510
+
511
+ class TestMemoryIntegration:
512
+ """Test integration between memory systems."""
513
+
514
+ @pytest.mark.asyncio
515
+ async def test_episodic_to_semantic_transfer(self):
516
+ """Test transferring episodic memories to semantic knowledge."""
517
+ episodic = EpisodicMemory()
518
+ semantic = SemanticMemory()
519
+
520
+ # Create multiple similar episodes
521
+ for i in range(10):
522
+ episode = Episode(
523
+ id=f"pattern_{i}",
524
+ type=EpisodeType.INVESTIGATION,
525
+ content={
526
+ "entity": "Ministry X",
527
+ "pattern": "end_of_year_spending_spike",
528
+ "severity": 0.8 + (i * 0.01)
529
+ }
530
+ )
531
+ await episodic.store_episode(episode)
532
+
533
+ # Consolidate into semantic knowledge
534
+ pattern_concept = Concept(
535
+ id="end_year_spike",
536
+ name="End of Year Spending Spike",
537
+ category="spending_pattern",
538
+ properties={
539
+ "frequency": "annual",
540
+ "typical_months": [11, 12],
541
+ "average_severity": 0.85
542
+ }
543
+ )
544
+
545
+ await semantic.store_concept(pattern_concept)
546
+
547
+ # Verify knowledge transfer
548
+ retrieved = await semantic.get_concept("end_year_spike")
549
+ assert retrieved is not None
550
+ assert retrieved.properties["frequency"] == "annual"
551
+
552
+ @pytest.mark.asyncio
553
+ async def test_memory_cross_referencing(self):
554
+ """Test cross-referencing between memory types."""
555
+ episodic = EpisodicMemory()
556
+ semantic = SemanticMemory()
557
+ conversational = ConversationalMemory()
558
+
559
+ # Create related memories
560
+ episode = Episode(
561
+ id="cross_ref_ep",
562
+ type=EpisodeType.DISCOVERY,
563
+ content={
564
+ "discovery": "New fraud pattern",
565
+ "concept_id": "fraud_pattern_123"
566
+ }
567
+ )
568
+
569
+ concept = Concept(
570
+ id="fraud_pattern_123",
571
+ name="Invoice Splitting Fraud",
572
+ category="fraud_type"
573
+ )
574
+
575
+ turn = DialogTurn(
576
+ id="turn_cross",
577
+ conversation_id="conv_cross",
578
+ speaker="agent",
579
+ utterance="Discovered new invoice splitting fraud pattern",
580
+ entities=["fraud_pattern_123"]
581
+ )
582
+
583
+ # Store all
584
+ await episodic.store_episode(episode)
585
+ await semantic.store_concept(concept)
586
+ await conversational.add_turn(turn)
587
+
588
+ # Cross-reference
589
+ episode_ref = await episodic.get_episode("cross_ref_ep")
590
+ concept_ref = await semantic.get_concept(
591
+ episode_ref.content["concept_id"]
592
+ )
593
+
594
+ assert concept_ref is not None
595
+ assert concept_ref.name == "Invoice Splitting Fraud"